diff --git a/dateutil/parser/__init__.py b/dateutil/parser/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d174b0e4dcc472999b75e55ebb88af320ae38081 --- /dev/null +++ b/dateutil/parser/__init__.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +from ._parser import parse, parser, parserinfo, ParserError +from ._parser import DEFAULTPARSER, DEFAULTTZPARSER +from ._parser import UnknownTimezoneWarning + +from ._parser import __doc__ + +from .isoparser import isoparser, isoparse + +__all__ = ['parse', 'parser', 'parserinfo', + 'isoparse', 'isoparser', + 'ParserError', + 'UnknownTimezoneWarning'] + + +### +# Deprecate portions of the private interface so that downstream code that +# is improperly relying on it is given *some* notice. + + +def __deprecated_private_func(f): + from functools import wraps + import warnings + + msg = ('{name} is a private function and may break without warning, ' + 'it will be moved and or renamed in future versions.') + msg = msg.format(name=f.__name__) + + @wraps(f) + def deprecated_func(*args, **kwargs): + warnings.warn(msg, DeprecationWarning) + return f(*args, **kwargs) + + return deprecated_func + +def __deprecate_private_class(c): + import warnings + + msg = ('{name} is a private class and may break without warning, ' + 'it will be moved and or renamed in future versions.') + msg = msg.format(name=c.__name__) + + class private_class(c): + __doc__ = c.__doc__ + + def __init__(self, *args, **kwargs): + warnings.warn(msg, DeprecationWarning) + super(private_class, self).__init__(*args, **kwargs) + + private_class.__name__ = c.__name__ + + return private_class + + +from ._parser import _timelex, _resultbase +from ._parser import _tzparser, _parsetz + +_timelex = __deprecate_private_class(_timelex) +_tzparser = __deprecate_private_class(_tzparser) +_resultbase = __deprecate_private_class(_resultbase) +_parsetz = __deprecated_private_func(_parsetz) diff --git a/dateutil/parser/_parser.py b/dateutil/parser/_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..37d1663b2f72447800d9a553929e3de932244289 --- /dev/null +++ b/dateutil/parser/_parser.py @@ -0,0 +1,1613 @@ +# -*- coding: utf-8 -*- +""" +This module offers a generic date/time string parser which is able to parse +most known formats to represent a date and/or time. + +This module attempts to be forgiving with regards to unlikely input formats, +returning a datetime object even for dates which are ambiguous. If an element +of a date/time stamp is omitted, the following rules are applied: + +- If AM or PM is left unspecified, a 24-hour clock is assumed, however, an hour + on a 12-hour clock (``0 <= hour <= 12``) *must* be specified if AM or PM is + specified. +- If a time zone is omitted, a timezone-naive datetime is returned. + +If any other elements are missing, they are taken from the +:class:`datetime.datetime` object passed to the parameter ``default``. If this +results in a day number exceeding the valid number of days per month, the +value falls back to the end of the month. + +Additional resources about date/time string formats can be found below: + +- `A summary of the international standard date and time notation + `_ +- `W3C Date and Time Formats `_ +- `Time Formats (Planetary Rings Node) `_ +- `CPAN ParseDate module + `_ +- `Java SimpleDateFormat Class + `_ +""" +from __future__ import unicode_literals + +import datetime +import re +import string +import time +import warnings + +from calendar import monthrange +from io import StringIO + +import six +from six import integer_types, text_type + +from decimal import Decimal + +from warnings import warn + +from .. import relativedelta +from .. import tz + +__all__ = ["parse", "parserinfo", "ParserError"] + + +# TODO: pandas.core.tools.datetimes imports this explicitly. Might be worth +# making public and/or figuring out if there is something we can +# take off their plate. +class _timelex(object): + # Fractional seconds are sometimes split by a comma + _split_decimal = re.compile("([.,])") + + def __init__(self, instream): + if isinstance(instream, (bytes, bytearray)): + instream = instream.decode() + + if isinstance(instream, text_type): + instream = StringIO(instream) + elif getattr(instream, 'read', None) is None: + raise TypeError('Parser must be a string or character stream, not ' + '{itype}'.format(itype=instream.__class__.__name__)) + + self.instream = instream + self.charstack = [] + self.tokenstack = [] + self.eof = False + + def get_token(self): + """ + This function breaks the time string into lexical units (tokens), which + can be parsed by the parser. Lexical units are demarcated by changes in + the character set, so any continuous string of letters is considered + one unit, any continuous string of numbers is considered one unit. + + The main complication arises from the fact that dots ('.') can be used + both as separators (e.g. "Sep.20.2009") or decimal points (e.g. + "4:30:21.447"). As such, it is necessary to read the full context of + any dot-separated strings before breaking it into tokens; as such, this + function maintains a "token stack", for when the ambiguous context + demands that multiple tokens be parsed at once. + """ + if self.tokenstack: + return self.tokenstack.pop(0) + + seenletters = False + token = None + state = None + + while not self.eof: + # We only realize that we've reached the end of a token when we + # find a character that's not part of the current token - since + # that character may be part of the next token, it's stored in the + # charstack. + if self.charstack: + nextchar = self.charstack.pop(0) + else: + nextchar = self.instream.read(1) + while nextchar == '\x00': + nextchar = self.instream.read(1) + + if not nextchar: + self.eof = True + break + elif not state: + # First character of the token - determines if we're starting + # to parse a word, a number or something else. + token = nextchar + if self.isword(nextchar): + state = 'a' + elif self.isnum(nextchar): + state = '0' + elif self.isspace(nextchar): + token = ' ' + break # emit token + else: + break # emit token + elif state == 'a': + # If we've already started reading a word, we keep reading + # letters until we find something that's not part of a word. + seenletters = True + if self.isword(nextchar): + token += nextchar + elif nextchar == '.': + token += nextchar + state = 'a.' + else: + self.charstack.append(nextchar) + break # emit token + elif state == '0': + # If we've already started reading a number, we keep reading + # numbers until we find something that doesn't fit. + if self.isnum(nextchar): + token += nextchar + elif nextchar == '.' or (nextchar == ',' and len(token) >= 2): + token += nextchar + state = '0.' + else: + self.charstack.append(nextchar) + break # emit token + elif state == 'a.': + # If we've seen some letters and a dot separator, continue + # parsing, and the tokens will be broken up later. + seenletters = True + if nextchar == '.' or self.isword(nextchar): + token += nextchar + elif self.isnum(nextchar) and token[-1] == '.': + token += nextchar + state = '0.' + else: + self.charstack.append(nextchar) + break # emit token + elif state == '0.': + # If we've seen at least one dot separator, keep going, we'll + # break up the tokens later. + if nextchar == '.' or self.isnum(nextchar): + token += nextchar + elif self.isword(nextchar) and token[-1] == '.': + token += nextchar + state = 'a.' + else: + self.charstack.append(nextchar) + break # emit token + + if (state in ('a.', '0.') and (seenletters or token.count('.') > 1 or + token[-1] in '.,')): + l = self._split_decimal.split(token) + token = l[0] + for tok in l[1:]: + if tok: + self.tokenstack.append(tok) + + if state == '0.' and token.count('.') == 0: + token = token.replace(',', '.') + + return token + + def __iter__(self): + return self + + def __next__(self): + token = self.get_token() + if token is None: + raise StopIteration + + return token + + def next(self): + return self.__next__() # Python 2.x support + + @classmethod + def split(cls, s): + return list(cls(s)) + + @classmethod + def isword(cls, nextchar): + """ Whether or not the next character is part of a word """ + return nextchar.isalpha() + + @classmethod + def isnum(cls, nextchar): + """ Whether the next character is part of a number """ + return nextchar.isdigit() + + @classmethod + def isspace(cls, nextchar): + """ Whether the next character is whitespace """ + return nextchar.isspace() + + +class _resultbase(object): + + def __init__(self): + for attr in self.__slots__: + setattr(self, attr, None) + + def _repr(self, classname): + l = [] + for attr in self.__slots__: + value = getattr(self, attr) + if value is not None: + l.append("%s=%s" % (attr, repr(value))) + return "%s(%s)" % (classname, ", ".join(l)) + + def __len__(self): + return (sum(getattr(self, attr) is not None + for attr in self.__slots__)) + + def __repr__(self): + return self._repr(self.__class__.__name__) + + +class parserinfo(object): + """ + Class which handles what inputs are accepted. Subclass this to customize + the language and acceptable values for each parameter. + + :param dayfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the day (``True``) or month (``False``). If + ``yearfirst`` is set to ``True``, this distinguishes between YDM + and YMD. Default is ``False``. + + :param yearfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the year. If ``True``, the first number is taken + to be the year, otherwise the last number is taken to be the year. + Default is ``False``. + """ + + # m from a.m/p.m, t from ISO T separator + JUMP = [" ", ".", ",", ";", "-", "/", "'", + "at", "on", "and", "ad", "m", "t", "of", + "st", "nd", "rd", "th"] + + WEEKDAYS = [("Mon", "Monday"), + ("Tue", "Tuesday"), # TODO: "Tues" + ("Wed", "Wednesday"), + ("Thu", "Thursday"), # TODO: "Thurs" + ("Fri", "Friday"), + ("Sat", "Saturday"), + ("Sun", "Sunday")] + MONTHS = [("Jan", "January"), + ("Feb", "February"), # TODO: "Febr" + ("Mar", "March"), + ("Apr", "April"), + ("May", "May"), + ("Jun", "June"), + ("Jul", "July"), + ("Aug", "August"), + ("Sep", "Sept", "September"), + ("Oct", "October"), + ("Nov", "November"), + ("Dec", "December")] + HMS = [("h", "hour", "hours"), + ("m", "minute", "minutes"), + ("s", "second", "seconds")] + AMPM = [("am", "a"), + ("pm", "p")] + UTCZONE = ["UTC", "GMT", "Z", "z"] + PERTAIN = ["of"] + TZOFFSET = {} + # TODO: ERA = ["AD", "BC", "CE", "BCE", "Stardate", + # "Anno Domini", "Year of Our Lord"] + + def __init__(self, dayfirst=False, yearfirst=False): + self._jump = self._convert(self.JUMP) + self._weekdays = self._convert(self.WEEKDAYS) + self._months = self._convert(self.MONTHS) + self._hms = self._convert(self.HMS) + self._ampm = self._convert(self.AMPM) + self._utczone = self._convert(self.UTCZONE) + self._pertain = self._convert(self.PERTAIN) + + self.dayfirst = dayfirst + self.yearfirst = yearfirst + + self._year = time.localtime().tm_year + self._century = self._year // 100 * 100 + + def _convert(self, lst): + dct = {} + for i, v in enumerate(lst): + if isinstance(v, tuple): + for v in v: + dct[v.lower()] = i + else: + dct[v.lower()] = i + return dct + + def jump(self, name): + return name.lower() in self._jump + + def weekday(self, name): + try: + return self._weekdays[name.lower()] + except KeyError: + pass + return None + + def month(self, name): + try: + return self._months[name.lower()] + 1 + except KeyError: + pass + return None + + def hms(self, name): + try: + return self._hms[name.lower()] + except KeyError: + return None + + def ampm(self, name): + try: + return self._ampm[name.lower()] + except KeyError: + return None + + def pertain(self, name): + return name.lower() in self._pertain + + def utczone(self, name): + return name.lower() in self._utczone + + def tzoffset(self, name): + if name in self._utczone: + return 0 + + return self.TZOFFSET.get(name) + + def convertyear(self, year, century_specified=False): + """ + Converts two-digit years to year within [-50, 49] + range of self._year (current local time) + """ + + # Function contract is that the year is always positive + assert year >= 0 + + if year < 100 and not century_specified: + # assume current century to start + year += self._century + + if year >= self._year + 50: # if too far in future + year -= 100 + elif year < self._year - 50: # if too far in past + year += 100 + + return year + + def validate(self, res): + # move to info + if res.year is not None: + res.year = self.convertyear(res.year, res.century_specified) + + if ((res.tzoffset == 0 and not res.tzname) or + (res.tzname == 'Z' or res.tzname == 'z')): + res.tzname = "UTC" + res.tzoffset = 0 + elif res.tzoffset != 0 and res.tzname and self.utczone(res.tzname): + res.tzoffset = 0 + return True + + +class _ymd(list): + def __init__(self, *args, **kwargs): + super(self.__class__, self).__init__(*args, **kwargs) + self.century_specified = False + self.dstridx = None + self.mstridx = None + self.ystridx = None + + @property + def has_year(self): + return self.ystridx is not None + + @property + def has_month(self): + return self.mstridx is not None + + @property + def has_day(self): + return self.dstridx is not None + + def could_be_day(self, value): + if self.has_day: + return False + elif not self.has_month: + return 1 <= value <= 31 + elif not self.has_year: + # Be permissive, assume leap year + month = self[self.mstridx] + return 1 <= value <= monthrange(2000, month)[1] + else: + month = self[self.mstridx] + year = self[self.ystridx] + return 1 <= value <= monthrange(year, month)[1] + + def append(self, val, label=None): + if hasattr(val, '__len__'): + if val.isdigit() and len(val) > 2: + self.century_specified = True + if label not in [None, 'Y']: # pragma: no cover + raise ValueError(label) + label = 'Y' + elif val > 100: + self.century_specified = True + if label not in [None, 'Y']: # pragma: no cover + raise ValueError(label) + label = 'Y' + + super(self.__class__, self).append(int(val)) + + if label == 'M': + if self.has_month: + raise ValueError('Month is already set') + self.mstridx = len(self) - 1 + elif label == 'D': + if self.has_day: + raise ValueError('Day is already set') + self.dstridx = len(self) - 1 + elif label == 'Y': + if self.has_year: + raise ValueError('Year is already set') + self.ystridx = len(self) - 1 + + def _resolve_from_stridxs(self, strids): + """ + Try to resolve the identities of year/month/day elements using + ystridx, mstridx, and dstridx, if enough of these are specified. + """ + if len(self) == 3 and len(strids) == 2: + # we can back out the remaining stridx value + missing = [x for x in range(3) if x not in strids.values()] + key = [x for x in ['y', 'm', 'd'] if x not in strids] + assert len(missing) == len(key) == 1 + key = key[0] + val = missing[0] + strids[key] = val + + assert len(self) == len(strids) # otherwise this should not be called + out = {key: self[strids[key]] for key in strids} + return (out.get('y'), out.get('m'), out.get('d')) + + def resolve_ymd(self, yearfirst, dayfirst): + len_ymd = len(self) + year, month, day = (None, None, None) + + strids = (('y', self.ystridx), + ('m', self.mstridx), + ('d', self.dstridx)) + + strids = {key: val for key, val in strids if val is not None} + if (len(self) == len(strids) > 0 or + (len(self) == 3 and len(strids) == 2)): + return self._resolve_from_stridxs(strids) + + mstridx = self.mstridx + + if len_ymd > 3: + raise ValueError("More than three YMD values") + elif len_ymd == 1 or (mstridx is not None and len_ymd == 2): + # One member, or two members with a month string + if mstridx is not None: + month = self[mstridx] + # since mstridx is 0 or 1, self[mstridx-1] always + # looks up the other element + other = self[mstridx - 1] + else: + other = self[0] + + if len_ymd > 1 or mstridx is None: + if other > 31: + year = other + else: + day = other + + elif len_ymd == 2: + # Two members with numbers + if self[0] > 31: + # 99-01 + year, month = self + elif self[1] > 31: + # 01-99 + month, year = self + elif dayfirst and self[1] <= 12: + # 13-01 + day, month = self + else: + # 01-13 + month, day = self + + elif len_ymd == 3: + # Three members + if mstridx == 0: + if self[1] > 31: + # Apr-2003-25 + month, year, day = self + else: + month, day, year = self + elif mstridx == 1: + if self[0] > 31 or (yearfirst and self[2] <= 31): + # 99-Jan-01 + year, month, day = self + else: + # 01-Jan-01 + # Give precedence to day-first, since + # two-digit years is usually hand-written. + day, month, year = self + + elif mstridx == 2: + # WTF!? + if self[1] > 31: + # 01-99-Jan + day, year, month = self + else: + # 99-01-Jan + year, day, month = self + + else: + if (self[0] > 31 or + self.ystridx == 0 or + (yearfirst and self[1] <= 12 and self[2] <= 31)): + # 99-01-01 + if dayfirst and self[2] <= 12: + year, day, month = self + else: + year, month, day = self + elif self[0] > 12 or (dayfirst and self[1] <= 12): + # 13-01-01 + day, month, year = self + else: + # 01-13-01 + month, day, year = self + + return year, month, day + + +class parser(object): + def __init__(self, info=None): + self.info = info or parserinfo() + + def parse(self, timestr, default=None, + ignoretz=False, tzinfos=None, **kwargs): + """ + Parse the date/time string into a :class:`datetime.datetime` object. + + :param timestr: + Any date/time string using the supported formats. + + :param default: + The default datetime object, if this is a datetime object and not + ``None``, elements specified in ``timestr`` replace elements in the + default object. + + :param ignoretz: + If set ``True``, time zones in parsed strings are ignored and a + naive :class:`datetime.datetime` object is returned. + + :param tzinfos: + Additional time zone names / aliases which may be present in the + string. This argument maps time zone names (and optionally offsets + from those time zones) to time zones. This parameter can be a + dictionary with timezone aliases mapping time zone names to time + zones or a function taking two parameters (``tzname`` and + ``tzoffset``) and returning a time zone. + + The timezones to which the names are mapped can be an integer + offset from UTC in seconds or a :class:`tzinfo` object. + + .. doctest:: + :options: +NORMALIZE_WHITESPACE + + >>> from dateutil.parser import parse + >>> from dateutil.tz import gettz + >>> tzinfos = {"BRST": -7200, "CST": gettz("America/Chicago")} + >>> parse("2012-01-19 17:21:00 BRST", tzinfos=tzinfos) + datetime.datetime(2012, 1, 19, 17, 21, tzinfo=tzoffset(u'BRST', -7200)) + >>> parse("2012-01-19 17:21:00 CST", tzinfos=tzinfos) + datetime.datetime(2012, 1, 19, 17, 21, + tzinfo=tzfile('/usr/share/zoneinfo/America/Chicago')) + + This parameter is ignored if ``ignoretz`` is set. + + :param \\*\\*kwargs: + Keyword arguments as passed to ``_parse()``. + + :return: + Returns a :class:`datetime.datetime` object or, if the + ``fuzzy_with_tokens`` option is ``True``, returns a tuple, the + first element being a :class:`datetime.datetime` object, the second + a tuple containing the fuzzy tokens. + + :raises ParserError: + Raised for invalid or unknown string format, if the provided + :class:`tzinfo` is not in a valid format, or if an invalid date + would be created. + + :raises TypeError: + Raised for non-string or character stream input. + + :raises OverflowError: + Raised if the parsed date exceeds the largest valid C integer on + your system. + """ + + if default is None: + default = datetime.datetime.now().replace(hour=0, minute=0, + second=0, microsecond=0) + + res, skipped_tokens = self._parse(timestr, **kwargs) + + if res is None: + raise ParserError("Unknown string format: %s", timestr) + + if len(res) == 0: + raise ParserError("String does not contain a date: %s", timestr) + + try: + ret = self._build_naive(res, default) + except ValueError as e: + six.raise_from(ParserError(str(e) + ": %s", timestr), e) + + if not ignoretz: + ret = self._build_tzaware(ret, res, tzinfos) + + if kwargs.get('fuzzy_with_tokens', False): + return ret, skipped_tokens + else: + return ret + + class _result(_resultbase): + __slots__ = ["year", "month", "day", "weekday", + "hour", "minute", "second", "microsecond", + "tzname", "tzoffset", "ampm","any_unused_tokens"] + + def _parse(self, timestr, dayfirst=None, yearfirst=None, fuzzy=False, + fuzzy_with_tokens=False): + """ + Private method which performs the heavy lifting of parsing, called from + ``parse()``, which passes on its ``kwargs`` to this function. + + :param timestr: + The string to parse. + + :param dayfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the day (``True``) or month (``False``). If + ``yearfirst`` is set to ``True``, this distinguishes between YDM + and YMD. If set to ``None``, this value is retrieved from the + current :class:`parserinfo` object (which itself defaults to + ``False``). + + :param yearfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the year. If ``True``, the first number is taken + to be the year, otherwise the last number is taken to be the year. + If this is set to ``None``, the value is retrieved from the current + :class:`parserinfo` object (which itself defaults to ``False``). + + :param fuzzy: + Whether to allow fuzzy parsing, allowing for string like "Today is + January 1, 2047 at 8:21:00AM". + + :param fuzzy_with_tokens: + If ``True``, ``fuzzy`` is automatically set to True, and the parser + will return a tuple where the first element is the parsed + :class:`datetime.datetime` datetimestamp and the second element is + a tuple containing the portions of the string which were ignored: + + .. doctest:: + + >>> from dateutil.parser import parse + >>> parse("Today is January 1, 2047 at 8:21:00AM", fuzzy_with_tokens=True) + (datetime.datetime(2047, 1, 1, 8, 21), (u'Today is ', u' ', u'at ')) + + """ + if fuzzy_with_tokens: + fuzzy = True + + info = self.info + + if dayfirst is None: + dayfirst = info.dayfirst + + if yearfirst is None: + yearfirst = info.yearfirst + + res = self._result() + l = _timelex.split(timestr) # Splits the timestr into tokens + + skipped_idxs = [] + + # year/month/day list + ymd = _ymd() + + len_l = len(l) + i = 0 + try: + while i < len_l: + + # Check if it's a number + value_repr = l[i] + try: + value = float(value_repr) + except ValueError: + value = None + + if value is not None: + # Numeric token + i = self._parse_numeric_token(l, i, info, ymd, res, fuzzy) + + # Check weekday + elif info.weekday(l[i]) is not None: + value = info.weekday(l[i]) + res.weekday = value + + # Check month name + elif info.month(l[i]) is not None: + value = info.month(l[i]) + ymd.append(value, 'M') + + if i + 1 < len_l: + if l[i + 1] in ('-', '/'): + # Jan-01[-99] + sep = l[i + 1] + ymd.append(l[i + 2]) + + if i + 3 < len_l and l[i + 3] == sep: + # Jan-01-99 + ymd.append(l[i + 4]) + i += 2 + + i += 2 + + elif (i + 4 < len_l and l[i + 1] == l[i + 3] == ' ' and + info.pertain(l[i + 2])): + # Jan of 01 + # In this case, 01 is clearly year + if l[i + 4].isdigit(): + # Convert it here to become unambiguous + value = int(l[i + 4]) + year = str(info.convertyear(value)) + ymd.append(year, 'Y') + else: + # Wrong guess + pass + # TODO: not hit in tests + i += 4 + + # Check am/pm + elif info.ampm(l[i]) is not None: + value = info.ampm(l[i]) + val_is_ampm = self._ampm_valid(res.hour, res.ampm, fuzzy) + + if val_is_ampm: + res.hour = self._adjust_ampm(res.hour, value) + res.ampm = value + + elif fuzzy: + skipped_idxs.append(i) + + # Check for a timezone name + elif self._could_be_tzname(res.hour, res.tzname, res.tzoffset, l[i]): + res.tzname = l[i] + res.tzoffset = info.tzoffset(res.tzname) + + # Check for something like GMT+3, or BRST+3. Notice + # that it doesn't mean "I am 3 hours after GMT", but + # "my time +3 is GMT". If found, we reverse the + # logic so that timezone parsing code will get it + # right. + if i + 1 < len_l and l[i + 1] in ('+', '-'): + l[i + 1] = ('+', '-')[l[i + 1] == '+'] + res.tzoffset = None + if info.utczone(res.tzname): + # With something like GMT+3, the timezone + # is *not* GMT. + res.tzname = None + + # Check for a numbered timezone + elif res.hour is not None and l[i] in ('+', '-'): + signal = (-1, 1)[l[i] == '+'] + len_li = len(l[i + 1]) + + # TODO: check that l[i + 1] is integer? + if len_li == 4: + # -0300 + hour_offset = int(l[i + 1][:2]) + min_offset = int(l[i + 1][2:]) + elif i + 2 < len_l and l[i + 2] == ':': + # -03:00 + hour_offset = int(l[i + 1]) + min_offset = int(l[i + 3]) # TODO: Check that l[i+3] is minute-like? + i += 2 + elif len_li <= 2: + # -[0]3 + hour_offset = int(l[i + 1][:2]) + min_offset = 0 + else: + raise ValueError(timestr) + + res.tzoffset = signal * (hour_offset * 3600 + min_offset * 60) + + # Look for a timezone name between parenthesis + if (i + 5 < len_l and + info.jump(l[i + 2]) and l[i + 3] == '(' and + l[i + 5] == ')' and + 3 <= len(l[i + 4]) and + self._could_be_tzname(res.hour, res.tzname, + None, l[i + 4])): + # -0300 (BRST) + res.tzname = l[i + 4] + i += 4 + + i += 1 + + # Check jumps + elif not (info.jump(l[i]) or fuzzy): + raise ValueError(timestr) + + else: + skipped_idxs.append(i) + i += 1 + + # Process year/month/day + year, month, day = ymd.resolve_ymd(yearfirst, dayfirst) + + res.century_specified = ymd.century_specified + res.year = year + res.month = month + res.day = day + + except (IndexError, ValueError): + return None, None + + if not info.validate(res): + return None, None + + if fuzzy_with_tokens: + skipped_tokens = self._recombine_skipped(l, skipped_idxs) + return res, tuple(skipped_tokens) + else: + return res, None + + def _parse_numeric_token(self, tokens, idx, info, ymd, res, fuzzy): + # Token is a number + value_repr = tokens[idx] + try: + value = self._to_decimal(value_repr) + except Exception as e: + six.raise_from(ValueError('Unknown numeric token'), e) + + len_li = len(value_repr) + + len_l = len(tokens) + + if (len(ymd) == 3 and len_li in (2, 4) and + res.hour is None and + (idx + 1 >= len_l or + (tokens[idx + 1] != ':' and + info.hms(tokens[idx + 1]) is None))): + # 19990101T23[59] + s = tokens[idx] + res.hour = int(s[:2]) + + if len_li == 4: + res.minute = int(s[2:]) + + elif len_li == 6 or (len_li > 6 and tokens[idx].find('.') == 6): + # YYMMDD or HHMMSS[.ss] + s = tokens[idx] + + if not ymd and '.' not in tokens[idx]: + ymd.append(s[:2]) + ymd.append(s[2:4]) + ymd.append(s[4:]) + else: + # 19990101T235959[.59] + + # TODO: Check if res attributes already set. + res.hour = int(s[:2]) + res.minute = int(s[2:4]) + res.second, res.microsecond = self._parsems(s[4:]) + + elif len_li in (8, 12, 14): + # YYYYMMDD + s = tokens[idx] + ymd.append(s[:4], 'Y') + ymd.append(s[4:6]) + ymd.append(s[6:8]) + + if len_li > 8: + res.hour = int(s[8:10]) + res.minute = int(s[10:12]) + + if len_li > 12: + res.second = int(s[12:]) + + elif self._find_hms_idx(idx, tokens, info, allow_jump=True) is not None: + # HH[ ]h or MM[ ]m or SS[.ss][ ]s + hms_idx = self._find_hms_idx(idx, tokens, info, allow_jump=True) + (idx, hms) = self._parse_hms(idx, tokens, info, hms_idx) + if hms is not None: + # TODO: checking that hour/minute/second are not + # already set? + self._assign_hms(res, value_repr, hms) + + elif idx + 2 < len_l and tokens[idx + 1] == ':': + # HH:MM[:SS[.ss]] + res.hour = int(value) + value = self._to_decimal(tokens[idx + 2]) # TODO: try/except for this? + (res.minute, res.second) = self._parse_min_sec(value) + + if idx + 4 < len_l and tokens[idx + 3] == ':': + res.second, res.microsecond = self._parsems(tokens[idx + 4]) + + idx += 2 + + idx += 2 + + elif idx + 1 < len_l and tokens[idx + 1] in ('-', '/', '.'): + sep = tokens[idx + 1] + ymd.append(value_repr) + + if idx + 2 < len_l and not info.jump(tokens[idx + 2]): + if tokens[idx + 2].isdigit(): + # 01-01[-01] + ymd.append(tokens[idx + 2]) + else: + # 01-Jan[-01] + value = info.month(tokens[idx + 2]) + + if value is not None: + ymd.append(value, 'M') + else: + raise ValueError() + + if idx + 3 < len_l and tokens[idx + 3] == sep: + # We have three members + value = info.month(tokens[idx + 4]) + + if value is not None: + ymd.append(value, 'M') + else: + ymd.append(tokens[idx + 4]) + idx += 2 + + idx += 1 + idx += 1 + + elif idx + 1 >= len_l or info.jump(tokens[idx + 1]): + if idx + 2 < len_l and info.ampm(tokens[idx + 2]) is not None: + # 12 am + hour = int(value) + res.hour = self._adjust_ampm(hour, info.ampm(tokens[idx + 2])) + idx += 1 + else: + # Year, month or day + ymd.append(value) + idx += 1 + + elif info.ampm(tokens[idx + 1]) is not None and (0 <= value < 24): + # 12am + hour = int(value) + res.hour = self._adjust_ampm(hour, info.ampm(tokens[idx + 1])) + idx += 1 + + elif ymd.could_be_day(value): + ymd.append(value) + + elif not fuzzy: + raise ValueError() + + return idx + + def _find_hms_idx(self, idx, tokens, info, allow_jump): + len_l = len(tokens) + + if idx+1 < len_l and info.hms(tokens[idx+1]) is not None: + # There is an "h", "m", or "s" label following this token. We take + # assign the upcoming label to the current token. + # e.g. the "12" in 12h" + hms_idx = idx + 1 + + elif (allow_jump and idx+2 < len_l and tokens[idx+1] == ' ' and + info.hms(tokens[idx+2]) is not None): + # There is a space and then an "h", "m", or "s" label. + # e.g. the "12" in "12 h" + hms_idx = idx + 2 + + elif idx > 0 and info.hms(tokens[idx-1]) is not None: + # There is a "h", "m", or "s" preceding this token. Since neither + # of the previous cases was hit, there is no label following this + # token, so we use the previous label. + # e.g. the "04" in "12h04" + hms_idx = idx-1 + + elif (1 < idx == len_l-1 and tokens[idx-1] == ' ' and + info.hms(tokens[idx-2]) is not None): + # If we are looking at the final token, we allow for a + # backward-looking check to skip over a space. + # TODO: Are we sure this is the right condition here? + hms_idx = idx - 2 + + else: + hms_idx = None + + return hms_idx + + def _assign_hms(self, res, value_repr, hms): + # See GH issue #427, fixing float rounding + value = self._to_decimal(value_repr) + + if hms == 0: + # Hour + res.hour = int(value) + if value % 1: + res.minute = int(60*(value % 1)) + + elif hms == 1: + (res.minute, res.second) = self._parse_min_sec(value) + + elif hms == 2: + (res.second, res.microsecond) = self._parsems(value_repr) + + def _could_be_tzname(self, hour, tzname, tzoffset, token): + return (hour is not None and + tzname is None and + tzoffset is None and + len(token) <= 5 and + (all(x in string.ascii_uppercase for x in token) + or token in self.info.UTCZONE)) + + def _ampm_valid(self, hour, ampm, fuzzy): + """ + For fuzzy parsing, 'a' or 'am' (both valid English words) + may erroneously trigger the AM/PM flag. Deal with that + here. + """ + val_is_ampm = True + + # If there's already an AM/PM flag, this one isn't one. + if fuzzy and ampm is not None: + val_is_ampm = False + + # If AM/PM is found and hour is not, raise a ValueError + if hour is None: + if fuzzy: + val_is_ampm = False + else: + raise ValueError('No hour specified with AM or PM flag.') + elif not 0 <= hour <= 12: + # If AM/PM is found, it's a 12 hour clock, so raise + # an error for invalid range + if fuzzy: + val_is_ampm = False + else: + raise ValueError('Invalid hour specified for 12-hour clock.') + + return val_is_ampm + + def _adjust_ampm(self, hour, ampm): + if hour < 12 and ampm == 1: + hour += 12 + elif hour == 12 and ampm == 0: + hour = 0 + return hour + + def _parse_min_sec(self, value): + # TODO: Every usage of this function sets res.second to the return + # value. Are there any cases where second will be returned as None and + # we *don't* want to set res.second = None? + minute = int(value) + second = None + + sec_remainder = value % 1 + if sec_remainder: + second = int(60 * sec_remainder) + return (minute, second) + + def _parse_hms(self, idx, tokens, info, hms_idx): + # TODO: Is this going to admit a lot of false-positives for when we + # just happen to have digits and "h", "m" or "s" characters in non-date + # text? I guess hex hashes won't have that problem, but there's plenty + # of random junk out there. + if hms_idx is None: + hms = None + new_idx = idx + elif hms_idx > idx: + hms = info.hms(tokens[hms_idx]) + new_idx = hms_idx + else: + # Looking backwards, increment one. + hms = info.hms(tokens[hms_idx]) + 1 + new_idx = idx + + return (new_idx, hms) + + # ------------------------------------------------------------------ + # Handling for individual tokens. These are kept as methods instead + # of functions for the sake of customizability via subclassing. + + def _parsems(self, value): + """Parse a I[.F] seconds value into (seconds, microseconds).""" + if "." not in value: + return int(value), 0 + else: + i, f = value.split(".") + return int(i), int(f.ljust(6, "0")[:6]) + + def _to_decimal(self, val): + try: + decimal_value = Decimal(val) + # See GH 662, edge case, infinite value should not be converted + # via `_to_decimal` + if not decimal_value.is_finite(): + raise ValueError("Converted decimal value is infinite or NaN") + except Exception as e: + msg = "Could not convert %s to decimal" % val + six.raise_from(ValueError(msg), e) + else: + return decimal_value + + # ------------------------------------------------------------------ + # Post-Parsing construction of datetime output. These are kept as + # methods instead of functions for the sake of customizability via + # subclassing. + + def _build_tzinfo(self, tzinfos, tzname, tzoffset): + if callable(tzinfos): + tzdata = tzinfos(tzname, tzoffset) + else: + tzdata = tzinfos.get(tzname) + # handle case where tzinfo is paased an options that returns None + # eg tzinfos = {'BRST' : None} + if isinstance(tzdata, datetime.tzinfo) or tzdata is None: + tzinfo = tzdata + elif isinstance(tzdata, text_type): + tzinfo = tz.tzstr(tzdata) + elif isinstance(tzdata, integer_types): + tzinfo = tz.tzoffset(tzname, tzdata) + else: + raise TypeError("Offset must be tzinfo subclass, tz string, " + "or int offset.") + return tzinfo + + def _build_tzaware(self, naive, res, tzinfos): + if (callable(tzinfos) or (tzinfos and res.tzname in tzinfos)): + tzinfo = self._build_tzinfo(tzinfos, res.tzname, res.tzoffset) + aware = naive.replace(tzinfo=tzinfo) + aware = self._assign_tzname(aware, res.tzname) + + elif res.tzname and res.tzname in time.tzname: + aware = naive.replace(tzinfo=tz.tzlocal()) + + # Handle ambiguous local datetime + aware = self._assign_tzname(aware, res.tzname) + + # This is mostly relevant for winter GMT zones parsed in the UK + if (aware.tzname() != res.tzname and + res.tzname in self.info.UTCZONE): + aware = aware.replace(tzinfo=tz.UTC) + + elif res.tzoffset == 0: + aware = naive.replace(tzinfo=tz.UTC) + + elif res.tzoffset: + aware = naive.replace(tzinfo=tz.tzoffset(res.tzname, res.tzoffset)) + + elif not res.tzname and not res.tzoffset: + # i.e. no timezone information was found. + aware = naive + + elif res.tzname: + # tz-like string was parsed but we don't know what to do + # with it + warnings.warn("tzname {tzname} identified but not understood. " + "Pass `tzinfos` argument in order to correctly " + "return a timezone-aware datetime. In a future " + "version, this will raise an " + "exception.".format(tzname=res.tzname), + category=UnknownTimezoneWarning) + aware = naive + + return aware + + def _build_naive(self, res, default): + repl = {} + for attr in ("year", "month", "day", "hour", + "minute", "second", "microsecond"): + value = getattr(res, attr) + if value is not None: + repl[attr] = value + + if 'day' not in repl: + # If the default day exceeds the last day of the month, fall back + # to the end of the month. + cyear = default.year if res.year is None else res.year + cmonth = default.month if res.month is None else res.month + cday = default.day if res.day is None else res.day + + if cday > monthrange(cyear, cmonth)[1]: + repl['day'] = monthrange(cyear, cmonth)[1] + + naive = default.replace(**repl) + + if res.weekday is not None and not res.day: + naive = naive + relativedelta.relativedelta(weekday=res.weekday) + + return naive + + def _assign_tzname(self, dt, tzname): + if dt.tzname() != tzname: + new_dt = tz.enfold(dt, fold=1) + if new_dt.tzname() == tzname: + return new_dt + + return dt + + def _recombine_skipped(self, tokens, skipped_idxs): + """ + >>> tokens = ["foo", " ", "bar", " ", "19June2000", "baz"] + >>> skipped_idxs = [0, 1, 2, 5] + >>> _recombine_skipped(tokens, skipped_idxs) + ["foo bar", "baz"] + """ + skipped_tokens = [] + for i, idx in enumerate(sorted(skipped_idxs)): + if i > 0 and idx - 1 == skipped_idxs[i - 1]: + skipped_tokens[-1] = skipped_tokens[-1] + tokens[idx] + else: + skipped_tokens.append(tokens[idx]) + + return skipped_tokens + + +DEFAULTPARSER = parser() + + +def parse(timestr, parserinfo=None, **kwargs): + """ + + Parse a string in one of the supported formats, using the + ``parserinfo`` parameters. + + :param timestr: + A string containing a date/time stamp. + + :param parserinfo: + A :class:`parserinfo` object containing parameters for the parser. + If ``None``, the default arguments to the :class:`parserinfo` + constructor are used. + + The ``**kwargs`` parameter takes the following keyword arguments: + + :param default: + The default datetime object, if this is a datetime object and not + ``None``, elements specified in ``timestr`` replace elements in the + default object. + + :param ignoretz: + If set ``True``, time zones in parsed strings are ignored and a naive + :class:`datetime` object is returned. + + :param tzinfos: + Additional time zone names / aliases which may be present in the + string. This argument maps time zone names (and optionally offsets + from those time zones) to time zones. This parameter can be a + dictionary with timezone aliases mapping time zone names to time + zones or a function taking two parameters (``tzname`` and + ``tzoffset``) and returning a time zone. + + The timezones to which the names are mapped can be an integer + offset from UTC in seconds or a :class:`tzinfo` object. + + .. doctest:: + :options: +NORMALIZE_WHITESPACE + + >>> from dateutil.parser import parse + >>> from dateutil.tz import gettz + >>> tzinfos = {"BRST": -7200, "CST": gettz("America/Chicago")} + >>> parse("2012-01-19 17:21:00 BRST", tzinfos=tzinfos) + datetime.datetime(2012, 1, 19, 17, 21, tzinfo=tzoffset(u'BRST', -7200)) + >>> parse("2012-01-19 17:21:00 CST", tzinfos=tzinfos) + datetime.datetime(2012, 1, 19, 17, 21, + tzinfo=tzfile('/usr/share/zoneinfo/America/Chicago')) + + This parameter is ignored if ``ignoretz`` is set. + + :param dayfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the day (``True``) or month (``False``). If + ``yearfirst`` is set to ``True``, this distinguishes between YDM and + YMD. If set to ``None``, this value is retrieved from the current + :class:`parserinfo` object (which itself defaults to ``False``). + + :param yearfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the year. If ``True``, the first number is taken to + be the year, otherwise the last number is taken to be the year. If + this is set to ``None``, the value is retrieved from the current + :class:`parserinfo` object (which itself defaults to ``False``). + + :param fuzzy: + Whether to allow fuzzy parsing, allowing for string like "Today is + January 1, 2047 at 8:21:00AM". + + :param fuzzy_with_tokens: + If ``True``, ``fuzzy`` is automatically set to True, and the parser + will return a tuple where the first element is the parsed + :class:`datetime.datetime` datetimestamp and the second element is + a tuple containing the portions of the string which were ignored: + + .. doctest:: + + >>> from dateutil.parser import parse + >>> parse("Today is January 1, 2047 at 8:21:00AM", fuzzy_with_tokens=True) + (datetime.datetime(2047, 1, 1, 8, 21), (u'Today is ', u' ', u'at ')) + + :return: + Returns a :class:`datetime.datetime` object or, if the + ``fuzzy_with_tokens`` option is ``True``, returns a tuple, the + first element being a :class:`datetime.datetime` object, the second + a tuple containing the fuzzy tokens. + + :raises ParserError: + Raised for invalid or unknown string formats, if the provided + :class:`tzinfo` is not in a valid format, or if an invalid date would + be created. + + :raises OverflowError: + Raised if the parsed date exceeds the largest valid C integer on + your system. + """ + if parserinfo: + return parser(parserinfo).parse(timestr, **kwargs) + else: + return DEFAULTPARSER.parse(timestr, **kwargs) + + +class _tzparser(object): + + class _result(_resultbase): + + __slots__ = ["stdabbr", "stdoffset", "dstabbr", "dstoffset", + "start", "end"] + + class _attr(_resultbase): + __slots__ = ["month", "week", "weekday", + "yday", "jyday", "day", "time"] + + def __repr__(self): + return self._repr("") + + def __init__(self): + _resultbase.__init__(self) + self.start = self._attr() + self.end = self._attr() + + def parse(self, tzstr): + res = self._result() + l = [x for x in re.split(r'([,:.]|[a-zA-Z]+|[0-9]+)',tzstr) if x] + used_idxs = list() + try: + + len_l = len(l) + + i = 0 + while i < len_l: + # BRST+3[BRDT[+2]] + j = i + while j < len_l and not [x for x in l[j] + if x in "0123456789:,-+"]: + j += 1 + if j != i: + if not res.stdabbr: + offattr = "stdoffset" + res.stdabbr = "".join(l[i:j]) + else: + offattr = "dstoffset" + res.dstabbr = "".join(l[i:j]) + + for ii in range(j): + used_idxs.append(ii) + i = j + if (i < len_l and (l[i] in ('+', '-') or l[i][0] in + "0123456789")): + if l[i] in ('+', '-'): + # Yes, that's right. See the TZ variable + # documentation. + signal = (1, -1)[l[i] == '+'] + used_idxs.append(i) + i += 1 + else: + signal = -1 + len_li = len(l[i]) + if len_li == 4: + # -0300 + setattr(res, offattr, (int(l[i][:2]) * 3600 + + int(l[i][2:]) * 60) * signal) + elif i + 1 < len_l and l[i + 1] == ':': + # -03:00 + setattr(res, offattr, + (int(l[i]) * 3600 + + int(l[i + 2]) * 60) * signal) + used_idxs.append(i) + i += 2 + elif len_li <= 2: + # -[0]3 + setattr(res, offattr, + int(l[i][:2]) * 3600 * signal) + else: + return None + used_idxs.append(i) + i += 1 + if res.dstabbr: + break + else: + break + + + if i < len_l: + for j in range(i, len_l): + if l[j] == ';': + l[j] = ',' + + assert l[i] == ',' + + i += 1 + + if i >= len_l: + pass + elif (8 <= l.count(',') <= 9 and + not [y for x in l[i:] if x != ',' + for y in x if y not in "0123456789+-"]): + # GMT0BST,3,0,30,3600,10,0,26,7200[,3600] + for x in (res.start, res.end): + x.month = int(l[i]) + used_idxs.append(i) + i += 2 + if l[i] == '-': + value = int(l[i + 1]) * -1 + used_idxs.append(i) + i += 1 + else: + value = int(l[i]) + used_idxs.append(i) + i += 2 + if value: + x.week = value + x.weekday = (int(l[i]) - 1) % 7 + else: + x.day = int(l[i]) + used_idxs.append(i) + i += 2 + x.time = int(l[i]) + used_idxs.append(i) + i += 2 + if i < len_l: + if l[i] in ('-', '+'): + signal = (-1, 1)[l[i] == "+"] + used_idxs.append(i) + i += 1 + else: + signal = 1 + used_idxs.append(i) + res.dstoffset = (res.stdoffset + int(l[i]) * signal) + + # This was a made-up format that is not in normal use + warn(('Parsed time zone "%s"' % tzstr) + + 'is in a non-standard dateutil-specific format, which ' + + 'is now deprecated; support for parsing this format ' + + 'will be removed in future versions. It is recommended ' + + 'that you switch to a standard format like the GNU ' + + 'TZ variable format.', tz.DeprecatedTzFormatWarning) + elif (l.count(',') == 2 and l[i:].count('/') <= 2 and + not [y for x in l[i:] if x not in (',', '/', 'J', 'M', + '.', '-', ':') + for y in x if y not in "0123456789"]): + for x in (res.start, res.end): + if l[i] == 'J': + # non-leap year day (1 based) + used_idxs.append(i) + i += 1 + x.jyday = int(l[i]) + elif l[i] == 'M': + # month[-.]week[-.]weekday + used_idxs.append(i) + i += 1 + x.month = int(l[i]) + used_idxs.append(i) + i += 1 + assert l[i] in ('-', '.') + used_idxs.append(i) + i += 1 + x.week = int(l[i]) + if x.week == 5: + x.week = -1 + used_idxs.append(i) + i += 1 + assert l[i] in ('-', '.') + used_idxs.append(i) + i += 1 + x.weekday = (int(l[i]) - 1) % 7 + else: + # year day (zero based) + x.yday = int(l[i]) + 1 + + used_idxs.append(i) + i += 1 + + if i < len_l and l[i] == '/': + used_idxs.append(i) + i += 1 + # start time + len_li = len(l[i]) + if len_li == 4: + # -0300 + x.time = (int(l[i][:2]) * 3600 + + int(l[i][2:]) * 60) + elif i + 1 < len_l and l[i + 1] == ':': + # -03:00 + x.time = int(l[i]) * 3600 + int(l[i + 2]) * 60 + used_idxs.append(i) + i += 2 + if i + 1 < len_l and l[i + 1] == ':': + used_idxs.append(i) + i += 2 + x.time += int(l[i]) + elif len_li <= 2: + # -[0]3 + x.time = (int(l[i][:2]) * 3600) + else: + return None + used_idxs.append(i) + i += 1 + + assert i == len_l or l[i] == ',' + + i += 1 + + assert i >= len_l + + except (IndexError, ValueError, AssertionError): + return None + + unused_idxs = set(range(len_l)).difference(used_idxs) + res.any_unused_tokens = not {l[n] for n in unused_idxs}.issubset({",",":"}) + return res + + +DEFAULTTZPARSER = _tzparser() + + +def _parsetz(tzstr): + return DEFAULTTZPARSER.parse(tzstr) + + +class ParserError(ValueError): + """Exception subclass used for any failure to parse a datetime string. + + This is a subclass of :py:exc:`ValueError`, and should be raised any time + earlier versions of ``dateutil`` would have raised ``ValueError``. + + .. versionadded:: 2.8.1 + """ + def __str__(self): + try: + return self.args[0] % self.args[1:] + except (TypeError, IndexError): + return super(ParserError, self).__str__() + + def __repr__(self): + args = ", ".join("'%s'" % arg for arg in self.args) + return "%s(%s)" % (self.__class__.__name__, args) + + +class UnknownTimezoneWarning(RuntimeWarning): + """Raised when the parser finds a timezone it cannot parse into a tzinfo. + + .. versionadded:: 2.7.0 + """ +# vim:ts=4:sw=4:et diff --git a/dateutil/parser/isoparser.py b/dateutil/parser/isoparser.py new file mode 100644 index 0000000000000000000000000000000000000000..7060087df4776a07347cbb60127a70db393e3a65 --- /dev/null +++ b/dateutil/parser/isoparser.py @@ -0,0 +1,416 @@ +# -*- coding: utf-8 -*- +""" +This module offers a parser for ISO-8601 strings + +It is intended to support all valid date, time and datetime formats per the +ISO-8601 specification. + +..versionadded:: 2.7.0 +""" +from datetime import datetime, timedelta, time, date +import calendar +from dateutil import tz + +from functools import wraps + +import re +import six + +__all__ = ["isoparse", "isoparser"] + + +def _takes_ascii(f): + @wraps(f) + def func(self, str_in, *args, **kwargs): + # If it's a stream, read the whole thing + str_in = getattr(str_in, 'read', lambda: str_in)() + + # If it's unicode, turn it into bytes, since ISO-8601 only covers ASCII + if isinstance(str_in, six.text_type): + # ASCII is the same in UTF-8 + try: + str_in = str_in.encode('ascii') + except UnicodeEncodeError as e: + msg = 'ISO-8601 strings should contain only ASCII characters' + six.raise_from(ValueError(msg), e) + + return f(self, str_in, *args, **kwargs) + + return func + + +class isoparser(object): + def __init__(self, sep=None): + """ + :param sep: + A single character that separates date and time portions. If + ``None``, the parser will accept any single character. + For strict ISO-8601 adherence, pass ``'T'``. + """ + if sep is not None: + if (len(sep) != 1 or ord(sep) >= 128 or sep in '0123456789'): + raise ValueError('Separator must be a single, non-numeric ' + + 'ASCII character') + + sep = sep.encode('ascii') + + self._sep = sep + + @_takes_ascii + def isoparse(self, dt_str): + """ + Parse an ISO-8601 datetime string into a :class:`datetime.datetime`. + + An ISO-8601 datetime string consists of a date portion, followed + optionally by a time portion - the date and time portions are separated + by a single character separator, which is ``T`` in the official + standard. Incomplete date formats (such as ``YYYY-MM``) may *not* be + combined with a time portion. + + Supported date formats are: + + Common: + + - ``YYYY`` + - ``YYYY-MM`` + - ``YYYY-MM-DD`` or ``YYYYMMDD`` + + Uncommon: + + - ``YYYY-Www`` or ``YYYYWww`` - ISO week (day defaults to 0) + - ``YYYY-Www-D`` or ``YYYYWwwD`` - ISO week and day + + The ISO week and day numbering follows the same logic as + :func:`datetime.date.isocalendar`. + + Supported time formats are: + + - ``hh`` + - ``hh:mm`` or ``hhmm`` + - ``hh:mm:ss`` or ``hhmmss`` + - ``hh:mm:ss.ssssss`` (Up to 6 sub-second digits) + + Midnight is a special case for `hh`, as the standard supports both + 00:00 and 24:00 as a representation. The decimal separator can be + either a dot or a comma. + + + .. caution:: + + Support for fractional components other than seconds is part of the + ISO-8601 standard, but is not currently implemented in this parser. + + Supported time zone offset formats are: + + - `Z` (UTC) + - `±HH:MM` + - `±HHMM` + - `±HH` + + Offsets will be represented as :class:`dateutil.tz.tzoffset` objects, + with the exception of UTC, which will be represented as + :class:`dateutil.tz.tzutc`. Time zone offsets equivalent to UTC (such + as `+00:00`) will also be represented as :class:`dateutil.tz.tzutc`. + + :param dt_str: + A string or stream containing only an ISO-8601 datetime string + + :return: + Returns a :class:`datetime.datetime` representing the string. + Unspecified components default to their lowest value. + + .. warning:: + + As of version 2.7.0, the strictness of the parser should not be + considered a stable part of the contract. Any valid ISO-8601 string + that parses correctly with the default settings will continue to + parse correctly in future versions, but invalid strings that + currently fail (e.g. ``2017-01-01T00:00+00:00:00``) are not + guaranteed to continue failing in future versions if they encode + a valid date. + + .. versionadded:: 2.7.0 + """ + components, pos = self._parse_isodate(dt_str) + + if len(dt_str) > pos: + if self._sep is None or dt_str[pos:pos + 1] == self._sep: + components += self._parse_isotime(dt_str[pos + 1:]) + else: + raise ValueError('String contains unknown ISO components') + + if len(components) > 3 and components[3] == 24: + components[3] = 0 + return datetime(*components) + timedelta(days=1) + + return datetime(*components) + + @_takes_ascii + def parse_isodate(self, datestr): + """ + Parse the date portion of an ISO string. + + :param datestr: + The string portion of an ISO string, without a separator + + :return: + Returns a :class:`datetime.date` object + """ + components, pos = self._parse_isodate(datestr) + if pos < len(datestr): + raise ValueError('String contains unknown ISO ' + + 'components: {!r}'.format(datestr.decode('ascii'))) + return date(*components) + + @_takes_ascii + def parse_isotime(self, timestr): + """ + Parse the time portion of an ISO string. + + :param timestr: + The time portion of an ISO string, without a separator + + :return: + Returns a :class:`datetime.time` object + """ + components = self._parse_isotime(timestr) + if components[0] == 24: + components[0] = 0 + return time(*components) + + @_takes_ascii + def parse_tzstr(self, tzstr, zero_as_utc=True): + """ + Parse a valid ISO time zone string. + + See :func:`isoparser.isoparse` for details on supported formats. + + :param tzstr: + A string representing an ISO time zone offset + + :param zero_as_utc: + Whether to return :class:`dateutil.tz.tzutc` for zero-offset zones + + :return: + Returns :class:`dateutil.tz.tzoffset` for offsets and + :class:`dateutil.tz.tzutc` for ``Z`` and (if ``zero_as_utc`` is + specified) offsets equivalent to UTC. + """ + return self._parse_tzstr(tzstr, zero_as_utc=zero_as_utc) + + # Constants + _DATE_SEP = b'-' + _TIME_SEP = b':' + _FRACTION_REGEX = re.compile(b'[\\.,]([0-9]+)') + + def _parse_isodate(self, dt_str): + try: + return self._parse_isodate_common(dt_str) + except ValueError: + return self._parse_isodate_uncommon(dt_str) + + def _parse_isodate_common(self, dt_str): + len_str = len(dt_str) + components = [1, 1, 1] + + if len_str < 4: + raise ValueError('ISO string too short') + + # Year + components[0] = int(dt_str[0:4]) + pos = 4 + if pos >= len_str: + return components, pos + + has_sep = dt_str[pos:pos + 1] == self._DATE_SEP + if has_sep: + pos += 1 + + # Month + if len_str - pos < 2: + raise ValueError('Invalid common month') + + components[1] = int(dt_str[pos:pos + 2]) + pos += 2 + + if pos >= len_str: + if has_sep: + return components, pos + else: + raise ValueError('Invalid ISO format') + + if has_sep: + if dt_str[pos:pos + 1] != self._DATE_SEP: + raise ValueError('Invalid separator in ISO string') + pos += 1 + + # Day + if len_str - pos < 2: + raise ValueError('Invalid common day') + components[2] = int(dt_str[pos:pos + 2]) + return components, pos + 2 + + def _parse_isodate_uncommon(self, dt_str): + if len(dt_str) < 4: + raise ValueError('ISO string too short') + + # All ISO formats start with the year + year = int(dt_str[0:4]) + + has_sep = dt_str[4:5] == self._DATE_SEP + + pos = 4 + has_sep # Skip '-' if it's there + if dt_str[pos:pos + 1] == b'W': + # YYYY-?Www-?D? + pos += 1 + weekno = int(dt_str[pos:pos + 2]) + pos += 2 + + dayno = 1 + if len(dt_str) > pos: + if (dt_str[pos:pos + 1] == self._DATE_SEP) != has_sep: + raise ValueError('Inconsistent use of dash separator') + + pos += has_sep + + dayno = int(dt_str[pos:pos + 1]) + pos += 1 + + base_date = self._calculate_weekdate(year, weekno, dayno) + else: + # YYYYDDD or YYYY-DDD + if len(dt_str) - pos < 3: + raise ValueError('Invalid ordinal day') + + ordinal_day = int(dt_str[pos:pos + 3]) + pos += 3 + + if ordinal_day < 1 or ordinal_day > (365 + calendar.isleap(year)): + raise ValueError('Invalid ordinal day' + + ' {} for year {}'.format(ordinal_day, year)) + + base_date = date(year, 1, 1) + timedelta(days=ordinal_day - 1) + + components = [base_date.year, base_date.month, base_date.day] + return components, pos + + def _calculate_weekdate(self, year, week, day): + """ + Calculate the day of corresponding to the ISO year-week-day calendar. + + This function is effectively the inverse of + :func:`datetime.date.isocalendar`. + + :param year: + The year in the ISO calendar + + :param week: + The week in the ISO calendar - range is [1, 53] + + :param day: + The day in the ISO calendar - range is [1 (MON), 7 (SUN)] + + :return: + Returns a :class:`datetime.date` + """ + if not 0 < week < 54: + raise ValueError('Invalid week: {}'.format(week)) + + if not 0 < day < 8: # Range is 1-7 + raise ValueError('Invalid weekday: {}'.format(day)) + + # Get week 1 for the specific year: + jan_4 = date(year, 1, 4) # Week 1 always has January 4th in it + week_1 = jan_4 - timedelta(days=jan_4.isocalendar()[2] - 1) + + # Now add the specific number of weeks and days to get what we want + week_offset = (week - 1) * 7 + (day - 1) + return week_1 + timedelta(days=week_offset) + + def _parse_isotime(self, timestr): + len_str = len(timestr) + components = [0, 0, 0, 0, None] + pos = 0 + comp = -1 + + if len_str < 2: + raise ValueError('ISO time too short') + + has_sep = False + + while pos < len_str and comp < 5: + comp += 1 + + if timestr[pos:pos + 1] in b'-+Zz': + # Detect time zone boundary + components[-1] = self._parse_tzstr(timestr[pos:]) + pos = len_str + break + + if comp == 1 and timestr[pos:pos+1] == self._TIME_SEP: + has_sep = True + pos += 1 + elif comp == 2 and has_sep: + if timestr[pos:pos+1] != self._TIME_SEP: + raise ValueError('Inconsistent use of colon separator') + pos += 1 + + if comp < 3: + # Hour, minute, second + components[comp] = int(timestr[pos:pos + 2]) + pos += 2 + + if comp == 3: + # Fraction of a second + frac = self._FRACTION_REGEX.match(timestr[pos:]) + if not frac: + continue + + us_str = frac.group(1)[:6] # Truncate to microseconds + components[comp] = int(us_str) * 10**(6 - len(us_str)) + pos += len(frac.group()) + + if pos < len_str: + raise ValueError('Unused components in ISO string') + + if components[0] == 24: + # Standard supports 00:00 and 24:00 as representations of midnight + if any(component != 0 for component in components[1:4]): + raise ValueError('Hour may only be 24 at 24:00:00.000') + + return components + + def _parse_tzstr(self, tzstr, zero_as_utc=True): + if tzstr == b'Z' or tzstr == b'z': + return tz.UTC + + if len(tzstr) not in {3, 5, 6}: + raise ValueError('Time zone offset must be 1, 3, 5 or 6 characters') + + if tzstr[0:1] == b'-': + mult = -1 + elif tzstr[0:1] == b'+': + mult = 1 + else: + raise ValueError('Time zone offset requires sign') + + hours = int(tzstr[1:3]) + if len(tzstr) == 3: + minutes = 0 + else: + minutes = int(tzstr[(4 if tzstr[3:4] == self._TIME_SEP else 3):]) + + if zero_as_utc and hours == 0 and minutes == 0: + return tz.UTC + else: + if minutes > 59: + raise ValueError('Invalid minutes in time zone offset') + + if hours > 23: + raise ValueError('Invalid hours in time zone offset') + + return tz.tzoffset(None, mult * (hours * 60 + minutes) * 60) + + +DEFAULT_ISOPARSER = isoparser() +isoparse = DEFAULT_ISOPARSER.isoparse diff --git a/dateutil/tz/__init__.py b/dateutil/tz/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af1352c47292f4eebc5cae8da45641b5544558e3 --- /dev/null +++ b/dateutil/tz/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +from .tz import * +from .tz import __doc__ + +__all__ = ["tzutc", "tzoffset", "tzlocal", "tzfile", "tzrange", + "tzstr", "tzical", "tzwin", "tzwinlocal", "gettz", + "enfold", "datetime_ambiguous", "datetime_exists", + "resolve_imaginary", "UTC", "DeprecatedTzFormatWarning"] + + +class DeprecatedTzFormatWarning(Warning): + """Warning raised when time zones are parsed from deprecated formats.""" diff --git a/dateutil/tz/_common.py b/dateutil/tz/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ac11831522b266114d5b68ee1da298e3aeb14a --- /dev/null +++ b/dateutil/tz/_common.py @@ -0,0 +1,419 @@ +from six import PY2 + +from functools import wraps + +from datetime import datetime, timedelta, tzinfo + + +ZERO = timedelta(0) + +__all__ = ['tzname_in_python2', 'enfold'] + + +def tzname_in_python2(namefunc): + """Change unicode output into bytestrings in Python 2 + + tzname() API changed in Python 3. It used to return bytes, but was changed + to unicode strings + """ + if PY2: + @wraps(namefunc) + def adjust_encoding(*args, **kwargs): + name = namefunc(*args, **kwargs) + if name is not None: + name = name.encode() + + return name + + return adjust_encoding + else: + return namefunc + + +# The following is adapted from Alexander Belopolsky's tz library +# https://github.com/abalkin/tz +if hasattr(datetime, 'fold'): + # This is the pre-python 3.6 fold situation + def enfold(dt, fold=1): + """ + Provides a unified interface for assigning the ``fold`` attribute to + datetimes both before and after the implementation of PEP-495. + + :param fold: + The value for the ``fold`` attribute in the returned datetime. This + should be either 0 or 1. + + :return: + Returns an object for which ``getattr(dt, 'fold', 0)`` returns + ``fold`` for all versions of Python. In versions prior to + Python 3.6, this is a ``_DatetimeWithFold`` object, which is a + subclass of :py:class:`datetime.datetime` with the ``fold`` + attribute added, if ``fold`` is 1. + + .. versionadded:: 2.6.0 + """ + return dt.replace(fold=fold) + +else: + class _DatetimeWithFold(datetime): + """ + This is a class designed to provide a PEP 495-compliant interface for + Python versions before 3.6. It is used only for dates in a fold, so + the ``fold`` attribute is fixed at ``1``. + + .. versionadded:: 2.6.0 + """ + __slots__ = () + + def replace(self, *args, **kwargs): + """ + Return a datetime with the same attributes, except for those + attributes given new values by whichever keyword arguments are + specified. Note that tzinfo=None can be specified to create a naive + datetime from an aware datetime with no conversion of date and time + data. + + This is reimplemented in ``_DatetimeWithFold`` because pypy3 will + return a ``datetime.datetime`` even if ``fold`` is unchanged. + """ + argnames = ( + 'year', 'month', 'day', 'hour', 'minute', 'second', + 'microsecond', 'tzinfo' + ) + + for arg, argname in zip(args, argnames): + if argname in kwargs: + raise TypeError('Duplicate argument: {}'.format(argname)) + + kwargs[argname] = arg + + for argname in argnames: + if argname not in kwargs: + kwargs[argname] = getattr(self, argname) + + dt_class = self.__class__ if kwargs.get('fold', 1) else datetime + + return dt_class(**kwargs) + + @property + def fold(self): + return 1 + + def enfold(dt, fold=1): + """ + Provides a unified interface for assigning the ``fold`` attribute to + datetimes both before and after the implementation of PEP-495. + + :param fold: + The value for the ``fold`` attribute in the returned datetime. This + should be either 0 or 1. + + :return: + Returns an object for which ``getattr(dt, 'fold', 0)`` returns + ``fold`` for all versions of Python. In versions prior to + Python 3.6, this is a ``_DatetimeWithFold`` object, which is a + subclass of :py:class:`datetime.datetime` with the ``fold`` + attribute added, if ``fold`` is 1. + + .. versionadded:: 2.6.0 + """ + if getattr(dt, 'fold', 0) == fold: + return dt + + args = dt.timetuple()[:6] + args += (dt.microsecond, dt.tzinfo) + + if fold: + return _DatetimeWithFold(*args) + else: + return datetime(*args) + + +def _validate_fromutc_inputs(f): + """ + The CPython version of ``fromutc`` checks that the input is a ``datetime`` + object and that ``self`` is attached as its ``tzinfo``. + """ + @wraps(f) + def fromutc(self, dt): + if not isinstance(dt, datetime): + raise TypeError("fromutc() requires a datetime argument") + if dt.tzinfo is not self: + raise ValueError("dt.tzinfo is not self") + + return f(self, dt) + + return fromutc + + +class _tzinfo(tzinfo): + """ + Base class for all ``dateutil`` ``tzinfo`` objects. + """ + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + + dt = dt.replace(tzinfo=self) + + wall_0 = enfold(dt, fold=0) + wall_1 = enfold(dt, fold=1) + + same_offset = wall_0.utcoffset() == wall_1.utcoffset() + same_dt = wall_0.replace(tzinfo=None) == wall_1.replace(tzinfo=None) + + return same_dt and not same_offset + + def _fold_status(self, dt_utc, dt_wall): + """ + Determine the fold status of a "wall" datetime, given a representation + of the same datetime as a (naive) UTC datetime. This is calculated based + on the assumption that ``dt.utcoffset() - dt.dst()`` is constant for all + datetimes, and that this offset is the actual number of hours separating + ``dt_utc`` and ``dt_wall``. + + :param dt_utc: + Representation of the datetime as UTC + + :param dt_wall: + Representation of the datetime as "wall time". This parameter must + either have a `fold` attribute or have a fold-naive + :class:`datetime.tzinfo` attached, otherwise the calculation may + fail. + """ + if self.is_ambiguous(dt_wall): + delta_wall = dt_wall - dt_utc + _fold = int(delta_wall == (dt_utc.utcoffset() - dt_utc.dst())) + else: + _fold = 0 + + return _fold + + def _fold(self, dt): + return getattr(dt, 'fold', 0) + + def _fromutc(self, dt): + """ + Given a timezone-aware datetime in a given timezone, calculates a + timezone-aware datetime in a new timezone. + + Since this is the one time that we *know* we have an unambiguous + datetime object, we take this opportunity to determine whether the + datetime is ambiguous and in a "fold" state (e.g. if it's the first + occurrence, chronologically, of the ambiguous datetime). + + :param dt: + A timezone-aware :class:`datetime.datetime` object. + """ + + # Re-implement the algorithm from Python's datetime.py + dtoff = dt.utcoffset() + if dtoff is None: + raise ValueError("fromutc() requires a non-None utcoffset() " + "result") + + # The original datetime.py code assumes that `dst()` defaults to + # zero during ambiguous times. PEP 495 inverts this presumption, so + # for pre-PEP 495 versions of python, we need to tweak the algorithm. + dtdst = dt.dst() + if dtdst is None: + raise ValueError("fromutc() requires a non-None dst() result") + delta = dtoff - dtdst + + dt += delta + # Set fold=1 so we can default to being in the fold for + # ambiguous dates. + dtdst = enfold(dt, fold=1).dst() + if dtdst is None: + raise ValueError("fromutc(): dt.dst gave inconsistent " + "results; cannot convert") + return dt + dtdst + + @_validate_fromutc_inputs + def fromutc(self, dt): + """ + Given a timezone-aware datetime in a given timezone, calculates a + timezone-aware datetime in a new timezone. + + Since this is the one time that we *know* we have an unambiguous + datetime object, we take this opportunity to determine whether the + datetime is ambiguous and in a "fold" state (e.g. if it's the first + occurrence, chronologically, of the ambiguous datetime). + + :param dt: + A timezone-aware :class:`datetime.datetime` object. + """ + dt_wall = self._fromutc(dt) + + # Calculate the fold status given the two datetimes. + _fold = self._fold_status(dt, dt_wall) + + # Set the default fold value for ambiguous dates + return enfold(dt_wall, fold=_fold) + + +class tzrangebase(_tzinfo): + """ + This is an abstract base class for time zones represented by an annual + transition into and out of DST. Child classes should implement the following + methods: + + * ``__init__(self, *args, **kwargs)`` + * ``transitions(self, year)`` - this is expected to return a tuple of + datetimes representing the DST on and off transitions in standard + time. + + A fully initialized ``tzrangebase`` subclass should also provide the + following attributes: + * ``hasdst``: Boolean whether or not the zone uses DST. + * ``_dst_offset`` / ``_std_offset``: :class:`datetime.timedelta` objects + representing the respective UTC offsets. + * ``_dst_abbr`` / ``_std_abbr``: Strings representing the timezone short + abbreviations in DST and STD, respectively. + * ``_hasdst``: Whether or not the zone has DST. + + .. versionadded:: 2.6.0 + """ + def __init__(self): + raise NotImplementedError('tzrangebase is an abstract base class') + + def utcoffset(self, dt): + isdst = self._isdst(dt) + + if isdst is None: + return None + elif isdst: + return self._dst_offset + else: + return self._std_offset + + def dst(self, dt): + isdst = self._isdst(dt) + + if isdst is None: + return None + elif isdst: + return self._dst_base_offset + else: + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + if self._isdst(dt): + return self._dst_abbr + else: + return self._std_abbr + + def fromutc(self, dt): + """ Given a datetime in UTC, return local time """ + if not isinstance(dt, datetime): + raise TypeError("fromutc() requires a datetime argument") + + if dt.tzinfo is not self: + raise ValueError("dt.tzinfo is not self") + + # Get transitions - if there are none, fixed offset + transitions = self.transitions(dt.year) + if transitions is None: + return dt + self.utcoffset(dt) + + # Get the transition times in UTC + dston, dstoff = transitions + + dston -= self._std_offset + dstoff -= self._std_offset + + utc_transitions = (dston, dstoff) + dt_utc = dt.replace(tzinfo=None) + + isdst = self._naive_isdst(dt_utc, utc_transitions) + + if isdst: + dt_wall = dt + self._dst_offset + else: + dt_wall = dt + self._std_offset + + _fold = int(not isdst and self.is_ambiguous(dt_wall)) + + return enfold(dt_wall, fold=_fold) + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + if not self.hasdst: + return False + + start, end = self.transitions(dt.year) + + dt = dt.replace(tzinfo=None) + return (end <= dt < end + self._dst_base_offset) + + def _isdst(self, dt): + if not self.hasdst: + return False + elif dt is None: + return None + + transitions = self.transitions(dt.year) + + if transitions is None: + return False + + dt = dt.replace(tzinfo=None) + + isdst = self._naive_isdst(dt, transitions) + + # Handle ambiguous dates + if not isdst and self.is_ambiguous(dt): + return not self._fold(dt) + else: + return isdst + + def _naive_isdst(self, dt, transitions): + dston, dstoff = transitions + + dt = dt.replace(tzinfo=None) + + if dston < dstoff: + isdst = dston <= dt < dstoff + else: + isdst = not dstoff <= dt < dston + + return isdst + + @property + def _dst_base_offset(self): + return self._dst_offset - self._std_offset + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s(...)" % self.__class__.__name__ + + __reduce__ = object.__reduce__ diff --git a/dateutil/tz/_factories.py b/dateutil/tz/_factories.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a65891a023ebf9eb0c24d391ba67541b7133f1 --- /dev/null +++ b/dateutil/tz/_factories.py @@ -0,0 +1,80 @@ +from datetime import timedelta +import weakref +from collections import OrderedDict + +from six.moves import _thread + + +class _TzSingleton(type): + def __init__(cls, *args, **kwargs): + cls.__instance = None + super(_TzSingleton, cls).__init__(*args, **kwargs) + + def __call__(cls): + if cls.__instance is None: + cls.__instance = super(_TzSingleton, cls).__call__() + return cls.__instance + + +class _TzFactory(type): + def instance(cls, *args, **kwargs): + """Alternate constructor that returns a fresh instance""" + return type.__call__(cls, *args, **kwargs) + + +class _TzOffsetFactory(_TzFactory): + def __init__(cls, *args, **kwargs): + cls.__instances = weakref.WeakValueDictionary() + cls.__strong_cache = OrderedDict() + cls.__strong_cache_size = 8 + + cls._cache_lock = _thread.allocate_lock() + + def __call__(cls, name, offset): + if isinstance(offset, timedelta): + key = (name, offset.total_seconds()) + else: + key = (name, offset) + + instance = cls.__instances.get(key, None) + if instance is None: + instance = cls.__instances.setdefault(key, + cls.instance(name, offset)) + + # This lock may not be necessary in Python 3. See GH issue #901 + with cls._cache_lock: + cls.__strong_cache[key] = cls.__strong_cache.pop(key, instance) + + # Remove an item if the strong cache is overpopulated + if len(cls.__strong_cache) > cls.__strong_cache_size: + cls.__strong_cache.popitem(last=False) + + return instance + + +class _TzStrFactory(_TzFactory): + def __init__(cls, *args, **kwargs): + cls.__instances = weakref.WeakValueDictionary() + cls.__strong_cache = OrderedDict() + cls.__strong_cache_size = 8 + + cls.__cache_lock = _thread.allocate_lock() + + def __call__(cls, s, posix_offset=False): + key = (s, posix_offset) + instance = cls.__instances.get(key, None) + + if instance is None: + instance = cls.__instances.setdefault(key, + cls.instance(s, posix_offset)) + + # This lock may not be necessary in Python 3. See GH issue #901 + with cls.__cache_lock: + cls.__strong_cache[key] = cls.__strong_cache.pop(key, instance) + + # Remove an item if the strong cache is overpopulated + if len(cls.__strong_cache) > cls.__strong_cache_size: + cls.__strong_cache.popitem(last=False) + + return instance + diff --git a/dateutil/tz/tz.py b/dateutil/tz/tz.py new file mode 100644 index 0000000000000000000000000000000000000000..617591446bd92eb1cc7b7d67fa3f17435e691cdd --- /dev/null +++ b/dateutil/tz/tz.py @@ -0,0 +1,1849 @@ +# -*- coding: utf-8 -*- +""" +This module offers timezone implementations subclassing the abstract +:py:class:`datetime.tzinfo` type. There are classes to handle tzfile format +files (usually are in :file:`/etc/localtime`, :file:`/usr/share/zoneinfo`, +etc), TZ environment string (in all known formats), given ranges (with help +from relative deltas), local machine timezone, fixed offset timezone, and UTC +timezone. +""" +import datetime +import struct +import time +import sys +import os +import bisect +import weakref +from collections import OrderedDict + +import six +from six import string_types +from six.moves import _thread +from ._common import tzname_in_python2, _tzinfo +from ._common import tzrangebase, enfold +from ._common import _validate_fromutc_inputs + +from ._factories import _TzSingleton, _TzOffsetFactory +from ._factories import _TzStrFactory +try: + from .win import tzwin, tzwinlocal +except ImportError: + tzwin = tzwinlocal = None + +# For warning about rounding tzinfo +from warnings import warn + +ZERO = datetime.timedelta(0) +EPOCH = datetime.datetime(1970, 1, 1, 0, 0) +EPOCHORDINAL = EPOCH.toordinal() + + +@six.add_metaclass(_TzSingleton) +class tzutc(datetime.tzinfo): + """ + This is a tzinfo object that represents the UTC time zone. + + **Examples:** + + .. doctest:: + + >>> from datetime import * + >>> from dateutil.tz import * + + >>> datetime.now() + datetime.datetime(2003, 9, 27, 9, 40, 1, 521290) + + >>> datetime.now(tzutc()) + datetime.datetime(2003, 9, 27, 12, 40, 12, 156379, tzinfo=tzutc()) + + >>> datetime.now(tzutc()).tzname() + 'UTC' + + .. versionchanged:: 2.7.0 + ``tzutc()`` is now a singleton, so the result of ``tzutc()`` will + always return the same object. + + .. doctest:: + + >>> from dateutil.tz import tzutc, UTC + >>> tzutc() is tzutc() + True + >>> tzutc() is UTC + True + """ + def utcoffset(self, dt): + return ZERO + + def dst(self, dt): + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + return "UTC" + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + return False + + @_validate_fromutc_inputs + def fromutc(self, dt): + """ + Fast track version of fromutc() returns the original ``dt`` object for + any valid :py:class:`datetime.datetime` object. + """ + return dt + + def __eq__(self, other): + if not isinstance(other, (tzutc, tzoffset)): + return NotImplemented + + return (isinstance(other, tzutc) or + (isinstance(other, tzoffset) and other._offset == ZERO)) + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s()" % self.__class__.__name__ + + __reduce__ = object.__reduce__ + + +#: Convenience constant providing a :class:`tzutc()` instance +#: +#: .. versionadded:: 2.7.0 +UTC = tzutc() + + +@six.add_metaclass(_TzOffsetFactory) +class tzoffset(datetime.tzinfo): + """ + A simple class for representing a fixed offset from UTC. + + :param name: + The timezone name, to be returned when ``tzname()`` is called. + :param offset: + The time zone offset in seconds, or (since version 2.6.0, represented + as a :py:class:`datetime.timedelta` object). + """ + def __init__(self, name, offset): + self._name = name + + try: + # Allow a timedelta + offset = offset.total_seconds() + except (TypeError, AttributeError): + pass + + self._offset = datetime.timedelta(seconds=_get_supported_offset(offset)) + + def utcoffset(self, dt): + return self._offset + + def dst(self, dt): + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + return self._name + + @_validate_fromutc_inputs + def fromutc(self, dt): + return dt + self._offset + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + return False + + def __eq__(self, other): + if not isinstance(other, tzoffset): + return NotImplemented + + return self._offset == other._offset + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s(%s, %s)" % (self.__class__.__name__, + repr(self._name), + int(self._offset.total_seconds())) + + __reduce__ = object.__reduce__ + + +class tzlocal(_tzinfo): + """ + A :class:`tzinfo` subclass built around the ``time`` timezone functions. + """ + def __init__(self): + super(tzlocal, self).__init__() + + self._std_offset = datetime.timedelta(seconds=-time.timezone) + if time.daylight: + self._dst_offset = datetime.timedelta(seconds=-time.altzone) + else: + self._dst_offset = self._std_offset + + self._dst_saved = self._dst_offset - self._std_offset + self._hasdst = bool(self._dst_saved) + self._tznames = tuple(time.tzname) + + def utcoffset(self, dt): + if dt is None and self._hasdst: + return None + + if self._isdst(dt): + return self._dst_offset + else: + return self._std_offset + + def dst(self, dt): + if dt is None and self._hasdst: + return None + + if self._isdst(dt): + return self._dst_offset - self._std_offset + else: + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + return self._tznames[self._isdst(dt)] + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + naive_dst = self._naive_is_dst(dt) + return (not naive_dst and + (naive_dst != self._naive_is_dst(dt - self._dst_saved))) + + def _naive_is_dst(self, dt): + timestamp = _datetime_to_timestamp(dt) + return time.localtime(timestamp + time.timezone).tm_isdst + + def _isdst(self, dt, fold_naive=True): + # We can't use mktime here. It is unstable when deciding if + # the hour near to a change is DST or not. + # + # timestamp = time.mktime((dt.year, dt.month, dt.day, dt.hour, + # dt.minute, dt.second, dt.weekday(), 0, -1)) + # return time.localtime(timestamp).tm_isdst + # + # The code above yields the following result: + # + # >>> import tz, datetime + # >>> t = tz.tzlocal() + # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() + # 'BRDT' + # >>> datetime.datetime(2003,2,16,0,tzinfo=t).tzname() + # 'BRST' + # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() + # 'BRST' + # >>> datetime.datetime(2003,2,15,22,tzinfo=t).tzname() + # 'BRDT' + # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() + # 'BRDT' + # + # Here is a more stable implementation: + # + if not self._hasdst: + return False + + # Check for ambiguous times: + dstval = self._naive_is_dst(dt) + fold = getattr(dt, 'fold', None) + + if self.is_ambiguous(dt): + if fold is not None: + return not self._fold(dt) + else: + return True + + return dstval + + def __eq__(self, other): + if isinstance(other, tzlocal): + return (self._std_offset == other._std_offset and + self._dst_offset == other._dst_offset) + elif isinstance(other, tzutc): + return (not self._hasdst and + self._tznames[0] in {'UTC', 'GMT'} and + self._std_offset == ZERO) + elif isinstance(other, tzoffset): + return (not self._hasdst and + self._tznames[0] == other._name and + self._std_offset == other._offset) + else: + return NotImplemented + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s()" % self.__class__.__name__ + + __reduce__ = object.__reduce__ + + +class _ttinfo(object): + __slots__ = ["offset", "delta", "isdst", "abbr", + "isstd", "isgmt", "dstoffset"] + + def __init__(self): + for attr in self.__slots__: + setattr(self, attr, None) + + def __repr__(self): + l = [] + for attr in self.__slots__: + value = getattr(self, attr) + if value is not None: + l.append("%s=%s" % (attr, repr(value))) + return "%s(%s)" % (self.__class__.__name__, ", ".join(l)) + + def __eq__(self, other): + if not isinstance(other, _ttinfo): + return NotImplemented + + return (self.offset == other.offset and + self.delta == other.delta and + self.isdst == other.isdst and + self.abbr == other.abbr and + self.isstd == other.isstd and + self.isgmt == other.isgmt and + self.dstoffset == other.dstoffset) + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __getstate__(self): + state = {} + for name in self.__slots__: + state[name] = getattr(self, name, None) + return state + + def __setstate__(self, state): + for name in self.__slots__: + if name in state: + setattr(self, name, state[name]) + + +class _tzfile(object): + """ + Lightweight class for holding the relevant transition and time zone + information read from binary tzfiles. + """ + attrs = ['trans_list', 'trans_list_utc', 'trans_idx', 'ttinfo_list', + 'ttinfo_std', 'ttinfo_dst', 'ttinfo_before', 'ttinfo_first'] + + def __init__(self, **kwargs): + for attr in self.attrs: + setattr(self, attr, kwargs.get(attr, None)) + + +class tzfile(_tzinfo): + """ + This is a ``tzinfo`` subclass that allows one to use the ``tzfile(5)`` + format timezone files to extract current and historical zone information. + + :param fileobj: + This can be an opened file stream or a file name that the time zone + information can be read from. + + :param filename: + This is an optional parameter specifying the source of the time zone + information in the event that ``fileobj`` is a file object. If omitted + and ``fileobj`` is a file stream, this parameter will be set either to + ``fileobj``'s ``name`` attribute or to ``repr(fileobj)``. + + See `Sources for Time Zone and Daylight Saving Time Data + `_ for more information. + Time zone files can be compiled from the `IANA Time Zone database files + `_ with the `zic time zone compiler + `_ + + .. note:: + + Only construct a ``tzfile`` directly if you have a specific timezone + file on disk that you want to read into a Python ``tzinfo`` object. + If you want to get a ``tzfile`` representing a specific IANA zone, + (e.g. ``'America/New_York'``), you should call + :func:`dateutil.tz.gettz` with the zone identifier. + + + **Examples:** + + Using the US Eastern time zone as an example, we can see that a ``tzfile`` + provides time zone information for the standard Daylight Saving offsets: + + .. testsetup:: tzfile + + from dateutil.tz import gettz + from datetime import datetime + + .. doctest:: tzfile + + >>> NYC = gettz('America/New_York') + >>> NYC + tzfile('/usr/share/zoneinfo/America/New_York') + + >>> print(datetime(2016, 1, 3, tzinfo=NYC)) # EST + 2016-01-03 00:00:00-05:00 + + >>> print(datetime(2016, 7, 7, tzinfo=NYC)) # EDT + 2016-07-07 00:00:00-04:00 + + + The ``tzfile`` structure contains a fully history of the time zone, + so historical dates will also have the right offsets. For example, before + the adoption of the UTC standards, New York used local solar mean time: + + .. doctest:: tzfile + + >>> print(datetime(1901, 4, 12, tzinfo=NYC)) # LMT + 1901-04-12 00:00:00-04:56 + + And during World War II, New York was on "Eastern War Time", which was a + state of permanent daylight saving time: + + .. doctest:: tzfile + + >>> print(datetime(1944, 2, 7, tzinfo=NYC)) # EWT + 1944-02-07 00:00:00-04:00 + + """ + + def __init__(self, fileobj, filename=None): + super(tzfile, self).__init__() + + file_opened_here = False + if isinstance(fileobj, string_types): + self._filename = fileobj + fileobj = open(fileobj, 'rb') + file_opened_here = True + elif filename is not None: + self._filename = filename + elif hasattr(fileobj, "name"): + self._filename = fileobj.name + else: + self._filename = repr(fileobj) + + if fileobj is not None: + if not file_opened_here: + fileobj = _nullcontext(fileobj) + + with fileobj as file_stream: + tzobj = self._read_tzfile(file_stream) + + self._set_tzdata(tzobj) + + def _set_tzdata(self, tzobj): + """ Set the time zone data of this object from a _tzfile object """ + # Copy the relevant attributes over as private attributes + for attr in _tzfile.attrs: + setattr(self, '_' + attr, getattr(tzobj, attr)) + + def _read_tzfile(self, fileobj): + out = _tzfile() + + # From tzfile(5): + # + # The time zone information files used by tzset(3) + # begin with the magic characters "TZif" to identify + # them as time zone information files, followed by + # sixteen bytes reserved for future use, followed by + # six four-byte values of type long, written in a + # ``standard'' byte order (the high-order byte + # of the value is written first). + if fileobj.read(4).decode() != "TZif": + raise ValueError("magic not found") + + fileobj.read(16) + + ( + # The number of UTC/local indicators stored in the file. + ttisgmtcnt, + + # The number of standard/wall indicators stored in the file. + ttisstdcnt, + + # The number of leap seconds for which data is + # stored in the file. + leapcnt, + + # The number of "transition times" for which data + # is stored in the file. + timecnt, + + # The number of "local time types" for which data + # is stored in the file (must not be zero). + typecnt, + + # The number of characters of "time zone + # abbreviation strings" stored in the file. + charcnt, + + ) = struct.unpack(">6l", fileobj.read(24)) + + # The above header is followed by tzh_timecnt four-byte + # values of type long, sorted in ascending order. + # These values are written in ``standard'' byte order. + # Each is used as a transition time (as returned by + # time(2)) at which the rules for computing local time + # change. + + if timecnt: + out.trans_list_utc = list(struct.unpack(">%dl" % timecnt, + fileobj.read(timecnt*4))) + else: + out.trans_list_utc = [] + + # Next come tzh_timecnt one-byte values of type unsigned + # char; each one tells which of the different types of + # ``local time'' types described in the file is associated + # with the same-indexed transition time. These values + # serve as indices into an array of ttinfo structures that + # appears next in the file. + + if timecnt: + out.trans_idx = struct.unpack(">%dB" % timecnt, + fileobj.read(timecnt)) + else: + out.trans_idx = [] + + # Each ttinfo structure is written as a four-byte value + # for tt_gmtoff of type long, in a standard byte + # order, followed by a one-byte value for tt_isdst + # and a one-byte value for tt_abbrind. In each + # structure, tt_gmtoff gives the number of + # seconds to be added to UTC, tt_isdst tells whether + # tm_isdst should be set by localtime(3), and + # tt_abbrind serves as an index into the array of + # time zone abbreviation characters that follow the + # ttinfo structure(s) in the file. + + ttinfo = [] + + for i in range(typecnt): + ttinfo.append(struct.unpack(">lbb", fileobj.read(6))) + + abbr = fileobj.read(charcnt).decode() + + # Then there are tzh_leapcnt pairs of four-byte + # values, written in standard byte order; the + # first value of each pair gives the time (as + # returned by time(2)) at which a leap second + # occurs; the second gives the total number of + # leap seconds to be applied after the given time. + # The pairs of values are sorted in ascending order + # by time. + + # Not used, for now (but seek for correct file position) + if leapcnt: + fileobj.seek(leapcnt * 8, os.SEEK_CUR) + + # Then there are tzh_ttisstdcnt standard/wall + # indicators, each stored as a one-byte value; + # they tell whether the transition times associated + # with local time types were specified as standard + # time or wall clock time, and are used when + # a time zone file is used in handling POSIX-style + # time zone environment variables. + + if ttisstdcnt: + isstd = struct.unpack(">%db" % ttisstdcnt, + fileobj.read(ttisstdcnt)) + + # Finally, there are tzh_ttisgmtcnt UTC/local + # indicators, each stored as a one-byte value; + # they tell whether the transition times associated + # with local time types were specified as UTC or + # local time, and are used when a time zone file + # is used in handling POSIX-style time zone envi- + # ronment variables. + + if ttisgmtcnt: + isgmt = struct.unpack(">%db" % ttisgmtcnt, + fileobj.read(ttisgmtcnt)) + + # Build ttinfo list + out.ttinfo_list = [] + for i in range(typecnt): + gmtoff, isdst, abbrind = ttinfo[i] + gmtoff = _get_supported_offset(gmtoff) + tti = _ttinfo() + tti.offset = gmtoff + tti.dstoffset = datetime.timedelta(0) + tti.delta = datetime.timedelta(seconds=gmtoff) + tti.isdst = isdst + tti.abbr = abbr[abbrind:abbr.find('\x00', abbrind)] + tti.isstd = (ttisstdcnt > i and isstd[i] != 0) + tti.isgmt = (ttisgmtcnt > i and isgmt[i] != 0) + out.ttinfo_list.append(tti) + + # Replace ttinfo indexes for ttinfo objects. + out.trans_idx = [out.ttinfo_list[idx] for idx in out.trans_idx] + + # Set standard, dst, and before ttinfos. before will be + # used when a given time is before any transitions, + # and will be set to the first non-dst ttinfo, or to + # the first dst, if all of them are dst. + out.ttinfo_std = None + out.ttinfo_dst = None + out.ttinfo_before = None + if out.ttinfo_list: + if not out.trans_list_utc: + out.ttinfo_std = out.ttinfo_first = out.ttinfo_list[0] + else: + for i in range(timecnt-1, -1, -1): + tti = out.trans_idx[i] + if not out.ttinfo_std and not tti.isdst: + out.ttinfo_std = tti + elif not out.ttinfo_dst and tti.isdst: + out.ttinfo_dst = tti + + if out.ttinfo_std and out.ttinfo_dst: + break + else: + if out.ttinfo_dst and not out.ttinfo_std: + out.ttinfo_std = out.ttinfo_dst + + for tti in out.ttinfo_list: + if not tti.isdst: + out.ttinfo_before = tti + break + else: + out.ttinfo_before = out.ttinfo_list[0] + + # Now fix transition times to become relative to wall time. + # + # I'm not sure about this. In my tests, the tz source file + # is setup to wall time, and in the binary file isstd and + # isgmt are off, so it should be in wall time. OTOH, it's + # always in gmt time. Let me know if you have comments + # about this. + lastdst = None + lastoffset = None + lastdstoffset = None + lastbaseoffset = None + out.trans_list = [] + + for i, tti in enumerate(out.trans_idx): + offset = tti.offset + dstoffset = 0 + + if lastdst is not None: + if tti.isdst: + if not lastdst: + dstoffset = offset - lastoffset + + if not dstoffset and lastdstoffset: + dstoffset = lastdstoffset + + tti.dstoffset = datetime.timedelta(seconds=dstoffset) + lastdstoffset = dstoffset + + # If a time zone changes its base offset during a DST transition, + # then you need to adjust by the previous base offset to get the + # transition time in local time. Otherwise you use the current + # base offset. Ideally, I would have some mathematical proof of + # why this is true, but I haven't really thought about it enough. + baseoffset = offset - dstoffset + adjustment = baseoffset + if (lastbaseoffset is not None and baseoffset != lastbaseoffset + and tti.isdst != lastdst): + # The base DST has changed + adjustment = lastbaseoffset + + lastdst = tti.isdst + lastoffset = offset + lastbaseoffset = baseoffset + + out.trans_list.append(out.trans_list_utc[i] + adjustment) + + out.trans_idx = tuple(out.trans_idx) + out.trans_list = tuple(out.trans_list) + out.trans_list_utc = tuple(out.trans_list_utc) + + return out + + def _find_last_transition(self, dt, in_utc=False): + # If there's no list, there are no transitions to find + if not self._trans_list: + return None + + timestamp = _datetime_to_timestamp(dt) + + # Find where the timestamp fits in the transition list - if the + # timestamp is a transition time, it's part of the "after" period. + trans_list = self._trans_list_utc if in_utc else self._trans_list + idx = bisect.bisect_right(trans_list, timestamp) + + # We want to know when the previous transition was, so subtract off 1 + return idx - 1 + + def _get_ttinfo(self, idx): + # For no list or after the last transition, default to _ttinfo_std + if idx is None or (idx + 1) >= len(self._trans_list): + return self._ttinfo_std + + # If there is a list and the time is before it, return _ttinfo_before + if idx < 0: + return self._ttinfo_before + + return self._trans_idx[idx] + + def _find_ttinfo(self, dt): + idx = self._resolve_ambiguous_time(dt) + + return self._get_ttinfo(idx) + + def fromutc(self, dt): + """ + The ``tzfile`` implementation of :py:func:`datetime.tzinfo.fromutc`. + + :param dt: + A :py:class:`datetime.datetime` object. + + :raises TypeError: + Raised if ``dt`` is not a :py:class:`datetime.datetime` object. + + :raises ValueError: + Raised if this is called with a ``dt`` which does not have this + ``tzinfo`` attached. + + :return: + Returns a :py:class:`datetime.datetime` object representing the + wall time in ``self``'s time zone. + """ + # These isinstance checks are in datetime.tzinfo, so we'll preserve + # them, even if we don't care about duck typing. + if not isinstance(dt, datetime.datetime): + raise TypeError("fromutc() requires a datetime argument") + + if dt.tzinfo is not self: + raise ValueError("dt.tzinfo is not self") + + # First treat UTC as wall time and get the transition we're in. + idx = self._find_last_transition(dt, in_utc=True) + tti = self._get_ttinfo(idx) + + dt_out = dt + datetime.timedelta(seconds=tti.offset) + + fold = self.is_ambiguous(dt_out, idx=idx) + + return enfold(dt_out, fold=int(fold)) + + def is_ambiguous(self, dt, idx=None): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + if idx is None: + idx = self._find_last_transition(dt) + + # Calculate the difference in offsets from current to previous + timestamp = _datetime_to_timestamp(dt) + tti = self._get_ttinfo(idx) + + if idx is None or idx <= 0: + return False + + od = self._get_ttinfo(idx - 1).offset - tti.offset + tt = self._trans_list[idx] # Transition time + + return timestamp < tt + od + + def _resolve_ambiguous_time(self, dt): + idx = self._find_last_transition(dt) + + # If we have no transitions, return the index + _fold = self._fold(dt) + if idx is None or idx == 0: + return idx + + # If it's ambiguous and we're in a fold, shift to a different index. + idx_offset = int(not _fold and self.is_ambiguous(dt, idx)) + + return idx - idx_offset + + def utcoffset(self, dt): + if dt is None: + return None + + if not self._ttinfo_std: + return ZERO + + return self._find_ttinfo(dt).delta + + def dst(self, dt): + if dt is None: + return None + + if not self._ttinfo_dst: + return ZERO + + tti = self._find_ttinfo(dt) + + if not tti.isdst: + return ZERO + + # The documentation says that utcoffset()-dst() must + # be constant for every dt. + return tti.dstoffset + + @tzname_in_python2 + def tzname(self, dt): + if not self._ttinfo_std or dt is None: + return None + return self._find_ttinfo(dt).abbr + + def __eq__(self, other): + if not isinstance(other, tzfile): + return NotImplemented + return (self._trans_list == other._trans_list and + self._trans_idx == other._trans_idx and + self._ttinfo_list == other._ttinfo_list) + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, repr(self._filename)) + + def __reduce__(self): + return self.__reduce_ex__(None) + + def __reduce_ex__(self, protocol): + return (self.__class__, (None, self._filename), self.__dict__) + + +class tzrange(tzrangebase): + """ + The ``tzrange`` object is a time zone specified by a set of offsets and + abbreviations, equivalent to the way the ``TZ`` variable can be specified + in POSIX-like systems, but using Python delta objects to specify DST + start, end and offsets. + + :param stdabbr: + The abbreviation for standard time (e.g. ``'EST'``). + + :param stdoffset: + An integer or :class:`datetime.timedelta` object or equivalent + specifying the base offset from UTC. + + If unspecified, +00:00 is used. + + :param dstabbr: + The abbreviation for DST / "Summer" time (e.g. ``'EDT'``). + + If specified, with no other DST information, DST is assumed to occur + and the default behavior or ``dstoffset``, ``start`` and ``end`` is + used. If unspecified and no other DST information is specified, it + is assumed that this zone has no DST. + + If this is unspecified and other DST information is *is* specified, + DST occurs in the zone but the time zone abbreviation is left + unchanged. + + :param dstoffset: + A an integer or :class:`datetime.timedelta` object or equivalent + specifying the UTC offset during DST. If unspecified and any other DST + information is specified, it is assumed to be the STD offset +1 hour. + + :param start: + A :class:`relativedelta.relativedelta` object or equivalent specifying + the time and time of year that daylight savings time starts. To + specify, for example, that DST starts at 2AM on the 2nd Sunday in + March, pass: + + ``relativedelta(hours=2, month=3, day=1, weekday=SU(+2))`` + + If unspecified and any other DST information is specified, the default + value is 2 AM on the first Sunday in April. + + :param end: + A :class:`relativedelta.relativedelta` object or equivalent + representing the time and time of year that daylight savings time + ends, with the same specification method as in ``start``. One note is + that this should point to the first time in the *standard* zone, so if + a transition occurs at 2AM in the DST zone and the clocks are set back + 1 hour to 1AM, set the ``hours`` parameter to +1. + + + **Examples:** + + .. testsetup:: tzrange + + from dateutil.tz import tzrange, tzstr + + .. doctest:: tzrange + + >>> tzstr('EST5EDT') == tzrange("EST", -18000, "EDT") + True + + >>> from dateutil.relativedelta import * + >>> range1 = tzrange("EST", -18000, "EDT") + >>> range2 = tzrange("EST", -18000, "EDT", -14400, + ... relativedelta(hours=+2, month=4, day=1, + ... weekday=SU(+1)), + ... relativedelta(hours=+1, month=10, day=31, + ... weekday=SU(-1))) + >>> tzstr('EST5EDT') == range1 == range2 + True + + """ + def __init__(self, stdabbr, stdoffset=None, + dstabbr=None, dstoffset=None, + start=None, end=None): + + global relativedelta + from dateutil import relativedelta + + self._std_abbr = stdabbr + self._dst_abbr = dstabbr + + try: + stdoffset = stdoffset.total_seconds() + except (TypeError, AttributeError): + pass + + try: + dstoffset = dstoffset.total_seconds() + except (TypeError, AttributeError): + pass + + if stdoffset is not None: + self._std_offset = datetime.timedelta(seconds=stdoffset) + else: + self._std_offset = ZERO + + if dstoffset is not None: + self._dst_offset = datetime.timedelta(seconds=dstoffset) + elif dstabbr and stdoffset is not None: + self._dst_offset = self._std_offset + datetime.timedelta(hours=+1) + else: + self._dst_offset = ZERO + + if dstabbr and start is None: + self._start_delta = relativedelta.relativedelta( + hours=+2, month=4, day=1, weekday=relativedelta.SU(+1)) + else: + self._start_delta = start + + if dstabbr and end is None: + self._end_delta = relativedelta.relativedelta( + hours=+1, month=10, day=31, weekday=relativedelta.SU(-1)) + else: + self._end_delta = end + + self._dst_base_offset_ = self._dst_offset - self._std_offset + self.hasdst = bool(self._start_delta) + + def transitions(self, year): + """ + For a given year, get the DST on and off transition times, expressed + always on the standard time side. For zones with no transitions, this + function returns ``None``. + + :param year: + The year whose transitions you would like to query. + + :return: + Returns a :class:`tuple` of :class:`datetime.datetime` objects, + ``(dston, dstoff)`` for zones with an annual DST transition, or + ``None`` for fixed offset zones. + """ + if not self.hasdst: + return None + + base_year = datetime.datetime(year, 1, 1) + + start = base_year + self._start_delta + end = base_year + self._end_delta + + return (start, end) + + def __eq__(self, other): + if not isinstance(other, tzrange): + return NotImplemented + + return (self._std_abbr == other._std_abbr and + self._dst_abbr == other._dst_abbr and + self._std_offset == other._std_offset and + self._dst_offset == other._dst_offset and + self._start_delta == other._start_delta and + self._end_delta == other._end_delta) + + @property + def _dst_base_offset(self): + return self._dst_base_offset_ + + +@six.add_metaclass(_TzStrFactory) +class tzstr(tzrange): + """ + ``tzstr`` objects are time zone objects specified by a time-zone string as + it would be passed to a ``TZ`` variable on POSIX-style systems (see + the `GNU C Library: TZ Variable`_ for more details). + + There is one notable exception, which is that POSIX-style time zones use an + inverted offset format, so normally ``GMT+3`` would be parsed as an offset + 3 hours *behind* GMT. The ``tzstr`` time zone object will parse this as an + offset 3 hours *ahead* of GMT. If you would like to maintain the POSIX + behavior, pass a ``True`` value to ``posix_offset``. + + The :class:`tzrange` object provides the same functionality, but is + specified using :class:`relativedelta.relativedelta` objects. rather than + strings. + + :param s: + A time zone string in ``TZ`` variable format. This can be a + :class:`bytes` (2.x: :class:`str`), :class:`str` (2.x: + :class:`unicode`) or a stream emitting unicode characters + (e.g. :class:`StringIO`). + + :param posix_offset: + Optional. If set to ``True``, interpret strings such as ``GMT+3`` or + ``UTC+3`` as being 3 hours *behind* UTC rather than ahead, per the + POSIX standard. + + .. caution:: + + Prior to version 2.7.0, this function also supported time zones + in the format: + + * ``EST5EDT,4,0,6,7200,10,0,26,7200,3600`` + * ``EST5EDT,4,1,0,7200,10,-1,0,7200,3600`` + + This format is non-standard and has been deprecated; this function + will raise a :class:`DeprecatedTZFormatWarning` until + support is removed in a future version. + + .. _`GNU C Library: TZ Variable`: + https://www.gnu.org/software/libc/manual/html_node/TZ-Variable.html + """ + def __init__(self, s, posix_offset=False): + global parser + from dateutil.parser import _parser as parser + + self._s = s + + res = parser._parsetz(s) + if res is None or res.any_unused_tokens: + raise ValueError("unknown string format") + + # Here we break the compatibility with the TZ variable handling. + # GMT-3 actually *means* the timezone -3. + if res.stdabbr in ("GMT", "UTC") and not posix_offset: + res.stdoffset *= -1 + + # We must initialize it first, since _delta() needs + # _std_offset and _dst_offset set. Use False in start/end + # to avoid building it two times. + tzrange.__init__(self, res.stdabbr, res.stdoffset, + res.dstabbr, res.dstoffset, + start=False, end=False) + + if not res.dstabbr: + self._start_delta = None + self._end_delta = None + else: + self._start_delta = self._delta(res.start) + if self._start_delta: + self._end_delta = self._delta(res.end, isend=1) + + self.hasdst = bool(self._start_delta) + + def _delta(self, x, isend=0): + from dateutil import relativedelta + kwargs = {} + if x.month is not None: + kwargs["month"] = x.month + if x.weekday is not None: + kwargs["weekday"] = relativedelta.weekday(x.weekday, x.week) + if x.week > 0: + kwargs["day"] = 1 + else: + kwargs["day"] = 31 + elif x.day: + kwargs["day"] = x.day + elif x.yday is not None: + kwargs["yearday"] = x.yday + elif x.jyday is not None: + kwargs["nlyearday"] = x.jyday + if not kwargs: + # Default is to start on first sunday of april, and end + # on last sunday of october. + if not isend: + kwargs["month"] = 4 + kwargs["day"] = 1 + kwargs["weekday"] = relativedelta.SU(+1) + else: + kwargs["month"] = 10 + kwargs["day"] = 31 + kwargs["weekday"] = relativedelta.SU(-1) + if x.time is not None: + kwargs["seconds"] = x.time + else: + # Default is 2AM. + kwargs["seconds"] = 7200 + if isend: + # Convert to standard time, to follow the documented way + # of working with the extra hour. See the documentation + # of the tzinfo class. + delta = self._dst_offset - self._std_offset + kwargs["seconds"] -= delta.seconds + delta.days * 86400 + return relativedelta.relativedelta(**kwargs) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, repr(self._s)) + + +class _tzicalvtzcomp(object): + def __init__(self, tzoffsetfrom, tzoffsetto, isdst, + tzname=None, rrule=None): + self.tzoffsetfrom = datetime.timedelta(seconds=tzoffsetfrom) + self.tzoffsetto = datetime.timedelta(seconds=tzoffsetto) + self.tzoffsetdiff = self.tzoffsetto - self.tzoffsetfrom + self.isdst = isdst + self.tzname = tzname + self.rrule = rrule + + +class _tzicalvtz(_tzinfo): + def __init__(self, tzid, comps=[]): + super(_tzicalvtz, self).__init__() + + self._tzid = tzid + self._comps = comps + self._cachedate = [] + self._cachecomp = [] + self._cache_lock = _thread.allocate_lock() + + def _find_comp(self, dt): + if len(self._comps) == 1: + return self._comps[0] + + dt = dt.replace(tzinfo=None) + + try: + with self._cache_lock: + return self._cachecomp[self._cachedate.index( + (dt, self._fold(dt)))] + except ValueError: + pass + + lastcompdt = None + lastcomp = None + + for comp in self._comps: + compdt = self._find_compdt(comp, dt) + + if compdt and (not lastcompdt or lastcompdt < compdt): + lastcompdt = compdt + lastcomp = comp + + if not lastcomp: + # RFC says nothing about what to do when a given + # time is before the first onset date. We'll look for the + # first standard component, or the first component, if + # none is found. + for comp in self._comps: + if not comp.isdst: + lastcomp = comp + break + else: + lastcomp = comp[0] + + with self._cache_lock: + self._cachedate.insert(0, (dt, self._fold(dt))) + self._cachecomp.insert(0, lastcomp) + + if len(self._cachedate) > 10: + self._cachedate.pop() + self._cachecomp.pop() + + return lastcomp + + def _find_compdt(self, comp, dt): + if comp.tzoffsetdiff < ZERO and self._fold(dt): + dt -= comp.tzoffsetdiff + + compdt = comp.rrule.before(dt, inc=True) + + return compdt + + def utcoffset(self, dt): + if dt is None: + return None + + return self._find_comp(dt).tzoffsetto + + def dst(self, dt): + comp = self._find_comp(dt) + if comp.isdst: + return comp.tzoffsetdiff + else: + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + return self._find_comp(dt).tzname + + def __repr__(self): + return "" % repr(self._tzid) + + __reduce__ = object.__reduce__ + + +class tzical(object): + """ + This object is designed to parse an iCalendar-style ``VTIMEZONE`` structure + as set out in `RFC 5545`_ Section 4.6.5 into one or more `tzinfo` objects. + + :param `fileobj`: + A file or stream in iCalendar format, which should be UTF-8 encoded + with CRLF endings. + + .. _`RFC 5545`: https://tools.ietf.org/html/rfc5545 + """ + def __init__(self, fileobj): + global rrule + from dateutil import rrule + + if isinstance(fileobj, string_types): + self._s = fileobj + # ical should be encoded in UTF-8 with CRLF + fileobj = open(fileobj, 'r') + else: + self._s = getattr(fileobj, 'name', repr(fileobj)) + fileobj = _nullcontext(fileobj) + + self._vtz = {} + + with fileobj as fobj: + self._parse_rfc(fobj.read()) + + def keys(self): + """ + Retrieves the available time zones as a list. + """ + return list(self._vtz.keys()) + + def get(self, tzid=None): + """ + Retrieve a :py:class:`datetime.tzinfo` object by its ``tzid``. + + :param tzid: + If there is exactly one time zone available, omitting ``tzid`` + or passing :py:const:`None` value returns it. Otherwise a valid + key (which can be retrieved from :func:`keys`) is required. + + :raises ValueError: + Raised if ``tzid`` is not specified but there are either more + or fewer than 1 zone defined. + + :returns: + Returns either a :py:class:`datetime.tzinfo` object representing + the relevant time zone or :py:const:`None` if the ``tzid`` was + not found. + """ + if tzid is None: + if len(self._vtz) == 0: + raise ValueError("no timezones defined") + elif len(self._vtz) > 1: + raise ValueError("more than one timezone available") + tzid = next(iter(self._vtz)) + + return self._vtz.get(tzid) + + def _parse_offset(self, s): + s = s.strip() + if not s: + raise ValueError("empty offset") + if s[0] in ('+', '-'): + signal = (-1, +1)[s[0] == '+'] + s = s[1:] + else: + signal = +1 + if len(s) == 4: + return (int(s[:2]) * 3600 + int(s[2:]) * 60) * signal + elif len(s) == 6: + return (int(s[:2]) * 3600 + int(s[2:4]) * 60 + int(s[4:])) * signal + else: + raise ValueError("invalid offset: " + s) + + def _parse_rfc(self, s): + lines = s.splitlines() + if not lines: + raise ValueError("empty string") + + # Unfold + i = 0 + while i < len(lines): + line = lines[i].rstrip() + if not line: + del lines[i] + elif i > 0 and line[0] == " ": + lines[i-1] += line[1:] + del lines[i] + else: + i += 1 + + tzid = None + comps = [] + invtz = False + comptype = None + for line in lines: + if not line: + continue + name, value = line.split(':', 1) + parms = name.split(';') + if not parms: + raise ValueError("empty property name") + name = parms[0].upper() + parms = parms[1:] + if invtz: + if name == "BEGIN": + if value in ("STANDARD", "DAYLIGHT"): + # Process component + pass + else: + raise ValueError("unknown component: "+value) + comptype = value + founddtstart = False + tzoffsetfrom = None + tzoffsetto = None + rrulelines = [] + tzname = None + elif name == "END": + if value == "VTIMEZONE": + if comptype: + raise ValueError("component not closed: "+comptype) + if not tzid: + raise ValueError("mandatory TZID not found") + if not comps: + raise ValueError( + "at least one component is needed") + # Process vtimezone + self._vtz[tzid] = _tzicalvtz(tzid, comps) + invtz = False + elif value == comptype: + if not founddtstart: + raise ValueError("mandatory DTSTART not found") + if tzoffsetfrom is None: + raise ValueError( + "mandatory TZOFFSETFROM not found") + if tzoffsetto is None: + raise ValueError( + "mandatory TZOFFSETFROM not found") + # Process component + rr = None + if rrulelines: + rr = rrule.rrulestr("\n".join(rrulelines), + compatible=True, + ignoretz=True, + cache=True) + comp = _tzicalvtzcomp(tzoffsetfrom, tzoffsetto, + (comptype == "DAYLIGHT"), + tzname, rr) + comps.append(comp) + comptype = None + else: + raise ValueError("invalid component end: "+value) + elif comptype: + if name == "DTSTART": + # DTSTART in VTIMEZONE takes a subset of valid RRULE + # values under RFC 5545. + for parm in parms: + if parm != 'VALUE=DATE-TIME': + msg = ('Unsupported DTSTART param in ' + + 'VTIMEZONE: ' + parm) + raise ValueError(msg) + rrulelines.append(line) + founddtstart = True + elif name in ("RRULE", "RDATE", "EXRULE", "EXDATE"): + rrulelines.append(line) + elif name == "TZOFFSETFROM": + if parms: + raise ValueError( + "unsupported %s parm: %s " % (name, parms[0])) + tzoffsetfrom = self._parse_offset(value) + elif name == "TZOFFSETTO": + if parms: + raise ValueError( + "unsupported TZOFFSETTO parm: "+parms[0]) + tzoffsetto = self._parse_offset(value) + elif name == "TZNAME": + if parms: + raise ValueError( + "unsupported TZNAME parm: "+parms[0]) + tzname = value + elif name == "COMMENT": + pass + else: + raise ValueError("unsupported property: "+name) + else: + if name == "TZID": + if parms: + raise ValueError( + "unsupported TZID parm: "+parms[0]) + tzid = value + elif name in ("TZURL", "LAST-MODIFIED", "COMMENT"): + pass + else: + raise ValueError("unsupported property: "+name) + elif name == "BEGIN" and value == "VTIMEZONE": + tzid = None + comps = [] + invtz = True + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, repr(self._s)) + + +if sys.platform != "win32": + TZFILES = ["/etc/localtime", "localtime"] + TZPATHS = ["/usr/share/zoneinfo", + "/usr/lib/zoneinfo", + "/usr/share/lib/zoneinfo", + "/etc/zoneinfo"] +else: + TZFILES = [] + TZPATHS = [] + + +def __get_gettz(): + tzlocal_classes = (tzlocal,) + if tzwinlocal is not None: + tzlocal_classes += (tzwinlocal,) + + class GettzFunc(object): + """ + Retrieve a time zone object from a string representation + + This function is intended to retrieve the :py:class:`tzinfo` subclass + that best represents the time zone that would be used if a POSIX + `TZ variable`_ were set to the same value. + + If no argument or an empty string is passed to ``gettz``, local time + is returned: + + .. code-block:: python3 + + >>> gettz() + tzfile('/etc/localtime') + + This function is also the preferred way to map IANA tz database keys + to :class:`tzfile` objects: + + .. code-block:: python3 + + >>> gettz('Pacific/Kiritimati') + tzfile('/usr/share/zoneinfo/Pacific/Kiritimati') + + On Windows, the standard is extended to include the Windows-specific + zone names provided by the operating system: + + .. code-block:: python3 + + >>> gettz('Egypt Standard Time') + tzwin('Egypt Standard Time') + + Passing a GNU ``TZ`` style string time zone specification returns a + :class:`tzstr` object: + + .. code-block:: python3 + + >>> gettz('AEST-10AEDT-11,M10.1.0/2,M4.1.0/3') + tzstr('AEST-10AEDT-11,M10.1.0/2,M4.1.0/3') + + :param name: + A time zone name (IANA, or, on Windows, Windows keys), location of + a ``tzfile(5)`` zoneinfo file or ``TZ`` variable style time zone + specifier. An empty string, no argument or ``None`` is interpreted + as local time. + + :return: + Returns an instance of one of ``dateutil``'s :py:class:`tzinfo` + subclasses. + + .. versionchanged:: 2.7.0 + + After version 2.7.0, any two calls to ``gettz`` using the same + input strings will return the same object: + + .. code-block:: python3 + + >>> tz.gettz('America/Chicago') is tz.gettz('America/Chicago') + True + + In addition to improving performance, this ensures that + `"same zone" semantics`_ are used for datetimes in the same zone. + + + .. _`TZ variable`: + https://www.gnu.org/software/libc/manual/html_node/TZ-Variable.html + + .. _`"same zone" semantics`: + https://blog.ganssle.io/articles/2018/02/aware-datetime-arithmetic.html + """ + def __init__(self): + + self.__instances = weakref.WeakValueDictionary() + self.__strong_cache_size = 8 + self.__strong_cache = OrderedDict() + self._cache_lock = _thread.allocate_lock() + + def __call__(self, name=None): + with self._cache_lock: + rv = self.__instances.get(name, None) + + if rv is None: + rv = self.nocache(name=name) + if not (name is None + or isinstance(rv, tzlocal_classes) + or rv is None): + # tzlocal is slightly more complicated than the other + # time zone providers because it depends on environment + # at construction time, so don't cache that. + # + # We also cannot store weak references to None, so we + # will also not store that. + self.__instances[name] = rv + else: + # No need for strong caching, return immediately + return rv + + self.__strong_cache[name] = self.__strong_cache.pop(name, rv) + + if len(self.__strong_cache) > self.__strong_cache_size: + self.__strong_cache.popitem(last=False) + + return rv + + def set_cache_size(self, size): + with self._cache_lock: + self.__strong_cache_size = size + while len(self.__strong_cache) > size: + self.__strong_cache.popitem(last=False) + + def cache_clear(self): + with self._cache_lock: + self.__instances = weakref.WeakValueDictionary() + self.__strong_cache.clear() + + @staticmethod + def nocache(name=None): + """A non-cached version of gettz""" + tz = None + if not name: + try: + name = os.environ["TZ"] + except KeyError: + pass + if name is None or name in ("", ":"): + for filepath in TZFILES: + if not os.path.isabs(filepath): + filename = filepath + for path in TZPATHS: + filepath = os.path.join(path, filename) + if os.path.isfile(filepath): + break + else: + continue + if os.path.isfile(filepath): + try: + tz = tzfile(filepath) + break + except (IOError, OSError, ValueError): + pass + else: + tz = tzlocal() + else: + try: + if name.startswith(":"): + name = name[1:] + except TypeError as e: + if isinstance(name, bytes): + new_msg = "gettz argument should be str, not bytes" + six.raise_from(TypeError(new_msg), e) + else: + raise + if os.path.isabs(name): + if os.path.isfile(name): + tz = tzfile(name) + else: + tz = None + else: + for path in TZPATHS: + filepath = os.path.join(path, name) + if not os.path.isfile(filepath): + filepath = filepath.replace(' ', '_') + if not os.path.isfile(filepath): + continue + try: + tz = tzfile(filepath) + break + except (IOError, OSError, ValueError): + pass + else: + tz = None + if tzwin is not None: + try: + tz = tzwin(name) + except (WindowsError, UnicodeEncodeError): + # UnicodeEncodeError is for Python 2.7 compat + tz = None + + if not tz: + from dateutil.zoneinfo import get_zonefile_instance + tz = get_zonefile_instance().get(name) + + if not tz: + for c in name: + # name is not a tzstr unless it has at least + # one offset. For short values of "name", an + # explicit for loop seems to be the fastest way + # To determine if a string contains a digit + if c in "0123456789": + try: + tz = tzstr(name) + except ValueError: + pass + break + else: + if name in ("GMT", "UTC"): + tz = UTC + elif name in time.tzname: + tz = tzlocal() + return tz + + return GettzFunc() + + +gettz = __get_gettz() +del __get_gettz + + +def datetime_exists(dt, tz=None): + """ + Given a datetime and a time zone, determine whether or not a given datetime + would fall in a gap. + + :param dt: + A :class:`datetime.datetime` (whose time zone will be ignored if ``tz`` + is provided.) + + :param tz: + A :class:`datetime.tzinfo` with support for the ``fold`` attribute. If + ``None`` or not provided, the datetime's own time zone will be used. + + :return: + Returns a boolean value whether or not the "wall time" exists in + ``tz``. + + .. versionadded:: 2.7.0 + """ + if tz is None: + if dt.tzinfo is None: + raise ValueError('Datetime is naive and no time zone provided.') + tz = dt.tzinfo + + dt = dt.replace(tzinfo=None) + + # This is essentially a test of whether or not the datetime can survive + # a round trip to UTC. + dt_rt = dt.replace(tzinfo=tz).astimezone(UTC).astimezone(tz) + dt_rt = dt_rt.replace(tzinfo=None) + + return dt == dt_rt + + +def datetime_ambiguous(dt, tz=None): + """ + Given a datetime and a time zone, determine whether or not a given datetime + is ambiguous (i.e if there are two times differentiated only by their DST + status). + + :param dt: + A :class:`datetime.datetime` (whose time zone will be ignored if ``tz`` + is provided.) + + :param tz: + A :class:`datetime.tzinfo` with support for the ``fold`` attribute. If + ``None`` or not provided, the datetime's own time zone will be used. + + :return: + Returns a boolean value whether or not the "wall time" is ambiguous in + ``tz``. + + .. versionadded:: 2.6.0 + """ + if tz is None: + if dt.tzinfo is None: + raise ValueError('Datetime is naive and no time zone provided.') + + tz = dt.tzinfo + + # If a time zone defines its own "is_ambiguous" function, we'll use that. + is_ambiguous_fn = getattr(tz, 'is_ambiguous', None) + if is_ambiguous_fn is not None: + try: + return tz.is_ambiguous(dt) + except Exception: + pass + + # If it doesn't come out and tell us it's ambiguous, we'll just check if + # the fold attribute has any effect on this particular date and time. + dt = dt.replace(tzinfo=tz) + wall_0 = enfold(dt, fold=0) + wall_1 = enfold(dt, fold=1) + + same_offset = wall_0.utcoffset() == wall_1.utcoffset() + same_dst = wall_0.dst() == wall_1.dst() + + return not (same_offset and same_dst) + + +def resolve_imaginary(dt): + """ + Given a datetime that may be imaginary, return an existing datetime. + + This function assumes that an imaginary datetime represents what the + wall time would be in a zone had the offset transition not occurred, so + it will always fall forward by the transition's change in offset. + + .. doctest:: + + >>> from dateutil import tz + >>> from datetime import datetime + >>> NYC = tz.gettz('America/New_York') + >>> print(tz.resolve_imaginary(datetime(2017, 3, 12, 2, 30, tzinfo=NYC))) + 2017-03-12 03:30:00-04:00 + + >>> KIR = tz.gettz('Pacific/Kiritimati') + >>> print(tz.resolve_imaginary(datetime(1995, 1, 1, 12, 30, tzinfo=KIR))) + 1995-01-02 12:30:00+14:00 + + As a note, :func:`datetime.astimezone` is guaranteed to produce a valid, + existing datetime, so a round-trip to and from UTC is sufficient to get + an extant datetime, however, this generally "falls back" to an earlier time + rather than falling forward to the STD side (though no guarantees are made + about this behavior). + + :param dt: + A :class:`datetime.datetime` which may or may not exist. + + :return: + Returns an existing :class:`datetime.datetime`. If ``dt`` was not + imaginary, the datetime returned is guaranteed to be the same object + passed to the function. + + .. versionadded:: 2.7.0 + """ + if dt.tzinfo is not None and not datetime_exists(dt): + + curr_offset = (dt + datetime.timedelta(hours=24)).utcoffset() + old_offset = (dt - datetime.timedelta(hours=24)).utcoffset() + + dt += curr_offset - old_offset + + return dt + + +def _datetime_to_timestamp(dt): + """ + Convert a :class:`datetime.datetime` object to an epoch timestamp in + seconds since January 1, 1970, ignoring the time zone. + """ + return (dt.replace(tzinfo=None) - EPOCH).total_seconds() + + +if sys.version_info >= (3, 6): + def _get_supported_offset(second_offset): + return second_offset +else: + def _get_supported_offset(second_offset): + # For python pre-3.6, round to full-minutes if that's not the case. + # Python's datetime doesn't accept sub-minute timezones. Check + # http://python.org/sf/1447945 or https://bugs.python.org/issue5288 + # for some information. + old_offset = second_offset + calculated_offset = 60 * ((second_offset + 30) // 60) + return calculated_offset + + +try: + # Python 3.7 feature + from contextlib import nullcontext as _nullcontext +except ImportError: + class _nullcontext(object): + """ + Class for wrapping contexts so that they are passed through in a + with statement. + """ + def __init__(self, context): + self.context = context + + def __enter__(self): + return self.context + + def __exit__(*args, **kwargs): + pass + +# vim:ts=4:sw=4:et diff --git a/dateutil/tz/win.py b/dateutil/tz/win.py new file mode 100644 index 0000000000000000000000000000000000000000..cde07ba792c40903f0c334839140173b39fd8124 --- /dev/null +++ b/dateutil/tz/win.py @@ -0,0 +1,370 @@ +# -*- coding: utf-8 -*- +""" +This module provides an interface to the native time zone data on Windows, +including :py:class:`datetime.tzinfo` implementations. + +Attempting to import this module on a non-Windows platform will raise an +:py:obj:`ImportError`. +""" +# This code was originally contributed by Jeffrey Harris. +import datetime +import struct + +from six.moves import winreg +from six import text_type + +try: + import ctypes + from ctypes import wintypes +except ValueError: + # ValueError is raised on non-Windows systems for some horrible reason. + raise ImportError("Running tzwin on non-Windows system") + +from ._common import tzrangebase + +__all__ = ["tzwin", "tzwinlocal", "tzres"] + +ONEWEEK = datetime.timedelta(7) + +TZKEYNAMENT = r"SOFTWARE\Microsoft\Windows NT\CurrentVersion\Time Zones" +TZKEYNAME9X = r"SOFTWARE\Microsoft\Windows\CurrentVersion\Time Zones" +TZLOCALKEYNAME = r"SYSTEM\CurrentControlSet\Control\TimeZoneInformation" + + +def _settzkeyname(): + handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) + try: + winreg.OpenKey(handle, TZKEYNAMENT).Close() + TZKEYNAME = TZKEYNAMENT + except WindowsError: + TZKEYNAME = TZKEYNAME9X + handle.Close() + return TZKEYNAME + + +TZKEYNAME = _settzkeyname() + + +class tzres(object): + """ + Class for accessing ``tzres.dll``, which contains timezone name related + resources. + + .. versionadded:: 2.5.0 + """ + p_wchar = ctypes.POINTER(wintypes.WCHAR) # Pointer to a wide char + + def __init__(self, tzres_loc='tzres.dll'): + # Load the user32 DLL so we can load strings from tzres + user32 = ctypes.WinDLL('user32') + + # Specify the LoadStringW function + user32.LoadStringW.argtypes = (wintypes.HINSTANCE, + wintypes.UINT, + wintypes.LPWSTR, + ctypes.c_int) + + self.LoadStringW = user32.LoadStringW + self._tzres = ctypes.WinDLL(tzres_loc) + self.tzres_loc = tzres_loc + + def load_name(self, offset): + """ + Load a timezone name from a DLL offset (integer). + + >>> from dateutil.tzwin import tzres + >>> tzr = tzres() + >>> print(tzr.load_name(112)) + 'Eastern Standard Time' + + :param offset: + A positive integer value referring to a string from the tzres dll. + + .. note:: + + Offsets found in the registry are generally of the form + ``@tzres.dll,-114``. The offset in this case is 114, not -114. + + """ + resource = self.p_wchar() + lpBuffer = ctypes.cast(ctypes.byref(resource), wintypes.LPWSTR) + nchar = self.LoadStringW(self._tzres._handle, offset, lpBuffer, 0) + return resource[:nchar] + + def name_from_string(self, tzname_str): + """ + Parse strings as returned from the Windows registry into the time zone + name as defined in the registry. + + >>> from dateutil.tzwin import tzres + >>> tzr = tzres() + >>> print(tzr.name_from_string('@tzres.dll,-251')) + 'Dateline Daylight Time' + >>> print(tzr.name_from_string('Eastern Standard Time')) + 'Eastern Standard Time' + + :param tzname_str: + A timezone name string as returned from a Windows registry key. + + :return: + Returns the localized timezone string from tzres.dll if the string + is of the form `@tzres.dll,-offset`, else returns the input string. + """ + if not tzname_str.startswith('@'): + return tzname_str + + name_splt = tzname_str.split(',-') + try: + offset = int(name_splt[1]) + except: + raise ValueError("Malformed timezone string.") + + return self.load_name(offset) + + +class tzwinbase(tzrangebase): + """tzinfo class based on win32's timezones available in the registry.""" + def __init__(self): + raise NotImplementedError('tzwinbase is an abstract base class') + + def __eq__(self, other): + # Compare on all relevant dimensions, including name. + if not isinstance(other, tzwinbase): + return NotImplemented + + return (self._std_offset == other._std_offset and + self._dst_offset == other._dst_offset and + self._stddayofweek == other._stddayofweek and + self._dstdayofweek == other._dstdayofweek and + self._stdweeknumber == other._stdweeknumber and + self._dstweeknumber == other._dstweeknumber and + self._stdhour == other._stdhour and + self._dsthour == other._dsthour and + self._stdminute == other._stdminute and + self._dstminute == other._dstminute and + self._std_abbr == other._std_abbr and + self._dst_abbr == other._dst_abbr) + + @staticmethod + def list(): + """Return a list of all time zones known to the system.""" + with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle: + with winreg.OpenKey(handle, TZKEYNAME) as tzkey: + result = [winreg.EnumKey(tzkey, i) + for i in range(winreg.QueryInfoKey(tzkey)[0])] + return result + + def display(self): + """ + Return the display name of the time zone. + """ + return self._display + + def transitions(self, year): + """ + For a given year, get the DST on and off transition times, expressed + always on the standard time side. For zones with no transitions, this + function returns ``None``. + + :param year: + The year whose transitions you would like to query. + + :return: + Returns a :class:`tuple` of :class:`datetime.datetime` objects, + ``(dston, dstoff)`` for zones with an annual DST transition, or + ``None`` for fixed offset zones. + """ + + if not self.hasdst: + return None + + dston = picknthweekday(year, self._dstmonth, self._dstdayofweek, + self._dsthour, self._dstminute, + self._dstweeknumber) + + dstoff = picknthweekday(year, self._stdmonth, self._stddayofweek, + self._stdhour, self._stdminute, + self._stdweeknumber) + + # Ambiguous dates default to the STD side + dstoff -= self._dst_base_offset + + return dston, dstoff + + def _get_hasdst(self): + return self._dstmonth != 0 + + @property + def _dst_base_offset(self): + return self._dst_base_offset_ + + +class tzwin(tzwinbase): + """ + Time zone object created from the zone info in the Windows registry + + These are similar to :py:class:`dateutil.tz.tzrange` objects in that + the time zone data is provided in the format of a single offset rule + for either 0 or 2 time zone transitions per year. + + :param: name + The name of a Windows time zone key, e.g. "Eastern Standard Time". + The full list of keys can be retrieved with :func:`tzwin.list`. + """ + + def __init__(self, name): + self._name = name + + with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle: + tzkeyname = text_type("{kn}\\{name}").format(kn=TZKEYNAME, name=name) + with winreg.OpenKey(handle, tzkeyname) as tzkey: + keydict = valuestodict(tzkey) + + self._std_abbr = keydict["Std"] + self._dst_abbr = keydict["Dlt"] + + self._display = keydict["Display"] + + # See http://ww_winreg.jsiinc.com/SUBA/tip0300/rh0398.htm + tup = struct.unpack("=3l16h", keydict["TZI"]) + stdoffset = -tup[0]-tup[1] # Bias + StandardBias * -1 + dstoffset = stdoffset-tup[2] # + DaylightBias * -1 + self._std_offset = datetime.timedelta(minutes=stdoffset) + self._dst_offset = datetime.timedelta(minutes=dstoffset) + + # for the meaning see the win32 TIME_ZONE_INFORMATION structure docs + # http://msdn.microsoft.com/en-us/library/windows/desktop/ms725481(v=vs.85).aspx + (self._stdmonth, + self._stddayofweek, # Sunday = 0 + self._stdweeknumber, # Last = 5 + self._stdhour, + self._stdminute) = tup[4:9] + + (self._dstmonth, + self._dstdayofweek, # Sunday = 0 + self._dstweeknumber, # Last = 5 + self._dsthour, + self._dstminute) = tup[12:17] + + self._dst_base_offset_ = self._dst_offset - self._std_offset + self.hasdst = self._get_hasdst() + + def __repr__(self): + return "tzwin(%s)" % repr(self._name) + + def __reduce__(self): + return (self.__class__, (self._name,)) + + +class tzwinlocal(tzwinbase): + """ + Class representing the local time zone information in the Windows registry + + While :class:`dateutil.tz.tzlocal` makes system calls (via the :mod:`time` + module) to retrieve time zone information, ``tzwinlocal`` retrieves the + rules directly from the Windows registry and creates an object like + :class:`dateutil.tz.tzwin`. + + Because Windows does not have an equivalent of :func:`time.tzset`, on + Windows, :class:`dateutil.tz.tzlocal` instances will always reflect the + time zone settings *at the time that the process was started*, meaning + changes to the machine's time zone settings during the run of a program + on Windows will **not** be reflected by :class:`dateutil.tz.tzlocal`. + Because ``tzwinlocal`` reads the registry directly, it is unaffected by + this issue. + """ + def __init__(self): + with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle: + with winreg.OpenKey(handle, TZLOCALKEYNAME) as tzlocalkey: + keydict = valuestodict(tzlocalkey) + + self._std_abbr = keydict["StandardName"] + self._dst_abbr = keydict["DaylightName"] + + try: + tzkeyname = text_type('{kn}\\{sn}').format(kn=TZKEYNAME, + sn=self._std_abbr) + with winreg.OpenKey(handle, tzkeyname) as tzkey: + _keydict = valuestodict(tzkey) + self._display = _keydict["Display"] + except OSError: + self._display = None + + stdoffset = -keydict["Bias"]-keydict["StandardBias"] + dstoffset = stdoffset-keydict["DaylightBias"] + + self._std_offset = datetime.timedelta(minutes=stdoffset) + self._dst_offset = datetime.timedelta(minutes=dstoffset) + + # For reasons unclear, in this particular key, the day of week has been + # moved to the END of the SYSTEMTIME structure. + tup = struct.unpack("=8h", keydict["StandardStart"]) + + (self._stdmonth, + self._stdweeknumber, # Last = 5 + self._stdhour, + self._stdminute) = tup[1:5] + + self._stddayofweek = tup[7] + + tup = struct.unpack("=8h", keydict["DaylightStart"]) + + (self._dstmonth, + self._dstweeknumber, # Last = 5 + self._dsthour, + self._dstminute) = tup[1:5] + + self._dstdayofweek = tup[7] + + self._dst_base_offset_ = self._dst_offset - self._std_offset + self.hasdst = self._get_hasdst() + + def __repr__(self): + return "tzwinlocal()" + + def __str__(self): + # str will return the standard name, not the daylight name. + return "tzwinlocal(%s)" % repr(self._std_abbr) + + def __reduce__(self): + return (self.__class__, ()) + + +def picknthweekday(year, month, dayofweek, hour, minute, whichweek): + """ dayofweek == 0 means Sunday, whichweek 5 means last instance """ + first = datetime.datetime(year, month, 1, hour, minute) + + # This will work if dayofweek is ISO weekday (1-7) or Microsoft-style (0-6), + # Because 7 % 7 = 0 + weekdayone = first.replace(day=((dayofweek - first.isoweekday()) % 7) + 1) + wd = weekdayone + ((whichweek - 1) * ONEWEEK) + if (wd.month != month): + wd -= ONEWEEK + + return wd + + +def valuestodict(key): + """Convert a registry key's values to a dictionary.""" + dout = {} + size = winreg.QueryInfoKey(key)[1] + tz_res = None + + for i in range(size): + key_name, value, dtype = winreg.EnumValue(key, i) + if dtype == winreg.REG_DWORD or dtype == winreg.REG_DWORD_LITTLE_ENDIAN: + # If it's a DWORD (32-bit integer), it's stored as unsigned - convert + # that to a proper signed integer + if value & (1 << 31): + value = value - (1 << 32) + elif dtype == winreg.REG_SZ: + # If it's a reference to the tzres DLL, load the actual string + if value.startswith('@tzres'): + tz_res = tz_res or tzres() + value = tz_res.name_from_string(value) + + value = value.rstrip('\x00') # Remove trailing nulls + + dout[key_name] = value + + return dout diff --git a/dateutil/zoneinfo/__init__.py b/dateutil/zoneinfo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34f11ad66c88047f2c049a4cdcc937b4b78ea6d6 --- /dev/null +++ b/dateutil/zoneinfo/__init__.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +import warnings +import json + +from tarfile import TarFile +from pkgutil import get_data +from io import BytesIO + +from dateutil.tz import tzfile as _tzfile + +__all__ = ["get_zonefile_instance", "gettz", "gettz_db_metadata"] + +ZONEFILENAME = "dateutil-zoneinfo.tar.gz" +METADATA_FN = 'METADATA' + + +class tzfile(_tzfile): + def __reduce__(self): + return (gettz, (self._filename,)) + + +def getzoneinfofile_stream(): + try: + return BytesIO(get_data(__name__, ZONEFILENAME)) + except IOError as e: # TODO switch to FileNotFoundError? + warnings.warn("I/O error({0}): {1}".format(e.errno, e.strerror)) + return None + + +class ZoneInfoFile(object): + def __init__(self, zonefile_stream=None): + if zonefile_stream is not None: + with TarFile.open(fileobj=zonefile_stream) as tf: + self.zones = {zf.name: tzfile(tf.extractfile(zf), filename=zf.name) + for zf in tf.getmembers() + if zf.isfile() and zf.name != METADATA_FN} + # deal with links: They'll point to their parent object. Less + # waste of memory + links = {zl.name: self.zones[zl.linkname] + for zl in tf.getmembers() if + zl.islnk() or zl.issym()} + self.zones.update(links) + try: + metadata_json = tf.extractfile(tf.getmember(METADATA_FN)) + metadata_str = metadata_json.read().decode('UTF-8') + self.metadata = json.loads(metadata_str) + except KeyError: + # no metadata in tar file + self.metadata = None + else: + self.zones = {} + self.metadata = None + + def get(self, name, default=None): + """ + Wrapper for :func:`ZoneInfoFile.zones.get`. This is a convenience method + for retrieving zones from the zone dictionary. + + :param name: + The name of the zone to retrieve. (Generally IANA zone names) + + :param default: + The value to return in the event of a missing key. + + .. versionadded:: 2.6.0 + + """ + return self.zones.get(name, default) + + +# The current API has gettz as a module function, although in fact it taps into +# a stateful class. So as a workaround for now, without changing the API, we +# will create a new "global" class instance the first time a user requests a +# timezone. Ugly, but adheres to the api. +# +# TODO: Remove after deprecation period. +_CLASS_ZONE_INSTANCE = [] + + +def get_zonefile_instance(new_instance=False): + """ + This is a convenience function which provides a :class:`ZoneInfoFile` + instance using the data provided by the ``dateutil`` package. By default, it + caches a single instance of the ZoneInfoFile object and returns that. + + :param new_instance: + If ``True``, a new instance of :class:`ZoneInfoFile` is instantiated and + used as the cached instance for the next call. Otherwise, new instances + are created only as necessary. + + :return: + Returns a :class:`ZoneInfoFile` object. + + .. versionadded:: 2.6 + """ + if new_instance: + zif = None + else: + zif = getattr(get_zonefile_instance, '_cached_instance', None) + + if zif is None: + zif = ZoneInfoFile(getzoneinfofile_stream()) + + get_zonefile_instance._cached_instance = zif + + return zif + + +def gettz(name): + """ + This retrieves a time zone from the local zoneinfo tarball that is packaged + with dateutil. + + :param name: + An IANA-style time zone name, as found in the zoneinfo file. + + :return: + Returns a :class:`dateutil.tz.tzfile` time zone object. + + .. warning:: + It is generally inadvisable to use this function, and it is only + provided for API compatibility with earlier versions. This is *not* + equivalent to ``dateutil.tz.gettz()``, which selects an appropriate + time zone based on the inputs, favoring system zoneinfo. This is ONLY + for accessing the dateutil-specific zoneinfo (which may be out of + date compared to the system zoneinfo). + + .. deprecated:: 2.6 + If you need to use a specific zoneinfofile over the system zoneinfo, + instantiate a :class:`dateutil.zoneinfo.ZoneInfoFile` object and call + :func:`dateutil.zoneinfo.ZoneInfoFile.get(name)` instead. + + Use :func:`get_zonefile_instance` to retrieve an instance of the + dateutil-provided zoneinfo. + """ + warnings.warn("zoneinfo.gettz() will be removed in future versions, " + "to use the dateutil-provided zoneinfo files, instantiate a " + "ZoneInfoFile object and use ZoneInfoFile.zones.get() " + "instead. See the documentation for details.", + DeprecationWarning) + + if len(_CLASS_ZONE_INSTANCE) == 0: + _CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream())) + return _CLASS_ZONE_INSTANCE[0].zones.get(name) + + +def gettz_db_metadata(): + """ Get the zonefile metadata + + See `zonefile_metadata`_ + + :returns: + A dictionary with the database metadata + + .. deprecated:: 2.6 + See deprecation warning in :func:`zoneinfo.gettz`. To get metadata, + query the attribute ``zoneinfo.ZoneInfoFile.metadata``. + """ + warnings.warn("zoneinfo.gettz_db_metadata() will be removed in future " + "versions, to use the dateutil-provided zoneinfo files, " + "ZoneInfoFile object and query the 'metadata' attribute " + "instead. See the documentation for details.", + DeprecationWarning) + + if len(_CLASS_ZONE_INSTANCE) == 0: + _CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream())) + return _CLASS_ZONE_INSTANCE[0].metadata diff --git a/dateutil/zoneinfo/rebuild.py b/dateutil/zoneinfo/rebuild.py new file mode 100644 index 0000000000000000000000000000000000000000..684c6586f091350c347f2b6150935f5214ffec27 --- /dev/null +++ b/dateutil/zoneinfo/rebuild.py @@ -0,0 +1,75 @@ +import logging +import os +import tempfile +import shutil +import json +from subprocess import check_call, check_output +from tarfile import TarFile + +from dateutil.zoneinfo import METADATA_FN, ZONEFILENAME + + +def rebuild(filename, tag=None, format="gz", zonegroups=[], metadata=None): + """Rebuild the internal timezone info in dateutil/zoneinfo/zoneinfo*tar* + + filename is the timezone tarball from ``ftp.iana.org/tz``. + + """ + tmpdir = tempfile.mkdtemp() + zonedir = os.path.join(tmpdir, "zoneinfo") + moduledir = os.path.dirname(__file__) + try: + with TarFile.open(filename) as tf: + for name in zonegroups: + tf.extract(name, tmpdir) + filepaths = [os.path.join(tmpdir, n) for n in zonegroups] + + _run_zic(zonedir, filepaths) + + # write metadata file + with open(os.path.join(zonedir, METADATA_FN), 'w') as f: + json.dump(metadata, f, indent=4, sort_keys=True) + target = os.path.join(moduledir, ZONEFILENAME) + with TarFile.open(target, "w:%s" % format) as tf: + for entry in os.listdir(zonedir): + entrypath = os.path.join(zonedir, entry) + tf.add(entrypath, entry) + finally: + shutil.rmtree(tmpdir) + + +def _run_zic(zonedir, filepaths): + """Calls the ``zic`` compiler in a compatible way to get a "fat" binary. + + Recent versions of ``zic`` default to ``-b slim``, while older versions + don't even have the ``-b`` option (but default to "fat" binaries). The + current version of dateutil does not support Version 2+ TZif files, which + causes problems when used in conjunction with "slim" binaries, so this + function is used to ensure that we always get a "fat" binary. + """ + + try: + help_text = check_output(["zic", "--help"]) + except OSError as e: + _print_on_nosuchfile(e) + raise + + if b"-b " in help_text: + bloat_args = ["-b", "fat"] + else: + bloat_args = [] + + check_call(["zic"] + bloat_args + ["-d", zonedir] + filepaths) + + +def _print_on_nosuchfile(e): + """Print helpful troubleshooting message + + e is an exception raised by subprocess.check_call() + + """ + if e.errno == 2: + logging.error( + "Could not find zic. Perhaps you need to install " + "libc-bin or some other package that provides it, " + "or it's not in your PATH?") diff --git a/pandas/_config/__init__.py b/pandas/_config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7fe239b8f8a8f9615ecee81caef501076e8224 --- /dev/null +++ b/pandas/_config/__init__.py @@ -0,0 +1,45 @@ +""" +pandas._config is considered explicitly upstream of everything else in pandas, +should have no intra-pandas dependencies. + +importing `dates` and `display` ensures that keys needed by _libs +are initialized. +""" + +__all__ = [ + "config", + "describe_option", + "detect_console_encoding", + "get_option", + "option_context", + "options", + "reset_option", + "set_option", +] +from pandas._config import config +from pandas._config import dates # pyright: ignore[reportUnusedImport] # noqa: F401 +from pandas._config.config import ( + _global_config, + describe_option, + get_option, + option_context, + options, + reset_option, + set_option, +) +from pandas._config.display import detect_console_encoding + + +def using_string_dtype() -> bool: + _mode_options = _global_config["future"] + return _mode_options["infer_string"] + + +def using_python_scalars() -> bool: + _mode_options = _global_config["future"] + return _mode_options["python_scalars"] + + +def is_nan_na() -> bool: + _mode_options = _global_config["future"] + return not _mode_options["distinguish_nan_and_na"] diff --git a/pandas/_config/config.py b/pandas/_config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..041596cf002a3c9a0d8a58e97a25de559d41f49d --- /dev/null +++ b/pandas/_config/config.py @@ -0,0 +1,954 @@ +""" +The config module holds package-wide configurables and provides +a uniform API for working with them. + +Overview +======== + +This module supports the following requirements: +- options are referenced using keys in dot.notation, e.g. "x.y.option - z". +- keys are case-insensitive. +- functions should accept partial/regex keys, when unambiguous. +- options can be registered by modules at import time. +- options can be registered at init-time (via core.config_init) +- options have a default value, and (optionally) a description and + validation function associated with them. +- options can be deprecated, in which case referencing them + should produce a warning. +- deprecated options can optionally be rerouted to a replacement + so that accessing a deprecated option reroutes to a differently + named option. +- options can be reset to their default value. +- all option can be reset to their default value at once. +- all options in a certain sub - namespace can be reset at once. +- the user can set / get / reset or ask for the description of an option. +- a developer can register and mark an option as deprecated. +- you can register a callback to be invoked when the option value + is set or reset. Changing the stored value is considered misuse, but + is not verboten. + +Implementation +============== + +- Data is stored using nested dictionaries, and should be accessed + through the provided API. + +- "Registered options" and "Deprecated options" have metadata associated + with them, which are stored in auxiliary dictionaries keyed on the + fully-qualified key, e.g. "x.y.z.option". + +- the config_init module is imported by the package's __init__.py file. + placing any register_option() calls there will ensure those options + are available as soon as pandas is loaded. If you use register_option + in a module, it will only be available after that module is imported, + which you should be aware of. + +- `config_prefix` is a context_manager (for use with the `with` keyword) + which can save developers some typing, see the docstring. + +""" + +from __future__ import annotations + +from contextlib import contextmanager +import re +from typing import ( + TYPE_CHECKING, + Any, + NamedTuple, + cast, +) +import warnings + +from pandas._typing import F +from pandas.util._exceptions import find_stack_level + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Generator, + Sequence, + ) + + +class DeprecatedOption(NamedTuple): + key: str + category: type[Warning] + msg: str | None + rkey: str | None + removal_ver: str | None + + +class RegisteredOption(NamedTuple): + key: str + defval: Any + doc: str + validator: Callable[[object], Any] | None + cb: Callable[[str], Any] | None + + +# holds deprecated option metadata +_deprecated_options: dict[str, DeprecatedOption] = {} + +# holds registered option metadata +_registered_options: dict[str, RegisteredOption] = {} + +# holds the current values for registered options +_global_config: dict[str, Any] = {} + +# keys which have a special meaning +_reserved_keys: list[str] = ["all"] + + +class OptionError(AttributeError, KeyError): + """ + Exception raised for pandas.options. + + Backwards compatible with KeyError checks. + + See Also + -------- + options : Access and modify global pandas settings. + + Examples + -------- + >>> pd.options.context + Traceback (most recent call last): + OptionError: No such option + """ + + __module__ = "pandas.errors" + + +# +# User API + + +def _get_single_key(pat: str) -> str: + keys = _select_options(pat) + if len(keys) == 0: + _warn_if_deprecated(pat) + raise OptionError(f"No such keys(s): {pat!r}") + if len(keys) > 1: + raise OptionError("Pattern matched multiple keys") + key = keys[0] + + _warn_if_deprecated(key) + + key = _translate_key(key) + + return key + + +def get_option(pat: str) -> Any: + """ + Retrieve the value of the specified option. + + This method allows users to query the current value of a given option + in the pandas configuration system. Options control various display, + performance, and behavior-related settings within pandas. + + Parameters + ---------- + pat : str + Regexp which should match a single option. + + .. warning:: + + Partial matches are supported for convenience, but unless you use the + full option name (e.g. x.y.z.option_name), your code may break in future + versions if new options with similar names are introduced. + + Returns + ------- + Any + The value of the option. + + Raises + ------ + OptionError : if no such option exists + + See Also + -------- + set_option : Set the value of the specified option or options. + reset_option : Reset one or more options to their default value. + describe_option : Print the description for one or more registered options. + + Notes + ----- + For all available options, please view the :ref:`User Guide ` + or use ``pandas.describe_option()``. + + Examples + -------- + >>> pd.get_option("display.max_columns") # doctest: +SKIP + 4 + """ + key = _get_single_key(pat) + + # walk the nested dict + root, k = _get_root(key) + return root[k] + + +def set_option(*args) -> None: + """ + Set the value of the specified option or options. + + This method allows fine-grained control over the behavior and display settings + of pandas. Options affect various functionalities such as output formatting, + display limits, and operational behavior. Settings can be modified at runtime + without requiring changes to global configurations or environment variables. + + Parameters + ---------- + *args : str | object | dict + Arguments provided in pairs, which will be interpreted as (pattern, value), + or as a single dictionary containing multiple option-value pairs. + pattern: str + Regexp which should match a single option + value: object + New value of option + + .. warning:: + + Partial pattern matches are supported for convenience, but unless you + use the full option name (e.g. x.y.z.option_name), your code may break in + future versions if new options with similar names are introduced. + + Returns + ------- + None + No return value. + + Raises + ------ + ValueError if odd numbers of non-keyword arguments are provided + TypeError if keyword arguments are provided + OptionError if no such option exists + + See Also + -------- + get_option : Retrieve the value of the specified option. + reset_option : Reset one or more options to their default value. + describe_option : Print the description for one or more registered options. + option_context : Context manager to temporarily set options in a ``with`` + statement. + + Notes + ----- + For all available options, please view the :ref:`User Guide ` + or use ``pandas.describe_option()``. + + Examples + -------- + Option-Value Pair Input: + + >>> pd.set_option("display.max_columns", 4) + >>> df = pd.DataFrame([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + >>> df + 0 1 ... 3 4 + 0 1 2 ... 4 5 + 1 6 7 ... 9 10 + [2 rows x 5 columns] + >>> pd.reset_option("display.max_columns") + + Dictionary Input: + + >>> pd.set_option({"display.max_columns": 4, "display.precision": 1}) + >>> df = pd.DataFrame([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + >>> df + 0 1 ... 3 4 + 0 1 2 ... 4 5 + 1 6 7 ... 9 10 + [2 rows x 5 columns] + >>> pd.reset_option("display.max_columns") + >>> pd.reset_option("display.precision") + """ + # Handle dictionary input + if len(args) == 1 and isinstance(args[0], dict): + args = tuple(kv for item in args[0].items() for kv in item) + + nargs = len(args) + if not nargs or nargs % 2 != 0: + raise ValueError("Must provide an even number of non-keyword arguments") + + for k, v in zip(args[::2], args[1::2], strict=True): + key = _get_single_key(k) + + opt = _get_registered_option(key) + if opt and opt.validator: + opt.validator(v) + + # walk the nested dict + root, k_root = _get_root(key) + root[k_root] = v + + if opt.cb: + opt.cb(key) + + +def describe_option(pat: str = "", _print_desc: bool = True) -> str | None: + """ + Print the description for one or more registered options. + + Call with no arguments to get a listing for all registered options. + + Parameters + ---------- + pat : str, default "" + String or string regexp pattern. + Empty string will return all options. + For regexp strings, all matching keys will have their description displayed. + _print_desc : bool, default True + If True (default) the description(s) will be printed to stdout. + Otherwise, the description(s) will be returned as a string + (for testing). + + Returns + ------- + None + If ``_print_desc=True``. + str + If the description(s) as a string if ``_print_desc=False``. + + See Also + -------- + get_option : Retrieve the value of the specified option. + set_option : Set the value of the specified option or options. + reset_option : Reset one or more options to their default value. + + Notes + ----- + For all available options, please view the + :ref:`User Guide `. + + Examples + -------- + >>> pd.describe_option("display.max_columns") # doctest: +SKIP + display.max_columns : int + If max_cols is exceeded, switch to truncate view... + """ + keys = _select_options(pat) + if len(keys) == 0: + raise OptionError(f"No such keys(s) for {pat=}") + + s = "\n".join([_build_option_description(k) for k in keys]) + + if _print_desc: + print(s) + return None + return s + + +def reset_option(pat: str) -> None: + """ + Reset one or more options to their default value. + + This method resets the specified pandas option(s) back to their default + values. It allows partial string matching for convenience, but users should + exercise caution to avoid unintended resets due to changes in option names + in future versions. + + Parameters + ---------- + pat : str/regex + If specified only options matching ``pat*`` will be reset. + Pass ``"all"`` as argument to reset all options. + + .. warning:: + + Partial matches are supported for convenience, but unless you + use the full option name (e.g. x.y.z.option_name), your code may break + in future versions if new options with similar names are introduced. + + Returns + ------- + None + No return value. + + See Also + -------- + get_option : Retrieve the value of the specified option. + set_option : Set the value of the specified option or options. + describe_option : Print the description for one or more registered options. + + Notes + ----- + For all available options, please view the + :ref:`User Guide `. + + Examples + -------- + >>> pd.reset_option("display.max_columns") # doctest: +SKIP + """ + keys = _select_options(pat) + + if len(keys) == 0: + raise OptionError(f"No such keys(s) for {pat=}") + + if len(keys) > 1 and len(pat) < 4 and pat != "all": + raise ValueError( + "You must specify at least 4 characters when " + "resetting multiple keys, use the special keyword " + '"all" to reset all the options to their default value' + ) + + for k in keys: + set_option(k, _registered_options[k].defval) + + +def get_default_val(pat: str): + key = _get_single_key(pat) + return _get_registered_option(key).defval + + +class DictWrapper: + """provide attribute-style access to a nested dict""" + + d: dict[str, Any] + + def __init__(self, d: dict[str, Any], prefix: str = "") -> None: + object.__setattr__(self, "d", d) + object.__setattr__(self, "prefix", prefix) + + def __setattr__(self, key: str, val: Any) -> None: + prefix = object.__getattribute__(self, "prefix") + if prefix: + prefix += "." + prefix += key + # you can't set new keys + # can you can't overwrite subtrees + if key in self.d and not isinstance(self.d[key], dict): + set_option(prefix, val) + else: + raise OptionError("You can only set the value of existing options") + + def __getattr__(self, key: str): + prefix = object.__getattribute__(self, "prefix") + if prefix: + prefix += "." + prefix += key + try: + v = object.__getattribute__(self, "d")[key] + except KeyError as err: + raise OptionError("No such option") from err + if isinstance(v, dict): + return DictWrapper(v, prefix) + else: + return get_option(prefix) + + def __dir__(self) -> list[str]: + return list(self.d.keys()) + + +options = DictWrapper(_global_config) +# DictWrapper defines a custom setattr +object.__setattr__(options, "__module__", "pandas") + +# +# Functions for use by pandas developers, in addition to User - api + + +@contextmanager +def option_context(*args) -> Generator[None]: + """ + Context manager to temporarily set options in a ``with`` statement. + + This method allows users to set one or more pandas options temporarily + within a controlled block. The previous options' values are restored + once the block is exited. This is useful when making temporary adjustments + to pandas' behavior without affecting the global state. + + Parameters + ---------- + *args : str | object | dict + An even amount of arguments provided in pairs which will be + interpreted as (pattern, value) pairs. Alternatively, a single + dictionary of {pattern: value} may be provided. + + Returns + ------- + None + No return value. + + Yields + ------ + None + No yield value. + + See Also + -------- + get_option : Retrieve the value of the specified option. + set_option : Set the value of the specified option. + reset_option : Reset one or more options to their default value. + describe_option : Print the description for one or more registered options. + + Notes + ----- + For all available options, please view the :ref:`User Guide ` + or use ``pandas.describe_option()``. + + Examples + -------- + >>> from pandas import option_context + >>> with option_context("display.max_rows", 10, "display.max_columns", 5): + ... pass + >>> with option_context({"display.max_rows": 10, "display.max_columns": 5}): + ... pass + """ + if len(args) == 1 and isinstance(args[0], dict): + args = tuple(kv for item in args[0].items() for kv in item) + + if len(args) % 2 != 0 or len(args) < 2: + raise ValueError( + "Provide an even amount of arguments as " + "option_context(pat, val, pat, val...)." + ) + + ops = tuple(zip(args[::2], args[1::2], strict=True)) + undo: tuple[tuple[Any, Any], ...] = () + try: + undo = tuple((pat, get_option(pat)) for pat, val in ops) + for pat, val in ops: + set_option(pat, val) + yield + finally: + for pat, val in undo: + set_option(pat, val) + + +def register_option( + key: str, + defval: object, + doc: str = "", + validator: Callable[[object], Any] | None = None, + cb: Callable[[str], Any] | None = None, +) -> None: + """ + Register an option in the package-wide pandas config object + + Parameters + ---------- + key : str + Fully-qualified key, e.g. "x.y.option - z". + defval : object + Default value of the option. + doc : str + Description of the option. + validator : Callable, optional + Function of a single argument, should raise `ValueError` if + called with a value which is not a legal value for the option. + cb + a function of a single argument "key", which is called + immediately after an option value is set/reset. key is + the full name of the option. + + Raises + ------ + ValueError if `validator` is specified and `defval` is not a valid value. + + """ + import keyword + import tokenize + + key = key.lower() + + if key in _registered_options: + raise OptionError(f"Option '{key}' has already been registered") + if key in _reserved_keys: + raise OptionError(f"Option '{key}' is a reserved key") + + # the default value should be legal + if validator: + validator(defval) + + # walk the nested dict, creating dicts as needed along the path + path = key.split(".") + + for k in path: + if not re.match("^" + tokenize.Name + "$", k): + raise ValueError(f"{k} is not a valid identifier") + if keyword.iskeyword(k): + raise ValueError(f"{k} is a python keyword") + + cursor = _global_config + msg = "Path prefix to option '{option}' is already an option" + + for i, p in enumerate(path[:-1]): + if not isinstance(cursor, dict): + raise OptionError(msg.format(option=".".join(path[:i]))) + if p not in cursor: + cursor[p] = {} + cursor = cursor[p] + + if not isinstance(cursor, dict): + raise OptionError(msg.format(option=".".join(path[:-1]))) + + cursor[path[-1]] = defval # initialize + + # save the option metadata + _registered_options[key] = RegisteredOption( + key=key, defval=defval, doc=doc, validator=validator, cb=cb + ) + + +def deprecate_option( + key: str, + category: type[Warning], + msg: str | None = None, + rkey: str | None = None, + removal_ver: str | None = None, +) -> None: + """ + Mark option `key` as deprecated, if code attempts to access this option, + a warning will be produced, using `msg` if given, or a default message + if not. + if `rkey` is given, any access to the key will be re-routed to `rkey`. + + Neither the existence of `key` nor that if `rkey` is checked. If they + do not exist, any subsequence access will fail as usual, after the + deprecation warning is given. + + Parameters + ---------- + key : str + Name of the option to be deprecated. + must be a fully-qualified option name (e.g "x.y.z.rkey"). + category : Warning + Warning class for the deprecation. + msg : str, optional + Warning message to output when the key is referenced. + if no message is given a default message will be emitted. + rkey : str, optional + Name of an option to reroute access to. + If specified, any referenced `key` will be + re-routed to `rkey` including set/get/reset. + rkey must be a fully-qualified option name (e.g "x.y.z.rkey"). + used by the default message if no `msg` is specified. + removal_ver : str, optional + Specifies the version in which this option will + be removed. used by the default message if no `msg` is specified. + + Raises + ------ + OptionError + If the specified key has already been deprecated. + """ + key = key.lower() + + if key in _deprecated_options: + raise OptionError(f"Option '{key}' has already been defined as deprecated.") + + _deprecated_options[key] = DeprecatedOption(key, category, msg, rkey, removal_ver) + + +# +# functions internal to the module + + +def _select_options(pat: str) -> list[str]: + """ + returns a list of keys matching `pat` + + if pat=="all", returns all registered options + """ + # short-circuit for exact key + if pat in _registered_options: + return [pat] + + # else look through all of them + keys = sorted(_registered_options.keys()) + if pat == "all": # reserved key + return keys + + return [k for k in keys if re.search(pat, k, re.I)] + + +def _get_root(key: str) -> tuple[dict[str, Any], str]: + path = key.split(".") + cursor = _global_config + for p in path[:-1]: + cursor = cursor[p] + return cursor, path[-1] + + +def _get_deprecated_option(key: str): + """ + Retrieves the metadata for a deprecated option, if `key` is deprecated. + + Returns + ------- + DeprecatedOption (namedtuple) if key is deprecated, None otherwise + """ + try: + d = _deprecated_options[key] + except KeyError: + return None + else: + return d + + +def _get_registered_option(key: str): + """ + Retrieves the option metadata if `key` is a registered option. + + Returns + ------- + RegisteredOption (namedtuple) if key is deprecated, None otherwise + """ + return _registered_options.get(key) + + +def _translate_key(key: str) -> str: + """ + if `key` is deprecated and a replacement key defined, will return the + replacement key, otherwise returns `key` as-is + """ + d = _get_deprecated_option(key) + if d: + return d.rkey or key + else: + return key + + +def _warn_if_deprecated(key: str) -> bool: + """ + Checks if `key` is a deprecated option and if so, prints a warning. + + Returns + ------- + bool - True if `key` is deprecated, False otherwise. + """ + d = _get_deprecated_option(key) + if d: + if d.msg: + warnings.warn( + d.msg, + d.category, + stacklevel=find_stack_level(), + ) + else: + msg = f"'{key}' is deprecated" + if d.removal_ver: + msg += f" and will be removed in {d.removal_ver}" + if d.rkey: + msg += f", please use '{d.rkey}' instead." + else: + msg += ", please refrain from using it." + + warnings.warn( + msg, + d.category, + stacklevel=find_stack_level(), + ) + return True + return False + + +def _build_option_description(k: str) -> str: + """Builds a formatted description of a registered option and prints it""" + o = _get_registered_option(k) + d = _get_deprecated_option(k) + + s = f"{k} " + + if o.doc: + s += "\n".join(o.doc.strip().split("\n")) + else: + s += "No description available." + + if o: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + warnings.simplefilter("ignore", DeprecationWarning) + s += f"\n [default: {o.defval}] [currently: {get_option(k)}]" + + if d: + rkey = d.rkey or "" + s += "\n (Deprecated" + s += f", use `{rkey}` instead." + s += ")" + + return s + + +# helpers + + +@contextmanager +def config_prefix(prefix: str) -> Generator[None]: + """ + contextmanager for multiple invocations of API with a common prefix + + supported API functions: (register / get / set )__option + + Warning: This is not thread - safe, and won't work properly if you import + the API functions into your module using the "from x import y" construct. + + Example + ------- + import pandas._config.config as cf + with cf.config_prefix("display.font"): + cf.register_option("color", "red") + cf.register_option("size", " 5 pt") + cf.set_option(size, " 6 pt") + cf.get_option(size) + ... + + etc' + + will register options "display.font.color", "display.font.size", set the + value of "display.font.size"... and so on. + """ + # Note: reset_option relies on set_option, and on key directly + # it does not fit in to this monkey-patching scheme + + global register_option, get_option, set_option + + def wrap(func: F) -> F: + def inner(key: str, *args, **kwds): + pkey = f"{prefix}.{key}" + return func(pkey, *args, **kwds) + + return cast(F, inner) + + _register_option = register_option + _get_option = get_option + _set_option = set_option + set_option = wrap(set_option) + get_option = wrap(get_option) + register_option = wrap(register_option) + try: + yield + finally: + set_option = _set_option + get_option = _get_option + register_option = _register_option + + +# These factories and methods are handy for use as the validator +# arg in register_option + + +def is_type_factory(_type: type[Any]) -> Callable[[Any], None]: + """ + + Parameters + ---------- + `_type` - a type to be compared against (e.g. type(x) == `_type`) + + Returns + ------- + validator - a function of a single argument x , which raises + ValueError if type(x) is not equal to `_type` + + """ + + def inner(x) -> None: + if type(x) != _type: + raise ValueError(f"Value must have type '{_type}'") + + return inner + + +def is_instance_factory(_type: type | tuple[type, ...]) -> Callable[[Any], None]: + """ + + Parameters + ---------- + `_type` - the type to be checked against + + Returns + ------- + validator - a function of a single argument x , which raises + ValueError if x is not an instance of `_type` + + """ + if isinstance(_type, tuple): + type_repr = "|".join(map(str, _type)) + else: + type_repr = f"'{_type}'" + + def inner(x) -> None: + if not isinstance(x, _type): + raise ValueError(f"Value must be an instance of {type_repr}") + + return inner + + +def is_one_of_factory(legal_values: Sequence) -> Callable[[Any], None]: + callables = [c for c in legal_values if callable(c)] + legal_values = [c for c in legal_values if not callable(c)] + + def inner(x) -> None: + if x not in legal_values: + if not any(c(x) for c in callables): + uvals = [str(lval) for lval in legal_values] + pp_values = "|".join(uvals) + msg = f"Value must be one of {pp_values}" + if len(callables): + msg += " or a callable" + raise ValueError(msg) + + return inner + + +def is_nonnegative_int(value: object) -> None: + """ + Verify that value is None or a positive int. + + Parameters + ---------- + value : None or int + The `value` to be checked. + + Raises + ------ + ValueError + When the value is not None or is a negative integer + """ + if value is None: + return + + elif isinstance(value, int): + if value >= 0: + return + + msg = "Value must be a nonnegative integer or None" + raise ValueError(msg) + + +# common type validators, for convenience +# usage: register_option(... , validator = is_int) +is_int = is_type_factory(int) +is_bool = is_type_factory(bool) +is_float = is_type_factory(float) +is_str = is_type_factory(str) +is_text = is_instance_factory((str, bytes)) + + +def is_callable(obj: object) -> bool: + """ + + Parameters + ---------- + `obj` - the object to be checked + + Returns + ------- + validator - returns True if object is callable + raises ValueError otherwise. + + """ + if not callable(obj): + raise ValueError("Value must be a callable") + return True + + +# import set_module here would cause circular import +get_option.__module__ = "pandas" +set_option.__module__ = "pandas" +describe_option.__module__ = "pandas" +reset_option.__module__ = "pandas" +option_context.__module__ = "pandas" diff --git a/pandas/_config/dates.py b/pandas/_config/dates.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9f5d390dc9c82d66c2f02c20f8d569085db177 --- /dev/null +++ b/pandas/_config/dates.py @@ -0,0 +1,26 @@ +""" +config for datetime formatting +""" + +from __future__ import annotations + +from pandas._config import config as cf + +pc_date_dayfirst_doc = """ +: boolean + When True, prints and parses dates with the day first, eg 20/01/2005 +""" + +pc_date_yearfirst_doc = """ +: boolean + When True, prints and parses dates with the year first, eg 2005/01/20 +""" + +with cf.config_prefix("display"): + # Needed upstream of `_libs` because these are used in tslibs.parsing + cf.register_option( + "date_dayfirst", False, pc_date_dayfirst_doc, validator=cf.is_bool + ) + cf.register_option( + "date_yearfirst", False, pc_date_yearfirst_doc, validator=cf.is_bool + ) diff --git a/pandas/_config/display.py b/pandas/_config/display.py new file mode 100644 index 0000000000000000000000000000000000000000..df2c3ad36c855d77c33d80c78c3d83ab3c09d5f9 --- /dev/null +++ b/pandas/_config/display.py @@ -0,0 +1,62 @@ +""" +Unopinionated display configuration. +""" + +from __future__ import annotations + +import locale +import sys + +from pandas._config import config as cf + +# ----------------------------------------------------------------------------- +# Global formatting options +_initial_defencoding: str | None = None + + +def detect_console_encoding() -> str: + """ + Try to find the most capable encoding supported by the console. + slightly modified from the way IPython handles the same issue. + """ + global _initial_defencoding + + encoding = None + try: + encoding = sys.stdout.encoding or sys.stdin.encoding + except (AttributeError, OSError): + pass + + # try again for something better + if not encoding or "ascii" in encoding.lower(): + try: + encoding = locale.getpreferredencoding() + except locale.Error: + # can be raised by locale.setlocale(), which is + # called by getpreferredencoding + # (on some systems, see stdlib locale docs) + pass + + # when all else fails. this will usually be "ascii" + if not encoding or "ascii" in encoding.lower(): + encoding = sys.getdefaultencoding() + + # GH#3360, save the reported defencoding at import time + # MPL backends may change it. Make available for debugging. + if not _initial_defencoding: + _initial_defencoding = sys.getdefaultencoding() + + return encoding + + +pc_encoding_doc = """ +: str/unicode + Defaults to the detected encoding of the console. + Specifies the encoding to be used for strings returned by to_string, + these are generally strings meant to be displayed on the console. +""" + +with cf.config_prefix("display"): + cf.register_option( + "encoding", detect_console_encoding(), pc_encoding_doc, validator=cf.is_text + ) diff --git a/pandas/_config/localization.py b/pandas/_config/localization.py new file mode 100644 index 0000000000000000000000000000000000000000..8e2af78f68eb21201476bcc42afaba419f8e1a8a --- /dev/null +++ b/pandas/_config/localization.py @@ -0,0 +1,176 @@ +""" +Helpers for configuring locale settings. + +Name `localization` is chosen to avoid overlap with builtin `locale` module. +""" + +from __future__ import annotations + +from contextlib import contextmanager +import locale +import platform +import re +import subprocess +from typing import ( + TYPE_CHECKING, + cast, +) + +from pandas._config.config import options + +if TYPE_CHECKING: + from collections.abc import Generator + + +@contextmanager +def set_locale( + new_locale: str | tuple[str, str], lc_var: int = locale.LC_ALL +) -> Generator[str | tuple[str, str]]: + """ + Context manager for temporarily setting a locale. + + Parameters + ---------- + new_locale : str or tuple + A string of the form .. For example to set + the current locale to US English with a UTF8 encoding, you would pass + "en_US.UTF-8". + lc_var : int, default `locale.LC_ALL` + The category of the locale being set. + + Notes + ----- + This is useful when you want to run a particular block of code under a + particular locale, without globally setting the locale. This probably isn't + thread-safe. + """ + # getlocale is not always compliant with setlocale, use setlocale. GH#46595 + current_locale = locale.setlocale(lc_var) + + try: + locale.setlocale(lc_var, new_locale) + normalized_code, normalized_encoding = locale.getlocale() + if normalized_code is not None and normalized_encoding is not None: + yield f"{normalized_code}.{normalized_encoding}" + else: + yield new_locale + finally: + locale.setlocale(lc_var, current_locale) + + +def can_set_locale(lc: str, lc_var: int = locale.LC_ALL) -> bool: + """ + Check to see if we can set a locale, and subsequently get the locale, + without raising an Exception. + + Parameters + ---------- + lc : str + The locale to attempt to set. + lc_var : int, default `locale.LC_ALL` + The category of the locale being set. + + Returns + ------- + bool + Whether the passed locale can be set + """ + try: + with set_locale(lc, lc_var=lc_var): + pass + except (ValueError, locale.Error): + # horrible name for an Exception subclass + return False + else: + return True + + +def _valid_locales(locales: list[str] | str, normalize: bool) -> list[str]: + """ + Return a list of normalized locales that do not throw an ``Exception`` + when set. + + Parameters + ---------- + locales : str + A string where each locale is separated by a newline. + normalize : bool + Whether to call ``locale.normalize`` on each locale. + + Returns + ------- + valid_locales : list + A list of valid locales. + """ + return [ + loc + for loc in ( + locale.normalize(loc.strip()) if normalize else loc.strip() + for loc in locales + ) + if can_set_locale(loc) + ] + + +def get_locales( + prefix: str | None = None, + normalize: bool = True, +) -> list[str]: + """ + Get all the locales that are available on the system. + + Parameters + ---------- + prefix : str + If not ``None`` then return only those locales with the prefix + provided. For example to get all English language locales (those that + start with ``"en"``), pass ``prefix="en"``. + normalize : bool + Call ``locale.normalize`` on the resulting list of available locales. + If ``True``, only locales that can be set without throwing an + ``Exception`` are returned. + + Returns + ------- + locales : list of strings + A list of locale strings that can be set with ``locale.setlocale()``. + For example:: + + locale.setlocale(locale.LC_ALL, locale_string) + + On error will return an empty list (no locale available, e.g. Windows) + + """ + if platform.system() in ("Linux", "Darwin"): + raw_locales = subprocess.check_output(["locale", "-a"]) + else: + # Other platforms e.g. windows platforms don't define "locale -a" + # Note: is_platform_windows causes circular import here + return [] + + try: + # raw_locales is "\n" separated list of locales + # it may contain non-decodable parts, so split + # extract what we can and then rejoin. + split_raw_locales = raw_locales.split(b"\n") + out_locales = [] + for x in split_raw_locales: + try: + out_locales.append(str(x, encoding=cast(str, options.display.encoding))) + except UnicodeError: + # 'locale -a' is used to populated 'raw_locales' and on + # Redhat 7 Linux (and maybe others) prints locale names + # using windows-1252 encoding. Bug only triggered by + # a few special characters and when there is an + # extensive list of installed locales. + out_locales.append(str(x, encoding="windows-1252")) + + except TypeError: + pass + + if prefix is None: + return _valid_locales(out_locales, normalize) + + pattern = re.compile(f"{prefix}.*") + found = pattern.findall("\n".join(out_locales)) + return _valid_locales(found, normalize) diff --git a/pandas/_libs/__init__.py b/pandas/_libs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d499f9a6cd75e53dc63a4a83c073449025caac94 --- /dev/null +++ b/pandas/_libs/__init__.py @@ -0,0 +1,27 @@ +__all__ = [ + "Interval", + "NaT", + "NaTType", + "OutOfBoundsDatetime", + "Period", + "Timedelta", + "Timestamp", + "iNaT", +] + + +# Below imports needs to happen first to ensure pandas top level +# module gets monkeypatched with the pandas_datetime_CAPI +# see pandas_datetime_exec in pd_datetime.c +import pandas._libs.pandas_parser # isort: skip # type: ignore[reportUnusedImport] +import pandas._libs.pandas_datetime # noqa: F401 # isort: skip # type: ignore[reportUnusedImport] +from pandas._libs.interval import Interval +from pandas._libs.tslibs import ( + NaT, + NaTType, + OutOfBoundsDatetime, + Period, + Timedelta, + Timestamp, + iNaT, +) diff --git a/pandas/_libs/algos.pyi b/pandas/_libs/algos.pyi new file mode 100644 index 0000000000000000000000000000000000000000..0a6be851e1efd0389eab3398462d3fa0e7e5946f --- /dev/null +++ b/pandas/_libs/algos.pyi @@ -0,0 +1,443 @@ +from typing import Any + +import numpy as np + +from pandas._typing import npt + +class Infinity: + def __eq__(self, other) -> bool: ... + def __ne__(self, other) -> bool: ... + def __lt__(self, other) -> bool: ... + def __le__(self, other) -> bool: ... + def __gt__(self, other) -> bool: ... + def __ge__(self, other) -> bool: ... + +class NegInfinity: + def __eq__(self, other) -> bool: ... + def __ne__(self, other) -> bool: ... + def __lt__(self, other) -> bool: ... + def __le__(self, other) -> bool: ... + def __gt__(self, other) -> bool: ... + def __ge__(self, other) -> bool: ... + +def unique_deltas( + arr: np.ndarray, # const int64_t[:] +) -> np.ndarray: ... # np.ndarray[np.int64, ndim=1] +def is_lexsorted(list_of_arrays: list[npt.NDArray[np.int64]]) -> bool: ... +def groupsort_indexer( + index: np.ndarray, # const int64_t[:] + ngroups: int, +) -> tuple[ + np.ndarray, # ndarray[int64_t, ndim=1] + np.ndarray, # ndarray[int64_t, ndim=1] +]: ... +def kth_smallest( + arr: np.ndarray, # numeric[:] + k: int, +) -> Any: ... # numeric + +# ---------------------------------------------------------------------- +# Pairwise correlation/covariance + +def nancorr( + mat: npt.NDArray[np.float64], # const float64_t[:, :] + cov: bool = ..., + minp: int | None = ..., +) -> npt.NDArray[np.float64]: ... # ndarray[float64_t, ndim=2] +def nancorr_spearman( + mat: npt.NDArray[np.float64], # ndarray[float64_t, ndim=2] + minp: int = ..., +) -> npt.NDArray[np.float64]: ... # ndarray[float64_t, ndim=2] + +# ---------------------------------------------------------------------- + +def validate_limit(nobs: int | None, limit=...) -> int: ... +def get_fill_indexer( + mask: npt.NDArray[np.bool_], + limit: int | None = None, +) -> npt.NDArray[np.intp]: ... +def pad( + old: np.ndarray, # ndarray[numeric_object_t] + new: np.ndarray, # ndarray[numeric_object_t] + limit=..., +) -> npt.NDArray[np.intp]: ... # np.ndarray[np.intp, ndim=1] +def pad_inplace( + values: np.ndarray, # numeric_object_t[:] + mask: np.ndarray, # uint8_t[:] + limit=..., +) -> None: ... +def pad_2d_inplace( + values: np.ndarray, # numeric_object_t[:, :] + mask: np.ndarray, # const uint8_t[:, :] + limit=..., +) -> None: ... +def backfill( + old: np.ndarray, # ndarray[numeric_object_t] + new: np.ndarray, # ndarray[numeric_object_t] + limit=..., +) -> npt.NDArray[np.intp]: ... # np.ndarray[np.intp, ndim=1] +def backfill_inplace( + values: np.ndarray, # numeric_object_t[:] + mask: np.ndarray, # uint8_t[:] + limit=..., +) -> None: ... +def backfill_2d_inplace( + values: np.ndarray, # numeric_object_t[:, :] + mask: np.ndarray, # const uint8_t[:, :] + limit=..., +) -> None: ... +def is_monotonic( + arr: np.ndarray, # ndarray[numeric_object_t, ndim=1] + timelike: bool, +) -> tuple[bool, bool, bool]: ... + +# ---------------------------------------------------------------------- +# rank_1d, rank_2d +# ---------------------------------------------------------------------- + +def rank_1d( + values: np.ndarray, # ndarray[numeric_object_t, ndim=1] + labels: np.ndarray | None = ..., # const int64_t[:]=None + is_datetimelike: bool = ..., + ties_method=..., + ascending: bool = ..., + pct: bool = ..., + na_option=..., + mask: npt.NDArray[np.bool_] | None = ..., +) -> np.ndarray: ... # np.ndarray[float64_t, ndim=1] +def rank_2d( + in_arr: np.ndarray, # ndarray[numeric_object_t, ndim=2] + axis: int = ..., + is_datetimelike: bool = ..., + ties_method=..., + ascending: bool = ..., + na_option=..., + pct: bool = ..., +) -> np.ndarray: ... # np.ndarray[float64_t, ndim=1] +def diff_2d( + arr: np.ndarray, # ndarray[diff_t, ndim=2] + out: np.ndarray, # ndarray[out_t, ndim=2] + periods: int, + axis: int, + datetimelike: bool = ..., +) -> None: ... +def ensure_platform_int(arr: object) -> npt.NDArray[np.intp]: ... +def ensure_object(arr: object) -> npt.NDArray[np.object_]: ... +def ensure_float64(arr: object) -> npt.NDArray[np.float64]: ... +def ensure_int8(arr: object) -> npt.NDArray[np.int8]: ... +def ensure_int16(arr: object) -> npt.NDArray[np.int16]: ... +def ensure_int32(arr: object) -> npt.NDArray[np.int32]: ... +def ensure_int64(arr: object) -> npt.NDArray[np.int64]: ... +def ensure_uint64(arr: object) -> npt.NDArray[np.uint64]: ... +def take_1d_int8_int8( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_int8_int32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_int8_int64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_int8_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_int16_int16( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_int16_int32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_int16_int64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_int16_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_int32_int32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_int32_int64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_int32_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_int64_int64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_uint16_uint16( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_uint32_uint32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_uint64_uint64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_int64_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_float32_float32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_float32_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_float64_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_object_object( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_bool_bool( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_1d_bool_object( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int8_int8( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int8_int32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int8_int64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int8_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int16_int16( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int16_int32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int16_int64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int16_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int32_int32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int32_int64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int32_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int64_int64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_int64_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_uint16_uint16( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_uint32_uint32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_uint64_uint64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_float32_float32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_float32_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_float64_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_object_object( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_bool_bool( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis0_bool_object( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int8_int8( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int8_int32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int8_int64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int8_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int16_int16( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int16_int32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int16_int64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int16_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int32_int32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int32_int64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int32_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int64_int64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_uint16_uint16( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_uint32_uint32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_uint64_uint64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_int64_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_float32_float32( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_float32_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_float64_float64( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_object_object( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_bool_bool( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_axis1_bool_object( + values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=... +) -> None: ... +def take_2d_multi_int8_int8( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_int8_int32( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_int8_int64( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_int8_float64( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_int16_int16( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_int16_int32( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_int16_int64( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_int16_float64( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_int32_int32( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_int32_int64( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_int32_float64( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_int64_float64( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_float32_float32( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_float32_float64( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_float64_float64( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_object_object( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_bool_bool( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_bool_object( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... +def take_2d_multi_int64_int64( + values: np.ndarray, + indexer: tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]], + out: np.ndarray, + fill_value=..., +) -> None: ... diff --git a/pandas/_libs/arrays.pyi b/pandas/_libs/arrays.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7b373240952ca308391619c25193cc81627098b0 --- /dev/null +++ b/pandas/_libs/arrays.pyi @@ -0,0 +1,40 @@ +from collections.abc import Sequence +from typing import Self + +import numpy as np + +from pandas._typing import ( + AxisInt, + DtypeObj, + Shape, +) + +class NDArrayBacked: + _dtype: DtypeObj + _ndarray: np.ndarray + def __init__(self, values: np.ndarray, dtype: DtypeObj) -> None: ... + @classmethod + def _simple_new(cls, values: np.ndarray, dtype: DtypeObj) -> Self: ... + def _from_backing_data(self, values: np.ndarray) -> Self: ... + def __setstate__(self, state) -> None: ... + def __len__(self) -> int: ... + @property + def shape(self) -> Shape: ... + @property + def ndim(self) -> int: ... + @property + def size(self) -> int: ... + @property + def nbytes(self) -> int: ... + def copy(self, order=...) -> Self: ... + def delete(self, loc, axis=...) -> Self: ... + def swapaxes(self, axis1, axis2) -> Self: ... + def repeat(self, repeats: int | Sequence[int], axis: int | None = ...) -> Self: ... + def reshape(self, *args, **kwargs) -> Self: ... + def ravel(self, order=...) -> Self: ... + @property + def T(self) -> Self: ... + @classmethod + def _concat_same_type( + cls, to_concat: Sequence[Self], axis: AxisInt = ... + ) -> Self: ... diff --git a/pandas/_libs/byteswap.cpython-312-x86_64-linux-gnu.so b/pandas/_libs/byteswap.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..b81736073b03355e3bb7d0c8af27c9bb0d8c9201 Binary files /dev/null and b/pandas/_libs/byteswap.cpython-312-x86_64-linux-gnu.so differ diff --git a/pandas/_libs/byteswap.pyi b/pandas/_libs/byteswap.pyi new file mode 100644 index 0000000000000000000000000000000000000000..bb0dbfc6a50b1bb7cd509dc5b3dfeed55ad70b09 --- /dev/null +++ b/pandas/_libs/byteswap.pyi @@ -0,0 +1,5 @@ +def read_float_with_byteswap(data: bytes, offset: int, byteswap: bool) -> float: ... +def read_double_with_byteswap(data: bytes, offset: int, byteswap: bool) -> float: ... +def read_uint16_with_byteswap(data: bytes, offset: int, byteswap: bool) -> int: ... +def read_uint32_with_byteswap(data: bytes, offset: int, byteswap: bool) -> int: ... +def read_uint64_with_byteswap(data: bytes, offset: int, byteswap: bool) -> int: ... diff --git a/pandas/_libs/groupby.pyi b/pandas/_libs/groupby.pyi new file mode 100644 index 0000000000000000000000000000000000000000..803c2cb0b0d19f53863a7e5fb8d976431f842f4a --- /dev/null +++ b/pandas/_libs/groupby.pyi @@ -0,0 +1,234 @@ +from typing import Literal + +import numpy as np + +from pandas._typing import npt + +def group_median_float64( + out: np.ndarray, # ndarray[float64_t, ndim=2] + counts: npt.NDArray[np.int64], + values: np.ndarray, # ndarray[float64_t, ndim=2] + labels: npt.NDArray[np.int64], + min_count: int = ..., # Py_ssize_t + mask: np.ndarray | None = ..., + result_mask: np.ndarray | None = ..., + is_datetimelike: bool = ..., # bint + skipna: bool = ..., +) -> None: ... +def group_cumprod( + out: np.ndarray, # float64_t[:, ::1] + values: np.ndarray, # const float64_t[:, :] + labels: np.ndarray, # const int64_t[:] + ngroups: int, + is_datetimelike: bool, + skipna: bool = ..., + mask: np.ndarray | None = ..., + result_mask: np.ndarray | None = ..., +) -> None: ... +def group_cumsum( + out: np.ndarray, # int64float_t[:, ::1] + values: np.ndarray, # ndarray[int64float_t, ndim=2] + labels: np.ndarray, # const int64_t[:] + ngroups: int, + is_datetimelike: bool, + skipna: bool = ..., + mask: np.ndarray | None = ..., + result_mask: np.ndarray | None = ..., +) -> None: ... +def group_shift_indexer( + out: np.ndarray, # int64_t[::1] + labels: np.ndarray, # const int64_t[:] + ngroups: int, + periods: int, +) -> None: ... +def group_fillna_indexer( + out: np.ndarray, # ndarray[intp_t] + labels: np.ndarray, # ndarray[int64_t] + mask: npt.NDArray[np.uint8], + limit: int, # int64_t + compute_ffill: bool, + ngroups: int, +) -> None: ... +def group_any_all( + out: np.ndarray, # uint8_t[::1] + values: np.ndarray, # const uint8_t[::1] + labels: np.ndarray, # const int64_t[:] + mask: np.ndarray, # const uint8_t[::1] + val_test: Literal["any", "all"], + skipna: bool, + result_mask: np.ndarray | None, +) -> None: ... +def group_sum( + out: np.ndarray, # complexfloatingintuint_t[:, ::1] + counts: np.ndarray, # int64_t[::1] + values: np.ndarray, # ndarray[complexfloatingintuint_t, ndim=2] + labels: np.ndarray, # const intp_t[:] + mask: np.ndarray | None, + result_mask: np.ndarray | None = ..., + min_count: int = ..., + is_datetimelike: bool = ..., + initial: object = ..., + skipna: bool = ..., +) -> None: ... +def group_prod( + out: np.ndarray, # int64float_t[:, ::1] + counts: np.ndarray, # int64_t[::1] + values: np.ndarray, # ndarray[int64float_t, ndim=2] + labels: np.ndarray, # const intp_t[:] + mask: np.ndarray | None, + result_mask: np.ndarray | None = ..., + min_count: int = ..., + skipna: bool = ..., +) -> None: ... +def group_var( + out: np.ndarray, # floating[:, ::1] + counts: np.ndarray, # int64_t[::1] + values: np.ndarray, # ndarray[floating, ndim=2] + labels: np.ndarray, # const intp_t[:] + min_count: int = ..., # Py_ssize_t + ddof: int = ..., # int64_t + mask: np.ndarray | None = ..., + result_mask: np.ndarray | None = ..., + is_datetimelike: bool = ..., + name: str = ..., + skipna: bool = ..., +) -> None: ... +def group_skew( + out: np.ndarray, # float64_t[:, ::1] + counts: np.ndarray, # int64_t[::1] + values: np.ndarray, # ndarray[float64_T, ndim=2] + labels: np.ndarray, # const intp_t[::1] + mask: np.ndarray | None = ..., + result_mask: np.ndarray | None = ..., + skipna: bool = ..., +) -> None: ... +def group_kurt( + out: np.ndarray, # float64_t[:, ::1] + counts: np.ndarray, # int64_t[::1] + values: np.ndarray, # ndarray[float64_T, ndim=2] + labels: np.ndarray, # const intp_t[::1] + mask: np.ndarray | None = ..., + result_mask: np.ndarray | None = ..., + skipna: bool = ..., +) -> None: ... +def group_mean( + out: np.ndarray, # floating[:, ::1] + counts: np.ndarray, # int64_t[::1] + values: np.ndarray, # ndarray[floating, ndim=2] + labels: np.ndarray, # const intp_t[:] + min_count: int = ..., # Py_ssize_t + is_datetimelike: bool = ..., # bint + mask: np.ndarray | None = ..., + result_mask: np.ndarray | None = ..., + skipna: bool = ..., +) -> None: ... +def group_ohlc( + out: np.ndarray, # floatingintuint_t[:, ::1] + counts: np.ndarray, # int64_t[::1] + values: np.ndarray, # ndarray[floatingintuint_t, ndim=2] + labels: np.ndarray, # const intp_t[:] + min_count: int = ..., + mask: np.ndarray | None = ..., + result_mask: np.ndarray | None = ..., +) -> None: ... +def group_quantile( + out: npt.NDArray[np.float64], + values: np.ndarray, # ndarray[numeric, ndim=1] + labels: npt.NDArray[np.intp], + mask: npt.NDArray[np.uint8], + qs: npt.NDArray[np.float64], # const + starts: npt.NDArray[np.int64], + ends: npt.NDArray[np.int64], + interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"], + result_mask: np.ndarray | None, + is_datetimelike: bool, +) -> None: ... +def group_last( + out: np.ndarray, # rank_t[:, ::1] + counts: np.ndarray, # int64_t[::1] + values: np.ndarray, # ndarray[rank_t, ndim=2] + labels: np.ndarray, # const int64_t[:] + mask: npt.NDArray[np.bool_] | None, + result_mask: npt.NDArray[np.bool_] | None = ..., + min_count: int = ..., # Py_ssize_t + is_datetimelike: bool = ..., + skipna: bool = ..., +) -> None: ... +def group_nth( + out: np.ndarray, # rank_t[:, ::1] + counts: np.ndarray, # int64_t[::1] + values: np.ndarray, # ndarray[rank_t, ndim=2] + labels: np.ndarray, # const int64_t[:] + mask: npt.NDArray[np.bool_] | None, + result_mask: npt.NDArray[np.bool_] | None = ..., + min_count: int = ..., # int64_t + rank: int = ..., # int64_t + is_datetimelike: bool = ..., + skipna: bool = ..., +) -> None: ... +def group_rank( + out: np.ndarray, # float64_t[:, ::1] + values: np.ndarray, # ndarray[rank_t, ndim=2] + labels: np.ndarray, # const int64_t[:] + ngroups: int, + is_datetimelike: bool, + ties_method: Literal["average", "min", "max", "first", "dense"] = ..., + ascending: bool = ..., + pct: bool = ..., + na_option: Literal["keep", "top", "bottom"] = ..., + mask: npt.NDArray[np.bool_] | None = ..., +) -> None: ... +def group_max( + out: np.ndarray, # groupby_t[:, ::1] + counts: np.ndarray, # int64_t[::1] + values: np.ndarray, # ndarray[groupby_t, ndim=2] + labels: np.ndarray, # const int64_t[:] + min_count: int = ..., + is_datetimelike: bool = ..., + mask: np.ndarray | None = ..., + result_mask: np.ndarray | None = ..., + skipna: bool = ..., +) -> None: ... +def group_min( + out: np.ndarray, # groupby_t[:, ::1] + counts: np.ndarray, # int64_t[::1] + values: np.ndarray, # ndarray[groupby_t, ndim=2] + labels: np.ndarray, # const int64_t[:] + min_count: int = ..., + is_datetimelike: bool = ..., + mask: np.ndarray | None = ..., + result_mask: np.ndarray | None = ..., + skipna: bool = ..., +) -> None: ... +def group_idxmin_idxmax( + out: npt.NDArray[np.intp], + counts: npt.NDArray[np.int64], + values: np.ndarray, # ndarray[groupby_t, ndim=2] + labels: npt.NDArray[np.intp], + min_count: int = ..., + is_datetimelike: bool = ..., + mask: np.ndarray | None = ..., + name: str = ..., + skipna: bool = ..., + result_mask: np.ndarray | None = ..., +) -> None: ... +def group_cummin( + out: np.ndarray, # groupby_t[:, ::1] + values: np.ndarray, # ndarray[groupby_t, ndim=2] + labels: np.ndarray, # const int64_t[:] + ngroups: int, + is_datetimelike: bool, + mask: np.ndarray | None = ..., + result_mask: np.ndarray | None = ..., + skipna: bool = ..., +) -> None: ... +def group_cummax( + out: np.ndarray, # groupby_t[:, ::1] + values: np.ndarray, # ndarray[groupby_t, ndim=2] + labels: np.ndarray, # const int64_t[:] + ngroups: int, + is_datetimelike: bool, + mask: np.ndarray | None = ..., + result_mask: np.ndarray | None = ..., + skipna: bool = ..., +) -> None: ... diff --git a/pandas/_libs/hashing.cpython-312-x86_64-linux-gnu.so b/pandas/_libs/hashing.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..28e6528ace127ef506c9d68897fefbb6562cb25a Binary files /dev/null and b/pandas/_libs/hashing.cpython-312-x86_64-linux-gnu.so differ diff --git a/pandas/_libs/hashing.pyi b/pandas/_libs/hashing.pyi new file mode 100644 index 0000000000000000000000000000000000000000..8361026e4a87d462e04c53f7f1f8aee8a7f6ffe0 --- /dev/null +++ b/pandas/_libs/hashing.pyi @@ -0,0 +1,9 @@ +import numpy as np + +from pandas._typing import npt + +def hash_object_array( + arr: npt.NDArray[np.object_], + key: str, + encoding: str = ..., +) -> npt.NDArray[np.uint64]: ... diff --git a/pandas/_libs/hashtable.pyi b/pandas/_libs/hashtable.pyi new file mode 100644 index 0000000000000000000000000000000000000000..5ee359d84a6ed786dea6fc3ac9659c3adf7c8f70 --- /dev/null +++ b/pandas/_libs/hashtable.pyi @@ -0,0 +1,274 @@ +from collections.abc import Hashable +from typing import ( + Any, + Literal, + overload, +) + +import numpy as np + +from pandas._typing import npt + +def unique_label_indices( + labels: np.ndarray, # const int64_t[:] +) -> np.ndarray: ... + +class Factorizer: + count: int + uniques: Any + def __init__(self, size_hint: int, uses_mask: bool = False) -> None: ... + def get_count(self) -> int: ... + def factorize( + self, + values: np.ndarray, + na_sentinel=..., + na_value=..., + mask=..., + ) -> npt.NDArray[np.intp]: ... + def hash_inner_join( + self, values: np.ndarray, mask=... + ) -> tuple[np.ndarray, np.ndarray]: ... + +class ObjectFactorizer(Factorizer): + table: PyObjectHashTable + uniques: ObjectVector + +class Int64Factorizer(Factorizer): + table: Int64HashTable + uniques: Int64Vector + +class UInt64Factorizer(Factorizer): + table: UInt64HashTable + uniques: UInt64Vector + +class Int32Factorizer(Factorizer): + table: Int32HashTable + uniques: Int32Vector + +class UInt32Factorizer(Factorizer): + table: UInt32HashTable + uniques: UInt32Vector + +class Int16Factorizer(Factorizer): + table: Int16HashTable + uniques: Int16Vector + +class UInt16Factorizer(Factorizer): + table: UInt16HashTable + uniques: UInt16Vector + +class Int8Factorizer(Factorizer): + table: Int8HashTable + uniques: Int8Vector + +class UInt8Factorizer(Factorizer): + table: UInt8HashTable + uniques: UInt8Vector + +class Float64Factorizer(Factorizer): + table: Float64HashTable + uniques: Float64Vector + +class Float32Factorizer(Factorizer): + table: Float32HashTable + uniques: Float32Vector + +class Complex64Factorizer(Factorizer): + table: Complex64HashTable + uniques: Complex64Vector + +class Complex128Factorizer(Factorizer): + table: Complex128HashTable + uniques: Complex128Vector + +class Int64Vector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.int64]: ... + +class Int32Vector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.int32]: ... + +class Int16Vector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.int16]: ... + +class Int8Vector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.int8]: ... + +class UInt64Vector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.uint64]: ... + +class UInt32Vector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.uint32]: ... + +class UInt16Vector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.uint16]: ... + +class UInt8Vector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.uint8]: ... + +class Float64Vector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.float64]: ... + +class Float32Vector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.float32]: ... + +class Complex128Vector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.complex128]: ... + +class Complex64Vector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.complex64]: ... + +class StringVector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.object_]: ... + +class ObjectVector: + def __init__(self, *args) -> None: ... + def __len__(self) -> int: ... + def to_array(self) -> npt.NDArray[np.object_]: ... + +class HashTable: + # NB: The base HashTable class does _not_ actually have these methods; + # we are putting them here for the sake of mypy to avoid + # reproducing them in each subclass below. + def __init__(self, size_hint: int = ..., uses_mask: bool = ...) -> None: ... + def __len__(self) -> int: ... + def __contains__(self, key: Hashable) -> bool: ... + def sizeof(self, deep: bool = ...) -> int: ... + def get_state(self) -> dict[str, int]: ... + # TODO: `val/key` type is subclass-specific + def get_item(self, val): ... # TODO: return type? + def set_item(self, key, val) -> None: ... + def get_na(self): ... # TODO: return type? + def set_na(self, val) -> None: ... + def map_locations( + self, + values: np.ndarray, # np.ndarray[subclass-specific] + mask: npt.NDArray[np.bool_] | None = ..., + ) -> None: ... + def lookup( + self, + values: np.ndarray, # np.ndarray[subclass-specific] + mask: npt.NDArray[np.bool_] | None = ..., + ) -> npt.NDArray[np.intp]: ... + def get_labels( + self, + values: np.ndarray, # np.ndarray[subclass-specific] + uniques, # SubclassTypeVector + count_prior: int = ..., + na_sentinel: int = ..., + na_value: object = ..., + mask=..., + ) -> npt.NDArray[np.intp]: ... + @overload + def unique( + self, + values: np.ndarray, # np.ndarray[subclass-specific] + *, + return_inverse: Literal[False] = ..., + mask: None = ..., + ) -> np.ndarray: ... # np.ndarray[subclass-specific] + @overload + def unique( + self, + values: np.ndarray, # np.ndarray[subclass-specific] + *, + return_inverse: Literal[True], + mask: None = ..., + ) -> tuple[np.ndarray, npt.NDArray[np.intp]]: ... # np.ndarray[subclass-specific] + @overload + def unique( + self, + values: np.ndarray, # np.ndarray[subclass-specific] + *, + return_inverse: Literal[False] = ..., + mask: npt.NDArray[np.bool_], + ) -> tuple[ + np.ndarray, + npt.NDArray[np.bool_], + ]: ... # np.ndarray[subclass-specific] + def factorize( + self, + values: np.ndarray, # np.ndarray[subclass-specific] + na_sentinel: int = ..., + na_value: object = ..., + mask=..., + ignore_na: bool = True, + ) -> tuple[np.ndarray, npt.NDArray[np.intp]]: ... # np.ndarray[subclass-specific] + def hash_inner_join( + self, values: np.ndarray, mask=... + ) -> tuple[np.ndarray, np.ndarray]: ... + +class Complex128HashTable(HashTable): ... +class Complex64HashTable(HashTable): ... +class Float64HashTable(HashTable): ... +class Float32HashTable(HashTable): ... + +class Int64HashTable(HashTable): + # Only Int64HashTable has get_labels_groupby, map_keys_to_values + def get_labels_groupby( + self, + values: npt.NDArray[np.int64], # const int64_t[:] + ) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.int64]]: ... + def map_keys_to_values( + self, + keys: npt.NDArray[np.int64], + values: npt.NDArray[np.int64], # const int64_t[:] + ) -> None: ... + +class Int32HashTable(HashTable): ... +class Int16HashTable(HashTable): ... +class Int8HashTable(HashTable): ... +class UInt64HashTable(HashTable): ... +class UInt32HashTable(HashTable): ... +class UInt16HashTable(HashTable): ... +class UInt8HashTable(HashTable): ... +class StringHashTable(HashTable): ... +class PyObjectHashTable(HashTable): ... +class IntpHashTable(HashTable): ... + +def duplicated( + values: np.ndarray, + keep: Literal["last", "first", False] = ..., + mask: npt.NDArray[np.bool_] | None = ..., +) -> npt.NDArray[np.bool_]: ... +def mode( + values: np.ndarray, dropna: bool, mask: npt.NDArray[np.bool_] | None = ... +) -> np.ndarray: ... +def value_count( + values: np.ndarray, + dropna: bool, + mask: npt.NDArray[np.bool_] | None = ..., +) -> tuple[np.ndarray, npt.NDArray[np.int64], int]: ... # np.ndarray[same-as-values] + +# arr and values should have same dtype +def ismember( + arr: np.ndarray, + values: np.ndarray, +) -> npt.NDArray[np.bool_]: ... +def object_hash(obj) -> int: ... +def objects_are_equal(a, b) -> bool: ... diff --git a/pandas/_libs/index.pyi b/pandas/_libs/index.pyi new file mode 100644 index 0000000000000000000000000000000000000000..3af2856d2fbbf36c1da27bb38fd66a62bc6ac3ea --- /dev/null +++ b/pandas/_libs/index.pyi @@ -0,0 +1,107 @@ +import numpy as np + +from pandas._typing import npt + +from pandas import ( + Index, + MultiIndex, +) +from pandas.core.arrays import ExtensionArray + +multiindex_nulls_shift: int + +class IndexEngine: + over_size_threshold: bool + def __init__(self, values: np.ndarray) -> None: ... + def __contains__(self, val: object) -> bool: ... + + # -> int | slice | np.ndarray[bool] + def get_loc(self, val: object) -> int | slice | np.ndarray: ... + def sizeof(self, deep: bool = ...) -> int: ... + def __sizeof__(self) -> int: ... + @property + def is_unique(self) -> bool: ... + @property + def is_monotonic_increasing(self) -> bool: ... + @property + def is_monotonic_decreasing(self) -> bool: ... + @property + def is_mapping_populated(self) -> bool: ... + def clear_mapping(self): ... + def get_indexer(self, values: np.ndarray) -> npt.NDArray[np.intp]: ... + def get_indexer_non_unique( + self, + targets: np.ndarray, + ) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ... + +class MaskedIndexEngine(IndexEngine): + def __init__(self, values: object) -> None: ... + def get_indexer_non_unique( + self, targets: object + ) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ... + +class Float64Engine(IndexEngine): ... +class Float32Engine(IndexEngine): ... +class Complex128Engine(IndexEngine): ... +class Complex64Engine(IndexEngine): ... +class Int64Engine(IndexEngine): ... +class Int32Engine(IndexEngine): ... +class Int16Engine(IndexEngine): ... +class Int8Engine(IndexEngine): ... +class UInt64Engine(IndexEngine): ... +class UInt32Engine(IndexEngine): ... +class UInt16Engine(IndexEngine): ... +class UInt8Engine(IndexEngine): ... +class ObjectEngine(IndexEngine): ... +class StringEngine(IndexEngine): ... +class DatetimeEngine(Int64Engine): ... +class TimedeltaEngine(DatetimeEngine): ... +class PeriodEngine(Int64Engine): ... +class BoolEngine(UInt8Engine): ... +class MaskedFloat64Engine(MaskedIndexEngine): ... +class MaskedFloat32Engine(MaskedIndexEngine): ... +class MaskedComplex128Engine(MaskedIndexEngine): ... +class MaskedComplex64Engine(MaskedIndexEngine): ... +class MaskedInt64Engine(MaskedIndexEngine): ... +class MaskedInt32Engine(MaskedIndexEngine): ... +class MaskedInt16Engine(MaskedIndexEngine): ... +class MaskedInt8Engine(MaskedIndexEngine): ... +class MaskedUInt64Engine(MaskedIndexEngine): ... +class MaskedUInt32Engine(MaskedIndexEngine): ... +class MaskedUInt16Engine(MaskedIndexEngine): ... +class MaskedUInt8Engine(MaskedIndexEngine): ... +class MaskedBoolEngine(MaskedUInt8Engine): ... + +class StringObjectEngine(ObjectEngine): + def __init__(self, values: object, na_value) -> None: ... + +class BaseMultiIndexCodesEngine: + levels: list[np.ndarray] + offsets: np.ndarray # np.ndarray[..., ndim=1] + + def __init__( + self, + levels: list[Index], # all entries hashable + labels: list[np.ndarray], # all entries integer-dtyped + offsets: np.ndarray, # np.ndarray[..., ndim=1] + ) -> None: ... + def get_indexer(self, target: npt.NDArray[np.object_]) -> npt.NDArray[np.intp]: ... + def _extract_level_codes(self, target: MultiIndex) -> np.ndarray: ... + +class ExtensionEngine: + def __init__(self, values: ExtensionArray) -> None: ... + def __contains__(self, val: object) -> bool: ... + def get_loc(self, val: object) -> int | slice | np.ndarray: ... + def get_indexer(self, values: np.ndarray) -> npt.NDArray[np.intp]: ... + def get_indexer_non_unique( + self, + targets: np.ndarray, + ) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ... + @property + def is_unique(self) -> bool: ... + @property + def is_monotonic_increasing(self) -> bool: ... + @property + def is_monotonic_decreasing(self) -> bool: ... + def sizeof(self, deep: bool = ...) -> int: ... + def clear_mapping(self): ... diff --git a/pandas/_libs/indexing.cpython-312-x86_64-linux-gnu.so b/pandas/_libs/indexing.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..065b9280a71c109580b05568e8595eb7fba9e7de Binary files /dev/null and b/pandas/_libs/indexing.cpython-312-x86_64-linux-gnu.so differ diff --git a/pandas/_libs/indexing.pyi b/pandas/_libs/indexing.pyi new file mode 100644 index 0000000000000000000000000000000000000000..3ae5c5044a2f75452fa57ba578af2c7b4c78ec96 --- /dev/null +++ b/pandas/_libs/indexing.pyi @@ -0,0 +1,17 @@ +from typing import ( + Generic, + TypeVar, +) + +from pandas.core.indexing import IndexingMixin + +_IndexingMixinT = TypeVar("_IndexingMixinT", bound=IndexingMixin) + +class NDFrameIndexerBase(Generic[_IndexingMixinT]): + name: str + # in practice obj is either a DataFrame or a Series + obj: _IndexingMixinT + + def __init__(self, name: str, obj: _IndexingMixinT) -> None: ... + @property + def ndim(self) -> int: ... diff --git a/pandas/_libs/internals.pyi b/pandas/_libs/internals.pyi new file mode 100644 index 0000000000000000000000000000000000000000..11d059ec53920e5f44911d8784b09332dbb4e797 --- /dev/null +++ b/pandas/_libs/internals.pyi @@ -0,0 +1,96 @@ +from collections.abc import ( + Iterator, + Sequence, +) +from typing import ( + Self, + final, + overload, +) +import weakref + +import numpy as np + +from pandas._typing import ( + ArrayLike, + npt, +) + +from pandas import Index +from pandas.core.internals.blocks import Block as B + +def slice_len(slc: slice, objlen: int = ...) -> int: ... +def get_concat_blkno_indexers( + blknos_list: list[npt.NDArray[np.intp]], +) -> list[tuple[npt.NDArray[np.intp], BlockPlacement]]: ... +def get_blkno_indexers( + blknos: np.ndarray, # int64_t[:] + group: bool = ..., +) -> list[tuple[int, slice | np.ndarray]]: ... +def get_blkno_placements( + blknos: np.ndarray, + group: bool = ..., +) -> Iterator[tuple[int, BlockPlacement]]: ... +def update_blklocs_and_blknos( + blklocs: npt.NDArray[np.intp], + blknos: npt.NDArray[np.intp], + loc: int, + nblocks: int, +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ... +@final +class BlockPlacement: + def __init__(self, val: int | slice | np.ndarray) -> None: ... + @property + def indexer(self) -> np.ndarray | slice: ... + @property + def as_array(self) -> np.ndarray: ... + @property + def as_slice(self) -> slice: ... + @property + def is_slice_like(self) -> bool: ... + @overload + def __getitem__( + self, loc: slice | Sequence[int] | npt.NDArray[np.intp] + ) -> BlockPlacement: ... + @overload + def __getitem__(self, loc: int) -> int: ... + def __iter__(self) -> Iterator[int]: ... + def __len__(self) -> int: ... + def delete(self, loc) -> BlockPlacement: ... + def add(self, other) -> BlockPlacement: ... + def append(self, others: list[BlockPlacement]) -> BlockPlacement: ... + def tile_for_unstack(self, factor: int) -> npt.NDArray[np.intp]: ... + +class Block: + _mgr_locs: BlockPlacement + ndim: int + values: ArrayLike + refs: BlockValuesRefs + def __init__( + self, + values: ArrayLike, + placement: BlockPlacement, + ndim: int, + refs: BlockValuesRefs | None = ..., + ) -> None: ... + def slice_block_rows(self, slicer: slice) -> Self: ... + +class BlockManager: + blocks: tuple[B, ...] + axes: list[Index] + _known_consolidated: bool + _is_consolidated: bool + _blknos: np.ndarray + _blklocs: np.ndarray + def __init__( + self, blocks: tuple[B, ...], axes: list[Index], verify_integrity=... + ) -> None: ... + def get_slice(self, slobj: slice, axis: int = ...) -> Self: ... + def _rebuild_blknos_and_blklocs(self) -> None: ... + +class BlockValuesRefs: + referenced_blocks: list[weakref.ref] + def __init__(self, blk: Block | None = ...) -> None: ... + def add_reference(self, blk: Block) -> None: ... + def add_index_reference(self, index: Index) -> None: ... + def has_reference(self) -> bool: ... diff --git a/pandas/_libs/interval.pyi b/pandas/_libs/interval.pyi new file mode 100644 index 0000000000000000000000000000000000000000..587fdf84f2f85520713352bbcab29804c95621e5 --- /dev/null +++ b/pandas/_libs/interval.pyi @@ -0,0 +1,174 @@ +from typing import ( + Any, + Generic, + TypeVar, + overload, +) + +import numpy as np +import numpy.typing as npt + +from pandas._typing import ( + IntervalClosedType, + Timedelta, + Timestamp, +) + +VALID_CLOSED: frozenset[str] + +_OrderableScalarT = TypeVar("_OrderableScalarT", int, float) +_OrderableTimesT = TypeVar("_OrderableTimesT", Timestamp, Timedelta) +_OrderableT = TypeVar("_OrderableT", int, float, Timestamp, Timedelta) + +class _LengthDescriptor: + @overload + def __get__( + self, instance: Interval[_OrderableScalarT], owner: Any + ) -> _OrderableScalarT: ... + @overload + def __get__( + self, instance: Interval[_OrderableTimesT], owner: Any + ) -> Timedelta: ... + +class _MidDescriptor: + @overload + def __get__(self, instance: Interval[_OrderableScalarT], owner: Any) -> float: ... + @overload + def __get__( + self, instance: Interval[_OrderableTimesT], owner: Any + ) -> _OrderableTimesT: ... + +class IntervalMixin: + @property + def closed_left(self) -> bool: ... + @property + def closed_right(self) -> bool: ... + @property + def open_left(self) -> bool: ... + @property + def open_right(self) -> bool: ... + @property + def is_empty(self) -> bool: ... + def _check_closed_matches(self, other: IntervalMixin, name: str = ...) -> None: ... + +class Interval(IntervalMixin, Generic[_OrderableT]): + @property + def left(self: Interval[_OrderableT]) -> _OrderableT: ... + @property + def right(self: Interval[_OrderableT]) -> _OrderableT: ... + @property + def closed(self) -> IntervalClosedType: ... + mid: _MidDescriptor + length: _LengthDescriptor + def __init__( + self, + left: _OrderableT, + right: _OrderableT, + closed: IntervalClosedType = ..., + ) -> None: ... + def __hash__(self) -> int: ... + @overload + def __contains__( + self: Interval[Timedelta], key: Timedelta | Interval[Timedelta] + ) -> bool: ... + @overload + def __contains__( + self: Interval[Timestamp], key: Timestamp | Interval[Timestamp] + ) -> bool: ... + @overload + def __contains__( + self: Interval[_OrderableScalarT], + key: _OrderableScalarT | Interval[_OrderableScalarT], + ) -> bool: ... + @overload + def __add__( + self: Interval[_OrderableTimesT], y: Timedelta + ) -> Interval[_OrderableTimesT]: ... + @overload + def __add__( + self: Interval[int], y: _OrderableScalarT + ) -> Interval[_OrderableScalarT]: ... + @overload + def __add__(self: Interval[float], y: float) -> Interval[float]: ... + @overload + def __radd__( + self: Interval[_OrderableTimesT], y: Timedelta + ) -> Interval[_OrderableTimesT]: ... + @overload + def __radd__( + self: Interval[int], y: _OrderableScalarT + ) -> Interval[_OrderableScalarT]: ... + @overload + def __radd__(self: Interval[float], y: float) -> Interval[float]: ... + @overload + def __sub__( + self: Interval[_OrderableTimesT], y: Timedelta + ) -> Interval[_OrderableTimesT]: ... + @overload + def __sub__( + self: Interval[int], y: _OrderableScalarT + ) -> Interval[_OrderableScalarT]: ... + @overload + def __sub__(self: Interval[float], y: float) -> Interval[float]: ... + @overload + def __rsub__( + self: Interval[_OrderableTimesT], y: Timedelta + ) -> Interval[_OrderableTimesT]: ... + @overload + def __rsub__( + self: Interval[int], y: _OrderableScalarT + ) -> Interval[_OrderableScalarT]: ... + @overload + def __rsub__(self: Interval[float], y: float) -> Interval[float]: ... + @overload + def __mul__( + self: Interval[int], y: _OrderableScalarT + ) -> Interval[_OrderableScalarT]: ... + @overload + def __mul__(self: Interval[float], y: float) -> Interval[float]: ... + @overload + def __rmul__( + self: Interval[int], y: _OrderableScalarT + ) -> Interval[_OrderableScalarT]: ... + @overload + def __rmul__(self: Interval[float], y: float) -> Interval[float]: ... + @overload + def __truediv__( + self: Interval[int], y: _OrderableScalarT + ) -> Interval[_OrderableScalarT]: ... + @overload + def __truediv__(self: Interval[float], y: float) -> Interval[float]: ... + @overload + def __floordiv__( + self: Interval[int], y: _OrderableScalarT + ) -> Interval[_OrderableScalarT]: ... + @overload + def __floordiv__(self: Interval[float], y: float) -> Interval[float]: ... + def overlaps(self: Interval[_OrderableT], other: Interval[_OrderableT]) -> bool: ... + +def intervals_to_interval_bounds( + intervals: np.ndarray, validate_closed: bool = ... +) -> tuple[np.ndarray, np.ndarray, IntervalClosedType]: ... + +class IntervalTree(IntervalMixin): + def __init__( + self, + left: np.ndarray, + right: np.ndarray, + closed: IntervalClosedType = ..., + leaf_size: int = ..., + ) -> None: ... + @property + def mid(self) -> np.ndarray: ... + @property + def length(self) -> np.ndarray: ... + def get_indexer(self, target) -> npt.NDArray[np.intp]: ... + def get_indexer_non_unique( + self, target + ) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ... + _na_count: int + @property + def is_overlapping(self) -> bool: ... + @property + def is_monotonic_increasing(self) -> bool: ... + def clear_mapping(self) -> None: ... diff --git a/pandas/_libs/join.pyi b/pandas/_libs/join.pyi new file mode 100644 index 0000000000000000000000000000000000000000..1d4e8c90bc5593eae650319e9ca1b58cbd7eed73 --- /dev/null +++ b/pandas/_libs/join.pyi @@ -0,0 +1,79 @@ +import numpy as np + +from pandas._typing import npt + +def inner_join( + left: np.ndarray, # const intp_t[:] + right: np.ndarray, # const intp_t[:] + max_groups: int, + sort: bool = ..., +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ... +def left_outer_join( + left: np.ndarray, # const intp_t[:] + right: np.ndarray, # const intp_t[:] + max_groups: int, + sort: bool = ..., +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ... +def full_outer_join( + left: np.ndarray, # const intp_t[:] + right: np.ndarray, # const intp_t[:] + max_groups: int, +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ... +def ffill_indexer( + indexer: np.ndarray, # const intp_t[:] +) -> npt.NDArray[np.intp]: ... +def left_join_indexer_unique( + left: np.ndarray, # ndarray[join_t] + right: np.ndarray, # ndarray[join_t] +) -> npt.NDArray[np.intp]: ... +def left_join_indexer( + left: np.ndarray, # ndarray[join_t] + right: np.ndarray, # ndarray[join_t] +) -> tuple[ + np.ndarray, # np.ndarray[join_t] + npt.NDArray[np.intp], + npt.NDArray[np.intp], +]: ... +def inner_join_indexer( + left: np.ndarray, # ndarray[join_t] + right: np.ndarray, # ndarray[join_t] +) -> tuple[ + np.ndarray, # np.ndarray[join_t] + npt.NDArray[np.intp], + npt.NDArray[np.intp], +]: ... +def outer_join_indexer( + left: np.ndarray, # ndarray[join_t] + right: np.ndarray, # ndarray[join_t] +) -> tuple[ + np.ndarray, # np.ndarray[join_t] + npt.NDArray[np.intp], + npt.NDArray[np.intp], +]: ... +def asof_join_backward_on_X_by_Y( + left_values: np.ndarray, # ndarray[numeric_t] + right_values: np.ndarray, # ndarray[numeric_t] + left_by_values: np.ndarray, # const int64_t[:] + right_by_values: np.ndarray, # const int64_t[:] + allow_exact_matches: bool = ..., + tolerance: np.number | float | None = ..., + use_hashtable: bool = ..., +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ... +def asof_join_forward_on_X_by_Y( + left_values: np.ndarray, # ndarray[numeric_t] + right_values: np.ndarray, # ndarray[numeric_t] + left_by_values: np.ndarray, # const int64_t[:] + right_by_values: np.ndarray, # const int64_t[:] + allow_exact_matches: bool = ..., + tolerance: np.number | float | None = ..., + use_hashtable: bool = ..., +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ... +def asof_join_nearest_on_X_by_Y( + left_values: np.ndarray, # ndarray[numeric_t] + right_values: np.ndarray, # ndarray[numeric_t] + left_by_values: np.ndarray, # const int64_t[:] + right_by_values: np.ndarray, # const int64_t[:] + allow_exact_matches: bool = ..., + tolerance: np.number | float | None = ..., + use_hashtable: bool = ..., +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ... diff --git a/pandas/_libs/json.cpython-312-x86_64-linux-gnu.so b/pandas/_libs/json.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..cba519b304383daf0c27a5bf66da87c31760b773 Binary files /dev/null and b/pandas/_libs/json.cpython-312-x86_64-linux-gnu.so differ diff --git a/pandas/_libs/json.pyi b/pandas/_libs/json.pyi new file mode 100644 index 0000000000000000000000000000000000000000..349320d69d707a27b5ca75a5eabfa3c867fa25c5 --- /dev/null +++ b/pandas/_libs/json.pyi @@ -0,0 +1,23 @@ +from collections.abc import Callable +from typing import ( + Any, +) + +def ujson_dumps( + obj: Any, + ensure_ascii: bool = ..., + double_precision: int = ..., + indent: int = ..., + orient: str = ..., + date_unit: str = ..., + iso_dates: bool = ..., + default_handler: None + | Callable[[Any], str | float | bool | list | dict | None] = ..., +) -> str: ... +def ujson_loads( + s: str, + precise_float: bool = ..., + numpy: bool = ..., + dtype: None = ..., + labelled: bool = ..., +) -> Any: ... diff --git a/pandas/_libs/lib.pyi b/pandas/_libs/lib.pyi new file mode 100644 index 0000000000000000000000000000000000000000..e50b301c348688423b62a53e2b44e98e54ddc9ae --- /dev/null +++ b/pandas/_libs/lib.pyi @@ -0,0 +1,238 @@ +# TODO(npdtypes): Many types specified here can be made more specific/accurate; +# the more specific versions are specified in comments +from collections.abc import ( + Callable, + Generator, + Hashable, +) +from decimal import Decimal +from typing import ( + Any, + Final, + Literal, + TypeAlias, + TypeGuard, + overload, +) + +import numpy as np + +from pandas._typing import ( + ArrayLike, + DtypeObj, + npt, +) + +# placeholder until we can specify np.ndarray[object, ndim=2] +ndarray_obj_2d = np.ndarray + +from enum import Enum + +class _NoDefault(Enum): + no_default = ... + +no_default: Final = _NoDefault.no_default +NoDefault: TypeAlias = Literal[_NoDefault.no_default] + +i8max: int +u8max: int + +def is_np_dtype(dtype: object, kinds: str | None = ...) -> TypeGuard[np.dtype]: ... +def item_from_zerodim(val: object) -> object: ... +def infer_dtype(value: object, skipna: bool = ...) -> str: ... +def is_iterator(obj: object) -> bool: ... +def is_scalar(val: object) -> bool: ... +def is_list_like(obj: object, allow_sets: bool = ...) -> bool: ... +def is_pyarrow_array(obj: object) -> bool: ... +def is_decimal(obj: object) -> TypeGuard[Decimal]: ... +def is_complex(obj: object) -> TypeGuard[complex]: ... +def is_bool(obj: object) -> TypeGuard[bool | np.bool_]: ... +def is_integer(obj: object) -> TypeGuard[int | np.integer]: ... +def is_int_or_none(obj) -> bool: ... +def is_float(obj: object) -> TypeGuard[float]: ... +def is_interval_array(values: np.ndarray) -> bool: ... +def is_datetime64_array(values: np.ndarray, skipna: bool = True) -> bool: ... +def is_timedelta_or_timedelta64_array( + values: np.ndarray, skipna: bool = True +) -> bool: ... +def is_datetime_with_singletz_array(values: np.ndarray) -> bool: ... +def is_time_array(values: np.ndarray, skipna: bool = ...): ... +def is_date_array(values: np.ndarray, skipna: bool = ...): ... +def is_datetime_array(values: np.ndarray, skipna: bool = ...): ... +def is_string_array(values: np.ndarray, skipna: bool = ...): ... +def is_float_array(values: np.ndarray, skipna: bool = ...): ... +def is_integer_array(values: np.ndarray, skipna: bool = ...): ... +def is_bool_array(values: np.ndarray, skipna: bool = ...): ... +def fast_multiget( + mapping: dict, + keys: np.ndarray, # object[:] + default=..., +) -> ArrayLike: ... +def fast_unique_multiple_list_gen(gen: Generator, sort: bool = ...) -> list: ... +@overload +def map_infer( + arr: np.ndarray, + f: Callable[[Any], Any], + *, + convert: Literal[False], + ignore_na: bool = ..., +) -> np.ndarray: ... +@overload +def map_infer( + arr: np.ndarray, + f: Callable[[Any], Any], + *, + convert: bool = ..., + ignore_na: bool = ..., +) -> ArrayLike: ... +@overload +def maybe_convert_objects( + objects: npt.NDArray[np.object_], + *, + try_float: bool = ..., + safe: bool = ..., + convert_numeric: bool = ..., + convert_non_numeric: Literal[False] = ..., + convert_to_nullable_dtype: Literal[False] = ..., + dtype_if_all_nat: DtypeObj | None = ..., +) -> npt.NDArray[np.object_ | np.number]: ... +@overload +def maybe_convert_objects( + objects: npt.NDArray[np.object_], + *, + try_float: bool = ..., + safe: bool = ..., + convert_numeric: bool = ..., + convert_non_numeric: bool = ..., + convert_to_nullable_dtype: Literal[True] = ..., + dtype_if_all_nat: DtypeObj | None = ..., +) -> ArrayLike: ... +@overload +def maybe_convert_objects( + objects: npt.NDArray[np.object_], + *, + try_float: bool = ..., + safe: bool = ..., + convert_numeric: bool = ..., + convert_non_numeric: bool = ..., + convert_to_nullable_dtype: bool = ..., + dtype_if_all_nat: DtypeObj | None = ..., +) -> ArrayLike: ... +@overload +def maybe_convert_numeric( + values: npt.NDArray[np.object_], + na_values: set, + convert_empty: bool = ..., + coerce_numeric: bool = ..., + convert_to_masked_nullable: Literal[False] = ..., +) -> tuple[np.ndarray, None]: ... +@overload +def maybe_convert_numeric( + values: npt.NDArray[np.object_], + na_values: set, + convert_empty: bool = ..., + coerce_numeric: bool = ..., + *, + convert_to_masked_nullable: Literal[True], +) -> tuple[np.ndarray, np.ndarray]: ... + +# TODO: restrict `arr`? +def ensure_string_array( + arr, + na_value: object = ..., + convert_na_value: bool = ..., + copy: bool = ..., + skipna: bool = ..., +) -> npt.NDArray[np.object_]: ... +def convert_nans_to_NA( + arr: npt.NDArray[np.object_], +) -> npt.NDArray[np.object_]: ... +def fast_zip(ndarrays: list) -> npt.NDArray[np.object_]: ... + +# TODO: can we be more specific about rows? +def to_object_array_tuples(rows: object) -> ndarray_obj_2d: ... +def tuples_to_object_array( + tuples: npt.NDArray[np.object_], +) -> ndarray_obj_2d: ... + +# TODO: can we be more specific about rows? +def to_object_array(rows: object, min_width: int = ...) -> ndarray_obj_2d: ... +def dicts_to_array(dicts: list, columns: list) -> ndarray_obj_2d: ... +def maybe_booleans_to_slice( + mask: npt.NDArray[np.uint8], +) -> slice | npt.NDArray[np.uint8]: ... +def maybe_indices_to_slice( + indices: npt.NDArray[np.intp], + max_len: int, +) -> slice | npt.NDArray[np.intp]: ... +def is_all_arraylike(obj: list) -> bool: ... + +# ----------------------------------------------------------------- +# Functions which in reality take memoryviews + +def memory_usage_of_objects(arr: np.ndarray) -> int: ... # object[:] # np.int64 +@overload +def map_infer_mask( + arr: np.ndarray, + f: Callable[[Any], Any], + mask: np.ndarray, # const uint8_t[:] + *, + convert: Literal[False], + na_value: Any = ..., + dtype: np.dtype = ..., +) -> np.ndarray: ... +@overload +def map_infer_mask( + arr: np.ndarray, + f: Callable[[Any], Any], + mask: np.ndarray, # const uint8_t[:] + *, + convert: bool = ..., + na_value: Any = ..., + dtype: np.dtype = ..., +) -> ArrayLike: ... +def indices_fast( + index: npt.NDArray[np.intp], + labels: np.ndarray, # const int64_t[:] + keys: list, + sorted_labels: list[npt.NDArray[np.int64]], +) -> dict[Hashable, npt.NDArray[np.intp]]: ... +def generate_slices( + labels: np.ndarray, + ngroups: int, # const intp_t[:] +) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]: ... +def count_level_2d( + mask: np.ndarray, # ndarray[uint8_t, ndim=2, cast=True], + labels: np.ndarray, # const intp_t[:] + max_bin: int, +) -> np.ndarray: ... # np.ndarray[np.int64, ndim=2] +def get_level_sorter( + codes: np.ndarray, # const int64_t[:] + starts: np.ndarray, # const intp_t[:] +) -> np.ndarray: ... # np.ndarray[np.intp, ndim=1] +def generate_bins_dt64( + values: npt.NDArray[np.int64], + binner: np.ndarray, # const int64_t[:] + closed: object = ..., + hasnans: bool = ..., +) -> np.ndarray: ... # np.ndarray[np.int64, ndim=1] +def array_equivalent_object( + left: npt.NDArray[np.object_], + right: npt.NDArray[np.object_], +) -> bool: ... +def has_infs(arr: np.ndarray) -> bool: ... # const floating[:] +def has_only_ints_or_nan(arr: np.ndarray) -> bool: ... # const floating[:] +def get_reverse_indexer( + indexer: np.ndarray, # const intp_t[:] + length: int, +) -> npt.NDArray[np.intp]: ... +def is_bool_list(obj: list) -> bool: ... +def dtypes_all_equal(types: list[DtypeObj]) -> bool: ... +def is_range_indexer( + left: np.ndarray, + n: int, # np.ndarray[np.int64, ndim=1] +) -> bool: ... +def is_sequence_range( + sequence: np.ndarray, + step: int, # np.ndarray[np.int64, ndim=1] +) -> bool: ... diff --git a/pandas/_libs/missing.pyi b/pandas/_libs/missing.pyi new file mode 100644 index 0000000000000000000000000000000000000000..64256ae4b36ad2fcf0aa59024957d01aa122fedf --- /dev/null +++ b/pandas/_libs/missing.pyi @@ -0,0 +1,17 @@ +import numpy as np +from numpy import typing as npt + +class NAType: + def __new__(cls, *args, **kwargs): ... + +NA: NAType + +def is_matching_na( + left: object, right: object, nan_matches_none: bool = ... +) -> bool: ... +def isposinf_scalar(val: object) -> bool: ... +def isneginf_scalar(val: object) -> bool: ... +def checknull(val: object) -> bool: ... +def isnaobj(arr: np.ndarray) -> npt.NDArray[np.bool_]: ... +def is_numeric_na(values: np.ndarray) -> npt.NDArray[np.bool_]: ... +def is_pdna_or_none(values: np.ndarray) -> npt.NDArray[np.bool_]: ... diff --git a/pandas/_libs/ops.pyi b/pandas/_libs/ops.pyi new file mode 100644 index 0000000000000000000000000000000000000000..81fe81930539d1dc009ba1242fff92a1afd4cdc4 --- /dev/null +++ b/pandas/_libs/ops.pyi @@ -0,0 +1,53 @@ +from collections.abc import ( + Callable, + Iterable, +) +from typing import ( + Any, + Literal, + TypeAlias, + overload, +) + +import numpy as np + +from pandas._typing import npt + +_BinOp: TypeAlias = Callable[[Any, Any], Any] +_BoolOp: TypeAlias = Callable[[Any, Any], bool] + +def scalar_compare( + values: np.ndarray, # object[:] + val: object, + op: _BoolOp, # {operator.eq, operator.ne, ...} +) -> npt.NDArray[np.bool_]: ... +def vec_compare( + left: npt.NDArray[np.object_], + right: npt.NDArray[np.object_], + op: _BoolOp, # {operator.eq, operator.ne, ...} +) -> npt.NDArray[np.bool_]: ... +def scalar_binop( + values: np.ndarray, # object[:] + val: object, + op: _BinOp, # binary operator +) -> np.ndarray: ... +def vec_binop( + left: np.ndarray, # object[:] + right: np.ndarray, # object[:] + op: _BinOp, # binary operator +) -> np.ndarray: ... +@overload +def maybe_convert_bool( + arr: npt.NDArray[np.object_], + true_values: Iterable | None = None, + false_values: Iterable | None = None, + convert_to_masked_nullable: Literal[False] = ..., +) -> tuple[np.ndarray, None]: ... +@overload +def maybe_convert_bool( + arr: npt.NDArray[np.object_], + true_values: Iterable = ..., + false_values: Iterable = ..., + *, + convert_to_masked_nullable: Literal[True], +) -> tuple[np.ndarray, np.ndarray]: ... diff --git a/pandas/_libs/ops_dispatch.cpython-312-x86_64-linux-gnu.so b/pandas/_libs/ops_dispatch.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..07d0736acd0575c0996ad61540272976788c750a Binary files /dev/null and b/pandas/_libs/ops_dispatch.cpython-312-x86_64-linux-gnu.so differ diff --git a/pandas/_libs/ops_dispatch.pyi b/pandas/_libs/ops_dispatch.pyi new file mode 100644 index 0000000000000000000000000000000000000000..91b5a4dbaaebc177191d3189f12e4e20d56ca0fa --- /dev/null +++ b/pandas/_libs/ops_dispatch.pyi @@ -0,0 +1,5 @@ +import numpy as np + +def maybe_dispatch_ufunc_to_dunder_op( + self, ufunc: np.ufunc, method: str, *inputs, **kwargs +): ... diff --git a/pandas/_libs/pandas_datetime.cpython-312-x86_64-linux-gnu.so b/pandas/_libs/pandas_datetime.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..7f2307cd48672dafeb3fd2b8927698d446551882 Binary files /dev/null and b/pandas/_libs/pandas_datetime.cpython-312-x86_64-linux-gnu.so differ diff --git a/pandas/_libs/pandas_parser.cpython-312-x86_64-linux-gnu.so b/pandas/_libs/pandas_parser.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..18393c333cbc36ba233bb18d2025627e4669c54a Binary files /dev/null and b/pandas/_libs/pandas_parser.cpython-312-x86_64-linux-gnu.so differ diff --git a/pandas/_libs/parsers.pyi b/pandas/_libs/parsers.pyi new file mode 100644 index 0000000000000000000000000000000000000000..d18f54c54623236d47df31be70e8516202d4c86a --- /dev/null +++ b/pandas/_libs/parsers.pyi @@ -0,0 +1,77 @@ +from collections.abc import Hashable +from typing import ( + Literal, +) + +import numpy as np + +from pandas._typing import ( + ArrayLike, + Dtype, + npt, +) + +STR_NA_VALUES: set[str] +DEFAULT_BUFFER_HEURISTIC: int + +def sanitize_objects( + values: npt.NDArray[np.object_], + na_values: set, +) -> int: ... + +class TextReader: + unnamed_cols: set[str] + table_width: int # int64_t + leading_cols: int # int64_t + header: list[list[int]] # non-negative integers + def __init__( + self, + source, + delimiter: bytes | str = ..., # single-character only + header=..., + header_start: int = ..., # int64_t + header_end: int = ..., # uint64_t + index_col=..., + names=..., + tokenize_chunksize: int = ..., # int64_t + delim_whitespace: bool = ..., + converters=..., + skipinitialspace: bool = ..., + escapechar: bytes | str | None = ..., # single-character only + doublequote: bool = ..., + quotechar: str | bytes | None = ..., # at most 1 character + quoting: int = ..., + lineterminator: bytes | str | None = ..., # at most 1 character + comment=..., + decimal: bytes | str = ..., # single-character only + thousands: bytes | str | None = ..., # single-character only + dtype: Dtype | dict[Hashable, Dtype] = ..., + usecols=..., + error_bad_lines: bool = ..., + warn_bad_lines: bool = ..., + na_filter: bool = ..., + na_values=..., + na_fvalues=..., + keep_default_na: bool = ..., + true_values=..., + false_values=..., + allow_leading_cols: bool = ..., + skiprows=..., + skipfooter: int = ..., # int64_t + verbose: bool = ..., + float_precision: Literal["round_trip", "legacy", "high"] | None = ..., + skip_blank_lines: bool = ..., + encoding_errors: bytes | str = ..., + ) -> None: ... + def set_noconvert(self, i: int) -> None: ... + def remove_noconvert(self, i: int) -> None: ... + def close(self) -> None: ... + def read(self, rows: int | None = ...) -> dict[int, ArrayLike]: ... + def read_low_memory(self, rows: int | None) -> list[dict[int, ArrayLike]]: ... + +# _maybe_upcast, na_values are only exposed for testing +na_values: dict + +def _maybe_upcast( + arr, use_dtype_backend: bool = ..., dtype_backend: str = ... +) -> np.ndarray: ... diff --git a/pandas/_libs/properties.cpython-312-x86_64-linux-gnu.so b/pandas/_libs/properties.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..2dfdcfefa6c9b1c264b74c4b1a80a6cdfbea9c71 Binary files /dev/null and b/pandas/_libs/properties.cpython-312-x86_64-linux-gnu.so differ diff --git a/pandas/_libs/properties.pyi b/pandas/_libs/properties.pyi new file mode 100644 index 0000000000000000000000000000000000000000..bbde6ec454202fe0c2f1d8f8da40280f13d1536c --- /dev/null +++ b/pandas/_libs/properties.pyi @@ -0,0 +1,27 @@ +from collections.abc import Sequence +from typing import ( + overload, +) + +from pandas._typing import ( + AnyArrayLike, + DataFrame, + Index, + Series, +) + +# note: this is a lie to make type checkers happy (they special +# case property). cache_readonly uses attribute names similar to +# property (fget) but it does not provide fset and fdel. +cache_readonly = property + +class AxisProperty: + axis: int + def __init__(self, axis: int = ..., doc: str = ...) -> None: ... + @overload + def __get__(self, obj: DataFrame | Series, type) -> Index: ... + @overload + def __get__(self, obj: None, type) -> AxisProperty: ... + def __set__( + self, obj: DataFrame | Series, value: AnyArrayLike | Sequence + ) -> None: ... diff --git a/pandas/_libs/reshape.pyi b/pandas/_libs/reshape.pyi new file mode 100644 index 0000000000000000000000000000000000000000..110687fcd0c313c45e8b025083fa5790fb9913b1 --- /dev/null +++ b/pandas/_libs/reshape.pyi @@ -0,0 +1,16 @@ +import numpy as np + +from pandas._typing import npt + +def unstack( + values: np.ndarray, # reshape_t[:, :] + mask: np.ndarray, # const uint8_t[:] + stride: int, + length: int, + width: int, + new_values: np.ndarray, # reshape_t[:, :] + new_mask: np.ndarray, # uint8_t[:, :] +) -> None: ... +def explode( + values: npt.NDArray[np.object_], +) -> tuple[npt.NDArray[np.object_], npt.NDArray[np.int64]]: ... diff --git a/pandas/_libs/sas.pyi b/pandas/_libs/sas.pyi new file mode 100644 index 0000000000000000000000000000000000000000..5d65e2b56b5916ed1e76e1409e4f75c652ee8fc9 --- /dev/null +++ b/pandas/_libs/sas.pyi @@ -0,0 +1,7 @@ +from pandas.io.sas.sas7bdat import SAS7BDATReader + +class Parser: + def __init__(self, parser: SAS7BDATReader) -> None: ... + def read(self, nrows: int) -> None: ... + +def get_subheader_index(signature: bytes) -> int: ... diff --git a/pandas/_libs/sparse.pyi b/pandas/_libs/sparse.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f1f3efb4d3096a77ad9aaafe14647fb0d22cbac3 --- /dev/null +++ b/pandas/_libs/sparse.pyi @@ -0,0 +1,51 @@ +from typing import Self + +import numpy as np + +from pandas._typing import ( + TakeIndexer, + npt, +) + +class SparseIndex: + length: int + npoints: int + def __init__(self) -> None: ... + @property + def ngaps(self) -> int: ... + @property + def nbytes(self) -> int: ... + @property + def indices(self) -> npt.NDArray[np.int32]: ... + def equals(self, other) -> bool: ... + def lookup(self, index: int) -> np.int32: ... + def lookup_array(self, indexer: npt.NDArray[np.int32]) -> npt.NDArray[np.int32]: ... + def to_int_index(self) -> IntIndex: ... + def to_block_index(self) -> BlockIndex: ... + def intersect(self, y_: SparseIndex) -> Self: ... + def make_union(self, y_: SparseIndex) -> Self: ... + +class IntIndex(SparseIndex): + indices: npt.NDArray[np.int32] + def __init__( + self, length: int, indices: TakeIndexer, check_integrity: bool = ... + ) -> None: ... + +class BlockIndex(SparseIndex): + nblocks: int + blocs: np.ndarray + blengths: np.ndarray + def __init__( + self, length: int, blocs: np.ndarray, blengths: np.ndarray + ) -> None: ... + + # Override to have correct parameters + def intersect(self, other: SparseIndex) -> Self: ... + def make_union(self, y: SparseIndex) -> Self: ... + +def make_mask_object_ndarray( + arr: npt.NDArray[np.object_], fill_value +) -> npt.NDArray[np.bool_]: ... +def get_blocks( + indices: npt.NDArray[np.int32], +) -> tuple[npt.NDArray[np.int32], npt.NDArray[np.int32]]: ... diff --git a/pandas/_libs/testing.pyi b/pandas/_libs/testing.pyi new file mode 100644 index 0000000000000000000000000000000000000000..4758483b3b5e7755669d4f171cb8f662f063a344 --- /dev/null +++ b/pandas/_libs/testing.pyi @@ -0,0 +1,14 @@ +from collections.abc import Mapping + +def assert_dict_equal(a: Mapping, b: Mapping, compare_keys: bool = ...) -> bool: ... +def assert_almost_equal( + a, + b, + rtol: float = ..., + atol: float = ..., + check_dtype: bool = ..., + obj=..., + lobj=..., + robj=..., + index_values=..., +) -> bool: ... diff --git a/pandas/_libs/tslib.pyi b/pandas/_libs/tslib.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7e3372a80db9db10cde474ab7a1b4f7b6a7c39b5 --- /dev/null +++ b/pandas/_libs/tslib.pyi @@ -0,0 +1,33 @@ +from datetime import tzinfo + +import numpy as np + +from pandas._typing import npt + +def format_array_from_datetime( + values: npt.NDArray[np.int64], + tz: tzinfo | None = ..., + format: str | None = ..., + na_rep: str | float = ..., + reso: int = ..., # NPY_DATETIMEUNIT +) -> npt.NDArray[np.object_]: ... +def first_non_null(values: np.ndarray) -> int: ... +def array_to_datetime( + values: npt.NDArray[np.object_], + errors: str = ..., + dayfirst: bool = ..., + yearfirst: bool = ..., + utc: bool = ..., + creso: int = ..., + unit_for_numerics: str | None = ..., +) -> tuple[np.ndarray, tzinfo | None]: ... + +# returned ndarray may be object dtype or datetime64[ns] + +def array_to_datetime_with_tz( + values: npt.NDArray[np.object_], + tz: tzinfo, + dayfirst: bool, + yearfirst: bool, + creso: int, +) -> npt.NDArray[np.int64]: ... diff --git a/pandas/_libs/tslibs/__init__.py b/pandas/_libs/tslibs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c6bbb87baa2c729ed11fc6773d5f13511d16b2d --- /dev/null +++ b/pandas/_libs/tslibs/__init__.py @@ -0,0 +1,89 @@ +__all__ = [ + "BaseOffset", + "Day", + "IncompatibleFrequency", + "NaT", + "NaTType", + "OutOfBoundsDatetime", + "OutOfBoundsTimedelta", + "Period", + "Resolution", + "Tick", + "Timedelta", + "Timestamp", + "add_overflowsafe", + "astype_overflowsafe", + "delta_to_nanoseconds", + "dt64arr_to_periodarr", + "dtypes", + "get_resolution", + "get_supported_dtype", + "get_unit_from_dtype", + "guess_datetime_format", + "iNaT", + "ints_to_pydatetime", + "ints_to_pytimedelta", + "is_date_array_normalized", + "is_supported_dtype", + "is_unitless", + "localize_pydatetime", + "nat_strings", + "normalize_i8_timestamps", + "periods_per_day", + "periods_per_second", + "to_offset", + "tz_compare", + "tz_convert_from_utc", + "tz_convert_from_utc_single", +] + +from pandas._libs.tslibs import dtypes +from pandas._libs.tslibs.conversion import localize_pydatetime +from pandas._libs.tslibs.dtypes import ( + Resolution, + periods_per_day, + periods_per_second, +) +from pandas._libs.tslibs.nattype import ( + NaT, + NaTType, + iNaT, + nat_strings, +) +from pandas._libs.tslibs.np_datetime import ( + OutOfBoundsDatetime, + OutOfBoundsTimedelta, + add_overflowsafe, + astype_overflowsafe, + get_supported_dtype, + is_supported_dtype, + is_unitless, + py_get_unit_from_dtype as get_unit_from_dtype, +) +from pandas._libs.tslibs.offsets import ( + BaseOffset, + Day, + Tick, + to_offset, +) +from pandas._libs.tslibs.parsing import guess_datetime_format +from pandas._libs.tslibs.period import ( + IncompatibleFrequency, + Period, +) +from pandas._libs.tslibs.timedeltas import ( + Timedelta, + delta_to_nanoseconds, + ints_to_pytimedelta, +) +from pandas._libs.tslibs.timestamps import Timestamp +from pandas._libs.tslibs.timezones import tz_compare +from pandas._libs.tslibs.tzconversion import tz_convert_from_utc_single +from pandas._libs.tslibs.vectorized import ( + dt64arr_to_periodarr, + get_resolution, + ints_to_pydatetime, + is_date_array_normalized, + normalize_i8_timestamps, + tz_convert_from_utc, +) diff --git a/pandas/_libs/tslibs/base.cpython-312-x86_64-linux-gnu.so b/pandas/_libs/tslibs/base.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..73a0440d12f1621c220e7318e8d9c77a950579ef Binary files /dev/null and b/pandas/_libs/tslibs/base.cpython-312-x86_64-linux-gnu.so differ diff --git a/pandas/_libs/tslibs/ccalendar.cpython-312-x86_64-linux-gnu.so b/pandas/_libs/tslibs/ccalendar.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..9f79d5d210efe68161f87f488464ba0d2560b60b Binary files /dev/null and b/pandas/_libs/tslibs/ccalendar.cpython-312-x86_64-linux-gnu.so differ diff --git a/pandas/_libs/tslibs/ccalendar.pyi b/pandas/_libs/tslibs/ccalendar.pyi new file mode 100644 index 0000000000000000000000000000000000000000..993f18a61d74aaa643e9790df70ede618b917223 --- /dev/null +++ b/pandas/_libs/tslibs/ccalendar.pyi @@ -0,0 +1,12 @@ +DAYS: list[str] +MONTH_ALIASES: dict[int, str] +MONTH_NUMBERS: dict[str, int] +MONTHS: list[str] +int_to_weekday: dict[int, str] + +def get_firstbday(year: int, month: int) -> int: ... +def get_lastbday(year: int, month: int) -> int: ... +def get_day_of_year(year: int, month: int, day: int) -> int: ... +def get_iso_calendar(year: int, month: int, day: int) -> tuple[int, int, int]: ... +def get_week_of_year(year: int, month: int, day: int) -> int: ... +def get_days_in_month(year: int, month: int) -> int: ... diff --git a/pandas/_libs/tslibs/conversion.pyi b/pandas/_libs/tslibs/conversion.pyi new file mode 100644 index 0000000000000000000000000000000000000000..26affae577f4d3f4ecda2ac9c1bc0cb748a35d4d --- /dev/null +++ b/pandas/_libs/tslibs/conversion.pyi @@ -0,0 +1,14 @@ +from datetime import ( + datetime, + tzinfo, +) + +import numpy as np + +DT64NS_DTYPE: np.dtype +TD64NS_DTYPE: np.dtype + +def localize_pydatetime(dt: datetime, tz: tzinfo | None) -> datetime: ... +def cast_from_unit_vectorized( + values: np.ndarray, unit: str, out_unit: str = ... +) -> np.ndarray: ... diff --git a/pandas/_libs/tslibs/dtypes.pyi b/pandas/_libs/tslibs/dtypes.pyi new file mode 100644 index 0000000000000000000000000000000000000000..821c46598620327616e8b152ab4123ef6f7bc6e1 --- /dev/null +++ b/pandas/_libs/tslibs/dtypes.pyi @@ -0,0 +1,86 @@ +from enum import Enum +from typing import Self + +OFFSET_TO_PERIOD_FREQSTR: dict[str, str] + +def periods_per_day(reso: int = ...) -> int: ... +def periods_per_second(reso: int) -> int: ... +def abbrev_to_npy_unit(abbrev: str | None) -> int: ... + +class PeriodDtypeBase: + _dtype_code: int # PeriodDtypeCode + _n: int + + # actually __cinit__ + def __new__(cls, code: int, n: int) -> Self: ... + @property + def _freq_group_code(self) -> int: ... + @property + def _resolution_obj(self) -> Resolution: ... + def _get_to_timestamp_base(self) -> int: ... + @property + def _freqstr(self) -> str: ... + def __hash__(self) -> int: ... + def _is_tick_like(self) -> bool: ... + @property + def _creso(self) -> int: ... + @property + def _td64_unit(self) -> str: ... + +class FreqGroup(Enum): + _value_: int + FR_ANN = ... + FR_QTR = ... + FR_MTH = ... + FR_WK = ... + FR_BUS = ... + FR_DAY = ... + FR_HR = ... + FR_MIN = ... + FR_SEC = ... + FR_MS = ... + FR_US = ... + FR_NS = ... + FR_UND = ... + @staticmethod + def from_period_dtype_code(code: int) -> FreqGroup: ... + +class Resolution(Enum): + _value_: int + RESO_NS = ... + RESO_US = ... + RESO_MS = ... + RESO_SEC = ... + RESO_MIN = ... + RESO_HR = ... + RESO_DAY = ... + RESO_MTH = ... + RESO_QTR = ... + RESO_YR = ... + def __lt__(self, other: Resolution) -> bool: ... + def __ge__(self, other: Resolution) -> bool: ... + @property + def attrname(self) -> str: ... + @classmethod + def from_attrname(cls, attrname: str) -> Resolution: ... + @classmethod + def get_reso_from_freqstr(cls, freq: str) -> Resolution: ... + @property + def attr_abbrev(self) -> str: ... + +class NpyDatetimeUnit(Enum): + _value_: int + NPY_FR_Y = ... + NPY_FR_M = ... + NPY_FR_W = ... + NPY_FR_D = ... + NPY_FR_h = ... + NPY_FR_m = ... + NPY_FR_s = ... + NPY_FR_ms = ... + NPY_FR_us = ... + NPY_FR_ns = ... + NPY_FR_ps = ... + NPY_FR_fs = ... + NPY_FR_as = ... + NPY_FR_GENERIC = ... diff --git a/pandas/_libs/tslibs/fields.pyi b/pandas/_libs/tslibs/fields.pyi new file mode 100644 index 0000000000000000000000000000000000000000..bc55e34f3d2088c24eabf0b9b3eebf8a9a775d2c --- /dev/null +++ b/pandas/_libs/tslibs/fields.pyi @@ -0,0 +1,62 @@ +import numpy as np + +from pandas._typing import npt + +def build_field_sarray( + dtindex: npt.NDArray[np.int64], # const int64_t[:] + reso: int, # NPY_DATETIMEUNIT +) -> np.ndarray: ... +def month_position_check(fields, weekdays) -> str | None: ... +def get_date_name_field( + dtindex: npt.NDArray[np.int64], # const int64_t[:] + field: str, + locale: str | None = ..., + reso: int = ..., # NPY_DATETIMEUNIT +) -> npt.NDArray[np.object_]: ... +def get_start_end_field( + dtindex: npt.NDArray[np.int64], + field: str, + freq_name: str | None = ..., + month_kw: int = ..., + reso: int = ..., # NPY_DATETIMEUNIT +) -> npt.NDArray[np.bool_]: ... +def get_date_field( + dtindex: npt.NDArray[np.int64], # const int64_t[:] + field: str, + reso: int = ..., # NPY_DATETIMEUNIT +) -> npt.NDArray[np.int32]: ... +def get_timedelta_field( + tdindex: npt.NDArray[np.int64], # const int64_t[:] + field: str, + reso: int = ..., # NPY_DATETIMEUNIT +) -> npt.NDArray[np.int32]: ... +def get_timedelta_days( + tdindex: npt.NDArray[np.int64], # const int64_t[:] + reso: int = ..., # NPY_DATETIMEUNIT +) -> npt.NDArray[np.int64]: ... +def isleapyear_arr( + years: np.ndarray, +) -> npt.NDArray[np.bool_]: ... +def build_isocalendar_sarray( + dtindex: npt.NDArray[np.int64], # const int64_t[:] + reso: int, # NPY_DATETIMEUNIT +) -> np.ndarray: ... +def _get_locale_names(name_type: str, locale: str | None = ...): ... + +class RoundTo: + @property + def MINUS_INFTY(self) -> int: ... + @property + def PLUS_INFTY(self) -> int: ... + @property + def NEAREST_HALF_EVEN(self) -> int: ... + @property + def NEAREST_HALF_PLUS_INFTY(self) -> int: ... + @property + def NEAREST_HALF_MINUS_INFTY(self) -> int: ... + +def round_nsint64( + values: npt.NDArray[np.int64], + mode: RoundTo, + nanos: int, +) -> npt.NDArray[np.int64]: ... diff --git a/pandas/_libs/tslibs/nattype.pyi b/pandas/_libs/tslibs/nattype.pyi new file mode 100644 index 0000000000000000000000000000000000000000..6d94fa6593b65731dcebe33d4b803b7035dd18a9 --- /dev/null +++ b/pandas/_libs/tslibs/nattype.pyi @@ -0,0 +1,184 @@ +from datetime import ( + datetime, + timedelta, + tzinfo as _tzinfo, +) +from typing import ( + Literal, + NoReturn, + Self, + TypeAlias, + overload, +) + +import numpy as np + +from pandas._libs.tslibs.period import Period +from pandas._typing import ( + Frequency, + TimestampNonexistent, + TimeUnit, +) + +NaT: NaTType +iNaT: int +nat_strings: set[str] + +_TimeLike: TypeAlias = datetime | timedelta | Period | np.datetime64 | np.timedelta64 +_TimeDelta: TypeAlias = timedelta | np.timedelta64 + +class NaTType: + _value: np.int64 + @property + def value(self) -> int: ... + @property + def asm8(self) -> np.datetime64: ... + def to_datetime64(self) -> np.datetime64: ... + def to_numpy( + self, dtype: np.dtype | str | None = ..., copy: bool = ... + ) -> np.datetime64 | np.timedelta64: ... + @property + def is_leap_year(self) -> bool: ... + @property + def is_month_start(self) -> bool: ... + @property + def is_quarter_start(self) -> bool: ... + @property + def is_year_start(self) -> bool: ... + @property + def is_month_end(self) -> bool: ... + @property + def is_quarter_end(self) -> bool: ... + @property + def is_year_end(self) -> bool: ... + @property + def day_of_year(self) -> float: ... + @property + def dayofyear(self) -> float: ... + @property + def days_in_month(self) -> float: ... + @property + def daysinmonth(self) -> float: ... + @property + def day_of_week(self) -> float: ... + @property + def dayofweek(self) -> float: ... + @property + def week(self) -> float: ... + @property + def weekofyear(self) -> float: ... + @property + def fold(self) -> int: ... + def day_name(self) -> float: ... + def month_name(self) -> float: ... + def weekday(self) -> float: ... + def isoweekday(self) -> float: ... + def isoformat(self, sep: str = ..., timespec: str = ...) -> str: ... + def strftime(self, format: str) -> NoReturn: ... + def total_seconds(self) -> float: ... + def today(self, *args, **kwargs) -> NaTType: ... + def now(self, *args, **kwargs) -> NaTType: ... + def to_pydatetime(self) -> NaTType: ... + def date(self) -> NaTType: ... + def round( + self, + freq: Frequency, + ambiguous: bool | Literal["raise"] | NaTType = ..., + nonexistent: TimestampNonexistent = ..., + ) -> NaTType: ... + def floor( + self, + freq: Frequency, + ambiguous: bool | Literal["raise"] | NaTType = ..., + nonexistent: TimestampNonexistent = ..., + ) -> NaTType: ... + def ceil( + self, + freq: Frequency, + ambiguous: bool | Literal["raise"] | NaTType = ..., + nonexistent: TimestampNonexistent = ..., + ) -> NaTType: ... + @property + def tzinfo(self) -> None: ... + @property + def tz(self) -> None: ... + def tz_convert(self, tz: _tzinfo | str | None) -> NaTType: ... + def tz_localize( + self, + tz: _tzinfo | str | None, + ambiguous: bool | Literal["raise"] | NaTType = ..., + nonexistent: TimestampNonexistent = ..., + ) -> NaTType: ... + def replace( + self, + year: int | None = ..., + month: int | None = ..., + day: int | None = ..., + hour: int | None = ..., + minute: int | None = ..., + second: int | None = ..., + microsecond: int | None = ..., + nanosecond: int | None = ..., + tzinfo: _tzinfo | None = ..., + fold: int | None = ..., + ) -> NaTType: ... + @property + def year(self) -> float: ... + @property + def quarter(self) -> float: ... + @property + def month(self) -> float: ... + @property + def day(self) -> float: ... + @property + def hour(self) -> float: ... + @property + def minute(self) -> float: ... + @property + def second(self) -> float: ... + @property + def millisecond(self) -> float: ... + @property + def microsecond(self) -> float: ... + @property + def nanosecond(self) -> float: ... + # inject Timedelta properties + @property + def days(self) -> float: ... + @property + def seconds(self) -> float: ... + @property + def microseconds(self) -> float: ... + @property + def nanoseconds(self) -> float: ... + # inject Period properties + @property + def qyear(self) -> float: ... + # comparisons + def __eq__(self, other: object, /) -> Literal[False]: ... + def __ne__(self, other: object, /) -> Literal[True]: ... + def __lt__(self, other: Self | _TimeLike, /) -> Literal[False]: ... + def __le__(self, other: Self | _TimeLike, /) -> Literal[False]: ... + def __gt__(self, other: Self | _TimeLike, /) -> Literal[False]: ... + def __ge__(self, other: Self | _TimeLike, /) -> Literal[False]: ... + # unary operators + def __pos__(self) -> Self: ... + def __neg__(self) -> Self: ... + # binary operators + def __sub__(self, other: Self | _TimeLike, /) -> Self: ... + def __rsub__(self, other: Self | _TimeLike, /) -> Self: ... + def __add__(self, other: Self | _TimeLike, /) -> Self: ... + def __radd__(self, other: Self | _TimeLike, /) -> Self: ... + def __mul__(self, other: float, /) -> Self: ... # analogous to timedelta + def __rmul__(self, other: float, /) -> Self: ... + @overload # analogous to timedelta + def __truediv__(self, other: Self | _TimeDelta, /) -> float: ... # Literal[NaN] + @overload + def __truediv__(self, other: float, /) -> Self: ... + @overload # analogous to timedelta + def __floordiv__(self, other: Self | _TimeDelta, /) -> float: ... # Literal[NaN] + @overload + def __floordiv__(self, other: float, /) -> Self: ... + # other + def __hash__(self) -> int: ... + def as_unit(self, unit: TimeUnit, round_ok: bool = ...) -> NaTType: ... diff --git a/pandas/_libs/tslibs/np_datetime.pyi b/pandas/_libs/tslibs/np_datetime.pyi new file mode 100644 index 0000000000000000000000000000000000000000..00ef35c50e53251d5ca6f6c6d5ad28a67a695a21 --- /dev/null +++ b/pandas/_libs/tslibs/np_datetime.pyi @@ -0,0 +1,27 @@ +import numpy as np + +from pandas._typing import npt + +class OutOfBoundsDatetime(ValueError): ... +class OutOfBoundsTimedelta(ValueError): ... + +# only exposed for testing +def py_get_unit_from_dtype(dtype: np.dtype): ... +def py_td64_to_tdstruct(td64: int, unit: int) -> dict: ... +def astype_overflowsafe( + values: np.ndarray, + dtype: np.dtype, + copy: bool = ..., + round_ok: bool = ..., + is_coerce: bool = ..., +) -> np.ndarray: ... +def is_unitless(dtype: np.dtype) -> bool: ... +def compare_mismatched_resolutions( + left: np.ndarray, right: np.ndarray, op +) -> npt.NDArray[np.bool_]: ... +def add_overflowsafe( + left: npt.NDArray[np.int64], + right: npt.NDArray[np.int64], +) -> npt.NDArray[np.int64]: ... +def get_supported_dtype(dtype: np.dtype) -> np.dtype: ... +def is_supported_dtype(dtype: np.dtype) -> bool: ... diff --git a/pandas/_libs/tslibs/offsets.pyi b/pandas/_libs/tslibs/offsets.pyi new file mode 100644 index 0000000000000000000000000000000000000000..eaee7f54b6ea8d8dd7ecedf45edeb2868723ef38 --- /dev/null +++ b/pandas/_libs/tslibs/offsets.pyi @@ -0,0 +1,308 @@ +from collections.abc import Collection +from datetime import ( + datetime, + time, + timedelta, +) +from typing import ( + Any, + Literal, + Self, + TypeVar, + overload, +) + +import numpy as np + +from pandas._libs.tslibs.nattype import NaTType +from pandas._typing import ( + OffsetCalendar, + npt, +) + +_BaseOffsetT = TypeVar("_BaseOffsetT", bound=BaseOffset) +_DatetimeT = TypeVar("_DatetimeT", bound=datetime) +_TimedeltaT = TypeVar("_TimedeltaT", bound=timedelta) + +_relativedelta_kwds: set[str] +prefix_mapping: dict[str, type] + +class ApplyTypeError(TypeError): ... + +class BaseOffset: + n: int + normalize: bool + def __init__(self, n: int = ..., normalize: bool = ...) -> None: ... + def __eq__(self, other) -> bool: ... + def __ne__(self, other) -> bool: ... + def __hash__(self) -> int: ... + @property + def kwds(self) -> dict: ... + @property + def base(self) -> BaseOffset: ... + @overload + def __add__(self, other: npt.NDArray[np.object_]) -> npt.NDArray[np.object_]: ... + @overload + def __add__(self, other: BaseOffset) -> Self: ... + @overload + def __add__(self, other: _DatetimeT) -> _DatetimeT: ... + @overload + def __add__(self, other: _TimedeltaT) -> _TimedeltaT: ... + @overload + def __radd__(self, other: npt.NDArray[np.object_]) -> npt.NDArray[np.object_]: ... + @overload + def __radd__(self, other: BaseOffset) -> Self: ... + @overload + def __radd__(self, other: _DatetimeT) -> _DatetimeT: ... + @overload + def __radd__(self, other: _TimedeltaT) -> _TimedeltaT: ... + @overload + def __radd__(self, other: NaTType) -> NaTType: ... + def __sub__(self, other: BaseOffset) -> Self: ... + @overload + def __rsub__(self, other: npt.NDArray[np.object_]) -> npt.NDArray[np.object_]: ... + @overload + def __rsub__(self, other: BaseOffset) -> Self: ... + @overload + def __rsub__(self, other: _DatetimeT) -> _DatetimeT: ... + @overload + def __rsub__(self, other: _TimedeltaT) -> _TimedeltaT: ... + @overload + def __mul__(self, other: np.ndarray) -> np.ndarray: ... + @overload + def __mul__(self, other: int) -> Self: ... + @overload + def __rmul__(self, other: np.ndarray) -> np.ndarray: ... + @overload + def __rmul__(self, other: int) -> Self: ... + def __neg__(self) -> Self: ... + def copy(self) -> Self: ... + @property + def name(self) -> str: ... + @property + def rule_code(self) -> str: ... + @property + def freqstr(self) -> str: ... + def _apply(self, other): ... + def _apply_array(self, dtarr: np.ndarray) -> np.ndarray: ... + def rollback(self, dt: datetime) -> datetime: ... + def rollforward(self, dt: datetime) -> datetime: ... + def is_on_offset(self, dt: datetime) -> bool: ... + def __setstate__(self, state) -> None: ... + def __getstate__(self): ... + @property + def nanos(self) -> int: ... + +def _get_offset(name: str) -> BaseOffset: ... + +class SingleConstructorOffset(BaseOffset): + @classmethod + def _from_name(cls, suffix: None = ...) -> Self: ... + def __reduce__(self): ... + +@overload +def to_offset(freq: None, is_period: bool = ...) -> None: ... +@overload +def to_offset(freq: _BaseOffsetT, is_period: bool = ...) -> _BaseOffsetT: ... +@overload +def to_offset(freq: timedelta | str, is_period: bool = ...) -> BaseOffset: ... + +class Tick(SingleConstructorOffset): + _creso: int + _prefix: str + def __init__(self, n: int = ..., normalize: bool = ...) -> None: ... + @property + def nanos(self) -> int: ... + +def delta_to_tick(delta: timedelta) -> Tick: ... + +class Day(BaseOffset): ... +class Hour(Tick): ... +class Minute(Tick): ... +class Second(Tick): ... +class Milli(Tick): ... +class Micro(Tick): ... +class Nano(Tick): ... + +class RelativeDeltaOffset(BaseOffset): + def __init__(self, n: int = ..., normalize: bool = ..., **kwds: Any) -> None: ... + +class BusinessMixin(SingleConstructorOffset): + def __init__( + self, n: int = ..., normalize: bool = ..., offset: timedelta = ... + ) -> None: ... + +class BusinessDay(BusinessMixin): ... + +class BusinessHour(BusinessMixin): + def __init__( + self, + n: int = ..., + normalize: bool = ..., + start: str | time | Collection[str | time] = ..., + end: str | time | Collection[str | time] = ..., + offset: timedelta = ..., + ) -> None: ... + +class WeekOfMonthMixin(SingleConstructorOffset): + def __init__( + self, n: int = ..., normalize: bool = ..., weekday: int = ... + ) -> None: ... + +class YearOffset(SingleConstructorOffset): + def __init__( + self, n: int = ..., normalize: bool = ..., month: int | None = ... + ) -> None: ... + @property + def month(self) -> int: ... + +class BYearEnd(YearOffset): ... +class BYearBegin(YearOffset): ... + +class YearEnd(YearOffset): + def __new__( + cls, n: int = ..., normalize: bool = ..., month: int | None = ... + ) -> Self: ... + +class YearBegin(YearOffset): ... + +class QuarterOffset(SingleConstructorOffset): + def __init__( + self, n: int = ..., normalize: bool = ..., startingMonth: int | None = ... + ) -> None: ... + +class BQuarterEnd(QuarterOffset): ... +class BQuarterBegin(QuarterOffset): ... +class QuarterEnd(QuarterOffset): ... +class QuarterBegin(QuarterOffset): ... + +class HalfYearOffset(SingleConstructorOffset): + def __init__( + self, n: int = ..., normalize: bool = ..., startingMonth: int | None = ... + ) -> None: ... + +class BHalfYearEnd(HalfYearOffset): ... +class BHalfYearBegin(HalfYearOffset): ... +class HalfYearEnd(HalfYearOffset): ... +class HalfYearBegin(HalfYearOffset): ... +class MonthOffset(SingleConstructorOffset): ... +class MonthEnd(MonthOffset): ... +class MonthBegin(MonthOffset): ... +class BusinessMonthEnd(MonthOffset): ... +class BusinessMonthBegin(MonthOffset): ... + +class SemiMonthOffset(SingleConstructorOffset): + def __init__( + self, n: int = ..., normalize: bool = ..., day_of_month: int | None = ... + ) -> None: ... + +class SemiMonthEnd(SemiMonthOffset): ... +class SemiMonthBegin(SemiMonthOffset): ... + +class Week(SingleConstructorOffset): + def __init__( + self, n: int = ..., normalize: bool = ..., weekday: int | None = ... + ) -> None: ... + +class WeekOfMonth(WeekOfMonthMixin): + def __init__( + self, n: int = ..., normalize: bool = ..., week: int = ..., weekday: int = ... + ) -> None: ... + +class LastWeekOfMonth(WeekOfMonthMixin): + def __init__( + self, n: int = ..., normalize: bool = ..., weekday: int = ... + ) -> None: ... + +class FY5253Mixin(SingleConstructorOffset): + def __init__( + self, + n: int = ..., + normalize: bool = ..., + weekday: int = ..., + startingMonth: int = ..., + variation: Literal["nearest", "last"] = ..., + ) -> None: ... + +class FY5253(FY5253Mixin): ... + +class FY5253Quarter(FY5253Mixin): + def __init__( + self, + n: int = ..., + normalize: bool = ..., + weekday: int = ..., + startingMonth: int = ..., + qtr_with_extra_week: int = ..., + variation: Literal["nearest", "last"] = ..., + ) -> None: ... + +class Easter(SingleConstructorOffset): + def __init__( + self, + n: int = ..., + normalize: bool = ..., + method: int = ..., + ) -> None: ... + +class _CustomBusinessMonth(BusinessMixin): + def __init__( + self, + n: int = ..., + normalize: bool = ..., + weekmask: str = ..., + holidays: list | None = ..., + calendar: OffsetCalendar | None = ..., + offset: timedelta = ..., + ) -> None: ... + +class CustomBusinessDay(BusinessDay): + def __init__( + self, + n: int = ..., + normalize: bool = ..., + weekmask: str = ..., + holidays: list | None = ..., + calendar: OffsetCalendar | None = ..., + offset: timedelta = ..., + ) -> None: ... + +class CustomBusinessHour(BusinessHour): + def __init__( + self, + n: int = ..., + normalize: bool = ..., + weekmask: str = ..., + holidays: list | None = ..., + calendar: OffsetCalendar | None = ..., + start: str | time | Collection[str | time] = ..., + end: str | time | Collection[str | time] = ..., + offset: timedelta = ..., + ) -> None: ... + +class CustomBusinessMonthEnd(_CustomBusinessMonth): ... +class CustomBusinessMonthBegin(_CustomBusinessMonth): ... +class OffsetMeta(type): ... +class DateOffset(RelativeDeltaOffset, metaclass=OffsetMeta): ... + +BDay = BusinessDay +BMonthEnd = BusinessMonthEnd +BMonthBegin = BusinessMonthBegin +CBMonthEnd = CustomBusinessMonthEnd +CBMonthBegin = CustomBusinessMonthBegin +CDay = CustomBusinessDay + +def roll_qtrday( + other: datetime, n: int, month: int, day_opt: str, modby: int +) -> int: ... + +INVALID_FREQ_ERR_MSG: Literal["Invalid frequency: {0}"] + +def shift_months( + dtindex: npt.NDArray[np.int64], + months: int, + day_opt: str | None = ..., + reso: int = ..., +) -> npt.NDArray[np.int64]: ... + +_offset_map: dict[str, BaseOffset] diff --git a/pandas/_libs/tslibs/parsing.pyi b/pandas/_libs/tslibs/parsing.pyi new file mode 100644 index 0000000000000000000000000000000000000000..845bd9a5a5635fcf87975dd0d2d5b383c6d4ae58 --- /dev/null +++ b/pandas/_libs/tslibs/parsing.pyi @@ -0,0 +1,30 @@ +from datetime import datetime + +import numpy as np + +from pandas._typing import npt + +class DateParseError(ValueError): ... + +def py_parse_datetime_string( + date_string: str, + dayfirst: bool = ..., + yearfirst: bool = ..., +) -> datetime: ... +def parse_datetime_string_with_reso( + date_string: str, + freq: str | None = ..., + dayfirst: bool | None = ..., + yearfirst: bool | None = ..., +) -> tuple[datetime, str]: ... +def _does_string_look_like_datetime(py_string: str) -> bool: ... +def quarter_to_myear(year: int, quarter: int, freq: str) -> tuple[int, int]: ... +def try_parse_dates( + values: npt.NDArray[np.object_], # object[:] + parser, +) -> npt.NDArray[np.object_]: ... +def guess_datetime_format( + dt_str: str, + dayfirst: bool | None = ..., +) -> str | None: ... +def get_rule_month(source: str) -> str: ... diff --git a/pandas/_libs/tslibs/period.pyi b/pandas/_libs/tslibs/period.pyi new file mode 100644 index 0000000000000000000000000000000000000000..5cb9f891b312a566545b51690f99343867547b89 --- /dev/null +++ b/pandas/_libs/tslibs/period.pyi @@ -0,0 +1,135 @@ +from datetime import timedelta +from typing import Literal + +import numpy as np + +from pandas._libs.tslibs.dtypes import PeriodDtypeBase +from pandas._libs.tslibs.nattype import NaTType +from pandas._libs.tslibs.offsets import BaseOffset +from pandas._libs.tslibs.timestamps import Timestamp +from pandas._typing import ( + Frequency, + npt, +) + +INVALID_FREQ_ERR_MSG: str +DIFFERENT_FREQ: str + +class IncompatibleFrequency(TypeError): ... + +def periodarr_to_dt64arr( + periodarr: npt.NDArray[np.int64], # const int64_t[:] + freq: int, +) -> npt.NDArray[np.int64]: ... +def period_asfreq_arr( + arr: npt.NDArray[np.int64], + freq1: int, + freq2: int, + end: bool, +) -> npt.NDArray[np.int64]: ... +def get_period_field_arr( + field: str, + arr: npt.NDArray[np.int64], # const int64_t[:] + freq: int, +) -> npt.NDArray[np.int64]: ... +def from_ordinals( + values: npt.NDArray[np.int64], # const int64_t[:] + freq: timedelta | BaseOffset | str, +) -> npt.NDArray[np.int64]: ... +def extract_ordinals( + values: npt.NDArray[np.object_], + freq: Frequency | int, +) -> npt.NDArray[np.int64]: ... +def extract_freq( + values: npt.NDArray[np.object_], +) -> BaseOffset: ... +def period_array_strftime( + values: npt.NDArray[np.int64], + dtype_code: int, + na_rep, + date_format: str | None, +) -> npt.NDArray[np.object_]: ... + +# exposed for tests +def period_asfreq(ordinal: int, freq1: int, freq2: int, end: bool) -> int: ... +def period_ordinal( + y: int, m: int, d: int, h: int, min: int, s: int, us: int, ps: int, freq: int +) -> int: ... +def freq_to_dtype_code(freq: BaseOffset) -> int: ... +def validate_end_alias(how: str) -> Literal["E", "S"]: ... + +class PeriodMixin: + @property + def end_time(self) -> Timestamp: ... + @property + def start_time(self) -> Timestamp: ... + def _require_matching_freq(self, other: BaseOffset, base: bool = ...) -> None: ... + +class Period(PeriodMixin): + ordinal: int # int64_t + freq: BaseOffset + _dtype: PeriodDtypeBase + + # error: "__new__" must return a class instance (got "Union[Period, NaTType]") + def __new__( # type: ignore[misc] + cls, + value=..., + freq: int | str | BaseOffset | None = ..., + ordinal: int | None = ..., + year: int | None = ..., + month: int | None = ..., + quarter: int | None = ..., + day: int | None = ..., + hour: int | None = ..., + minute: int | None = ..., + second: int | None = ..., + ) -> Period | NaTType: ... + @classmethod + def _maybe_convert_freq(cls, freq) -> BaseOffset: ... + @classmethod + def _from_ordinal(cls, ordinal: int, freq: BaseOffset) -> Period: ... + @classmethod + def now(cls, freq: Frequency) -> Period: ... + def strftime(self, fmt: str | None) -> str: ... + def to_timestamp( + self, + freq: str | BaseOffset | None = ..., + how: str = ..., + ) -> Timestamp: ... + def asfreq(self, freq: str | BaseOffset, how: str = ...) -> Period: ... + @property + def freqstr(self) -> str: ... + @property + def is_leap_year(self) -> bool: ... + @property + def daysinmonth(self) -> int: ... + @property + def days_in_month(self) -> int: ... + @property + def qyear(self) -> int: ... + @property + def quarter(self) -> int: ... + @property + def day_of_year(self) -> int: ... + @property + def weekday(self) -> int: ... + @property + def day_of_week(self) -> int: ... + @property + def week(self) -> int: ... + @property + def weekofyear(self) -> int: ... + @property + def second(self) -> int: ... + @property + def minute(self) -> int: ... + @property + def hour(self) -> int: ... + @property + def day(self) -> int: ... + @property + def month(self) -> int: ... + @property + def year(self) -> int: ... + def __sub__(self, other) -> Period | BaseOffset: ... + def __add__(self, other) -> Period: ... diff --git a/pandas/_libs/tslibs/strptime.pyi b/pandas/_libs/tslibs/strptime.pyi new file mode 100644 index 0000000000000000000000000000000000000000..0ec1a1e25a2b3cfe974baebfe32d686435f73e11 --- /dev/null +++ b/pandas/_libs/tslibs/strptime.pyi @@ -0,0 +1,14 @@ +import numpy as np + +from pandas._typing import npt + +def array_strptime( + values: npt.NDArray[np.object_], + fmt: str | None, + exact: bool = ..., + errors: str = ..., + utc: bool = ..., + creso: int = ..., # NPY_DATETIMEUNIT +) -> tuple[np.ndarray, np.ndarray]: ... + +# first ndarray is M8[ns], second is object ndarray of tzinfo | None diff --git a/pandas/_libs/tslibs/timedeltas.pyi b/pandas/_libs/tslibs/timedeltas.pyi new file mode 100644 index 0000000000000000000000000000000000000000..a04387eb09d6b94d0dee888dba14141b1ea128f2 --- /dev/null +++ b/pandas/_libs/tslibs/timedeltas.pyi @@ -0,0 +1,168 @@ +from datetime import timedelta +from typing import ( + ClassVar, + Literal, + Self, + TypeAlias, + overload, +) + +import numpy as np + +from pandas._libs.tslibs import ( + NaTType, + Tick, +) +from pandas._typing import ( + Frequency, + TimeUnit, + npt, +) + +# This should be kept consistent with the keys in the dict timedelta_abbrevs +# in pandas/_libs/tslibs/timedeltas.pyx +UnitChoices: TypeAlias = Literal[ + "Y", + "y", + "M", + "W", + "w", + "D", + "d", + "days", + "day", + "hours", + "hour", + "hr", + "h", + "m", + "minute", + "min", + "minutes", + "s", + "seconds", + "sec", + "second", + "ms", + "milliseconds", + "millisecond", + "milli", + "millis", + "us", + "microseconds", + "microsecond", + "µs", + "micro", + "micros", + "ns", + "nanoseconds", + "nano", + "nanos", + "nanosecond", +] + +def get_unit_for_round(freq, creso: int) -> int: ... +def disallow_ambiguous_unit(unit: str | None) -> None: ... +def ints_to_pytimedelta( + m8values: npt.NDArray[np.timedelta64], + box: bool = ..., +) -> npt.NDArray[np.object_]: ... +def array_to_timedelta64( + values: npt.NDArray[np.object_], + unit: str | None = ..., + errors: str = ..., + creso: int = ..., +) -> np.ndarray: ... # np.ndarray[m8ns] +def parse_timedelta_unit(unit: str | None) -> UnitChoices: ... +def delta_to_nanoseconds( + delta: np.timedelta64 | timedelta | Tick, + reso: int = ..., # NPY_DATETIMEUNIT + round_ok: bool = ..., +) -> int: ... +def floordiv_object_array( + left: np.ndarray, right: npt.NDArray[np.object_] +) -> np.ndarray: ... +def truediv_object_array( + left: np.ndarray, right: npt.NDArray[np.object_] +) -> np.ndarray: ... + +class Timedelta(timedelta): + _creso: int + min: ClassVar[Timedelta] + max: ClassVar[Timedelta] + resolution: ClassVar[Timedelta] + value: int # np.int64 + _value: int # np.int64 + # error: "__new__" must return a class instance (got "Union[Timestamp, NaTType]") + def __new__( # type: ignore[misc] + cls: type[Self], + value=..., + unit: str | None = ..., + **kwargs: float | np.integer | np.floating, + ) -> Self | NaTType: ... + @classmethod + def _from_value_and_reso(cls, value: np.int64, reso: int) -> Timedelta: ... + @property + def days(self) -> int: ... + @property + def seconds(self) -> int: ... + @property + def microseconds(self) -> int: ... + def total_seconds(self) -> float: ... + def to_pytimedelta(self) -> timedelta: ... + def to_timedelta64(self) -> np.timedelta64: ... + @property + def asm8(self) -> np.timedelta64: ... + # TODO: round/floor/ceil could return NaT? + def round(self, freq: Frequency) -> Self: ... + def floor(self, freq: Frequency) -> Self: ... + def ceil(self, freq: Frequency) -> Self: ... + @property + def resolution_string(self) -> str: ... + def __add__(self, other: timedelta) -> Timedelta: ... + def __radd__(self, other: timedelta) -> Timedelta: ... + def __sub__(self, other: timedelta) -> Timedelta: ... + def __rsub__(self, other: timedelta) -> Timedelta: ... + def __neg__(self) -> Timedelta: ... + def __pos__(self) -> Timedelta: ... + def __abs__(self) -> Timedelta: ... + def __mul__(self, other: float) -> Timedelta: ... + def __rmul__(self, other: float) -> Timedelta: ... + # error: Signature of "__floordiv__" incompatible with supertype "timedelta" + @overload # type: ignore[override] + def __floordiv__(self, other: timedelta) -> int: ... + @overload + def __floordiv__(self, other: float) -> Timedelta: ... + @overload + def __floordiv__( + self, other: npt.NDArray[np.timedelta64] + ) -> npt.NDArray[np.intp]: ... + @overload + def __floordiv__( + self, other: npt.NDArray[np.number] + ) -> npt.NDArray[np.timedelta64] | Timedelta: ... + @overload + def __rfloordiv__(self, other: timedelta | str) -> int: ... + @overload + def __rfloordiv__(self, other: None | NaTType) -> NaTType: ... + @overload + def __rfloordiv__(self, other: np.ndarray) -> npt.NDArray[np.timedelta64]: ... + @overload + def __truediv__(self, other: timedelta) -> float: ... + @overload + def __truediv__(self, other: float) -> Timedelta: ... + def __mod__(self, other: timedelta) -> Timedelta: ... + def __divmod__(self, other: timedelta) -> tuple[int, Timedelta]: ... + def __le__(self, other: timedelta) -> bool: ... + def __lt__(self, other: timedelta) -> bool: ... + def __ge__(self, other: timedelta) -> bool: ... + def __gt__(self, other: timedelta) -> bool: ... + def __hash__(self) -> int: ... + def isoformat(self) -> str: ... + def to_numpy( + self, dtype: npt.DTypeLike = ..., copy: bool = False + ) -> np.timedelta64: ... + def view(self, dtype: npt.DTypeLike) -> object: ... + @property + def unit(self) -> TimeUnit: ... + def as_unit(self, unit: TimeUnit, round_ok: bool = ...) -> Timedelta: ... diff --git a/pandas/_libs/tslibs/timestamps.pyi b/pandas/_libs/tslibs/timestamps.pyi new file mode 100644 index 0000000000000000000000000000000000000000..d06c78b22626a325f01f4ab3466271256249d415 --- /dev/null +++ b/pandas/_libs/tslibs/timestamps.pyi @@ -0,0 +1,242 @@ +from datetime import ( + date as _date, + datetime, + time as _time, + timedelta, + tzinfo as _tzinfo, +) +from time import struct_time +from typing import ( + ClassVar, + Literal, + Self, + TypeAlias, + overload, +) + +import numpy as np + +from pandas._libs.tslibs import ( + BaseOffset, + NaTType, + Period, + Tick, + Timedelta, +) +from pandas._typing import ( + TimestampNonexistent, + TimeUnit, +) + +_TimeZones: TypeAlias = str | _tzinfo | None | int + +def integer_op_not_supported(obj: object) -> TypeError: ... + +class Timestamp(datetime): + _creso: int + min: ClassVar[Timestamp] + max: ClassVar[Timestamp] + + resolution: ClassVar[Timedelta] + _value: int # np.int64 + # error: "__new__" must return a class instance (got "Union[Timestamp, NaTType]") + def __new__( # type: ignore[misc] + cls: type[Self], + ts_input: np.integer | float | str | _date | datetime | np.datetime64 = ..., + year: int | None = ..., + month: int | None = ..., + day: int | None = ..., + hour: int | None = ..., + minute: int | None = ..., + second: int | None = ..., + microsecond: int | None = ..., + tzinfo: _tzinfo | None = ..., + *, + nanosecond: int | None = ..., + tz: _TimeZones = ..., + unit: str | int | None = ..., + fold: int | None = ..., + ) -> Self | NaTType: ... + @classmethod + def _from_value_and_reso( + cls, value: int, reso: int, tz: _TimeZones + ) -> Timestamp: ... + @property + def value(self) -> int: ... # np.int64 + @property + def year(self) -> int: ... + @property + def month(self) -> int: ... + @property + def day(self) -> int: ... + @property + def hour(self) -> int: ... + @property + def minute(self) -> int: ... + @property + def second(self) -> int: ... + @property + def microsecond(self) -> int: ... + @property + def nanosecond(self) -> int: ... + @property + def tzinfo(self) -> _tzinfo | None: ... + @property + def tz(self) -> _tzinfo | None: ... + @property + def fold(self) -> int: ... + @classmethod + def fromtimestamp(cls, ts: float, tz: _TimeZones = ...) -> Self: ... + @classmethod + def utcfromtimestamp(cls, ts: float) -> Self: ... + @classmethod + def today(cls, tz: _TimeZones = ...) -> Self: ... + @classmethod + def fromordinal( + cls, + ordinal: int, + tz: _TimeZones = ..., + ) -> Self: ... + @classmethod + def now(cls, tz: _TimeZones = ...) -> Self: ... + @classmethod + def utcnow(cls) -> Self: ... + # error: Signature of "combine" incompatible with supertype "datetime" + @classmethod + def combine( # type: ignore[override] + cls, date: _date, time: _time + ) -> datetime: ... + @classmethod + def fromisoformat(cls, date_string: str) -> Self: ... + def strftime(self, format: str) -> str: ... + def __format__(self, fmt: str) -> str: ... + def toordinal(self) -> int: ... + def timetuple(self) -> struct_time: ... + def timestamp(self) -> float: ... + def utctimetuple(self) -> struct_time: ... + def date(self) -> _date: ... + def time(self) -> _time: ... + def timetz(self) -> _time: ... + # LSP violation: nanosecond is not present in datetime.datetime.replace + # and has positional args following it + def replace( # type: ignore[override] + self, + year: int | None = ..., + month: int | None = ..., + day: int | None = ..., + hour: int | None = ..., + minute: int | None = ..., + second: int | None = ..., + microsecond: int | None = ..., + nanosecond: int | None = ..., + tzinfo: _tzinfo | type[object] | None = ..., + fold: int | None = ..., + ) -> Self: ... + # LSP violation: datetime.datetime.astimezone has a default value for tz + def astimezone(self, tz: _TimeZones) -> Self: ... # type: ignore[override] + def ctime(self) -> str: ... + def isoformat(self, sep: str = ..., timespec: str = ...) -> str: ... + @classmethod + def strptime( + # Note: strptime is actually disabled and raises NotImplementedError + cls, + date_string: str, + format: str, + ) -> Self: ... + def utcoffset(self) -> timedelta | None: ... + def tzname(self) -> str | None: ... + def dst(self) -> timedelta | None: ... + def __le__(self, other: datetime) -> bool: ... # type: ignore[override] + def __lt__(self, other: datetime) -> bool: ... # type: ignore[override] + def __ge__(self, other: datetime) -> bool: ... # type: ignore[override] + def __gt__(self, other: datetime) -> bool: ... # type: ignore[override] + # error: Signature of "__add__" incompatible with supertype "date"/"datetime" + @overload # type: ignore[override] + def __add__(self, other: np.ndarray) -> np.ndarray: ... + @overload + def __add__(self, other: timedelta | np.timedelta64 | Tick) -> Self: ... + def __radd__(self, other: timedelta) -> Self: ... + @overload # type: ignore[override] + def __sub__(self, other: datetime) -> Timedelta: ... + @overload + def __sub__(self, other: timedelta | np.timedelta64 | Tick) -> Self: ... + def __hash__(self) -> int: ... + def weekday(self) -> int: ... + def isoweekday(self) -> int: ... + # Return type "Tuple[int, int, int]" of "isocalendar" incompatible with return + # type "_IsoCalendarDate" in supertype "date" + def isocalendar(self) -> tuple[int, int, int]: ... # type: ignore[override] + @property + def is_leap_year(self) -> bool: ... + @property + def is_month_start(self) -> bool: ... + @property + def is_quarter_start(self) -> bool: ... + @property + def is_year_start(self) -> bool: ... + @property + def is_month_end(self) -> bool: ... + @property + def is_quarter_end(self) -> bool: ... + @property + def is_year_end(self) -> bool: ... + def to_pydatetime(self, warn: bool = ...) -> datetime: ... + def to_datetime64(self) -> np.datetime64: ... + def to_period(self, freq: BaseOffset | str | None = None) -> Period: ... + def to_julian_date(self) -> np.float64: ... + @property + def asm8(self) -> np.datetime64: ... + def tz_convert(self, tz: _TimeZones) -> Self: ... + # TODO: could return NaT? + def tz_localize( + self, + tz: _TimeZones, + ambiguous: bool | Literal["raise", "NaT"] = ..., + nonexistent: TimestampNonexistent = ..., + ) -> Self: ... + def normalize(self) -> Self: ... + # TODO: round/floor/ceil could return NaT? + def round( + self, + freq: str, + ambiguous: bool | Literal["raise", "NaT"] = ..., + nonexistent: TimestampNonexistent = ..., + ) -> Self: ... + def floor( + self, + freq: str, + ambiguous: bool | Literal["raise", "NaT"] = ..., + nonexistent: TimestampNonexistent = ..., + ) -> Self: ... + def ceil( + self, + freq: str, + ambiguous: bool | Literal["raise", "NaT"] = ..., + nonexistent: TimestampNonexistent = ..., + ) -> Self: ... + def day_name(self, locale: str | None = ...) -> str: ... + def month_name(self, locale: str | None = ...) -> str: ... + @property + def day_of_week(self) -> int: ... + @property + def dayofweek(self) -> int: ... + @property + def day_of_year(self) -> int: ... + @property + def dayofyear(self) -> int: ... + @property + def quarter(self) -> int: ... + @property + def week(self) -> int: ... + def to_numpy( + self, dtype: np.dtype | None = ..., copy: bool = ... + ) -> np.datetime64: ... + @property + def _date_repr(self) -> str: ... + @property + def days_in_month(self) -> int: ... + @property + def daysinmonth(self) -> int: ... + @property + def unit(self) -> TimeUnit: ... + def as_unit(self, unit: TimeUnit, round_ok: bool = ...) -> Timestamp: ... diff --git a/pandas/_libs/tslibs/timezones.pyi b/pandas/_libs/tslibs/timezones.pyi new file mode 100644 index 0000000000000000000000000000000000000000..26ffa568a848001b4b80bf0525d60527682a9be4 --- /dev/null +++ b/pandas/_libs/tslibs/timezones.pyi @@ -0,0 +1,21 @@ +from collections.abc import Callable +from datetime import ( + datetime, + tzinfo, +) + +import numpy as np + +# imported from dateutil.tz +dateutil_gettz: Callable[[str], tzinfo] + +def tz_standardize(tz: tzinfo) -> tzinfo: ... +def tz_compare(start: tzinfo | None, end: tzinfo | None) -> bool: ... +def infer_tzinfo( + start: datetime | None, + end: datetime | None, +) -> tzinfo | None: ... +def maybe_get_tz(tz: str | int | np.int64 | tzinfo | None) -> tzinfo | None: ... +def get_timezone(tz: tzinfo) -> tzinfo | str: ... +def is_utc(tz: tzinfo | None) -> bool: ... +def is_fixed_offset(tz: tzinfo) -> bool: ... diff --git a/pandas/_libs/tslibs/tzconversion.pyi b/pandas/_libs/tslibs/tzconversion.pyi new file mode 100644 index 0000000000000000000000000000000000000000..07ee46858577aeb3a373b30d292d0803c36c3d01 --- /dev/null +++ b/pandas/_libs/tslibs/tzconversion.pyi @@ -0,0 +1,21 @@ +from collections.abc import Iterable +from datetime import ( + timedelta, + tzinfo, +) + +import numpy as np + +from pandas._typing import npt + +# tz_convert_from_utc_single exposed for testing +def tz_convert_from_utc_single( + utc_val: np.int64, tz: tzinfo, creso: int = ... +) -> np.int64: ... +def tz_localize_to_utc( + vals: npt.NDArray[np.int64], + tz: tzinfo | None, + ambiguous: str | bool | Iterable[bool] | None = ..., + nonexistent: str | timedelta | np.timedelta64 | None = ..., + creso: int = ..., # NPY_DATETIMEUNIT +) -> npt.NDArray[np.int64]: ... diff --git a/pandas/_libs/tslibs/vectorized.pyi b/pandas/_libs/tslibs/vectorized.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f377c2e26ab81e4cf767f6614a430a602fcfd8d9 --- /dev/null +++ b/pandas/_libs/tslibs/vectorized.pyi @@ -0,0 +1,41 @@ +# For cython types that cannot be represented precisely, closest-available +# python equivalents are used, and the precise types kept as adjacent comments. +from datetime import tzinfo + +import numpy as np + +from pandas._libs.tslibs.dtypes import Resolution +from pandas._typing import npt + +def dt64arr_to_periodarr( + stamps: npt.NDArray[np.int64], + freq: int, + tz: tzinfo | None, + reso: int = ..., # NPY_DATETIMEUNIT +) -> npt.NDArray[np.int64]: ... +def is_date_array_normalized( + stamps: npt.NDArray[np.int64], + tz: tzinfo | None, + reso: int, # NPY_DATETIMEUNIT +) -> bool: ... +def normalize_i8_timestamps( + stamps: npt.NDArray[np.int64], + tz: tzinfo | None, + reso: int, # NPY_DATETIMEUNIT +) -> npt.NDArray[np.int64]: ... +def get_resolution( + stamps: npt.NDArray[np.int64], + tz: tzinfo | None = ..., + reso: int = ..., # NPY_DATETIMEUNIT +) -> Resolution: ... +def ints_to_pydatetime( + stamps: npt.NDArray[np.int64], + tz: tzinfo | None = ..., + box: str = ..., + reso: int = ..., # NPY_DATETIMEUNIT +) -> npt.NDArray[np.object_]: ... +def tz_convert_from_utc( + stamps: npt.NDArray[np.int64], + tz: tzinfo | None, + reso: int = ..., # NPY_DATETIMEUNIT +) -> npt.NDArray[np.int64]: ... diff --git a/pandas/_libs/window/__init__.py b/pandas/_libs/window/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/_libs/window/aggregations.pyi b/pandas/_libs/window/aggregations.pyi new file mode 100644 index 0000000000000000000000000000000000000000..99413751cd5c2f88466556d05b12939fd8adb148 --- /dev/null +++ b/pandas/_libs/window/aggregations.pyi @@ -0,0 +1,145 @@ +from collections.abc import Callable +from typing import ( + Any, + Literal, +) + +import numpy as np + +from pandas._typing import ( + WindowingRankType, + npt, +) + +def roll_sum( + values: np.ndarray, # const float64_t[:] + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t +) -> np.ndarray: ... # np.ndarray[float] +def roll_mean( + values: np.ndarray, # const float64_t[:] + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t +) -> np.ndarray: ... # np.ndarray[float] +def roll_var( + values: np.ndarray, # const float64_t[:] + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t + ddof: int = ..., +) -> np.ndarray: ... # np.ndarray[float] +def roll_skew( + values: np.ndarray, # np.ndarray[np.float64] + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t +) -> np.ndarray: ... # np.ndarray[float] +def roll_kurt( + values: np.ndarray, # np.ndarray[np.float64] + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t +) -> np.ndarray: ... # np.ndarray[float] +def roll_median_c( + values: np.ndarray, # np.ndarray[np.float64] + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t +) -> np.ndarray: ... # np.ndarray[float] +def roll_max( + values: np.ndarray, # np.ndarray[np.float64] + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t +) -> np.ndarray: ... # np.ndarray[float] +def roll_min( + values: np.ndarray, # np.ndarray[np.float64] + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t +) -> np.ndarray: ... # np.ndarray[float] +def roll_first( + values: np.ndarray, # np.ndarray[np.float64] + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t +) -> np.ndarray: ... # np.ndarray[float] +def roll_last( + values: np.ndarray, # np.ndarray[np.float64] + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t +) -> np.ndarray: ... # np.ndarray[float] +def roll_quantile( + values: np.ndarray, # const float64_t[:] + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t + quantile: float, # float64_t + interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"], +) -> np.ndarray: ... # np.ndarray[float] +def roll_rank( + values: np.ndarray, + start: np.ndarray, + end: np.ndarray, + minp: int, + percentile: bool, + method: WindowingRankType, + ascending: bool, +) -> np.ndarray: ... # np.ndarray[float] +def roll_nunique( + values: np.ndarray, # const float64_t[:] + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t +) -> np.ndarray: ... # np.ndarray[float] +def roll_apply( + obj: object, + start: np.ndarray, # np.ndarray[np.int64] + end: np.ndarray, # np.ndarray[np.int64] + minp: int, # int64_t + function: Callable[..., Any], + raw: bool, + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> npt.NDArray[np.float64]: ... +def roll_weighted_sum( + values: np.ndarray, # const float64_t[:] + weights: np.ndarray, # const float64_t[:] + minp: int, +) -> np.ndarray: ... # np.ndarray[np.float64] +def roll_weighted_mean( + values: np.ndarray, # const float64_t[:] + weights: np.ndarray, # const float64_t[:] + minp: int, +) -> np.ndarray: ... # np.ndarray[np.float64] +def roll_weighted_var( + values: np.ndarray, # const float64_t[:] + weights: np.ndarray, # const float64_t[:] + minp: int, # int64_t + ddof: int, # unsigned int +) -> np.ndarray: ... # np.ndarray[np.float64] +def ewm( + vals: np.ndarray, # const float64_t[:] + start: np.ndarray, # const int64_t[:] + end: np.ndarray, # const int64_t[:] + minp: int, + com: float, # float64_t + adjust: bool, + ignore_na: bool, + deltas: np.ndarray | None = None, # const float64_t[:] + normalize: bool = True, +) -> np.ndarray: ... # np.ndarray[np.float64] +def ewmcov( + input_x: np.ndarray, # const float64_t[:] + start: np.ndarray, # const int64_t[:] + end: np.ndarray, # const int64_t[:] + minp: int, + input_y: np.ndarray, # const float64_t[:] + com: float, # float64_t + adjust: bool, + ignore_na: bool, + bias: bool, +) -> np.ndarray: ... # np.ndarray[np.float64] diff --git a/pandas/_libs/window/indexers.pyi b/pandas/_libs/window/indexers.pyi new file mode 100644 index 0000000000000000000000000000000000000000..c9bc64be34ac9a41d14fef33b0fc76bdf66527e9 --- /dev/null +++ b/pandas/_libs/window/indexers.pyi @@ -0,0 +1,12 @@ +import numpy as np + +from pandas._typing import npt + +def calculate_variable_window_bounds( + num_values: int, # int64_t + window_size: int, # int64_t + min_periods, + center: bool, + closed: str | None, + index: np.ndarray, # const int64_t[:] +) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]: ... diff --git a/pandas/_libs/writers.cpython-312-x86_64-linux-gnu.so b/pandas/_libs/writers.cpython-312-x86_64-linux-gnu.so new file mode 100644 index 0000000000000000000000000000000000000000..d4abc5f481e8c7c7f33af82423db15f9b1de8ee2 Binary files /dev/null and b/pandas/_libs/writers.cpython-312-x86_64-linux-gnu.so differ diff --git a/pandas/_libs/writers.pyi b/pandas/_libs/writers.pyi new file mode 100644 index 0000000000000000000000000000000000000000..7b41856525dadf79a2bf4b29c7ddebfedaa880db --- /dev/null +++ b/pandas/_libs/writers.pyi @@ -0,0 +1,20 @@ +import numpy as np + +from pandas._typing import ArrayLike + +def write_csv_rows( + data: list[ArrayLike], + data_index: np.ndarray, + nlevels: int, + cols: np.ndarray, + writer: object, # _csv.writer +) -> None: ... +def convert_json_to_lines(arr: str) -> str: ... +def max_len_string_array( + arr: np.ndarray, # pandas_string[:] +) -> int: ... +def word_len(val: object) -> int: ... +def string_array_replace_from_nan_rep( + arr: np.ndarray, # np.ndarray[object, ndim=1] + nan_rep: object, +) -> None: ... diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad323c3347b85ef636a9290ef1f35fd8fb45e51f --- /dev/null +++ b/pandas/_testing/__init__.py @@ -0,0 +1,645 @@ +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from decimal import Decimal +import operator +import os +from sys import byteorder +import threading +from typing import ( + TYPE_CHECKING, + ContextManager, +) + +import numpy as np + +from pandas._config import using_string_dtype +from pandas._config.localization import ( + can_set_locale, + get_locales, + set_locale, +) + +from pandas.compat import HAS_PYARROW + +import pandas as pd +from pandas import ( + ArrowDtype, + DataFrame, + Index, + MultiIndex, + RangeIndex, + Series, +) +from pandas._testing._io import ( + round_trip_pathlib, + round_trip_pickle, + write_to_compressed, +) +from pandas._testing._warnings import ( + assert_produces_warning, + maybe_produces_warning, +) +from pandas._testing.asserters import ( + assert_almost_equal, + assert_attr_equal, + assert_categorical_equal, + assert_class_equal, + assert_contains_all, + assert_copy, + assert_datetime_array_equal, + assert_dict_equal, + assert_equal, + assert_extension_array_equal, + assert_frame_equal, + assert_index_equal, + assert_indexing_slices_equivalent, + assert_interval_array_equal, + assert_is_sorted, + assert_metadata_equivalent, + assert_numpy_array_equal, + assert_period_array_equal, + assert_series_equal, + assert_sp_array_equal, + assert_timedelta_array_equal, + raise_assert_detail, +) +from pandas._testing.compat import ( + get_dtype, + get_obj, +) +from pandas._testing.contexts import ( + decompress_file, + raises_chained_assignment_error, + set_timezone, + with_csv_dialect, +) +from pandas.core.arrays import ( + ArrowExtensionArray, + BaseMaskedArray, + NumpyExtensionArray, +) +from pandas.core.arrays._mixins import NDArrayBackedExtensionArray +from pandas.core.construction import extract_array + +if TYPE_CHECKING: + from collections.abc import Callable + + from pandas._typing import ( + Dtype, + NpDtype, + ) + + +UNSIGNED_INT_NUMPY_DTYPES: list[NpDtype] = ["uint8", "uint16", "uint32", "uint64"] +UNSIGNED_INT_EA_DTYPES: list[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"] +SIGNED_INT_NUMPY_DTYPES: list[NpDtype] = [int, "int8", "int16", "int32", "int64"] +SIGNED_INT_EA_DTYPES: list[Dtype] = ["Int8", "Int16", "Int32", "Int64"] +ALL_INT_NUMPY_DTYPES = UNSIGNED_INT_NUMPY_DTYPES + SIGNED_INT_NUMPY_DTYPES +ALL_INT_EA_DTYPES = UNSIGNED_INT_EA_DTYPES + SIGNED_INT_EA_DTYPES +ALL_INT_DTYPES: list[Dtype] = [*ALL_INT_NUMPY_DTYPES, *ALL_INT_EA_DTYPES] + +FLOAT_NUMPY_DTYPES: list[NpDtype] = [float, "float32", "float64"] +FLOAT_EA_DTYPES: list[Dtype] = ["Float32", "Float64"] +ALL_FLOAT_DTYPES: list[Dtype] = [*FLOAT_NUMPY_DTYPES, *FLOAT_EA_DTYPES] + +COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"] +if using_string_dtype(): + STRING_DTYPES: list[Dtype] = ["U"] +else: + STRING_DTYPES: list[Dtype] = [str, "str", "U"] # type: ignore[no-redef] +COMPLEX_FLOAT_DTYPES: list[Dtype] = [*COMPLEX_DTYPES, *FLOAT_NUMPY_DTYPES] + +DATETIME64_DTYPES: list[Dtype] = ["datetime64[ns]", "M8[ns]"] +TIMEDELTA64_DTYPES: list[Dtype] = ["timedelta64[ns]", "m8[ns]"] + +BOOL_DTYPES: list[Dtype] = [bool, "bool"] +BYTES_DTYPES: list[Dtype] = [bytes, "bytes"] +OBJECT_DTYPES: list[Dtype] = [object, "object"] + +ALL_REAL_NUMPY_DTYPES = FLOAT_NUMPY_DTYPES + ALL_INT_NUMPY_DTYPES +ALL_REAL_EXTENSION_DTYPES = FLOAT_EA_DTYPES + ALL_INT_EA_DTYPES +ALL_REAL_DTYPES: list[Dtype] = [*ALL_REAL_NUMPY_DTYPES, *ALL_REAL_EXTENSION_DTYPES] +ALL_NUMERIC_DTYPES: list[Dtype] = [*ALL_REAL_DTYPES, *COMPLEX_DTYPES] + +ALL_NUMPY_DTYPES = ( + ALL_REAL_NUMPY_DTYPES + + COMPLEX_DTYPES + + STRING_DTYPES + + DATETIME64_DTYPES + + TIMEDELTA64_DTYPES + + BOOL_DTYPES + + OBJECT_DTYPES + + BYTES_DTYPES +) + +NARROW_NP_DTYPES = [ + np.float16, + np.float32, + np.int8, + np.int16, + np.int32, + np.uint8, + np.uint16, + np.uint32, +] + +PYTHON_DATA_TYPES = [ + str, + int, + float, + complex, + list, + tuple, + range, + dict, + set, + frozenset, + bool, + bytes, + bytearray, + memoryview, +] + +ENDIAN = {"little": "<", "big": ">"}[byteorder] + +NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")] +NP_NAT_OBJECTS = [ + cls("NaT", unit) + for cls in [np.datetime64, np.timedelta64] + for unit in [ + "Y", + "M", + "W", + "D", + "h", + "m", + "s", + "ms", + "us", + "ns", + "ps", + "fs", + "as", + ] +] + +if HAS_PYARROW: + import pyarrow as pa + + UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()] + SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()] + ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES + ALL_INT_PYARROW_DTYPES_STR_REPR = [ + str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES + ] + + # pa.float16 doesn't seem supported + # https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86 + FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()] + FLOAT_PYARROW_DTYPES_STR_REPR = [ + str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES + ] + DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)] + STRING_PYARROW_DTYPES = [pa.string()] + BINARY_PYARROW_DTYPES = [pa.binary()] + + TIME_PYARROW_DTYPES = [ + pa.time32("s"), + pa.time32("ms"), + pa.time64("us"), + pa.time64("ns"), + ] + DATE_PYARROW_DTYPES = [pa.date32(), pa.date64()] + DATETIME_PYARROW_DTYPES = [ + pa.timestamp(unit=unit, tz=tz) + for unit in ["s", "ms", "us", "ns"] + for tz in [None, "UTC", "US/Pacific", "US/Eastern"] + ] + TIMEDELTA_PYARROW_DTYPES = [pa.duration(unit) for unit in ["s", "ms", "us", "ns"]] + + BOOL_PYARROW_DTYPES = [pa.bool_()] + + # TODO: Add container like pyarrow types: + # https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions + ALL_PYARROW_DTYPES = ( + ALL_INT_PYARROW_DTYPES + + FLOAT_PYARROW_DTYPES + + DECIMAL_PYARROW_DTYPES + + STRING_PYARROW_DTYPES + + BINARY_PYARROW_DTYPES + + TIME_PYARROW_DTYPES + + DATE_PYARROW_DTYPES + + DATETIME_PYARROW_DTYPES + + TIMEDELTA_PYARROW_DTYPES + + BOOL_PYARROW_DTYPES + ) + ALL_REAL_PYARROW_DTYPES_STR_REPR = ( + ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR + ) +else: + FLOAT_PYARROW_DTYPES_STR_REPR = [] + ALL_INT_PYARROW_DTYPES_STR_REPR = [] + ALL_PYARROW_DTYPES = [] + ALL_REAL_PYARROW_DTYPES_STR_REPR = [] + +ALL_REAL_NULLABLE_DTYPES = ( + FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR +) + +arithmetic_dunder_methods = [ + "__add__", + "__radd__", + "__sub__", + "__rsub__", + "__mul__", + "__rmul__", + "__floordiv__", + "__rfloordiv__", + "__truediv__", + "__rtruediv__", + "__pow__", + "__rpow__", + "__mod__", + "__rmod__", +] + +comparison_dunder_methods = ["__eq__", "__ne__", "__le__", "__lt__", "__ge__", "__gt__"] + + +# ----------------------------------------------------------------------------- +# Comparators + + +def box_expected(expected, box_cls, transpose: bool = True): + """ + Helper function to wrap the expected output of a test in a given box_class. + + Parameters + ---------- + expected : np.ndarray, Index, Series + box_cls : {Index, Series, DataFrame} + + Returns + ------- + subclass of box_cls + """ + if box_cls is pd.array: + if isinstance(expected, RangeIndex): + # pd.array would return an IntegerArray + expected = NumpyExtensionArray(np.asarray(expected._values)) + else: + expected = pd.array(expected, copy=False) + elif box_cls is Index: + expected = Index(expected, copy=False) + elif box_cls is Series: + expected = Series(expected) + elif box_cls is DataFrame: + expected = Series(expected).to_frame() + if transpose: + # for vector operations, we need a DataFrame to be a single-row, + # not a single-column, in order to operate against non-DataFrame + # vectors of the same length. But convert to two rows to avoid + # single-row special cases in datetime arithmetic + expected = expected.T + expected = pd.concat([expected] * 2, ignore_index=True) + elif box_cls is np.ndarray or box_cls is np.array: + expected = np.array(expected) + elif box_cls is to_array: + expected = to_array(expected) + else: + raise NotImplementedError(box_cls) + return expected + + +def to_array(obj): + """ + Similar to pd.array, but does not cast numpy dtypes to nullable dtypes. + """ + # temporary implementation until we get pd.array in place + dtype = getattr(obj, "dtype", None) + + if dtype is None: + return np.asarray(obj) + + return extract_array(obj, extract_numpy=True) + + +class SubclassedSeries(Series): + _metadata = ["testattr", "name"] + + @property + def _constructor(self): + # For testing, those properties return a generic callable, and not + # the actual class. In this case that is equivalent, but it is to + # ensure we don't rely on the property returning a class + # See https://github.com/pandas-dev/pandas/pull/46018 and + # https://github.com/pandas-dev/pandas/issues/32638 and linked issues + return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs) + + @property + def _constructor_expanddim(self): + return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs) + + +class SubclassedDataFrame(DataFrame): + _metadata = ["testattr"] + + @property + def _constructor(self): + return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs) + + # error: Cannot override writeable attribute with read-only property + @property + def _constructor_sliced(self): # type: ignore[override] + return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs) + + +def convert_rows_list_to_csv_str(rows_list: list[str]) -> str: + """ + Convert list of CSV rows to single CSV-formatted string for current OS. + + This method is used for creating expected value of to_csv() method. + + Parameters + ---------- + rows_list : List[str] + Each element represents the row of csv. + + Returns + ------- + str + Expected output of to_csv() in current OS. + """ + sep = os.linesep + return sep.join(rows_list) + sep + + +def external_error_raised(expected_exception: type[Exception]) -> ContextManager: + """ + Helper function to mark pytest.raises that have an external error message. + + Parameters + ---------- + expected_exception : Exception + Expected error to raise. + + Returns + ------- + Callable + Regular `pytest.raises` function with `match` equal to `None`. + """ + import pytest + + return pytest.raises(expected_exception, match=None) + + +def get_cython_table_params(ndframe, func_names_and_expected): + """ + Combine frame, functions from com._cython_table + keys and expected result. + + Parameters + ---------- + ndframe : DataFrame or Series + func_names_and_expected : Sequence of two items + The first item is a name of an NDFrame method ('sum', 'prod') etc. + The second item is the expected return value. + + Returns + ------- + list + List of three items (DataFrame, function, expected result) + """ + results = [] + for func_name, expected in func_names_and_expected: + results.append((ndframe, func_name, expected)) + return results + + +def get_op_from_name(op_name: str) -> Callable: + """ + The operator function for a given op name. + + Parameters + ---------- + op_name : str + The op name, in form of "add" or "__add__". + + Returns + ------- + function + A function performing the operation. + """ + short_opname = op_name.strip("_") + try: + op = getattr(operator, short_opname) + except AttributeError: + # Assume it is the reverse operator + rop = getattr(operator, short_opname[1:]) + op = lambda x, y: rop(y, x) + + return op + + +# ----------------------------------------------------------------------------- +# Indexing test helpers + + +def getitem(x): + return x + + +def setitem(x): + return x + + +def loc(x): + return x.loc + + +def iloc(x): + return x.iloc + + +def at(x): + return x.at + + +def iat(x): + return x.iat + + +# ----------------------------------------------------------------------------- + +_UNITS = ["s", "ms", "us", "ns"] + + +def get_finest_unit(left: str, right: str) -> str: + """ + Find the higher of two datetime64 units. + """ + if _UNITS.index(left) >= _UNITS.index(right): + return left + return right + + +def shares_memory(left, right) -> bool: + """ + Pandas-compat for np.shares_memory. + """ + if isinstance(left, np.ndarray) and isinstance(right, np.ndarray): + return np.shares_memory(left, right) + elif isinstance(left, np.ndarray): + # Call with reversed args to get to unpacking logic below. + return shares_memory(right, left) + + if isinstance(left, RangeIndex): + return False + if isinstance(left, MultiIndex): + return shares_memory(left._codes, right) + if isinstance(left, (Index, Series)): + if isinstance(right, (Index, Series)): + return shares_memory(left._values, right._values) + return shares_memory(left._values, right) + + if isinstance(left, NDArrayBackedExtensionArray): + return shares_memory(left._ndarray, right) + if isinstance(left, pd.core.arrays.SparseArray): + return shares_memory(left.sp_values, right) + if isinstance(left, pd.core.arrays.IntervalArray): + return shares_memory(left._left, right) or shares_memory(left._right, right) + + if isinstance(left, ArrowExtensionArray): + if isinstance(right, ArrowExtensionArray): + # https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669 + left_pa_data = left._pa_array + right_pa_data = right._pa_array + left_buf1 = left_pa_data.chunk(0).buffers()[1] + right_buf1 = right_pa_data.chunk(0).buffers()[1] + return left_buf1.address == right_buf1.address + else: + # if we have one one ArrowExtensionArray and one other array, assume + # they can only share memory if they share the same numpy buffer + return np.shares_memory(left, right) + + if isinstance(left, BaseMaskedArray) and isinstance(right, BaseMaskedArray): + # By convention, we'll say these share memory if they share *either* + # the _data or the _mask + return np.shares_memory(left._data, right._data) or np.shares_memory( + left._mask, right._mask + ) + + if isinstance(left, DataFrame) and len(left._mgr.blocks) == 1: + arr = left._mgr.blocks[0].values + return shares_memory(arr, right) + + raise NotImplementedError(type(left), type(right)) + + +def run_multithreaded(closure, max_workers, arguments=None, pass_barrier=False): + with ThreadPoolExecutor(max_workers=max_workers) as tpe: + if arguments is None: + arguments = [] + else: + arguments = list(arguments) + + if pass_barrier: + barrier = threading.Barrier(max_workers) + arguments.append(barrier) + + try: + futures = [] + for _ in range(max_workers): + futures.append(tpe.submit(closure, *arguments)) # noqa: PERF401 + except RuntimeError as e: + import pytest + + pytest.skip( + f"Spawning {max_workers} threads failed with " + f"error {e!r} (likely due to resource limits on the " + "system running the tests)" + ) + finally: + if len(futures) < max_workers and pass_barrier: + barrier.abort() + for f in futures: + f.result() + + +__all__ = [ + "ALL_INT_EA_DTYPES", + "ALL_INT_NUMPY_DTYPES", + "ALL_NUMPY_DTYPES", + "ALL_REAL_NUMPY_DTYPES", + "BOOL_DTYPES", + "BYTES_DTYPES", + "COMPLEX_DTYPES", + "DATETIME64_DTYPES", + "ENDIAN", + "FLOAT_EA_DTYPES", + "FLOAT_NUMPY_DTYPES", + "NARROW_NP_DTYPES", + "NP_NAT_OBJECTS", + "NULL_OBJECTS", + "OBJECT_DTYPES", + "SIGNED_INT_EA_DTYPES", + "SIGNED_INT_NUMPY_DTYPES", + "STRING_DTYPES", + "TIMEDELTA64_DTYPES", + "UNSIGNED_INT_EA_DTYPES", + "UNSIGNED_INT_NUMPY_DTYPES", + "SubclassedDataFrame", + "SubclassedSeries", + "assert_almost_equal", + "assert_attr_equal", + "assert_categorical_equal", + "assert_class_equal", + "assert_contains_all", + "assert_copy", + "assert_datetime_array_equal", + "assert_dict_equal", + "assert_equal", + "assert_extension_array_equal", + "assert_frame_equal", + "assert_index_equal", + "assert_indexing_slices_equivalent", + "assert_interval_array_equal", + "assert_is_sorted", + "assert_metadata_equivalent", + "assert_numpy_array_equal", + "assert_period_array_equal", + "assert_produces_warning", + "assert_series_equal", + "assert_sp_array_equal", + "assert_timedelta_array_equal", + "at", + "box_expected", + "can_set_locale", + "convert_rows_list_to_csv_str", + "decompress_file", + "external_error_raised", + "get_cython_table_params", + "get_dtype", + "get_finest_unit", + "get_locales", + "get_obj", + "get_op_from_name", + "getitem", + "iat", + "iloc", + "loc", + "maybe_produces_warning", + "raise_assert_detail", + "raises_chained_assignment_error", + "round_trip_pathlib", + "round_trip_pickle", + "run_multithreaded", + "set_locale", + "set_timezone", + "setitem", + "shares_memory", + "to_array", + "with_csv_dialect", + "write_to_compressed", +] diff --git a/pandas/_testing/_hypothesis.py b/pandas/_testing/_hypothesis.py new file mode 100644 index 0000000000000000000000000000000000000000..bbad21d8ab8d11b1590d7904090d0b528d24c744 --- /dev/null +++ b/pandas/_testing/_hypothesis.py @@ -0,0 +1,89 @@ +""" +Hypothesis data generator helpers. +""" + +from datetime import datetime + +from hypothesis import strategies as st +from hypothesis.extra.dateutil import timezones as dateutil_timezones + +from pandas.compat import is_platform_windows + +import pandas as pd + +from pandas.tseries.offsets import ( + BMonthBegin, + BMonthEnd, + BQuarterBegin, + BQuarterEnd, + BYearBegin, + BYearEnd, + MonthBegin, + MonthEnd, + QuarterBegin, + QuarterEnd, + YearBegin, + YearEnd, +) + +OPTIONAL_INTS = st.lists(st.one_of(st.integers(), st.none()), max_size=10, min_size=3) + +OPTIONAL_FLOATS = st.lists(st.one_of(st.floats(), st.none()), max_size=10, min_size=3) + +OPTIONAL_TEXT = st.lists(st.one_of(st.none(), st.text()), max_size=10, min_size=3) + +OPTIONAL_DICTS = st.lists( + st.one_of(st.none(), st.dictionaries(st.text(), st.integers())), + max_size=10, + min_size=3, +) + +OPTIONAL_LISTS = st.lists( + st.one_of(st.none(), st.lists(st.text(), max_size=10, min_size=3)), + max_size=10, + min_size=3, +) + +OPTIONAL_ONE_OF_ALL = st.one_of( + OPTIONAL_DICTS, OPTIONAL_FLOATS, OPTIONAL_INTS, OPTIONAL_LISTS, OPTIONAL_TEXT +) + +if is_platform_windows(): + DATETIME_NO_TZ = st.datetimes(min_value=datetime(1900, 1, 1)) +else: + DATETIME_NO_TZ = st.datetimes() + +DATETIME_JAN_1_1900_OPTIONAL_TZ = st.datetimes( + min_value=pd.Timestamp(1900, 1, 1).to_pydatetime(), # pyright: ignore[reportArgumentType] + max_value=pd.Timestamp(1900, 1, 1).to_pydatetime(), # pyright: ignore[reportArgumentType] + timezones=st.one_of(st.none(), dateutil_timezones(), st.timezones()), +) + +DATETIME_IN_PD_TIMESTAMP_RANGE_NO_TZ = st.datetimes( + min_value=pd.Timestamp.min.to_pydatetime(warn=False), + max_value=pd.Timestamp.max.to_pydatetime(warn=False), +) + +INT_NEG_999_TO_POS_999 = st.integers(-999, 999) + +# The strategy for each type is registered in conftest.py, as they don't carry +# enough runtime information (e.g. type hints) to infer how to build them. +YQM_OFFSET = st.one_of( + *map( + st.from_type, + [ + MonthBegin, + MonthEnd, + BMonthBegin, + BMonthEnd, + QuarterBegin, + QuarterEnd, + BQuarterBegin, + BQuarterEnd, + YearBegin, + YearEnd, + BYearBegin, + BYearEnd, + ], + ) +) diff --git a/pandas/_testing/_io.py b/pandas/_testing/_io.py new file mode 100644 index 0000000000000000000000000000000000000000..78ed56bd59077abfa403286ee973300bb6826c66 --- /dev/null +++ b/pandas/_testing/_io.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import gzip +import io +import tarfile +from typing import ( + TYPE_CHECKING, + Any, +) +import zipfile + +from pandas.compat._optional import import_optional_dependency + +import pandas as pd + +if TYPE_CHECKING: + from collections.abc import Callable + from pathlib import Path + + from pandas import ( + DataFrame, + Series, + ) + +# ------------------------------------------------------------------ +# File-IO + + +def round_trip_pickle(obj: Any, tmp_path: Path) -> DataFrame | Series: + """ + Pickle an object and then read it again. + + Parameters + ---------- + obj : any object + The object to pickle and then re-read. + path : str, path object or file-like object, default None + The path where the pickled object is written and then read. + + Returns + ------- + pandas object + The original object that was pickled and then re-read. + """ + pd.to_pickle(obj, tmp_path) + return pd.read_pickle(tmp_path) + + +def round_trip_pathlib(writer, reader, tmp_path: Path): + """ + Write an object to file specified by a pathlib.Path and read it back + + Parameters + ---------- + writer : callable bound to pandas object + IO writing function (e.g. DataFrame.to_csv ) + reader : callable + IO reading function (e.g. pd.read_csv ) + path : str, default None + The path where the object is written and then read. + + Returns + ------- + pandas object + The original object that was serialized and then re-read. + """ + writer(tmp_path) + obj = reader(tmp_path) + return obj + + +def write_to_compressed(compression, path: str, data, dest: str = "test") -> None: + """ + Write data to a compressed file. + + Parameters + ---------- + compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd'} + The compression type to use. + path : str + The file path to write the data. + data : str + The data to write. + dest : str, default "test" + The destination file (for ZIP only) + + Raises + ------ + ValueError : An invalid compression value was passed in. + """ + args: tuple[Any, ...] = (data,) + mode = "wb" + method = "write" + compress_method: Callable + + if compression == "zip": + compress_method = zipfile.ZipFile + mode = "w" + args = (dest, data) + method = "writestr" + elif compression == "tar": + compress_method = tarfile.TarFile + mode = "w" + file = tarfile.TarInfo(name=dest) + bytes = io.BytesIO(data) + file.size = len(data) + args = (file, bytes) + method = "addfile" + elif compression == "gzip": + compress_method = gzip.GzipFile + elif compression == "bz2": + import bz2 + + compress_method = bz2.BZ2File + elif compression == "zstd": + compress_method = import_optional_dependency("zstandard").open + elif compression == "xz": + import lzma + + compress_method = lzma.LZMAFile + else: + raise ValueError(f"Unrecognized compression type: {compression}") + + # error: No overload variant of "ZipFile" matches argument types "str", "str" + # error: No overload variant of "BZ2File" matches argument types "str", "str" + # error: Argument "mode" to "TarFile" has incompatible type "str"; + # expected "Literal['r', 'a', 'w', 'x'] + with compress_method(path, mode=mode) as f: # type: ignore[call-overload, arg-type] + getattr(f, method)(*args) diff --git a/pandas/_testing/_warnings.py b/pandas/_testing/_warnings.py new file mode 100644 index 0000000000000000000000000000000000000000..d2d1f5c0c273e74bba14b9c55d44bb8f89e7edc3 --- /dev/null +++ b/pandas/_testing/_warnings.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +from contextlib import ( + AbstractContextManager, + contextmanager, + nullcontext, +) +import inspect +import re +import sys +from typing import ( + TYPE_CHECKING, + Literal, + Union, + cast, +) +import warnings + +if TYPE_CHECKING: + from collections.abc import ( + Generator, + Sequence, + ) + + +@contextmanager +def assert_produces_warning( + expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None = Warning, + filter_level: Literal[ + "error", "ignore", "always", "default", "module", "once" + ] = "always", + check_stacklevel: bool = True, + raise_on_extra_warnings: bool = True, + match: str | tuple[str | None, ...] | None = None, + must_find_all_warnings: bool = True, +) -> Generator[list[warnings.WarningMessage]]: + """ + Context manager for running code expected to either raise a specific warning, + multiple specific warnings, or not raise any warnings. Verifies that the code + raises the expected warning(s), and that it does not raise any other unexpected + warnings. It is basically a wrapper around ``warnings.catch_warnings``. + + Parameters + ---------- + expected_warning : {Warning, False, tuple[Warning, ...], None}, default Warning + The type of Exception raised. ``exception.Warning`` is the base + class for all warnings. To raise multiple types of exceptions, + pass them as a tuple. To check that no warning is returned, + specify ``False`` or ``None``. + filter_level : str or None, default "always" + Specifies whether warnings are ignored, displayed, or turned + into errors. + Valid values are: + + * "error" - turns matching warnings into exceptions + * "ignore" - discard the warning + * "always" - always emit a warning + * "default" - print the warning the first time it is generated + from each location + * "module" - print the warning the first time it is generated + from each module + * "once" - print the warning the first time it is generated + + check_stacklevel : bool, default True + If True, displays the line that called the function containing + the warning to show were the function is called. Otherwise, the + line that implements the function is displayed. + raise_on_extra_warnings : bool, default True + Whether extra warnings not of the type `expected_warning` should + cause the test to fail. + match : {str, tuple[str, ...]}, optional + Match warning message. If it's a tuple, it has to be the size of + `expected_warning`. If additionally `must_find_all_warnings` is + True, each expected warning's message gets matched with a respective + match. Otherwise, multiple values get treated as an alternative. + must_find_all_warnings : bool, default True + If True and `expected_warning` is a tuple, each expected warning + type must get encountered. Otherwise, even one expected warning + results in success. + + Examples + -------- + >>> import warnings + >>> with assert_produces_warning(): + ... warnings.warn(UserWarning()) + >>> with assert_produces_warning(False): + ... warnings.warn(RuntimeWarning()) + Traceback (most recent call last): + ... + AssertionError: Caused unexpected warning(s): ['RuntimeWarning']. + >>> with assert_produces_warning(UserWarning): + ... warnings.warn(RuntimeWarning()) + Traceback (most recent call last): + ... + AssertionError: Did not see expected warning of class 'UserWarning'. + + ..warn:: This is *not* thread-safe. + """ + __tracebackhide__ = True + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter(filter_level) + try: + yield w + finally: + if expected_warning: + if isinstance(expected_warning, tuple) and must_find_all_warnings: + match = ( + match + if isinstance(match, tuple) + else (match,) * len(expected_warning) + ) + for warning_type, warning_match in zip( + expected_warning, match, strict=True + ): + _assert_caught_expected_warnings( + caught_warnings=w, + expected_warning=warning_type, + match=warning_match, + check_stacklevel=check_stacklevel, + ) + else: + expected_warning = cast( + Union[type[Warning], tuple[type[Warning], ...]], + expected_warning, + ) + match = ( + "|".join(m for m in match if m) + if isinstance(match, tuple) + else match + ) + _assert_caught_expected_warnings( + caught_warnings=w, + expected_warning=expected_warning, + match=match, + check_stacklevel=check_stacklevel, + ) + if raise_on_extra_warnings: + _assert_caught_no_extra_warnings( + caught_warnings=w, + expected_warning=expected_warning, + ) + + +def maybe_produces_warning( + warning: type[Warning], condition: bool, **kwargs +) -> AbstractContextManager: + """ + Return a context manager that possibly checks a warning based on the condition + """ + if condition: + return assert_produces_warning(warning, **kwargs) + else: + return nullcontext() + + +def _assert_caught_expected_warnings( + *, + caught_warnings: Sequence[warnings.WarningMessage], + expected_warning: type[Warning] | tuple[type[Warning], ...], + match: str | None, + check_stacklevel: bool, +) -> None: + """Assert that there was the expected warning among the caught warnings.""" + saw_warning = False + matched_message = False + unmatched_messages = [] + warning_name = ( + tuple(x.__name__ for x in expected_warning) + if isinstance(expected_warning, tuple) + else expected_warning.__name__ + ) + + for actual_warning in caught_warnings: + if issubclass(actual_warning.category, expected_warning): + saw_warning = True + + if check_stacklevel: + _assert_raised_with_correct_stacklevel(actual_warning) + + if match is not None: + if re.search(match, str(actual_warning.message)): + matched_message = True + else: + unmatched_messages.append(actual_warning.message) + + if not saw_warning: + raise AssertionError(f"Did not see expected warning of class {warning_name!r}") + + if match and not matched_message: + raise AssertionError( + f"Did not see warning {warning_name!r} " + f"matching '{match}'. The emitted warning messages are " + f"{unmatched_messages}" + ) + + +def _assert_caught_no_extra_warnings( + *, + caught_warnings: Sequence[warnings.WarningMessage], + expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None, +) -> None: + """Assert that no extra warnings apart from the expected ones are caught.""" + extra_warnings = [] + + for actual_warning in caught_warnings: + if _is_unexpected_warning(actual_warning, expected_warning): + # GH#38630 pytest.filterwarnings does not suppress these. + if actual_warning.category == ResourceWarning: + # GH 44732: Don't make the CI flaky by filtering SSL-related + # ResourceWarning from dependencies + if "unclosed bool: + """Check if the actual warning issued is unexpected.""" + if actual_warning and not expected_warning: + return True + expected_warning = cast(type[Warning], expected_warning) + return bool(not issubclass(actual_warning.category, expected_warning)) + + +def _assert_raised_with_correct_stacklevel( + actual_warning: warnings.WarningMessage, +) -> None: + # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow + frame = inspect.currentframe() + for _ in range(4): + frame = frame.f_back # type: ignore[union-attr] + try: + caller_filename = inspect.getfile(frame) # type: ignore[arg-type] + finally: + # See note in + # https://docs.python.org/3/library/inspect.html#inspect.Traceback + del frame + msg = ( + "Warning not set with correct stacklevel. " + f"File where warning is raised: {actual_warning.filename} != " + f"{caller_filename}. Warning message: {actual_warning.message}" + ) + assert actual_warning.filename == caller_filename, msg diff --git a/pandas/_testing/asserters.py b/pandas/_testing/asserters.py new file mode 100644 index 0000000000000000000000000000000000000000..3732fd9d2065561d4390fd4a167a31abab7a5bce --- /dev/null +++ b/pandas/_testing/asserters.py @@ -0,0 +1,1503 @@ +from __future__ import annotations + +import operator +from typing import ( + TYPE_CHECKING, + Literal, + NoReturn, + cast, +) +import warnings + +import numpy as np + +from pandas._libs import lib +from pandas._libs.missing import is_matching_na +from pandas._libs.sparse import SparseIndex +import pandas._libs.testing as _testing +from pandas._libs.tslibs.np_datetime import compare_mismatched_resolutions +from pandas.errors import Pandas4Warning +from pandas.util._decorators import ( + deprecate_kwarg, + set_module, +) + +from pandas.core.dtypes.common import ( + is_bool, + is_float_dtype, + is_integer_dtype, + is_number, + is_numeric_dtype, + needs_i8_conversion, +) +from pandas.core.dtypes.dtypes import ( + CategoricalDtype, + DatetimeTZDtype, + ExtensionDtype, + NumpyEADtype, +) +from pandas.core.dtypes.missing import array_equivalent + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + DatetimeIndex, + Index, + IntervalDtype, + IntervalIndex, + MultiIndex, + PeriodIndex, + RangeIndex, + Series, + TimedeltaIndex, +) +from pandas.core.arrays import ( + DatetimeArray, + ExtensionArray, + IntervalArray, + PeriodArray, + TimedeltaArray, +) +from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin +from pandas.core.arrays.string_ import StringDtype +from pandas.core.indexes.api import safe_sort_index + +from pandas.io.formats.printing import pprint_thing + +if TYPE_CHECKING: + from pandas._typing import DtypeObj + + +def assert_almost_equal( + left, + right, + check_dtype: bool | Literal["equiv"] = "equiv", + rtol: float = 1.0e-5, + atol: float = 1.0e-8, + **kwargs, +) -> None: + """ + Check that the left and right objects are approximately equal. + + By approximately equal, we refer to objects that are numbers or that + contain numbers which may be equivalent to specific levels of precision. + + Parameters + ---------- + left : object + right : object + check_dtype : bool or {'equiv'}, default 'equiv' + Check dtype if both a and b are the same type. If 'equiv' is passed in, + then `RangeIndex` and `Index` with int64 dtype are also considered + equivalent when doing type checking. + rtol : float, default 1e-5 + Relative tolerance. + atol : float, default 1e-8 + Absolute tolerance. + """ + if isinstance(left, Index): + assert_index_equal( + left, + right, + check_exact=False, + exact=check_dtype, + rtol=rtol, + atol=atol, + **kwargs, + ) + + elif isinstance(left, Series): + assert_series_equal( + left, + right, + check_exact=False, + check_dtype=check_dtype, + rtol=rtol, + atol=atol, + **kwargs, + ) + + elif isinstance(left, DataFrame): + assert_frame_equal( + left, + right, + check_exact=False, + check_dtype=check_dtype, + rtol=rtol, + atol=atol, + **kwargs, + ) + + else: + # Other sequences. + if check_dtype: + if is_number(left) and is_number(right): + # Do not compare numeric classes, like np.float64 and float. + pass + elif is_bool(left) and is_bool(right): + # Do not compare bool classes, like np.bool_ and bool. + pass + else: + if isinstance(left, np.ndarray) or isinstance(right, np.ndarray): + obj = "numpy array" + else: + obj = "Input" + assert_class_equal(left, right, obj=obj) + + # if we have "equiv", this becomes True + _testing.assert_almost_equal( + left, right, check_dtype=bool(check_dtype), rtol=rtol, atol=atol, **kwargs + ) + + +def _check_isinstance(left, right, cls) -> None: + """ + Helper method for our assert_* methods that ensures that + the two objects being compared have the right type before + proceeding with the comparison. + + Parameters + ---------- + left : The first object being compared. + right : The second object being compared. + cls : The class type to check against. + + Raises + ------ + AssertionError : Either `left` or `right` is not an instance of `cls`. + """ + cls_name = cls.__name__ + + if not isinstance(left, cls): + raise AssertionError( + f"{cls_name} Expected type {cls}, found {type(left)} instead" + ) + if not isinstance(right, cls): + raise AssertionError( + f"{cls_name} Expected type {cls}, found {type(right)} instead" + ) + + +def assert_dict_equal(left, right, compare_keys: bool = True) -> None: + _check_isinstance(left, right, dict) + _testing.assert_dict_equal(left, right, compare_keys=compare_keys) + + +@set_module("pandas.testing") +def assert_index_equal( + left: Index, + right: Index, + exact: bool | str = "equiv", + check_names: bool = True, + check_exact: bool = True, + check_categorical: bool = True, + check_order: bool = True, + rtol: float = 1.0e-5, + atol: float = 1.0e-8, + obj: str | None = None, +) -> None: + """ + Check that left and right Index are equal. + + Parameters + ---------- + left : Index + The first index to compare. + right : Index + The second index to compare. + exact : bool or {'equiv'}, default 'equiv' + Whether to check the Index class, dtype and inferred_type + are identical. If 'equiv', then RangeIndex can be substituted for + Index with an int64 dtype as well. + check_names : bool, default True + Whether to check the names attribute. + check_exact : bool, default True + Whether to compare number exactly. + check_categorical : bool, default True + Whether to compare internal Categorical exactly. + check_order : bool, default True + Whether to compare the order of index entries as well as their values. + If True, both indexes must contain the same elements, in the same order. + If False, both indexes must contain the same elements, but in any order. + rtol : float, default 1e-5 + Relative tolerance. Only used when check_exact is False. + atol : float, default 1e-8 + Absolute tolerance. Only used when check_exact is False. + obj : str, default 'Index' or 'MultiIndex' + Specify object name being compared, internally used to show appropriate + assertion message. + + See Also + -------- + testing.assert_series_equal : Check that two Series are equal. + testing.assert_frame_equal : Check that two DataFrames are equal. + + Examples + -------- + >>> from pandas import testing as tm + >>> a = pd.Index([1, 2, 3]) + >>> b = pd.Index([1, 2, 3]) + >>> tm.assert_index_equal(a, b) + """ + __tracebackhide__ = True + + if obj is None: + obj = "MultiIndex" if isinstance(left, MultiIndex) else "Index" + + def _check_types(left, right, obj: str = "Index") -> None: + if not exact: + return + + assert_class_equal(left, right, exact=exact, obj=obj) + assert_attr_equal("inferred_type", left, right, obj=obj) + + # Skip exact dtype checking when `check_categorical` is False + if isinstance(left.dtype, CategoricalDtype) and isinstance( + right.dtype, CategoricalDtype + ): + if check_categorical: + assert_attr_equal("dtype", left, right, obj=obj) + assert_index_equal(left.categories, right.categories, exact=exact) + return + + assert_attr_equal("dtype", left, right, obj=obj) + + # instance validation + _check_isinstance(left, right, Index) + + # class / dtype comparison + _check_types(left, right, obj=obj) + + # level comparison + if left.nlevels != right.nlevels: + msg1 = f"{obj} levels are different" + msg2 = f"{left.nlevels}, {left}" + msg3 = f"{right.nlevels}, {right}" + raise_assert_detail(obj, msg1, msg2, msg3) + + # length comparison + if len(left) != len(right): + msg1 = f"{obj} length are different" + msg2 = f"{len(left)}, {left}" + msg3 = f"{len(right)}, {right}" + raise_assert_detail(obj, msg1, msg2, msg3) + + # If order doesn't matter then sort the index entries + if not check_order: + left = safe_sort_index(left) + right = safe_sort_index(right) + + # MultiIndex special comparison for little-friendly error messages + if isinstance(left, MultiIndex): + right = cast(MultiIndex, right) + + for level in range(left.nlevels): + lobj = f"{obj} level [{level}]" + try: + # try comparison on levels/codes to avoid densifying MultiIndex + assert_index_equal( + left.levels[level], + right.levels[level], + exact=exact, + check_names=check_names, + check_exact=check_exact, + check_categorical=check_categorical, + rtol=rtol, + atol=atol, + obj=lobj, + ) + assert_numpy_array_equal(left.codes[level], right.codes[level]) + except AssertionError: + llevel = left.get_level_values(level) + rlevel = right.get_level_values(level) + + assert_index_equal( + llevel, + rlevel, + exact=exact, + check_names=check_names, + check_exact=check_exact, + check_categorical=check_categorical, + rtol=rtol, + atol=atol, + obj=lobj, + ) + # get_level_values may change dtype + _check_types(left.levels[level], right.levels[level], obj=lobj) + + # skip exact index checking when `check_categorical` is False + elif check_exact and check_categorical: + if not left.equals(right): + # _values compare can raise TypeError (non-comparable + # categoricals (GH#61935) + try: + mismatch = left._values != right._values + except TypeError: + raise_assert_detail( + obj, + "types are not comparable (non-matching categorical categories)", + left, + right, + ) + + if not isinstance(mismatch, np.ndarray): + mismatch = cast("ExtensionArray", mismatch).fillna(True) + + diff = np.sum(mismatch.astype(int)) * 100.0 / len(left) + msg = f"{obj} values are different ({np.round(diff, 5)} %)" + raise_assert_detail(obj, msg, left, right) + else: + # if we have "equiv", this becomes True + exact_bool = bool(exact) + _testing.assert_almost_equal( + left.values, + right.values, + rtol=rtol, + atol=atol, + check_dtype=exact_bool, + obj=obj, + lobj=left, + robj=right, + ) + + # metadata comparison + if check_names: + assert_attr_equal("names", left, right, obj=obj) + if isinstance(left, PeriodIndex) or isinstance(right, PeriodIndex): + assert_attr_equal("dtype", left, right, obj=obj) + if isinstance(left, IntervalIndex) or isinstance(right, IntervalIndex): + assert_interval_array_equal(left._values, right._values) + + if check_categorical: + if isinstance(left.dtype, CategoricalDtype) or isinstance( + right.dtype, CategoricalDtype + ): + assert_categorical_equal(left._values, right._values, obj=f"{obj} category") + + +def assert_class_equal( + left, right, exact: bool | str = True, obj: str = "Input" +) -> None: + """ + Checks classes are equal. + """ + __tracebackhide__ = True + + def repr_class(x): + if isinstance(x, Index): + # return Index as it is to include values in the error message + return x + + return type(x).__name__ + + def is_class_equiv(idx: Index) -> bool: + """Classes that are a RangeIndex (sub-)instance or exactly an `Index` . + + This only checks class equivalence. There is a separate check that the + dtype is int64. + """ + return type(idx) is Index or isinstance(idx, RangeIndex) + + if type(left) == type(right): + return + + if exact == "equiv": + if is_class_equiv(left) and is_class_equiv(right): + return + + msg = f"{obj} classes are different" + raise_assert_detail(obj, msg, repr_class(left), repr_class(right)) + + +def assert_attr_equal(attr: str, left, right, obj: str = "Attributes") -> None: + """ + Check attributes are equal. Both objects must have attribute. + + Parameters + ---------- + attr : str + Attribute name being compared. + left : object + right : object + obj : str, default 'Attributes' + Specify object name being compared, internally used to show appropriate + assertion message + """ + __tracebackhide__ = True + + left_attr = getattr(left, attr) + right_attr = getattr(right, attr) + + if left_attr is right_attr or is_matching_na(left_attr, right_attr): + # e.g. both np.nan, both NaT, both pd.NA, ... + return None + + try: + result = left_attr == right_attr + except TypeError: + # datetimetz on rhs may raise TypeError + result = False + if (left_attr is pd.NA) ^ (right_attr is pd.NA): + result = False + elif not isinstance(result, bool): + result = result.all() + + if not result: + msg = f'Attribute "{attr}" are different' + raise_assert_detail(obj, msg, left_attr, right_attr) + return None + + +def assert_is_sorted(seq) -> None: + """Assert that the sequence is sorted.""" + if isinstance(seq, (Index, Series)): + seq = seq.values + # sorting does not change precisions + if isinstance(seq, np.ndarray): + assert_numpy_array_equal(seq, np.sort(np.array(seq))) + else: + assert_extension_array_equal(seq, seq[seq.argsort()]) + + +def assert_categorical_equal( + left, + right, + check_dtype: bool = True, + check_category_order: bool = True, + obj: str = "Categorical", +) -> None: + """ + Test that Categoricals are equivalent. + + Parameters + ---------- + left : Categorical + right : Categorical + check_dtype : bool, default True + Check that integer dtype of the codes are the same. + check_category_order : bool, default True + Whether the order of the categories should be compared, which + implies identical integer codes. If False, only the resulting + values are compared. The ordered attribute is + checked regardless. + obj : str, default 'Categorical' + Specify object name being compared, internally used to show appropriate + assertion message. + """ + _check_isinstance(left, right, Categorical) + + exact: bool | str + if isinstance(left.categories, RangeIndex) or isinstance( + right.categories, RangeIndex + ): + exact = "equiv" + else: + # We still want to require exact matches for Index + exact = True + + if check_category_order: + assert_index_equal( + left.categories, right.categories, obj=f"{obj}.categories", exact=exact + ) + assert_numpy_array_equal( + left.codes, right.codes, check_dtype=check_dtype, obj=f"{obj}.codes" + ) + else: + try: + lc = left.categories.sort_values() + rc = right.categories.sort_values() + except TypeError: + # e.g. '<' not supported between instances of 'int' and 'str' + lc, rc = left.categories, right.categories + assert_index_equal(lc, rc, obj=f"{obj}.categories", exact=exact) + assert_index_equal( + left.categories.take(left.codes), + right.categories.take(right.codes), + obj=f"{obj}.values", + exact=exact, + ) + + assert_attr_equal("ordered", left, right, obj=obj) + + +def assert_interval_array_equal( + left, right, exact: bool | Literal["equiv"] = "equiv", obj: str = "IntervalArray" +) -> None: + """ + Test that two IntervalArrays are equivalent. + + Parameters + ---------- + left, right : IntervalArray + The IntervalArrays to compare. + exact : bool or {'equiv'}, default 'equiv' + Whether to check the Index class, dtype and inferred_type + are identical. If 'equiv', then RangeIndex can be substituted for + Index with an int64 dtype as well. + obj : str, default 'IntervalArray' + Specify object name being compared, internally used to show appropriate + assertion message + """ + _check_isinstance(left, right, IntervalArray) + + kwargs = {} + if left._left.dtype.kind in "mM": + # We have a DatetimeArray or TimedeltaArray + kwargs["check_freq"] = False + + assert_equal(left._left, right._left, obj=f"{obj}.left", **kwargs) + assert_equal(left._right, right._right, obj=f"{obj}.right", **kwargs) + + assert_attr_equal("closed", left, right, obj=obj) + + +def assert_period_array_equal(left, right, obj: str = "PeriodArray") -> None: + _check_isinstance(left, right, PeriodArray) + + assert_numpy_array_equal(left._ndarray, right._ndarray, obj=f"{obj}._ndarray") + assert_attr_equal("dtype", left, right, obj=obj) + + +def assert_datetime_array_equal( + left, right, obj: str = "DatetimeArray", check_freq: bool = True +) -> None: + __tracebackhide__ = True + _check_isinstance(left, right, DatetimeArray) + + assert_numpy_array_equal(left._ndarray, right._ndarray, obj=f"{obj}._ndarray") + if check_freq: + assert_attr_equal("freq", left, right, obj=obj) + assert_attr_equal("tz", left, right, obj=obj) + + +def assert_timedelta_array_equal( + left, right, obj: str = "TimedeltaArray", check_freq: bool = True +) -> None: + __tracebackhide__ = True + _check_isinstance(left, right, TimedeltaArray) + assert_numpy_array_equal(left._ndarray, right._ndarray, obj=f"{obj}._ndarray") + if check_freq: + assert_attr_equal("freq", left, right, obj=obj) + + +def raise_assert_detail( + obj, message, left, right, diff=None, first_diff=None, index_values=None +) -> NoReturn: + __tracebackhide__ = True + + msg = f"""{obj} are different + +{message}""" + + if isinstance(index_values, Index): + index_values = np.asarray(index_values) + + if isinstance(index_values, np.ndarray): + msg += f"\n[index]: {pprint_thing(index_values)}" + + if isinstance(left, np.ndarray): + left = pprint_thing(left) + elif isinstance(left, (CategoricalDtype, StringDtype, NumpyEADtype)): + left = repr(left) + + if isinstance(right, np.ndarray): + right = pprint_thing(right) + elif isinstance(right, (CategoricalDtype, StringDtype, NumpyEADtype)): + right = repr(right) + + msg += f""" +[left]: {left} +[right]: {right}""" + + if diff is not None: + msg += f"\n[diff]: {diff}" + + if first_diff is not None: + msg += f"\n{first_diff}" + + raise AssertionError(msg) + + +def assert_numpy_array_equal( + left, + right, + strict_nan: bool = False, + check_dtype: bool | Literal["equiv"] = True, + err_msg=None, + check_same=None, + obj: str = "numpy array", + index_values=None, +) -> None: + """ + Check that 'np.ndarray' is equivalent. + + Parameters + ---------- + left, right : numpy.ndarray or iterable + The two arrays to be compared. + strict_nan : bool, default False + If True, consider NaN and None to be different. + check_dtype : bool, default True + Check dtype if both a and b are np.ndarray. + err_msg : str, default None + If provided, used as assertion message. + check_same : None|'copy'|'same', default None + Ensure left and right refer/do not refer to the same memory area. + obj : str, default 'numpy array' + Specify object name being compared, internally used to show appropriate + assertion message. + index_values : Index | numpy.ndarray, default None + optional index (shared by both left and right), used in output. + """ + __tracebackhide__ = True + + # instance validation + # Show a detailed error message when classes are different + assert_class_equal(left, right, obj=obj) + # both classes must be an np.ndarray + _check_isinstance(left, right, np.ndarray) + + def _get_base(obj): + return obj.base if getattr(obj, "base", None) is not None else obj + + left_base = _get_base(left) + right_base = _get_base(right) + + if check_same == "same": + if left_base is not right_base: + raise AssertionError(f"{left_base!r} is not {right_base!r}") + elif check_same == "copy": + if left_base is right_base: + raise AssertionError(f"{left_base!r} is {right_base!r}") + + def _raise(left, right, err_msg) -> NoReturn: + if err_msg is None: + if left.shape != right.shape: + raise_assert_detail( + obj, f"{obj} shapes are different", left.shape, right.shape + ) + + diff = 0 + for left_arr, right_arr in zip(left, right, strict=True): + # count up differences + if not array_equivalent(left_arr, right_arr, strict_nan=strict_nan): + diff += 1 + + diff = diff * 100.0 / left.size + msg = f"{obj} values are different ({np.round(diff, 5)} %)" + raise_assert_detail(obj, msg, left, right, index_values=index_values) + + raise AssertionError(err_msg) + + # compare shape and values + if not array_equivalent(left, right, strict_nan=strict_nan): + _raise(left, right, err_msg) + + if check_dtype: + if isinstance(left, np.ndarray) and isinstance(right, np.ndarray): + assert_attr_equal("dtype", left, right, obj=obj) + + +@set_module("pandas.testing") +def assert_extension_array_equal( + left, + right, + check_dtype: bool | Literal["equiv"] = True, + index_values=None, + check_exact: bool | lib.NoDefault = lib.no_default, + rtol: float | lib.NoDefault = lib.no_default, + atol: float | lib.NoDefault = lib.no_default, + obj: str = "ExtensionArray", +) -> None: + """ + Check that left and right ExtensionArrays are equal. + + This method compares two ``ExtensionArray`` instances for equality, + including checks for missing values, the dtype of the arrays, and + the exactness of the comparison (or tolerance when comparing floats). + + Parameters + ---------- + left, right : ExtensionArray + The two arrays to compare. + check_dtype : bool, default True + Whether to check if the ExtensionArray dtypes are identical. + index_values : Index | numpy.ndarray, default None + Optional index (shared by both left and right), used in output. + check_exact : bool, default False + Whether to compare number exactly. + + .. versionchanged:: 2.2.0 + + Defaults to True for integer dtypes if none of + ``check_exact``, ``rtol`` and ``atol`` are specified. + rtol : float, default 1e-5 + Relative tolerance. Only used when check_exact is False. + atol : float, default 1e-8 + Absolute tolerance. Only used when check_exact is False. + obj : str, default 'ExtensionArray' + Specify object name being compared, internally used to show appropriate + assertion message. + + .. versionadded:: 2.0.0 + + See Also + -------- + testing.assert_series_equal : Check that left and right ``Series`` are equal. + testing.assert_frame_equal : Check that left and right ``DataFrame`` are equal. + testing.assert_index_equal : Check that left and right ``Index`` are equal. + + Notes + ----- + Missing values are checked separately from valid values. + A mask of missing values is computed for each and checked to match. + The remaining all-valid values are cast to object dtype and checked. + + Examples + -------- + >>> from pandas import testing as tm + >>> a = pd.Series([1, 2, 3, 4]) + >>> b, c = a.array, a.array + >>> tm.assert_extension_array_equal(b, c) + """ + if ( + check_exact is lib.no_default + and rtol is lib.no_default + and atol is lib.no_default + ): + check_exact = ( + is_numeric_dtype(left.dtype) and not is_float_dtype(left.dtype) + ) or (is_numeric_dtype(right.dtype) and not is_float_dtype(right.dtype)) + elif check_exact is lib.no_default: + check_exact = False + + rtol = rtol if rtol is not lib.no_default else 1.0e-5 + atol = atol if atol is not lib.no_default else 1.0e-8 + + assert isinstance(left, ExtensionArray), "left is not an ExtensionArray" + assert isinstance(right, ExtensionArray), "right is not an ExtensionArray" + if check_dtype: + assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}") + + if ( + isinstance(left, DatetimeLikeArrayMixin) + and isinstance(right, DatetimeLikeArrayMixin) + and type(right) == type(left) + ): + # GH 52449 + if not check_dtype and left.dtype.kind in "mM": + if not isinstance(left.dtype, np.dtype): + l_unit = cast(DatetimeTZDtype, left.dtype).unit + else: + l_unit = np.datetime_data(left.dtype)[0] + if not isinstance(right.dtype, np.dtype): + r_unit = cast(DatetimeTZDtype, right.dtype).unit + else: + r_unit = np.datetime_data(right.dtype)[0] + if ( + l_unit != r_unit + and compare_mismatched_resolutions( + left._ndarray, right._ndarray, operator.eq + ).all() + ): + return + # Avoid slow object-dtype comparisons + # np.asarray for case where we have an np.MaskedArray + assert_numpy_array_equal( + np.asarray(left.asi8), + np.asarray(right.asi8), + index_values=index_values, + obj=obj, + ) + return + + left_na = np.asarray(left.isna()) + right_na = np.asarray(right.isna()) + assert_numpy_array_equal( + left_na, right_na, obj=f"{obj} NA mask", index_values=index_values + ) + + # Specifically for StringArrayNumpySemantics, validate here we have a valid array + if ( + isinstance(left.dtype, StringDtype) + and left.dtype.storage == "python" + and left.dtype.na_value is np.nan + ): + assert np.all( + [np.isnan(val) for val in left._ndarray[left_na]] # type: ignore[attr-defined] + ), "wrong missing value sentinels" + if ( + isinstance(right.dtype, StringDtype) + and right.dtype.storage == "python" + and right.dtype.na_value is np.nan + ): + assert np.all( + [np.isnan(val) for val in right._ndarray[right_na]] # type: ignore[attr-defined] + ), "wrong missing value sentinels" + + left_valid = left[~left_na].to_numpy(dtype=object) + right_valid = right[~right_na].to_numpy(dtype=object) + if check_exact: + assert_numpy_array_equal( + left_valid, right_valid, obj=obj, index_values=index_values + ) + else: + _testing.assert_almost_equal( + left_valid, + right_valid, + check_dtype=bool(check_dtype), + rtol=rtol, + atol=atol, + obj=obj, + index_values=index_values, + ) + + +# This could be refactored to use the NDFrame.equals method +@set_module("pandas.testing") +@deprecate_kwarg(Pandas4Warning, "check_datetimelike_compat", new_arg_name=None) +def assert_series_equal( + left, + right, + check_dtype: bool | Literal["equiv"] = True, + check_index_type: bool | Literal["equiv"] = "equiv", + check_series_type: bool = True, + check_names: bool = True, + check_exact: bool | lib.NoDefault = lib.no_default, + check_datetimelike_compat: bool = False, + check_categorical: bool = True, + check_category_order: bool = True, + check_freq: bool = True, + check_flags: bool = True, + rtol: float | lib.NoDefault = lib.no_default, + atol: float | lib.NoDefault = lib.no_default, + obj: str = "Series", + *, + check_index: bool = True, + check_like: bool = False, +) -> None: + """ + Check that left and right Series are equal. + + Parameters + ---------- + left : Series + First Series to compare. + right : Series + Second Series to compare. + check_dtype : bool, default True + Whether to check the Series dtype is identical. + check_index_type : bool or {'equiv'}, default 'equiv' + Whether to check the Index class, dtype and inferred_type + are identical. + check_series_type : bool, default True + Whether to check the Series class is identical. + check_names : bool, default True + Whether to check the Series and Index names attribute. + check_exact : bool, default False + Whether to compare number exactly. This also applies when checking + Index equivalence. + + .. versionchanged:: 2.2.0 + + Defaults to True for integer dtypes if none of + ``check_exact``, ``rtol`` and ``atol`` are specified. + + .. versionchanged:: 3.0.0 + + check_exact for comparing the Indexes defaults to True by + checking if an Index is of integer dtypes. + + check_datetimelike_compat : bool, default False + Compare datetime-like which is comparable ignoring dtype. + + .. deprecated:: 3.0 + + check_categorical : bool, default True + Whether to compare internal Categorical exactly. + check_category_order : bool, default True + Whether to compare category order of internal Categoricals. + check_freq : bool, default True + Whether to check the `freq` attribute on a DatetimeIndex or TimedeltaIndex. + check_flags : bool, default True + Whether to check the `flags` attribute. + rtol : float, default 1e-5 + Relative tolerance. Only used when check_exact is False. + atol : float, default 1e-8 + Absolute tolerance. Only used when check_exact is False. + obj : str, default 'Series' + Specify object name being compared, internally used to show appropriate + assertion message. + check_index : bool, default True + Whether to check index equivalence. If False, then compare only values. + check_like : bool, default False + If True, ignore the order of the index. Must be False if check_index is False. + Note: same labels must be with the same data. + + See Also + -------- + testing.assert_index_equal : Check that two Indexes are equal. + testing.assert_frame_equal : Check that two DataFrames are equal. + + Examples + -------- + >>> from pandas import testing as tm + >>> a = pd.Series([1, 2, 3, 4]) + >>> b = pd.Series([1, 2, 3, 4]) + >>> tm.assert_series_equal(a, b) + """ + __tracebackhide__ = True + if ( + check_exact is lib.no_default + and rtol is lib.no_default + and atol is lib.no_default + ): + check_exact = ( + is_numeric_dtype(left.dtype) and not is_float_dtype(left.dtype) + ) or (is_numeric_dtype(right.dtype) and not is_float_dtype(right.dtype)) + left_index_dtypes = ( + [left.index.dtype] if left.index.nlevels == 1 else left.index.dtypes + ) + right_index_dtypes = ( + [right.index.dtype] if right.index.nlevels == 1 else right.index.dtypes + ) + check_exact_index = all( + dtype.kind in "iu" for dtype in left_index_dtypes + ) or all(dtype.kind in "iu" for dtype in right_index_dtypes) + elif check_exact is lib.no_default: + check_exact = False + check_exact_index = False + else: + check_exact_index = check_exact + + rtol = rtol if rtol is not lib.no_default else 1.0e-5 + atol = atol if atol is not lib.no_default else 1.0e-8 + + if not check_index and check_like: + raise ValueError("check_like must be False if check_index is False") + + # instance validation + _check_isinstance(left, right, Series) + + if check_series_type: + assert_class_equal(left, right, obj=obj) + + # length comparison + if len(left) != len(right): + msg1 = f"{len(left)}, {left.index}" + msg2 = f"{len(right)}, {right.index}" + raise_assert_detail(obj, "Series length are different", msg1, msg2) + + if check_flags: + assert left.flags == right.flags, f"{left.flags!r} != {right.flags!r}" + + if check_index: + # GH #38183 + assert_index_equal( + left.index, + right.index, + exact=check_index_type, + check_names=check_names, + check_exact=check_exact_index, + check_categorical=check_categorical, + check_order=not check_like, + rtol=rtol, + atol=atol, + obj=f"{obj}.index", + ) + + if check_like: + left = left.reindex_like(right) + + if check_freq and isinstance(left.index, (DatetimeIndex, TimedeltaIndex)): + lidx = left.index + ridx = right.index + assert lidx.freq == ridx.freq, (lidx.freq, ridx.freq) + + if check_dtype: + # We want to skip exact dtype checking when `check_categorical` + # is False. We'll still raise if only one is a `Categorical`, + # regardless of `check_categorical` + if ( + isinstance(left.dtype, CategoricalDtype) + and isinstance(right.dtype, CategoricalDtype) + and not check_categorical + ): + pass + else: + assert_attr_equal("dtype", left, right, obj=f"Attributes of {obj}") + if check_exact: + left_values = left._values + right_values = right._values + # Only check exact if dtype is numeric + if isinstance(left_values, ExtensionArray) and isinstance( + right_values, ExtensionArray + ): + assert_extension_array_equal( + left_values, + right_values, + check_dtype=check_dtype, + index_values=left.index, + obj=str(obj), + ) + else: + # convert both to NumPy if not, check_dtype would raise earlier + lv, rv = left_values, right_values + if isinstance(left_values, ExtensionArray): + lv = left_values.to_numpy() + if isinstance(right_values, ExtensionArray): + rv = right_values.to_numpy() + assert_numpy_array_equal( + lv, + rv, + check_dtype=check_dtype, + obj=str(obj), + index_values=left.index, + ) + elif check_datetimelike_compat and ( + needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype) + ): + # we want to check only if we have compat dtypes + # e.g. integer and M|m are NOT compat, but we can simply check + # the values in that case + + # datetimelike may have different objects (e.g. datetime.datetime + # vs Timestamp) but will compare equal + if not Index(left._values).equals(Index(right._values)): + msg = ( + f"[datetimelike_compat=True] {left._values} " + f"is not equal to {right._values}." + ) + raise AssertionError(msg) + elif isinstance(left.dtype, IntervalDtype) and isinstance( + right.dtype, IntervalDtype + ): + assert_interval_array_equal(left.array, right.array) + elif isinstance(left.dtype, CategoricalDtype) or isinstance( + right.dtype, CategoricalDtype + ): + _testing.assert_almost_equal( + left._values, + right._values, + rtol=rtol, + atol=atol, + check_dtype=bool(check_dtype), + obj=str(obj), + index_values=left.index, + ) + elif isinstance(left.dtype, ExtensionDtype) and isinstance( + right.dtype, ExtensionDtype + ): + assert_extension_array_equal( + left._values, + right._values, + rtol=rtol, + atol=atol, + check_dtype=check_dtype, + index_values=left.index, + obj=str(obj), + ) + elif is_extension_array_dtype_and_needs_i8_conversion( + left.dtype, right.dtype + ) or is_extension_array_dtype_and_needs_i8_conversion(right.dtype, left.dtype): + assert_extension_array_equal( + left._values, + right._values, + check_dtype=check_dtype, + index_values=left.index, + obj=str(obj), + ) + elif needs_i8_conversion(left.dtype) and needs_i8_conversion(right.dtype): + # DatetimeArray or TimedeltaArray + assert_extension_array_equal( + left._values, + right._values, + check_dtype=check_dtype, + index_values=left.index, + obj=str(obj), + ) + else: + _testing.assert_almost_equal( + left._values, + right._values, + rtol=rtol, + atol=atol, + check_dtype=bool(check_dtype), + obj=str(obj), + index_values=left.index, + ) + + # metadata comparison + if check_names: + assert_attr_equal("name", left, right, obj=obj) + + if check_categorical: + if isinstance(left.dtype, CategoricalDtype) or isinstance( + right.dtype, CategoricalDtype + ): + assert_categorical_equal( + left._values, + right._values, + obj=f"{obj} category", + check_category_order=check_category_order, + ) + + +# This could be refactored to use the NDFrame.equals method +@set_module("pandas.testing") +@deprecate_kwarg(Pandas4Warning, "check_datetimelike_compat", new_arg_name=None) +def assert_frame_equal( + left, + right, + check_dtype: bool | Literal["equiv"] = True, + check_index_type: bool | Literal["equiv"] = "equiv", + check_column_type: bool | Literal["equiv"] = "equiv", + check_frame_type: bool = True, + check_names: bool = True, + by_blocks: bool = False, + check_exact: bool | lib.NoDefault = lib.no_default, + check_datetimelike_compat: bool = False, + check_categorical: bool = True, + check_like: bool = False, + check_freq: bool = True, + check_flags: bool = True, + rtol: float | lib.NoDefault = lib.no_default, + atol: float | lib.NoDefault = lib.no_default, + obj: str = "DataFrame", +) -> None: + """ + Check that left and right DataFrame are equal. + + This function is intended to compare two DataFrames and output any + differences. It is mostly intended for use in unit tests. + Additional parameters allow varying the strictness of the + equality checks performed. + + Parameters + ---------- + left : DataFrame + First DataFrame to compare. + right : DataFrame + Second DataFrame to compare. + check_dtype : bool, default True + Whether to check the DataFrame dtype is identical. + check_index_type : bool or {'equiv'}, default 'equiv' + Whether to check the Index class, dtype and inferred_type + are identical. + check_column_type : bool or {'equiv'}, default 'equiv' + Whether to check the columns class, dtype and inferred_type + are identical. Is passed as the ``exact`` argument of + :func:`assert_index_equal`. + check_frame_type : bool, default True + Whether to check the DataFrame class is identical. + check_names : bool, default True + Whether to check that the `names` attribute for both the `index` + and `column` attributes of the DataFrame is identical. + by_blocks : bool, default False + Specify how to compare internal data. If False, compare by columns. + If True, compare by blocks. + check_exact : bool, default False + Whether to compare number exactly. If False, the comparison uses the + relative tolerance (``rtol``) and absolute tolerance (``atol``) + parameters to determine if two values are considered close, + according to the formula: ``|a - b| <= (atol + rtol * |b|)``. + + .. versionchanged:: 2.2.0 + + Defaults to True for integer dtypes if none of + ``check_exact``, ``rtol`` and ``atol`` are specified. + check_datetimelike_compat : bool, default False + Compare datetime-like which is comparable ignoring dtype. + + .. deprecated:: 3.0 + + check_categorical : bool, default True + Whether to compare internal Categorical exactly. + check_like : bool, default False + If True, ignore the order of index & columns. + Note: index labels must match their respective rows + (same as in columns) - same labels must be with the same data. + check_freq : bool, default True + Whether to check the `freq` attribute on a DatetimeIndex or TimedeltaIndex. + check_flags : bool, default True + Whether to check the `flags` attribute. + rtol : float, default 1e-5 + Relative tolerance. Only used when check_exact is False. + atol : float, default 1e-8 + Absolute tolerance. Only used when check_exact is False. + obj : str, default 'DataFrame' + Specify object name being compared, internally used to show appropriate + assertion message. + + See Also + -------- + assert_series_equal : Equivalent method for asserting Series equality. + DataFrame.equals : Check DataFrame equality. + + Examples + -------- + This example shows comparing two DataFrames that are equal + but with columns of differing dtypes. + + >>> from pandas.testing import assert_frame_equal + >>> df1 = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + >>> df2 = pd.DataFrame({"a": [1, 2], "b": [3.0, 4.0]}) + + df1 equals itself. + + >>> assert_frame_equal(df1, df1) + + df1 differs from df2 as column 'b' is of a different type. + + >>> assert_frame_equal(df1, df2) + Traceback (most recent call last): + ... + AssertionError: Attributes of DataFrame.iloc[:, 1] (column name="b") are different + + Attribute "dtype" are different + [left]: int64 + [right]: float64 + + Ignore differing dtypes in columns with check_dtype. + + >>> assert_frame_equal(df1, df2, check_dtype=False) + """ + __tracebackhide__ = True + _rtol = rtol if rtol is not lib.no_default else 1.0e-5 + _atol = atol if atol is not lib.no_default else 1.0e-8 + _check_exact = check_exact if check_exact is not lib.no_default else False + + # instance validation + _check_isinstance(left, right, DataFrame) + + if check_frame_type: + assert isinstance(left, type(right)) + # assert_class_equal(left, right, obj=obj) + + # shape comparison + if left.shape != right.shape: + raise_assert_detail( + obj, f"{obj} shape mismatch", f"{left.shape!r}", f"{right.shape!r}" + ) + + if check_flags: + assert left.flags == right.flags, f"{left.flags!r} != {right.flags!r}" + + # index comparison + assert_index_equal( + left.index, + right.index, + exact=check_index_type, + check_names=check_names, + check_exact=_check_exact, + check_categorical=check_categorical, + check_order=not check_like, + rtol=_rtol, + atol=_atol, + obj=f"{obj}.index", + ) + + # column comparison + assert_index_equal( + left.columns, + right.columns, + exact=check_column_type, + check_names=check_names, + check_exact=_check_exact, + check_categorical=check_categorical, + check_order=not check_like, + rtol=_rtol, + atol=_atol, + obj=f"{obj}.columns", + ) + + if check_like: + left = left.reindex_like(right) + + # compare by blocks + if by_blocks: + rblocks = right._to_dict_of_blocks() + lblocks = left._to_dict_of_blocks() + for dtype in list(set(list(lblocks.keys()) + list(rblocks.keys()))): + assert dtype in lblocks + assert dtype in rblocks + assert_frame_equal( + lblocks[dtype], rblocks[dtype], check_dtype=check_dtype, obj=obj + ) + + # compare by columns + else: + for i, col in enumerate(left.columns): + # We have already checked that columns match, so we can do + # fast location-based lookups + lcol = left._ixs(i, axis=1) + rcol = right._ixs(i, axis=1) + + # GH #38183 + # use check_index=False, because we do not want to run + # assert_index_equal for each column, + # as we already checked it for the whole dataframe before. + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="the 'check_datetimelike_compat' keyword", + category=Pandas4Warning, + ) + assert_series_equal( + lcol, + rcol, + check_dtype=check_dtype, + check_index_type=check_index_type, + check_exact=check_exact, + check_names=check_names, + check_datetimelike_compat=check_datetimelike_compat, + check_categorical=check_categorical, + check_freq=check_freq, + obj=f'{obj}.iloc[:, {i}] (column name="{col}")', + rtol=rtol, + atol=atol, + check_index=False, + check_flags=False, + ) + + +def assert_equal(left, right, **kwargs) -> None: + """ + Wrapper for tm.assert_*_equal to dispatch to the appropriate test function. + + Parameters + ---------- + left, right : Index, Series, DataFrame, ExtensionArray, or np.ndarray + The two items to be compared. + **kwargs + All keyword arguments are passed through to the underlying assert method. + """ + __tracebackhide__ = True + + if isinstance(left, Index): + assert_index_equal(left, right, **kwargs) + if isinstance(left, (DatetimeIndex, TimedeltaIndex)): + assert left.freq == right.freq, (left.freq, right.freq) + elif isinstance(left, Series): + assert_series_equal(left, right, **kwargs) + elif isinstance(left, DataFrame): + assert_frame_equal(left, right, **kwargs) + elif isinstance(left, IntervalArray): + assert_interval_array_equal(left, right, **kwargs) + elif isinstance(left, PeriodArray): + assert_period_array_equal(left, right, **kwargs) + elif isinstance(left, DatetimeArray): + assert_datetime_array_equal(left, right, **kwargs) + elif isinstance(left, TimedeltaArray): + assert_timedelta_array_equal(left, right, **kwargs) + elif isinstance(left, ExtensionArray): + assert_extension_array_equal(left, right, **kwargs) + elif isinstance(left, np.ndarray): + assert_numpy_array_equal(left, right, **kwargs) + elif isinstance(left, str): + assert kwargs == {} + assert left == right + else: + assert kwargs == {} + assert_almost_equal(left, right) + + +def assert_sp_array_equal(left, right) -> None: + """ + Check that the left and right SparseArray are equal. + + Parameters + ---------- + left : SparseArray + right : SparseArray + """ + _check_isinstance(left, right, pd.arrays.SparseArray) + + assert_numpy_array_equal(left.sp_values, right.sp_values) + + # SparseIndex comparison + assert isinstance(left.sp_index, SparseIndex) + assert isinstance(right.sp_index, SparseIndex) + + left_index = left.sp_index + right_index = right.sp_index + + if not left_index.equals(right_index): + raise_assert_detail( + "SparseArray.index", "index are not equal", left_index, right_index + ) + else: + # Just ensure a + pass + + assert_attr_equal("fill_value", left, right) + assert_attr_equal("dtype", left, right) + assert_numpy_array_equal(left.to_dense(), right.to_dense()) + + +def assert_contains_all(iterable, dic) -> None: + for k in iterable: + assert k in dic, f"Did not contain item: {k!r}" + + +def assert_copy(iter1, iter2, **eql_kwargs) -> None: + """ + iter1, iter2: iterables that produce elements + comparable with assert_almost_equal + + Checks that the elements are equal, but not + the same object. (Does not check that items + in sequences are also not the same object) + """ + for elem1, elem2 in zip(iter1, iter2, strict=True): + assert_almost_equal(elem1, elem2, **eql_kwargs) + msg = ( + f"Expected object {type(elem1)!r} and object {type(elem2)!r} to be " + "different objects, but they were the same object." + ) + assert elem1 is not elem2, msg + + +def is_extension_array_dtype_and_needs_i8_conversion( + left_dtype: DtypeObj, right_dtype: DtypeObj +) -> bool: + """ + Checks that we have the combination of an ExtensionArraydtype and + a dtype that should be converted to int64 + + Returns + ------- + bool + + Related to issue #37609 + """ + return isinstance(left_dtype, ExtensionDtype) and needs_i8_conversion(right_dtype) + + +def assert_indexing_slices_equivalent(ser: Series, l_slc: slice, i_slc: slice) -> None: + """ + Check that ser.iloc[i_slc] matches ser.loc[l_slc] and, if applicable, + ser[l_slc]. + """ + expected = ser.iloc[i_slc] + + assert_series_equal(ser.loc[l_slc], expected) + + if not is_integer_dtype(ser.index): + # For integer indices, .loc and plain getitem are position-based. + assert_series_equal(ser[l_slc], expected) + + +def assert_metadata_equivalent( + left: DataFrame | Series, right: DataFrame | Series | None = None +) -> None: + """ + Check that ._metadata attributes are equivalent. + """ + for attr in left._metadata: + val = getattr(left, attr, None) + if right is None: + assert val is None + else: + assert val == getattr(right, attr, None) diff --git a/pandas/_testing/compat.py b/pandas/_testing/compat.py new file mode 100644 index 0000000000000000000000000000000000000000..722ba61a3227f88821c27e7b89bc27749cbb83fd --- /dev/null +++ b/pandas/_testing/compat.py @@ -0,0 +1,30 @@ +""" +Helpers for sharing tests between DataFrame/Series +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pandas import DataFrame + +if TYPE_CHECKING: + from pandas._typing import DtypeObj + + +def get_dtype(obj) -> DtypeObj: + if isinstance(obj, DataFrame): + # Note: we are assuming only one column + return obj.dtypes.iat[0] + else: + return obj.dtype + + +def get_obj(df: DataFrame, klass): + """ + For sharing tests using frame_or_series, either return the DataFrame + unchanged or return it's first column as a Series. + """ + if klass is DataFrame: + return df + return df._ixs(0, axis=1) diff --git a/pandas/_testing/contexts.py b/pandas/_testing/contexts.py new file mode 100644 index 0000000000000000000000000000000000000000..dcb20d00904cba798ae4c9a283c7d0eed44169c4 --- /dev/null +++ b/pandas/_testing/contexts.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from contextlib import contextmanager +import os +import sys +from typing import ( + IO, + TYPE_CHECKING, +) + +from pandas.compat import CHAINED_WARNING_DISABLED +from pandas.errors import ChainedAssignmentError + +from pandas.io.common import get_handle + +if TYPE_CHECKING: + from collections.abc import Generator + + from pandas._typing import ( + BaseBuffer, + CompressionOptions, + FilePath, + ) + + +@contextmanager +def decompress_file( + path: FilePath | BaseBuffer, compression: CompressionOptions +) -> Generator[IO[bytes]]: + """ + Open a compressed file and return a file object. + + Parameters + ---------- + path : str + The path where the file is read from. + + compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd', None} + Name of the decompression to use + + Returns + ------- + file object + """ + with get_handle(path, "rb", compression=compression, is_text=False) as handle: + yield handle.handle + + +@contextmanager +def set_timezone(tz: str) -> Generator[None]: + """ + Context manager for temporarily setting a timezone. + + Parameters + ---------- + tz : str + A string representing a valid timezone. + + Examples + -------- + >>> from datetime import datetime + >>> from dateutil.tz import tzlocal + >>> tzlocal().tzname(datetime(2021, 1, 1)) # doctest: +SKIP + 'IST' + + >>> with set_timezone("US/Eastern"): + ... tzlocal().tzname(datetime(2021, 1, 1)) + 'EST' + """ + import time + + def setTZ(tz) -> None: + if hasattr(time, "tzset"): + if tz is None: + try: + del os.environ["TZ"] + except KeyError: + pass + else: + os.environ["TZ"] = tz + # Next line allows typing checks to pass on Windows + if sys.platform != "win32": + time.tzset() + + orig_tz = os.environ.get("TZ") + setTZ(tz) + try: + yield + finally: + setTZ(orig_tz) + + +@contextmanager +def with_csv_dialect(name: str, **kwargs) -> Generator[None]: + """ + Context manager to temporarily register a CSV dialect for parsing CSV. + + Parameters + ---------- + name : str + The name of the dialect. + kwargs : mapping + The parameters for the dialect. + + Raises + ------ + ValueError : the name of the dialect conflicts with a builtin one. + + See Also + -------- + csv : Python's CSV library. + """ + import csv + + _BUILTIN_DIALECTS = {"excel", "excel-tab", "unix"} + + if name in _BUILTIN_DIALECTS: + raise ValueError("Cannot override builtin dialect.") + + csv.register_dialect(name, **kwargs) + try: + yield + finally: + csv.unregister_dialect(name) + + +def raises_chained_assignment_error(extra_warnings=(), extra_match=()): + from pandas._testing import assert_produces_warning + + if CHAINED_WARNING_DISABLED: + if not extra_warnings: + from contextlib import nullcontext + + return nullcontext() + else: + return assert_produces_warning( + extra_warnings, + match=extra_match, + ) + else: + warning = ChainedAssignmentError + match = ( + "A value is being set on a copy of a DataFrame or Series " + "through chained assignment" + ) + if extra_warnings: + warning = (warning, *extra_warnings) # type: ignore[assignment] + return assert_produces_warning( + warning, + match=(match, *extra_match), + ) diff --git a/pandas/api/__init__.py b/pandas/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a016e67a41360eadb5475dc961fba2c000a7f32b --- /dev/null +++ b/pandas/api/__init__.py @@ -0,0 +1,19 @@ +"""public toolkit API""" + +from pandas.api import ( + executors, + extensions, + indexers, + interchange, + types, + typing, +) + +__all__ = [ + "executors", + "extensions", + "indexers", + "interchange", + "types", + "typing", +] diff --git a/pandas/api/executors/__init__.py b/pandas/api/executors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..04c94ee688332ea042ae4b6511a91e1a2653880f --- /dev/null +++ b/pandas/api/executors/__init__.py @@ -0,0 +1,7 @@ +""" +Public API for function executor engines to be used with ``map`` and ``apply``. +""" + +from pandas.core.apply import BaseExecutionEngine + +__all__ = ["BaseExecutionEngine"] diff --git a/pandas/api/extensions/__init__.py b/pandas/api/extensions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1c88c0d35b4d70270e2bb7b4a1a6e77f03555592 --- /dev/null +++ b/pandas/api/extensions/__init__.py @@ -0,0 +1,33 @@ +""" +Public API for extending pandas objects. +""" + +from pandas._libs.lib import no_default + +from pandas.core.dtypes.base import ( + ExtensionDtype, + register_extension_dtype, +) + +from pandas.core.accessor import ( + register_dataframe_accessor, + register_index_accessor, + register_series_accessor, +) +from pandas.core.algorithms import take +from pandas.core.arrays import ( + ExtensionArray, + ExtensionScalarOpsMixin, +) + +__all__ = [ + "ExtensionArray", + "ExtensionDtype", + "ExtensionScalarOpsMixin", + "no_default", + "register_dataframe_accessor", + "register_extension_dtype", + "register_index_accessor", + "register_series_accessor", + "take", +] diff --git a/pandas/api/indexers/__init__.py b/pandas/api/indexers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c6546218de4ce8ddc3b044fc26b6946065b07e --- /dev/null +++ b/pandas/api/indexers/__init__.py @@ -0,0 +1,17 @@ +""" +Public API for Rolling Window Indexers. +""" + +from pandas.core.indexers import check_array_indexer +from pandas.core.indexers.objects import ( + BaseIndexer, + FixedForwardWindowIndexer, + VariableOffsetWindowIndexer, +) + +__all__ = [ + "BaseIndexer", + "FixedForwardWindowIndexer", + "VariableOffsetWindowIndexer", + "check_array_indexer", +] diff --git a/pandas/api/interchange/__init__.py b/pandas/api/interchange/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aded37abc7224a6b24180beba2f52be60d7ae25d --- /dev/null +++ b/pandas/api/interchange/__init__.py @@ -0,0 +1,8 @@ +""" +Public API for DataFrame interchange protocol. +""" + +from pandas.core.interchange.dataframe_protocol import DataFrame +from pandas.core.interchange.from_dataframe import from_dataframe + +__all__ = ["DataFrame", "from_dataframe"] diff --git a/pandas/api/internals.py b/pandas/api/internals.py new file mode 100644 index 0000000000000000000000000000000000000000..03d8992a875758d3cecad32cd20f5feab9e14cca --- /dev/null +++ b/pandas/api/internals.py @@ -0,0 +1,62 @@ +import numpy as np + +from pandas._typing import ArrayLike + +from pandas import ( + DataFrame, + Index, +) +from pandas.core.internals.api import _make_block +from pandas.core.internals.managers import BlockManager as _BlockManager + + +def create_dataframe_from_blocks( + blocks: list[tuple[ArrayLike, np.ndarray]], index: Index, columns: Index +) -> DataFrame: + """ + Low-level function to create a DataFrame from arrays as they are + representing the block structure of the resulting DataFrame. + + Attention: this is an advanced, low-level function that should only be + used if you know that the below-mentioned assumptions are guaranteed. + If passing data that do not follow those assumptions, subsequent + subsequent operations on the resulting DataFrame might lead to strange + errors. + For almost all use cases, you should use the standard pd.DataFrame(..) + constructor instead. If you are planning to use this function, let us + know by opening an issue at https://github.com/pandas-dev/pandas/issues. + + Assumptions: + + - The block arrays are either a 2D numpy array or a pandas ExtensionArray + - In case of a numpy array, it is assumed to already be in the expected + shape for Blocks (2D, (cols, rows), i.e. transposed compared to the + DataFrame columns). + - All arrays are taken as is (no type inference) and expected to have the + correct size. + - The placement arrays have the correct length (equalling the number of + columns that its equivalent block array represents), and all placement + arrays together form a complete set of 0 to n_columns - 1. + + Parameters + ---------- + blocks : list of tuples of (block_array, block_placement) + This should be a list of tuples existing of (block_array, block_placement), + where: + + - block_array is a 2D numpy array or a 1D ExtensionArray, following the + requirements listed above. + - block_placement is a 1D integer numpy array + index : Index + The Index object for the `index` of the resulting DataFrame. + columns : Index + The Index object for the `columns` of the resulting DataFrame. + + Returns + ------- + DataFrame + """ + block_objs = [_make_block(*block) for block in blocks] + axes = [columns, index] + mgr = _BlockManager(block_objs, axes) + return DataFrame._from_mgr(mgr, mgr.axes) diff --git a/pandas/api/types/__init__.py b/pandas/api/types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5c742b1628b797e1015247b1a2b52e0bda4470 --- /dev/null +++ b/pandas/api/types/__init__.py @@ -0,0 +1,23 @@ +""" +Public toolkit API. +""" + +from pandas._libs.lib import infer_dtype + +from pandas.core.dtypes.api import * # noqa: F403 +from pandas.core.dtypes.concat import union_categoricals +from pandas.core.dtypes.dtypes import ( + CategoricalDtype, + DatetimeTZDtype, + IntervalDtype, + PeriodDtype, +) + +__all__ = [ + "CategoricalDtype", + "DatetimeTZDtype", + "IntervalDtype", + "PeriodDtype", + "infer_dtype", + "union_categoricals", +] diff --git a/pandas/api/typing/__init__.py b/pandas/api/typing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de6657b58ee80337097bbeb08f1ed54bb47f42d7 --- /dev/null +++ b/pandas/api/typing/__init__.py @@ -0,0 +1,61 @@ +""" +Public API classes that store intermediate results useful for type-hinting. +""" + +from pandas._libs import NaTType +from pandas._libs.lib import NoDefault +from pandas._libs.missing import NAType + +from pandas.core.col import Expression +from pandas.core.groupby import ( + DataFrameGroupBy, + SeriesGroupBy, +) +from pandas.core.indexes.frozen import FrozenList +from pandas.core.resample import ( + DatetimeIndexResamplerGroupby, + PeriodIndexResamplerGroupby, + Resampler, + TimedeltaIndexResamplerGroupby, + TimeGrouper, +) +from pandas.core.window import ( + Expanding, + ExpandingGroupby, + ExponentialMovingWindow, + ExponentialMovingWindowGroupby, + Rolling, + RollingGroupby, + Window, +) + +# TODO: Can't import Styler without importing jinja2 +# from pandas.io.formats.style import Styler +from pandas.io.json._json import JsonReader +from pandas.io.sas.sasreader import SASReader +from pandas.io.stata import StataReader + +__all__ = [ + "DataFrameGroupBy", + "DatetimeIndexResamplerGroupby", + "Expanding", + "ExpandingGroupby", + "ExponentialMovingWindow", + "ExponentialMovingWindowGroupby", + "Expression", + "FrozenList", + "JsonReader", + "NAType", + "NaTType", + "NoDefault", + "PeriodIndexResamplerGroupby", + "Resampler", + "Rolling", + "RollingGroupby", + "SASReader", + "SeriesGroupBy", + "StataReader", + "TimeGrouper", + "TimedeltaIndexResamplerGroupby", + "Window", +] diff --git a/pandas/api/typing/aliases.py b/pandas/api/typing/aliases.py new file mode 100644 index 0000000000000000000000000000000000000000..b9cad814393722a45e3e3fcbadb34e93cc968f31 --- /dev/null +++ b/pandas/api/typing/aliases.py @@ -0,0 +1,145 @@ +from pandas._typing import ( + AggFuncType, + AlignJoin, + AnyAll, + AnyArrayLike, + ArrayLike, + AstypeArg, + Axes, + Axis, + ColspaceArgType, + CompressionOptions, + CorrelationMethod, + CSVEngine, + DropKeep, + Dtype, + DtypeArg, + DtypeBackend, + DtypeObj, + ExcelWriterIfSheetExists, + ExcelWriterMergeCells, + FilePath, + FillnaOptions, + FloatFormatType, + FormattersType, + FromDictOrient, + HTMLFlavors, + IgnoreRaise, + IndexLabel, + InterpolateOptions, + IntervalClosedType, + IntervalLeftRight, + JoinHow, + JoinValidate, + JSONEngine, + JSONSerializable, + ListLike, + MergeHow, + MergeValidate, + NaPosition, + NsmallestNlargestKeep, + OpenFileErrors, + Ordered, + ParquetCompressionOptions, + QuantileInterpolation, + ReadBuffer, + ReadCsvBuffer, + ReadPickleBuffer, + ReindexMethod, + Scalar, + ScalarIndexer, + SequenceIndexer, + SequenceNotStr, + SliceType, + SortKind, + StorageOptions, + Suffixes, + TakeIndexer, + TimeAmbiguous, + TimedeltaConvertibleTypes, + TimeGrouperOrigin, + TimeNonexistent, + TimestampConvertibleTypes, + TimeUnit, + ToStataByteorder, + ToTimestampHow, + UpdateJoin, + UsecolsArgType, + WindowingRankType, + WriteBuffer, + WriteExcelBuffer, + XMLParsers, +) + +__all__ = [ + "AggFuncType", + "AlignJoin", + "AnyAll", + "AnyArrayLike", + "ArrayLike", + "AstypeArg", + "Axes", + "Axis", + "CSVEngine", + "ColspaceArgType", + "CompressionOptions", + "CorrelationMethod", + "DropKeep", + "Dtype", + "DtypeArg", + "DtypeBackend", + "DtypeObj", + "ExcelWriterIfSheetExists", + "ExcelWriterMergeCells", + "FilePath", + "FillnaOptions", + "FloatFormatType", + "FormattersType", + "FromDictOrient", + "HTMLFlavors", + "IgnoreRaise", + "IndexLabel", + "InterpolateOptions", + "IntervalClosedType", + "IntervalLeftRight", + "JSONEngine", + "JSONSerializable", + "JoinHow", + "JoinValidate", + "ListLike", + "MergeHow", + "MergeValidate", + "NaPosition", + "NsmallestNlargestKeep", + "OpenFileErrors", + "Ordered", + "ParquetCompressionOptions", + "QuantileInterpolation", + "ReadBuffer", + "ReadCsvBuffer", + "ReadPickleBuffer", + "ReindexMethod", + "Scalar", + "ScalarIndexer", + "SequenceIndexer", + "SequenceNotStr", + "SliceType", + "SortKind", + "StorageOptions", + "Suffixes", + "TakeIndexer", + "TimeAmbiguous", + "TimeGrouperOrigin", + "TimeNonexistent", + "TimeUnit", + "TimedeltaConvertibleTypes", + "TimestampConvertibleTypes", + "ToStataByteorder", + "ToTimestampHow", + "UpdateJoin", + "UsecolsArgType", + "WindowingRankType", + "WriteBuffer", + "WriteExcelBuffer", + "XMLParsers", +] diff --git a/pandas/arrays/__init__.py b/pandas/arrays/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5c1c98da1c785fdd51503ea6005c3eb014a8311 --- /dev/null +++ b/pandas/arrays/__init__.py @@ -0,0 +1,37 @@ +""" +All of pandas' ExtensionArrays. + +See :ref:`extending.extension-types` for more. +""" + +from pandas.core.arrays import ( + ArrowExtensionArray, + ArrowStringArray, + BooleanArray, + Categorical, + DatetimeArray, + FloatingArray, + IntegerArray, + IntervalArray, + NumpyExtensionArray, + PeriodArray, + SparseArray, + StringArray, + TimedeltaArray, +) + +__all__ = [ + "ArrowExtensionArray", + "ArrowStringArray", + "BooleanArray", + "Categorical", + "DatetimeArray", + "FloatingArray", + "IntegerArray", + "IntervalArray", + "NumpyExtensionArray", + "PeriodArray", + "SparseArray", + "StringArray", + "TimedeltaArray", +] diff --git a/pandas/compat/__init__.py b/pandas/compat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49b56c63961550b1d06a333324d2e70c70a5ab9c --- /dev/null +++ b/pandas/compat/__init__.py @@ -0,0 +1,173 @@ +""" +compat +====== + +Cross-compatible functions for different versions of Python. + +Other items: +* platform checker +""" + +from __future__ import annotations + +import os +import platform +import sys +from typing import TYPE_CHECKING + +from pandas.compat._constants import ( + CHAINED_WARNING_DISABLED, + IS64, + ISMUSL, + PY312, + PY314, + PYPY, + WASM, +) +from pandas.compat.numpy import is_numpy_dev +from pandas.compat.pyarrow import ( + HAS_PYARROW, + PYARROW_MIN_VERSION, + pa_version_under14p0, + pa_version_under14p1, + pa_version_under16p0, + pa_version_under17p0, + pa_version_under18p0, + pa_version_under19p0, + pa_version_under20p0, + pa_version_under21p0, +) + +if TYPE_CHECKING: + from pandas._typing import F + + +def set_function_name(f: F, name: str, cls: type) -> F: + """ + Bind the name/qualname attributes of the function. + """ + f.__name__ = name + f.__qualname__ = f"{cls.__name__}.{name}" + f.__module__ = cls.__module__ + return f + + +def is_platform_little_endian() -> bool: + """ + Checking if the running platform is little endian. + + Returns + ------- + bool + True if the running platform is little endian. + """ + return sys.byteorder == "little" + + +def is_platform_windows() -> bool: + """ + Checking if the running platform is windows. + + Returns + ------- + bool + True if the running platform is windows. + """ + return sys.platform in ["win32", "cygwin"] + + +def is_platform_linux() -> bool: + """ + Checking if the running platform is linux. + + Returns + ------- + bool + True if the running platform is linux. + """ + return sys.platform == "linux" + + +def is_platform_mac() -> bool: + """ + Checking if the running platform is mac. + + Returns + ------- + bool + True if the running platform is mac. + """ + return sys.platform == "darwin" + + +def is_platform_arm() -> bool: + """ + Checking if the running platform use ARM architecture. + + Returns + ------- + bool + True if the running platform uses ARM architecture. + """ + return platform.machine() in ("arm64", "aarch64") or platform.machine().startswith( + "armv" + ) + + +def is_platform_power() -> bool: + """ + Checking if the running platform use Power architecture. + + Returns + ------- + bool + True if the running platform uses ARM architecture. + """ + return platform.machine() in ("ppc64", "ppc64le") + + +def is_platform_riscv64() -> bool: + """ + Checking if the running platform use riscv64 architecture. + + Returns + ------- + bool + True if the running platform uses riscv64 architecture. + """ + return platform.machine() == "riscv64" + + +def is_ci_environment() -> bool: + """ + Checking if running in a continuous integration environment by checking + the PANDAS_CI environment variable. + + Returns + ------- + bool + True if the running in a continuous integration environment. + """ + return os.environ.get("PANDAS_CI", "0") == "1" + + +__all__ = [ + "CHAINED_WARNING_DISABLED", + "HAS_PYARROW", + "IS64", + "ISMUSL", + "PY312", + "PY314", + "PYARROW_MIN_VERSION", + "PYPY", + "WASM", + "is_numpy_dev", + "pa_version_under14p0", + "pa_version_under14p1", + "pa_version_under16p0", + "pa_version_under17p0", + "pa_version_under18p0", + "pa_version_under19p0", + "pa_version_under20p0", + "pa_version_under21p0", +] diff --git a/pandas/compat/_constants.py b/pandas/compat/_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad31e0725bd448c9a836009581508c6d3230bd6 --- /dev/null +++ b/pandas/compat/_constants.py @@ -0,0 +1,35 @@ +""" +_constants +====== + +Constants relevant for the Python implementation. +""" + +from __future__ import annotations + +import platform +import sys +import sysconfig + +IS64 = sys.maxsize > 2**32 + +PY312 = sys.version_info >= (3, 12) +PY314 = sys.version_info >= (3, 14) +PYPY = platform.python_implementation() == "PyPy" +WASM = (sys.platform == "emscripten") or (platform.machine() in ["wasm32", "wasm64"]) +ISMUSL = "musl" in (sysconfig.get_config_var("HOST_GNU_TYPE") or "") +# the refcount for self in a chained __setitem__/.(i)loc indexing/method call +REF_COUNT = 2 if PY314 else 3 +REF_COUNT_IDX = 2 +REF_COUNT_METHOD = 1 if PY314 else 2 +CHAINED_WARNING_DISABLED = PYPY + + +__all__ = [ + "IS64", + "ISMUSL", + "PY312", + "PY314", + "PYPY", + "WASM", +] diff --git a/pandas/compat/_optional.py b/pandas/compat/_optional.py new file mode 100644 index 0000000000000000000000000000000000000000..42bd965e88c86c989fd7d1f448fd8aed1bacad8e --- /dev/null +++ b/pandas/compat/_optional.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import importlib +import sys +from typing import ( + TYPE_CHECKING, + Literal, + overload, +) +import warnings + +from pandas.util._exceptions import find_stack_level + +from pandas.util.version import Version + +if TYPE_CHECKING: + import types + +# Update install.rst, actions-311-minimum_versions.yaml, +# deps_minimum.toml & pyproject.toml when updating versions! + +VERSIONS = { + "adbc-driver-postgresql": "1.2.0", + "adbc-driver-sqlite": "1.2.0", + "bs4": "4.12.3", + "bottleneck": "1.4.2", + "fastparquet": "2024.11.0", + "fsspec": "2024.10.0", + "html5lib": "1.1", + "hypothesis": "6.116.0", + "gcsfs": "2024.10.0", + "jinja2": "3.1.5", + "lxml.etree": "5.3.0", + "matplotlib": "3.9.3", + "numba": "0.60.0", + "numexpr": "2.10.2", + "odfpy": "1.4.1", + "openpyxl": "3.1.5", + "psycopg2": "2.9.10", # (dt dec pq3 ext lo64) + "pymysql": "1.1.1", + "pyarrow": "13.0.0", + "pyiceberg": "0.8.1", + "pyreadstat": "1.2.8", + "pytest": "8.3.4", + "python-calamine": "0.3.0", + "pytz": "2024.2", + "pyxlsb": "1.0.10", + "s3fs": "2024.10.0", + "scipy": "1.14.1", + "sqlalchemy": "2.0.36", + "tables": "3.10.1", + "tabulate": "0.9.0", + "xarray": "2024.10.0", + "xlrd": "2.0.1", + "xlsxwriter": "3.2.0", + "zstandard": "0.23.0", + "qtpy": "2.4.2", + "pyqt5": "5.15.9", +} + +# A mapping from import name to package name (on PyPI) for packages where +# these two names are different. + +INSTALL_MAPPING = { + "bs4": "beautifulsoup4", + "bottleneck": "Bottleneck", + "jinja2": "Jinja2", + "lxml.etree": "lxml", + "odf": "odfpy", + "python_calamine": "python-calamine", + "sqlalchemy": "SQLAlchemy", + "tables": "pytables", +} + + +def get_version(module: types.ModuleType) -> str: + version = getattr(module, "__version__", None) + + if version is None: + raise ImportError(f"Can't determine version for {module.__name__}") + if module.__name__ == "psycopg2": + # psycopg2 appends " (dt dec pq3 ext lo64)" to it's version + version = version.split()[0] + return version + + +@overload +def import_optional_dependency( + name: str, + extra: str = ..., + min_version: str | None = ..., + *, + errors: Literal["raise"] = ..., +) -> types.ModuleType: ... + + +@overload +def import_optional_dependency( + name: str, + extra: str = ..., + min_version: str | None = ..., + *, + errors: Literal["warn", "ignore"], +) -> types.ModuleType | None: ... + + +def import_optional_dependency( + name: str, + extra: str = "", + min_version: str | None = None, + *, + errors: Literal["raise", "warn", "ignore"] = "raise", +) -> types.ModuleType | None: + """ + Import an optional dependency. + + By default, if a dependency is missing an ImportError with a nice + message will be raised. If a dependency is present, but too old, + we raise. + + Parameters + ---------- + name : str + The module name. + extra : str + Additional text to include in the ImportError message. + errors : str {'raise', 'warn', 'ignore'} + What to do when a dependency is not found or its version is too old. + + * raise : Raise an ImportError + * warn : Only applicable when a module's version is to old. + Warns that the version is too old and returns None + * ignore: If the module is not installed, return None, otherwise, + return the module, even if the version is too old. + It's expected that users validate the version locally when + using ``errors="ignore"`` (see. ``io/html.py``) + min_version : str, default None + Specify a minimum version that is different from the global pandas + minimum version required. + Returns + ------- + maybe_module : Optional[ModuleType] + The imported module, when found and the version is correct. + None is returned when the package is not found and `errors` + is False, or when the package's version is too old and `errors` + is ``'warn'`` or ``'ignore'``. + """ + assert errors in {"warn", "raise", "ignore"} + + package_name = INSTALL_MAPPING.get(name) + install_name = package_name if package_name is not None else name + + msg = ( + f"`Import {install_name}` failed. {extra} " + f"Use pip or conda to install the {install_name} package." + ) + try: + module = importlib.import_module(name) + except ImportError as err: + if errors == "raise": + raise ImportError(msg) from err + return None + + # Handle submodules: if we have submodule, grab parent module from sys.modules + parent = name.split(".")[0] + if parent != name: + install_name = parent + module_to_get = sys.modules[install_name] + else: + module_to_get = module + minimum_version = min_version if min_version is not None else VERSIONS.get(parent) + if minimum_version: + version = get_version(module_to_get) + if version and Version(version) < Version(minimum_version): + msg = ( + f"Pandas requires version '{minimum_version}' or newer of '{parent}' " + f"(version '{version}' currently installed)." + ) + if errors == "warn": + warnings.warn( + msg, + UserWarning, + stacklevel=find_stack_level(), + ) + return None + elif errors == "raise": + raise ImportError(msg) + else: + return None + + return module diff --git a/pandas/compat/pickle_compat.py b/pandas/compat/pickle_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..beb4a69232b277b92c85a1f995269e0ef278a43f --- /dev/null +++ b/pandas/compat/pickle_compat.py @@ -0,0 +1,143 @@ +""" +Pickle compatibility to pandas version 1.0 +""" + +from __future__ import annotations + +import contextlib +import io +import pickle +from typing import ( + TYPE_CHECKING, + Any, +) + +import numpy as np + +from pandas._libs.arrays import NDArrayBacked +from pandas._libs.tslibs import BaseOffset + +from pandas.core.arrays import ( + DatetimeArray, + PeriodArray, + TimedeltaArray, +) +from pandas.core.internals import BlockManager + +if TYPE_CHECKING: + from collections.abc import Generator + + +# If classes are moved, provide compat here. +_class_locations_map = { + # Re-routing unpickle block logic to go through _unpickle_block instead + # for pandas <= 1.3.5 + ("pandas.core.internals.blocks", "new_block"): ( + "pandas._libs.internals", + "_unpickle_block", + ), + # Avoid Cython's warning "contradiction to Python 'class private name' rules" + ("pandas._libs.tslibs.nattype", "__nat_unpickle"): ( + "pandas._libs.tslibs.nattype", + "_nat_unpickle", + ), + # 50775, remove Int64Index, UInt64Index & Float64Index from codebase + ("pandas.core.indexes.numeric", "Int64Index"): ( + "pandas.core.indexes.base", + "Index", + ), + ("pandas.core.indexes.numeric", "UInt64Index"): ( + "pandas.core.indexes.base", + "Index", + ), + ("pandas.core.indexes.numeric", "Float64Index"): ( + "pandas.core.indexes.base", + "Index", + ), + ("pandas.core.arrays.sparse.dtype", "SparseDtype"): ( + "pandas.core.dtypes.dtypes", + "SparseDtype", + ), +} + + +# our Unpickler sub-class to override methods and some dispatcher +# functions for compat and uses a non-public class of the pickle module. +class Unpickler(pickle._Unpickler): + def find_class(self, module: str, name: str) -> Any: + key = (module, name) + module, name = _class_locations_map.get(key, key) + return super().find_class(module, name) + + dispatch = pickle._Unpickler.dispatch.copy() + + def load_reduce(self) -> None: + stack = self.stack # type: ignore[attr-defined] + args = stack.pop() + func = stack[-1] + + try: + stack[-1] = func(*args) + except TypeError: + # If we have a deprecated function, + # try to replace and try again. + if args and isinstance(args[0], type) and issubclass(args[0], BaseOffset): + # TypeError: object.__new__(Day) is not safe, use Day.__new__() + cls = args[0] + stack[-1] = cls.__new__(*args) + return + elif args and issubclass(args[0], PeriodArray): + cls = args[0] + stack[-1] = NDArrayBacked.__new__(*args) + return + raise + + dispatch[pickle.REDUCE[0]] = load_reduce # type: ignore[assignment] + + def load_newobj(self) -> None: + args = self.stack.pop() # type: ignore[attr-defined] + cls = self.stack.pop() # type: ignore[attr-defined] + + # compat + if issubclass(cls, DatetimeArray) and not args: + arr = np.array([], dtype="M8[ns]") + obj = cls.__new__(cls, arr, arr.dtype) + elif issubclass(cls, TimedeltaArray) and not args: + arr = np.array([], dtype="m8[ns]") + obj = cls.__new__(cls, arr, arr.dtype) + elif cls is BlockManager and not args: + obj = cls.__new__(cls, (), [], False) + else: + obj = cls.__new__(cls, *args) + self.append(obj) # type: ignore[attr-defined] + + dispatch[pickle.NEWOBJ[0]] = load_newobj # type: ignore[assignment] + + +def loads( + bytes_object: bytes, + *, + fix_imports: bool = True, + encoding: str = "ASCII", + errors: str = "strict", +) -> Any: + """ + Analogous to pickle._loads. + """ + fd = io.BytesIO(bytes_object) + return Unpickler( + fd, fix_imports=fix_imports, encoding=encoding, errors=errors + ).load() + + +@contextlib.contextmanager +def patch_pickle() -> Generator[None]: + """ + Temporarily patch pickle to use our unpickler. + """ + orig_loads = pickle.loads + try: + setattr(pickle, "loads", loads) + yield + finally: + setattr(pickle, "loads", orig_loads) diff --git a/pandas/compat/pyarrow.py b/pandas/compat/pyarrow.py new file mode 100644 index 0000000000000000000000000000000000000000..fe71e8a82cd936b683c4234e732dc8ee3097af27 --- /dev/null +++ b/pandas/compat/pyarrow.py @@ -0,0 +1,91 @@ +"""support pyarrow compatibility across versions""" + +from __future__ import annotations + +import sys +from typing import Any + +from pandas.util.version import Version + +PYARROW_MIN_VERSION = "13.0.0" +try: + import pyarrow as pa + + _palv = Version(Version(pa.__version__).base_version) + pa_version_under14p0 = _palv < Version("14.0.0") + pa_version_under14p1 = _palv < Version("14.0.1") + pa_version_under15p0 = _palv < Version("15.0.0") + pa_version_under16p0 = _palv < Version("16.0.0") + pa_version_under17p0 = _palv < Version("17.0.0") + pa_version_under18p0 = _palv < Version("18.0.0") + pa_version_under19p0 = _palv < Version("19.0.0") + pa_version_under20p0 = _palv < Version("20.0.0") + pa_version_under21p0 = _palv < Version("21.0.0") + pa_version_under22p0 = _palv < Version("22.0.0") + HAS_PYARROW = _palv >= Version(PYARROW_MIN_VERSION) +except ImportError: + pa_version_under14p0 = True + pa_version_under14p1 = True + pa_version_under15p0 = True + pa_version_under16p0 = True + pa_version_under17p0 = True + pa_version_under18p0 = True + pa_version_under19p0 = True + pa_version_under20p0 = True + pa_version_under21p0 = True + pa_version_under22p0 = True + HAS_PYARROW = False + + +def _safe_fill_null( + arr: pa.Array | pa.ChunkedArray, fill_value: Any +) -> pa.Array | pa.ChunkedArray: + """ + Safe wrapper for pyarrow.compute.fill_null with fallback for Windows + pyarrow 21. + + pyarrow 21.0.0 on Windows has a bug in fill_null that incorrectly fills null values. + This function uses a fallback implementation for that specific case, otherwise uses + the standard pyarrow.compute.fill_null. + + Parameters + ---------- + arr : pyarrow.Array | pyarrow.ChunkedArray + Input array with potential null values. + fill_value : Any + Value to fill nulls with. + + Returns + ------- + pyarrow.Array | pyarrow.ChunkedArray + Array with nulls filled with fill_value. + """ + import pyarrow.compute as pc + + is_windows = sys.platform in ["win32", "cygwin"] + use_fallback = ( + HAS_PYARROW and is_windows and not pa_version_under21p0 and pa_version_under22p0 + ) + if not use_fallback or isinstance(fill_value, (pa.Array, pa.ChunkedArray)): + return pc.fill_null(arr, fill_value) + + fill_scalar = pa.scalar(fill_value, type=arr.type) + + if pa.types.is_duration(arr.type): + + def fill_null_duration(arr: pa.Array, fill_scalar: pa.Scalar) -> pa.Array: + mask = pc.is_null(arr) + zero_duration = pa.scalar(0, type=arr.type) + arr_zeroed = pc.if_else(mask, zero_duration, arr) + return pc.if_else(mask, fill_scalar, arr_zeroed) + + if isinstance(arr, pa.ChunkedArray): + return pa.chunked_array( + [fill_null_duration(chunk, fill_scalar) for chunk in arr.chunks] + ) + return fill_null_duration(arr, fill_scalar) + + if isinstance(arr, pa.ChunkedArray): + return pa.chunked_array( + [pc.if_else(pc.is_null(chunk), fill_scalar, chunk) for chunk in arr.chunks] + ) + return pc.if_else(pc.is_null(arr), fill_scalar, arr) diff --git a/pandas/core/__init__.py b/pandas/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/core/accessor.py b/pandas/core/accessor.py new file mode 100644 index 0000000000000000000000000000000000000000..4163de0d2cf011ea5040215ae3892ba8b4541f91 --- /dev/null +++ b/pandas/core/accessor.py @@ -0,0 +1,588 @@ +""" + +accessor.py contains base classes for implementing accessor properties +that can be mixed into or pinned onto other pandas classes. + +""" + +from __future__ import annotations + +import functools +from typing import ( + TYPE_CHECKING, + final, +) +import warnings + +from pandas.util._decorators import ( + set_module, +) +from pandas.util._exceptions import find_stack_level + +if TYPE_CHECKING: + from collections.abc import Callable + + from pandas._typing import TypeT + + from pandas import Index + from pandas.core.generic import NDFrame + + +class DirNamesMixin: + _accessors: set[str] = set() + _hidden_attrs: frozenset[str] = frozenset() + + @final + def _dir_deletions(self) -> set[str]: + """ + Delete unwanted __dir__ for this object. + """ + return self._accessors | self._hidden_attrs + + def _dir_additions(self) -> set[str]: + """ + Add additional __dir__ for this object. + """ + return {accessor for accessor in self._accessors if hasattr(self, accessor)} + + def __dir__(self) -> list[str]: + """ + Provide method name lookup and completion. + + Notes + ----- + Only provide 'public' methods. + """ + rv = set(super().__dir__()) + rv = (rv - self._dir_deletions()) | self._dir_additions() + return sorted(rv) + + +class PandasDelegate: + """ + Abstract base class for delegating methods/properties. + """ + + def _delegate_property_get(self, name: str, *args, **kwargs): + raise TypeError(f"You cannot access the property {name}") + + def _delegate_property_set(self, name: str, value, *args, **kwargs) -> None: + raise TypeError(f"The property {name} cannot be set") + + def _delegate_method(self, name: str, *args, **kwargs): + raise TypeError(f"You cannot call method {name}") + + @classmethod + def _add_delegate_accessors( + cls, + delegate, + accessors: list[str], + typ: str, + overwrite: bool = False, + accessor_mapping: Callable[[str], str] = lambda x: x, + raise_on_missing: bool = True, + ) -> None: + """ + Add accessors to cls from the delegate class. + + Parameters + ---------- + cls + Class to add the methods/properties to. + delegate + Class to get methods/properties and docstrings. + accessors : list of str + List of accessors to add. + typ : {'property', 'method'} + overwrite : bool, default False + Overwrite the method/property in the target class if it exists. + accessor_mapping: Callable, default lambda x: x + Callable to map the delegate's function to the cls' function. + raise_on_missing: bool, default True + Raise if an accessor does not exist on delegate. + False skips the missing accessor. + """ + + def _create_delegator_property(name: str): + def _getter(self): + return self._delegate_property_get(name) + + def _setter(self, new_values): + return self._delegate_property_set(name, new_values) + + _getter.__name__ = name + _setter.__name__ = name + + return property( + fget=_getter, + fset=_setter, + doc=getattr(delegate, accessor_mapping(name)).__doc__, + ) + + def _create_delegator_method(name: str): + method = getattr(delegate, accessor_mapping(name)) + + @functools.wraps(method) + def f(self, *args, **kwargs): + return self._delegate_method(name, *args, **kwargs) + + return f + + for name in accessors: + if ( + not raise_on_missing + and getattr(delegate, accessor_mapping(name), None) is None + ): + continue + + if typ == "property": + f = _create_delegator_property(name) + else: + f = _create_delegator_method(name) + + # don't overwrite existing methods/properties + if overwrite or not hasattr(cls, name): + setattr(cls, name, f) + + +def delegate_names( + delegate, + accessors: list[str], + typ: str, + overwrite: bool = False, + accessor_mapping: Callable[[str], str] = lambda x: x, + raise_on_missing: bool = True, +): + """ + Add delegated names to a class using a class decorator. This provides + an alternative usage to directly calling `_add_delegate_accessors` + below a class definition. + + Parameters + ---------- + delegate : object + The class to get methods/properties & docstrings. + accessors : Sequence[str] + List of accessor to add. + typ : {'property', 'method'} + overwrite : bool, default False + Overwrite the method/property in the target class if it exists. + accessor_mapping: Callable, default lambda x: x + Callable to map the delegate's function to the cls' function. + raise_on_missing: bool, default True + Raise if an accessor does not exist on delegate. + False skips the missing accessor. + + Returns + ------- + callable + A class decorator. + + Examples + -------- + @delegate_names(Categorical, ["categories", "ordered"], "property") + class CategoricalAccessor(PandasDelegate): + [...] + """ + + def add_delegate_accessors(cls): + cls._add_delegate_accessors( + delegate, + accessors, + typ, + overwrite=overwrite, + accessor_mapping=accessor_mapping, + raise_on_missing=raise_on_missing, + ) + return cls + + return add_delegate_accessors + + +class Accessor: + """ + Custom property-like object. + + A descriptor for accessors. + + Parameters + ---------- + name : str + Namespace that will be accessed under, e.g. ``df.foo``. + accessor : cls + Class with the extension methods. + + Notes + ----- + For accessor, The class's __init__ method assumes that one of + ``Series``, ``DataFrame`` or ``Index`` as the + single argument ``data``. + """ + + def __init__(self, name: str, accessor) -> None: + self._name = name + self._accessor = accessor + + def __get__(self, obj, cls): + if obj is None: + # we're accessing the attribute of the class, i.e., Dataset.geo + return self._accessor + return self._accessor(obj) + + +# Alias kept for downstream libraries +# TODO: Deprecate as name is now misleading +CachedAccessor = Accessor + + +def _register_accessor( + name: str, cls: type[NDFrame | Index] +) -> Callable[[TypeT], TypeT]: + """ + Register a custom accessor on objects. + + Parameters + ---------- + name : str + Name under which the accessor should be registered. A warning is issued + if this name conflicts with a preexisting attribute. + + Returns + ------- + callable + A class decorator. + + See Also + -------- + register_dataframe_accessor : Register a custom accessor on DataFrame objects. + register_series_accessor : Register a custom accessor on Series objects. + register_index_accessor : Register a custom accessor on Index objects. + + Notes + ----- + This function allows you to register a custom-defined accessor class + for pandas objects (DataFrame, Series, or Index). + The requirements for the accessor class are as follows: + + * Must contain an init method that: + + * accepts a single object + + * raises an AttributeError if the object does not have correctly + matching inputs for the accessor + + * Must contain a method for each access pattern. + + * The methods should be able to take any argument signature. + + * Accessible using the @property decorator if no additional arguments are + needed. + + """ + + def decorator(accessor: TypeT) -> TypeT: + if hasattr(cls, name): + warnings.warn( + f"registration of accessor {accessor!r} under name " + f"{name!r} for type {cls!r} is overriding a preexisting " + f"attribute with the same name.", + UserWarning, + stacklevel=find_stack_level(), + ) + setattr(cls, name, Accessor(name, accessor)) + cls._accessors.add(name) + return accessor + + return decorator + + +_register_df_examples = """ +An accessor that only accepts integers could +have a class defined like this: + +>>> @pd.api.extensions.register_dataframe_accessor("int_accessor") +... class IntAccessor: +... def __init__(self, pandas_obj): +... if not all(pandas_obj[col].dtype == 'int64' for col in pandas_obj.columns): +... raise AttributeError("All columns must contain integer values only") +... self._obj = pandas_obj +... +... def sum(self): +... return self._obj.sum() +... +>>> df = pd.DataFrame([[1, 2], ['x', 'y']]) +>>> df.int_accessor +Traceback (most recent call last): +... +AttributeError: All columns must contain integer values only. +>>> df = pd.DataFrame([[1, 2], [3, 4]]) +>>> df.int_accessor.sum() +0 4 +1 6 +dtype: int64""" + + +@set_module("pandas.api.extensions") +def register_dataframe_accessor(name: str) -> Callable[[TypeT], TypeT]: + """ + Register a custom accessor on DataFrame objects. + + Parameters + ---------- + name : str + Name under which the accessor should be registered. A warning is issued + if this name conflicts with a preexisting attribute. + + Returns + ------- + callable + A class decorator. + + See Also + -------- + register_dataframe_accessor : Register a custom accessor on DataFrame objects. + register_series_accessor : Register a custom accessor on Series objects. + register_index_accessor : Register a custom accessor on Index objects. + + Notes + ----- + This function allows you to register a custom-defined accessor class for DataFrame. + The requirements for the accessor class are as follows: + + * Must contain an init method that: + + * accepts a single DataFrame object + + * raises an AttributeError if the DataFrame object does not have correctly + matching inputs for the accessor + + * Must contain a method for each access pattern. + + * The methods should be able to take any argument signature. + + * Accessible using the @property decorator if no additional arguments are + needed. + + Examples + -------- + An accessor that only accepts integers could + have a class defined like this: + + >>> @pd.api.extensions.register_dataframe_accessor("int_accessor") + ... class IntAccessor: + ... def __init__(self, pandas_obj): + ... if not all( + ... pandas_obj[col].dtype == "int64" for col in pandas_obj.columns + ... ): + ... raise AttributeError("All columns must contain integer values only") + ... self._obj = pandas_obj + ... + ... def sum(self): + ... return self._obj.sum() + >>> df = pd.DataFrame([[1, 2], ["x", "y"]]) + >>> df.int_accessor + Traceback (most recent call last): + ... + AttributeError: All columns must contain integer values only. + >>> df = pd.DataFrame([[1, 2], [3, 4]]) + >>> df.int_accessor.sum() + 0 4 + 1 6 + dtype: int64 + """ + from pandas import DataFrame + + return _register_accessor(name, DataFrame) + + +_register_series_examples = """ +An accessor that only accepts integers could +have a class defined like this: + +>>> @pd.api.extensions.register_series_accessor("int_accessor") +... class IntAccessor: +... def __init__(self, pandas_obj): +... if not pandas_obj.dtype == 'int64': +... raise AttributeError("The series must contain integer data only") +... self._obj = pandas_obj +... +... def sum(self): +... return self._obj.sum() +... +>>> df = pd.Series([1, 2, 'x']) +>>> df.int_accessor +Traceback (most recent call last): +... +AttributeError: The series must contain integer data only. +>>> df = pd.Series([1, 2, 3]) +>>> df.int_accessor.sum() +6""" + + +@set_module("pandas.api.extensions") +def register_series_accessor(name: str) -> Callable[[TypeT], TypeT]: + """ + Register a custom accessor on Series objects. + + Parameters + ---------- + name : str + Name under which the accessor should be registered. A warning is issued + if this name conflicts with a preexisting attribute. + + Returns + ------- + callable + A class decorator. + + See Also + -------- + register_dataframe_accessor : Register a custom accessor on DataFrame objects. + register_series_accessor : Register a custom accessor on Series objects. + register_index_accessor : Register a custom accessor on Index objects. + + Notes + ----- + This function allows you to register a custom-defined accessor class for Series. + The requirements for the accessor class are as follows: + + * Must contain an init method that: + + * accepts a single Series object + + * raises an AttributeError if the Series object does not have correctly + matching inputs for the accessor + + * Must contain a method for each access pattern. + + * The methods should be able to take any argument signature. + + * Accessible using the @property decorator if no additional arguments are + needed. + + Examples + -------- + An accessor that only accepts integers could + have a class defined like this: + + >>> @pd.api.extensions.register_series_accessor("int_accessor") + ... class IntAccessor: + ... def __init__(self, pandas_obj): + ... if not pandas_obj.dtype == "int64": + ... raise AttributeError("The series must contain integer data only") + ... self._obj = pandas_obj + ... + ... def sum(self): + ... return self._obj.sum() + >>> df = pd.Series([1, 2, "x"]) + >>> df.int_accessor + Traceback (most recent call last): + ... + AttributeError: The series must contain integer data only. + >>> df = pd.Series([1, 2, 3]) + >>> df.int_accessor.sum() + 6 + """ + from pandas import Series + + return _register_accessor(name, Series) + + +_register_index_examples = """ +An accessor that only accepts integers could +have a class defined like this: + +>>> @pd.api.extensions.register_index_accessor("int_accessor") +... class IntAccessor: +... def __init__(self, pandas_obj): +... if not all(isinstance(x, int) for x in pandas_obj): +... raise AttributeError("The index must only be an integer value") +... self._obj = pandas_obj +... +... def even(self): +... return [x for x in self._obj if x % 2 == 0] +>>> df = pd.DataFrame.from_dict( +... {"row1": {"1": 1, "2": "a"}, "row2": {"1": 2, "2": "b"}}, orient="index" +... ) +>>> df.index.int_accessor +Traceback (most recent call last): +... +AttributeError: The index must only be an integer value. +>>> df = pd.DataFrame( +... {"col1": [1, 2, 3, 4], "col2": ["a", "b", "c", "d"]}, index=[1, 2, 5, 8] +... ) +>>> df.index.int_accessor.even() +[2, 8]""" + + +@set_module("pandas.api.extensions") +def register_index_accessor(name: str) -> Callable[[TypeT], TypeT]: + """ + Register a custom accessor on Index objects. + + Parameters + ---------- + name : str + Name under which the accessor should be registered. A warning is issued + if this name conflicts with a preexisting attribute. + + Returns + ------- + callable + A class decorator. + + See Also + -------- + register_dataframe_accessor : Register a custom accessor on DataFrame objects. + register_series_accessor : Register a custom accessor on Series objects. + register_index_accessor : Register a custom accessor on Index objects. + + Notes + ----- + This function allows you to register a custom-defined accessor class for Index. + The requirements for the accessor class are as follows: + + * Must contain an init method that: + + * accepts a single Index object + + * raises an AttributeError if the Index object does not have correctly + matching inputs for the accessor + + * Must contain a method for each access pattern. + + * The methods should be able to take any argument signature. + + * Accessible using the @property decorator if no additional arguments are + needed. + + Examples + -------- + An accessor that only accepts integers could + have a class defined like this: + + >>> @pd.api.extensions.register_index_accessor("int_accessor") + ... class IntAccessor: + ... def __init__(self, pandas_obj): + ... if not all(isinstance(x, int) for x in pandas_obj): + ... raise AttributeError("The index must only be an integer value") + ... self._obj = pandas_obj + ... + ... def even(self): + ... return [x for x in self._obj if x % 2 == 0] + >>> df = pd.DataFrame.from_dict( + ... {"row1": {"1": 1, "2": "a"}, "row2": {"1": 2, "2": "b"}}, orient="index" + ... ) + >>> df.index.int_accessor + Traceback (most recent call last): + ... + AttributeError: The index must only be an integer value. + >>> df = pd.DataFrame( + ... {"col1": [1, 2, 3, 4], "col2": ["a", "b", "c", "d"]}, index=[1, 2, 5, 8] + ... ) + >>> df.index.int_accessor.even() + [2, 8] + """ + from pandas import Index + + return _register_accessor(name, Index) diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py new file mode 100644 index 0000000000000000000000000000000000000000..4683cf9fc744a910e5a5d6749e15942b37ae30af --- /dev/null +++ b/pandas/core/algorithms.py @@ -0,0 +1,1712 @@ +""" +Generic data algorithms. This module is experimental at the moment and not +intended for public consumption +""" + +from __future__ import annotations + +import decimal +import operator +from typing import ( + TYPE_CHECKING, + Literal, + TypeVar, + cast, + overload, +) +import warnings + +import numpy as np + +from pandas._libs import ( + algos, + hashtable as htable, + iNaT, + lib, +) +from pandas._libs.missing import NA +from pandas._typing import ( + AnyArrayLike, + ArrayLike, + ArrayLikeT, + AxisInt, + DtypeObj, + TakeIndexer, + npt, +) +from pandas.util._decorators import set_module +from pandas.util._exceptions import find_stack_level + +from pandas.core.dtypes.cast import ( + construct_1d_object_array_from_listlike, + np_find_common_type, +) +from pandas.core.dtypes.common import ( + ensure_float64, + ensure_object, + ensure_platform_int, + is_bool_dtype, + is_complex_dtype, + is_dict_like, + is_dtype_equal, + is_extension_array_dtype, + is_float, + is_float_dtype, + is_integer, + is_integer_dtype, + is_list_like, + is_object_dtype, + is_signed_integer_dtype, + needs_i8_conversion, +) +from pandas.core.dtypes.concat import concat_compat +from pandas.core.dtypes.dtypes import ( + BaseMaskedDtype, + CategoricalDtype, + ExtensionDtype, + NumpyEADtype, +) +from pandas.core.dtypes.generic import ( + ABCDatetimeArray, + ABCExtensionArray, + ABCIndex, + ABCMultiIndex, + ABCNumpyExtensionArray, + ABCSeries, + ABCTimedeltaArray, +) +from pandas.core.dtypes.missing import ( + isna, + na_value_for_dtype, +) + +from pandas.core.array_algos.take import take_nd +from pandas.core.construction import ( + array as pd_array, + ensure_wrapped_if_datetimelike, + extract_array, +) +from pandas.core.indexers import validate_indices + +if TYPE_CHECKING: + from pandas._typing import ( + ListLike, + NumpySorter, + NumpyValueArrayLike, + ) + + from pandas import ( + Categorical, + Index, + Series, + ) + from pandas.core.arrays import ( + BaseMaskedArray, + ExtensionArray, + ) + + T = TypeVar("T", bound=Index | Categorical | ExtensionArray) + + +# --------------- # +# dtype access # +# --------------- # +def _ensure_data(values: ArrayLike) -> np.ndarray: + """ + routine to ensure that our data is of the correct + input dtype for lower-level routines + + This will coerce: + - ints -> int64 + - uint -> uint64 + - bool -> uint8 + - datetimelike -> i8 + - datetime64tz -> i8 (in local tz) + - categorical -> codes + + Parameters + ---------- + values : np.ndarray or ExtensionArray + + Returns + ------- + np.ndarray + """ + + if not isinstance(values, ABCMultiIndex): + # extract_array would raise + values = extract_array(values, extract_numpy=True) + + if is_object_dtype(values.dtype): + return ensure_object(np.asarray(values)) + + elif isinstance(values.dtype, BaseMaskedDtype): + # i.e. BooleanArray, FloatingArray, IntegerArray + values = cast("BaseMaskedArray", values) + if not values._hasna: + # No pd.NAs -> We can avoid an object-dtype cast (and copy) GH#41816 + # recurse to avoid re-implementing logic for eg bool->uint8 + return _ensure_data(values._data) + return np.asarray(values) + + elif isinstance(values.dtype, CategoricalDtype): + # NB: cases that go through here should NOT be using _reconstruct_data + # on the back-end. + values = cast("Categorical", values) + return values.codes + + elif is_bool_dtype(values.dtype): + if isinstance(values, np.ndarray): + # i.e. actually dtype == np.dtype("bool") + return np.asarray(values).view("uint8") + else: + # e.g. Sparse[bool, False] # TODO: no test cases get here + return np.asarray(values).astype("uint8", copy=False) + + elif is_integer_dtype(values.dtype): + return np.asarray(values) + + elif is_float_dtype(values.dtype): + # Note: checking `values.dtype == "float128"` raises on Windows and 32bit + # error: Item "ExtensionDtype" of "Union[Any, ExtensionDtype, dtype[Any]]" + # has no attribute "itemsize" + if values.dtype.itemsize in [2, 12, 16]: # type: ignore[union-attr] + # we dont (yet) have float128 hashtable support + return ensure_float64(values) + return np.asarray(values) + + elif is_complex_dtype(values.dtype): + return cast(np.ndarray, values) + + # datetimelike + elif needs_i8_conversion(values.dtype): + npvalues = values.view("i8") + npvalues = cast(np.ndarray, npvalues) + return npvalues + + # we have failed, return object + values = np.asarray(values, dtype=object) + return ensure_object(values) + + +def _reconstruct_data( + values: ArrayLikeT, dtype: DtypeObj, original: AnyArrayLike +) -> ArrayLikeT: + """ + reverse of _ensure_data + + Parameters + ---------- + values : np.ndarray or ExtensionArray + dtype : np.dtype or ExtensionDtype + original : AnyArrayLike + + Returns + ------- + ExtensionArray or np.ndarray + """ + if isinstance(values, ABCExtensionArray) and values.dtype == dtype: + # Catch DatetimeArray/TimedeltaArray + return values + + if not isinstance(dtype, np.dtype): + # i.e. ExtensionDtype; note we have ruled out above the possibility + # that values.dtype == dtype + cls = dtype.construct_array_type() + + # error: Incompatible return value type + # (got "ExtensionArray", + # expected "ndarray[tuple[Any, ...], dtype[Any]]") + return cls._from_sequence(values, dtype=dtype) # type: ignore[return-value] + + # error: Incompatible return value type + # (got "ndarray[tuple[Any, ...], dtype[Any]]", + # expected "ExtensionArray") + return values.astype(dtype, copy=False) # type: ignore[return-value] + + +def _ensure_arraylike(values, func_name: str) -> ArrayLike: + """ + ensure that we are arraylike if not already + """ + if not isinstance( + values, + (ABCIndex, ABCSeries, ABCExtensionArray, np.ndarray, ABCNumpyExtensionArray), + ): + # GH#52986 + if func_name != "isin-targets": + # Make an exception for the comps argument in isin. + raise TypeError( + f"{func_name} requires a Series, Index, " + f"ExtensionArray, np.ndarray or NumpyExtensionArray " + f"got {type(values).__name__}." + ) + + inferred = lib.infer_dtype(values, skipna=False) + if inferred in ["mixed", "string", "mixed-integer"]: + # "mixed-integer" to ensure we do not cast ["ss", 42] to str GH#22160 + if isinstance(values, tuple): + values = list(values) + values = construct_1d_object_array_from_listlike(values) + else: + values = np.asarray(values) + return values + + +_hashtables = { + "complex128": htable.Complex128HashTable, + "complex64": htable.Complex64HashTable, + "float64": htable.Float64HashTable, + "float32": htable.Float32HashTable, + "uint64": htable.UInt64HashTable, + "uint32": htable.UInt32HashTable, + "uint16": htable.UInt16HashTable, + "uint8": htable.UInt8HashTable, + "int64": htable.Int64HashTable, + "int32": htable.Int32HashTable, + "int16": htable.Int16HashTable, + "int8": htable.Int8HashTable, + "string": htable.StringHashTable, + "object": htable.PyObjectHashTable, +} + + +def _get_hashtable_algo( + values: np.ndarray, +) -> tuple[type[htable.HashTable], np.ndarray]: + """ + Parameters + ---------- + values : np.ndarray + + Returns + ------- + htable : HashTable subclass + values : ndarray + """ + values = _ensure_data(values) + + ndtype = _check_object_for_strings(values) + hashtable = _hashtables[ndtype] + return hashtable, values + + +def _check_object_for_strings(values: np.ndarray) -> str: + """ + Check if we can use string hashtable instead of object hashtable. + + Parameters + ---------- + values : ndarray + + Returns + ------- + str + """ + ndtype = values.dtype.name + if ndtype == "object": + # it's cheaper to use a String Hash Table than Object; we infer + # including nulls because that is the only difference between + # StringHashTable and ObjectHashtable + if lib.is_string_array(values, skipna=False): + ndtype = "string" + return ndtype + + +# --------------- # +# top-level algos # +# --------------- # + + +@overload +def unique(values: T) -> T: ... +@overload +def unique(values: np.ndarray | Series) -> np.ndarray: ... + + +@set_module("pandas") +def unique(values): + """ + Return unique values based on a hash table. + + Uniques are returned in order of appearance. This does NOT sort. + + Significantly faster than numpy.unique for long enough sequences. + Includes NA values. + + Parameters + ---------- + values : 1d array-like + The input array-like object containing values from which to extract + unique values. + + Returns + ------- + numpy.ndarray, ExtensionArray or NumpyExtensionArray + + The return can be: + + * Index : when the input is an Index + * Categorical : when the input is a Categorical dtype + * ndarray : when the input is a Series/ndarray + + Return numpy.ndarray, ExtensionArray or NumpyExtensionArray. + + See Also + -------- + Index.unique : Return unique values from an Index. + Series.unique : Return unique values of Series object. + + Examples + -------- + >>> pd.unique(pd.Series([2, 1, 3, 3])) + array([2, 1, 3]) + + >>> pd.unique(pd.Series([2] + [1] * 5)) + array([2, 1]) + + >>> pd.unique(pd.Series([pd.Timestamp("20160101"), pd.Timestamp("20160101")])) + array(['2016-01-01T00:00:00.000000'], dtype='datetime64[us]') + + >>> pd.unique( + ... pd.Series( + ... [ + ... pd.Timestamp("20160101", tz="US/Eastern"), + ... pd.Timestamp("20160101", tz="US/Eastern"), + ... ], + ... dtype="M8[ns, US/Eastern]", + ... ) + ... ) + + ['2016-01-01 00:00:00-05:00'] + Length: 1, dtype: datetime64[ns, US/Eastern] + + >>> pd.unique( + ... pd.Index( + ... [ + ... pd.Timestamp("20160101", tz="US/Eastern"), + ... pd.Timestamp("20160101", tz="US/Eastern"), + ... ], + ... dtype="M8[ns, US/Eastern]", + ... ) + ... ) + DatetimeIndex(['2016-01-01 00:00:00-05:00'], + dtype='datetime64[ns, US/Eastern]', + freq=None) + + >>> pd.unique(np.array(list("baabc"), dtype="O")) + array(['b', 'a', 'c'], dtype=object) + + An unordered Categorical will return categories in the + order of appearance. + + >>> pd.unique(pd.Series(pd.Categorical(list("baabc")))) + ['b', 'a', 'c'] + Categories (3, str): ['a', 'b', 'c'] + + >>> pd.unique(pd.Series(pd.Categorical(list("baabc"), categories=list("abc")))) + ['b', 'a', 'c'] + Categories (3, str): ['a', 'b', 'c'] + + An ordered Categorical preserves the category ordering. + + >>> pd.unique( + ... pd.Series( + ... pd.Categorical(list("baabc"), categories=list("abc"), ordered=True) + ... ) + ... ) + ['b', 'a', 'c'] + Categories (3, str): ['a' < 'b' < 'c'] + + An array of tuples + + >>> pd.unique(pd.Series([("a", "b"), ("b", "a"), ("a", "c"), ("b", "a")]).values) + array([('a', 'b'), ('b', 'a'), ('a', 'c')], dtype=object) + + A NumpyExtensionArray of complex + + >>> pd.unique(pd.array([1 + 1j, 2, 3])) + + [(1+1j), (2+0j), (3+0j)] + Length: 3, dtype: complex128 + """ + return unique_with_mask(values) + + +def nunique_ints(values: ArrayLike) -> int: + """ + Return the number of unique values for integer array-likes. + + Significantly faster than pandas.unique for long enough sequences. + No checks are done to ensure input is integral. + + Parameters + ---------- + values : 1d array-like + + Returns + ------- + int : The number of unique values in ``values`` + """ + if len(values) == 0: + return 0 + values = _ensure_data(values) + # bincount requires intp + result = (np.bincount(values.ravel().astype("intp")) != 0).sum() + return result + + +def unique_with_mask(values, mask: npt.NDArray[np.bool_] | None = None): + """See algorithms.unique for docs. Takes a mask for masked arrays.""" + values = _ensure_arraylike(values, func_name="unique") + + if isinstance(values.dtype, ExtensionDtype): + # Dispatch to extension dtype's unique. + return values.unique() + + if isinstance(values, ABCIndex): + # Dispatch to Index's unique. + return values.unique() + + original = values + hashtable, values = _get_hashtable_algo(values) + + table = hashtable(len(values)) + if mask is None: + uniques = table.unique(values) + uniques = _reconstruct_data(uniques, original.dtype, original) + return uniques + + else: + uniques, mask = table.unique(values, mask=mask) + uniques = _reconstruct_data(uniques, original.dtype, original) + assert mask is not None # for mypy + return uniques, mask.astype("bool") + + +unique1d = unique + + +_MINIMUM_COMP_ARR_LEN = 1_000_000 + + +def isin(comps: ListLike, values: ListLike) -> npt.NDArray[np.bool_]: + """ + Compute the isin boolean array. + + Parameters + ---------- + comps : list-like + values : list-like + + Returns + ------- + ndarray[bool] + Same length as `comps`. + """ + if not is_list_like(comps): + raise TypeError( + "only list-like objects are allowed to be passed " + f"to isin(), you passed a `{type(comps).__name__}`" + ) + if not is_list_like(values): + raise TypeError( + "only list-like objects are allowed to be passed " + f"to isin(), you passed a `{type(values).__name__}`" + ) + + if not isinstance(values, (ABCIndex, ABCSeries, ABCExtensionArray, np.ndarray)): + orig_values = list(values) + values = _ensure_arraylike(orig_values, func_name="isin-targets") + + if ( + len(values) > 0 + and values.dtype.kind in "iufcb" + and not is_signed_integer_dtype(comps) + and not is_dtype_equal(values, comps) + ): + # GH#46485 Use object to avoid upcast to float64 later + # TODO: Share with _find_common_type_compat + values = construct_1d_object_array_from_listlike(orig_values) + + elif isinstance(values, ABCMultiIndex): + # Avoid raising in extract_array + values = np.array(values) + else: + values = extract_array(values, extract_numpy=True, extract_range=True) + + comps_array = _ensure_arraylike(comps, func_name="isin") + comps_array = extract_array(comps_array, extract_numpy=True) + if not isinstance(comps_array, np.ndarray): + # i.e. Extension Array + return comps_array.isin(values) + + elif needs_i8_conversion(comps_array.dtype): + # Dispatch to DatetimeLikeArrayMixin.isin + return pd_array(comps_array).isin(values) + elif needs_i8_conversion(values.dtype) and not is_object_dtype(comps_array.dtype): + # e.g. comps_array are integers and values are datetime64s + return np.zeros(comps_array.shape, dtype=bool) + # TODO: not quite right ... Sparse/Categorical + elif needs_i8_conversion(values.dtype): + return isin(comps_array, values.astype(object)) + + elif isinstance(values.dtype, ExtensionDtype): + return isin(np.asarray(comps_array), np.asarray(values)) + + # GH16012 + # Ensure np.isin doesn't get object types or it *may* throw an exception + # Albeit hashmap has O(1) look-up (vs. O(logn) in sorted array), + # isin is faster for small sizes + + # GH60678 + # Ensure values don't contain , otherwise it throws exception with np.in1d + + if ( + len(comps_array) > _MINIMUM_COMP_ARR_LEN + and len(values) <= 26 + and comps_array.dtype != object + and not any(v is NA for v in values) + ): + # If the values include nan we need to check for nan explicitly + # since np.nan it not equal to np.nan + if isna(values).any(): + + def f(c, v): + return np.logical_or(np.isin(c, v).ravel(), np.isnan(c)) + + else: + f = lambda a, b: np.isin(a, b).ravel() + + else: + common = np_find_common_type(values.dtype, comps_array.dtype) + values = values.astype(common, copy=False) + comps_array = comps_array.astype(common, copy=False) + f = htable.ismember + + return f(comps_array, values) + + +def factorize_array( + values: np.ndarray, + use_na_sentinel: bool = True, + size_hint: int | None = None, + na_value: object = None, + mask: npt.NDArray[np.bool_] | None = None, +) -> tuple[npt.NDArray[np.intp], np.ndarray]: + """ + Factorize a numpy array to codes and uniques. + + This doesn't do any coercion of types or unboxing before factorization. + + Parameters + ---------- + values : ndarray + use_na_sentinel : bool, default True + If True, the sentinel -1 will be used for NaN values. If False, + NaN values will be encoded as non-negative integers and will not drop the + NaN from the uniques of the values. + size_hint : int, optional + Passed through to the hashtable's 'get_labels' method + na_value : object, optional + A value in `values` to consider missing. Note: only use this + parameter when you know that you don't have any values pandas would + consider missing in the array (NaN for float data, iNaT for + datetimes, etc.). + mask : ndarray[bool], optional + If not None, the mask is used as indicator for missing values + (True = missing, False = valid) instead of `na_value` or + condition "val != val". + + Returns + ------- + codes : ndarray[np.intp] + uniques : ndarray + """ + original = values + if values.dtype.kind in "mM": + # _get_hashtable_algo will cast dt64/td64 to i8 via _ensure_data, so we + # need to do the same to na_value. We are assuming here that the passed + # na_value is an appropriately-typed NaT. + # e.g. test_where_datetimelike_categorical + na_value = iNaT + + hash_klass, values = _get_hashtable_algo(values) + + table = hash_klass(size_hint or len(values)) + uniques, codes = table.factorize( + values, + na_sentinel=-1, + na_value=na_value, + mask=mask, + ignore_na=use_na_sentinel, + ) + + # re-cast e.g. i8->dt64/td64, uint8->bool + uniques = _reconstruct_data(uniques, original.dtype, original) + + codes = ensure_platform_int(codes) + return codes, uniques + + +@set_module("pandas") +def factorize( + values, + sort: bool = False, + use_na_sentinel: bool = True, + size_hint: int | None = None, +) -> tuple[np.ndarray, np.ndarray | Index]: + """ + Encode the object as an enumerated type or categorical variable. + + This method is useful for obtaining a numeric representation of an + array when all that matters is identifying distinct values. `factorize` + is available as both a top-level function :func:`pandas.factorize`, + and as a method :meth:`Series.factorize` and :meth:`Index.factorize`. + + Parameters + ---------- + values : sequence + A 1-D sequence. Sequences that aren't pandas objects are + coerced to ndarrays before factorization. + sort : bool, default False + Sort `uniques` and shuffle `codes` to maintain the + relationship. + use_na_sentinel : bool, default True + If True, the sentinel -1 will be used for NaN values. If False, + NaN values will be encoded as non-negative integers and will not drop the + NaN from the uniques of the values. + size_hint : int, optional + Hint to the hashtable sizer. + + Returns + ------- + codes : ndarray + An integer ndarray that's an indexer into `uniques`. + ``uniques.take(codes)`` will have the same values as `values`. + uniques : ndarray, Index, or Categorical + The unique valid values. When `values` is Categorical, `uniques` + is a Categorical. When `values` is some other pandas object, an + `Index` is returned. Otherwise, a 1-D ndarray is returned. + + .. note:: + + Even if there's a missing value in `values`, `uniques` will + *not* contain an entry for it. + + See Also + -------- + cut : Discretize continuous-valued array. + unique : Find the unique value in an array. + + Notes + ----- + Reference :ref:`the user guide ` for more examples. + + Examples + -------- + These examples all show factorize as a top-level method like + ``pd.factorize(values)``. The results are identical for methods like + :meth:`Series.factorize`. + + >>> codes, uniques = pd.factorize(np.array(["b", "b", "a", "c", "b"], dtype="O")) + >>> codes + array([0, 0, 1, 2, 0]) + >>> uniques + array(['b', 'a', 'c'], dtype=object) + + With ``sort=True``, the `uniques` will be sorted, and `codes` will be + shuffled so that the relationship is the maintained. + + >>> codes, uniques = pd.factorize( + ... np.array(["b", "b", "a", "c", "b"], dtype="O"), sort=True + ... ) + >>> codes + array([1, 1, 0, 2, 1]) + >>> uniques + array(['a', 'b', 'c'], dtype=object) + + When ``use_na_sentinel=True`` (the default), missing values are indicated in + the `codes` with the sentinel value ``-1`` and missing values are not + included in `uniques`. + + >>> codes, uniques = pd.factorize(np.array(["b", None, "a", "c", "b"], dtype="O")) + >>> codes + array([ 0, -1, 1, 2, 0]) + >>> uniques + array(['b', 'a', 'c'], dtype=object) + + Thus far, we've only factorized lists (which are internally coerced to + NumPy arrays). When factorizing pandas objects, the type of `uniques` + will differ. For Categoricals, a `Categorical` is returned. + + >>> cat = pd.Categorical(["a", "a", "c"], categories=["a", "b", "c"]) + >>> codes, uniques = pd.factorize(cat) + >>> codes + array([0, 0, 1]) + >>> uniques + ['a', 'c'] + Categories (3, str): ['a', 'b', 'c'] + + Notice that ``'b'`` is in ``uniques.categories``, despite not being + present in ``cat.values``. + + For all other pandas objects, an Index of the appropriate type is + returned. + + >>> cat = pd.Series(["a", "a", "c"]) + >>> codes, uniques = pd.factorize(cat) + >>> codes + array([0, 0, 1]) + >>> uniques + Index(['a', 'c'], dtype='str') + + If NaN is in the values, and we want to include NaN in the uniques of the + values, it can be achieved by setting ``use_na_sentinel=False``. + + >>> values = np.array([1, 2, 1, np.nan]) + >>> codes, uniques = pd.factorize(values) # default: use_na_sentinel=True + >>> codes + array([ 0, 1, 0, -1]) + >>> uniques + array([1., 2.]) + + >>> codes, uniques = pd.factorize(values, use_na_sentinel=False) + >>> codes + array([0, 1, 0, 2]) + >>> uniques + array([ 1., 2., nan]) + """ + # Implementation notes: This method is responsible for 3 things + # 1.) coercing data to array-like (ndarray, Index, extension array) + # 2.) factorizing codes and uniques + # 3.) Maybe boxing the uniques in an Index + # + # Step 2 is dispatched to extension types (like Categorical). They are + # responsible only for factorization. All data coercion, sorting and boxing + # should happen here. + if isinstance(values, (ABCIndex, ABCSeries)): + return values.factorize(sort=sort, use_na_sentinel=use_na_sentinel) + + values = _ensure_arraylike(values, func_name="factorize") + original = values + + if ( + isinstance(values, (ABCDatetimeArray, ABCTimedeltaArray)) + and values.freq is not None + ): + # The presence of 'freq' means we can fast-path sorting and know there + # aren't NAs + codes, uniques = values.factorize(sort=sort) + return codes, uniques + + elif not isinstance(values, np.ndarray): + # i.e. ExtensionArray + codes, uniques = values.factorize(use_na_sentinel=use_na_sentinel) + + else: + values = np.asarray(values) # convert DTA/TDA/MultiIndex + + if not use_na_sentinel and values.dtype == object: + # factorize can now handle differentiating various types of null values. + # These can only occur when the array has object dtype. + # However, for backwards compatibility we only use the null for the + # provided dtype. This may be revisited in the future, see GH#48476. + null_mask = isna(values) + if null_mask.any(): + na_value = na_value_for_dtype(values.dtype, compat=False) + # Don't modify (potentially user-provided) array + values = np.where(null_mask, na_value, values) + + codes, uniques = factorize_array( + values, + use_na_sentinel=use_na_sentinel, + size_hint=size_hint, + ) + + if sort and len(uniques) > 0: + uniques, codes = safe_sort( + uniques, + codes, + use_na_sentinel=use_na_sentinel, + assume_unique=True, + verify=False, + ) + + uniques = _reconstruct_data(uniques, original.dtype, original) + + return codes, uniques + + +def value_counts_internal( + values, + sort: bool = True, + ascending: bool = False, + normalize: bool = False, + bins=None, + dropna: bool = True, +) -> Series: + from pandas import ( + DatetimeIndex, + Index, + Series, + TimedeltaIndex, + ) + + index_name = getattr(values, "name", None) + name = "proportion" if normalize else "count" + + if bins is not None: + from pandas.core.reshape.tile import cut + + if isinstance(values, Series): + values = values._values + + try: + ii = cut(values, bins, include_lowest=True) + except TypeError as err: + raise TypeError("bins argument only works with numeric data.") from err + + # count, remove nulls (from the index), and but the bins + result = ii.value_counts(dropna=dropna) + result.name = name + result = result[result.index.notna()] + result.index = result.index.astype("interval") + result = result.sort_index() + + # if we are dropna and we have NO values + if dropna and (result._values == 0).all(): + result = result.iloc[0:0] + + # normalizing is by len of all (regardless of dropna) + normalize_denominator = len(ii) + + else: + normalize_denominator = None + if is_extension_array_dtype(values): + # handle Categorical and sparse, + result = Series(values, copy=False)._values.value_counts(dropna=dropna) + result.name = name + result.index.name = index_name + + elif isinstance(values, ABCMultiIndex): + # GH49558 + levels = list(range(values.nlevels)) + result = ( + Series(index=values, name=name) + .groupby(level=levels, dropna=dropna) + .size() + ) + result.index.names = values.names + + else: + values = _ensure_arraylike(values, func_name="value_counts") + keys, counts, _ = value_counts_arraylike(values, dropna) + if keys.dtype == np.float16: + keys = keys.astype(np.float32) + + # Starting in 3.0, we no longer perform dtype inference on the + # Index object we construct here, xref GH#56161 + idx = Index(keys, dtype=keys.dtype, name=index_name, copy=False) + + if ( + not sort + and isinstance(values, (DatetimeIndex, TimedeltaIndex)) + and idx.equals(values) + and values.inferred_freq is not None + ): + # Preserve freq of original index + idx.freq = values.inferred_freq # type: ignore[attr-defined] + + result = Series(counts, index=idx, name=name, copy=False) + + if sort: + result = result.sort_values(ascending=ascending, kind="stable") + + if normalize: + if normalize_denominator is not None: + result = result / normalize_denominator + else: + result = result / result.sum() + + return result + + +# Called once from SparseArray, otherwise could be private +def value_counts_arraylike( + values: np.ndarray, dropna: bool, mask: npt.NDArray[np.bool_] | None = None +) -> tuple[ArrayLike, npt.NDArray[np.int64], int]: + """ + Parameters + ---------- + values : np.ndarray + dropna : bool + mask : np.ndarray[bool] or None, default None + + Returns + ------- + uniques : np.ndarray + counts : np.ndarray[np.int64] + """ + original = values + values = _ensure_data(values) + + keys, counts, na_counter = htable.value_count(values, dropna, mask=mask) + + if needs_i8_conversion(original.dtype): + # datetime, timedelta, or period + + if dropna: + mask = keys != iNaT + keys, counts = keys[mask], counts[mask] + + res_keys = _reconstruct_data(keys, original.dtype, original) + return res_keys, counts, na_counter + + +def duplicated( + values: ArrayLike, + keep: Literal["first", "last", False] = "first", + mask: npt.NDArray[np.bool_] | None = None, +) -> npt.NDArray[np.bool_]: + """ + Return boolean ndarray denoting duplicate values. + + Parameters + ---------- + values : np.ndarray or ExtensionArray + Array over which to check for duplicate values. + keep : {'first', 'last', False}, default 'first' + - ``first`` : Mark duplicates as ``True`` except for the first + occurrence. + - ``last`` : Mark duplicates as ``True`` except for the last + occurrence. + - False : Mark all duplicates as ``True``. + mask : ndarray[bool], optional + array indicating which elements to exclude from checking + + Returns + ------- + duplicated : ndarray[bool] + """ + values = _ensure_data(values) + return htable.duplicated(values, keep=keep, mask=mask) + + +def mode( + values: ArrayLike, dropna: bool = True, mask: npt.NDArray[np.bool_] | None = None +) -> tuple[np.ndarray, npt.NDArray[np.bool_]] | ExtensionArray: + """ + Returns the mode(s) of an array. + + Parameters + ---------- + values : array-like + Array over which to check for duplicate values. + dropna : bool, default True + Don't consider counts of NaN/NaT. + + Returns + ------- + Union[Tuple[np.ndarray, npt.NDArray[np.bool_]], ExtensionArray] + """ + values = _ensure_arraylike(values, func_name="mode") + original = values + + if needs_i8_conversion(values.dtype): + # Got here with ndarray; dispatch to DatetimeArray/TimedeltaArray. + values = ensure_wrapped_if_datetimelike(values) + values = cast("ExtensionArray", values) + return values._mode(dropna=dropna) + + values = _ensure_data(values) + + npresult, res_mask = htable.mode(values, dropna=dropna, mask=mask) + if res_mask is None: + res_mask = np.zeros(npresult.shape, dtype=np.bool_) + else: + return npresult, res_mask + + try: + npresult = safe_sort(npresult) + except TypeError as err: + warnings.warn( + f"Unable to sort modes: {err}", + stacklevel=find_stack_level(), + ) + + result = _reconstruct_data(npresult, original.dtype, original) + return result, res_mask + + +def rank( + values: ArrayLike, + axis: AxisInt = 0, + method: str = "average", + na_option: str = "keep", + ascending: bool = True, + pct: bool = False, +) -> npt.NDArray[np.float64]: + """ + Rank the values along a given axis. + + Parameters + ---------- + values : np.ndarray or ExtensionArray + Array whose values will be ranked. The number of dimensions in this + array must not exceed 2. + axis : int, default 0 + Axis over which to perform rankings. + method : {'average', 'min', 'max', 'first', 'dense'}, default 'average' + The method by which tiebreaks are broken during the ranking. + na_option : {'keep', 'top'}, default 'keep' + The method by which NaNs are placed in the ranking. + - ``keep``: rank each NaN value with a NaN ranking + - ``top``: replace each NaN with either +/- inf so that they + there are ranked at the top + ascending : bool, default True + Whether or not the elements should be ranked in ascending order. + pct : bool, default False + Whether or not to the display the returned rankings in integer form + (e.g. 1, 2, 3) or in percentile form (e.g. 0.333..., 0.666..., 1). + """ + is_datetimelike = needs_i8_conversion(values.dtype) + values = _ensure_data(values) + + if values.ndim == 1: + ranks = algos.rank_1d( + values, + is_datetimelike=is_datetimelike, + ties_method=method, + ascending=ascending, + na_option=na_option, + pct=pct, + ) + elif values.ndim == 2: + ranks = algos.rank_2d( + values, + axis=axis, + is_datetimelike=is_datetimelike, + ties_method=method, + ascending=ascending, + na_option=na_option, + pct=pct, + ) + else: + raise TypeError("Array with ndim > 2 are not supported.") + + return ranks + + +# ---- # +# take # +# ---- # + + +@set_module("pandas.api.extensions") +def take( + arr, + indices: TakeIndexer, + axis: AxisInt = 0, + allow_fill: bool = False, + fill_value=None, +): + """ + Take elements from an array. + + Parameters + ---------- + arr : numpy.ndarray, ExtensionArray, Index, or Series + Input array. + indices : sequence of int or one-dimensional np.ndarray of int + Indices to be taken. + axis : int, default 0 + The axis over which to select values. + allow_fill : bool, default False + How to handle negative values in `indices`. + + * False: negative values in `indices` indicate positional indices + from the right (the default). This is similar to :func:`numpy.take`. + + * True: negative values in `indices` indicate + missing values. These values are set to `fill_value`. Any other + negative values raise a ``ValueError``. + + fill_value : any, optional + Fill value to use for NA-indices when `allow_fill` is True. + This may be ``None``, in which case the default NA value for + the type (``self.dtype.na_value``) is used. + + For multi-dimensional `arr`, each *element* is filled with + `fill_value`. + + Returns + ------- + ndarray or ExtensionArray + Same type as the input. + + Raises + ------ + IndexError + When `indices` is out of bounds for the array. + ValueError + When the indexer contains negative values other than ``-1`` + and `allow_fill` is True. + + Notes + ----- + When `allow_fill` is False, `indices` may be whatever dimensionality + is accepted by NumPy for `arr`. + + When `allow_fill` is True, `indices` should be 1-D. + + See Also + -------- + numpy.take : Take elements from an array along an axis. + + Examples + -------- + >>> import pandas as pd + + With the default ``allow_fill=False``, negative numbers indicate + positional indices from the right. + + >>> pd.api.extensions.take(np.array([10, 20, 30]), [0, 0, -1]) + array([10, 10, 30]) + + Setting ``allow_fill=True`` will place `fill_value` in those positions. + + >>> pd.api.extensions.take(np.array([10, 20, 30]), [0, 0, -1], allow_fill=True) + array([10., 10., nan]) + + >>> pd.api.extensions.take( + ... np.array([10, 20, 30]), [0, 0, -1], allow_fill=True, fill_value=-10 + ... ) + array([ 10, 10, -10]) + """ + if not isinstance( + arr, + (np.ndarray, ABCExtensionArray, ABCIndex, ABCSeries, ABCNumpyExtensionArray), + ): + # GH#52981 + raise TypeError( + "pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, " + f"Index, Series, or NumpyExtensionArray got {type(arr).__name__}." + ) + + indices = ensure_platform_int(indices) + + if allow_fill: + # Pandas style, -1 means NA + validate_indices(indices, arr.shape[axis]) + # error: Argument 1 to "take_nd" has incompatible type + # "ndarray[Any, Any] | ExtensionArray | Index | Series"; expected + # "ndarray[Any, Any]" + result = take_nd( + arr, # type: ignore[arg-type] + indices, + axis=axis, + allow_fill=True, + fill_value=fill_value, + ) + else: + # NumPy style + # error: Unexpected keyword argument "axis" for "take" of "ExtensionArray" + result = arr.take(indices, axis=axis) # type: ignore[call-arg,assignment] + return result + + +# ------------ # +# searchsorted # +# ------------ # + + +def searchsorted( + arr: ArrayLike, + value: NumpyValueArrayLike | ExtensionArray, + side: Literal["left", "right"] = "left", + sorter: NumpySorter | None = None, +) -> npt.NDArray[np.intp] | np.intp: + """ + Find indices where elements should be inserted to maintain order. + + Find the indices into a sorted array `arr` (a) such that, if the + corresponding elements in `value` were inserted before the indices, + the order of `arr` would be preserved. + + Assuming that `arr` is sorted: + + ====== ================================ + `side` returned index `i` satisfies + ====== ================================ + left ``arr[i-1] < value <= self[i]`` + right ``arr[i-1] <= value < self[i]`` + ====== ================================ + + Parameters + ---------- + arr: np.ndarray, ExtensionArray, Series + Input array. If `sorter` is None, then it must be sorted in + ascending order, otherwise `sorter` must be an array of indices + that sort it. + value : array-like or scalar + Values to insert into `arr`. + side : {'left', 'right'}, optional + If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. If there is no suitable + index, return either 0 or N (where N is the length of `self`). + sorter : 1-D array-like, optional + Optional array of integer indices that sort array a into ascending + order. They are typically the result of argsort. + + Returns + ------- + array of ints or int + If value is array-like, array of insertion points. + If value is scalar, a single integer. + + See Also + -------- + numpy.searchsorted : Similar method from NumPy. + """ + if sorter is not None: + sorter = ensure_platform_int(sorter) + + if ( + isinstance(arr, np.ndarray) + and arr.dtype.kind in "iu" + and (is_integer(value) or is_integer_dtype(value)) + ): + # if `arr` and `value` have different dtypes, `arr` would be + # recast by numpy, causing a slow search. + # Before searching below, we therefore try to give `value` the + # same dtype as `arr`, while guarding against integer overflows. + iinfo = np.iinfo(arr.dtype.type) + value_arr = np.array([value]) if is_integer(value) else np.array(value) + if (value_arr >= iinfo.min).all() and (value_arr <= iinfo.max).all(): + # value within bounds, so no overflow, so can convert value dtype + # to dtype of arr + dtype = arr.dtype + else: + dtype = value_arr.dtype + + if is_integer(value): + # We know that value is int + value = cast(int, dtype.type(value)) + else: + value = pd_array(cast(ArrayLike, value), dtype=dtype) + else: + # E.g. if `arr` is an array with dtype='datetime64[ns]' + # and `value` is a pd.Timestamp, we may need to convert value + arr = ensure_wrapped_if_datetimelike(arr) + + # Argument 1 to "searchsorted" of "ndarray" has incompatible type + # "Union[NumpyValueArrayLike, ExtensionArray]"; expected "NumpyValueArrayLike" + return arr.searchsorted(value, side=side, sorter=sorter) # type: ignore[arg-type] + + +# ---- # +# diff # +# ---- # + +_diff_special = {"float64", "float32", "int64", "int32", "int16", "int8"} + + +def diff(arr, n: int | float | np.integer | np.floating, axis: AxisInt = 0): + """ + difference of n between self, + analogous to s-s.shift(n) + + Parameters + ---------- + arr : ndarray or ExtensionArray + n : int + number of periods + axis : {0, 1} + axis to shift on + stacklevel : int, default 3 + The stacklevel for the lost dtype warning. + + Returns + ------- + shifted + """ + + # added a check on the integer value of period + # see https://github.com/pandas-dev/pandas/issues/56607 + if not lib.is_integer(n): + if not (is_float(n) and n.is_integer()): + raise ValueError("periods must be an integer") + n = int(n) + na = np.nan + dtype = arr.dtype + + is_bool = is_bool_dtype(dtype) + if is_bool: + op = operator.xor + else: + op = operator.sub + + if isinstance(dtype, NumpyEADtype): + # NumpyExtensionArray cannot necessarily hold shifted versions of itself. + arr = arr.to_numpy() + dtype = arr.dtype + + if not isinstance(arr, np.ndarray): + # i.e ExtensionArray + if hasattr(arr, f"__{op.__name__}__"): + if axis != 0: + raise ValueError(f"cannot diff {type(arr).__name__} on axis={axis}") + return op(arr, arr.shift(n)) + else: + raise TypeError( + f"{type(arr).__name__} has no 'diff' method. " + "Convert to a suitable dtype prior to calling 'diff'." + ) + + is_timedelta = False + if arr.dtype.kind in "mM": + dtype = np.int64 + arr = arr.view("i8") + na = iNaT + is_timedelta = True + + elif is_bool: + # We have to cast in order to be able to hold np.nan + dtype = np.object_ + + elif dtype.kind in "iu": + # We have to cast in order to be able to hold np.nan + + # int8, int16 are incompatible with float64, + # see https://github.com/cython/cython/issues/2646 + if arr.dtype.name in ["int8", "int16"]: + dtype = np.float32 + else: + dtype = np.float64 + + orig_ndim = arr.ndim + if orig_ndim == 1: + # reshape so we can always use algos.diff_2d + arr = arr.reshape(-1, 1) + # TODO: require axis == 0 + + dtype = np.dtype(dtype) + out_arr = np.empty(arr.shape, dtype=dtype) + + na_indexer = [slice(None)] * 2 + na_indexer[axis] = slice(None, n) if n >= 0 else slice(n, None) + out_arr[tuple(na_indexer)] = na + + if arr.dtype.name in _diff_special: + # TODO: can diff_2d dtype specialization troubles be fixed by defining + # out_arr inside diff_2d? + algos.diff_2d(arr, out_arr, int(n), axis, datetimelike=is_timedelta) + else: + # To keep mypy happy, _res_indexer is a list while res_indexer is + # a tuple, ditto for lag_indexer. + _res_indexer = [slice(None)] * 2 + _res_indexer[axis] = slice(n, None) if n >= 0 else slice(None, n) + res_indexer = tuple(_res_indexer) + + _lag_indexer = [slice(None)] * 2 + _lag_indexer[axis] = slice(None, -n) if n > 0 else slice(-n, None) + lag_indexer = tuple(_lag_indexer) + + out_arr[res_indexer] = op(arr[res_indexer], arr[lag_indexer]) + + if is_timedelta: + out_arr = out_arr.view("timedelta64[ns]") + + if orig_ndim == 1: + out_arr = out_arr[:, 0] + return out_arr + + +# -------------------------------------------------------------------- +# Helper functions + + +# Note: safe_sort is in algorithms.py instead of sorting.py because it is +# low-dependency, is used in this module, and used private methods from +# this module. +def safe_sort( + values: Index | ArrayLike, + codes: npt.NDArray[np.intp] | None = None, + use_na_sentinel: bool = True, + assume_unique: bool = False, + verify: bool = True, +) -> AnyArrayLike | tuple[AnyArrayLike, np.ndarray]: + """ + Sort ``values`` and reorder corresponding ``codes``. + + ``values`` should be unique if ``codes`` is not None. + Safe for use with mixed types (int, str), orders ints before strs. + + Parameters + ---------- + values : list-like + Sequence; must be unique if ``codes`` is not None. + codes : np.ndarray[intp] or None, default None + Indices to ``values``. All out of bound indices are treated as + "not found" and will be masked with ``-1``. + use_na_sentinel : bool, default True + If True, the sentinel -1 will be used for NaN values. If False, + NaN values will be encoded as non-negative integers and will not drop the + NaN from the uniques of the values. + assume_unique : bool, default False + When True, ``values`` are assumed to be unique, which can speed up + the calculation. Ignored when ``codes`` is None. + verify : bool, default True + Check if codes are out of bound for the values and put out of bound + codes equal to ``-1``. If ``verify=False``, it is assumed there + are no out of bound codes. Ignored when ``codes`` is None. + + Returns + ------- + ordered : AnyArrayLike + Sorted ``values`` + new_codes : ndarray + Reordered ``codes``; returned when ``codes`` is not None. + + Raises + ------ + TypeError + * If ``values`` is not list-like or if ``codes`` is neither None + nor list-like + * If ``values`` cannot be sorted + ValueError + * If ``codes`` is not None and ``values`` contain duplicates. + """ + if not isinstance(values, (np.ndarray, ABCExtensionArray, ABCIndex)): + raise TypeError( + "Only np.ndarray, ExtensionArray, and Index objects are allowed to " + "be passed to safe_sort as values" + ) + + sorter = None + ordered: AnyArrayLike + + if ( + not isinstance(values.dtype, ExtensionDtype) + and lib.infer_dtype(values, skipna=False) == "mixed-integer" + ): + ordered = _sort_mixed(values) + else: + try: + sorter = values.argsort() + ordered = values.take(sorter) + except (TypeError, decimal.InvalidOperation): + # Previous sorters failed or were not applicable, try `_sort_mixed` + # which would work, but which fails for special case of 1d arrays + # with tuples. + if values.size and isinstance(values[0], tuple): + # error: Argument 1 to "_sort_tuples" has incompatible type + # "Union[Index, ExtensionArray, ndarray[Any, Any]]"; expected + # "ndarray[Any, Any]" + ordered = _sort_tuples(values) # type: ignore[arg-type] + else: + ordered = _sort_mixed(values) + + # codes: + + if codes is None: + return ordered + + if not is_list_like(codes): + raise TypeError( + "Only list-like objects or None are allowed to " + "be passed to safe_sort as codes" + ) + codes = ensure_platform_int(np.asarray(codes)) + + if not assume_unique and not len(unique(values)) == len(values): + raise ValueError("values should be unique if codes is not None") + + if sorter is None: + # mixed types + # error: Argument 1 to "_get_hashtable_algo" has incompatible type + # "Union[Index, ExtensionArray, ndarray[Any, Any]]"; expected + # "ndarray[Any, Any]" + hash_klass, values = _get_hashtable_algo(values) # type: ignore[arg-type] + t = hash_klass(len(values)) + t.map_locations(values) + # error: Argument 1 to "lookup" of "HashTable" has incompatible type + # "ExtensionArray | ndarray[Any, Any] | Index | Series"; expected "ndarray" + sorter = ensure_platform_int(t.lookup(ordered)) # type: ignore[arg-type] + + if use_na_sentinel: + # take_nd is faster, but only works for na_sentinels of -1 + order2 = sorter.argsort() + if verify: + mask = (codes < -len(values)) | (codes >= len(values)) + codes[mask] = -1 + new_codes = take_nd(order2, codes, fill_value=-1) + else: + reverse_indexer = np.empty(len(sorter), dtype=int) + reverse_indexer.put(sorter, np.arange(len(sorter))) + # Out of bound indices will be masked with `-1` next, so we + # may deal with them here without performance loss using `mode='wrap'` + new_codes = reverse_indexer.take(codes, mode="wrap") + + return ordered, ensure_platform_int(new_codes) + + +def _sort_mixed(values) -> AnyArrayLike: + """order ints before strings before nulls in 1d arrays""" + str_pos = np.array([isinstance(x, str) for x in values], dtype=bool) + null_pos = np.array([isna(x) for x in values], dtype=bool) + num_pos = ~str_pos & ~null_pos + str_argsort = np.argsort(values[str_pos]) + num_argsort = np.argsort(values[num_pos]) + # convert boolean arrays to positional indices, then order by underlying values + str_locs = str_pos.nonzero()[0].take(str_argsort) + num_locs = num_pos.nonzero()[0].take(num_argsort) + null_locs = null_pos.nonzero()[0] + locs = np.concatenate([num_locs, str_locs, null_locs]) + return values.take(locs) + + +def _sort_tuples(values: np.ndarray) -> np.ndarray: + """ + Convert array of tuples (1d) to array of arrays (2d). + We need to keep the columns separately as they contain different types and + nans (can't use `np.sort` as it may fail when str and nan are mixed in a + column as types cannot be compared). + """ + from pandas.core.internals.construction import to_arrays + from pandas.core.sorting import lexsort_indexer + + arrays, _ = to_arrays(values, None) + indexer = lexsort_indexer(arrays, orders=True) + return values[indexer] + + +def union_with_duplicates( + lvals: ArrayLike | Index, rvals: ArrayLike | Index +) -> ArrayLike | Index: + """ + Extracts the union from lvals and rvals with respect to duplicates and nans in + both arrays. + + Parameters + ---------- + lvals: np.ndarray or ExtensionArray + left values which is ordered in front. + rvals: np.ndarray or ExtensionArray + right values ordered after lvals. + + Returns + ------- + np.ndarray or ExtensionArray + Containing the unsorted union of both arrays. + + Notes + ----- + Caller is responsible for ensuring lvals.dtype == rvals.dtype. + """ + from pandas import Series + + l_count = value_counts_internal(lvals, dropna=False) + r_count = value_counts_internal(rvals, dropna=False) + l_count, r_count = l_count.align(r_count, fill_value=0) + final_count = np.maximum(l_count.values, r_count.values) + final_count = Series(final_count, index=l_count.index, dtype="int", copy=False) + if isinstance(lvals, ABCMultiIndex) and isinstance(rvals, ABCMultiIndex): + unique_vals = lvals.append(rvals).unique() + else: + if isinstance(lvals, ABCIndex): + lvals = lvals._values + if isinstance(rvals, ABCIndex): + rvals = rvals._values + # error: List item 0 has incompatible type "Union[ExtensionArray, + # ndarray[Any, Any], Index]"; expected "Union[ExtensionArray, + # ndarray[Any, Any]]" + combined = concat_compat([lvals, rvals]) # type: ignore[list-item] + unique_vals = unique(combined) + unique_vals = ensure_wrapped_if_datetimelike(unique_vals) + repeats = final_count.reindex(unique_vals).values + return np.repeat(unique_vals, repeats) + + +def map_array( + arr: ArrayLike, + mapper, + na_action: Literal["ignore"] | None = None, +) -> np.ndarray | ExtensionArray | Index: + """ + Map values using an input mapping or function. + + Parameters + ---------- + mapper : function, dict, or Series + Mapping correspondence. + na_action : {None, 'ignore'}, default None + If 'ignore', propagate NA values, without passing them to the + mapping correspondence. + + Returns + ------- + Union[ndarray, Index, ExtensionArray] + The output of the mapping function applied to the array. + If the function returns a tuple with more than one element + a MultiIndex will be returned. + """ + from pandas import Index + + if na_action not in (None, "ignore"): + msg = f"na_action must either be 'ignore' or None, {na_action} was passed" + raise ValueError(msg) + + # we can fastpath dict/Series to an efficient map + # as we know that we are not going to have to yield + # python types + if is_dict_like(mapper): + if isinstance(mapper, dict) and hasattr(mapper, "__missing__"): + # If a dictionary subclass defines a default value method, + # convert mapper to a lookup function (GH #15999). + dict_with_default = mapper + mapper = lambda x: dict_with_default[ + np.nan if isinstance(x, float) and np.isnan(x) else x + ] + else: + # Dictionary does not have a default. Thus it's safe to + # convert to a Series for efficiency. + # we specify the keys here to handle the + # possibility that they are tuples + + # The return value of mapping with an empty mapper is + # expected to be pd.Series(np.nan, ...). As np.nan is + # of dtype float64 the return value of this method should + # be float64 as well + from pandas import Series + + if len(mapper) == 0: + mapper = Series(mapper, dtype=np.float64) + elif isinstance(mapper, dict): + mapper = Series( + mapper.values(), index=Index(mapper.keys(), tupleize_cols=False) + ) + else: + mapper = Series(mapper) + + if isinstance(mapper, ABCSeries): + if na_action == "ignore": + mapper = mapper[mapper.index.notna()] + + # Since values were input this means we came from either + # a dict or a series and mapper should be an index + indexer = mapper.index.get_indexer(arr) + new_values = take_nd(mapper._values, indexer) + + return new_values + + if not len(arr): + return arr.copy() + + # we must convert to python types + values = arr.astype(object, copy=False) + if na_action is None: + return lib.map_infer(values, mapper) + else: + return lib.map_infer_mask(values, mapper, mask=isna(values).view(np.uint8)) diff --git a/pandas/core/api.py b/pandas/core/api.py new file mode 100644 index 0000000000000000000000000000000000000000..ec12d543d8389afa38c7c84a658dcaeee960690c --- /dev/null +++ b/pandas/core/api.py @@ -0,0 +1,138 @@ +from pandas._libs import ( + NaT, + Period, + Timedelta, + Timestamp, +) +from pandas._libs.missing import NA + +from pandas.core.dtypes.dtypes import ( + ArrowDtype, + CategoricalDtype, + DatetimeTZDtype, + IntervalDtype, + PeriodDtype, +) +from pandas.core.dtypes.missing import ( + isna, + isnull, + notna, + notnull, +) + +from pandas.core.algorithms import ( + factorize, + unique, +) +from pandas.core.arrays import Categorical +from pandas.core.arrays.boolean import BooleanDtype +from pandas.core.arrays.floating import ( + Float32Dtype, + Float64Dtype, +) +from pandas.core.arrays.integer import ( + Int8Dtype, + Int16Dtype, + Int32Dtype, + Int64Dtype, + UInt8Dtype, + UInt16Dtype, + UInt32Dtype, + UInt64Dtype, +) +from pandas.core.arrays.string_ import StringDtype +from pandas.core.construction import array # noqa: ICN001 +from pandas.core.flags import Flags +from pandas.core.groupby import ( + Grouper, + NamedAgg, +) +from pandas.core.indexes.api import ( + CategoricalIndex, + DatetimeIndex, + Index, + IntervalIndex, + MultiIndex, + PeriodIndex, + RangeIndex, + TimedeltaIndex, +) +from pandas.core.indexes.datetimes import ( + bdate_range, + date_range, +) +from pandas.core.indexes.interval import ( + Interval, + interval_range, +) +from pandas.core.indexes.period import period_range +from pandas.core.indexes.timedeltas import timedelta_range +from pandas.core.indexing import IndexSlice +from pandas.core.series import Series +from pandas.core.tools.datetimes import to_datetime +from pandas.core.tools.numeric import to_numeric +from pandas.core.tools.timedeltas import to_timedelta + +from pandas.io.formats.format import set_eng_float_format +from pandas.tseries.offsets import DateOffset + +# DataFrame needs to be imported after NamedAgg to avoid a circular import +from pandas.core.frame import DataFrame # isort:skip + +__all__ = [ + "NA", + "ArrowDtype", + "BooleanDtype", + "Categorical", + "CategoricalDtype", + "CategoricalIndex", + "DataFrame", + "DateOffset", + "DatetimeIndex", + "DatetimeTZDtype", + "Flags", + "Float32Dtype", + "Float64Dtype", + "Grouper", + "Index", + "IndexSlice", + "Int8Dtype", + "Int16Dtype", + "Int32Dtype", + "Int64Dtype", + "Interval", + "IntervalDtype", + "IntervalIndex", + "MultiIndex", + "NaT", + "NamedAgg", + "Period", + "PeriodDtype", + "PeriodIndex", + "RangeIndex", + "Series", + "StringDtype", + "Timedelta", + "TimedeltaIndex", + "Timestamp", + "UInt8Dtype", + "UInt16Dtype", + "UInt32Dtype", + "UInt64Dtype", + "array", + "bdate_range", + "date_range", + "factorize", + "interval_range", + "isna", + "isnull", + "notna", + "notnull", + "period_range", + "set_eng_float_format", + "timedelta_range", + "to_datetime", + "to_numeric", + "to_timedelta", + "unique", +] diff --git a/pandas/core/apply.py b/pandas/core/apply.py new file mode 100644 index 0000000000000000000000000000000000000000..3f218b3813149a2c6584e9919d11802dd27b7ce4 --- /dev/null +++ b/pandas/core/apply.py @@ -0,0 +1,2132 @@ +from __future__ import annotations + +import abc +from collections import defaultdict +from collections.abc import Callable +import functools +from functools import partial +import inspect +from typing import ( + TYPE_CHECKING, + Any, + Literal, + TypeAlias, + cast, +) + +import numpy as np + +from pandas._libs.internals import BlockValuesRefs +from pandas._typing import ( + AggFuncType, + AggFuncTypeBase, + AggFuncTypeDict, + AggObjType, + Axis, + AxisInt, + NDFrameT, + npt, +) +from pandas.compat._optional import import_optional_dependency +from pandas.errors import SpecificationError +from pandas.util._decorators import ( + cache_readonly, + set_module, +) + +from pandas.core.dtypes.cast import is_nested_object +from pandas.core.dtypes.common import ( + is_dict_like, + is_extension_array_dtype, + is_list_like, + is_numeric_dtype, + is_sequence, +) +from pandas.core.dtypes.dtypes import ExtensionDtype +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCNDFrame, + ABCSeries, +) + +from pandas.core._numba.executor import generate_apply_looper +import pandas.core.common as com +from pandas.core.construction import ensure_wrapped_if_datetimelike +from pandas.core.util.numba_ import ( + get_jit_arguments, + prepare_function_arguments, +) + +if TYPE_CHECKING: + from collections.abc import ( + Generator, + Hashable, + Iterable, + MutableMapping, + Sequence, + ) + + from pandas import ( + DataFrame, + Index, + Series, + ) + from pandas.core.groupby import GroupBy + from pandas.core.resample import Resampler + from pandas.core.window.rolling import BaseWindow + +ResType: TypeAlias = dict[int, Any] + + +@set_module("pandas.api.executors") +class BaseExecutionEngine(abc.ABC): + """ + Base class for execution engines for map and apply methods. + + An execution engine receives all the parameters of a call to + ``apply`` or ``map``, such as the data container, the function, + etc. and takes care of running the execution. + + Supporting different engines allows functions to be JIT compiled, + run in parallel, and others. Besides the default executor which + simply runs the code with the Python interpreter and pandas. + """ + + @staticmethod + @abc.abstractmethod + def map( + data: Series | DataFrame | np.ndarray, + func: AggFuncType, + args: tuple, + kwargs: dict[str, Any], + decorator: Callable | None, + skip_na: bool, + ): + """ + Executor method to run functions elementwise. + + In general, pandas uses ``map`` for running functions elementwise, + but ``Series.apply`` with the default ``by_row='compat'`` will also + call this executor function. + + Parameters + ---------- + data : Series, DataFrame or NumPy ndarray + The object to use for the data. Some methods implement a ``raw`` + parameter which will convert the original pandas object to a + NumPy array, which will then be passed here to the executor. + func : function or NumPy ufunc + The function to execute. + args : tuple + Positional arguments to be passed to ``func``. + kwargs : dict + Keyword arguments to be passed to ``func``. + decorator : function, optional + For JIT compilers and other engines that need to decorate the + function ``func``, this is the decorator to use. While the + executor may already know which is the decorator to use, this + is useful as for a single executor the user can specify for + example ``numba.jit`` or ``numba.njit(nogil=True)``, and this + decorator parameter will contain the exact decorator from the + executor the user wants to use. + skip_na : bool + Whether the function should be called for missing values or not. + This is specified by the pandas user as ``map(na_action=None)`` + or ``map(na_action='ignore')``. + """ + + @staticmethod + @abc.abstractmethod + def apply( + data: Series | DataFrame | np.ndarray, + func: AggFuncType, + args: tuple, + kwargs: dict[str, Any], + decorator: Callable, + axis: Axis, + ): + """ + Executor method to run functions by an axis. + + While we can see ``map`` as executing the function for each cell + in a ``DataFrame`` (or ``Series``), ``apply`` will execute the + function for each column (or row). + + Parameters + ---------- + data : Series, DataFrame or NumPy ndarray + The object to use for the data. Some methods implement a ``raw`` + parameter which will convert the original pandas object to a + NumPy array, which will then be passed here to the executor. + func : function or NumPy ufunc + The function to execute. + args : tuple + Positional arguments to be passed to ``func``. + kwargs : dict + Keyword arguments to be passed to ``func``. + decorator : function, optional + For JIT compilers and other engines that need to decorate the + function ``func``, this is the decorator to use. While the + executor may already know which is the decorator to use, this + is useful as for a single executor the user can specify for + example ``numba.jit`` or ``numba.njit(nogil=True)``, and this + decorator parameter will contain the exact decorator from the + executor the user wants to use. + axis : {0 or 'index', 1 or 'columns'} + 0 or 'index' should execute the function passing each column as + parameter. 1 or 'columns' should execute the function passing + each row as parameter. The default executor engine passes rows + as pandas ``Series``. Other executor engines should probably + expect functions to be implemented this way for compatibility. + But passing rows as other data structures is technically possible + as far as the function ``func`` is implemented accordingly. + """ + + +def frame_apply( + obj: DataFrame, + func: AggFuncType, + axis: Axis = 0, + raw: bool = False, + result_type: str | None = None, + by_row: Literal[False, "compat"] = "compat", + engine: str = "python", + engine_kwargs: dict[str, bool] | None = None, + args=None, + kwargs=None, +) -> FrameApply: + """construct and return a row or column based frame apply object""" + _, func, columns, _ = reconstruct_func(func, **kwargs) + + axis = obj._get_axis_number(axis) + klass: type[FrameApply] + if axis == 0: + klass = FrameRowApply + elif axis == 1: + if columns: + raise NotImplementedError( + f"Named aggregation is not supported when {axis=}." + ) + klass = FrameColumnApply + + return klass( + obj, + func, + raw=raw, + result_type=result_type, + by_row=by_row, + engine=engine, + engine_kwargs=engine_kwargs, + args=args, + kwargs=kwargs, + ) + + +class Apply(metaclass=abc.ABCMeta): + axis: AxisInt + + def __init__( + self, + obj: AggObjType, + func: AggFuncType, + raw: bool, + result_type: str | None, + *, + by_row: Literal[False, "compat", "_compat"] = "compat", + engine: str = "python", + engine_kwargs: dict[str, bool] | None = None, + args, + kwargs, + ) -> None: + self.obj = obj + self.raw = raw + + assert by_row is False or by_row in ["compat", "_compat"] + self.by_row = by_row + + self.args = args or () + self.kwargs = kwargs or {} + + self.engine = engine + self.engine_kwargs = {} if engine_kwargs is None else engine_kwargs + + if result_type not in [None, "reduce", "broadcast", "expand"]: + raise ValueError( + "invalid value for result_type, must be one " + "of {None, 'reduce', 'broadcast', 'expand'}" + ) + + self.result_type = result_type + + self.func = func + + @abc.abstractmethod + def apply(self) -> DataFrame | Series: + pass + + @abc.abstractmethod + def agg_or_apply_list_like( + self, op_name: Literal["agg", "apply"] + ) -> DataFrame | Series: + pass + + @abc.abstractmethod + def agg_or_apply_dict_like( + self, op_name: Literal["agg", "apply"] + ) -> DataFrame | Series: + pass + + def agg(self) -> DataFrame | Series | None: + """ + Provide an implementation for the aggregators. + + Returns + ------- + Result of aggregation, or None if agg cannot be performed by + this method. + """ + func = self.func + + if isinstance(func, str): + return self.apply_str() + + if is_dict_like(func): + return self.agg_dict_like() + elif is_list_like(func): + # we require a list, but not a 'str' + return self.agg_list_like() + + # caller can react + return None + + def transform(self) -> DataFrame | Series: + """ + Transform a DataFrame or Series. + + Returns + ------- + DataFrame or Series + Result of applying ``func`` along the given axis of the + Series or DataFrame. + + Raises + ------ + ValueError + If the transform function fails or does not transform. + """ + obj = self.obj + func = self.func + axis = self.axis + args = self.args + kwargs = self.kwargs + + is_series = obj.ndim == 1 + + if obj._get_axis_number(axis) == 1: + assert not is_series + return obj.T.transform(func, 0, *args, **kwargs).T + + if is_list_like(func) and not is_dict_like(func): + func = cast(list[AggFuncTypeBase], func) + # Convert func equivalent dict + if is_series: + func = {com.get_callable_name(v) or v: v for v in func} + else: + func = dict.fromkeys(obj, func) + + if is_dict_like(func): + func = cast(AggFuncTypeDict, func) + return self.transform_dict_like(func) + + # func is either str or callable + func = cast(AggFuncTypeBase, func) + try: + result = self.transform_str_or_callable(func) + except TypeError: + raise + except Exception as err: + raise ValueError("Transform function failed") from err + + # Functions that transform may return empty Series/DataFrame + # when the dtype is not appropriate + if ( + isinstance(result, (ABCSeries, ABCDataFrame)) + and result.empty + and not obj.empty + ): + raise ValueError("Transform function failed") + if not isinstance(result, (ABCSeries, ABCDataFrame)) or not result.index.equals( + obj.index + ): + raise ValueError("Function did not transform") + + return result + + def transform_dict_like(self, func) -> DataFrame: + """ + Compute transform in the case of a dict-like func + """ + from pandas.core.reshape.concat import concat + + obj = self.obj + args = self.args + kwargs = self.kwargs + + # transform is currently only for Series/DataFrame + assert isinstance(obj, ABCNDFrame) + + if len(func) == 0: + raise ValueError("No transform functions were provided") + + func = self.normalize_dictlike_arg("transform", obj, func) + + results: dict[Hashable, DataFrame | Series] = {} + for name, how in func.items(): + colg = obj._gotitem(name, ndim=1) + results[name] = colg.transform(how, 0, *args, **kwargs) + return concat(results, axis=1) + + def transform_str_or_callable(self, func) -> DataFrame | Series: + """ + Compute transform in the case of a string or callable func + """ + obj = self.obj + args = self.args + kwargs = self.kwargs + + if isinstance(func, str): + return self._apply_str(obj, func, *args, **kwargs) + + # Two possible ways to use a UDF - apply or call directly + try: + return obj.apply(func, args=args, **kwargs) + except Exception: + return func(obj, *args, **kwargs) + + def agg_list_like(self) -> DataFrame | Series: + """ + Compute aggregation in the case of a list-like argument. + + Returns + ------- + Result of aggregation. + """ + return self.agg_or_apply_list_like(op_name="agg") + + def compute_list_like( + self, + op_name: Literal["agg", "apply"], + selected_obj: Series | DataFrame, + kwargs: dict[str, Any], + ) -> tuple[list[Hashable] | Index, list[Any]]: + """ + Compute agg/apply results for like-like input. + + Parameters + ---------- + op_name : {"agg", "apply"} + Operation being performed. + selected_obj : Series or DataFrame + Data to perform operation on. + kwargs : dict + Keyword arguments to pass to the functions. + + Returns + ------- + keys : list[Hashable] or Index + Index labels for result. + results : list + Data for result. When aggregating with a Series, this can contain any + Python objects. + """ + func = cast(list[AggFuncTypeBase], self.func) + obj = self.obj + + results = [] + keys = [] + + # degenerate case + if selected_obj.ndim == 1: + for a in func: + colg = obj._gotitem(selected_obj.name, ndim=1, subset=selected_obj) + args = ( + [self.axis, *self.args] + if include_axis(op_name, colg) + else self.args + ) + new_res = getattr(colg, op_name)(a, *args, **kwargs) + results.append(new_res) + + # make sure we find a good name + name = com.get_callable_name(a) or a + keys.append(name) + + else: + indices = [] + for index, col in enumerate(selected_obj): + colg = obj._gotitem(col, ndim=1, subset=selected_obj.iloc[:, index]) + args = ( + [self.axis, *self.args] + if include_axis(op_name, colg) + else self.args + ) + new_res = getattr(colg, op_name)(func, *args, **kwargs) + results.append(new_res) + indices.append(index) + # error: Incompatible types in assignment (expression has type "Any | + # Index", variable has type "list[Any | Callable[..., Any] | str]") + keys = selected_obj.columns.take(indices) # type: ignore[assignment] + + return keys, results + + def wrap_results_list_like( + self, keys: Iterable[Hashable], results: list[Series | DataFrame] + ): + from pandas.core.reshape.concat import concat + + obj = self.obj + + try: + return concat(results, keys=keys, axis=1, sort=False) + except TypeError as err: + # we are concatting non-NDFrame objects, + # e.g. a list of scalars + from pandas import Series + + result = Series(results, index=keys, name=obj.name) + if is_nested_object(result): + raise ValueError( + "cannot combine transform and aggregation operations" + ) from err + return result + + def agg_dict_like(self) -> DataFrame | Series: + """ + Compute aggregation in the case of a dict-like argument. + + Returns + ------- + Result of aggregation. + """ + return self.agg_or_apply_dict_like(op_name="agg") + + def compute_dict_like( + self, + op_name: Literal["agg", "apply"], + selected_obj: Series | DataFrame, + selection: Hashable | Sequence[Hashable], + kwargs: dict[str, Any], + ) -> tuple[list[Hashable], list[Any]]: + """ + Compute agg/apply results for dict-like input. + + Parameters + ---------- + op_name : {"agg", "apply"} + Operation being performed. + selected_obj : Series or DataFrame + Data to perform operation on. + selection : hashable or sequence of hashables + Used by GroupBy, Window, and Resample if selection is applied to the object. + kwargs : dict + Keyword arguments to pass to the functions. + + Returns + ------- + keys : list[hashable] + Index labels for result. + results : list + Data for result. When aggregating with a Series, this can contain any + Python object. + """ + from pandas.core.groupby.generic import ( + DataFrameGroupBy, + SeriesGroupBy, + ) + + obj = self.obj + is_groupby = isinstance(obj, (DataFrameGroupBy, SeriesGroupBy)) + func = cast(AggFuncTypeDict, self.func) + func = self.normalize_dictlike_arg(op_name, selected_obj, func) + + is_non_unique_col = ( + selected_obj.ndim == 2 + and selected_obj.columns.nunique() < len(selected_obj.columns) + ) + + if selected_obj.ndim == 1: + # key only used for output + colg = obj._gotitem(selection, ndim=1) + results = [getattr(colg, op_name)(how, **kwargs) for _, how in func.items()] + keys = list(func.keys()) + elif not is_groupby and is_non_unique_col: + # key used for column selection and output + # GH#51099 + results = [] + keys = [] + for key, how in func.items(): + indices = selected_obj.columns.get_indexer_for([key]) + labels = selected_obj.columns.take(indices) + label_to_indices = defaultdict(list) + for index, label in zip(indices, labels, strict=True): + label_to_indices[label].append(index) + + key_data = [ + getattr(selected_obj._ixs(indice, axis=1), op_name)(how, **kwargs) + for label, indices in label_to_indices.items() + for indice in indices + ] + + keys += [key] * len(key_data) + results += key_data + elif is_groupby: + # key used for column selection and output + + df = selected_obj + results, keys = [], [] + for key, how in func.items(): + cols = df[key] + + if cols.ndim == 1: + series = obj._gotitem(key, ndim=1, subset=cols) + results.append(getattr(series, op_name)(how, **kwargs)) + keys.append(key) + else: + for _, col in cols.items(): + series = obj._gotitem(key, ndim=1, subset=col) + results.append(getattr(series, op_name)(how, **kwargs)) + keys.append(key) + else: + results = [ + getattr(obj._gotitem(key, ndim=1), op_name)(how, **kwargs) + for key, how in func.items() + ] + keys = list(func.keys()) + + return keys, results + + def wrap_results_dict_like( + self, + selected_obj: Series | DataFrame, + result_index: list[Hashable], + result_data: list, + ): + from pandas import Index + from pandas.core.reshape.concat import concat + + obj = self.obj + + # Avoid making two isinstance calls in all and any below + is_ndframe = [isinstance(r, ABCNDFrame) for r in result_data] + + if all(is_ndframe): + results = [result for result in result_data if not result.empty] + keys_to_use: Iterable[Hashable] + keys_to_use = [ + k for k, v in zip(result_index, result_data, strict=True) if not v.empty + ] + # Have to check, if at least one DataFrame is not empty. + if keys_to_use == []: + keys_to_use = result_index + results = result_data + + if selected_obj.ndim == 2: + # keys are columns, so we can preserve names + ktu = Index(keys_to_use) + ktu._set_names(selected_obj.columns.names) + keys_to_use = ktu + + axis: AxisInt = 0 if isinstance(obj, ABCSeries) else 1 + result = concat( + results, + axis=axis, + keys=keys_to_use, + sort=False, + ) + elif any(is_ndframe): + # There is a mix of NDFrames and scalars + raise ValueError( + "cannot perform both aggregation " + "and transformation operations " + "simultaneously" + ) + else: + from pandas import Series + + # we have a list of scalars + # GH 36212 use name only if obj is a series + if obj.ndim == 1: + obj = cast("Series", obj) + name = obj.name + else: + name = None + + result = Series(result_data, index=result_index, name=name) + + return result + + def apply_str(self) -> DataFrame | Series: + """ + Compute apply in case of a string. + + Returns + ------- + result: Series or DataFrame + """ + # Caller is responsible for checking isinstance(self.f, str) + func = cast(str, self.func) + + obj = self.obj + + from pandas.core.groupby.generic import ( + DataFrameGroupBy, + SeriesGroupBy, + ) + + # Support for `frame.transform('method')` + # Some methods (shift, etc.) require the axis argument, others + # don't, so inspect and insert if necessary. + method = getattr(obj, func, None) + if callable(method): + sig = inspect.getfullargspec(method) + arg_names = (*sig.args, *sig.kwonlyargs) + if self.axis != 0 and ( + "axis" not in arg_names or func in ("corrwith", "skew") + ): + raise ValueError(f"Operation {func} does not support axis=1") + if "axis" in arg_names and not isinstance( + obj, (SeriesGroupBy, DataFrameGroupBy) + ): + self.kwargs["axis"] = self.axis + return self._apply_str(obj, func, *self.args, **self.kwargs) + + def apply_list_or_dict_like(self) -> DataFrame | Series: + """ + Compute apply in case of a list-like or dict-like. + + Returns + ------- + result: Series, DataFrame, or None + Result when self.func is a list-like or dict-like, None otherwise. + """ + + if self.engine == "numba": + raise NotImplementedError( + "The 'numba' engine doesn't support list-like/" + "dict likes of callables yet." + ) + + if self.axis == 1 and isinstance(self.obj, ABCDataFrame): + return self.obj.T.apply(self.func, 0, args=self.args, **self.kwargs).T + + func = self.func + kwargs = self.kwargs + + if is_dict_like(func): + result = self.agg_or_apply_dict_like(op_name="apply") + else: + result = self.agg_or_apply_list_like(op_name="apply") + + result = reconstruct_and_relabel_result(result, func, **kwargs) + + return result + + def normalize_dictlike_arg( + self, how: str, obj: DataFrame | Series, func: AggFuncTypeDict + ) -> AggFuncTypeDict: + """ + Handler for dict-like argument. + + Ensures that necessary columns exist if obj is a DataFrame, and + that a nested renamer is not passed. Also normalizes to all lists + when values consists of a mix of list and non-lists. + """ + assert how in ("apply", "agg", "transform") + + # Can't use func.values(); wouldn't work for a Series + if ( + how == "agg" + and isinstance(obj, ABCSeries) + and any(is_list_like(v) for _, v in func.items()) + ) or (any(is_dict_like(v) for _, v in func.items())): + # GH 15931 - deprecation of renaming keys + raise SpecificationError("nested renamer is not supported") + + if obj.ndim != 1: + # Check for missing columns on a frame + from pandas import Index + + cols = Index(list(func.keys())).difference(obj.columns, sort=True) + if len(cols) > 0: + # GH 58474 + raise KeyError(f"Label(s) {list(cols)} do not exist") + + aggregator_types = (list, tuple, dict) + + # if we have a dict of any non-scalars + # eg. {'A' : ['mean']}, normalize all to + # be list-likes + # Cannot use func.values() because arg may be a Series + if any(isinstance(x, aggregator_types) for _, x in func.items()): + new_func: AggFuncTypeDict = {} + for k, v in func.items(): + if not isinstance(v, aggregator_types): + new_func[k] = [v] + else: + new_func[k] = v + func = new_func + return func + + def _apply_str(self, obj, func: str, *args, **kwargs): + """ + if arg is a string, then try to operate on it: + - try to find a function (or attribute) on obj + - try to find a numpy function + - raise + """ + assert isinstance(func, str) + + if hasattr(obj, func): + f = getattr(obj, func) + if callable(f): + return f(*args, **kwargs) + + # people may aggregate on a non-callable attribute + # but don't let them think they can pass args to it + assert len(args) == 0 + assert not any(kwarg == "axis" for kwarg in kwargs) + return f + elif hasattr(np, func) and hasattr(obj, "__array__"): + # in particular exclude Window + f = getattr(np, func) + return f(obj, *args, **kwargs) + else: + msg = f"'{func}' is not a valid function for '{type(obj).__name__}' object" + raise AttributeError(msg) + + +class NDFrameApply(Apply): + """ + Methods shared by FrameApply and SeriesApply but + not GroupByApply or ResamplerWindowApply + """ + + obj: DataFrame | Series + + @property + def index(self) -> Index: + return self.obj.index + + @property + def agg_axis(self) -> Index: + return self.obj._get_agg_axis(self.axis) + + def agg_or_apply_list_like( + self, op_name: Literal["agg", "apply"] + ) -> DataFrame | Series: + obj = self.obj + kwargs = self.kwargs + + if op_name == "apply": + if isinstance(self, FrameApply): + by_row = self.by_row + + elif isinstance(self, SeriesApply): + by_row = "_compat" if self.by_row else False + else: + by_row = False + kwargs = {**kwargs, "by_row": by_row} + + if getattr(obj, "axis", 0) == 1: + raise NotImplementedError("axis other than 0 is not supported") + + keys, results = self.compute_list_like(op_name, obj, kwargs) + result = self.wrap_results_list_like(keys, results) + return result + + def agg_or_apply_dict_like( + self, op_name: Literal["agg", "apply"] + ) -> DataFrame | Series: + assert op_name in ["agg", "apply"] + obj = self.obj + + kwargs = {} + if op_name == "apply": + by_row = "_compat" if self.by_row else False + kwargs.update({"by_row": by_row}) + + if getattr(obj, "axis", 0) == 1: + raise NotImplementedError("axis other than 0 is not supported") + + selection = None + result_index, result_data = self.compute_dict_like( + op_name, obj, selection, kwargs + ) + result = self.wrap_results_dict_like(obj, result_index, result_data) + return result + + +class FrameApply(NDFrameApply): + obj: DataFrame + + def __init__( + self, + obj: AggObjType, + func: AggFuncType, + raw: bool, + result_type: str | None, + *, + by_row: Literal[False, "compat"] = False, + engine: str = "python", + engine_kwargs: dict[str, bool] | None = None, + args, + kwargs, + ) -> None: + if by_row is not False and by_row != "compat": + raise ValueError(f"by_row={by_row} not allowed") + super().__init__( + obj, + func, + raw, + result_type, + by_row=by_row, + engine=engine, + engine_kwargs=engine_kwargs, + args=args, + kwargs=kwargs, + ) + + # --------------------------------------------------------------- + # Abstract Methods + + @property + @abc.abstractmethod + def result_index(self) -> Index: + pass + + @property + @abc.abstractmethod + def result_columns(self) -> Index: + pass + + @property + @abc.abstractmethod + def series_generator(self) -> Generator[Series]: + pass + + @staticmethod + @functools.cache + @abc.abstractmethod + def generate_numba_apply_func( + func, nogil=True, nopython=True, parallel=False + ) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]: + pass + + @abc.abstractmethod + def apply_with_numba(self): + pass + + def validate_values_for_numba(self) -> None: + # Validate column dtyps all OK + for colname, dtype in self.obj.dtypes.items(): + if not is_numeric_dtype(dtype): + raise ValueError( + f"Column {colname} must have a numeric dtype. " + f"Found '{dtype}' instead" + ) + if is_extension_array_dtype(dtype): + raise ValueError( + f"Column {colname} is backed by an extension array, " + f"which is not supported by the numba engine." + ) + + @abc.abstractmethod + def wrap_results_for_axis( + self, results: ResType, res_index: Index + ) -> DataFrame | Series: + pass + + # --------------------------------------------------------------- + + @property + def res_columns(self) -> Index: + return self.result_columns + + @property + def columns(self) -> Index: + return self.obj.columns + + @cache_readonly + def values(self): + return self.obj.values + + def apply(self) -> DataFrame | Series: + """compute the results""" + + # dispatch to handle list-like or dict-like + if is_list_like(self.func): + if self.engine == "numba": + raise NotImplementedError( + "the 'numba' engine doesn't support lists of callables yet" + ) + return self.apply_list_or_dict_like() + + # all empty + if len(self.columns) == 0 and len(self.index) == 0: + return self.apply_empty_result() + + # string dispatch + if isinstance(self.func, str): + if self.engine == "numba": + raise NotImplementedError( + "the 'numba' engine doesn't support using " + "a string as the callable function" + ) + return self.apply_str() + + # ufunc + elif isinstance(self.func, np.ufunc): + if self.engine == "numba": + raise NotImplementedError( + "the 'numba' engine doesn't support " + "using a numpy ufunc as the callable function" + ) + with np.errstate(all="ignore"): + results = self.obj._mgr.apply("apply", func=self.func) + # _constructor will retain self.index and self.columns + return self.obj._constructor_from_mgr(results, axes=results.axes) + + # broadcasting + if self.result_type == "broadcast": + if self.engine == "numba": + raise NotImplementedError( + "the 'numba' engine doesn't support result_type='broadcast'" + ) + return self.apply_broadcast(self.obj) + + # one axis empty + elif not all(self.obj.shape): + return self.apply_empty_result() + + # raw + elif self.raw: + return self.apply_raw(engine=self.engine, engine_kwargs=self.engine_kwargs) + + return self.apply_standard() + + def agg(self): + obj = self.obj + axis = self.axis + + # TODO: Avoid having to change state + self.obj = self.obj if self.axis == 0 else self.obj.T + self.axis = 0 + + result = None + try: + result = super().agg() + finally: + self.obj = obj + self.axis = axis + + if axis == 1: + result = result.T if result is not None else result + + if result is None: + result = self.obj.apply(self.func, axis, args=self.args, **self.kwargs) + + return result + + def apply_empty_result(self): + """ + we have an empty result; at least 1 axis is 0 + + we will try to apply the function to an empty + series in order to see if this is a reduction function + """ + assert callable(self.func) + + # we are not asked to reduce or infer reduction + # so just return a copy of the existing object + if self.result_type not in ["reduce", None]: + return self.obj.copy() + + # we may need to infer + should_reduce = self.result_type == "reduce" + + from pandas import Series + + if not should_reduce: + try: + if self.axis == 0: + r = self.func( + Series([], dtype=np.float64), *self.args, **self.kwargs + ) + else: + r = self.func( + Series(index=self.columns, dtype=np.float64), + *self.args, + **self.kwargs, + ) + except Exception: + pass + else: + should_reduce = not isinstance(r, Series) + + if should_reduce: + if len(self.agg_axis): + r = self.func(Series([], dtype=np.float64), *self.args, **self.kwargs) + else: + r = np.nan + + return self.obj._constructor_sliced(r, index=self.agg_axis) + else: + return self.obj.copy() + + def apply_raw(self, engine="python", engine_kwargs=None): + """apply to the values as a numpy array""" + + def wrap_function(func): + """ + Wrap user supplied function to work around numpy issue. + + see https://github.com/numpy/numpy/issues/8352 + """ + + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + if isinstance(result, str): + result = np.array(result, dtype=object) + return result + + return wrapper + + if engine == "numba": + args, kwargs = prepare_function_arguments( + self.func, # type: ignore[arg-type] + self.args, + self.kwargs, + num_required_args=1, + ) + # error: Argument 1 to "__call__" of "_lru_cache_wrapper" has + # incompatible type "Callable[..., Any] | str | list[Callable + # [..., Any] | str] | dict[Hashable,Callable[..., Any] | str | + # list[Callable[..., Any] | str]]"; expected "Hashable" + nb_looper = generate_apply_looper( + self.func, # type: ignore[arg-type] + **get_jit_arguments(engine_kwargs), + ) + result = nb_looper(self.values, self.axis, *args) + # If we made the result 2-D, squeeze it back to 1-D + result = np.squeeze(result) + else: + result = np.apply_along_axis( + wrap_function(self.func), + self.axis, + self.values, + *self.args, + **self.kwargs, + ) + + # TODO: mixed type case + if result.ndim == 2: + return self.obj._constructor(result, index=self.index, columns=self.columns) + else: + return self.obj._constructor_sliced(result, index=self.agg_axis) + + def apply_broadcast(self, target: DataFrame) -> DataFrame: + assert callable(self.func) + + result_values = np.empty_like(target.values) + + # axis which we want to compare compliance + result_compare = target.shape[0] + + for i, col in enumerate(target.columns): + res = self.func(target[col], *self.args, **self.kwargs) + ares = np.asarray(res).ndim + + # must be a scalar or 1d + if ares > 1: + raise ValueError("too many dims to broadcast") + if ares == 1: + # must match return dim + if result_compare != len(res): + raise ValueError("cannot broadcast result") + + result_values[:, i] = res + + # we *always* preserve the original index / columns + result = self.obj._constructor( + result_values, index=target.index, columns=target.columns + ) + return result + + def apply_standard(self): + if self.engine == "python": + results, res_index = self.apply_series_generator() + else: + results, res_index = self.apply_series_numba() + + # wrap results + return self.wrap_results(results, res_index) + + def apply_series_generator(self) -> tuple[ResType, Index]: + assert callable(self.func) + + series_gen = self.series_generator + res_index = self.result_index + + results = {} + + for i, v in enumerate(series_gen): + results[i] = self.func(v, *self.args, **self.kwargs) + if isinstance(results[i], ABCSeries): + # If we have a view on v, we need to make a copy because + # series_generator will swap out the underlying data + results[i] = results[i].copy(deep=False) + + return results, res_index + + def apply_series_numba(self): + if self.engine_kwargs.get("parallel", False): + raise NotImplementedError( + "Parallel apply is not supported when raw=False and engine='numba'" + ) + if not self.obj.index.is_unique or not self.columns.is_unique: + raise NotImplementedError( + "The index/columns must be unique when raw=False and engine='numba'" + ) + self.validate_values_for_numba() + results = self.apply_with_numba() + return results, self.result_index + + def wrap_results(self, results: ResType, res_index: Index) -> DataFrame | Series: + from pandas import Series + + # see if we can infer the results + if len(results) > 0 and 0 in results and is_sequence(results[0]): + return self.wrap_results_for_axis(results, res_index) + + # dict of scalars + + # the default dtype of an empty Series is `object`, but this + # code can be hit by df.mean() where the result should have dtype + # float64 even if it's an empty Series. + constructor_sliced = self.obj._constructor_sliced + if len(results) == 0 and constructor_sliced is Series: + result = constructor_sliced(results, dtype=np.float64) + else: + result = constructor_sliced(results) + result.index = res_index + + return result + + def apply_str(self) -> DataFrame | Series: + # Caller is responsible for checking isinstance(self.func, str) + # TODO: GH#39993 - Avoid special-casing by replacing with lambda + if self.func == "size": + # Special-cased because DataFrame.size returns a single scalar + obj = self.obj + value = obj.shape[self.axis] + return obj._constructor_sliced(value, index=self.agg_axis) + return super().apply_str() + + +class FrameRowApply(FrameApply): + axis: AxisInt = 0 + + @property + def series_generator(self) -> Generator[Series]: + return (self.obj._ixs(i, axis=1) for i in range(len(self.columns))) + + @staticmethod + @functools.cache + def generate_numba_apply_func( + func, nogil=True, nopython=True, parallel=False + ) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]: + numba = import_optional_dependency("numba") + from pandas import Series + + # Import helper from extensions to cast string object -> np strings + # Note: This also has the side effect of loading our numba extensions + from pandas.core._numba.extensions import maybe_cast_str + + jitted_udf = numba.extending.register_jitable(func) + + # Currently the parallel argument doesn't get passed through here + # (it's disabled) since the dicts in numba aren't thread-safe. + @numba.jit(nogil=nogil, nopython=nopython, parallel=parallel) + def numba_func(values, col_names, df_index, *args): + results = {} + for j in range(values.shape[1]): + # Create the series + ser = Series( + values[:, j], index=df_index, name=maybe_cast_str(col_names[j]) + ) + results[j] = jitted_udf(ser, *args) + return results + + return numba_func + + def apply_with_numba(self) -> dict[int, Any]: + func = cast(Callable, self.func) + args, kwargs = prepare_function_arguments( + func, self.args, self.kwargs, num_required_args=1 + ) + nb_func = self.generate_numba_apply_func( + func, **get_jit_arguments(self.engine_kwargs) + ) + from pandas.core._numba.extensions import set_numba_data + + index = self.obj.index + columns = self.obj.columns + + # Convert from numba dict to regular dict + # Our isinstance checks in the df constructor don't pass for numbas typed dict + with set_numba_data(index) as index, set_numba_data(columns) as columns: + res = dict(nb_func(self.values, columns, index, *args)) + return res + + @property + def result_index(self) -> Index: + return self.columns + + @property + def result_columns(self) -> Index: + return self.index + + def wrap_results_for_axis( + self, results: ResType, res_index: Index + ) -> DataFrame | Series: + """return the results for the rows""" + + if self.result_type == "reduce": + # e.g. test_apply_dict GH#8735 + res = self.obj._constructor_sliced(results) + res.index = res_index + return res + + elif self.result_type is None and all( + isinstance(x, dict) for x in results.values() + ): + # Our operation was a to_dict op e.g. + # test_apply_dict GH#8735, test_apply_reduce_to_dict GH#25196 #37544 + res = self.obj._constructor_sliced(results) + res.index = res_index + return res + + try: + result = self.obj._constructor(data=results) + except ValueError as err: + if "All arrays must be of the same length" in str(err): + # e.g. result = [[2, 3], [1.5], ['foo', 'bar']] + # see test_agg_listlike_result GH#29587 + res = self.obj._constructor_sliced(results) + res.index = res_index + return res + else: + raise + + if not isinstance(results[0], ABCSeries): + if len(result.index) == len(self.res_columns): + result.index = self.res_columns + + if len(result.columns) == len(res_index): + result.columns = res_index + + return result + + +class FrameColumnApply(FrameApply): + axis: AxisInt = 1 + + def apply_broadcast(self, target: DataFrame) -> DataFrame: + result = super().apply_broadcast(target.T) + return result.T + + @property + def series_generator(self) -> Generator[Series]: + values = self.values + values = ensure_wrapped_if_datetimelike(values) + assert len(values) > 0 + + # We create one Series object, and will swap out the data inside + # of it. Kids: don't do this at home. + ser = self.obj._ixs(0, axis=0) + mgr = ser._mgr + + is_view = mgr.blocks[0].refs.has_reference() + + if isinstance(ser.dtype, ExtensionDtype): + # values will be incorrect for this block + # TODO(EA2D): special case would be unnecessary with 2D EAs + obj = self.obj + for i in range(len(obj)): + yield obj._ixs(i, axis=0) + + else: + for arr, name in zip(values, self.index, strict=True): + # GH#35462 re-pin mgr in case setitem changed it + ser._mgr = mgr + mgr.set_values(arr) + object.__setattr__(ser, "_name", name) + if not is_view: + # In apply_series_generator we store the a shallow copy of the + # result, which potentially increases the ref count of this reused + # `ser` object (depending on the result of the applied function) + # -> if that happened and `ser` is already a copy, then we reset + # the refs here to avoid triggering a unnecessary CoW inside the + # applied function (https://github.com/pandas-dev/pandas/pull/56212) + mgr.blocks[0].refs = BlockValuesRefs(mgr.blocks[0]) + yield ser + + @staticmethod + @functools.cache + def generate_numba_apply_func( + func, nogil=True, nopython=True, parallel=False + ) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]: + numba = import_optional_dependency("numba") + from pandas import Series + from pandas.core._numba.extensions import maybe_cast_str + + jitted_udf = numba.extending.register_jitable(func) + + @numba.jit(nogil=nogil, nopython=nopython, parallel=parallel) + def numba_func(values, col_names_index, index, *args): + results = {} + # Currently the parallel argument doesn't get passed through here + # (it's disabled) since the dicts in numba aren't thread-safe. + for i in range(values.shape[0]): + # Create the series + # TODO: values corrupted without the copy + ser = Series( + values[i].copy(), + index=col_names_index, + name=maybe_cast_str(index[i]), + ) + results[i] = jitted_udf(ser, *args) + + return results + + return numba_func + + def apply_with_numba(self) -> dict[int, Any]: + func = cast(Callable, self.func) + args, kwargs = prepare_function_arguments( + func, self.args, self.kwargs, num_required_args=1 + ) + nb_func = self.generate_numba_apply_func( + func, **get_jit_arguments(self.engine_kwargs) + ) + + from pandas.core._numba.extensions import set_numba_data + + # Convert from numba dict to regular dict + # Our isinstance checks in the df constructor don't pass for numbas typed dict + with ( + set_numba_data(self.obj.index) as index, + set_numba_data(self.columns) as columns, + ): + res = dict(nb_func(self.values, columns, index, *args)) + + return res + + @property + def result_index(self) -> Index: + return self.index + + @property + def result_columns(self) -> Index: + return self.columns + + def wrap_results_for_axis( + self, results: ResType, res_index: Index + ) -> DataFrame | Series: + """return the results for the columns""" + result: DataFrame | Series + + # we have requested to expand + if self.result_type == "expand": + result = self.infer_to_same_shape(results, res_index) + + # we have a non-series and don't want inference + elif not isinstance(results[0], ABCSeries): + result = self.obj._constructor_sliced(results) + result.index = res_index + + # we may want to infer results + else: + result = self.infer_to_same_shape(results, res_index) + + return result + + def infer_to_same_shape(self, results: ResType, res_index: Index) -> DataFrame: + """infer the results to the same shape as the input object""" + result = self.obj._constructor(data=results) + result = result.T + + # set the index + result.index = res_index + + # infer dtypes + result = result.infer_objects() + + return result + + +class SeriesApply(NDFrameApply): + obj: Series + axis: AxisInt = 0 + by_row: Literal[False, "compat", "_compat"] # only relevant for apply() + + def __init__( + self, + obj: Series, + func: AggFuncType, + *, + by_row: Literal[False, "compat", "_compat"] = "compat", + args, + kwargs, + ) -> None: + super().__init__( + obj, + func, + raw=False, + result_type=None, + by_row=by_row, + args=args, + kwargs=kwargs, + ) + + def apply(self) -> DataFrame | Series: + obj = self.obj + + if len(obj) == 0: + return self.apply_empty_result() + + # dispatch to handle list-like or dict-like + if is_list_like(self.func): + return self.apply_list_or_dict_like() + + if isinstance(self.func, str): + # if we are a string, try to dispatch + return self.apply_str() + + if self.by_row == "_compat": + return self.apply_compat() + + # self.func is Callable + return self.apply_standard() + + def agg(self): + result = super().agg() + if result is None: + obj = self.obj + func = self.func + # string, list-like, and dict-like are entirely handled in super + assert callable(func) + result = func(obj, *self.args, **self.kwargs) + return result + + def apply_empty_result(self) -> Series: + obj = self.obj + return obj._constructor(dtype=obj.dtype, index=obj.index).__finalize__( + obj, method="apply" + ) + + def apply_compat(self): + """compat apply method for funcs in listlikes and dictlikes. + + Used for each callable when giving listlikes and dictlikes of callables to + apply. Needed for compatibility with Pandas < v2.1. + + .. versionadded:: 2.1.0 + """ + obj = self.obj + func = self.func + + if callable(func): + f = com.get_cython_func(func) + if f and not self.args and not self.kwargs: + return obj.apply(func, by_row=False) + + try: + result = obj.apply(func, by_row="compat") + except (ValueError, AttributeError, TypeError): + result = obj.apply(func, by_row=False) + return result + + def apply_standard(self) -> DataFrame | Series: + # caller is responsible for ensuring that f is Callable + func = cast(Callable, self.func) + obj = self.obj + + if isinstance(func, np.ufunc): + with np.errstate(all="ignore"): + return func(obj, *self.args, **self.kwargs) + elif not self.by_row: + return func(obj, *self.args, **self.kwargs) + + if self.args or self.kwargs: + # _map_values does not support args/kwargs + def curried(x): + return func(x, *self.args, **self.kwargs) + + else: + curried = func + mapped = obj._map_values(mapper=curried) + + if len(mapped) and isinstance(mapped[0], ABCSeries): + # GH#43986 Need to do list(mapped) in order to get treated as nested + # See also GH#25959 regarding EA support + return obj._constructor_expanddim(list(mapped), index=obj.index) + else: + return obj._constructor(mapped, index=obj.index).__finalize__( + obj, method="apply" + ) + + +class GroupByApply(Apply): + obj: GroupBy | Resampler | BaseWindow + + def __init__( + self, + obj: GroupBy[NDFrameT], + func: AggFuncType, + *, + args, + kwargs, + ) -> None: + kwargs = kwargs.copy() + self.axis = obj.obj._get_axis_number(kwargs.get("axis", 0)) + super().__init__( + obj, + func, + raw=False, + result_type=None, + args=args, + kwargs=kwargs, + ) + + def apply(self): + raise NotImplementedError + + def transform(self): + raise NotImplementedError + + def agg_or_apply_list_like( + self, op_name: Literal["agg", "apply"] + ) -> DataFrame | Series: + obj = self.obj + kwargs = self.kwargs + if op_name == "apply": + kwargs = {**kwargs, "by_row": False} + + if getattr(obj, "axis", 0) == 1: + raise NotImplementedError("axis other than 0 is not supported") + + if obj._selected_obj.ndim == 1: + # For SeriesGroupBy this matches _obj_with_exclusions + selected_obj = obj._selected_obj + else: + selected_obj = obj._obj_with_exclusions + + # Only set as_index=True on groupby objects, not Window or Resample + # that inherit from this class. + with com.temp_setattr( + obj, "as_index", True, condition=hasattr(obj, "as_index") + ): + keys, results = self.compute_list_like(op_name, selected_obj, kwargs) + result = self.wrap_results_list_like(keys, results) + return result + + def agg_or_apply_dict_like( + self, op_name: Literal["agg", "apply"] + ) -> DataFrame | Series: + from pandas.core.groupby.generic import ( + DataFrameGroupBy, + SeriesGroupBy, + ) + + assert op_name in ["agg", "apply"] + + obj = self.obj + kwargs: dict[str, Any] = {} + if op_name == "apply": + by_row = "_compat" if self.by_row else False + kwargs.update({"by_row": by_row}) + + if getattr(obj, "axis", 0) == 1: + raise NotImplementedError("axis other than 0 is not supported") + + selected_obj = obj._selected_obj + selection = obj._selection + + is_groupby = isinstance(obj, (DataFrameGroupBy, SeriesGroupBy)) + + # Numba Groupby engine/engine-kwargs passthrough + if is_groupby: + engine = self.kwargs.get("engine", None) + engine_kwargs = self.kwargs.get("engine_kwargs", None) + kwargs.update({"engine": engine, "engine_kwargs": engine_kwargs}) + + with com.temp_setattr( + obj, "as_index", True, condition=hasattr(obj, "as_index") + ): + result_index, result_data = self.compute_dict_like( + op_name, selected_obj, selection, kwargs + ) + result = self.wrap_results_dict_like(selected_obj, result_index, result_data) + return result + + +class ResamplerWindowApply(GroupByApply): + axis: AxisInt = 0 + obj: Resampler | BaseWindow + + def __init__( + self, + obj: Resampler | BaseWindow, + func: AggFuncType, + *, + args, + kwargs, + ) -> None: + super(GroupByApply, self).__init__( + obj, + func, + raw=False, + result_type=None, + args=args, + kwargs=kwargs, + ) + + def apply(self): + raise NotImplementedError + + def transform(self): + raise NotImplementedError + + +def reconstruct_func( + func: AggFuncType | None, **kwargs +) -> tuple[bool, AggFuncType, tuple[str, ...] | None, npt.NDArray[np.intp] | None]: + """ + This is the internal function to reconstruct func given if there is relabeling + or not and also normalize the keyword to get new order of columns. + + If named aggregation is applied, `func` will be None, and kwargs contains the + column and aggregation function information to be parsed; + If named aggregation is not applied, `func` is either string (e.g. 'min') or + Callable, or list of them (e.g. ['min', np.max]), or the dictionary of column name + and str/Callable/list of them (e.g. {'A': 'min'}, or {'A': [np.min, lambda x: x]}) + + If relabeling is True, will return relabeling, reconstructed func, column + names, and the reconstructed order of columns. + If relabeling is False, the columns and order will be None. + + Parameters + ---------- + func: agg function (e.g. 'min' or Callable) or list of agg functions + (e.g. ['min', np.max]) or dictionary (e.g. {'A': ['min', np.max]}). + **kwargs: dict, kwargs used in is_multi_agg_with_relabel and + normalize_keyword_aggregation function for relabelling + + Returns + ------- + relabelling: bool, if there is relabelling or not + func: normalized and mangled func + columns: tuple of column names + order: array of columns indices + + Examples + -------- + >>> reconstruct_func(None, **{"foo": ("col", "min")}) + (True, defaultdict(, {'col': ['min']}), ('foo',), array([0])) + + >>> reconstruct_func("min") + (False, 'min', None, None) + """ + from pandas.core.groupby.generic import NamedAgg + + relabeling = func is None and ( + is_multi_agg_with_relabel(**kwargs) + or any(isinstance(v, NamedAgg) for v in kwargs.values()) + ) + + columns: tuple[str, ...] | None = None + order: npt.NDArray[np.intp] | None = None + + if not relabeling: + if isinstance(func, list) and len(func) > len(set(func)): + # GH 28426 will raise error if duplicated function names are used and + # there is no reassigned name + raise SpecificationError( + "Function names must be unique if there is no new column names assigned" + ) + if func is None: + # nicer error message + raise TypeError("Must provide 'func' or tuples of '(column, aggfunc).") + + if relabeling: + # error: Incompatible types in assignment (expression has type + # "MutableMapping[Hashable, list[Callable[..., Any] | str]]", variable has type + # "Callable[..., Any] | str | list[Callable[..., Any] | str] | + # MutableMapping[Hashable, Callable[..., Any] | str | list[Callable[..., Any] | + # str]] | None") + converted_kwargs = {} + for key, val in kwargs.items(): + if isinstance(val, NamedAgg): + aggfunc = val.aggfunc + if val.args or val.kwargs: + aggfunc = lambda x, func=aggfunc, a=val.args, kw=val.kwargs: func( + x, *a, **kw + ) + converted_kwargs[key] = (val.column, aggfunc) + else: + converted_kwargs[key] = val + + func, columns, order = normalize_keyword_aggregation( # type: ignore[assignment] + converted_kwargs + ) + + assert func is not None + + return relabeling, func, columns, order + + +def is_multi_agg_with_relabel(**kwargs) -> bool: + """ + Check whether kwargs passed to .agg look like multi-agg with relabeling. + + Parameters + ---------- + **kwargs : dict + + Returns + ------- + bool + + Examples + -------- + >>> is_multi_agg_with_relabel(a="max") + False + >>> is_multi_agg_with_relabel(a_max=("a", "max"), a_min=("a", "min")) + True + >>> is_multi_agg_with_relabel() + False + """ + return all(isinstance(v, tuple) and len(v) == 2 for v in kwargs.values()) and ( + len(kwargs) > 0 + ) + + +def normalize_keyword_aggregation( + kwargs: dict, +) -> tuple[ + MutableMapping[Hashable, list[AggFuncTypeBase]], + tuple[str, ...], + npt.NDArray[np.intp], +]: + """ + Normalize user-provided "named aggregation" kwargs. + Transforms from the new ``Mapping[str, NamedAgg]`` style kwargs + to the old Dict[str, List[scalar]]]. + + Parameters + ---------- + kwargs : dict + + Returns + ------- + aggspec : dict + The transformed kwargs. + columns : tuple[str, ...] + The user-provided keys. + col_idx_order : List[int] + List of columns indices. + + Examples + -------- + >>> normalize_keyword_aggregation({"output": ("input", "sum")}) + (defaultdict(, {'input': ['sum']}), ('output',), array([0])) + """ + from pandas.core.indexes.base import Index + + # Normalize the aggregation functions as Mapping[column, List[func]], + # process normally, then fixup the names. + # TODO: aggspec type: typing.Dict[str, List[AggScalar]] + aggspec = defaultdict(list) + order = [] + columns = tuple(kwargs.keys()) + + for column, aggfunc in kwargs.values(): + aggspec[column].append(aggfunc) + order.append((column, com.get_callable_name(aggfunc) or aggfunc)) + + # uniquify aggfunc name if duplicated in order list + uniquified_order = _make_unique_kwarg_list(order) + + # GH 25719, due to aggspec will change the order of assigned columns in aggregation + # uniquified_aggspec will store uniquified order list and will compare it with order + # based on index + aggspec_order = [ + (column, com.get_callable_name(aggfunc) or aggfunc) + for column, aggfuncs in aggspec.items() + for aggfunc in aggfuncs + ] + uniquified_aggspec = _make_unique_kwarg_list(aggspec_order) + + # get the new index of columns by comparison + col_idx_order = Index(uniquified_aggspec).get_indexer(uniquified_order) + return aggspec, columns, col_idx_order + + +def _make_unique_kwarg_list( + seq: Sequence[tuple[Any, Any]], +) -> Sequence[tuple[Any, Any]]: + """ + Uniquify aggfunc name of the pairs in the order list + + Examples: + -------- + >>> kwarg_list = [("a", ""), ("a", ""), ("b", "")] + >>> _make_unique_kwarg_list(kwarg_list) + [('a', '_0'), ('a', '_1'), ('b', '')] + """ + return [ + (pair[0], f"{pair[1]}_{seq[:i].count(pair)}") if seq.count(pair) > 1 else pair + for i, pair in enumerate(seq) + ] + + +def relabel_result( + result: DataFrame | Series, + func: dict[str, list[Callable | str]], + columns: Iterable[Hashable], + order: Iterable[int], +) -> dict[Hashable, Series]: + """ + Internal function to reorder result if relabelling is True for + dataframe.agg, and return the reordered result in dict. + + Parameters: + ---------- + result: Result from aggregation + func: Dict of (column name, funcs) + columns: New columns name for relabelling + order: New order for relabelling + + Examples + -------- + >>> from pandas.core.apply import relabel_result + >>> result = pd.DataFrame( + ... {"A": [np.nan, 2, np.nan], "C": [6, np.nan, np.nan], "B": [np.nan, 4, 2.5]}, + ... index=["max", "mean", "min"], + ... ) + >>> funcs = {"A": ["max"], "C": ["max"], "B": ["mean", "min"]} + >>> columns = ("foo", "aab", "bar", "dat") + >>> order = [0, 1, 2, 3] + >>> result_in_dict = relabel_result(result, funcs, columns, order) + >>> pd.DataFrame(result_in_dict, index=columns) + A C B + foo 2.0 NaN NaN + aab NaN 6.0 NaN + bar NaN NaN 4.0 + dat NaN NaN 2.5 + """ + from pandas.core.indexes.base import Index + + reordered_indexes = [ + pair[0] for pair in sorted(zip(columns, order, strict=True), key=lambda t: t[1]) + ] + reordered_result_in_dict: dict[Hashable, Series] = {} + idx = 0 + + reorder_mask = not isinstance(result, ABCSeries) and len(result.columns) > 1 + for col, fun in func.items(): + s = result[col].dropna() + + # In the `_aggregate`, the callable names are obtained and used in `result`, and + # these names are ordered alphabetically. e.g. + # C2 C1 + # 1 NaN + # amax NaN 4.0 + # max NaN 4.0 + # sum 18.0 6.0 + # Therefore, the order of functions for each column could be shuffled + # accordingly so need to get the callable name if it is not parsed names, and + # reorder the aggregated result for each column. + # e.g. if df.agg(c1=("C2", sum), c2=("C2", lambda x: min(x))), correct order is + # [sum, ], but in `result`, it will be [, sum], and we need to + # reorder so that aggregated values map to their functions regarding the order. + + # However there is only one column being used for aggregation, not need to + # reorder since the index is not sorted, and keep as is in `funcs`, e.g. + # A + # min 1.0 + # mean 1.5 + # mean 1.5 + if reorder_mask: + fun = [ + com.get_callable_name(f) if not isinstance(f, str) else f for f in fun + ] + col_idx_order = Index(s.index, copy=False).get_indexer(fun) + valid_idx = col_idx_order != -1 + if valid_idx.any(): + s = s.iloc[col_idx_order[valid_idx]] + # assign the new user-provided "named aggregation" as index names, and reindex + # it based on the whole user-provided names. + if not s.empty: + s.index = reordered_indexes[idx : idx + len(fun)] + reordered_result_in_dict[col] = s.reindex(columns) + idx = idx + len(fun) + return reordered_result_in_dict + + +def reconstruct_and_relabel_result(result, func, **kwargs) -> DataFrame | Series: + from pandas import DataFrame + + relabeling, func, columns, order = reconstruct_func(func, **kwargs) + + if relabeling: + # This is to keep the order to columns occurrence unchanged, and also + # keep the order of new columns occurrence unchanged + + # For the return values of reconstruct_func, if relabeling is + # False, columns and order will be None. + assert columns is not None + assert order is not None + + result_in_dict = relabel_result(result, func, columns, order) + result = DataFrame(result_in_dict, index=columns) + + return result + + +# TODO: Can't use, because mypy doesn't like us setting __name__ +# error: "partial[Any]" has no attribute "__name__" +# the type is: +# typing.Sequence[Callable[..., ScalarResult]] +# -> typing.Sequence[Callable[..., ScalarResult]]: + + +def _managle_lambda_list(aggfuncs: Sequence[Any]) -> Sequence[Any]: + """ + Possibly mangle a list of aggfuncs. + + Parameters + ---------- + aggfuncs : Sequence + + Returns + ------- + mangled: list-like + A new AggSpec sequence, where lambdas have been converted + to have unique names. + + Notes + ----- + If just one aggfunc is passed, the name will not be mangled. + """ + if len(aggfuncs) <= 1: + # don't mangle for .agg([lambda x: .]) + return aggfuncs + i = 0 + mangled_aggfuncs = [] + for aggfunc in aggfuncs: + if com.get_callable_name(aggfunc) == "": + aggfunc = partial(aggfunc) + # error: "partial[Any]" has no attribute "__name__"; maybe "__new__"? + aggfunc.__name__ = f"" # type: ignore[attr-defined] + i += 1 + mangled_aggfuncs.append(aggfunc) + + return mangled_aggfuncs + + +def maybe_mangle_lambdas(agg_spec: Any) -> Any: + """ + Make new lambdas with unique names. + + Parameters + ---------- + agg_spec : Any + An argument to GroupBy.agg. + Non-dict-like `agg_spec` are pass through as is. + For dict-like `agg_spec` a new spec is returned + with name-mangled lambdas. + + Returns + ------- + mangled : Any + Same type as the input. + + Examples + -------- + >>> maybe_mangle_lambdas("sum") + 'sum' + >>> maybe_mangle_lambdas([lambda: 1, lambda: 2]) # doctest: +SKIP + [, + .f(*args, **kwargs)>] + """ + is_dict = is_dict_like(agg_spec) + if not (is_dict or is_list_like(agg_spec)): + return agg_spec + mangled_aggspec = type(agg_spec)() # dict or OrderedDict + + if is_dict: + for key, aggfuncs in agg_spec.items(): + if is_list_like(aggfuncs) and not is_dict_like(aggfuncs): + mangled_aggfuncs = _managle_lambda_list(aggfuncs) + else: + mangled_aggfuncs = aggfuncs + + mangled_aggspec[key] = mangled_aggfuncs + else: + mangled_aggspec = _managle_lambda_list(agg_spec) + + return mangled_aggspec + + +def validate_func_kwargs( + kwargs: dict, +) -> tuple[list[str], list[str | Callable[..., Any]]]: + """ + Validates types of user-provided "named aggregation" kwargs. + `TypeError` is raised if aggfunc is not `str` or callable. + + Parameters + ---------- + kwargs : dict + + Returns + ------- + columns : List[str] + List of user-provided keys. + func : List[Union[str, callable[...,Any]]] + List of user-provided aggfuncs + + Examples + -------- + >>> validate_func_kwargs({"one": "min", "two": "max"}) + (['one', 'two'], ['min', 'max']) + """ + tuple_given_message = "func is expected but received {} in **kwargs." + columns = list(kwargs) + func = [] + for col_func in kwargs.values(): + if not (isinstance(col_func, str) or callable(col_func)): + raise TypeError(tuple_given_message.format(type(col_func).__name__)) + func.append(col_func) + if not columns: + no_arg_message = "Must provide 'func' or named aggregation **kwargs." + raise TypeError(no_arg_message) + return columns, func + + +def include_axis(op_name: Literal["agg", "apply"], colg: Series | DataFrame) -> bool: + return isinstance(colg, ABCDataFrame) or ( + isinstance(colg, ABCSeries) and op_name == "agg" + ) diff --git a/pandas/core/arraylike.py b/pandas/core/arraylike.py new file mode 100644 index 0000000000000000000000000000000000000000..5244f86e47318b4e8895a9271161ee2aea50ee10 --- /dev/null +++ b/pandas/core/arraylike.py @@ -0,0 +1,534 @@ +""" +Methods that can be shared by many array-like classes or subclasses: + Series + Index + ExtensionArray +""" + +from __future__ import annotations + +import operator +from typing import Any + +import numpy as np + +from pandas._libs import lib +from pandas._libs.ops_dispatch import maybe_dispatch_ufunc_to_dunder_op + +from pandas.core.dtypes.cast import maybe_unbox_numpy_scalar +from pandas.core.dtypes.generic import ABCNDFrame + +from pandas.core import roperator +from pandas.core.construction import extract_array +from pandas.core.ops.common import unpack_zerodim_and_defer + +REDUCTION_ALIASES = { + "maximum": "max", + "minimum": "min", + "add": "sum", + "multiply": "prod", +} + + +class OpsMixin: + # ------------------------------------------------------------- + # Comparisons + + def _cmp_method(self, other, op): + return NotImplemented + + @unpack_zerodim_and_defer("__eq__") + def __eq__(self, other): + return self._cmp_method(other, operator.eq) + + @unpack_zerodim_and_defer("__ne__") + def __ne__(self, other): + return self._cmp_method(other, operator.ne) + + @unpack_zerodim_and_defer("__lt__") + def __lt__(self, other): + return self._cmp_method(other, operator.lt) + + @unpack_zerodim_and_defer("__le__") + def __le__(self, other): + return self._cmp_method(other, operator.le) + + @unpack_zerodim_and_defer("__gt__") + def __gt__(self, other): + return self._cmp_method(other, operator.gt) + + @unpack_zerodim_and_defer("__ge__") + def __ge__(self, other): + return self._cmp_method(other, operator.ge) + + # ------------------------------------------------------------- + # Logical Methods + + def _logical_method(self, other, op): + return NotImplemented + + @unpack_zerodim_and_defer("__and__") + def __and__(self, other): + return self._logical_method(other, operator.and_) + + @unpack_zerodim_and_defer("__rand__") + def __rand__(self, other): + return self._logical_method(other, roperator.rand_) + + @unpack_zerodim_and_defer("__or__") + def __or__(self, other): + return self._logical_method(other, operator.or_) + + @unpack_zerodim_and_defer("__ror__") + def __ror__(self, other): + return self._logical_method(other, roperator.ror_) + + @unpack_zerodim_and_defer("__xor__") + def __xor__(self, other): + return self._logical_method(other, operator.xor) + + @unpack_zerodim_and_defer("__rxor__") + def __rxor__(self, other): + return self._logical_method(other, roperator.rxor) + + # ------------------------------------------------------------- + # Arithmetic Methods + + def _arith_method(self, other, op): + return NotImplemented + + @unpack_zerodim_and_defer("__add__") + def __add__(self, other): + """ + Get Addition of DataFrame and other, column-wise. + + Equivalent to ``DataFrame.add(other)``. + + Parameters + ---------- + other : scalar, sequence, Series, dict or DataFrame + Object to be added to the DataFrame. + + Returns + ------- + DataFrame + The result of adding ``other`` to DataFrame. + + See Also + -------- + DataFrame.add : Add a DataFrame and another object, with option for index- + or column-oriented addition. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"height": [1.5, 2.6], "weight": [500, 800]}, index=["elk", "moose"] + ... ) + >>> df + height weight + elk 1.5 500 + moose 2.6 800 + + Adding a scalar affects all rows and columns. + + >>> df[["height", "weight"]] + 1.5 + height weight + elk 3.0 501.5 + moose 4.1 801.5 + + Each element of a list is added to a column of the DataFrame, in order. + + >>> df[["height", "weight"]] + [0.5, 1.5] + height weight + elk 2.0 501.5 + moose 3.1 801.5 + + Keys of a dictionary are aligned to the DataFrame, based on column names; + each value in the dictionary is added to the corresponding column. + + >>> df[["height", "weight"]] + {"height": 0.5, "weight": 1.5} + height weight + elk 2.0 501.5 + moose 3.1 801.5 + + When `other` is a :class:`Series`, the index of `other` is aligned with the + columns of the DataFrame. + + >>> s1 = pd.Series([0.5, 1.5], index=["weight", "height"]) + >>> df[["height", "weight"]] + s1 + height weight + elk 3.0 500.5 + moose 4.1 800.5 + + Even when the index of `other` is the same as the index of the DataFrame, + the :class:`Series` will not be reoriented. If index-wise alignment is desired, + :meth:`DataFrame.add` should be used with `axis='index'`. + + >>> s2 = pd.Series([0.5, 1.5], index=["elk", "moose"]) + >>> df[["height", "weight"]] + s2 + elk height moose weight + elk NaN NaN NaN NaN + moose NaN NaN NaN NaN + + >>> df[["height", "weight"]].add(s2, axis="index") + height weight + elk 2.0 500.5 + moose 4.1 801.5 + + When `other` is a :class:`DataFrame`, both columns names and the + index are aligned. + + >>> other = pd.DataFrame( + ... {"height": [0.2, 0.4, 0.6]}, index=["elk", "moose", "deer"] + ... ) + >>> df[["height", "weight"]] + other + height weight + deer NaN NaN + elk 1.7 NaN + moose 3.0 NaN + """ + return self._arith_method(other, operator.add) + + @unpack_zerodim_and_defer("__radd__") + def __radd__(self, other): + return self._arith_method(other, roperator.radd) + + @unpack_zerodim_and_defer("__sub__") + def __sub__(self, other): + return self._arith_method(other, operator.sub) + + @unpack_zerodim_and_defer("__rsub__") + def __rsub__(self, other): + return self._arith_method(other, roperator.rsub) + + @unpack_zerodim_and_defer("__mul__") + def __mul__(self, other): + return self._arith_method(other, operator.mul) + + @unpack_zerodim_and_defer("__rmul__") + def __rmul__(self, other): + return self._arith_method(other, roperator.rmul) + + @unpack_zerodim_and_defer("__truediv__") + def __truediv__(self, other): + return self._arith_method(other, operator.truediv) + + @unpack_zerodim_and_defer("__rtruediv__") + def __rtruediv__(self, other): + return self._arith_method(other, roperator.rtruediv) + + @unpack_zerodim_and_defer("__floordiv__") + def __floordiv__(self, other): + return self._arith_method(other, operator.floordiv) + + @unpack_zerodim_and_defer("__rfloordiv") + def __rfloordiv__(self, other): + return self._arith_method(other, roperator.rfloordiv) + + @unpack_zerodim_and_defer("__mod__") + def __mod__(self, other): + return self._arith_method(other, operator.mod) + + @unpack_zerodim_and_defer("__rmod__") + def __rmod__(self, other): + return self._arith_method(other, roperator.rmod) + + @unpack_zerodim_and_defer("__divmod__") + def __divmod__(self, other): + return self._arith_method(other, divmod) + + @unpack_zerodim_and_defer("__rdivmod__") + def __rdivmod__(self, other): + return self._arith_method(other, roperator.rdivmod) + + @unpack_zerodim_and_defer("__pow__") + def __pow__(self, other): + return self._arith_method(other, operator.pow) + + @unpack_zerodim_and_defer("__rpow__") + def __rpow__(self, other): + return self._arith_method(other, roperator.rpow) + + +# ----------------------------------------------------------------------------- +# Helpers to implement __array_ufunc__ + + +def array_ufunc(self, ufunc: np.ufunc, method: str, *inputs: Any, **kwargs: Any): + """ + Compatibility with numpy ufuncs. + + See also + -------- + numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_ufunc__ + """ + from pandas.core.frame import ( + DataFrame, + Series, + ) + from pandas.core.generic import NDFrame + from pandas.core.internals import BlockManager + + cls = type(self) + + kwargs = _standardize_out_kwarg(**kwargs) + + # for binary ops, use our custom dunder methods + result = maybe_dispatch_ufunc_to_dunder_op(self, ufunc, method, *inputs, **kwargs) + if result is not NotImplemented: + return result + + # Determine if we should defer. + no_defer = ( + np.ndarray.__array_ufunc__, + cls.__array_ufunc__, + ) + + for item in inputs: + higher_priority = ( + hasattr(item, "__array_priority__") + and item.__array_priority__ > self.__array_priority__ + ) + has_array_ufunc = ( + hasattr(item, "__array_ufunc__") + and type(item).__array_ufunc__ not in no_defer + and not isinstance(item, self._HANDLED_TYPES) + ) + if higher_priority or has_array_ufunc: + return NotImplemented + + # align all the inputs. + types = tuple(type(x) for x in inputs) + alignable = [ + x for x, t in zip(inputs, types, strict=True) if issubclass(t, NDFrame) + ] + + if len(alignable) > 1: + # This triggers alignment. + # At the moment, there aren't any ufuncs with more than two inputs + # so this ends up just being x1.index | x2.index, but we write + # it to handle *args. + set_types = set(types) + if len(set_types) > 1 and {DataFrame, Series}.issubset(set_types): + # We currently don't handle ufunc(DataFrame, Series) + # well. Previously this raised an internal ValueError. We might + # support it someday, so raise a NotImplementedError. + raise NotImplementedError( + f"Cannot apply ufunc {ufunc} to mixed DataFrame and Series inputs." + ) + axes = self.axes + for obj in alignable[1:]: + # this relies on the fact that we aren't handling mixed + # series / frame ufuncs. + for i, (ax1, ax2) in enumerate(zip(axes, obj.axes, strict=True)): + axes[i] = ax1.union(ax2) + + reconstruct_axes = dict(zip(self._AXIS_ORDERS, axes, strict=True)) + inputs = tuple( + x.reindex(**reconstruct_axes) if issubclass(t, NDFrame) else x + for x, t in zip(inputs, types, strict=True) + ) + else: + reconstruct_axes = dict(zip(self._AXIS_ORDERS, self.axes, strict=True)) + + if self.ndim == 1: + names = {x.name for x in inputs if hasattr(x, "name")} + name = names.pop() if len(names) == 1 else None + reconstruct_kwargs = {"name": name} + else: + reconstruct_kwargs = {} + + def reconstruct(result): + if ufunc.nout > 1: + # np.modf, np.frexp, np.divmod + return tuple(_reconstruct(x) for x in result) + + return _reconstruct(result) + + def _reconstruct(result): + if lib.is_scalar(result): + return result + + if result.ndim != self.ndim: + if method == "outer": + raise NotImplementedError + return result + if isinstance(result, BlockManager): + # we went through BlockManager.apply e.g. np.sqrt + result = self._constructor_from_mgr(result, axes=result.axes) + else: + # we converted an array, lost our axes + result = self._constructor( + result, **reconstruct_axes, **reconstruct_kwargs, copy=False + ) + # TODO: When we support multiple values in __finalize__, this + # should pass alignable to `__finalize__` instead of self. + # Then `np.add(a, b)` would consider attrs from both a and b + # when a and b are NDFrames. + if len(alignable) == 1: + result = result.__finalize__(self) + return result + + if "out" in kwargs: + # e.g. test_multiindex_get_loc + result = dispatch_ufunc_with_out(self, ufunc, method, *inputs, **kwargs) + return reconstruct(result) + + if method == "reduce": + # e.g. test.series.test_ufunc.test_reduce + result = dispatch_reduction_ufunc(self, ufunc, method, *inputs, **kwargs) + if result is not NotImplemented: + return result + + # We still get here with kwargs `axis` for e.g. np.maximum.accumulate + # and `dtype` and `keepdims` for np.ptp + + if self.ndim > 1 and (len(inputs) > 1 or ufunc.nout > 1): + # Just give up on preserving types in the complex case. + # In theory we could preserve them for them. + # * nout>1 is doable if BlockManager.apply took nout and + # returned a Tuple[BlockManager]. + # * len(inputs) > 1 is doable when we know that we have + # aligned blocks / dtypes. + + # e.g. my_ufunc, modf, logaddexp, heaviside, subtract, add + inputs = tuple(np.asarray(x) for x in inputs) + # Note: we can't use default_array_ufunc here bc reindexing means + # that `self` may not be among `inputs` + result = getattr(ufunc, method)(*inputs, **kwargs) + elif self.ndim == 1: + # ufunc(series, ...) + inputs = tuple(extract_array(x, extract_numpy=True) for x in inputs) + result = getattr(ufunc, method)(*inputs, **kwargs) + # ufunc(dataframe) + elif method == "__call__" and not kwargs: + # for np.(..) calls + # kwargs cannot necessarily be handled block-by-block, so only + # take this path if there are no kwargs + mgr = inputs[0]._mgr # pyright: ignore[reportGeneralTypeIssues] + result = mgr.apply(getattr(ufunc, method)) + else: + # otherwise specific ufunc methods (eg np..accumulate(..)) + # Those can have an axis keyword and thus can't be called block-by-block + result = default_array_ufunc(inputs[0], ufunc, method, *inputs, **kwargs) # pyright: ignore[reportGeneralTypeIssues] + # e.g. np.negative (only one reached), with "where" and "out" in kwargs + + result = reconstruct(result) + return result + + +def _standardize_out_kwarg(**kwargs) -> dict: + """ + If kwargs contain "out1" and "out2", replace that with a tuple "out" + + np.divmod, np.modf, np.frexp can have either `out=(out1, out2)` or + `out1=out1, out2=out2)` + """ + if "out" not in kwargs and "out1" in kwargs and "out2" in kwargs: + out1 = kwargs.pop("out1") + out2 = kwargs.pop("out2") + out = (out1, out2) + kwargs["out"] = out + return kwargs + + +def dispatch_ufunc_with_out(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): + """ + If we have an `out` keyword, then call the ufunc without `out` and then + set the result into the given `out`. + """ + + # Note: we assume _standardize_out_kwarg has already been called. + out = kwargs.pop("out") + where = kwargs.pop("where", None) + + result = getattr(ufunc, method)(*inputs, **kwargs) + + if result is NotImplemented: + return NotImplemented + + if isinstance(result, tuple): + # i.e. np.divmod, np.modf, np.frexp + if not isinstance(out, tuple) or len(out) != len(result): + raise NotImplementedError + + for arr, res in zip(out, result, strict=True): + _assign_where(arr, res, where) + + return out + + if isinstance(out, tuple): + if len(out) == 1: + out = out[0] + else: + raise NotImplementedError + + _assign_where(out, result, where) + return out + + +def _assign_where(out, result, where) -> None: + """ + Set a ufunc result into 'out', masking with a 'where' argument if necessary. + """ + if where is None: + # no 'where' arg passed to ufunc + out[:] = result + else: + np.putmask(out, where, result) + + +def default_array_ufunc(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): + """ + Fallback to the behavior we would get if we did not define __array_ufunc__. + + Notes + ----- + We are assuming that `self` is among `inputs`. + """ + if not any(x is self for x in inputs): + raise NotImplementedError + + new_inputs = [x if x is not self else np.asarray(x) for x in inputs] + + return getattr(ufunc, method)(*new_inputs, **kwargs) + + +def dispatch_reduction_ufunc(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): + """ + Dispatch ufunc reductions to self's reduction methods. + """ + assert method == "reduce" + + if len(inputs) != 1 or inputs[0] is not self: + return NotImplemented + + if ufunc.__name__ not in REDUCTION_ALIASES: + return NotImplemented + + method_name = REDUCTION_ALIASES[ufunc.__name__] + + # NB: we are assuming that min/max represent minimum/maximum methods, + # which would not be accurate for e.g. Timestamp.min + if not hasattr(self, method_name): + return NotImplemented + + if self.ndim > 1: + if isinstance(self, ABCNDFrame): + # TODO: test cases where this doesn't hold, i.e. 2D DTA/TDA + kwargs["numeric_only"] = False + + if "axis" not in kwargs: + # For DataFrame reductions we don't want the default axis=0 + # Note: np.min is not a ufunc, but uses array_function_dispatch, + # so calls DataFrame.min (without ever getting here) with the np.min + # default of axis=None, which DataFrame.min catches and changes to axis=0. + # np.minimum.reduce(df) gets here bc axis is not in kwargs, + # so we set axis=0 to match the behavior of np.minimum.reduce(df.values) + kwargs["axis"] = 0 + + # By default, numpy's reductions do not skip NaNs, so we have to + # pass skipna=False + result = getattr(self, method_name)(skipna=False, **kwargs) + result = maybe_unbox_numpy_scalar(result) + return result diff --git a/pandas/core/base.py b/pandas/core/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f92558a34756fe59112ec978e53e364db2c15714 --- /dev/null +++ b/pandas/core/base.py @@ -0,0 +1,1653 @@ +""" +Base and utility classes for pandas objects. +""" + +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + Self, + cast, + final, + overload, +) + +import numpy as np + +from pandas._libs import lib +from pandas._typing import ( + AxisInt, + DtypeObj, + IndexLabel, + NDFrameT, + Shape, + npt, +) +from pandas.compat import PYPY +from pandas.compat.numpy import function as nv +from pandas.errors import AbstractMethodError +from pandas.util._decorators import cache_readonly + +from pandas.core.dtypes.cast import can_hold_element +from pandas.core.dtypes.common import ( + is_object_dtype, + is_scalar, +) +from pandas.core.dtypes.dtypes import ExtensionDtype +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCIndex, + ABCMultiIndex, + ABCSeries, +) +from pandas.core.dtypes.missing import ( + isna, + remove_na_arraylike, +) + +from pandas.core import ( + algorithms, + nanops, + ops, +) +from pandas.core.accessor import DirNamesMixin +from pandas.core.arraylike import OpsMixin +from pandas.core.arrays import ExtensionArray +from pandas.core.construction import ( + ensure_wrapped_if_datetimelike, + extract_array, +) + +if TYPE_CHECKING: + from collections.abc import ( + Hashable, + Iterator, + ) + + from pandas._typing import ( + DropKeep, + NumpySorter, + NumpyValueArrayLike, + ScalarLike_co, + ) + + from pandas import ( + DataFrame, + Index, + Series, + ) + + +class PandasObject(DirNamesMixin): + """ + Base class for various pandas objects. + """ + + # results from calls to methods decorated with cache_readonly get added to _cache + _cache: dict[str, Any] + + @property + def _constructor(self) -> type[Self]: + """ + Class constructor (for this class it's just `__class__`). + """ + return type(self) + + def __repr__(self) -> str: + """ + Return a string representation for a particular object. + """ + # Should be overwritten by base classes + return object.__repr__(self) + + def _reset_cache(self, key: str | None = None) -> None: + """ + Reset cached properties. If ``key`` is passed, only clears that key. + """ + if not hasattr(self, "_cache"): + return + if key is None: + self._cache.clear() + else: + self._cache.pop(key, None) + + def __sizeof__(self) -> int: + """ + Generates the total memory usage for an object that returns + either a value or Series of values + """ + memory_usage = getattr(self, "memory_usage", None) + if memory_usage: + mem = memory_usage(deep=True) + return int(mem if is_scalar(mem) else mem.sum()) + + # no memory_usage attribute, so fall back to object's 'sizeof' + return super().__sizeof__() + + +class NoNewAttributesMixin: + """ + Mixin which prevents adding new attributes. + + Prevents additional attributes via xxx.attribute = "something" after a + call to `self.__freeze()`. Mainly used to prevent the user from using + wrong attributes on an accessor (`Series.cat/.str/.dt`). + + If you really want to add a new attribute at a later time, you need to use + `object.__setattr__(self, key, value)`. + """ + + def _freeze(self) -> None: + """ + Prevents setting additional attributes. + """ + object.__setattr__(self, "__frozen", True) + + # prevent adding any attribute via s.xxx.new_attribute = ... + def __setattr__(self, key: str, value) -> None: + # _cache is used by a decorator + # We need to check both 1.) cls.__dict__ and 2.) getattr(self, key) + # because + # 1.) getattr is false for attributes that raise errors + # 2.) cls.__dict__ doesn't traverse into base classes + if getattr(self, "__frozen", False) and not ( + key == "_cache" + or key in type(self).__dict__ + or getattr(self, key, None) is not None + ): + raise AttributeError(f"You cannot add any new attribute '{key}'") + object.__setattr__(self, key, value) + + +class SelectionMixin(Generic[NDFrameT]): + """ + mixin implementing the selection & aggregation interface on a group-like + object sub-classes need to define: obj, exclusions + """ + + obj: NDFrameT + _selection: IndexLabel | None = None + exclusions: frozenset[Hashable] + _internal_names = ["_cache", "__setstate__"] + _internal_names_set = set(_internal_names) + + @final + @property + def _selection_list(self): + if not isinstance( + self._selection, (list, tuple, ABCSeries, ABCIndex, np.ndarray) + ): + return [self._selection] + return self._selection + + @cache_readonly + def _selected_obj(self): + if self._selection is None or isinstance(self.obj, ABCSeries): + return self.obj + else: + return self.obj[self._selection] + + @final + @cache_readonly + def ndim(self) -> int: + return self._selected_obj.ndim + + @final + @cache_readonly + def _obj_with_exclusions(self): + if isinstance(self.obj, ABCSeries): + return self.obj + + if self._selection is not None: + return self.obj[self._selection_list] + + if len(self.exclusions) > 0: + # equivalent to `self.obj.drop(self.exclusions, axis=1) + # but this avoids consolidating and making a copy + # TODO: following GH#45287 can we now use .drop directly without + # making a copy? + return self.obj._drop_axis(self.exclusions, axis=1, only_slice=True) + else: + return self.obj + + def __getitem__(self, key): + if self._selection is not None: + raise IndexError(f"Column(s) {self._selection} already selected") + + if isinstance(key, (list, tuple, ABCSeries, ABCIndex, np.ndarray)): + if len(self.obj.columns.intersection(key)) != len(set(key)): + bad_keys = list(set(key).difference(self.obj.columns)) + raise KeyError(f"Columns not found: {str(bad_keys)[1:-1]}") + return self._gotitem(list(key), ndim=2) + + else: + if key not in self.obj: + raise KeyError(f"Column not found: {key}") + ndim = self.obj[key].ndim + return self._gotitem(key, ndim=ndim) + + def _gotitem(self, key, ndim: int, subset=None): + """ + sub-classes to define + return a sliced object + + Parameters + ---------- + key : str / list of selections + ndim : {1, 2} + requested ndim of result + subset : object, default None + subset to act on + """ + raise AbstractMethodError(self) + + @final + def _infer_selection(self, key, subset: Series | DataFrame): + """ + Infer the `selection` to pass to our constructor in _gotitem. + """ + # Shared by Rolling and Resample + selection = None + if subset.ndim == 2 and ( + (lib.is_scalar(key) and key in subset) or lib.is_list_like(key) + ): + selection = key + elif subset.ndim == 1 and lib.is_scalar(key) and key == subset.name: + selection = key + return selection + + def aggregate(self, func, *args, **kwargs): + raise AbstractMethodError(self) + + agg = aggregate + + +class IndexOpsMixin(OpsMixin): + """ + Common ops mixin to support a unified interface / docs for Series / Index + """ + + # ndarray compatibility + __array_priority__ = 1000 + _hidden_attrs: frozenset[str] = frozenset( + ["tolist"] # tolist is not deprecated, just suppressed in the __dir__ + ) + + @property + def dtype(self) -> DtypeObj: + # must be defined here as a property for mypy + raise AbstractMethodError(self) + + @property + def _values(self) -> ExtensionArray | np.ndarray: + # must be defined here as a property for mypy + raise AbstractMethodError(self) + + @final + def transpose(self, *args, **kwargs) -> Self: + """ + Return the transpose, which is by definition self. + + Returns + ------- + %(klass)s + """ + nv.validate_transpose(args, kwargs) + return self + + T = property( + transpose, + doc=""" + Return the transpose, which is by definition self. + + See Also + -------- + Index : Immutable sequence used for indexing and alignment. + + Examples + -------- + For Series: + + >>> s = pd.Series(['Ant', 'Bear', 'Cow']) + >>> s + 0 Ant + 1 Bear + 2 Cow + dtype: str + >>> s.T + 0 Ant + 1 Bear + 2 Cow + dtype: str + + For Index: + + >>> idx = pd.Index([1, 2, 3]) + >>> idx.T + Index([1, 2, 3], dtype='int64') + """, + ) + + @property + def shape(self) -> Shape: + """ + Return a tuple of the shape of the underlying data. + + See Also + -------- + Series.ndim : Number of dimensions of the underlying data. + Series.size : Return the number of elements in the underlying data. + Series.nbytes : Return the number of bytes in the underlying data. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.shape + (3,) + """ + return self._values.shape + + def __len__(self) -> int: + # We need this defined here for mypy + raise AbstractMethodError(self) + + # Temporarily avoid using `-> Literal[1]:` because of an IPython (jedi) bug + # https://github.com/ipython/ipython/issues/14412 + # https://github.com/davidhalter/jedi/issues/1990 + @property + def ndim(self) -> int: + """ + Number of dimensions of the underlying data, by definition 1. + + See Also + -------- + Series.size: Return the number of elements in the underlying data. + Series.shape: Return a tuple of the shape of the underlying data. + Series.dtype: Return the dtype object of the underlying data. + Series.values: Return Series as ndarray or ndarray-like depending on the dtype. + + Examples + -------- + >>> s = pd.Series(["Ant", "Bear", "Cow"]) + >>> s + 0 Ant + 1 Bear + 2 Cow + dtype: str + >>> s.ndim + 1 + + For Index: + + >>> idx = pd.Index([1, 2, 3]) + >>> idx + Index([1, 2, 3], dtype='int64') + >>> idx.ndim + 1 + """ + return 1 + + @final + def item(self): + """ + Return the first element of the underlying data as a Python scalar. + + Returns + ------- + scalar + The first element of Series or Index. + + Raises + ------ + ValueError + If the data is not length = 1. + + See Also + -------- + Index.values : Returns an array representing the data in the Index. + Series.head : Returns the first `n` rows. + + Examples + -------- + >>> s = pd.Series([1]) + >>> s.item() + 1 + + For an index: + + >>> s = pd.Series([1], index=["a"]) + >>> s.index.item() + 'a' + """ + if len(self) == 1: + return next(iter(self)) + raise ValueError("can only convert an array of size 1 to a Python scalar") + + @property + def nbytes(self) -> int: + """ + Return the number of bytes in the underlying data. + + See Also + -------- + Series.ndim : Number of dimensions of the underlying data. + Series.size : Return the number of elements in the underlying data. + + Examples + -------- + For Series: + + >>> s = pd.Series(["Ant", "Bear", "Cow"]) + >>> s + 0 Ant + 1 Bear + 2 Cow + dtype: str + >>> s.nbytes + 34 + + For Index: + + >>> idx = pd.Index([1, 2, 3]) + >>> idx + Index([1, 2, 3], dtype='int64') + >>> idx.nbytes + 24 + """ + return self._values.nbytes + + @property + def size(self) -> int: + """ + Return the number of elements in the underlying data. + + See Also + -------- + Series.ndim: Number of dimensions of the underlying data, by definition 1. + Series.shape: Return a tuple of the shape of the underlying data. + Series.dtype: Return the dtype object of the underlying data. + Series.values: Return Series as ndarray or ndarray-like depending on the dtype. + + Examples + -------- + For Series: + + >>> s = pd.Series(["Ant", "Bear", "Cow"]) + >>> s + 0 Ant + 1 Bear + 2 Cow + dtype: str + >>> s.size + 3 + + For Index: + + >>> idx = pd.Index([1, 2, 3]) + >>> idx + Index([1, 2, 3], dtype='int64') + >>> idx.size + 3 + """ + return len(self._values) + + @property + def array(self) -> ExtensionArray: + """ + The ExtensionArray of the data backing this Series or Index. + + This property provides direct access to the underlying array data of a + Series or Index without requiring conversion to a NumPy array. It + returns an ExtensionArray, which is the native storage format for + pandas extension dtypes. + + Returns + ------- + ExtensionArray + An ExtensionArray of the values stored within. For extension + types, this is the actual array. For NumPy native types, this + is a thin (no copy) wrapper around :class:`numpy.ndarray`. + + ``.array`` differs from ``.values``, which may require converting + the data to a different form. + + See Also + -------- + Index.to_numpy : Similar method that always returns a NumPy array. + Series.to_numpy : Similar method that always returns a NumPy array. + + Notes + ----- + This table lays out the different array types for each extension + dtype within pandas. + + ================== ============================= + dtype array type + ================== ============================= + category Categorical + period PeriodArray + interval IntervalArray + IntegerNA IntegerArray + string StringArray + boolean BooleanArray + datetime64[ns, tz] DatetimeArray + ================== ============================= + + For any 3rd-party extension types, the array type will be an + ExtensionArray. + + For all remaining dtypes ``.array`` will be a + :class:`arrays.NumpyExtensionArray` wrapping the actual ndarray + stored within. If you absolutely need a NumPy array (possibly with + copying / coercing data), then use :meth:`Series.to_numpy` instead. + + Examples + -------- + For regular NumPy types like int, and float, a NumpyExtensionArray + is returned. + + >>> pd.Series([1, 2, 3]).array + + [1, 2, 3] + Length: 3, dtype: int64 + + For extension types, like Categorical, the actual ExtensionArray + is returned + + >>> ser = pd.Series(pd.Categorical(["a", "b", "a"])) + >>> ser.array + ['a', 'b', 'a'] + Categories (2, str): ['a', 'b'] + """ + raise AbstractMethodError(self) + + def to_numpy( + self, + dtype: npt.DTypeLike | None = None, + copy: bool = False, + na_value: object = lib.no_default, + **kwargs, + ) -> np.ndarray: + """ + A NumPy ndarray representing the values in this Series or Index. + + Parameters + ---------- + dtype : str or numpy.dtype, optional + The dtype to pass to :meth:`numpy.asarray`. + copy : bool, default False + Whether to ensure that the returned value is not a view on + another array. Note that ``copy=False`` does not *ensure* that + ``to_numpy()`` is no-copy. Rather, ``copy=True`` ensure that + a copy is made, even if not strictly necessary. + na_value : Any, optional + The value to use for missing values. The default value depends + on `dtype` and the type of the array. + **kwargs + Additional keywords passed through to the ``to_numpy`` method + of the underlying array (for extension arrays). + + Returns + ------- + numpy.ndarray + The NumPy ndarray holding the values from this Series or Index. + The dtype of the array may differ. See Notes. + + See Also + -------- + Series.array : Get the actual data stored within. + Index.array : Get the actual data stored within. + DataFrame.to_numpy : Similar method for DataFrame. + + Notes + ----- + The returned array will be the same up to equality (values equal + in `self` will be equal in the returned array; likewise for values + that are not equal). When `self` contains an ExtensionArray, the + dtype may be different. For example, for a category-dtype Series, + ``to_numpy()`` will return a NumPy array and the categorical dtype + will be lost. + + For NumPy dtypes, this will be a reference to the actual data stored + in this Series or Index (assuming ``copy=False``). Modifying the result + in place will modify the data stored in the Series or Index (not that + we recommend doing that). + + For extension types, ``to_numpy()`` *may* require copying data and + coercing the result to a NumPy type (possibly object), which may be + expensive. When you need a no-copy reference to the underlying data, + :attr:`Series.array` should be used instead. + + This table lays out the different dtypes and default return types of + ``to_numpy()`` for various dtypes within pandas. + + ================== ================================ + dtype array type + ================== ================================ + category[T] ndarray[T] (same dtype as input) + period ndarray[object] (Periods) + interval ndarray[object] (Intervals) + IntegerNA ndarray[object] + datetime64[ns] datetime64[ns] + datetime64[ns, tz] ndarray[object] (Timestamps) + ================== ================================ + + Examples + -------- + >>> ser = pd.Series(pd.Categorical(["a", "b", "a"])) + >>> ser.to_numpy() + array(['a', 'b', 'a'], dtype=object) + + Specify the `dtype` to control how datetime-aware data is represented. + Use ``dtype=object`` to return an ndarray of pandas :class:`Timestamp` + objects, each with the correct ``tz``. + + >>> ser = pd.Series(pd.date_range("2000", periods=2, tz="CET")) + >>> ser.to_numpy(dtype=object) + array([Timestamp('2000-01-01 00:00:00+0100', tz='CET'), + Timestamp('2000-01-02 00:00:00+0100', tz='CET')], + dtype=object) + + Or ``dtype='datetime64[ns]'`` to return an ndarray of native + datetime64 values. The values are converted to UTC and the timezone + info is dropped. + + >>> ser.to_numpy(dtype="datetime64[ns]") + ... # doctest: +ELLIPSIS + array(['1999-12-31T23:00:00.000000000', '2000-01-01T23:00:00...'], + dtype='datetime64[ns]') + """ + if isinstance(self.dtype, ExtensionDtype): + return self.array.to_numpy(dtype, copy=copy, na_value=na_value, **kwargs) + elif kwargs: + bad_keys = next(iter(kwargs.keys())) + raise TypeError( + f"to_numpy() got an unexpected keyword argument '{bad_keys}'" + ) + + fillna = ( + na_value is not lib.no_default + # no need to fillna with np.nan if we already have a float dtype + and not (na_value is np.nan and np.issubdtype(self.dtype, np.floating)) + ) + + values = self._values + if fillna and self.hasnans: + if not can_hold_element(values, na_value): + # if we can't hold the na_value asarray either makes a copy or we + # error before modifying values. The asarray later on thus won't make + # another copy + values = np.asarray(values, dtype=dtype) + else: + values = values.copy() + + values[np.asanyarray(isna(self))] = na_value + + result = np.asarray(values, dtype=dtype) + + if (copy and not fillna) or not copy: + if np.shares_memory(self._values[:2], result[:2]): + # Take slices to improve performance of check + if not copy: + result = result.view() + result.flags.writeable = False + else: + result = result.copy() + + return result + + @final + @property + def empty(self) -> bool: + """ + Indicator whether Index is empty. + + An Index is considered empty if it has no elements. This property can be + useful for quickly checking the state of an Index, especially in data + processing and analysis workflows where handling of empty datasets might + be required. + + Returns + ------- + bool + If Index is empty, return True, if not return False. + + See Also + -------- + Index.size : Return the number of elements in the underlying data. + + Examples + -------- + >>> idx = pd.Index([1, 2, 3]) + >>> idx + Index([1, 2, 3], dtype='int64') + >>> idx.empty + False + + >>> idx_empty = pd.Index([]) + >>> idx_empty + Index([], dtype='object') + >>> idx_empty.empty + True + + If we only have NaNs in our DataFrame, it is not considered empty! + + >>> idx = pd.Index([np.nan, np.nan]) + >>> idx + Index([nan, nan], dtype='float64') + >>> idx.empty + False + """ + return not self.size + + def argmax( + self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs + ) -> int: + """ + Return int position of the largest value in the Series. + + If the maximum is achieved in multiple locations, + the first row position is returned. + + Parameters + ---------- + axis : None + Unused. Parameter needed for compatibility with DataFrame. + skipna : bool, default True + Exclude NA/null values. If the entire Series is NA, or if ``skipna=False`` + and there is an NA value, this method will raise a ``ValueError``. + *args, **kwargs + Additional arguments and keywords for compatibility with NumPy. + + Returns + ------- + int + Row position of the maximum value. + + See Also + -------- + Series.argmax : Return position of the maximum value. + Series.argmin : Return position of the minimum value. + numpy.ndarray.argmax : Equivalent method for numpy arrays. + Series.idxmax : Return index label of the maximum values. + Series.idxmin : Return index label of the minimum values. + + Examples + -------- + Consider dataset containing cereal calories + + >>> s = pd.Series( + ... [100.0, 110.0, 120.0, 110.0], + ... index=[ + ... "Corn Flakes", + ... "Almond Delight", + ... "Cinnamon Toast Crunch", + ... "Cocoa Puff", + ... ], + ... ) + >>> s + Corn Flakes 100.0 + Almond Delight 110.0 + Cinnamon Toast Crunch 120.0 + Cocoa Puff 110.0 + dtype: float64 + + >>> s.argmax() + np.int64(2) + >>> s.argmin() + np.int64(0) + + The maximum cereal calories is the third element and + the minimum cereal calories is the first element, + since series is zero-indexed. + """ + delegate = self._values + nv.validate_minmax_axis(axis) + skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs) + + if isinstance(delegate, ExtensionArray): + return delegate.argmax(skipna=skipna) + else: + result = nanops.nanargmax(delegate, skipna=skipna) + # error: Incompatible return value type (got "Union[int, ndarray]", expected + # "int") + return result # type: ignore[return-value] + + def argmin( + self, axis: AxisInt | None = None, skipna: bool = True, *args, **kwargs + ) -> int: + """ + Return int position of the smallest value in the Series. + + If the minimum is achieved in multiple locations, + the first row position is returned. + + Parameters + ---------- + axis : None + Unused. Parameter needed for compatibility with DataFrame. + skipna : bool, default True + Exclude NA/null values. If the entire Series is NA, or if ``skipna=False`` + and there is an NA value, this method will raise a ``ValueError``. + *args, **kwargs + Additional arguments and keywords for compatibility with NumPy. + + Returns + ------- + int + Row position of the minimum value. + + See Also + -------- + Series.argmin : Return position of the minimum value. + Series.argmax : Return position of the maximum value. + numpy.ndarray.argmin : Equivalent method for numpy arrays. + Series.idxmin : Return index label of the minimum values. + Series.idxmax : Return index label of the maximum values. + + Examples + -------- + Consider dataset containing cereal calories + + >>> s = pd.Series( + ... [100.0, 110.0, 120.0, 110.0], + ... index=[ + ... "Corn Flakes", + ... "Almond Delight", + ... "Cinnamon Toast Crunch", + ... "Cocoa Puff", + ... ], + ... ) + >>> s + Corn Flakes 100.0 + Almond Delight 110.0 + Cinnamon Toast Crunch 120.0 + Cocoa Puff 110.0 + dtype: float64 + + >>> s.argmax() + np.int64(2) + >>> s.argmin() + np.int64(0) + + The maximum cereal calories is the third element and + the minimum cereal calories is the first element, + since series is zero-indexed. + """ + delegate = self._values + nv.validate_minmax_axis(axis) + skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs) + + if isinstance(delegate, ExtensionArray): + return delegate.argmin(skipna=skipna) + else: + result = nanops.nanargmin(delegate, skipna=skipna) + # error: Incompatible return value type (got "Union[int, ndarray]", expected + # "int") + return result # type: ignore[return-value] + + def tolist(self) -> list: + """ + Return a list of the values. + + These are each a scalar type, which is a Python scalar + (for str, int, float) or a pandas scalar + (for Timestamp/Timedelta/Interval/Period) + + Returns + ------- + list + List containing the values as Python or pandas scalers. + + See Also + -------- + numpy.ndarray.tolist : Return the array as an a.ndim-levels deep + nested list of Python scalars. + + Examples + -------- + For Series + + >>> s = pd.Series([1, 2, 3]) + >>> s.to_list() + [1, 2, 3] + + For Index: + + >>> idx = pd.Index([1, 2, 3]) + >>> idx + Index([1, 2, 3], dtype='int64') + + >>> idx.to_list() + [1, 2, 3] + """ + return self._values.tolist() + + to_list = tolist + + def __iter__(self) -> Iterator: + """ + Return an iterator of the values. + + These are each a scalar type, which is a Python scalar + (for str, int, float) or a pandas scalar + (for Timestamp/Timedelta/Interval/Period) + + Returns + ------- + iterator + An iterator yielding scalar values from the Series. + + See Also + -------- + Series.items : Lazily iterate over (index, value) tuples. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> for x in s: + ... print(x) + 1 + 2 + 3 + """ + # We are explicitly making element iterators. + if not isinstance(self._values, np.ndarray): + # Check type instead of dtype to catch DTA/TDA + return iter(self._values) + else: + return map(self._values.item, range(self._values.size)) + + @cache_readonly + def hasnans(self) -> bool: + """ + Return True if there are any NaNs. + + Enables various performance speedups. + + Returns + ------- + bool + + See Also + -------- + Series.isna : Detect missing values. + Series.notna : Detect existing (non-missing) values. + + Examples + -------- + >>> s = pd.Series([1, 2, 3, None]) + >>> s + 0 1.0 + 1 2.0 + 2 3.0 + 3 NaN + dtype: float64 + >>> s.hasnans + True + """ + # error: Item "bool" of "Union[bool, ndarray[Any, dtype[bool_]], NDFrame]" + # has no attribute "any" + return bool(isna(self).any()) # type: ignore[union-attr] + + @final + def _map_values(self, mapper, na_action=None): + """ + An internal function that maps values using the input + correspondence (which can be a dict, Series, or function). + + Parameters + ---------- + mapper : function, dict, or Series + The input correspondence object + na_action : {None, 'ignore'} + If 'ignore', propagate NA values, without passing them to the + mapping function + + Returns + ------- + Union[Index, MultiIndex], inferred + The output of the mapping function applied to the index. + If the function returns a tuple with more than one element + a MultiIndex will be returned. + """ + arr = self._values + + if isinstance(arr, ExtensionArray): + return arr.map(mapper, na_action=na_action) + + return algorithms.map_array(arr, mapper, na_action=na_action) + + def value_counts( + self, + normalize: bool = False, + sort: bool = True, + ascending: bool = False, + bins=None, + dropna: bool = True, + ) -> Series: + """ + Return a Series containing counts of unique values. + + The resulting object will be in descending order so that the + first element is the most frequently-occurring element. + Excludes NA values by default. + + Parameters + ---------- + normalize : bool, default False + If True then the object returned will contain the relative + frequencies of the unique values. + sort : bool, default True + Stable sort by frequencies when True. Preserve the order of the data + when False. + + .. versionchanged:: 3.0.0 + + Prior to 3.0.0, the sort was unstable. + ascending : bool, default False + Sort in ascending order. + bins : int, optional + Rather than count values, group them into half-open bins, + a convenience for ``pd.cut``, only works with numeric data. + dropna : bool, default True + Don't include counts of NaN. + + Returns + ------- + Series + Series containing counts of unique values. + + See Also + -------- + Series.count: Number of non-NA elements in a Series. + DataFrame.count: Number of non-NA elements in a DataFrame. + DataFrame.value_counts: Equivalent method on DataFrames. + + Examples + -------- + >>> index = pd.Index([3, 1, 2, 3, 4, np.nan]) + >>> index.value_counts() + 3.0 2 + 1.0 1 + 2.0 1 + 4.0 1 + Name: count, dtype: int64 + + With `normalize` set to `True`, returns the relative frequency by + dividing all values by the sum of values. + + >>> s = pd.Series([3, 1, 2, 3, 4, np.nan]) + >>> s.value_counts(normalize=True) + 3.0 0.4 + 1.0 0.2 + 2.0 0.2 + 4.0 0.2 + Name: proportion, dtype: float64 + + **bins** + + Bins can be useful for going from a continuous variable to a + categorical variable; instead of counting unique + apparitions of values, divide the index in the specified + number of half-open bins. + + >>> s.value_counts(bins=3) + (0.996, 2.0] 2 + (2.0, 3.0] 2 + (3.0, 4.0] 1 + Name: count, dtype: int64 + + **dropna** + + With `dropna` set to `False` we can also see NaN index values. + + >>> s.value_counts(dropna=False) + 3.0 2 + 1.0 1 + 2.0 1 + 4.0 1 + NaN 1 + Name: count, dtype: int64 + + **Categorical Dtypes** + + Rows with categorical type will be counted as one group + if they have same categories and order. + In the example below, even though ``a``, ``c``, and ``d`` + all have the same data types of ``category``, + only ``c`` and ``d`` will be counted as one group + since ``a`` doesn't have the same categories. + + >>> df = pd.DataFrame({"a": [1], "b": ["2"], "c": [3], "d": [3]}) + >>> df = df.astype({"a": "category", "c": "category", "d": "category"}) + >>> df + a b c d + 0 1 2 3 3 + + >>> df.dtypes + a category + b str + c category + d category + dtype: object + + >>> df.dtypes.value_counts() + category 2 + category 1 + str 1 + Name: count, dtype: int64 + """ + return algorithms.value_counts_internal( + self, + sort=sort, + ascending=ascending, + normalize=normalize, + bins=bins, + dropna=dropna, + ) + + def unique(self): + values = self._values + if not isinstance(values, np.ndarray): + # i.e. ExtensionArray + result = values.unique() + else: + result = algorithms.unique1d(values) # type: ignore[assignment] + return result + + @final + def nunique(self, dropna: bool = True) -> int: + """ + Return number of unique elements in the object. + + Excludes NA values by default. + + Parameters + ---------- + dropna : bool, default True + Don't include NaN in the count. + + Returns + ------- + int + An integer indicating the number of unique elements in the object. + + See Also + -------- + DataFrame.nunique: Method nunique for DataFrame. + Series.count: Count non-NA/null observations in the Series. + + Examples + -------- + >>> s = pd.Series([1, 3, 5, 7, 7]) + >>> s + 0 1 + 1 3 + 2 5 + 3 7 + 4 7 + dtype: int64 + + >>> s.nunique() + 4 + """ + uniqs = self.unique() + if dropna: + uniqs = remove_na_arraylike(uniqs) + return len(uniqs) + + @property + def is_unique(self) -> bool: + """ + Return True if values in the object are unique. + + Returns + ------- + bool + + See Also + -------- + Series.unique : Return unique values of Series object. + Series.drop_duplicates : Return Series with duplicate values removed. + Series.duplicated : Indicate duplicate Series values. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.is_unique + True + + >>> s = pd.Series([1, 2, 3, 1]) + >>> s.is_unique + False + """ + return self.nunique(dropna=False) == len(self) + + @property + def is_monotonic_increasing(self) -> bool: + """ + Return True if values in the object are monotonically increasing. + + Returns + ------- + bool + + See Also + -------- + Series.is_monotonic_decreasing : Return boolean if values in the object are + monotonically decreasing. + + Examples + -------- + >>> s = pd.Series([1, 2, 2]) + >>> s.is_monotonic_increasing + True + + >>> s = pd.Series([3, 2, 1]) + >>> s.is_monotonic_increasing + False + """ + from pandas import Index + + return Index(self).is_monotonic_increasing + + @property + def is_monotonic_decreasing(self) -> bool: + """ + Return True if values in the object are monotonically decreasing. + + Returns + ------- + bool + + See Also + -------- + Series.is_monotonic_increasing : Return boolean if values in the object are + monotonically increasing. + + Examples + -------- + >>> s = pd.Series([3, 2, 2, 1]) + >>> s.is_monotonic_decreasing + True + + >>> s = pd.Series([1, 2, 3]) + >>> s.is_monotonic_decreasing + False + """ + from pandas import Index + + return Index(self).is_monotonic_decreasing + + @final + def _memory_usage(self, deep: bool = False) -> int: + """ + Memory usage of the values. + + Parameters + ---------- + deep : bool, default False + Introspect the data deeply, interrogate + `object` dtypes for system-level memory consumption. + + Returns + ------- + bytes used + Returns memory usage of the values in the Index in bytes. + + See Also + -------- + numpy.ndarray.nbytes : Total bytes consumed by the elements of the + array. + + Notes + ----- + Memory usage does not include memory consumed by elements that + are not components of the array if deep=False or if used on PyPy + + Examples + -------- + >>> idx = pd.Index([1, 2, 3]) + >>> idx.memory_usage() + 24 + """ + if hasattr(self.array, "memory_usage"): + return self.array.memory_usage( # pyright: ignore[reportAttributeAccessIssue] + deep=deep, + ) + + v = self.array.nbytes + if deep and is_object_dtype(self.dtype) and not PYPY: + values = cast(np.ndarray, self._values) + v += lib.memory_usage_of_objects(values) + return v + + def factorize( + self, + sort: bool = False, + use_na_sentinel: bool = True, + ) -> tuple[npt.NDArray[np.intp], Index]: + """ + Encode the object as an enumerated type or categorical variable. + + This method is useful for obtaining a numeric representation of an + array when all that matters is identifying distinct values. `factorize` + is available as both a top-level function :func:`pandas.factorize`, + and as a method :meth:`Series.factorize` and :meth:`Index.factorize`. + + Parameters + ---------- + sort : bool, default False + Sort `uniques` and shuffle `codes` to maintain the + relationship. + use_na_sentinel : bool, default True + If True, the sentinel -1 will be used for NaN values. If False, + NaN values will be encoded as non-negative integers and will not drop the + NaN from the uniques of the values. + + Returns + ------- + codes : ndarray + An integer ndarray that's an indexer into `uniques`. + ``uniques.take(codes)`` will have the same values as `values`. + uniques : ndarray, Index, or Categorical + The unique valid values. When `values` is Categorical, `uniques` + is a Categorical. When `values` is some other pandas object, an + `Index` is returned. Otherwise, a 1-D ndarray is returned. + + .. note:: + + Even if there's a missing value in `values`, `uniques` will + *not* contain an entry for it. + + See Also + -------- + cut : Discretize continuous-valued array. + unique : Find the unique value in an array. + + Notes + ----- + Reference :ref:`the user guide ` for more examples. + + Examples + -------- + These examples all show factorize as a top-level method like + ``pd.factorize(values)``. The results are identical for methods like + :meth:`Series.factorize`. + + >>> codes, uniques = pd.factorize( + ... np.array(["b", "b", "a", "c", "b"], dtype="O") + ... ) + >>> codes + array([0, 0, 1, 2, 0]) + >>> uniques + array(['b', 'a', 'c'], dtype=object) + + With ``sort=True``, the `uniques` will be sorted, and `codes` will be + shuffled so that the relationship is the maintained. + + >>> codes, uniques = pd.factorize( + ... np.array(["b", "b", "a", "c", "b"], dtype="O"), sort=True + ... ) + >>> codes + array([1, 1, 0, 2, 1]) + >>> uniques + array(['a', 'b', 'c'], dtype=object) + + When ``use_na_sentinel=True`` (the default), missing values are indicated in + the `codes` with the sentinel value ``-1`` and missing values are not + included in `uniques`. + + >>> codes, uniques = pd.factorize( + ... np.array(["b", None, "a", "c", "b"], dtype="O") + ... ) + >>> codes + array([ 0, -1, 1, 2, 0]) + >>> uniques + array(['b', 'a', 'c'], dtype=object) + + Thus far, we've only factorized lists (which are internally coerced to + NumPy arrays). When factorizing pandas objects, the type of `uniques` + will differ. For Categoricals, a `Categorical` is returned. + + >>> cat = pd.Categorical(["a", "a", "c"], categories=["a", "b", "c"]) + >>> codes, uniques = pd.factorize(cat) + >>> codes + array([0, 0, 1]) + >>> uniques + ['a', 'c'] + Categories (3, str): ['a', 'b', 'c'] + + Notice that ``'b'`` is in ``uniques.categories``, despite not being + present in ``cat.values``. + + For all other pandas objects, an Index of the appropriate type is + returned. + + >>> cat = pd.Series(["a", "a", "c"]) + >>> codes, uniques = pd.factorize(cat) + >>> codes + array([0, 0, 1]) + >>> uniques + Index(['a', 'c'], dtype='str') + + If NaN is in the values, and we want to include NaN in the uniques of the + values, it can be achieved by setting ``use_na_sentinel=False``. + + >>> values = np.array([1, 2, 1, np.nan]) + >>> codes, uniques = pd.factorize(values) # default: use_na_sentinel=True + >>> codes + array([ 0, 1, 0, -1]) + >>> uniques + array([1., 2.]) + + >>> codes, uniques = pd.factorize(values, use_na_sentinel=False) + >>> codes + array([0, 1, 0, 2]) + >>> uniques + array([ 1., 2., nan]) + """ + codes, uniques = algorithms.factorize( + self._values, sort=sort, use_na_sentinel=use_na_sentinel + ) + if uniques.dtype == np.float16: + uniques = uniques.astype(np.float32) + + if isinstance(self, ABCMultiIndex): + # preserve MultiIndex + if len(self) == 0: + # GH#57517 + uniques = self[:0] + else: + uniques = self._constructor(uniques) + else: + from pandas import Index + + try: + uniques = Index(uniques, dtype=self.dtype, copy=False) + except NotImplementedError: + # not all dtypes are supported in Index that are allowed for Series + # e.g. float16 or bytes + uniques = Index(uniques, copy=False) + return codes, uniques + + # This overload is needed so that the call to searchsorted in + # pandas.core.resample.TimeGrouper._get_period_bins picks the correct result + + # error: Overloaded function signatures 1 and 2 overlap with incompatible + # return types + @overload + def searchsorted( # type: ignore[overload-overlap] + self, + value: ScalarLike_co, + side: Literal["left", "right"] = ..., + sorter: NumpySorter = ..., + ) -> np.intp: ... + + @overload + def searchsorted( + self, + value: npt.ArrayLike | ExtensionArray, + side: Literal["left", "right"] = ..., + sorter: NumpySorter = ..., + ) -> npt.NDArray[np.intp]: ... + + def searchsorted( + self, + value: NumpyValueArrayLike | ExtensionArray, + side: Literal["left", "right"] = "left", + sorter: NumpySorter | None = None, + ) -> npt.NDArray[np.intp] | np.intp: + """ + Find indices where elements should be inserted to maintain order. + + Find the indices into a sorted Index `self` such that, if the + corresponding elements in `value` were inserted before the indices, + the order of `self` would be preserved. + + .. note:: + + The Index *must* be monotonically sorted, otherwise + wrong locations will likely be returned. Pandas does *not* + check this for you. + + Parameters + ---------- + value : array-like or scalar + Values to insert into `self`. + side : {{'left', 'right'}}, optional + If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. If there is no suitable + index, return either 0 or N (where N is the length of `self`). + sorter : 1-D array-like, optional + Optional array of integer indices that sort `self` into ascending + order. They are typically the result of ``np.argsort``. + + Returns + ------- + int or array of int + A scalar or array of insertion points with the + same shape as `value`. + + See Also + -------- + sort_values : Sort by the values along either axis. + numpy.searchsorted : Similar method from NumPy. + + Notes + ----- + Binary search is used to find the required insertion points. + + Examples + -------- + >>> ser = pd.Series([1, 2, 3]) + >>> ser + 0 1 + 1 2 + 2 3 + dtype: int64 + + >>> ser.searchsorted(4) + np.int64(3) + + >>> ser.searchsorted([0, 4]) + array([0, 3]) + + >>> ser.searchsorted([1, 3], side="left") + array([0, 2]) + + >>> ser.searchsorted([1, 3], side="right") + array([1, 3]) + + >>> ser = pd.Series(pd.to_datetime(["3/11/2000", "3/12/2000", "3/13/2000"])) + >>> ser + 0 2000-03-11 + 1 2000-03-12 + 2 2000-03-13 + dtype: datetime64[us] + + >>> ser.searchsorted("3/14/2000") + np.int64(3) + + >>> ser = pd.Categorical( + ... ["apple", "bread", "bread", "cheese", "milk"], ordered=True + ... ) + >>> ser + ['apple', 'bread', 'bread', 'cheese', 'milk'] + Categories (4, str): ['apple' < 'bread' < 'cheese' < 'milk'] + + >>> ser.searchsorted("bread") + np.int64(1) + + >>> ser.searchsorted(["bread"], side="right") + array([3]) + + If the values are not monotonically sorted, wrong locations + may be returned: + + >>> ser = pd.Series([2, 1, 3]) + >>> ser + 0 2 + 1 1 + 2 3 + dtype: int64 + + >>> ser.searchsorted(1) # doctest: +SKIP + 0 # wrong result, correct would be 1 + """ + if isinstance(value, ABCDataFrame): + msg = ( + "Value must be 1-D array-like or scalar, " + f"{type(value).__name__} is not supported" + ) + raise ValueError(msg) + + values = self._values + if not isinstance(values, np.ndarray): + # Going through EA.searchsorted directly improves performance GH#38083 + return values.searchsorted(value, side=side, sorter=sorter) + + return algorithms.searchsorted( + values, + value, + side=side, + sorter=sorter, + ) + + def drop_duplicates(self, *, keep: DropKeep = "first") -> Self: + duplicated = self._duplicated(keep=keep) + # error: Value of type "IndexOpsMixin" is not indexable + return self[~duplicated] # type: ignore[index] + + @final + def _duplicated(self, keep: DropKeep = "first") -> npt.NDArray[np.bool_]: + arr = self._values + if isinstance(arr, ExtensionArray): + return arr.duplicated(keep=keep) + return algorithms.duplicated(arr, keep=keep) + + def _arith_method(self, other, op): + res_name = ops.get_op_result_name(self, other) + + lvalues = self._values + rvalues = extract_array(other, extract_numpy=True, extract_range=True) + rvalues = ops.maybe_prepare_scalar_for_op(rvalues, lvalues.shape) + rvalues = ensure_wrapped_if_datetimelike(rvalues) + if isinstance(rvalues, range): + rvalues = np.arange(rvalues.start, rvalues.stop, rvalues.step) + + with np.errstate(all="ignore"): + result = ops.arithmetic_op(lvalues, rvalues, op) + + return self._construct_result(result, name=res_name, other=other) + + def _construct_result(self, result, name, other): + """ + Construct an appropriately-wrapped result from the ArrayLike result + of an arithmetic-like operation. + """ + raise AbstractMethodError(self) diff --git a/pandas/core/col.py b/pandas/core/col.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c064a17f2531d23958c1455247977c2895bbce --- /dev/null +++ b/pandas/core/col.py @@ -0,0 +1,374 @@ +from __future__ import annotations + +from collections.abc import ( + Callable, + Hashable, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +from pandas.util._decorators import set_module + +if TYPE_CHECKING: + from pandas import ( + DataFrame, + Series, + ) + + +# Used only for generating the str repr of expressions. +_OP_SYMBOLS = { + "__add__": "+", + "__radd__": "+", + "__sub__": "-", + "__rsub__": "-", + "__mul__": "*", + "__rmul__": "*", + "__truediv__": "/", + "__rtruediv__": "/", + "__floordiv__": "//", + "__rfloordiv__": "//", + "__mod__": "%", + "__rmod__": "%", + "__ge__": ">=", + "__gt__": ">", + "__le__": "<=", + "__lt__": "<", + "__eq__": "==", + "__ne__": "!=", + "__and__": "&", + "__rand__": "&", + "__or__": "|", + "__ror__": "|", + "__xor__": "^", + "__rxor__": "^", +} + + +def _parse_args(df: DataFrame, *args: Any) -> tuple[Series]: + # Parse `args`, evaluating any expressions we encounter. + return tuple( + [x._eval_expression(df) if isinstance(x, Expression) else x for x in args] + ) + + +def _parse_kwargs(df: DataFrame, **kwargs: Any) -> dict[str, Any]: + # Parse `kwargs`, evaluating any expressions we encounter. + return { + key: val._eval_expression(df) if isinstance(val, Expression) else val + for key, val in kwargs.items() + } + + +def _pretty_print_args_kwargs(*args: Any, **kwargs: Any) -> str: + inputs_repr = ", ".join(repr(arg) for arg in args) + kwargs_repr = ", ".join(f"{k}={v!r}" for k, v in kwargs.items()) + + all_args = [] + if inputs_repr: + all_args.append(inputs_repr) + if kwargs_repr: + all_args.append(kwargs_repr) + + return ", ".join(all_args) + + +@set_module("pandas.api.typing") +class Expression: + """ + Class representing a deferred column. + + This is not meant to be instantiated directly. Instead, use :meth:`pandas.col`. + """ + + def __init__( + self, + func: Callable[[DataFrame], Any], + repr_str: str, + needs_parenthese: bool = False, + ) -> None: + self._func = func + self._repr_str = repr_str + self._needs_parentheses = needs_parenthese + + def _eval_expression(self, df: DataFrame) -> Any: + return self._func(df) + + def _with_op( + self, op: str, other: Any, repr_str: str, needs_parentheses: bool = True + ) -> Expression: + if isinstance(other, Expression): + return Expression( + lambda df: getattr(self._eval_expression(df), op)( + other._eval_expression(df) + ), + repr_str, + needs_parenthese=needs_parentheses, + ) + else: + return Expression( + lambda df: getattr(self._eval_expression(df), op)(other), + repr_str, + needs_parenthese=needs_parentheses, + ) + + def _maybe_wrap_parentheses(self, other: Any) -> tuple[str, str]: + if self._needs_parentheses: + self_repr = f"({self!r})" + else: + self_repr = f"{self!r}" + if isinstance(other, Expression) and other._needs_parentheses: + other_repr = f"({other!r})" + else: + other_repr = f"{other!r}" + return self_repr, other_repr + + # Binary ops + def __add__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__add__", other, f"{self_repr} + {other_repr}") + + def __radd__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__radd__", other, f"{other_repr} + {self_repr}") + + def __sub__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__sub__", other, f"{self_repr} - {other_repr}") + + def __rsub__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__rsub__", other, f"{other_repr} - {self_repr}") + + def __mul__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__mul__", other, f"{self_repr} * {other_repr}") + + def __rmul__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__rmul__", other, f"{other_repr} * {self_repr}") + + def __truediv__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__truediv__", other, f"{self_repr} / {other_repr}") + + def __rtruediv__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__rtruediv__", other, f"{other_repr} / {self_repr}") + + def __floordiv__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__floordiv__", other, f"{self_repr} // {other_repr}") + + def __rfloordiv__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__rfloordiv__", other, f"{other_repr} // {self_repr}") + + def __ge__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__ge__", other, f"{self_repr} >= {other_repr}") + + def __gt__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__gt__", other, f"{self_repr} > {other_repr}") + + def __le__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__le__", other, f"{self_repr} <= {other_repr}") + + def __lt__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__lt__", other, f"{self_repr} < {other_repr}") + + def __eq__(self, other: object) -> Expression: # type: ignore[override] + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__eq__", other, f"{self_repr} == {other_repr}") + + def __ne__(self, other: object) -> Expression: # type: ignore[override] + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__ne__", other, f"{self_repr} != {other_repr}") + + def __mod__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__mod__", other, f"{self_repr} % {other_repr}") + + def __rmod__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__rmod__", other, f"{other_repr} % {self_repr}") + + # Logical ops + def __and__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__and__", other, f"{self_repr} & {other_repr}") + + def __rand__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__rand__", other, f"{other_repr} & {self_repr}") + + def __or__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__or__", other, f"{self_repr} | {other_repr}") + + def __ror__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__ror__", other, f"{other_repr} | {self_repr}") + + def __xor__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__xor__", other, f"{self_repr} ^ {other_repr}") + + def __rxor__(self, other: Any) -> Expression: + self_repr, other_repr = self._maybe_wrap_parentheses(other) + return self._with_op("__rxor__", other, f"{other_repr} ^ {self_repr}") + + def __invert__(self) -> Expression: + return Expression( + lambda df: ~self._eval_expression(df), + f"~{self._repr_str}", + needs_parenthese=True, + ) + + def __neg__(self) -> Expression: + if self._needs_parentheses: + repr_str = f"-({self._repr_str})" + else: + repr_str = f"-{self._repr_str}" + return Expression( + lambda df: -self._eval_expression(df), + repr_str, + needs_parenthese=True, + ) + + def __pos__(self) -> Expression: + if self._needs_parentheses: + repr_str = f"+({self._repr_str})" + else: + repr_str = f"+{self._repr_str}" + return Expression( + lambda df: +self._eval_expression(df), + repr_str, + needs_parenthese=True, + ) + + def __abs__(self) -> Expression: + return Expression( + lambda df: abs(self._eval_expression(df)), + f"abs({self._repr_str})", + needs_parenthese=True, + ) + + def __array_ufunc__( + self, ufunc: Callable[..., Any], method: str, *inputs: Any, **kwargs: Any + ) -> Expression: + def func(df: DataFrame) -> Any: + parsed_inputs = _parse_args(df, *inputs) + parsed_kwargs = _parse_kwargs(df, *kwargs) + return ufunc(*parsed_inputs, **parsed_kwargs) + + args_str = _pretty_print_args_kwargs(*inputs, **kwargs) + repr_str = f"{ufunc.__name__}({args_str})" + + return Expression(func, repr_str) + + def __getitem__(self, item: Any) -> Expression: + return self._with_op( + "__getitem__", item, f"{self!r}[{item!r}]", needs_parentheses=True + ) + + def _call_with_func(self, func: Callable, **kwargs: Any) -> Expression: + def wrapped(df: DataFrame) -> Any: + parsed_kwargs = _parse_kwargs(df, **kwargs) + return func(**parsed_kwargs) + + args_str = _pretty_print_args_kwargs(**kwargs) + repr_str = func.__name__ + "(" + args_str + ")" + + return Expression(wrapped, repr_str) + + def __call__(self, *args: Any, **kwargs: Any) -> Expression: + def func(df: DataFrame, *args: Any, **kwargs: Any) -> Any: + parsed_args = _parse_args(df, *args) + parsed_kwargs = _parse_kwargs(df, **kwargs) + return self._eval_expression(df)(*parsed_args, **parsed_kwargs) + + args_str = _pretty_print_args_kwargs(*args, **kwargs) + repr_str = f"{self._repr_str}({args_str})" + return Expression(lambda df: func(df, *args, **kwargs), repr_str) + + def __getattr__(self, name: str, /) -> Any: + repr_str = f"{self!r}" + if self._needs_parentheses: + repr_str = f"({repr_str})" + repr_str += f".{name}" + return Expression(lambda df: getattr(self._eval_expression(df), name), repr_str) + + def __repr__(self) -> str: + return self._repr_str or "Expr(...)" + + +@set_module("pandas") +def col(col_name: Hashable) -> Expression: + """ + Generate deferred object representing a column of a DataFrame. + + Any place which accepts ``lambda df: df[col_name]``, such as + :meth:`DataFrame.assign` or :meth:`DataFrame.loc`, can also accept + ``pd.col(col_name)``. + + .. versionadded:: 3.0.0 + + Parameters + ---------- + col_name : Hashable + Column name. + + Returns + ------- + `pandas.api.typing.Expression` + A deferred object representing a column of a DataFrame. + + See Also + -------- + DataFrame.query : Query columns of a dataframe using string expressions. + + Examples + -------- + + You can use `col` in `assign`. + + >>> df = pd.DataFrame({"name": ["beluga", "narwhal"], "speed": [100, 110]}) + >>> df.assign(name_titlecase=pd.col("name").str.title()) + name speed name_titlecase + 0 beluga 100 Beluga + 1 narwhal 110 Narwhal + + You can also use it for filtering. + + >>> df.loc[pd.col("speed") > 105] + name speed + 1 narwhal 110 + """ + if not isinstance(col_name, Hashable): + msg = f"Expected Hashable, got: {type(col_name)}" + raise TypeError(msg) + + def func(df: DataFrame) -> Series: + if col_name not in df.columns: + columns_str = str(df.columns.tolist()) + max_len = 90 + if len(columns_str) > max_len: + columns_str = columns_str[:max_len] + "...]" + + msg = ( + f"Column '{col_name}' not found in given DataFrame.\n\n" + f"Hint: did you mean one of {columns_str} instead?" + ) + raise ValueError(msg) + return df[col_name] + + return Expression(func, f"col({col_name!r})") + + +__all__ = ["Expression", "col"] diff --git a/pandas/core/common.py b/pandas/core/common.py new file mode 100644 index 0000000000000000000000000000000000000000..3ca6586222ca13845e7313eeaf51ee036e5f9f9d --- /dev/null +++ b/pandas/core/common.py @@ -0,0 +1,685 @@ +""" +Misc tools for implementing data structures + +Note: pandas.core.common is *not* part of the public API. +""" + +from __future__ import annotations + +import builtins +from collections import ( + abc, + defaultdict, +) +from collections.abc import ( + Callable, + Collection, + Generator, + Hashable, + Iterable, + Sequence, +) +import contextlib +from functools import partial +import inspect +import sys +from typing import ( + TYPE_CHECKING, + Any, + Concatenate, + TypeVar, + cast, + overload, +) + +import numpy as np + +from pandas._libs import lib + +from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike +from pandas.core.dtypes.common import ( + is_bool_dtype, + is_integer, +) +from pandas.core.dtypes.generic import ( + ABCExtensionArray, + ABCIndex, + ABCMultiIndex, + ABCNumpyExtensionArray, + ABCSeries, +) +from pandas.core.dtypes.inference import iterable_not_string + +from pandas.core.col import Expression + +if TYPE_CHECKING: + from pandas._typing import ( + AnyArrayLike, + ArrayLike, + NpDtype, + P, + RandomState, + T, + ) + + from pandas import Index + + +def flatten(line): + """ + Flatten an arbitrarily nested sequence. + + Parameters + ---------- + line : sequence + The non string sequence to flatten + + Notes + ----- + This doesn't consider strings sequences. + + Returns + ------- + flattened : generator + """ + for element in line: + if iterable_not_string(element): + yield from flatten(element) + else: + yield element + + +def consensus_name_attr(objs): + name = objs[0].name + for obj in objs[1:]: + try: + if obj.name != name: + name = None + break + except ValueError: + name = None + break + return name + + +def is_bool_indexer(key: Any) -> bool: + """ + Check whether `key` is a valid boolean indexer. + + Parameters + ---------- + key : Any + Only list-likes may be considered boolean indexers. + All other types are not considered a boolean indexer. + For array-like input, boolean ndarrays or ExtensionArrays + with ``_is_boolean`` set are considered boolean indexers. + + Returns + ------- + bool + Whether `key` is a valid boolean indexer. + + Raises + ------ + ValueError + When the array is an object-dtype ndarray or ExtensionArray + and contains missing values. + + See Also + -------- + check_array_indexer : Check that `key` is a valid array to index, + and convert to an ndarray. + """ + if isinstance( + key, + (ABCSeries, np.ndarray, ABCIndex, ABCExtensionArray, ABCNumpyExtensionArray), + ) and not isinstance(key, ABCMultiIndex): + if key.dtype == np.object_: + key_array = np.asarray(key) + + if not lib.is_bool_array(key_array): + na_msg = "Cannot mask with non-boolean array containing NA / NaN values" + if lib.is_bool_array(key_array, skipna=True): + # Don't raise on e.g. ["A", "B", np.nan], see + # test_loc_getitem_list_of_labels_categoricalindex_with_na + raise ValueError(na_msg) + return False + return True + elif is_bool_dtype(key.dtype): + return True + elif isinstance(key, list): + # check if np.array(key).dtype would be bool + if len(key) > 0: + if type(key) is not list: + # GH#42461 cython will raise TypeError if we pass a subclass + key = list(key) + return lib.is_bool_list(key) + + return False + + +def cast_scalar_indexer(val): + """ + Disallow indexing with a float key, even if that key is a round number. + + Parameters + ---------- + val : scalar + + Returns + ------- + outval : scalar + """ + # assumes lib.is_scalar(val) + if lib.is_float(val) and val.is_integer(): + raise IndexError( + # GH#34193 + "Indexing with a float is no longer supported. Manually convert " + "to an integer key instead." + ) + return val + + +def not_none(*args): + """ + Returns a generator consisting of the arguments that are not None. + """ + return (arg for arg in args if arg is not None) + + +def any_none(*args) -> bool: + """ + Returns a boolean indicating if any argument is None. + """ + return any(arg is None for arg in args) + + +def all_none(*args) -> bool: + """ + Returns a boolean indicating if all arguments are None. + """ + return all(arg is None for arg in args) + + +def any_not_none(*args) -> bool: + """ + Returns a boolean indicating if any argument is not None. + """ + return any(arg is not None for arg in args) + + +def all_not_none(*args) -> bool: + """ + Returns a boolean indicating if all arguments are not None. + """ + return all(arg is not None for arg in args) + + +def count_not_none(*args) -> int: + """ + Returns the count of arguments that are not None. + """ + return sum(x is not None for x in args) + + +@overload +def asarray_tuplesafe( + values: ArrayLike | list | tuple | zip, dtype: NpDtype | None = ... +) -> np.ndarray: + # ExtensionArray can only be returned when values is an Index, all other iterables + # will return np.ndarray. Unfortunately "all other" cannot be encoded in a type + # signature, so instead we special-case some common types. + ... + + +@overload +def asarray_tuplesafe(values: Iterable, dtype: NpDtype | None = ...) -> ArrayLike: ... + + +def asarray_tuplesafe(values: Iterable, dtype: NpDtype | None = None) -> ArrayLike: + if not (isinstance(values, (list, tuple)) or hasattr(values, "__array__")): + values = list(values) + elif isinstance(values, ABCIndex): + return values._values + elif isinstance(values, ABCSeries): + return values._values + + if isinstance(values, list) and dtype in [np.object_, object]: + return construct_1d_object_array_from_listlike(values) + + try: + result = np.asarray(values, dtype=dtype) + except ValueError: + # Using try/except since it's more performant than checking is_list_like + # over each element + # error: Argument 1 to "construct_1d_object_array_from_listlike" + # has incompatible type "Iterable[Any]"; expected "Sized" + return construct_1d_object_array_from_listlike(values) # type: ignore[arg-type] + + if issubclass(result.dtype.type, str): + result = np.asarray(values, dtype=object) + + if result.ndim == 2: + # Avoid building an array of arrays: + values = [tuple(x) for x in values] + result = construct_1d_object_array_from_listlike(values) + + return result + + +def index_labels_to_array( + labels: np.ndarray | Iterable, dtype: NpDtype | None = None +) -> np.ndarray: + """ + Transform label or iterable of labels to array, for use in Index. + + Parameters + ---------- + dtype : dtype + If specified, use as dtype of the resulting array, otherwise infer. + + Returns + ------- + array + """ + if isinstance(labels, (str, tuple)): + labels = [labels] + + if not isinstance(labels, (list, np.ndarray)): + try: + labels = list(labels) + except TypeError: # non-iterable + labels = [labels] + + rlabels = asarray_tuplesafe(labels, dtype=dtype) + + return rlabels + + +def maybe_make_list(obj): + if obj is not None and not isinstance(obj, (tuple, list)): + return [obj] + return obj + + +def maybe_iterable_to_list(obj: Iterable[T] | T) -> Collection[T] | T: + """ + If obj is Iterable but not list-like, consume into list. + """ + if isinstance(obj, abc.Iterable) and not isinstance(obj, abc.Sized): + return list(obj) + obj = cast(Collection, obj) + return obj + + +def is_null_slice(obj) -> bool: + """ + We have a null slice. + """ + return ( + isinstance(obj, slice) + and obj.start is None + and obj.stop is None + and obj.step is None + ) + + +def is_empty_slice(obj) -> bool: + """ + We have an empty slice, e.g. no values are selected. + """ + return ( + isinstance(obj, slice) + and obj.start is not None + and obj.stop is not None + and obj.start == obj.stop + ) + + +def is_true_slices(line: abc.Iterable) -> abc.Generator[bool, None, None]: + """ + Find non-trivial slices in "line": yields a bool. + """ + for k in line: + yield isinstance(k, slice) and not is_null_slice(k) + + +# TODO: used only once in indexing; belongs elsewhere? +def is_full_slice(obj, line: int) -> bool: + """ + We have a full length slice. + """ + return ( + isinstance(obj, slice) + and obj.start == 0 + and obj.stop == line + and obj.step is None + ) + + +def get_callable_name(obj): + # typical case has name + if hasattr(obj, "__name__"): + return obj.__name__ + # some objects don't; could recurse + if isinstance(obj, partial): + return get_callable_name(obj.func) + # fall back to class name + if callable(obj): + return type(obj).__name__ + # everything failed (probably because the argument + # wasn't actually callable); we return None + # instead of the empty string in this case to allow + # distinguishing between no name and a name of '' + return None + + +def apply_if_callable(maybe_callable, obj, **kwargs): + """ + Evaluate possibly callable input using obj and kwargs if it is callable, + otherwise return as it is. + + Parameters + ---------- + maybe_callable : possibly a callable + obj : NDFrame + **kwargs + """ + if isinstance(maybe_callable, Expression): + return maybe_callable._eval_expression(obj, **kwargs) + elif callable(maybe_callable): + return maybe_callable(obj, **kwargs) + + return maybe_callable + + +def standardize_mapping(into): + """ + Helper function to standardize a supplied mapping. + + Parameters + ---------- + into : instance or subclass of collections.abc.Mapping + Must be a class, an initialized collections.defaultdict, + or an instance of a collections.abc.Mapping subclass. + + Returns + ------- + mapping : a collections.abc.Mapping subclass or other constructor + a callable object that can accept an iterator to create + the desired Mapping. + + See Also + -------- + DataFrame.to_dict + Series.to_dict + """ + if not inspect.isclass(into): + if isinstance(into, defaultdict): + return partial(defaultdict, into.default_factory) + into = type(into) + if not issubclass(into, abc.Mapping): + raise TypeError(f"unsupported type: {into}") + if into == defaultdict: + raise TypeError("to_dict() only accepts initialized defaultdicts") + return into + + +@overload +def random_state(state: np.random.Generator) -> np.random.Generator: ... + + +@overload +def random_state( + state: int | np.ndarray | np.random.BitGenerator | np.random.RandomState | None, +) -> np.random.RandomState: ... + + +def random_state(state: RandomState | None = None): + """ + Helper function for processing random_state arguments. + + Parameters + ---------- + state : int, array-like, BitGenerator, Generator, np.random.RandomState, None. + If receives an int, array-like, or BitGenerator, passes to + np.random.RandomState() as seed. + If receives an np.random RandomState or Generator, just returns that unchanged. + If receives `None`, returns np.random. + If receives anything else, raises an informative ValueError. + + Default None. + + Returns + ------- + np.random.RandomState or np.random.Generator. If state is None, returns np.random + + """ + if is_integer(state) or isinstance(state, (np.ndarray, np.random.BitGenerator)): + return np.random.RandomState(state) + elif isinstance(state, np.random.RandomState): + return state + elif isinstance(state, np.random.Generator): + return state + elif state is None: + return np.random + else: + raise ValueError( + "random_state must be an integer, array-like, a BitGenerator, Generator, " + "a numpy RandomState, or None" + ) + + +_T = TypeVar("_T") # Secondary TypeVar for use in pipe's type hints + + +@overload +def pipe( + obj: _T, + func: Callable[Concatenate[_T, P], T], + *args: P.args, + **kwargs: P.kwargs, +) -> T: ... + + +@overload +def pipe( + obj: Any, + func: tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, +) -> T: ... + + +def pipe( + obj: _T, + func: Callable[Concatenate[_T, P], T] | tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, +) -> T: + """ + Apply a function ``func`` to object ``obj`` either by passing obj as the + first argument to the function or, in the case that the func is a tuple, + interpret the first element of the tuple as a function and pass the obj to + that function as a keyword argument whose key is the value of the second + element of the tuple. + + Parameters + ---------- + func : callable or tuple of (callable, str) + Function to apply to this object or, alternatively, a + ``(callable, data_keyword)`` tuple where ``data_keyword`` is a + string indicating the keyword of ``callable`` that expects the + object. + *args : iterable, optional + Positional arguments passed into ``func``. + **kwargs : dict, optional + A dictionary of keyword arguments passed into ``func``. + + Returns + ------- + object : the return type of ``func``. + """ + if isinstance(func, tuple): + # Assigning to func_ so pyright understands that it's a callable + func_, target = func + if target in kwargs: + msg = f"{target} is both the pipe target and a keyword argument" + raise ValueError(msg) + kwargs[target] = obj + return func_(*args, **kwargs) + else: + return func(obj, *args, **kwargs) + + +def get_rename_function(mapper): + """ + Returns a function that will map names/labels, dependent if mapper + is a dict, Series or just a function. + """ + + def f(x): + if x in mapper: + return mapper[x] + else: + return x + + return f if isinstance(mapper, (abc.Mapping, ABCSeries)) else mapper + + +def convert_to_list_like( + values: Hashable | Iterable | AnyArrayLike, +) -> list | AnyArrayLike: + """ + Convert list-like or scalar input to list-like. List, numpy and pandas array-like + inputs are returned unmodified whereas others are converted to list. + """ + if isinstance(values, (list, np.ndarray, ABCIndex, ABCSeries, ABCExtensionArray)): + return values + elif isinstance(values, abc.Iterable) and not isinstance(values, str): + return list(values) + + return [values] + + +@contextlib.contextmanager +def temp_setattr(obj, attr: str, value, condition: bool = True) -> Generator[None]: + """ + Temporarily set attribute on an object. + + Parameters + ---------- + obj : object + Object whose attribute will be modified. + attr : str + Attribute to modify. + value : Any + Value to temporarily set attribute to. + condition : bool, default True + Whether to set the attribute. Provided in order to not have to + conditionally use this context manager. + + Yields + ------ + object : obj with modified attribute. + """ + if condition: + old_value = getattr(obj, attr) + setattr(obj, attr, value) + try: + yield obj + finally: + if condition: + setattr(obj, attr, old_value) + + +def require_length_match(data, index: Index) -> None: + """ + Check the length of data matches the length of the index. + """ + if len(data) != len(index): + raise ValueError( + "Length of values " + f"({len(data)}) " + "does not match length of index " + f"({len(index)})" + ) + + +_cython_table = { + builtins.sum: "sum", + builtins.max: "max", + builtins.min: "min", + np.all: "all", + np.any: "any", + np.sum: "sum", + np.nansum: "sum", + np.mean: "mean", + np.nanmean: "mean", + np.prod: "prod", + np.nanprod: "prod", + np.std: "std", + np.nanstd: "std", + np.var: "var", + np.nanvar: "var", + np.median: "median", + np.nanmedian: "median", + np.max: "max", + np.nanmax: "max", + np.min: "min", + np.nanmin: "min", + np.cumprod: "cumprod", + np.nancumprod: "cumprod", + np.cumsum: "cumsum", + np.nancumsum: "cumsum", +} + + +def get_cython_func(arg: Callable) -> str | None: + """ + if we define an internal function for this argument, return it + """ + return _cython_table.get(arg) + + +def fill_missing_names(names: Sequence[Hashable | None]) -> list[Hashable]: + """ + If a name is missing then replace it by level_n, where n is the count + + Parameters + ---------- + names : list-like + list of column names or None values. + + Returns + ------- + list + list of column names with the None values replaced. + """ + return [f"level_{i}" if name is None else name for i, name in enumerate(names)] + + +def is_local_in_caller_frame(obj): + """ + Helper function used in detecting chained assignment. + + If the pandas object (DataFrame/Series) is a local variable + in the caller's frame, it should not be a case of chained + assignment or method call. + + For example: + + def test(): + df = pd.DataFrame(...) + df["a"] = 1 # not chained assignment + + Inside ``df.__setitem__``, we call this function to check whether `df` + (`self`) is a local variable in `test` frame (the frame calling setitem). If + so, we know it is not a case of chained assignment (even when the refcount + of `df` is below the threshold due to optimization of local variables). + """ + frame = sys._getframe(2) + for v in frame.f_locals.values(): + if v is obj: + return True + return False diff --git a/pandas/core/config_init.py b/pandas/core/config_init.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb7e1b9fff0ae92ca57f7bd7d669217f7d02cd4 --- /dev/null +++ b/pandas/core/config_init.py @@ -0,0 +1,923 @@ +""" +This module is imported from the pandas package __init__.py file +in order to ensure that the core.config options registered here will +be available as soon as the user loads the package. if register_option +is invoked inside specific modules, they will not be registered until that +module is imported, which may or may not be a problem. + +If you need to make sure options are available even before a certain +module is imported, register them here rather than in the module. + +""" + +from __future__ import annotations + +from collections.abc import Callable +import os +from typing import Any + +import pandas._config.config as cf +from pandas._config.config import ( + is_bool, + is_callable, + is_instance_factory, + is_int, + is_nonnegative_int, + is_one_of_factory, + is_str, + is_text, +) + +from pandas.errors import Pandas4Warning + +# compute + +use_bottleneck_doc = """ +: bool + Use the bottleneck library to accelerate if it is installed, + the default is True + Valid values: False,True +""" + + +def use_bottleneck_cb(key: str) -> None: + from pandas.core import nanops + + nanops.set_use_bottleneck(cf.get_option(key)) + + +use_numexpr_doc = """ +: bool + Use the numexpr library to accelerate computation if it is installed, + the default is True + Valid values: False,True +""" + + +def use_numexpr_cb(key: str) -> None: + from pandas.core.computation import expressions + + expressions.set_use_numexpr(cf.get_option(key)) + + +use_numba_doc = """ +: bool + Use the numba engine option for select operations if it is installed, + the default is False + Valid values: False,True +""" + + +def use_numba_cb(key: str) -> None: + from pandas.core.util import numba_ + + numba_.set_use_numba(cf.get_option(key)) + + +with cf.config_prefix("compute"): + cf.register_option( + "use_bottleneck", + True, + use_bottleneck_doc, + validator=is_bool, + cb=use_bottleneck_cb, + ) + cf.register_option( + "use_numexpr", True, use_numexpr_doc, validator=is_bool, cb=use_numexpr_cb + ) + cf.register_option( + "use_numba", False, use_numba_doc, validator=is_bool, cb=use_numba_cb + ) +# +# options from the "display" namespace + +pc_precision_doc = """ +: int + Floating point output precision in terms of number of places after the + decimal, for regular formatting as well as scientific notation. Similar + to ``precision`` in :meth:`numpy.set_printoptions`. +""" + +pc_max_rows_doc = """ +: int + If max_rows is exceeded, switch to truncate view. Depending on + `large_repr`, objects are either centrally truncated or printed as + a summary view. + + 'None' value means unlimited. Beware that printing a large number of rows + could cause your rendering environment (the browser, etc.) to crash. + + In case python/IPython is running in a terminal and `large_repr` + equals 'truncate' this can be set to 0 and pandas will auto-detect + the height of the terminal and print a truncated object which fits + the screen height. The IPython notebook, IPython qtconsole, or + IDLE do not run in a terminal and hence it is not possible to do + correct auto-detection. +""" + +pc_min_rows_doc = """ +: int + The numbers of rows to show in a truncated view (when `max_rows` is + exceeded). Ignored when `max_rows` is set to None or 0. When set to + None, follows the value of `max_rows`. +""" + +pc_max_cols_doc = """ +: int + If max_cols is exceeded, switch to truncate view. Depending on + `large_repr`, objects are either centrally truncated or printed as + a summary view. + + 'None' value means unlimited. Beware that printing a large number of + columns could cause your rendering environment (the browser, etc.) to + crash. + + In case python/IPython is running in a terminal and `large_repr` + equals 'truncate' this can be set to 0 or None and pandas will auto-detect + the width of the terminal and print a truncated object which fits + the screen width. The IPython notebook, IPython qtconsole, or IDLE + do not run in a terminal and hence it is not possible to do + correct auto-detection and defaults to 20. +""" + +pc_max_categories_doc = """ +: int + This sets the maximum number of categories pandas should output when + printing out a `Categorical` or a Series of dtype "category". +""" + +pc_max_info_cols_doc = """ +: int + max_info_columns is used in DataFrame.info method to decide if + per column information will be printed. +""" + +pc_nb_repr_h_doc = """ +: boolean + When True, IPython notebook will use html representation for + pandas objects (if it is available). +""" + +pc_pprint_nest_depth = """ +: int + Controls the number of nested levels to process when pretty-printing +""" + +pc_multi_sparse_doc = """ +: boolean + "sparsify" MultiIndex display (don't display repeated + elements in outer levels within groups) +""" + +float_format_doc = """ +: callable + The callable should accept a floating point number and return + a string with the desired format of the number. This is used + in some places like SeriesFormatter. + See formats.format.EngFormatter for an example. +""" + +max_colwidth_doc = """ +: int or None + The maximum width in characters of a column in the repr of + a pandas data structure. When the column overflows, a "..." + placeholder is embedded in the output. A 'None' value means unlimited. +""" + +colheader_justify_doc = """ +: 'left'/'right' + Controls the justification of column headers. used by DataFrameFormatter. +""" + +pc_expand_repr_doc = """ +: boolean + Whether to print out the full DataFrame repr for wide DataFrames across + multiple lines, `max_columns` is still respected, but the output will + wrap-around across multiple "pages" if its width exceeds `display.width`. +""" + +pc_show_dimensions_doc = """ +: boolean or 'truncate' + Whether to print out dimensions at the end of DataFrame repr. + If 'truncate' is specified, only print out the dimensions if the + frame is truncated (e.g. not display all rows and/or columns) +""" + +pc_east_asian_width_doc = """ +: boolean + Whether to use the Unicode East Asian Width to calculate the display text + width. + Enabling this may affect to the performance (default: False) +""" + + +pc_table_schema_doc = """ +: boolean + Whether to publish a Table Schema representation for frontends + that support it. + (default: False) +""" + +pc_html_border_doc = """ +: int + A ``border=value`` attribute is inserted in the ```` tag + for the DataFrame HTML repr. +""" + +pc_html_use_mathjax_doc = """\ +: boolean + When True, Jupyter notebook will process table contents using MathJax, + rendering mathematical expressions enclosed by the dollar symbol. + (default: True) +""" + +pc_max_dir_items = """\ +: int + The number of items that will be added to `dir(...)`. 'None' value means + unlimited. Because dir is cached, changing this option will not immediately + affect already existing dataframes until a column is deleted or added. + + This is for instance used to suggest columns from a dataframe to tab + completion. +""" + +pc_width_doc = """ +: int + Width of the display in characters. In case python/IPython is running in + a terminal this can be set to None and pandas will correctly auto-detect + the width. + Note that the IPython notebook, IPython qtconsole, or IDLE do not run in a + terminal and hence it is not possible to correctly detect the width. +""" + +pc_chop_threshold_doc = """ +: float or None + if set to a float value, all float values smaller than the given threshold + will be displayed as exactly 0 by repr and friends. +""" + +pc_max_seq_items = """ +: int or None + When pretty-printing a long sequence, no more then `max_seq_items` + will be printed. If items are omitted, they will be denoted by the + addition of "..." to the resulting string. + + If set to None, the number of items to be printed is unlimited. +""" + +pc_max_info_rows_doc = """ +: int + df.info() will usually show null-counts for each column. + For large frames this can be quite slow. max_info_rows and max_info_cols + limit this null check only to frames with smaller dimensions than + specified. +""" + +pc_large_repr_doc = """ +: 'truncate'/'info' + For DataFrames exceeding max_rows/max_cols, the repr (and HTML repr) can + show a truncated table, or switch to the view from + df.info() (the behaviour in earlier versions of pandas). +""" + +pc_memory_usage_doc = """ +: bool, string or None + This specifies if the memory usage of a DataFrame should be displayed when + df.info() is called. Valid values True,False,'deep' +""" + + +def table_schema_cb(key: str) -> None: + from pandas.io.formats.printing import enable_data_resource_formatter + + enable_data_resource_formatter(cf.get_option(key)) + + +def is_terminal() -> bool: + """ + Detect if Python is running in a terminal. + + Returns True if Python is running in a terminal or False if not. + """ + try: + # error: Name 'get_ipython' is not defined + ip = get_ipython() # type: ignore[name-defined] + except NameError: # assume standard Python interpreter in a terminal + return True + else: + if hasattr(ip, "kernel"): # IPython as a Jupyter kernel + return False + else: # IPython in a terminal + return True + + +with cf.config_prefix("display"): + cf.register_option("precision", 6, pc_precision_doc, validator=is_nonnegative_int) + cf.register_option( + "float_format", + None, + float_format_doc, + validator=is_one_of_factory([None, is_callable]), + ) + cf.register_option( + "max_info_rows", + 1690785, + pc_max_info_rows_doc, + validator=is_int, + ) + cf.register_option("max_rows", 60, pc_max_rows_doc, validator=is_nonnegative_int) + cf.register_option( + "min_rows", + 10, + pc_min_rows_doc, + validator=is_instance_factory((type(None), int)), + ) + cf.register_option("max_categories", 8, pc_max_categories_doc, validator=is_int) + + cf.register_option( + "max_colwidth", + 50, + max_colwidth_doc, + validator=is_nonnegative_int, + ) + if is_terminal(): + max_cols = 0 # automatically determine optimal number of columns + else: + max_cols = 20 # cannot determine optimal number of columns + cf.register_option( + "max_columns", max_cols, pc_max_cols_doc, validator=is_nonnegative_int + ) + cf.register_option( + "large_repr", + "truncate", + pc_large_repr_doc, + validator=is_one_of_factory(["truncate", "info"]), + ) + cf.register_option("max_info_columns", 100, pc_max_info_cols_doc, validator=is_int) + cf.register_option( + "colheader_justify", "right", colheader_justify_doc, validator=is_text + ) + cf.register_option("notebook_repr_html", True, pc_nb_repr_h_doc, validator=is_bool) + cf.register_option("pprint_nest_depth", 3, pc_pprint_nest_depth, validator=is_int) + cf.register_option("multi_sparse", True, pc_multi_sparse_doc, validator=is_bool) + cf.register_option("expand_frame_repr", True, pc_expand_repr_doc) + cf.register_option( + "show_dimensions", + "truncate", + pc_show_dimensions_doc, + validator=is_one_of_factory([True, False, "truncate"]), + ) + cf.register_option("chop_threshold", None, pc_chop_threshold_doc) + cf.register_option("max_seq_items", 100, pc_max_seq_items) + cf.register_option( + "width", 80, pc_width_doc, validator=is_instance_factory((type(None), int)) + ) + cf.register_option( + "memory_usage", + True, + pc_memory_usage_doc, + validator=is_one_of_factory([None, True, False, "deep"]), + ) + cf.register_option( + "unicode.east_asian_width", False, pc_east_asian_width_doc, validator=is_bool + ) + cf.register_option( + "unicode.ambiguous_as_wide", False, pc_east_asian_width_doc, validator=is_bool + ) + cf.register_option( + "html.table_schema", + False, + pc_table_schema_doc, + validator=is_bool, + cb=table_schema_cb, + ) + cf.register_option("html.border", 1, pc_html_border_doc, validator=is_int) + cf.register_option( + "html.use_mathjax", True, pc_html_use_mathjax_doc, validator=is_bool + ) + cf.register_option( + "max_dir_items", 100, pc_max_dir_items, validator=is_nonnegative_int + ) + +tc_sim_interactive_doc = """ +: boolean + Whether to simulate interactive mode for purposes of testing +""" + +with cf.config_prefix("mode"): + cf.register_option("sim_interactive", False, tc_sim_interactive_doc) + + +copy_on_write_doc = """ +: bool + Use new copy-view behaviour using Copy-on-Write. No longer used, + pandas now always uses Copy-on-Write behavior. This option will + be removed in pandas 4.0. +""" + + +with cf.config_prefix("mode"): + cf.register_option( + "copy_on_write", + # Get the default from an environment variable, if set, otherwise defaults + # to False. This environment variable can be set for testing. + "warn" + if os.environ.get("PANDAS_COPY_ON_WRITE", "0") == "warn" + else os.environ.get("PANDAS_COPY_ON_WRITE", "1") == "1", + copy_on_write_doc, + validator=is_one_of_factory([True, False, "warn"]), + ) + + +# user warnings +chained_assignment = """ +: string + Raise an exception, warn, or no action if trying to use chained assignment, + The default is warn +""" + +with cf.config_prefix("mode"): + cf.register_option( + "chained_assignment", + "warn", + chained_assignment, + validator=is_one_of_factory([None, "warn", "raise"]), + ) + +performance_warnings = """ +: boolean + Whether to show or hide PerformanceWarnings. +""" + +with cf.config_prefix("mode"): + cf.register_option( + "performance_warnings", + True, + performance_warnings, + validator=is_bool, + ) + + +string_storage_doc = """ +: string + The default storage for StringDtype. +""" + + +def is_valid_string_storage(value: Any) -> None: + legal_values = ["auto", "python", "pyarrow"] + if value not in legal_values: + msg = "Value must be one of python|pyarrow" + raise ValueError(msg) + + +with cf.config_prefix("mode"): + cf.register_option( + "string_storage", + "auto", + string_storage_doc, + # validator=is_one_of_factory(["python", "pyarrow"]), + validator=is_valid_string_storage, + ) + + +# Set up the io.excel specific reader configuration. +reader_engine_doc = """ +: string + The default Excel reader engine for '{ext}' files. Available options: + auto, {others}. +""" + +_xls_options = ["xlrd", "calamine"] +_xlsm_options = ["xlrd", "openpyxl", "calamine"] +_xlsx_options = ["xlrd", "openpyxl", "calamine"] +_ods_options = ["odf", "calamine"] +_xlsb_options = ["pyxlsb", "calamine"] + + +with cf.config_prefix("io.excel.xls"): + cf.register_option( + "reader", + "auto", + reader_engine_doc.format(ext="xls", others=", ".join(_xls_options)), + validator=is_one_of_factory([*_xls_options, "auto"]), + ) + +with cf.config_prefix("io.excel.xlsm"): + cf.register_option( + "reader", + "auto", + reader_engine_doc.format(ext="xlsm", others=", ".join(_xlsm_options)), + validator=is_one_of_factory([*_xlsm_options, "auto"]), + ) + + +with cf.config_prefix("io.excel.xlsx"): + cf.register_option( + "reader", + "auto", + reader_engine_doc.format(ext="xlsx", others=", ".join(_xlsx_options)), + validator=is_one_of_factory([*_xlsx_options, "auto"]), + ) + + +with cf.config_prefix("io.excel.ods"): + cf.register_option( + "reader", + "auto", + reader_engine_doc.format(ext="ods", others=", ".join(_ods_options)), + validator=is_one_of_factory([*_ods_options, "auto"]), + ) + +with cf.config_prefix("io.excel.xlsb"): + cf.register_option( + "reader", + "auto", + reader_engine_doc.format(ext="xlsb", others=", ".join(_xlsb_options)), + validator=is_one_of_factory([*_xlsb_options, "auto"]), + ) + +# Set up the io.excel specific writer configuration. +writer_engine_doc = """ +: string + The default Excel writer engine for '{ext}' files. Available options: + auto, {others}. +""" + +_xlsm_options = ["openpyxl"] +_xlsx_options = ["openpyxl", "xlsxwriter"] +_ods_options = ["odf"] + + +with cf.config_prefix("io.excel.xlsm"): + cf.register_option( + "writer", + "auto", + writer_engine_doc.format(ext="xlsm", others=", ".join(_xlsm_options)), + validator=str, + ) + + +with cf.config_prefix("io.excel.xlsx"): + cf.register_option( + "writer", + "auto", + writer_engine_doc.format(ext="xlsx", others=", ".join(_xlsx_options)), + validator=str, + ) + + +with cf.config_prefix("io.excel.ods"): + cf.register_option( + "writer", + "auto", + writer_engine_doc.format(ext="ods", others=", ".join(_ods_options)), + validator=str, + ) + + +# Set up the io.parquet specific configuration. +parquet_engine_doc = """ +: string + The default parquet reader/writer engine. Available options: + 'auto', 'pyarrow', 'fastparquet', the default is 'auto' +""" + +with cf.config_prefix("io.parquet"): + cf.register_option( + "engine", + "auto", + parquet_engine_doc, + validator=is_one_of_factory(["auto", "pyarrow", "fastparquet"]), + ) + + +# Set up the io.sql specific configuration. +sql_engine_doc = """ +: string + The default sql reader/writer engine. Available options: + 'auto', 'sqlalchemy', the default is 'auto' +""" + +with cf.config_prefix("io.sql"): + cf.register_option( + "engine", + "auto", + sql_engine_doc, + validator=is_one_of_factory(["auto", "sqlalchemy"]), + ) + +# -------- +# Plotting +# --------- + +plotting_backend_doc = """ +: str + The plotting backend to use. The default value is "matplotlib", the + backend provided with pandas. Other backends can be specified by + providing the name of the module that implements the backend. +""" + + +def register_plotting_backend_cb(key: str | None) -> None: + if key == "matplotlib": + # We defer matplotlib validation, since it's the default + return + from pandas.plotting._core import _get_plot_backend + + _get_plot_backend(key) + + +with cf.config_prefix("plotting"): + cf.register_option( + "backend", + defval="matplotlib", + doc=plotting_backend_doc, + validator=register_plotting_backend_cb, # type: ignore[arg-type] + ) + + +register_converter_doc = """ +: bool or 'auto'. + Whether to register converters with matplotlib's units registry for + dates, times, datetimes, and Periods. Toggling to False will remove + the converters, restoring any converters that pandas overwrote. +""" + + +def register_converter_cb(key: str) -> None: + from pandas.plotting import ( + deregister_matplotlib_converters, + register_matplotlib_converters, + ) + + if cf.get_option(key): + register_matplotlib_converters() + else: + deregister_matplotlib_converters() + + +with cf.config_prefix("plotting.matplotlib"): + cf.register_option( + "register_converters", + "auto", + register_converter_doc, + validator=is_one_of_factory(["auto", True, False]), + cb=register_converter_cb, + ) + +# ------ +# Styler +# ------ + +styler_sparse_index_doc = """ +: bool + Whether to sparsify the display of a hierarchical index. Setting to False will + display each explicit level element in a hierarchical key for each row. +""" + +styler_sparse_columns_doc = """ +: bool + Whether to sparsify the display of hierarchical columns. Setting to False will + display each explicit level element in a hierarchical key for each column. +""" + +styler_render_repr = """ +: str + Determine which output to use in Jupyter Notebook in {"html", "latex"}. +""" + +styler_max_elements = """ +: int + The maximum number of data-cell (
) elements that will be rendered before + trimming will occur over columns, rows or both if needed. +""" + +styler_max_rows = """ +: int, optional + The maximum number of rows that will be rendered. May still be reduced to + satisfy ``max_elements``, which takes precedence. +""" + +styler_max_columns = """ +: int, optional + The maximum number of columns that will be rendered. May still be reduced to + satisfy ``max_elements``, which takes precedence. +""" + +styler_precision = """ +: int + The precision for floats and complex numbers. +""" + +styler_decimal = """ +: str + The character representation for the decimal separator for floats and complex. +""" + +styler_thousands = """ +: str, optional + The character representation for thousands separator for floats, int and complex. +""" + +styler_na_rep = """ +: str, optional + The string representation for values identified as missing. +""" + +styler_escape = """ +: str, optional + Whether to escape certain characters according to the given context; html or latex. +""" + +styler_formatter = """ +: str, callable, dict, optional + A formatter object to be used as default within ``Styler.format``. +""" + +styler_multirow_align = """ +: {"c", "t", "b"} + The specifier for vertical alignment of sparsified LaTeX multirows. +""" + +styler_multicol_align = r""" +: {"r", "c", "l", "naive-l", "naive-r"} + The specifier for horizontal alignment of sparsified LaTeX multicolumns. Pipe + decorators can also be added to non-naive values to draw vertical + rules, e.g. "\|r" will draw a rule on the left side of right aligned merged cells. +""" + +styler_hrules = """ +: bool + Whether to add horizontal rules on top and bottom and below the headers. +""" + +styler_environment = """ +: str + The environment to replace ``\\begin{table}``. If "longtable" is used results + in a specific longtable environment format. +""" + +styler_encoding = """ +: str + The encoding used for output HTML and LaTeX files. +""" + +styler_mathjax = """ +: bool + If False will render special CSS classes to table attributes that indicate Mathjax + will not be used in Jupyter Notebook. +""" + +with cf.config_prefix("styler"): + cf.register_option("sparse.index", True, styler_sparse_index_doc, validator=is_bool) + + cf.register_option( + "sparse.columns", True, styler_sparse_columns_doc, validator=is_bool + ) + + cf.register_option( + "render.repr", + "html", + styler_render_repr, + validator=is_one_of_factory(["html", "latex"]), + ) + + cf.register_option( + "render.max_elements", + 2**18, + styler_max_elements, + validator=is_nonnegative_int, + ) + + cf.register_option( + "render.max_rows", + None, + styler_max_rows, + validator=is_nonnegative_int, + ) + + cf.register_option( + "render.max_columns", + None, + styler_max_columns, + validator=is_nonnegative_int, + ) + + cf.register_option("render.encoding", "utf-8", styler_encoding, validator=is_str) + + cf.register_option("format.decimal", ".", styler_decimal, validator=is_str) + + cf.register_option( + "format.precision", 6, styler_precision, validator=is_nonnegative_int + ) + + cf.register_option( + "format.thousands", + None, + styler_thousands, + validator=is_instance_factory((type(None), str)), + ) + + cf.register_option( + "format.na_rep", + None, + styler_na_rep, + validator=is_instance_factory((type(None), str)), + ) + + cf.register_option( + "format.escape", + None, + styler_escape, + validator=is_one_of_factory([None, "html", "latex", "latex-math"]), + ) + + # error: Argument 1 to "is_instance_factory" has incompatible type "tuple[ + # ..., , ...]"; expected "type | tuple[type, ...]" + cf.register_option( + "format.formatter", + None, + styler_formatter, + validator=is_instance_factory( + (type(None), dict, Callable, str) # type: ignore[arg-type] + ), + ) + + cf.register_option("html.mathjax", True, styler_mathjax, validator=is_bool) + + cf.register_option( + "latex.multirow_align", + "c", + styler_multirow_align, + validator=is_one_of_factory(["c", "t", "b", "naive"]), + ) + + val_mca = ["r", "|r|", "|r", "r|", "c", "|c|", "|c", "c|", "l", "|l|", "|l", "l|"] + val_mca += ["naive-l", "naive-r"] + cf.register_option( + "latex.multicol_align", + "r", + styler_multicol_align, + validator=is_one_of_factory(val_mca), + ) + + cf.register_option("latex.hrules", False, styler_hrules, validator=is_bool) + + cf.register_option( + "latex.environment", + None, + styler_environment, + validator=is_instance_factory((type(None), str)), + ) + + +with cf.config_prefix("future"): + cf.register_option( + "infer_string", + False if os.environ.get("PANDAS_FUTURE_INFER_STRING", "1") == "0" else True, + "Whether to infer sequence of str objects as pyarrow string " + "dtype, which will be the default in pandas 3.0 " + "(at which point this option will be deprecated).", + validator=is_one_of_factory([True, False]), + ) + + cf.register_option( + "no_silent_downcasting", + False, + "This option is deprecated and will be removed in a future version. " + "It has no effect.", + validator=is_one_of_factory([True, False]), + ) + + cf.register_option( + "distinguish_nan_and_na", + os.environ.get("PANDAS_FUTURE_DISTINGUISH_NAN_AND_NA", "0") == "1", + "Whether to treat NaN entries as distinct from pd.NA in " + "numpy-nullable and pyarrow float dtypes. By default treats both " + "interchangeable as missing values (NaN will be coerced to NA). " + "See discussion in " + "https://github.com/pandas-dev/pandas/issues/32265", + validator=is_one_of_factory([True, False]), + ) + + cf.register_option( + "python_scalars", + False if os.environ.get("PANDAS_FUTURE_PYTHON_SCALARS", "0") == "0" else True, + "Whether to return Python scalars instead of NumPy or PyArrow scalars. " + "Currently experimental, setting to True is not recommended for end users.", + validator=is_one_of_factory([True, False]), + ) + + +# GH#59502 +cf.deprecate_option("future.no_silent_downcasting", Pandas4Warning) +cf.deprecate_option( + "mode.copy_on_write", + Pandas4Warning, + msg=( + "The 'mode.copy_on_write' option is deprecated. Copy-on-Write can no longer " + "be disabled (it is always enabled with pandas >= 3.0), and setting the option " + "has no impact. This option will be removed in pandas 4.0." + ), +) diff --git a/pandas/core/construction.py b/pandas/core/construction.py new file mode 100644 index 0000000000000000000000000000000000000000..953309e03fac8b5c722663477af6db46d8be4f94 --- /dev/null +++ b/pandas/core/construction.py @@ -0,0 +1,852 @@ +""" +Constructor functions intended to be shared by pd.array, Series.__init__, +and Index.__new__. + +These should not depend on core.internals. +""" + +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + cast, + overload, +) + +import numpy as np +from numpy import ma + +from pandas._config import using_string_dtype + +from pandas._libs import lib +from pandas._libs.tslibs import ( + get_supported_dtype, + is_supported_dtype, +) +from pandas.util._decorators import set_module + +from pandas.core.dtypes.base import ExtensionDtype +from pandas.core.dtypes.cast import ( + construct_1d_arraylike_from_scalar, + construct_1d_object_array_from_listlike, + maybe_cast_to_datetime, + maybe_cast_to_integer_array, + maybe_convert_platform, + maybe_promote, +) +from pandas.core.dtypes.common import ( + ensure_object, + is_list_like, + is_object_dtype, + pandas_dtype, +) +from pandas.core.dtypes.dtypes import NumpyEADtype +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCExtensionArray, + ABCIndex, + ABCSeries, +) +from pandas.core.dtypes.missing import isna + +import pandas.core.common as com + +if TYPE_CHECKING: + from collections.abc import Sequence + + from pandas._typing import ( + AnyArrayLike, + ArrayLike, + Dtype, + DtypeObj, + T, + ) + + from pandas import ( + Index, + Series, + ) + from pandas.core.arrays import ( + DatetimeArray, + ExtensionArray, + TimedeltaArray, + ) + + +@set_module("pandas") +def array( + data: Sequence[object] | AnyArrayLike, + dtype: Dtype | None = None, + copy: bool = True, +) -> ExtensionArray: + """ + Create an array. + + This method constructs an array using pandas extension types when possible. + If `dtype` is specified, it determines the type of array returned. Otherwise, + pandas attempts to infer the appropriate dtype based on `data`. + + Parameters + ---------- + data : Sequence of objects + The scalars inside `data` should be instances of the + scalar type for `dtype`. It's expected that `data` + represents a 1-dimensional array of data. + + When `data` is an Index or Series, the underlying array + will be extracted from `data`. + + dtype : str, np.dtype, or ExtensionDtype, optional + The dtype to use for the array. This may be a NumPy + dtype or an extension type registered with pandas using + :meth:`pandas.api.extensions.register_extension_dtype`. + + If not specified, there are two possibilities: + + 1. When `data` is a :class:`Series`, :class:`Index`, or + :class:`ExtensionArray`, the `dtype` will be taken + from the data. + 2. Otherwise, pandas will attempt to infer the `dtype` + from the data. + + Note that when `data` is a NumPy array, ``data.dtype`` is + *not* used for inferring the array type. This is because + NumPy cannot represent all the types of data that can be + held in extension arrays. + + Currently, pandas will infer an extension dtype for sequences of + + ============================== ======================================= + Scalar Type Array Type + ============================== ======================================= + :class:`pandas.Interval` :class:`pandas.arrays.IntervalArray` + :class:`pandas.Period` :class:`pandas.arrays.PeriodArray` + :class:`datetime.datetime` :class:`pandas.arrays.DatetimeArray` + :class:`datetime.timedelta` :class:`pandas.arrays.TimedeltaArray` + :class:`int` :class:`pandas.arrays.IntegerArray` + :class:`float` :class:`pandas.arrays.FloatingArray` + :class:`str` :class:`pandas.arrays.StringArray` or + :class:`pandas.arrays.ArrowStringArray` + :class:`bool` :class:`pandas.arrays.BooleanArray` + ============================== ======================================= + + The ExtensionArray created when the scalar type is :class:`str` is determined by + ``pd.options.mode.string_storage`` if the dtype is not explicitly given. + + For all other cases, NumPy's usual inference rules will be used. + copy : bool, default True + Whether to copy the data, even if not necessary. Depending + on the type of `data`, creating the new array may require + copying data, even if ``copy=False``. + + Returns + ------- + ExtensionArray + The newly created array. + + Raises + ------ + ValueError + When `data` is not 1-dimensional. + + See Also + -------- + numpy.array : Construct a NumPy array. + Series : Construct a pandas Series. + Index : Construct a pandas Index. + arrays.NumpyExtensionArray : ExtensionArray wrapping a NumPy array. + Series.array : Extract the array stored within a Series. + + Notes + ----- + Omitting the `dtype` argument means pandas will attempt to infer the + best array type from the values in the data. As new array types are + added by pandas and 3rd party libraries, the "best" array type may + change. We recommend specifying `dtype` to ensure that + + 1. the correct array type for the data is returned + 2. the returned array type doesn't change as new extension types + are added by pandas and third-party libraries + + Additionally, if the underlying memory representation of the returned + array matters, we recommend specifying the `dtype` as a concrete object + rather than a string alias or allowing it to be inferred. For example, + a future version of pandas or a 3rd-party library may include a + dedicated ExtensionArray for string data. In this event, the following + would no longer return a :class:`arrays.NumpyExtensionArray` backed by a + NumPy array. + + >>> pd.array(["a", "b"], dtype=str) + + ['a', 'b'] + Length: 2, dtype: str + + This would instead return the new ExtensionArray dedicated for string + data. If you really need the new array to be backed by a NumPy array, + specify that in the dtype. + + >>> pd.array(["a", "b"], dtype=np.dtype(" + ['a', 'b'] + Length: 2, dtype: str32 + + Finally, Pandas has arrays that mostly overlap with NumPy + + * :class:`arrays.DatetimeArray` + * :class:`arrays.TimedeltaArray` + + When data with a ``datetime64[ns]`` or ``timedelta64[ns]`` dtype is + passed, pandas will always return a ``DatetimeArray`` or ``TimedeltaArray`` + rather than a ``NumpyExtensionArray``. This is for symmetry with the case of + timezone-aware data, which NumPy does not natively support. + + >>> pd.array(["2015", "2016"], dtype="datetime64[ns]") + + ['2015-01-01 00:00:00', '2016-01-01 00:00:00'] + Length: 2, dtype: datetime64[ns] + + >>> pd.array(["1h", "2h"], dtype="timedelta64[ns]") + + ['0 days 01:00:00', '0 days 02:00:00'] + Length: 2, dtype: timedelta64[ns] + + Examples + -------- + If a dtype is not specified, pandas will infer the best dtype from the values. + See the description of `dtype` for the types pandas infers for. + + >>> pd.array([1, 2]) + + [1, 2] + Length: 2, dtype: Int64 + + >>> pd.array([1, 2, np.nan]) + + [1, 2, ] + Length: 3, dtype: Int64 + + >>> pd.array([1.1, 2.2]) + + [1.1, 2.2] + Length: 2, dtype: Float64 + + >>> pd.array(["a", None, "c"]) + + ['a', , 'c'] + Length: 3, dtype: string + + >>> with pd.option_context("string_storage", "python"): + ... arr = pd.array(["a", None, "c"]) + >>> arr + + ['a', , 'c'] + Length: 3, dtype: string + + >>> pd.array([pd.Period("2000", freq="D"), pd.Period("2000", freq="D")]) + + ['2000-01-01', '2000-01-01'] + Length: 2, dtype: period[D] + + You can use the string alias for `dtype` + + >>> pd.array(["a", "b", "a"], dtype="category") + ['a', 'b', 'a'] + Categories (2, str): ['a', 'b'] + + Or specify the actual dtype + + >>> pd.array( + ... ["a", "b", "a"], dtype=pd.CategoricalDtype(["a", "b", "c"], ordered=True) + ... ) + ['a', 'b', 'a'] + Categories (3, str): ['a' < 'b' < 'c'] + + If pandas does not infer a dedicated extension type a + :class:`arrays.NumpyExtensionArray` is returned. + + >>> pd.array([1 + 1j, 3 + 2j]) + + [(1+1j), (3+2j)] + Length: 2, dtype: complex128 + + As mentioned in the "Notes" section, new extension types may be added + in the future (by pandas or 3rd party libraries), causing the return + value to no longer be a :class:`arrays.NumpyExtensionArray`. Specify the + `dtype` as a NumPy dtype if you need to ensure there's no future change in + behavior. + + >>> pd.array([1, 2], dtype=np.dtype("int32")) + + [1, 2] + Length: 2, dtype: int32 + + `data` must be 1-dimensional. A ValueError is raised when the input + has the wrong dimensionality. + + >>> pd.array(1) + Traceback (most recent call last): + ... + ValueError: Cannot pass scalar '1' to 'pandas.array'. + """ + from pandas.core.arrays import ( + BooleanArray, + DatetimeArray, + ExtensionArray, + FloatingArray, + IntegerArray, + NumpyExtensionArray, + TimedeltaArray, + ) + from pandas.core.arrays.string_ import StringDtype + + if lib.is_scalar(data): + msg = f"Cannot pass scalar '{data}' to 'pandas.array'." + raise ValueError(msg) + elif isinstance(data, ABCDataFrame): + raise TypeError("Cannot pass DataFrame to 'pandas.array'") + + if dtype is None and isinstance(data, (ABCSeries, ABCIndex, ExtensionArray)): + # Note: we exclude np.ndarray here, will do type inference on it + dtype = data.dtype + + data = extract_array(data, extract_numpy=True) + + # this returns None for not-found dtypes. + if dtype is not None: + dtype = pandas_dtype(dtype) + + if isinstance(data, ExtensionArray) and (dtype is None or data.dtype == dtype): + # e.g. TimedeltaArray[s], avoid casting to NumpyExtensionArray + if copy: + return data.copy() + return data + + if isinstance(dtype, ExtensionDtype): + cls = dtype.construct_array_type() + return cls._from_sequence(data, dtype=dtype, copy=copy) + + if dtype is None: + was_ndarray = isinstance(data, np.ndarray) + # error: Item "Sequence[object]" of "Sequence[object] | ExtensionArray | + # ndarray[Any, Any]" has no attribute "dtype" + if not was_ndarray or data.dtype == object: # type: ignore[union-attr] + result = lib.maybe_convert_objects( + ensure_object(data), + convert_non_numeric=True, + convert_to_nullable_dtype=True, + dtype_if_all_nat=np.dtype("M8[s]"), + ) + result = ensure_wrapped_if_datetimelike(result) + if isinstance(result, np.ndarray): + if len(result) == 0 and not was_ndarray: + # e.g. empty list + return FloatingArray._from_sequence(data, dtype="Float64") + return NumpyExtensionArray._from_sequence( + data, dtype=result.dtype, copy=copy + ) + if result is data and copy: + return result.copy() + return result + + data = cast(np.ndarray, data) + result = ensure_wrapped_if_datetimelike(data) + if result is not data: + result = cast("DatetimeArray | TimedeltaArray", result) + if copy and result.dtype == data.dtype: + return result.copy() + return result + + if data.dtype.kind in "SU": + # StringArray/ArrowStringArray depending on pd.options.mode.string_storage + dtype = StringDtype() + cls = dtype.construct_array_type() + return cls._from_sequence(data, dtype=dtype, copy=copy) + + elif data.dtype.kind in "iu": + dtype = IntegerArray._dtype_cls._get_dtype_mapping()[data.dtype] + return IntegerArray._from_sequence(data, dtype=dtype, copy=copy) + elif data.dtype.kind == "f": + # GH#44715 Exclude np.float16 bc FloatingArray does not support it; + # we will fall back to NumpyExtensionArray. + if data.dtype == np.float16: + return NumpyExtensionArray._from_sequence( + data, dtype=data.dtype, copy=copy + ) + dtype = FloatingArray._dtype_cls._get_dtype_mapping()[data.dtype] + return FloatingArray._from_sequence(data, dtype=dtype, copy=copy) + + elif data.dtype.kind == "b": + return BooleanArray._from_sequence(data, dtype="boolean", copy=copy) + else: + # e.g. complex + return NumpyExtensionArray._from_sequence(data, dtype=data.dtype, copy=copy) + + # Pandas overrides NumPy for + # 1. datetime64[ns,us,ms,s] + # 2. timedelta64[ns,us,ms,s] + # so that a DatetimeArray is returned. + if lib.is_np_dtype(dtype, "M") and is_supported_dtype(dtype): + return DatetimeArray._from_sequence(data, dtype=dtype, copy=copy) + if lib.is_np_dtype(dtype, "m") and is_supported_dtype(dtype): + return TimedeltaArray._from_sequence(data, dtype=dtype, copy=copy) + + elif lib.is_np_dtype(dtype, "mM"): + raise ValueError( + # GH#53817 + r"datetime64 and timedelta64 dtype resolutions other than " + r"'s', 'ms', 'us', and 'ns' are no longer supported." + ) + + return NumpyExtensionArray._from_sequence(data, dtype=dtype, copy=copy) + + +_typs = frozenset( + { + "index", + "rangeindex", + "multiindex", + "datetimeindex", + "timedeltaindex", + "periodindex", + "categoricalindex", + "intervalindex", + "series", + } +) + + +@overload +def extract_array( + obj: Series | Index, extract_numpy: bool = ..., extract_range: bool = ... +) -> ArrayLike: ... + + +@overload +def extract_array( + obj: T, extract_numpy: bool = ..., extract_range: bool = ... +) -> T | ArrayLike: ... + + +def extract_array( + obj: T, extract_numpy: bool = False, extract_range: bool = False +) -> T | ArrayLike: + """ + Extract the ndarray or ExtensionArray from a Series or Index. + + For all other types, `obj` is just returned as is. + + Parameters + ---------- + obj : object + For Series / Index, the underlying ExtensionArray is unboxed. + + extract_numpy : bool, default False + Whether to extract the ndarray from a NumpyExtensionArray. + + extract_range : bool, default False + If we have a RangeIndex, return range._values if True + (which is a materialized integer ndarray), otherwise return unchanged. + + Returns + ------- + arr : object + + Examples + -------- + >>> extract_array(pd.Series(["a", "b", "c"], dtype="category")) + ['a', 'b', 'c'] + Categories (3, str): ['a', 'b', 'c'] + + Other objects like lists, arrays, and DataFrames are just passed through. + + >>> extract_array([1, 2, 3]) + [1, 2, 3] + + For an ndarray-backed Series / Index the ndarray is returned. + + >>> extract_array(pd.Series([1, 2, 3])) + array([1, 2, 3]) + + To extract all the way down to the ndarray, pass ``extract_numpy=True``. + + >>> extract_array(pd.Series([1, 2, 3]), extract_numpy=True) + array([1, 2, 3]) + """ + typ = getattr(obj, "_typ", None) + if typ in _typs: + # i.e. isinstance(obj, (ABCIndex, ABCSeries)) + if typ == "rangeindex": + if extract_range: + # error: "T" has no attribute "_values" + return obj._values # type: ignore[attr-defined] + return obj + + # error: "T" has no attribute "_values" + return obj._values # type: ignore[attr-defined] + + elif extract_numpy and typ == "npy_extension": + # i.e. isinstance(obj, ABCNumpyExtensionArray) + # error: "T" has no attribute "to_numpy" + return obj.to_numpy() # type: ignore[attr-defined] + + return obj + + +def ensure_wrapped_if_datetimelike(arr): + """ + Wrap datetime64 and timedelta64 ndarrays in DatetimeArray/TimedeltaArray. + """ + if isinstance(arr, np.ndarray): + if arr.dtype.kind == "M": + from pandas.core.arrays import DatetimeArray + + dtype = get_supported_dtype(arr.dtype) + return DatetimeArray._from_sequence(arr, dtype=dtype) + + elif arr.dtype.kind == "m": + from pandas.core.arrays import TimedeltaArray + + dtype = get_supported_dtype(arr.dtype) + return TimedeltaArray._from_sequence(arr, dtype=dtype) + + return arr + + +def sanitize_masked_array(data: ma.MaskedArray) -> np.ndarray: + """ + Convert numpy MaskedArray to ensure mask is softened. + """ + mask = ma.getmaskarray(data) + if mask.any(): + dtype, fill_value = maybe_promote(data.dtype, np.nan) + dtype = cast(np.dtype, dtype) + data = ma.asarray(data.astype(dtype, copy=True)) + data.soften_mask() # set hardmask False if it was True + data[mask] = fill_value + else: + data = data.copy() + return data + + +def sanitize_array( + data, + index: Index | None, + dtype: DtypeObj | None = None, + copy: bool = False, + *, + allow_2d: bool = False, +) -> ArrayLike: + """ + Sanitize input data to an ndarray or ExtensionArray, copy if specified, + coerce to the dtype if specified. + + Parameters + ---------- + data : Any + index : Index or None, default None + dtype : np.dtype, ExtensionDtype, or None, default None + copy : bool, default False + allow_2d : bool, default False + If False, raise if we have a 2D Arraylike. + + Returns + ------- + np.ndarray or ExtensionArray + """ + original_dtype = dtype + if isinstance(data, ma.MaskedArray): + data = sanitize_masked_array(data) + + if isinstance(dtype, NumpyEADtype): + # Avoid ending up with a NumpyExtensionArray + dtype = dtype.numpy_dtype + + infer_object = not isinstance(data, (ABCIndex, ABCSeries)) + + # extract ndarray or ExtensionArray, ensure we have no NumpyExtensionArray + data = extract_array(data, extract_numpy=True, extract_range=True) + + if isinstance(data, np.ndarray) and data.ndim == 0: + if dtype is None: + dtype = data.dtype + data = lib.item_from_zerodim(data) + elif isinstance(data, range): + # GH#16804 + data = range_to_ndarray(data) + copy = False + + if not is_list_like(data): + if index is None: + raise ValueError("index must be specified when data is not list-like") + if isinstance(data, str) and using_string_dtype() and original_dtype is None: + from pandas.core.arrays.string_ import StringDtype + + dtype = StringDtype(na_value=np.nan) + data = construct_1d_arraylike_from_scalar(data, len(index), dtype) + + return data + + elif isinstance(data, ABCExtensionArray): + # it is already ensured above this is not a NumpyExtensionArray + # Until GH#49309 is fixed this check needs to come before the + # ExtensionDtype check + if dtype is not None: + subarr = data.astype(dtype, copy=copy) + elif copy: + subarr = data.copy() + else: + subarr = data + + elif isinstance(dtype, ExtensionDtype): + # create an extension array from its dtype + _sanitize_non_ordered(data) + cls = dtype.construct_array_type() + if not hasattr(data, "__array__"): + data = list(data) + subarr = cls._from_sequence(data, dtype=dtype, copy=copy) + + # GH#846 + elif isinstance(data, np.ndarray): + if isinstance(data, np.matrix): + data = data.A + + if dtype is None: + subarr = data + if data.dtype == object and infer_object: + subarr = lib.maybe_convert_objects( + data, + # Here we do not convert numeric dtypes, as if we wanted that, + # numpy would have done it for us. + convert_numeric=False, + convert_non_numeric=True, + convert_to_nullable_dtype=False, + dtype_if_all_nat=np.dtype("M8[s]"), + ) + elif data.dtype.kind == "U" and using_string_dtype(): + from pandas.core.arrays.string_ import StringDtype + + dtype = StringDtype(na_value=np.nan) + subarr = dtype.construct_array_type()._from_sequence(data, dtype=dtype) + + if ( + subarr is data + or (subarr.dtype == "str" and subarr.dtype.storage == "python") # type: ignore[union-attr] + ) and copy: + subarr = subarr.copy() + + else: + # we will try to copy by-definition here + subarr = _try_cast(data, dtype, copy) + + elif hasattr(data, "__array__"): + # e.g. dask array GH#38645 + if not copy: + data = np.asarray(data) + else: + data = np.array(data, copy=copy) + return sanitize_array( + data, + index=index, + dtype=dtype, + copy=False, + allow_2d=allow_2d, + ) + + else: + _sanitize_non_ordered(data) + # materialize e.g. generators, convert e.g. tuples, abc.ValueView + data = list(data) + + if len(data) == 0 and dtype is None: + # We default to float64, matching numpy + subarr = np.array([], dtype=np.float64) + + elif dtype is not None: + subarr = _try_cast(data, dtype, copy) + + else: + subarr = maybe_convert_platform(data) + if subarr.dtype == object: + subarr = cast(np.ndarray, subarr) + subarr = lib.maybe_convert_objects( + subarr, + # Here we do not convert numeric dtypes, as if we wanted that, + # numpy would have done it for us. + convert_numeric=False, + convert_non_numeric=True, + convert_to_nullable_dtype=False, + dtype_if_all_nat=np.dtype("M8[s]"), + ) + + subarr = _sanitize_ndim(subarr, data, dtype, index, allow_2d=allow_2d) + + if isinstance(subarr, np.ndarray): + # at this point we should have dtype be None or subarr.dtype == dtype + dtype = cast(np.dtype, dtype) + subarr = _sanitize_str_dtypes(subarr, data, dtype, copy) + + return subarr + + +def range_to_ndarray(rng: range) -> np.ndarray: + """ + Cast a range object to ndarray. + """ + # GH#30171 perf avoid realizing range as a list in np.array + try: + arr = np.arange(rng.start, rng.stop, rng.step, dtype="int64") + except OverflowError: + # GH#30173 handling for ranges that overflow int64 + if (rng.start >= 0 and rng.step > 0) or (rng.step < 0 <= rng.stop): + try: + arr = np.arange(rng.start, rng.stop, rng.step, dtype="uint64") + except OverflowError: + arr = construct_1d_object_array_from_listlike(list(rng)) + else: + arr = construct_1d_object_array_from_listlike(list(rng)) + return arr + + +def _sanitize_non_ordered(data) -> None: + """ + Raise only for unordered sets, e.g., not for dict_keys + """ + if isinstance(data, (set, frozenset)): + raise TypeError(f"'{type(data).__name__}' type is unordered") + + +def _sanitize_ndim( + result: ArrayLike, + data, + dtype: DtypeObj | None, + index: Index | None, + *, + allow_2d: bool = False, +) -> ArrayLike: + """ + Ensure we have a 1-dimensional result array. + """ + if getattr(result, "ndim", 0) == 0: + raise ValueError("result should be arraylike with ndim > 0") + + if result.ndim == 1: + # the result that we want + result = _maybe_repeat(result, index) + + elif result.ndim > 1: + if isinstance(data, np.ndarray): + if allow_2d: + return result + raise ValueError( + f"Data must be 1-dimensional, got ndarray of shape {data.shape} instead" + ) + if is_object_dtype(dtype) and isinstance(dtype, ExtensionDtype): + # i.e. NumpyEADtype("O") + + result = com.asarray_tuplesafe(data, dtype=np.dtype("object")) + cls = dtype.construct_array_type() + result = cls._from_sequence(result, dtype=dtype) + else: + # error: Argument "dtype" to "asarray_tuplesafe" has incompatible type + # "Union[dtype[Any], ExtensionDtype, None]"; expected "Union[str, + # dtype[Any], None]" + result = com.asarray_tuplesafe(data, dtype=dtype) # type: ignore[arg-type] + return result + + +def _sanitize_str_dtypes( + result: np.ndarray, data, dtype: np.dtype | None, copy: bool +) -> np.ndarray: + """ + Ensure we have a dtype that is supported by pandas. + """ + + # This is to prevent mixed-type Series getting all casted to + # NumPy string type, e.g. NaN --> '-1#IND'. + if issubclass(result.dtype.type, str): + # GH#16605 + # If not empty convert the data to dtype + # GH#19853: If data is a scalar, result has already the result + if not lib.is_scalar(data): + if not np.all(isna(data)): + data = np.asarray(data, dtype=dtype) + if not copy: + result = np.asarray(data, dtype=object) + else: + result = np.array(data, dtype=object, copy=copy) + return result + + +def _maybe_repeat(arr: ArrayLike, index: Index | None) -> ArrayLike: + """ + If we have a length-1 array and an index describing how long we expect + the result to be, repeat the array. + """ + if index is not None: + if 1 == len(arr) != len(index): + arr = arr.repeat(len(index)) + return arr + + +def _try_cast( + arr: list | np.ndarray, + dtype: np.dtype, + copy: bool, +) -> ArrayLike: + """ + Convert input to numpy ndarray and optionally cast to a given dtype. + + Parameters + ---------- + arr : ndarray or list + Excludes: ExtensionArray, Series, Index. + dtype : np.dtype + copy : bool + If False, don't copy the data if not needed. + + Returns + ------- + np.ndarray or ExtensionArray + """ + is_ndarray = isinstance(arr, np.ndarray) + + if dtype == object: + if not is_ndarray: + subarr = construct_1d_object_array_from_listlike(arr) + return subarr + return ensure_wrapped_if_datetimelike(arr).astype(dtype, copy=copy) + + elif dtype.kind == "U": + # TODO: test cases with arr.dtype.kind in "mM" + if is_ndarray: + arr = cast(np.ndarray, arr) + shape = arr.shape + if arr.ndim > 1: + arr = arr.ravel() + else: + shape = (len(arr),) + return lib.ensure_string_array(arr, convert_na_value=False, copy=copy).reshape( + shape + ) + + elif dtype.kind in "mM": + if is_ndarray: + arr = cast(np.ndarray, arr) + if arr.ndim == 2 and arr.shape[1] == 1: + # GH#60081: DataFrame Constructor converts 1D data to array of + # shape (N, 1), but maybe_cast_to_datetime assumes 1D input + return maybe_cast_to_datetime(arr[:, 0], dtype).reshape(arr.shape) + return maybe_cast_to_datetime(arr, dtype) + + # GH#15832: Check if we are requesting a numeric dtype and + # that we can convert the data to the requested dtype. + elif dtype.kind in "iu": + # this will raise if we have e.g. floats + + subarr = maybe_cast_to_integer_array(arr, dtype) + elif not copy: + subarr = np.asarray(arr, dtype=dtype) + else: + subarr = np.array(arr, dtype=dtype, copy=copy) + + return subarr diff --git a/pandas/core/flags.py b/pandas/core/flags.py new file mode 100644 index 0000000000000000000000000000000000000000..f6088e3f40b1be470cf2e0ef138355d2f0239031 --- /dev/null +++ b/pandas/core/flags.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +import weakref + +from pandas.util._decorators import set_module + +if TYPE_CHECKING: + from pandas.core.generic import NDFrame + + +@set_module("pandas") +class Flags: + """ + Flags that apply to pandas objects. + + “Flags” differ from “metadata”. Flags reflect properties of the pandas + object (the Series or DataFrame). Metadata refer to properties of the + dataset, and should be stored in DataFrame.attrs. + + Parameters + ---------- + obj : Series or DataFrame + The object these flags are associated with. + allows_duplicate_labels : bool, default True + Whether to allow duplicate labels in this object. By default, + duplicate labels are permitted. Setting this to ``False`` will + cause an :class:`errors.DuplicateLabelError` to be raised when + `index` (or columns for DataFrame) is not unique, or any + subsequent operation on introduces duplicates. + See :ref:`duplicates.disallow` for more. + + .. warning:: + + This is an experimental feature. Currently, many methods fail to + propagate the ``allows_duplicate_labels`` value. In future versions + it is expected that every method taking or returning one or more + DataFrame or Series objects will propagate ``allows_duplicate_labels``. + + See Also + -------- + DataFrame.attrs : Dictionary of global attributes of this dataset. + Series.attrs : Dictionary of global attributes of this dataset. + + Examples + -------- + Attributes can be set in two ways: + + >>> df = pd.DataFrame() + >>> df.flags + + >>> df.flags.allows_duplicate_labels = False + >>> df.flags + + + >>> df.flags["allows_duplicate_labels"] = True + >>> df.flags + + """ + + _keys: set[str] = {"allows_duplicate_labels"} + + def __init__(self, obj: NDFrame, *, allows_duplicate_labels: bool) -> None: + self._allows_duplicate_labels = allows_duplicate_labels + self._obj = weakref.ref(obj) + + @property + def allows_duplicate_labels(self) -> bool: + """ + Whether this object allows duplicate labels. + + Setting ``allows_duplicate_labels=False`` ensures that the + index (and columns of a DataFrame) are unique. Most methods + that accept and return a Series or DataFrame will propagate + the value of ``allows_duplicate_labels``. + + See :ref:`duplicates` for more. + + See Also + -------- + DataFrame.attrs : Set global metadata on this object. + DataFrame.set_flags : Set global flags on this object. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 2]}, index=["a", "a"]) + >>> df.flags.allows_duplicate_labels + True + >>> df.flags.allows_duplicate_labels = False + Traceback (most recent call last): + ... + pandas.errors.DuplicateLabelError: Index has duplicates. + positions + label + a [0, 1] + """ + return self._allows_duplicate_labels + + @allows_duplicate_labels.setter + def allows_duplicate_labels(self, value: bool) -> None: + value = bool(value) + obj = self._obj() + if obj is None: + raise ValueError("This flag's object has been deleted.") + + if not value: + for ax in obj.axes: + ax._maybe_check_unique() + + self._allows_duplicate_labels = value + + def __getitem__(self, key: str): + if key not in self._keys: + raise KeyError(key) + + return getattr(self, key) + + def __setitem__(self, key: str, value) -> None: + if key not in self._keys: + raise ValueError(f"Unknown flag {key}. Must be one of {self._keys}") + setattr(self, key, value) + + def __repr__(self) -> str: + return f"" + + def __eq__(self, other: object) -> bool: + if isinstance(other, type(self)): + return self.allows_duplicate_labels == other.allows_duplicate_labels + return False diff --git a/pandas/core/frame.py b/pandas/core/frame.py new file mode 100644 index 0000000000000000000000000000000000000000..79e3cf246fe262487f97cb8181dda466951d74fe --- /dev/null +++ b/pandas/core/frame.py @@ -0,0 +1,16710 @@ +""" +DataFrame +--------- +An efficient 2D container for potentially mixed-type time series or other +labeled data series. + +Similar to its R counterpart, data.frame, except providing automatic data +alignment and a host of useful data manipulation methods having to do with the +labeling information +""" + +from __future__ import annotations + +import collections +from collections import abc +from collections.abc import ( + Callable, + Hashable, + Iterable, + Iterator, + Mapping, + Sequence, +) +import functools +from io import StringIO +import itertools +import operator +import sys +from textwrap import dedent +from typing import ( + TYPE_CHECKING, + Any, + Literal, + Self, + cast, + overload, +) +import warnings + +import numpy as np +from numpy import ma + +from pandas._config import get_option + +from pandas._libs import ( + algos as libalgos, + lib, + properties, +) +from pandas._libs.hashtable import duplicated +from pandas._libs.lib import is_range_indexer +from pandas.compat import CHAINED_WARNING_DISABLED +from pandas.compat._constants import ( + REF_COUNT, + REF_COUNT_METHOD, +) +from pandas.compat._optional import import_optional_dependency +from pandas.compat.numpy import function as nv +from pandas.errors import ( + ChainedAssignmentError, + InvalidIndexError, + Pandas4Warning, +) +from pandas.errors.cow import ( + _chained_assignment_method_update_msg, + _chained_assignment_msg, +) +from pandas.util._decorators import ( + Appender, + Substitution, + deprecate_nonkeyword_arguments, + set_module, +) +from pandas.util._exceptions import ( + find_stack_level, +) +from pandas.util._validators import ( + validate_ascending, + validate_bool_kwarg, + validate_percentile, +) + +from pandas.core.dtypes.cast import ( + LossySetitemError, + can_hold_element, + construct_1d_arraylike_from_scalar, + construct_2d_arraylike_from_scalar, + find_common_type, + infer_dtype_from_scalar, + invalidate_string_dtypes, + maybe_downcast_to_dtype, + maybe_unbox_numpy_scalar, +) +from pandas.core.dtypes.common import ( + infer_dtype_from_object, + is_1d_only_ea_dtype, + is_array_like, + is_bool_dtype, + is_dataclass, + is_dict_like, + is_float, + is_float_dtype, + is_hashable, + is_integer, + is_integer_dtype, + is_iterator, + is_list_like, + is_scalar, + is_sequence, + is_string_dtype, + needs_i8_conversion, + pandas_dtype, +) +from pandas.core.dtypes.concat import concat_compat +from pandas.core.dtypes.dtypes import ( + ArrowDtype, + BaseMaskedDtype, + ExtensionDtype, +) +from pandas.core.dtypes.generic import ( + ABCIndex, + ABCSeries, +) +from pandas.core.dtypes.missing import ( + isna, + notna, +) + +from pandas.core import ( + algorithms, + common as com, + nanops, + ops, + roperator, +) +from pandas.core.accessor import Accessor +from pandas.core.apply import reconstruct_and_relabel_result +from pandas.core.array_algos.take import take_2d_multi +from pandas.core.arraylike import OpsMixin +from pandas.core.arrays import ( + BaseMaskedArray, + DatetimeArray, + ExtensionArray, + PeriodArray, + TimedeltaArray, +) +from pandas.core.arrays.sparse import SparseFrameAccessor +from pandas.core.arrays.string_ import StringDtype +from pandas.core.construction import ( + ensure_wrapped_if_datetimelike, + sanitize_array, + sanitize_masked_array, +) +from pandas.core.generic import NDFrame +from pandas.core.indexers import check_key_length +from pandas.core.indexes.api import ( + DatetimeIndex, + Index, + PeriodIndex, + default_index, + ensure_index, + ensure_index_from_sequences, +) +from pandas.core.indexes.multi import ( + MultiIndex, + maybe_droplevels, +) +from pandas.core.indexing import ( + check_bool_indexer, + check_dict_or_set_indexers, +) +from pandas.core.internals import BlockManager +from pandas.core.internals.construction import ( + arrays_to_mgr, + dataclasses_to_dicts, + dict_to_mgr, + ndarray_to_mgr, + nested_data_to_arrays, + rec_array_to_mgr, + reorder_arrays, + to_arrays, + treat_as_nested, +) +from pandas.core.methods import selectn +from pandas.core.reshape.melt import melt +from pandas.core.series import Series +from pandas.core.shared_docs import _shared_docs +from pandas.core.sorting import ( + get_group_index, + lexsort_indexer, + nargsort, +) + +from pandas.io.common import get_handle +from pandas.io.formats import ( + console, + format as fmt, +) +from pandas.io.formats.info import DataFrameInfo +import pandas.plotting + +if TYPE_CHECKING: + import datetime + + from pandas._libs.internals import BlockValuesRefs + from pandas._typing import ( + AggFuncType, + AnyAll, + AnyArrayLike, + ArrayLike, + ArrowArrayExportable, + ArrowStreamExportable, + Axes, + Axis, + AxisInt, + ColspaceArgType, + CompressionOptions, + CorrelationMethod, + DropKeep, + Dtype, + DtypeObj, + FilePath, + FloatFormatType, + FormattersType, + Frequency, + FromDictOrient, + HashableT, + HashableT2, + IgnoreRaise, + IndexKeyFunc, + IndexLabel, + JoinValidate, + Level, + ListLike, + MergeHow, + MergeValidate, + MutableMappingT, + NaPosition, + NsmallestNlargestKeep, + ParquetCompressionOptions, + PythonFuncType, + QuantileInterpolation, + ReadBuffer, + ReindexMethod, + Renamer, + Scalar, + SequenceNotStr, + SortKind, + StorageOptions, + Suffixes, + T, + ToStataByteorder, + ToTimestampHow, + UpdateJoin, + ValueKeyFunc, + WriteBuffer, + XMLParsers, + npt, + ) + + from pandas.core.groupby.generic import DataFrameGroupBy + from pandas.core.interchange.dataframe_protocol import DataFrame as DataFrameXchg + from pandas.core.internals.managers import SingleBlockManager + + from pandas.io.formats.style import Styler + +# --------------------------------------------------------------------- +# Docstring templates + +_shared_doc_kwargs = { + "axes": "index, columns", + "klass": "DataFrame", + "axes_single_arg": "{0 or 'index', 1 or 'columns'}", + "axis": """axis : {0 or 'index', 1 or 'columns'}, default 0 + If 0 or 'index': apply function to each column. + If 1 or 'columns': apply function to each row.""", + "inplace": """ + inplace : bool, default False + Whether to modify the DataFrame rather than creating a new one.""", + "optional_by": """ +by : str or list of str + Name or list of names to sort by. + + - if `axis` is 0 or `'index'` then `by` may contain index + levels and/or column labels. + - if `axis` is 1 or `'columns'` then `by` may contain column + levels and/or index labels.""", + "optional_reindex": """ +labels : array-like, optional + New labels / index to conform the axis specified by 'axis' to. +index : array-like, optional + New labels for the index. Preferably an Index object to avoid + duplicating data. +columns : array-like, optional + New labels for the columns. Preferably an Index object to avoid + duplicating data. +axis : int or str, optional + Axis to target. Can be either the axis name ('index', 'columns') + or number (0, 1).""", +} + +_merge_doc = """ +Merge DataFrame or named Series objects with a database-style join. + +A named Series object is treated as a DataFrame with a single named column. + +The join is done on columns or indexes. If joining columns on +columns, the DataFrame indexes *will be ignored*. Otherwise if joining indexes +on indexes or indexes on a column or columns, the index will be passed on. +When performing a cross merge, no column specifications to merge on are +allowed. + +.. warning:: + + If both key columns contain rows where the key is a null value, those + rows will be matched against each other. This is different from usual SQL + join behaviour and can lead to unexpected results. + +Parameters +----------%s +right : DataFrame or named Series + Object to merge with. +how : {'left', 'right', 'outer', 'inner', 'cross', 'left_anti', 'right_anti'}, + default 'inner' + Type of merge to be performed. + + * left: use only keys from left frame, similar to a SQL left outer join; + preserve key order. + * right: use only keys from right frame, similar to a SQL right outer join; + preserve key order. + * outer: use union of keys from both frames, similar to a SQL full outer + join; sort keys lexicographically. + * inner: use intersection of keys from both frames, similar to a SQL inner + join; preserve the order of the left keys. + * cross: creates the cartesian product from both frames, preserves the order + of the left keys. + * left_anti: use only keys from left frame that are not in right frame, similar + to SQL left anti join; preserve key order. + + .. versionadded:: 3.0 + * right_anti: use only keys from right frame that are not in left frame, similar + to SQL right anti join; preserve key order. + + .. versionadded:: 3.0 +on : Hashable or a sequence of the previous + Column or index level names to join on. These must be found in both + DataFrames. If `on` is None and not merging on indexes then this defaults + to the intersection of the columns in both DataFrames. +left_on : Hashable or a sequence of the previous, or array-like + Column or index level names to join on in the left DataFrame. Can also + be an array or list of arrays of the length of the left DataFrame. + These arrays are treated as if they are columns. +right_on : Hashable or a sequence of the previous, or array-like + Column or index level names to join on in the right DataFrame. Can also + be an array or list of arrays of the length of the right DataFrame. + These arrays are treated as if they are columns. +left_index : bool, default False + Use the index from the left DataFrame as the join key(s). If it is a + MultiIndex, the number of keys in the other DataFrame (either the index + or a number of columns) must match the number of levels. +right_index : bool, default False + Use the index from the right DataFrame as the join key. Same caveats as + left_index. +sort : bool, default False + Sort the join keys lexicographically in the result DataFrame. If False, + the order of the join keys depends on the join type (how keyword). +suffixes : list-like, default is ("_x", "_y") + A length-2 sequence where each element is optionally a string + indicating the suffix to add to overlapping column names in + `left` and `right` respectively. Pass a value of `None` instead + of a string to indicate that the column name from `left` or + `right` should be left as-is, with no suffix. At least one of the + values must not be None. +copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + +indicator : bool or str, default False + If True, adds a column to the output DataFrame called "_merge" with + information on the source of each row. The column can be given a different + name by providing a string argument. The column will have a Categorical + type with the value of "left_only" for observations whose merge key only + appears in the left DataFrame, "right_only" for observations + whose merge key only appears in the right DataFrame, and "both" + if the observation's merge key is found in both DataFrames. + +validate : str, optional + If specified, checks if merge is of specified type. + + * "one_to_one" or "1:1": check if merge keys are unique in both + left and right datasets. + * "one_to_many" or "1:m": check if merge keys are unique in left + dataset. + * "many_to_one" or "m:1": check if merge keys are unique in right + dataset. + * "many_to_many" or "m:m": allowed, but does not result in checks. + +Returns +------- +DataFrame + A DataFrame of the two merged objects. + +See Also +-------- +merge_ordered : Merge with optional filling/interpolation. +merge_asof : Merge on nearest keys. +DataFrame.join : Similar method using indices. + +Examples +-------- +>>> df1 = pd.DataFrame({'lkey': ['foo', 'bar', 'baz', 'foo'], +... 'value': [1, 2, 3, 5]}) +>>> df2 = pd.DataFrame({'rkey': ['foo', 'bar', 'baz', 'foo'], +... 'value': [5, 6, 7, 8]}) +>>> df1 + lkey value +0 foo 1 +1 bar 2 +2 baz 3 +3 foo 5 +>>> df2 + rkey value +0 foo 5 +1 bar 6 +2 baz 7 +3 foo 8 + +Merge df1 and df2 on the lkey and rkey columns. The value columns have +the default suffixes, _x and _y, appended. + +>>> df1.merge(df2, left_on='lkey', right_on='rkey') + lkey value_x rkey value_y +0 foo 1 foo 5 +1 foo 1 foo 8 +2 bar 2 bar 6 +3 baz 3 baz 7 +4 foo 5 foo 5 +5 foo 5 foo 8 + +Merge DataFrames df1 and df2 with specified left and right suffixes +appended to any overlapping columns. + +>>> df1.merge(df2, left_on='lkey', right_on='rkey', +... suffixes=('_left', '_right')) + lkey value_left rkey value_right +0 foo 1 foo 5 +1 foo 1 foo 8 +2 bar 2 bar 6 +3 baz 3 baz 7 +4 foo 5 foo 5 +5 foo 5 foo 8 + +Merge DataFrames df1 and df2, but raise an exception if the DataFrames have +any overlapping columns. + +>>> df1.merge(df2, left_on='lkey', right_on='rkey', suffixes=(False, False)) +Traceback (most recent call last): +... +ValueError: columns overlap but no suffix specified: + Index(['value'], dtype='object') + +>>> df1 = pd.DataFrame({'a': ['foo', 'bar'], 'b': [1, 2]}) +>>> df2 = pd.DataFrame({'a': ['foo', 'baz'], 'c': [3, 4]}) +>>> df1 + a b +0 foo 1 +1 bar 2 +>>> df2 + a c +0 foo 3 +1 baz 4 + +>>> df1.merge(df2, how='inner', on='a') + a b c +0 foo 1 3 + +>>> df1.merge(df2, how='left', on='a') + a b c +0 foo 1 3.0 +1 bar 2 NaN + +>>> df1 = pd.DataFrame({'left': ['foo', 'bar']}) +>>> df2 = pd.DataFrame({'right': [7, 8]}) +>>> df1 + left +0 foo +1 bar +>>> df2 + right +0 7 +1 8 + +>>> df1.merge(df2, how='cross') + left right +0 foo 7 +1 foo 8 +2 bar 7 +3 bar 8 +""" + + +# ----------------------------------------------------------------------- +# DataFrame class + + +@set_module("pandas") +class DataFrame(NDFrame, OpsMixin): + """ + Two-dimensional, size-mutable, potentially heterogeneous tabular data. + + Data structure also contains labeled axes (rows and columns). + Arithmetic operations align on both row and column labels. Can be + thought of as a dict-like container for Series objects. The primary + pandas data structure. + + Parameters + ---------- + data : ndarray (structured or homogeneous), Iterable, dict, or DataFrame + Dict can contain Series, arrays, constants, dataclass or list-like objects. If + data is a dict, column order follows insertion-order. If a dict contains Series + which have an index defined, it is aligned by its index. This alignment also + occurs if data is a Series or a DataFrame itself. Alignment is done on + Series/DataFrame inputs. + + If data is a list of dicts, column order follows insertion-order. + + index : Index or array-like + Index to use for resulting frame. Will default to RangeIndex if + no indexing information part of input data and no index provided. + columns : Index or array-like + Column labels to use for resulting frame when data does not have them, + defaulting to RangeIndex(0, 1, 2, ..., n). If data contains column labels, + will perform column selection instead. + dtype : dtype, default None + Data type to force. Only a single dtype is allowed. If None, infer. + If ``data`` is DataFrame then is ignored. + copy : bool or None, default None + Copy data from inputs. + For dict data, the default of None behaves like ``copy=True``. For DataFrame + or 2d ndarray input, the default of None behaves like ``copy=False``. + If data is a dict containing one or more Series (possibly of different dtypes), + ``copy=False`` will ensure that these inputs are not copied. + + See Also + -------- + DataFrame.from_records : Constructor from tuples, also record arrays. + DataFrame.from_dict : From dicts of Series, arrays, or dicts. + read_csv : Read a comma-separated values (csv) file into DataFrame. + read_table : Read general delimited file into DataFrame. + read_clipboard : Read text from clipboard into DataFrame. + + Notes + ----- + Please reference the :ref:`User Guide ` for more information. + + Examples + -------- + Constructing DataFrame from a dictionary. + + >>> d = {"col1": [1, 2], "col2": [3, 4]} + >>> df = pd.DataFrame(data=d) + >>> df + col1 col2 + 0 1 3 + 1 2 4 + + Notice that the inferred dtype is int64. + + >>> df.dtypes + col1 int64 + col2 int64 + dtype: object + + To enforce a single dtype: + + >>> df = pd.DataFrame(data=d, dtype=np.int8) + >>> df.dtypes + col1 int8 + col2 int8 + dtype: object + + Constructing DataFrame from a dictionary including Series: + + >>> d = {"col1": [0, 1, 2, 3], "col2": pd.Series([2, 3], index=[2, 3])} + >>> pd.DataFrame(data=d, index=[0, 1, 2, 3]) + col1 col2 + 0 0 NaN + 1 1 NaN + 2 2 2.0 + 3 3 3.0 + + Constructing DataFrame from numpy ndarray: + + >>> df2 = pd.DataFrame( + ... np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), columns=["a", "b", "c"] + ... ) + >>> df2 + a b c + 0 1 2 3 + 1 4 5 6 + 2 7 8 9 + + Constructing DataFrame from a numpy ndarray that has labeled columns: + + >>> data = np.array( + ... [(1, 2, 3), (4, 5, 6), (7, 8, 9)], + ... dtype=[("a", "i4"), ("b", "i4"), ("c", "i4")], + ... ) + >>> df3 = pd.DataFrame(data, columns=["c", "a"]) + >>> df3 + c a + 0 3 1 + 1 6 4 + 2 9 7 + + Constructing DataFrame from dataclass: + + >>> from dataclasses import make_dataclass + >>> Point = make_dataclass("Point", [("x", int), ("y", int)]) + >>> pd.DataFrame([Point(0, 0), Point(0, 3), Point(2, 3)]) + x y + 0 0 0 + 1 0 3 + 2 2 3 + + Constructing DataFrame from Series/DataFrame: + + >>> ser = pd.Series([1, 2, 3], index=["a", "b", "c"]) + >>> df = pd.DataFrame(data=ser, index=["a", "c"]) + >>> df + 0 + a 1 + c 3 + + >>> df1 = pd.DataFrame([1, 2, 3], index=["a", "b", "c"], columns=["x"]) + >>> df2 = pd.DataFrame(data=df1, index=["a", "c"]) + >>> df2 + x + a 1 + c 3 + """ + + _internal_names_set = {"columns", "index"} | NDFrame._internal_names_set + _typ = "dataframe" + _HANDLED_TYPES = (Series, Index, ExtensionArray, np.ndarray) + _accessors: set[str] = {"sparse"} + _hidden_attrs: frozenset[str] = NDFrame._hidden_attrs | frozenset([]) + _mgr: BlockManager + + # similar to __array_priority__, positions DataFrame before Series, Index, + # and ExtensionArray. Should NOT be overridden by subclasses. + __pandas_priority__ = 4000 + + @property + def _constructor(self) -> type[DataFrame]: + return DataFrame + + def _constructor_from_mgr(self, mgr, axes) -> DataFrame: + df = DataFrame._from_mgr(mgr, axes=axes) + + if type(self) is DataFrame: + # This would also work `if self._constructor is DataFrame`, but + # this check is slightly faster, benefiting the most-common case. + return df + + elif type(self).__name__ == "GeoDataFrame": + # Shim until geopandas can override their _constructor_from_mgr + # bc they have different behavior for Managers than for DataFrames + return self._constructor(mgr) + + # We assume that the subclass __init__ knows how to handle a + # pd.DataFrame object. + return self._constructor(df) + + _constructor_sliced: Callable[..., Series] = Series + + def _constructor_sliced_from_mgr(self, mgr, axes) -> Series: + ser = Series._from_mgr(mgr, axes) + ser._name = None # caller is responsible for setting real name + + if type(self) is DataFrame: + # This would also work `if self._constructor_sliced is Series`, but + # this check is slightly faster, benefiting the most-common case. + return ser + + # We assume that the subclass __init__ knows how to handle a + # pd.Series object. + return self._constructor_sliced(ser) + + # ---------------------------------------------------------------------- + # Constructors + + def __init__( + self, + data=None, + index: Axes | None = None, + columns: Axes | None = None, + dtype: Dtype | None = None, + copy: bool | None = None, + ) -> None: + allow_mgr = False + if dtype is not None: + dtype = self._validate_dtype(dtype) + + if isinstance(data, DataFrame): + data = data._mgr + allow_mgr = True + if not copy: + # if not copying data, ensure to still return a shallow copy + # to avoid the result sharing the same Manager + data = data.copy(deep=False) + + if isinstance(data, BlockManager): + if not allow_mgr: + # GH#52419 + warnings.warn( + f"Passing a {type(data).__name__} to {type(self).__name__} " + "is deprecated and will raise in a future version. " + "Use public APIs instead.", + Pandas4Warning, + stacklevel=2, + ) + + data = data.copy(deep=False) + # first check if a Manager is passed without any other arguments + # -> use fastpath (without checking Manager type) + if index is None and columns is None and dtype is None and not copy: + # GH#33357 fastpath + NDFrame.__init__(self, data) + return + + # GH47215 + if isinstance(index, set): + raise ValueError("index cannot be a set") + if isinstance(columns, set): + raise ValueError("columns cannot be a set") + + if copy is None: + if isinstance(data, dict): + # retain pre-GH#38939 default behavior + copy = True + elif not isinstance(data, (Index, DataFrame, Series)): + copy = True + else: + copy = False + + if data is None: + index = index if index is not None else default_index(0) + columns = columns if columns is not None else default_index(0) + dtype = dtype if dtype is not None else pandas_dtype(object) + data = [] + + if isinstance(data, BlockManager): + mgr = self._init_mgr( + data, axes={"index": index, "columns": columns}, dtype=dtype, copy=copy + ) + + elif isinstance(data, dict): + # GH#38939 de facto copy defaults to False only in non-dict cases + mgr = dict_to_mgr(data, index, columns, dtype=dtype, copy=copy) + elif isinstance(data, ma.MaskedArray): + from numpy.ma import mrecords + + # masked recarray + if isinstance(data, mrecords.MaskedRecords): + raise TypeError( + "MaskedRecords are not supported. Pass " + "{name: data[name] for name in data.dtype.names} " + "instead" + ) + + # a masked array + data = sanitize_masked_array(data) + mgr = ndarray_to_mgr( + data, + index, + columns, + dtype=dtype, + copy=copy, + ) + + elif isinstance(data, (np.ndarray, Series, Index, ExtensionArray)): + if data.dtype.names: + # i.e. numpy structured array + data = cast(np.ndarray, data) + mgr = rec_array_to_mgr( + data, + index, + columns, + dtype, + copy, + ) + elif isinstance(data, (ABCSeries, ABCIndex)) and data.name is not None: + # i.e. Series/Index with non-None name + mgr = dict_to_mgr( + # error: Item "ndarray" of "Union[ndarray, Series, Index]" has no + # attribute "name" + {data.name: data}, + index, + columns, + dtype=dtype, + copy=copy, + ) + else: + mgr = ndarray_to_mgr( + data, + index, + columns, + dtype=dtype, + copy=copy, + ) + + # For data is list-like, or Iterable (will consume into list) + elif is_list_like(data): + if not isinstance(data, abc.Sequence): + if hasattr(data, "__array__"): + # GH#44616 big perf improvement for e.g. pytorch tensor + data = np.asarray(data) + else: + data = list(data) + if len(data) > 0: + if is_dataclass(data[0]): + data = dataclasses_to_dicts(data) + if not isinstance(data, np.ndarray) and treat_as_nested(data): + # exclude ndarray as we may have cast it a few lines above + if columns is not None: + columns = ensure_index(columns) + arrays, columns, index = nested_data_to_arrays( + # error: Argument 3 to "nested_data_to_arrays" has incompatible + # type "Optional[Collection[Any]]"; expected "Optional[Index]" + data, + columns, + index, # type: ignore[arg-type] + dtype, + ) + mgr = arrays_to_mgr( + arrays, + columns, + index, + dtype=dtype, + ) + else: + mgr = ndarray_to_mgr( + data, + index, + columns, + dtype=dtype, + copy=copy, + ) + else: + mgr = dict_to_mgr( + {}, + index, + columns if columns is not None else default_index(0), + dtype=dtype, + ) + # For data is scalar + else: + if index is None or columns is None: + raise ValueError("DataFrame constructor not properly called!") + + index = ensure_index(index) + columns = ensure_index(columns) + + if not dtype: + dtype, _ = infer_dtype_from_scalar(data) + + # For data is a scalar extension dtype + if isinstance(dtype, ExtensionDtype): + # TODO(EA2D): special case not needed with 2D EAs + + values = [ + construct_1d_arraylike_from_scalar(data, len(index), dtype) + for _ in range(len(columns)) + ] + mgr = arrays_to_mgr(values, columns, index, dtype=None) + else: + arr2d = construct_2d_arraylike_from_scalar( + data, + len(index), + len(columns), + dtype, + copy, + ) + + mgr = ndarray_to_mgr( + arr2d, + index, + columns, + dtype=arr2d.dtype, + copy=False, + ) + + NDFrame.__init__(self, mgr) + + # ---------------------------------------------------------------------- + + def __dataframe__( + self, nan_as_null: bool = False, allow_copy: bool = True + ) -> DataFrameXchg: + """ + Return the dataframe interchange object implementing the interchange protocol. + + .. deprecated:: 3.0.0 + + The Dataframe Interchange Protocol is deprecated. + For dataframe-agnostic code, you may want to look into: + + - `Arrow PyCapsule Interface `_ + - `Narwhals `_ + + .. note:: + + For new development, we highly recommend using the Arrow C Data Interface + alongside the Arrow PyCapsule Interface instead of the interchange protocol + + .. warning:: + + Due to severe implementation issues, we recommend only considering using the + interchange protocol in the following cases: + + - converting to pandas: for pandas >= 2.0.3 + - converting from pandas: for pandas >= 3.0.0 + + Parameters + ---------- + nan_as_null : bool, default False + `nan_as_null` is DEPRECATED and has no effect. Please avoid using + it; it will be removed in a future release. + allow_copy : bool, default True + Whether to allow memory copying when exporting. If set to False + it would cause non-zero-copy exports to fail. + + Returns + ------- + DataFrame interchange object + The object which consuming library can use to ingress the dataframe. + + See Also + -------- + DataFrame.from_records : Constructor from tuples, also record arrays. + DataFrame.from_dict : From dicts of Series, arrays, or dicts. + + Notes + ----- + Details on the interchange protocol: + https://data-apis.org/dataframe-protocol/latest/index.html + + Examples + -------- + >>> df_not_necessarily_pandas = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + >>> interchange_object = df_not_necessarily_pandas.__dataframe__() + >>> interchange_object.column_names() + Index(['A', 'B'], dtype='str') + >>> df_pandas = pd.api.interchange.from_dataframe( + ... interchange_object.select_columns_by_name(["A"]) + ... ) + >>> df_pandas + A + 0 1 + 1 2 + + These methods (``column_names``, ``select_columns_by_name``) should work + for any dataframe library which implements the interchange protocol. + """ + warnings.warn( + "The Dataframe Interchange Protocol is deprecated.\n" + "For dataframe-agnostic code, you may want to look into:\n" + "- Arrow PyCapsule Interface: https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html\n" + "- Narwhals: https://github.com/narwhals-dev/narwhals\n", + Pandas4Warning, + stacklevel=find_stack_level(), + ) + from pandas.core.interchange.dataframe import PandasDataFrameXchg + + return PandasDataFrameXchg(self, allow_copy=allow_copy) + + def __arrow_c_stream__(self, requested_schema=None): + """ + Export the pandas DataFrame as an Arrow C stream PyCapsule. + + This relies on pyarrow to convert the pandas DataFrame to the Arrow + format (and follows the default behaviour of ``pyarrow.Table.from_pandas`` + in its handling of the index, i.e. store the index as a column except + for RangeIndex). + This conversion is not necessarily zero-copy. + + Parameters + ---------- + requested_schema : PyCapsule, default None + The schema to which the dataframe should be casted, passed as a + PyCapsule containing a C ArrowSchema representation of the + requested schema. + + Returns + ------- + PyCapsule + """ + pa = import_optional_dependency("pyarrow", min_version="14.0.0") + if requested_schema is not None: + requested_schema = pa.Schema._import_from_c_capsule(requested_schema) + table = pa.Table.from_pandas(self, schema=requested_schema) + return table.__arrow_c_stream__() + + # ---------------------------------------------------------------------- + + @property + def axes(self) -> list[Index]: + """ + Return a list representing the axes of the DataFrame. + + It has the row axis labels and column axis labels as the only members. + They are returned in that order. + + See Also + -------- + DataFrame.index: The index (row labels) of the DataFrame. + DataFrame.columns: The column labels of the DataFrame. + + Examples + -------- + >>> df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + >>> df.axes + [RangeIndex(start=0, stop=2, step=1), Index(['col1', 'col2'], dtype='str')] + """ + return [self.index, self.columns] + + @property + def shape(self) -> tuple[int, int]: + """ + Return a tuple representing the dimensionality of the DataFrame. + + Unlike the `len()` method, which only returns the number of rows, `shape` + provides both row and column counts, making it a more informative method for + understanding dataset size. + + See Also + -------- + numpy.ndarray.shape : Tuple of array dimensions. + + Examples + -------- + >>> df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + >>> df.shape + (2, 2) + + >>> df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 6]}) + >>> df.shape + (2, 3) + """ + return len(self.index), len(self.columns) + + @property + def _is_homogeneous_type(self) -> bool: + """ + Whether all the columns in a DataFrame have the same type. + + Returns + ------- + bool + + Examples + -------- + >>> DataFrame({"A": [1, 2], "B": [3, 4]})._is_homogeneous_type + True + >>> DataFrame({"A": [1, 2], "B": [3.0, 4.0]})._is_homogeneous_type + False + + Items with the same type but different sizes are considered + different types. + + >>> DataFrame( + ... { + ... "A": np.array([1, 2], dtype=np.int32), + ... "B": np.array([1, 2], dtype=np.int64), + ... } + ... )._is_homogeneous_type + False + """ + # The "<" part of "<=" here is for empty DataFrame cases + return len({block.values.dtype for block in self._mgr.blocks}) <= 1 + + @property + def _can_fast_transpose(self) -> bool: + """ + Can we transpose this DataFrame without creating any new array objects. + """ + blocks = self._mgr.blocks + if len(blocks) != 1: + return False + + dtype = blocks[0].dtype + # TODO(EA2D) special case would be unnecessary with 2D EAs + return not is_1d_only_ea_dtype(dtype) + + @property + def _values(self) -> np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray: + """ + Analogue to ._values that may return a 2D ExtensionArray. + """ + mgr = self._mgr + + blocks = mgr.blocks + if len(blocks) != 1: + return ensure_wrapped_if_datetimelike(self.values) + + arr = blocks[0].values + if arr.ndim == 1: + # non-2D ExtensionArray + return self.values + + # more generally, whatever we allow in NDArrayBackedExtensionBlock + arr = cast("np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray", arr) + return arr.T + + # ---------------------------------------------------------------------- + # Rendering Methods + + def _repr_fits_vertical_(self) -> bool: + """ + Check length against max_rows. + """ + max_rows = get_option("display.max_rows") + return len(self) <= max_rows + + def _repr_fits_horizontal_(self) -> bool: + """ + Check if full repr fits in horizontal boundaries imposed by the display + options width and max_columns. + """ + width, height = console.get_console_size() + max_columns = get_option("display.max_columns") + nb_columns = len(self.columns) + + # exceed max columns + if (max_columns and nb_columns > max_columns) or ( + width and nb_columns > (width // 2) + ): + return False + + # used by repr_html under IPython notebook or scripts ignore terminal + # dims + if width is None or not console.in_interactive_session(): + return True + + if get_option("display.width") is not None or console.in_ipython_frontend(): + # check at least the column row for excessive width + max_rows = 1 + else: + max_rows = get_option("display.max_rows") + + # when auto-detecting, so width=None and not in ipython front end + # check whether repr fits horizontal by actually checking + # the width of the rendered repr + buf = StringIO() + + # only care about the stuff we'll actually print out + # and to_string on entire frame may be expensive + d = self + + if max_rows is not None: # unlimited rows + # min of two, where one may be None + d = d.iloc[: min(max_rows, len(d))] + else: + return True + + d.to_string(buf=buf) + value = buf.getvalue() + repr_width = max(len(line) for line in value.split("\n")) + + return repr_width < width + + def _info_repr(self) -> bool: + """ + True if the repr should show the info view. + """ + info_repr_option = get_option("display.large_repr") == "info" + return info_repr_option and not ( + self._repr_fits_horizontal_() and self._repr_fits_vertical_() + ) + + def __repr__(self) -> str: + """ + Return a string representation for a particular DataFrame. + """ + if self._info_repr(): + buf = StringIO() + self.info(buf=buf) + return buf.getvalue() + + repr_params = fmt.get_dataframe_repr_params() + return self.to_string(**repr_params) + + def _repr_html_(self) -> str | None: + """ + Return a html representation for a particular DataFrame. + + Mainly for IPython notebook. + """ + if self._info_repr(): + buf = StringIO() + self.info(buf=buf) + # need to escape the , should be the first line. + val = buf.getvalue().replace("<", r"<", 1) + val = val.replace(">", r">", 1) + return f"
{val}
" + + if get_option("display.notebook_repr_html"): + max_rows = get_option("display.max_rows") + min_rows = get_option("display.min_rows") + max_cols = get_option("display.max_columns") + show_dimensions = get_option("display.show_dimensions") + show_floats = get_option("display.float_format") + + formatter = fmt.DataFrameFormatter( + self, + columns=None, + col_space=None, + na_rep="NaN", + formatters=None, + float_format=show_floats, + sparsify=None, + justify=None, + index_names=True, + header=True, + index=True, + bold_rows=True, + escape=True, + max_rows=max_rows, + min_rows=min_rows, + max_cols=max_cols, + show_dimensions=show_dimensions, + decimal=".", + ) + return fmt.DataFrameRenderer(formatter).to_html(notebook=True) + else: + return None + + @overload + def to_string( + self, + buf: None = ..., + *, + columns: Axes | None = ..., + col_space: int | list[int] | dict[Hashable, int] | None = ..., + header: bool | SequenceNotStr[str] = ..., + index: bool = ..., + na_rep: str = ..., + formatters: fmt.FormattersType | None = ..., + float_format: fmt.FloatFormatType | None = ..., + sparsify: bool | None = ..., + index_names: bool = ..., + justify: str | None = ..., + max_rows: int | None = ..., + max_cols: int | None = ..., + show_dimensions: bool = ..., + decimal: str = ..., + line_width: int | None = ..., + min_rows: int | None = ..., + max_colwidth: int | None = ..., + encoding: str | None = ..., + ) -> str: ... + + @overload + def to_string( + self, + buf: FilePath | WriteBuffer[str], + *, + columns: Axes | None = ..., + col_space: int | list[int] | dict[Hashable, int] | None = ..., + header: bool | SequenceNotStr[str] = ..., + index: bool = ..., + na_rep: str = ..., + formatters: fmt.FormattersType | None = ..., + float_format: fmt.FloatFormatType | None = ..., + sparsify: bool | None = ..., + index_names: bool = ..., + justify: str | None = ..., + max_rows: int | None = ..., + max_cols: int | None = ..., + show_dimensions: bool = ..., + decimal: str = ..., + line_width: int | None = ..., + min_rows: int | None = ..., + max_colwidth: int | None = ..., + encoding: str | None = ..., + ) -> None: ... + + @Substitution( + header_type="bool or list of str", + header="Write out the column names. If a list of columns " + "is given, it is assumed to be aliases for the " + "column names", + col_space_type="int, list or dict of int", + col_space="The minimum width of each column. If a list of ints is given " + "every integers corresponds with one column. If a dict is given, the key " + "references the column, while the value defines the space to use.", + ) + @Substitution(shared_params=fmt.common_docstring, returns=fmt.return_docstring) + def to_string( + self, + buf: FilePath | WriteBuffer[str] | None = None, + *, + columns: Axes | None = None, + col_space: int | list[int] | dict[Hashable, int] | None = None, + header: bool | SequenceNotStr[str] = True, + index: bool = True, + na_rep: str = "NaN", + formatters: fmt.FormattersType | None = None, + float_format: fmt.FloatFormatType | None = None, + sparsify: bool | None = None, + index_names: bool = True, + justify: str | None = None, + max_rows: int | None = None, + max_cols: int | None = None, + show_dimensions: bool = False, + decimal: str = ".", + line_width: int | None = None, + min_rows: int | None = None, + max_colwidth: int | None = None, + encoding: str | None = None, + ) -> str | None: + """ + Render a DataFrame to a console-friendly tabular output. + %(shared_params)s + line_width : int, optional + Width to wrap a line in characters. + min_rows : int, optional + The number of rows to display in the console in a truncated repr + (when number of rows is above `max_rows`). + max_colwidth : int, optional + Max width to truncate each column in characters. By default, no limit. + encoding : str, default "utf-8" + Set character encoding. + %(returns)s + See Also + -------- + to_html : Convert DataFrame to HTML. + + Examples + -------- + >>> d = {"col1": [1, 2, 3], "col2": [4, 5, 6]} + >>> df = pd.DataFrame(d) + >>> print(df.to_string()) + col1 col2 + 0 1 4 + 1 2 5 + 2 3 6 + """ + from pandas import option_context + + with option_context("display.max_colwidth", max_colwidth): + formatter = fmt.DataFrameFormatter( + self, + columns=columns, + col_space=col_space, + na_rep=na_rep, + formatters=formatters, + float_format=float_format, + sparsify=sparsify, + justify=justify, + index_names=index_names, + header=header, + index=index, + min_rows=min_rows, + max_rows=max_rows, + max_cols=max_cols, + show_dimensions=show_dimensions, + decimal=decimal, + ) + return fmt.DataFrameRenderer(formatter).to_string( + buf=buf, + encoding=encoding, + line_width=line_width, + ) + + def _get_values_for_csv( + self, + *, + float_format: FloatFormatType | None, + date_format: str | None, + decimal: str, + na_rep: str, + quoting, # int csv.QUOTE_FOO from stdlib + ) -> DataFrame: + # helper used by to_csv + mgr = self._mgr.get_values_for_csv( + float_format=float_format, + date_format=date_format, + decimal=decimal, + na_rep=na_rep, + quoting=quoting, + ) + return self._constructor_from_mgr(mgr, axes=mgr.axes) + + # ---------------------------------------------------------------------- + + @property + def style(self) -> Styler: + """ + Returns a Styler object. + + Contains methods for building a styled HTML representation of the DataFrame. + + See Also + -------- + io.formats.style.Styler : Helps style a DataFrame or Series according to the + data with HTML and CSS. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 2, 3]}) + >>> df.style # doctest: +SKIP + + Please see + `Table Visualization <../../user_guide/style.ipynb>`_ for more examples. + """ + # Raise AttributeError so that inspect works even if jinja2 is not installed. + has_jinja2 = import_optional_dependency("jinja2", errors="ignore") + if not has_jinja2: + raise AttributeError("The '.style' accessor requires jinja2") + + from pandas.io.formats.style import Styler + + return Styler(self) + + _shared_docs["items"] = r""" + Iterate over (column name, Series) pairs. + + Iterates over the DataFrame columns, returning a tuple with + the column name and the content as a Series. + + Yields + ------ + label : object + The column names for the DataFrame being iterated over. + content : Series + The column entries belonging to each label, as a Series. + + See Also + -------- + DataFrame.iterrows : Iterate over DataFrame rows as + (index, Series) pairs. + DataFrame.itertuples : Iterate over DataFrame rows as namedtuples + of the values. + + Examples + -------- + >>> df = pd.DataFrame({'species': ['bear', 'bear', 'marsupial'], + ... 'population': [1864, 22000, 80000]}, + ... index=['panda', 'polar', 'koala']) + >>> df + species population + panda bear 1864 + polar bear 22000 + koala marsupial 80000 + >>> for label, content in df.items(): + ... print(f'label: {label}') + ... print(f'content: {content}', sep='\n') + ... + label: species + content: + panda bear + polar bear + koala marsupial + Name: species, dtype: str + label: population + content: + panda 1864 + polar 22000 + koala 80000 + Name: population, dtype: int64 + """ + + def items(self) -> Iterable[tuple[Hashable, Series]]: + r""" + Iterate over (column name, Series) pairs. + + Iterates over the DataFrame columns, returning a tuple with + the column name and the content as a Series. + + Yields + ------ + label : object + The column names for the DataFrame being iterated over. + content : Series + The column entries belonging to each label, as a Series. + + See Also + -------- + DataFrame.iterrows : Iterate over DataFrame rows as + (index, Series) pairs. + DataFrame.itertuples : Iterate over DataFrame rows as namedtuples + of the values. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "species": ["bear", "bear", "marsupial"], + ... "population": [1864, 22000, 80000], + ... }, + ... index=["panda", "polar", "koala"], + ... ) + >>> df + species population + panda bear 1864 + polar bear 22000 + koala marsupial 80000 + >>> for label, content in df.items(): + ... print(f"label: {label}") + ... print(f"content: {content}", sep="\n") + label: species + content: + panda bear + polar bear + koala marsupial + Name: species, dtype: str + label: population + content: + panda 1864 + polar 22000 + koala 80000 + Name: population, dtype: int64 + """ + for i, k in enumerate(self.columns): + yield k, self._ixs(i, axis=1) + + def iterrows(self) -> Iterable[tuple[Hashable, Series]]: + """ + Iterate over DataFrame rows as (index, Series) pairs. + + Yields + ------ + index : label or tuple of label + The index of the row. A tuple for a `MultiIndex`. + data : Series + The data of the row as a Series. + + See Also + -------- + DataFrame.itertuples : Iterate over DataFrame rows as namedtuples of the values. + DataFrame.items : Iterate over (column name, Series) pairs. + + Notes + ----- + 1. Because ``iterrows`` returns a Series for each row, + it does **not** preserve dtypes across the rows (dtypes are + preserved across columns for DataFrames). + + To preserve dtypes while iterating over the rows, it is better + to use :meth:`itertuples` which returns namedtuples of the values + and which is generally faster than ``iterrows``. + + 2. You should **never modify** something you are iterating over. + This is not guaranteed to work in all cases. Depending on the + data types, the iterator returns a copy and not a view, and writing + to it will have no effect. + + Examples + -------- + + >>> df = pd.DataFrame([[1, 1.5]], columns=["int", "float"]) + >>> row = next(df.iterrows())[1] + >>> row + int 1.0 + float 1.5 + Name: 0, dtype: float64 + >>> print(row["int"].dtype) + float64 + >>> print(df["int"].dtype) + int64 + """ + columns = self.columns + klass = self._constructor_sliced + for k, v in zip(self.index, self.values, strict=True): + s = klass(v, index=columns, name=k).__finalize__(self) + if self._mgr.is_single_block: + s._mgr.add_references(self._mgr) + yield k, s + + def itertuples( + self, index: bool = True, name: str | None = "Pandas" + ) -> Iterable[tuple[Any, ...]]: + """ + Iterate over DataFrame rows as namedtuples. + + Parameters + ---------- + index : bool, default True + If True, return the index as the first element of the tuple. + name : str or None, default "Pandas" + The name of the returned namedtuples or None to return regular + tuples. + + Returns + ------- + iterator + An object to iterate over namedtuples for each row in the + DataFrame with the first field possibly being the index and + following fields being the column values. + + See Also + -------- + DataFrame.iterrows : Iterate over DataFrame rows as (index, Series) + pairs. + DataFrame.items : Iterate over (column name, Series) pairs. + + Notes + ----- + The column names will be renamed to positional names if they are + invalid Python identifiers, repeated, or start with an underscore. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"num_legs": [4, 2], "num_wings": [0, 2]}, index=["dog", "hawk"] + ... ) + >>> df + num_legs num_wings + dog 4 0 + hawk 2 2 + >>> for row in df.itertuples(): + ... print(row) + Pandas(Index='dog', num_legs=4, num_wings=0) + Pandas(Index='hawk', num_legs=2, num_wings=2) + + By setting the `index` parameter to False we can remove the index + as the first element of the tuple: + + >>> for row in df.itertuples(index=False): + ... print(row) + Pandas(num_legs=4, num_wings=0) + Pandas(num_legs=2, num_wings=2) + + With the `name` parameter set we set a custom name for the yielded + namedtuples: + + >>> for row in df.itertuples(name="Animal"): + ... print(row) + Animal(Index='dog', num_legs=4, num_wings=0) + Animal(Index='hawk', num_legs=2, num_wings=2) + """ + arrays = [] + fields = list(self.columns) + if index: + arrays.append(self.index) + fields.insert(0, "Index") + + # use integer indexing because of possible duplicate column names + arrays.extend(self.iloc[:, k] for k in range(len(self.columns))) + + if name is not None: + # https://github.com/python/mypy/issues/9046 + # error: namedtuple() expects a string literal as the first argument + itertuple = collections.namedtuple( # type: ignore[misc] + name, fields, rename=True + ) + return map(itertuple._make, zip(*arrays, strict=True)) + + # fallback to regular tuples + return zip(*arrays, strict=True) + + def __len__(self) -> int: + """ + Returns length of info axis, but here we use the index. + """ + return len(self.index) + + @overload + def dot(self, other: Series) -> Series: ... + + @overload + def dot(self, other: DataFrame | Index | ArrayLike) -> DataFrame: ... + + def dot(self, other: AnyArrayLike | DataFrame) -> DataFrame | Series: + """ + Compute the matrix multiplication between the DataFrame and other. + + This method computes the matrix product between the DataFrame and the + values of an other Series, DataFrame or a numpy array. + + It can also be called using ``self @ other``. + + Parameters + ---------- + other : Series, DataFrame or array-like + The other object to compute the matrix product with. + + Returns + ------- + Series or DataFrame + If other is a Series, return the matrix product between self and + other as a Series. If other is a DataFrame or a numpy.array, return + the matrix product of self and other in a DataFrame of a np.array. + + See Also + -------- + Series.dot: Similar method for Series. + + Notes + ----- + The dimensions of DataFrame and other must be compatible in order to + compute the matrix multiplication. In addition, the column names of + DataFrame and the index of other must contain the same values, as they + will be aligned prior to the multiplication. + + The dot method for Series computes the inner product, instead of the + matrix product here. + + Examples + -------- + Here we multiply a DataFrame with a Series. + + >>> df = pd.DataFrame([[0, 1, -2, -1], [1, 1, 1, 1]]) + >>> s = pd.Series([1, 1, 2, 1]) + >>> df.dot(s) + 0 -4 + 1 5 + dtype: int64 + + Here we multiply a DataFrame with another DataFrame. + + >>> other = pd.DataFrame([[0, 1], [1, 2], [-1, -1], [2, 0]]) + >>> df.dot(other) + 0 1 + 0 1 4 + 1 2 2 + + Note that the dot method give the same result as @ + + >>> df @ other + 0 1 + 0 1 4 + 1 2 2 + + The dot method works also if other is an np.array. + + >>> arr = np.array([[0, 1], [1, 2], [-1, -1], [2, 0]]) + >>> df.dot(arr) + 0 1 + 0 1 4 + 1 2 2 + + Note how shuffling of the objects does not change the result. + + >>> s2 = s.reindex([1, 0, 2, 3]) + >>> df.dot(s2) + 0 -4 + 1 5 + dtype: int64 + """ + if isinstance(other, (Series, DataFrame)): + common = self.columns.union(other.index) + if len(common) > len(self.columns) or len(common) > len(other.index): + raise ValueError("matrices are not aligned") + + left = self.reindex(columns=common) + right = other.reindex(index=common) + lvals = left.values + rvals = right._values + else: + left = self + lvals = self.values + rvals = np.asarray(other) + if lvals.shape[1] != rvals.shape[0]: + raise ValueError( + f"Dot product shape mismatch, {lvals.shape} vs {rvals.shape}" + ) + + if isinstance(other, DataFrame): + common_type = find_common_type(list(self.dtypes) + list(other.dtypes)) + return self._constructor( + np.dot(lvals, rvals), + index=left.index, + columns=other.columns, + copy=False, + dtype=common_type, + ) + elif isinstance(other, Series): + common_type = find_common_type([*list(self.dtypes), other.dtypes]) + return self._constructor_sliced( + np.dot(lvals, rvals), index=left.index, copy=False, dtype=common_type + ) + elif isinstance(rvals, (np.ndarray, Index)): + result = np.dot(lvals, rvals) + if result.ndim == 2: + return self._constructor(result, index=left.index, copy=False) + else: + return self._constructor_sliced(result, index=left.index, copy=False) + else: # pragma: no cover + raise TypeError(f"unsupported type: {type(other)}") + + @overload + def __matmul__(self, other: Series) -> Series: ... + + @overload + def __matmul__(self, other: AnyArrayLike | DataFrame) -> DataFrame | Series: ... + + def __matmul__(self, other: AnyArrayLike | DataFrame) -> DataFrame | Series: + """ + Matrix multiplication using binary `@` operator. + """ + return self.dot(other) + + def __rmatmul__(self, other) -> DataFrame: + """ + Matrix multiplication using binary `@` operator. + """ + try: + return self.T.dot(np.transpose(other)).T + except ValueError as err: + if "shape mismatch" not in str(err): + raise + # GH#21581 give exception message for original shapes + msg = f"shapes {np.shape(other)} and {self.shape} not aligned" + raise ValueError(msg) from err + + # ---------------------------------------------------------------------- + # IO methods (to / from other formats) + + @classmethod + def from_arrow( + cls, data: ArrowArrayExportable | ArrowStreamExportable + ) -> DataFrame: + """ + Construct a DataFrame from a tabular Arrow object. + + This function accepts any Arrow-compatible tabular object implementing + the `Arrow PyCapsule Protocol`_ (i.e. having an ``__arrow_c_array__`` + or ``__arrow_c_stream__`` method). + + This function currently relies on ``pyarrow`` to convert the tabular + object in Arrow format to pandas. + + .. _Arrow PyCapsule Protocol: https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + + .. versionadded:: 3.0 + + Parameters + ---------- + data : pyarrow.Table or Arrow-compatible table + Any tabular object implementing the Arrow PyCapsule Protocol + (i.e. has an ``__arrow_c_array__`` or ``__arrow_c_stream__`` + method). + + Returns + ------- + DataFrame + + See Also + -------- + Series.from_arrow : Construct a Series from an Arrow object. + + Examples + -------- + >>> import pyarrow as pa + >>> table = pa.table({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + >>> pd.DataFrame.from_arrow(table) + a b + 0 1 x + 1 2 y + 2 3 z + """ + pa = import_optional_dependency("pyarrow", min_version="14.0.0") + if not isinstance(data, pa.Table): + if not ( + hasattr(data, "__arrow_c_array__") + or hasattr(data, "__arrow_c_stream__") + ): + # explicitly test this, because otherwise we would accept variour other + # input types through the pa.table(..) call + raise TypeError( + "Expected an Arrow-compatible tabular object (i.e. having an " + "'_arrow_c_array__' or '__arrow_c_stream__' method), got " + f"'{type(data).__name__}' instead." + ) + pa_table = pa.table(data) + else: + pa_table = data + + df = pa_table.to_pandas() + return df + + @classmethod + def from_dict( + cls, + data: dict, + orient: FromDictOrient = "columns", + dtype: Dtype | None = None, + columns: Axes | None = None, + ) -> DataFrame: + """ + Construct DataFrame from dict of array-like or dicts. + + Creates DataFrame object from dictionary by columns or by index + allowing dtype specification. + + Parameters + ---------- + data : dict + Of the form {field : array-like} or {field : dict}. + orient : {'columns', 'index', 'tight'}, default 'columns' + The "orientation" of the data. If the keys of the passed dict + should be the columns of the resulting DataFrame, pass 'columns' + (default). Otherwise if the keys should be rows, pass 'index'. + If 'tight', assume a dict with keys ['index', 'columns', 'data', + 'index_names', 'column_names']. + + dtype : dtype, default None + Data type to force after DataFrame construction, otherwise infer. + columns : list, default None + Column labels to use when ``orient='index'``. Raises a ValueError + if used with ``orient='columns'`` or ``orient='tight'``. + + Returns + ------- + DataFrame + + See Also + -------- + DataFrame.from_records : DataFrame from structured ndarray, sequence + of tuples or dicts, or DataFrame. + DataFrame : DataFrame object creation using constructor. + DataFrame.to_dict : Convert the DataFrame to a dictionary. + + Examples + -------- + By default the keys of the dict become the DataFrame columns: + + >>> data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]} + >>> pd.DataFrame.from_dict(data) + col_1 col_2 + 0 3 a + 1 2 b + 2 1 c + 3 0 d + + Specify ``orient='index'`` to create the DataFrame using dictionary + keys as rows: + + >>> data = {"row_1": [3, 2, 1, 0], "row_2": ["a", "b", "c", "d"]} + >>> pd.DataFrame.from_dict(data, orient="index") + 0 1 2 3 + row_1 3 2 1 0 + row_2 a b c d + + When using the 'index' orientation, the column names can be + specified manually: + + >>> pd.DataFrame.from_dict(data, orient="index", columns=["A", "B", "C", "D"]) + A B C D + row_1 3 2 1 0 + row_2 a b c d + + Specify ``orient='tight'`` to create the DataFrame using a 'tight' + format: + + >>> data = { + ... "index": [("a", "b"), ("a", "c")], + ... "columns": [("x", 1), ("y", 2)], + ... "data": [[1, 3], [2, 4]], + ... "index_names": ["n1", "n2"], + ... "column_names": ["z1", "z2"], + ... } + >>> pd.DataFrame.from_dict(data, orient="tight") + z1 x y + z2 1 2 + n1 n2 + a b 1 3 + c 2 4 + """ + index: list | Index | None = None + orient = orient.lower() # type: ignore[assignment] + if orient == "index": + if len(data) > 0: + # TODO speed up Series case + if isinstance(next(iter(data.values())), (Series, dict)): + data = _from_nested_dict(data) + else: + index = list(data.keys()) + # error: Incompatible types in assignment (expression has type + # "List[Any]", variable has type "Dict[Any, Any]") + data = list(data.values()) # type: ignore[assignment] + elif orient in ("columns", "tight"): + if columns is not None: + raise ValueError(f"cannot use columns parameter with orient='{orient}'") + else: # pragma: no cover + raise ValueError( + f"Expected 'index', 'columns' or 'tight' for orient parameter. " + f"Got '{orient}' instead" + ) + + if orient != "tight": + return cls(data, index=index, columns=columns, dtype=dtype) + else: + realdata = data["data"] + + def create_index(indexlist, namelist) -> Index: + index: Index + if len(namelist) > 1: + index = MultiIndex.from_tuples(indexlist, names=namelist) + else: + index = Index(indexlist, name=namelist[0]) + return index + + index = create_index(data["index"], data["index_names"]) + columns = create_index(data["columns"], data["column_names"]) + return cls(realdata, index=index, columns=columns, dtype=dtype) + + def to_numpy( + self, + dtype: npt.DTypeLike | None = None, + copy: bool = False, + na_value: object = lib.no_default, + ) -> np.ndarray: + """ + Convert the DataFrame to a NumPy array. + + By default, the dtype of the returned array will be the common NumPy + dtype of all types in the DataFrame. For example, if the dtypes are + ``float16`` and ``float32``, the results dtype will be ``float32``. + This may require copying data and coercing values, which may be + expensive. + + Parameters + ---------- + dtype : str or numpy.dtype, optional + The dtype to pass to :meth:`numpy.asarray`. + copy : bool, default False + Whether to ensure that the returned value is not a view on + another array. Note that ``copy=False`` does not *ensure* that + ``to_numpy()`` is no-copy. Rather, ``copy=True`` ensure that + a copy is made, even if not strictly necessary. + na_value : Any, optional + The value to use for missing values. The default value depends + on `dtype` and the dtypes of the DataFrame columns. + + Returns + ------- + numpy.ndarray + The NumPy array representing the values in the DataFrame. + + See Also + -------- + Series.to_numpy : Similar method for Series. + + Examples + -------- + >>> pd.DataFrame({"A": [1, 2], "B": [3, 4]}).to_numpy() + array([[1, 3], + [2, 4]]) + + With heterogeneous data, the lowest common type will have to + be used. + + >>> df = pd.DataFrame({"A": [1, 2], "B": [3.0, 4.5]}) + >>> df.to_numpy() + array([[1. , 3. ], + [2. , 4.5]]) + + For a mix of numeric and non-numeric types, the output array will + have object dtype. + + >>> df["C"] = pd.date_range("2000", periods=2) + >>> df.to_numpy() + array([[1, 3.0, Timestamp('2000-01-01 00:00:00')], + [2, 4.5, Timestamp('2000-01-02 00:00:00')]], dtype=object) + """ + if dtype is not None: + dtype = np.dtype(dtype) + result = self._mgr.as_array(dtype=dtype, copy=copy, na_value=na_value) + if result.dtype is not dtype: + result = np.asarray(result, dtype=dtype) + + return result + + @overload + def to_dict( + self, + orient: Literal["dict", "list", "series", "split", "tight", "index"] = ..., + *, + into: type[MutableMappingT] | MutableMappingT, + index: bool = ..., + ) -> MutableMappingT: ... + + @overload + def to_dict( + self, + orient: Literal["records"], + *, + into: type[MutableMappingT] | MutableMappingT, + index: bool = ..., + ) -> list[MutableMappingT]: ... + + @overload + def to_dict( + self, + orient: Literal["dict", "list", "series", "split", "tight", "index"] = ..., + *, + into: type[dict] = ..., + index: bool = ..., + ) -> dict: ... + + @overload + def to_dict( + self, + orient: Literal["records"], + *, + into: type[dict] = ..., + index: bool = ..., + ) -> list[dict]: ... + + # error: Incompatible default for argument "into" (default has type "type + # [dict[Any, Any]]", argument has type "type[MutableMappingT] | MutableMappingT") + def to_dict( + self, + orient: Literal[ + "dict", "list", "series", "split", "tight", "records", "index" + ] = "dict", + *, + into: type[MutableMappingT] | MutableMappingT = dict, # type: ignore[assignment] + index: bool = True, + ) -> MutableMappingT | list[MutableMappingT]: + """ + Convert the DataFrame to a dictionary. + + The type of the key-value pairs can be customized with the parameters + (see below). + + Parameters + ---------- + orient : str {'dict', 'list', 'series', 'split', 'tight', 'records', 'index'} + Determines the type of the values of the dictionary. + + - 'dict' (default) : dict like {column -> {index -> value}} + - 'list' : dict like {column -> [values]} + - 'series' : dict like {column -> Series(values)} + - 'split' : dict like + {'index' -> [index], 'columns' -> [columns], 'data' -> [values]} + - 'tight' : dict like + {'index' -> [index], 'columns' -> [columns], 'data' -> [values], + 'index_names' -> [index.names], 'column_names' -> [column.names]} + - 'records' : list like + [{column -> value}, ... , {column -> value}] + - 'index' : dict like {index -> {column -> value}} + + into : class, default dict + The collections.abc.MutableMapping subclass used for all Mappings + in the return value. Can be the actual class or an empty + instance of the mapping type you want. If you want a + collections.defaultdict, you must pass it initialized. + + index : bool, default True + Whether to include the index item (and index_names item if `orient` + is 'tight') in the returned dictionary. Can only be ``False`` + when `orient` is 'split' or 'tight'. Note that when `orient` is + 'records', this parameter does not take effect (index item always + not included). + + .. versionadded:: 2.0.0 + + Returns + ------- + dict, list or collections.abc.MutableMapping + Return a collections.abc.MutableMapping object representing the + DataFrame. The resulting transformation depends on the `orient` + parameter. + + See Also + -------- + DataFrame.from_dict: Create a DataFrame from a dictionary. + DataFrame.to_json: Convert a DataFrame to JSON format. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"col1": [1, 2], "col2": [0.5, 0.75]}, index=["row1", "row2"] + ... ) + >>> df + col1 col2 + row1 1 0.50 + row2 2 0.75 + >>> df.to_dict() + {'col1': {'row1': 1, 'row2': 2}, 'col2': {'row1': 0.5, 'row2': 0.75}} + + You can specify the return orientation. + + >>> df.to_dict("series") + {'col1': row1 1 + row2 2 + Name: col1, dtype: int64, + 'col2': row1 0.50 + row2 0.75 + Name: col2, dtype: float64} + + >>> df.to_dict("split") + {'index': ['row1', 'row2'], 'columns': ['col1', 'col2'], + 'data': [[1, 0.5], [2, 0.75]]} + + >>> df.to_dict("records") + [{'col1': 1, 'col2': 0.5}, {'col1': 2, 'col2': 0.75}] + + >>> df.to_dict("index") + {'row1': {'col1': 1, 'col2': 0.5}, 'row2': {'col1': 2, 'col2': 0.75}} + + >>> df.to_dict("tight") + {'index': ['row1', 'row2'], 'columns': ['col1', 'col2'], + 'data': [[1, 0.5], [2, 0.75]], 'index_names': [None], 'column_names': [None]} + + You can also specify the mapping type. + + >>> from collections import OrderedDict, defaultdict + >>> df.to_dict(into=OrderedDict) + OrderedDict([('col1', OrderedDict([('row1', 1), ('row2', 2)])), + ('col2', OrderedDict([('row1', 0.5), ('row2', 0.75)]))]) + + If you want a `defaultdict`, you need to initialize it: + + >>> dd = defaultdict(list) + >>> df.to_dict("records", into=dd) + [defaultdict(, {'col1': 1, 'col2': 0.5}), + defaultdict(, {'col1': 2, 'col2': 0.75})] + """ + from pandas.core.methods.to_dict import to_dict + + return to_dict(self, orient, into=into, index=index) + + @classmethod + def from_records( + cls, + data, + index=None, + exclude=None, + columns=None, + coerce_float: bool = False, + nrows: int | None = None, + ) -> DataFrame: + """ + Convert structured or record ndarray to DataFrame. + + Creates a DataFrame object from a structured ndarray, or iterable of + tuples or dicts. + + Parameters + ---------- + data : structured ndarray, iterable of tuples or dicts + Structured input data. + index : str, list of fields, array-like + Field of array to use as the index, alternately a specific set of + input labels to use. + exclude : sequence, default None + Columns or fields to exclude. + columns : sequence, default None + Column names to use. If the passed data do not have names + associated with them, this argument provides names for the + columns. Otherwise, this argument indicates the order of the columns + in the result (any names not found in the data will become all-NA + columns) and limits the data to these columns if not all column names + are provided. + coerce_float : bool, default False + Attempt to convert values of non-string, non-numeric objects (like + decimal.Decimal) to floating point, useful for SQL result sets. + nrows : int, default None + Number of rows to read if data is an iterator. + + Returns + ------- + DataFrame + + See Also + -------- + DataFrame.from_dict : DataFrame from dict of array-like or dicts. + DataFrame : DataFrame object creation using constructor. + + Examples + -------- + Data can be provided as a structured ndarray: + + >>> data = np.array( + ... [(3, "a"), (2, "b"), (1, "c"), (0, "d")], + ... dtype=[("col_1", "i4"), ("col_2", "U1")], + ... ) + >>> pd.DataFrame.from_records(data) + col_1 col_2 + 0 3 a + 1 2 b + 2 1 c + 3 0 d + + Data can be provided as a list of dicts: + + >>> data = [ + ... {"col_1": 3, "col_2": "a"}, + ... {"col_1": 2, "col_2": "b"}, + ... {"col_1": 1, "col_2": "c"}, + ... {"col_1": 0, "col_2": "d"}, + ... ] + >>> pd.DataFrame.from_records(data) + col_1 col_2 + 0 3 a + 1 2 b + 2 1 c + 3 0 d + + Data can be provided as a list of tuples with corresponding columns: + + >>> data = [(3, "a"), (2, "b"), (1, "c"), (0, "d")] + >>> pd.DataFrame.from_records(data, columns=["col_1", "col_2"]) + col_1 col_2 + 0 3 a + 1 2 b + 2 1 c + 3 0 d + """ + if isinstance(data, DataFrame): + raise TypeError( + "Passing a DataFrame to DataFrame.from_records is not supported. Use " + "set_index and/or drop to modify the DataFrame instead.", + ) + + result_index = None + + # Make a copy of the input columns so we can modify it + if columns is not None: + columns = ensure_index(columns) + + def maybe_reorder( + arrays: list[ArrayLike], arr_columns: Index, columns: Index, index + ) -> tuple[list[ArrayLike], Index, Index | None]: + """ + If our desired 'columns' do not match the data's pre-existing 'arr_columns', + we re-order our arrays. This is like a preemptive (cheap) reindex. + """ + if len(arrays): + length = len(arrays[0]) + else: + length = 0 + + result_index = None + if len(arrays) == 0 and index is None and length == 0: + result_index = default_index(0) + + arrays, arr_columns = reorder_arrays(arrays, arr_columns, columns, length) + return arrays, arr_columns, result_index + + if is_iterator(data): + if nrows == 0: + return cls(index=index, columns=columns) + + try: + first_row = next(data) + except StopIteration: + return cls(index=index, columns=columns) + + dtype = None + if hasattr(first_row, "dtype") and first_row.dtype.names: + dtype = first_row.dtype + + values = [first_row] + + if nrows is None: + values += data + else: + values.extend(itertools.islice(data, nrows - 1)) + + if dtype is not None: + data = np.array(values, dtype=dtype) + else: + data = values + + if isinstance(data, dict): + if columns is None: + columns = arr_columns = ensure_index(sorted(data)) + arrays = [data[k] for k in columns] + else: + arrays = [] + arr_columns_list = [] + for k, v in data.items(): + if k in columns: + arr_columns_list.append(k) + arrays.append(v) + + arr_columns = Index(arr_columns_list) + arrays, arr_columns, result_index = maybe_reorder( + arrays, arr_columns, columns, index + ) + + elif isinstance(data, np.ndarray): + arrays, columns = to_arrays(data, columns) + arr_columns = columns + else: + arrays, arr_columns = to_arrays(data, columns) + if coerce_float: + for i, arr in enumerate(arrays): + if arr.dtype == object: + # error: Argument 1 to "maybe_convert_objects" has + # incompatible type "Union[ExtensionArray, ndarray]"; + # expected "ndarray" + arrays[i] = lib.maybe_convert_objects( + arr, # type: ignore[arg-type] + try_float=True, + ) + + arr_columns = ensure_index(arr_columns) + if columns is None: + columns = arr_columns + else: + arrays, arr_columns, result_index = maybe_reorder( + arrays, arr_columns, columns, index + ) + + if exclude is None: + exclude = set() + else: + exclude = set(exclude) + + if index is not None: + if isinstance(index, str) or not hasattr(index, "__iter__"): + i = columns.get_loc(index) + exclude.add(index) + if len(arrays) > 0: + result_index = Index(arrays[i], name=index) + else: + result_index = Index([], name=index) + else: + try: + index_data = [arrays[arr_columns.get_loc(field)] for field in index] + except (KeyError, TypeError): + # raised by get_loc, see GH#29258 + result_index = index + else: + result_index = ensure_index_from_sequences(index_data, names=index) + exclude.update(index) + + if any(exclude): + arr_exclude = (x for x in exclude if x in arr_columns) + to_remove = {arr_columns.get_loc(col) for col in arr_exclude} # pyright: ignore[reportUnhashable] + arrays = [v for i, v in enumerate(arrays) if i not in to_remove] + + columns = columns.drop(exclude) + + mgr = arrays_to_mgr(arrays, columns, result_index) + df = DataFrame._from_mgr(mgr, axes=mgr.axes) + if cls is not DataFrame: + return cls(df, copy=False) + return df + + def to_records( + self, index: bool = True, column_dtypes=None, index_dtypes=None + ) -> np.rec.recarray: + """ + Convert DataFrame to a NumPy record array. + + Index will be included as the first field of the record array if + requested. + + Parameters + ---------- + index : bool, default True + Include index in resulting record array, stored in 'index' + field or using the index label, if set. + column_dtypes : str, type, dict, default None + If a string or type, the data type to store all columns. If + a dictionary, a mapping of column names and indices (zero-indexed) + to specific data types. + index_dtypes : str, type, dict, default None + If a string or type, the data type to store all index levels. If + a dictionary, a mapping of index level names and indices + (zero-indexed) to specific data types. + + This mapping is applied only if `index=True`. + + Returns + ------- + numpy.rec.recarray + NumPy ndarray with the DataFrame labels as fields and each row + of the DataFrame as entries. + + See Also + -------- + DataFrame.from_records: Convert structured or record ndarray + to DataFrame. + numpy.rec.recarray: An ndarray that allows field access using + attributes, analogous to typed columns in a + spreadsheet. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 2], "B": [0.5, 0.75]}, index=["a", "b"]) + >>> df + A B + a 1 0.50 + b 2 0.75 + >>> df.to_records() + rec.array([('a', 1, 0.5 ), ('b', 2, 0.75)], + dtype=[('index', 'O'), ('A', '>> df.index = df.index.rename("I") + >>> df.to_records() + rec.array([('a', 1, 0.5 ), ('b', 2, 0.75)], + dtype=[('I', 'O'), ('A', '>> df.to_records(index=False) + rec.array([(1, 0.5 ), (2, 0.75)], + dtype=[('A', '>> df.to_records(column_dtypes={"A": "int32"}) + rec.array([('a', 1, 0.5 ), ('b', 2, 0.75)], + dtype=[('I', 'O'), ('A', '>> df.to_records(index_dtypes=">> index_dtypes = f">> df.to_records(index_dtypes=index_dtypes) + rec.array([(b'a', 1, 0.5 ), (b'b', 2, 0.75)], + dtype=[('I', 'S1'), ('A', ' Self: + """ + Create DataFrame from a list of arrays corresponding to the columns. + + Parameters + ---------- + arrays : list-like of arrays + Each array in the list corresponds to one column, in order. + columns : list-like, Index + The column names for the resulting DataFrame. + index : list-like, Index + The rows labels for the resulting DataFrame. + dtype : dtype, optional + Optional dtype to enforce for all arrays. + verify_integrity : bool, default True + Validate and homogenize all input. If set to False, it is assumed + that all elements of `arrays` are actual arrays how they will be + stored in a block (numpy ndarray or ExtensionArray), have the same + length as and are aligned with the index, and that `columns` and + `index` are ensured to be an Index object. + + Returns + ------- + DataFrame + """ + if dtype is not None: + dtype = pandas_dtype(dtype) + + columns = ensure_index(columns) + if len(columns) != len(arrays): + raise ValueError("len(columns) must match len(arrays)") + mgr = arrays_to_mgr( + arrays, + columns, + index, + dtype=dtype, + verify_integrity=verify_integrity, + ) + return cls._from_mgr(mgr, axes=mgr.axes) + + def to_stata( + self, + path: FilePath | WriteBuffer[bytes], + *, + convert_dates: dict[Hashable, str] | None = None, + write_index: bool = True, + byteorder: ToStataByteorder | None = None, + time_stamp: datetime.datetime | None = None, + data_label: str | None = None, + variable_labels: dict[Hashable, str] | None = None, + version: int | None = 114, + convert_strl: Sequence[Hashable] | None = None, + compression: CompressionOptions = "infer", + storage_options: StorageOptions | None = None, + value_labels: dict[Hashable, dict[float, str]] | None = None, + ) -> None: + """ + Export DataFrame object to Stata dta format. + + Writes the DataFrame to a Stata dataset file. + "dta" files contain a Stata dataset. + + Parameters + ---------- + path : str, path object, or buffer + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``write()`` function. + + convert_dates : dict + Dictionary mapping columns containing datetime types to stata + internal format to use when writing the dates. Options are 'tc', + 'td', 'tm', 'tw', 'th', 'tq', 'ty'. Column can be either an integer + or a name. Datetime columns that do not have a conversion type + specified will be converted to 'tc'. Raises NotImplementedError if + a datetime column has timezone information. + write_index : bool + Write the index to Stata dataset. + byteorder : str + Can be ">", "<", "little", or "big". default is `sys.byteorder`. + time_stamp : datetime + A datetime to use as file creation date. Default is the current + time. + data_label : str, optional + A label for the data set. Must be 80 characters or smaller. + variable_labels : dict + Dictionary containing columns as keys and variable labels as + values. Each label must be 80 characters or smaller. + version : {{114, 117, 118, 119, None}}, default 114 + Version to use in the output dta file. Set to None to let pandas + decide between 118 or 119 formats depending on the number of + columns in the frame. Version 114 can be read by Stata 10 and + later. Version 117 can be read by Stata 13 or later. Version 118 + is supported in Stata 14 and later. Version 119 is supported in + Stata 15 and later. Version 114 limits string variables to 244 + characters or fewer while versions 117 and later allow strings + with lengths up to 2,000,000 characters. Versions 118 and 119 + support Unicode characters, and version 119 supports more than + 32,767 variables. + + Version 119 should usually only be used when the number of + variables exceeds the capacity of dta format 118. Exporting + smaller datasets in format 119 may have unintended consequences, + and, as of November 2020, Stata SE cannot read version 119 files. + + convert_strl : list, optional + List of column names to convert to string columns to Stata StrL + format. Only available if version is 117. Storing strings in the + StrL format can produce smaller dta files if strings have more than + 8 characters and values are repeated. + + compression : str or dict, default 'infer' + For on-the-fly compression of the output data. If 'infer' and 'path' is + path-like, then detect compression from the following extensions: '.gz', + '.bz2', '.zip', '.xz', '.zst', '.tar', '.tar.gz', '.tar.xz' or '.tar.bz2' + (otherwise no compression). + Set to ``None`` for no compression. + Can also be a dict with key ``'method'`` set to one of + {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} and + other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdCompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for faster compression and + to create a reproducible gzip archive: + ``compression={'method': 'gzip', 'compresslevel': 1, 'mtime': 1}``. + + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + value_labels : dict of dicts + Dictionary containing columns as keys and dictionaries of column value + to labels as values. Labels for a single variable must be 32,000 + characters or smaller. + + Raises + ------ + NotImplementedError + * If datetimes contain timezone information + * Column dtype is not representable in Stata + ValueError + * Columns listed in convert_dates are neither datetime64[ns] + or datetime.datetime + * Column listed in convert_dates is not in DataFrame + * Categorical label contains more than 32,000 characters + + See Also + -------- + read_stata : Import Stata data files. + io.stata.StataWriter : Low-level writer for Stata data files. + io.stata.StataWriter117 : Low-level writer for version 117 files. + + Examples + -------- + >>> df = pd.DataFrame( + ... [["falcon", 350], ["parrot", 18]], columns=["animal", "parrot"] + ... ) + >>> df.to_stata("animals.dta") # doctest: +SKIP + """ + if version not in (114, 117, 118, 119, None): + raise ValueError("Only formats 114, 117, 118 and 119 are supported.") + if version == 114: + if convert_strl is not None: + raise ValueError("strl is not supported in format 114") + from pandas.io.stata import StataWriter as statawriter + elif version == 117: + # Incompatible import of "statawriter" (imported name has type + # "Type[StataWriter117]", local name has type "Type[StataWriter]") + from pandas.io.stata import ( # type: ignore[assignment] + StataWriter117 as statawriter, + ) + else: # versions 118 and 119 + # Incompatible import of "statawriter" (imported name has type + # "Type[StataWriter117]", local name has type "Type[StataWriter]") + from pandas.io.stata import ( # type: ignore[assignment] + StataWriterUTF8 as statawriter, + ) + + kwargs: dict[str, Any] = {} + if version is None or version >= 117: + # strl conversion is only supported >= 117 + kwargs["convert_strl"] = convert_strl + if version is None or version >= 118: + # Specifying the version is only supported for UTF8 (118 or 119) + kwargs["version"] = version + + writer = statawriter( + path, + self, + convert_dates=convert_dates, + byteorder=byteorder, + time_stamp=time_stamp, + data_label=data_label, + write_index=write_index, + variable_labels=variable_labels, + compression=compression, + storage_options=storage_options, + value_labels=value_labels, + **kwargs, + ) + writer.write_file() + + def to_feather(self, path: FilePath | WriteBuffer[bytes], **kwargs) -> None: + """ + Write a DataFrame to the binary Feather format. + + Parameters + ---------- + path : str, path object, file-like object + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``write()`` function. If a string or a path, + it will be used as Root Directory path when writing a partitioned dataset. + **kwargs : + Additional keywords passed to :func:`pyarrow.feather.write_feather`. + This includes the `compression`, `compression_level`, `chunksize` + and `version` keywords. + + See Also + -------- + DataFrame.to_parquet : Write a DataFrame to the binary parquet format. + DataFrame.to_excel : Write object to an Excel sheet. + DataFrame.to_sql : Write to a sql table. + DataFrame.to_csv : Write a csv file. + DataFrame.to_json : Convert the object to a JSON string. + DataFrame.to_html : Render a DataFrame as an HTML table. + DataFrame.to_string : Convert DataFrame to a string. + + Notes + ----- + This function writes the dataframe as a `feather file + `_. Requires a default + index. For saving the DataFrame with your custom index use a method that + supports custom indices e.g. `to_parquet`. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2, 3], [4, 5, 6]]) + >>> df.to_feather("file.feather") # doctest: +SKIP + """ + from pandas.io.feather_format import to_feather + + to_feather(self, path, **kwargs) + + @overload + def to_markdown( + self, + buf: None = ..., + *, + mode: str = ..., + index: bool = ..., + storage_options: StorageOptions | None = ..., + **kwargs, + ) -> str: ... + + @overload + def to_markdown( + self, + buf: FilePath | WriteBuffer[str], + *, + mode: str = ..., + index: bool = ..., + storage_options: StorageOptions | None = ..., + **kwargs, + ) -> None: ... + + @overload + def to_markdown( + self, + buf: FilePath | WriteBuffer[str] | None, + *, + mode: str = ..., + index: bool = ..., + storage_options: StorageOptions | None = ..., + **kwargs, + ) -> str | None: ... + + def to_markdown( + self, + buf: FilePath | WriteBuffer[str] | None = None, + *, + mode: str = "wt", + index: bool = True, + storage_options: StorageOptions | None = None, + **kwargs, + ) -> str | None: + """ + Print DataFrame in Markdown-friendly format. + + Parameters + ---------- + buf : str, Path or StringIO-like, optional, default None + Buffer to write to. If None, the output is returned as a string. + mode : str, optional + Mode in which file is opened, "wt" by default. + index : bool, optional, default True + Add index (row) labels. + + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + **kwargs + These parameters will be passed to `tabulate `_. + + Returns + ------- + str + DataFrame in Markdown-friendly format. + + See Also + -------- + DataFrame.to_html : Render DataFrame to HTML-formatted table. + DataFrame.to_latex : Render DataFrame to LaTeX-formatted table. + + Notes + ----- + Requires the `tabulate `_ package. + + Examples + -------- + >>> df = pd.DataFrame( + ... data={"animal_1": ["elk", "pig"], "animal_2": ["dog", "quetzal"]} + ... ) + >>> print(df.to_markdown()) + | | animal_1 | animal_2 | + |---:|:-----------|:-----------| + | 0 | elk | dog | + | 1 | pig | quetzal | + + Output markdown with a tabulate option. + + >>> print(df.to_markdown(tablefmt="grid")) + +----+------------+------------+ + | | animal_1 | animal_2 | + +====+============+============+ + | 0 | elk | dog | + +----+------------+------------+ + | 1 | pig | quetzal | + +----+------------+------------+ + """ + if "showindex" in kwargs: + raise ValueError("Pass 'index' instead of 'showindex") + + kwargs.setdefault("headers", "keys") + kwargs.setdefault("tablefmt", "pipe") + kwargs.setdefault("showindex", index) + tabulate = import_optional_dependency("tabulate") + result = tabulate.tabulate(self, **kwargs) + if buf is None: + return result + + with get_handle(buf, mode, storage_options=storage_options) as handles: + handles.handle.write(result) + return None + + @overload + def to_parquet( + self, + path: None = ..., + *, + engine: Literal["auto", "pyarrow", "fastparquet"] = ..., + compression: ParquetCompressionOptions = ..., + index: bool | None = ..., + partition_cols: list[str] | None = ..., + storage_options: StorageOptions = ..., + filesystem: Any = ..., + **kwargs, + ) -> bytes: ... + + @overload + def to_parquet( + self, + path: FilePath | WriteBuffer[bytes], + *, + engine: Literal["auto", "pyarrow", "fastparquet"] = ..., + compression: ParquetCompressionOptions = ..., + index: bool | None = ..., + partition_cols: list[str] | None = ..., + storage_options: StorageOptions = ..., + filesystem: Any = ..., + **kwargs, + ) -> None: ... + + def to_parquet( + self, + path: FilePath | WriteBuffer[bytes] | None = None, + *, + engine: Literal["auto", "pyarrow", "fastparquet"] = "auto", + compression: ParquetCompressionOptions = "snappy", + index: bool | None = None, + partition_cols: list[str] | None = None, + storage_options: StorageOptions | None = None, + filesystem: Any = None, + **kwargs, + ) -> bytes | None: + """ + Write a DataFrame to the binary parquet format. + + This function writes the dataframe as a `parquet file + `_. You can choose different parquet + backends, and have the option of compression. See + :ref:`the user guide ` for more details. + + Parameters + ---------- + path : str, path object, file-like object, or None, default None + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``write()`` function. If None, the result is + returned as bytes. If a string or path, it will be used as Root Directory + path when writing a partitioned dataset. + engine : {{'auto', 'pyarrow', 'fastparquet'}}, default 'auto' + Parquet library to use. If 'auto', then the option + ``io.parquet.engine`` is used. The default ``io.parquet.engine`` + behavior is to try 'pyarrow', falling back to 'fastparquet' if + 'pyarrow' is unavailable. + compression : str or None, default 'snappy' + Name of the compression to use. Use ``None`` for no compression. + Supported options: 'snappy', 'gzip', 'brotli', 'lz4', 'zstd'. + index : bool, default None + If ``True``, include the dataframe's index(es) in the file output. + If ``False``, they will not be written to the file. + If ``None``, similar to ``True`` the dataframe's index(es) + will be saved. However, instead of being saved as values, + the RangeIndex will be stored as a range in the metadata so it + doesn't require much space and is faster. Other indexes will + be included as columns in the file output. + partition_cols : list, optional, default None + Column names by which to partition the dataset. + Columns are partitioned in the order they are given. + Must be None if path is not a string. + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + filesystem : fsspec or pyarrow filesystem, default None + Filesystem object to use when reading the parquet file. Only implemented + for ``engine="pyarrow"``. + + .. versionadded:: 2.1.0 + + **kwargs + Additional arguments passed to the parquet library. See + :ref:`pandas io ` for more details. + + Returns + ------- + bytes if no path argument is provided else None + Returns the DataFrame converted to the binary parquet format as bytes if no + path argument. Returns None and writes the DataFrame to the specified + location in the Parquet format if the path argument is provided. + + See Also + -------- + read_parquet : Read a parquet file. + DataFrame.to_orc : Write an orc file. + DataFrame.to_csv : Write a csv file. + DataFrame.to_sql : Write to a sql table. + DataFrame.to_hdf : Write to hdf. + + Notes + ----- + * This function requires either the `fastparquet + `_ or `pyarrow + `_ library. + * When saving a DataFrame with categorical columns to parquet, + the file size may increase due to the inclusion of all possible + categories, not just those present in the data. This behavior + is expected and consistent with pandas' handling of categorical data. + To manage file size and ensure a more predictable roundtrip process, + consider using :meth:`Categorical.remove_unused_categories` on the + DataFrame before saving. + + Examples + -------- + >>> df = pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}) + >>> df.to_parquet("df.parquet.gzip", compression="gzip") # doctest: +SKIP + >>> pd.read_parquet("df.parquet.gzip") # doctest: +SKIP + col1 col2 + 0 1 3 + 1 2 4 + + If you want to get a buffer to the parquet content you can use a io.BytesIO + object, as long as you don't use partition_cols, which creates multiple files. + + >>> import io + >>> f = io.BytesIO() + >>> df.to_parquet(f) + >>> f.seek(0) + 0 + >>> content = f.read() + """ + from pandas.io.parquet import to_parquet + + return to_parquet( + self, + path, + engine, + compression=compression, + index=index, + partition_cols=partition_cols, + storage_options=storage_options, + filesystem=filesystem, + **kwargs, + ) + + @overload + def to_orc( + self, + path: None = ..., + *, + engine: Literal["pyarrow"] = ..., + index: bool | None = ..., + engine_kwargs: dict[str, Any] | None = ..., + ) -> bytes: ... + + @overload + def to_orc( + self, + path: FilePath | WriteBuffer[bytes], + *, + engine: Literal["pyarrow"] = ..., + index: bool | None = ..., + engine_kwargs: dict[str, Any] | None = ..., + ) -> None: ... + + @overload + def to_orc( + self, + path: FilePath | WriteBuffer[bytes] | None, + *, + engine: Literal["pyarrow"] = ..., + index: bool | None = ..., + engine_kwargs: dict[str, Any] | None = ..., + ) -> bytes | None: ... + + def to_orc( + self, + path: FilePath | WriteBuffer[bytes] | None = None, + *, + engine: Literal["pyarrow"] = "pyarrow", + index: bool | None = None, + engine_kwargs: dict[str, Any] | None = None, + ) -> bytes | None: + """ + Write a DataFrame to the Optimized Row Columnar (ORC) format. + + Parameters + ---------- + path : str, file-like object or None, default None + If a string, it will be used as Root Directory path + when writing a partitioned dataset. By file-like object, + we refer to objects with a write() method, such as a file handle + (e.g. via builtin open function). If path is None, + a bytes object is returned. + engine : {'pyarrow'}, default 'pyarrow' + ORC library to use. + index : bool, optional + If ``True``, include the dataframe's index(es) in the file output. + If ``False``, they will not be written to the file. + If ``None``, similar to ``infer`` the dataframe's index(es) + will be saved. However, instead of being saved as values, + the RangeIndex will be stored as a range in the metadata so it + doesn't require much space and is faster. Other indexes will + be included as columns in the file output. + engine_kwargs : dict[str, Any] or None, default None + Additional keyword arguments passed to :func:`pyarrow.orc.write_table`. + + Returns + ------- + bytes if no ``path`` argument is provided else None + Bytes object with DataFrame data if ``path`` is not specified else None. + + Raises + ------ + NotImplementedError + Dtype of one or more columns is category, unsigned integers, interval, + period or sparse. + ValueError + engine is not pyarrow. + + See Also + -------- + read_orc : Read a ORC file. + DataFrame.to_parquet : Write a parquet file. + DataFrame.to_csv : Write a csv file. + DataFrame.to_sql : Write to a sql table. + DataFrame.to_hdf : Write to hdf. + + Notes + ----- + * Find more information on ORC + `here `__. + * Before using this function you should read the :ref:`user guide about + ORC ` and :ref:`install optional dependencies `. + * This function requires `pyarrow `_ + library. + * For supported dtypes please refer to `supported ORC features in Arrow + `__. + * Currently timezones in datetime columns are not preserved when a + dataframe is converted into ORC files. + + Examples + -------- + >>> df = pd.DataFrame(data={"col1": [1, 2], "col2": [4, 3]}) + >>> df.to_orc("df.orc") # doctest: +SKIP + >>> pd.read_orc("df.orc") # doctest: +SKIP + col1 col2 + 0 1 4 + 1 2 3 + + If you want to get a buffer to the orc content you can write it to io.BytesIO + + >>> import io + >>> b = io.BytesIO(df.to_orc()) # doctest: +SKIP + >>> b.seek(0) # doctest: +SKIP + 0 + >>> content = b.read() # doctest: +SKIP + """ + from pandas.io.orc import to_orc + + return to_orc( + self, path, engine=engine, index=index, engine_kwargs=engine_kwargs + ) + + @overload + def to_html( + self, + buf: FilePath | WriteBuffer[str], + *, + columns: Axes | None = ..., + col_space: ColspaceArgType | None = ..., + header: bool = ..., + index: bool = ..., + na_rep: str = ..., + formatters: FormattersType | None = ..., + float_format: FloatFormatType | None = ..., + sparsify: bool | None = ..., + index_names: bool = ..., + justify: str | None = ..., + max_rows: int | None = ..., + max_cols: int | None = ..., + show_dimensions: bool | str = ..., + decimal: str = ..., + bold_rows: bool = ..., + classes: str | list | tuple | None = ..., + escape: bool = ..., + notebook: bool = ..., + border: int | bool | None = ..., + table_id: str | None = ..., + render_links: bool = ..., + encoding: str | None = ..., + ) -> None: ... + + @overload + def to_html( + self, + buf: None = ..., + *, + columns: Axes | None = ..., + col_space: ColspaceArgType | None = ..., + header: bool = ..., + index: bool = ..., + na_rep: str = ..., + formatters: FormattersType | None = ..., + float_format: FloatFormatType | None = ..., + sparsify: bool | None = ..., + index_names: bool = ..., + justify: str | None = ..., + max_rows: int | None = ..., + max_cols: int | None = ..., + show_dimensions: bool | str = ..., + decimal: str = ..., + bold_rows: bool = ..., + classes: str | list | tuple | None = ..., + escape: bool = ..., + notebook: bool = ..., + border: int | bool | None = ..., + table_id: str | None = ..., + render_links: bool = ..., + encoding: str | None = ..., + ) -> str: ... + + @Substitution( + header_type="bool", + header="Whether to print column labels, default True", + col_space_type="str or int, list or dict of int or str", + col_space="The minimum width of each column in CSS length " + "units. An int is assumed to be px units.", + ) + @Substitution(shared_params=fmt.common_docstring, returns=fmt.return_docstring) + def to_html( + self, + buf: FilePath | WriteBuffer[str] | None = None, + *, + columns: Axes | None = None, + col_space: ColspaceArgType | None = None, + header: bool = True, + index: bool = True, + na_rep: str = "NaN", + formatters: FormattersType | None = None, + float_format: FloatFormatType | None = None, + sparsify: bool | None = None, + index_names: bool = True, + justify: str | None = None, + max_rows: int | None = None, + max_cols: int | None = None, + show_dimensions: bool | str = False, + decimal: str = ".", + bold_rows: bool = True, + classes: str | list | tuple | None = None, + escape: bool = True, + notebook: bool = False, + border: int | bool | None = None, + table_id: str | None = None, + render_links: bool = False, + encoding: str | None = None, + ) -> str | None: + """ + Render a DataFrame as an HTML table. + %(shared_params)s + bold_rows : bool, default True + Make the row labels bold in the output. + classes : str or list or tuple, default None + CSS class(es) to apply to the resulting html table. + escape : bool, default True + Convert the characters <, >, and & to HTML-safe sequences. + notebook : {True, False}, default False + Whether the generated HTML is for IPython Notebook. + border : int or bool + When an integer value is provided, it sets the border attribute in + the opening tag, specifying the thickness of the border. + If ``False`` or ``0`` is passed, the border attribute will not + be present in the ```` tag. + The default value for this parameter is governed by + ``pd.options.display.html.border``. + table_id : str, optional + A css id is included in the opening `
` tag if specified. + render_links : bool, default False + Convert URLs to HTML links. + encoding : str, default "utf-8" + Set character encoding. + %(returns)s + See Also + -------- + to_string : Convert DataFrame to a string. + + Examples + -------- + >>> df = pd.DataFrame(data={"col1": [1, 2], "col2": [4, 3]}) + >>> html_string = df.to_html() + >>> print(html_string) +
+ + + + + + + + + + + + + + + + + + + +
col1col2
014
123
+ + HTML output + + +----+-----+-----+ + | |col1 |col2 | + +====+=====+=====+ + |0 |1 |4 | + +----+-----+-----+ + |1 |2 |3 | + +----+-----+-----+ + + >>> df = pd.DataFrame(data={"col1": [1, 2], "col2": [4, 3]}) + >>> html_string = df.to_html(index=False) + >>> print(html_string) + + + + + + + + + + + + + + + + + +
col1col2
14
23
+ + HTML output + + +-----+-----+ + |col1 |col2 | + +=====+=====+ + |1 |4 | + +-----+-----+ + |2 |3 | + +-----+-----+ + """ + if justify is not None and justify not in fmt.VALID_JUSTIFY_PARAMETERS: + raise ValueError("Invalid value for justify parameter") + + formatter = fmt.DataFrameFormatter( + self, + columns=columns, + col_space=col_space, + na_rep=na_rep, + header=header, + index=index, + formatters=formatters, + float_format=float_format, + bold_rows=bold_rows, + sparsify=sparsify, + justify=justify, + index_names=index_names, + escape=escape, + decimal=decimal, + max_rows=max_rows, + max_cols=max_cols, + show_dimensions=show_dimensions, + ) + # TODO: a generic formatter wld b in DataFrameFormatter + return fmt.DataFrameRenderer(formatter).to_html( + buf=buf, + classes=classes, + notebook=notebook, + border=border, + encoding=encoding, + table_id=table_id, + render_links=render_links, + ) + + @overload + def to_xml( + self, + path_or_buffer: None = ..., + *, + index: bool = ..., + root_name: str | None = ..., + row_name: str | None = ..., + na_rep: str | None = ..., + attr_cols: list[str] | None = ..., + elem_cols: list[str] | None = ..., + namespaces: dict[str | None, str] | None = ..., + prefix: str | None = ..., + encoding: str = ..., + xml_declaration: bool | None = ..., + pretty_print: bool | None = ..., + parser: XMLParsers | None = ..., + stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] | None = ..., + compression: CompressionOptions = ..., + storage_options: StorageOptions | None = ..., + ) -> str: ... + + @overload + def to_xml( + self, + path_or_buffer: FilePath | WriteBuffer[bytes] | WriteBuffer[str], + *, + index: bool = ..., + root_name: str | None = ..., + row_name: str | None = ..., + na_rep: str | None = ..., + attr_cols: list[str] | None = ..., + elem_cols: list[str] | None = ..., + namespaces: dict[str | None, str] | None = ..., + prefix: str | None = ..., + encoding: str = ..., + xml_declaration: bool | None = ..., + pretty_print: bool | None = ..., + parser: XMLParsers | None = ..., + stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] | None = ..., + compression: CompressionOptions = ..., + storage_options: StorageOptions | None = ..., + ) -> None: ... + + def to_xml( + self, + path_or_buffer: FilePath | WriteBuffer[bytes] | WriteBuffer[str] | None = None, + *, + index: bool = True, + root_name: str | None = "data", + row_name: str | None = "row", + na_rep: str | None = None, + attr_cols: list[str] | None = None, + elem_cols: list[str] | None = None, + namespaces: dict[str | None, str] | None = None, + prefix: str | None = None, + encoding: str = "utf-8", + xml_declaration: bool | None = True, + pretty_print: bool | None = True, + parser: XMLParsers | None = "lxml", + stylesheet: FilePath | ReadBuffer[str] | ReadBuffer[bytes] | None = None, + compression: CompressionOptions = "infer", + storage_options: StorageOptions | None = None, + ) -> str | None: + """ + Render a DataFrame to an XML document. + + Parameters + ---------- + path_or_buffer : str, path object, file-like object, or None, default None + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a ``write()`` function. If None, the result is returned + as a string. + index : bool, default True + Whether to include index in XML document. + root_name : str, default 'data' + The name of root element in XML document. + row_name : str, default 'row' + The name of row element in XML document. + na_rep : str, optional + Missing data representation. + attr_cols : list-like, optional + List of columns to write as attributes in row element. + Hierarchical columns will be flattened with underscore + delimiting the different levels. + elem_cols : list-like, optional + List of columns to write as children in row element. By default, + all columns output as children of row element. Hierarchical + columns will be flattened with underscore delimiting the + different levels. + namespaces : dict, optional + All namespaces to be defined in root element. Keys of dict + should be prefix names and values of dict corresponding URIs. + Default namespaces should be given empty string key. For + example, :: + + namespaces = {{"": "https://example.com"}} + + prefix : str, optional + Namespace prefix to be used for every element and/or attribute + in document. This should be one of the keys in ``namespaces`` + dict. + encoding : str, default 'utf-8' + Encoding of the resulting document. + xml_declaration : bool, default True + Whether to include the XML declaration at start of document. + pretty_print : bool, default True + Whether output should be pretty printed with indentation and + line breaks. + parser : {{'lxml','etree'}}, default 'lxml' + Parser module to use for building of tree. Only 'lxml' and + 'etree' are supported. With 'lxml', the ability to use XSLT + stylesheet is supported. + stylesheet : str, path object or file-like object, optional + A URL, file-like object, or a raw string containing an XSLT + script used to transform the raw XML output. Script should use + layout of elements and attributes from original output. This + argument requires ``lxml`` to be installed. Only XSLT 1.0 + scripts and not later versions is currently supported. + compression : str or dict, default 'infer' + For on-the-fly compression of the output data. If 'infer' and + 'path_or_buffer' is + path-like, then detect compression from the following extensions: '.gz', + '.bz2', '.zip', '.xz', '.zst', '.tar', '.tar.gz', '.tar.xz' or '.tar.bz2' + (otherwise no compression). + Set to ``None`` for no compression. + Can also be a dict with key ``'method'`` set to one of + {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} and + other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdCompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for faster compression and + to create a reproducible gzip archive: + ``compression={'method': 'gzip', 'compresslevel': 1, 'mtime': 1}``. + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + Returns + ------- + None or str + If ``io`` is None, returns the resulting XML format as a + string. Otherwise returns None. + + See Also + -------- + to_json : Convert the pandas object to a JSON string. + to_html : Convert DataFrame to a html. + + Examples + -------- + >>> df = pd.DataFrame( + ... [["square", 360, 4], ["circle", 360, np.nan], ["triangle", 180, 3]], + ... columns=["shape", "degrees", "sides"], + ... ) + + >>> df.to_xml() # doctest: +SKIP + + + + 0 + square + 360 + 4.0 + + + 1 + circle + 360 + + + + 2 + triangle + 180 + 3.0 + + + + >>> df.to_xml( + ... attr_cols=["index", "shape", "degrees", "sides"] + ... ) # doctest: +SKIP + + + + + + + + >>> df.to_xml( + ... namespaces={{"doc": "https://example.com"}}, prefix="doc" + ... ) # doctest: +SKIP + + + + 0 + square + 360 + 4.0 + + + 1 + circle + 360 + + + + 2 + triangle + 180 + 3.0 + + + """ + + from pandas.io.formats.xml import ( + EtreeXMLFormatter, + LxmlXMLFormatter, + ) + + lxml = import_optional_dependency("lxml.etree", errors="ignore") + + TreeBuilder: type[EtreeXMLFormatter | LxmlXMLFormatter] + + if parser == "lxml": + if lxml is not None: + TreeBuilder = LxmlXMLFormatter + else: + raise ImportError( + "lxml not found, please install or use the etree parser." + ) + + elif parser == "etree": + TreeBuilder = EtreeXMLFormatter + + else: + raise ValueError("Values for parser can only be lxml or etree.") + + xml_formatter = TreeBuilder( + self, + path_or_buffer=path_or_buffer, + index=index, + root_name=root_name, + row_name=row_name, + na_rep=na_rep, + attr_cols=attr_cols, + elem_cols=elem_cols, + namespaces=namespaces, + prefix=prefix, + encoding=encoding, + xml_declaration=xml_declaration, + pretty_print=pretty_print, + stylesheet=stylesheet, + compression=compression, + storage_options=storage_options, + ) + + return xml_formatter.write_output() + + def to_iceberg( + self, + table_identifier: str, + catalog_name: str | None = None, + *, + catalog_properties: dict[str, Any] | None = None, + location: str | None = None, + append: bool = False, + snapshot_properties: dict[str, str] | None = None, + ) -> None: + """ + Write a DataFrame to an Apache Iceberg table. + + .. versionadded:: 3.0.0 + + .. warning:: + + to_iceberg is experimental and may change without warning. + + Parameters + ---------- + table_identifier : str + Table identifier. + catalog_name : str, optional + The name of the catalog. + catalog_properties : dict of {str: str}, optional + The properties that are used next to the catalog configuration. + location : str, optional + Location for the table. + append : bool, default False + If ``True``, append data to the table, instead of replacing the content. + snapshot_properties : dict of {str: str}, optional + Custom properties to be added to the snapshot summary + + See Also + -------- + read_iceberg : Read an Apache Iceberg table. + DataFrame.to_parquet : Write a DataFrame in Parquet format. + + Examples + -------- + >>> df = pd.DataFrame(data={"col1": [1, 2], "col2": [4, 3]}) + >>> df.to_iceberg("my_table", catalog_name="my_catalog") # doctest: +SKIP + """ + from pandas.io.iceberg import to_iceberg + + to_iceberg( + self, + table_identifier, + catalog_name, + catalog_properties=catalog_properties, + location=location, + append=append, + snapshot_properties=snapshot_properties, + ) + + # ---------------------------------------------------------------------- + def info( + self, + verbose: bool | None = None, + buf: WriteBuffer[str] | None = None, + max_cols: int | None = None, + memory_usage: bool | str | None = None, + show_counts: bool | None = None, + ) -> None: + """ + Print a concise summary of a DataFrame. + + This method prints information about a DataFrame including + the index dtype and columns, non-NA values and memory usage. + + Parameters + ---------- + verbose : bool, optional + Whether to print the full summary. By default, the setting in + ``pandas.options.display.max_info_columns`` is followed. + buf : writable buffer, defaults to sys.stdout + Where to send the output. By default, the output is printed to + sys.stdout. Pass a writable buffer if you need to further process + the output. + max_cols : int, optional + When to switch from the verbose to the truncated output. If the + DataFrame has more than `max_cols` columns, the truncated output + is used. By default, the setting in + ``pandas.options.display.max_info_columns`` is used. + memory_usage : bool, str, optional + Specifies whether total memory usage of the DataFrame + elements (including the index) should be displayed. By default, + this follows the ``pandas.options.display.memory_usage`` setting. + + True always show memory usage. False never shows memory usage. + A value of 'deep' is equivalent to "True with deep introspection". + Memory usage is shown in human-readable units (base-2 + representation). Without deep introspection a memory estimation is + made based in column dtype and number of rows assuming values + consume the same memory amount for corresponding dtypes. With deep + memory introspection, a real memory usage calculation is performed + at the cost of computational resources. See the + :ref:`Frequently Asked Questions ` for more + details. + show_counts : bool, optional + Whether to show the non-null counts. By default, this is shown + only if the DataFrame is smaller than + ``pandas.options.display.max_info_rows`` and + ``pandas.options.display.max_info_columns``. A value of True always + shows the counts, and False never shows the counts. + + Returns + ------- + None + This method prints a summary of a DataFrame and returns None. + + See Also + -------- + DataFrame.describe: Generate descriptive statistics of DataFrame + columns. + DataFrame.memory_usage: Memory usage of DataFrame columns. + + Examples + -------- + >>> int_values = [1, 2, 3, 4, 5] + >>> text_values = ["alpha", "beta", "gamma", "delta", "epsilon"] + >>> float_values = [0.0, 0.25, 0.5, 0.75, 1.0] + >>> df = pd.DataFrame( + ... { + ... "int_col": int_values, + ... "text_col": text_values, + ... "float_col": float_values, + ... } + ... ) + >>> df + int_col text_col float_col + 0 1 alpha 0.00 + 1 2 beta 0.25 + 2 3 gamma 0.50 + 3 4 delta 0.75 + 4 5 epsilon 1.00 + + Prints information of all columns: + + >>> df.info(verbose=True) + + RangeIndex: 5 entries, 0 to 4 + Data columns (total 3 columns): + # Column Non-Null Count Dtype + --- ------ -------------- ----- + 0 int_col 5 non-null int64 + 1 text_col 5 non-null str + 2 float_col 5 non-null float64 + dtypes: float64(1), int64(1), str(1) + memory usage: 278.0 bytes + + Prints a summary of columns count and its dtypes but not per column + information: + + >>> df.info(verbose=False) + + RangeIndex: 5 entries, 0 to 4 + Columns: 3 entries, int_col to float_col + dtypes: float64(1), int64(1), str(1) + memory usage: 278.0 bytes + + Pipe output of DataFrame.info to buffer instead of sys.stdout, get + buffer content and writes to a text file: + + >>> import io + >>> buffer = io.StringIO() + >>> df.info(buf=buffer) + >>> s = buffer.getvalue() + >>> with open("df_info.txt", "w", encoding="utf-8") as f: # doctest: +SKIP + ... f.write(s) + 260 + + The `memory_usage` parameter allows deep introspection mode, specially + useful for big DataFrames and fine-tune memory optimization: + + >>> random_strings_array = np.random.choice(["a", "b", "c"], 10**6) + >>> df = pd.DataFrame( + ... { + ... "column_1": np.random.choice(["a", "b", "c"], 10**6), + ... "column_2": np.random.choice(["a", "b", "c"], 10**6), + ... "column_3": np.random.choice(["a", "b", "c"], 10**6), + ... } + ... ) + >>> df.info() + + RangeIndex: 1000000 entries, 0 to 999999 + Data columns (total 3 columns): + # Column Non-Null Count Dtype + --- ------ -------------- ----- + 0 column_1 1000000 non-null str + 1 column_2 1000000 non-null str + 2 column_3 1000000 non-null str + dtypes: str(3) + memory usage: 25.7 MB + + >>> df.info(memory_usage="deep") + + RangeIndex: 1000000 entries, 0 to 999999 + Data columns (total 3 columns): + # Column Non-Null Count Dtype + --- ------ -------------- ----- + 0 column_1 1000000 non-null str + 1 column_2 1000000 non-null str + 2 column_3 1000000 non-null str + dtypes: str(3) + memory usage: 25.7 MB + """ + info = DataFrameInfo( + data=self, + memory_usage=memory_usage, + ) + info.render( + buf=buf, + max_cols=max_cols, + verbose=verbose, + show_counts=show_counts, + ) + + def memory_usage(self, index: bool = True, deep: bool = False) -> Series: + """ + Return the memory usage of each column in bytes. + + The memory usage can optionally include the contribution of + the index and elements of `object` dtype. + + This value is displayed in `DataFrame.info` by default. This can be + suppressed by setting ``pandas.options.display.memory_usage`` to False. + + Parameters + ---------- + index : bool, default True + Specifies whether to include the memory usage of the DataFrame's + index in returned Series. If ``index=True``, the memory usage of + the index is the first item in the output. + deep : bool, default False + If True, introspect the data deeply by interrogating + `object` dtypes for system-level memory consumption, and include + it in the returned values. + + Returns + ------- + Series + A Series whose index is the original column names and whose values + is the memory usage of each column in bytes. + + See Also + -------- + numpy.ndarray.nbytes : Total bytes consumed by the elements of an + ndarray. + Series.memory_usage : Bytes consumed by a Series. + Categorical : Memory-efficient array for string values with + many repeated values. + DataFrame.info : Concise summary of a DataFrame. + + Notes + ----- + See the :ref:`Frequently Asked Questions ` for more + details. + + Examples + -------- + >>> dtypes = ["int64", "float64", "complex128", "object", "bool"] + >>> data = dict([(t, np.ones(shape=5000, dtype=int).astype(t)) for t in dtypes]) + >>> df = pd.DataFrame(data) + >>> df.head() + int64 float64 complex128 object bool + 0 1 1.0 1.0+0.0j 1 True + 1 1 1.0 1.0+0.0j 1 True + 2 1 1.0 1.0+0.0j 1 True + 3 1 1.0 1.0+0.0j 1 True + 4 1 1.0 1.0+0.0j 1 True + + >>> df.memory_usage() + Index 132 + int64 40000 + float64 40000 + complex128 80000 + object 40000 + bool 5000 + dtype: int64 + + >>> df.memory_usage(index=False) + int64 40000 + float64 40000 + complex128 80000 + object 40000 + bool 5000 + dtype: int64 + + The memory footprint of `object` dtype columns is ignored by default: + + >>> df.memory_usage(deep=True) + Index 132 + int64 40000 + float64 40000 + complex128 80000 + object 180000 + bool 5000 + dtype: int64 + + Use a Categorical for efficient storage of an object-dtype column with + many repeated values. + + >>> df["object"].astype("category").memory_usage(deep=True) + 5140 + """ + result = self._constructor_sliced( + [c.memory_usage(index=False, deep=deep) for col, c in self.items()], + index=self.columns, + dtype=np.intp, + ) + if index: + index_memory_usage = self._constructor_sliced( + self.index.memory_usage(deep=deep), index=["Index"] + ) + result = index_memory_usage._append_internal(result) + return result + + def transpose( + self, + *args, + copy: bool | lib.NoDefault = lib.no_default, + ) -> DataFrame: + """ + Transpose index and columns. + + Reflect the DataFrame over its main diagonal by writing rows as columns + and vice-versa. The property :attr:`.T` is an accessor to the method + :meth:`transpose`. + + Parameters + ---------- + *args : tuple, optional + Accepted for compatibility with NumPy. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + Note that a copy is always required for mixed dtype DataFrames, + or for DataFrames with any extension types. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + Returns + ------- + DataFrame + The transposed DataFrame. + + See Also + -------- + numpy.transpose : Permute the dimensions of a given array. + + Notes + ----- + Transposing a DataFrame with mixed dtypes will result in a homogeneous + DataFrame with the `object` dtype. In such a case, a copy of the data + is always made. + + Examples + -------- + **Square DataFrame with homogeneous dtype** + + >>> d1 = {"col1": [1, 2], "col2": [3, 4]} + >>> df1 = pd.DataFrame(data=d1) + >>> df1 + col1 col2 + 0 1 3 + 1 2 4 + + >>> df1_transposed = df1.T # or df1.transpose() + >>> df1_transposed + 0 1 + col1 1 2 + col2 3 4 + + When the dtype is homogeneous in the original DataFrame, we get a + transposed DataFrame with the same dtype: + + >>> df1.dtypes + col1 int64 + col2 int64 + dtype: object + >>> df1_transposed.dtypes + 0 int64 + 1 int64 + dtype: object + + **Non-square DataFrame with mixed dtypes** + + >>> d2 = { + ... "name": ["Alice", "Bob"], + ... "score": [9.5, 8], + ... "employed": [False, True], + ... "kids": [0, 0], + ... } + >>> df2 = pd.DataFrame(data=d2) + >>> df2 + name score employed kids + 0 Alice 9.5 False 0 + 1 Bob 8.0 True 0 + + >>> df2_transposed = df2.T # or df2.transpose() + >>> df2_transposed + 0 1 + name Alice Bob + score 9.5 8.0 + employed False True + kids 0 0 + + When the DataFrame has mixed dtypes, we get a transposed DataFrame with + the `object` dtype: + + >>> df2.dtypes + name str + score float64 + employed bool + kids int64 + dtype: object + >>> df2_transposed.dtypes + 0 object + 1 object + dtype: object + """ + self._check_copy_deprecation(copy) + nv.validate_transpose(args, {}) + # construct the args + + first_dtype = self.dtypes.iloc[0] if len(self.columns) else None + + if self._can_fast_transpose: + # Note: tests pass without this, but this improves perf quite a bit. + new_vals = self._values.T + + result = self._constructor( + new_vals, + index=self.columns, + columns=self.index, + copy=False, + dtype=new_vals.dtype, + ) + if len(self) > 0: + result._mgr.add_references(self._mgr) + + elif ( + self._is_homogeneous_type + and first_dtype is not None + and isinstance(first_dtype, ExtensionDtype) + ): + new_values: list + if isinstance(first_dtype, BaseMaskedDtype): + # We have masked arrays with the same dtype. We can transpose faster. + from pandas.core.arrays.masked import ( + transpose_homogeneous_masked_arrays, + ) + + new_values = transpose_homogeneous_masked_arrays( + cast(Sequence[BaseMaskedArray], self._iter_column_arrays()) + ) + elif isinstance(first_dtype, ArrowDtype): + # We have arrow EAs with the same dtype. We can transpose faster. + from pandas.core.arrays.arrow.array import ( + ArrowExtensionArray, + transpose_homogeneous_pyarrow, + ) + + new_values = transpose_homogeneous_pyarrow( + cast(Sequence[ArrowExtensionArray], self._iter_column_arrays()) + ) + else: + # We have other EAs with the same dtype. We preserve dtype in transpose. + arr_typ = first_dtype.construct_array_type() + values = self.values + new_values = [ + arr_typ._from_sequence(row, dtype=first_dtype) for row in values + ] + + result = type(self)._from_arrays( + new_values, + index=self.columns, + columns=self.index, + verify_integrity=False, + ) + + else: + new_arr = self.values.T + result = self._constructor( + new_arr, + index=self.columns, + columns=self.index, + dtype=new_arr.dtype, + # We already made a copy (more than one block) + copy=False, + ) + + return result.__finalize__(self, method="transpose") + + @property + def T(self) -> DataFrame: + """ + The transpose of the DataFrame. + + Returns + ------- + DataFrame + The transposed DataFrame. + + See Also + -------- + DataFrame.transpose : Transpose index and columns. + + Examples + -------- + >>> df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + >>> df + col1 col2 + 0 1 3 + 1 2 4 + + >>> df.T + 0 1 + col1 1 2 + col2 3 4 + """ + return self.transpose() + + # ---------------------------------------------------------------------- + # Indexing Methods + + def _ixs(self, i: int, axis: AxisInt = 0) -> Series: + """ + Parameters + ---------- + i : int + axis : int + + Returns + ------- + Series + """ + # irow + if axis == 0: + new_mgr = self._mgr.fast_xs(i) + + result = self._constructor_sliced_from_mgr(new_mgr, axes=new_mgr.axes) + result._name = self.index[i] + return result.__finalize__(self) + + # icol + else: + col_mgr = self._mgr.iget(i) + return self._box_col_values(col_mgr, i) + + def _get_column_array(self, i: int) -> ArrayLike: + """ + Get the values of the i'th column (ndarray or ExtensionArray, as stored + in the Block) + + Warning! The returned array is a view but doesn't handle Copy-on-Write, + so this should be used with caution (for read-only purposes). + """ + return self._mgr.iget_values(i) + + def _iter_column_arrays(self) -> Iterator[ArrayLike]: + """ + Iterate over the arrays of all columns in order. + This returns the values as stored in the Block (ndarray or ExtensionArray). + + Warning! The returned array is a view but doesn't handle Copy-on-Write, + so this should be used with caution (for read-only purposes). + """ + for i in range(len(self.columns)): + yield self._get_column_array(i) + + def __getitem__(self, key): + check_dict_or_set_indexers(key) + key = lib.item_from_zerodim(key) + key = com.apply_if_callable(key, self) + + if is_hashable(key, allow_slice=False) and not is_iterator(key): + # is_iterator to exclude generator e.g. test_getitem_listlike + # As of Python 3.12, slice is hashable which breaks MultiIndex (GH#57500) + + # shortcut if the key is in columns + is_mi = isinstance(self.columns, MultiIndex) + # GH#45316 Return view if key is not duplicated + # Only use drop_duplicates with duplicates for performance + if not is_mi and ( + (self.columns.is_unique and key in self.columns) + or key in self.columns.drop_duplicates(keep=False) + ): + return self._get_item(key) + + elif is_mi and self.columns.is_unique and key in self.columns: + return self._getitem_multilevel(key) + + # Do we have a slicer (on rows)? + if isinstance(key, slice): + return self._getitem_slice(key) + + # Do we have a (boolean) DataFrame? + if isinstance(key, DataFrame): + return self.where(key) + + # Do we have a (boolean) 1d indexer? + if com.is_bool_indexer(key): + return self._getitem_bool_array(key) + + # We are left with two options: a single key, and a collection of keys, + # We interpret tuples as collections only for non-MultiIndex + is_single_key = isinstance(key, tuple) or not is_list_like(key) + + if is_single_key: + if self.columns.nlevels > 1: + return self._getitem_multilevel(key) + indexer = self.columns.get_loc(key) + if is_integer(indexer): + indexer = [indexer] + else: + if is_iterator(key): + key = list(key) + indexer = self.columns._get_indexer_strict(key, "columns")[1] + + # take() does not accept boolean indexers + if getattr(indexer, "dtype", None) == bool: + indexer = np.where(indexer)[0] + + if isinstance(indexer, slice): + return self._slice(indexer, axis=1) + + data = self.take(indexer, axis=1) + + if is_single_key: + # What does looking for a single key in a non-unique index return? + # The behavior is inconsistent. It returns a Series, except when + # - the key itself is repeated (test on data.shape, #9519), or + # - we have a MultiIndex on columns (test on self.columns, #21309) + if data.shape[1] == 1 and not isinstance(self.columns, MultiIndex): + # GH#26490 using data[key] can cause RecursionError + return data._get_item(key) + + return data + + def _getitem_bool_array(self, key): + # also raises Exception if object array with NA values + # warning here just in case -- previously __setitem__ was + # reindexing but __getitem__ was not; it seems more reasonable to + # go with the __setitem__ behavior since that is more consistent + # with all other indexing behavior + if isinstance(key, Series) and not key.index.equals(self.index): + warnings.warn( + "Boolean Series key will be reindexed to match DataFrame index.", + UserWarning, + stacklevel=find_stack_level(), + ) + elif len(key) != len(self.index): + raise ValueError( + f"Item wrong length {len(key)} instead of {len(self.index)}." + ) + + # check_bool_indexer will throw exception if Series key cannot + # be reindexed to match DataFrame rows + key = check_bool_indexer(self.index, key) + + if key.all(): + return self.copy(deep=False) + + indexer = key.nonzero()[0] + return self.take(indexer, axis=0) + + def _getitem_multilevel(self, key): + # self.columns is a MultiIndex + loc = self.columns.get_loc(key) + if isinstance(loc, (slice, np.ndarray)): + new_columns = self.columns[loc] + result_columns = maybe_droplevels(new_columns, key) + result = self.iloc[:, loc] + result.columns = result_columns + + # If there is only one column being returned, and its name is + # either an empty string, or a tuple with an empty string as its + # first element, then treat the empty string as a placeholder + # and return the column as if the user had provided that empty + # string in the key. If the result is a Series, exclude the + # implied empty string from its name. + if len(result.columns) == 1: + # e.g. test_frame_getitem_multicolumn_empty_level, + # test_frame_mixed_depth_get, test_loc_setitem_single_column_slice + top = result.columns[0] + if isinstance(top, tuple): + top = top[0] + if top == "": + result = result[""] + if isinstance(result, Series): + result = self._constructor_sliced( + result, index=self.index, name=key + ) + + return result + else: + # loc is neither a slice nor ndarray, so must be an int + return self._ixs(loc, axis=1) + + def _get_value(self, index, col, takeable: bool = False) -> Scalar: + """ + Quickly retrieve single value at passed column and index. + + Parameters + ---------- + index : row label + col : column label + takeable : interpret the index/col as indexers, default False + + Returns + ------- + scalar + + Notes + ----- + Assumes that both `self.index._index_as_unique` and + `self.columns._index_as_unique`; Caller is responsible for checking. + """ + if takeable: + series = self._ixs(col, axis=1) + return series._values[index] + + series = self._get_item(col) + + if not isinstance(self.index, MultiIndex): + # CategoricalIndex: Trying to use the engine fastpath may give incorrect + # results if our categories are integers that dont match our codes + # IntervalIndex: IntervalTree has no get_loc + row = self.index.get_loc(index) + return series._values[row] + + # For MultiIndex going through engine effectively restricts us to + # same-length tuples; see test_get_set_value_no_partial_indexing + loc = self.index._engine.get_loc(index) + return series._values[loc] + + def isetitem(self, loc, value) -> None: + """ + Set the given value in the column with position `loc`. + + This is a positional analogue to ``__setitem__``. + + Parameters + ---------- + loc : int or sequence of ints + Index position for the column. + value : scalar or arraylike + Value(s) for the column. + + See Also + -------- + DataFrame.iloc : Purely integer-location based indexing for selection by + position. + + Notes + ----- + ``frame.isetitem(loc, value)`` is an in-place method as it will + modify the DataFrame in place (not returning a new object). In contrast to + ``frame.iloc[:, i] = value`` which will try to update the existing values in + place, ``frame.isetitem(loc, value)`` will not update the values of the column + itself in place, it will instead insert a new array. + + In cases where ``frame.columns`` is unique, this is equivalent to + ``frame[frame.columns[i]] = value``. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + >>> df.isetitem(1, [5, 6]) + >>> df + A B + 0 1 5 + 1 2 6 + """ + if isinstance(value, DataFrame): + if is_integer(loc): + loc = [loc] + + if len(loc) != len(value.columns): + raise ValueError( + f"Got {len(loc)} positions but value has {len(value.columns)} " + f"columns." + ) + + for i, idx in enumerate(loc): + arraylike, refs = self._sanitize_column(value.iloc[:, i]) + self._iset_item_mgr(idx, arraylike, inplace=False, refs=refs) + return + + arraylike, refs = self._sanitize_column(value) + self._iset_item_mgr(loc, arraylike, inplace=False, refs=refs) + + def __setitem__(self, key, value) -> None: + """ + Set item(s) in DataFrame by key. + + This method allows you to set the values of one or more columns in the + DataFrame using a key. If the key does not exist, a new + column will be created. + + Parameters + ---------- + key : The object(s) in the index which are to be assigned to + Column label(s) to set. Can be a single column name, list of column names, + or tuple for MultiIndex columns. + value : scalar, array-like, Series, or DataFrame + Value(s) to set for the specified key(s). + + Returns + ------- + None + This method does not return a value. + + See Also + -------- + DataFrame.loc : Access and set values by label-based indexing. + DataFrame.iloc : Access and set values by position-based indexing. + DataFrame.assign : Assign new columns to a DataFrame. + + Notes + ----- + When assigning a Series to a DataFrame column, pandas aligns the Series + by index labels, not by position. This means: + + * Values from the Series are matched to DataFrame rows by index label + * If a Series index label doesn't exist in the DataFrame index, it's ignored + * If a DataFrame index label doesn't exist in the Series index, NaN is assigned + * The order of values in the Series doesn't matter; only the index labels matter + + Examples + -------- + Basic column assignment: + + >>> df = pd.DataFrame({"A": [1, 2, 3]}) + >>> df["B"] = [4, 5, 6] # Assigns by position + >>> df + A B + 0 1 4 + 1 2 5 + 2 3 6 + + Series assignment with index alignment: + + >>> df = pd.DataFrame({"A": [1, 2, 3]}, index=[0, 1, 2]) + >>> s = pd.Series([10, 20], index=[1, 3]) # Note: index 3 doesn't exist in df + >>> df["B"] = s # Assigns by index label, not position + >>> df + A B + 0 1 NaN + 1 2 10.0 + 2 3 NaN + + Series assignment with partial index match: + + >>> df = pd.DataFrame({"A": [1, 2, 3, 4]}, index=["a", "b", "c", "d"]) + >>> s = pd.Series([100, 200], index=["b", "d"]) + >>> df["B"] = s + >>> df + A B + a 1 NaN + b 2 100.0 + c 3 NaN + d 4 200.0 + + Series index labels NOT in DataFrame, ignored: + + >>> df = pd.DataFrame({"A": [1, 2, 3]}, index=["x", "y", "z"]) + >>> s = pd.Series([10, 20, 30, 40, 50], index=["x", "y", "a", "b", "z"]) + >>> df["B"] = s + >>> df + A B + x 1 10 + y 2 20 + z 3 50 + """ + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount(self) <= REF_COUNT and not com.is_local_in_caller_frame( + self + ): + warnings.warn( + _chained_assignment_msg, ChainedAssignmentError, stacklevel=2 + ) + + key = com.apply_if_callable(key, self) + + # see if we can slice the rows + if isinstance(key, slice): + slc = self.index._convert_slice_indexer(key, kind="getitem") + return self._setitem_slice(slc, value) + + if isinstance(key, DataFrame) or getattr(key, "ndim", None) == 2: + self._setitem_frame(key, value) + elif isinstance(key, (Series, np.ndarray, list, Index)): + self._setitem_array(key, value) + elif isinstance(value, DataFrame): + self._set_item_frame_value(key, value) + elif ( + is_list_like(value) + and not self.columns.is_unique + and 1 < len(self.columns.get_indexer_for([key])) == len(value) + ): + # Column to set is duplicated + self._setitem_array([key], value) + else: + # set column + self._set_item(key, value) + + def _setitem_slice(self, key: slice, value) -> None: + # NB: we can't just use self.loc[key] = value because that + # operates on labels and we need to operate positional for + # backwards-compat, xref GH#31469 + self.iloc[key] = value + + def _setitem_array(self, key, value) -> None: + # also raises Exception if object array with NA values + if com.is_bool_indexer(key): + # bool indexer is indexing along rows + if len(key) != len(self.index): + raise ValueError( + f"Item wrong length {len(key)} instead of {len(self.index)}!" + ) + key = check_bool_indexer(self.index, key) + indexer = key.nonzero()[0] + if isinstance(value, DataFrame): + # GH#39931 reindex since iloc does not align + value = value.reindex(self.index.take(indexer)) + self.iloc[indexer] = value + + # Note: unlike self.iloc[:, indexer] = value, this will + # never try to overwrite values inplace + + elif isinstance(value, DataFrame): + check_key_length(self.columns, key, value) + for k1, k2 in zip(key, value.columns, strict=False): + self[k1] = value[k2] + + elif not is_list_like(value): + for col in key: + self[col] = value + + elif isinstance(value, np.ndarray) and value.ndim == 2: + self._iset_not_inplace(key, value) + + elif np.ndim(value) > 1: + # list of lists + value = DataFrame(value).values + self._setitem_array(key, value) + + else: + self._iset_not_inplace(key, value) + + def _iset_not_inplace(self, key, value) -> None: + # GH#39510 when setting with df[key] = obj with a list-like key and + # list-like value, we iterate over those listlikes and set columns + # one at a time. This is different from dispatching to + # `self.loc[:, key]= value` because loc.__setitem__ may overwrite + # data inplace, whereas this will insert new arrays. + + def igetitem(obj, i: int): + # Note: we catch DataFrame obj before getting here, but + # hypothetically would return obj.iloc[:, i] + if isinstance(obj, np.ndarray): + return obj[..., i] + else: + return obj[i] + + if self.columns.is_unique: + if np.shape(value)[-1] != len(key): + raise ValueError("Columns must be same length as key") + + for i, col in enumerate(key): + self[col] = igetitem(value, i) + + else: + ilocs = self.columns.get_indexer_non_unique(key)[0] + if (ilocs < 0).any(): + # key entries not in self.columns + raise NotImplementedError + + if np.shape(value)[-1] != len(ilocs): + raise ValueError("Columns must be same length as key") + + assert np.ndim(value) <= 2 + + orig_columns = self.columns + + # Using self.iloc[:, i] = ... may set values inplace, which + # by convention we do not do in __setitem__ + try: + self.columns = Index(range(len(self.columns))) + for i, iloc in enumerate(ilocs): + self[iloc] = igetitem(value, i) + finally: + self.columns = orig_columns + + def _setitem_frame(self, key, value) -> None: + # support boolean setting with DataFrame input, e.g. + # df[df > df2] = 0 + if isinstance(key, np.ndarray): + if key.shape != self.shape: + raise ValueError("Array conditional must be same shape as self") + key = self._constructor(key, **self._construct_axes_dict(), copy=False) + + if key.size and not all(is_bool_dtype(blk.dtype) for blk in key._mgr.blocks): + raise TypeError( + "Must pass DataFrame or 2-d ndarray with boolean values only" + ) + + self._where(-key, value, inplace=True) + + def _set_item_frame_value(self, key, value: DataFrame) -> None: + self._ensure_valid_index(value) + + # align columns + if key in self.columns: + loc = self.columns.get_loc(key) + cols = self.columns[loc] + len_cols = 1 if is_scalar(cols) or isinstance(cols, tuple) else len(cols) + if len_cols != len(value.columns): + raise ValueError("Columns must be same length as key") + + # align right-hand-side columns if self.columns + # is multi-index and self[key] is a sub-frame + if isinstance(self.columns, MultiIndex) and isinstance( + loc, (slice, Series, np.ndarray, Index) + ): + cols_droplevel = maybe_droplevels(cols, key) + if ( + not isinstance(cols_droplevel, MultiIndex) + and is_string_dtype(cols_droplevel.dtype) + and not cols_droplevel.any() + ): + # if cols_droplevel contains only empty strings, + # value.reindex(cols_droplevel, axis=1) would be full of NaNs + # see GH#62518 and GH#61841 + return + if len(cols_droplevel) and not cols_droplevel.equals(value.columns): + value = value.reindex(cols_droplevel, axis=1) + + for col, col_droplevel in zip(cols, cols_droplevel, strict=True): + self[col] = value[col_droplevel] + return + + if is_scalar(cols): + self[cols] = value[value.columns[0]] + return + + locs: np.ndarray | list + if isinstance(loc, slice): + locs = np.arange(loc.start, loc.stop, loc.step) + elif is_scalar(loc): + locs = [loc] + else: + locs = loc.nonzero()[0] + + return self.isetitem(locs, value) + + if len(value.columns) > 1: + raise ValueError( + "Cannot set a DataFrame with multiple columns to the single " + f"column {key}" + ) + elif len(value.columns) == 0: + raise ValueError( + f"Cannot set a DataFrame without columns to the column {key}" + ) + + self[key] = value[value.columns[0]] + + def _iset_item_mgr( + self, + loc: int | slice | np.ndarray, + value, + inplace: bool = False, + refs: BlockValuesRefs | None = None, + ) -> None: + # when called from _set_item_mgr loc can be anything returned from get_loc + self._mgr.iset(loc, value, inplace=inplace, refs=refs) + + def _set_item_mgr( + self, key, value: ArrayLike, refs: BlockValuesRefs | None = None + ) -> None: + try: + loc = self._info_axis.get_loc(key) + except KeyError: + # This item wasn't present, just insert at end + self._mgr.insert(len(self._info_axis), key, value, refs) + else: + self._iset_item_mgr(loc, value, refs=refs) + + def _iset_item(self, loc: int, value: Series, inplace: bool = True) -> None: + # We are only called from _replace_columnwise which guarantees that + # no reindex is necessary + self._iset_item_mgr(loc, value._values, inplace=inplace, refs=value._references) + + def _set_item(self, key, value) -> None: + """ + Add series to DataFrame in specified column. + + If series is a numpy-array (not a Series/TimeSeries), it must be the + same length as the DataFrames index or an error will be thrown. + + Series/TimeSeries will be conformed to the DataFrames index to + ensure homogeneity. + """ + value, refs = self._sanitize_column(value) + + if ( + key in self.columns + and value.ndim == 1 + and not isinstance(value.dtype, ExtensionDtype) + ): + # broadcast across multiple columns if necessary + if not self.columns.is_unique or isinstance(self.columns, MultiIndex): + existing_piece = self[key] + if isinstance(existing_piece, DataFrame): + value = np.tile(value, (len(existing_piece.columns), 1)).T + refs = None + + self._set_item_mgr(key, value, refs) + + def _set_value( + self, index: IndexLabel, col, value: Scalar, takeable: bool = False + ) -> None: + """ + Put single value at passed column and index. + + Parameters + ---------- + index : Label + row label + col : Label + column label + value : scalar + takeable : bool, default False + Sets whether or not index/col interpreted as indexers + """ + try: + if takeable: + icol = col + iindex = cast(int, index) + else: + icol = self.columns.get_loc(col) + iindex = self.index.get_loc(index) + self._mgr.column_setitem(icol, iindex, value, inplace_only=True) + + except (KeyError, TypeError, ValueError, LossySetitemError): + # get_loc might raise a KeyError for missing labels (falling back + # to (i)loc will do expansion of the index) + # column_setitem will do validation that may raise TypeError, + # ValueError, or LossySetitemError + # set using a non-recursive method & reset the cache + if takeable: + self.iloc[index, col] = value + else: + self.loc[index, col] = value + + except InvalidIndexError as ii_err: + # GH48729: Seems like you are trying to assign a value to a + # row when only scalar options are permitted + raise InvalidIndexError( + f"You can only assign a scalar value not a {type(value)}" + ) from ii_err + + def _ensure_valid_index(self, value) -> None: + """ + Ensure that if we don't have an index, that we can create one from the + passed value. + """ + # GH5632, make sure that we are a Series convertible + if not len(self.index) and is_list_like(value) and len(value): + if not isinstance(value, DataFrame): + try: + value = Series(value) + except (ValueError, NotImplementedError, TypeError) as err: + raise ValueError( + "Cannot set a frame with no defined index " + "and a value that cannot be converted to a Series" + ) from err + + # GH31368 preserve name of index + index_copy = value.index.copy() + if self.index.name is not None: + index_copy.name = self.index.name + + self._mgr = self._mgr.reindex_axis(index_copy, axis=1, fill_value=np.nan) + + def _box_col_values(self, values: SingleBlockManager, loc: int) -> Series: + """ + Provide boxed values for a column. + """ + # Lookup in columns so that if e.g. a str datetime was passed + # we attach the Timestamp object as the name. + name = self.columns[loc] + # We get index=self.index bc values is a SingleBlockManager + obj = self._constructor_sliced_from_mgr(values, axes=values.axes) + obj._name = name + return obj.__finalize__(self) + + def _get_item(self, item: Hashable) -> Series: + loc = self.columns.get_loc(item) + return self._ixs(loc, axis=1) + + # ---------------------------------------------------------------------- + # Unsorted + + @overload + def query( + self, + expr: str, + *, + parser: Literal["pandas", "python"] = ..., + engine: Literal["python", "numexpr"] | None = ..., + local_dict: dict[str, Any] | None = ..., + global_dict: dict[str, Any] | None = ..., + resolvers: list[Mapping] | None = ..., + level: int = ..., + inplace: Literal[False] = ..., + ) -> DataFrame: ... + + @overload + def query( + self, + expr: str, + *, + parser: Literal["pandas", "python"] = ..., + engine: Literal["python", "numexpr"] | None = ..., + local_dict: dict[str, Any] | None = ..., + global_dict: dict[str, Any] | None = ..., + resolvers: list[Mapping] | None = ..., + level: int = ..., + inplace: Literal[True], + ) -> None: ... + + @overload + def query( + self, + expr: str, + *, + parser: Literal["pandas", "python"] = ..., + engine: Literal["python", "numexpr"] | None = ..., + local_dict: dict[str, Any] | None = ..., + global_dict: dict[str, Any] | None = ..., + resolvers: list[Mapping] | None = ..., + level: int = ..., + inplace: bool = ..., + ) -> DataFrame | None: ... + + def query( + self, + expr: str, + *, + parser: Literal["pandas", "python"] = "pandas", + engine: Literal["python", "numexpr"] | None = None, + local_dict: dict[str, Any] | None = None, + global_dict: dict[str, Any] | None = None, + resolvers: list[Mapping] | None = None, + level: int = 0, + inplace: bool = False, + ) -> DataFrame | None: + """ + Query the columns of a DataFrame with a boolean expression. + + .. warning:: + + This method can run arbitrary code which can make you vulnerable to code + injection if you pass user input to this function. + + Parameters + ---------- + expr : str + The query string to evaluate. + + See the documentation for :func:`eval` for details of + supported operations and functions in the query string. + + See the documentation for :meth:`DataFrame.eval` for details on + referring to column names and variables in the query string. + parser : {'pandas', 'python'}, default 'pandas' + The parser to use to construct the syntax tree from the expression. The + default of ``'pandas'`` parses code slightly different than standard + Python. Alternatively, you can parse an expression using the + ``'python'`` parser to retain strict Python semantics. See the + :ref:`enhancing performance ` documentation for + more details. + engine : {'python', 'numexpr'}, default 'numexpr' + + The engine used to evaluate the expression. Supported engines are + + - None : tries to use ``numexpr``, falls back to ``python`` + - ``'numexpr'`` : This default engine evaluates pandas objects using + numexpr for large speed ups in complex expressions with large frames. + - ``'python'`` : Performs operations as if you had ``eval``'d in top + level python. This engine is generally not that useful. + + More backends may be available in the future. + local_dict : dict or None, optional + A dictionary of local variables, taken from locals() by default. + global_dict : dict or None, optional + A dictionary of global variables, taken from globals() by default. + resolvers : list of dict-like or None, optional + A list of objects implementing the ``__getitem__`` special method that + you can use to inject an additional collection of namespaces to use for + variable lookup. For example, this is used in the + :meth:`~DataFrame.query` method to inject the + ``DataFrame.index`` and ``DataFrame.columns`` + variables that refer to their respective :class:`~pandas.DataFrame` + instance attributes. + level : int, optional + The number of prior stack frames to traverse and add to the current + scope. Most users will **not** need to change this parameter. + inplace : bool + Whether to modify the DataFrame rather than creating a new one. + + Returns + ------- + DataFrame or None + DataFrame resulting from the provided query expression or + None if ``inplace=True``. + + See Also + -------- + eval : Evaluate a string describing operations on + DataFrame columns. + DataFrame.eval : Evaluate a string describing operations on + DataFrame columns. + + Notes + ----- + The result of the evaluation of this expression is first passed to + :attr:`DataFrame.loc` and if that fails because of a + multidimensional key (e.g., a DataFrame) then the result will be passed + to :meth:`DataFrame.__getitem__`. + + This method uses the top-level :func:`eval` function to + evaluate the passed query. + + The :meth:`~pandas.DataFrame.query` method uses a slightly + modified Python syntax by default. For example, the ``&`` and ``|`` + (bitwise) operators have the precedence of their boolean cousins, + :keyword:`and` and :keyword:`or`. This *is* syntactically valid Python, + however the semantics are different. + + You can change the semantics of the expression by passing the keyword + argument ``parser='python'``. This enforces the same semantics as + evaluation in Python space. Likewise, you can pass ``engine='python'`` + to evaluate an expression using Python itself as a backend. This is not + recommended as it is inefficient compared to using ``numexpr`` as the + engine. + + The :attr:`DataFrame.index` and + :attr:`DataFrame.columns` attributes of the + :class:`~pandas.DataFrame` instance are placed in the query namespace + by default, which allows you to treat both the index and columns of the + frame as a column in the frame. + The identifier ``index`` is used for the frame index; you can also + use the name of the index to identify it in a query. Please note that + Python keywords may not be used as identifiers. + + For further details and examples see the ``query`` documentation in + :ref:`indexing `. + + *Backtick quoted variables* + + Backtick quoted variables are parsed as literal Python code and + are converted internally to a Python valid identifier. + This can lead to the following problems. + + During parsing a number of disallowed characters inside the backtick + quoted string are replaced by strings that are allowed as a Python identifier. + These characters include all operators in Python, the space character, the + question mark, the exclamation mark, the dollar sign, and the euro sign. + + A backtick can be escaped by double backticks. + + See also the `Python documentation about lexical analysis + `__ + in combination with the source code in :mod:`pandas.core.computation.parsing`. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"A": range(1, 6), "B": range(10, 0, -2), "C&C": range(10, 5, -1)} + ... ) + >>> df + A B C&C + 0 1 10 10 + 1 2 8 9 + 2 3 6 8 + 3 4 4 7 + 4 5 2 6 + >>> df.query("A > B") + A B C&C + 4 5 2 6 + + The previous expression is equivalent to + + >>> df[df.A > df.B] + A B C&C + 4 5 2 6 + + For columns with spaces in their name, you can use backtick quoting. + + >>> df.query("B == `C&C`") + A B C&C + 0 1 10 10 + + The previous expression is equivalent to + + >>> df[df.B == df["C&C"]] + A B C&C + 0 1 10 10 + + Using local variable: + + >>> local_var = 2 + >>> df.query("A <= @local_var") + A B C&C + 0 1 10 10 + 1 2 8 9 + """ + inplace = validate_bool_kwarg(inplace, "inplace") + if not isinstance(expr, str): + msg = f"expr must be a string to be evaluated, {type(expr)} given" + raise ValueError(msg) + + res = self.eval( + expr, + level=level + 1, + parser=parser, + target=None, + engine=engine, + local_dict=local_dict, + global_dict=global_dict, + resolvers=resolvers or (), + ) + + try: + result = self.loc[res] + except ValueError: + # when res is multi-dimensional loc raises, but this is sometimes a + # valid query + result = self[res] + + if inplace: + self._update_inplace(result) + return None + else: + return result + + @overload + def eval(self, expr: str, *, inplace: Literal[False] = ..., **kwargs) -> Any: ... + + @overload + def eval(self, expr: str, *, inplace: Literal[True], **kwargs) -> None: ... + + def eval(self, expr: str, *, inplace: bool = False, **kwargs) -> Any | None: + """ + Evaluate a string describing operations on DataFrame columns. + + .. warning:: + + This method can run arbitrary code which can make you vulnerable to code + injection if you pass user input to this function. + + Operates on columns only, not specific rows or elements. This allows + `eval` to run arbitrary code, which can make you vulnerable to code + injection if you pass user input to this function. + + Parameters + ---------- + expr : str + The expression string to evaluate. + + You can refer to variables + in the environment by prefixing them with an '@' character like + ``@a + b``. + + You can refer to column names that are not valid Python variable names + by surrounding them in backticks. Thus, column names containing spaces + or punctuation (besides underscores) or starting with digits must be + surrounded by backticks. (For example, a column named "Area (cm^2)" would + be referenced as ```Area (cm^2)```). Column names which are Python keywords + (like "if", "for", "import", etc) cannot be used. + + For example, if one of your columns is called ``a a`` and you want + to sum it with ``b``, your query should be ```a a` + b``. + + See the documentation for :func:`eval` for full details of + supported operations and functions in the expression string. + inplace : bool, default False + If the expression contains an assignment, whether to perform the + operation inplace and mutate the existing DataFrame. Otherwise, + a new DataFrame is returned. + **kwargs + See the documentation for :func:`eval` for complete details + on the keyword arguments accepted by + :meth:`~pandas.DataFrame.eval`. + + Returns + ------- + ndarray, scalar, pandas object, or None + The result of the evaluation or None if ``inplace=True``. + + See Also + -------- + DataFrame.query : Evaluates a boolean expression to query the columns + of a frame. + DataFrame.assign : Can evaluate an expression or function to create new + values for a column. + eval : Evaluate a Python expression as a string using various + backends. + + Notes + ----- + For more details see the API documentation for :func:`~eval`. + For detailed examples see :ref:`enhancing performance with eval + `. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"A": range(1, 6), "B": range(10, 0, -2), "C&C": range(10, 5, -1)} + ... ) + >>> df + A B C&C + 0 1 10 10 + 1 2 8 9 + 2 3 6 8 + 3 4 4 7 + 4 5 2 6 + >>> df.eval("A + B") + 0 11 + 1 10 + 2 9 + 3 8 + 4 7 + dtype: int64 + + Assignment is allowed though by default the original DataFrame is not + modified. + + >>> df.eval("D = A + B") + A B C&C D + 0 1 10 10 11 + 1 2 8 9 10 + 2 3 6 8 9 + 3 4 4 7 8 + 4 5 2 6 7 + >>> df + A B C&C + 0 1 10 10 + 1 2 8 9 + 2 3 6 8 + 3 4 4 7 + 4 5 2 6 + + Multiple columns can be assigned to using multi-line expressions: + + >>> df.eval( + ... ''' + ... D = A + B + ... E = A - B + ... ''' + ... ) + A B C&C D E + 0 1 10 10 11 -9 + 1 2 8 9 10 -6 + 2 3 6 8 9 -3 + 3 4 4 7 8 0 + 4 5 2 6 7 3 + + For columns with spaces or other disallowed characters in their name, you can + use backtick quoting. + + >>> df.eval("B * `C&C`") + 0 100 + 1 72 + 2 48 + 3 28 + 4 12 + dtype: int64 + + Local variables shall be explicitly referenced using ``@`` + character in front of the name: + + >>> local_var = 2 + >>> df.eval("@local_var * A") + 0 2 + 1 4 + 2 6 + 3 8 + 4 10 + Name: A, dtype: int64 + """ + from pandas.core.computation.eval import eval as _eval + + inplace = validate_bool_kwarg(inplace, "inplace") + kwargs["level"] = kwargs.pop("level", 0) + 1 + index_resolvers = self._get_index_resolvers() + column_resolvers = self._get_cleaned_column_resolvers() + resolvers = column_resolvers, index_resolvers + if "target" not in kwargs: + kwargs["target"] = self + kwargs["resolvers"] = tuple(kwargs.get("resolvers", ())) + resolvers + + return _eval(expr, inplace=inplace, **kwargs) + + def select_dtypes(self, include=None, exclude=None) -> DataFrame: + """ + Return a subset of the DataFrame's columns based on the column dtypes. + + This method allows for filtering columns based on their data types. + It is useful when working with heterogeneous DataFrames where operations + need to be performed on a specific subset of data types. + + Parameters + ---------- + include, exclude : scalar or list-like + A selection of dtypes or strings to be included/excluded. At least + one of these parameters must be supplied. + + Returns + ------- + DataFrame + The subset of the frame including the dtypes in ``include`` and + excluding the dtypes in ``exclude``. + + Raises + ------ + ValueError + * If both of ``include`` and ``exclude`` are empty + * If ``include`` and ``exclude`` have overlapping elements + TypeError + * If any kind of string dtype is passed in. + + See Also + -------- + DataFrame.dtypes: Return Series with the data type of each column. + + Notes + ----- + * To select all *numeric* types, use ``np.number`` or ``'number'`` + * To select strings you must use the ``object`` dtype, but note that + this will return *all* object dtype columns. With + ``pd.options.future.infer_string`` enabled, using ``"str"`` will + work to select all string columns. + * See the `numpy dtype hierarchy + `__ + * To select datetimes, use ``np.datetime64``, ``'datetime'`` or + ``'datetime64'`` + * To select timedeltas, use ``np.timedelta64``, ``'timedelta'`` or + ``'timedelta64'`` + * To select Pandas categorical dtypes, use ``'category'`` + * To select Pandas datetimetz dtypes, use ``'datetimetz'`` + or ``'datetime64[ns, tz]'`` + + Examples + -------- + >>> df = pd.DataFrame( + ... {"a": [1, 2] * 3, "b": [True, False] * 3, "c": [1.0, 2.0] * 3} + ... ) + >>> df + a b c + 0 1 True 1.0 + 1 2 False 2.0 + 2 1 True 1.0 + 3 2 False 2.0 + 4 1 True 1.0 + 5 2 False 2.0 + + >>> df.select_dtypes(include="bool") + b + 0 True + 1 False + 2 True + 3 False + 4 True + 5 False + + >>> df.select_dtypes(include=["float64"]) + c + 0 1.0 + 1 2.0 + 2 1.0 + 3 2.0 + 4 1.0 + 5 2.0 + + >>> df.select_dtypes(exclude=["int64"]) + b c + 0 True 1.0 + 1 False 2.0 + 2 True 1.0 + 3 False 2.0 + 4 True 1.0 + 5 False 2.0 + """ + if not is_list_like(include): + include = (include,) if include is not None else () + if not is_list_like(exclude): + exclude = (exclude,) if exclude is not None else () + + selection = (frozenset(include), frozenset(exclude)) + + if not any(selection): + raise ValueError("at least one of include or exclude must be nonempty") + + # convert the myriad valid dtypes object to a single representation + def check_int_infer_dtype(dtypes): + converted_dtypes: list[type] = [] + for dtype in dtypes: + # Numpy maps int to different types (int32, in64) on Windows and Linux + # see https://github.com/numpy/numpy/issues/9464 + if (isinstance(dtype, str) and dtype == "int") or (dtype is int): + converted_dtypes.append(np.int32) + converted_dtypes.append(np.int64) + elif dtype == "float" or dtype is float: + # GH#42452 : np.dtype("float") coerces to np.float64 from Numpy 1.20 + converted_dtypes.extend([np.float64, np.float32]) + else: + converted_dtypes.append(infer_dtype_from_object(dtype)) + return frozenset(converted_dtypes) + + include = check_int_infer_dtype(include) + exclude = check_int_infer_dtype(exclude) + + for dtypes in (include, exclude): + invalidate_string_dtypes(dtypes) + + # can't both include AND exclude! + if not include.isdisjoint(exclude): + raise ValueError(f"include and exclude overlap on {(include & exclude)}") + + def dtype_predicate(dtype: DtypeObj, dtypes_set) -> bool: + # GH 46870: BooleanDtype._is_numeric == True but should be excluded + dtype = dtype if not isinstance(dtype, ArrowDtype) else dtype.numpy_dtype + return ( + issubclass(dtype.type, tuple(dtypes_set)) + or ( + np.number in dtypes_set + and getattr(dtype, "_is_numeric", False) + and not is_bool_dtype(dtype) + ) + # backwards compat for the default `str` dtype being selected by object + or ( + isinstance(dtype, StringDtype) + and dtype.na_value is np.nan + and np.object_ in dtypes_set + ) + ) + + def predicate(arr: ArrayLike) -> bool: + dtype = arr.dtype + if include: + if not dtype_predicate(dtype, include): + return False + + if exclude: + if dtype_predicate(dtype, exclude): + return False + + return True + + blk_dtypes = [blk.dtype for blk in self._mgr.blocks] + if ( + np.object_ in include + and str not in include + and str not in exclude + and any( + isinstance(dtype, StringDtype) and dtype.na_value is np.nan + for dtype in blk_dtypes + ) + ): + # GH#61916 + warnings.warn( + "For backward compatibility, 'str' dtypes are included by " + "select_dtypes when 'object' dtype is specified. " + "This behavior is deprecated and will be removed in a future " + "version. Explicitly pass 'str' to `include` to select them, " + "or to `exclude` to remove them and silence this warning.\nSee " + "https://pandas.pydata.org/docs/user_guide/migration-3-strings.html" + "#string-migration-select-dtypes for details on how to write code " + "that works with pandas 2 and 3.", + Pandas4Warning, + stacklevel=find_stack_level(), + ) + + mgr = self._mgr._get_data_subset(predicate).copy(deep=False) + return self._constructor_from_mgr(mgr, axes=mgr.axes).__finalize__(self) + + def insert( + self, + loc: int, + column: Hashable, + value: object, + allow_duplicates: bool | lib.NoDefault = lib.no_default, + ) -> None: + """ + Insert column into DataFrame at specified location. + + Raises a ValueError if `column` is already contained in the DataFrame, + unless `allow_duplicates` is set to True. + + Parameters + ---------- + loc : int + Insertion index. Must verify 0 <= loc <= len(columns). + column : str, number, or hashable object + Label of the inserted column. + value : Scalar, Series, or array-like + Content of the inserted column. + allow_duplicates : bool, optional, default lib.no_default + Allow duplicate column labels to be created. + + See Also + -------- + Index.insert : Insert new item by index. + + Examples + -------- + >>> df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + >>> df + col1 col2 + 0 1 3 + 1 2 4 + >>> df.insert(1, "newcol", [99, 99]) + >>> df + col1 newcol col2 + 0 1 99 3 + 1 2 99 4 + >>> df.insert(0, "col1", [100, 100], allow_duplicates=True) + >>> df + col1 col1 newcol col2 + 0 100 1 99 3 + 1 100 2 99 4 + + Notice that pandas uses index alignment in case of `value` from type `Series`: + + >>> df.insert(0, "col0", pd.Series([5, 6], index=[1, 2])) + >>> df + col0 col1 col1 newcol col2 + 0 NaN 100 1 99 3 + 1 5.0 100 2 99 4 + """ + if allow_duplicates is lib.no_default: + allow_duplicates = False + if allow_duplicates and not self.flags.allows_duplicate_labels: + raise ValueError( + "Cannot specify 'allow_duplicates=True' when " + "'self.flags.allows_duplicate_labels' is False." + ) + if not allow_duplicates and column in self.columns: + # Should this be a different kind of error?? + raise ValueError(f"cannot insert {column}, already exists") + if not is_integer(loc): + raise TypeError("loc must be int") + # convert non stdlib ints to satisfy typing checks + loc = int(loc) + if isinstance(value, DataFrame) and len(value.columns) > 1: + raise ValueError( + f"Expected a one-dimensional object, got a DataFrame with " + f"{len(value.columns)} columns instead." + ) + elif isinstance(value, DataFrame): + value = value.iloc[:, 0] + + value, refs = self._sanitize_column(value) + self._mgr.insert(loc, column, value, refs=refs) + + def assign(self, **kwargs) -> DataFrame: + r""" + Assign new columns to a DataFrame. + + Returns a new object with all original columns in addition to new ones. + Existing columns that are re-assigned will be overwritten. + + Parameters + ---------- + **kwargs : callable or Series + The column names are keywords. If the values are + callable, they are computed on the DataFrame and + assigned to the new columns. The callable must not + change input DataFrame (though pandas doesn't check it). + If the values are not callable, (e.g. a Series, scalar, or array), + they are simply assigned. + + Returns + ------- + DataFrame + A new DataFrame with the new columns in addition to + all the existing columns. + + See Also + -------- + DataFrame.loc : Select a subset of a DataFrame by labels. + DataFrame.iloc : Select a subset of a DataFrame by positions. + + Notes + ----- + Assigning multiple columns within the same ``assign`` is possible. + Later items in '\*\*kwargs' may refer to newly created or modified + columns in 'df'; items are computed and assigned into 'df' in order. + + Examples + -------- + >>> df = pd.DataFrame({"temp_c": [17.0, 25.0]}, index=["Portland", "Berkeley"]) + >>> df + temp_c + Portland 17.0 + Berkeley 25.0 + + Where the value is a callable, evaluated on `df`: + + >>> df.assign(temp_f=lambda x: x.temp_c * 9 / 5 + 32) + temp_c temp_f + Portland 17.0 62.6 + Berkeley 25.0 77.0 + + Alternatively, the same behavior can be achieved by directly + referencing an existing Series or sequence: + + >>> df.assign(temp_f=df["temp_c"] * 9 / 5 + 32) + temp_c temp_f + Portland 17.0 62.6 + Berkeley 25.0 77.0 + + or by using :meth:`pandas.col`: + + >>> df.assign(temp_f=pd.col("temp_c") * 9 / 5 + 32) + temp_c temp_f + Portland 17.0 62.6 + Berkeley 25.0 77.0 + + You can create multiple columns within the same assign where one + of the columns depends on another one defined within the same assign: + + >>> df.assign( + ... temp_f=lambda x: x["temp_c"] * 9 / 5 + 32, + ... temp_k=lambda x: (x["temp_f"] + 459.67) * 5 / 9, + ... ) + temp_c temp_f temp_k + Portland 17.0 62.6 290.15 + Berkeley 25.0 77.0 298.15 + """ + data = self.copy(deep=False) + + for k, v in kwargs.items(): + data[k] = com.apply_if_callable(v, data) + return data + + def _sanitize_column(self, value) -> tuple[ArrayLike, BlockValuesRefs | None]: + """ + Ensures new columns (which go into the BlockManager as new blocks) are + always copied (or a reference is being tracked to them under CoW) + and converted into an array. + + Parameters + ---------- + value : scalar, Series, or array-like + + Returns + ------- + tuple of numpy.ndarray or ExtensionArray and optional BlockValuesRefs + """ + self._ensure_valid_index(value) + + # Using a DataFrame would mean coercing values to one dtype + assert not isinstance(value, DataFrame) + if is_dict_like(value): + if not isinstance(value, Series): + value = Series(value) + return _reindex_for_setitem(value, self.index) + + if is_list_like(value): + com.require_length_match(value, self.index) + return sanitize_array(value, self.index, copy=True, allow_2d=True), None + + @property + def _series(self): + return {item: self._ixs(idx, axis=1) for idx, item in enumerate(self.columns)} + + # ---------------------------------------------------------------------- + # Reindexing and alignment + + def _reindex_multi(self, axes: dict[str, Index], fill_value) -> DataFrame: + """ + We are guaranteed non-Nones in the axes. + """ + + new_index, row_indexer = self.index.reindex(axes["index"]) + new_columns, col_indexer = self.columns.reindex(axes["columns"]) + + if row_indexer is not None and col_indexer is not None: + # Fastpath. By doing two 'take's at once we avoid making an + # unnecessary copy. + # We only get here with `self._can_fast_transpose`, which (almost) + # ensures that self.values is cheap. It may be worth making this + # condition more specific. + indexer = row_indexer, col_indexer + new_values = take_2d_multi(self.values, indexer, fill_value=fill_value) + return self._constructor( + new_values, index=new_index, columns=new_columns, copy=False + ) + else: + return self._reindex_with_indexers( + {0: [new_index, row_indexer], 1: [new_columns, col_indexer]}, + fill_value=fill_value, + ) + + def set_axis( + self, + labels, + *, + axis: Axis = 0, + copy: bool | lib.NoDefault = lib.no_default, + ) -> DataFrame: + """ + Assign desired index to given axis. + + Indexes for column or row labels can be changed by assigning + a list-like or Index. + + Parameters + ---------- + labels : list-like, Index + The values for the new index. + + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to update. The value 0 identifies the rows. For `Series` + this parameter is unused and defaults to 0. + + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + Returns + ------- + DataFrame + An object of type DataFrame. + + See Also + -------- + DataFrame.rename_axis : Alter the name of the index or columns. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + Change the row labels. + + >>> df.set_axis(["a", "b", "c"], axis="index") + A B + a 1 4 + b 2 5 + c 3 6 + + Change the column labels. + + >>> df.set_axis(["I", "II"], axis="columns") + I II + 0 1 4 + 1 2 5 + 2 3 6 + """ + return super().set_axis(labels, axis=axis, copy=copy) + + def reindex( + self, + labels=None, + *, + index=None, + columns=None, + axis: Axis | None = None, + method: ReindexMethod | None = None, + copy: bool | lib.NoDefault = lib.no_default, + level: Level | None = None, + fill_value: Scalar | None = np.nan, + limit: int | None = None, + tolerance=None, + ) -> DataFrame: + """ + Conform DataFrame to new index with optional filling logic. + + Places NA/NaN in locations having no value in the previous index. A new object + is produced unless the new index is equivalent to the current one and + ``copy=False``. + + Parameters + ---------- + + labels : array-like, optional + New labels / index to conform the axis specified by 'axis' to. + index : array-like, optional + New labels for the index. Preferably an Index object to avoid + duplicating data. + columns : array-like, optional + New labels for the columns. Preferably an Index object to avoid + duplicating data. + axis : int or str, optional + Axis to target. Can be either the axis name ('index', 'columns') + or number (0, 1). + method : {None, 'backfill'/'bfill', 'pad'/'ffill', 'nearest'} + Method to use for filling holes in reindexed DataFrame. + Please note: this is only applicable to DataFrames/Series with a + monotonically increasing/decreasing index. + + * None (default): don't fill gaps + * pad / ffill: Propagate last valid observation forward to next + valid. + * backfill / bfill: Use next valid observation to fill gap. + * nearest: Use nearest valid observations to fill gap. + + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + level : int or name + Broadcast across a level, matching Index values on the + passed MultiIndex level. + fill_value : scalar, default np.nan + Value to use for missing values. Defaults to NaN, but can be any + "compatible" value. + limit : int, default None + Maximum number of consecutive elements to forward or backward fill. + tolerance : optional + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations most + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + + Tolerance may be a scalar value, which applies the same tolerance + to all values, or list-like, which applies variable tolerance per + element. List-like includes list, tuple, array, Series, and must be + the same size as the index and its dtype must exactly match the + index's type. + + Returns + ------- + DataFrame + DataFrame with changed index. + + See Also + -------- + DataFrame.set_index : Set row labels. + DataFrame.reset_index : Remove row labels or move them to new columns. + DataFrame.reindex_like : Change to same indices as other DataFrame. + + Examples + -------- + ``DataFrame.reindex`` supports two calling conventions + + * ``(index=index_labels, columns=column_labels, ...)`` + * ``(labels, axis={'index', 'columns'}, ...)`` + + We *highly* recommend using keyword arguments to clarify your + intent. + + Create a DataFrame with some fictional data. + + >>> index = ["Firefox", "Chrome", "Safari", "IE10", "Konqueror"] + >>> columns = ["http_status", "response_time"] + >>> df = pd.DataFrame( + ... [[200, 0.04], [200, 0.02], [404, 0.07], [404, 0.08], [301, 1.0]], + ... columns=columns, + ... index=index, + ... ) + >>> df + http_status response_time + Firefox 200 0.04 + Chrome 200 0.02 + Safari 404 0.07 + IE10 404 0.08 + Konqueror 301 1.00 + + Create a new index and reindex the DataFrame. By default + values in the new index that do not have corresponding + records in the DataFrame are assigned ``NaN``. + + >>> new_index = ["Safari", "Iceweasel", "Comodo Dragon", "IE10", "Chrome"] + >>> df.reindex(new_index) + http_status response_time + Safari 404.0 0.07 + Iceweasel NaN NaN + Comodo Dragon NaN NaN + IE10 404.0 0.08 + Chrome 200.0 0.02 + + We can fill in the missing values by passing a value to + the keyword ``fill_value``. Because the index is not monotonically + increasing or decreasing, we cannot use arguments to the keyword + ``method`` to fill the ``NaN`` values. + + >>> df.reindex(new_index, fill_value=0) + http_status response_time + Safari 404 0.07 + Iceweasel 0 0.00 + Comodo Dragon 0 0.00 + IE10 404 0.08 + Chrome 200 0.02 + + >>> df.reindex(new_index, fill_value="missing") + http_status response_time + Safari 404 0.07 + Iceweasel missing missing + Comodo Dragon missing missing + IE10 404 0.08 + Chrome 200 0.02 + + We can also reindex the columns. + + >>> df.reindex(columns=["http_status", "user_agent"]) + http_status user_agent + Firefox 200 NaN + Chrome 200 NaN + Safari 404 NaN + IE10 404 NaN + Konqueror 301 NaN + + Or we can use "axis-style" keyword arguments + + >>> df.reindex(["http_status", "user_agent"], axis="columns") + http_status user_agent + Firefox 200 NaN + Chrome 200 NaN + Safari 404 NaN + IE10 404 NaN + Konqueror 301 NaN + + To further illustrate the filling functionality in + ``reindex``, we will create a DataFrame with a + monotonically increasing index (for example, a sequence + of dates). + + >>> date_index = pd.date_range("1/1/2010", periods=6, freq="D") + >>> df2 = pd.DataFrame( + ... {"prices": [100, 101, np.nan, 100, 89, 88]}, index=date_index + ... ) + >>> df2 + prices + 2010-01-01 100.0 + 2010-01-02 101.0 + 2010-01-03 NaN + 2010-01-04 100.0 + 2010-01-05 89.0 + 2010-01-06 88.0 + + Suppose we decide to expand the DataFrame to cover a wider + date range. + + >>> date_index2 = pd.date_range("12/29/2009", periods=10, freq="D") + >>> df2.reindex(date_index2) + prices + 2009-12-29 NaN + 2009-12-30 NaN + 2009-12-31 NaN + 2010-01-01 100.0 + 2010-01-02 101.0 + 2010-01-03 NaN + 2010-01-04 100.0 + 2010-01-05 89.0 + 2010-01-06 88.0 + 2010-01-07 NaN + + The index entries that did not have a value in the original data frame + (for example, '2009-12-29') are by default filled with ``NaN``. + If desired, we can fill in the missing values using one of several + options. + + For example, to back-propagate the last valid value to fill the ``NaN`` + values, pass ``bfill`` as an argument to the ``method`` keyword. + + >>> df2.reindex(date_index2, method="bfill") + prices + 2009-12-29 100.0 + 2009-12-30 100.0 + 2009-12-31 100.0 + 2010-01-01 100.0 + 2010-01-02 101.0 + 2010-01-03 NaN + 2010-01-04 100.0 + 2010-01-05 89.0 + 2010-01-06 88.0 + 2010-01-07 NaN + + Please note that the ``NaN`` value present in the original DataFrame + (at index value 2010-01-03) will not be filled by any of the + value propagation schemes. This is because filling while reindexing + does not look at DataFrame values, but only compares the original and + desired indexes. If you do want to fill in the ``NaN`` values present + in the original DataFrame, use the ``fillna()`` method. + + See the :ref:`user guide ` for more. + """ + return super().reindex( + labels=labels, + index=index, + columns=columns, + axis=axis, + method=method, + level=level, + fill_value=fill_value, + limit=limit, + tolerance=tolerance, + copy=copy, + ) + + @overload + def drop( + self, + labels: IndexLabel | ListLike = ..., + *, + axis: Axis = ..., + index: IndexLabel | ListLike = ..., + columns: IndexLabel | ListLike = ..., + level: Level = ..., + inplace: Literal[True], + errors: IgnoreRaise = ..., + ) -> None: ... + + @overload + def drop( + self, + labels: IndexLabel | ListLike = ..., + *, + axis: Axis = ..., + index: IndexLabel | ListLike = ..., + columns: IndexLabel | ListLike = ..., + level: Level = ..., + inplace: Literal[False] = ..., + errors: IgnoreRaise = ..., + ) -> DataFrame: ... + + @overload + def drop( + self, + labels: IndexLabel | ListLike = ..., + *, + axis: Axis = ..., + index: IndexLabel | ListLike = ..., + columns: IndexLabel | ListLike = ..., + level: Level = ..., + inplace: bool = ..., + errors: IgnoreRaise = ..., + ) -> DataFrame | None: ... + + def drop( + self, + labels: IndexLabel | ListLike = None, + *, + axis: Axis = 0, + index: IndexLabel | ListLike = None, + columns: IndexLabel | ListLike = None, + level: Level | None = None, + inplace: bool = False, + errors: IgnoreRaise = "raise", + ) -> DataFrame | None: + """ + Drop specified labels from rows or columns. + + Remove rows or columns by specifying label names and corresponding + axis, or by directly specifying index or column names. When using a + multi-index, labels on different levels can be removed by specifying + the level. See the :ref:`user guide ` + for more information about the now unused levels. + + Parameters + ---------- + labels : single label or iterable of labels + Index or column labels to drop. A tuple will be used as a single + label and not treated as an iterable. + axis : {0 or 'index', 1 or 'columns'}, default 0 + Whether to drop labels from the index (0 or 'index') or + columns (1 or 'columns'). + index : single label or iterable of labels + Alternative to specifying axis (``labels, axis=0`` + is equivalent to ``index=labels``). + columns : single label or iterable of labels + Alternative to specifying axis (``labels, axis=1`` + is equivalent to ``columns=labels``). + level : int or level name, optional + For MultiIndex, level from which the labels will be removed. + inplace : bool, default False + If False, return a copy. Otherwise, do operation + in place and return None. + errors : {'ignore', 'raise'}, default 'raise' + If 'ignore', suppress error and only existing labels are + dropped. + + Returns + ------- + DataFrame or None + Returns DataFrame or None DataFrame with the specified + index or column labels removed or None if inplace=True. + + Raises + ------ + KeyError + If any of the labels is not found in the selected axis. + + See Also + -------- + DataFrame.loc : Label-location based indexer for selection by label. + DataFrame.dropna : Return DataFrame with labels on given axis omitted + where (all or any) data are missing. + DataFrame.drop_duplicates : Return DataFrame with duplicate rows + removed, optionally only considering certain columns. + Series.drop : Return Series with specified index labels removed. + + Examples + -------- + >>> df = pd.DataFrame(np.arange(12).reshape(3, 4), columns=["A", "B", "C", "D"]) + >>> df + A B C D + 0 0 1 2 3 + 1 4 5 6 7 + 2 8 9 10 11 + + Drop columns + + >>> df.drop(["B", "C"], axis=1) + A D + 0 0 3 + 1 4 7 + 2 8 11 + + >>> df.drop(columns=["B", "C"]) + A D + 0 0 3 + 1 4 7 + 2 8 11 + + Drop a row by index + + >>> df.drop([0, 1]) + A B C D + 2 8 9 10 11 + + Drop columns and/or rows of MultiIndex DataFrame + + >>> midx = pd.MultiIndex( + ... levels=[["llama", "cow", "falcon"], ["speed", "weight", "length"]], + ... codes=[[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]], + ... ) + >>> df = pd.DataFrame( + ... index=midx, + ... columns=["big", "small"], + ... data=[ + ... [45, 30], + ... [200, 100], + ... [1.5, 1], + ... [30, 20], + ... [250, 150], + ... [1.5, 0.8], + ... [320, 250], + ... [1, 0.8], + ... [0.3, 0.2], + ... ], + ... ) + >>> df + big small + llama speed 45.0 30.0 + weight 200.0 100.0 + length 1.5 1.0 + cow speed 30.0 20.0 + weight 250.0 150.0 + length 1.5 0.8 + falcon speed 320.0 250.0 + weight 1.0 0.8 + length 0.3 0.2 + + Drop a specific index combination from the MultiIndex + DataFrame, i.e., drop the combination ``'falcon'`` and + ``'weight'``, which deletes only the corresponding row + + >>> df.drop(index=("falcon", "weight")) + big small + llama speed 45.0 30.0 + weight 200.0 100.0 + length 1.5 1.0 + cow speed 30.0 20.0 + weight 250.0 150.0 + length 1.5 0.8 + falcon speed 320.0 250.0 + length 0.3 0.2 + + >>> df.drop(index="cow", columns="small") + big + llama speed 45.0 + weight 200.0 + length 1.5 + falcon speed 320.0 + weight 1.0 + length 0.3 + + >>> df.drop(index="length", level=1) + big small + llama speed 45.0 30.0 + weight 200.0 100.0 + cow speed 30.0 20.0 + weight 250.0 150.0 + falcon speed 320.0 250.0 + weight 1.0 0.8 + """ + return super().drop( + labels=labels, + axis=axis, + index=index, + columns=columns, + level=level, + inplace=inplace, + errors=errors, + ) + + @overload + def rename( + self, + mapper: Renamer | None = ..., + *, + index: Renamer | None = ..., + columns: Renamer | None = ..., + axis: Axis | None = ..., + copy: bool | lib.NoDefault = lib.no_default, + inplace: Literal[True], + level: Level = ..., + errors: IgnoreRaise = ..., + ) -> None: ... + + @overload + def rename( + self, + mapper: Renamer | None = ..., + *, + index: Renamer | None = ..., + columns: Renamer | None = ..., + axis: Axis | None = ..., + copy: bool | lib.NoDefault = lib.no_default, + inplace: Literal[False] = ..., + level: Level = ..., + errors: IgnoreRaise = ..., + ) -> DataFrame: ... + + @overload + def rename( + self, + mapper: Renamer | None = ..., + *, + index: Renamer | None = ..., + columns: Renamer | None = ..., + axis: Axis | None = ..., + copy: bool | lib.NoDefault = lib.no_default, + inplace: bool = ..., + level: Level = ..., + errors: IgnoreRaise = ..., + ) -> DataFrame | None: ... + + def rename( + self, + mapper: Renamer | None = None, + *, + index: Renamer | None = None, + columns: Renamer | None = None, + axis: Axis | None = None, + copy: bool | lib.NoDefault = lib.no_default, + inplace: bool = False, + level: Level | None = None, + errors: IgnoreRaise = "ignore", + ) -> DataFrame | None: + """ + Rename columns or index labels. + + Function / dict values must be unique (1-to-1). Labels not contained in + a dict / Series will be left as-is. Extra labels listed don't throw an + error. + + See the :ref:`user guide ` for more. + + Parameters + ---------- + mapper : dict-like or function + Dict-like or function transformations to apply to + that axis' values. Use either ``mapper`` and ``axis`` to + specify the axis to target with ``mapper``, or ``index`` and + ``columns``. + index : dict-like or function + Alternative to specifying axis (``mapper, axis=0`` + is equivalent to ``index=mapper``). + columns : dict-like or function + Alternative to specifying axis (``mapper, axis=1`` + is equivalent to ``columns=mapper``). + axis : {0 or 'index', 1 or 'columns'}, default 0 + Axis to target with ``mapper``. Can be either the axis name + ('index', 'columns') or number (0, 1). The default is 'index'. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + inplace : bool, default False + Whether to modify the DataFrame rather than creating a new one. + If True then value of copy is ignored. + level : int or level name, default None + In case of a MultiIndex, only rename labels in the specified + level. + errors : {'ignore', 'raise'}, default 'ignore' + If 'raise', raise a `KeyError` when a dict-like `mapper`, `index`, + or `columns` contains labels that are not present in the Index + being transformed. + If 'ignore', existing keys will be renamed and extra keys will be + ignored. + + Returns + ------- + DataFrame or None + DataFrame with the renamed axis labels or None if ``inplace=True``. + + Raises + ------ + KeyError + If any of the labels is not found in the selected axis and + "errors='raise'". + + See Also + -------- + DataFrame.rename_axis : Set the name of the axis. + + Examples + -------- + ``DataFrame.rename`` supports two calling conventions + + * ``(index=index_mapper, columns=columns_mapper, ...)`` + * ``(mapper, axis={'index', 'columns'}, ...)`` + + We *highly* recommend using keyword arguments to clarify your + intent. + + Rename columns using a mapping: + + >>> df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + >>> df.rename(columns={"A": "a", "B": "c"}) + a c + 0 1 4 + 1 2 5 + 2 3 6 + + Rename index using a mapping: + + >>> df.rename(index={0: "x", 1: "y", 2: "z"}) + A B + x 1 4 + y 2 5 + z 3 6 + + Cast index labels to a different type: + + >>> df.index + RangeIndex(start=0, stop=3, step=1) + >>> df.rename(index=str).index + Index(['0', '1', '2'], dtype='str') + + >>> df.rename(columns={"A": "a", "B": "b", "C": "c"}, errors="raise") + Traceback (most recent call last): + KeyError: ['C'] not found in axis + + Using axis-style parameters: + + >>> df.rename(str.lower, axis="columns") + a b + 0 1 4 + 1 2 5 + 2 3 6 + + >>> df.rename({1: 2, 2: 4}, axis="index") + A B + 0 1 4 + 2 2 5 + 4 3 6 + """ + self._check_copy_deprecation(copy) + return super()._rename( + mapper=mapper, + index=index, + columns=columns, + axis=axis, + inplace=inplace, + level=level, + errors=errors, + ) + + def pop(self, item: Hashable) -> Series: + """ + Return item and drop it from DataFrame. Raise KeyError if not found. + + Parameters + ---------- + item : label + Label of column to be popped. + + Returns + ------- + Series + Series representing the item that is dropped. + + See Also + -------- + DataFrame.drop: Drop specified labels from rows or columns. + DataFrame.drop_duplicates: Return DataFrame with duplicate rows removed. + + Examples + -------- + >>> df = pd.DataFrame( + ... [ + ... ("falcon", "bird", 389.0), + ... ("parrot", "bird", 24.0), + ... ("lion", "mammal", 80.5), + ... ("monkey", "mammal", np.nan), + ... ], + ... columns=("name", "class", "max_speed"), + ... ) + >>> df + name class max_speed + 0 falcon bird 389.0 + 1 parrot bird 24.0 + 2 lion mammal 80.5 + 3 monkey mammal NaN + + >>> df.pop("class") + 0 bird + 1 bird + 2 mammal + 3 mammal + Name: class, dtype: str + + >>> df + name max_speed + 0 falcon 389.0 + 1 parrot 24.0 + 2 lion 80.5 + 3 monkey NaN + """ + return super().pop(item=item) + + def _replace_columnwise( + self, mapping: dict[Hashable, tuple[Any, Any]], inplace: bool, regex + ) -> Self: + """ + Dispatch to Series.replace column-wise. + + Parameters + ---------- + mapping : dict + of the form {col: (target, value)} + inplace : bool + regex : bool or same types as `to_replace` in DataFrame.replace + + Returns + ------- + DataFrame + """ + # Operate column-wise + res = self if inplace else self.copy(deep=False) + ax = self.columns + + for i, ax_value in enumerate(ax): + if ax_value in mapping: + ser = self.iloc[:, i] + + target, value = mapping[ax_value] + newobj = ser.replace(target, value, regex=regex) + + res._iset_item(i, newobj, inplace=inplace) + + return res if inplace else res.__finalize__(self) + + def shift( + self, + periods: int | Sequence[int] = 1, + freq: Frequency | None = None, + axis: Axis = 0, + fill_value: Hashable = lib.no_default, + suffix: str | None = None, + ) -> DataFrame: + """ + Shift index by desired number of periods with an optional time `freq`. + + When `freq` is not passed, shift the index without realigning the data. + If `freq` is passed (in this case, the index must be date or datetime, + or it will raise a `NotImplementedError`), the index will be + increased using the periods and the `freq`. `freq` can be inferred + when specified as "infer" as long as either freq or inferred_freq + attribute is set in the index. + + Parameters + ---------- + periods : int or Sequence + Number of periods to shift. Can be positive or negative. + If an iterable of ints, the data will be shifted once by each int. + This is equivalent to shifting by one value at a time and + concatenating all resulting frames. The resulting columns will have + the shift suffixed to their column names. For multiple periods, + axis must not be 1. + freq : DateOffset, tseries.offsets, timedelta, or str, optional + Offset to use from the tseries module or time rule (e.g. 'EOM'). + If `freq` is specified then the index values are shifted but the + data is not realigned. That is, use `freq` if you would like to + extend the index when shifting and preserve the original data. + If `freq` is specified as "infer" then it will be inferred from + the freq or inferred_freq attributes of the index. If neither of + those attributes exist, a ValueError is thrown. + axis : {0 or 'index', 1 or 'columns', None}, default None + Shift direction. For `Series` this parameter is unused and defaults to 0. + fill_value : object, optional + The scalar value to use for newly introduced missing values. + the default depends on the dtype of `self`. + For Boolean and numeric NumPy data types, ``np.nan`` is used. + For datetime, timedelta, or period data, etc. :attr:`NaT` is used. + For extension dtypes, ``self.dtype.na_value`` is used. + suffix : str, optional + If str and periods is an iterable, this is added after the column + name and before the shift value for each shifted column name. + For `Series` this parameter is unused and defaults to `None`. + + Returns + ------- + DataFrame + Copy of input object, shifted. + + See Also + -------- + Index.shift : Shift values of Index. + DatetimeIndex.shift : Shift values of DatetimeIndex. + PeriodIndex.shift : Shift values of PeriodIndex. + + Examples + -------- + >>> df = pd.DataFrame( + ... [[10, 13, 17], [20, 23, 27], [15, 18, 22], [30, 33, 37], [45, 48, 52]], + ... columns=["Col1", "Col2", "Col3"], + ... index=pd.date_range("2020-01-01", "2020-01-05"), + ... ) + >>> df + Col1 Col2 Col3 + 2020-01-01 10 13 17 + 2020-01-02 20 23 27 + 2020-01-03 15 18 22 + 2020-01-04 30 33 37 + 2020-01-05 45 48 52 + + >>> df.shift(periods=3) + Col1 Col2 Col3 + 2020-01-01 NaN NaN NaN + 2020-01-02 NaN NaN NaN + 2020-01-03 NaN NaN NaN + 2020-01-04 10.0 13.0 17.0 + 2020-01-05 20.0 23.0 27.0 + + >>> df.shift(periods=1, axis="columns") + Col1 Col2 Col3 + 2020-01-01 NaN 10 13 + 2020-01-02 NaN 20 23 + 2020-01-03 NaN 15 18 + 2020-01-04 NaN 30 33 + 2020-01-05 NaN 45 48 + + >>> df.shift(periods=3, fill_value=0) + Col1 Col2 Col3 + 2020-01-01 0 0 0 + 2020-01-02 0 0 0 + 2020-01-03 0 0 0 + 2020-01-04 10 13 17 + 2020-01-05 20 23 27 + + >>> df.shift(periods=3, freq="D") + Col1 Col2 Col3 + 2020-01-04 10 13 17 + 2020-01-05 20 23 27 + 2020-01-06 15 18 22 + 2020-01-07 30 33 37 + 2020-01-08 45 48 52 + + >>> df.shift(periods=3, freq="infer") + Col1 Col2 Col3 + 2020-01-04 10 13 17 + 2020-01-05 20 23 27 + 2020-01-06 15 18 22 + 2020-01-07 30 33 37 + 2020-01-08 45 48 52 + + >>> df["Col1"].shift(periods=[0, 1, 2]) + Col1_0 Col1_1 Col1_2 + 2020-01-01 10 NaN NaN + 2020-01-02 20 10.0 NaN + 2020-01-03 15 20.0 10.0 + 2020-01-04 30 15.0 20.0 + 2020-01-05 45 30.0 15.0 + """ + if freq is not None and fill_value is not lib.no_default: + # GH#53832 + raise ValueError( + "Passing a 'freq' together with a 'fill_value' is not allowed." + ) + + if self.empty and freq is None: + return self.copy() + + axis = self._get_axis_number(axis) + + if is_list_like(periods): + periods = cast(Sequence, periods) + if axis == 1: + raise ValueError( + "If `periods` contains multiple shifts, `axis` cannot be 1." + ) + if len(periods) == 0: + raise ValueError("If `periods` is an iterable, it cannot be empty.") + from pandas.core.reshape.concat import concat + + shifted_dataframes = [] + for period in periods: + if not is_integer(period): + raise TypeError( + f"Periods must be integer, but {period} is {type(period)}." + ) + period = cast(int, period) + shifted_dataframes.append( + super() + .shift(periods=period, freq=freq, axis=axis, fill_value=fill_value) + .add_suffix(f"{suffix}_{period}" if suffix else f"_{period}") + ) + return concat(shifted_dataframes, axis=1, sort=False) + elif suffix: + raise ValueError("Cannot specify `suffix` if `periods` is an int.") + periods = cast(int, periods) + + ncols = len(self.columns) + if axis == 1 and periods != 0 and ncols > 0 and freq is None: + if fill_value is lib.no_default: + # We will infer fill_value to match the closest column + + # Use a column that we know is valid for our column's dtype GH#38434 + label = self.columns[0] + + if periods > 0: + result = self.iloc[:, :-periods] + for col in range(min(ncols, abs(periods))): + # TODO(EA2D): doing this in a loop unnecessary with 2D EAs + # Define filler inside loop so we get a copy + filler = self.iloc[:, 0].shift(len(self)) + result.insert(0, label, filler, allow_duplicates=True) + else: + result = self.iloc[:, -periods:] + for col in range(min(ncols, abs(periods))): + # Define filler inside loop so we get a copy + filler = self.iloc[:, -1].shift(len(self)) + result.insert( + len(result.columns), label, filler, allow_duplicates=True + ) + + result.columns = self.columns.copy() + return result + elif len(self._mgr.blocks) > 1 or ( + # If we only have one block and we know that we can't + # keep the same dtype (i.e. the _can_hold_element check) + # then we can go through the reindex_indexer path + # (and avoid casting logic in the Block method). + not can_hold_element(self._mgr.blocks[0].values, fill_value) + ): + # GH#35488 we need to watch out for multi-block cases + # We only get here with fill_value not-lib.no_default + nper = abs(periods) + nper = min(nper, ncols) + if periods > 0: + indexer = np.array( + [-1] * nper + list(range(ncols - periods)), dtype=np.intp + ) + else: + indexer = np.array( + list(range(nper, ncols)) + [-1] * nper, dtype=np.intp + ) + mgr = self._mgr.reindex_indexer( + self.columns, + indexer, + axis=0, + fill_value=fill_value, + allow_dups=True, + ) + res_df = self._constructor_from_mgr(mgr, axes=mgr.axes) + return res_df.__finalize__(self, method="shift") + else: + return self.T.shift(periods=periods, fill_value=fill_value).T + + return super().shift( + periods=periods, freq=freq, axis=axis, fill_value=fill_value + ) + + @overload + def set_index( + self, + keys, + *, + drop: bool = ..., + append: bool = ..., + inplace: Literal[False] = ..., + verify_integrity: bool | lib.NoDefault = ..., + ) -> DataFrame: ... + + @overload + def set_index( + self, + keys, + *, + drop: bool = ..., + append: bool = ..., + inplace: Literal[True], + verify_integrity: bool | lib.NoDefault = ..., + ) -> None: ... + + def set_index( + self, + keys, + *, + drop: bool = True, + append: bool = False, + inplace: bool = False, + verify_integrity: bool | lib.NoDefault = lib.no_default, + ) -> DataFrame | None: + """ + Set the DataFrame index using existing columns. + + Set the DataFrame index (row labels) using one or more existing + columns or arrays (of the correct length). The index can replace the + existing index or expand on it. + + Parameters + ---------- + keys : label or array-like or list of labels/arrays + This parameter can be either a single column key, a single array of + the same length as the calling DataFrame, or a list containing an + arbitrary combination of column keys and arrays. Here, "array" + encompasses :class:`Series`, :class:`Index`, ``np.ndarray``, and + instances of :class:`~collections.abc.Iterator`. + drop : bool, default True + Delete columns to be used as the new index. + append : bool, default False + Whether to append columns to existing index. + Setting to True will add the new columns to existing index. + When set to False, the current index will be dropped from the DataFrame. + inplace : bool, default False + Whether to modify the DataFrame rather than creating a new one. + verify_integrity : bool, default False + Check the new index for duplicates. Otherwise defer the check until + necessary. Setting to False will improve the performance of this + method. + + .. deprecated:: 3.0.0 + + Returns + ------- + DataFrame or None + Changed row labels or None if ``inplace=True``. + + See Also + -------- + DataFrame.reset_index : Opposite of set_index. + DataFrame.reindex : Change to new indices or expand indices. + DataFrame.reindex_like : Change to same indices as other DataFrame. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "month": [1, 4, 7, 10], + ... "year": [2012, 2014, 2013, 2014], + ... "sale": [55, 40, 84, 31], + ... } + ... ) + >>> df + month year sale + 0 1 2012 55 + 1 4 2014 40 + 2 7 2013 84 + 3 10 2014 31 + + Set the index to become the 'month' column: + + >>> df.set_index("month") + year sale + month + 1 2012 55 + 4 2014 40 + 7 2013 84 + 10 2014 31 + + Create a MultiIndex using columns 'year' and 'month': + + >>> df.set_index(["year", "month"]) + sale + year month + 2012 1 55 + 2014 4 40 + 2013 7 84 + 2014 10 31 + + Create a MultiIndex using an Index and a column: + + >>> df.set_index([pd.Index([1, 2, 3, 4]), "year"]) + month sale + year + 1 2012 1 55 + 2 2014 4 40 + 3 2013 7 84 + 4 2014 10 31 + + Create a MultiIndex using two Series: + + >>> s = pd.Series([1, 2, 3, 4]) + >>> df.set_index([s, s**2]) + month year sale + 1 1 1 2012 55 + 2 4 4 2014 40 + 3 9 7 2013 84 + 4 16 10 2014 31 + + Append a column to the existing index: + + >>> df = df.set_index("month") + >>> df.set_index("year", append=True) + sale + month year + 1 2012 55 + 4 2014 40 + 7 2013 84 + 10 2014 31 + + >>> df.set_index("year", append=False) + sale + year + 2012 55 + 2014 40 + 2013 84 + 2014 31 + """ + if verify_integrity is not lib.no_default: + # GH#62919 + warnings.warn( + "The 'verify_integrity' keyword in DataFrame.set_index is " + "deprecated and will be removed in a future version. " + "Directly check the result.index.is_unique instead.", + Pandas4Warning, + stacklevel=find_stack_level(), + ) + else: + verify_integrity = False + + inplace = validate_bool_kwarg(inplace, "inplace") + self._check_inplace_and_allows_duplicate_labels(inplace) + if not isinstance(keys, list): + keys = [keys] + + err_msg = ( + 'The parameter "keys" may be a column key, one-dimensional ' + "array, or a list containing only valid column keys and " + "one-dimensional arrays." + ) + + missing: list[Hashable] = [] + for col in keys: + if isinstance(col, (Index, Series, np.ndarray, list, abc.Iterator)): + # arrays are fine as long as they are one-dimensional + # iterators get converted to list below + if getattr(col, "ndim", 1) != 1: + raise ValueError(err_msg) + else: + # everything else gets tried as a key; see GH 24969 + try: + found = col in self.columns + except TypeError as err: + raise TypeError( + f"{err_msg}. Received column of type {type(col)}" + ) from err + else: + if not found: + missing.append(col) + + if missing: + raise KeyError(f"None of {missing} are in the columns") + + if inplace: + frame = self + else: + frame = self.copy(deep=False) + + arrays: list[Index] = [] + names: list[Hashable] = [] + if append: + names = list(self.index.names) + if isinstance(self.index, MultiIndex): + arrays.extend( + self.index._get_level_values(i) for i in range(self.index.nlevels) + ) + else: + arrays.append(self.index) + + to_remove: set[Hashable] = set() + for col in keys: + if isinstance(col, MultiIndex): + arrays.extend(col._get_level_values(n) for n in range(col.nlevels)) + names.extend(col.names) + elif isinstance(col, (Index, Series)): + # if Index then not MultiIndex (treated above) + + # error: Argument 1 to "append" of "list" has incompatible type + # "Union[Index, Series]"; expected "Index" + arrays.append(col) # type: ignore[arg-type] + names.append(col.name) + elif isinstance(col, (list, np.ndarray)): + # error: Argument 1 to "append" of "list" has incompatible type + # "Union[List[Any], ndarray]"; expected "Index" + arrays.append(col) # type: ignore[arg-type] + names.append(None) + elif isinstance(col, abc.Iterator): + # error: Argument 1 to "append" of "list" has incompatible type + # "List[Any]"; expected "Index" + arrays.append(list(col)) # type: ignore[arg-type] + names.append(None) + # from here, col can only be a column label + else: + arrays.append(frame[col]) + names.append(col) + if drop: + to_remove.add(col) + + if len(arrays[-1]) != len(self): + # check newest element against length of calling frame, since + # ensure_index_from_sequences would not raise for append=False. + raise ValueError( + f"Length mismatch: Expected {len(self)} rows, " + f"received array of length {len(arrays[-1])}" + ) + + index = ensure_index_from_sequences(arrays, names) + + if verify_integrity and not index.is_unique: + duplicates = index[index.duplicated()].unique() + raise ValueError(f"Index has duplicate keys: {duplicates}") + + # use set to handle duplicate column names gracefully in case of drop + for c in to_remove: + del frame[c] + + # clear up memory usage + index._cleanup() + + frame.index = index + + if not inplace: + return frame + return None + + @overload + def reset_index( + self, + level: IndexLabel = ..., + *, + drop: bool = ..., + inplace: Literal[False] = ..., + col_level: Hashable = ..., + col_fill: Hashable = ..., + allow_duplicates: bool | lib.NoDefault = ..., + names: Hashable | Sequence[Hashable] | None = None, + ) -> DataFrame: ... + + @overload + def reset_index( + self, + level: IndexLabel = ..., + *, + drop: bool = ..., + inplace: Literal[True], + col_level: Hashable = ..., + col_fill: Hashable = ..., + allow_duplicates: bool | lib.NoDefault = ..., + names: Hashable | Sequence[Hashable] | None = None, + ) -> None: ... + + @overload + def reset_index( + self, + level: IndexLabel = ..., + *, + drop: bool = ..., + inplace: bool = ..., + col_level: Hashable = ..., + col_fill: Hashable = ..., + allow_duplicates: bool | lib.NoDefault = ..., + names: Hashable | Sequence[Hashable] | None = None, + ) -> DataFrame | None: ... + + def reset_index( + self, + level: IndexLabel | None = None, + *, + drop: bool = False, + inplace: bool = False, + col_level: Hashable = 0, + col_fill: Hashable = "", + allow_duplicates: bool | lib.NoDefault = lib.no_default, + names: Hashable | Sequence[Hashable] | None = None, + ) -> DataFrame | None: + """ + Reset the index, or a level of it. + + Reset the index of the DataFrame, and use the default one instead. + If the DataFrame has a MultiIndex, this method can remove one or more + levels. + + Parameters + ---------- + level : int, str, tuple, or list, default None + Only remove the given levels from the index. Removes all levels by + default. + drop : bool, default False + Do not try to insert index into dataframe columns. This resets + the index to the default integer index. + inplace : bool, default False + Whether to modify the DataFrame rather than creating a new one. + col_level : int or str, default 0 + If the columns have multiple levels, determines which level the + labels are inserted into. By default it is inserted into the first + level. + col_fill : object, default '' + If the columns have multiple levels, determines how the other + levels are named. If None then the index name is repeated. + allow_duplicates : bool, optional, default lib.no_default + Allow duplicate column labels to be created. + names : int, str or 1-dimensional list, default None + Using the given string, rename the DataFrame column which contains the + index data. If the DataFrame has a MultiIndex, this has to be a list + with length equal to the number of levels. + + Returns + ------- + DataFrame or None + DataFrame with the new index or None if ``inplace=True``. + + See Also + -------- + DataFrame.set_index : Opposite of reset_index. + DataFrame.reindex : Change to new indices or expand indices. + DataFrame.reindex_like : Change to same indices as other DataFrame. + + Examples + -------- + >>> df = pd.DataFrame( + ... [("bird", 389.0), ("bird", 24.0), ("mammal", 80.5), ("mammal", np.nan)], + ... index=["falcon", "parrot", "lion", "monkey"], + ... columns=("class", "max_speed"), + ... ) + >>> df + class max_speed + falcon bird 389.0 + parrot bird 24.0 + lion mammal 80.5 + monkey mammal NaN + + When we reset the index, the old index is added as a column, and a + new sequential index is used: + + >>> df.reset_index() + index class max_speed + 0 falcon bird 389.0 + 1 parrot bird 24.0 + 2 lion mammal 80.5 + 3 monkey mammal NaN + + We can use the `drop` parameter to avoid the old index being added as + a column: + + >>> df.reset_index(drop=True) + class max_speed + 0 bird 389.0 + 1 bird 24.0 + 2 mammal 80.5 + 3 mammal NaN + + You can also use `reset_index` with `MultiIndex`. + + >>> index = pd.MultiIndex.from_tuples( + ... [ + ... ("bird", "falcon"), + ... ("bird", "parrot"), + ... ("mammal", "lion"), + ... ("mammal", "monkey"), + ... ], + ... names=["class", "name"], + ... ) + >>> columns = pd.MultiIndex.from_tuples([("speed", "max"), ("species", "type")]) + >>> df = pd.DataFrame( + ... [(389.0, "fly"), (24.0, "fly"), (80.5, "run"), (np.nan, "jump")], + ... index=index, + ... columns=columns, + ... ) + >>> df + speed species + max type + class name + bird falcon 389.0 fly + parrot 24.0 fly + mammal lion 80.5 run + monkey NaN jump + + Using the `names` parameter, choose a name for the index column: + + >>> df.reset_index(names=["classes", "names"]) + classes names speed species + max type + 0 bird falcon 389.0 fly + 1 bird parrot 24.0 fly + 2 mammal lion 80.5 run + 3 mammal monkey NaN jump + + If the index has multiple levels, we can reset a subset of them: + + >>> df.reset_index(level="class") + class speed species + max type + name + falcon bird 389.0 fly + parrot bird 24.0 fly + lion mammal 80.5 run + monkey mammal NaN jump + + If we are not dropping the index, by default, it is placed in the top + level. We can place it in another level: + + >>> df.reset_index(level="class", col_level=1) + speed species + class max type + name + falcon bird 389.0 fly + parrot bird 24.0 fly + lion mammal 80.5 run + monkey mammal NaN jump + + When the index is inserted under another level, we can specify under + which one with the parameter `col_fill`: + + >>> df.reset_index(level="class", col_level=1, col_fill="species") + species speed species + class max type + name + falcon bird 389.0 fly + parrot bird 24.0 fly + lion mammal 80.5 run + monkey mammal NaN jump + + If we specify a nonexistent level for `col_fill`, it is created: + + >>> df.reset_index(level="class", col_level=1, col_fill="genus") + genus speed species + class max type + name + falcon bird 389.0 fly + parrot bird 24.0 fly + lion mammal 80.5 run + monkey mammal NaN jump + """ + inplace = validate_bool_kwarg(inplace, "inplace") + self._check_inplace_and_allows_duplicate_labels(inplace) + if inplace: + new_obj = self + else: + new_obj = self.copy(deep=False) + if allow_duplicates is not lib.no_default: + allow_duplicates = validate_bool_kwarg(allow_duplicates, "allow_duplicates") + + new_index = default_index(len(new_obj)) + if level is not None: + if not isinstance(level, (tuple, list)): + level = [level] + level = [self.index._get_level_number(lev) for lev in level] + if len(level) < self.index.nlevels: + new_index = self.index.droplevel(level) + + if not drop: + to_insert: Iterable[tuple[Any, Any | None]] + + default = "index" if "index" not in self else "level_0" + names = self.index._get_default_index_names(names, default) + + if isinstance(self.index, MultiIndex): + to_insert = zip( + reversed(self.index.levels), + reversed(self.index.codes), + strict=True, + ) + else: + to_insert = ((self.index, None),) + + multi_col = isinstance(self.columns, MultiIndex) + for j, (lev, lab) in enumerate(to_insert, start=1): + i = self.index.nlevels - j + if level is not None and i not in level: + continue + name = names[i] + if multi_col: + col_name = list(name) if isinstance(name, tuple) else [name] + if col_fill is None: + if len(col_name) not in (1, self.columns.nlevels): + raise ValueError( + "col_fill=None is incompatible " + f"with incomplete column name {name}" + ) + col_fill = col_name[0] + + lev_num = self.columns._get_level_number(col_level) + name_lst = [col_fill] * lev_num + col_name + missing = self.columns.nlevels - len(name_lst) + name_lst += [col_fill] * missing + name = tuple(name_lst) + + # to ndarray and maybe infer different dtype + level_values = lev._values + if level_values.dtype == np.object_: + level_values = lib.maybe_convert_objects(level_values) + + if lab is not None: + # if we have the codes, extract the values with a mask + level_values = algorithms.take( + level_values, lab, allow_fill=True, fill_value=lev._na_value + ) + + new_obj.insert( + 0, + name, + level_values, + allow_duplicates=allow_duplicates, + ) + + new_obj.index = new_index + if not inplace: + return new_obj + + return None + + # ---------------------------------------------------------------------- + # Reindex-based selection methods + + def isna(self) -> DataFrame: + """ + Detect missing values. + + Return a boolean same-sized object indicating if the values are NA. + NA values, such as None or :attr:`numpy.NaN`, gets mapped to True + values. + Everything else gets mapped to False values. Characters such as empty + strings ``''`` or :attr:`numpy.inf` are not considered NA values. + + Returns + ------- + Series/DataFrame + Mask of bool values for each element in Series/DataFrame + that indicates whether an element is an NA value. + + See Also + -------- + Series.isnull : Alias of isna. + DataFrame.isnull : Alias of isna. + Series.notna : Boolean inverse of isna. + DataFrame.notna : Boolean inverse of isna. + Series.dropna : Omit axes labels with missing values. + DataFrame.dropna : Omit axes labels with missing values. + isna : Top-level isna. + + Examples + -------- + Show which entries in a DataFrame are NA. + + >>> df = pd.DataFrame( + ... dict( + ... age=[5, 6, np.nan], + ... born=[ + ... pd.NaT, + ... pd.Timestamp("1939-05-27"), + ... pd.Timestamp("1940-04-25"), + ... ], + ... name=["Alfred", "Batman", ""], + ... toy=[None, "Batmobile", "Joker"], + ... ) + ... ) + >>> df + age born name toy + 0 5.0 NaT Alfred NaN + 1 6.0 1939-05-27 Batman Batmobile + 2 NaN 1940-04-25 Joker + + >>> df.isna() + age born name toy + 0 False True False True + 1 False False False False + 2 True False False False + + Show which entries in a Series are NA. + + >>> ser = pd.Series([5, 6, np.nan]) + >>> ser + 0 5.0 + 1 6.0 + 2 NaN + dtype: float64 + + >>> ser.isna() + 0 False + 1 False + 2 True + dtype: bool + """ + res_mgr = self._mgr.isna(func=isna) + result = self._constructor_from_mgr(res_mgr, axes=res_mgr.axes) + return result.__finalize__(self, method="isna") + + def isnull(self) -> DataFrame: + """ + DataFrame.isnull is an alias for DataFrame.isna. + + Detect missing values. + + Return a boolean same-sized object indicating if the values are NA. + NA values, such as None or :attr:`numpy.NaN`, gets mapped to True + values. + Everything else gets mapped to False values. Characters such as empty + strings ``''`` or :attr:`numpy.inf` are not considered NA values. + + Returns + ------- + Series/DataFrame + Mask of bool values for each element in Series/DataFrame + that indicates whether an element is an NA value. + + See Also + -------- + Series.isnull : Alias of isna. + DataFrame.isnull : Alias of isna. + Series.notna : Boolean inverse of isna. + DataFrame.notna : Boolean inverse of isna. + Series.dropna : Omit axes labels with missing values. + DataFrame.dropna : Omit axes labels with missing values. + isna : Top-level isna. + + Examples + -------- + Show which entries in a DataFrame are NA. + + >>> df = pd.DataFrame( + ... dict( + ... age=[5, 6, np.nan], + ... born=[ + ... pd.NaT, + ... pd.Timestamp("1939-05-27"), + ... pd.Timestamp("1940-04-25"), + ... ], + ... name=["Alfred", "Batman", ""], + ... toy=[None, "Batmobile", "Joker"], + ... ) + ... ) + >>> df + age born name toy + 0 5.0 NaT Alfred NaN + 1 6.0 1939-05-27 Batman Batmobile + 2 NaN 1940-04-25 Joker + + >>> df.isna() + age born name toy + 0 False True False True + 1 False False False False + 2 True False False False + + Show which entries in a Series are NA. + + >>> ser = pd.Series([5, 6, np.nan]) + >>> ser + 0 5.0 + 1 6.0 + 2 NaN + dtype: float64 + + >>> ser.isna() + 0 False + 1 False + 2 True + dtype: bool + """ + return self.isna() + + def notna(self) -> DataFrame: + """ + Detect existing (non-missing) values. + + Return a boolean same-sized object indicating if the values are not NA. + Non-missing values get mapped to True. Characters such as empty + strings ``''`` or :attr:`numpy.inf` are not considered NA values. + NA values, such as None or :attr:`numpy.NaN`, get mapped to False + values. + + Returns + ------- + Series/DataFrame + Mask of bool values for each element in Series/DataFrame + that indicates whether an element is not an NA value. + + See Also + -------- + Series.notnull : Alias of notna. + DataFrame.notnull : Alias of notna. + Series.isna : Boolean inverse of notna. + DataFrame.isna : Boolean inverse of notna. + Series.dropna : Omit axes labels with missing values. + DataFrame.dropna : Omit axes labels with missing values. + notna : Top-level notna. + + Examples + -------- + Show which entries in a DataFrame are not NA. + + >>> df = pd.DataFrame( + ... dict( + ... age=[5, 6, np.nan], + ... born=[ + ... pd.NaT, + ... pd.Timestamp("1939-05-27"), + ... pd.Timestamp("1940-04-25"), + ... ], + ... name=["Alfred", "Batman", ""], + ... toy=[None, "Batmobile", "Joker"], + ... ) + ... ) + >>> df + age born name toy + 0 5.0 NaT Alfred NaN + 1 6.0 1939-05-27 Batman Batmobile + 2 NaN 1940-04-25 Joker + + >>> df.notna() + age born name toy + 0 True False True False + 1 True True True True + 2 False True True True + + Show which entries in a Series are not NA. + + >>> ser = pd.Series([5, 6, np.nan]) + >>> ser + 0 5.0 + 1 6.0 + 2 NaN + dtype: float64 + + >>> ser.notna() + 0 True + 1 True + 2 False + dtype: bool + """ + return ~self.isna() + + def notnull(self) -> DataFrame: + """ + DataFrame.notnull is an alias for DataFrame.notna. + + Detect existing (non-missing) values. + + Return a boolean same-sized object indicating if the values are not NA. + Non-missing values get mapped to True. Characters such as empty + strings ``''`` or :attr:`numpy.inf` are not considered NA values. + NA values, such as None or :attr:`numpy.NaN`, get mapped to False + values. + + Returns + ------- + Series/DataFrame + Mask of bool values for each element in Series/DataFrame + that indicates whether an element is not an NA value. + + See Also + -------- + Series.notnull : Alias of notna. + DataFrame.notnull : Alias of notna. + Series.isna : Boolean inverse of notna. + DataFrame.isna : Boolean inverse of notna. + Series.dropna : Omit axes labels with missing values. + DataFrame.dropna : Omit axes labels with missing values. + notna : Top-level notna. + + Examples + -------- + Show which entries in a DataFrame are not NA. + + >>> df = pd.DataFrame( + ... dict( + ... age=[5, 6, np.nan], + ... born=[ + ... pd.NaT, + ... pd.Timestamp("1939-05-27"), + ... pd.Timestamp("1940-04-25"), + ... ], + ... name=["Alfred", "Batman", ""], + ... toy=[None, "Batmobile", "Joker"], + ... ) + ... ) + >>> df + age born name toy + 0 5.0 NaT Alfred NaN + 1 6.0 1939-05-27 Batman Batmobile + 2 NaN 1940-04-25 Joker + + >>> df.notnull() + age born name toy + 0 True False True False + 1 True True True True + 2 False True True True + + Show which entries in a Series are not NA. + + >>> ser = pd.Series([5, 6, np.nan]) + >>> ser + 0 5.0 + 1 6.0 + 2 NaN + dtype: float64 + + >>> ser.notnull() + 0 True + 1 True + 2 False + dtype: bool + """ + return ~self.isna() + + @overload + def dropna( + self, + *, + axis: Axis = ..., + how: AnyAll | lib.NoDefault = ..., + thresh: int | lib.NoDefault = ..., + subset: IndexLabel = ..., + inplace: Literal[False] = ..., + ignore_index: bool = ..., + ) -> DataFrame: ... + + @overload + def dropna( + self, + *, + axis: Axis = ..., + how: AnyAll | lib.NoDefault = ..., + thresh: int | lib.NoDefault = ..., + subset: IndexLabel = ..., + inplace: Literal[True], + ignore_index: bool = ..., + ) -> None: ... + + def dropna( + self, + *, + axis: Axis = 0, + how: AnyAll | lib.NoDefault = lib.no_default, + thresh: int | lib.NoDefault = lib.no_default, + subset: IndexLabel | AnyArrayLike | None = None, + inplace: bool = False, + ignore_index: bool = False, + ) -> DataFrame | None: + """ + Remove missing values. + + See the :ref:`User Guide ` for more on which values are + considered missing, and how to work with missing data. + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns'}, default 0 + Determine if rows or columns which contain missing values are + removed. + + * 0, or 'index' : Drop rows which contain missing values. + * 1, or 'columns' : Drop columns which contain missing value. + + Only a single axis is allowed. + + how : {'any', 'all'}, default 'any' + Determine if row or column is removed from DataFrame, when we have + at least one NA or all NA. + + * 'any' : If any NA values are present, drop that row or column. + * 'all' : If all values are NA, drop that row or column. + + thresh : int, optional + Require that many non-NA values. Cannot be combined with how. + subset : column label or iterable of labels, optional + Labels along other axis to consider, e.g. if you are dropping rows + these would be a list of columns to include. + inplace : bool, default False + Whether to modify the DataFrame rather than creating a new one. + ignore_index : bool, default ``False`` + If ``True``, the resulting axis will be labeled 0, 1, …, n - 1. + + .. versionadded:: 2.0.0 + + Returns + ------- + DataFrame or None + DataFrame with NA entries dropped from it or None if ``inplace=True``. + + See Also + -------- + DataFrame.isna: Indicate missing values. + DataFrame.notna : Indicate existing (non-missing) values. + DataFrame.fillna : Replace missing values. + Series.dropna : Drop missing values. + Index.dropna : Drop missing indices. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "name": ["Alfred", "Batman", "Catwoman"], + ... "toy": [np.nan, "Batmobile", "Bullwhip"], + ... "born": [pd.NaT, pd.Timestamp("1940-04-25"), pd.NaT], + ... } + ... ) + >>> df + name toy born + 0 Alfred NaN NaT + 1 Batman Batmobile 1940-04-25 + 2 Catwoman Bullwhip NaT + + Drop the rows where at least one element is missing. + + >>> df.dropna() + name toy born + 1 Batman Batmobile 1940-04-25 + + Drop the columns where at least one element is missing. + + >>> df.dropna(axis="columns") + name + 0 Alfred + 1 Batman + 2 Catwoman + + Drop the rows where all elements are missing. + + >>> df.dropna(how="all") + name toy born + 0 Alfred NaN NaT + 1 Batman Batmobile 1940-04-25 + 2 Catwoman Bullwhip NaT + + Keep only the rows with at least 2 non-NA values. + + >>> df.dropna(thresh=2) + name toy born + 1 Batman Batmobile 1940-04-25 + 2 Catwoman Bullwhip NaT + + Define in which columns to look for missing values. + + >>> df.dropna(subset=["name", "toy"]) + name toy born + 1 Batman Batmobile 1940-04-25 + 2 Catwoman Bullwhip NaT + """ + if (how is not lib.no_default) and (thresh is not lib.no_default): + raise TypeError( + "You cannot set both the how and thresh arguments at the same time." + ) + + if how is lib.no_default: + how = "any" + + inplace = validate_bool_kwarg(inplace, "inplace") + if isinstance(axis, (tuple, list)): + # GH20987 + raise TypeError("supplying multiple axes to axis is no longer supported.") + + axis = self._get_axis_number(axis) + agg_axis = 1 - axis + + agg_obj = self + if subset is not None: + # subset needs to be list + if not is_list_like(subset): + subset = [cast(Hashable, subset)] + ax = self._get_axis(agg_axis) + indices = ax.get_indexer_for(subset) + check = indices == -1 + if check.any(): + raise KeyError(np.array(subset)[check].tolist()) + agg_obj = self.take(indices, axis=agg_axis) + + if thresh is not lib.no_default: + count = agg_obj.count(axis=agg_axis) + mask = count >= thresh + elif how == "any": + # faster equivalent to 'agg_obj.count(agg_axis) == self.shape[agg_axis]' + mask = notna(agg_obj).all(axis=agg_axis, bool_only=False) + elif how == "all": + # faster equivalent to 'agg_obj.count(agg_axis) > 0' + mask = notna(agg_obj).any(axis=agg_axis, bool_only=False) + else: + raise ValueError(f"invalid how option: {how}") + + if np.all(mask): + result = self.copy(deep=False) + else: + result = self.loc(axis=axis)[mask] + + if ignore_index: + result.index = default_index(len(result)) + + if not inplace: + return result + self._update_inplace(result) + return None + + @overload + def drop_duplicates( + self, + subset: Hashable | Iterable[Hashable] | None = ..., + *, + keep: DropKeep = ..., + inplace: Literal[True], + ignore_index: bool = ..., + ) -> None: ... + + @overload + def drop_duplicates( + self, + subset: Hashable | Iterable[Hashable] | None = ..., + *, + keep: DropKeep = ..., + inplace: Literal[False] = ..., + ignore_index: bool = ..., + ) -> DataFrame: ... + + @overload + def drop_duplicates( + self, + subset: Hashable | Iterable[Hashable] | None = ..., + *, + keep: DropKeep = ..., + inplace: bool = ..., + ignore_index: bool = ..., + ) -> DataFrame | None: ... + + def drop_duplicates( + self, + subset: Hashable | Iterable[Hashable] | None = None, + *, + keep: DropKeep = "first", + inplace: bool = False, + ignore_index: bool = False, + ) -> DataFrame | None: + """ + Return DataFrame with duplicate rows removed. + + Considering certain columns is optional. Indexes, including time indexes + are ignored. + + Parameters + ---------- + subset : column label or iterable of labels, optional + Only consider certain columns for identifying duplicates, by + default use all of the columns. + keep : {'first', 'last', ``False``}, default 'first' + Determines which duplicates (if any) to keep. + + - 'first' : Drop duplicates except for the first occurrence. + - 'last' : Drop duplicates except for the last occurrence. + - ``False`` : Drop all duplicates. + + inplace : bool, default ``False`` + Whether to modify the DataFrame rather than creating a new one. + ignore_index : bool, default ``False`` + If ``True``, the resulting axis will be labeled 0, 1, …, n - 1. + + Returns + ------- + DataFrame or None + DataFrame with duplicates removed or None if ``inplace=True``. + + See Also + -------- + DataFrame.value_counts: Count unique combinations of columns. + + Notes + ----- + This method requires columns specified by ``subset`` to be of hashable type. + Passing unhashable columns will raise a ``TypeError``. + + Examples + -------- + Consider dataset containing ramen rating. + + >>> df = pd.DataFrame( + ... { + ... "brand": ["Yum Yum", "Yum Yum", "Indomie", "Indomie", "Indomie"], + ... "style": ["cup", "cup", "cup", "pack", "pack"], + ... "rating": [4, 4, 3.5, 15, 5], + ... } + ... ) + >>> df + brand style rating + 0 Yum Yum cup 4.0 + 1 Yum Yum cup 4.0 + 2 Indomie cup 3.5 + 3 Indomie pack 15.0 + 4 Indomie pack 5.0 + + By default, it removes duplicate rows based on all columns. + + >>> df.drop_duplicates() + brand style rating + 0 Yum Yum cup 4.0 + 2 Indomie cup 3.5 + 3 Indomie pack 15.0 + 4 Indomie pack 5.0 + + To remove duplicates on specific column(s), use ``subset``. + + >>> df.drop_duplicates(subset=["brand"]) + brand style rating + 0 Yum Yum cup 4.0 + 2 Indomie cup 3.5 + + To remove duplicates and keep last occurrences, use ``keep``. + + >>> df.drop_duplicates(subset=["brand", "style"], keep="last") + brand style rating + 1 Yum Yum cup 4.0 + 2 Indomie cup 3.5 + 4 Indomie pack 5.0 + """ + if self.empty: + return self.copy(deep=False) + + inplace = validate_bool_kwarg(inplace, "inplace") + ignore_index = validate_bool_kwarg(ignore_index, "ignore_index") + + result = self[-self.duplicated(subset, keep=keep)] + if ignore_index: + result.index = default_index(len(result)) + + if inplace: + self._update_inplace(result) + return None + else: + return result + + def duplicated( + self, + subset: Hashable | Iterable[Hashable] | None = None, + keep: DropKeep = "first", + ) -> Series: + """ + Return boolean Series denoting duplicate rows. + + Considering certain columns is optional. + + Parameters + ---------- + subset : column label or iterable of labels, optional + Only consider certain columns for identifying duplicates, by + default use all of the columns. + keep : {'first', 'last', False}, default 'first' + Determines which duplicates (if any) to mark. + + - ``first`` : Mark duplicates as ``True`` except for the first occurrence. + - ``last`` : Mark duplicates as ``True`` except for the last occurrence. + - False : Mark all duplicates as ``True``. + + Returns + ------- + Series + Boolean series for each duplicated rows. + + See Also + -------- + Index.duplicated : Equivalent method on index. + Series.duplicated : Equivalent method on Series. + Series.drop_duplicates : Remove duplicate values from Series. + DataFrame.drop_duplicates : Remove duplicate values from DataFrame. + + Examples + -------- + Consider dataset containing ramen rating. + + >>> df = pd.DataFrame( + ... { + ... "brand": ["Yum Yum", "Yum Yum", "Indomie", "Indomie", "Indomie"], + ... "style": ["cup", "cup", "cup", "pack", "pack"], + ... "rating": [4, 4, 3.5, 15, 5], + ... } + ... ) + >>> df + brand style rating + 0 Yum Yum cup 4.0 + 1 Yum Yum cup 4.0 + 2 Indomie cup 3.5 + 3 Indomie pack 15.0 + 4 Indomie pack 5.0 + + By default, for each set of duplicated values, the first occurrence + is set on False and all others on True. + + >>> df.duplicated() + 0 False + 1 True + 2 False + 3 False + 4 False + dtype: bool + + By using 'last', the last occurrence of each set of duplicated values + is set on False and all others on True. + + >>> df.duplicated(keep="last") + 0 True + 1 False + 2 False + 3 False + 4 False + dtype: bool + + By setting ``keep`` on False, all duplicates are True. + + >>> df.duplicated(keep=False) + 0 True + 1 True + 2 False + 3 False + 4 False + dtype: bool + + To find duplicates on specific column(s), use ``subset``. + + >>> df.duplicated(subset=["brand"]) + 0 False + 1 True + 2 False + 3 True + 4 True + dtype: bool + """ + + if self.empty: + return self._constructor_sliced(dtype=bool) + + def f(vals) -> tuple[np.ndarray, int]: + labels, shape = algorithms.factorize(vals, size_hint=len(self)) + return labels.astype("i8"), len(shape) + + if subset is None: + subset = self.columns + elif ( + not np.iterable(subset) + or isinstance(subset, str) + or (isinstance(subset, tuple) and subset in self.columns) + ): + subset = (subset,) + + # needed for mypy since can't narrow types using np.iterable + subset = cast(Sequence, subset) + + # Verify all columns in subset exist in the queried dataframe + # Otherwise, raise a KeyError, same as if you try to __getitem__ with a + # key that doesn't exist. + diff = set(subset) - set(self.columns) + if diff: + raise KeyError(Index(diff)) + + if len(subset) == 1 and self.columns.is_unique: + # GH#45236 This is faster than get_group_index below + result = self[next(iter(subset))].duplicated(keep) + result.name = None + else: + vals = (col.values for name, col in self.items() if name in subset) + labels, shape = map(list, zip(*map(f, vals), strict=True)) + + ids = get_group_index(labels, tuple(shape), sort=False, xnull=False) + result = self._constructor_sliced(duplicated(ids, keep), index=self.index) + return result.__finalize__(self, method="duplicated") + + # ---------------------------------------------------------------------- + # Sorting + # error: Signature of "sort_values" incompatible with supertype "NDFrame" + @overload # type: ignore[override] + def sort_values( + self, + by: IndexLabel, + *, + axis: Axis = ..., + ascending=..., + inplace: Literal[False] = ..., + kind: SortKind = ..., + na_position: NaPosition = ..., + ignore_index: bool = ..., + key: ValueKeyFunc = ..., + ) -> DataFrame: ... + + @overload + def sort_values( + self, + by: IndexLabel, + *, + axis: Axis = ..., + ascending=..., + inplace: Literal[True], + kind: SortKind = ..., + na_position: str = ..., + ignore_index: bool = ..., + key: ValueKeyFunc = ..., + ) -> None: ... + + def sort_values( + self, + by: IndexLabel, + *, + axis: Axis = 0, + ascending: bool | list[bool] | tuple[bool, ...] = True, + inplace: bool = False, + kind: SortKind = "quicksort", + na_position: str = "last", + ignore_index: bool = False, + key: ValueKeyFunc | None = None, + ) -> DataFrame | None: + """ + Sort by the values along either axis. + + Parameters + ---------- + by : str or list of str + Name or list of names to sort by. + + - if `axis` is 0 or `'index'` then `by` may contain index + levels and/or column labels. + - if `axis` is 1 or `'columns'` then `by` may contain column + levels and/or index labels. + axis : "{0 or 'index', 1 or 'columns'}", default 0 + Axis to be sorted. + ascending : bool or list of bool, default True + Sort ascending vs. descending. Specify list for multiple sort + orders. If this is a list of bools, must match the length of + the by. + inplace : bool, default False + If True, perform operation in-place. + kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, default 'quicksort' + Choice of sorting algorithm. See also :func:`numpy.sort` for more + information. `mergesort` and `stable` are the only stable algorithms. For + DataFrames, this option is only applied when sorting on a single + column or label. + na_position : {'first', 'last'}, default 'last' + Puts NaNs at the beginning if `first`; `last` puts NaNs at the + end. + ignore_index : bool, default False + If True, the resulting axis will be labeled 0, 1, …, n - 1. + key : callable, optional + Apply the key function to the values + before sorting. This is similar to the `key` argument in the + builtin :meth:`sorted` function, with the notable difference that + this `key` function should be *vectorized*. It should expect a + ``Series`` and return a Series with the same shape as the input. + It will be applied to each column in `by` independently. The values in the + returned Series will be used as the keys for sorting. + + Returns + ------- + DataFrame or None + DataFrame with sorted values or None if ``inplace=True``. + + See Also + -------- + DataFrame.sort_index : Sort a DataFrame by the index. + Series.sort_values : Similar method for a Series. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "col1": ["A", "A", "B", np.nan, "D", "C"], + ... "col2": [2, 1, 9, 8, 7, 4], + ... "col3": [0, 1, 9, 4, 2, 3], + ... "col4": ["a", "B", "c", "D", "e", "F"], + ... } + ... ) + >>> df + col1 col2 col3 col4 + 0 A 2 0 a + 1 A 1 1 B + 2 B 9 9 c + 3 NaN 8 4 D + 4 D 7 2 e + 5 C 4 3 F + + **Sort by a single column** + + In this case, we are sorting the rows according to values in ``col1``: + + >>> df.sort_values(by=["col1"]) + col1 col2 col3 col4 + 0 A 2 0 a + 1 A 1 1 B + 2 B 9 9 c + 5 C 4 3 F + 4 D 7 2 e + 3 NaN 8 4 D + + **Sort by multiple columns** + + You can also provide multiple columns to ``by`` argument, as shown below. + In this example, the rows are first sorted according to ``col1``, and then + the rows that have an identical value in ``col1`` are sorted according + to ``col2``. + + >>> df.sort_values(by=["col1", "col2"]) + col1 col2 col3 col4 + 1 A 1 1 B + 0 A 2 0 a + 2 B 9 9 c + 5 C 4 3 F + 4 D 7 2 e + 3 NaN 8 4 D + + **Sort in a descending order** + + The sort order can be reversed using ``ascending`` argument, as shown below: + + >>> df.sort_values(by="col1", ascending=False) + col1 col2 col3 col4 + 4 D 7 2 e + 5 C 4 3 F + 2 B 9 9 c + 0 A 2 0 a + 1 A 1 1 B + 3 NaN 8 4 D + + **Placing any** ``NA`` **first** + + Note that in the above example, the rows that contain an ``NA`` value in their + ``col1`` are placed at the end of the dataframe. This behavior can be modified + via ``na_position`` argument, as shown below: + + >>> df.sort_values(by="col1", ascending=False, na_position="first") + col1 col2 col3 col4 + 3 NaN 8 4 D + 4 D 7 2 e + 5 C 4 3 F + 2 B 9 9 c + 0 A 2 0 a + 1 A 1 1 B + + **Customized sort order** + + The ``key`` argument allows for a further customization of sorting behaviour. + For example, you may want + to ignore the `letter's case `__ + when sorting strings: + + >>> df.sort_values(by="col4", key=lambda col: col.str.lower()) + col1 col2 col3 col4 + 0 A 2 0 a + 1 A 1 1 B + 2 B 9 9 c + 3 NaN 8 4 D + 4 D 7 2 e + 5 C 4 3 F + + Another typical example is + `natural sorting `__. + This can be done using + ``natsort`` `package `__, + which provides a function to generate a key + to sort data in their natural order: + + >>> df = pd.DataFrame( + ... { + ... "hours": ["0hr", "128hr", "0hr", "64hr", "64hr", "128hr"], + ... "mins": [ + ... "10mins", + ... "40mins", + ... "40mins", + ... "40mins", + ... "10mins", + ... "10mins", + ... ], + ... "value": [10, 20, 30, 40, 50, 60], + ... } + ... ) + >>> df + hours mins value + 0 0hr 10mins 10 + 1 128hr 40mins 20 + 2 0hr 40mins 30 + 3 64hr 40mins 40 + 4 64hr 10mins 50 + 5 128hr 10mins 60 + >>> from natsort import natsort_keygen + >>> df.sort_values( + ... by=["hours", "mins"], + ... key=natsort_keygen(), + ... ) + hours mins value + 0 0hr 10mins 10 + 2 0hr 40mins 30 + 4 64hr 10mins 50 + 3 64hr 40mins 40 + 5 128hr 10mins 60 + 1 128hr 40mins 20 + """ + inplace = validate_bool_kwarg(inplace, "inplace") + axis = self._get_axis_number(axis) + ascending = validate_ascending(ascending) + if not isinstance(by, list): + by = [by] + # error: Argument 1 to "len" has incompatible type "Union[bool, List[bool]]"; + # expected "Sized" + if is_sequence(ascending) and ( + len(by) != len(ascending) # type: ignore[arg-type] + ): + # error: Argument 1 to "len" has incompatible type "Union[bool, + # List[bool]]"; expected "Sized" + raise ValueError( + f"Length of ascending ({len(ascending)})" # type: ignore[arg-type] + f" != length of by ({len(by)})" + ) + if len(by) > 1: + keys = (self._get_label_or_level_values(x, axis=axis) for x in by) + + # need to rewrap columns in Series to apply key function + if key is not None: + keys_data = [ + Series(k, name=name) for (k, name) in zip(keys, by, strict=True) + ] + else: + # error: Argument 1 to "list" has incompatible type + # "Generator[ExtensionArray | ndarray[Any, Any], None, None]"; + # expected "Iterable[Series]" + keys_data = list(keys) # type: ignore[arg-type] + + indexer = lexsort_indexer( + keys_data, orders=ascending, na_position=na_position, key=key + ) + elif by: + # len(by) == 1 + + k = self._get_label_or_level_values(by[0], axis=axis) + + # need to rewrap column in Series to apply key function + if key is not None: + # error: Incompatible types in assignment (expression has type + # "Series", variable has type "ndarray") + k = Series(k, name=by[0]) # type: ignore[assignment] + + if isinstance(ascending, (tuple, list)): + ascending = ascending[0] + + indexer = nargsort( + k, kind=kind, ascending=ascending, na_position=na_position, key=key + ) + elif inplace: + return self._update_inplace(self) + else: + return self.copy(deep=False) + + if is_range_indexer(indexer, len(indexer)): + result = self.copy(deep=False) + if ignore_index: + result.index = default_index(len(result)) + + if inplace: + return self._update_inplace(result) + else: + return result + + new_data = self._mgr.take( + indexer, axis=self._get_block_manager_axis(axis), verify=False + ) + + if ignore_index: + new_data.set_axis( + self._get_block_manager_axis(axis), default_index(len(indexer)) + ) + + result = self._constructor_from_mgr(new_data, axes=new_data.axes) + if inplace: + return self._update_inplace(result) + else: + return result.__finalize__(self, method="sort_values") + + @overload + def sort_index( + self, + *, + axis: Axis = ..., + level: IndexLabel = ..., + ascending: bool | Sequence[bool] = ..., + inplace: Literal[True], + kind: SortKind = ..., + na_position: NaPosition = ..., + sort_remaining: bool = ..., + ignore_index: bool = ..., + key: IndexKeyFunc = ..., + ) -> None: ... + + @overload + def sort_index( + self, + *, + axis: Axis = ..., + level: IndexLabel = ..., + ascending: bool | Sequence[bool] = ..., + inplace: Literal[False] = ..., + kind: SortKind = ..., + na_position: NaPosition = ..., + sort_remaining: bool = ..., + ignore_index: bool = ..., + key: IndexKeyFunc = ..., + ) -> DataFrame: ... + + @overload + def sort_index( + self, + *, + axis: Axis = ..., + level: IndexLabel = ..., + ascending: bool | Sequence[bool] = ..., + inplace: bool = ..., + kind: SortKind = ..., + na_position: NaPosition = ..., + sort_remaining: bool = ..., + ignore_index: bool = ..., + key: IndexKeyFunc = ..., + ) -> DataFrame | None: ... + + def sort_index( + self, + *, + axis: Axis = 0, + level: IndexLabel | None = None, + ascending: bool | Sequence[bool] = True, + inplace: bool = False, + kind: SortKind = "quicksort", + na_position: NaPosition = "last", + sort_remaining: bool = True, + ignore_index: bool = False, + key: IndexKeyFunc | None = None, + ) -> DataFrame | None: + """ + Sort object by labels (along an axis). + + Returns a new DataFrame sorted by label if `inplace` argument is + ``False``, otherwise updates the original DataFrame and returns None. + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis along which to sort. The value 0 identifies the rows, + and 1 identifies the columns. + level : int or level name or list of ints or list of level names + If not None, sort on values in specified index level(s). + ascending : bool or list-like of bools, default True + Sort ascending vs. descending. When the index is a MultiIndex the + sort direction can be controlled for each level individually. + inplace : bool, default False + Whether to modify the DataFrame rather than creating a new one. + kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, default 'quicksort' + Choice of sorting algorithm. See also :func:`numpy.sort` for more + information. `mergesort` and `stable` are the only stable algorithms. For + DataFrames, this option is only applied when sorting on a single + column or label. + na_position : {'first', 'last'}, default 'last' + Puts NaNs at the beginning if `first`; `last` puts NaNs at the end. + Not implemented for MultiIndex. + sort_remaining : bool, default True + If True and sorting by level and index is multilevel, sort by other + levels too (in order) after sorting by specified level. + ignore_index : bool, default False + If True, the resulting axis will be labeled 0, 1, …, n - 1. + key : callable, optional + If not None, apply the key function to the index values + before sorting. This is similar to the `key` argument in the + builtin :meth:`sorted` function, with the notable difference that + this `key` function should be *vectorized*. It should expect an + ``Index`` and return an ``Index`` of the same shape. For MultiIndex + inputs, the key is applied *per level*. + + Returns + ------- + DataFrame or None + The original DataFrame sorted by the labels or None if ``inplace=True``. + + See Also + -------- + Series.sort_index : Sort Series by the index. + DataFrame.sort_values : Sort DataFrame by the value. + Series.sort_values : Sort Series by the value. + + Examples + -------- + >>> df = pd.DataFrame( + ... [1, 2, 3, 4, 5], index=[100, 29, 234, 1, 150], columns=["A"] + ... ) + >>> df.sort_index() + A + 1 4 + 29 2 + 100 1 + 150 5 + 234 3 + + By default, it sorts in ascending order, to sort in descending order, + use ``ascending=False`` + + >>> df.sort_index(ascending=False) + A + 234 3 + 150 5 + 100 1 + 29 2 + 1 4 + + A key function can be specified which is applied to the index before + sorting. For a ``MultiIndex`` this is applied to each level separately. + + >>> df = pd.DataFrame({"a": [1, 2, 3, 4]}, index=["A", "b", "C", "d"]) + >>> df.sort_index(key=lambda x: x.str.lower()) + a + A 1 + b 2 + C 3 + d 4 + """ + return super().sort_index( + axis=axis, + level=level, + ascending=ascending, + inplace=inplace, + kind=kind, + na_position=na_position, + sort_remaining=sort_remaining, + ignore_index=ignore_index, + key=key, + ) + + def value_counts( + self, + subset: IndexLabel | None = None, + normalize: bool = False, + sort: bool = True, + ascending: bool = False, + dropna: bool = True, + ) -> Series: + """ + Return a Series containing the frequency of each distinct row in the DataFrame. + + Parameters + ---------- + subset : Hashable or a sequence of the previous, optional + Columns to use when counting unique combinations. + normalize : bool, default False + Return proportions rather than frequencies. + sort : bool, default True + Stable sort by frequencies when True. Preserve the order of the data + when False. + + .. versionchanged:: 3.0.0 + + Prior to 3.0.0, ``sort=False`` would sort by the columns values. + + .. versionchanged:: 3.0.0 + + Prior to 3.0.0, the sort was unstable. + ascending : bool, default False + Sort in ascending order. + dropna : bool, default True + Do not include counts of rows that contain NA values. + + Returns + ------- + Series + Series containing the frequency of each distinct row in the DataFrame. + + See Also + -------- + Series.value_counts: Equivalent method on Series. + + Notes + ----- + The returned Series will have a MultiIndex with one level per input + column but an Index (non-multi) for a single label. By default, rows + that contain any NA values are omitted from the result. By default, + the resulting Series will be sorted by frequencies in descending order so that + the first element is the most frequently-occurring row. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"num_legs": [2, 4, 4, 6], "num_wings": [2, 0, 0, 0]}, + ... index=["falcon", "dog", "cat", "ant"], + ... ) + >>> df + num_legs num_wings + falcon 2 2 + dog 4 0 + cat 4 0 + ant 6 0 + + >>> df.value_counts() + num_legs num_wings + 4 0 2 + 2 2 1 + 6 0 1 + Name: count, dtype: int64 + + >>> df.value_counts(sort=False) + num_legs num_wings + 2 2 1 + 4 0 2 + 6 0 1 + Name: count, dtype: int64 + + >>> df.value_counts(ascending=True) + num_legs num_wings + 2 2 1 + 6 0 1 + 4 0 2 + Name: count, dtype: int64 + + >>> df.value_counts(normalize=True) + num_legs num_wings + 4 0 0.50 + 2 2 0.25 + 6 0 0.25 + Name: proportion, dtype: float64 + + With `dropna` set to `False` we can also count rows with NA values. + + >>> df = pd.DataFrame( + ... { + ... "first_name": ["John", "Anne", "John", "Beth"], + ... "middle_name": ["Smith", pd.NA, pd.NA, "Louise"], + ... } + ... ) + >>> df + first_name middle_name + 0 John Smith + 1 Anne NaN + 2 John NaN + 3 Beth Louise + + >>> df.value_counts() + first_name middle_name + John Smith 1 + Beth Louise 1 + Name: count, dtype: int64 + + >>> df.value_counts(dropna=False) + first_name middle_name + John Smith 1 + Anne NaN 1 + John NaN 1 + Beth Louise 1 + Name: count, dtype: int64 + + >>> df.value_counts("first_name") + first_name + John 2 + Anne 1 + Beth 1 + Name: count, dtype: int64 + """ + if subset is None: + subset = self.columns.tolist() + + name = "proportion" if normalize else "count" + counts = self.groupby( + subset, sort=False, dropna=dropna, observed=False + )._grouper.size() + counts.name = name + + if sort: + counts = counts.sort_values(ascending=ascending, kind="stable") + if normalize: + counts /= counts.sum() + + # Force MultiIndex for a list_like subset with a single column + if is_list_like(subset) and len(subset) == 1: # type: ignore[arg-type] + counts.index = MultiIndex.from_arrays( + [counts.index], names=[counts.index.name] + ) + + return counts + + def nlargest( + self, n: int, columns: IndexLabel, keep: NsmallestNlargestKeep = "first" + ) -> DataFrame: + """ + Return the first `n` rows ordered by `columns` in descending order. + + Return the first `n` rows with the largest values in `columns`, in + descending order. The columns that are not specified are returned as + well, but not used for ordering. + + This method is equivalent to + ``df.sort_values(columns, ascending=False).head(n)``, but more + performant. + + Parameters + ---------- + n : int + Number of rows to return. + columns : Hashable or a sequence of the previous + Column label(s) to order by. + keep : {'first', 'last', 'all'}, default 'first' + Where there are duplicate values: + + - ``first`` : prioritize the first occurrence(s) + - ``last`` : prioritize the last occurrence(s) + - ``all`` : keep all the ties of the smallest item even if it means + selecting more than ``n`` items. + + Returns + ------- + DataFrame + The first `n` rows ordered by the given columns in descending + order. + + See Also + -------- + DataFrame.nsmallest : Return the first `n` rows ordered by `columns` in + ascending order. + DataFrame.sort_values : Sort DataFrame by the values. + DataFrame.head : Return the first `n` rows without re-ordering. + + Notes + ----- + This function cannot be used with all column types. For example, when + specifying columns with `object` or `category` dtypes, ``TypeError`` is + raised. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "population": [ + ... 59000000, + ... 65000000, + ... 434000, + ... 434000, + ... 434000, + ... 337000, + ... 11300, + ... 11300, + ... 11300, + ... ], + ... "GDP": [1937894, 2583560, 12011, 4520, 12128, 17036, 182, 38, 311], + ... "alpha-2": ["IT", "FR", "MT", "MV", "BN", "IS", "NR", "TV", "AI"], + ... }, + ... index=[ + ... "Italy", + ... "France", + ... "Malta", + ... "Maldives", + ... "Brunei", + ... "Iceland", + ... "Nauru", + ... "Tuvalu", + ... "Anguilla", + ... ], + ... ) + >>> df + population GDP alpha-2 + Italy 59000000 1937894 IT + France 65000000 2583560 FR + Malta 434000 12011 MT + Maldives 434000 4520 MV + Brunei 434000 12128 BN + Iceland 337000 17036 IS + Nauru 11300 182 NR + Tuvalu 11300 38 TV + Anguilla 11300 311 AI + + In the following example, we will use ``nlargest`` to select the three + rows having the largest values in column "population". + + >>> df.nlargest(3, "population") + population GDP alpha-2 + France 65000000 2583560 FR + Italy 59000000 1937894 IT + Malta 434000 12011 MT + + When using ``keep='last'``, ties are resolved in reverse order: + + >>> df.nlargest(3, "population", keep="last") + population GDP alpha-2 + France 65000000 2583560 FR + Italy 59000000 1937894 IT + Brunei 434000 12128 BN + + When using ``keep='all'``, the number of element kept can go beyond ``n`` + if there are duplicate values for the smallest element, all the + ties are kept: + + >>> df.nlargest(3, "population", keep="all") + population GDP alpha-2 + France 65000000 2583560 FR + Italy 59000000 1937894 IT + Malta 434000 12011 MT + Maldives 434000 4520 MV + Brunei 434000 12128 BN + + However, ``nlargest`` does not keep ``n`` distinct largest elements: + + >>> df.nlargest(5, "population", keep="all") + population GDP alpha-2 + France 65000000 2583560 FR + Italy 59000000 1937894 IT + Malta 434000 12011 MT + Maldives 434000 4520 MV + Brunei 434000 12128 BN + + To order by the largest values in column "population" and then "GDP", + we can specify multiple columns like in the next example. + + >>> df.nlargest(3, ["population", "GDP"]) + population GDP alpha-2 + France 65000000 2583560 FR + Italy 59000000 1937894 IT + Brunei 434000 12128 BN + """ + return selectn.SelectNFrame(self, n=n, keep=keep, columns=columns).nlargest() + + def nsmallest( + self, n: int, columns: IndexLabel, keep: NsmallestNlargestKeep = "first" + ) -> DataFrame: + """ + Return the first `n` rows ordered by `columns` in ascending order. + + Return the first `n` rows with the smallest values in `columns`, in + ascending order. The columns that are not specified are returned as + well, but not used for ordering. + + This method is equivalent to + ``df.sort_values(columns, ascending=True).head(n)``, but more + performant. + + Parameters + ---------- + n : int + Number of items to retrieve. + columns : list or str + Column name or names to order by. + keep : {'first', 'last', 'all'}, default 'first' + Where there are duplicate values: + + - ``first`` : take the first occurrence. + - ``last`` : take the last occurrence. + - ``all`` : keep all the ties of the largest item even if it means + selecting more than ``n`` items. + + Returns + ------- + DataFrame + DataFrame with the first `n` rows ordered by `columns` in ascending order. + + See Also + -------- + DataFrame.nlargest : Return the first `n` rows ordered by `columns` in + descending order. + DataFrame.sort_values : Sort DataFrame by the values. + DataFrame.head : Return the first `n` rows without re-ordering. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "population": [ + ... 59000000, + ... 65000000, + ... 434000, + ... 434000, + ... 434000, + ... 337000, + ... 337000, + ... 11300, + ... 11300, + ... ], + ... "GDP": [1937894, 2583560, 12011, 4520, 12128, 17036, 182, 38, 311], + ... "alpha-2": ["IT", "FR", "MT", "MV", "BN", "IS", "NR", "TV", "AI"], + ... }, + ... index=[ + ... "Italy", + ... "France", + ... "Malta", + ... "Maldives", + ... "Brunei", + ... "Iceland", + ... "Nauru", + ... "Tuvalu", + ... "Anguilla", + ... ], + ... ) + >>> df + population GDP alpha-2 + Italy 59000000 1937894 IT + France 65000000 2583560 FR + Malta 434000 12011 MT + Maldives 434000 4520 MV + Brunei 434000 12128 BN + Iceland 337000 17036 IS + Nauru 337000 182 NR + Tuvalu 11300 38 TV + Anguilla 11300 311 AI + + In the following example, we will use ``nsmallest`` to select the + three rows having the smallest values in column "population". + + >>> df.nsmallest(3, "population") + population GDP alpha-2 + Tuvalu 11300 38 TV + Anguilla 11300 311 AI + Iceland 337000 17036 IS + + When using ``keep='last'``, ties are resolved in reverse order: + + >>> df.nsmallest(3, "population", keep="last") + population GDP alpha-2 + Anguilla 11300 311 AI + Tuvalu 11300 38 TV + Nauru 337000 182 NR + + When using ``keep='all'``, the number of element kept can go beyond ``n`` + if there are duplicate values for the largest element, all the + ties are kept. + + >>> df.nsmallest(3, "population", keep="all") + population GDP alpha-2 + Tuvalu 11300 38 TV + Anguilla 11300 311 AI + Iceland 337000 17036 IS + Nauru 337000 182 NR + + However, ``nsmallest`` does not keep ``n`` distinct + smallest elements: + + >>> df.nsmallest(4, "population", keep="all") + population GDP alpha-2 + Tuvalu 11300 38 TV + Anguilla 11300 311 AI + Iceland 337000 17036 IS + Nauru 337000 182 NR + + To order by the smallest values in column "population" and then "GDP", we can + specify multiple columns like in the next example. + + >>> df.nsmallest(3, ["population", "GDP"]) + population GDP alpha-2 + Tuvalu 11300 38 TV + Anguilla 11300 311 AI + Nauru 337000 182 NR + """ + return selectn.SelectNFrame(self, n=n, keep=keep, columns=columns).nsmallest() + + def swaplevel(self, i: Axis = -2, j: Axis = -1, axis: Axis = 0) -> DataFrame: + """ + Swap levels i and j in a :class:`MultiIndex`. + + Default is to swap the two innermost levels of the index. + + Parameters + ---------- + i, j : int or str + Levels of the indices to be swapped. Can pass level name as string. + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to swap levels on. 0 or 'index' for row-wise, 1 or + 'columns' for column-wise. + + Returns + ------- + DataFrame + DataFrame with levels swapped in MultiIndex. + + See Also + -------- + DataFrame.reorder_levels: Reorder levels of MultiIndex. + DataFrame.sort_index: Sort MultiIndex. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"Grade": ["A", "B", "A", "C"]}, + ... index=[ + ... ["Final exam", "Final exam", "Coursework", "Coursework"], + ... ["History", "Geography", "History", "Geography"], + ... ["January", "February", "March", "April"], + ... ], + ... ) + >>> df + Grade + Final exam History January A + Geography February B + Coursework History March A + Geography April C + + In the following example, we will swap the levels of the indices. + Here, we will swap the levels column-wise, but levels can be swapped row-wise + in a similar manner. Note that column-wise is the default behaviour. + By not supplying any arguments for i and j, we swap the last and second to + last indices. + + >>> df.swaplevel() + Grade + Final exam January History A + February Geography B + Coursework March History A + April Geography C + + By supplying one argument, we can choose which index to swap the last + index with. We can for example swap the first index with the last one as + follows. + + >>> df.swaplevel(0) + Grade + January History Final exam A + February Geography Final exam B + March History Coursework A + April Geography Coursework C + + We can also define explicitly which indices we want to swap by supplying values + for both i and j. Here, we for example swap the first and second indices. + + >>> df.swaplevel(0, 1) + Grade + History Final exam January A + Geography Final exam February B + History Coursework March A + Geography Coursework April C + """ + result = self.copy(deep=False) + + axis = self._get_axis_number(axis) + + if not isinstance(result._get_axis(axis), MultiIndex): # pragma: no cover + raise TypeError("Can only swap levels on a hierarchical axis.") + + if axis == 0: + assert isinstance(result.index, MultiIndex) + result.index = result.index.swaplevel(i, j) + else: + assert isinstance(result.columns, MultiIndex) + result.columns = result.columns.swaplevel(i, j) + return result + + def reorder_levels(self, order: Sequence[int | str], axis: Axis = 0) -> DataFrame: + """ + Rearrange index or column levels using input ``order``. + + May not drop or duplicate levels. + + Parameters + ---------- + order : list of int or list of str + List representing new level order. Reference level by number + (position) or by key (label). + axis : {0 or 'index', 1 or 'columns'}, default 0 + Where to reorder levels. + + Returns + ------- + DataFrame + DataFrame with indices or columns with reordered levels. + + See Also + -------- + DataFrame.swaplevel : Swap levels i and j in a MultiIndex. + + Examples + -------- + >>> data = { + ... "class": ["Mammals", "Mammals", "Reptiles"], + ... "diet": ["Omnivore", "Carnivore", "Carnivore"], + ... "species": ["Humans", "Dogs", "Snakes"], + ... } + >>> df = pd.DataFrame(data, columns=["class", "diet", "species"]) + >>> df = df.set_index(["class", "diet"]) + >>> df + species + class diet + Mammals Omnivore Humans + Carnivore Dogs + Reptiles Carnivore Snakes + + Let's reorder the levels of the index: + + >>> df.reorder_levels(["diet", "class"]) + species + diet class + Omnivore Mammals Humans + Carnivore Mammals Dogs + Reptiles Snakes + """ + axis = self._get_axis_number(axis) + if not isinstance(self._get_axis(axis), MultiIndex): # pragma: no cover + raise TypeError("Can only reorder levels on a hierarchical axis.") + + result = self.copy(deep=False) + + if axis == 0: + assert isinstance(result.index, MultiIndex) + result.index = result.index.reorder_levels(order) + else: + assert isinstance(result.columns, MultiIndex) + result.columns = result.columns.reorder_levels(order) + return result + + # ---------------------------------------------------------------------- + # Arithmetic Methods + + def _cmp_method(self, other, op): + axis: Literal[1] = 1 # only relevant for Series other case + + self, other = self._align_for_op(other, axis, flex=False, level=None) + + # See GH#4537 for discussion of scalar op behavior + new_data = self._dispatch_frame_op(other, op, axis=axis) + return self._construct_result(new_data, other=other) + + def _arith_method(self, other, op): + if self._should_reindex_frame_op(other, op, 1, None, None): + return self._arith_method_with_reindex(other, op) + + axis: Literal[1] = 1 # only relevant for Series other case + other = ops.maybe_prepare_scalar_for_op(other, (self.shape[axis],)) + + self, other = self._align_for_op(other, axis, flex=True, level=None) + + with np.errstate(all="ignore"): + new_data = self._dispatch_frame_op(other, op, axis=axis) + return self._construct_result(new_data, other=other) + + _logical_method = _arith_method + + def _dispatch_frame_op( + self, right, func: Callable, axis: AxisInt | None = None + ) -> DataFrame: + """ + Evaluate the frame operation func(left, right) by evaluating + column-by-column, dispatching to the Series implementation. + + Parameters + ---------- + right : scalar, Series, or DataFrame + func : arithmetic or comparison operator + axis : {None, 0, 1} + + Returns + ------- + DataFrame + + Notes + ----- + Caller is responsible for setting np.errstate where relevant. + """ + # Get the appropriate array-op to apply to each column/block's values. + array_op = ops.get_array_op(func) + + right = lib.item_from_zerodim(right) + if not is_list_like(right): + # i.e. scalar, faster than checking np.ndim(right) == 0 + bm = self._mgr.apply(array_op, right=right) + return self._constructor_from_mgr(bm, axes=bm.axes) + + elif isinstance(right, DataFrame): + assert self.index.equals(right.index) + assert self.columns.equals(right.columns) + # TODO: The previous assertion `assert right._indexed_same(self)` + # fails in cases with empty columns reached via + # _frame_arith_method_with_reindex + + # TODO operate_blockwise expects a manager of the same type + bm = self._mgr.operate_blockwise( + right._mgr, + array_op, + ) + return self._constructor_from_mgr(bm, axes=bm.axes) + + elif isinstance(right, Series) and axis == 1: + # axis=1 means we want to operate row-by-row + assert right.index.equals(self.columns) + + right = right._values + # maybe_align_as_frame ensures we do not have an ndarray here + assert not isinstance(right, np.ndarray) + + arrays = [ + array_op(_left, _right) + for _left, _right in zip(self._iter_column_arrays(), right, strict=True) + ] + + elif isinstance(right, Series): + assert right.index.equals(self.index) + right = right._values + + arrays = [array_op(left, right) for left in self._iter_column_arrays()] + + else: + raise NotImplementedError(right) + + return type(self)._from_arrays( + arrays, self.columns, self.index, verify_integrity=False + ) + + def _combine_frame(self, other: DataFrame, func, fill_value=None): + # at this point we have `self._indexed_same(other)` + + if fill_value is None: + # since _arith_op may be called in a loop, avoid function call + # overhead if possible by doing this check once + _arith_op = func + + else: + + def _arith_op(left, right): + # for the mixed_type case where we iterate over columns, + # _arith_op(left, right) is equivalent to + # left._binop(right, func, fill_value=fill_value) + left, right = ops.fill_binop(left, right, fill_value) + return func(left, right) + + new_data = self._dispatch_frame_op(other, _arith_op) + return new_data + + def _arith_method_with_reindex(self, right: DataFrame, op) -> DataFrame: + """ + For DataFrame-with-DataFrame operations that require reindexing, + operate only on shared columns, then reindex. + + Parameters + ---------- + right : DataFrame + op : binary operator + + Returns + ------- + DataFrame + """ + left = self + + # GH#31623, only operate on shared columns + cols, lcol_indexer, rcol_indexer = left.columns.join( + right.columns, how="inner", return_indexers=True + ) + + new_left = left if lcol_indexer is None else left.iloc[:, lcol_indexer] + new_right = right if rcol_indexer is None else right.iloc[:, rcol_indexer] + + # GH#60498 For MultiIndex column alignment + if isinstance(cols, MultiIndex): + # When overwriting column names, make a shallow copy so as to not modify + # the input DFs + new_left = new_left.copy(deep=False) + new_right = new_right.copy(deep=False) + new_left.columns = cols + new_right.columns = cols + + result = op(new_left, new_right) + + # Do the join on the columns instead of using left._align_for_op + # to avoid constructing two potentially large/sparse DataFrames + join_columns = left.columns.join(right.columns, how="outer") + + if result.columns.has_duplicates: + # Avoid reindexing with a duplicate axis. + # https://github.com/pandas-dev/pandas/issues/35194 + indexer, _ = result.columns.get_indexer_non_unique(join_columns) + indexer = algorithms.unique1d(indexer) + result = result._reindex_with_indexers( + {1: [join_columns, indexer]}, allow_dups=True + ) + else: + result = result.reindex(join_columns, axis=1) + + return result + + def _should_reindex_frame_op(self, right, op, axis: int, fill_value, level) -> bool: + """ + Check if this is an operation between DataFrames that will need to reindex. + """ + if op is operator.pow or op is roperator.rpow: + # GH#32685 pow has special semantics for operating with null values + return False + + if not isinstance(right, DataFrame): + return False + + if ( + ( + isinstance(self.columns, MultiIndex) + or isinstance(right.columns, MultiIndex) + ) + and not self.columns.equals(right.columns) + and fill_value is None + ): + # GH#60498 Reindex if MultiIndexe columns are not matching + # GH#60903 Don't reindex if fill_value is provided + return True + + if fill_value is None and level is None and axis == 1: + # TODO: any other cases we should handle here? + + # Intersection is always unique so we have to check the unique columns + left_uniques = self.columns.unique() + right_uniques = right.columns.unique() + cols = left_uniques.intersection(right_uniques) + if len(cols) and not ( + len(cols) == len(left_uniques) and len(cols) == len(right_uniques) + ): + # TODO: is there a shortcut available when len(cols) == 0? + return True + + return False + + def _align_for_op( + self, + other, + axis: AxisInt, + flex: bool | None = False, + level: Level | None = None, + ): + """ + Convert rhs to meet lhs dims if input is list, tuple or np.ndarray. + + Parameters + ---------- + other : Any + axis : int + flex : bool or None, default False + Whether this is a flex op, in which case we reindex. + None indicates not to check for alignment. + level : int or level name, default None + + Returns + ------- + left : DataFrame + right : Any + """ + left, right = self, other + + def to_series(right): + msg = ( + "Unable to coerce to Series, " + "length must be {req_len}: given {given_len}" + ) + + # pass dtype to avoid doing inference, which would break consistency + # with Index/Series ops + dtype = None + if getattr(right, "dtype", None) == object: + # can't pass right.dtype unconditionally as that would break on e.g. + # datetime64[h] ndarray + dtype = object + + if axis == 0: + if len(left.index) != len(right): + raise ValueError( + msg.format(req_len=len(left.index), given_len=len(right)) + ) + right = left._constructor_sliced(right, index=left.index, dtype=dtype) + else: + if len(left.columns) != len(right): + raise ValueError( + msg.format(req_len=len(left.columns), given_len=len(right)) + ) + right = left._constructor_sliced(right, index=left.columns, dtype=dtype) + return right + + if isinstance(right, np.ndarray): + if right.ndim == 1: + right = to_series(right) + + elif right.ndim == 2: + # We need to pass dtype=right.dtype to retain object dtype + # otherwise we lose consistency with Index and array ops + dtype = None + if right.dtype == object: + # can't pass right.dtype unconditionally as that would break on e.g. + # datetime64[h] ndarray + dtype = object + + if right.shape == left.shape: + right = left._constructor( + right, index=left.index, columns=left.columns, dtype=dtype + ) + + elif right.shape[0] == left.shape[0] and right.shape[1] == 1: + # Broadcast across columns + right = np.broadcast_to(right, left.shape) + right = left._constructor( + right, index=left.index, columns=left.columns, dtype=dtype + ) + + elif right.shape[1] == left.shape[1] and right.shape[0] == 1: + # Broadcast along rows + right = to_series(right[0, :]) + + else: + raise ValueError( + "Unable to coerce to DataFrame, shape " + f"must be {left.shape}: given {right.shape}" + ) + + elif right.ndim > 2: + raise ValueError( + "Unable to coerce to Series/DataFrame, " + f"dimension must be <= 2: {right.shape}" + ) + + elif is_list_like(right) and not isinstance(right, (Series, DataFrame)): + # GH#36702. Raise when attempting arithmetic with list of array-like. + if any(is_array_like(el) for el in right): + raise ValueError( + f"Unable to coerce list of {type(right[0])} to Series/DataFrame" + ) + # GH#17901 + right = to_series(right) + + if flex is not None and isinstance(right, DataFrame): + if not left._indexed_same(right): + if flex: + left, right = left.align(right, join="outer", level=level) + else: + raise ValueError( + "Can only compare identically-labeled (both index and columns) " + "DataFrame objects" + ) + elif isinstance(right, Series): + # axis=1 is default for DataFrame-with-Series op + axis = axis if axis is not None else 1 + if not flex: + if not left.axes[axis].equals(right.index): + raise ValueError( + "Operands are not aligned. Do " + "`left, right = left.align(right, axis=1)` " + "before operating." + ) + + left, right = left.align( + right, + join="outer", + axis=axis, + level=level, + ) + right = left._maybe_align_series_as_frame(right, axis) + return left, right + + def _maybe_align_series_as_frame(self, series: Series, axis: AxisInt): + """ + If the Series operand is not EA-dtype, we can broadcast to 2D and operate + blockwise. + """ + rvalues = series._values + if not isinstance(rvalues, np.ndarray): + # TODO(EA2D): no need to special-case with 2D EAs + if lib.is_np_dtype(rvalues.dtype, "mM"): + # i.e. DatetimeArray[tznaive] or TimedeltaArray + # We can losslessly+cheaply cast to ndarray + rvalues = np.asarray(rvalues) + else: + return series + + if axis == 0: + rvalues = rvalues.reshape(-1, 1) + else: + rvalues = rvalues.reshape(1, -1) + + rvalues = np.broadcast_to(rvalues, self.shape) + # pass dtype to avoid doing inference + return self._constructor( + rvalues, + index=self.index, + columns=self.columns, + dtype=rvalues.dtype, + ).__finalize__(series) + + def _flex_arith_method( + self, other, op, *, axis: Axis = "columns", level=None, fill_value=None + ): + axis = self._get_axis_number(axis) if axis is not None else 1 + + if self._should_reindex_frame_op(other, op, axis, fill_value, level): + return self._arith_method_with_reindex(other, op) + + if isinstance(other, Series) and fill_value is not None: + # TODO: We could allow this in cases where we end up going + # through the DataFrame path + raise NotImplementedError(f"fill_value {fill_value} not supported.") + + other = ops.maybe_prepare_scalar_for_op(other, self.shape) + self, other = self._align_for_op(other, axis, flex=True, level=level) + + with np.errstate(all="ignore"): + if isinstance(other, DataFrame): + # Another DataFrame + new_data = self._combine_frame(other, op, fill_value) + + elif isinstance(other, Series): + new_data = self._dispatch_frame_op(other, op, axis=axis) + else: + # in this case we always have `np.ndim(other) == 0` + if fill_value is not None: + self = self.fillna(fill_value) + + new_data = self._dispatch_frame_op(other, op) + + return self._construct_result(new_data, other=other) + + def _construct_result(self, result, other) -> DataFrame: + """ + Wrap the result of an arithmetic, comparison, or logical operation. + + Parameters + ---------- + result : DataFrame + + Returns + ------- + DataFrame + """ + out = self._constructor(result, copy=False).__finalize__(self) + # Pin columns instead of passing to constructor for compat with + # non-unique columns case + out.columns = self.columns + out.index = self.index + out = out.__finalize__(other) + return out + + def __divmod__(self, other) -> tuple[DataFrame, DataFrame]: + # Naive implementation, room for optimization + div = self // other + mod = self - div * other + return div, mod + + def __rdivmod__(self, other) -> tuple[DataFrame, DataFrame]: + # Naive implementation, room for optimization + div = other // self + mod = other - div * self + return div, mod + + def _flex_cmp_method(self, other, op, *, axis: Axis = "columns", level=None): + axis = self._get_axis_number(axis) if axis is not None else 1 + + self, other = self._align_for_op(other, axis, flex=True, level=level) + + new_data = self._dispatch_frame_op(other, op, axis=axis) + return self._construct_result(new_data, other=other) + + def eq(self, other, axis: Axis = "columns", level=None) -> DataFrame: + """ + Get Not equal to of dataframe and other, element-wise (binary operator `eq`). + + Among flexible wrappers (`eq`, `ne`, `le`, `lt`, `ge`, `gt`) to comparison + operators. + + Equivalent to `==`, `!=`, `<=`, `<`, `>=`, `>` with support to choose axis + (rows or columns) and level for comparison. + + Parameters + ---------- + other : scalar, sequence, Series, or DataFrame + Any single or multiple element data structure, or list-like object. + axis : {0 or 'index', 1 or 'columns'}, default 'columns' + Whether to compare by the index (0 or 'index') or columns + (1 or 'columns'). + level : int or label + Broadcast across a level, matching Index values on the passed + MultiIndex level. + + Returns + ------- + DataFrame of bool + Result of the comparison. + + See Also + -------- + DataFrame.eq : Compare DataFrames for equality elementwise. + DataFrame.ne : Compare DataFrames for inequality elementwise. + DataFrame.le : Compare DataFrames for less than inequality + or equality elementwise. + DataFrame.lt : Compare DataFrames for strictly less than + inequality elementwise. + DataFrame.ge : Compare DataFrames for greater than inequality + or equality elementwise. + DataFrame.gt : Compare DataFrames for strictly greater than + inequality elementwise. + + Notes + ----- + Mismatched indices will be unioned together. + `NaN` values are considered different (i.e. `NaN` != `NaN`). + + Examples + -------- + >>> df = pd.DataFrame( + ... {"cost": [250, 150, 100], "revenue": [100, 250, 300]}, + ... index=["A", "B", "C"], + ... ) + >>> df + cost revenue + A 250 100 + B 150 250 + C 100 300 + + Comparison with a scalar, using either the operator or method: + + >>> df == 100 + cost revenue + A False True + B False False + C True False + + >>> df.eq(100) + cost revenue + A False True + B False False + C True False + + When `other` is a :class:`Series`, the columns of a DataFrame are aligned + with the index of `other` and broadcast: + + >>> df != pd.Series([100, 250], index=["cost", "revenue"]) + cost revenue + A True True + B True False + C False True + + Use the method to control the broadcast axis: + + >>> df.ne(pd.Series([100, 300], index=["A", "D"]), axis="index") + cost revenue + A True False + B True True + C True True + D True True + + When comparing to an arbitrary sequence, the number of columns must + match the number elements in `other`: + + >>> df == [250, 100] + cost revenue + A True True + B False False + C False False + + Use the method to control the axis: + + >>> df.eq([250, 250, 100], axis="index") + cost revenue + A True False + B False True + C True False + + Compare to a DataFrame of different shape. + + >>> other = pd.DataFrame( + ... {"revenue": [300, 250, 100, 150]}, index=["A", "B", "C", "D"] + ... ) + >>> other + revenue + A 300 + B 250 + C 100 + D 150 + + >>> df.gt(other) + cost revenue + A False False + B False False + C False True + D False False + + Compare to a MultiIndex by level. + + >>> df_multindex = pd.DataFrame( + ... { + ... "cost": [250, 150, 100, 150, 300, 220], + ... "revenue": [100, 250, 300, 200, 175, 225], + ... }, + ... index=[ + ... ["Q1", "Q1", "Q1", "Q2", "Q2", "Q2"], + ... ["A", "B", "C", "A", "B", "C"], + ... ], + ... ) + >>> df_multindex + cost revenue + Q1 A 250 100 + B 150 250 + C 100 300 + Q2 A 150 200 + B 300 175 + C 220 225 + + >>> df.le(df_multindex, level=1) + cost revenue + Q1 A True True + B True True + C True True + Q2 A False True + B True False + C True False + """ + return self._flex_cmp_method(other, operator.eq, axis=axis, level=level) + + def ne(self, other, axis: Axis = "columns", level=None) -> DataFrame: + """ + Get Not equal to of dataframe and other, element-wise (binary operator `ne`). + + Among flexible wrappers (`eq`, `ne`, `le`, `lt`, `ge`, `gt`) to comparison + operators. + + Equivalent to `==`, `!=`, `<=`, `<`, `>=`, `>` with support to choose axis + (rows or columns) and level for comparison. + + Parameters + ---------- + other : scalar, sequence, Series, or DataFrame + Any single or multiple element data structure, or list-like object. + axis : {0 or 'index', 1 or 'columns'}, default 'columns' + Whether to compare by the index (0 or 'index') or columns + (1 or 'columns'). + level : int or label + Broadcast across a level, matching Index values on the passed + MultiIndex level. + + Returns + ------- + DataFrame of bool + Result of the comparison. + + See Also + -------- + DataFrame.eq : Compare DataFrames for equality elementwise. + DataFrame.ne : Compare DataFrames for inequality elementwise. + DataFrame.le : Compare DataFrames for less than inequality + or equality elementwise. + DataFrame.lt : Compare DataFrames for strictly less than + inequality elementwise. + DataFrame.ge : Compare DataFrames for greater than inequality + or equality elementwise. + DataFrame.gt : Compare DataFrames for strictly greater than + inequality elementwise. + + Notes + ----- + Mismatched indices will be unioned together. + `NaN` values are considered different (i.e. `NaN` != `NaN`). + + Examples + -------- + >>> df = pd.DataFrame( + ... {"cost": [250, 150, 100], "revenue": [100, 250, 300]}, + ... index=["A", "B", "C"], + ... ) + >>> df + cost revenue + A 250 100 + B 150 250 + C 100 300 + + Comparison with a scalar, using either the operator or method: + + >>> df == 100 + cost revenue + A False True + B False False + C True False + + >>> df.eq(100) + cost revenue + A False True + B False False + C True False + + When `other` is a :class:`Series`, the columns of a DataFrame are aligned + with the index of `other` and broadcast: + + >>> df != pd.Series([100, 250], index=["cost", "revenue"]) + cost revenue + A True True + B True False + C False True + + Use the method to control the broadcast axis: + + >>> df.ne(pd.Series([100, 300], index=["A", "D"]), axis="index") + cost revenue + A True False + B True True + C True True + D True True + + When comparing to an arbitrary sequence, the number of columns must + match the number elements in `other`: + + >>> df == [250, 100] + cost revenue + A True True + B False False + C False False + + Use the method to control the axis: + + >>> df.eq([250, 250, 100], axis="index") + cost revenue + A True False + B False True + C True False + + Compare to a DataFrame of different shape. + + >>> other = pd.DataFrame( + ... {"revenue": [300, 250, 100, 150]}, index=["A", "B", "C", "D"] + ... ) + >>> other + revenue + A 300 + B 250 + C 100 + D 150 + + >>> df.gt(other) + cost revenue + A False False + B False False + C False True + D False False + + Compare to a MultiIndex by level. + + >>> df_multindex = pd.DataFrame( + ... { + ... "cost": [250, 150, 100, 150, 300, 220], + ... "revenue": [100, 250, 300, 200, 175, 225], + ... }, + ... index=[ + ... ["Q1", "Q1", "Q1", "Q2", "Q2", "Q2"], + ... ["A", "B", "C", "A", "B", "C"], + ... ], + ... ) + >>> df_multindex + cost revenue + Q1 A 250 100 + B 150 250 + C 100 300 + Q2 A 150 200 + B 300 175 + C 220 225 + + >>> df.le(df_multindex, level=1) + cost revenue + Q1 A True True + B True True + C True True + Q2 A False True + B True False + C True False + """ + return self._flex_cmp_method(other, operator.ne, axis=axis, level=level) + + @Appender(ops.make_flex_doc("le", "dataframe")) + def le(self, other, axis: Axis = "columns", level=None) -> DataFrame: + return self._flex_cmp_method(other, operator.le, axis=axis, level=level) + + @Appender(ops.make_flex_doc("lt", "dataframe")) + def lt(self, other, axis: Axis = "columns", level=None) -> DataFrame: + return self._flex_cmp_method(other, operator.lt, axis=axis, level=level) + + @Appender(ops.make_flex_doc("ge", "dataframe")) + def ge(self, other, axis: Axis = "columns", level=None) -> DataFrame: + return self._flex_cmp_method(other, operator.ge, axis=axis, level=level) + + @Appender(ops.make_flex_doc("gt", "dataframe")) + def gt(self, other, axis: Axis = "columns", level=None) -> DataFrame: + return self._flex_cmp_method(other, operator.gt, axis=axis, level=level) + + @Appender(ops.make_flex_doc("add", "dataframe")) + def add( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, operator.add, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("radd", "dataframe")) + def radd( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, roperator.radd, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("sub", "dataframe")) + def sub( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, operator.sub, level=level, fill_value=fill_value, axis=axis + ) + + subtract = sub + + @Appender(ops.make_flex_doc("rsub", "dataframe")) + def rsub( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, roperator.rsub, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("mul", "dataframe")) + def mul( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, operator.mul, level=level, fill_value=fill_value, axis=axis + ) + + multiply = mul + + @Appender(ops.make_flex_doc("rmul", "dataframe")) + def rmul( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, roperator.rmul, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("truediv", "dataframe")) + def truediv( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, operator.truediv, level=level, fill_value=fill_value, axis=axis + ) + + div = truediv + divide = truediv + + @Appender(ops.make_flex_doc("rtruediv", "dataframe")) + def rtruediv( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, roperator.rtruediv, level=level, fill_value=fill_value, axis=axis + ) + + rdiv = rtruediv + + @Appender(ops.make_flex_doc("floordiv", "dataframe")) + def floordiv( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, operator.floordiv, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("rfloordiv", "dataframe")) + def rfloordiv( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, roperator.rfloordiv, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("mod", "dataframe")) + def mod( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, operator.mod, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("rmod", "dataframe")) + def rmod( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, roperator.rmod, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("pow", "dataframe")) + def pow( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, operator.pow, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("rpow", "dataframe")) + def rpow( + self, other, axis: Axis = "columns", level=None, fill_value=None + ) -> DataFrame: + return self._flex_arith_method( + other, roperator.rpow, level=level, fill_value=fill_value, axis=axis + ) + + # ---------------------------------------------------------------------- + # Combination-Related + + def compare( + self, + other: DataFrame, + align_axis: Axis = 1, + keep_shape: bool = False, + keep_equal: bool = False, + result_names: Suffixes = ("self", "other"), + ) -> DataFrame: + """ + Compare to another DataFrame and show the differences. + + Parameters + ---------- + other : DataFrame + Object to compare with. + + align_axis : {0 or 'index', 1 or 'columns'}, default 1 + Determine which axis to align the comparison on. + + * 0, or 'index' : Resulting differences are stacked vertically + with rows drawn alternately from self and other. + * 1, or 'columns' : Resulting differences are aligned horizontally + with columns drawn alternately from self and other. + + keep_shape : bool, default False + If true, all rows and columns are kept. + Otherwise, only the ones with different values are kept. + + keep_equal : bool, default False + If true, the result keeps values that are equal. + Otherwise, equal values are shown as NaNs. + + result_names : tuple, default ('self', 'other') + Set the dataframes names in the comparison. + + Returns + ------- + DataFrame + DataFrame that shows the differences stacked side by side. + + The resulting index will be a MultiIndex with 'self' and 'other' + stacked alternately at the inner level. + + Raises + ------ + ValueError + When the two DataFrames don't have identical labels or shape. + + See Also + -------- + Series.compare : Compare with another Series and show differences. + DataFrame.equals : Test whether two objects contain the same elements. + + Notes + ----- + Matching NaNs will not appear as a difference. + + Can only compare identically-labeled + (i.e. same shape, identical row and column labels) DataFrames + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "col1": ["a", "a", "b", "b", "a"], + ... "col2": [1.0, 2.0, 3.0, np.nan, 5.0], + ... "col3": [1.0, 2.0, 3.0, 4.0, 5.0], + ... }, + ... columns=["col1", "col2", "col3"], + ... ) + >>> df + col1 col2 col3 + 0 a 1.0 1.0 + 1 a 2.0 2.0 + 2 b 3.0 3.0 + 3 b NaN 4.0 + 4 a 5.0 5.0 + + >>> df2 = df.copy() + >>> df2.loc[0, "col1"] = "c" + >>> df2.loc[2, "col3"] = 4.0 + >>> df2 + col1 col2 col3 + 0 c 1.0 1.0 + 1 a 2.0 2.0 + 2 b 3.0 4.0 + 3 b NaN 4.0 + 4 a 5.0 5.0 + + Align the differences on columns + + >>> df.compare(df2) + col1 col3 + self other self other + 0 a c NaN NaN + 2 NaN NaN 3.0 4.0 + + Assign result_names + + >>> df.compare(df2, result_names=("left", "right")) + col1 col3 + left right left right + 0 a c NaN NaN + 2 NaN NaN 3.0 4.0 + + Stack the differences on rows + + >>> df.compare(df2, align_axis=0) + col1 col3 + 0 self a NaN + other c NaN + 2 self NaN 3.0 + other NaN 4.0 + + Keep the equal values + + >>> df.compare(df2, keep_equal=True) + col1 col3 + self other self other + 0 a c 1.0 1.0 + 2 b b 3.0 4.0 + + Keep all original rows and columns + + >>> df.compare(df2, keep_shape=True) + col1 col2 col3 + self other self other self other + 0 a c NaN NaN NaN NaN + 1 NaN NaN NaN NaN NaN NaN + 2 NaN NaN NaN NaN 3.0 4.0 + 3 NaN NaN NaN NaN NaN NaN + 4 NaN NaN NaN NaN NaN NaN + + Keep all original rows and columns and also all original values + + >>> df.compare(df2, keep_shape=True, keep_equal=True) + col1 col2 col3 + self other self other self other + 0 a c 1.0 1.0 1.0 1.0 + 1 a a 2.0 2.0 2.0 2.0 + 2 b b 3.0 3.0 3.0 4.0 + 3 b b NaN NaN 4.0 4.0 + 4 a a 5.0 5.0 5.0 5.0 + """ + return super().compare( + other=other, + align_axis=align_axis, + keep_shape=keep_shape, + keep_equal=keep_equal, + result_names=result_names, + ) + + def combine( + self, + other: DataFrame, + func: Callable[[Series, Series], Series | Hashable], + fill_value=None, + overwrite: bool = True, + ) -> DataFrame: + """ + Perform column-wise combine with another DataFrame. + + Combines a DataFrame with `other` DataFrame using `func` + to element-wise combine columns. The row and column indexes of the + resulting DataFrame will be the union of the two. + + Parameters + ---------- + other : DataFrame + The DataFrame to merge column-wise. + func : function + Function that takes two series as inputs and return a Series or a + scalar. Used to merge the two dataframes column by columns. + fill_value : scalar value, default None + The value to fill NaNs with prior to passing any column to the + merge func. + overwrite : bool, default True + If True, columns in `self` that do not exist in `other` will be + overwritten with NaNs. + + Returns + ------- + DataFrame + Combination of the provided DataFrames. + + See Also + -------- + DataFrame.combine_first : Combine two DataFrame objects and default to + non-null values in frame calling the method. + + Examples + -------- + Combine using a simple function that chooses the smaller column. + + >>> df1 = pd.DataFrame({"A": [0, 0], "B": [4, 4]}) + >>> df2 = pd.DataFrame({"A": [1, 1], "B": [3, 3]}) + >>> take_smaller = lambda s1, s2: s1 if s1.sum() < s2.sum() else s2 + >>> df1.combine(df2, take_smaller) + A B + 0 0 3 + 1 0 3 + + Example using a true element-wise combine function. + + >>> df1 = pd.DataFrame({"A": [5, 0], "B": [2, 4]}) + >>> df2 = pd.DataFrame({"A": [1, 1], "B": [3, 3]}) + >>> df1.combine(df2, np.minimum) + A B + 0 1 2 + 1 0 3 + + Using `fill_value` fills Nones prior to passing the column to the + merge function. + + >>> df1 = pd.DataFrame({"A": [0, 0], "B": [None, 4]}) + >>> df2 = pd.DataFrame({"A": [1, 1], "B": [3, 3]}) + >>> df1.combine(df2, take_smaller, fill_value=-5) + A B + 0 0 -5.0 + 1 0 4.0 + + Example that demonstrates the use of `overwrite` and behavior when + the axis differ between the dataframes. + + >>> df1 = pd.DataFrame({"A": [0, 0], "B": [4, 4]}) + >>> df2 = pd.DataFrame( + ... { + ... "B": [3, 3], + ... "C": [-10, 1], + ... }, + ... index=[1, 2], + ... ) + >>> df1.combine(df2, take_smaller) + A B C + 0 NaN NaN NaN + 1 NaN 3.0 -10.0 + 2 NaN 3.0 1.0 + + >>> df1.combine(df2, take_smaller, overwrite=False) + A B C + 0 0.0 NaN NaN + 1 0.0 3.0 -10.0 + 2 NaN 3.0 1.0 + + Demonstrating the preference of the passed in dataframe. + + >>> df2 = pd.DataFrame( + ... { + ... "B": [3, 3], + ... "C": [1, 1], + ... }, + ... index=[1, 2], + ... ) + >>> df2.combine(df1, take_smaller) + B C A + 0 NaN NaN 0.0 + 1 3.0 NaN 0.0 + 2 3.0 NaN NaN + + >>> df2.combine(df1, take_smaller, overwrite=False) + B C A + 0 NaN NaN 0.0 + 1 3.0 1.0 0.0 + 2 3.0 1.0 NaN + """ + other_idxlen = len(other.index) # save for compare + other_columns = other.columns + + this, other = self.align(other) + new_index = this.index + + if other.empty and len(new_index) == len(self.index): + return self.copy() + + if self.empty and len(other) == other_idxlen: + return other.copy() + + # preserve column order + new_columns = self.columns.union(other_columns, sort=False) + this = this.reindex(new_columns, axis=1) + other = other.reindex(new_columns, axis=1) + + do_fill = fill_value is not None + result = {} + for i in range(this.shape[1]): + series = this.iloc[:, i] + other_series = other.iloc[:, i] + + this_dtype = series.dtype + other_dtype = other_series.dtype + + this_mask = isna(series) + other_mask = isna(other_series) + + # don't overwrite columns unnecessarily + # DO propagate if this column is not in the intersection + if not overwrite and other_mask.all(): + result[i] = series.copy() + continue + + if do_fill: + series = series.copy() + other_series = other_series.copy() + series[this_mask] = fill_value + other_series[other_mask] = fill_value + + if new_columns[i] not in self.columns: + # If self DataFrame does not have col in other DataFrame, + # try to promote series, which is all NaN, as other_dtype. + new_dtype = other_dtype + try: + series = series.astype(new_dtype) + except ValueError: + # e.g. new_dtype is integer types + pass + else: + # if we have different dtypes, possibly promote + new_dtype = find_common_type([this_dtype, other_dtype]) + series = series.astype(new_dtype) + other_series = other_series.astype(new_dtype) + + arr = func(series, other_series) + if isinstance(new_dtype, np.dtype): + # if new_dtype is an EA Dtype, then `func` is expected to return + # the correct dtype without any additional casting + # error: No overload variant of "maybe_downcast_to_dtype" matches + # argument types "Union[Series, Hashable]", "dtype[Any]" + arr = maybe_downcast_to_dtype( # type: ignore[call-overload] + arr, new_dtype + ) + + result[i] = arr + + frame_result = self._constructor(result, index=new_index) + frame_result.columns = new_columns + return frame_result.__finalize__(self, method="combine") + + def combine_first(self, other: DataFrame) -> DataFrame: + """ + Update null elements with value in the same location in `other`. + + Combine two DataFrame objects by filling null values in one DataFrame + with non-null values from other DataFrame. The row and column indexes + of the resulting DataFrame will be the union of the two. The resulting + dataframe contains the 'first' dataframe values and overrides the + second one values where both first.loc[index, col] and + second.loc[index, col] are not missing values, upon calling + first.combine_first(second). + + Parameters + ---------- + other : DataFrame + Provided DataFrame to use to fill null values. + + Returns + ------- + DataFrame + The result of combining the provided DataFrame with the other object. + + See Also + -------- + DataFrame.combine : Perform series-wise operation on two DataFrames + using a given function. + + Examples + -------- + >>> df1 = pd.DataFrame({"A": [None, 0], "B": [None, 4]}) + >>> df2 = pd.DataFrame({"A": [1, 1], "B": [3, 3]}) + >>> df1.combine_first(df2) + A B + 0 1.0 3.0 + 1 0.0 4.0 + + Null values still persist if the location of that null value + does not exist in `other` + + >>> df1 = pd.DataFrame({"A": [None, 0], "B": [4, None]}) + >>> df2 = pd.DataFrame({"B": [3, 3], "C": [1, 1]}, index=[1, 2]) + >>> df1.combine_first(df2) + A B C + 0 NaN 4.0 NaN + 1 0.0 3.0 1.0 + 2 NaN 3.0 1.0 + """ + + def combiner(x: Series, y: Series): + # GH#60128 The combiner is supposed to preserve EA Dtypes. + return y if y.name not in self.columns else y.where(x.isna(), x) + + if len(other) == 0: + combined = self.reindex( + self.columns.append(other.columns.difference(self.columns)), axis=1 + ) + combined = combined.astype(other.dtypes) + else: + combined = self.combine(other, combiner, overwrite=False) + + dtypes = { + # Check for isinstance(..., (np.dtype, ExtensionDtype)) + # to prevent raising on non-unique columns see GH#29135. + # Note we will just not-cast in these cases. + col: find_common_type([self.dtypes[col], other.dtypes[col]]) + for col in self.columns.intersection(other.columns) + if isinstance(combined.dtypes[col], (np.dtype, ExtensionDtype)) + and isinstance(self.dtypes[col], (np.dtype, ExtensionDtype)) + and combined.dtypes[col] != self.dtypes[col] + } + + if dtypes: + combined = combined.astype(dtypes) + + return combined.__finalize__(self, method="combine_first") + + def update( + self, + other, + join: UpdateJoin = "left", + overwrite: bool = True, + filter_func=None, + errors: IgnoreRaise = "ignore", + ) -> None: + """ + Modify in place using non-NA values from another DataFrame. + + Aligns on indices. There is no return value. + + Parameters + ---------- + other : DataFrame, or object coercible into a DataFrame + Should have at least one matching index/column label + with the original DataFrame. If a Series is passed, + its name attribute must be set, and that will be + used as the column name to align with the original DataFrame. + join : {'left'}, default 'left' + Only left join is implemented, keeping the index and columns of the + original object. + overwrite : bool, default True + How to handle non-NA values for overlapping keys: + + * True: overwrite original DataFrame's values + with values from `other`. + * False: only update values that are NA in + the original DataFrame. + + filter_func : callable(1d-array) -> bool 1d-array, optional + Can choose to replace values other than NA. Return True for values + that should be updated. + errors : {'raise', 'ignore'}, default 'ignore' + If 'raise', will raise a ValueError if the DataFrame and `other` + both contain non-NA data in the same place. + + Returns + ------- + None + This method directly changes calling object. + + Raises + ------ + ValueError + * When `errors='raise'` and there's overlapping non-NA data. + * When `errors` is not either `'ignore'` or `'raise'` + NotImplementedError + * If `join != 'left'` + + See Also + -------- + dict.update : Similar method for dictionaries. + DataFrame.merge : For column(s)-on-column(s) operations. + + Notes + ----- + 1. Duplicate indices on `other` are not supported and raises `ValueError`. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 2, 3], "B": [400, 500, 600]}) + >>> new_df = pd.DataFrame({"B": [4, 5, 6], "C": [7, 8, 9]}) + >>> df.update(new_df) + >>> df + A B + 0 1 4 + 1 2 5 + 2 3 6 + + The DataFrame's length does not increase as a result of the update, + only values at matching index/column labels are updated. + + >>> df = pd.DataFrame({"A": ["a", "b", "c"], "B": ["x", "y", "z"]}) + >>> new_df = pd.DataFrame({"B": ["d", "e", "f", "g", "h", "i"]}) + >>> df.update(new_df) + >>> df + A B + 0 a d + 1 b e + 2 c f + + >>> df = pd.DataFrame({"A": ["a", "b", "c"], "B": ["x", "y", "z"]}) + >>> new_df = pd.DataFrame({"B": ["d", "f"]}, index=[0, 2]) + >>> df.update(new_df) + >>> df + A B + 0 a d + 1 b y + 2 c f + + For Series, its name attribute must be set. + + >>> df = pd.DataFrame({"A": ["a", "b", "c"], "B": ["x", "y", "z"]}) + >>> new_column = pd.Series(["d", "e", "f"], name="B") + >>> df.update(new_column) + >>> df + A B + 0 a d + 1 b e + 2 c f + + If `other` contains NaNs the corresponding values are not updated + in the original dataframe. + + >>> df = pd.DataFrame({"A": [1, 2, 3], "B": [400.0, 500.0, 600.0]}) + >>> new_df = pd.DataFrame({"B": [4, np.nan, 6]}) + >>> df.update(new_df) + >>> df + A B + 0 1 4.0 + 1 2 500.0 + 2 3 6.0 + """ + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount( + self + ) <= REF_COUNT_METHOD and not com.is_local_in_caller_frame(self): + warnings.warn( + _chained_assignment_method_update_msg, + ChainedAssignmentError, + stacklevel=2, + ) + + # TODO: Support other joins + if join != "left": # pragma: no cover + raise NotImplementedError("Only left join is supported") + if errors not in ["ignore", "raise"]: + raise ValueError("The parameter errors must be either 'ignore' or 'raise'") + + if not isinstance(other, DataFrame): + other = DataFrame(other) + + if other.index.has_duplicates: + raise ValueError("Update not allowed with duplicate indexes on other.") + + index_intersection = other.index.intersection(self.index) + if index_intersection.empty: + return + other = other.reindex(index_intersection) + this_data = self.loc[index_intersection] + + for col in self.columns.intersection(other.columns): + this = this_data[col] + that = other[col] + + if filter_func is not None: + mask = ~filter_func(this) | isna(that) + else: + if errors == "raise": + mask_this = notna(that) + mask_that = notna(this) + if any(mask_this & mask_that): + raise ValueError("Data overlaps.") + + if overwrite: + mask = isna(that) + else: + mask = notna(this) + + # don't overwrite columns unnecessarily + if mask.all(): + continue + + self.loc[index_intersection, col] = this.where(mask, that) + + # ---------------------------------------------------------------------- + # Data reshaping + @deprecate_nonkeyword_arguments( + Pandas4Warning, allowed_args=["self", "by", "level"], name="groupby" + ) + def groupby( + self, + by=None, + level: IndexLabel | None = None, + as_index: bool = True, + sort: bool = True, + group_keys: bool = True, + observed: bool = True, + dropna: bool = True, + ) -> DataFrameGroupBy: + """ + Group DataFrame using a mapper or by a Series of columns. + + A groupby operation involves some combination of splitting the + object, applying a function, and combining the results. This can be + used to group large amounts of data and compute operations on these + groups. + + Parameters + ---------- + by : mapping, function, label, pd.Grouper or list of such + Used to determine the groups for the groupby. + If ``by`` is a function, it's called on each value of the object's + index. If a dict or Series is passed, the Series or dict VALUES + will be used to determine the groups (the Series' values are first + aligned; see ``.align()`` method). If a list or ndarray of length + equal to the number of rows is passed (see the `groupby user guide + `_), + the values are used as-is to determine the groups. A label or list + of labels may be passed to group by the columns in ``self``. + Notice that a tuple is interpreted as a (single) key. + level : int, level name, or sequence of such, default None + If the axis is a MultiIndex (hierarchical), group by a particular + level or levels. Do not specify both ``by`` and ``level``. + as_index : bool, default True + Return object with group labels as the + index. Only relevant for DataFrame input. as_index=False is + effectively "SQL-style" grouped output. This argument has no effect + on filtrations (see the `filtrations in the user guide + `_), + such as ``head()``, ``tail()``, ``nth()`` and in transformations + (see the `transformations in the user guide + `_). + sort : bool, default True + Sort group keys. Get better performance by turning this off. + Note this does not influence the order of observations within each + group. Groupby preserves the order of rows within each group. If False, + the groups will appear in the same order as they did in the original + DataFrame. + This argument has no effect on filtrations (see the `filtrations + in the user guide + `_), + such as ``head()``, ``tail()``, ``nth()`` and in transformations + (see the `transformations in the user guide + `_). + + .. versionchanged:: 2.0.0 + + Specifying ``sort=False`` with an ordered categorical grouper will no + longer sort the values. + + group_keys : bool, default True + When calling apply and the ``by`` argument produces a like-indexed + (i.e. :ref:`a transform `) result, add group keys to + index to identify pieces. By default group keys are not included + when the result's index (and column) labels match the inputs, and + are included otherwise. + + .. versionchanged:: 2.0.0 + + ``group_keys`` now defaults to ``True``. + + observed : bool, default True + This only applies if any of the groupers are Categoricals. + If True: only show observed values for categorical groupers. + If False: show all values for categorical groupers. + + .. versionchanged:: 3.0.0 + + The default value is now ``True``. + + dropna : bool, default True + If True, and if group keys contain NA values, NA values together + with row/column will be dropped. + If False, NA values will also be treated as the key in groups. + + Returns + ------- + pandas.api.typing.DataFrameGroupBy + Returns a groupby object that contains information about the groups. + + See Also + -------- + resample : Convenience method for frequency conversion and resampling + of time series. + + Notes + ----- + See the `user guide + `__ for more + detailed usage and examples, including splitting an object into groups, + iterating through groups, selecting a group, aggregation, and more. + + The implementation of groupby is hash-based, meaning in particular that + objects that compare as equal will be considered to be in the same group. + An exception to this is that pandas has special handling of NA values: + any NA values will be collapsed to a single group, regardless of how + they compare. See the user guide linked above for more details. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "Animal": ["Falcon", "Falcon", "Parrot", "Parrot"], + ... "Max Speed": [380.0, 370.0, 24.0, 26.0], + ... } + ... ) + >>> df + Animal Max Speed + 0 Falcon 380.0 + 1 Falcon 370.0 + 2 Parrot 24.0 + 3 Parrot 26.0 + >>> df.groupby(["Animal"]).mean() + Max Speed + Animal + Falcon 375.0 + Parrot 25.0 + + **Hierarchical Indexes** + + We can groupby different levels of a hierarchical index + using the `level` parameter: + + >>> arrays = [ + ... ["Falcon", "Falcon", "Parrot", "Parrot"], + ... ["Captive", "Wild", "Captive", "Wild"], + ... ] + >>> index = pd.MultiIndex.from_arrays(arrays, names=("Animal", "Type")) + >>> df = pd.DataFrame({"Max Speed": [390.0, 350.0, 30.0, 20.0]}, index=index) + >>> df + Max Speed + Animal Type + Falcon Captive 390.0 + Wild 350.0 + Parrot Captive 30.0 + Wild 20.0 + >>> df.groupby(level=0).mean() + Max Speed + Animal + Falcon 370.0 + Parrot 25.0 + >>> df.groupby(level="Type").mean() + Max Speed + Type + Captive 210.0 + Wild 185.0 + + We can also choose to include NA in group keys or not by setting + `dropna` parameter, the default setting is `True`. + + >>> arr = [[1, 2, 3], [1, None, 4], [2, 1, 3], [1, 2, 2]] + >>> df = pd.DataFrame(arr, columns=["a", "b", "c"]) + + >>> df.groupby(by=["b"]).sum() + a c + b + 1.0 2 3 + 2.0 2 5 + + >>> df.groupby(by=["b"], dropna=False).sum() + a c + b + 1.0 2 3 + 2.0 2 5 + NaN 1 4 + + >>> arr = [["a", 12, 12], [None, 12.3, 33.0], ["b", 12.3, 123], ["a", 1, 1]] + >>> df = pd.DataFrame(arr, columns=["a", "b", "c"]) + + >>> df.groupby(by="a").sum() + b c + a + a 13.0 13.0 + b 12.3 123.0 + + >>> df.groupby(by="a", dropna=False).sum() + b c + a + a 13.0 13.0 + b 12.3 123.0 + NaN 12.3 33.0 + + When using ``.apply()``, use ``group_keys`` to include or exclude the + group keys. The ``group_keys`` argument defaults to ``True`` (include). + + >>> df = pd.DataFrame( + ... { + ... "Animal": ["Falcon", "Falcon", "Parrot", "Parrot"], + ... "Max Speed": [380.0, 370.0, 24.0, 26.0], + ... } + ... ) + >>> df.groupby("Animal", group_keys=True)[["Max Speed"]].apply(lambda x: x) + Max Speed + Animal + Falcon 0 380.0 + 1 370.0 + Parrot 2 24.0 + 3 26.0 + + >>> df.groupby("Animal", group_keys=False)[["Max Speed"]].apply(lambda x: x) + Max Speed + 0 380.0 + 1 370.0 + 2 24.0 + 3 26.0 + """ + from pandas.core.groupby.generic import DataFrameGroupBy + + if level is None and by is None: + raise TypeError("You have to supply one of 'by' and 'level'") + + return DataFrameGroupBy( + obj=self, + keys=by, + level=level, + as_index=as_index, + sort=sort, + group_keys=group_keys, + observed=observed, + dropna=dropna, + ) + + _shared_docs["pivot"] = """ + Return reshaped DataFrame organized by given index / column values. + + Reshape data (produce a "pivot" table) based on column values. Uses + unique values from specified `index` / `columns` to form axes of the + resulting DataFrame. This function does not support data + aggregation, multiple values will result in a MultiIndex in the + columns. See the :ref:`User Guide ` for more on reshaping. + + Parameters + ----------%s + columns : Hashable or a sequence of the previous + Column to use to make new frame's columns. + index : Hashable or a sequence of the previous, optional + Column to use to make new frame's index. If not given, uses existing index. + values : Hashable or a sequence of the previous, optional + Column(s) to use for populating new frame's values. If not + specified, all remaining columns will be used and the result will + have hierarchically indexed columns. + + Returns + ------- + DataFrame + Returns reshaped DataFrame. + + Raises + ------ + ValueError: + When there are any `index`, `columns` combinations with multiple + values. `DataFrame.pivot_table` when you need to aggregate. + + See Also + -------- + DataFrame.pivot_table : Generalization of pivot that can handle + duplicate values for one index/column pair. + DataFrame.unstack : Pivot based on the index values instead of a + column. + wide_to_long : Wide panel to long format. Less flexible but more + user-friendly than melt. + + Notes + ----- + For finer-tuned control, see hierarchical indexing documentation along + with the related stack/unstack methods. + + Reference :ref:`the user guide ` for more examples. + + Examples + -------- + >>> df = pd.DataFrame({'foo': ['one', 'one', 'one', 'two', 'two', + ... 'two'], + ... 'bar': ['A', 'B', 'C', 'A', 'B', 'C'], + ... 'baz': [1, 2, 3, 4, 5, 6], + ... 'zoo': ['x', 'y', 'z', 'q', 'w', 't']}) + >>> df + foo bar baz zoo + 0 one A 1 x + 1 one B 2 y + 2 one C 3 z + 3 two A 4 q + 4 two B 5 w + 5 two C 6 t + + >>> df.pivot(index='foo', columns='bar', values='baz') + bar A B C + foo + one 1 2 3 + two 4 5 6 + + >>> df.pivot(index='foo', columns='bar')['baz'] + bar A B C + foo + one 1 2 3 + two 4 5 6 + + >>> df.pivot(index='foo', columns='bar', values=['baz', 'zoo']) + baz zoo + bar A B C A B C + foo + one 1 2 3 x y z + two 4 5 6 q w t + + You could also assign a list of column names or a list of index names. + + >>> df = pd.DataFrame({ + ... "lev1": [1, 1, 1, 2, 2, 2], + ... "lev2": [1, 1, 2, 1, 1, 2], + ... "lev3": [1, 2, 1, 2, 1, 2], + ... "lev4": [1, 2, 3, 4, 5, 6], + ... "values": [0, 1, 2, 3, 4, 5]}) + >>> df + lev1 lev2 lev3 lev4 values + 0 1 1 1 1 0 + 1 1 1 2 2 1 + 2 1 2 1 3 2 + 3 2 1 2 4 3 + 4 2 1 1 5 4 + 5 2 2 2 6 5 + + >>> df.pivot(index="lev1", columns=["lev2", "lev3"], values="values") + lev2 1 2 + lev3 1 2 1 2 + lev1 + 1 0.0 1.0 2.0 NaN + 2 4.0 3.0 NaN 5.0 + + >>> df.pivot(index=["lev1", "lev2"], columns=["lev3"], values="values") + lev3 1 2 + lev1 lev2 + 1 1 0.0 1.0 + 2 2.0 NaN + 2 1 4.0 3.0 + 2 NaN 5.0 + + A ValueError is raised if there are any duplicates. + + >>> df = pd.DataFrame({"foo": ['one', 'one', 'two', 'two'], + ... "bar": ['A', 'A', 'B', 'C'], + ... "baz": [1, 2, 3, 4]}) + >>> df + foo bar baz + 0 one A 1 + 1 one A 2 + 2 two B 3 + 3 two C 4 + + Notice that the first two rows are the same for our `index` + and `columns` arguments. + + >>> df.pivot(index='foo', columns='bar', values='baz') + Traceback (most recent call last): + ... + ValueError: Index contains duplicate entries, cannot reshape + """ + + @Substitution("") + @Appender(_shared_docs["pivot"]) + def pivot( + self, *, columns, index=lib.no_default, values=lib.no_default + ) -> DataFrame: + from pandas.core.reshape.pivot import pivot + + return pivot(self, index=index, columns=columns, values=values) + + _shared_docs["pivot_table"] = """ + Create a spreadsheet-style pivot table as a DataFrame. + + The levels in the pivot table will be stored in MultiIndex objects + (hierarchical indexes) on the index and columns of the result DataFrame. + + Parameters + ----------%s + values : list-like or scalar, optional + Column or columns to aggregate. + index : column, Grouper, array, or sequence of the previous + Keys to group by on the pivot table index. If a list is passed, + it can contain any of the other types (except list). If an array is + passed, it must be the same length as the data and will be used in + the same manner as column values. + columns : column, Grouper, array, or sequence of the previous + Keys to group by on the pivot table column. If a list is passed, + it can contain any of the other types (except list). If an array is + passed, it must be the same length as the data and will be used in + the same manner as column values. + aggfunc : function, list of functions, dict, default "mean" + If a list of functions is passed, the resulting pivot table will have + hierarchical columns whose top level are the function names + (inferred from the function objects themselves). + If a dict is passed, the key is column to aggregate and the value is + function or list of functions. If ``margin=True``, aggfunc will be + used to calculate the partial aggregates. + fill_value : scalar, default None + Value to replace missing values with (in the resulting pivot table, + after aggregation). + margins : bool, default False + If ``margins=True``, special ``All`` columns and rows + will be added with partial group aggregates across the categories + on the rows and columns. + dropna : bool, default True + Do not include columns whose entries are all NaN. If True, + + * rows with an NA value in any column will be omitted before computing + margins, + * index/column keys containing NA values will be dropped (see ``dropna`` + parameter in :meth:`DataFrame.groupby`). + + margins_name : str, default 'All' + Name of the row / column that will contain the totals + when margins is True. + observed : bool, default False + This only applies if any of the groupers are Categoricals. + If True: only show observed values for categorical groupers. + If False: show all values for categorical groupers. + + .. versionchanged:: 3.0.0 + + The default value is now ``True``. + + sort : bool, default True + Specifies if the result should be sorted. + + **kwargs : dict + Optional keyword arguments to pass to ``aggfunc``. + + Returns + ------- + DataFrame + An Excel style pivot table. + + See Also + -------- + DataFrame.pivot : Pivot without aggregation that can handle + non-numeric data. + DataFrame.melt: Unpivot a DataFrame from wide to long format, + optionally leaving identifiers set. + wide_to_long : Wide panel to long format. Less flexible but more + user-friendly than melt. + + Notes + ----- + Reference :ref:`the user guide ` for more examples. + + Examples + -------- + >>> df = pd.DataFrame({"A": ["foo", "foo", "foo", "foo", "foo", + ... "bar", "bar", "bar", "bar"], + ... "B": ["one", "one", "one", "two", "two", + ... "one", "one", "two", "two"], + ... "C": ["small", "large", "large", "small", + ... "small", "large", "small", "small", + ... "large"], + ... "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + ... "E": [2, 4, 5, 5, 6, 6, 8, 9, 9]}) + >>> df + A B C D E + 0 foo one small 1 2 + 1 foo one large 2 4 + 2 foo one large 2 5 + 3 foo two small 3 5 + 4 foo two small 3 6 + 5 bar one large 4 6 + 6 bar one small 5 8 + 7 bar two small 6 9 + 8 bar two large 7 9 + + This first example aggregates values by taking the sum. + + >>> table = pd.pivot_table(df, values='D', index=['A', 'B'], + ... columns=['C'], aggfunc="sum") + >>> table + C large small + A B + bar one 4.0 5.0 + two 7.0 6.0 + foo one 4.0 1.0 + two NaN 6.0 + + We can also fill missing values using the `fill_value` parameter. + + >>> table = pd.pivot_table(df, values='D', index=['A', 'B'], + ... columns=['C'], aggfunc="sum", fill_value=0) + >>> table + C large small + A B + bar one 4 5 + two 7 6 + foo one 4 1 + two 0 6 + + The next example aggregates by taking the mean across multiple columns. + + >>> table = pd.pivot_table(df, values=['D', 'E'], index=['A', 'C'], + ... aggfunc={'D': "mean", 'E': "mean"}) + >>> table + D E + A C + bar large 5.500000 7.500000 + small 5.500000 8.500000 + foo large 2.000000 4.500000 + small 2.333333 4.333333 + + We can also calculate multiple types of aggregations for any given + value column. + + >>> table = pd.pivot_table(df, values=['D', 'E'], index=['A', 'C'], + ... aggfunc={'D': "mean", + ... 'E': ["min", "max", "mean"]}) + >>> table + D E + mean max mean min + A C + bar large 5.500000 9 7.500000 6 + small 5.500000 9 8.500000 8 + foo large 2.000000 5 4.500000 4 + small 2.333333 6 4.333333 2 + """ + + @Substitution("") + @Appender(_shared_docs["pivot_table"]) + def pivot_table( + self, + values=None, + index=None, + columns=None, + aggfunc: AggFuncType = "mean", + fill_value=None, + margins: bool = False, + dropna: bool = True, + margins_name: Level = "All", + observed: bool = True, + sort: bool = True, + **kwargs, + ) -> DataFrame: + from pandas.core.reshape.pivot import pivot_table + + return pivot_table( + self, + values=values, + index=index, + columns=columns, + aggfunc=aggfunc, + fill_value=fill_value, + margins=margins, + dropna=dropna, + margins_name=margins_name, + observed=observed, + sort=sort, + **kwargs, + ) + + def stack( + self, + level: IndexLabel = -1, + dropna: bool | lib.NoDefault = lib.no_default, + sort: bool | lib.NoDefault = lib.no_default, + future_stack: bool = True, + ): + """ + Stack the prescribed level(s) from columns to index. + + Return a reshaped DataFrame or Series having a multi-level + index with one or more new inner-most levels compared to the current + DataFrame. The new inner-most levels are created by pivoting the + columns of the current dataframe: + + - if the columns have a single level, the output is a Series; + - if the columns have multiple levels, the new index level(s) is (are) + taken from the prescribed level(s) and the output is a DataFrame. + + Parameters + ---------- + level : int, str, list, default -1 + Level(s) to stack from the column axis onto the index + axis, defined as one index or label, or a list of indices + or labels. + dropna : bool, default True + Whether to drop rows in the resulting Frame/Series with + missing values. Stacking a column level onto the index + axis can create combinations of index and column values + that are missing from the original dataframe. See Examples + section. + sort : bool, default True + Whether to sort the levels of the resulting MultiIndex. + future_stack : bool, default True + Whether to use the new implementation that will replace the current + implementation in pandas 3.0. When True, dropna and sort have no impact + on the result and must remain unspecified. See :ref:`pandas 2.1.0 Release + notes ` for more details. + + Returns + ------- + DataFrame or Series + Stacked dataframe or series. + + See Also + -------- + DataFrame.unstack : Unstack prescribed level(s) from index axis + onto column axis. + DataFrame.pivot : Reshape dataframe from long format to wide + format. + DataFrame.pivot_table : Create a spreadsheet-style pivot table + as a DataFrame. + + Notes + ----- + The function is named by analogy with a collection of books + being reorganized from being side by side on a horizontal + position (the columns of the dataframe) to being stacked + vertically on top of each other (in the index of the + dataframe). + + Reference :ref:`the user guide ` for more examples. + + Examples + -------- + **Single level columns** + + >>> df_single_level_cols = pd.DataFrame( + ... [[0, 1], [2, 3]], index=["cat", "dog"], columns=["weight", "height"] + ... ) + + Stacking a dataframe with a single level column axis returns a Series: + + >>> df_single_level_cols + weight height + cat 0 1 + dog 2 3 + >>> df_single_level_cols.stack() + cat weight 0 + height 1 + dog weight 2 + height 3 + dtype: int64 + + **Multi level columns: simple case** + + >>> multicol1 = pd.MultiIndex.from_tuples( + ... [("weight", "kg"), ("weight", "pounds")] + ... ) + >>> df_multi_level_cols1 = pd.DataFrame( + ... [[1, 2], [2, 4]], index=["cat", "dog"], columns=multicol1 + ... ) + + Stacking a dataframe with a multi-level column axis: + + >>> df_multi_level_cols1 + weight + kg pounds + cat 1 2 + dog 2 4 + >>> df_multi_level_cols1.stack() + weight + cat kg 1 + pounds 2 + dog kg 2 + pounds 4 + + **Missing values** + + >>> multicol2 = pd.MultiIndex.from_tuples([("weight", "kg"), ("height", "m")]) + >>> df_multi_level_cols2 = pd.DataFrame( + ... [[1.0, 2.0], [3.0, 4.0]], index=["cat", "dog"], columns=multicol2 + ... ) + + It is common to have missing values when stacking a dataframe + with multi-level columns, as the stacked dataframe typically + has more values than the original dataframe. Missing values + are filled with NaNs: + + >>> df_multi_level_cols2 + weight height + kg m + cat 1.0 2.0 + dog 3.0 4.0 + >>> df_multi_level_cols2.stack() + weight height + cat kg 1.0 NaN + m NaN 2.0 + dog kg 3.0 NaN + m NaN 4.0 + + **Prescribing the level(s) to be stacked** + + The first parameter controls which level or levels are stacked: + + >>> df_multi_level_cols2.stack(0) + kg m + cat weight 1.0 NaN + height NaN 2.0 + dog weight 3.0 NaN + height NaN 4.0 + >>> df_multi_level_cols2.stack([0, 1]) + cat weight kg 1.0 + height m 2.0 + dog weight kg 3.0 + height m 4.0 + dtype: float64 + """ + if not future_stack: + from pandas.core.reshape.reshape import ( + stack, + stack_multiple, + ) + + warnings.warn( + "The previous implementation of stack is deprecated and will be " + "removed in a future version of pandas. See the What's New notes " + "for pandas 2.1.0 for details. Do not specify the future_stack " + "argument to adopt the new implementation and silence this warning.", + Pandas4Warning, + stacklevel=find_stack_level(), + ) + + if dropna is lib.no_default: + dropna = True + if sort is lib.no_default: + sort = True + + if isinstance(level, (tuple, list)): + result = stack_multiple(self, level, dropna=dropna, sort=sort) + else: + result = stack(self, level, dropna=dropna, sort=sort) + else: + from pandas.core.reshape.reshape import stack_v3 + + if dropna is not lib.no_default: + raise ValueError( + "dropna must be unspecified as the new " + "implementation does not introduce rows of NA values. This " + "argument will be removed in a future version of pandas." + ) + + if sort is not lib.no_default: + raise ValueError( + "Cannot specify sort, this argument will be " + "removed in a future version of pandas. Sort the result using " + ".sort_index instead." + ) + + if ( + isinstance(level, (tuple, list)) + and not all(lev in self.columns.names for lev in level) + and not all(isinstance(lev, int) for lev in level) + ): + raise ValueError( + "level should contain all level names or all level " + "numbers, not a mixture of the two." + ) + + if not isinstance(level, (tuple, list)): + level = [level] + level = [self.columns._get_level_number(lev) for lev in level] + result = stack_v3(self, level) + + return result.__finalize__(self, method="stack") + + def explode( + self, + column: IndexLabel, + ignore_index: bool = False, + ) -> DataFrame: + """ + Transform each element of a list-like to a row, replicating index values. + + Parameters + ---------- + column : IndexLabel + Column(s) to explode. + For multiple columns, specify a non-empty list with each element + be str or tuple, and all specified columns their list-like data + on same row of the frame must have matching length. + + ignore_index : bool, default False + If True, the resulting index will be labeled 0, 1, …, n - 1. + + Returns + ------- + DataFrame + Exploded lists to rows of the subset columns; + index will be duplicated for these rows. + + Raises + ------ + ValueError : + * If columns of the frame are not unique. + * If specified columns to explode is empty list. + * If specified columns to explode have not matching count of + elements rowwise in the frame. + + See Also + -------- + DataFrame.unstack : Pivot a level of the (necessarily hierarchical) + index labels. + DataFrame.melt : Unpivot a DataFrame from wide format to long format. + Series.explode : Explode a DataFrame from list-like columns to long format. + + Notes + ----- + This routine will explode list-likes including lists, tuples, sets, + Series, and np.ndarray. The result dtype of the subset rows will + be object. Scalars will be returned unchanged, and empty list-likes will + result in a np.nan for that row. In addition, the ordering of rows in the + output will be non-deterministic when exploding sets. + + Reference :ref:`the user guide ` for more examples. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "A": [[0, 1, 2], "foo", [], [3, 4]], + ... "B": 1, + ... "C": [["a", "b", "c"], np.nan, [], ["d", "e"]], + ... } + ... ) + >>> df + A B C + 0 [0, 1, 2] 1 [a, b, c] + 1 foo 1 NaN + 2 [] 1 [] + 3 [3, 4] 1 [d, e] + + Single-column explode. + + >>> df.explode("A") + A B C + 0 0 1 [a, b, c] + 0 1 1 [a, b, c] + 0 2 1 [a, b, c] + 1 foo 1 NaN + 2 NaN 1 [] + 3 3 1 [d, e] + 3 4 1 [d, e] + + Multi-column explode. + + >>> df.explode(list("AC")) + A B C + 0 0 1 a + 0 1 1 b + 0 2 1 c + 1 foo 1 NaN + 2 NaN 1 NaN + 3 3 1 d + 3 4 1 e + """ + if not self.columns.is_unique: + duplicate_cols = self.columns[self.columns.duplicated()].tolist() + raise ValueError( + f"DataFrame columns must be unique. Duplicate columns: {duplicate_cols}" + ) + + columns: list[Hashable] + if is_scalar(column) or isinstance(column, tuple): + columns = [column] + elif isinstance(column, list) and all( + is_scalar(c) or isinstance(c, tuple) for c in column + ): + if not column: + raise ValueError("column must be nonempty") + if len(column) > len(set(column)): + raise ValueError("column must be unique") + columns = column + else: + raise ValueError("column must be a scalar, tuple, or list thereof") + + df = self.reset_index(drop=True) + if len(columns) == 1: + result = df[columns[0]].explode() + else: + mylen = lambda x: len(x) if (is_list_like(x) and len(x) > 0) else 1 + counts0 = self[columns[0]].apply(mylen) + for c in columns[1:]: + if not all(counts0 == self[c].apply(mylen)): + raise ValueError("columns must have matching element counts") + result = DataFrame({c: df[c].explode() for c in columns}) + result = df.drop(columns, axis=1).join(result) + if ignore_index: + result.index = default_index(len(result)) + else: + result.index = self.index.take(result.index) + result = result.reindex(columns=self.columns) + + return result.__finalize__(self, method="explode") + + def unstack( + self, level: IndexLabel = -1, fill_value=None, sort: bool = True + ) -> DataFrame | Series: + """ + Pivot a level of the (necessarily hierarchical) index labels. + + Returns a DataFrame having a new level of column labels whose inner-most level + consists of the pivoted index labels. + + If the index is not a MultiIndex, the output will be a Series + (the analogue of stack when the columns are not a MultiIndex). + + Parameters + ---------- + level : int, str, or list of these, default -1 (last level) + Level(s) of index to unstack, can pass level name. + fill_value : scalar + Replace NaN with this value if the unstack produces missing values. + sort : bool, default True + Sort the level(s) in the resulting MultiIndex columns. + + Returns + ------- + Series or DataFrame + If index is a MultiIndex: DataFrame with pivoted index labels as new + inner-most level column labels, else Series. + + See Also + -------- + DataFrame.pivot : Pivot a table based on column values. + DataFrame.stack : Pivot a level of the column labels (inverse operation + from `unstack`). + + Notes + ----- + Reference :ref:`the user guide ` for more examples. + + Examples + -------- + >>> index = pd.MultiIndex.from_tuples( + ... [("one", "a"), ("one", "b"), ("two", "a"), ("two", "b")] + ... ) + >>> s = pd.Series(np.arange(1.0, 5.0), index=index) + >>> s + one a 1.0 + b 2.0 + two a 3.0 + b 4.0 + dtype: float64 + + >>> s.unstack(level=-1) + a b + one 1.0 2.0 + two 3.0 4.0 + + >>> s.unstack(level=0) + one two + a 1.0 3.0 + b 2.0 4.0 + + >>> df = s.unstack(level=0) + >>> df.unstack() + one a 1.0 + b 2.0 + two a 3.0 + b 4.0 + dtype: float64 + """ + from pandas.core.reshape.reshape import unstack + + result = unstack(self, level, fill_value, sort) + + return result.__finalize__(self, method="unstack") + + def melt( + self, + id_vars=None, + value_vars=None, + var_name=None, + value_name: Hashable = "value", + col_level: Level | None = None, + ignore_index: bool = True, + ) -> DataFrame: + """ + Unpivot DataFrame from wide to long format, optionally leaving identifiers set. + + This function is useful to massage a DataFrame into a format where one + or more columns are identifier variables (`id_vars`), while all other + columns, considered measured variables (`value_vars`), are "unpivoted" to + the row axis, leaving just two non-identifier columns, 'variable' and + 'value'. + + Parameters + ---------- + id_vars : scalar, tuple, list, or ndarray, optional + Column(s) to use as identifier variables. + value_vars : scalar, tuple, list, or ndarray, optional + Column(s) to unpivot. If not specified, uses all columns that + are not set as `id_vars`. + var_name : scalar, default None + Name to use for the 'variable' column. If None it uses + ``frame.columns.name`` or 'variable'. + value_name : scalar, default 'value' + Name to use for the 'value' column, can't be an existing column label. + col_level : scalar, optional + If columns are a MultiIndex then use this level to melt. + ignore_index : bool, default True + If True, original index is ignored. If False, original index is retained. + Index labels will be repeated as necessary. + + Returns + ------- + DataFrame + Unpivoted DataFrame. + + See Also + -------- + melt : Identical method. + pivot_table : Create a spreadsheet-style pivot table as a DataFrame. + DataFrame.pivot : Return reshaped DataFrame organized + by given index / column values. + DataFrame.explode : Explode a DataFrame from list-like + columns to long format. + + Notes + ----- + Reference :ref:`the user guide ` for more examples. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "A": {0: "a", 1: "b", 2: "c"}, + ... "B": {0: 1, 1: 3, 2: 5}, + ... "C": {0: 2, 1: 4, 2: 6}, + ... } + ... ) + >>> df + A B C + 0 a 1 2 + 1 b 3 4 + 2 c 5 6 + + >>> df.melt(id_vars=["A"], value_vars=["B"]) + A variable value + 0 a B 1 + 1 b B 3 + 2 c B 5 + + >>> df.melt(id_vars=["A"], value_vars=["B", "C"]) + A variable value + 0 a B 1 + 1 b B 3 + 2 c B 5 + 3 a C 2 + 4 b C 4 + 5 c C 6 + + The names of 'variable' and 'value' columns can be customized: + + >>> df.melt( + ... id_vars=["A"], + ... value_vars=["B"], + ... var_name="myVarname", + ... value_name="myValname", + ... ) + A myVarname myValname + 0 a B 1 + 1 b B 3 + 2 c B 5 + + Original index values can be kept around: + + >>> df.melt(id_vars=["A"], value_vars=["B", "C"], ignore_index=False) + A variable value + 0 a B 1 + 1 b B 3 + 2 c B 5 + 0 a C 2 + 1 b C 4 + 2 c C 6 + + If you have multi-index columns: + + >>> df.columns = [list("ABC"), list("DEF")] + >>> df + A B C + D E F + 0 a 1 2 + 1 b 3 4 + 2 c 5 6 + + >>> df.melt(col_level=0, id_vars=["A"], value_vars=["B"]) + A variable value + 0 a B 1 + 1 b B 3 + 2 c B 5 + + >>> df.melt(id_vars=[("A", "D")], value_vars=[("B", "E")]) + (A, D) variable_0 variable_1 value + 0 a B E 1 + 1 b B E 3 + 2 c B E 5 + """ + return melt( + self, + id_vars=id_vars, + value_vars=value_vars, + var_name=var_name, + value_name=value_name, + col_level=col_level, + ignore_index=ignore_index, + ).__finalize__(self, method="melt") + + # ---------------------------------------------------------------------- + # Time series-related + + def diff(self, periods: int = 1, axis: Axis = 0) -> DataFrame: + """ + First discrete difference of element. + + Calculates the difference of a DataFrame element compared with another + element in the DataFrame (default is element in previous row). + + Parameters + ---------- + periods : int, default 1 + Periods to shift for calculating difference, accepts negative + values. + axis : {0 or 'index', 1 or 'columns'}, default 0 + Take difference over rows (0) or columns (1). + + Returns + ------- + DataFrame + First differences of the Series. + + See Also + -------- + DataFrame.pct_change: Percent change over given number of periods. + DataFrame.shift: Shift index by desired number of periods with an + optional time freq. + Series.diff: First discrete difference of object. + + Notes + ----- + For boolean dtypes, this uses :meth:`operator.xor` rather than + :meth:`operator.sub`. + The result is calculated according to current dtype in DataFrame, + however dtype of the result is always float64. + + Examples + -------- + + Difference with previous row + + >>> df = pd.DataFrame( + ... { + ... "a": [1, 2, 3, 4, 5, 6], + ... "b": [1, 1, 2, 3, 5, 8], + ... "c": [1, 4, 9, 16, 25, 36], + ... } + ... ) + >>> df + a b c + 0 1 1 1 + 1 2 1 4 + 2 3 2 9 + 3 4 3 16 + 4 5 5 25 + 5 6 8 36 + >>> df.diff() + a b c + 0 NaN NaN NaN + 1 1.0 0.0 3.0 + 2 1.0 1.0 5.0 + 3 1.0 1.0 7.0 + 4 1.0 2.0 9.0 + 5 1.0 3.0 11.0 + + Difference with previous column + + >>> df.diff(axis=1) + a b c + 0 NaN 0 0 + 1 NaN -1 3 + 2 NaN -1 7 + 3 NaN -1 13 + 4 NaN 0 20 + 5 NaN 2 28 + + Difference with 3rd previous row + + >>> df.diff(periods=3) + a b c + 0 NaN NaN NaN + 1 NaN NaN NaN + 2 NaN NaN NaN + 3 3.0 2.0 15.0 + 4 3.0 4.0 21.0 + 5 3.0 6.0 27.0 + + Difference with following row + + >>> df.diff(periods=-1) + a b c + 0 -1.0 0.0 -3.0 + 1 -1.0 -1.0 -5.0 + 2 -1.0 -1.0 -7.0 + 3 -1.0 -2.0 -9.0 + 4 -1.0 -3.0 -11.0 + 5 NaN NaN NaN + + Overflow in input dtype + + >>> df = pd.DataFrame({"a": [1, 0]}, dtype=np.uint8) + >>> df.diff() + a + 0 NaN + 1 255.0 + """ + if not lib.is_integer(periods): + if not (is_float(periods) and periods.is_integer()): + raise ValueError("periods must be an integer") + periods = int(periods) + + axis = self._get_axis_number(axis) + if axis == 1: + if periods != 0: + # in the periods == 0 case, this is equivalent diff of 0 periods + # along axis=0, and the Manager method may be somewhat more + # performant, so we dispatch in that case. + return self - self.shift(periods, axis=axis) + # With periods=0 this is equivalent to a diff with axis=0 + axis = 0 + + new_data = self._mgr.diff(n=periods) + res_df = self._constructor_from_mgr(new_data, axes=new_data.axes) + return res_df.__finalize__(self, "diff") + + # ---------------------------------------------------------------------- + # Function application + + def _gotitem( + self, + key: IndexLabel, + ndim: int, + subset: DataFrame | Series | None = None, + ) -> DataFrame | Series: + """ + Sub-classes to define. Return a sliced object. + + Parameters + ---------- + key : string / list of selections + ndim : {1, 2} + requested ndim of result + subset : object, default None + subset to act on + """ + if subset is None: + subset = self + elif subset.ndim == 1: # is Series + return subset + + # TODO: _shallow_copy(subset)? + return subset[key] + + _agg_see_also_doc = dedent( + """ + See Also + -------- + DataFrame.apply : Perform any type of operations. + DataFrame.transform : Perform transformation type operations. + DataFrame.groupby : Perform operations over groups. + DataFrame.resample : Perform operations over resampled bins. + DataFrame.rolling : Perform operations over rolling window. + DataFrame.expanding : Perform operations over expanding window. + core.window.ewm.ExponentialMovingWindow : Perform operation over exponential + weighted window. + """ + ) + + _agg_examples_doc = dedent( + """ + Examples + -------- + >>> df = pd.DataFrame([[1, 2, 3], + ... [4, 5, 6], + ... [7, 8, 9], + ... [np.nan, np.nan, np.nan]], + ... columns=['A', 'B', 'C']) + + Aggregate these functions over the rows. + + >>> df.agg(['sum', 'min']) + A B C + sum 12.0 15.0 18.0 + min 1.0 2.0 3.0 + + Different aggregations per column. + + >>> df.agg({'A' : ['sum', 'min'], 'B' : ['min', 'max']}) + A B + sum 12.0 NaN + min 1.0 2.0 + max NaN 8.0 + + Aggregate different functions over the columns and rename the index + of the resulting DataFrame. + + >>> df.agg(x=('A', 'max'), y=('B', 'min'), z=('C', 'mean')) + A B C + x 7.0 NaN NaN + y NaN 2.0 NaN + z NaN NaN 6.0 + + Aggregate over the columns. + + >>> df.agg("mean", axis="columns") + 0 2.0 + 1 5.0 + 2 8.0 + 3 NaN + dtype: float64 + """ + ) + + def aggregate(self, func=None, axis: Axis = 0, *args, **kwargs): + """ + Aggregate using one or more operations over the specified axis. + + Parameters + ---------- + func : function, str, list or dict + Function to use for aggregating the data. If a function, must either + work when passed a DataFrame or when passed to DataFrame.apply. + + Accepted combinations are: + + - function + - string function name + - list of functions and/or function names, e.g. ``[np.sum, 'mean']`` + - dict of axis labels -> functions, function names or list of such. + axis : {0 or 'index', 1 or 'columns'}, default 0 + If 0 or 'index': apply function to each column. + If 1 or 'columns': apply function to each row. + *args + Positional arguments to pass to `func`. + **kwargs + Keyword arguments to pass to `func`. + + Returns + ------- + scalar, Series or DataFrame + + The return can be: + + * scalar : when Series.agg is called with single function + * Series : when DataFrame.agg is called with a single function + * DataFrame : when DataFrame.agg is called with several functions + + See Also + -------- + DataFrame.apply : Perform any type of operations. + DataFrame.transform : Perform transformation type operations. + DataFrame.groupby : Perform operations over groups. + DataFrame.resample : Perform operations over resampled bins. + DataFrame.rolling : Perform operations over rolling window. + DataFrame.expanding : Perform operations over expanding window. + core.window.ewm.ExponentialMovingWindow : Perform operation over exponential + weighted window. + + Notes + ----- + The aggregation operations are always performed over an axis, either the + index (default) or the column axis. This behavior is different from + `numpy` aggregation functions (`mean`, `median`, `prod`, `sum`, `std`, + `var`), where the default is to compute the aggregation of the flattened + array, e.g., ``numpy.mean(arr_2d)`` as opposed to + ``numpy.mean(arr_2d, axis=0)``. + + `agg` is an alias for `aggregate`. Use the alias. + + Functions that mutate the passed object can produce unexpected + behavior or errors and are not supported. See :ref:`gotchas.udf-mutation` + for more details. + + A passed user-defined-function will be passed a Series for evaluation. + + If ``func`` defines an index relabeling, ``axis`` must be ``0`` or ``index``. + + Examples + -------- + >>> df = pd.DataFrame( + ... [[1, 2, 3], [4, 5, 6], [7, 8, 9], [np.nan, np.nan, np.nan]], + ... columns=["A", "B", "C"], + ... ) + + Aggregate these functions over the rows. + + >>> df.agg(["sum", "min"]) + A B C + sum 12.0 15.0 18.0 + min 1.0 2.0 3.0 + + Different aggregations per column. + + >>> df.agg({"A": ["sum", "min"], "B": ["min", "max"]}) + A B + sum 12.0 NaN + min 1.0 2.0 + max NaN 8.0 + + Aggregate different functions over the columns and rename the index of + the resulting DataFrame. + + >>> df.agg(x=("A", "max"), y=("B", "min"), z=("C", "mean")) + A B C + x 7.0 NaN NaN + y NaN 2.0 NaN + z NaN NaN 6.0 + + Aggregate over the columns. + + >>> df.agg("mean", axis="columns") + 0 2.0 + 1 5.0 + 2 8.0 + 3 NaN + dtype: float64 + """ + from pandas.core.apply import frame_apply + + axis = self._get_axis_number(axis) + + op = frame_apply(self, func=func, axis=axis, args=args, kwargs=kwargs) + result = op.agg() + result = reconstruct_and_relabel_result(result, func, **kwargs) + return result + + agg = aggregate + + def transform( + self, func: AggFuncType, axis: Axis = 0, *args, **kwargs + ) -> DataFrame: + """ + Call ``func`` on self producing a DataFrame with the same axis shape as self. + + Parameters + ---------- + func : function, str, list-like or dict-like + Function to use for transforming the data. If a function, must either + work when passed a DataFrame or when passed to DataFrame.apply. If func + is both list-like and dict-like, dict-like behavior takes precedence. + + Accepted combinations are: + + - function + - string function name + - list-like of functions and/or function names, e.g. ``[np.exp, 'sqrt']`` + - dict-like of axis labels -> functions, function names or list-like + of such. + axis : {0 or 'index', 1 or 'columns'}, default 0 + If 0 or 'index': apply function to each column. + If 1 or 'columns': apply function to each row. + *args + Positional arguments to pass to `func`. + **kwargs + Keyword arguments to pass to `func`. + + Returns + ------- + DataFrame + A DataFrame that must have the same length as self. + + Raises + ------ + ValueError : If the returned DataFrame has a different length than self. + + See Also + -------- + DataFrame.agg : Only perform aggregating type operations. + DataFrame.apply : Invoke function on a DataFrame. + + Notes + ----- + Functions that mutate the passed object can produce unexpected + behavior or errors and are not supported. See :ref:`gotchas.udf-mutation` + for more details. + + Examples + -------- + >>> df = pd.DataFrame({"A": range(3), "B": range(1, 4)}) + >>> df + A B + 0 0 1 + 1 1 2 + 2 2 3 + >>> df.transform(lambda x: x + 1) + A B + 0 1 2 + 1 2 3 + 2 3 4 + + Even though the resulting DataFrame must have the same length as the + input DataFrame, it is possible to provide several input functions: + + >>> s = pd.Series(range(3)) + >>> s + 0 0 + 1 1 + 2 2 + dtype: int64 + >>> s.transform([np.sqrt, np.exp]) + sqrt exp + 0 0.000000 1.000000 + 1 1.000000 2.718282 + 2 1.414214 7.389056 + + You can call transform on a GroupBy object: + + >>> df = pd.DataFrame( + ... { + ... "Date": [ + ... "2015-05-08", + ... "2015-05-07", + ... "2015-05-06", + ... "2015-05-05", + ... "2015-05-08", + ... "2015-05-07", + ... "2015-05-06", + ... "2015-05-05", + ... ], + ... "Data": [5, 8, 6, 1, 50, 100, 60, 120], + ... } + ... ) + >>> df + Date Data + 0 2015-05-08 5 + 1 2015-05-07 8 + 2 2015-05-06 6 + 3 2015-05-05 1 + 4 2015-05-08 50 + 5 2015-05-07 100 + 6 2015-05-06 60 + 7 2015-05-05 120 + >>> df.groupby("Date")["Data"].transform("sum") + 0 55 + 1 108 + 2 66 + 3 121 + 4 55 + 5 108 + 6 66 + 7 121 + Name: Data, dtype: int64 + + >>> df = pd.DataFrame( + ... { + ... "c": [1, 1, 1, 2, 2, 2, 2], + ... "type": ["m", "n", "o", "m", "m", "n", "n"], + ... } + ... ) + >>> df + c type + 0 1 m + 1 1 n + 2 1 o + 3 2 m + 4 2 m + 5 2 n + 6 2 n + >>> df["size"] = df.groupby("c")["type"].transform(len) + >>> df + c type size + 0 1 m 3 + 1 1 n 3 + 2 1 o 3 + 3 2 m 4 + 4 2 m 4 + 5 2 n 4 + 6 2 n 4 + """ + from pandas.core.apply import frame_apply + + op = frame_apply(self, func=func, axis=axis, args=args, kwargs=kwargs) + result = op.transform() + assert isinstance(result, DataFrame) + return result + + def apply( + self, + func: AggFuncType, + axis: Axis = 0, + raw: bool = False, + result_type: Literal["expand", "reduce", "broadcast"] | None = None, + args=(), + by_row: Literal[False, "compat"] = "compat", + engine: Callable | None | Literal["python", "numba"] = None, + engine_kwargs: dict[str, bool] | None = None, + **kwargs, + ): + """ + Apply a function along an axis of the DataFrame. + + Objects passed to the function are Series objects whose index is + either the DataFrame's index (``axis=0``) or the DataFrame's columns + (``axis=1``). By default (``result_type=None``), the final return type + is inferred from the return type of the applied function. Otherwise, + it depends on the `result_type` argument. The return type of the applied + function is inferred based on the first computed result obtained after + applying the function to a Series object. + + Parameters + ---------- + func : function + Function to apply to each column or row. + axis : {0 or 'index', 1 or 'columns'}, default 0 + Axis along which the function is applied: + + * 0 or 'index': apply function to each column. + * 1 or 'columns': apply function to each row. + + raw : bool, default False + Determines if row or column is passed as a Series or ndarray object: + + * ``False`` : passes each row or column as a Series to the + function. + * ``True`` : the passed function will receive ndarray objects + instead. + If you are just applying a NumPy reduction function this will + achieve much better performance. + + .. note:: + + When ``raw=True``, the result dtype is inferred from the **first** + returned value. + + result_type : {'expand', 'reduce', 'broadcast', None}, default None + These only act when ``axis=1`` (columns): + + * 'expand' : list-like results will be turned into columns. + * 'reduce' : returns a Series if possible rather than expanding + list-like results. This is the opposite of 'expand'. + * 'broadcast' : results will be broadcast to the original shape + of the DataFrame, the original index and columns will be + retained. + + The default behaviour (None) depends on the return value of the + applied function: list-like results will be returned as a Series + of those. However if the apply function returns a Series these + are expanded to columns. + args : tuple + Positional arguments to pass to `func` in addition to the + array/series. + by_row : False or "compat", default "compat" + Only has an effect when ``func`` is a listlike or dictlike of funcs + and the func isn't a string. + If "compat", will if possible first translate the func into pandas + methods (e.g. ``Series().apply(np.sum)`` will be translated to + ``Series().sum()``). If that doesn't work, will try call to apply again with + ``by_row=True`` and if that fails, will call apply again with + ``by_row=False`` (backward compatible). + If False, the funcs will be passed the whole Series at once. + + .. versionadded:: 2.1.0 + + engine : decorator or {'python', 'numba'}, optional + Choose the execution engine to use. If not provided the function + will be executed by the regular Python interpreter. + + Other options include JIT compilers such Numba and Bodo, which in some + cases can speed up the execution. To use an executor you can provide + the decorators ``numba.jit``, ``numba.njit`` or ``bodo.jit``. You can + also provide the decorator with parameters, like ``numba.jit(nogit=True)``. + + Not all functions can be executed with all execution engines. In general, + JIT compilers will require type stability in the function (no variable + should change data type during the execution). And not all pandas and + NumPy APIs are supported. Check the engine documentation [1]_ and [2]_ + for limitations. + + .. warning:: + + String parameters will stop being supported in a future pandas version. + + .. versionadded:: 2.2.0 + + engine_kwargs : dict + Pass keyword arguments to the engine. + This is currently only used by the numba engine, + see the documentation for the engine argument for more information. + + **kwargs + Additional keyword arguments to pass as keywords arguments to + `func`. + + Returns + ------- + Series or DataFrame + Result of applying ``func`` along the given axis of the + DataFrame. + + See Also + -------- + DataFrame.map: For elementwise operations. + DataFrame.aggregate: Only perform aggregating type operations. + DataFrame.transform: Only perform transforming type operations. + + Notes + ----- + Functions that mutate the passed object can produce unexpected + behavior or errors and are not supported. See :ref:`gotchas.udf-mutation` + for more details. + + References + ---------- + .. [1] `Numba documentation + `_ + .. [2] `Bodo documentation + `/ + + Examples + -------- + >>> df = pd.DataFrame([[4, 9]] * 3, columns=["A", "B"]) + >>> df + A B + 0 4 9 + 1 4 9 + 2 4 9 + + Using a numpy universal function (in this case the same as + ``np.sqrt(df)``): + + >>> df.apply(np.sqrt) + A B + 0 2.0 3.0 + 1 2.0 3.0 + 2 2.0 3.0 + + Using a reducing function on either axis + + >>> df.apply(np.sum, axis=0) + A 12 + B 27 + dtype: int64 + + >>> df.apply(np.sum, axis=1) + 0 13 + 1 13 + 2 13 + dtype: int64 + + Returning a list-like will result in a Series + + >>> df.apply(lambda x: [1, 2], axis=1) + 0 [1, 2] + 1 [1, 2] + 2 [1, 2] + dtype: object + + Passing ``result_type='expand'`` will expand list-like results + to columns of a Dataframe + + >>> df.apply(lambda x: [1, 2], axis=1, result_type="expand") + 0 1 + 0 1 2 + 1 1 2 + 2 1 2 + + Returning a Series inside the function is similar to passing + ``result_type='expand'``. The resulting column names + will be the Series index. + + >>> df.apply(lambda x: pd.Series([1, 2], index=["foo", "bar"]), axis=1) + foo bar + 0 1 2 + 1 1 2 + 2 1 2 + + Passing ``result_type='broadcast'`` will ensure the same shape + result, whether list-like or scalar is returned by the function, + and broadcast it along the axis. The resulting column names will + be the originals. + + >>> df.apply(lambda x: [1, 2], axis=1, result_type="broadcast") + A B + 0 1 2 + 1 1 2 + 2 1 2 + + Advanced users can speed up their code by using a Just-in-time (JIT) compiler + with ``apply``. The main JIT compilers available for pandas are Numba and Bodo. + In general, JIT compilation is only possible when the function passed to + ``apply`` has type stability (variables in the function do not change their + type during the execution). + + >>> import bodo # doctest: +SKIP + >>> df.apply(lambda x: x.A + x.B, axis=1, engine=bodo.jit) # doctest: +SKIP + + Note that JIT compilation is only recommended for functions that take a + significant amount of time to run. Fast functions are unlikely to run faster + with JIT compilation. + """ + if engine is None or isinstance(engine, str): + from pandas.core.apply import frame_apply + + if engine is None: + engine = "python" + + if engine not in ["python", "numba"]: + raise ValueError(f"Unknown engine '{engine}'") + + op = frame_apply( + self, + func=func, + axis=axis, + raw=raw, + result_type=result_type, + by_row=by_row, + engine=engine, + engine_kwargs=engine_kwargs, + args=args, + kwargs=kwargs, + ) + return op.apply().__finalize__(self, method="apply") + elif hasattr(engine, "__pandas_udf__"): + if result_type is not None: + raise NotImplementedError( + f"{result_type=} only implemented for the default engine" + ) + + agg_axis = self._get_agg_axis(self._get_axis_number(axis)) + + # one axis is empty + if not all(self.shape): + func = cast(Callable, func) + try: + if axis == 0: + r = func(Series([], dtype=np.float64), *args, **kwargs) + else: + r = func( + Series(index=self.columns, dtype=np.float64), + *args, + **kwargs, + ) + except Exception: + pass + else: + if not isinstance(r, Series): + if len(agg_axis): + r = func(Series([], dtype=np.float64), *args, **kwargs) + else: + r = np.nan + + return self._constructor_sliced(r, index=agg_axis) + return self.copy() + + data: DataFrame | np.ndarray = self + if raw: + # This will upcast the whole DataFrame to the same type, + # and likely result in an object 2D array. + # We should probably pass a list of 1D arrays instead, at + # lest for ``axis=0`` + data = self.values + result = engine.__pandas_udf__.apply( + data=data, + func=func, + args=args, + kwargs=kwargs, + decorator=engine, + axis=axis, + ) + if raw: + if result.ndim == 2: + return self._constructor( + result, index=self.index, columns=self.columns + ) + else: + return self._constructor_sliced(result, index=agg_axis) + return result + else: + raise ValueError(f"Unknown engine {engine}") + + def map( + self, func: PythonFuncType, na_action: Literal["ignore"] | None = None, **kwargs + ) -> DataFrame: + """ + Apply a function to a Dataframe elementwise. + + .. versionadded:: 2.1.0 + + DataFrame.applymap was deprecated and renamed to DataFrame.map. + + This method applies a function that accepts and returns a scalar + to every element of a DataFrame. + + Parameters + ---------- + func : callable + Python function, returns a single value from a single value. + na_action : {None, 'ignore'}, default None + If 'ignore', propagate NaN values, without passing them to func. + **kwargs + Additional keyword arguments to pass as keywords arguments to + `func`. + + Returns + ------- + DataFrame + Transformed DataFrame. + + See Also + -------- + DataFrame.apply : Apply a function along input axis of DataFrame. + DataFrame.replace: Replace values given in `to_replace` with `value`. + Series.map : Apply a function elementwise on a Series. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2.12], [3.356, 4.567]]) + >>> df + 0 1 + 0 1.000 2.120 + 1 3.356 4.567 + + >>> df.map(lambda x: len(str(x))) + 0 1 + 0 3 4 + 1 5 5 + + Like Series.map, NA values can be ignored: + + >>> df_copy = df.copy() + >>> df_copy.iloc[0, 0] = pd.NA + >>> df_copy.map(lambda x: len(str(x)), na_action="ignore") + 0 1 + 0 NaN 4 + 1 5.0 5 + + It is also possible to use `map` with functions that are not + `lambda` functions: + + >>> df.map(round, ndigits=1) + 0 1 + 0 1.0 2.1 + 1 3.4 4.6 + + Note that a vectorized version of `func` often exists, which will + be much faster. You could square each number elementwise. + + >>> df.map(lambda x: x**2) + 0 1 + 0 1.000000 4.494400 + 1 11.262736 20.857489 + + But it's better to avoid map in that case. + + >>> df**2 + 0 1 + 0 1.000000 4.494400 + 1 11.262736 20.857489 + """ + if na_action not in {"ignore", None}: + raise ValueError(f"na_action must be 'ignore' or None. Got {na_action!r}") + + if self.empty: + return self.copy() + + func = functools.partial(func, **kwargs) + + def infer(x): + return x._map_values(func, na_action=na_action) + + return self.apply(infer).__finalize__(self, "map") + + # ---------------------------------------------------------------------- + # Merging / joining methods + + def _append_internal( + self, + other: Series, + ignore_index: bool = False, + ) -> DataFrame: + assert isinstance(other, Series), type(other) + + if other.name is None and not ignore_index: + raise TypeError( + "Can only append a Series if ignore_index=True " + "or if the Series has a name" + ) + + index = Index( + [other.name], + name=( + self.index.names + if isinstance(self.index, MultiIndex) + else self.index.name + ), + ) + + row_df = other.to_frame().T + if isinstance(self.index.dtype, ExtensionDtype): + # GH#41626 retain e.g. CategoricalDtype if reached via + # df.loc[key] = item + row_df.index = self.index.array._cast_pointwise_result(row_df.index._values) + + # infer_objects is needed for + # test_append_empty_frame_to_series_with_dateutil_tz + row_df = row_df.infer_objects().rename_axis(index.names) + + from pandas.core.reshape.concat import concat + + result = concat( + [self, row_df], + ignore_index=ignore_index, + ) + return result.__finalize__(self, method="append") + + def join( + self, + other: DataFrame | Series | Iterable[DataFrame | Series], + on: IndexLabel | None = None, + how: MergeHow = "left", + lsuffix: str = "", + rsuffix: str = "", + sort: bool = False, + validate: JoinValidate | None = None, + ) -> DataFrame: + """ + Join columns of another DataFrame. + + Join columns with `other` DataFrame either on index or on a key + column. Efficiently join multiple DataFrame objects by index at once by + passing a list. + + Parameters + ---------- + other : DataFrame, Series, or a list containing any combination of them + Index should be similar to one of the columns in this one. If a + Series is passed, its name attribute must be set, and that will be + used as the column name in the resulting joined DataFrame. + on : str, list of str, or array-like, optional + Column or index level name(s) in the caller to join on the index + in `other`, otherwise joins index-on-index. If multiple + values given, the `other` DataFrame must have a MultiIndex. Can + pass an array as the join key if it is not already contained in + the calling DataFrame. Like an Excel VLOOKUP operation. + how : {'left', 'right', 'outer', 'inner', 'cross', 'left_anti', 'right_anti'}, + default 'left' + How to handle the operation of the two objects. + + * left: use calling frame's index (or column if on is specified) + * right: use `other`'s index. + * outer: form union of calling frame's index (or column if on is + specified) with `other`'s index, and sort it lexicographically. + * inner: form intersection of calling frame's index (or column if + on is specified) with `other`'s index, preserving the order + of the calling's one. + * cross: creates the cartesian product from both frames, preserves the order + of the left keys. + * left_anti: use set difference of calling frame's index and `other`'s + index. + * right_anti: use set difference of `other`'s index and calling frame's + index. + lsuffix : str, default '' + Suffix to use from left frame's overlapping columns. + rsuffix : str, default '' + Suffix to use from right frame's overlapping columns. + sort : bool, default False + Order result DataFrame lexicographically by the join key. If False, + the order of the join key depends on the join type (how keyword). + validate : str, optional + If specified, checks if join is of specified type. + + * "one_to_one" or "1:1": check if join keys are unique in both left + and right datasets. + * "one_to_many" or "1:m": check if join keys are unique in left dataset. + * "many_to_one" or "m:1": check if join keys are unique in right dataset. + * "many_to_many" or "m:m": allowed, but does not result in checks. + + Returns + ------- + DataFrame + A dataframe containing columns from both the caller and `other`. + + See Also + -------- + DataFrame.merge : For column(s)-on-column(s) operations. + + Notes + ----- + Parameters `on`, `lsuffix`, and `rsuffix` are not supported when + passing a list of `DataFrame` objects. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "key": ["K0", "K1", "K2", "K3", "K4", "K5"], + ... "A": ["A0", "A1", "A2", "A3", "A4", "A5"], + ... } + ... ) + + >>> df + key A + 0 K0 A0 + 1 K1 A1 + 2 K2 A2 + 3 K3 A3 + 4 K4 A4 + 5 K5 A5 + + >>> other = pd.DataFrame({"key": ["K0", "K1", "K2"], "B": ["B0", "B1", "B2"]}) + + >>> other + key B + 0 K0 B0 + 1 K1 B1 + 2 K2 B2 + + Join DataFrames using their indexes. + + >>> df.join(other, lsuffix="_caller", rsuffix="_other") + key_caller A key_other B + 0 K0 A0 K0 B0 + 1 K1 A1 K1 B1 + 2 K2 A2 K2 B2 + 3 K3 A3 NaN NaN + 4 K4 A4 NaN NaN + 5 K5 A5 NaN NaN + + If we want to join using the key columns, we need to set key to be + the index in both `df` and `other`. The joined DataFrame will have + key as its index. + + >>> df.set_index("key").join(other.set_index("key")) + A B + key + K0 A0 B0 + K1 A1 B1 + K2 A2 B2 + K3 A3 NaN + K4 A4 NaN + K5 A5 NaN + + Another option to join using the key columns is to use the `on` + parameter. DataFrame.join always uses `other`'s index but we can use + any column in `df`. This method preserves the original DataFrame's + index in the result. + + >>> df.join(other.set_index("key"), on="key") + key A B + 0 K0 A0 B0 + 1 K1 A1 B1 + 2 K2 A2 B2 + 3 K3 A3 NaN + 4 K4 A4 NaN + 5 K5 A5 NaN + + Using non-unique key values shows how they are matched. + + >>> df = pd.DataFrame( + ... { + ... "key": ["K0", "K1", "K1", "K3", "K0", "K1"], + ... "A": ["A0", "A1", "A2", "A3", "A4", "A5"], + ... } + ... ) + + >>> df + key A + 0 K0 A0 + 1 K1 A1 + 2 K1 A2 + 3 K3 A3 + 4 K0 A4 + 5 K1 A5 + + >>> df.join(other.set_index("key"), on="key", validate="m:1") + key A B + 0 K0 A0 B0 + 1 K1 A1 B1 + 2 K1 A2 B1 + 3 K3 A3 NaN + 4 K0 A4 B0 + 5 K1 A5 B1 + """ + from pandas.core.reshape.concat import concat + from pandas.core.reshape.merge import merge + + if isinstance(other, Series): + if other.name is None: + raise ValueError("Other Series must have a name") + other = DataFrame({other.name: other}) + + if isinstance(other, DataFrame): + if how == "cross": + return merge( + self, + other, + how=how, + on=on, + suffixes=(lsuffix, rsuffix), + sort=sort, + validate=validate, + ) + return merge( + self, + other, + left_on=on, + how=how, + left_index=on is None, + right_index=True, + suffixes=(lsuffix, rsuffix), + sort=sort, + validate=validate, + ) + else: + if on is not None: + raise ValueError( + "Joining multiple DataFrames only supported for joining on index" + ) + + if rsuffix or lsuffix: + raise ValueError( + "Suffixes not supported when joining multiple DataFrames" + ) + + # Mypy thinks the RHS is a + # "Union[DataFrame, Series, Iterable[Union[DataFrame, Series]]]" whereas + # the LHS is an "Iterable[DataFrame]", but in reality both types are + # "Iterable[Union[DataFrame, Series]]" due to the if statements + frames = [cast("DataFrame | Series", self), *list(other)] + + can_concat = all(df.index.is_unique for df in frames) + + # join indexes only using concat + if can_concat: + if how in {"left", "right"}: + res = concat( + frames, axis=1, join="outer", verify_integrity=True, sort=sort + ) + index = self.index if how == "left" else frames[-1].index + if sort: + index = index.sort_values() + result = res.reindex(index) + return result + else: + if how == "outer": + sort = True + return concat( + frames, axis=1, join=how, verify_integrity=True, sort=sort + ) + + joined = frames[0] + + for frame in frames[1:]: + joined = merge( + joined, + frame, + sort=sort, + how=how, + left_index=True, + right_index=True, + validate=validate, + ) + + return joined + + @Substitution("") + @Appender(_merge_doc, indents=2) + def merge( + self, + right: DataFrame | Series, + how: MergeHow = "inner", + on: IndexLabel | AnyArrayLike | None = None, + left_on: IndexLabel | AnyArrayLike | None = None, + right_on: IndexLabel | AnyArrayLike | None = None, + left_index: bool = False, + right_index: bool = False, + sort: bool = False, + suffixes: Suffixes = ("_x", "_y"), + copy: bool | lib.NoDefault = lib.no_default, + indicator: str | bool = False, + validate: MergeValidate | None = None, + ) -> DataFrame: + self._check_copy_deprecation(copy) + + from pandas.core.reshape.merge import merge + + return merge( + self, + right, + how=how, + on=on, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + sort=sort, + suffixes=suffixes, + indicator=indicator, + validate=validate, + ) + + def round( + self, decimals: int | dict[IndexLabel, int] | Series = 0, *args, **kwargs + ) -> DataFrame: + """ + Round numeric columns in a DataFrame to a variable number of decimal places. + + Parameters + ---------- + decimals : int, dict, Series + Number of decimal places to round each column to. If an int is + given, round each column to the same number of places. + Otherwise dict and Series round to variable numbers of places. + Column names should be in the keys if `decimals` is a + dict-like, or in the index if `decimals` is a Series. Any + columns not included in `decimals` will be left as is. Elements + of `decimals` which are not columns of the input will be + ignored. + *args + Additional keywords have no effect but might be accepted for + compatibility with numpy. + **kwargs + Additional keywords have no effect but might be accepted for + compatibility with numpy. + + Returns + ------- + DataFrame + A DataFrame with the affected columns rounded to the specified + number of decimal places. + + See Also + -------- + numpy.around : Round a numpy array to the given number of decimals. + Series.round : Round a Series to the given number of decimals. + + Notes + ----- + For values exactly halfway between rounded decimal values, pandas rounds + to the nearest even value (e.g. -0.5 and 0.5 round to 0.0, 1.5 and 2.5 + round to 2.0, etc.). + + Examples + -------- + >>> df = pd.DataFrame( + ... [(0.21, 0.32), (0.01, 0.67), (0.66, 0.03), (0.21, 0.18)], + ... columns=["dogs", "cats"], + ... ) + >>> df + dogs cats + 0 0.21 0.32 + 1 0.01 0.67 + 2 0.66 0.03 + 3 0.21 0.18 + + By providing an integer each column is rounded to the same number + of decimal places + + >>> df.round(1) + dogs cats + 0 0.2 0.3 + 1 0.0 0.7 + 2 0.7 0.0 + 3 0.2 0.2 + + With a dict, the number of places for specific columns can be + specified with the column names as key and the number of decimal + places as value + + >>> df.round({"dogs": 1, "cats": 0}) + dogs cats + 0 0.2 0.0 + 1 0.0 1.0 + 2 0.7 0.0 + 3 0.2 0.0 + + Using a Series, the number of places for specific columns can be + specified with the column names as index and the number of + decimal places as value + + >>> decimals = pd.Series([0, 1], index=["cats", "dogs"]) + >>> df.round(decimals) + dogs cats + 0 0.2 0.0 + 1 0.0 1.0 + 2 0.7 0.0 + 3 0.2 0.0 + """ + from pandas.core.reshape.concat import concat + + def _dict_round(df: DataFrame, decimals) -> Iterator[Series]: + for col, vals in df.items(): + try: + yield _series_round(vals, decimals[col]) + except KeyError: + yield vals + + def _series_round(ser: Series, decimals: int) -> Series: + if is_integer_dtype(ser.dtype) or is_float_dtype(ser.dtype): + return ser.round(decimals) + elif isinstance(ser._values, (DatetimeArray, TimedeltaArray, PeriodArray)): + # GH#57781 + # TODO: also the ArrowDtype analogues? + warnings.warn( + "obj.round has no effect with datetime, timedelta, " + "or period dtypes. Use obj.dt.round(...) instead.", + UserWarning, + stacklevel=find_stack_level(), + ) + return ser + + nv.validate_round(args, kwargs) + + if isinstance(decimals, (dict, Series)): + if isinstance(decimals, Series) and not decimals.index.is_unique: + raise ValueError("Index of decimals must be unique") + if is_dict_like(decimals) and not all( + is_integer(value) for _, value in decimals.items() + ): + raise TypeError("Values in decimals must be integers") + new_cols = list(_dict_round(self, decimals)) + elif is_integer(decimals): + # Dispatch to Block.round + # Argument "decimals" to "round" of "BaseBlockManager" has incompatible + # type "Union[int, integer[Any]]"; expected "int" + new_mgr = self._mgr.round( + decimals=decimals, # type: ignore[arg-type] + ) + return self._constructor_from_mgr(new_mgr, axes=new_mgr.axes).__finalize__( + self, method="round" + ) + else: + raise TypeError("decimals must be an integer, a dict-like or a Series") + + if new_cols is not None and len(new_cols) > 0: + return self._constructor( + concat(new_cols, axis=1), index=self.index, columns=self.columns + ).__finalize__(self, method="round") + else: + return self.copy(deep=False) + + # ---------------------------------------------------------------------- + # Statistical methods, etc. + + def corr( + self, + method: CorrelationMethod = "pearson", + min_periods: int = 1, + numeric_only: bool = False, + ) -> DataFrame: + """ + Compute pairwise correlation of columns, excluding NA/null values. + + Parameters + ---------- + method : {'pearson', 'kendall', 'spearman'} or callable + Method of correlation: + + * pearson : standard correlation coefficient + * kendall : Kendall Tau correlation coefficient + * spearman : Spearman rank correlation + * callable: callable with input two 1d ndarrays + and returning a float. Note that the returned matrix from corr + will have 1 along the diagonals and will be symmetric + regardless of the callable's behavior. + min_periods : int, optional + Minimum number of observations required per pair of columns + to have a valid result. Currently only available for Pearson + and Spearman correlation. + numeric_only : bool, default False + Include only `float`, `int` or `boolean` data. + + .. versionchanged:: 2.0.0 + The default value of ``numeric_only`` is now ``False``. + + Returns + ------- + DataFrame + Correlation matrix. + + See Also + -------- + DataFrame.corrwith : Compute pairwise correlation with another + DataFrame or Series. + Series.corr : Compute the correlation between two Series. + + Notes + ----- + Pearson, Kendall and Spearman correlation are currently computed using pairwise complete observations. + + * `Pearson correlation coefficient `_ + * `Kendall rank correlation coefficient `_ + * `Spearman's rank correlation coefficient `_ + + Examples + -------- + >>> def histogram_intersection(a, b): + ... v = np.minimum(a, b).sum().round(decimals=1) + ... return v + >>> df = pd.DataFrame( + ... [(0.2, 0.3), (0.0, 0.6), (0.6, 0.0), (0.2, 0.1)], + ... columns=["dogs", "cats"], + ... ) + >>> df.corr(method=histogram_intersection) + dogs cats + dogs 1.0 0.3 + cats 0.3 1.0 + + >>> df = pd.DataFrame( + ... [(1, 1), (2, np.nan), (np.nan, 3), (4, 4)], columns=["dogs", "cats"] + ... ) + >>> df.corr(min_periods=3) + dogs cats + dogs 1.0 NaN + cats NaN 1.0 + """ # noqa: E501 + data = self._get_numeric_data() if numeric_only else self + cols = data.columns + idx = cols.copy() + mat = data.to_numpy(dtype=float, na_value=np.nan, copy=False) + + if method == "pearson": + correl = libalgos.nancorr(mat, minp=min_periods) + elif method == "spearman": + correl = libalgos.nancorr_spearman(mat, minp=min_periods) + elif method == "kendall" or callable(method): + if min_periods is None: + min_periods = 1 + mat = mat.T + corrf = nanops.get_corr_func(method) + K = len(cols) + correl = np.empty((K, K), dtype=float) + mask = np.isfinite(mat) + for i, ac in enumerate(mat): + for j, bc in enumerate(mat): + if i > j: + continue + + valid = mask[i] & mask[j] + if valid.sum() < min_periods: + c = np.nan + elif i == j: + c = 1.0 + elif not valid.all(): + c = corrf(ac[valid], bc[valid]) + else: + c = corrf(ac, bc) + correl[i, j] = c + correl[j, i] = c + else: + raise ValueError( + "method must be either 'pearson', " + "'spearman', 'kendall', or a callable, " + f"'{method}' was supplied" + ) + + result = self._constructor(correl, index=idx, columns=cols, copy=False) + return result.__finalize__(self, method="corr") + + def cov( + self, + min_periods: int | None = None, + ddof: int | None = 1, + numeric_only: bool = False, + ) -> DataFrame: + """ + Compute pairwise covariance of columns, excluding NA/null values. + + Compute the pairwise covariance among the series of a DataFrame. + The returned data frame is the `covariance matrix + `__ of the columns + of the DataFrame. + + Both NA and null values are automatically excluded from the + calculation. (See the note below about bias from missing values.) + A threshold can be set for the minimum number of + observations for each value created. Comparisons with observations + below this threshold will be returned as ``NaN``. + + This method is generally used for the analysis of time series data to + understand the relationship between different measures + across time. + + Parameters + ---------- + min_periods : int, optional + Minimum number of observations required per pair of columns + to have a valid result. + + ddof : int, default 1 + Delta degrees of freedom. The divisor used in calculations + is ``N - ddof``, where ``N`` represents the number of elements. + This argument is applicable only when no ``nan`` is in the dataframe. + + numeric_only : bool, default False + Include only `float`, `int` or `boolean` data. + + .. versionchanged:: 2.0.0 + The default value of ``numeric_only`` is now ``False``. + + Returns + ------- + DataFrame + The covariance matrix of the series of the DataFrame. + + See Also + -------- + Series.cov : Compute covariance with another Series. + core.window.ewm.ExponentialMovingWindow.cov : Exponential weighted sample + covariance. + core.window.expanding.Expanding.cov : Expanding sample covariance. + core.window.rolling.Rolling.cov : Rolling sample covariance. + + Notes + ----- + Returns the covariance matrix of the DataFrame's time series. + The covariance is normalized by N-ddof. + + For DataFrames that have Series that are missing data (assuming that + data is `missing at random + `__) + the returned covariance matrix will be an unbiased estimate + of the variance and covariance between the member Series. + + However, for many applications this estimate may not be acceptable + because the estimate covariance matrix is not guaranteed to be positive + semi-definite. This could lead to estimate correlations having + absolute values which are greater than one, and/or a non-invertible + covariance matrix. See `Estimation of covariance matrices + `__ for more details. + + Examples + -------- + >>> df = pd.DataFrame( + ... [(1, 2), (0, 3), (2, 0), (1, 1)], columns=["dogs", "cats"] + ... ) + >>> df.cov() + dogs cats + dogs 0.666667 -1.000000 + cats -1.000000 1.666667 + + >>> np.random.seed(42) + >>> df = pd.DataFrame( + ... np.random.randn(1000, 5), columns=["a", "b", "c", "d", "e"] + ... ) + >>> df.cov() + a b c d e + a 0.998438 -0.020161 0.059277 -0.008943 0.014144 + b -0.020161 1.059352 -0.008543 -0.024738 0.009826 + c 0.059277 -0.008543 1.010670 -0.001486 -0.000271 + d -0.008943 -0.024738 -0.001486 0.921297 -0.013692 + e 0.014144 0.009826 -0.000271 -0.013692 0.977795 + + **Minimum number of periods** + + This method also supports an optional ``min_periods`` keyword + that specifies the required minimum number of non-NA observations for + each column pair in order to have a valid result: + + >>> np.random.seed(42) + >>> df = pd.DataFrame(np.random.randn(20, 3), columns=["a", "b", "c"]) + >>> df.loc[df.index[:5], "a"] = np.nan + >>> df.loc[df.index[5:10], "b"] = np.nan + >>> df.cov(min_periods=12) + a b c + a 0.316741 NaN -0.150812 + b NaN 1.248003 0.191417 + c -0.150812 0.191417 0.895202 + """ + data = self._get_numeric_data() if numeric_only else self + if any(blk.dtype.kind in "mM" for blk in self._mgr.blocks): + msg = ( + "DataFrame contains columns with dtype datetime64 " + "or timedelta64, which are not supported for cov." + ) + raise TypeError(msg) + cols = data.columns + idx = cols.copy() + mat = data.to_numpy(dtype=float, na_value=np.nan, copy=False) + + if notna(mat).all(): + if min_periods is not None and min_periods > len(mat): + base_cov = np.empty((mat.shape[1], mat.shape[1])) + base_cov.fill(np.nan) + else: + base_cov = np.cov(mat.T, ddof=ddof) + base_cov = base_cov.reshape((len(cols), len(cols))) + else: + base_cov = libalgos.nancorr(mat, cov=True, minp=min_periods) + + result = self._constructor(base_cov, index=idx, columns=cols, copy=False) + return result.__finalize__(self, method="cov") + + def corrwith( + self, + other: DataFrame | Series, + axis: Axis = 0, + drop: bool = False, + method: CorrelationMethod = "pearson", + numeric_only: bool = False, + min_periods: int | None = None, + ) -> Series: + """ + Compute pairwise correlation. + + Pairwise correlation is computed between rows or columns of + DataFrame with rows or columns of Series or DataFrame. DataFrames + are first aligned along both axes before computing the + correlations. + + Parameters + ---------- + other : DataFrame, Series + Object with which to compute correlations. + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to use. 0 or 'index' to compute row-wise, 1 or 'columns' for + column-wise. + drop : bool, default False + Drop missing indices from result. + method : {'pearson', 'kendall', 'spearman'} or callable + Method of correlation: + + * pearson : standard correlation coefficient + * kendall : Kendall Tau correlation coefficient + * spearman : Spearman rank correlation + * callable: callable with input two 1d ndarrays + and returning a float. + + numeric_only : bool, default False + Include only `float`, `int` or `boolean` data. + + min_periods : int, optional + Minimum number of observations needed to have a valid result. + + .. versionchanged:: 2.0.0 + The default value of ``numeric_only`` is now ``False``. + + Returns + ------- + Series + Pairwise correlations. + + See Also + -------- + DataFrame.corr : Compute pairwise correlation of columns. + + Examples + -------- + >>> index = ["a", "b", "c", "d", "e"] + >>> columns = ["one", "two", "three", "four"] + >>> df1 = pd.DataFrame( + ... np.arange(20).reshape(5, 4), index=index, columns=columns + ... ) + >>> df2 = pd.DataFrame( + ... np.arange(16).reshape(4, 4), index=index[:4], columns=columns + ... ) + >>> df1.corrwith(df2) + one 1.0 + two 1.0 + three 1.0 + four 1.0 + dtype: float64 + + >>> df2.corrwith(df1, axis=1) + a 1.0 + b 1.0 + c 1.0 + d 1.0 + e NaN + dtype: float64 + """ + axis = self._get_axis_number(axis) + this = self._get_numeric_data() if numeric_only else self + + if isinstance(other, Series): + return this.apply( + lambda x: other.corr(x, method=method, min_periods=min_periods), + axis=axis, + ) + + if numeric_only: + other = other._get_numeric_data() + left, right = this.align(other, join="inner") + + if axis == 1: + left = left.T + right = right.T + + if method == "pearson": + # mask missing values + left = left + right * 0 + right = right + left * 0 + + # demeaned data + ldem = left - left.mean(numeric_only=numeric_only) + rdem = right - right.mean(numeric_only=numeric_only) + + num = (ldem * rdem).sum() + dom = ( + (left.count() - 1) + * left.std(numeric_only=numeric_only) + * right.std(numeric_only=numeric_only) + ) + + correl = num / dom + + elif method in ["kendall", "spearman"] or callable(method): + + def c(x): + return nanops.nancorr(x[0], x[1], method=method) + + correl = self._constructor_sliced( + map(c, zip(left.values.T, right.values.T, strict=True)), + index=left.columns, + copy=False, + ) + + else: + raise ValueError( + f"Invalid method {method} was passed, " + "valid methods are: 'pearson', 'kendall', " + "'spearman', or callable" + ) + + if not drop: + # Find non-matching labels along the given axis + # and append missing correlations (GH 22375) + raxis: AxisInt = 1 if axis == 0 else 0 + result_index = this._get_axis(raxis).union(other._get_axis(raxis)) + idx_diff = result_index.difference(correl.index) + + if len(idx_diff) > 0: + correl = correl._append_internal( + Series([np.nan] * len(idx_diff), index=idx_diff) + ) + + return correl + + # ---------------------------------------------------------------------- + # ndarray-like stats methods + + def count(self, axis: Axis = 0, numeric_only: bool = False) -> Series: + """ + Count non-NA cells for each column or row. + + The values `None`, `NaN`, `NaT`, ``pandas.NA`` are considered NA. + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns'}, default 0 + If 0 or 'index' counts are generated for each column. + If 1 or 'columns' counts are generated for each row. + numeric_only : bool, default False + Include only `float`, `int` or `boolean` data. + + Returns + ------- + Series + For each column/row the number of non-NA/null entries. + + See Also + -------- + Series.count: Number of non-NA elements in a Series. + DataFrame.value_counts: Count unique combinations of columns. + DataFrame.shape: Number of DataFrame rows and columns (including NA + elements). + DataFrame.isna: Boolean same-sized DataFrame showing places of NA + elements. + + Examples + -------- + Constructing DataFrame from a dictionary: + + >>> df = pd.DataFrame( + ... { + ... "Person": ["John", "Myla", "Lewis", "John", "Myla"], + ... "Age": [24.0, np.nan, 21.0, 33, 26], + ... "Single": [False, True, True, True, False], + ... } + ... ) + >>> df + Person Age Single + 0 John 24.0 False + 1 Myla NaN True + 2 Lewis 21.0 True + 3 John 33.0 True + 4 Myla 26.0 False + + Notice the uncounted NA values: + + >>> df.count() + Person 5 + Age 4 + Single 5 + dtype: int64 + + Counts for each **row**: + + >>> df.count(axis="columns") + 0 3 + 1 2 + 2 3 + 3 3 + 4 3 + dtype: int64 + """ + axis = self._get_axis_number(axis) + + if numeric_only: + frame = self._get_numeric_data() + else: + frame = self + + # GH #423 + if len(frame._get_axis(axis)) == 0: + result = self._constructor_sliced(0, index=frame._get_agg_axis(axis)) + else: + result = notna(frame).sum(axis=axis) + + return result.astype("int64").__finalize__(self, method="count") + + def _reduce( + self, + op, + name: str, + *, + axis: Axis = 0, + skipna: bool = True, + numeric_only: bool = False, + filter_type=None, + **kwds, + ): + assert filter_type is None or filter_type == "bool", filter_type + out_dtype = "bool" if filter_type == "bool" else None + + if axis is not None: + axis = self._get_axis_number(axis) + + def func(values: np.ndarray): + # We only use this in the case that operates on self.values + return op(values, axis=axis, skipna=skipna, **kwds) + + def blk_func(values, axis: Axis = 1): + if isinstance(values, ExtensionArray): + if not is_1d_only_ea_dtype(values.dtype): + return values._reduce(name, axis=1, skipna=skipna, **kwds) + return values._reduce(name, skipna=skipna, keepdims=True, **kwds) + else: + return op(values, axis=axis, skipna=skipna, **kwds) + + def _get_data() -> DataFrame: + if filter_type is None: + data = self._get_numeric_data() + else: + # GH#25101, GH#24434 + assert filter_type == "bool" + data = self._get_bool_data() + return data + + # Case with EAs see GH#35881 + df = self + if numeric_only: + df = _get_data() + if axis is None: + dtype = find_common_type([block.values.dtype for block in df._mgr.blocks]) + if isinstance(dtype, ExtensionDtype): + df = df.astype(dtype) + arr = concat_compat(list(df._iter_column_arrays())) + return arr._reduce(name, skipna=skipna, keepdims=False, **kwds) + return maybe_unbox_numpy_scalar(func(df.values)) + elif axis == 1: + if len(df.index) == 0: + # Taking a transpose would result in no columns, losing the dtype. + # In the empty case, reducing along axis 0 or 1 gives the same + # result dtype, so reduce with axis=0 and ignore values + result = df._reduce( + op, + name, + axis=0, + skipna=skipna, + numeric_only=False, + filter_type=filter_type, + **kwds, + ).iloc[:0] + result.index = df.index + return result + + if df.shape[1]: + dtype = find_common_type( + [block.values.dtype for block in df._mgr.blocks] + ) + if isinstance(dtype, ExtensionDtype): + # GH 54341: fastpath for EA-backed axis=1 reductions + # This flattens the frame into a single 1D array while keeping + # track of the row and column indices of the original frame. Once + # flattened, grouping by the row indices and aggregating should + # be equivalent to transposing the original frame and aggregating + # with axis=0. + name = {"argmax": "idxmax", "argmin": "idxmin"}.get(name, name) + df = df.astype(dtype) + arr = concat_compat(list(df._iter_column_arrays())) + nrows, ncols = df.shape + row_index = np.tile(np.arange(nrows), ncols) + col_index = np.repeat(np.arange(ncols), nrows) + ser = Series(arr, index=col_index, copy=False) + if name == "all": + # Behavior here appears incorrect; preserving + # for backwards compatibility for now. + # See https://github.com/pandas-dev/pandas/issues/57171 + skipna = True + result = ser.groupby(row_index).agg(name, **kwds, skipna=skipna) + result.index = df.index + return result + + df = df.T + + # After possibly _get_data and transposing, we are now in the + # simple case where we can use BlockManager.reduce + res = df._mgr.reduce(blk_func) + out = df._constructor_from_mgr(res, axes=res.axes).iloc[0] + out.name = None + if out_dtype is not None and out.dtype != "boolean": + out = out.astype(out_dtype) + elif (df._mgr.get_dtypes() == object).any() and name not in ["any", "all"]: + out = out.astype(object) + elif len(self) == 0 and out.dtype == object and name in ("sum", "prod"): + # Even if we are object dtype, follow numpy and return + # float64, see test_apply_funcs_over_empty + out = out.astype(np.float64) + + return out + + def _reduce_axis1(self, name: str, func, skipna: bool) -> Series: + """ + Special case for _reduce to try to avoid a potentially-expensive transpose. + + Apply the reduction block-wise along axis=1 and then reduce the resulting + 1D arrays. + """ + if name == "all": + result = np.ones(len(self), dtype=bool) + ufunc = np.logical_and + elif name == "any": + result = np.zeros(len(self), dtype=bool) + # error: Incompatible types in assignment + # (expression has type "_UFunc_Nin2_Nout1[Literal['logical_or'], + # Literal[20], Literal[False]]", variable has type + # "_UFunc_Nin2_Nout1[Literal['logical_and'], Literal[20], + # Literal[True]]") + ufunc = np.logical_or # type: ignore[assignment] + else: + raise NotImplementedError(name) + + for blocks in self._mgr.blocks: + middle = func(blocks.values, axis=0, skipna=skipna) + result = ufunc(result, middle) + + res_ser = self._constructor_sliced(result, index=self.index, copy=False) + return res_ser + + # error: Signature of "any" incompatible with supertype "NDFrame" + @overload # type: ignore[override] + def any( + self, + *, + axis: Axis = ..., + bool_only: bool = ..., + skipna: bool = ..., + **kwargs, + ) -> Series: ... + + @overload + def any( + self, + *, + axis: None, + bool_only: bool = ..., + skipna: bool = ..., + **kwargs, + ) -> bool: ... + + @overload + def any( + self, + *, + axis: Axis | None, + bool_only: bool = ..., + skipna: bool = ..., + **kwargs, + ) -> Series | bool: ... + + def any( + self, + *, + axis: Axis | None = 0, + bool_only: bool = False, + skipna: bool = True, + **kwargs, + ) -> Series | bool: + """ + Return whether any element is True, potentially over an axis. + + Returns False unless there is at least one element within a series or + along a Dataframe axis that is True or equivalent (e.g. non-zero or + non-empty). + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns', None}, default 0 + Indicate which axis or axes should be reduced. For `Series` this parameter + is unused and defaults to 0. + + * 0 / 'index' : reduce the index, return a Series whose index is the + original column labels. + * 1 / 'columns' : reduce the columns, return a Series whose index is the + original index. + * None : reduce all axes, return a scalar. + + bool_only : bool, default False + Include only boolean columns. Not implemented for Series. + skipna : bool, default True + Exclude NA/null values. If the entire row/column is NA and skipna is + True, then the result will be False, as for an empty row/column. + If skipna is False, then NA are treated as True, because these are not + equal to zero. + **kwargs : any, default None + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + + Returns + ------- + Series or scalar + If axis=None, then a scalar boolean is returned. + Otherwise a Series is returned with index matching the index argument. + + See Also + -------- + numpy.any : Numpy version of this method. + Series.any : Return whether any element is True. + Series.all : Return whether all elements are True. + DataFrame.any : Return whether any element is True over requested axis. + DataFrame.all : Return whether all elements are True over requested axis. + + Examples + -------- + **Series** + + For Series input, the output is a scalar indicating whether any element + is True. + + >>> pd.Series([False, False]).any() + False + >>> pd.Series([True, False]).any() + True + >>> pd.Series([], dtype="float64").any() + False + >>> pd.Series([np.nan]).any() + False + >>> pd.Series([np.nan]).any(skipna=False) + True + + **DataFrame** + + Whether each column contains at least one True element (the default). + + >>> df = pd.DataFrame({"A": [1, 2], "B": [0, 2], "C": [0, 0]}) + >>> df + A B C + 0 1 0 0 + 1 2 2 0 + + >>> df.any() + A True + B True + C False + dtype: bool + + Aggregating over the columns. + + >>> df = pd.DataFrame({"A": [True, False], "B": [1, 2]}) + >>> df + A B + 0 True 1 + 1 False 2 + + >>> df.any(axis="columns") + 0 True + 1 True + dtype: bool + + >>> df = pd.DataFrame({"A": [True, False], "B": [1, 0]}) + >>> df + A B + 0 True 1 + 1 False 0 + + >>> df.any(axis="columns") + 0 True + 1 False + dtype: bool + + Aggregating over the entire DataFrame with ``axis=None``. + + >>> df.any(axis=None) + True + + `any` for an empty DataFrame is an empty Series. + + >>> pd.DataFrame([]).any() + Series([], dtype: bool) + """ + result = self._logical_func( + "any", nanops.nanany, axis, bool_only, skipna, **kwargs + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="any") + return result + + @overload + def all( + self, + *, + axis: Axis = ..., + bool_only: bool = ..., + skipna: bool = ..., + **kwargs, + ) -> Series: ... + + @overload + def all( + self, + *, + axis: None, + bool_only: bool = ..., + skipna: bool = ..., + **kwargs, + ) -> bool: ... + + @overload + def all( + self, + *, + axis: Axis | None, + bool_only: bool = ..., + skipna: bool = ..., + **kwargs, + ) -> Series | bool: ... + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="all") + def all( + self, + axis: Axis | None = 0, + bool_only: bool = False, + skipna: bool = True, + **kwargs, + ) -> Series | bool: + """ + Return whether all elements are True, potentially over an axis. + + Returns True unless there at least one element within a series or + along a Dataframe axis that is False or equivalent (e.g. zero or + empty). + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns', None}, default 0 + Indicate which axis or axes should be reduced. For `Series` this parameter + is unused and defaults to 0. + + * 0 / 'index' : reduce the index, return a Series whose index is the + original column labels. + * 1 / 'columns' : reduce the columns, return a Series whose index is the + original index. + * None : reduce all axes, return a scalar. + + bool_only : bool, default False + Include only boolean columns. Not implemented for Series. + skipna : bool, default True + Exclude NA/null values. If the entire row/column is NA and skipna is + True, then the result will be True, as for an empty row/column. + If skipna is False, then NA are treated as True, because these are not + equal to zero. + **kwargs : any, default None + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + + Returns + ------- + Series or scalar + If axis=None, then a scalar boolean is returned. + Otherwise a Series is returned with index matching the index argument. + + See Also + -------- + Series.all : Return True if all elements are True. + DataFrame.any : Return True if one (or more) elements are True. + + Examples + -------- + **Series** + + >>> pd.Series([True, True]).all() + True + >>> pd.Series([True, False]).all() + False + >>> pd.Series([], dtype="float64").all() + True + >>> pd.Series([np.nan]).all() + True + >>> pd.Series([np.nan]).all(skipna=False) + True + + **DataFrames** + + Create a DataFrame from a dictionary. + + >>> df = pd.DataFrame({"col1": [True, True], "col2": [True, False]}) + >>> df + col1 col2 + 0 True True + 1 True False + + Default behaviour checks if values in each column all return True. + + >>> df.all() + col1 True + col2 False + dtype: bool + + Specify ``axis='columns'`` to check if values in each row all return True. + + >>> df.all(axis="columns") + 0 True + 1 False + dtype: bool + + Or ``axis=None`` for whether every value is True. + + >>> df.all(axis=None) + False + """ + result = self._logical_func( + "all", nanops.nanall, axis, bool_only, skipna, **kwargs + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="all") + return result + + # error: Signature of "min" incompatible with supertype "NDFrame" + @overload # type: ignore[override] + def min( + self, + *, + axis: Axis = ..., + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series: ... + + @overload + def min( + self, + *, + axis: None, + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Any: ... + + @overload + def min( + self, + *, + axis: Axis | None, + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series | Any: ... + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="min") + def min( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ) -> Series | Any: + """ + Return the minimum of the values over the requested axis. + + If you want the *index* of the minimum, use ``idxmin``. + This is the equivalent of the ``numpy.ndarray`` method ``argmin``. + + Parameters + ---------- + axis : {index (0), columns (1)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + For DataFrames, specifying ``axis=None`` will apply the aggregation + across both axes. + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. + + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + Series or scalar + Value containing the calculation referenced in the description. + + See Also + -------- + Series.sum : Return the sum. + Series.min : Return the minimum. + Series.max : Return the maximum. + Series.idxmin : Return the index of the minimum. + Series.idxmax : Return the index of the maximum. + DataFrame.sum : Return the sum over the requested axis. + DataFrame.min : Return the minimum over the requested axis. + DataFrame.max : Return the maximum over the requested axis. + DataFrame.idxmin : Return the index of the minimum over the requested axis. + DataFrame.idxmax : Return the index of the maximum over the requested axis. + + Examples + -------- + >>> idx = pd.MultiIndex.from_arrays( + ... [["warm", "warm", "cold", "cold"], ["dog", "falcon", "fish", "spider"]], + ... names=["blooded", "animal"], + ... ) + >>> s = pd.Series([4, 2, 0, 8], name="legs", index=idx) + >>> s + blooded animal + warm dog 4 + falcon 2 + cold fish 0 + spider 8 + Name: legs, dtype: int64 + + >>> s.min() + 0 + """ + result = super().min( + axis=axis, skipna=skipna, numeric_only=numeric_only, **kwargs + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="min") + return result + + # error: Signature of "max" incompatible with supertype "NDFrame" + @overload # type: ignore[override] + def max( + self, + *, + axis: Axis = ..., + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series: ... + + @overload + def max( + self, + *, + axis: None, + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Any: ... + + @overload + def max( + self, + *, + axis: Axis | None, + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series | Any: ... + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="max") + def max( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ) -> Series | Any: + """ + Return the maximum of the values over the requested axis. + + If you want the *index* of the maximum, use ``idxmax``. + This is the equivalent of the ``numpy.ndarray`` method ``argmax``. + + Parameters + ---------- + axis : {index (0), columns (1)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + For DataFrames, specifying ``axis=None`` will apply the aggregation + across both axes. + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. + + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + Series or scalar + Value containing the calculation referenced in the description. + + See Also + -------- + Series.sum : Return the sum. + Series.min : Return the minimum. + Series.max : Return the maximum. + Series.idxmin : Return the index of the minimum. + Series.idxmax : Return the index of the maximum. + DataFrame.sum : Return the sum over the requested axis. + DataFrame.min : Return the minimum over the requested axis. + DataFrame.max : Return the maximum over the requested axis. + DataFrame.idxmin : Return the index of the minimum over the requested axis. + DataFrame.idxmax : Return the index of the maximum over the requested axis. + + Examples + -------- + >>> idx = pd.MultiIndex.from_arrays( + ... [["warm", "warm", "cold", "cold"], ["dog", "falcon", "fish", "spider"]], + ... names=["blooded", "animal"], + ... ) + >>> s = pd.Series([4, 2, 0, 8], name="legs", index=idx) + >>> s + blooded animal + warm dog 4 + falcon 2 + cold fish 0 + spider 8 + Name: legs, dtype: int64 + + >>> s.max() + 8 + """ + result = super().max( + axis=axis, skipna=skipna, numeric_only=numeric_only, **kwargs + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="max") + return result + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="sum") + def sum( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + min_count: int = 0, + **kwargs, + ) -> Series: + """ + Return the sum of the values over the requested axis. + + This is equivalent to the method ``numpy.sum``. + + Parameters + ---------- + axis : {index (0), columns (1)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + .. warning:: + + The behavior of DataFrame.sum with ``axis=None`` is deprecated, + in a future version this will reduce over both axes and return a scalar + To retain the old behavior, pass axis=0 (or do not pass axis). + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. Not implemented for Series. + min_count : int, default 0 + The required number of valid values to perform the operation. If fewer than + ``min_count`` non-NA values are present the result will be NA. + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + Series or scalar + Sum over requested axis. + + See Also + -------- + Series.sum : Return the sum over Series values. + DataFrame.mean : Return the mean of the values over the requested axis. + DataFrame.median : Return the median of the values over the requested axis. + DataFrame.mode : Get the mode(s) of each element along the requested axis. + DataFrame.std : Return the standard deviation of the values over the + requested axis. + + Examples + -------- + >>> idx = pd.MultiIndex.from_arrays( + ... [["warm", "warm", "cold", "cold"], ["dog", "falcon", "fish", "spider"]], + ... names=["blooded", "animal"], + ... ) + >>> s = pd.Series([4, 2, 0, 8], name="legs", index=idx) + >>> s + blooded animal + warm dog 4 + falcon 2 + cold fish 0 + spider 8 + Name: legs, dtype: int64 + + >>> s.sum() + 14 + + By default, the sum of an empty or all-NA Series is ``0``. + + >>> pd.Series([], dtype="float64").sum() # min_count=0 is the default + 0.0 + + This can be controlled with the ``min_count`` parameter. For example, if + you'd like the sum of an empty series to be NaN, pass ``min_count=1``. + + >>> pd.Series([], dtype="float64").sum(min_count=1) + nan + + Thanks to the ``skipna`` parameter, ``min_count`` handles all-NA and + empty series identically. + + >>> pd.Series([np.nan]).sum() + 0.0 + + >>> pd.Series([np.nan]).sum(min_count=1) + nan + """ + result = super().sum( + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + min_count=min_count, + **kwargs, + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="sum") + return result + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="prod") + def prod( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + min_count: int = 0, + **kwargs, + ) -> Series: + """ + Return the product of the values over the requested axis. + + Parameters + ---------- + axis : {index (0), columns (1)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + .. warning:: + + The behavior of DataFrame.prod with ``axis=None`` is deprecated, + in a future version this will reduce over both axes and return a scalar + To retain the old behavior, pass axis=0 (or do not pass axis). + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. Not implemented for Series. + + min_count : int, default 0 + The required number of valid values to perform the operation. If fewer than + ``min_count`` non-NA values are present the result will be NA. + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + Series or scalar + The product of the values over the requested axis. + + See Also + -------- + Series.sum : Return the sum. + Series.min : Return the minimum. + Series.max : Return the maximum. + Series.idxmin : Return the index of the minimum. + Series.idxmax : Return the index of the maximum. + DataFrame.sum : Return the sum over the requested axis. + DataFrame.min : Return the minimum over the requested axis. + DataFrame.max : Return the maximum over the requested axis. + DataFrame.idxmin : Return the index of the minimum over the requested axis. + DataFrame.idxmax : Return the index of the maximum over the requested axis. + + Examples + -------- + By default, the product of an empty or all-NA Series is ``1`` + + >>> pd.Series([], dtype="float64").prod() + 1.0 + + This can be controlled with the ``min_count`` parameter + + >>> pd.Series([], dtype="float64").prod(min_count=1) + nan + + Thanks to the ``skipna`` parameter, ``min_count`` handles all-NA and + empty series identically. + + >>> pd.Series([np.nan]).prod() + 1.0 + + >>> pd.Series([np.nan]).prod(min_count=1) + nan + """ + result = super().prod( + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + min_count=min_count, + **kwargs, + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="prod") + return result + + # error: Signature of "mean" incompatible with supertype "NDFrame" + @overload # type: ignore[override] + def mean( + self, + *, + axis: Axis = ..., + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series: ... + + @overload + def mean( + self, + *, + axis: None, + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Any: ... + + @overload + def mean( + self, + *, + axis: Axis | None, + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series | Any: ... + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="mean") + def mean( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ) -> Series | Any: + """ + Return the mean of the values over the requested axis. + + Parameters + ---------- + axis : {index (0), columns (1)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + For DataFrames, specifying ``axis=None`` will apply the aggregation + across both axes. + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. + + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + Series or scalar + Value containing the calculation referenced in the description. + + See Also + -------- + Series.sum : Return the sum. + Series.min : Return the minimum. + Series.max : Return the maximum. + Series.idxmin : Return the index of the minimum. + Series.idxmax : Return the index of the maximum. + DataFrame.sum : Return the sum over the requested axis. + DataFrame.min : Return the minimum over the requested axis. + DataFrame.max : Return the maximum over the requested axis. + DataFrame.idxmin : Return the index of the minimum over the requested axis. + DataFrame.idxmax : Return the index of the maximum over the requested axis. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.mean() + 2.0 + + With a DataFrame + + >>> df = pd.DataFrame({"a": [1, 2], "b": [2, 3]}, index=["tiger", "zebra"]) + >>> df + a b + tiger 1 2 + zebra 2 3 + >>> df.mean() + a 1.5 + b 2.5 + dtype: float64 + + Using axis=1 + + >>> df.mean(axis=1) + tiger 1.5 + zebra 2.5 + dtype: float64 + + In this case, `numeric_only` should be set to `True` to avoid + getting an error. + + >>> df = pd.DataFrame({"a": [1, 2], "b": ["T", "Z"]}, index=["tiger", "zebra"]) + >>> df.mean(numeric_only=True) + a 1.5 + dtype: float64 + """ + result = super().mean( + axis=axis, skipna=skipna, numeric_only=numeric_only, **kwargs + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="mean") + return result + + # error: Signature of "median" incompatible with supertype "NDFrame" + @overload # type: ignore[override] + def median( + self, + *, + axis: Axis = ..., + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series: ... + + @overload + def median( + self, + *, + axis: None, + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Any: ... + + @overload + def median( + self, + *, + axis: Axis | None, + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series | Any: ... + + @deprecate_nonkeyword_arguments( + Pandas4Warning, allowed_args=["self"], name="median" + ) + def median( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ) -> Series | Any: + """ + Return the median of the values over the requested axis. + + Parameters + ---------- + axis : {index (0), columns (1)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + For DataFrames, specifying ``axis=None`` will apply the aggregation + across both axes. + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. + + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + Series or scalar + Value containing the calculation referenced in the description. + + See Also + -------- + Series.sum : Return the sum. + Series.min : Return the minimum. + Series.max : Return the maximum. + Series.idxmin : Return the index of the minimum. + Series.idxmax : Return the index of the maximum. + DataFrame.sum : Return the sum over the requested axis. + DataFrame.min : Return the minimum over the requested axis. + DataFrame.max : Return the maximum over the requested axis. + DataFrame.idxmin : Return the index of the minimum over the requested axis. + DataFrame.idxmax : Return the index of the maximum over the requested axis. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.median() + 2.0 + + With a DataFrame + + >>> df = pd.DataFrame({"a": [1, 2], "b": [2, 3]}, index=["tiger", "zebra"]) + >>> df + a b + tiger 1 2 + zebra 2 3 + >>> df.median() + a 1.5 + b 2.5 + dtype: float64 + + Using axis=1 + + >>> df.median(axis=1) + tiger 1.5 + zebra 2.5 + dtype: float64 + + In this case, `numeric_only` should be set to `True` + to avoid getting an error. + + >>> df = pd.DataFrame({"a": [1, 2], "b": ["T", "Z"]}, index=["tiger", "zebra"]) + >>> df.median(numeric_only=True) + a 1.5 + dtype: float64 + """ + result = super().median( + axis=axis, skipna=skipna, numeric_only=numeric_only, **kwargs + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="median") + return result + + # error: Signature of "sem" incompatible with supertype "NDFrame" + @overload # type: ignore[override] + def sem( + self, + *, + axis: Axis = ..., + skipna: bool = ..., + ddof: int = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series: ... + + @overload + def sem( + self, + *, + axis: None, + skipna: bool = ..., + ddof: int = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Any: ... + + @overload + def sem( + self, + *, + axis: Axis | None, + skipna: bool = ..., + ddof: int = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series | Any: ... + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="sem") + def sem( + self, + axis: Axis | None = 0, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, + ) -> Series | Any: + """ + Return unbiased standard error of the mean over requested axis. + + Normalized by N-1 by default. This can be changed using the ddof argument + + Parameters + ---------- + axis : {index (0), columns (1)} + For `Series` this parameter is unused and defaults to 0. + + .. warning:: + + The behavior of DataFrame.sem with ``axis=None`` is deprecated, + in a future version this will reduce over both axes and return a scalar + To retain the old behavior, pass axis=0 (or do not pass axis). + + skipna : bool, default True + Exclude NA/null values. If an entire row/column is NA, the result + will be NA. + ddof : int, default 1 + Delta Degrees of Freedom. The divisor used in calculations is N - ddof, + where N represents the number of elements. + numeric_only : bool, default False + Include only float, int, boolean columns. Not implemented for Series. + **kwargs : + Additional keywords passed. + + Returns + ------- + Series or DataFrame (if level specified) + Unbiased standard error of the mean over requested axis. + + See Also + -------- + DataFrame.var : Return unbiased variance over requested axis. + DataFrame.std : Returns sample standard deviation over requested axis. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> round(s.sem(), 6) + 0.57735 + + With a DataFrame + + >>> df = pd.DataFrame({"a": [1, 2], "b": [2, 3]}, index=["tiger", "zebra"]) + >>> df + a b + tiger 1 2 + zebra 2 3 + >>> df.sem() + a 0.5 + b 0.5 + dtype: float64 + + Using axis=1 + + >>> df.sem(axis=1) + tiger 0.5 + zebra 0.5 + dtype: float64 + + In this case, `numeric_only` should be set to `True` + to avoid getting an error. + + >>> df = pd.DataFrame({"a": [1, 2], "b": ["T", "Z"]}, index=["tiger", "zebra"]) + >>> df.sem(numeric_only=True) + a 0.5 + dtype: float64 + """ + result = super().sem( + axis=axis, skipna=skipna, ddof=ddof, numeric_only=numeric_only, **kwargs + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="sem") + return result + + # error: Signature of "var" incompatible with supertype "NDFrame" + @overload # type: ignore[override] + def var( + self, + *, + axis: Axis = ..., + skipna: bool = ..., + ddof: int = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series: ... + + @overload + def var( + self, + *, + axis: None, + skipna: bool = ..., + ddof: int = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Any: ... + + @overload + def var( + self, + *, + axis: Axis | None, + skipna: bool = ..., + ddof: int = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series | Any: ... + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="var") + def var( + self, + axis: Axis | None = 0, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, + ) -> Series | Any: + """ + Return unbiased variance over requested axis. + + Normalized by N-1 by default. This can be changed using the ddof argument. + + Parameters + ---------- + axis : {index (0), columns (1)} + For `Series` this parameter is unused and defaults to 0. + + .. warning:: + + The behavior of DataFrame.var with ``axis=None`` is deprecated, + in a future version this will reduce over both axes and return a scalar + To retain the old behavior, pass axis=0 (or do not pass axis). + + skipna : bool, default True + Exclude NA/null values. If an entire row/column is NA, the result + will be NA. + ddof : int, default 1 + Delta Degrees of Freedom. The divisor used in calculations is N - ddof, + where N represents the number of elements. + numeric_only : bool, default False + Include only float, int, boolean columns. Not implemented for Series. + **kwargs : + Additional keywords passed. + + Returns + ------- + Series or scalaer + Unbiased variance over requested axis. + + See Also + -------- + numpy.var : Equivalent function in NumPy. + Series.var : Return unbiased variance over Series values. + Series.std : Return standard deviation over Series values. + DataFrame.std : Return standard deviation of the values over + the requested axis. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "person_id": [0, 1, 2, 3], + ... "age": [21, 25, 62, 43], + ... "height": [1.61, 1.87, 1.49, 2.01], + ... } + ... ).set_index("person_id") + >>> df + age height + person_id + 0 21 1.61 + 1 25 1.87 + 2 62 1.49 + 3 43 2.01 + + >>> df.var() + age 352.916667 + height 0.056367 + dtype: float64 + + Alternatively, ``ddof=0`` can be set to normalize by N instead of N-1: + + >>> df.var(ddof=0) + age 264.687500 + height 0.042275 + dtype: float64 + """ + result = super().var( + axis=axis, skipna=skipna, ddof=ddof, numeric_only=numeric_only, **kwargs + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="var") + return result + + # error: Signature of "std" incompatible with supertype "NDFrame" + @overload # type: ignore[override] + def std( + self, + *, + axis: Axis = ..., + skipna: bool = ..., + ddof: int = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series: ... + + @overload + def std( + self, + *, + axis: None, + skipna: bool = ..., + ddof: int = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Any: ... + + @overload + def std( + self, + *, + axis: Axis | None, + skipna: bool = ..., + ddof: int = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series | Any: ... + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="std") + def std( + self, + axis: Axis | None = 0, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, + ) -> Series | Any: + """ + Return sample standard deviation over requested axis. + + Normalized by N-1 by default. This can be changed using the ddof argument. + + Parameters + ---------- + axis : {index (0), columns (1)} + For `Series` this parameter is unused and defaults to 0. + + .. warning:: + + The behavior of DataFrame.std with ``axis=None`` is deprecated, + in a future version this will reduce over both axes and return a scalar + To retain the old behavior, pass axis=0 (or do not pass axis). + + skipna : bool, default True + Exclude NA/null values. If an entire row/column is NA, the result + will be NA. + ddof : int, default 1 + Delta Degrees of Freedom. The divisor used in calculations is N - ddof, + where N represents the number of elements. + numeric_only : bool, default False + Include only float, int, boolean columns. Not implemented for Series. + **kwargs : dict + Additional keyword arguments to be passed to the function. + + Returns + ------- + Series or scalar + Standard deviation over requested axis. + + See Also + -------- + Series.std : Return standard deviation over Series values. + DataFrame.mean : Return the mean of the values over the requested axis. + DataFrame.median : Return the median of the values over the requested axis. + DataFrame.mode : Get the mode(s) of each element along the requested axis. + DataFrame.sum : Return the sum of the values over the requested axis. + + Notes + ----- + To have the same behaviour as `numpy.std`, use `ddof=0` (instead of the + default `ddof=1`) + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "person_id": [0, 1, 2, 3], + ... "age": [21, 25, 62, 43], + ... "height": [1.61, 1.87, 1.49, 2.01], + ... } + ... ).set_index("person_id") + >>> df + age height + person_id + 0 21 1.61 + 1 25 1.87 + 2 62 1.49 + 3 43 2.01 + + The standard deviation of the columns can be found as follows: + + >>> df.std() + age 18.786076 + height 0.237417 + dtype: float64 + + Alternatively, `ddof=0` can be set to normalize by N instead of N-1: + + >>> df.std(ddof=0) + age 16.269219 + height 0.205609 + dtype: float64 + """ + result = super().std( + axis=axis, skipna=skipna, ddof=ddof, numeric_only=numeric_only, **kwargs + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="std") + return result + + # error: Signature of "skew" incompatible with supertype "NDFrame" + @overload # type: ignore[override] + def skew( + self, + *, + axis: Axis = ..., + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series: ... + + @overload + def skew( + self, + *, + axis: None, + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Any: ... + + @overload + def skew( + self, + *, + axis: Axis | None, + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series | Any: ... + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="skew") + def skew( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ) -> Series | Any: + """ + Return unbiased skew over requested axis. + + Normalized by N-1. + + Parameters + ---------- + axis : {index (0), columns (1)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + For DataFrames, specifying ``axis=None`` will apply the aggregation + across both axes. + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. + + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + Series or scalar + Unbiased skew over requested axis. + + See Also + -------- + Dataframe.kurt : Returns unbiased kurtosis over requested axis. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.skew() + 0.0 + + With a DataFrame + + >>> df = pd.DataFrame( + ... {"a": [1, 2, 3], "b": [2, 3, 4], "c": [1, 3, 5]}, + ... index=["tiger", "zebra", "cow"], + ... ) + >>> df + a b c + tiger 1 2 1 + zebra 2 3 3 + cow 3 4 5 + >>> df.skew() + a 0.0 + b 0.0 + c 0.0 + dtype: float64 + + Using axis=1 + + >>> df.skew(axis=1) + tiger 1.732051 + zebra -1.732051 + cow 0.000000 + dtype: float64 + + In this case, `numeric_only` should be set to `True` to avoid + getting an error. + + >>> df = pd.DataFrame( + ... {"a": [1, 2, 3], "b": ["T", "Z", "X"]}, index=["tiger", "zebra", "cow"] + ... ) + >>> df.skew(numeric_only=True) + a 0.0 + dtype: float64 + """ + result = super().skew( + axis=axis, skipna=skipna, numeric_only=numeric_only, **kwargs + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="skew") + return result + + # error: Signature of "kurt" incompatible with supertype "NDFrame" + @overload # type: ignore[override] + def kurt( + self, + *, + axis: Axis = ..., + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series: ... + + @overload + def kurt( + self, + *, + axis: None, + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Any: ... + + @overload + def kurt( + self, + *, + axis: Axis | None, + skipna: bool = ..., + numeric_only: bool = ..., + **kwargs, + ) -> Series | Any: ... + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="kurt") + def kurt( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ) -> Series | Any: + """ + Return unbiased kurtosis over requested axis. + + Kurtosis obtained using Fisher's definition of + kurtosis (kurtosis of normal == 0.0). Normalized by N-1. + + Parameters + ---------- + axis : {index (0), columns (1)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + For DataFrames, specifying ``axis=None`` will apply the aggregation + across both axes. + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. + + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + Series or scalar + Unbiased kurtosis over requested axis. + + See Also + -------- + Dataframe.kurtosis : Returns unbiased kurtosis over requested axis. + + Examples + -------- + >>> s = pd.Series([1, 2, 2, 3], index=["cat", "dog", "dog", "mouse"]) + >>> s + cat 1 + dog 2 + dog 2 + mouse 3 + dtype: int64 + >>> s.kurt() + 1.5 + + With a DataFrame + + >>> df = pd.DataFrame( + ... {"a": [1, 2, 2, 3], "b": [3, 4, 4, 4]}, + ... index=["cat", "dog", "dog", "mouse"], + ... ) + >>> df + a b + cat 1 3 + dog 2 4 + dog 2 4 + mouse 3 4 + >>> df.kurt() + a 1.5 + b 4.0 + dtype: float64 + + With axis=None + + >>> df.kurt(axis=None) + -0.9886927196984727 + + Using axis=1 + + >>> df = pd.DataFrame( + ... {"a": [1, 2], "b": [3, 4], "c": [3, 4], "d": [1, 2]}, + ... index=["cat", "dog"], + ... ) + >>> df.kurt(axis=1) + cat -6.0 + dog -6.0 + dtype: float64 + """ + result = super().kurt( + axis=axis, skipna=skipna, numeric_only=numeric_only, **kwargs + ) + if isinstance(result, Series): + result = result.__finalize__(self, method="kurt") + return result + + # error: Incompatible types in assignment + kurtosis = kurt # type: ignore[assignment] + product = prod + + def cummin( + self, + axis: Axis = 0, + skipna: bool = True, + numeric_only: bool = False, + *args, + **kwargs, + ) -> Self: + """ + Return cumulative minimum over a DataFrame or Series axis. + + Returns a DataFrame or Series of the same size containing the cumulative + minimum. + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns'}, default 0 + The index or the name of the axis. 0 is equivalent to None or 'index'. + For `Series` this parameter is unused and defaults to 0. + skipna : bool, default True + Exclude NA/null values. If an entire row/column is NA, the result + will be NA. + numeric_only : bool, default False + Include only float, int, boolean columns. + *args, **kwargs + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + + Returns + ------- + Series or DataFrame + Return cumulative minimum of Series or DataFrame. + + See Also + -------- + core.window.expanding.Expanding.min : Similar functionality + but ignores ``NaN`` values. + DataFrame.min : Return the minimum over + DataFrame axis. + DataFrame.cummax : Return cumulative maximum over DataFrame axis. + DataFrame.cummin : Return cumulative minimum over DataFrame axis. + DataFrame.cumsum : Return cumulative sum over DataFrame axis. + DataFrame.cumprod : Return cumulative product over DataFrame axis. + + Examples + -------- + **Series** + + >>> s = pd.Series([2, np.nan, 5, -1, 0]) + >>> s + 0 2.0 + 1 NaN + 2 5.0 + 3 -1.0 + 4 0.0 + dtype: float64 + + By default, NA values are ignored. + + >>> s.cummin() + 0 2.0 + 1 NaN + 2 2.0 + 3 -1.0 + 4 -1.0 + dtype: float64 + + To include NA values in the operation, use ``skipna=False`` + + >>> s.cummin(skipna=False) + 0 2.0 + 1 NaN + 2 NaN + 3 NaN + 4 NaN + dtype: float64 + + **DataFrame** + + >>> df = pd.DataFrame( + ... [[2.0, 1.0], [3.0, np.nan], [1.0, 0.0]], columns=list("AB") + ... ) + >>> df + A B + 0 2.0 1.0 + 1 3.0 NaN + 2 1.0 0.0 + + By default, iterates over rows and finds the minimum + in each column. This is equivalent to ``axis=None`` or ``axis='index'``. + + >>> df.cummin() + A B + 0 2.0 1.0 + 1 2.0 NaN + 2 1.0 0.0 + + To iterate over columns and find the minimum in each row, + use ``axis=1`` + + >>> df.cummin(axis=1) + A B + 0 2.0 1.0 + 1 3.0 NaN + 2 1.0 0.0 + """ + data = self._get_numeric_data() if numeric_only else self + return NDFrame.cummin(data, axis, skipna, *args, **kwargs) + + def cummax( + self, + axis: Axis = 0, + skipna: bool = True, + numeric_only: bool = False, + *args, + **kwargs, + ) -> Self: + """ + Return cumulative maximum over a DataFrame or Series axis. + + Returns a DataFrame or Series of the same size containing the cumulative + maximum. + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns'}, default 0 + The index or the name of the axis. 0 is equivalent to None or 'index'. + For `Series` this parameter is unused and defaults to 0. + skipna : bool, default True + Exclude NA/null values. If an entire row/column is NA, the result + will be NA. + numeric_only : bool, default False + Include only float, int, boolean columns. + *args, **kwargs + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + + Returns + ------- + Series or DataFrame + Return cumulative maximum of Series or DataFrame. + + See Also + -------- + core.window.expanding.Expanding.max : Similar functionality + but ignores ``NaN`` values. + DataFrame.max : Return the maximum over + DataFrame axis. + DataFrame.cummax : Return cumulative maximum over DataFrame axis. + DataFrame.cummin : Return cumulative minimum over DataFrame axis. + DataFrame.cumsum : Return cumulative sum over DataFrame axis. + DataFrame.cumprod : Return cumulative product over DataFrame axis. + + Examples + -------- + **Series** + + >>> s = pd.Series([2, np.nan, 5, -1, 0]) + >>> s + 0 2.0 + 1 NaN + 2 5.0 + 3 -1.0 + 4 0.0 + dtype: float64 + + By default, NA values are ignored. + + >>> s.cummax() + 0 2.0 + 1 NaN + 2 5.0 + 3 5.0 + 4 5.0 + dtype: float64 + + To include NA values in the operation, use ``skipna=False`` + + >>> s.cummax(skipna=False) + 0 2.0 + 1 NaN + 2 NaN + 3 NaN + 4 NaN + dtype: float64 + + **DataFrame** + + >>> df = pd.DataFrame( + ... [[2.0, 1.0], [3.0, np.nan], [1.0, 0.0]], columns=list("AB") + ... ) + >>> df + A B + 0 2.0 1.0 + 1 3.0 NaN + 2 1.0 0.0 + + By default, iterates over rows and finds the maximum + in each column. This is equivalent to ``axis=None`` or ``axis='index'``. + + >>> df.cummax() + A B + 0 2.0 1.0 + 1 3.0 NaN + 2 3.0 1.0 + + To iterate over columns and find the maximum in each row, + use ``axis=1`` + + >>> df.cummax(axis=1) + A B + 0 2.0 2.0 + 1 3.0 NaN + 2 1.0 1.0 + """ + data = self._get_numeric_data() if numeric_only else self + return NDFrame.cummax(data, axis, skipna, *args, **kwargs) + + def cumsum( + self, + axis: Axis = 0, + skipna: bool = True, + numeric_only: bool = False, + *args, + **kwargs, + ) -> Self: + """ + Return cumulative sum over a DataFrame or Series axis. + + Returns a DataFrame or Series of the same size containing the cumulative + sum. + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns'}, default 0 + The index or the name of the axis. 0 is equivalent to None or 'index'. + For `Series` this parameter is unused and defaults to 0. + skipna : bool, default True + Exclude NA/null values. If an entire row/column is NA, the result + will be NA. + numeric_only : bool, default False + Include only float, int, boolean columns. + *args, **kwargs + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + + Returns + ------- + Series or DataFrame + Return cumulative sum of Series or DataFrame. + + See Also + -------- + core.window.expanding.Expanding.sum : Similar functionality + but ignores ``NaN`` values. + DataFrame.sum : Return the sum over + DataFrame axis. + DataFrame.cummax : Return cumulative maximum over DataFrame axis. + DataFrame.cummin : Return cumulative minimum over DataFrame axis. + DataFrame.cumsum : Return cumulative sum over DataFrame axis. + DataFrame.cumprod : Return cumulative product over DataFrame axis. + + Examples + -------- + **Series** + + >>> s = pd.Series([2, np.nan, 5, -1, 0]) + >>> s + 0 2.0 + 1 NaN + 2 5.0 + 3 -1.0 + 4 0.0 + dtype: float64 + + By default, NA values are ignored. + + >>> s.cumsum() + 0 2.0 + 1 NaN + 2 7.0 + 3 6.0 + 4 6.0 + dtype: float64 + + To include NA values in the operation, use ``skipna=False`` + + >>> s.cumsum(skipna=False) + 0 2.0 + 1 NaN + 2 NaN + 3 NaN + 4 NaN + dtype: float64 + + **DataFrame** + + >>> df = pd.DataFrame( + ... [[2.0, 1.0], [3.0, np.nan], [1.0, 0.0]], columns=list("AB") + ... ) + >>> df + A B + 0 2.0 1.0 + 1 3.0 NaN + 2 1.0 0.0 + + By default, iterates over rows and finds the sum + in each column. This is equivalent to ``axis=None`` or ``axis='index'``. + + >>> df.cumsum() + A B + 0 2.0 1.0 + 1 5.0 NaN + 2 6.0 1.0 + + To iterate over columns and find the sum in each row, + use ``axis=1`` + + >>> df.cumsum(axis=1) + A B + 0 2.0 3.0 + 1 3.0 NaN + 2 1.0 1.0 + """ + data = self._get_numeric_data() if numeric_only else self + return NDFrame.cumsum(data, axis, skipna, *args, **kwargs) + + def cumprod( + self, + axis: Axis = 0, + skipna: bool = True, + numeric_only: bool = False, + *args, + **kwargs, + ) -> Self: + """ + Return cumulative product over a DataFrame or Series axis. + + Returns a DataFrame or Series of the same size containing the cumulative + product. + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns'}, default 0 + The index or the name of the axis. 0 is equivalent to None or 'index'. + For `Series` this parameter is unused and defaults to 0. + skipna : bool, default True + Exclude NA/null values. If an entire row/column is NA, the result + will be NA. + numeric_only : bool, default False + Include only float, int, boolean columns. + *args, **kwargs + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + + Returns + ------- + Series or DataFrame + Return cumulative product of Series or DataFrame. + + See Also + -------- + core.window.expanding.Expanding.prod : Similar functionality + but ignores ``NaN`` values. + DataFrame.prod : Return the product over + DataFrame axis. + DataFrame.cummax : Return cumulative maximum over DataFrame axis. + DataFrame.cummin : Return cumulative minimum over DataFrame axis. + DataFrame.cumsum : Return cumulative sum over DataFrame axis. + DataFrame.cumprod : Return cumulative product over DataFrame axis. + + Examples + -------- + **Series** + + >>> s = pd.Series([2, np.nan, 5, -1, 0]) + >>> s + 0 2.0 + 1 NaN + 2 5.0 + 3 -1.0 + 4 0.0 + dtype: float64 + + By default, NA values are ignored. + + >>> s.cumprod() + 0 2.0 + 1 NaN + 2 10.0 + 3 -10.0 + 4 -0.0 + dtype: float64 + + To include NA values in the operation, use ``skipna=False`` + + >>> s.cumprod(skipna=False) + 0 2.0 + 1 NaN + 2 NaN + 3 NaN + 4 NaN + dtype: float64 + + **DataFrame** + + >>> df = pd.DataFrame( + ... [[2.0, 1.0], [3.0, np.nan], [1.0, 0.0]], columns=list("AB") + ... ) + >>> df + A B + 0 2.0 1.0 + 1 3.0 NaN + 2 1.0 0.0 + + By default, iterates over rows and finds the product + in each column. This is equivalent to ``axis=None`` or ``axis='index'``. + + >>> df.cumprod() + A B + 0 2.0 1.0 + 1 6.0 NaN + 2 6.0 0.0 + + To iterate over columns and find the product in each row, + use ``axis=1`` + + >>> df.cumprod(axis=1) + A B + 0 2.0 2.0 + 1 3.0 NaN + 2 1.0 0.0 + """ + data = self._get_numeric_data() if numeric_only else self + return NDFrame.cumprod(data, axis, skipna, *args, **kwargs) + + def nunique(self, axis: Axis = 0, dropna: bool = True) -> Series: + """ + Count number of distinct elements in specified axis. + + Return Series with number of distinct elements. Can ignore NaN + values. + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for + column-wise. + dropna : bool, default True + Don't include NaN in the counts. + + Returns + ------- + Series + Series with counts of unique values per row or column, depending on `axis`. + + See Also + -------- + Series.nunique: Method nunique for Series. + DataFrame.count: Count non-NA cells for each column or row. + + Examples + -------- + >>> df = pd.DataFrame({"A": [4, 5, 6], "B": [4, 1, 1]}) + >>> df.nunique() + A 3 + B 2 + dtype: int64 + + >>> df.nunique(axis=1) + 0 1 + 1 2 + 2 2 + dtype: int64 + """ + return self.apply(Series.nunique, axis=axis, dropna=dropna) + + def idxmin( + self, axis: Axis = 0, skipna: bool = True, numeric_only: bool = False + ) -> Series: + """ + Return index of first occurrence of minimum over requested axis. + + NA/null values are excluded. + + Parameters + ---------- + axis : {{0 or 'index', 1 or 'columns'}}, default 0 + The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for column-wise. + skipna : bool, default True + Exclude NA/null values. If the entire DataFrame is NA, + or if ``skipna=False`` and there is an NA value, this method + will raise a ``ValueError``. + numeric_only : bool, default False + Include only `float`, `int` or `boolean` data. + + Returns + ------- + Series + Indexes of minima along the specified axis. + + Raises + ------ + ValueError + * If the row/column is empty + + See Also + -------- + Series.idxmin : Return index of the minimum element. + + Notes + ----- + This method is the DataFrame version of ``ndarray.argmin``. + + Examples + -------- + Consider a dataset containing food consumption in Argentina. + + >>> df = pd.DataFrame( + ... { + ... "consumption": [10.51, 103.11, 55.48], + ... "co2_emissions": [37.2, 19.66, 1712], + ... }, + ... index=["Pork", "Wheat Products", "Beef"], + ... ) + + >>> df + consumption co2_emissions + Pork 10.51 37.20 + Wheat Products 103.11 19.66 + Beef 55.48 1712.00 + + By default, it returns the index for the minimum value in each column. + + >>> df.idxmin() + consumption Pork + co2_emissions Wheat Products + dtype: str + + To return the index for the minimum value in each row, use ``axis="columns"``. + + >>> df.idxmin(axis="columns") + Pork consumption + Wheat Products co2_emissions + Beef consumption + dtype: str + """ + axis = self._get_axis_number(axis) + + if self.empty and len(self.axes[axis]): + axis_dtype = self.axes[axis].dtype + return self._constructor_sliced(dtype=axis_dtype) + + if numeric_only: + data = self._get_numeric_data() + else: + data = self + + res = data._reduce( + nanops.nanargmin, "argmin", axis=axis, skipna=skipna, numeric_only=False + ) + indices = res._values + # indices will always be np.ndarray since axis is not N + + if (indices == -1).any(): + if skipna: + msg = "Encountered all NA values" + else: + msg = "Encountered an NA values with skipna=False" + raise ValueError(msg) + + index = data._get_axis(axis) + result = algorithms.take( + index._values, indices, allow_fill=True, fill_value=index._na_value + ) + final_result = data._constructor_sliced(result, index=data._get_agg_axis(axis)) + return final_result.__finalize__(self, method="idxmin") + + def idxmax( + self, axis: Axis = 0, skipna: bool = True, numeric_only: bool = False + ) -> Series: + """ + Return index of first occurrence of maximum over requested axis. + + NA/null values are excluded. + + Parameters + ---------- + axis : {{0 or 'index', 1 or 'columns'}}, default 0 + The axis to use. 0 or 'index' for row-wise, 1 or 'columns' for column-wise. + skipna : bool, default True + Exclude NA/null values. If the entire DataFrame is NA, + or if ``skipna=False`` and there is an NA value, this method + will raise a ``ValueError``. + numeric_only : bool, default False + Include only `float`, `int` or `boolean` data. + + Returns + ------- + Series + Indexes of maxima along the specified axis. + + Raises + ------ + ValueError + * If the row/column is empty + + See Also + -------- + Series.idxmax : Return index of the maximum element. + + Notes + ----- + This method is the DataFrame version of ``ndarray.argmax``. + + Examples + -------- + Consider a dataset containing food consumption in Argentina. + + >>> df = pd.DataFrame( + ... { + ... "consumption": [10.51, 103.11, 55.48], + ... "co2_emissions": [37.2, 19.66, 1712], + ... }, + ... index=["Pork", "Wheat Products", "Beef"], + ... ) + + >>> df + consumption co2_emissions + Pork 10.51 37.20 + Wheat Products 103.11 19.66 + Beef 55.48 1712.00 + + By default, it returns the index for the maximum value in each column. + + >>> df.idxmax() + consumption Wheat Products + co2_emissions Beef + dtype: str + + To return the index for the maximum value in each row, use ``axis="columns"``. + + >>> df.idxmax(axis="columns") + Pork co2_emissions + Wheat Products consumption + Beef co2_emissions + dtype: str + """ + axis = self._get_axis_number(axis) + + if self.empty and len(self.axes[axis]): + axis_dtype = self.axes[axis].dtype + return self._constructor_sliced(dtype=axis_dtype) + + if numeric_only: + data = self._get_numeric_data() + else: + data = self + + res = data._reduce( + nanops.nanargmax, "argmax", axis=axis, skipna=skipna, numeric_only=False + ) + indices = res._values + # indices will always be 1d array since axis is not None + + if (indices == -1).any(): + if skipna: + msg = "Encountered all NA values" + else: + msg = "Encountered an NA values with skipna=False" + raise ValueError(msg) + + index = data._get_axis(axis) + result = algorithms.take( + index._values, indices, allow_fill=True, fill_value=index._na_value + ) + final_result = data._constructor_sliced(result, index=data._get_agg_axis(axis)) + return final_result.__finalize__(self, method="idxmax") + + def _get_agg_axis(self, axis_num: int) -> Index: + """ + Let's be explicit about this. + """ + if axis_num == 0: + return self.columns + elif axis_num == 1: + return self.index + else: + raise ValueError(f"Axis must be 0 or 1 (got {axis_num!r})") + + def mode( + self, axis: Axis = 0, numeric_only: bool = False, dropna: bool = True + ) -> DataFrame: + """ + Get the mode(s) of each element along the selected axis. + + The mode of a set of values is the value that appears most often. + It can be multiple values. + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to iterate over while searching for the mode: + + * 0 or 'index' : get mode of each column + * 1 or 'columns' : get mode of each row. + + numeric_only : bool, default False + If True, only apply to numeric columns. + dropna : bool, default True + Don't consider counts of NaN/NaT. + + Returns + ------- + DataFrame + The modes of each column or row. + + See Also + -------- + Series.mode : Return the highest frequency value in a Series. + Series.value_counts : Return the counts of values in a Series. + + Examples + -------- + >>> df = pd.DataFrame( + ... [ + ... ("bird", 2, 2), + ... ("mammal", 4, np.nan), + ... ("arthropod", 8, 0), + ... ("bird", 2, np.nan), + ... ], + ... index=("falcon", "horse", "spider", "ostrich"), + ... columns=("species", "legs", "wings"), + ... ) + >>> df + species legs wings + falcon bird 2 2.0 + horse mammal 4 NaN + spider arthropod 8 0.0 + ostrich bird 2 NaN + + By default, missing values are not considered, and the mode of wings + are both 0 and 2. Because the resulting DataFrame has two rows, + the second row of ``species`` and ``legs`` contains ``NaN``. + + >>> df.mode() + species legs wings + 0 bird 2.0 0.0 + 1 NaN NaN 2.0 + + Setting ``dropna=False`` ``NaN`` values are considered and they can be + the mode (like for wings). + + >>> df.mode(dropna=False) + species legs wings + 0 bird 2 NaN + + Setting ``numeric_only=True``, only the mode of numeric columns is + computed, and columns of other types are ignored. + + >>> df.mode(numeric_only=True) + legs wings + 0 2.0 0.0 + 1 NaN 2.0 + + To compute the mode over columns and not rows, use the axis parameter: + + >>> df.mode(axis="columns", numeric_only=True) + 0 1 + falcon 2.0 NaN + horse 4.0 NaN + spider 0.0 8.0 + ostrich 2.0 NaN + """ + data = self if not numeric_only else self._get_numeric_data() + + def f(s): + return s.mode(dropna=dropna) + + data = data.apply(f, axis=axis) + # Ensure index is type stable (should always use int index) + if data.empty: + data.index = default_index(0) + + return data + + @overload + def quantile( + self, + q: float = ..., + axis: Axis = ..., + numeric_only: bool = ..., + interpolation: QuantileInterpolation = ..., + method: Literal["single", "table"] = ..., + ) -> Series: ... + + @overload + def quantile( + self, + q: AnyArrayLike | Sequence[float], + axis: Axis = ..., + numeric_only: bool = ..., + interpolation: QuantileInterpolation = ..., + method: Literal["single", "table"] = ..., + ) -> Series | DataFrame: ... + + @overload + def quantile( + self, + q: float | AnyArrayLike | Sequence[float] = ..., + axis: Axis = ..., + numeric_only: bool = ..., + interpolation: QuantileInterpolation = ..., + method: Literal["single", "table"] = ..., + ) -> Series | DataFrame: ... + + def quantile( + self, + q: float | AnyArrayLike | Sequence[float] = 0.5, + axis: Axis = 0, + numeric_only: bool = False, + interpolation: QuantileInterpolation = "linear", + method: Literal["single", "table"] = "single", + ) -> Series | DataFrame: + """ + Return values at the given quantile over requested axis. + + Parameters + ---------- + q : float or array-like, default 0.5 (50% quantile) + Value between 0 <= q <= 1, the quantile(s) to compute. + axis : {0 or 'index', 1 or 'columns'}, default 0 + Equals 0 or 'index' for row-wise, 1 or 'columns' for column-wise. + numeric_only : bool, default False + Include only `float`, `int` or `boolean` data. + + .. versionchanged:: 2.0.0 + The default value of ``numeric_only`` is now ``False``. + + interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} + This optional parameter specifies the interpolation method to use, + when the desired quantile lies between two data points `i` and `j`: + + * linear: `i + (j - i) * fraction`, where `fraction` is the + fractional part of the index surrounded by `i` and `j`. + * lower: `i`. + * higher: `j`. + * nearest: `i` or `j` whichever is nearest. + * midpoint: (`i` + `j`) / 2. + method : {'single', 'table'}, default 'single' + Whether to compute quantiles per-column ('single') or over all columns + ('table'). When 'table', the only allowed interpolation methods are + 'nearest', 'lower', and 'higher'. + + Returns + ------- + Series or DataFrame + + If ``q`` is an array, a DataFrame will be returned where the + index is ``q``, the columns are the columns of self, and the + values are the quantiles. + If ``q`` is a float, a Series will be returned where the + index is the columns of self and the values are the quantiles. + + See Also + -------- + core.window.rolling.Rolling.quantile: Rolling quantile. + numpy.percentile: Numpy function to compute the percentile. + + Examples + -------- + >>> df = pd.DataFrame( + ... np.array([[1, 1], [2, 10], [3, 100], [4, 100]]), columns=["a", "b"] + ... ) + >>> df.quantile(0.1) + a 1.3 + b 3.7 + Name: 0.1, dtype: float64 + >>> df.quantile([0.1, 0.5]) + a b + 0.1 1.3 3.7 + 0.5 2.5 55.0 + + Specifying `method='table'` will compute the quantile over all columns. + + >>> df.quantile(0.1, method="table", interpolation="nearest") + a 1 + b 1 + Name: 0.1, dtype: int64 + >>> df.quantile([0.1, 0.5], method="table", interpolation="nearest") + a b + 0.1 1 1 + 0.5 3 100 + + Specifying `numeric_only=False` will compute the quantiles for all + columns. + + >>> df = pd.DataFrame( + ... { + ... "A": [1, 2], + ... "B": [pd.Timestamp("2010"), pd.Timestamp("2011")], + ... "C": [pd.Timedelta("1 days"), pd.Timedelta("2 days")], + ... } + ... ) + >>> df.quantile(0.5, numeric_only=False) + A 1.5 + B 2010-07-02 12:00:00 + C 1 days 12:00:00 + Name: 0.5, dtype: object + """ + validate_percentile(q) + axis = self._get_axis_number(axis) + + if not is_list_like(q): + # BlockManager.quantile expects listlike, so we wrap and unwrap here + # error: List item 0 has incompatible type "float | ExtensionArray | + # ndarray[Any, Any] | Index | Series | Sequence[float]"; expected "float" + res_df = self.quantile( + [q], # type: ignore[list-item] + axis=axis, + numeric_only=numeric_only, + interpolation=interpolation, + method=method, + ) + if method == "single": + res = res_df.iloc[0] + else: + # cannot directly iloc over sparse arrays + res = res_df.T.iloc[:, 0] + if axis == 1 and len(self) == 0: + # GH#41544 try to get an appropriate dtype + dtype = find_common_type(list(self.dtypes)) + if needs_i8_conversion(dtype): + return res.astype(dtype) + return res + + q = Index(q, dtype=np.float64) + data = self._get_numeric_data() if numeric_only else self + + if axis == 1: + data = data.T + + if len(data.columns) == 0: + # GH#23925 _get_numeric_data may have dropped all columns + cols = self.columns[:0] + + dtype = np.float64 + if axis == 1: + # GH#41544 try to get an appropriate dtype + cdtype = find_common_type(list(self.dtypes)) + if needs_i8_conversion(cdtype): + dtype = cdtype + + res = self._constructor([], index=q, columns=cols, dtype=dtype) + return res.__finalize__(self, method="quantile") + + valid_method = {"single", "table"} + if method not in valid_method: + raise ValueError( + f"Invalid method: {method}. Method must be in {valid_method}." + ) + if method == "single": + res = data._mgr.quantile(qs=q, interpolation=interpolation) + elif method == "table": + valid_interpolation = {"nearest", "lower", "higher"} + if interpolation not in valid_interpolation: + raise ValueError( + f"Invalid interpolation: {interpolation}. " + f"Interpolation must be in {valid_interpolation}" + ) + # handle degenerate case + if len(data) == 0: + if data.ndim == 2: + dtype = find_common_type(list(self.dtypes)) + else: + dtype = self.dtype + return self._constructor([], index=q, columns=data.columns, dtype=dtype) + + q_idx = np.quantile(np.arange(len(data)), q, method=interpolation) + + by = data.columns + if len(by) > 1: + keys = [data._get_label_or_level_values(x) for x in by] + indexer = lexsort_indexer(keys) + else: + k = data._get_label_or_level_values(by[0]) + indexer = nargsort(k) + + res = data._mgr.take(indexer[q_idx], verify=False) + res.axes[1] = q + + result = self._constructor_from_mgr(res, axes=res.axes) + return result.__finalize__(self, method="quantile") + + def to_timestamp( + self, + freq: Frequency | None = None, + how: ToTimestampHow = "start", + axis: Axis = 0, + copy: bool | lib.NoDefault = lib.no_default, + ) -> DataFrame: + """ + Cast PeriodIndex to DatetimeIndex of timestamps, at *beginning* of period. + + This can be changed to the *end* of the period, by specifying `how="e"`. + + Parameters + ---------- + freq : str, default frequency of PeriodIndex + Desired frequency. + how : {'s', 'e', 'start', 'end'} + Convention for converting period to timestamp; start of period + vs. end. + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to convert (the index by default). + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + Returns + ------- + DataFrame with DatetimeIndex + DataFrame with the PeriodIndex cast to DatetimeIndex. + + See Also + -------- + DataFrame.to_period: Inverse method to cast DatetimeIndex to PeriodIndex. + Series.to_timestamp: Equivalent method for Series. + + Examples + -------- + >>> idx = pd.PeriodIndex(["2023", "2024"], freq="Y") + >>> d = {"col1": [1, 2], "col2": [3, 4]} + >>> df1 = pd.DataFrame(data=d, index=idx) + >>> df1 + col1 col2 + 2023 1 3 + 2024 2 4 + + The resulting timestamps will be at the beginning of the year in this case + + >>> df1 = df1.to_timestamp() + >>> df1 + col1 col2 + 2023-01-01 1 3 + 2024-01-01 2 4 + >>> df1.index + DatetimeIndex(['2023-01-01', '2024-01-01'], dtype='datetime64[us]', freq=None) + + Using `freq` which is the offset that the Timestamps will have + + >>> df2 = pd.DataFrame(data=d, index=idx) + >>> df2 = df2.to_timestamp(freq="M") + >>> df2 + col1 col2 + 2023-01-31 1 3 + 2024-01-31 2 4 + >>> df2.index + DatetimeIndex(['2023-01-31', '2024-01-31'], dtype='datetime64[us]', freq=None) + """ + self._check_copy_deprecation(copy) + new_obj = self.copy(deep=False) + + axis_name = self._get_axis_name(axis) + old_ax = getattr(self, axis_name) + if not isinstance(old_ax, PeriodIndex): + raise TypeError(f"unsupported Type {type(old_ax).__name__}") + + new_ax = old_ax.to_timestamp(freq=freq, how=how) + + setattr(new_obj, axis_name, new_ax) + return new_obj + + def to_period( + self, + freq: Frequency | None = None, + axis: Axis = 0, + copy: bool | lib.NoDefault = lib.no_default, + ) -> DataFrame: + """ + Convert DataFrame from DatetimeIndex to PeriodIndex. + + Convert DataFrame from DatetimeIndex to PeriodIndex with desired + frequency (inferred from index if not passed). Either index of columns can be + converted, depending on `axis` argument. + + Parameters + ---------- + freq : str, default + Frequency of the PeriodIndex. + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to convert (the index by default). + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + Returns + ------- + DataFrame + The DataFrame with the converted PeriodIndex. + + See Also + -------- + Series.to_period: Equivalent method for Series. + Series.dt.to_period: Convert DateTime column values. + + Examples + -------- + >>> idx = pd.to_datetime( + ... [ + ... "2001-03-31 00:00:00", + ... "2002-05-31 00:00:00", + ... "2003-08-31 00:00:00", + ... ] + ... ) + + >>> idx + DatetimeIndex(['2001-03-31', '2002-05-31', '2003-08-31'], + dtype='datetime64[us]', freq=None) + + >>> idx.to_period("M") + PeriodIndex(['2001-03', '2002-05', '2003-08'], dtype='period[M]') + + For the yearly frequency + + >>> idx.to_period("Y") + PeriodIndex(['2001', '2002', '2003'], dtype='period[Y-DEC]') + """ + self._check_copy_deprecation(copy) + new_obj = self.copy(deep=False) + + axis_name = self._get_axis_name(axis) + old_ax = getattr(self, axis_name) + if not isinstance(old_ax, DatetimeIndex): + raise TypeError(f"unsupported Type {type(old_ax).__name__}") + + new_ax = old_ax.to_period(freq=freq) + + setattr(new_obj, axis_name, new_ax) + return new_obj + + def isin(self, values: Series | DataFrame | Sequence | Mapping) -> DataFrame: + """ + Whether each element in the DataFrame is contained in values. + + Parameters + ---------- + values : iterable, Series, DataFrame or dict + The result will only be true at a location if all the + labels match. If `values` is a Series, that's the index. If + `values` is a dict, the keys must be the column names, + which must match. If `values` is a DataFrame, + then both the index and column labels must match. + + Returns + ------- + DataFrame + DataFrame of booleans showing whether each element in the DataFrame + is contained in values. + + See Also + -------- + DataFrame.eq: Equality test for DataFrame. + Series.isin: Equivalent method on Series. + Series.str.contains: Test if pattern or regex is contained within a + string of a Series or Index. + + Notes + ----- + ``__iter__`` is used (and not ``__contains__``) to iterate over values + when checking if it contains the elements in DataFrame. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"num_legs": [2, 4], "num_wings": [2, 0]}, index=["falcon", "dog"] + ... ) + >>> df + num_legs num_wings + falcon 2 2 + dog 4 0 + + When ``values`` is a list check whether every value in the DataFrame + is present in the list (which animals have 0 or 2 legs or wings) + + >>> df.isin([0, 2]) + num_legs num_wings + falcon True True + dog False True + + To check if ``values`` is *not* in the DataFrame, use the ``~`` operator: + + >>> ~df.isin([0, 2]) + num_legs num_wings + falcon False False + dog True False + + When ``values`` is a dict, we can pass values to check for each + column separately: + + >>> df.isin({"num_wings": [0, 3]}) + num_legs num_wings + falcon False False + dog False True + + When ``values`` is a Series or DataFrame the index and column must + match. Note that 'falcon' does not match based on the number of legs + in other. + + >>> other = pd.DataFrame( + ... {"num_legs": [8, 3], "num_wings": [0, 2]}, index=["spider", "falcon"] + ... ) + >>> df.isin(other) + num_legs num_wings + falcon False True + dog False False + """ + if isinstance(values, dict): + from pandas.core.reshape.concat import concat + + values = collections.defaultdict(list, values) + result = concat( + ( + self.iloc[:, [i]].isin(values[col]) + for i, col in enumerate(self.columns) + ), + axis=1, + ) + elif isinstance(values, Series): + if not values.index.is_unique: + raise ValueError("cannot compute isin with a duplicate axis.") + result = self.eq(values.reindex_like(self), axis="index") + elif isinstance(values, DataFrame): + if not (values.columns.is_unique and values.index.is_unique): + raise ValueError("cannot compute isin with a duplicate axis.") + result = self.eq(values.reindex_like(self)) + else: + if not is_list_like(values): + raise TypeError( + "only list-like or dict-like objects are allowed " + "to be passed to DataFrame.isin(), " + f"you passed a '{type(values).__name__}'" + ) + + def isin_(x): + # error: Argument 2 to "isin" has incompatible type "Union[Series, + # DataFrame, Sequence[Any], Mapping[Any, Any]]"; expected + # "Union[Union[Union[ExtensionArray, ndarray[Any, Any]], Index, + # Series], List[Any], range]" + result = algorithms.isin( + x.ravel(), + values, # type: ignore[arg-type] + ) + return result.reshape(x.shape) + + res_mgr = self._mgr.apply(isin_) + result = self._constructor_from_mgr( + res_mgr, + axes=res_mgr.axes, + ) + return result.__finalize__(self, method="isin") + + # ---------------------------------------------------------------------- + # Add index and columns + _AXIS_ORDERS: list[Literal["index", "columns"]] = ["index", "columns"] + _AXIS_TO_AXIS_NUMBER: dict[Axis, int] = { + **NDFrame._AXIS_TO_AXIS_NUMBER, + 1: 1, + "columns": 1, + } + _AXIS_LEN = len(_AXIS_ORDERS) + _info_axis_number: Literal[1] = 1 + _info_axis_name: Literal["columns"] = "columns" + + index = properties.AxisProperty( + axis=1, + doc=""" + The index (row labels) of the DataFrame. + + The index of a DataFrame is a series of labels that identify each row. + The labels can be integers, strings, or any other hashable type. The index + is used for label-based access and alignment, and can be accessed or + modified using this attribute. + + Returns + ------- + pandas.Index + The index labels of the DataFrame. + + See Also + -------- + DataFrame.columns : The column labels of the DataFrame. + DataFrame.to_numpy : Convert the DataFrame to a NumPy array. + + Examples + -------- + >>> df = pd.DataFrame({'Name': ['Alice', 'Bob', 'Aritra'], + ... 'Age': [25, 30, 35], + ... 'Location': ['Seattle', 'New York', 'Kona']}, + ... index=([10, 20, 30])) + >>> df.index + Index([10, 20, 30], dtype='int64') + + In this example, we create a DataFrame with 3 rows and 3 columns, + including Name, Age, and Location information. We set the index labels to + be the integers 10, 20, and 30. We then access the `index` attribute of the + DataFrame, which returns an `Index` object containing the index labels. + + >>> df.index = [100, 200, 300] + >>> df + Name Age Location + 100 Alice 25 Seattle + 200 Bob 30 New York + 300 Aritra 35 Kona + + In this example, we modify the index labels of the DataFrame by assigning + a new list of labels to the `index` attribute. The DataFrame is then + updated with the new labels, and the output shows the modified DataFrame. + """, + ) + columns = properties.AxisProperty( + axis=0, + doc=""" + The column labels of the DataFrame. + + This property holds the column names as a pandas ``Index`` object. + It provides an immutable sequence of column labels that can be + used for data selection, renaming, and alignment in DataFrame operations. + + Returns + ------- + pandas.Index + The column labels of the DataFrame. + + See Also + -------- + DataFrame.index: The index (row labels) of the DataFrame. + DataFrame.axes: Return a list representing the axes of the DataFrame. + + Examples + -------- + >>> df = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}) + >>> df + A B + 0 1 3 + 1 2 4 + >>> df.columns + Index(['A', 'B'], dtype='str') + """, + ) + + # ---------------------------------------------------------------------- + # Add plotting methods to DataFrame + plot = Accessor("plot", pandas.plotting.PlotAccessor) + hist = pandas.plotting.hist_frame + boxplot = pandas.plotting.boxplot_frame + sparse = Accessor("sparse", SparseFrameAccessor) + + # ---------------------------------------------------------------------- + # Internal Interface Methods + + def _to_dict_of_blocks(self): + """ + Return a dict of dtype -> Constructor Types that + each is a homogeneous dtype. + + Internal ONLY. + """ + mgr = self._mgr + return { + k: self._constructor_from_mgr(v, axes=v.axes).__finalize__(self) + for k, v in mgr.to_iter_dict() + } + + @property + def values(self) -> np.ndarray: + """ + Return a Numpy representation of the DataFrame. + + .. warning:: + + We recommend using :meth:`DataFrame.to_numpy` instead. + + Only the values in the DataFrame will be returned, the axes labels + will be removed. + + Returns + ------- + numpy.ndarray + The values of the DataFrame. + + See Also + -------- + DataFrame.to_numpy : Recommended alternative to this method. + DataFrame.index : Retrieve the index labels. + DataFrame.columns : Retrieving the column names. + + Notes + ----- + The dtype will be a lower-common-denominator dtype (implicit + upcasting); that is to say if the dtypes (even of numeric types) + are mixed, the one that accommodates all will be chosen. Use this + with care if you are not dealing with the blocks. + + e.g. If the dtypes are float16 and float32, dtype will be upcast to + float32. If dtypes are int32 and uint8, dtype will be upcast to + int32. By :func:`numpy.find_common_type` convention, mixing int64 + and uint64 will result in a float64 dtype. + + Examples + -------- + A DataFrame where all columns are the same type (e.g., int64) results + in an array of the same type. + + >>> df = pd.DataFrame( + ... {"age": [3, 29], "height": [94, 170], "weight": [31, 115]} + ... ) + >>> df + age height weight + 0 3 94 31 + 1 29 170 115 + >>> df.dtypes + age int64 + height int64 + weight int64 + dtype: object + >>> df.values + array([[ 3, 94, 31], + [ 29, 170, 115]]) + + A DataFrame with mixed type columns(e.g., str/object, int64, float32) + results in an ndarray of the broadest type that accommodates these + mixed types (e.g., object). + + >>> df2 = pd.DataFrame( + ... [ + ... ("parrot", 24.0, "second"), + ... ("lion", 80.5, 1), + ... ("monkey", np.nan, None), + ... ], + ... columns=("name", "max_speed", "rank"), + ... ) + >>> df2.dtypes + name str + max_speed float64 + rank object + dtype: object + >>> df2.values + array([['parrot', 24.0, 'second'], + ['lion', 80.5, 1], + ['monkey', nan, None]], dtype=object) + """ + return self._mgr.as_array() + + +def _from_nested_dict( + data: Mapping[HashableT, Mapping[HashableT2, T]], +) -> collections.defaultdict[HashableT2, dict[HashableT, T]]: + new_data: collections.defaultdict[HashableT2, dict[HashableT, T]] = ( + collections.defaultdict(dict) + ) + for index, s in data.items(): + for col, v in s.items(): + new_data[col][index] = v + return new_data + + +def _reindex_for_setitem( + value: DataFrame | Series, index: Index +) -> tuple[ArrayLike, BlockValuesRefs | None]: + # reindex if necessary + + if value.index.equals(index) or not len(index): + if isinstance(value, Series): + return value._values, value._references + return value._values.copy(), None + + # GH#4107 + try: + reindexed_value = value.reindex(index)._values + except ValueError as err: + # raised in MultiIndex.from_tuples, see test_insert_error_msmgs + if not value.index.is_unique: + # duplicate axis + raise err + + raise TypeError( + "incompatible index of inserted column with frame index" + ) from err + return reindexed_value, None diff --git a/pandas/core/generic.py b/pandas/core/generic.py new file mode 100644 index 0000000000000000000000000000000000000000..8a861ca8aeed7fae6eca5d772c53c91b4804b19d --- /dev/null +++ b/pandas/core/generic.py @@ -0,0 +1,13769 @@ +# pyright: reportPropertyTypeMismatch=false +from __future__ import annotations + +import collections +from copy import deepcopy +import datetime as dt +from functools import partial +from json import loads +import operator +import pickle +import re +import sys +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Concatenate, + Literal, + NoReturn, + Self, + cast, + final, + overload, +) +import warnings + +import numpy as np + +from pandas._config import config + +from pandas._libs import lib +from pandas._libs.lib import is_range_indexer +from pandas._libs.tslibs import ( + Period, + Timestamp, + to_offset, +) +from pandas._typing import ( + AlignJoin, + AnyArrayLike, + ArrayLike, + Axes, + Axis, + AxisInt, + CompressionOptions, + DtypeArg, + DtypeBackend, + DtypeObj, + FilePath, + FillnaOptions, + FloatFormatType, + FormattersType, + Frequency, + IgnoreRaise, + IndexKeyFunc, + IndexLabel, + InterpolateOptions, + IntervalClosedType, + JSONSerializable, + Level, + ListLike, + Manager, + NaPosition, + NDFrameT, + OpenFileErrors, + RandomState, + ReindexMethod, + Renamer, + Scalar, + SequenceNotStr, + SortKind, + StorageOptions, + Suffixes, + T, + TimeAmbiguous, + TimedeltaConvertibleTypes, + TimeNonexistent, + TimestampConvertibleTypes, + TimeUnit, + ValueKeyFunc, + WriteBuffer, + WriteExcelBuffer, + npt, +) +from pandas.compat import CHAINED_WARNING_DISABLED +from pandas.compat._constants import ( + REF_COUNT_METHOD, +) +from pandas.compat._optional import import_optional_dependency +from pandas.compat.numpy import function as nv +from pandas.errors import ( + AbstractMethodError, + ChainedAssignmentError, + InvalidIndexError, + Pandas4Warning, +) +from pandas.errors.cow import _chained_assignment_method_msg +from pandas.util._decorators import ( + deprecate_kwarg, + doc, +) +from pandas.util._exceptions import find_stack_level +from pandas.util._validators import ( + check_dtype_backend, + validate_ascending, + validate_bool_kwarg, + validate_inclusive, +) + +from pandas.core.dtypes.astype import astype_is_view +from pandas.core.dtypes.cast import can_hold_element +from pandas.core.dtypes.common import ( + ensure_object, + ensure_platform_int, + ensure_str, + is_bool, + is_bool_dtype, + is_dict_like, + is_extension_array_dtype, + is_list_like, + is_number, + is_numeric_dtype, + is_re_compilable, + is_scalar, + pandas_dtype, +) +from pandas.core.dtypes.dtypes import ( + DatetimeTZDtype, + ExtensionDtype, + PeriodDtype, +) +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCSeries, +) +from pandas.core.dtypes.inference import ( + is_hashable, + is_nested_list_like, +) +from pandas.core.dtypes.missing import ( + isna, + notna, +) + +from pandas.core import ( + algorithms as algos, + arraylike, + common, + indexing, + missing, + nanops, + sample, +) +from pandas.core.array_algos.replace import should_use_regex +from pandas.core.arrays import ExtensionArray +from pandas.core.base import PandasObject +from pandas.core.construction import extract_array +from pandas.core.flags import Flags +from pandas.core.indexes.api import ( + DatetimeIndex, + Index, + MultiIndex, + PeriodIndex, + default_index, + ensure_index, +) +from pandas.core.internals import BlockManager +from pandas.core.methods.describe import describe_ndframe +from pandas.core.missing import ( + clean_fill_method, + clean_reindex_fill_method, + find_valid_index, +) +from pandas.core.reshape.concat import concat +from pandas.core.shared_docs import _shared_docs +from pandas.core.sorting import get_indexer_indexer +from pandas.core.window import ( + Expanding, + ExponentialMovingWindow, + Rolling, + Window, +) + +from pandas.io.formats.format import ( + DataFrameFormatter, + DataFrameRenderer, +) +from pandas.io.formats.printing import pprint_thing + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Hashable, + Iterator, + Mapping, + Sequence, + ) + + from pandas._libs.tslibs import BaseOffset + from pandas._typing import P + + from pandas import ( + DataFrame, + ExcelWriter, + HDFStore, + Series, + ) + from pandas.core.indexers.objects import BaseIndexer + from pandas.core.resample import Resampler + + +# goal is to be able to define the docs close to function, while still being +# able to share +_shared_docs = {**_shared_docs} +_shared_doc_kwargs = { + "axes": "keywords for axes", + "klass": "Series/DataFrame", + "axes_single_arg": "{0 or 'index'} for Series, {0 or 'index', 1 or 'columns'} for DataFrame", # noqa: E501 + "inplace": """ + inplace : bool, default False + If True, performs operation inplace.""", + "optional_by": """ + by : str or list of str + Name or list of names to sort by""", +} + + +class NDFrame(PandasObject, indexing.IndexingMixin): + """ + N-dimensional analogue of DataFrame. Store multi-dimensional in a + size-mutable, labeled data structure + + Parameters + ---------- + data : BlockManager + axes : list + copy : bool, default False + """ + + _internal_names: list[str] = [ + "_mgr", + "_cache", + "_name", + "_metadata", + "_flags", + ] + _internal_names_set: set[str] = set(_internal_names) + _accessors: set[str] = set() + _hidden_attrs: frozenset[str] = frozenset([]) + _metadata: list[str] = [] + _mgr: Manager + _attrs: dict[Hashable, Any] + _typ: str + + # ---------------------------------------------------------------------- + # Constructors + + def __init__(self, data: Manager) -> None: + object.__setattr__(self, "_mgr", data) + object.__setattr__(self, "_attrs", {}) + object.__setattr__(self, "_flags", Flags(self, allows_duplicate_labels=True)) + + @final + @classmethod + def _init_mgr( + cls, + mgr: Manager, + axes: dict[Literal["index", "columns"], Axes | None], + dtype: DtypeObj | None = None, + copy: bool = False, + ) -> Manager: + """passed a manager and a axes dict""" + for a, axe in axes.items(): + if axe is not None: + axe = ensure_index(axe) + bm_axis = cls._get_block_manager_axis(a) + mgr = mgr.reindex_axis(axe, axis=bm_axis) + + # make a copy if explicitly requested + if copy: + mgr = mgr.copy(deep=True) + if dtype is not None: + # avoid further copies if we can + if ( + isinstance(mgr, BlockManager) + and len(mgr.blocks) == 1 + and mgr.blocks[0].values.dtype == dtype + ): + pass + else: + mgr = mgr.astype(dtype=dtype) + return mgr + + @final + @classmethod + def _from_mgr(cls, mgr: Manager, axes: list[Index]) -> Self: + """ + Construct a new object of this type from a Manager object and axes. + + Parameters + ---------- + mgr : Manager + Must have the same ndim as cls. + axes : list[Index] + + Notes + ----- + The axes must match mgr.axes, but are required for future-proofing + in the event that axes are refactored out of the Manager objects. + """ + obj = cls.__new__(cls) + NDFrame.__init__(obj, mgr) + return obj + + # ---------------------------------------------------------------------- + # attrs and flags + + @property + def attrs(self) -> dict[Hashable, Any]: + """ + Dictionary of global attributes of this dataset. + + .. warning:: + + attrs is experimental and may change without warning. + + See Also + -------- + DataFrame.flags : Global flags applying to this object. + + Notes + ----- + Many operations that create new datasets will copy ``attrs``. Copies + are always deep so that changing ``attrs`` will only affect the + present dataset. :func:`pandas.concat` and :func:`pandas.merge` will + only copy ``attrs`` if all input datasets have the same ``attrs``. + + Examples + -------- + For Series: + + >>> ser = pd.Series([1, 2, 3]) + >>> ser.attrs = {"A": [10, 20, 30]} + >>> ser.attrs + {'A': [10, 20, 30]} + + For DataFrame: + + >>> df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + >>> df.attrs = {"A": [10, 20, 30]} + >>> df.attrs + {'A': [10, 20, 30]} + """ + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping[Hashable, Any]) -> None: + self._attrs = dict(value) + + @final + @property + def flags(self) -> Flags: + """ + Get the properties associated with this pandas object. + + The available flags are + + * :attr:`Flags.allows_duplicate_labels` + + See Also + -------- + Flags : Flags that apply to pandas objects. + DataFrame.attrs : Global metadata applying to this dataset. + + Notes + ----- + "Flags" differ from "metadata". Flags reflect properties of the + pandas object (the Series or DataFrame). Metadata refer to properties + of the dataset, and should be stored in :attr:`DataFrame.attrs`. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 2]}) + >>> df.flags + + + Flags can be get or set using ``.`` + + >>> df.flags.allows_duplicate_labels + True + >>> df.flags.allows_duplicate_labels = False + + Or by slicing with a key + + >>> df.flags["allows_duplicate_labels"] + False + >>> df.flags["allows_duplicate_labels"] = True + """ + return self._flags + + @final + def set_flags( + self, + *, + copy: bool | lib.NoDefault = lib.no_default, + allows_duplicate_labels: bool | None = None, + ) -> Self: + """ + Return a new object with updated flags. + + This method creates a shallow copy of the original object, preserving its + underlying data while modifying its global flags. In particular, it allows + you to update properties such as whether duplicate labels are permitted. This + behavior is especially useful in method chains, where one wishes to + adjust DataFrame or Series characteristics without altering the original object. + + Parameters + ---------- + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + allows_duplicate_labels : bool, optional + Whether the returned object allows duplicate labels. + + Returns + ------- + Series or DataFrame + The same type as the caller. + + See Also + -------- + DataFrame.attrs : Global metadata applying to this dataset. + DataFrame.flags : Global flags applying to this object. + + Notes + ----- + This method returns a new object that's a view on the same data + as the input. Mutating the input or the output values will be reflected + in the other. + + This method is intended to be used in method chains. + + "Flags" differ from "metadata". Flags reflect properties of the + pandas object (the Series or DataFrame). Metadata refer to properties + of the dataset, and should be stored in :attr:`DataFrame.attrs`. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 2]}) + >>> df.flags.allows_duplicate_labels + True + >>> df2 = df.set_flags(allows_duplicate_labels=False) + >>> df2.flags.allows_duplicate_labels + False + """ + self._check_copy_deprecation(copy) + df = self.copy(deep=False) + if allows_duplicate_labels is not None: + df.flags["allows_duplicate_labels"] = allows_duplicate_labels + return df + + @final + @classmethod + def _validate_dtype(cls, dtype) -> DtypeObj | None: + """validate the passed dtype""" + if dtype is not None: + dtype = pandas_dtype(dtype) + + # a compound dtype + if dtype.kind == "V" and not isinstance(dtype, ExtensionDtype): + raise NotImplementedError( + "compound dtypes are not implemented " + f"in the {cls.__name__} constructor" + ) + + return dtype + + # ---------------------------------------------------------------------- + # Construction + + # error: Signature of "_constructor" incompatible with supertype "PandasObject" + @property + def _constructor(self) -> Callable[..., Self]: # type: ignore[override] + """ + Used when a manipulation result has the same dimensions as the + original. + """ + raise AbstractMethodError(self) + + # ---------------------------------------------------------------------- + # Axis + _AXIS_ORDERS: list[Literal["index", "columns"]] + _AXIS_TO_AXIS_NUMBER: dict[Axis, AxisInt] = {0: 0, "index": 0, "rows": 0} + _info_axis_number: int + _info_axis_name: Literal["index", "columns"] + _AXIS_LEN: int + + @final + def _construct_axes_dict( + self, axes: Sequence[Axis] | None = None, **kwargs: AxisInt + ) -> dict: + """Return an axes dictionary for myself.""" + d = {a: self._get_axis(a) for a in (axes or self._AXIS_ORDERS)} + # error: Argument 1 to "update" of "MutableMapping" has incompatible type + # "Dict[str, Any]"; expected "SupportsKeysAndGetItem[Union[int, str], Any]" + d.update(kwargs) # type: ignore[arg-type] + return d + + @final + @classmethod + def _get_axis_number(cls, axis: Axis) -> AxisInt: + try: + return cls._AXIS_TO_AXIS_NUMBER[axis] + except KeyError as err: + raise ValueError( + f"No axis named {axis} for object type {cls.__name__}" + ) from err + + @final + @classmethod + def _get_axis_name(cls, axis: Axis) -> Literal["index", "columns"]: + axis_number = cls._get_axis_number(axis) + return cls._AXIS_ORDERS[axis_number] + + @final + def _get_axis(self, axis: Axis) -> Index: + axis_number = self._get_axis_number(axis) + assert axis_number in {0, 1} + return self.index if axis_number == 0 else self.columns + + @final + @classmethod + def _get_block_manager_axis(cls, axis: Axis) -> AxisInt: + """Map the axis to the block_manager axis.""" + axis = cls._get_axis_number(axis) + ndim = cls._AXIS_LEN + if ndim == 2: + # i.e. DataFrame + return 1 - axis + return axis + + @final + def _get_axis_resolvers(self, axis: str) -> dict[str, Series | MultiIndex]: + # index or columns + axis_index = getattr(self, axis) + d = {} + prefix = axis[0] + + for i, name in enumerate(axis_index.names): + if name is not None: + key = level = name + else: + # prefix with 'i' or 'c' depending on the input axis + # e.g., you must do ilevel_0 for the 0th level of an unnamed + # multiiindex + key = f"{prefix}level_{i}" + level = i + + level_values = axis_index.get_level_values(level) + s = level_values.to_series() + s.index = axis_index + d[key] = s + + # put the index/columns itself in the dict + if isinstance(axis_index, MultiIndex): + dindex = axis_index + else: + dindex = axis_index.to_series() + + d[axis] = dindex + return d + + @final + def _get_index_resolvers(self) -> dict[Hashable, Series | MultiIndex]: + from pandas.core.computation.parsing import clean_column_name + + d: dict[str, Series | MultiIndex] = {} + for axis_name in self._AXIS_ORDERS: + d.update(self._get_axis_resolvers(axis_name)) + + return {clean_column_name(k): v for k, v in d.items() if not isinstance(k, int)} + + @final + def _get_cleaned_column_resolvers(self) -> dict[Hashable, Series]: + """ + Return the special character free column resolvers of a DataFrame. + + Column names with special characters are 'cleaned up' so that they can + be referred to by backtick quoting. + Used in :meth:`DataFrame.eval`. + """ + from pandas.core.computation.parsing import clean_column_name + from pandas.core.series import Series + + if isinstance(self, ABCSeries): + return {clean_column_name(self.name): self} + + dtypes = self.dtypes + return { + clean_column_name(k): Series( + v, copy=False, index=self.index, name=k, dtype=dtype + ).__finalize__(self) + for k, v, dtype in zip( + self.columns, + self._iter_column_arrays(), + dtypes, + strict=True, + ) + } + + @final + @property + def _info_axis(self) -> Index: + return getattr(self, self._info_axis_name) + + @property + def shape(self) -> tuple[int, ...]: + """ + Return a tuple of axis dimensions + """ + return tuple(len(self._get_axis(a)) for a in self._AXIS_ORDERS) + + @property + def axes(self) -> list[Index]: + """ + Return index label(s) of the internal NDFrame + """ + # we do it this way because if we have reversed axes, then + # the block manager shows then reversed + return [self._get_axis(a) for a in self._AXIS_ORDERS] + + @final + @property + def ndim(self) -> int: + """ + Return an int representing the number of axes / array dimensions. + + Return 1 if Series. Otherwise return 2 if DataFrame. + + See Also + -------- + numpy.ndarray.ndim : Number of array dimensions. + + Examples + -------- + >>> s = pd.Series({"a": 1, "b": 2, "c": 3}) + >>> s.ndim + 1 + + >>> df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + >>> df.ndim + 2 + """ + return self._mgr.ndim + + @final + @property + def size(self) -> int: + """ + Return an int representing the number of elements in this object. + + Return the number of rows if Series. Otherwise return the number of + rows times number of columns if DataFrame. + + See Also + -------- + numpy.ndarray.size : Number of elements in the array. + + Examples + -------- + >>> s = pd.Series({"a": 1, "b": 2, "c": 3}) + >>> s.size + 3 + + >>> df = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + >>> df.size + 4 + """ + + return int(np.prod(self.shape)) + + def set_axis( + self, + labels, + *, + axis: Axis = 0, + copy: bool | lib.NoDefault = lib.no_default, + ) -> Self: + """ + Assign desired index to given axis. + + Indexes for%(extended_summary_sub)s row labels can be changed by assigning + a list-like or Index. + + Parameters + ---------- + labels : list-like, Index + The values for the new index. + + axis : %(axes_single_arg)s, default 0 + The axis to update. The value 0 identifies the rows. For `Series` + this parameter is unused and defaults to 0. + + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + Returns + ------- + %(klass)s + An object of type %(klass)s. + + See Also + -------- + %(klass)s.rename_axis : Alter the name of the index%(see_also_sub)s. + """ + self._check_copy_deprecation(copy) + return self._set_axis_nocheck(labels, axis, inplace=False) + + @overload + def _set_axis_nocheck( + self, labels, axis: Axis, inplace: Literal[False] + ) -> Self: ... + + @overload + def _set_axis_nocheck(self, labels, axis: Axis, inplace: Literal[True]) -> None: ... + + @overload + def _set_axis_nocheck(self, labels, axis: Axis, inplace: bool) -> Self | None: ... + + @final + def _set_axis_nocheck(self, labels, axis: Axis, inplace: bool) -> Self | None: + if inplace: + setattr(self, self._get_axis_name(axis), labels) + return None + obj = self.copy(deep=False) + setattr(obj, obj._get_axis_name(axis), labels) + return obj + + @final + def _set_axis(self, axis: AxisInt, labels: AnyArrayLike | list) -> None: + """ + This is called from the cython code when we set the `index` attribute + directly, e.g. `series.index = [1, 2, 3]`. + """ + labels = ensure_index(labels) + self._mgr.set_axis(axis, labels) + + @final + def droplevel(self, level: IndexLabel, axis: Axis = 0) -> Self: + """ + Return Series/DataFrame with requested index / column level(s) removed. + + Parameters + ---------- + level : int, str, or list-like + If a string is given, must be the name of a level + If list-like, elements must be names or positional indexes + of levels. + + axis : {{0 or 'index', 1 or 'columns'}}, default 0 + Axis along which the level(s) is removed: + + * 0 or 'index': remove level(s) in column. + * 1 or 'columns': remove level(s) in row. + + For `Series` this parameter is unused and defaults to 0. + + Returns + ------- + Series/DataFrame + Series/DataFrame with requested index / column level(s) removed. + + See Also + -------- + DataFrame.replace : Replace values given in `to_replace` with `value`. + DataFrame.pivot : Return reshaped DataFrame organized by given + index / column values. + + Examples + -------- + >>> df = ( + ... pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) + ... .set_index([0, 1]) + ... .rename_axis(["a", "b"]) + ... ) + + >>> df.columns = pd.MultiIndex.from_tuples( + ... [("c", "e"), ("d", "f")], names=["level_1", "level_2"] + ... ) + + >>> df + level_1 c d + level_2 e f + a b + 1 2 3 4 + 5 6 7 8 + 9 10 11 12 + + >>> df.droplevel("a") + level_1 c d + level_2 e f + b + 2 3 4 + 6 7 8 + 10 11 12 + + >>> df.droplevel("level_2", axis=1) + level_1 c d + a b + 1 2 3 4 + 5 6 7 8 + 9 10 11 12 + """ + labels = self._get_axis(axis) + new_labels = labels.droplevel(level) + return self.set_axis(new_labels, axis=axis) + + def pop(self, item: Hashable) -> Series | Any: + result = self[item] + del self[item] + + return result + + @final + def squeeze(self, axis: Axis | None = None) -> Scalar | Series | DataFrame: + """ + Squeeze 1 dimensional axis objects into scalars. + + Series or DataFrames with a single element are squeezed to a scalar. + DataFrames with a single column or a single row are squeezed to a + Series. Otherwise the object is unchanged. + + This method is most useful when you don't know if your + object is a Series or DataFrame, but you do know it has just a single + column. In that case you can safely call `squeeze` to ensure you have a + Series. + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns', None}, default None + A specific axis to squeeze. By default, all length-1 axes are + squeezed. For `Series` this parameter is unused and defaults to `None`. + + Returns + ------- + DataFrame, Series, or scalar + The projection after squeezing `axis` or all the axes. + + See Also + -------- + Series.iloc : Integer-location based indexing for selecting scalars. + DataFrame.iloc : Integer-location based indexing for selecting Series. + Series.to_frame : Inverse of DataFrame.squeeze for a + single-column DataFrame. + + Examples + -------- + >>> primes = pd.Series([2, 3, 5, 7]) + + Slicing might produce a Series with a single value: + + >>> even_primes = primes[primes % 2 == 0] + >>> even_primes + 0 2 + dtype: int64 + + >>> even_primes.squeeze() + np.int64(2) + + Squeezing objects with more than one value in every axis does nothing: + + >>> odd_primes = primes[primes % 2 == 1] + >>> odd_primes + 1 3 + 2 5 + 3 7 + dtype: int64 + + >>> odd_primes.squeeze() + 1 3 + 2 5 + 3 7 + dtype: int64 + + Squeezing is even more effective when used with DataFrames. + + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=["a", "b"]) + >>> df + a b + 0 1 2 + 1 3 4 + + Slicing a single column will produce a DataFrame with the columns + having only one value: + + >>> df_a = df[["a"]] + >>> df_a + a + 0 1 + 1 3 + + So the columns can be squeezed down, resulting in a Series: + + >>> df_a.squeeze("columns") + 0 1 + 1 3 + Name: a, dtype: int64 + + Slicing a single row from a single column will produce a single + scalar DataFrame: + + >>> df_0a = df.loc[df.index < 1, ["a"]] + >>> df_0a + a + 0 1 + + Squeezing the rows produces a single scalar Series: + + >>> df_0a.squeeze("rows") + a 1 + Name: 0, dtype: int64 + + Squeezing all axes will project directly into a scalar: + + >>> df_0a.squeeze() + np.int64(1) + """ + axes = range(self._AXIS_LEN) if axis is None else (self._get_axis_number(axis),) + result = self.iloc[ + tuple( + 0 if i in axes and len(a) == 1 else slice(None) + for i, a in enumerate(self.axes) + ) + ] + if isinstance(result, NDFrame): + result = result.__finalize__(self, method="squeeze") + return result + + # ---------------------------------------------------------------------- + # Rename + + @overload + def _rename( + self, + mapper: Renamer | None = ..., + *, + index: Renamer | None = ..., + columns: Renamer | None = ..., + axis: Axis | None = ..., + inplace: Literal[False] = ..., + level: Level | None = ..., + errors: str = ..., + ) -> Self: ... + + @overload + def _rename( + self, + mapper: Renamer | None = ..., + *, + index: Renamer | None = ..., + columns: Renamer | None = ..., + axis: Axis | None = ..., + inplace: Literal[True], + level: Level | None = ..., + errors: str = ..., + ) -> None: ... + + @overload + def _rename( + self, + mapper: Renamer | None = ..., + *, + index: Renamer | None = ..., + columns: Renamer | None = ..., + axis: Axis | None = ..., + inplace: bool, + level: Level | None = ..., + errors: str = ..., + ) -> Self | None: ... + + @final + def _rename( + self, + mapper: Renamer | None = None, + *, + index: Renamer | None = None, + columns: Renamer | None = None, + axis: Axis | None = None, + inplace: bool = False, + level: Level | None = None, + errors: str = "ignore", + ) -> Self | None: + # called by Series.rename and DataFrame.rename + + if mapper is None and index is None and columns is None: + raise TypeError("must pass an index to rename") + + if index is not None or columns is not None: + if axis is not None: + raise TypeError( + "Cannot specify both 'axis' and any of 'index' or 'columns'" + ) + if mapper is not None: + raise TypeError( + "Cannot specify both 'mapper' and any of 'index' or 'columns'" + ) + # use the mapper argument + elif axis and self._get_axis_number(axis) == 1: + columns = mapper + else: + index = mapper + + self._check_inplace_and_allows_duplicate_labels(inplace) + result = self if inplace else self.copy(deep=False) + + for axis_no, replacements in enumerate((index, columns)): + if replacements is None: + continue + + ax = self._get_axis(axis_no) + f = common.get_rename_function(replacements) + + if level is not None: + level = ax._get_level_number(level) + + if isinstance(replacements, ABCSeries) and not replacements.index.is_unique: + # GH#58621 + raise ValueError("Cannot rename with a Series with non-unique index.") + + # GH 13473 + if not callable(replacements): + if ax._is_multi and level is not None: + indexer = ax.get_level_values(level).get_indexer_for(replacements) + else: + indexer = ax.get_indexer_for(replacements) + + if errors == "raise" and len(indexer[indexer == -1]): + missing_labels = [ + label + for index, label in enumerate(replacements) + if indexer[index] == -1 + ] + raise KeyError(f"{missing_labels} not found in axis") + + new_index = ax._transform_index(f, level=level) + result._set_axis_nocheck(new_index, axis=axis_no, inplace=True) + + if inplace: + self._update_inplace(result) + return None + else: + return result.__finalize__(self, method="rename") + + @overload + def rename_axis( + self, + mapper: IndexLabel | lib.NoDefault = ..., + *, + index=..., + columns=..., + axis: Axis = ..., + copy: bool | lib.NoDefault = lib.no_default, + inplace: Literal[False] = ..., + ) -> Self: ... + + @overload + def rename_axis( + self, + mapper: IndexLabel | lib.NoDefault = ..., + *, + index=..., + columns=..., + axis: Axis = ..., + copy: bool | lib.NoDefault = lib.no_default, + inplace: Literal[True], + ) -> None: ... + + @overload + def rename_axis( + self, + mapper: IndexLabel | lib.NoDefault = ..., + *, + index=..., + columns=..., + axis: Axis = ..., + copy: bool | lib.NoDefault = lib.no_default, + inplace: bool = ..., + ) -> Self | None: ... + + def rename_axis( + self, + mapper: IndexLabel | lib.NoDefault = lib.no_default, + *, + index=lib.no_default, + columns=lib.no_default, + axis: Axis = 0, + copy: bool | lib.NoDefault = lib.no_default, + inplace: bool = False, + ) -> Self | None: + """ + Set the name of the axis for the index or columns. + + Parameters + ---------- + mapper : scalar, list-like, optional + Value to set the axis name attribute. + + Use either ``mapper`` and ``axis`` to + specify the axis to target with ``mapper``, or ``index`` + and/or ``columns``. + index : scalar, list-like, dict-like or function, optional + A scalar, list-like, dict-like or functions transformations to + apply to that axis' values. + columns : scalar, list-like, dict-like or function, optional + A scalar, list-like, dict-like or functions transformations to + apply to that axis' values. + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to rename. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + inplace : bool, default False + Modifies the object directly, instead of creating a new Series + or DataFrame. + + Returns + ------- + DataFrame, or None + The same type as the caller or None if ``inplace=True``. + + See Also + -------- + Series.rename : Alter Series index labels or name. + DataFrame.rename : Alter DataFrame index labels or name. + Index.rename : Set new names on index. + + Notes + ----- + ``DataFrame.rename_axis`` supports two calling conventions + + * ``(index=index_mapper, columns=columns_mapper, ...)`` + * ``(mapper, axis={'index', 'columns'}, ...)`` + + The first calling convention will only modify the names of + the index and/or the names of the Index object that is the columns. + In this case, the parameter ``copy`` is ignored. + + The second calling convention will modify the names of the + corresponding index if mapper is a list or a scalar. + However, if mapper is dict-like or a function, it will use the + deprecated behavior of modifying the axis *labels*. + + We *highly* recommend using keyword arguments to clarify your + intent. + + Examples + -------- + **DataFrame** + + >>> df = pd.DataFrame( + ... {"num_legs": [4, 4, 2], "num_arms": [0, 0, 2]}, ["dog", "cat", "monkey"] + ... ) + >>> df + num_legs num_arms + dog 4 0 + cat 4 0 + monkey 2 2 + >>> df = df.rename_axis("animal") + >>> df + num_legs num_arms + animal + dog 4 0 + cat 4 0 + monkey 2 2 + >>> df = df.rename_axis("limbs", axis="columns") + >>> df + limbs num_legs num_arms + animal + dog 4 0 + cat 4 0 + monkey 2 2 + + **MultiIndex** + + >>> df.index = pd.MultiIndex.from_product( + ... [["mammal"], ["dog", "cat", "monkey"]], names=["type", "name"] + ... ) + >>> df + limbs num_legs num_arms + type name + mammal dog 4 0 + cat 4 0 + monkey 2 2 + + >>> df.rename_axis(index={"type": "class"}) + limbs num_legs num_arms + class name + mammal dog 4 0 + cat 4 0 + monkey 2 2 + + >>> df.rename_axis(columns=str.upper) + LIMBS num_legs num_arms + type name + mammal dog 4 0 + cat 4 0 + monkey 2 2 + """ + self._check_copy_deprecation(copy) + axes = {"index": index, "columns": columns} + + if axis is not None: + axis = self._get_axis_number(axis) + + inplace = validate_bool_kwarg(inplace, "inplace") + + if mapper is not lib.no_default: + # Use v0.23 behavior if a scalar or list + non_mapper = is_scalar(mapper) or ( + is_list_like(mapper) and not is_dict_like(mapper) + ) + if non_mapper: + return self._set_axis_name(mapper, axis=axis, inplace=inplace) + else: + raise ValueError("Use `.rename` to alter labels with a mapper.") + else: + # Use new behavior. Means that index and/or columns + # is specified + result = self if inplace else self.copy(deep=False) + + for axis in range(self._AXIS_LEN): + v = axes.get(self._get_axis_name(axis)) + if v is lib.no_default: + continue + non_mapper = is_scalar(v) or (is_list_like(v) and not is_dict_like(v)) + if non_mapper: + newnames = v + else: + f = common.get_rename_function(v) + curnames = self._get_axis(axis).names + newnames = [f(name) for name in curnames] + result._set_axis_name(newnames, axis=axis, inplace=True) + if not inplace: + return result + return None + + @overload + def _set_axis_name( + self, name, axis: Axis = ..., *, inplace: Literal[False] = ... + ) -> Self: ... + + @overload + def _set_axis_name( + self, name, axis: Axis = ..., *, inplace: Literal[True] + ) -> None: ... + + @overload + def _set_axis_name( + self, name, axis: Axis = ..., *, inplace: bool + ) -> Self | None: ... + + @final + def _set_axis_name( + self, name, axis: Axis = 0, *, inplace: bool = False + ) -> Self | None: + """ + Set the name(s) of the axis. + + Parameters + ---------- + name : str or list of str + Name(s) to set. + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis to set the label. The value 0 or 'index' specifies index, + and the value 1 or 'columns' specifies columns. + inplace : bool, default False + If `True`, do operation inplace and return None. + + Returns + ------- + Series, DataFrame, or None + The same type as the caller or `None` if `inplace` is `True`. + + See Also + -------- + DataFrame.rename : Alter the axis labels of :class:`DataFrame`. + Series.rename : Alter the index labels or set the index name + of :class:`Series`. + Index.rename : Set the name of :class:`Index` or :class:`MultiIndex`. + + Examples + -------- + >>> df = pd.DataFrame({"num_legs": [4, 4, 2]}, ["dog", "cat", "monkey"]) + >>> df + num_legs + dog 4 + cat 4 + monkey 2 + >>> df._set_axis_name("animal") + num_legs + animal + dog 4 + cat 4 + monkey 2 + >>> df.index = pd.MultiIndex.from_product( + ... [["mammal"], ["dog", "cat", "monkey"]] + ... ) + >>> df._set_axis_name(["type", "name"]) + num_legs + type name + mammal dog 4 + cat 4 + monkey 2 + """ + axis = self._get_axis_number(axis) + idx = self._get_axis(axis).set_names(name) + + inplace = validate_bool_kwarg(inplace, "inplace") + renamed = self if inplace else self.copy(deep=False) + if axis == 0: + renamed.index = idx + else: + renamed.columns = idx + + if not inplace: + return renamed + return None + + # ---------------------------------------------------------------------- + # Comparison Methods + + @final + def _indexed_same(self, other) -> bool: + return all( + self._get_axis(a).equals(other._get_axis(a)) for a in self._AXIS_ORDERS + ) + + @final + def equals(self, other: object) -> bool: + """ + Test whether two objects contain the same elements. + + This function allows two Series or DataFrames to be compared against + each other to see if they have the same shape and elements. NaNs in + the same location are considered equal. + + The row/column index do not need to have the same type, as long + as the values are considered equal. Corresponding columns and + index must be of the same dtype. + + Parameters + ---------- + other : Series or DataFrame + The other Series or DataFrame to be compared with the first. + + Returns + ------- + bool + True if all elements are the same in both objects, False + otherwise. + + See Also + -------- + Series.eq : Compare two Series objects of the same length + and return a Series where each element is True if the element + in each Series is equal, False otherwise. + DataFrame.eq : Compare two DataFrame objects of the same shape and + return a DataFrame where each element is True if the respective + element in each DataFrame is equal, False otherwise. + testing.assert_series_equal : Raises an AssertionError if left and + right are not equal. Provides an easy interface to ignore + inequality in dtypes, indexes and precision among others. + testing.assert_frame_equal : Like assert_series_equal, but targets + DataFrames. + numpy.array_equal : Return True if two arrays have the same shape + and elements, False otherwise. + + Examples + -------- + >>> df = pd.DataFrame({1: [10], 2: [20]}) + >>> df + 1 2 + 0 10 20 + + DataFrames df and exactly_equal have the same types and values for + their elements and column labels, which will return True. + + >>> exactly_equal = pd.DataFrame({1: [10], 2: [20]}) + >>> exactly_equal + 1 2 + 0 10 20 + >>> df.equals(exactly_equal) + True + + DataFrames df and different_column_type have the same element + types and values, but have different types for the column labels, + which will still return True. + + >>> different_column_type = pd.DataFrame({1.0: [10], 2.0: [20]}) + >>> different_column_type + 1.0 2.0 + 0 10 20 + >>> df.equals(different_column_type) + True + + DataFrames df and different_data_type have different types for the + same values for their elements, and will return False even though + their column labels are the same values and types. + + >>> different_data_type = pd.DataFrame({1: [10.0], 2: [20.0]}) + >>> different_data_type + 1 2 + 0 10.0 20.0 + >>> df.equals(different_data_type) + False + + DataFrames with NaN in the same locations compare equal. + + >>> df_nan1 = pd.DataFrame({"a": [1, np.nan], "b": [3, np.nan]}) + >>> df_nan2 = pd.DataFrame({"a": [1, np.nan], "b": [3, np.nan]}) + >>> df_nan1.equals(df_nan2) + True + + If the NaN values are not in the same locations, they compare unequal. + + >>> df_nan3 = pd.DataFrame({"a": [1, np.nan], "b": [3, 4]}) + >>> df_nan1.equals(df_nan3) + False + """ + if not (isinstance(other, type(self)) or isinstance(self, type(other))): + return False + other = cast(NDFrame, other) + return self._mgr.equals(other._mgr) + + # ------------------------------------------------------------------------- + # Unary Methods + + @final + def __neg__(self) -> Self: + def blk_func(values: ArrayLike): + if is_bool_dtype(values.dtype): + # error: Argument 1 to "inv" has incompatible type "Union + # [ExtensionArray, ndarray[Any, Any]]"; expected + # "_SupportsInversion[ndarray[Any, dtype[bool_]]]" + return operator.inv(values) # type: ignore[arg-type] + else: + # error: Argument 1 to "neg" has incompatible type "Union + # [ExtensionArray, ndarray[Any, Any]]"; expected + # "_SupportsNeg[ndarray[Any, dtype[Any]]]" + return operator.neg(values) # type: ignore[arg-type] + + new_data = self._mgr.apply(blk_func) + res = self._constructor_from_mgr(new_data, axes=new_data.axes) + return res.__finalize__(self, method="__neg__") + + @final + def __pos__(self) -> Self: + def blk_func(values: ArrayLike): + if is_bool_dtype(values.dtype): + return values.copy() + else: + # error: Argument 1 to "pos" has incompatible type "Union + # [ExtensionArray, ndarray[Any, Any]]"; expected + # "_SupportsPos[ndarray[Any, dtype[Any]]]" + return operator.pos(values) # type: ignore[arg-type] + + new_data = self._mgr.apply(blk_func) + res = self._constructor_from_mgr(new_data, axes=new_data.axes) + return res.__finalize__(self, method="__pos__") + + @final + def __invert__(self) -> Self: + if not self.size: + # inv fails with 0 len + return self.copy(deep=False) + + new_data = self._mgr.apply(operator.invert) + res = self._constructor_from_mgr(new_data, axes=new_data.axes) + return res.__finalize__(self, method="__invert__") + + @final + def __bool__(self) -> NoReturn: + raise ValueError( + f"The truth value of a {type(self).__name__} is ambiguous. " + "Use a.empty, a.bool(), a.item(), a.any() or a.all()." + ) + + @final + def abs(self) -> Self: + """ + Return a Series/DataFrame with absolute numeric value of each element. + + This function only applies to elements that are all numeric. + + Returns + ------- + abs + Series/DataFrame containing the absolute value of each element. + + See Also + -------- + numpy.absolute : Calculate the absolute value element-wise. + + Notes + ----- + For ``complex`` inputs, ``1.2 + 1j``, the absolute value is + :math:`\\sqrt{ a^2 + b^2 }`. + + Examples + -------- + Absolute numeric values in a Series. + + >>> s = pd.Series([-1.10, 2, -3.33, 4]) + >>> s.abs() + 0 1.10 + 1 2.00 + 2 3.33 + 3 4.00 + dtype: float64 + + Absolute numeric values in a Series with complex numbers. + + >>> s = pd.Series([1.2 + 1j]) + >>> s.abs() + 0 1.56205 + dtype: float64 + + Absolute numeric values in a Series with a Timedelta element. + + >>> s = pd.Series([pd.Timedelta("1 days")]) + >>> s.abs() + 0 1 days + dtype: timedelta64[us] + + Select rows with data closest to certain value using argsort (from + `StackOverflow `__). + + >>> df = pd.DataFrame( + ... {"a": [4, 5, 6, 7], "b": [10, 20, 30, 40], "c": [100, 50, -30, -50]} + ... ) + >>> df + a b c + 0 4 10 100 + 1 5 20 50 + 2 6 30 -30 + 3 7 40 -50 + >>> df.loc[(df.c - 43).abs().argsort()] + a b c + 1 5 20 50 + 0 4 10 100 + 2 6 30 -30 + 3 7 40 -50 + """ + res_mgr = self._mgr.apply(np.abs) + return self._constructor_from_mgr(res_mgr, axes=res_mgr.axes).__finalize__( + self, name="abs" + ) + + @final + def __abs__(self) -> Self: + return self.abs() + + @final + def __round__(self, decimals: int = 0) -> Self: + return self.round(decimals).__finalize__(self, method="__round__") + + # ------------------------------------------------------------------------- + # Label or Level Combination Helpers + # + # A collection of helper methods for DataFrame/Series operations that + # accept a combination of column/index labels and levels. All such + # operations should utilize/extend these methods when possible so that we + # have consistent precedence and validation logic throughout the library. + + @final + def _is_level_reference(self, key: Level, axis: Axis = 0) -> bool: + """ + Test whether a key is a level reference for a given axis. + + To be considered a level reference, `key` must be a string that: + - (axis=0): Matches the name of an index level and does NOT match + a column label. + - (axis=1): Matches the name of a column level and does NOT match + an index label. + + Parameters + ---------- + key : Hashable + Potential level name for the given axis + axis : int, default 0 + Axis that levels are associated with (0 for index, 1 for columns) + + Returns + ------- + is_level : bool + """ + axis_int = self._get_axis_number(axis) + + return ( + key is not None + and is_hashable(key) + and key in self.axes[axis_int].names + and not self._is_label_reference(key, axis=axis_int) + ) + + @final + def _is_label_reference(self, key: Level, axis: Axis = 0) -> bool: + """ + Test whether a key is a label reference for a given axis. + + To be considered a label reference, `key` must be a string that: + - (axis=0): Matches a column label + - (axis=1): Matches an index label + + Parameters + ---------- + key : Hashable + Potential label name, i.e. Index entry. + axis : int, default 0 + Axis perpendicular to the axis that labels are associated with + (0 means search for column labels, 1 means search for index labels) + + Returns + ------- + is_label: bool + """ + axis_int = self._get_axis_number(axis) + other_axes = (ax for ax in range(self._AXIS_LEN) if ax != axis_int) + + return is_hashable(key) and any(key in self.axes[ax] for ax in other_axes) + + @final + def _is_label_or_level_reference(self, key: Level, axis: AxisInt = 0) -> bool: + """ + Test whether a key is a label or level reference for a given axis. + + To be considered either a label or a level reference, `key` must be a + string that: + - (axis=0): Matches a column label or an index level + - (axis=1): Matches an index label or a column level + + Parameters + ---------- + key : Hashable + Potential label or level name + axis : int, default 0 + Axis that levels are associated with (0 for index, 1 for columns) + + Returns + ------- + bool + """ + return self._is_level_reference(key, axis=axis) or self._is_label_reference( + key, axis=axis + ) + + @final + def _check_label_or_level_ambiguity(self, key: Level, axis: Axis = 0) -> None: + """ + Check whether `key` is ambiguous. + + By ambiguous, we mean that it matches both a level of the input + `axis` and a label of the other axis. + + Parameters + ---------- + key : Hashable + Label or level name. + axis : int, default 0 + Axis that levels are associated with (0 for index, 1 for columns). + + Raises + ------ + ValueError: `key` is ambiguous + """ + + axis_int = self._get_axis_number(axis) + other_axes = (ax for ax in range(self._AXIS_LEN) if ax != axis_int) + + if ( + key is not None + and is_hashable(key) + and key in self.axes[axis_int].names + and any(key in self.axes[ax] for ax in other_axes) + ): + # Build an informative and grammatical warning + level_article, level_type = ( + ("an", "index") if axis_int == 0 else ("a", "column") + ) + + label_article, label_type = ( + ("a", "column") if axis_int == 0 else ("an", "index") + ) + + msg = ( + f"'{key}' is both {level_article} {level_type} level and " + f"{label_article} {label_type} label, which is ambiguous." + ) + raise ValueError(msg) + + @final + def _get_label_or_level_values(self, key: Level, axis: AxisInt = 0) -> ArrayLike: + """ + Return a 1-D array of values associated with `key`, a label or level + from the given `axis`. + + Retrieval logic: + - (axis=0): Return column values if `key` matches a column label. + Otherwise return index level values if `key` matches an index + level. + - (axis=1): Return row values if `key` matches an index label. + Otherwise return column level values if 'key' matches a column + level + + Parameters + ---------- + key : Hashable + Label or level name. + axis : int, default 0 + Axis that levels are associated with (0 for index, 1 for columns) + + Returns + ------- + np.ndarray or ExtensionArray + + Raises + ------ + KeyError + if `key` matches neither a label nor a level + ValueError + if `key` matches multiple labels + """ + axis = self._get_axis_number(axis) + first_other_axes = next( + (ax for ax in range(self._AXIS_LEN) if ax != axis), None + ) + + if self._is_label_reference(key, axis=axis): + self._check_label_or_level_ambiguity(key, axis=axis) + if first_other_axes is None: + raise ValueError("axis matched all axes") + values = self.xs(key, axis=first_other_axes)._values + elif self._is_level_reference(key, axis=axis): + values = self.axes[axis].get_level_values(key)._values + else: + raise KeyError(key) + + # Check for duplicates + if values.ndim > 1: + if first_other_axes is not None and isinstance( + self._get_axis(first_other_axes), MultiIndex + ): + multi_message = ( + "\n" + "For a multi-index, the label must be a " + "tuple with elements corresponding to each level." + ) + else: + multi_message = "" + + label_axis_name = "column" if axis == 0 else "index" + raise ValueError( + f"The {label_axis_name} label '{key}' is not unique.{multi_message}" + ) + + return values + + @final + def _drop_labels_or_levels(self, keys, axis: AxisInt = 0): + """ + Drop labels and/or levels for the given `axis`. + + For each key in `keys`: + - (axis=0): If key matches a column label then drop the column. + Otherwise if key matches an index level then drop the level. + - (axis=1): If key matches an index label then drop the row. + Otherwise if key matches a column level then drop the level. + + Parameters + ---------- + keys : str or list of str + labels or levels to drop + axis : int, default 0 + Axis that levels are associated with (0 for index, 1 for columns) + + Returns + ------- + dropped: DataFrame + + Raises + ------ + ValueError + if any `keys` match neither a label nor a level + """ + axis = self._get_axis_number(axis) + + # Validate keys + keys = common.maybe_make_list(keys) + invalid_keys = [ + k for k in keys if not self._is_label_or_level_reference(k, axis=axis) + ] + + if invalid_keys: + raise ValueError( + "The following keys are not valid labels or " + f"levels for axis {axis}: {invalid_keys}" + ) + + # Compute levels and labels to drop + levels_to_drop = [k for k in keys if self._is_level_reference(k, axis=axis)] + + labels_to_drop = [k for k in keys if not self._is_level_reference(k, axis=axis)] + + # Perform copy upfront and then use inplace operations below. + # This ensures that we always perform exactly one copy. + # ``copy`` and/or ``inplace`` options could be added in the future. + dropped = self.copy(deep=False) + + if axis == 0: + # Handle dropping index levels + if levels_to_drop: + dropped.reset_index(levels_to_drop, drop=True, inplace=True) + + # Handle dropping columns labels + if labels_to_drop: + dropped.drop(labels_to_drop, axis=1, inplace=True) + else: + # Handle dropping column levels + if levels_to_drop: + if isinstance(dropped.columns, MultiIndex): + # Drop the specified levels from the MultiIndex + dropped.columns = dropped.columns.droplevel(levels_to_drop) + else: + # Drop the last level of Index by replacing with + # a RangeIndex + dropped.columns = default_index(dropped.columns.size) + + # Handle dropping index labels + if labels_to_drop: + dropped.drop(labels_to_drop, axis=0, inplace=True) + + return dropped + + # ---------------------------------------------------------------------- + # Iteration + + # https://github.com/python/typeshed/issues/2148#issuecomment-520783318 + # Incompatible types in assignment (expression has type "None", base class + # "object" defined the type as "Callable[[object], int]") + __hash__: ClassVar[None] # type: ignore[assignment] + + def __iter__(self) -> Iterator: + """ + Iterate over info axis. + + Returns + ------- + iterator + Info axis as iterator. + + See Also + -------- + DataFrame.items : Iterate over (column name, Series) pairs. + DataFrame.itertuples : Iterate over DataFrame rows as namedtuples. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + >>> for x in df: + ... print(x) + A + B + """ + return iter(self._info_axis) + + # can we get a better explanation of this? + def keys(self) -> Index: + """ + Get the 'info axis' (see Indexing for more). + + This is index for Series, columns for DataFrame. + + Returns + ------- + Index + Info axis. + + See Also + -------- + DataFrame.index : The index (row labels) of the DataFrame. + DataFrame.columns: The column labels of the DataFrame. + + Examples + -------- + >>> d = pd.DataFrame( + ... data={"A": [1, 2, 3], "B": [0, 4, 8]}, index=["a", "b", "c"] + ... ) + >>> d + A B + a 1 0 + b 2 4 + c 3 8 + >>> d.keys() + Index(['A', 'B'], dtype='str') + """ + return self._info_axis + + def items(self): + """ + Iterate over (label, values) on info axis + + This is index for Series and columns for DataFrame. + + Returns + ------- + Generator + """ + for h in self._info_axis: + yield h, self[h] + + def __len__(self) -> int: + """Returns length of info axis""" + return len(self._info_axis) + + @final + def __contains__(self, key) -> bool: + """True if the key is in the info axis""" + return key in self._info_axis + + @property + def empty(self) -> bool: + """ + Indicator whether Series/DataFrame is empty. + + True if Series/DataFrame is entirely empty (no items), meaning any of the + axes are of length 0. + + Returns + ------- + bool + If Series/DataFrame is empty, return True, if not return False. + + See Also + -------- + Series.dropna : Return series without null values. + DataFrame.dropna : Return DataFrame with labels on given axis omitted + where (all or any) data are missing. + + Notes + ----- + If Series/DataFrame contains only NaNs, it is still not considered empty. See + the example below. + + Examples + -------- + An example of an actual empty DataFrame. Notice the index is empty: + + >>> df_empty = pd.DataFrame({"A": []}) + >>> df_empty + Empty DataFrame + Columns: [A] + Index: [] + >>> df_empty.empty + True + + If we only have NaNs in our DataFrame, it is not considered empty! We + will need to drop the NaNs to make the DataFrame empty: + + >>> df = pd.DataFrame({"A": [np.nan]}) + >>> df + A + 0 NaN + >>> df.empty + False + >>> df.dropna().empty + True + + >>> ser_empty = pd.Series({"A": []}) + >>> ser_empty + A [] + dtype: object + >>> ser_empty.empty + False + >>> ser_empty = pd.Series() + >>> ser_empty.empty + True + """ + return any(len(self._get_axis(a)) == 0 for a in self._AXIS_ORDERS) + + # ---------------------------------------------------------------------- + # Array Interface + + # This is also set in IndexOpsMixin + # GH#23114 Ensure ndarray.__op__(DataFrame) returns NotImplemented + __array_priority__: int = 1000 + + def __array__( + self, dtype: npt.DTypeLike | None = None, copy: bool | None = None + ) -> np.ndarray: + if copy is False and not self._mgr.is_single_block and not self.empty: + # check this manually, otherwise ._values will already return a copy + # and np.array(values, copy=False) will not raise an error + raise ValueError( + "Unable to avoid copy while creating an array as requested." + ) + values = self._values + if copy is None: + # Note: branch avoids `copy=None` for NumPy 1.x support + arr = np.asarray(values, dtype=dtype) + else: + arr = np.array(values, dtype=dtype, copy=copy) + + if ( + copy is not True + and astype_is_view(values.dtype, arr.dtype) + and self._mgr.is_single_block + ): + # Check if both conversions can be done without a copy + if astype_is_view(self.dtypes.iloc[0], values.dtype) and astype_is_view( + values.dtype, arr.dtype + ): + arr = arr.view() + arr.flags.writeable = False + return arr + + @final + def __array_ufunc__( + self, ufunc: np.ufunc, method: str, *inputs: Any, **kwargs: Any + ): + return arraylike.array_ufunc(self, ufunc, method, *inputs, **kwargs) + + # ---------------------------------------------------------------------- + # Picklability + + @final + def __getstate__(self) -> dict[str, Any]: + meta = {k: getattr(self, k, None) for k in self._metadata} + return { + "_mgr": self._mgr, + "_typ": self._typ, + "_metadata": self._metadata, + "attrs": self.attrs, + "_flags": {k: self.flags[k] for k in self.flags._keys}, + **meta, + } + + @final + def __setstate__(self, state) -> None: + if isinstance(state, BlockManager): + self._mgr = state + elif isinstance(state, dict): + if "_data" in state and "_mgr" not in state: + # compat for older pickles + state["_mgr"] = state.pop("_data") + typ = state.get("_typ") + if typ is not None: + attrs = state.get("_attrs", {}) + if attrs is None: # should not happen, but better be on the safe side + attrs = {} + object.__setattr__(self, "_attrs", attrs) + flags = state.get("_flags", {"allows_duplicate_labels": True}) + object.__setattr__(self, "_flags", Flags(self, **flags)) + + # set in the order of internal names + # to avoid definitional recursion + # e.g. say fill_value needing _mgr to be + # defined + meta = set(self._internal_names + self._metadata) + for k in meta: + if k in state and k != "_flags": + v = state[k] + object.__setattr__(self, k, v) + + for k, v in state.items(): + if k not in meta: + object.__setattr__(self, k, v) + + else: + raise NotImplementedError("Pre-0.12 pickles are no longer supported") + elif len(state) == 2: + raise NotImplementedError("Pre-0.12 pickles are no longer supported") + + # ---------------------------------------------------------------------- + # Rendering Methods + + def __repr__(self) -> str: + # string representation based upon iterating over self + # (since, by definition, `PandasContainers` are iterable) + prepr = f"[{','.join(map(pprint_thing, self))}]" + return f"{type(self).__name__}({prepr})" + + @final + def _repr_latex_(self): + """ + Returns a LaTeX representation for a particular object. + Mainly for use with nbconvert (jupyter notebook conversion to pdf). + """ + if config.get_option("styler.render.repr") == "latex": + return self.to_latex() + else: + return None + + @final + def _repr_data_resource_(self): + """ + Not a real Jupyter special repr method, but we use the same + naming convention. + """ + if config.get_option("display.html.table_schema"): + data = self.head(config.get_option("display.max_rows")) + + as_json = data.to_json(orient="table") + as_json = cast(str, as_json) + return loads(as_json, object_pairs_hook=collections.OrderedDict) + + # ---------------------------------------------------------------------- + # I/O Methods + + @final + def to_excel( + self, + excel_writer: FilePath | WriteExcelBuffer | ExcelWriter, + *, + sheet_name: str = "Sheet1", + na_rep: str = "", + float_format: str | None = None, + columns: Sequence[Hashable] | None = None, + header: Sequence[Hashable] | bool = True, + index: bool = True, + index_label: IndexLabel | None = None, + startrow: int = 0, + startcol: int = 0, + engine: Literal["openpyxl", "xlsxwriter"] | None = None, + merge_cells: bool = True, + inf_rep: str = "inf", + freeze_panes: tuple[int, int] | None = None, + storage_options: StorageOptions | None = None, + engine_kwargs: dict[str, Any] | None = None, + autofilter: bool = False, + ) -> None: + """ + Write object to an Excel sheet. + + To write a single object to an Excel .xlsx file it is only necessary to + specify a target file name. To write to multiple sheets it is necessary to + create an `ExcelWriter` object with a target file name, and specify a sheet + in the file to write to. + + Multiple sheets may be written to by specifying unique `sheet_name`. + With all data written to the file it is necessary to save the changes. + Note that creating an `ExcelWriter` object with a file name that already exists + will overwrite the existing file because the default mode is write. + + Parameters + ---------- + excel_writer : path-like, file-like, or ExcelWriter object + File path or existing ExcelWriter. + sheet_name : str, default 'Sheet1' + Name of sheet which will contain DataFrame. + na_rep : str, default '' + Missing data representation. + float_format : str, optional + Format string for floating point numbers. For example + ``float_format="%.2f"`` will format 0.1234 to 0.12. + columns : sequence or list of str, optional + Columns to write. + header : bool or list of str, default True + Write out the column names. If a list of string is given it is + assumed to be aliases for the column names. + index : bool, default True + Write row names (index). + index_label : str or sequence, optional + Column label for index column(s) if desired. If not specified, and + `header` and `index` are True, then the index names are used. A + sequence should be given if the DataFrame uses MultiIndex. + startrow : int, default 0 + Upper left cell row to dump data frame. + startcol : int, default 0 + Upper left cell column to dump data frame. + engine : str, optional + Write engine to use, 'openpyxl' or 'xlsxwriter'. You can also set this + via the options ``io.excel.xlsx.writer`` or + ``io.excel.xlsm.writer``. + merge_cells : bool or 'columns', default False + If True, write MultiIndex index and columns as merged cells. + If 'columns', merge MultiIndex column cells only. + inf_rep : str, default 'inf' + Representation for infinity (there is no native representation for + infinity in Excel). + freeze_panes : tuple of int (length 2), optional + Specifies the one-based bottommost row and rightmost column that + is to be frozen. + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + engine_kwargs : dict, optional + Arbitrary keyword arguments passed to excel engine. + autofilter : bool, default False + If True, add automatic filters to all columns. + + See Also + -------- + to_csv : Write DataFrame to a comma-separated values (csv) file. + ExcelWriter : Class for writing DataFrame objects into excel sheets. + read_excel : Read an Excel file into a pandas DataFrame. + read_csv : Read a comma-separated values (csv) file into DataFrame. + io.formats.style.Styler.to_excel : Add styles to Excel sheet. + + Notes + ----- + For compatibility with :meth:`~DataFrame.to_csv`, + to_excel serializes lists and dicts to strings before writing. + + Once a workbook has been saved it is not possible to write further + data without rewriting the whole workbook. + + pandas will check the number of rows, columns, + and cell character count does not exceed Excel's limitations. + All other limitations must be checked by the user. + + Examples + -------- + + Create, write to and save a workbook: + + >>> df1 = pd.DataFrame( + ... [["a", "b"], ["c", "d"]], + ... index=["row 1", "row 2"], + ... columns=["col 1", "col 2"], + ... ) + >>> df1.to_excel("output.xlsx") # doctest: +SKIP + + To specify the sheet name: + + >>> df1.to_excel("output.xlsx", sheet_name="Sheet_name_1") # doctest: +SKIP + + If you wish to write to more than one sheet in the workbook, it is + necessary to specify an ExcelWriter object: + + >>> df2 = df1.copy() + >>> with pd.ExcelWriter("output.xlsx") as writer: # doctest: +SKIP + ... df1.to_excel(writer, sheet_name="Sheet_name_1") + ... df2.to_excel(writer, sheet_name="Sheet_name_2") + + ExcelWriter can also be used to append to an existing Excel file: + + >>> with pd.ExcelWriter("output.xlsx", mode="a") as writer: # doctest: +SKIP + ... df1.to_excel(writer, sheet_name="Sheet_name_3") + + To set the library that is used to write the Excel file, + you can pass the `engine` keyword (the default engine is + automatically chosen depending on the file extension): + + >>> df1.to_excel("output1.xlsx", engine="xlsxwriter") # doctest: +SKIP + """ + if engine_kwargs is None: + engine_kwargs = {} + + df = self if isinstance(self, ABCDataFrame) else self.to_frame() + + from pandas.io.formats.excel import ExcelFormatter + + formatter = ExcelFormatter( + df, + na_rep=na_rep, + cols=columns, + header=header, + float_format=float_format, + index=index, + index_label=index_label, + merge_cells=merge_cells, + inf_rep=inf_rep, + autofilter=autofilter, + ) + formatter.write( + excel_writer, + sheet_name=sheet_name, + startrow=startrow, + startcol=startcol, + freeze_panes=freeze_panes, + engine=engine, + storage_options=storage_options, + engine_kwargs=engine_kwargs, + ) + + @final + def to_json( + self, + path_or_buf: FilePath | WriteBuffer[bytes] | WriteBuffer[str] | None = None, + *, + orient: Literal["split", "records", "index", "table", "columns", "values"] + | None = None, + date_format: str | None = None, + double_precision: int = 10, + force_ascii: bool = True, + date_unit: TimeUnit = "ms", + default_handler: Callable[[Any], JSONSerializable] | None = None, + lines: bool = False, + compression: CompressionOptions = "infer", + index: bool | None = None, + indent: int | None = None, + storage_options: StorageOptions | None = None, + mode: Literal["a", "w"] = "w", + ) -> str | None: + """ + Convert the object to a JSON string. + + Note NaN's and None will be converted to null and datetime objects + will be converted to UNIX timestamps. + + Parameters + ---------- + path_or_buf : str, path object, file-like object, or None, default None + String, path object (implementing os.PathLike[str]), or file-like + object implementing a write() function. If None, the result is + returned as a string. + orient : str + Indication of expected JSON string format. + + * Series: + + - default is 'index' + - allowed values are: {{'split', 'records', 'index', 'table'}}. + + * DataFrame: + + - default is 'columns' + - allowed values are: {{'split', 'records', 'index', 'columns', + 'values', 'table'}}. + + * The format of the JSON string: + + - 'split' : dict like {{'index' -> [index], 'columns' -> [columns], + 'data' -> [values]}} + - 'records' : list like [{{column -> value}}, ... , {{column -> value}}] + - 'index' : dict like {{index -> {{column -> value}}}} + - 'columns' : dict like {{column -> {{index -> value}}}} + - 'values' : just the values array + - 'table' : dict like {{'schema': {{schema}}, 'data': {{data}}}} + + Describing the data, where data component is like ``orient='records'``. + + date_format : {{None, 'epoch', 'iso'}} + Type of date conversion. 'epoch' = epoch milliseconds, + 'iso' = ISO8601. The default depends on the `orient`. For + ``orient='table'``, the default is 'iso'. For all other orients, + the default is 'epoch'. + + .. deprecated:: 3.0.0 + 'epoch' date format is deprecated and will be removed in a future + version, please use 'iso' instead. + + double_precision : int, default 10 + The number of decimal places to use when encoding + floating point values. The possible maximal value is 15. + Passing double_precision greater than 15 will raise a ValueError. + force_ascii : bool, default True + Force encoded string to be ASCII. + date_unit : str, default 'ms' (milliseconds) + The time unit to encode to, governs timestamp and ISO8601 + precision. One of 's', 'ms', 'us', 'ns' for second, millisecond, + microsecond, and nanosecond respectively. + default_handler : callable, default None + Handler to call if object cannot otherwise be converted to a + suitable format for JSON. Should receive a single argument which is + the object to convert and return a serialisable object. + lines : bool, default False + If 'orient' is 'records' write out line-delimited json format. Will + throw ValueError if incorrect 'orient' since others are not + list-like. + + compression : str or dict, default 'infer' + For on-the-fly compression of the output data. If 'infer' and + 'path_or_buf' is path-like, then detect compression from the following + extensions: '.gz', + '.bz2', '.zip', '.xz', '.zst', '.tar', '.tar.gz', '.tar.xz' or '.tar.bz2' + (otherwise no compression). + Set to ``None`` for no compression. + Can also be a dict with key ``'method'`` set to one of + {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} and + other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdCompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for faster compression and + to create a reproducible gzip archive: + ``compression={'method': 'gzip', 'compresslevel': 1, 'mtime': 1}``. + + index : bool or None, default None + The index is only used when 'orient' is 'split', 'index', 'column', + or 'table'. Of these, 'index' and 'column' do not support + `index=False`. The string 'index' as a column name with empty :class:`Index` + or if it is 'index' will raise a ``ValueError``. + + indent : int, optional + Length of whitespace used to indent each record. + + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + mode : str, default 'w' (writing) + Specify the IO mode for output when supplying a path_or_buf. + Accepted args are 'w' (writing) and 'a' (append) only. + mode='a' is only supported when lines is True and orient is 'records'. + + Returns + ------- + None or str + If path_or_buf is None, returns the resulting json format as a + string. Otherwise returns None. + + See Also + -------- + read_json : Convert a JSON string to pandas object. + + Notes + ----- + The behavior of ``indent=0`` varies from the stdlib, which does not + indent the output but does insert newlines. Currently, ``indent=0`` + and the default ``indent=None`` are equivalent in pandas, though this + may change in a future release. + + ``orient='table'`` contains a 'pandas_version' field under 'schema'. + This stores the version of `pandas` used in the latest revision of the + schema. + + Examples + -------- + >>> from json import loads, dumps + >>> df = pd.DataFrame( + ... [["a", "b"], ["c", "d"]], + ... index=["row 1", "row 2"], + ... columns=["col 1", "col 2"], + ... ) + + >>> result = df.to_json(orient="split") + >>> parsed = loads(result) + >>> dumps(parsed, indent=4) # doctest: +SKIP + {{ + "columns": [ + "col 1", + "col 2" + ], + "index": [ + "row 1", + "row 2" + ], + "data": [ + [ + "a", + "b" + ], + [ + "c", + "d" + ] + ] + }} + + Encoding/decoding a Dataframe using ``'records'`` formatted JSON. + Note that index labels are not preserved with this encoding. + + >>> result = df.to_json(orient="records") + >>> parsed = loads(result) + >>> dumps(parsed, indent=4) # doctest: +SKIP + [ + {{ + "col 1": "a", + "col 2": "b" + }}, + {{ + "col 1": "c", + "col 2": "d" + }} + ] + + Encoding/decoding a Dataframe using ``'index'`` formatted JSON: + + >>> result = df.to_json(orient="index") + >>> parsed = loads(result) + >>> dumps(parsed, indent=4) # doctest: +SKIP + {{ + "row 1": {{ + "col 1": "a", + "col 2": "b" + }}, + "row 2": {{ + "col 1": "c", + "col 2": "d" + }} + }} + + Encoding/decoding a Dataframe using ``'columns'`` formatted JSON: + + >>> result = df.to_json(orient="columns") + >>> parsed = loads(result) + >>> dumps(parsed, indent=4) # doctest: +SKIP + {{ + "col 1": {{ + "row 1": "a", + "row 2": "c" + }}, + "col 2": {{ + "row 1": "b", + "row 2": "d" + }} + }} + + Encoding/decoding a Dataframe using ``'values'`` formatted JSON: + + >>> result = df.to_json(orient="values") + >>> parsed = loads(result) + >>> dumps(parsed, indent=4) # doctest: +SKIP + [ + [ + "a", + "b" + ], + [ + "c", + "d" + ] + ] + + Encoding with Table Schema: + + >>> result = df.to_json(orient="table") + >>> parsed = loads(result) + >>> dumps(parsed, indent=4) # doctest: +SKIP + {{ + "schema": {{ + "fields": [ + {{ + "name": "index", + "type": "string" + }}, + {{ + "name": "col 1", + "type": "string" + }}, + {{ + "name": "col 2", + "type": "string" + }} + ], + "primaryKey": [ + "index" + ], + "pandas_version": "1.4.0" + }}, + "data": [ + {{ + "index": "row 1", + "col 1": "a", + "col 2": "b" + }}, + {{ + "index": "row 2", + "col 1": "c", + "col 2": "d" + }} + ] + }} + """ + from pandas.io import json + + if date_format is None and orient == "table": + date_format = "iso" + elif date_format is None: + date_format = "epoch" + dtypes = self.dtypes if self.ndim == 2 else [self.dtype] + if any(dtype.kind in "mM" for dtype in dtypes): + warnings.warn( + "The default 'epoch' date format is deprecated and will be removed " + "in a future version, please use 'iso' date format instead.", + Pandas4Warning, + stacklevel=find_stack_level(), + ) + elif date_format == "epoch": + # GH#57063 + warnings.warn( + "'epoch' date format is deprecated and will be removed in a future " + "version, please use 'iso' date format instead.", + Pandas4Warning, + stacklevel=find_stack_level(), + ) + + config.is_nonnegative_int(indent) + indent = indent or 0 + + return json.to_json( + path_or_buf=path_or_buf, + obj=self, + orient=orient, + date_format=date_format, + double_precision=double_precision, + force_ascii=force_ascii, + date_unit=date_unit, + default_handler=default_handler, + lines=lines, + compression=compression, + index=index, + indent=indent, + storage_options=storage_options, + mode=mode, + ) + + @final + def to_hdf( + self, + path_or_buf: FilePath | HDFStore, + *, + key: str, + mode: Literal["a", "w", "r+"] = "a", + complevel: int | None = None, + complib: Literal["zlib", "lzo", "bzip2", "blosc"] | None = None, + append: bool = False, + format: Literal["fixed", "table"] | None = None, + index: bool = True, + min_itemsize: int | dict[str, int] | None = None, + nan_rep=None, + dropna: bool | None = None, + data_columns: Literal[True] | list[str] | None = None, + errors: OpenFileErrors = "strict", + encoding: str = "UTF-8", + ) -> None: + """ + Write the contained data to an HDF5 file using HDFStore. + + Hierarchical Data Format (HDF) is self-describing, allowing an + application to interpret the structure and contents of a file with + no outside information. One HDF file can hold a mix of related objects + which can be accessed as a group or as individual objects. + + In order to add another DataFrame or Series to an existing HDF file + please use append mode and a different a key. + + .. warning:: + + One can store a subclass of ``DataFrame`` or ``Series`` to HDF5, + but the type of the subclass is lost upon storing. + + For more information see the :ref:`user guide `. + + Parameters + ---------- + path_or_buf : str or pandas.HDFStore + File path or HDFStore object. + key : str + Identifier for the group in the store. + mode : {'a', 'w', 'r+'}, default 'a' + Mode to open file: + + - 'w': write, a new file is created (an existing file with + the same name would be deleted). + - 'a': append, an existing file is opened for reading and + writing, and if the file does not exist it is created. + - 'r+': similar to 'a', but the file must already exist. + complevel : {0-9}, default None + Specifies a compression level for data. + A value of 0 or None disables compression. + complib : {'zlib', 'lzo', 'bzip2', 'blosc'}, default 'zlib' + Specifies the compression library to be used. + These additional compressors for Blosc are supported + (default if no compressor specified: 'blosc:blosclz'): + {'blosc:blosclz', 'blosc:lz4', 'blosc:lz4hc', 'blosc:snappy', + 'blosc:zlib', 'blosc:zstd'}. + Specifying a compression library which is not available issues + a ValueError. + append : bool, default False + For Table formats, append the input data to the existing. + format : {'fixed', 'table', None}, default 'fixed' + Possible values: + + - 'fixed': Fixed format. Fast writing/reading. Not-appendable, + nor searchable. + - 'table': Table format. Write as a PyTables Table structure + which may perform worse but allow more flexible operations + like searching / selecting subsets of the data. + - If None, pd.get_option('io.hdf.default_format') is checked, + followed by fallback to "fixed". + index : bool, default True + Write DataFrame index as a column. + min_itemsize : dict or int, optional + Map column names to minimum string sizes for columns. + nan_rep : Any, optional + How to represent null values as str. + Not allowed with append=True. + dropna : bool, default False, optional + Remove missing values. + data_columns : list of columns or True, optional + List of columns to create as indexed data columns for on-disk + queries, or True to use all columns. By default only the axes + of the object are indexed. See + :ref:`Query via data columns`. for + more information. + Applicable only to format='table'. + errors : str, default 'strict' + Specifies how encoding and decoding errors are to be handled. + See the errors argument for :func:`open` for a full list + of options. + encoding : str, default "UTF-8" + Set character encoding. + + See Also + -------- + read_hdf : Read from HDF file. + DataFrame.to_orc : Write a DataFrame to the binary orc format. + DataFrame.to_parquet : Write a DataFrame to the binary parquet format. + DataFrame.to_sql : Write to a SQL table. + DataFrame.to_feather : Write out feather-format for DataFrames. + DataFrame.to_csv : Write out to a csv file. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"A": [1, 2, 3], "B": [4, 5, 6]}, index=["a", "b", "c"] + ... ) # doctest: +SKIP + >>> df.to_hdf("data.h5", key="df", mode="w") # doctest: +SKIP + + We can add another object to the same file: + + >>> s = pd.Series([1, 2, 3, 4]) # doctest: +SKIP + >>> s.to_hdf("data.h5", key="s") # doctest: +SKIP + + Reading from HDF file: + + >>> pd.read_hdf("data.h5", "df") # doctest: +SKIP + A B + a 1 4 + b 2 5 + c 3 6 + >>> pd.read_hdf("data.h5", "s") # doctest: +SKIP + 0 1 + 1 2 + 2 3 + 3 4 + dtype: int64 + """ + from pandas.io import pytables + + # Argument 3 to "to_hdf" has incompatible type "NDFrame"; expected + # "Union[DataFrame, Series]" [arg-type] + pytables.to_hdf( + path_or_buf, + key, + self, # type: ignore[arg-type] + mode=mode, + complevel=complevel, + complib=complib, + append=append, + format=format, + index=index, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + dropna=dropna, + data_columns=data_columns, + errors=errors, + encoding=encoding, + ) + + @final + def to_sql( + self, + name: str, + con, + *, + schema: str | None = None, + if_exists: Literal["fail", "replace", "append", "delete_rows"] = "fail", + index: bool = True, + index_label: IndexLabel | None = None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + method: Literal["multi"] | Callable | None = None, + ) -> int | None: + """ + Write records stored in a DataFrame to a SQL database. + + Databases supported by SQLAlchemy [1]_ are supported. Tables can be + newly created, appended to, or overwritten. + + .. warning:: + The pandas library does not attempt to sanitize inputs provided via a to_sql call. + Please refer to the documentation for the underlying database driver to see if it + will properly prevent injection, or alternatively be advised of a security risk when + executing arbitrary commands in a to_sql call. + + Parameters + ---------- + name : str + Name of SQL table. + con : ADBC connection, sqlalchemy.engine.(Engine or Connection) or sqlite3.Connection + ADBC provides high performance I/O with native type support, where available. + Using SQLAlchemy makes it possible to use any DB supported by that + library. Legacy support is provided for sqlite3.Connection objects. The user + is responsible for engine disposal and connection closure for the SQLAlchemy + connectable. See `here \ + `_. + If passing a sqlalchemy.engine.Connection which is already in a transaction, + the transaction will not be committed. If passing a sqlite3.Connection, + it will not be possible to roll back the record insertion. + + schema : str, optional + Specify the schema (if database flavor supports this). If None, use + default schema. + if_exists : {'fail', 'replace', 'append', 'delete_rows'}, default 'fail' + How to behave if the table already exists. + + * fail: Raise a ValueError. + * replace: Drop the table before inserting new values. + * append: Insert new values to the existing table. + * delete_rows: If a table exists, delete all records and insert data. + + index : bool, default True + Write DataFrame index as a column. Uses `index_label` as the column + name in the table. Creates a table index for this column. + index_label : str or sequence, default None + Column label for index column(s). If None is given (default) and + `index` is True, then the index names are used. + A sequence should be given if the DataFrame uses MultiIndex. + chunksize : int, optional + Specify the number of rows in each batch to be written to the database connection at a time. + By default, all rows will be written at once. Also see the method keyword. + dtype : dict or scalar, optional + Specifying the datatype for columns. If a dictionary is used, the + keys should be the column names and the values should be the + SQLAlchemy types or strings for the sqlite3 legacy mode. If a + scalar is provided, it will be applied to all columns. + method : {None, 'multi', callable}, optional + Controls the SQL insertion clause used: + + * None : Uses standard SQL ``INSERT`` clause (one per row). + * 'multi': Pass multiple values in a single ``INSERT`` clause. + * callable with signature ``(pd_table, conn, keys, data_iter)``. + + Details and a sample callable implementation can be found in the + section :ref:`insert method `. + + Returns + ------- + None or int + Number of rows affected by to_sql. None is returned if the callable + passed into ``method`` does not return an integer number of rows. + + The number of returned rows affected is the sum of the ``rowcount`` + attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not + reflect the exact number of written rows as stipulated in the + `sqlite3 `__ or + `SQLAlchemy `__. + + Raises + ------ + ValueError + When the table already exists and `if_exists` is 'fail' (the + default). + + See Also + -------- + read_sql : Read a DataFrame from a table. + + Notes + ----- + Timezone aware datetime columns will be written as + ``Timestamp with timezone`` type with SQLAlchemy if supported by the + database. Otherwise, the datetimes will be stored as timezone unaware + timestamps local to the original timezone. + + Not all datastores support ``method="multi"``. Oracle, for example, + does not support multi-value insert. + + References + ---------- + .. [1] https://docs.sqlalchemy.org + .. [2] https://www.python.org/dev/peps/pep-0249/ + + Examples + -------- + Create an in-memory SQLite database. + + >>> from sqlalchemy import create_engine + >>> engine = create_engine('sqlite://', echo=False) + + Create a table from scratch with 3 rows. + + >>> df = pd.DataFrame({'name' : ['User 1', 'User 2', 'User 3']}) + >>> df + name + 0 User 1 + 1 User 2 + 2 User 3 + + >>> df.to_sql(name='users', con=engine) + 3 + >>> from sqlalchemy import text + >>> with engine.connect() as conn: + ... conn.execute(text("SELECT * FROM users")).fetchall() + [(0, 'User 1'), (1, 'User 2'), (2, 'User 3')] + + An `sqlalchemy.engine.Connection` can also be passed to `con`: + + >>> with engine.begin() as connection: + ... df1 = pd.DataFrame({'name' : ['User 4', 'User 5']}) + ... df1.to_sql(name='users', con=connection, if_exists='append') + 2 + + This is allowed to support operations that require that the same + DBAPI connection is used for the entire operation. + + >>> df2 = pd.DataFrame({'name' : ['User 6', 'User 7']}) + >>> df2.to_sql(name='users', con=engine, if_exists='append') + 2 + >>> with engine.connect() as conn: + ... conn.execute(text("SELECT * FROM users")).fetchall() + [(0, 'User 1'), (1, 'User 2'), (2, 'User 3'), + (0, 'User 4'), (1, 'User 5'), (0, 'User 6'), + (1, 'User 7')] + + Overwrite the table with just ``df2``. + + >>> df2.to_sql(name='users', con=engine, if_exists='replace', + ... index_label='id') + 2 + >>> with engine.connect() as conn: + ... conn.execute(text("SELECT * FROM users")).fetchall() + [(0, 'User 6'), (1, 'User 7')] + + Delete all rows before inserting new records with ``df3`` + + >>> df3 = pd.DataFrame({"name": ['User 8', 'User 9']}) + >>> df3.to_sql(name='users', con=engine, if_exists='delete_rows', + ... index_label='id') + 2 + >>> with engine.connect() as conn: + ... conn.execute(text("SELECT * FROM users")).fetchall() + [(0, 'User 8'), (1, 'User 9')] + + Use ``method`` to define a callable insertion method to do nothing + if there's a primary key conflict on a table in a PostgreSQL database. + + >>> from sqlalchemy.dialects.postgresql import insert + >>> def insert_on_conflict_nothing(table, conn, keys, data_iter): + ... # "a" is the primary key in "conflict_table" + ... data = [dict(zip(keys, row)) for row in data_iter] + ... stmt = insert(table.table).values(data).on_conflict_do_nothing(index_elements=["a"]) + ... result = conn.execute(stmt) + ... return result.rowcount + >>> df_conflict.to_sql(name="conflict_table", con=conn, if_exists="append", # noqa: F821 + ... method=insert_on_conflict_nothing) # doctest: +SKIP + 0 + + For MySQL, a callable to update columns ``b`` and ``c`` if there's a conflict + on a primary key. + + >>> from sqlalchemy.dialects.mysql import insert # noqa: F811 + >>> def insert_on_conflict_update(table, conn, keys, data_iter): + ... # update columns "b" and "c" on primary key conflict + ... data = [dict(zip(keys, row)) for row in data_iter] + ... stmt = ( + ... insert(table.table) + ... .values(data) + ... ) + ... stmt = stmt.on_duplicate_key_update(b=stmt.inserted.b, c=stmt.inserted.c) + ... result = conn.execute(stmt) + ... return result.rowcount + >>> df_conflict.to_sql(name="conflict_table", con=conn, if_exists="append", # noqa: F821 + ... method=insert_on_conflict_update) # doctest: +SKIP + 2 + + Specify the dtype (especially useful for integers with missing values). + Notice that while pandas is forced to store the data as floating point, + the database supports nullable integers. When fetching the data with + Python, we get back integer scalars. + + >>> df = pd.DataFrame({"A": [1, None, 2]}) + >>> df + A + 0 1.0 + 1 NaN + 2 2.0 + + >>> from sqlalchemy.types import Integer + >>> df.to_sql(name='integers', con=engine, index=False, + ... dtype={"A": Integer()}) + 3 + + >>> with engine.connect() as conn: + ... conn.execute(text("SELECT * FROM integers")).fetchall() + [(1,), (None,), (2,)] + + .. versionadded:: 2.2.0 + + pandas now supports writing via ADBC drivers + + >>> df = pd.DataFrame({'name' : ['User 10', 'User 11', 'User 12']}) + >>> df + name + 0 User 10 + 1 User 11 + 2 User 12 + + >>> from adbc_driver_sqlite import dbapi # doctest:+SKIP + >>> with dbapi.connect("sqlite://") as conn: # doctest:+SKIP + ... df.to_sql(name="users", con=conn) + 3 + """ # noqa: E501 + from pandas.io import sql + + return sql.to_sql( + self, + name, + con, + schema=schema, + if_exists=if_exists, + index=index, + index_label=index_label, + chunksize=chunksize, + dtype=dtype, + method=method, + ) + + @final + def to_pickle( + self, + path: FilePath | WriteBuffer[bytes], + *, + compression: CompressionOptions = "infer", + protocol: int = pickle.HIGHEST_PROTOCOL, + storage_options: StorageOptions | None = None, + ) -> None: + """ + Pickle (serialize) object to file. + + Parameters + ---------- + path : str, path object, or file-like object + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``write()`` function. File path where + the pickled object will be stored. + + compression : str or dict, default 'infer' + For on-the-fly compression of the output data. If 'infer' and + 'path_or_buf' is path-like, then detect compression from the following + extensions: '.gz', + '.bz2', '.zip', '.xz', '.zst', '.tar', '.tar.gz', '.tar.xz' or '.tar.bz2' + (otherwise no compression). + Set to ``None`` for no compression. + Can also be a dict with key ``'method'`` set to one of + {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} and + other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdCompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for faster compression and + to create a reproducible gzip archive: + ``compression={'method': 'gzip', 'compresslevel': 1, 'mtime': 1}``. + + protocol : int + Int which indicates which protocol should be used by the pickler, + default HIGHEST_PROTOCOL (see [1]_ paragraph 12.1.2). The possible + values are 0, 1, 2, 3, 4, 5. A negative value for the protocol + parameter is equivalent to setting its value to HIGHEST_PROTOCOL. + + .. [1] https://docs.python.org/3/library/pickle.html. + + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + See Also + -------- + read_pickle : Load pickled pandas object (or any object) from file. + DataFrame.to_hdf : Write DataFrame to an HDF5 file. + DataFrame.to_sql : Write DataFrame to a SQL database. + DataFrame.to_parquet : Write a DataFrame to the binary parquet format. + + Examples + -------- + >>> original_df = pd.DataFrame( + ... {{"foo": range(5), "bar": range(5, 10)}} + ... ) # doctest: +SKIP + >>> original_df # doctest: +SKIP + foo bar + 0 0 5 + 1 1 6 + 2 2 7 + 3 3 8 + 4 4 9 + >>> original_df.to_pickle("./dummy.pkl") # doctest: +SKIP + + >>> unpickled_df = pd.read_pickle("./dummy.pkl") # doctest: +SKIP + >>> unpickled_df # doctest: +SKIP + foo bar + 0 0 5 + 1 1 6 + 2 2 7 + 3 3 8 + 4 4 9 + """ + from pandas.io.pickle import to_pickle + + to_pickle( + self, + path, + compression=compression, + protocol=protocol, + storage_options=storage_options, + ) + + @final + def to_clipboard( + self, *, excel: bool = True, sep: str | None = None, **kwargs + ) -> None: + r""" + Copy object to the system clipboard. + + Write a text representation of object to the system clipboard. + This can be pasted into Excel, for example. + + Parameters + ---------- + excel : bool, default True + Produce output in a csv format for easy pasting into excel. + + - True, use the provided separator for csv pasting. + - False, write a string representation of the object to the clipboard. + + sep : str, default ``'\t'`` + Field delimiter. + **kwargs + These parameters will be passed to DataFrame.to_csv. + + See Also + -------- + DataFrame.to_csv : Write a DataFrame to a comma-separated values + (csv) file. + read_clipboard : Read text from clipboard and pass to read_csv. + + Notes + ----- + Requirements for your platform. + + - Linux : `xclip`, or `xsel` (with `PyQt4` modules) + - Windows : none + - macOS : none + + This method uses the processes developed for the package `pyperclip`. A + solution to render any output string format is given in the examples. + + Examples + -------- + Copy the contents of a DataFrame to the clipboard. + + >>> df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["A", "B", "C"]) + + >>> df.to_clipboard(sep=",") # doctest: +SKIP + ... # Wrote the following to the system clipboard: + ... # ,A,B,C + ... # 0,1,2,3 + ... # 1,4,5,6 + + We can omit the index by passing the keyword `index` and setting + it to false. + + >>> df.to_clipboard(sep=",", index=False) # doctest: +SKIP + ... # Wrote the following to the system clipboard: + ... # A,B,C + ... # 1,2,3 + ... # 4,5,6 + + Using the original `pyperclip` package for any string output format. + + .. code-block:: python + + import pyperclip + + html = df.style.to_html() + pyperclip.copy(html) + """ + from pandas.io import clipboards + + clipboards.to_clipboard(self, excel=excel, sep=sep, **kwargs) + + @final + def to_xarray(self): + """ + Return an xarray object from the pandas object. + + Returns + ------- + xarray.DataArray or xarray.Dataset + Data in the pandas structure converted to Dataset if the object is + a DataFrame, or a DataArray if the object is a Series. + + See Also + -------- + DataFrame.to_hdf : Write DataFrame to an HDF5 file. + DataFrame.to_parquet : Write a DataFrame to the binary parquet format. + + Notes + ----- + See the `xarray docs `__ + + Examples + -------- + >>> df = pd.DataFrame( + ... [ + ... ("falcon", "bird", 389.0, 2), + ... ("parrot", "bird", 24.0, 2), + ... ("lion", "mammal", 80.5, 4), + ... ("monkey", "mammal", np.nan, 4), + ... ], + ... columns=["name", "class", "max_speed", "num_legs"], + ... ) + >>> df + name class max_speed num_legs + 0 falcon bird 389.0 2 + 1 parrot bird 24.0 2 + 2 lion mammal 80.5 4 + 3 monkey mammal NaN 4 + + >>> df.to_xarray() # doctest: +SKIP + + Dimensions: (index: 4) + Coordinates: + * index (index) int64 32B 0 1 2 3 + Data variables: + name (index) object 32B 'falcon' 'parrot' 'lion' 'monkey' + class (index) object 32B 'bird' 'bird' 'mammal' 'mammal' + max_speed (index) float64 32B 389.0 24.0 80.5 nan + num_legs (index) int64 32B 2 2 4 4 + + >>> df["max_speed"].to_xarray() # doctest: +SKIP + + array([389. , 24. , 80.5, nan]) + Coordinates: + * index (index) int64 0 1 2 3 + + >>> dates = pd.to_datetime( + ... ["2018-01-01", "2018-01-01", "2018-01-02", "2018-01-02"] + ... ) + >>> df_multiindex = pd.DataFrame( + ... { + ... "date": dates, + ... "animal": ["falcon", "parrot", "falcon", "parrot"], + ... "speed": [350, 18, 361, 15], + ... } + ... ) + >>> df_multiindex = df_multiindex.set_index(["date", "animal"]) + + >>> df_multiindex + speed + date animal + 2018-01-01 falcon 350 + parrot 18 + 2018-01-02 falcon 361 + parrot 15 + + >>> df_multiindex.to_xarray() # doctest: +SKIP + + Dimensions: (date: 2, animal: 2) + Coordinates: + * date (date) datetime64[s] 2018-01-01 2018-01-02 + * animal (animal) object 'falcon' 'parrot' + Data variables: + speed (date, animal) int64 350 18 361 15 + """ + xarray = import_optional_dependency("xarray") + + if self.ndim == 1: + return xarray.DataArray.from_series(self) + else: + return xarray.Dataset.from_dataframe(self) + + @overload + def to_latex( + self, + buf: None = ..., + *, + columns: Sequence[Hashable] | None = ..., + header: bool | SequenceNotStr[str] = ..., + index: bool = ..., + na_rep: str = ..., + formatters: FormattersType | None = ..., + float_format: FloatFormatType | None = ..., + sparsify: bool | None = ..., + index_names: bool = ..., + bold_rows: bool = ..., + column_format: str | None = ..., + longtable: bool | None = ..., + escape: bool | None = ..., + encoding: str | None = ..., + decimal: str = ..., + multicolumn: bool | None = ..., + multicolumn_format: str | None = ..., + multirow: bool | None = ..., + caption: str | tuple[str, str] | None = ..., + label: str | None = ..., + position: str | None = ..., + ) -> str: ... + + @overload + def to_latex( + self, + buf: FilePath | WriteBuffer[str], + *, + columns: Sequence[Hashable] | None = ..., + header: bool | SequenceNotStr[str] = ..., + index: bool = ..., + na_rep: str = ..., + formatters: FormattersType | None = ..., + float_format: FloatFormatType | None = ..., + sparsify: bool | None = ..., + index_names: bool = ..., + bold_rows: bool = ..., + column_format: str | None = ..., + longtable: bool | None = ..., + escape: bool | None = ..., + encoding: str | None = ..., + decimal: str = ..., + multicolumn: bool | None = ..., + multicolumn_format: str | None = ..., + multirow: bool | None = ..., + caption: str | tuple[str, str] | None = ..., + label: str | None = ..., + position: str | None = ..., + ) -> None: ... + + @final + def to_latex( + self, + buf: FilePath | WriteBuffer[str] | None = None, + *, + columns: Sequence[Hashable] | None = None, + header: bool | SequenceNotStr[str] = True, + index: bool = True, + na_rep: str = "NaN", + formatters: FormattersType | None = None, + float_format: FloatFormatType | None = None, + sparsify: bool | None = None, + index_names: bool = True, + bold_rows: bool = False, + column_format: str | None = None, + longtable: bool | None = None, + escape: bool | None = None, + encoding: str | None = None, + decimal: str = ".", + multicolumn: bool | None = None, + multicolumn_format: str | None = None, + multirow: bool | None = None, + caption: str | tuple[str, str] | None = None, + label: str | None = None, + position: str | None = None, + ) -> str | None: + r""" + Render object to a LaTeX tabular, longtable, or nested table. + + Requires ``\usepackage{booktabs}``. The output can be copy/pasted + into a main LaTeX document or read from an external file + with ``\input{table.tex}``. + + .. versionchanged:: 2.0.0 + Refactored to use the Styler implementation via jinja2 templating. + + Parameters + ---------- + buf : str, Path or StringIO-like, optional, default None + Buffer to write to. If None, the output is returned as a string. + columns : list of label, optional + The subset of columns to write. Writes all columns by default. + header : bool or list of str, default True + Write out the column names. If a list of strings is given, + it is assumed to be aliases for the column names. Braces must be escaped. + index : bool, default True + Write row names (index). + na_rep : str, default 'NaN' + Missing data representation. + formatters : list of functions or dict of {str: function}, optional + Formatter functions to apply to columns' elements by position or + name. The result of each function must be a unicode string. + List must be of length equal to the number of columns. + float_format : one-parameter function or str, optional, default None + Formatter for floating point numbers. For example + ``float_format="%.2f"`` and ``float_format="{:0.2f}".format`` will + both result in 0.1234 being formatted as 0.12. + sparsify : bool, optional + Set to False for a DataFrame with a hierarchical index to print + every multiindex key at each row. By default, the value will be + read from the config module. + index_names : bool, default True + Prints the names of the indexes. + bold_rows : bool, default False + Make the row labels bold in the output. + column_format : str, optional + The columns format as specified in `LaTeX table format + `__ e.g. 'rcl' for 3 + columns. By default, 'l' will be used for all columns except + columns of numbers, which default to 'r'. + longtable : bool, optional + Use a longtable environment instead of tabular. Requires + adding a \usepackage{longtable} to your LaTeX preamble. + By default, the value will be read from the pandas config + module, and set to `True` if the option ``styler.latex.environment`` is + `"longtable"`. + + .. versionchanged:: 2.0.0 + The pandas option affecting this argument has changed. + escape : bool, optional + By default, the value will be read from the pandas config + module and set to `True` if the option ``styler.format.escape`` is + `"latex"`. When set to False prevents from escaping latex special + characters in column names. + + .. versionchanged:: 2.0.0 + The pandas option affecting this argument has changed, as has the + default value to `False`. + encoding : str, optional + A string representing the encoding to use in the output file, + defaults to 'utf-8'. + decimal : str, default '.' + Character recognized as decimal separator, e.g. ',' in Europe. + multicolumn : bool, default True + Use \multicolumn to enhance MultiIndex columns. + The default will be read from the config module, and is set + as the option ``styler.sparse.columns``. + + .. versionchanged:: 2.0.0 + The pandas option affecting this argument has changed. + multicolumn_format : str, default 'r' + The alignment for multicolumns, similar to `column_format` + The default will be read from the config module, and is set as the option + ``styler.latex.multicol_align``. + + .. versionchanged:: 2.0.0 + The pandas option affecting this argument has changed, as has the + default value to "r". + multirow : bool, default True + Use \multirow to enhance MultiIndex rows. Requires adding a + \usepackage{multirow} to your LaTeX preamble. Will print + centered labels (instead of top-aligned) across the contained + rows, separating groups via clines. The default will be read + from the pandas config module, and is set as the option + ``styler.sparse.index``. + + .. versionchanged:: 2.0.0 + The pandas option affecting this argument has changed, as has the + default value to `True`. + caption : str or tuple, optional + Tuple (full_caption, short_caption), + which results in ``\caption[short_caption]{full_caption}``; + if a single string is passed, no short caption will be set. + label : str, optional + The LaTeX label to be placed inside ``\label{}`` in the output. + This is used with ``\ref{}`` in the main ``.tex`` file. + + position : str, optional + The LaTeX positional argument for tables, to be placed after + ``\begin{}`` in the output. + + Returns + ------- + str or None + If buf is None, returns the result as a string. Otherwise returns None. + + See Also + -------- + io.formats.style.Styler.to_latex : Render a DataFrame to LaTeX + with conditional formatting. + DataFrame.to_string : Render a DataFrame to a console-friendly + tabular output. + DataFrame.to_html : Render a DataFrame as an HTML table. + + Notes + ----- + As of v2.0.0 this method has changed to use the Styler implementation as + part of :meth:`.Styler.to_latex` via ``jinja2`` templating. This means + that ``jinja2`` is a requirement, and needs to be installed, for this method + to function. It is advised that users switch to using Styler, since that + implementation is more frequently updated and contains much more + flexibility with the output. + + Examples + -------- + Convert a general DataFrame to LaTeX with formatting: + + >>> df = pd.DataFrame(dict(name=['Raphael', 'Donatello'], + ... age=[26, 45], + ... height=[181.23, 177.65])) + >>> print(df.to_latex(index=False, + ... formatters={"name": str.upper}, + ... float_format="{:.1f}".format, + ... )) # doctest: +SKIP + \begin{tabular}{lrr} + \toprule + name & age & height \\ + \midrule + RAPHAEL & 26 & 181.2 \\ + DONATELLO & 45 & 177.7 \\ + \bottomrule + \end{tabular} + """ + # Get defaults from the pandas config + if self.ndim == 1: + self = self.to_frame() + if longtable is None: + longtable = config.get_option("styler.latex.environment") == "longtable" + if escape is None: + escape = config.get_option("styler.format.escape") == "latex" + if multicolumn is None: + multicolumn = config.get_option("styler.sparse.columns") + if multicolumn_format is None: + multicolumn_format = config.get_option("styler.latex.multicol_align") + if multirow is None: + multirow = config.get_option("styler.sparse.index") + + if column_format is not None and not isinstance(column_format, str): + raise ValueError("`column_format` must be str or unicode") + length = len(self.columns) if columns is None else len(columns) + if isinstance(header, (list, tuple)) and len(header) != length: + raise ValueError(f"Writing {length} cols but got {len(header)} aliases") + + # Refactor formatters/float_format/decimal/na_rep/escape to Styler structure + base_format_ = { + "na_rep": na_rep, + "escape": "latex" if escape else None, + "decimal": decimal, + } + index_format_: dict[str, Any] = {"axis": 0, **base_format_} + column_format_: dict[str, Any] = {"axis": 1, **base_format_} + + if isinstance(float_format, str): + float_format_: Callable | None = lambda x: float_format % x + else: + float_format_ = float_format + + def _wrap(x, alt_format_): + if isinstance(x, (float, complex)) and float_format_ is not None: + return float_format_(x) + else: + return alt_format_(x) + + formatters_: list | tuple | dict | Callable | None = None + if isinstance(formatters, list): + formatters_ = { + c: partial(_wrap, alt_format_=formatters[i]) + for i, c in enumerate(self.columns) + } + elif isinstance(formatters, dict): + index_formatter = formatters.pop("__index__", None) + column_formatter = formatters.pop("__columns__", None) + if index_formatter is not None: + index_format_.update({"formatter": index_formatter}) + if column_formatter is not None: + column_format_.update({"formatter": column_formatter}) + + formatters_ = formatters + float_columns = self.select_dtypes(include="float").columns + for col in float_columns: + if col not in formatters.keys(): + formatters_.update({col: float_format_}) + elif formatters is None and float_format is not None: + formatters_ = partial(_wrap, alt_format_=lambda v: v) + format_index_ = [index_format_, column_format_] + format_index_names_ = [index_format_, column_format_] + + # Deal with hiding indexes and relabelling column names + hide_: list[dict] = [] + relabel_index_: list[dict] = [] + if columns: + hide_.append( + { + "subset": [c for c in self.columns if c not in columns], + "axis": "columns", + } + ) + if header is False: + hide_.append({"axis": "columns"}) + elif isinstance(header, (list, tuple)): + relabel_index_.append({"labels": header, "axis": "columns"}) + format_index_ = [index_format_] # column_format is overwritten + + if index is False: + hide_.append({"axis": "index"}) + if index_names is False: + hide_.append({"names": True, "axis": "index"}) + + render_kwargs_ = { + "hrules": True, + "sparse_index": sparsify, + "sparse_columns": sparsify, + "environment": "longtable" if longtable else None, + "multicol_align": multicolumn_format + if multicolumn + else f"naive-{multicolumn_format}", + "multirow_align": "t" if multirow else "naive", + "encoding": encoding, + "caption": caption, + "label": label, + "position": position, + "column_format": column_format, + "clines": "skip-last;data" + if (multirow and isinstance(self.index, MultiIndex)) + else None, + "bold_rows": bold_rows, + } + + return self._to_latex_via_styler( + buf, + hide=hide_, + relabel_index=relabel_index_, + format={"formatter": formatters_, **base_format_}, + format_index=format_index_, + format_index_names=format_index_names_, + render_kwargs=render_kwargs_, + ) + + @final + def _to_latex_via_styler( + self, + buf=None, + *, + hide: dict | list[dict] | None = None, + relabel_index: dict | list[dict] | None = None, + format: dict | list[dict] | None = None, + format_index: dict | list[dict] | None = None, + format_index_names: dict | list[dict] | None = None, + render_kwargs: dict | None = None, + ): + """ + Render object to a LaTeX tabular, longtable, or nested table. + + Uses the ``Styler`` implementation with the following, ordered, method chaining: + + .. code-block:: python + styler = Styler(DataFrame) + styler.hide(**hide) + styler.relabel_index(**relabel_index) + styler.format(**format) + styler.format_index(**format_index) + styler.to_latex(buf=buf, **render_kwargs) + + Parameters + ---------- + buf : str, Path or StringIO-like, optional, default None + Buffer to write to. If None, the output is returned as a string. + hide : dict, list of dict + Keyword args to pass to the method call of ``Styler.hide``. If a list will + call the method numerous times. + relabel_index : dict, list of dict + Keyword args to pass to the method of ``Styler.relabel_index``. If a list + will call the method numerous times. + format : dict, list of dict + Keyword args to pass to the method call of ``Styler.format``. If a list will + call the method numerous times. + format_index : dict, list of dict + Keyword args to pass to the method call of ``Styler.format_index``. If a + list will call the method numerous times. + render_kwargs : dict + Keyword args to pass to the method call of ``Styler.to_latex``. + + Returns + ------- + str or None + If buf is None, returns the result as a string. Otherwise returns None. + """ + from pandas.io.formats.style import Styler + + self = cast("DataFrame", self) + styler = Styler(self, uuid="") + + for kw_name in [ + "hide", + "relabel_index", + "format", + "format_index", + "format_index_names", + ]: + kw = vars()[kw_name] + if isinstance(kw, dict): + getattr(styler, kw_name)(**kw) + elif isinstance(kw, list): + for sub_kw in kw: + getattr(styler, kw_name)(**sub_kw) + + # bold_rows is not a direct kwarg of Styler.to_latex + render_kwargs = {} if render_kwargs is None else render_kwargs + if render_kwargs.pop("bold_rows"): + styler.map_index(lambda v: "textbf:--rwrap;") + + return styler.to_latex(buf=buf, **render_kwargs) + + @overload + def to_csv( + self, + path_or_buf: None = ..., + *, + sep: str = ..., + na_rep: str = ..., + float_format: str | Callable | None = ..., + columns: Sequence[Hashable] | None = ..., + header: bool | list[str] = ..., + index: bool = ..., + index_label: IndexLabel | None = ..., + mode: str = ..., + encoding: str | None = ..., + compression: CompressionOptions = ..., + quoting: int | None = ..., + quotechar: str = ..., + lineterminator: str | None = ..., + chunksize: int | None = ..., + date_format: str | None = ..., + doublequote: bool = ..., + escapechar: str | None = ..., + decimal: str = ..., + errors: OpenFileErrors = ..., + storage_options: StorageOptions = ..., + ) -> str: ... + + @overload + def to_csv( + self, + path_or_buf: FilePath | WriteBuffer[bytes] | WriteBuffer[str], + *, + sep: str = ..., + na_rep: str = ..., + float_format: str | Callable | None = ..., + columns: Sequence[Hashable] | None = ..., + header: bool | list[str] = ..., + index: bool = ..., + index_label: IndexLabel | None = ..., + mode: str = ..., + encoding: str | None = ..., + compression: CompressionOptions = ..., + quoting: int | None = ..., + quotechar: str = ..., + lineterminator: str | None = ..., + chunksize: int | None = ..., + date_format: str | None = ..., + doublequote: bool = ..., + escapechar: str | None = ..., + decimal: str = ..., + errors: OpenFileErrors = ..., + storage_options: StorageOptions = ..., + ) -> None: ... + + @final + def to_csv( + self, + path_or_buf: FilePath | WriteBuffer[bytes] | WriteBuffer[str] | None = None, + *, + sep: str = ",", + na_rep: str = "", + float_format: str | Callable | None = None, + columns: Sequence[Hashable] | None = None, + header: bool | list[str] = True, + index: bool = True, + index_label: IndexLabel | None = None, + mode: str = "w", + encoding: str | None = None, + compression: CompressionOptions = "infer", + quoting: int | None = None, + quotechar: str = '"', + lineterminator: str | None = None, + chunksize: int | None = None, + date_format: str | None = None, + doublequote: bool = True, + escapechar: str | None = None, + decimal: str = ".", + errors: OpenFileErrors = "strict", + storage_options: StorageOptions | None = None, + ) -> str | None: + r""" + Write object to a comma-separated values (csv) file. + + Parameters + ---------- + path_or_buf : str, path object, file-like object, or None, default None + String, path object (implementing os.PathLike[str]), or file-like + object implementing a write() function. If None, the result is + returned as a string. If a non-binary file object is passed, it should + be opened with `newline=''`, disabling universal newlines. If a binary + file object is passed, `mode` might need to contain a `'b'`. + sep : str, default ',' + String of length 1. Field delimiter for the output file. + na_rep : str, default '' + Missing data representation. + float_format : str, Callable, default None + Format string for floating point numbers. If a Callable is given, it takes + precedence over other numeric formatting parameters, like decimal. + columns : sequence, optional + Columns to write. + header : bool or list of str, default True + Write out the column names. If a list of strings is given it is + assumed to be aliases for the column names. + index : bool, default True + Write row names (index). + index_label : str or sequence, or False, default None + Column label for index column(s) if desired. If None is given, and + `header` and `index` are True, then the index names are used. A + sequence should be given if the object uses MultiIndex. If + False do not print fields for index names. Use index_label=False + for easier importing in R. + mode : {{'w', 'x', 'a'}}, default 'w' + Forwarded to either `open(mode=)` or `fsspec.open(mode=)` to control + the file opening. Typical values include: + + - 'w', truncate the file first. + - 'x', exclusive creation, failing if the file already exists. + - 'a', append to the end of file if it exists. + + encoding : str, optional + A string representing the encoding to use in the output file, + defaults to 'utf-8'. `encoding` is not supported if `path_or_buf` + is a non-binary file object. + + compression : str or dict, default 'infer' + For on-the-fly compression of the output data. If 'infer' and + 'path_or_buf' is path-like, then detect compression from the following + extensions: '.gz', + '.bz2', '.zip', '.xz', '.zst', '.tar', '.tar.gz', '.tar.xz' or '.tar.bz2' + (otherwise no compression). + Set to ``None`` for no compression. + Can also be a dict with key ``'method'`` set to one of + {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} and + other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdCompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for faster compression and + to create a reproducible gzip archive: + ``compression={'method': 'gzip', 'compresslevel': 1, 'mtime': 1}``. + + May be a dict with key 'method' as compression mode + and other entries as additional compression options if + compression mode is 'zip'. + + Passing compression options as keys in dict is + supported for compression modes 'gzip', 'bz2', 'zstd', and 'zip'. + quoting : optional constant from csv module + Defaults to csv.QUOTE_MINIMAL. If you have set a `float_format` + then floats are converted to strings and thus csv.QUOTE_NONNUMERIC + will treat them as non-numeric. + quotechar : str, default '\"' + String of length 1. Character used to quote fields. + lineterminator : str, optional + The newline character or character sequence to use in the output + file. Defaults to `os.linesep`, which depends on the OS in which + this method is called ('\\n' for linux, '\\r\\n' for Windows, i.e.). + chunksize : int or None + Rows to write at a time. + date_format : str, default None + Format string for datetime objects. + doublequote : bool, default True + Control quoting of `quotechar` inside a field. + escapechar : str, default None + String of length 1. Character used to escape `sep` and `quotechar` + when appropriate. + decimal : str, default '.' + Character recognized as decimal separator. E.g. use ',' for + European data. + errors : str, default 'strict' + Specifies how encoding and decoding errors are to be handled. + See the errors argument for :func:`open` for a full list + of options. + + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + Returns + ------- + None or str + If path_or_buf is None, returns the resulting csv format as a + string. Otherwise returns None. + + See Also + -------- + read_csv : Load a CSV file into a DataFrame. + to_excel : Write DataFrame to an Excel file. + + Examples + -------- + Create 'out.csv' containing 'df' without indices + + >>> df = pd.DataFrame( + ... [["Raphael", "red", "sai"], ["Donatello", "purple", "bo staff"]], + ... columns=["name", "mask", "weapon"], + ... ) + >>> df.to_csv("out.csv", index=False) # doctest: +SKIP + + Create 'out.zip' containing 'out.csv' + + >>> df.to_csv(index=False) + 'name,mask,weapon\nRaphael,red,sai\nDonatello,purple,bo staff\n' + >>> compression_opts = dict( + ... method="zip", archive_name="out.csv" + ... ) # doctest: +SKIP + >>> df.to_csv( + ... "out.zip", index=False, compression=compression_opts + ... ) # doctest: +SKIP + + To write a csv file to a new folder or nested folder you will first + need to create it using either Pathlib or os: + + >>> from pathlib import Path # doctest: +SKIP + >>> filepath = Path("folder/subfolder/out.csv") # doctest: +SKIP + >>> filepath.parent.mkdir(parents=True, exist_ok=True) # doctest: +SKIP + >>> df.to_csv(filepath) # doctest: +SKIP + + >>> import os # doctest: +SKIP + >>> os.makedirs("folder/subfolder", exist_ok=True) # doctest: +SKIP + >>> df.to_csv("folder/subfolder/out.csv") # doctest: +SKIP + + Format floats to two decimal places: + + >>> df.to_csv("out1.csv", float_format="%.2f") # doctest: +SKIP + + Format floats using scientific notation: + + >>> df.to_csv("out2.csv", float_format="{{:.2e}}".format) # doctest: +SKIP + """ + df = self if isinstance(self, ABCDataFrame) else self.to_frame() + + formatter = DataFrameFormatter( + frame=df, + header=header, + index=index, + na_rep=na_rep, + float_format=float_format, + decimal=decimal, + ) + + return DataFrameRenderer(formatter).to_csv( + path_or_buf, + lineterminator=lineterminator, + sep=sep, + encoding=encoding, + errors=errors, + compression=compression, + quoting=quoting, + columns=columns, + index_label=index_label, + mode=mode, + chunksize=chunksize, + quotechar=quotechar, + date_format=date_format, + doublequote=doublequote, + escapechar=escapechar, + storage_options=storage_options, + ) + + # ---------------------------------------------------------------------- + # Indexing Methods + + @final + def take(self, indices, axis: Axis = 0, **kwargs) -> Self: + """ + Return the elements in the given *positional* indices along an axis. + + This means that we are not indexing according to actual values in + the index attribute of the object. We are indexing according to the + actual position of the element in the object. + + Parameters + ---------- + indices : array-like + An array of ints indicating which positions to take. + axis : {0 or 'index', 1 or 'columns'}, default 0 + The axis on which to select elements. ``0`` means that we are + selecting rows, ``1`` means that we are selecting columns. + For `Series` this parameter is unused and defaults to 0. + **kwargs + For compatibility with :meth:`numpy.take`. Has no effect on the + output. + + Returns + ------- + same type as caller + An array-like containing the elements taken from the object. + + See Also + -------- + DataFrame.loc : Select a subset of a DataFrame by labels. + DataFrame.iloc : Select a subset of a DataFrame by positions. + numpy.take : Take elements from an array along an axis. + + Examples + -------- + >>> df = pd.DataFrame( + ... [ + ... ("falcon", "bird", 389.0), + ... ("parrot", "bird", 24.0), + ... ("lion", "mammal", 80.5), + ... ("monkey", "mammal", np.nan), + ... ], + ... columns=["name", "class", "max_speed"], + ... index=[0, 2, 3, 1], + ... ) + >>> df + name class max_speed + 0 falcon bird 389.0 + 2 parrot bird 24.0 + 3 lion mammal 80.5 + 1 monkey mammal NaN + + Take elements at positions 0 and 3 along the axis 0 (default). + + Note how the actual indices selected (0 and 1) do not correspond to + our selected indices 0 and 3. That's because we are selecting the 0th + and 3rd rows, not rows whose indices equal 0 and 3. + + >>> df.take([0, 3]) + name class max_speed + 0 falcon bird 389.0 + 1 monkey mammal NaN + + Take elements at indices 1 and 2 along the axis 1 (column selection). + + >>> df.take([1, 2], axis=1) + class max_speed + 0 bird 389.0 + 2 bird 24.0 + 3 mammal 80.5 + 1 mammal NaN + + We may take elements using negative integers for positive indices, + starting from the end of the object, just like with Python lists. + + >>> df.take([-1, -2]) + name class max_speed + 1 monkey mammal NaN + 3 lion mammal 80.5 + """ + + nv.validate_take((), kwargs) + + if isinstance(indices, slice): + raise TypeError( + f"{type(self).__name__}.take requires a sequence of integers, " + "not slice." + ) + indices = np.asarray(indices, dtype=np.intp) + if axis == 0 and indices.ndim == 1 and is_range_indexer(indices, len(self)): + return self.copy(deep=False) + + new_data = self._mgr.take( + indices, + axis=self._get_block_manager_axis(axis), + verify=True, + ) + return self._constructor_from_mgr(new_data, axes=new_data.axes).__finalize__( + self, method="take" + ) + + @final + def xs( + self, + key: IndexLabel, + axis: Axis = 0, + level: IndexLabel | None = None, + drop_level: bool = True, + ) -> Self: + """ + Return cross-section from the Series/DataFrame. + + This method takes a `key` argument to select data at a particular + level of a MultiIndex. + + Parameters + ---------- + key : label or tuple of label + Label contained in the index, or partially in a MultiIndex. + axis : {0 or 'index', 1 or 'columns'}, default 0 + Axis to retrieve cross-section on. + level : object, defaults to first n levels (n=1 or len(key)) + In case of a key partially contained in a MultiIndex, indicate + which levels are used. Levels can be referred by label or position. + drop_level : bool, default True + If False, returns object with same levels as self. + + Returns + ------- + Series or DataFrame + Cross-section from the original Series or DataFrame + corresponding to the selected index levels. + + See Also + -------- + DataFrame.loc : Access a group of rows and columns + by label(s) or a boolean array. + DataFrame.iloc : Purely integer-location based indexing + for selection by position. + + Notes + ----- + `xs` can not be used to set values. + + MultiIndex Slicers is a generic way to get/set values on + any level or levels. + It is a superset of `xs` functionality, see + :ref:`MultiIndex Slicers `. + + Examples + -------- + >>> d = { + ... "num_legs": [4, 4, 2, 2], + ... "num_wings": [0, 0, 2, 2], + ... "class": ["mammal", "mammal", "mammal", "bird"], + ... "animal": ["cat", "dog", "bat", "penguin"], + ... "locomotion": ["walks", "walks", "flies", "walks"], + ... } + >>> df = pd.DataFrame(data=d) + >>> df = df.set_index(["class", "animal", "locomotion"]) + >>> df + num_legs num_wings + class animal locomotion + mammal cat walks 4 0 + dog walks 4 0 + bat flies 2 2 + bird penguin walks 2 2 + + Get values at specified index + + >>> df.xs("mammal") + num_legs num_wings + animal locomotion + cat walks 4 0 + dog walks 4 0 + bat flies 2 2 + + Get values at several indexes + + >>> df.xs(("mammal", "dog", "walks")) + num_legs 4 + num_wings 0 + Name: (mammal, dog, walks), dtype: int64 + + Get values at specified index and level + + >>> df.xs("cat", level=1) + num_legs num_wings + class locomotion + mammal walks 4 0 + + Get values at several indexes and levels + + >>> df.xs(("bird", "walks"), level=[0, "locomotion"]) + num_legs num_wings + animal + penguin 2 2 + + Get values at specified column and axis + + >>> df.xs("num_wings", axis=1) + class animal locomotion + mammal cat walks 0 + dog walks 0 + bat flies 2 + bird penguin walks 2 + Name: num_wings, dtype: int64 + """ + axis = self._get_axis_number(axis) + labels = self._get_axis(axis) + + if isinstance(key, list): + raise TypeError("list keys are not supported in xs, pass a tuple instead") + + if level is not None: + if not isinstance(labels, MultiIndex): + raise TypeError("Index must be a MultiIndex") + loc, new_ax = labels.get_loc_level(key, level=level, drop_level=drop_level) + + # create the tuple of the indexer + _indexer = [slice(None)] * self.ndim + _indexer[axis] = loc + indexer = tuple(_indexer) + + result = self.iloc[indexer] + setattr(result, result._get_axis_name(axis), new_ax) + return result + + if axis == 1: + if drop_level: + return self[key] + index = self.columns + else: + index = self.index + + if isinstance(index, MultiIndex): + loc, new_index = index._get_loc_level(key, level=0) + if not drop_level: + if lib.is_integer(loc): + # Slice index must be an integer or None + new_index = index[loc : loc + 1] + else: + new_index = index[loc] + else: + loc = index.get_loc(key) + + if isinstance(loc, np.ndarray): + if loc.dtype == np.bool_: + (inds,) = loc.nonzero() + return self.take(inds, axis=axis) + else: + return self.take(loc, axis=axis) + + if not is_scalar(loc): + new_index = index[loc] + + if is_scalar(loc) and axis == 0: + # In this case loc should be an integer + if self.ndim == 1: + # if we encounter an array-like and we only have 1 dim + # that means that their are list/ndarrays inside the Series! + # so just return them (GH 6394) + return self._values[loc] + + new_mgr = self._mgr.fast_xs(loc) + + result = self._constructor_sliced_from_mgr(new_mgr, axes=new_mgr.axes) + result._name = self.index[loc] + result = result.__finalize__(self) + elif is_scalar(loc): + result = self.iloc[:, slice(loc, loc + 1)] + elif axis == 1: + result = self.iloc[:, loc] + else: + result = self.iloc[loc] + result.index = new_index + + return result + + def __getitem__(self, item): + raise AbstractMethodError(self) + + @final + def _getitem_slice(self, key: slice) -> Self: + """ + __getitem__ for the case where the key is a slice object. + """ + # _convert_slice_indexer to determine if this slice is positional + # or label based, and if the latter, convert to positional + slobj = self.index._convert_slice_indexer(key, kind="getitem") + if isinstance(slobj, np.ndarray): + # reachable with DatetimeIndex + indexer = lib.maybe_indices_to_slice(slobj.astype(np.intp), len(self)) + if isinstance(indexer, np.ndarray): + # GH#43223 If we can not convert, use take + return self.take(indexer, axis=0) + slobj = indexer + return self._slice(slobj) + + def _slice(self, slobj: slice, axis: AxisInt = 0) -> Self: + """ + Construct a slice of this container. + + Slicing with this method is *always* positional. + """ + assert isinstance(slobj, slice), type(slobj) + axis = self._get_block_manager_axis(axis) + new_mgr = self._mgr.get_slice(slobj, axis=axis) + result = self._constructor_from_mgr(new_mgr, axes=new_mgr.axes) + result = result.__finalize__(self) + return result + + @final + def __delitem__(self, key) -> None: + """ + Delete item + """ + deleted = False + + maybe_shortcut = False + if self.ndim == 2 and isinstance(self.columns, MultiIndex): + try: + # By using engine's __contains__ we effectively + # restrict to same-length tuples + maybe_shortcut = key not in self.columns._engine + except TypeError: + pass + + if maybe_shortcut: + # Allow shorthand to delete all columns whose first len(key) + # elements match key: + if not isinstance(key, tuple): + key = (key,) + for col in self.columns: + if isinstance(col, tuple) and col[: len(key)] == key: + del self[col] + deleted = True + if not deleted: + # If the above loop ran and didn't delete anything because + # there was no match, this call should raise the appropriate + # exception: + loc = self.axes[-1].get_loc(key) + self._mgr = self._mgr.idelete(loc) + + # ---------------------------------------------------------------------- + # Unsorted + + @final + def _check_inplace_and_allows_duplicate_labels(self, inplace: bool) -> None: + if inplace and not self.flags.allows_duplicate_labels: + raise ValueError( + "Cannot specify 'inplace=True' when " + "'self.flags.allows_duplicate_labels' is False." + ) + + @final + def get(self, key, default=None): + """ + Get item from object for given key (ex: DataFrame column). + + Returns ``default`` value if not found. + + Parameters + ---------- + key : object + Key for which item should be returned. + default : object, default None + Default value to return if key is not found. + + Returns + ------- + same type as items contained in object + Item for given key or ``default`` value, if key is not found. + + See Also + -------- + DataFrame.get : Get item from object for given key (ex: DataFrame column). + Series.get : Get item from object for given key (ex: DataFrame column). + + Examples + -------- + >>> df = pd.DataFrame( + ... [ + ... [24.3, 75.7, "high"], + ... [31, 87.8, "high"], + ... [22, 71.6, "medium"], + ... [35, 95, "medium"], + ... ], + ... columns=["temp_celsius", "temp_fahrenheit", "windspeed"], + ... index=pd.date_range(start="2014-02-12", end="2014-02-15", freq="D"), + ... ) + + >>> df + temp_celsius temp_fahrenheit windspeed + 2014-02-12 24.3 75.7 high + 2014-02-13 31.0 87.8 high + 2014-02-14 22.0 71.6 medium + 2014-02-15 35.0 95.0 medium + + >>> df.get(["temp_celsius", "windspeed"]) + temp_celsius windspeed + 2014-02-12 24.3 high + 2014-02-13 31.0 high + 2014-02-14 22.0 medium + 2014-02-15 35.0 medium + + >>> ser = df["windspeed"] + >>> ser.get("2014-02-13") + 'high' + + If the key isn't found, the default value will be used. + + >>> df.get(["temp_celsius", "temp_kelvin"], default="default_value") + 'default_value' + + >>> ser.get("2014-02-10", "[unknown]") + '[unknown]' + """ + try: + return self[key] + except (KeyError, ValueError, IndexError): + return default + + @staticmethod + def _check_copy_deprecation(copy): + if copy is not lib.no_default: + warnings.warn( + "The copy keyword is deprecated and will be removed in a future " + "version. Copy-on-Write is active in pandas since 3.0 which utilizes " + "a lazy copy mechanism that defers copies until necessary. Use " + ".copy() to make an eager copy if necessary.", + Pandas4Warning, + stacklevel=find_stack_level(), + ) + + # issue 58667 + @deprecate_kwarg(Pandas4Warning, "method", new_arg_name=None) + @final + def reindex_like( + self, + other, + method: Literal["backfill", "bfill", "pad", "ffill", "nearest"] | None = None, + copy: bool | lib.NoDefault = lib.no_default, + limit: int | None = None, + tolerance=None, + ) -> Self: + """ + Return an object with matching indices as other object. + + Conform the object to the same index on all axes. Optional + filling logic, placing NaN in locations having no value + in the previous index. A new object is produced unless the + new index is equivalent to the current one and copy=False. + + Parameters + ---------- + other : Object of the same data type + Its row and column indices are used to define the new indices + of this object. + method : {None, 'backfill'/'bfill', 'pad'/'ffill', 'nearest'} + Method to use for filling holes in reindexed DataFrame. + Please note: this is only applicable to DataFrames/Series with a + monotonically increasing/decreasing index. + + .. deprecated:: 3.0.0 + + * None (default): don't fill gaps + * pad / ffill: propagate last valid observation forward to next + valid + * backfill / bfill: use next valid observation to fill gap + * nearest: use nearest valid observations to fill gap. + + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + limit : int, default None + Maximum number of consecutive labels to fill for inexact matches. + tolerance : optional + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations must + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + + Tolerance may be a scalar value, which applies the same tolerance + to all values, or list-like, which applies variable tolerance per + element. List-like includes list, tuple, array, Series, and must be + the same size as the index and its dtype must exactly match the + index's type. + + Returns + ------- + Series or DataFrame + Same type as caller, but with changed indices on each axis. + + See Also + -------- + DataFrame.set_index : Set row labels. + DataFrame.reset_index : Remove row labels or move them to new columns. + DataFrame.reindex : Change to new indices or expand indices. + + Notes + ----- + Same as calling + ``.reindex(index=other.index, columns=other.columns,...)``. + + Examples + -------- + >>> df1 = pd.DataFrame( + ... [ + ... [24.3, 75.7, "high"], + ... [31, 87.8, "high"], + ... [22, 71.6, "medium"], + ... [35, 95, "medium"], + ... ], + ... columns=["temp_celsius", "temp_fahrenheit", "windspeed"], + ... index=pd.date_range(start="2014-02-12", end="2014-02-15", freq="D"), + ... ) + + >>> df1 + temp_celsius temp_fahrenheit windspeed + 2014-02-12 24.3 75.7 high + 2014-02-13 31.0 87.8 high + 2014-02-14 22.0 71.6 medium + 2014-02-15 35.0 95.0 medium + + >>> df2 = pd.DataFrame( + ... [[28, "low"], [30, "low"], [35.1, "medium"]], + ... columns=["temp_celsius", "windspeed"], + ... index=pd.DatetimeIndex(["2014-02-12", "2014-02-13", "2014-02-15"]), + ... ) + + >>> df2 + temp_celsius windspeed + 2014-02-12 28.0 low + 2014-02-13 30.0 low + 2014-02-15 35.1 medium + + >>> df2.reindex_like(df1) + temp_celsius temp_fahrenheit windspeed + 2014-02-12 28.0 NaN low + 2014-02-13 30.0 NaN low + 2014-02-14 NaN NaN NaN + 2014-02-15 35.1 NaN medium + """ + self._check_copy_deprecation(copy) + d = other._construct_axes_dict( + axes=self._AXIS_ORDERS, + method=method, + limit=limit, + tolerance=tolerance, + ) + + return self.reindex(**d) + + @overload + def drop( + self, + labels: IndexLabel | ListLike = ..., + *, + axis: Axis = ..., + index: IndexLabel | ListLike = ..., + columns: IndexLabel | ListLike = ..., + level: Level | None = ..., + inplace: Literal[True], + errors: IgnoreRaise = ..., + ) -> None: ... + + @overload + def drop( + self, + labels: IndexLabel | ListLike = ..., + *, + axis: Axis = ..., + index: IndexLabel | ListLike = ..., + columns: IndexLabel | ListLike = ..., + level: Level | None = ..., + inplace: Literal[False] = ..., + errors: IgnoreRaise = ..., + ) -> Self: ... + + @overload + def drop( + self, + labels: IndexLabel | ListLike = ..., + *, + axis: Axis = ..., + index: IndexLabel | ListLike = ..., + columns: IndexLabel | ListLike = ..., + level: Level | None = ..., + inplace: bool = ..., + errors: IgnoreRaise = ..., + ) -> Self | None: ... + + def drop( + self, + labels: IndexLabel | ListLike = None, + *, + axis: Axis = 0, + index: IndexLabel | ListLike = None, + columns: IndexLabel | ListLike = None, + level: Level | None = None, + inplace: bool = False, + errors: IgnoreRaise = "raise", + ) -> Self | None: + inplace = validate_bool_kwarg(inplace, "inplace") + + if labels is not None: + if index is not None or columns is not None: + raise ValueError("Cannot specify both 'labels' and 'index'/'columns'") + axis_name = self._get_axis_name(axis) + axes = {axis_name: labels} + elif index is not None or columns is not None: + if axis == 1: + raise ValueError("Cannot specify both 'axis' and 'index'/'columns'") + axes = {"index": index} + if self.ndim == 2: + axes["columns"] = columns + else: + raise ValueError( + "Need to specify at least one of 'labels', 'index' or 'columns'" + ) + + obj = self + + for axis, labels in axes.items(): + if labels is not None: + obj = obj._drop_axis(labels, axis, level=level, errors=errors) + + if inplace: + self._update_inplace(obj) + return None + else: + return obj + + @final + def _drop_axis( + self, + labels, + axis, + level=None, + errors: IgnoreRaise = "raise", + only_slice: bool = False, + ) -> Self: + """ + Drop labels from specified axis. Used in the ``drop`` method + internally. + + Parameters + ---------- + labels : single label or list-like + axis : int or axis name + level : int or level name, default None + For MultiIndex + errors : {'ignore', 'raise'}, default 'raise' + If 'ignore', suppress error and existing labels are dropped. + only_slice : bool, default False + Whether indexing along columns should be view-only. + + """ + axis_num = self._get_axis_number(axis) + axis = self._get_axis(axis) + + if axis.is_unique: + if level is not None: + if not isinstance(axis, MultiIndex): + raise AssertionError("axis must be a MultiIndex") + new_axis = axis.drop(labels, level=level, errors=errors) + else: + new_axis = axis.drop(labels, errors=errors) + indexer = axis.get_indexer(new_axis) + + # Case for non-unique axis + else: + is_tuple_labels = is_nested_list_like(labels) or isinstance(labels, tuple) + labels = ensure_object(common.index_labels_to_array(labels)) + if level is not None: + if not isinstance(axis, MultiIndex): + raise AssertionError("axis must be a MultiIndex") + mask = ~axis.get_level_values(level).isin(labels) + + # GH 18561 MultiIndex.drop should raise if label is absent + if errors == "raise" and mask.all(): + raise KeyError(f"{labels} not found in axis") + elif ( + isinstance(axis, MultiIndex) + and labels.dtype == "object" + and not is_tuple_labels + ): + # Set level to zero in case of MultiIndex and label is string, + # because isin can't handle strings for MultiIndexes GH#36293 + # In case of tuples we get dtype object but have to use isin GH#42771 + mask = ~axis.get_level_values(0).isin(labels) + else: + mask = ~axis.isin(labels) + # Check if label doesn't exist along axis + labels_missing = (axis.get_indexer_for(labels) == -1).any() + if errors == "raise" and labels_missing: + raise KeyError(f"{labels} not found in axis") + + if isinstance(mask.dtype, ExtensionDtype): + # GH#45860 + mask = mask.to_numpy(dtype=bool) + + indexer = mask.nonzero()[0] + new_axis = axis.take(indexer) + + bm_axis = self.ndim - axis_num - 1 + new_mgr = self._mgr.reindex_indexer( + new_axis, + indexer, + axis=bm_axis, + allow_dups=True, + only_slice=only_slice, + ) + result = self._constructor_from_mgr(new_mgr, axes=new_mgr.axes) + if self.ndim == 1: + result._name = self.name + + return result.__finalize__(self) + + @final + def _update_inplace(self, result) -> None: + """ + Replace self internals with result. + + Parameters + ---------- + result : same type as self + """ + # NOTE: This does *not* call __finalize__ and that's an explicit + # decision that we may revisit in the future. + self._mgr = result._mgr + + @final + def add_prefix(self, prefix: str, axis: Axis | None = None) -> Self: + """ + Prefix labels with string `prefix`. + + For Series, the row labels are prefixed. + For DataFrame, the column labels are prefixed. + + Parameters + ---------- + prefix : str + The string to add before each label. + axis : {0 or 'index', 1 or 'columns', None}, default None + Axis to add prefix on + + .. versionadded:: 2.0.0 + + Returns + ------- + Series or DataFrame + New Series or DataFrame with updated labels. + + See Also + -------- + Series.add_suffix: Suffix row labels with string `suffix`. + DataFrame.add_suffix: Suffix column labels with string `suffix`. + + Examples + -------- + >>> s = pd.Series([1, 2, 3, 4]) + >>> s + 0 1 + 1 2 + 2 3 + 3 4 + dtype: int64 + + >>> s.add_prefix("item_") + item_0 1 + item_1 2 + item_2 3 + item_3 4 + dtype: int64 + + >>> df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]}) + >>> df + A B + 0 1 3 + 1 2 4 + 2 3 5 + 3 4 6 + + >>> df.add_prefix("col_") + col_A col_B + 0 1 3 + 1 2 4 + 2 3 5 + 3 4 6 + """ + f = lambda x: f"{prefix}{x}" + + axis_name = self._info_axis_name + if axis is not None: + axis_name = self._get_axis_name(axis) + + mapper = {axis_name: f} + + # error: Keywords must be strings + # error: No overload variant of "_rename" of "NDFrame" matches + # argument type "dict[Literal['index', 'columns'], Callable[[Any], str]]" + return self._rename(**mapper) # type: ignore[call-overload, misc] + + @final + def add_suffix(self, suffix: str, axis: Axis | None = None) -> Self: + """ + Suffix labels with string `suffix`. + + For Series, the row labels are suffixed. + For DataFrame, the column labels are suffixed. + + Parameters + ---------- + suffix : str + The string to add after each label. + axis : {0 or 'index', 1 or 'columns', None}, default None + Axis to add suffix on + + .. versionadded:: 2.0.0 + + Returns + ------- + Series or DataFrame + New Series or DataFrame with updated labels. + + See Also + -------- + Series.add_prefix: Prefix row labels with string `prefix`. + DataFrame.add_prefix: Prefix column labels with string `prefix`. + + Examples + -------- + >>> s = pd.Series([1, 2, 3, 4]) + >>> s + 0 1 + 1 2 + 2 3 + 3 4 + dtype: int64 + + >>> s.add_suffix("_item") + 0_item 1 + 1_item 2 + 2_item 3 + 3_item 4 + dtype: int64 + + >>> df = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]}) + >>> df + A B + 0 1 3 + 1 2 4 + 2 3 5 + 3 4 6 + + >>> df.add_suffix("_col") + A_col B_col + 0 1 3 + 1 2 4 + 2 3 5 + 3 4 6 + """ + f = lambda x: f"{x}{suffix}" + + axis_name = self._info_axis_name + if axis is not None: + axis_name = self._get_axis_name(axis) + + mapper = {axis_name: f} + # error: Keywords must be strings + # error: No overload variant of "_rename" of "NDFrame" matches argument + # type "dict[Literal['index', 'columns'], Callable[[Any], str]]" + return self._rename(**mapper) # type: ignore[call-overload, misc] + + @overload + def sort_values( + self, + *, + axis: Axis = ..., + ascending: bool | Sequence[bool] = ..., + inplace: Literal[False] = ..., + kind: SortKind = ..., + na_position: NaPosition = ..., + ignore_index: bool = ..., + key: ValueKeyFunc = ..., + ) -> Self: ... + + @overload + def sort_values( + self, + *, + axis: Axis = ..., + ascending: bool | Sequence[bool] = ..., + inplace: Literal[True], + kind: SortKind = ..., + na_position: NaPosition = ..., + ignore_index: bool = ..., + key: ValueKeyFunc = ..., + ) -> None: ... + + @overload + def sort_values( + self, + *, + axis: Axis = ..., + ascending: bool | Sequence[bool] = ..., + inplace: bool = ..., + kind: SortKind = ..., + na_position: NaPosition = ..., + ignore_index: bool = ..., + key: ValueKeyFunc = ..., + ) -> Self | None: ... + + def sort_values( + self, + *, + axis: Axis = 0, + ascending: bool | Sequence[bool] = True, + inplace: bool = False, + kind: SortKind = "quicksort", + na_position: NaPosition = "last", + ignore_index: bool = False, + key: ValueKeyFunc | None = None, + ) -> Self | None: + """ + Sort by the values along either axis. + + Parameters + ----------%(optional_by)s + axis : %(axes_single_arg)s, default 0 + Axis to be sorted. + ascending : bool or list of bool, default True + Sort ascending vs. descending. Specify list for multiple sort + orders. If this is a list of bools, must match the length of + the by. + inplace : bool, default False + If True, perform operation in-place. + kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, default 'quicksort' + Choice of sorting algorithm. See also :func:`numpy.sort` for more + information. `mergesort` and `stable` are the only stable algorithms. For + DataFrames, this option is only applied when sorting on a single + column or label. + na_position : {'first', 'last'}, default 'last' + Puts NaNs at the beginning if `first`; `last` puts NaNs at the + end. + ignore_index : bool, default False + If True, the resulting axis will be labeled 0, 1, …, n - 1. + key : callable, optional + Apply the key function to the values + before sorting. This is similar to the `key` argument in the + builtin :meth:`sorted` function, with the notable difference that + this `key` function should be *vectorized*. It should expect a + ``Series`` and return a Series with the same shape as the input. + It will be applied to each column in `by` independently. The values in the + returned Series will be used as the keys for sorting. + + Returns + ------- + DataFrame or None + DataFrame with sorted values or None if ``inplace=True``. + + See Also + -------- + DataFrame.sort_index : Sort a DataFrame by the index. + Series.sort_values : Similar method for a Series. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "col1": ["A", "A", "B", np.nan, "D", "C"], + ... "col2": [2, 1, 9, 8, 7, 4], + ... "col3": [0, 1, 9, 4, 2, 3], + ... "col4": ["a", "B", "c", "D", "e", "F"], + ... } + ... ) + >>> df + col1 col2 col3 col4 + 0 A 2 0 a + 1 A 1 1 B + 2 B 9 9 c + 3 NaN 8 4 D + 4 D 7 2 e + 5 C 4 3 F + + Sort by col1 + + >>> df.sort_values(by=["col1"]) + col1 col2 col3 col4 + 0 A 2 0 a + 1 A 1 1 B + 2 B 9 9 c + 5 C 4 3 F + 4 D 7 2 e + 3 NaN 8 4 D + + Sort by multiple columns + + >>> df.sort_values(by=["col1", "col2"]) + col1 col2 col3 col4 + 1 A 1 1 B + 0 A 2 0 a + 2 B 9 9 c + 5 C 4 3 F + 4 D 7 2 e + 3 NaN 8 4 D + + Sort Descending + + >>> df.sort_values(by="col1", ascending=False) + col1 col2 col3 col4 + 4 D 7 2 e + 5 C 4 3 F + 2 B 9 9 c + 0 A 2 0 a + 1 A 1 1 B + 3 NaN 8 4 D + + Putting NAs first + + >>> df.sort_values(by="col1", ascending=False, na_position="first") + col1 col2 col3 col4 + 3 NaN 8 4 D + 4 D 7 2 e + 5 C 4 3 F + 2 B 9 9 c + 0 A 2 0 a + 1 A 1 1 B + + Sorting with a key function + + >>> df.sort_values(by="col4", key=lambda col: col.str.lower()) + col1 col2 col3 col4 + 0 A 2 0 a + 1 A 1 1 B + 2 B 9 9 c + 3 NaN 8 4 D + 4 D 7 2 e + 5 C 4 3 F + + Natural sort with the key argument, + using the `natsort ` package. + + >>> df = pd.DataFrame( + ... { + ... "hours": ["0hr", "128hr", "0hr", "64hr", "64hr", "128hr"], + ... "mins": [ + ... "10mins", + ... "40mins", + ... "40mins", + ... "40mins", + ... "10mins", + ... "10mins", + ... ], + ... "value": [10, 20, 30, 40, 50, 60], + ... } + ... ) + >>> df + hours mins value + 0 0hr 10mins 10 + 1 128hr 40mins 20 + 2 0hr 40mins 30 + 3 64hr 40mins 40 + 4 64hr 10mins 50 + 5 128hr 10mins 60 + >>> from natsort import natsort_keygen + >>> df.sort_values( + ... by=["hours", "mins"], + ... key=natsort_keygen(), + ... ) + hours mins value + 0 0hr 10mins 10 + 2 0hr 40mins 30 + 4 64hr 10mins 50 + 3 64hr 40mins 40 + 5 128hr 10mins 60 + 1 128hr 40mins 20 + """ + raise AbstractMethodError(self) + + @overload + def sort_index( + self, + *, + axis: Axis = ..., + level: IndexLabel = ..., + ascending: bool | Sequence[bool] = ..., + inplace: Literal[True], + kind: SortKind = ..., + na_position: NaPosition = ..., + sort_remaining: bool = ..., + ignore_index: bool = ..., + key: IndexKeyFunc = ..., + ) -> None: ... + + @overload + def sort_index( + self, + *, + axis: Axis = ..., + level: IndexLabel = ..., + ascending: bool | Sequence[bool] = ..., + inplace: Literal[False] = ..., + kind: SortKind = ..., + na_position: NaPosition = ..., + sort_remaining: bool = ..., + ignore_index: bool = ..., + key: IndexKeyFunc = ..., + ) -> Self: ... + + @overload + def sort_index( + self, + *, + axis: Axis = ..., + level: IndexLabel = ..., + ascending: bool | Sequence[bool] = ..., + inplace: bool = ..., + kind: SortKind = ..., + na_position: NaPosition = ..., + sort_remaining: bool = ..., + ignore_index: bool = ..., + key: IndexKeyFunc = ..., + ) -> Self | None: ... + + def sort_index( + self, + *, + axis: Axis = 0, + level: IndexLabel | None = None, + ascending: bool | Sequence[bool] = True, + inplace: bool = False, + kind: SortKind = "quicksort", + na_position: NaPosition = "last", + sort_remaining: bool = True, + ignore_index: bool = False, + key: IndexKeyFunc | None = None, + ) -> Self | None: + inplace = validate_bool_kwarg(inplace, "inplace") + axis = self._get_axis_number(axis) + ascending = validate_ascending(ascending) + + target = self._get_axis(axis) + + indexer = get_indexer_indexer( + target, level, ascending, kind, na_position, sort_remaining, key + ) + + if indexer is None: + if inplace: + result = self + else: + result = self.copy(deep=False) + + if ignore_index: + if axis == 1: + result.columns = default_index(len(self.columns)) + else: + result.index = default_index(len(self)) + if inplace: + return None + else: + return result + + baxis = self._get_block_manager_axis(axis) + new_data = self._mgr.take(indexer, axis=baxis, verify=False) + + # reconstruct axis if needed + if not ignore_index: + new_axis = new_data.axes[baxis]._sort_levels_monotonic() + else: + new_axis = default_index(len(indexer)) + new_data.set_axis(baxis, new_axis) + + result = self._constructor_from_mgr(new_data, axes=new_data.axes) + + if inplace: + return self._update_inplace(result) + else: + return result.__finalize__(self, method="sort_index") + + def reindex( + self, + labels=None, + *, + index=None, + columns=None, + axis: Axis | None = None, + method: ReindexMethod | None = None, + copy: bool | lib.NoDefault = lib.no_default, + level: Level | None = None, + fill_value: Scalar | None = np.nan, + limit: int | None = None, + tolerance=None, + ) -> Self: + """ + Conform Series/DataFrame to new index with optional filling logic. + + Places NA/NaN in locations having no value in the previous index. A new object + is produced unless the new index is equivalent to the current one and + ``copy=False``. + + Parameters + ---------- + method : {{None, 'backfill'/'bfill', 'pad'/'ffill', 'nearest'}} + Method to use for filling holes in reindexed DataFrame. + Please note: this is only applicable to DataFrames/Series with a + monotonically increasing/decreasing index. + + * None (default): don't fill gaps + * pad / ffill: Propagate last valid observation forward to next + valid. + * backfill / bfill: Use next valid observation to fill gap. + * nearest: Use nearest valid observations to fill gap. + + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + level : int or name + Broadcast across a level, matching Index values on the + passed MultiIndex level. + fill_value : scalar, default np.nan + Value to use for missing values. Defaults to NaN, but can be any + "compatible" value. + limit : int, default None + Maximum number of consecutive elements to forward or backward fill. + tolerance : optional + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations most + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + + Tolerance may be a scalar value, which applies the same tolerance + to all values, or list-like, which applies variable tolerance per + element. List-like includes list, tuple, array, Series, and must be + the same size as the index and its dtype must exactly match the + index's type. + + Returns + ------- + Series/DataFrame + Series/DataFrame with changed index. + + See Also + -------- + DataFrame.set_index : Set row labels. + DataFrame.reset_index : Remove row labels or move them to new columns. + DataFrame.reindex_like : Change to same indices as other DataFrame. + + Examples + -------- + ``DataFrame.reindex`` supports two calling conventions + + * ``(index=index_labels, columns=column_labels, ...)`` + * ``(labels, axis={{'index', 'columns'}}, ...)`` + + We *highly* recommend using keyword arguments to clarify your + intent. + + Create a DataFrame with some fictional data. + + >>> index = ["Firefox", "Chrome", "Safari", "IE10", "Konqueror"] + >>> columns = ["http_status", "response_time"] + >>> df = pd.DataFrame( + ... [[200, 0.04], [200, 0.02], [404, 0.07], [404, 0.08], [301, 1.0]], + ... columns=columns, + ... index=index, + ... ) + >>> df + http_status response_time + Firefox 200 0.04 + Chrome 200 0.02 + Safari 404 0.07 + IE10 404 0.08 + Konqueror 301 1.00 + + Create a new index and reindex the DataFrame. By default + values in the new index that do not have corresponding + records in the DataFrame are assigned ``NaN``. + + >>> new_index = ["Safari", "Iceweasel", "Comodo Dragon", "IE10", "Chrome"] + >>> df.reindex(new_index) + http_status response_time + Safari 404.0 0.07 + Iceweasel NaN NaN + Comodo Dragon NaN NaN + IE10 404.0 0.08 + Chrome 200.0 0.02 + + We can fill in the missing values by passing a value to + the keyword ``fill_value``. Because the index is not monotonically + increasing or decreasing, we cannot use arguments to the keyword + ``method`` to fill the ``NaN`` values. + + >>> df.reindex(new_index, fill_value=0) + http_status response_time + Safari 404 0.07 + Iceweasel 0 0.00 + Comodo Dragon 0 0.00 + IE10 404 0.08 + Chrome 200 0.02 + + >>> df.reindex(new_index, fill_value="missing") + http_status response_time + Safari 404 0.07 + Iceweasel missing missing + Comodo Dragon missing missing + IE10 404 0.08 + Chrome 200 0.02 + + We can also reindex the columns. + + >>> df.reindex(columns=["http_status", "user_agent"]) + http_status user_agent + Firefox 200 NaN + Chrome 200 NaN + Safari 404 NaN + IE10 404 NaN + Konqueror 301 NaN + + Or we can use "axis-style" keyword arguments + + >>> df.reindex(["http_status", "user_agent"], axis="columns") + http_status user_agent + Firefox 200 NaN + Chrome 200 NaN + Safari 404 NaN + IE10 404 NaN + Konqueror 301 NaN + + To further illustrate the filling functionality in + ``reindex``, we will create a DataFrame with a + monotonically increasing index (for example, a sequence + of dates). + + >>> date_index = pd.date_range("1/1/2010", periods=6, freq="D") + >>> df2 = pd.DataFrame( + ... {"prices": [100, 101, np.nan, 100, 89, 88]}, index=date_index + ... ) + >>> df2 + prices + 2010-01-01 100.0 + 2010-01-02 101.0 + 2010-01-03 NaN + 2010-01-04 100.0 + 2010-01-05 89.0 + 2010-01-06 88.0 + + Suppose we decide to expand the DataFrame to cover a wider + date range. + + >>> date_index2 = pd.date_range("12/29/2009", periods=10, freq="D") + >>> df2.reindex(date_index2) + prices + 2009-12-29 NaN + 2009-12-30 NaN + 2009-12-31 NaN + 2010-01-01 100.0 + 2010-01-02 101.0 + 2010-01-03 NaN + 2010-01-04 100.0 + 2010-01-05 89.0 + 2010-01-06 88.0 + 2010-01-07 NaN + + The index entries that did not have a value in the original data frame + (for example, '2009-12-29') are by default filled with ``NaN``. + If desired, we can fill in the missing values using one of several + options. + + For example, to back-propagate the last valid value to fill the ``NaN`` + values, pass ``bfill`` as an argument to the ``method`` keyword. + + >>> df2.reindex(date_index2, method="bfill") + prices + 2009-12-29 100.0 + 2009-12-30 100.0 + 2009-12-31 100.0 + 2010-01-01 100.0 + 2010-01-02 101.0 + 2010-01-03 NaN + 2010-01-04 100.0 + 2010-01-05 89.0 + 2010-01-06 88.0 + 2010-01-07 NaN + + Please note that the ``NaN`` value present in the original DataFrame + (at index value 2010-01-03) will not be filled by any of the + value propagation schemes. This is because filling while reindexing + does not look at DataFrame values, but only compares the original and + desired indexes. If you do want to fill in the ``NaN`` values present + in the original DataFrame, use the ``fillna()`` method. + + See the :ref:`user guide ` for more. + """ + # TODO: Decide if we care about having different examples for different + # kinds + + # Automatically detect matching level when reindexing from Index to MultiIndex. + # This prevents values from being incorrectly set to NaN when the source index + # name matches a index name in the target MultiIndex + if ( + level is None + and index is not None + and isinstance(index, MultiIndex) + and not isinstance(self.index, MultiIndex) + and self.index.name in index.names + ): + level = self.index.name + self._check_copy_deprecation(copy) + + if index is not None and columns is not None and labels is not None: + raise TypeError("Cannot specify all of 'labels', 'index', 'columns'.") + elif index is not None or columns is not None: + if axis is not None: + raise TypeError( + "Cannot specify both 'axis' and any of 'index' or 'columns'" + ) + if labels is not None: + if index is not None: + columns = labels + else: + index = labels + elif axis and self._get_axis_number(axis) == 1: + columns = labels + else: + index = labels + axes: dict[Literal["index", "columns"], Any] = { + "index": index, + "columns": columns, + } + method = clean_reindex_fill_method(method) + + # if all axes that are requested to reindex are equal, then only copy + # if indicated must have index names equal here as well as values + if all( + self._get_axis(axis_name).identical(ax) + for axis_name, ax in axes.items() + if ax is not None + ): + return self.copy(deep=False) + + # check if we are a multi reindex + if self._needs_reindex_multi(axes, method, level): + return self._reindex_multi(axes, fill_value) + + # perform the reindex on the axes + return self._reindex_axes( + axes, level, limit, tolerance, method, fill_value + ).__finalize__(self, method="reindex") + + @final + def _reindex_axes( + self, + axes, + level: Level | None, + limit: int | None, + tolerance, + method, + fill_value: Scalar | None, + ) -> Self: + """Perform the reindex for all the axes.""" + obj = self + for a in self._AXIS_ORDERS: + labels = axes[a] + if labels is None: + continue + + ax = self._get_axis(a) + new_index, indexer = ax.reindex( + labels, level=level, limit=limit, tolerance=tolerance, method=method + ) + + axis = self._get_axis_number(a) + obj = obj._reindex_with_indexers( + {axis: [new_index, indexer]}, + fill_value=fill_value, + allow_dups=False, + ) + + return obj + + def _needs_reindex_multi(self, axes, method, level: Level | None) -> bool: + """Check if we do need a multi reindex.""" + return ( + (common.count_not_none(*axes.values()) == self._AXIS_LEN) + and method is None + and level is None + # reindex_multi calls self.values, so we only want to go + # down that path when doing so is cheap. + and self._can_fast_transpose + ) + + def _reindex_multi(self, axes, fill_value): + raise AbstractMethodError(self) + + @final + def _reindex_with_indexers( + self, + reindexers, + fill_value=None, + allow_dups: bool = False, + ) -> Self: + """allow_dups indicates an internal call here""" + # reindex doing multiple operations on different axes if indicated + new_data = self._mgr + for axis in sorted(reindexers.keys()): + index, indexer = reindexers[axis] + baxis = self._get_block_manager_axis(axis) + + if index is None: + continue + + index = ensure_index(index) + if indexer is not None: + indexer = ensure_platform_int(indexer) + + # TODO: speed up on homogeneous DataFrame objects (see _reindex_multi) + new_data = new_data.reindex_indexer( + index, + indexer, + axis=baxis, + fill_value=fill_value, + allow_dups=allow_dups, + ) + + if new_data is self._mgr: + new_data = new_data.copy(deep=False) + + return self._constructor_from_mgr(new_data, axes=new_data.axes).__finalize__( + self + ) + + def filter( + self, + items=None, + like: str | None = None, + regex: str | None = None, + axis: Axis | None = None, + ) -> Self: + """ + Subset the DataFrame or Series according to the specified index labels. + + For DataFrame, filter rows or columns depending on ``axis`` argument. + Note that this routine does not filter based on content. + The filter is applied to the labels of the index. + + Parameters + ---------- + items : list-like + Keep labels from axis which are in items. + like : str + Keep labels from axis for which "like in label == True". + regex : str (regular expression) + Keep labels from axis for which re.search(regex, label) == True. + axis : {0 or 'index', 1 or 'columns', None}, default None + The axis to filter on, expressed either as an index (int) + or axis name (str). By default this is the info axis, 'columns' for + ``DataFrame``. For ``Series`` this parameter is unused and defaults to + ``None``. + + Returns + ------- + Same type as caller + The filtered subset of the DataFrame or Series. + + See Also + -------- + DataFrame.loc : Access a group of rows and columns + by label(s) or a boolean array. + + Notes + ----- + The ``items``, ``like``, and ``regex`` parameters are + enforced to be mutually exclusive. + + ``axis`` defaults to the info axis that is used when indexing + with ``[]``. + + Examples + -------- + >>> df = pd.DataFrame( + ... np.array(([1, 2, 3], [4, 5, 6])), + ... index=["mouse", "rabbit"], + ... columns=["one", "two", "three"], + ... ) + >>> df + one two three + mouse 1 2 3 + rabbit 4 5 6 + + >>> # select columns by name + >>> df.filter(items=["one", "three"]) + one three + mouse 1 3 + rabbit 4 6 + + >>> # select columns by regular expression + >>> df.filter(regex="e$", axis=1) + one three + mouse 1 3 + rabbit 4 6 + + >>> # select rows containing 'bbi' + >>> df.filter(like="bbi", axis=0) + one two three + rabbit 4 5 6 + """ + nkw = common.count_not_none(items, like, regex) + if nkw > 1: + raise TypeError( + "Keyword arguments `items`, `like`, or `regex` are mutually exclusive" + ) + + if axis is None: + axis = self._info_axis_name + labels = self._get_axis(axis) + + if items is not None: + name = self._get_axis_name(axis) + items = Index(items).intersection(labels) + if len(items) == 0: + # Keep the dtype of labels when we are empty + items = items.astype(labels.dtype) + # error: Keywords must be strings + return self.reindex(**{name: items}) # type: ignore[misc] + elif like: + + def f(x) -> bool: + assert like is not None # needed for mypy + return like in ensure_str(x) + + values = labels.map(f) + return self.loc(axis=axis)[values] + elif regex: + + def f(x) -> bool: + return matcher.search(ensure_str(x)) is not None + + matcher = re.compile(regex) + values = labels.map(f) + return self.loc(axis=axis)[values] + else: + raise TypeError("Must pass either `items`, `like`, or `regex`") + + @final + def head(self, n: int = 5) -> Self: + """ + Return the first `n` rows. + + This function exhibits the same behavior as ``df[:n]``, returning the + first ``n`` rows based on position. It is useful for quickly checking + if your object has the right type of data in it. + + When ``n`` is positive, it returns the first ``n`` rows. For ``n`` equal to 0, + it returns an empty object. When ``n`` is negative, it returns + all rows except the last ``|n|`` rows, mirroring the behavior of ``df[:n]``. + + If ``n`` is larger than the number of rows, this function returns all rows. + + Parameters + ---------- + n : int, default 5 + Number of rows to select. + + Returns + ------- + same type as caller + The first `n` rows of the caller object. + + See Also + -------- + DataFrame.tail: Returns the last `n` rows. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "animal": [ + ... "alligator", + ... "bee", + ... "falcon", + ... "lion", + ... "monkey", + ... "parrot", + ... "shark", + ... "whale", + ... "zebra", + ... ] + ... } + ... ) + >>> df + animal + 0 alligator + 1 bee + 2 falcon + 3 lion + 4 monkey + 5 parrot + 6 shark + 7 whale + 8 zebra + + Viewing the first 5 lines + + >>> df.head() + animal + 0 alligator + 1 bee + 2 falcon + 3 lion + 4 monkey + + Viewing the first `n` lines (three in this case) + + >>> df.head(3) + animal + 0 alligator + 1 bee + 2 falcon + + For negative values of `n` + + >>> df.head(-3) + animal + 0 alligator + 1 bee + 2 falcon + 3 lion + 4 monkey + 5 parrot + """ + return self.iloc[:n].copy() + + @final + def tail(self, n: int = 5) -> Self: + """ + Return the last `n` rows. + + This function returns last `n` rows from the object based on + position. It is useful for quickly verifying data, for example, + after sorting or appending rows. + + For negative values of `n`, this function returns all rows except + the first `|n|` rows, equivalent to ``df[|n|:]``. + + If ``n`` is larger than the number of rows, this function returns all rows. + + Parameters + ---------- + n : int, default 5 + Number of rows to select. + + Returns + ------- + type of caller + The last `n` rows of the caller object. + + See Also + -------- + DataFrame.head : The first `n` rows of the caller object. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "animal": [ + ... "alligator", + ... "bee", + ... "falcon", + ... "lion", + ... "monkey", + ... "parrot", + ... "shark", + ... "whale", + ... "zebra", + ... ] + ... } + ... ) + >>> df + animal + 0 alligator + 1 bee + 2 falcon + 3 lion + 4 monkey + 5 parrot + 6 shark + 7 whale + 8 zebra + + Viewing the last 5 lines + + >>> df.tail() + animal + 4 monkey + 5 parrot + 6 shark + 7 whale + 8 zebra + + Viewing the last `n` lines (three in this case) + + >>> df.tail(3) + animal + 6 shark + 7 whale + 8 zebra + + For negative values of `n` + + >>> df.tail(-3) + animal + 3 lion + 4 monkey + 5 parrot + 6 shark + 7 whale + 8 zebra + """ + if n == 0: + return self.iloc[0:0].copy() + return self.iloc[-n:].copy() + + @final + def sample( + self, + n: int | None = None, + frac: float | None = None, + replace: bool = False, + weights=None, + random_state: RandomState | None = None, + axis: Axis | None = None, + ignore_index: bool = False, + ) -> Self: + """ + Return a random sample of items from an axis of object. + + You can use `random_state` for reproducibility. + + Parameters + ---------- + n : int, optional + Number of items from axis to return. Cannot be used with `frac`. + Default = 1 if `frac` = None. + frac : float, optional + Fraction of axis items to return. Cannot be used with `n`. + replace : bool, default False + Allow or disallow sampling of the same row more than once. + weights : str or ndarray-like, optional + Default ``None`` results in equal probability weighting. + If passed a Series, will align with target object on index. Index + values in weights not found in sampled object will be ignored and + index values in sampled object not in weights will be assigned + weights of zero. + If called on a DataFrame, will accept the name of a column + when axis = 0. + Unless weights are a Series, weights must be same length as axis + being sampled. + If weights do not sum to 1, they will be normalized to sum to 1. + Missing values in the weights column will be treated as zero. + Infinite values not allowed. + When replace = False will not allow ``(n * max(weights) / sum(weights)) > 1`` + in order to avoid biased results. See the Notes below for more details. + random_state : int, array-like, BitGenerator, np.random.RandomState, np.random.Generator, optional + If int, array-like, or BitGenerator, seed for random number generator. + If np.random.RandomState or np.random.Generator, use as given. + Default ``None`` results in sampling with the current state of np.random. + axis : {0 or 'index', 1 or 'columns', None}, default None + Axis to sample. Accepts axis number or name. Default is stat axis + for given data type. For `Series` this parameter is unused and defaults to `None`. + ignore_index : bool, default False + If True, the resulting index will be labeled 0, 1, …, n - 1. + + Returns + ------- + Series or DataFrame + A new object of same type as caller containing `n` items randomly + sampled from the caller object. + + See Also + -------- + DataFrameGroupBy.sample: Generates random samples from each group of a + DataFrame object. + SeriesGroupBy.sample: Generates random samples from each group of a + Series object. + numpy.random.choice: Generates a random sample from a given 1-D numpy + array. + + Notes + ----- + If `frac` > 1, `replacement` should be set to `True`. + + When replace = False will not allow ``(n * max(weights) / sum(weights)) > 1``, + since that would cause results to be biased. E.g. sampling 2 items without replacement + with weights [100, 1, 1] would yield two last items in 1/2 of cases, instead of 1/102. + This is similar to specifying `n=4` without replacement on a Series with 3 elements. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "num_legs": [2, 4, 8, 0], + ... "num_wings": [2, 0, 0, 0], + ... "num_specimen_seen": [10, 2, 1, 8], + ... }, + ... index=["falcon", "dog", "spider", "fish"], + ... ) + >>> df + num_legs num_wings num_specimen_seen + falcon 2 2 10 + dog 4 0 2 + spider 8 0 1 + fish 0 0 8 + + Extract 3 random elements from the ``Series`` ``df['num_legs']``: + Note that we use `random_state` to ensure the reproducibility of + the examples. + + >>> df["num_legs"].sample(n=3, random_state=1) + fish 0 + spider 8 + falcon 2 + Name: num_legs, dtype: int64 + + A random 50% sample of the ``DataFrame`` with replacement: + + >>> df.sample(frac=0.5, replace=True, random_state=1) + num_legs num_wings num_specimen_seen + dog 4 0 2 + fish 0 0 8 + + An upsample sample of the ``DataFrame`` with replacement: + Note that `replace` parameter has to be `True` for `frac` parameter > 1. + + >>> df.sample(frac=2, replace=True, random_state=1) + num_legs num_wings num_specimen_seen + dog 4 0 2 + fish 0 0 8 + falcon 2 2 10 + falcon 2 2 10 + fish 0 0 8 + dog 4 0 2 + fish 0 0 8 + dog 4 0 2 + + Using a DataFrame column as weights. Rows with larger value in the + `num_specimen_seen` column are more likely to be sampled. + + >>> df.sample(n=2, weights="num_specimen_seen", random_state=1) + num_legs num_wings num_specimen_seen + falcon 2 2 10 + fish 0 0 8 + """ # noqa: E501 + if axis is None: + axis = 0 + + axis = self._get_axis_number(axis) + obj_len = self.shape[axis] + + # Process random_state argument + rs = common.random_state(random_state) + + size = sample.process_sampling_size(n, frac, replace) + if size is None: + assert frac is not None + size = round(frac * obj_len) + + if weights is not None: + weights = sample.preprocess_weights(self, weights, axis) + + sampled_indices = sample.sample(obj_len, size, replace, weights, rs) + result = self.take(sampled_indices, axis=axis) + + if ignore_index: + result.index = default_index(len(result)) + + return result + + @overload + def pipe( + self, + func: Callable[Concatenate[Self, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: ... + + @overload + def pipe( + self, + func: tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, + ) -> T: ... + + @final + def pipe( + self, + func: Callable[Concatenate[Self, P], T] | tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, + ) -> T: + r""" + Apply chainable functions that expect Series or DataFrames. + + Parameters + ---------- + func : function + Function to apply to the Series/DataFrame. + ``args``, and ``kwargs`` are passed into ``func``. + Alternatively a ``(callable, data_keyword)`` tuple where + ``data_keyword`` is a string indicating the keyword of + ``callable`` that expects the Series/DataFrame. + *args : iterable, optional + Positional arguments passed into ``func``. + **kwargs : mapping, optional + A dictionary of keyword arguments passed into ``func``. + + Returns + ------- + The return type of ``func``. + The result of applying ``func`` to the Series or DataFrame. + + See Also + -------- + DataFrame.apply : Apply a function along input axis of DataFrame. + DataFrame.map : Apply a function elementwise on a whole DataFrame. + Series.map : Apply a mapping correspondence on a + :class:`~pandas.Series`. + + Notes + ----- + Use ``.pipe`` when chaining together functions that expect + Series, DataFrames or GroupBy objects. + + Examples + -------- + Constructing an income DataFrame from a dictionary. + + >>> data = [[8000, 1000], [9500, np.nan], [5000, 2000]] + >>> df = pd.DataFrame(data, columns=["Salary", "Others"]) + >>> df + Salary Others + 0 8000 1000.0 + 1 9500 NaN + 2 5000 2000.0 + + Functions that perform tax reductions on an income DataFrame. + + >>> def subtract_federal_tax(df): + ... return df * 0.9 + >>> def subtract_state_tax(df, rate): + ... return df * (1 - rate) + >>> def subtract_national_insurance(df, rate, rate_increase): + ... new_rate = rate + rate_increase + ... return df * (1 - new_rate) + + Instead of writing + + >>> subtract_national_insurance( + ... subtract_state_tax(subtract_federal_tax(df), rate=0.12), + ... rate=0.05, + ... rate_increase=0.02, + ... ) # doctest: +SKIP + + You can write + + >>> ( + ... df.pipe(subtract_federal_tax) + ... .pipe(subtract_state_tax, rate=0.12) + ... .pipe(subtract_national_insurance, rate=0.05, rate_increase=0.02) + ... ) + Salary Others + 0 5892.48 736.56 + 1 6997.32 NaN + 2 3682.80 1473.12 + + If you have a function that takes the data as (say) the second + argument, pass a tuple indicating which keyword expects the + data. For example, suppose ``national_insurance`` takes its data as ``df`` + in the second argument: + + >>> def subtract_national_insurance(rate, df, rate_increase): + ... new_rate = rate + rate_increase + ... return df * (1 - new_rate) + >>> ( + ... df.pipe(subtract_federal_tax) + ... .pipe(subtract_state_tax, rate=0.12) + ... .pipe( + ... (subtract_national_insurance, "df"), rate=0.05, rate_increase=0.02 + ... ) + ... ) + Salary Others + 0 5892.48 736.56 + 1 6997.32 NaN + 2 3682.80 1473.12 + """ + return common.pipe(self.copy(deep=False), func, *args, **kwargs) + + # ---------------------------------------------------------------------- + # Attribute access + + @final + def __finalize__(self, other, method: str | None = None, **kwargs) -> Self: + """ + Propagate metadata from other to self. + + This is the default implementation. Subclasses may override this method to + implement their own metadata handling. + + Parameters + ---------- + other : the object from which to get the attributes that we are going + to propagate. If ``other`` has an ``input_objs`` attribute, then + this attribute must contain an iterable of objects, each with an + ``attrs`` attribute. + method : str, optional + A passed method name providing context on where ``__finalize__`` + was called. + + .. warning:: + + The value passed as `method` are not currently considered + stable across pandas releases. + + Notes + ----- + In case ``other`` has an ``input_objs`` attribute, this method only + propagates its metadata if each object in ``input_objs`` has the exact + same metadata as the others. + """ + if isinstance(other, NDFrame): + if other.attrs: + # We want attrs propagation to have minimal performance + # impact if attrs are not used; i.e. attrs is an empty dict. + # One could make the deepcopy unconditionally, but a deepcopy + # of an empty dict is 50x more expensive than the empty check. + self.attrs = deepcopy(other.attrs) + self.flags.allows_duplicate_labels = ( + self.flags.allows_duplicate_labels + and other.flags.allows_duplicate_labels + ) + # For subclasses using _metadata. + for name in set(self._metadata) & set(other._metadata): + assert isinstance(name, str) + object.__setattr__(self, name, getattr(other, name, None)) + + elif hasattr(other, "input_objs"): + objs = other.input_objs + # propagate attrs only if all inputs have the same attrs + if all(bool(obj.attrs) for obj in objs): + # all inputs have non-empty attrs + attrs = objs[0].attrs + have_same_attrs = all(obj.attrs == attrs for obj in objs[1:]) + if have_same_attrs: + self.attrs = deepcopy(attrs) + + allows_duplicate_labels = all(x.flags.allows_duplicate_labels for x in objs) + self.flags.allows_duplicate_labels = allows_duplicate_labels + + return self + + @final + def __getattr__(self, name: str): + """ + After regular attribute access, try looking up the name + This allows simpler access to columns for interactive use. + """ + # Note: obj.x will always call obj.__getattribute__('x') prior to + # calling obj.__getattr__('x'). + if ( + name not in self._internal_names_set + and name not in self._metadata + and name not in self._accessors + and self._info_axis._can_hold_identifiers_and_holds_name(name) + ): + return self[name] + return object.__getattribute__(self, name) + + @final + def __setattr__(self, name: str, value) -> None: + """ + After regular attribute access, try setting the name + This allows simpler access to columns for interactive use. + """ + # first try regular attribute access via __getattribute__, so that + # e.g. ``obj.x`` and ``obj.x = 4`` will always reference/modify + # the same attribute. + + try: + object.__getattribute__(self, name) + return object.__setattr__(self, name, value) + except AttributeError: + pass + + # if this fails, go on to more involved attribute setting + # (note that this matches __getattr__, above). + if name in self._internal_names_set: + object.__setattr__(self, name, value) + elif name in self._metadata: + object.__setattr__(self, name, value) + else: + try: + existing = getattr(self, name) + if isinstance(existing, Index): + object.__setattr__(self, name, value) + elif name in self._info_axis: + self[name] = value + else: + object.__setattr__(self, name, value) + except (AttributeError, TypeError): + if isinstance(self, ABCDataFrame) and (is_list_like(value)): + warnings.warn( + "Pandas doesn't allow columns to be " + "created via a new attribute name - see " + "https://pandas.pydata.org/pandas-docs/" + "stable/indexing.html#attribute-access", + stacklevel=find_stack_level(), + ) + object.__setattr__(self, name, value) + + @final + def _dir_additions(self) -> set[str]: + """ + add the string-like attributes from the info_axis. + If info_axis is a MultiIndex, its first level values are used. + """ + additions = super()._dir_additions() + if self._info_axis._can_hold_strings: + additions.update(self._info_axis._dir_additions_for_owner) + return additions + + # ---------------------------------------------------------------------- + # Consolidation of internals + + @final + def _consolidate_inplace(self) -> None: + """Consolidate data in place and return None""" + + self._mgr = self._mgr.consolidate() + + @final + def _consolidate(self): + """ + Compute NDFrame with "consolidated" internals (data of each dtype + grouped together in a single ndarray). + + Returns + ------- + consolidated : same type as caller + """ + cons_data = self._mgr.consolidate() + return self._constructor_from_mgr(cons_data, axes=cons_data.axes).__finalize__( + self + ) + + @final + @property + def _is_mixed_type(self) -> bool: + if self._mgr.is_single_block: + # Includes all Series cases + return False + + if self._mgr.any_extension_types: + # Even if they have the same dtype, we can't consolidate them, + # so we pretend this is "mixed'" + return True + + return self.dtypes.nunique() > 1 + + @final + def _get_numeric_data(self) -> Self: + new_mgr = self._mgr.get_numeric_data() + return self._constructor_from_mgr(new_mgr, axes=new_mgr.axes).__finalize__(self) + + @final + def _get_bool_data(self): + new_mgr = self._mgr.get_bool_data() + return self._constructor_from_mgr(new_mgr, axes=new_mgr.axes).__finalize__(self) + + # ---------------------------------------------------------------------- + # Internal Interface Methods + + @property + def values(self): + raise AbstractMethodError(self) + + @property + def _values(self) -> ArrayLike: + """internal implementation""" + raise AbstractMethodError(self) + + @property + def dtypes(self): + """ + Return the dtypes in the DataFrame. + + This returns a Series with the data type of each column. + The result's index is the original DataFrame's columns. Columns + with mixed types are stored with the ``object`` dtype. See + :ref:`the User Guide ` for more. + + Returns + ------- + pandas.Series + The data type of each column. + + See Also + -------- + Series.dtypes : Return the dtype object of the underlying data. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "float": [1.0], + ... "int": [1], + ... "datetime": [pd.Timestamp("20180310")], + ... "string": ["foo"], + ... } + ... ) + >>> df.dtypes + float float64 + int int64 + datetime datetime64[us] + string str + dtype: object + """ + data = self._mgr.get_dtypes() + return self._constructor_sliced(data, index=self._info_axis, dtype=np.object_) + + @final + def astype( + self, + dtype, + copy: bool | lib.NoDefault = lib.no_default, + errors: IgnoreRaise = "raise", + ) -> Self: + """ + Cast a pandas object to a specified dtype ``dtype``. + + This method allows the conversion of the data types of pandas objects, + including DataFrames and Series, to the specified dtype. It supports casting + entire objects to a single data type or applying different data types to + individual columns using a mapping. + + Parameters + ---------- + dtype : str, data type, Series or Mapping of column name -> data type + Use a str, numpy.dtype, pandas.ExtensionDtype or Python type to + cast entire pandas object to the same type. Alternatively, use a + mapping, e.g. {col: dtype, ...}, where col is a column label and dtype is + a numpy.dtype or Python type to cast one or more of the DataFrame's + columns to column-specific types. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + errors : {'raise', 'ignore'}, default 'raise' + Control raising of exceptions on invalid data for provided dtype. + + - ``raise`` : allow exceptions to be raised + - ``ignore`` : suppress exceptions. On error return original object. + + Returns + ------- + same type as caller + The pandas object casted to the specified ``dtype``. + + See Also + -------- + to_datetime : Convert argument to datetime. + to_timedelta : Convert argument to timedelta. + to_numeric : Convert argument to a numeric type. + numpy.ndarray.astype : Cast a numpy array to a specified type. + + Notes + ----- + .. versionchanged:: 2.0.0 + + Using ``astype`` to convert from timezone-naive dtype to + timezone-aware dtype will raise an exception. + Use :meth:`Series.dt.tz_localize` instead. + + Examples + -------- + Create a DataFrame: + + >>> d = {"col1": [1, 2], "col2": [3, 4]} + >>> df = pd.DataFrame(data=d) + >>> df.dtypes + col1 int64 + col2 int64 + dtype: object + + Cast all columns to int32: + + >>> df.astype("int32").dtypes + col1 int32 + col2 int32 + dtype: object + + Cast col1 to int32 using a dictionary: + + >>> df.astype({"col1": "int32"}).dtypes + col1 int32 + col2 int64 + dtype: object + + Create a series: + + >>> ser = pd.Series([1, 2], dtype="int32") + >>> ser + 0 1 + 1 2 + dtype: int32 + >>> ser.astype("int64") + 0 1 + 1 2 + dtype: int64 + + Convert to categorical type: + + >>> ser.astype("category") + 0 1 + 1 2 + dtype: category + Categories (2, int32): [1, 2] + + Convert to ordered categorical type with custom ordering: + + >>> from pandas.api.types import CategoricalDtype + >>> cat_dtype = CategoricalDtype(categories=[2, 1], ordered=True) + >>> ser.astype(cat_dtype) + 0 1 + 1 2 + dtype: category + Categories (2, int64): [2 < 1] + + Create a series of dates: + + >>> ser_date = pd.Series(pd.date_range("20200101", periods=3)) + >>> ser_date + 0 2020-01-01 + 1 2020-01-02 + 2 2020-01-03 + dtype: datetime64[us] + """ + self._check_copy_deprecation(copy) + if is_dict_like(dtype): + if self.ndim == 1: # i.e. Series + if len(dtype) > 1 or self.name not in dtype: + raise KeyError( + "Only the Series name can be used for " + "the key in Series dtype mappings." + ) + new_type = dtype[self.name] + return self.astype(new_type, errors=errors) + + # GH#44417 cast to Series so we can use .iat below, which will be + # robust in case we + from pandas import Series + + dtype_ser = Series(dtype, dtype=object) + + for col_name in dtype_ser.index: + if col_name not in self: + raise KeyError( + "Only a column name can be used for the " + "key in a dtype mappings argument. " + f"'{col_name}' not found in columns." + ) + + dtype_ser = dtype_ser.reindex(self.columns, fill_value=None) + + results = [] + for i, (col_name, col) in enumerate(self.items()): + cdt = dtype_ser.iat[i] + if isna(cdt): + res_col = col.copy(deep=False) + else: + try: + res_col = col.astype(dtype=cdt, errors=errors) + except ValueError as ex: + ex.args = ( + f"{ex}: Error while type casting for column '{col_name}'", + ) + raise + results.append(res_col) + + elif is_extension_array_dtype(dtype) and self.ndim > 1: + # TODO(EA2D): special case not needed with 2D EAs + dtype = pandas_dtype(dtype) + if isinstance(dtype, ExtensionDtype) and all( + block.values.dtype == dtype for block in self._mgr.blocks + ): + return self.copy(deep=False) + # GH 18099/22869: columnwise conversion to extension dtype + # GH 24704: self.items handles duplicate column names + results = [ser.astype(dtype, errors=errors) for _, ser in self.items()] + + else: + # else, only a single dtype is given + new_data = self._mgr.astype(dtype=dtype, errors=errors) + res = self._constructor_from_mgr(new_data, axes=new_data.axes) + return res.__finalize__(self, method="astype") + + # GH 33113: handle empty frame or series + if not results: + return self.copy(deep=False) + + # GH 19920: retain column metadata after concat + result = concat(results, axis=1) + # GH#40810 retain subclass + # error: Incompatible types in assignment + # (expression has type "Self", variable has type "DataFrame") + result = self._constructor(result) # type: ignore[assignment] + result.columns = self.columns + result = result.__finalize__(self, method="astype") + # https://github.com/python/mypy/issues/8354 + return cast(Self, result) + + @final + def copy(self, deep: bool = True) -> Self: + """ + Make a copy of this object's indices and data. + + When ``deep=True`` (default), a new object will be created with a + copy of the calling object's data and indices. Modifications to + the data or indices of the copy will not be reflected in the + original object (see notes below). + + When ``deep=False``, a new object will be created without copying + the calling object's data or index (only references to the data + and index are copied). With Copy-on-Write, changes to the original + will *not* be reflected in the shallow copy (and vice versa). The + shallow copy uses a lazy (deferred) copy mechanism that copies the + data only when any changes to the original or shallow copy are made, + ensuring memory efficiency while maintaining data integrity. + + .. note:: + In pandas versions prior to 3.0, the default behavior without + Copy-on-Write was different: changes to the original *were* reflected + in the shallow copy (and vice versa). See the :ref:`Copy-on-Write + user guide ` for more information. + + Parameters + ---------- + deep : bool, default True + Make a deep copy, including a copy of the data and the indices. + With ``deep=False`` neither the indices nor the data are copied. + + Returns + ------- + Series or DataFrame + Object type matches caller. + + See Also + -------- + copy.copy : Return a shallow copy of an object. + copy.deepcopy : Return a deep copy of an object. + + Notes + ----- + When ``deep=True``, data is copied but actual Python objects + will not be copied recursively, only the reference to the object. + This is in contrast to `copy.deepcopy` in the Standard Library, + which recursively copies object data (see examples below). + + While ``Index`` objects are copied when ``deep=True``, the underlying + numpy array is not copied for performance reasons. Since ``Index`` is + immutable, the underlying data can be safely shared and a copy + is not needed. + + Since pandas is not thread safe, see the + :ref:`gotchas ` when copying in a threading + environment. + + Copy-on-Write protects shallow copies against accidental modifications. + This means that any changes to the copied data would make a new copy + of the data upon write (and vice versa). Changes made to either the + original or copied variable would not be reflected in the counterpart. + See :ref:`Copy_on_Write ` for more information. + + Examples + -------- + >>> s = pd.Series([1, 2], index=["a", "b"]) + >>> s + a 1 + b 2 + dtype: int64 + + >>> s_copy = s.copy(deep=True) + >>> s_copy + a 1 + b 2 + dtype: int64 + + Due to Copy-on-Write, shallow copies still protect data modifications. + Note shallow does not get modified below. + + >>> s = pd.Series([1, 2], index=["a", "b"]) + >>> shallow = s.copy(deep=False) + >>> s.iloc[1] = 200 + >>> shallow + a 1 + b 2 + dtype: int64 + + When the data has object dtype, even a deep copy does not copy the + underlying Python objects. Updating a nested data object will be + reflected in the deep copy. + + >>> s = pd.Series([[1, 2], [3, 4]]) + >>> deep = s.copy() + >>> s[0][0] = 10 + >>> s + 0 [10, 2] + 1 [3, 4] + dtype: object + >>> deep + 0 [10, 2] + 1 [3, 4] + dtype: object + """ + data = self._mgr.copy(deep=deep) + return self._constructor_from_mgr(data, axes=data.axes).__finalize__( + self, method="copy" + ) + + @final + def __copy__(self) -> Self: + return self.copy(deep=False) + + @final + def __deepcopy__(self, memo=None) -> Self: + """ + Parameters + ---------- + memo, default None + Standard signature. Unused + """ + return self.copy(deep=True) + + @final + def infer_objects(self, copy: bool | lib.NoDefault = lib.no_default) -> Self: + """ + Attempt to infer better dtypes for object columns. + + Attempts soft conversion of object-dtyped + columns, leaving non-object and unconvertible + columns unchanged. The inference rules are the + same as during normal Series/DataFrame construction. + + Parameters + ---------- + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + Returns + ------- + same type as input object + Returns an object of the same type as the input object. + + See Also + -------- + to_datetime : Convert argument to datetime. + to_timedelta : Convert argument to timedelta. + to_numeric : Convert argument to numeric type. + convert_dtypes : Convert argument to best possible dtype. + + Examples + -------- + >>> df = pd.DataFrame({"A": ["a", 1, 2, 3]}) + >>> df = df.iloc[1:] + >>> df + A + 1 1 + 2 2 + 3 3 + + >>> df.dtypes + A object + dtype: object + + >>> df.infer_objects().dtypes + A int64 + dtype: object + """ + self._check_copy_deprecation(copy) + new_mgr = self._mgr.convert() + res = self._constructor_from_mgr(new_mgr, axes=new_mgr.axes) + return res.__finalize__(self, method="infer_objects") + + @final + def convert_dtypes( + self, + infer_objects: bool = True, + convert_string: bool = True, + convert_integer: bool = True, + convert_boolean: bool = True, + convert_floating: bool = True, + dtype_backend: DtypeBackend = "numpy_nullable", + ) -> Self: + """ + Convert columns from numpy dtypes to the best dtypes that support ``pd.NA``. + + Parameters + ---------- + infer_objects : bool, default True + Whether object dtypes should be converted to the best possible types. + convert_string : bool, default True + Whether object dtypes should be converted to ``StringDtype()``. + convert_integer : bool, default True + Whether, if possible, conversion can be done to integer extension types. + convert_boolean : bool, defaults True + Whether object dtypes should be converted to ``BooleanDtypes()``. + convert_floating : bool, defaults True + Whether, if possible, conversion can be done to floating extension types. + If `convert_integer` is also True, preference will be give to integer + dtypes if the floats can be faithfully casted to integers. + dtype_backend : {'numpy_nullable', 'pyarrow'}, default 'numpy_nullable' + Back-end data type applied to the resultant :class:`DataFrame` or + :class:`Series` (still experimental). Behaviour is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed + :class:`DataFrame` or :class:`Serires`. + * ``"pyarrow"``: returns pyarrow-backed nullable :class:`ArrowDtype` + :class:`DataFrame` or :class:`Series`. + + .. versionadded:: 2.0 + + Returns + ------- + Series or DataFrame + Copy of input object with new dtype. + + See Also + -------- + infer_objects : Infer dtypes of objects. + to_datetime : Convert argument to datetime. + to_timedelta : Convert argument to timedelta. + to_numeric : Convert argument to a numeric type. + + Notes + ----- + By default, ``convert_dtypes`` will attempt to convert a Series (or each + Series in a DataFrame) to dtypes that support ``pd.NA``. By using the options + ``convert_string``, ``convert_integer``, ``convert_boolean`` and + ``convert_floating``, it is possible to turn off individual conversions + to ``StringDtype``, the integer extension types, ``BooleanDtype`` + or floating extension types, respectively. + + For object-dtyped columns, if ``infer_objects`` is ``True``, use the inference + rules as during normal Series/DataFrame construction. Then, if possible, + convert to ``StringDtype``, ``BooleanDtype`` or an appropriate integer + or floating extension type, otherwise leave as ``object``. + + If the dtype is integer, convert to an appropriate integer extension type. + + If the dtype is numeric, and consists of all integers, convert to an + appropriate integer extension type. Otherwise, convert to an + appropriate floating extension type. + + In the future, as new dtypes are added that support ``pd.NA``, the results + of this method will change to support those new dtypes. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "a": pd.Series([1, 2, 3], dtype=np.dtype("int32")), + ... "b": pd.Series(["x", "y", "z"], dtype=np.dtype("O")), + ... "c": pd.Series([True, False, np.nan], dtype=np.dtype("O")), + ... "d": pd.Series(["h", "i", np.nan], dtype=np.dtype("O")), + ... "e": pd.Series([10, np.nan, 20], dtype=np.dtype("float")), + ... "f": pd.Series([np.nan, 100.5, 200], dtype=np.dtype("float")), + ... } + ... ) + + Start with a DataFrame with default dtypes. + + >>> df + a b c d e f + 0 1 x True h 10.0 NaN + 1 2 y False i NaN 100.5 + 2 3 z NaN NaN 20.0 200.0 + + >>> df.dtypes + a int32 + b object + c object + d object + e float64 + f float64 + dtype: object + + Convert the DataFrame to use best possible dtypes. + + >>> dfn = df.convert_dtypes() + >>> dfn + a b c d e f + 0 1 x True h 10 + 1 2 y False i 100.5 + 2 3 z 20 200.0 + + >>> dfn.dtypes + a Int32 + b string + c boolean + d string + e Int64 + f Float64 + dtype: object + + Start with a Series of strings and missing data represented by ``np.nan``. + + >>> s = pd.Series(["a", "b", np.nan]) + >>> s + 0 a + 1 b + 2 NaN + dtype: str + + Obtain a Series with dtype ``StringDtype``. + + >>> s.convert_dtypes() + 0 a + 1 b + 2 + dtype: string + """ + check_dtype_backend(dtype_backend) + new_mgr = self._mgr.convert_dtypes( + infer_objects=infer_objects, + convert_string=convert_string, + convert_integer=convert_integer, + convert_boolean=convert_boolean, + convert_floating=convert_floating, + dtype_backend=dtype_backend, + ) + res = self._constructor_from_mgr(new_mgr, axes=new_mgr.axes) + return res.__finalize__(self, method="convert_dtypes") + + # ---------------------------------------------------------------------- + # Filling NA's + + @final + def _pad_or_backfill( + self, + method: Literal["ffill", "bfill", "pad", "backfill"], + *, + axis: None | Axis = None, + inplace: bool = False, + limit: None | int = None, + limit_area: Literal["inside", "outside"] | None = None, + ): + if axis is None: + axis = 0 + axis = self._get_axis_number(axis) + method = clean_fill_method(method) + + if axis == 1: + if not self._mgr.is_single_block and inplace: + raise NotImplementedError + # e.g. test_align_fill_method + result = self.T._pad_or_backfill( + method=method, limit=limit, limit_area=limit_area + ).T + + return result + + new_mgr = self._mgr.pad_or_backfill( + method=method, + limit=limit, + limit_area=limit_area, + inplace=inplace, + ) + result = self._constructor_from_mgr(new_mgr, axes=new_mgr.axes) + if inplace: + self._update_inplace(result) + return self + else: + return result.__finalize__(self, method="fillna") + + @final + def fillna( + self, + value: Hashable | Mapping | Series | DataFrame, + *, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + ) -> Self: + """ + Fill NA/NaN values with `value`. + + Parameters + ---------- + value : scalar, dict, Series, or DataFrame + Value to use to fill holes (e.g. 0), alternately a + dict/Series/DataFrame of values specifying which value to use for + each index (for a Series) or column (for a DataFrame). Values not + in the dict/Series/DataFrame will not be filled. This value cannot + be a list. + axis : {0 or 'index'} for Series, {0 or 'index', 1 or 'columns'} for DataFrame + Axis along which to fill missing values. For `Series` + this parameter is unused and defaults to 0. + inplace : bool, default False + If True, fill in-place. Note: this will modify any + other views on this object (e.g., a no-copy slice for a column in a + DataFrame). + limit : int, default None + This is the maximum number of entries along the entire axis + where NaNs will be filled. Must be greater than 0 if not None. + + Returns + ------- + Series/DataFrame + Object with missing values filled. + + See Also + -------- + ffill : Fill values by propagating the last valid observation to next valid. + bfill : Fill values by using the next valid observation to fill the gap. + interpolate : Fill NaN values using interpolation. + reindex : Conform object to new index. + asfreq : Convert TimeSeries to specified frequency. + + Notes + ----- + For non-object dtype, ``value=None`` will use the NA value of the dtype. + See more details in the :ref:`Filling missing data` + section. + + Examples + -------- + >>> df = pd.DataFrame( + ... [ + ... [np.nan, 2, np.nan, 0], + ... [3, 4, np.nan, 1], + ... [np.nan, np.nan, np.nan, np.nan], + ... [np.nan, 3, np.nan, 4], + ... ], + ... columns=list("ABCD"), + ... ) + >>> df + A B C D + 0 NaN 2.0 NaN 0.0 + 1 3.0 4.0 NaN 1.0 + 2 NaN NaN NaN NaN + 3 NaN 3.0 NaN 4.0 + + Replace all NaN elements with 0s. + + >>> df.fillna(0) + A B C D + 0 0.0 2.0 0.0 0.0 + 1 3.0 4.0 0.0 1.0 + 2 0.0 0.0 0.0 0.0 + 3 0.0 3.0 0.0 4.0 + + Replace all NaN elements in column 'A', 'B', 'C', and 'D', with 0, 1, + 2, and 3 respectively. + + >>> values = {"A": 0, "B": 1, "C": 2, "D": 3} + >>> df.fillna(value=values) + A B C D + 0 0.0 2.0 2.0 0.0 + 1 3.0 4.0 2.0 1.0 + 2 0.0 1.0 2.0 3.0 + 3 0.0 3.0 2.0 4.0 + + Only replace the first NaN element. + + >>> df.fillna(value=values, limit=1) + A B C D + 0 0.0 2.0 2.0 0.0 + 1 3.0 4.0 NaN 1.0 + 2 NaN 1.0 NaN 3.0 + 3 NaN 3.0 NaN 4.0 + + When filling using a DataFrame, replacement happens along + the same column names and same indices + + >>> df2 = pd.DataFrame(np.zeros((4, 4)), columns=list("ABCE")) + >>> df.fillna(df2) + A B C D + 0 0.0 2.0 0.0 0.0 + 1 3.0 4.0 0.0 1.0 + 2 0.0 0.0 0.0 NaN + 3 0.0 3.0 0.0 4.0 + + Note that column D is not affected since it is not present in df2. + """ + inplace = validate_bool_kwarg(inplace, "inplace") + if inplace: + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount( + self + ) <= REF_COUNT_METHOD and not common.is_local_in_caller_frame(self): + warnings.warn( + _chained_assignment_method_msg, + ChainedAssignmentError, + stacklevel=2, + ) + + if isinstance(value, (list, tuple)): + raise TypeError( + '"value" parameter must be a scalar or dict, but ' + f'you passed a "{type(value).__name__}"' + ) + + # set the default here, so functions examining the signature + # can detect if something was set (e.g. in groupby) (GH9221) + if axis is None: + axis = 0 + axis = self._get_axis_number(axis) + + if self.ndim == 1: + if isinstance(value, (dict, ABCSeries)): + if not len(value): + # test_fillna_nonscalar + return self if inplace else self.copy(deep=False) + from pandas import Series + + value = Series(value) + value = value.reindex(self.index) + value = value._values + elif not is_list_like(value): + pass + else: + raise TypeError( + '"value" parameter must be a scalar, dict ' + "or Series, but you passed a " + f'"{type(value).__name__}"' + ) + + new_data = self._mgr.fillna(value=value, limit=limit, inplace=inplace) + + elif isinstance(value, (dict, ABCSeries)): + result = self if inplace else self.copy(deep=False) + if axis == 1: + # Check that all columns in result have the same dtype + # otherwise don't bother with fillna and losing accurate dtypes + unique_dtypes = algos.unique(self._mgr.get_dtypes()) + if len(unique_dtypes) > 1: + raise ValueError( + "All columns must have the same dtype, but got dtypes: " + f"{list(unique_dtypes)}" + ) + # Use the first column, which we have already validated has the + # same dtypes as the other columns. + if not can_hold_element(result.iloc[:, 0], value): + frame_dtype = unique_dtypes.item() + raise ValueError( + f"{value} not a suitable type to fill into {frame_dtype}" + ) + result = result.T.fillna(value=value).T + if inplace: + self._update_inplace(result) + result = self + else: + for k, v in value.items(): + if k not in result: + continue + + res_k = result[k].fillna(v, limit=limit) + + if not inplace: + result[k] = res_k + # We can write into our existing column(s) iff dtype + # was preserved. + elif isinstance(res_k, ABCSeries): + # i.e. 'k' only shows up once in self.columns + if res_k.dtype == result[k].dtype: + result.loc[:, k] = res_k + else: + # Different dtype -> no way to do inplace. + result[k] = res_k + else: + # see test_fillna_dict_inplace_nonunique_columns + locs = result.columns.get_loc(k) + if isinstance(locs, slice): + locs = range(self.shape[1])[locs] + elif isinstance(locs, np.ndarray) and locs.dtype.kind == "b": + locs = locs.nonzero()[0] + elif not ( + isinstance(locs, np.ndarray) and locs.dtype.kind == "i" + ): + # Should never be reached, but let's cover our bases + raise NotImplementedError( + "Unexpected get_loc result, please report a bug at " + "https://github.com/pandas-dev/pandas" + ) + + for i, loc in enumerate(locs): + res_loc = res_k.iloc[:, i] + target = self.iloc[:, loc] + + if res_loc.dtype == target.dtype: + result.iloc[:, loc] = res_loc + else: + result.isetitem(loc, res_loc) + return result + + elif not is_list_like(value): + if axis == 1: + result = self.T.fillna(value=value, limit=limit).T + new_data = result._mgr + else: + new_data = self._mgr.fillna(value=value, limit=limit, inplace=inplace) + elif isinstance(value, ABCDataFrame) and self.ndim == 2: + new_data = self.where(self.notna(), value)._mgr + else: + raise ValueError(f"invalid fill value with a {type(value)}") + + result = self._constructor_from_mgr(new_data, axes=new_data.axes) + if inplace: + self._update_inplace(result) + return self + else: + return result.__finalize__(self, method="fillna") + + @final + def ffill( + self, + *, + axis: None | Axis = None, + inplace: bool = False, + limit: None | int = None, + limit_area: Literal["inside", "outside"] | None = None, + ) -> Self: + """ + Fill NA/NaN values by propagating the last valid observation to next valid. + + Parameters + ---------- + axis : {0 or 'index'} for Series, {0 or 'index', 1 or 'columns'} for DataFrame + Axis along which to fill missing values. For `Series` + this parameter is unused and defaults to 0. + inplace : bool, default False + If True, fill in-place. Note: this will modify any + other views on this object (e.g., a no-copy slice for a column in a + DataFrame). + limit : int, default None + If method is specified, this is the maximum number of consecutive + NaN values to forward/backward fill. In other words, if there is + a gap with more than this number of consecutive NaNs, it will only + be partially filled. If method is not specified, this is the + maximum number of entries along the entire axis where NaNs will be + filled. Must be greater than 0 if not None. + limit_area : {{`None`, 'inside', 'outside'}}, default None + If limit is specified, consecutive NaNs will be filled with this + restriction. + + * ``None``: No fill restriction. + * 'inside': Only fill NaNs surrounded by valid values + (interpolate). + * 'outside': Only fill NaNs outside valid values (extrapolate). + + .. versionadded:: 2.2.0 + + Returns + ------- + Series/DataFrame + Object with missing values filled. + + See Also + -------- + DataFrame.bfill : Fill NA/NaN values by using the next valid observation + to fill the gap. + + Examples + -------- + >>> df = pd.DataFrame( + ... [ + ... [np.nan, 2, np.nan, 0], + ... [3, 4, np.nan, 1], + ... [np.nan, np.nan, np.nan, np.nan], + ... [np.nan, 3, np.nan, 4], + ... ], + ... columns=list("ABCD"), + ... ) + >>> df + A B C D + 0 NaN 2.0 NaN 0.0 + 1 3.0 4.0 NaN 1.0 + 2 NaN NaN NaN NaN + 3 NaN 3.0 NaN 4.0 + + >>> df.ffill() + A B C D + 0 NaN 2.0 NaN 0.0 + 1 3.0 4.0 NaN 1.0 + 2 3.0 4.0 NaN 1.0 + 3 3.0 3.0 NaN 4.0 + + >>> ser = pd.Series([1, np.nan, 2, 3]) + >>> ser.ffill() + 0 1.0 + 1 1.0 + 2 2.0 + 3 3.0 + dtype: float64 + """ + inplace = validate_bool_kwarg(inplace, "inplace") + if inplace: + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount( + self + ) <= REF_COUNT_METHOD and not common.is_local_in_caller_frame(self): + warnings.warn( + _chained_assignment_method_msg, + ChainedAssignmentError, + stacklevel=2, + ) + + return self._pad_or_backfill( + "ffill", + axis=axis, + inplace=inplace, + limit=limit, + limit_area=limit_area, + ) + + @final + def bfill( + self, + *, + axis: None | Axis = None, + inplace: bool = False, + limit: None | int = None, + limit_area: Literal["inside", "outside"] | None = None, + ) -> Self: + """ + Fill NA/NaN values by using the next valid observation to fill the gap. + + This method fills missing values in a backward direction along the + specified axis, propagating non-null values from later positions to + earlier positions containing NaN. + + Parameters + ---------- + axis : {0 or 'index'} for Series, {0 or 'index', 1 or 'columns'} for DataFrame + Axis along which to fill missing values. For `Series` + this parameter is unused and defaults to 0. + inplace : bool, default False + If True, fill in-place. Note: this will modify any + other views on this object (e.g., a no-copy slice for a column in a + DataFrame). + limit : int, default None + If method is specified, this is the maximum number of consecutive + NaN values to forward/backward fill. In other words, if there is + a gap with more than this number of consecutive NaNs, it will only + be partially filled. If method is not specified, this is the + maximum number of entries along the entire axis where NaNs will be + filled. Must be greater than 0 if not None. + limit_area : {{`None`, 'inside', 'outside'}}, default None + If limit is specified, consecutive NaNs will be filled with this + restriction. + + * ``None``: No fill restriction. + * 'inside': Only fill NaNs surrounded by valid values + (interpolate). + * 'outside': Only fill NaNs outside valid values (extrapolate). + + .. versionadded:: 2.2.0 + + Returns + ------- + Series/DataFrame + Object with missing values filled. + + See Also + -------- + DataFrame.ffill : Fill NA/NaN values by propagating the last valid + observation to next valid. + + Examples + -------- + For Series: + + >>> s = pd.Series([1, None, None, 2]) + >>> s.bfill() + 0 1.0 + 1 2.0 + 2 2.0 + 3 2.0 + dtype: float64 + >>> s.bfill(limit=1) + 0 1.0 + 1 NaN + 2 2.0 + 3 2.0 + dtype: float64 + + With DataFrame: + + >>> df = pd.DataFrame({"A": [1, None, None, 4], "B": [None, 5, None, 7]}) + >>> df + A B + 0 1.0 NaN + 1 NaN 5.0 + 2 NaN NaN + 3 4.0 7.0 + >>> df.bfill() + A B + 0 1.0 5.0 + 1 4.0 5.0 + 2 4.0 7.0 + 3 4.0 7.0 + >>> df.bfill(limit=1) + A B + 0 1.0 5.0 + 1 NaN 5.0 + 2 4.0 7.0 + 3 4.0 7.0 + """ + inplace = validate_bool_kwarg(inplace, "inplace") + if inplace: + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount( + self + ) <= REF_COUNT_METHOD and not common.is_local_in_caller_frame(self): + warnings.warn( + _chained_assignment_method_msg, + ChainedAssignmentError, + stacklevel=2, + ) + + return self._pad_or_backfill( + "bfill", + axis=axis, + inplace=inplace, + limit=limit, + limit_area=limit_area, + ) + + @final + def replace( + self, + to_replace=None, + value=lib.no_default, + *, + inplace: bool = False, + regex: bool = False, + ) -> Self: + """ + Replace values given in `to_replace` with `value`. + + Values of the Series/DataFrame are replaced with other values dynamically. + This differs from updating with ``.loc`` or ``.iloc``, which require + you to specify a location to update with some value. + + Parameters + ---------- + to_replace : str, regex, list, dict, Series, int, float, or None + How to find the values that will be replaced. + + * numeric, str or regex: + + - numeric: numeric values equal to `to_replace` will be + replaced with `value` + - str: string exactly matching `to_replace` will be replaced + with `value` + - regex: regexes matching `to_replace` will be replaced with + `value` + + * list of str, regex, or numeric: + + - First, if `to_replace` and `value` are both lists, they + **must** be the same length. + - Second, if ``regex=True`` then all of the strings in **both** + lists will be interpreted as regexes otherwise they will match + directly. This doesn't matter much for `value` since there + are only a few possible substitution regexes you can use. + - str, regex and numeric rules apply as above. + + * dict: + + - Dicts can be used to specify different replacement values + for different existing values. For example, + ``{'a': 'b', 'y': 'z'}`` replaces the value 'a' with 'b' and + 'y' with 'z'. To use a dict in this way, the optional `value` + parameter should not be given. + - For a DataFrame a dict can specify that different values + should be replaced in different columns. For example, + ``{'a': 1, 'b': 'z'}`` looks for the value 1 in column 'a' + and the value 'z' in column 'b' and replaces these values + with whatever is specified in `value`. The `value` parameter + should not be ``None`` in this case. You can treat this as a + special case of passing two lists except that you are + specifying the column to search in. + - For a DataFrame nested dictionaries, e.g., + ``{'a': {'b': np.nan}}``, are read as follows: look in column + 'a' for the value 'b' and replace it with NaN. The optional `value` + parameter should not be specified to use a nested dict in this + way. You can nest regular expressions as well. Note that + column names (the top-level dictionary keys in a nested + dictionary) **cannot** be regular expressions. + + * None: + + - This means that the `regex` argument must be a string, + compiled regular expression, or list, dict, ndarray or + Series of such elements. If `value` is also ``None`` then + this **must** be a nested dictionary or Series. + + See the examples section for examples of each of these. + value : scalar, dict, list, str, regex, default None + Value to replace any values matching `to_replace` with. + For a DataFrame a dict of values can be used to specify which + value to use for each column (columns not in the dict will not be + filled). Regular expressions, strings and lists or dicts of such + objects are also allowed. + + inplace : bool, default False + If True, performs operation inplace. + regex : bool or same types as `to_replace`, default False + Whether to interpret `to_replace` and/or `value` as regular + expressions. Alternatively, this could be a regular expression or a + list, dict, or array of regular expressions in which case + `to_replace` must be ``None``. + + Returns + ------- + Series/DataFrame + Object after replacement. + + Raises + ------ + AssertionError + * If `regex` is not a ``bool`` and `to_replace` is not + ``None``. + + TypeError + * If `to_replace` is not a scalar, array-like, ``dict``, or ``None`` + * If `to_replace` is a ``dict`` and `value` is not a ``list``, + ``dict``, ``ndarray``, or ``Series`` + * If `to_replace` is ``None`` and `regex` is not compilable + into a regular expression or is a list, dict, ndarray, or + Series. + * When replacing multiple ``bool`` or ``datetime64`` objects and + the arguments to `to_replace` does not match the type of the + value being replaced + + ValueError + * If a ``list`` or an ``ndarray`` is passed to `to_replace` and + `value` but they are not the same length. + + See Also + -------- + Series.fillna : Fill NA values. + DataFrame.fillna : Fill NA values. + Series.where : Replace values based on boolean condition. + DataFrame.where : Replace values based on boolean condition. + DataFrame.map: Apply a function to a Dataframe elementwise. + Series.map: Map values of Series according to an input mapping or function. + Series.str.replace : Simple string replacement. + + Notes + ----- + * Regex substitution is performed under the hood with ``re.sub``. The + rules for substitution for ``re.sub`` are the same. + * Regular expressions will only substitute on strings, meaning you + cannot provide, for example, a regular expression matching floating + point numbers and expect the columns in your frame that have a + numeric dtype to be matched. However, if those floating point + numbers *are* strings, then you can do this. + * This method has *a lot* of options. You are encouraged to experiment + and play with this method to gain intuition about how it works. + * When dict is used as the `to_replace` value, it is like + key(s) in the dict are the to_replace part and + value(s) in the dict are the value parameter. + + Examples + -------- + + **Scalar `to_replace` and `value`** + + >>> s = pd.Series([1, 2, 3, 4, 5]) + >>> s.replace(1, 5) + 0 5 + 1 2 + 2 3 + 3 4 + 4 5 + dtype: int64 + + >>> df = pd.DataFrame( + ... { + ... "A": [0, 1, 2, 3, 4], + ... "B": [5, 6, 7, 8, 9], + ... "C": ["a", "b", "c", "d", "e"], + ... } + ... ) + >>> df.replace(0, 5) + A B C + 0 5 5 a + 1 1 6 b + 2 2 7 c + 3 3 8 d + 4 4 9 e + + **List-like `to_replace`** + + >>> df.replace([0, 1, 2, 3], 4) + A B C + 0 4 5 a + 1 4 6 b + 2 4 7 c + 3 4 8 d + 4 4 9 e + + >>> df.replace([0, 1, 2, 3], [4, 3, 2, 1]) + A B C + 0 4 5 a + 1 3 6 b + 2 2 7 c + 3 1 8 d + 4 4 9 e + + **dict-like `to_replace`** + + >>> df.replace({0: 10, 1: 100}) + A B C + 0 10 5 a + 1 100 6 b + 2 2 7 c + 3 3 8 d + 4 4 9 e + + >>> df.replace({"A": 0, "B": 5}, 100) + A B C + 0 100 100 a + 1 1 6 b + 2 2 7 c + 3 3 8 d + 4 4 9 e + + >>> df.replace({"A": {0: 100, 4: 400}}) + A B C + 0 100 5 a + 1 1 6 b + 2 2 7 c + 3 3 8 d + 4 400 9 e + + **Regular expression `to_replace`** + + >>> df = pd.DataFrame({"A": ["bat", "foo", "bait"], "B": ["abc", "bar", "xyz"]}) + >>> df.replace(to_replace=r"^ba.$", value="new", regex=True) + A B + 0 new abc + 1 foo new + 2 bait xyz + + >>> df.replace({"A": r"^ba.$"}, {"A": "new"}, regex=True) + A B + 0 new abc + 1 foo bar + 2 bait xyz + + >>> df.replace(regex=r"^ba.$", value="new") + A B + 0 new abc + 1 foo new + 2 bait xyz + + >>> df.replace(regex={r"^ba.$": "new", "foo": "xyz"}) + A B + 0 new abc + 1 xyz new + 2 bait xyz + + >>> df.replace(regex=[r"^ba.$", "foo"], value="new") + A B + 0 new abc + 1 new new + 2 bait xyz + + Compare the behavior of ``s.replace({'a': None})`` and + ``s.replace('a', None)`` to understand the peculiarities + of the `to_replace` parameter: + + >>> s = pd.Series([10, "a", "a", "b", "a"]) + + When one uses a dict as the `to_replace` value, it is like the + value(s) in the dict are equal to the `value` parameter. + ``s.replace({'a': None})`` is equivalent to + ``s.replace(to_replace={'a': None}, value=None)``: + + >>> s.replace({"a": None}) + 0 10 + 1 None + 2 None + 3 b + 4 None + dtype: object + + If ``None`` is explicitly passed for ``value``, it will be respected: + + >>> s.replace("a", None) + 0 10 + 1 None + 2 None + 3 b + 4 None + dtype: object + + When ``regex=True``, ``value`` is not ``None`` and `to_replace` is a string, + the replacement will be applied in all columns of the DataFrame. + + >>> df = pd.DataFrame( + ... { + ... "A": [0, 1, 2, 3, 4], + ... "B": ["a", "b", "c", "d", "e"], + ... "C": ["f", "g", "h", "i", "j"], + ... } + ... ) + + >>> df.replace(to_replace="^[a-g]", value="e", regex=True) + A B C + 0 0 e e + 1 1 e e + 2 2 e h + 3 3 e i + 4 4 e j + + If ``value`` is not ``None`` and `to_replace` is a dictionary, the dictionary + keys will be the DataFrame columns that the replacement will be applied. + + >>> df.replace(to_replace={"B": "^[a-c]", "C": "^[h-j]"}, value="e", regex=True) + A B C + 0 0 e f + 1 1 e g + 2 2 e e + 3 3 d e + 4 4 e e + """ + if not is_bool(regex) and to_replace is not None: + raise ValueError("'to_replace' must be 'None' if 'regex' is not a bool") + + if not ( + is_scalar(to_replace) + or is_re_compilable(to_replace) + or is_list_like(to_replace) + ): + raise TypeError( + "Expecting 'to_replace' to be either a scalar, array-like, " + "dict or None, got invalid type " + f"{type(to_replace).__name__!r}" + ) + + if value is lib.no_default and not ( + is_dict_like(to_replace) or is_dict_like(regex) + ): + raise ValueError( + # GH#33302 + f"{type(self).__name__}.replace must specify either 'value', " + "a dict-like 'to_replace', or dict-like 'regex'." + ) + + inplace = validate_bool_kwarg(inplace, "inplace") + if inplace: + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount( + self + ) <= REF_COUNT_METHOD and not common.is_local_in_caller_frame(self): + warnings.warn( + _chained_assignment_method_msg, + ChainedAssignmentError, + stacklevel=2, + ) + + if value is lib.no_default: + if not is_dict_like(to_replace): + # In this case we have checked above that + # 1) regex is dict-like and 2) to_replace is None + to_replace = regex + regex = True + + items = list(to_replace.items()) + if items: + keys, values = zip(*items, strict=True) + else: + keys, values = ([], []) # type: ignore[assignment] + + are_mappings = [is_dict_like(v) for v in values] + + if any(are_mappings): + if not all(are_mappings): + raise TypeError( + "If a nested mapping is passed, all values " + "of the top level mapping must be mappings" + ) + # passed a nested dict/Series + to_rep_dict = {} + value_dict = {} + + for k, v in items: + # error: Incompatible types in assignment (expression has type + # "list[Never]", variable has type "tuple[Any, ...]") + keys, values = list(zip(*v.items(), strict=True)) or ( # type: ignore[assignment] + [], + [], + ) + + to_rep_dict[k] = list(keys) + value_dict[k] = list(values) + + to_replace, value = to_rep_dict, value_dict + else: + to_replace, value = keys, values + + return self.replace(to_replace, value, inplace=inplace, regex=regex) + else: + # need a non-zero len on all axes + if not self.size: + return self if inplace else self.copy(deep=False) + if is_dict_like(to_replace): + if is_dict_like(value): # {'A' : NA} -> {'A' : 0} + if isinstance(self, ABCSeries): + raise ValueError( + "to_replace and value cannot be dict-like for " + "Series.replace" + ) + # Note: Checking below for `in foo.keys()` instead of + # `in foo` is needed for when we have a Series and not dict + mapping = { + col: (to_replace[col], value[col]) + for col in to_replace.keys() + if col in value.keys() and col in self + } + return self._replace_columnwise(mapping, inplace, regex) + + # {'A': NA} -> 0 + elif not is_list_like(value): + # Operate column-wise + if self.ndim == 1: + raise ValueError( + "Series.replace cannot specify both a dict-like " + "'to_replace' and a 'value'" + ) + mapping = { + col: (to_rep, value) for col, to_rep in to_replace.items() + } + return self._replace_columnwise(mapping, inplace, regex) + else: + raise TypeError("value argument must be scalar, dict, or Series") + + elif is_list_like(to_replace): + if not is_list_like(value): + # e.g. to_replace = [NA, ''] and value is 0, + # so we replace NA with 0 and then replace '' with 0 + value = [value] * len(to_replace) + + # e.g. we have to_replace = [NA, ''] and value = [0, 'missing'] + if len(to_replace) != len(value): + raise ValueError( + f"Replacement lists must match in length. " + f"Expecting {len(to_replace)} got {len(value)} " + ) + new_data = self._mgr.replace_list( + src_list=to_replace, + dest_list=value, + inplace=inplace, + regex=regex, + ) + + elif to_replace is None: + if not ( + is_re_compilable(regex) + or is_list_like(regex) + or is_dict_like(regex) + ): + raise TypeError( + f"'regex' must be a string or a compiled regular expression " + f"or a list or dict of strings or regular expressions, " + f"you passed a {type(regex).__name__!r}" + ) + return self.replace(regex, value, inplace=inplace, regex=True) + # dest iterable dict-like + elif is_dict_like(value): # NA -> {'A' : 0, 'B' : -1} + # Operate column-wise + if self.ndim == 1: + raise ValueError( + "Series.replace cannot use dict-value and non-None to_replace" + ) + mapping = {col: (to_replace, val) for col, val in value.items()} + return self._replace_columnwise(mapping, inplace, regex) + + elif not is_list_like(value): # NA -> 0 + regex = should_use_regex(regex, to_replace) + if regex: + new_data = self._mgr.replace_regex( + to_replace=to_replace, + value=value, + inplace=inplace, + ) + else: + new_data = self._mgr.replace( + to_replace=to_replace, value=value, inplace=inplace + ) + else: + raise TypeError( + f'Invalid "to_replace" type: {type(to_replace).__name__!r}' + ) + + result = self._constructor_from_mgr(new_data, axes=new_data.axes) + if inplace: + self._update_inplace(result) + return self + else: + return result.__finalize__(self, method="replace") + + @final + def interpolate( + self, + method: InterpolateOptions = "linear", + *, + axis: Axis = 0, + limit: int | None = None, + inplace: bool = False, + limit_direction: Literal["forward", "backward", "both"] | None = None, + limit_area: Literal["inside", "outside"] | None = None, + **kwargs, + ) -> Self: + """ + Fill NaN values using an interpolation method. + + Please note that only ``method='linear'`` is supported for + DataFrame/Series with a MultiIndex. + + Parameters + ---------- + method : str, default 'linear' + Interpolation technique to use. One of: + + * 'linear': Ignore the index and treat the values as equally + spaced. This is the only method supported on MultiIndexes. + * 'time': Works on daily and higher resolution data to interpolate + given length of interval. This interpolates values based on + time interval between observations. + * 'index': The interpolation uses the numerical values + of the DataFrame's index to linearly calculate missing values. + * 'values': Interpolation based on the numerical values + in the DataFrame, treating them as equally spaced along the index. + * 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', + 'barycentric', 'polynomial': Passed to + `scipy.interpolate.interp1d`, whereas 'spline' is passed to + `scipy.interpolate.UnivariateSpline`. These methods use the numerical + values of the index. Both 'polynomial' and 'spline' require that + you also specify an `order` (int), e.g. + ``df.interpolate(method='polynomial', order=5)``. Note that, + `slinear` method in Pandas refers to the Scipy first order `spline` + instead of Pandas first order `spline`. + * 'krogh', 'piecewise_polynomial', 'spline', 'pchip', 'akima', + 'cubicspline': Wrappers around the SciPy interpolation methods of + similar names. See `Notes`. + * 'from_derivatives': Refers to + `scipy.interpolate.BPoly.from_derivatives`. + + axis : {{0 or 'index', 1 or 'columns', None}}, default None + Axis to interpolate along. For `Series` this parameter is unused + and defaults to 0. + limit : int, optional + Maximum number of consecutive NaNs to fill. Must be greater than + 0. + inplace : bool, default False + Update the data in place if possible. + limit_direction : {{'forward', 'backward', 'both'}}, optional, default 'forward' + Consecutive NaNs will be filled in this direction. + + limit_area : {{`None`, 'inside', 'outside'}}, default None + If limit is specified, consecutive NaNs will be filled with this + restriction. + + * ``None``: No fill restriction. + * 'inside': Only fill NaNs surrounded by valid values + (interpolate). + * 'outside': Only fill NaNs outside valid values (extrapolate). + + **kwargs : optional + Keyword arguments to pass on to the interpolating function. + + Returns + ------- + Series or DataFrame + Returns the same object type as the caller, interpolated at + some or all ``NaN`` values. + + See Also + -------- + fillna : Fill missing values using different methods. + scipy.interpolate.Akima1DInterpolator : Piecewise cubic polynomials + (Akima interpolator). + scipy.interpolate.BPoly.from_derivatives : Piecewise polynomial in the + Bernstein basis. + scipy.interpolate.interp1d : Interpolate a 1-D function. + scipy.interpolate.KroghInterpolator : Interpolate polynomial (Krogh + interpolator). + scipy.interpolate.PchipInterpolator : PCHIP 1-d monotonic cubic + interpolation. + scipy.interpolate.CubicSpline : Cubic spline data interpolator. + + Notes + ----- + The 'krogh', 'piecewise_polynomial', 'spline', 'pchip' and 'akima' + methods are wrappers around the respective SciPy implementations of + similar names. These use the actual numerical values of the index. + For more information on their behavior, see the + `SciPy documentation + `__. + + Examples + -------- + Filling in ``NaN`` in a :class:`~pandas.Series` via linear + interpolation. + + >>> s = pd.Series([0, 1, np.nan, 3]) + >>> s + 0 0.0 + 1 1.0 + 2 NaN + 3 3.0 + dtype: float64 + >>> s.interpolate() + 0 0.0 + 1 1.0 + 2 2.0 + 3 3.0 + dtype: float64 + + Filling in ``NaN`` in a Series via polynomial interpolation or splines: + Both 'polynomial' and 'spline' methods require that you also specify + an ``order`` (int). + + >>> s = pd.Series([0, 2, np.nan, 8]) + >>> s.interpolate(method="polynomial", order=2) + 0 0.000000 + 1 2.000000 + 2 4.666667 + 3 8.000000 + dtype: float64 + + Fill the DataFrame forward (that is, going down) along each column + using linear interpolation. + + Note how the last entry in column 'a' is interpolated differently, + because there is no entry after it to use for interpolation. + Note how the first entry in column 'b' remains ``NaN``, because there + is no entry before it to use for interpolation. + + >>> df = pd.DataFrame( + ... [ + ... (0.0, np.nan, -1.0, 1.0), + ... (np.nan, 2.0, np.nan, np.nan), + ... (2.0, 3.0, np.nan, 9.0), + ... (np.nan, 4.0, -4.0, 16.0), + ... ], + ... columns=list("abcd"), + ... ) + >>> df + a b c d + 0 0.0 NaN -1.0 1.0 + 1 NaN 2.0 NaN NaN + 2 2.0 3.0 NaN 9.0 + 3 NaN 4.0 -4.0 16.0 + >>> df.interpolate(method="linear", limit_direction="forward", axis=0) + a b c d + 0 0.0 NaN -1.0 1.0 + 1 1.0 2.0 -2.0 5.0 + 2 2.0 3.0 -3.0 9.0 + 3 2.0 4.0 -4.0 16.0 + + Using polynomial interpolation. + + >>> df["d"].interpolate(method="polynomial", order=2) + 0 1.0 + 1 4.0 + 2 9.0 + 3 16.0 + Name: d, dtype: float64 + """ + inplace = validate_bool_kwarg(inplace, "inplace") + + if inplace: + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount( + self + ) <= REF_COUNT_METHOD and not common.is_local_in_caller_frame(self): + warnings.warn( + _chained_assignment_method_msg, + ChainedAssignmentError, + stacklevel=2, + ) + + axis = self._get_axis_number(axis) + + if self.empty: + return self if inplace else self.copy() + + if not isinstance(method, str): + raise ValueError("'method' should be a string, not None.") + + obj, should_transpose = (self.T, True) if axis == 1 else (self, False) + + if isinstance(obj.index, MultiIndex) and method != "linear": + raise ValueError( + "Only `method=linear` interpolation is supported on MultiIndexes." + ) + + limit_direction = missing.infer_limit_direction(limit_direction, method) + + index = missing.get_interp_index(method, obj.index) + new_data = obj._mgr.interpolate( + method=method, + index=index, + limit=limit, + limit_direction=limit_direction, + limit_area=limit_area, + inplace=inplace, + **kwargs, + ) + + result = self._constructor_from_mgr(new_data, axes=new_data.axes) + if should_transpose: + result = result.T + if inplace: + self._update_inplace(result) + return self + else: + return result.__finalize__(self, method="interpolate") + + # ---------------------------------------------------------------------- + # Timeseries methods Methods + + @final + def asof(self, where, subset=None): + """ + Return the last row(s) without any NaNs before `where`. + + The last row (for each element in `where`, if list) without any + NaN is taken. + In case of a :class:`~pandas.DataFrame`, the last row without NaN + considering only the subset of columns (if not `None`) + + If there is no good value, NaN is returned for a Series or + a Series of NaN values for a DataFrame + + Parameters + ---------- + where : date or array-like of dates + Date(s) before which the last row(s) are returned. + subset : str or array-like of str, default `None` + For DataFrame, if not `None`, only use these columns to + check for NaNs. + + Returns + ------- + scalar, Series, or DataFrame + + The return can be: + + * scalar : when `self` is a Series and `where` is a scalar + * Series: when `self` is a Series and `where` is an array-like, + or when `self` is a DataFrame and `where` is a scalar + * DataFrame : when `self` is a DataFrame and `where` is an + array-like + + See Also + -------- + merge_asof : Perform an asof merge. Similar to left join. + + Notes + ----- + Dates are assumed to be sorted. Raises if this is not the case. + + Examples + -------- + A Series and a scalar `where`. + + >>> s = pd.Series([1, 2, np.nan, 4], index=[10, 20, 30, 40]) + >>> s + 10 1.0 + 20 2.0 + 30 NaN + 40 4.0 + dtype: float64 + + >>> s.asof(20) + np.float64(2.0) + + For a sequence `where`, a Series is returned. The first value is + NaN, because the first element of `where` is before the first + index value. + + >>> s.asof([5, 20]) + 5 NaN + 20 2.0 + dtype: float64 + + Missing values are not considered. The following is ``2.0``, not + NaN, even though NaN is at the index location for ``30``. + + >>> s.asof(30) + np.float64(2.0) + + Take all columns into consideration + + >>> df = pd.DataFrame( + ... { + ... "a": [10.0, 20.0, 30.0, 40.0, 50.0], + ... "b": [None, None, None, None, 500], + ... }, + ... index=pd.DatetimeIndex( + ... [ + ... "2018-02-27 09:01:00", + ... "2018-02-27 09:02:00", + ... "2018-02-27 09:03:00", + ... "2018-02-27 09:04:00", + ... "2018-02-27 09:05:00", + ... ] + ... ), + ... ) + >>> df.asof(pd.DatetimeIndex(["2018-02-27 09:03:30", "2018-02-27 09:04:30"])) + a b + 2018-02-27 09:03:30 NaN NaN + 2018-02-27 09:04:30 NaN NaN + + Take a single column into consideration + + >>> df.asof( + ... pd.DatetimeIndex(["2018-02-27 09:03:30", "2018-02-27 09:04:30"]), + ... subset=["a"], + ... ) + a b + 2018-02-27 09:03:30 30.0 NaN + 2018-02-27 09:04:30 40.0 NaN + """ + if isinstance(where, str): + where = Timestamp(where) + + if not self.index.is_monotonic_increasing: + raise ValueError("asof requires a sorted index") + + is_series = isinstance(self, ABCSeries) + if is_series: + if subset is not None: + raise ValueError("subset is not valid for Series") + else: + if subset is None: + subset = self.columns + if not is_list_like(subset): + subset = [subset] + + is_list = is_list_like(where) + if not is_list: + start = self.index[0] + if isinstance(self.index, PeriodIndex): + where = Period(where, freq=self.index.freq) + + if where < start: + if not is_series: + return self._constructor_sliced( + index=self.columns, name=where, dtype=np.float64 + ) + return np.nan + + # It's always much faster to use a *while* loop here for + # Series than pre-computing all the NAs. However a + # *while* loop is extremely expensive for DataFrame + # so we later pre-compute all the NAs and use the same + # code path whether *where* is a scalar or list. + # See PR: https://github.com/pandas-dev/pandas/pull/14476 + if is_series: + loc = self.index.searchsorted(where, side="right") + if loc > 0: + loc -= 1 + + values = self._values + while loc > 0 and isna(values[loc]): + loc -= 1 + return values[loc] + + if not isinstance(where, Index): + where = Index(where) if is_list else Index([where]) + + nulls = self.isna() if is_series else self[subset].isna().any(axis=1) + if nulls.all(): + if is_series: + self = cast("Series", self) + return self._constructor(np.nan, index=where, name=self.name) + elif is_list: + self = cast("DataFrame", self) + return self._constructor(np.nan, index=where, columns=self.columns) + else: + self = cast("DataFrame", self) + return self._constructor_sliced( + np.nan, index=self.columns, name=where[0] + ) + + # error: Unsupported operand type for + # ~ ("ExtensionArray | ndarray[Any, Any] | Any") + locs = self.index.asof_locs(where, ~nulls._values) # type: ignore[operator] + + # mask the missing + mask = locs == -1 + data = self.take(locs) + data.index = where + if mask.any(): + # GH#16063 only do this setting when necessary, otherwise + # we'd cast e.g. bools to floats + data.loc[mask] = np.nan + return data if is_list else data.iloc[-1] + + # ---------------------------------------------------------------------- + # Action Methods + + def isna(self) -> Self: + """ + Detect missing values. + + Return a boolean same-sized object indicating if the values are NA. + NA values, such as None or :attr:`numpy.NaN`, gets mapped to True + values. + Everything else gets mapped to False values. Characters such as empty + strings ``''`` or :attr:`numpy.inf` are not considered NA values. + + Returns + ------- + Series/DataFrame + Mask of bool values for each element in Series/DataFrame + that indicates whether an element is an NA value. + + See Also + -------- + Series.isnull : Alias of isna. + DataFrame.isnull : Alias of isna. + Series.notna : Boolean inverse of isna. + DataFrame.notna : Boolean inverse of isna. + Series.dropna : Omit axes labels with missing values. + DataFrame.dropna : Omit axes labels with missing values. + isna : Top-level isna. + + Examples + -------- + Show which entries in a DataFrame are NA. + + >>> df = pd.DataFrame( + ... dict( + ... age=[5, 6, np.nan], + ... born=[ + ... pd.NaT, + ... pd.Timestamp("1939-05-27"), + ... pd.Timestamp("1940-04-25"), + ... ], + ... name=["Alfred", "Batman", ""], + ... toy=[None, "Batmobile", "Joker"], + ... ) + ... ) + >>> df + age born name toy + 0 5.0 NaT Alfred NaN + 1 6.0 1939-05-27 Batman Batmobile + 2 NaN 1940-04-25 Joker + + >>> df.isna() + age born name toy + 0 False True False True + 1 False False False False + 2 True False False False + + Show which entries in a Series are NA. + + >>> ser = pd.Series([5, 6, np.nan]) + >>> ser + 0 5.0 + 1 6.0 + 2 NaN + dtype: float64 + + >>> ser.isna() + 0 False + 1 False + 2 True + dtype: bool + """ + return isna(self).__finalize__(self, method="isna") + + def isnull(self) -> Self: + """ + Detect missing values. + + Return a boolean same-sized object indicating if the values are NA. + NA values, such as None or :attr:`numpy.NaN`, gets mapped to True + values. + Everything else gets mapped to False values. Characters such as empty + strings ``''`` or :attr:`numpy.inf` are not considered NA values. + + Returns + ------- + Series/DataFrame + Mask of bool values for each element in Series/DataFrame + that indicates whether an element is an NA value. + + See Also + -------- + Series.isna : Alias of isnull. + DataFrame.isna : Alias of isnull. + Series.notna : Boolean inverse of isnull. + DataFrame.notna : Boolean inverse of isnull. + Series.dropna : Omit axes labels with missing values. + DataFrame.dropna : Omit axes labels with missing values. + isna : Top-level isna. + + Examples + -------- + Show which entries in a DataFrame are NA. + + >>> df = pd.DataFrame( + ... dict( + ... age=[5, 6, np.nan], + ... born=[ + ... pd.NaT, + ... pd.Timestamp("1939-05-27"), + ... pd.Timestamp("1940-04-25"), + ... ], + ... name=["Alfred", "Batman", ""], + ... toy=[None, "Batmobile", "Joker"], + ... ) + ... ) + >>> df + age born name toy + 0 5.0 NaT Alfred NaN + 1 6.0 1939-05-27 Batman Batmobile + 2 NaN 1940-04-25 Joker + + >>> df.isna() + age born name toy + 0 False True False True + 1 False False False False + 2 True False False False + + Show which entries in a Series are NA. + + >>> ser = pd.Series([5, 6, np.nan]) + >>> ser + 0 5.0 + 1 6.0 + 2 NaN + dtype: float64 + + >>> ser.isna() + 0 False + 1 False + 2 True + dtype: bool + """ + return isna(self).__finalize__(self, method="isnull") + + def notna(self) -> Self: + """ + Detect existing (non-missing) values. + + Return a boolean same-sized object indicating if the values are not NA. + Non-missing values get mapped to True. Characters such as empty + strings ``''`` or :attr:`numpy.inf` are not considered NA values. + NA values, such as None or :attr:`numpy.NaN`, get mapped to False + values. + + Returns + ------- + Series/DataFrame + Mask of bool values for each element in Series/DataFrame + that indicates whether an element is not an NA value. + + See Also + -------- + Series.notnull : Alias of notna. + DataFrame.notnull : Alias of notna. + Series.isna : Boolean inverse of notna. + DataFrame.isna : Boolean inverse of notna. + Series.dropna : Omit axes labels with missing values. + DataFrame.dropna : Omit axes labels with missing values. + notna : Top-level notna. + + Examples + -------- + Show which entries in a DataFrame are not NA. + + >>> df = pd.DataFrame( + ... dict( + ... age=[5, 6, np.nan], + ... born=[ + ... pd.NaT, + ... pd.Timestamp("1939-05-27"), + ... pd.Timestamp("1940-04-25"), + ... ], + ... name=["Alfred", "Batman", ""], + ... toy=[None, "Batmobile", "Joker"], + ... ) + ... ) + >>> df + age born name toy + 0 5.0 NaT Alfred NaN + 1 6.0 1939-05-27 Batman Batmobile + 2 NaN 1940-04-25 Joker + + >>> df.notna() + age born name toy + 0 True False True False + 1 True True True True + 2 False True True True + + Show which entries in a Series are not NA. + + >>> ser = pd.Series([5, 6, np.nan]) + >>> ser + 0 5.0 + 1 6.0 + 2 NaN + dtype: float64 + + >>> ser.notna() + 0 True + 1 True + 2 False + dtype: bool + """ + return notna(self).__finalize__(self, method="notna") + + def notnull(self) -> Self: + """ + Detect existing (non-missing) values. + + Return a boolean same-sized object indicating if the values are not NA. + Non-missing values get mapped to True. Characters such as empty + strings ``''`` or :attr:`numpy.inf` are not considered NA values. + NA values, such as None or :attr:`numpy.NaN`, get mapped to False + values. + + Returns + ------- + Series/DataFrame + Mask of bool values for each element in Series/DataFrame + that indicates whether an element is not an NA value. + + See Also + -------- + Series.notnull : Alias of notna. + DataFrame.notnull : Alias of notna. + Series.isna : Boolean inverse of notna. + DataFrame.isna : Boolean inverse of notna. + Series.dropna : Omit axes labels with missing values. + DataFrame.dropna : Omit axes labels with missing values. + notna : Top-level notna. + + Examples + -------- + Show which entries in a DataFrame are not NA. + + >>> df = pd.DataFrame( + ... dict( + ... age=[5, 6, np.nan], + ... born=[ + ... pd.NaT, + ... pd.Timestamp("1939-05-27"), + ... pd.Timestamp("1940-04-25"), + ... ], + ... name=["Alfred", "Batman", ""], + ... toy=[None, "Batmobile", "Joker"], + ... ) + ... ) + >>> df + age born name toy + 0 5.0 NaT Alfred NaN + 1 6.0 1939-05-27 Batman Batmobile + 2 NaN 1940-04-25 Joker + + >>> df.notna() + age born name toy + 0 True False True False + 1 True True True True + 2 False True True True + + Show which entries in a Series are not NA. + + >>> ser = pd.Series([5, 6, np.nan]) + >>> ser + 0 5.0 + 1 6.0 + 2 NaN + dtype: float64 + + >>> ser.notna() + 0 True + 1 True + 2 False + dtype: bool + """ + return notna(self).__finalize__(self, method="notnull") + + @final + def _clip_with_scalar(self, lower, upper, inplace: bool = False): + if (lower is not None and np.any(isna(lower))) or ( + upper is not None and np.any(isna(upper)) + ): + raise ValueError("Cannot use an NA value as a clip threshold") + + result = self + mask = self.isna() + + if lower is not None: + cond = mask | (self >= lower) + result = result.where(cond, lower, inplace=inplace) + if upper is not None: + cond = mask | (self <= upper) + result = result.where(cond, upper, inplace=inplace) + + return result + + @final + def _clip_with_one_bound(self, threshold, method, axis, inplace): + if axis is not None: + axis = self._get_axis_number(axis) + + # method is self.le for upper bound and self.ge for lower bound + if is_scalar(threshold) and is_number(threshold): + if method.__name__ == "le": + return self._clip_with_scalar(None, threshold, inplace=inplace) + return self._clip_with_scalar(threshold, None, inplace=inplace) + + # GH #15390 + # In order for where method to work, the threshold must + # be transformed to NDFrame from other array like structure. + if (not isinstance(threshold, ABCSeries)) and is_list_like(threshold): + if isinstance(self, ABCSeries): + threshold = self._constructor(threshold, index=self.index) + else: + threshold = self._align_for_op(threshold, axis, flex=None)[1] + + # GH 40420 + # Treat missing thresholds as no bounds, not clipping the values + if is_list_like(threshold): + fill_value = np.inf if method.__name__ == "le" else -np.inf + threshold_inf = threshold.fillna(fill_value) + else: + threshold_inf = threshold + + subset = method(threshold_inf, axis=axis) | isna(self) + + # GH 40420 + return self.where(subset, threshold, axis=axis, inplace=inplace) + + @final + def clip( + self, + lower=None, + upper=None, + *, + axis: Axis | None = None, + inplace: bool = False, + **kwargs, + ) -> Self: + """ + Trim values at input threshold(s). + + Assigns values outside boundary to boundary values. Thresholds + can be singular values or array like, and in the latter case + the clipping is performed element-wise in the specified axis. + + Parameters + ---------- + lower : float or array-like, default None + Minimum threshold value. All values below this + threshold will be set to it. A missing + threshold (e.g `NA`) will not clip the value. + upper : float or array-like, default None + Maximum threshold value. All values above this + threshold will be set to it. A missing + threshold (e.g `NA`) will not clip the value. + axis : {{0 or 'index', 1 or 'columns', None}}, default None + Align object with lower and upper along the given axis. + For `Series` this parameter is unused and defaults to `None`. + inplace : bool, default False + Whether to perform the operation in place on the data. + **kwargs + Additional keywords have no effect but might be accepted + for compatibility with numpy. + + Returns + ------- + Series or DataFrame + Same type as calling object with the values outside the + clip boundaries replaced. + + See Also + -------- + Series.clip : Trim values at input threshold in series. + DataFrame.clip : Trim values at input threshold in DataFrame. + numpy.clip : Clip (limit) the values in an array. + + Examples + -------- + >>> data = {"col_0": [9, -3, 0, -1, 5], "col_1": [-2, -7, 6, 8, -5]} + >>> df = pd.DataFrame(data) + >>> df + col_0 col_1 + 0 9 -2 + 1 -3 -7 + 2 0 6 + 3 -1 8 + 4 5 -5 + + Clips per column using lower and upper thresholds: + + >>> df.clip(-4, 6) + col_0 col_1 + 0 6 -2 + 1 -3 -4 + 2 0 6 + 3 -1 6 + 4 5 -4 + + Clips using specific lower and upper thresholds per column: + + >>> df.clip([-2, -1], [4, 5]) + col_0 col_1 + 0 4 -1 + 1 -2 -1 + 2 0 5 + 3 -1 5 + 4 4 -1 + + Clips using specific lower and upper thresholds per column element: + + >>> t = pd.Series([2, -4, -1, 6, 3]) + >>> t + 0 2 + 1 -4 + 2 -1 + 3 6 + 4 3 + dtype: int64 + + >>> df.clip(t, t + 4, axis=0) + col_0 col_1 + 0 6 2 + 1 -3 -4 + 2 0 3 + 3 6 8 + 4 5 3 + + Clips using specific lower threshold per column element, with missing values: + + >>> t = pd.Series([2, -4, np.nan, 6, 3]) + >>> t + 0 2.0 + 1 -4.0 + 2 NaN + 3 6.0 + 4 3.0 + dtype: float64 + + >>> df.clip(t, axis=0) + col_0 col_1 + 0 9.0 2.0 + 1 -3.0 -4.0 + 2 0.0 6.0 + 3 6.0 8.0 + 4 5.0 3.0 + """ + inplace = validate_bool_kwarg(inplace, "inplace") + + if inplace: + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount( + self + ) <= REF_COUNT_METHOD and not common.is_local_in_caller_frame(self): + warnings.warn( + _chained_assignment_method_msg, + ChainedAssignmentError, + stacklevel=2, + ) + + axis = nv.validate_clip_with_axis(axis, (), kwargs) + if axis is not None: + axis = self._get_axis_number(axis) + + # GH 17276 + # numpy doesn't like NaN as a clip value + # so ignore + # GH 19992 + # numpy doesn't drop a list-like bound containing NaN + isna_lower = isna(lower) + if not is_list_like(lower): + if np.any(isna_lower): + lower = None + elif np.all(isna_lower): + lower = None + isna_upper = isna(upper) + if not is_list_like(upper): + if np.any(isna_upper): + upper = None + elif np.all(isna_upper): + upper = None + + # GH 2747 (arguments were reversed) + if ( + lower is not None + and upper is not None + and is_scalar(lower) + and is_scalar(upper) + ): + lower, upper = min(lower, upper), max(lower, upper) + + # fast-path for scalars + if (lower is None or is_number(lower)) and (upper is None or is_number(upper)): + return self._clip_with_scalar(lower, upper, inplace=inplace) + + result = self + if lower is not None: + result = result._clip_with_one_bound( + lower, method=self.ge, axis=axis, inplace=inplace + ) + if upper is not None: + if inplace: + result = self + result = result._clip_with_one_bound( + upper, method=self.le, axis=axis, inplace=inplace + ) + + return result + + @final + def asfreq( + self, + freq: Frequency, + method: FillnaOptions | None = None, + how: Literal["start", "end"] | None = None, + normalize: bool = False, + fill_value: Hashable | None = None, + ) -> Self: + """ + Convert time series to specified frequency. + + Returns the original data conformed to a new index with the specified + frequency. + + If the index of this Series/DataFrame is a :class:`~pandas.PeriodIndex`, the + new index is the result of transforming the original index with + :meth:`PeriodIndex.asfreq ` (so the original index + will map one-to-one to the new index). + + Otherwise, the new index will be equivalent to ``pd.date_range(start, end, + freq=freq)`` where ``start`` and ``end`` are, respectively, the min and + max entries in the original index (see :func:`pandas.date_range`). The + values corresponding to any timesteps in the new index which were not present + in the original index will be null (``NaN``), unless a method for filling + such unknowns is provided (see the ``method`` parameter below). + + The :meth:`resample` method is more appropriate if an operation on each group of + timesteps (such as an aggregate) is necessary to represent the data at the new + frequency. + + Parameters + ---------- + freq : DateOffset or str + Frequency DateOffset or string. + method : {{'backfill'/'bfill', 'pad'/'ffill'}}, default None + Method to use for filling holes in reindexed Series (note this + does not fill NaNs that already were present): + + * 'pad' / 'ffill': propagate last valid observation forward to next + valid based on the order of the index + * 'backfill' / 'bfill': use NEXT valid observation to fill. + how : {{'start', 'end'}}, default end + For PeriodIndex only (see PeriodIndex.asfreq). + normalize : bool, default False + Whether to reset output index to midnight. + fill_value : scalar, optional + Value to use for missing values, applied during upsampling (note + this does not fill NaNs that already were present). + + Returns + ------- + Series/DataFrame + Series/DataFrame object reindexed to the specified frequency. + + See Also + -------- + reindex : Conform DataFrame to new index with optional filling logic. + + Notes + ----- + To learn more about the frequency strings, please see + :ref:`this link`. + + Examples + -------- + Start by creating a series with 4 one minute timestamps. + + >>> index = pd.date_range("1/1/2000", periods=4, freq="min") + >>> series = pd.Series([0.0, None, 2.0, 3.0], index=index) + >>> df = pd.DataFrame({"s": series}) + >>> df + s + 2000-01-01 00:00:00 0.0 + 2000-01-01 00:01:00 NaN + 2000-01-01 00:02:00 2.0 + 2000-01-01 00:03:00 3.0 + + Upsample the series into 30 second bins. + + >>> df.asfreq(freq="30s") + s + 2000-01-01 00:00:00 0.0 + 2000-01-01 00:00:30 NaN + 2000-01-01 00:01:00 NaN + 2000-01-01 00:01:30 NaN + 2000-01-01 00:02:00 2.0 + 2000-01-01 00:02:30 NaN + 2000-01-01 00:03:00 3.0 + + Upsample again, providing a ``fill value``. + + >>> df.asfreq(freq="30s", fill_value=9.0) + s + 2000-01-01 00:00:00 0.0 + 2000-01-01 00:00:30 9.0 + 2000-01-01 00:01:00 NaN + 2000-01-01 00:01:30 9.0 + 2000-01-01 00:02:00 2.0 + 2000-01-01 00:02:30 9.0 + 2000-01-01 00:03:00 3.0 + + Upsample again, providing a ``method``. + + >>> df.asfreq(freq="30s", method="bfill") + s + 2000-01-01 00:00:00 0.0 + 2000-01-01 00:00:30 NaN + 2000-01-01 00:01:00 NaN + 2000-01-01 00:01:30 2.0 + 2000-01-01 00:02:00 2.0 + 2000-01-01 00:02:30 3.0 + 2000-01-01 00:03:00 3.0 + """ + from pandas.core.resample import asfreq + + return asfreq( + self, + freq, + method=method, + how=how, + normalize=normalize, + fill_value=fill_value, + ) + + @final + def at_time(self, time, asof: bool = False, axis: Axis | None = None) -> Self: + """ + Select values at particular time of day (e.g., 9:30AM). + + Parameters + ---------- + time : datetime.time or str + The values to select. + asof : bool, default False + This parameter is currently not supported. + axis : {0 or 'index', 1 or 'columns'}, default 0 + For `Series` this parameter is unused and defaults to 0. + + Returns + ------- + Series or DataFrame + The values with the specified time. + + Raises + ------ + TypeError + If the index is not a :class:`DatetimeIndex` + + See Also + -------- + between_time : Select values between particular times of the day. + first : Select initial periods of time series based on a date offset. + last : Select final periods of time series based on a date offset. + DatetimeIndex.indexer_at_time : Get just the index locations for + values at particular time of the day. + + Examples + -------- + >>> i = pd.date_range("2018-04-09", periods=4, freq="12h") + >>> ts = pd.DataFrame({"A": [1, 2, 3, 4]}, index=i) + >>> ts + A + 2018-04-09 00:00:00 1 + 2018-04-09 12:00:00 2 + 2018-04-10 00:00:00 3 + 2018-04-10 12:00:00 4 + + >>> ts.at_time("12:00") + A + 2018-04-09 12:00:00 2 + 2018-04-10 12:00:00 4 + """ + if axis is None: + axis = 0 + axis = self._get_axis_number(axis) + + index = self._get_axis(axis) + + if not isinstance(index, DatetimeIndex): + raise TypeError("Index must be DatetimeIndex") + + indexer = index.indexer_at_time(time, asof=asof) + return self.take(indexer, axis=axis) + + @final + def between_time( + self, + start_time, + end_time, + inclusive: IntervalClosedType = "both", + axis: Axis | None = None, + ) -> Self: + """ + Select values between particular times of the day (e.g., 9:00-9:30 AM). + + By setting ``start_time`` to be later than ``end_time``, + you can get the times that are *not* between the two times. + + Parameters + ---------- + start_time : datetime.time or str + Initial time as a time filter limit. + end_time : datetime.time or str + End time as a time filter limit. + inclusive : {"both", "neither", "left", "right"}, default "both" + Include boundaries; whether to set each bound as closed or open. + axis : {0 or 'index', 1 or 'columns'}, default 0 + Determine range time on index or columns value. + For `Series` this parameter is unused and defaults to 0. + + Returns + ------- + Series or DataFrame + Data from the original object filtered to the specified dates range. + + Raises + ------ + TypeError + If the index is not a :class:`DatetimeIndex` + + See Also + -------- + at_time : Select values at a particular time of the day. + first : Select initial periods of time series based on a date offset. + last : Select final periods of time series based on a date offset. + DatetimeIndex.indexer_between_time : Get just the index locations for + values between particular times of the day. + + Examples + -------- + >>> i = pd.date_range("2018-04-09", periods=4, freq="1D20min") + >>> ts = pd.DataFrame({"A": [1, 2, 3, 4]}, index=i) + >>> ts + A + 2018-04-09 00:00:00 1 + 2018-04-10 00:20:00 2 + 2018-04-11 00:40:00 3 + 2018-04-12 01:00:00 4 + + >>> ts.between_time("0:15", "0:45") + A + 2018-04-10 00:20:00 2 + 2018-04-11 00:40:00 3 + + You get the times that are *not* between two times by setting + ``start_time`` later than ``end_time``: + + >>> ts.between_time("0:45", "0:15") + A + 2018-04-09 00:00:00 1 + 2018-04-12 01:00:00 4 + """ + if axis is None: + axis = 0 + axis = self._get_axis_number(axis) + + index = self._get_axis(axis) + if not isinstance(index, DatetimeIndex): + raise TypeError("Index must be DatetimeIndex") + + left_inclusive, right_inclusive = validate_inclusive(inclusive) + indexer = index.indexer_between_time( + start_time, + end_time, + include_start=left_inclusive, + include_end=right_inclusive, + ) + return self.take(indexer, axis=axis) + + @final + def resample( + self, + rule, + closed: Literal["right", "left"] | None = None, + label: Literal["right", "left"] | None = None, + convention: Literal["start", "end", "s", "e"] = "start", + on: Level | None = None, + level: Level | None = None, + origin: str | TimestampConvertibleTypes = "start_day", + offset: TimedeltaConvertibleTypes | None = None, + group_keys: bool = False, + ) -> Resampler: + """ + Resample time-series data. + + Convenience method for frequency conversion and resampling of time series. + The object must have a datetime-like index (`DatetimeIndex`, `PeriodIndex`, + or `TimedeltaIndex`), or the caller must pass the label of a datetime-like + series/index to the ``on``/``level`` keyword parameter. + + Parameters + ---------- + rule : DateOffset, Timedelta or str + The offset string or object representing target conversion. + closed : {{'right', 'left'}}, default None + Which side of bin interval is closed. The default is 'left' + for all frequency offsets except for 'ME', 'YE', 'QE', 'BME', + 'BA', 'BQE', and 'W' which all have a default of 'right'. + label : {{'right', 'left'}}, default None + Which bin edge label to label bucket with. The default is 'left' + for all frequency offsets except for 'ME', 'YE', 'QE', 'BME', + 'BA', 'BQE', and 'W' which all have a default of 'right'. + convention : {{'start', 'end', 's', 'e'}}, default 'start' + For `PeriodIndex` only, controls whether to use the start or + end of `rule`. + on : str, optional + For a DataFrame, column to use instead of index for resampling. + Column must be datetime-like. + level : str or int, optional + For a MultiIndex, level (name or number) to use for + resampling. `level` must be datetime-like. + origin : Timestamp or str, default 'start_day' + The timestamp on which to adjust the grouping. The timezone of origin + must match the timezone of the index. + If string, must be Timestamp convertible or one of the following: + + - 'epoch': `origin` is 1970-01-01 + - 'start': `origin` is the first value of the timeseries + - 'start_day': `origin` is the first day at midnight of the timeseries + + - 'end': `origin` is the last value of the timeseries + - 'end_day': `origin` is the ceiling midnight of the last day + + .. note:: + + Only takes effect for Tick-frequencies (i.e. fixed frequencies like + days, hours, and minutes, rather than months or quarters). + offset : Timedelta or str, default is None + An offset timedelta added to the origin. + + group_keys : bool, default False + Whether to include the group keys in the result index when using + ``.apply()`` on the resampled object. + + .. versionchanged:: 2.0.0 + + ``group_keys`` now defaults to ``False``. + + Returns + ------- + pandas.api.typing.Resampler + :class:`~pandas.core.Resampler` object. + + See Also + -------- + Series.resample : Resample a Series. + DataFrame.resample : Resample a DataFrame. + groupby : Group Series/DataFrame by mapping, function, label, or list of labels. + asfreq : Reindex a Series/DataFrame with the given frequency without grouping. + + Notes + ----- + See the `user guide + `__ + for more. + + To learn more about the offset strings, please see `this link + `__. + + Examples + -------- + Start by creating a series with 9 one minute timestamps. + + >>> index = pd.date_range("1/1/2000", periods=9, freq="min") + >>> series = pd.Series(range(9), index=index) + >>> series + 2000-01-01 00:00:00 0 + 2000-01-01 00:01:00 1 + 2000-01-01 00:02:00 2 + 2000-01-01 00:03:00 3 + 2000-01-01 00:04:00 4 + 2000-01-01 00:05:00 5 + 2000-01-01 00:06:00 6 + 2000-01-01 00:07:00 7 + 2000-01-01 00:08:00 8 + Freq: min, dtype: int64 + + Downsample the series into 3 minute bins and sum the values + of the timestamps falling into a bin. + + >>> series.resample("3min").sum() + 2000-01-01 00:00:00 3 + 2000-01-01 00:03:00 12 + 2000-01-01 00:06:00 21 + Freq: 3min, dtype: int64 + + Downsample the series into 3 minute bins as above, but label each + bin using the right edge instead of the left. Please note that the + value in the bucket used as the label is not included in the bucket, + which it labels. For example, in the original series the + bucket ``2000-01-01 00:03:00`` contains the value 3, but the summed + value in the resampled bucket with the label ``2000-01-01 00:03:00`` + does not include 3 (if it did, the summed value would be 6, not 3). + + >>> series.resample("3min", label="right").sum() + 2000-01-01 00:03:00 3 + 2000-01-01 00:06:00 12 + 2000-01-01 00:09:00 21 + Freq: 3min, dtype: int64 + + To include this value close the right side of the bin interval, + as shown below. + + >>> series.resample("3min", label="right", closed="right").sum() + 2000-01-01 00:00:00 0 + 2000-01-01 00:03:00 6 + 2000-01-01 00:06:00 15 + 2000-01-01 00:09:00 15 + Freq: 3min, dtype: int64 + + Upsample the series into 30 second bins. + + >>> series.resample("30s").asfreq()[0:5] # Select first 5 rows + 2000-01-01 00:00:00 0.0 + 2000-01-01 00:00:30 NaN + 2000-01-01 00:01:00 1.0 + 2000-01-01 00:01:30 NaN + 2000-01-01 00:02:00 2.0 + Freq: 30s, dtype: float64 + + Upsample the series into 30 second bins and fill the ``NaN`` + values using the ``ffill`` method. + + >>> series.resample("30s").ffill()[0:5] + 2000-01-01 00:00:00 0 + 2000-01-01 00:00:30 0 + 2000-01-01 00:01:00 1 + 2000-01-01 00:01:30 1 + 2000-01-01 00:02:00 2 + Freq: 30s, dtype: int64 + + Upsample the series into 30 second bins and fill the + ``NaN`` values using the ``bfill`` method. + + >>> series.resample("30s").bfill()[0:5] + 2000-01-01 00:00:00 0 + 2000-01-01 00:00:30 1 + 2000-01-01 00:01:00 1 + 2000-01-01 00:01:30 2 + 2000-01-01 00:02:00 2 + Freq: 30s, dtype: int64 + + Pass a custom function via ``apply`` + + >>> def custom_resampler(arraylike): + ... return np.sum(arraylike) + 5 + >>> series.resample("3min").apply(custom_resampler) + 2000-01-01 00:00:00 8 + 2000-01-01 00:03:00 17 + 2000-01-01 00:06:00 26 + Freq: 3min, dtype: int64 + + For a Series with a PeriodIndex, the keyword `convention` can be + used to control whether to use the start or end of `rule`. + + Resample a year by quarter using 'start' `convention`. Values are + assigned to the first quarter of the period. + + >>> s = pd.Series( + ... [1, 2], index=pd.period_range("2012-01-01", freq="Y", periods=2) + ... ) + >>> s + 2012 1 + 2013 2 + Freq: Y-DEC, dtype: int64 + >>> s.resample("Q", convention="start").asfreq() + 2012Q1 1.0 + 2012Q2 NaN + 2012Q3 NaN + 2012Q4 NaN + 2013Q1 2.0 + 2013Q2 NaN + 2013Q3 NaN + 2013Q4 NaN + Freq: Q-DEC, dtype: float64 + + Resample quarters by month using 'end' `convention`. Values are + assigned to the last month of the period. + + >>> q = pd.Series( + ... [1, 2, 3, 4], index=pd.period_range("2018-01-01", freq="Q", periods=4) + ... ) + >>> q + 2018Q1 1 + 2018Q2 2 + 2018Q3 3 + 2018Q4 4 + Freq: Q-DEC, dtype: int64 + >>> q.resample("M", convention="end").asfreq() + 2018-03 1.0 + 2018-04 NaN + 2018-05 NaN + 2018-06 2.0 + 2018-07 NaN + 2018-08 NaN + 2018-09 3.0 + 2018-10 NaN + 2018-11 NaN + 2018-12 4.0 + Freq: M, dtype: float64 + + For DataFrame objects, the keyword `on` can be used to specify the + column instead of the index for resampling. + + >>> df = pd.DataFrame([10, 11, 9, 13, 14, 18, 17, 19], columns=["price"]) + >>> df["volume"] = [50, 60, 40, 100, 50, 100, 40, 50] + >>> df["week_starting"] = pd.date_range("01/01/2018", periods=8, freq="W") + >>> df + price volume week_starting + 0 10 50 2018-01-07 + 1 11 60 2018-01-14 + 2 9 40 2018-01-21 + 3 13 100 2018-01-28 + 4 14 50 2018-02-04 + 5 18 100 2018-02-11 + 6 17 40 2018-02-18 + 7 19 50 2018-02-25 + >>> df.resample("ME", on="week_starting").mean() + price volume + week_starting + 2018-01-31 10.75 62.5 + 2018-02-28 17.00 60.0 + + For a DataFrame with MultiIndex, the keyword `level` can be used to + specify on which level the resampling needs to take place. + + >>> days = pd.date_range("1/1/2000", periods=4, freq="D") + >>> df2 = pd.DataFrame( + ... [ + ... [10, 50], + ... [11, 60], + ... [9, 40], + ... [13, 100], + ... [14, 50], + ... [18, 100], + ... [17, 40], + ... [19, 50], + ... ], + ... columns=["price", "volume"], + ... index=pd.MultiIndex.from_product([days, ["morning", "afternoon"]]), + ... ) + >>> df2 + price volume + 2000-01-01 morning 10 50 + afternoon 11 60 + 2000-01-02 morning 9 40 + afternoon 13 100 + 2000-01-03 morning 14 50 + afternoon 18 100 + 2000-01-04 morning 17 40 + afternoon 19 50 + >>> df2.resample("D", level=0).sum() + price volume + 2000-01-01 21 110 + 2000-01-02 22 140 + 2000-01-03 32 150 + 2000-01-04 36 90 + + If you want to adjust the start of the bins based on a fixed timestamp: + + >>> start, end = "2000-10-01 23:30:00", "2000-10-02 00:30:00" + >>> rng = pd.date_range(start, end, freq="7min") + >>> ts = pd.Series(np.arange(len(rng)) * 3, index=rng) + >>> ts + 2000-10-01 23:30:00 0 + 2000-10-01 23:37:00 3 + 2000-10-01 23:44:00 6 + 2000-10-01 23:51:00 9 + 2000-10-01 23:58:00 12 + 2000-10-02 00:05:00 15 + 2000-10-02 00:12:00 18 + 2000-10-02 00:19:00 21 + 2000-10-02 00:26:00 24 + Freq: 7min, dtype: int64 + + >>> ts.resample("17min").sum() + 2000-10-01 23:14:00 0 + 2000-10-01 23:31:00 9 + 2000-10-01 23:48:00 21 + 2000-10-02 00:05:00 54 + 2000-10-02 00:22:00 24 + Freq: 17min, dtype: int64 + + >>> ts.resample("17min", origin="epoch").sum() + 2000-10-01 23:18:00 0 + 2000-10-01 23:35:00 18 + 2000-10-01 23:52:00 27 + 2000-10-02 00:09:00 39 + 2000-10-02 00:26:00 24 + Freq: 17min, dtype: int64 + + >>> ts.resample("17min", origin="2000-01-01").sum() + 2000-10-01 23:24:00 3 + 2000-10-01 23:41:00 15 + 2000-10-01 23:58:00 45 + 2000-10-02 00:15:00 45 + Freq: 17min, dtype: int64 + + If you want to adjust the start of the bins with an `offset` Timedelta, the two + following lines are equivalent: + + >>> ts.resample("17min", origin="start").sum() + 2000-10-01 23:30:00 9 + 2000-10-01 23:47:00 21 + 2000-10-02 00:04:00 54 + 2000-10-02 00:21:00 24 + Freq: 17min, dtype: int64 + + >>> ts.resample("17min", offset="23h30min").sum() + 2000-10-01 23:30:00 9 + 2000-10-01 23:47:00 21 + 2000-10-02 00:04:00 54 + 2000-10-02 00:21:00 24 + Freq: 17min, dtype: int64 + + If you want to take the largest Timestamp as the end of the bins: + + >>> ts.resample("17min", origin="end").sum() + 2000-10-01 23:35:00 0 + 2000-10-01 23:52:00 18 + 2000-10-02 00:09:00 27 + 2000-10-02 00:26:00 63 + Freq: 17min, dtype: int64 + + In contrast with the `start_day`, you can use `end_day` to take the ceiling + midnight of the largest Timestamp as the end of the bins and drop the bins + not containing data: + + >>> ts.resample("17min", origin="end_day").sum() + 2000-10-01 23:38:00 3 + 2000-10-01 23:55:00 15 + 2000-10-02 00:12:00 45 + 2000-10-02 00:29:00 45 + Freq: 17min, dtype: int64 + """ + from pandas.core.resample import get_resampler + + return get_resampler( + cast("Series | DataFrame", self), + freq=rule, + label=label, + closed=closed, + convention=convention, + key=on, + level=level, + origin=origin, + offset=offset, + group_keys=group_keys, + ) + + @final + def rank( + self, + axis: Axis = 0, + method: Literal["average", "min", "max", "first", "dense"] = "average", + numeric_only: bool = False, + na_option: Literal["keep", "top", "bottom"] = "keep", + ascending: bool = True, + pct: bool = False, + ) -> Self: + """ + Compute numerical data ranks (1 through n) along axis. + + By default, equal values are assigned a rank that is the average of the + ranks of those values. + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns'}, default 0 + Index to direct ranking. + For `Series` this parameter is unused and defaults to 0. + method : {'average', 'min', 'max', 'first', 'dense'}, default 'average' + How to rank the group of records that have the same value (i.e. ties): + + * average: average rank of the group + * min: lowest rank in the group + * max: highest rank in the group + * first: ranks assigned in order they appear in the array + * dense: like 'min', but rank always increases by 1 between groups. + + numeric_only : bool, default False + For DataFrame objects, rank only numeric columns if set to True. + + .. versionchanged:: 2.0.0 + The default value of ``numeric_only`` is now ``False``. + + na_option : {'keep', 'top', 'bottom'}, default 'keep' + How to rank NaN values: + + * keep: assign NaN rank to NaN values + * top: assign lowest rank to NaN values + * bottom: assign highest rank to NaN values + + ascending : bool, default True + Whether or not the elements should be ranked in ascending order. + pct : bool, default False + Whether or not to display the returned rankings in percentile + form. + + Returns + ------- + same type as caller + Return a Series or DataFrame with data ranks as values. + + See Also + -------- + core.groupby.DataFrameGroupBy.rank : Rank of values within each group. + core.groupby.SeriesGroupBy.rank : Rank of values within each group. + + Examples + -------- + >>> df = pd.DataFrame( + ... data={ + ... "Animal": ["cat", "penguin", "dog", "spider", "snake"], + ... "Number_legs": [4, 2, 4, 8, np.nan], + ... } + ... ) + >>> df + Animal Number_legs + 0 cat 4.0 + 1 penguin 2.0 + 2 dog 4.0 + 3 spider 8.0 + 4 snake NaN + + Ties are assigned the mean of the ranks (by default) for the group. + + >>> s = pd.Series(range(5), index=list("abcde")) + >>> s["d"] = s["b"] + >>> s.rank() + a 1.0 + b 2.5 + c 4.0 + d 2.5 + e 5.0 + dtype: float64 + + The following example shows how the method behaves with the above + parameters: + + * default_rank: this is the default behaviour obtained without using + any parameter. + * max_rank: setting ``method = 'max'`` the records that have the + same values are ranked using the highest rank (e.g.: since 'cat' + and 'dog' are both in the 2nd and 3rd position, rank 3 is assigned.) + * NA_bottom: choosing ``na_option = 'bottom'``, if there are records + with NaN values they are placed at the bottom of the ranking. + * pct_rank: when setting ``pct = True``, the ranking is expressed as + percentile rank. + + >>> df["default_rank"] = df["Number_legs"].rank() + >>> df["max_rank"] = df["Number_legs"].rank(method="max") + >>> df["NA_bottom"] = df["Number_legs"].rank(na_option="bottom") + >>> df["pct_rank"] = df["Number_legs"].rank(pct=True) + >>> df + Animal Number_legs default_rank max_rank NA_bottom pct_rank + 0 cat 4.0 2.5 3.0 2.5 0.625 + 1 penguin 2.0 1.0 1.0 1.0 0.250 + 2 dog 4.0 2.5 3.0 2.5 0.625 + 3 spider 8.0 4.0 4.0 4.0 1.000 + 4 snake NaN NaN NaN 5.0 NaN + """ + axis_int = self._get_axis_number(axis) + + if na_option not in {"keep", "top", "bottom"}: + msg = "na_option must be one of 'keep', 'top', or 'bottom'" + raise ValueError(msg) + + def ranker(data): + if data.ndim == 2: + # i.e. DataFrame, we cast to ndarray + values = data.values + else: + # i.e. Series, can dispatch to EA + values = data._values + + if isinstance(values, ExtensionArray): + ranks = values._rank( + axis=axis_int, + method=method, + ascending=ascending, + na_option=na_option, + pct=pct, + ) + else: + ranks = algos.rank( + values, + axis=axis_int, + method=method, + ascending=ascending, + na_option=na_option, + pct=pct, + ) + + ranks_obj = self._constructor(ranks, **data._construct_axes_dict()) + return ranks_obj.__finalize__(self, method="rank") + + if numeric_only: + if self.ndim == 1 and not is_numeric_dtype(self.dtype): + # GH#47500 + raise TypeError( + "Series.rank does not allow numeric_only=True with " + "non-numeric dtype." + ) + data = self._get_numeric_data() + else: + data = self + + return ranker(data) + + def compare( + self, + other: Self, + align_axis: Axis = 1, + keep_shape: bool = False, + keep_equal: bool = False, + result_names: Suffixes = ("self", "other"), + ): + """ + Compare to another Series/DataFrame and show the differences. + + Parameters + ---------- + other : Series/DataFrame + Object to compare with. + + align_axis : {0 or 'index', 1 or 'columns'}, default 1 + Determine which axis to align the comparison on. + + * 0, or 'index' : Resulting differences are stacked vertically + with rows drawn alternately from self and other. + * 1, or 'columns' : Resulting differences are aligned horizontally + with columns drawn alternately from self and other. + + keep_shape : bool, default False + If true, all rows and columns are kept. + Otherwise, only the ones with different values are kept. + + keep_equal : bool, default False + If true, the result keeps values that are equal. + Otherwise, equal values are shown as NaNs. + + result_names : tuple, default ('self', 'other') + Set the dataframes names in the comparison. + """ + if type(self) is not type(other): + cls_self, cls_other = type(self).__name__, type(other).__name__ + raise TypeError( + f"can only compare '{cls_self}' (not '{cls_other}') with '{cls_self}'" + ) + + # error: Unsupported left operand type for & ("Self") + mask = ~((self == other) | (self.isna() & other.isna())) # type: ignore[operator] + mask.fillna(True, inplace=True) + + if not keep_equal: + self = self.where(mask) + other = other.where(mask) + + if not keep_shape: + if isinstance(self, ABCDataFrame): + cmask = mask.any() + rmask = mask.any(axis=1) + self = self.loc[rmask, cmask] + other = other.loc[rmask, cmask] + else: + self = self[mask] + other = other[mask] + if not isinstance(result_names, tuple): + raise TypeError( + f"Passing 'result_names' as a {type(result_names)} is not " + "supported. Provide 'result_names' as a tuple instead." + ) + + if align_axis in (1, "columns"): # This is needed for Series + axis = 1 + else: + axis = self._get_axis_number(align_axis) + + # error: List item 0 has incompatible type "NDFrame"; expected + # "Union[Series, DataFrame]" + diff = concat( + [self, other], # type: ignore[list-item] + axis=axis, + keys=result_names, + ) + + if axis >= self.ndim: + # No need to reorganize data if stacking on new axis + # This currently applies for stacking two Series on columns + return diff + + ax = diff._get_axis(axis) + ax_names = np.array(ax.names) + + # set index names to positions to avoid confusion + ax.names = np.arange(len(ax_names)) + + # bring self-other to inner level + order = [*range(1, ax.nlevels), 0] + if isinstance(diff, ABCDataFrame): + diff = diff.reorder_levels(order, axis=axis) + else: + diff = diff.reorder_levels(order) + + # restore the index names in order + diff._get_axis(axis=axis).names = ax_names[order] + + # reorder axis to keep things organized + indices = ( + np.arange(diff.shape[axis]) + .reshape([2, diff.shape[axis] // 2]) + .T.reshape(-1) + ) + diff = diff.take(indices, axis=axis) + + return diff + + @final + def align( + self, + other: NDFrameT, + join: AlignJoin = "outer", + axis: Axis | None = None, + level: Level | None = None, + copy: bool | lib.NoDefault = lib.no_default, + fill_value: Hashable | None = None, + ) -> tuple[Self, NDFrameT]: + """ + Align two objects on their axes with the specified join method. + + Join method is specified for each axis Index. + + Parameters + ---------- + other : DataFrame or Series + The object to align with. + join : {{'outer', 'inner', 'left', 'right'}}, default 'outer' + Type of alignment to be performed. + + * left: use only keys from left frame, preserve key order. + * right: use only keys from right frame, preserve key order. + * outer: use union of keys from both frames, sort keys lexicographically. + * inner: use intersection of keys from both frames, + preserve the order of the left keys. + + axis : allowed axis of the other object, default None + Align on index (0), columns (1), or both (None). + level : int or level name, default None + Broadcast across a level, matching Index values on the + passed MultiIndex level. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + fill_value : scalar, default np.nan + Value to use for missing values. Defaults to NaN, but can be any + "compatible" value. + + Returns + ------- + tuple of (Series/DataFrame, type of other) + Aligned objects. + + See Also + -------- + Series.align : Align two objects on their axes with specified join method. + DataFrame.align : Align two objects on their axes with specified join method. + + Examples + -------- + >>> df = pd.DataFrame( + ... [[1, 2, 3, 4], [6, 7, 8, 9]], columns=["D", "B", "E", "A"], index=[1, 2] + ... ) + >>> other = pd.DataFrame( + ... [[10, 20, 30, 40], [60, 70, 80, 90], [600, 700, 800, 900]], + ... columns=["A", "B", "C", "D"], + ... index=[2, 3, 4], + ... ) + >>> df + D B E A + 1 1 2 3 4 + 2 6 7 8 9 + >>> other + A B C D + 2 10 20 30 40 + 3 60 70 80 90 + 4 600 700 800 900 + + Align on columns: + + >>> left, right = df.align(other, join="outer", axis=1) + >>> left + A B C D E + 1 4 2 NaN 1 3 + 2 9 7 NaN 6 8 + >>> right + A B C D E + 2 10 20 30 40 NaN + 3 60 70 80 90 NaN + 4 600 700 800 900 NaN + + We can also align on the index: + + >>> left, right = df.align(other, join="outer", axis=0) + >>> left + D B E A + 1 1.0 2.0 3.0 4.0 + 2 6.0 7.0 8.0 9.0 + 3 NaN NaN NaN NaN + 4 NaN NaN NaN NaN + >>> right + A B C D + 1 NaN NaN NaN NaN + 2 10.0 20.0 30.0 40.0 + 3 60.0 70.0 80.0 90.0 + 4 600.0 700.0 800.0 900.0 + + Finally, the default `axis=None` will align on both index and columns: + + >>> left, right = df.align(other, join="outer", axis=None) + >>> left + A B C D E + 1 4.0 2.0 NaN 1.0 3.0 + 2 9.0 7.0 NaN 6.0 8.0 + 3 NaN NaN NaN NaN NaN + 4 NaN NaN NaN NaN NaN + >>> right + A B C D E + 1 NaN NaN NaN NaN NaN + 2 10.0 20.0 30.0 40.0 NaN + 3 60.0 70.0 80.0 90.0 NaN + 4 600.0 700.0 800.0 900.0 NaN + """ + self._check_copy_deprecation(copy) + + _right: DataFrame | Series + if axis is not None: + axis = self._get_axis_number(axis) + if isinstance(other, ABCDataFrame): + left, _right, join_index = self._align_frame( + other, + join=join, + axis=axis, + level=level, + fill_value=fill_value, + ) + + elif isinstance(other, ABCSeries): + left, _right, join_index = self._align_series( + other, + join=join, + axis=axis, + level=level, + fill_value=fill_value, + ) + else: # pragma: no cover + raise TypeError(f"unsupported type: {type(other)}") + + right = cast(NDFrameT, _right) + if self.ndim == 1 or axis == 0: + # If we are aligning timezone-aware DatetimeIndexes and the timezones + # do not match, convert both to UTC. + if isinstance(left.index.dtype, DatetimeTZDtype): + if left.index.tz != right.index.tz: + if join_index is not None: + # GH#33671 copy to ensure we don't change the index on + # our original Series + left = left.copy(deep=False) + right = right.copy(deep=False) + left.index = join_index + right.index = join_index + + left = left.__finalize__(self) + right = right.__finalize__(other) + return left, right + + @final + def _align_frame( + self, + other: DataFrame, + join: AlignJoin = "outer", + axis: Axis | None = None, + level=None, + fill_value=None, + ) -> tuple[Self, DataFrame, Index | None]: + # defaults + join_index, join_columns = None, None + ilidx, iridx = None, None + clidx, cridx = None, None + + is_series = isinstance(self, ABCSeries) + + if (axis is None or axis == 0) and not self.index.equals(other.index): + join_index, ilidx, iridx = self.index.join( + other.index, how=join, level=level, return_indexers=True + ) + + if ( + (axis is None or axis == 1) + and not is_series + and not self.columns.equals(other.columns) + ): + join_columns, clidx, cridx = self.columns.join( + other.columns, how=join, level=level, return_indexers=True + ) + + if is_series: + reindexers = {0: [join_index, ilidx]} + else: + reindexers = {0: [join_index, ilidx], 1: [join_columns, clidx]} + + left = self._reindex_with_indexers( + reindexers, fill_value=fill_value, allow_dups=True + ) + # other must be always DataFrame + right = other._reindex_with_indexers( + {0: [join_index, iridx], 1: [join_columns, cridx]}, + fill_value=fill_value, + allow_dups=True, + ) + return left, right, join_index + + @final + def _align_series( + self, + other: Series, + join: AlignJoin = "outer", + axis: Axis | None = None, + level=None, + fill_value=None, + ) -> tuple[Self, Series, Index | None]: + is_series = isinstance(self, ABCSeries) + + if (not is_series and axis is None) or axis not in [None, 0, 1]: + raise ValueError("Must specify axis=0 or 1") + + if is_series and axis == 1: + raise ValueError("cannot align series to a series other than axis 0") + + # series/series compat, other must always be a Series + if not axis: + # equal + if self.index.equals(other.index): + join_index, lidx, ridx = None, None, None + else: + join_index, lidx, ridx = self.index.join( + other.index, how=join, level=level, return_indexers=True + ) + + if is_series: + left = self._reindex_indexer(join_index, lidx) + elif lidx is None or join_index is None: + left = self.copy(deep=False) + else: + new_mgr = self._mgr.reindex_indexer(join_index, lidx, axis=1) + left = self._constructor_from_mgr(new_mgr, axes=new_mgr.axes) + + right = other._reindex_indexer(join_index, ridx) + + else: + # one has > 1 ndim + fdata = self._mgr + join_index = self.axes[1] + lidx, ridx = None, None + if not join_index.equals(other.index): + join_index, lidx, ridx = join_index.join( + other.index, how=join, level=level, return_indexers=True + ) + + if lidx is not None: + bm_axis = self._get_block_manager_axis(1) + fdata = fdata.reindex_indexer(join_index, lidx, axis=bm_axis) + + left = self._constructor_from_mgr(fdata, axes=fdata.axes) + + right = other._reindex_indexer(join_index, ridx) + + # fill + fill_na = notna(fill_value) + if fill_na: + left = left.fillna(fill_value) + right = right.fillna(fill_value) + + return left, right, join_index + + @final + def _where( + self, + cond, + other=lib.no_default, + *, + inplace: bool = False, + axis: Axis | None = None, + level=None, + ) -> Self: + """ + Equivalent to public method `where`, except that `other` is not + applied as a function even if callable. Used in __setitem__. + """ + inplace = validate_bool_kwarg(inplace, "inplace") + + if axis is not None: + axis = self._get_axis_number(axis) + + # align the cond to same shape as myself + cond = common.apply_if_callable(cond, self) + if isinstance(cond, NDFrame): + # CoW: Make sure reference is not kept alive + if cond.ndim == 1 and self.ndim == 2: + cond = cond._constructor_expanddim( + dict.fromkeys(range(len(self.columns)), cond), + copy=False, + ) + cond.columns = self.columns + cond = cond.align(self, join="right")[0] + else: + if not hasattr(cond, "shape"): + cond = np.asanyarray(cond) + if cond.shape != self.shape: + raise ValueError("Array conditional must be same shape as self") + cond = self._constructor(cond, **self._construct_axes_dict(), copy=False) + + # make sure we are boolean + fill_value = bool(inplace) + cond = cond.fillna(fill_value) + cond = cond.infer_objects() + + msg = "Boolean array expected for the condition, not {dtype}" + + if not cond.empty: + if not isinstance(cond, ABCDataFrame): + # This is a single-dimensional object. + if not is_bool_dtype(cond): + raise TypeError(msg.format(dtype=cond.dtype)) + else: + for block in cond._mgr.blocks: + if not is_bool_dtype(block.dtype): + raise TypeError(msg.format(dtype=block.dtype)) + if cond._mgr.any_extension_types: + # GH51574: avoid object ndarray conversion later on + cond = cond._constructor( + cond.to_numpy(dtype=bool, na_value=fill_value), + **cond._construct_axes_dict(), + ) + else: + # GH#21947 we have an empty DataFrame/Series, could be object-dtype + cond = cond.astype(bool) + + cond = -cond if inplace else cond + cond = cond.reindex(self._info_axis, axis=self._info_axis_number) + + # try to align with other + if isinstance(other, NDFrame): + # align with me + if other.ndim <= self.ndim: + # CoW: Make sure reference is not kept alive + other = self.align( + other, + join="left", + axis=axis, + level=level, + fill_value=None, + )[1] + + # if we are NOT aligned, raise as we cannot where index + if axis is None and not other._indexed_same(self): + raise InvalidIndexError + + if other.ndim < self.ndim: + other = other._values + if isinstance(other, np.ndarray): + # TODO(EA2D): could also do this for NDArrayBackedEA cases? + if axis == 0: + other = np.reshape(other, (-1, 1)) + elif axis == 1: + other = np.reshape(other, (1, -1)) + + other = np.broadcast_to(other, self.shape) + else: + # GH#38729, GH#62038 avoid lossy casting or object-casting + if axis == 0: + res_cols = [ + self.iloc[:, i]._where( + cond.iloc[:, i], + other, + ) + for i in range(self.shape[1]) + ] + elif axis == 1: + # TODO: can we use a zero-copy alternative to "repeat"? + res_cols = [ + self.iloc[:, i]._where( + cond.iloc[:, i], + other[i : i + 1].repeat(len(self)), + ) + for i in range(self.shape[1]) + ] + res = self._constructor(dict(enumerate(res_cols))) + res.index = self.index + res.columns = self.columns + if inplace: + self._update_inplace(res) + return self + return res.__finalize__(self) + + # slice me out of the other + else: + raise NotImplementedError( + "cannot align with a higher dimensional NDFrame" + ) + + elif not isinstance(other, (MultiIndex, NDFrame)): + # mainly just catching Index here + other = extract_array(other, extract_numpy=True) + + if isinstance(other, (np.ndarray, ExtensionArray)): + if other.shape != self.shape: + if self.ndim != 1: + # In the ndim == 1 case we may have + # other length 1, which we treat as scalar (GH#2745, GH#4192) + # or len(other) == icond.sum(), which we treat like + # __setitem__ (GH#3235) + raise ValueError( + "other must be the same shape as self when an ndarray" + ) + + # we are the same shape, so create an actual object for alignment + else: + other = self._constructor( + other, **self._construct_axes_dict(), copy=False + ) + + if axis is None: + axis = 0 + + if self.ndim == getattr(other, "ndim", 0): + align = True + else: + align = self._get_axis_number(axis) == 1 + + if inplace: + # we may have different type blocks come out of putmask, so + # reconstruct the block manager + + new_data = self._mgr.putmask(mask=cond, new=other, align=align) + result = self._constructor_from_mgr(new_data, axes=new_data.axes) + self._update_inplace(result) + return self + + else: + new_data = self._mgr.where( + other=other, + cond=cond, + align=align, + ) + result = self._constructor_from_mgr(new_data, axes=new_data.axes) + return result.__finalize__(self) + + @final + def where( + self, + cond, + other=lib.no_default, + *, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, + ) -> Self: + """ + Replace values where the condition is False. + + This method allows conditional replacement of values. Where the + condition evaluates to True, the original values are retained; where + it evaluates to False, values are replaced with corresponding entries + from ``other``. + + Parameters + ---------- + cond : bool Series/DataFrame, array-like, or callable + Where `cond` is True, keep the original value. Where + False, replace with corresponding value from `other`. + If `cond` is callable, it is computed on the Series/DataFrame and + should return boolean Series/DataFrame or array. The callable must + not change input Series/DataFrame (though pandas doesn't check it). + other : scalar, Series/DataFrame, or callable + Entries where `cond` is False are replaced with + corresponding value from `other`. + If other is callable, it is computed on the Series/DataFrame and + should return scalar or Series/DataFrame. The callable must not + change input Series/DataFrame (though pandas doesn't check it). + If not specified, entries will be filled with the corresponding + NULL value (``np.nan`` for numpy dtypes, ``pd.NA`` for extension + dtypes). + inplace : bool, default False + Whether to perform the operation in place on the data. + axis : int, default None + Alignment axis if needed. For `Series` this parameter is + unused and defaults to 0. + level : int, default None + Alignment level if needed. + + Returns + ------- + Series or DataFrame + When applied to a Series, the function will return a Series, + and when applied to a DataFrame, it will return a DataFrame. + + See Also + -------- + :func:`DataFrame.mask` : Return an object of same shape as caller. + :func:`Series.mask` : Return an object of same shape as caller. + + Notes + ----- + The where method is an application of the if-then idiom. For each + element in the caller, if ``cond`` is ``True`` the + element is used; otherwise the corresponding element from + ``other`` is used. If the axis of ``other`` does not align with axis of + ``cond`` Series/DataFrame, the values of ``cond`` on misaligned index positions + will be filled with False. + + The signature for :func:`Series.where` or + :func:`DataFrame.where` differs from :func:`numpy.where`. + Roughly ``df1.where(m, df2)`` is equivalent to ``np.where(m, df1, df2)``. + + For further details and examples see the ``where`` documentation in + :ref:`indexing `. + + The dtype of the object takes precedence. The fill value is casted to + the object's dtype, if this can be done losslessly. + + Examples + -------- + >>> s = pd.Series(range(5)) + >>> s.where(s > 0) + 0 NaN + 1 1.0 + 2 2.0 + 3 3.0 + 4 4.0 + dtype: float64 + >>> s.mask(s > 0) + 0 0.0 + 1 NaN + 2 NaN + 3 NaN + 4 NaN + dtype: float64 + + >>> s = pd.Series(range(5)) + >>> t = pd.Series([True, False]) + >>> s.where(t, 99) + 0 0 + 1 99 + 2 99 + 3 99 + 4 99 + dtype: int64 + >>> s.mask(t, 99) + 0 99 + 1 1 + 2 99 + 3 99 + 4 99 + dtype: int64 + + >>> s.where(s > 1, 10) + 0 10 + 1 10 + 2 2 + 3 3 + 4 4 + dtype: int64 + >>> s.mask(s > 1, 10) + 0 0 + 1 1 + 2 10 + 3 10 + 4 10 + dtype: int64 + + >>> df = pd.DataFrame(np.arange(10).reshape(-1, 2), columns=["A", "B"]) + >>> df + A B + 0 0 1 + 1 2 3 + 2 4 5 + 3 6 7 + 4 8 9 + >>> m = df % 3 == 0 + >>> df.where(m, -df) + A B + 0 0 -1 + 1 -2 3 + 2 -4 -5 + 3 6 -7 + 4 -8 9 + >>> df.where(m, -df) == np.where(m, df, -df) + A B + 0 True True + 1 True True + 2 True True + 3 True True + 4 True True + >>> df.where(m, -df) == df.mask(~m, -df) + A B + 0 True True + 1 True True + 2 True True + 3 True True + 4 True True + """ + inplace = validate_bool_kwarg(inplace, "inplace") + if inplace: + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount( + self + ) <= REF_COUNT_METHOD and not common.is_local_in_caller_frame(self): + warnings.warn( + _chained_assignment_method_msg, + ChainedAssignmentError, + stacklevel=2, + ) + + other = common.apply_if_callable(other, self) + return self._where(cond, other, inplace=inplace, axis=axis, level=level) + + @final + def mask( + self, + cond, + other=lib.no_default, + *, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, + ) -> Self: + """ + Replace values where the condition is True. + + Parameters + ---------- + cond : bool Series/DataFrame, array-like, or callable + Where `cond` is False, keep the original value. Where + True, replace with corresponding value from `other`. + If `cond` is callable, it is computed on the Series/DataFrame and + should return boolean Series/DataFrame or array. The callable must + not change input Series/DataFrame (though pandas doesn't check it). + other : scalar, Series/DataFrame, or callable + Entries where `cond` is True are replaced with + corresponding value from `other`. + If other is callable, it is computed on the Series/DataFrame and + should return scalar or Series/DataFrame. The callable must not + change input Series/DataFrame (though pandas doesn't check it). + If not specified, entries will be filled with the corresponding + NULL value (``np.nan`` for numpy dtypes, ``pd.NA`` for extension + dtypes). + inplace : bool, default False + Whether to perform the operation in place on the data. + axis : int, default None + Alignment axis if needed. For `Series` this parameter is + unused and defaults to 0. + level : int, default None + Alignment level if needed. + + Returns + ------- + Series or DataFrame + When applied to a Series, the function will return a Series, + and when applied to a DataFrame, it will return a DataFrame. + + See Also + -------- + :func:`DataFrame.where` : Return an object of same shape as caller. + :func:`Series.where` : Return an object of same shape as caller. + + Notes + ----- + The mask method is an application of the if-then idiom. For each + element in the caller, if ``cond`` is ``False`` the + element is used; otherwise the corresponding element from + ``other`` is used. If the axis of ``other`` does not align with axis of + ``cond`` Series/DataFrame, the values of ``cond`` on misaligned index positions + will be filled with True. + + The signature for :func:`Series.where` or + :func:`DataFrame.where` differs from :func:`numpy.where`. + Roughly ``df1.where(m, df2)`` is equivalent to ``np.where(m, df1, df2)``. + + For further details and examples see the ``mask`` documentation in + :ref:`indexing `. + + The dtype of the object takes precedence. The fill value is casted to + the object's dtype, if this can be done losslessly. + + Examples + -------- + >>> s = pd.Series(range(5)) + >>> s.where(s > 0) + 0 NaN + 1 1.0 + 2 2.0 + 3 3.0 + 4 4.0 + dtype: float64 + >>> s.mask(s > 0) + 0 0.0 + 1 NaN + 2 NaN + 3 NaN + 4 NaN + dtype: float64 + + >>> s = pd.Series(range(5)) + >>> t = pd.Series([True, False]) + >>> s.where(t, 99) + 0 0 + 1 99 + 2 99 + 3 99 + 4 99 + dtype: int64 + >>> s.mask(t, 99) + 0 99 + 1 1 + 2 99 + 3 99 + 4 99 + dtype: int64 + + >>> s.where(s > 1, 10) + 0 10 + 1 10 + 2 2 + 3 3 + 4 4 + dtype: int64 + >>> s.mask(s > 1, 10) + 0 0 + 1 1 + 2 10 + 3 10 + 4 10 + dtype: int64 + + >>> df = pd.DataFrame(np.arange(10).reshape(-1, 2), columns=["A", "B"]) + >>> df + A B + 0 0 1 + 1 2 3 + 2 4 5 + 3 6 7 + 4 8 9 + >>> m = df % 3 == 0 + >>> df.where(m, -df) + A B + 0 0 -1 + 1 -2 3 + 2 -4 -5 + 3 6 -7 + 4 -8 9 + >>> df.where(m, -df) == np.where(m, df, -df) + A B + 0 True True + 1 True True + 2 True True + 3 True True + 4 True True + >>> df.where(m, -df) == df.mask(~m, -df) + A B + 0 True True + 1 True True + 2 True True + 3 True True + 4 True True + """ + inplace = validate_bool_kwarg(inplace, "inplace") + if inplace: + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount( + self + ) <= REF_COUNT_METHOD and not common.is_local_in_caller_frame(self): + warnings.warn( + _chained_assignment_method_msg, + ChainedAssignmentError, + stacklevel=2, + ) + + cond = common.apply_if_callable(cond, self) + other = common.apply_if_callable(other, self) + + # see gh-21891 + if not hasattr(cond, "__invert__"): + cond = np.array(cond) + + return self._where( + ~cond, + other=other, + inplace=inplace, + axis=axis, + level=level, + ) + + def shift( + self, + periods: int | Sequence[int] = 1, + freq=None, + axis: Axis = 0, + fill_value: Hashable = lib.no_default, + suffix: str | None = None, + ) -> Self | DataFrame: + """ + Shift index by desired number of periods with an optional time `freq`. + + When `freq` is not passed, shift the index without realigning the data. + If `freq` is passed (in this case, the index must be date or datetime, + or it will raise a `NotImplementedError`), the index will be + increased using the periods and the `freq`. `freq` can be inferred + when specified as "infer" as long as either freq or inferred_freq + attribute is set in the index. + + Parameters + ---------- + periods : int or Sequence + Number of periods to shift. Can be positive or negative. + If an iterable of ints, the data will be shifted once by each int. + This is equivalent to shifting by one value at a time and + concatenating all resulting frames. The resulting columns will have + the shift suffixed to their column names. For multiple periods, + axis must not be 1. + freq : DateOffset, tseries.offsets, timedelta, or str, optional + Offset to use from the tseries module or time rule (e.g. 'EOM'). + If `freq` is specified then the index values are shifted but the + data is not realigned. That is, use `freq` if you would like to + extend the index when shifting and preserve the original data. + If `freq` is specified as "infer" then it will be inferred from + the freq or inferred_freq attributes of the index. If neither of + those attributes exist, a ValueError is thrown. + axis : {{0 or 'index', 1 or 'columns', None}}, default None + Shift direction. For `Series` this parameter is unused and defaults to 0. + fill_value : object, optional + The scalar value to use for newly introduced missing values. + the default depends on the dtype of `self`. + For Boolean and numeric NumPy data types, ``np.nan`` is used. + For datetime, timedelta, or period data, etc. :attr:`NaT` is used. + For extension dtypes, ``self.dtype.na_value`` is used. + suffix : str, optional + If str and periods is an iterable, this is added after the column + name and before the shift value for each shifted column name. + For `Series` this parameter is unused and defaults to `None`. + + Returns + ------- + Series/DataFrame + Copy of input object, shifted. + + See Also + -------- + Index.shift : Shift values of Index. + DatetimeIndex.shift : Shift values of DatetimeIndex. + PeriodIndex.shift : Shift values of PeriodIndex. + + Examples + -------- + >>> df = pd.DataFrame( + ... [[10, 13, 17], [20, 23, 27], [15, 18, 22], [30, 33, 37], [45, 48, 52]], + ... columns=["Col1", "Col2", "Col3"], + ... index=pd.date_range("2020-01-01", "2020-01-05"), + ... ) + >>> df + Col1 Col2 Col3 + 2020-01-01 10 13 17 + 2020-01-02 20 23 27 + 2020-01-03 15 18 22 + 2020-01-04 30 33 37 + 2020-01-05 45 48 52 + + >>> df.shift(periods=3) + Col1 Col2 Col3 + 2020-01-01 NaN NaN NaN + 2020-01-02 NaN NaN NaN + 2020-01-03 NaN NaN NaN + 2020-01-04 10.0 13.0 17.0 + 2020-01-05 20.0 23.0 27.0 + + >>> df.shift(periods=1, axis="columns") + Col1 Col2 Col3 + 2020-01-01 NaN 10 13 + 2020-01-02 NaN 20 23 + 2020-01-03 NaN 15 18 + 2020-01-04 NaN 30 33 + 2020-01-05 NaN 45 48 + + >>> df.shift(periods=3, fill_value=0) + Col1 Col2 Col3 + 2020-01-01 0 0 0 + 2020-01-02 0 0 0 + 2020-01-03 0 0 0 + 2020-01-04 10 13 17 + 2020-01-05 20 23 27 + + >>> df.shift(periods=3, freq="D") + Col1 Col2 Col3 + 2020-01-04 10 13 17 + 2020-01-05 20 23 27 + 2020-01-06 15 18 22 + 2020-01-07 30 33 37 + 2020-01-08 45 48 52 + + >>> df.shift(periods=3, freq="infer") + Col1 Col2 Col3 + 2020-01-04 10 13 17 + 2020-01-05 20 23 27 + 2020-01-06 15 18 22 + 2020-01-07 30 33 37 + 2020-01-08 45 48 52 + + >>> df["Col1"].shift(periods=[0, 1, 2]) + Col1_0 Col1_1 Col1_2 + 2020-01-01 10 NaN NaN + 2020-01-02 20 10.0 NaN + 2020-01-03 15 20.0 10.0 + 2020-01-04 30 15.0 20.0 + 2020-01-05 45 30.0 15.0 + """ + axis = self._get_axis_number(axis) + + if freq is not None and fill_value is not lib.no_default: + # GH#53832 + raise ValueError( + "Passing a 'freq' together with a 'fill_value' is not allowed." + ) + + if periods == 0: + return self.copy(deep=False) + + if is_list_like(periods) and isinstance(self, ABCSeries): + return self.to_frame().shift( + periods=periods, freq=freq, axis=axis, fill_value=fill_value + ) + periods = cast(int, periods) + + if freq is None: + # when freq is None, data is shifted, index is not + axis = self._get_axis_number(axis) + assert axis == 0 # axis == 1 cases handled in DataFrame.shift + new_data = self._mgr.shift(periods=periods, fill_value=fill_value) + return self._constructor_from_mgr( + new_data, axes=new_data.axes + ).__finalize__(self, method="shift") + + return self._shift_with_freq(periods, axis, freq) + + @final + def _shift_with_freq(self, periods: int, axis: int, freq) -> Self: + # see shift.__doc__ + # when freq is given, index is shifted, data is not + index = self._get_axis(axis) + + if freq == "infer": + freq = getattr(index, "freq", None) + + if freq is None: + freq = getattr(index, "inferred_freq", None) + + if freq is None: + msg = "Freq was not set in the index hence cannot be inferred" + raise ValueError(msg) + + elif isinstance(freq, str): + is_period = isinstance(index, PeriodIndex) + freq = to_offset(freq, is_period=is_period) + + if isinstance(index, PeriodIndex): + orig_freq = to_offset(index.freq) + if freq != orig_freq: + assert orig_freq is not None # for mypy + raise ValueError( + f"Given freq {PeriodDtype(freq)._freqstr} " + f"does not match PeriodIndex freq " + f"{PeriodDtype(orig_freq)._freqstr}" + ) + new_ax: Index = index.shift(periods) + else: + new_ax = index.shift(periods, freq) + + result = self.set_axis(new_ax, axis=axis) + return result.__finalize__(self, method="shift") + + @final + def truncate( + self, + before=None, + after=None, + axis: Axis | None = None, + copy: bool | lib.NoDefault = lib.no_default, + ) -> Self: + """ + Truncate a Series or DataFrame before and after some index value. + + This is a useful shorthand for boolean indexing based on index + values above or below certain thresholds. + + Parameters + ---------- + before : date, str, int + Truncate all rows before this index value. + after : date, str, int + Truncate all rows after this index value. + axis : {0 or 'index', 1 or 'columns'}, optional + Axis to truncate. Truncates the index (rows) by default. + For `Series` this parameter is unused and defaults to 0. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + Returns + ------- + type of caller + The truncated Series or DataFrame. + + See Also + -------- + DataFrame.loc : Select a subset of a DataFrame by label. + DataFrame.iloc : Select a subset of a DataFrame by position. + + Notes + ----- + If the index being truncated contains only datetime values, + `before` and `after` may be specified as strings instead of + Timestamps. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "A": ["a", "b", "c", "d", "e"], + ... "B": ["f", "g", "h", "i", "j"], + ... "C": ["k", "l", "m", "n", "o"], + ... }, + ... index=[1, 2, 3, 4, 5], + ... ) + >>> df + A B C + 1 a f k + 2 b g l + 3 c h m + 4 d i n + 5 e j o + + >>> df.truncate(before=2, after=4) + A B C + 2 b g l + 3 c h m + 4 d i n + + The columns of a DataFrame can be truncated. + + >>> df.truncate(before="A", after="B", axis="columns") + A B + 1 a f + 2 b g + 3 c h + 4 d i + 5 e j + + For Series, only rows can be truncated. + + >>> df["A"].truncate(before=2, after=4) + 2 b + 3 c + 4 d + Name: A, dtype: str + + The index values in ``truncate`` can be datetimes or string + dates. + + >>> dates = pd.date_range("2016-01-01", "2016-02-01", freq="s") + >>> df = pd.DataFrame(index=dates, data={"A": 1}) + >>> df.tail() + A + 2016-01-31 23:59:56 1 + 2016-01-31 23:59:57 1 + 2016-01-31 23:59:58 1 + 2016-01-31 23:59:59 1 + 2016-02-01 00:00:00 1 + + >>> df.truncate( + ... before=pd.Timestamp("2016-01-05"), after=pd.Timestamp("2016-01-10") + ... ).tail() + A + 2016-01-09 23:59:56 1 + 2016-01-09 23:59:57 1 + 2016-01-09 23:59:58 1 + 2016-01-09 23:59:59 1 + 2016-01-10 00:00:00 1 + + Because the index is a DatetimeIndex containing only dates, we can + specify `before` and `after` as strings. They will be coerced to + Timestamps before truncation. + + >>> df.truncate("2016-01-05", "2016-01-10").tail() + A + 2016-01-09 23:59:56 1 + 2016-01-09 23:59:57 1 + 2016-01-09 23:59:58 1 + 2016-01-09 23:59:59 1 + 2016-01-10 00:00:00 1 + + Note that ``truncate`` assumes a 0 value for any unspecified time + component (midnight). This differs from partial string slicing, which + returns any partially matching dates. + + >>> df.loc["2016-01-05":"2016-01-10", :].tail() + A + 2016-01-10 23:59:55 1 + 2016-01-10 23:59:56 1 + 2016-01-10 23:59:57 1 + 2016-01-10 23:59:58 1 + 2016-01-10 23:59:59 1 + """ + self._check_copy_deprecation(copy) + + if axis is None: + axis = 0 + axis = self._get_axis_number(axis) + ax = self._get_axis(axis) + + # GH 17935 + # Check that index is sorted + if not ax.is_monotonic_increasing and not ax.is_monotonic_decreasing: + raise ValueError("truncate requires a sorted index") + + # if we have a date index, convert to dates, otherwise + # treat like a slice + if ax._is_all_dates: + from pandas.core.tools.datetimes import to_datetime + + if before is not None: + # Avoid converting to NaT + before = to_datetime(before) + if after is not None: + # Avoid converting to NaT + after = to_datetime(after) + + if before is not None and after is not None and before > after: + raise ValueError(f"Truncate: {after} must be after {before}") + + if len(ax) > 1 and ax.is_monotonic_decreasing and ax.nunique() > 1: + before, after = after, before + + slicer = [slice(None, None)] * self._AXIS_LEN + slicer[axis] = slice(before, after) + result = self.loc[tuple(slicer)] + + if isinstance(ax, MultiIndex): + setattr(result, self._get_axis_name(axis), ax.truncate(before, after)) + + result = result.copy(deep=False) + + return result + + @final + def tz_convert( + self, + tz, + axis: Axis = 0, + level=None, + copy: bool | lib.NoDefault = lib.no_default, + ) -> Self: + """ + Convert tz-aware axis to target time zone. + + Parameters + ---------- + tz : str or tzinfo object or None + Target time zone. Passing ``None`` will convert to + UTC and remove the timezone information. + axis : {{0 or 'index', 1 or 'columns'}}, default 0 + The axis to convert + level : int, str, default None + If axis is a MultiIndex, convert a specific level. Otherwise + must be None. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + Returns + ------- + Series/DataFrame + Object with time zone converted axis. + + Raises + ------ + TypeError + If the axis is tz-naive. + + See Also + -------- + DataFrame.tz_localize: Localize tz-naive index of DataFrame to target time zone. + Series.tz_localize: Localize tz-naive index of Series to target time zone. + + Examples + -------- + Change to another time zone: + + >>> s = pd.Series( + ... [1], + ... index=pd.DatetimeIndex(["2018-09-15 01:30:00+02:00"]), + ... ) + >>> s.tz_convert("Asia/Shanghai") + 2018-09-15 07:30:00+08:00 1 + dtype: int64 + + Pass None to convert to UTC and get a tz-naive index: + + >>> s = pd.Series([1], index=pd.DatetimeIndex(["2018-09-15 01:30:00+02:00"])) + >>> s.tz_convert(None) + 2018-09-14 23:30:00 1 + dtype: int64 + """ + self._check_copy_deprecation(copy) + axis = self._get_axis_number(axis) + ax = self._get_axis(axis) + + def _tz_convert(ax, tz): + if not hasattr(ax, "tz_convert"): + if len(ax) > 0: + ax_name = self._get_axis_name(axis) + raise TypeError( + f"{ax_name} is not a valid DatetimeIndex or PeriodIndex" + ) + ax = DatetimeIndex([], tz=tz) + else: + ax = ax.tz_convert(tz) + return ax + + # if a level is given it must be a MultiIndex level or + # equivalent to the axis name + if isinstance(ax, MultiIndex): + level = ax._get_level_number(level) + new_level = _tz_convert(ax.levels[level], tz) + ax = ax.set_levels(new_level, level=level) + else: + if level not in (None, 0, ax.name): + raise ValueError(f"The level {level} is not valid") + ax = _tz_convert(ax, tz) + + result = self.copy(deep=False) + result = result.set_axis(ax, axis=axis) + return result.__finalize__(self, method="tz_convert") + + @final + def tz_localize( + self, + tz, + axis: Axis = 0, + level=None, + copy: bool | lib.NoDefault = lib.no_default, + ambiguous: TimeAmbiguous = "raise", + nonexistent: TimeNonexistent = "raise", + ) -> Self: + """ + Localize time zone naive index of a Series or DataFrame to target time zone. + + This operation localizes the Index. To localize the values in a + time zone naive Series, use :meth:`Series.dt.tz_localize`. + + Parameters + ---------- + tz : str or tzinfo or None + Time zone to localize. Passing ``None`` will remove the + time zone information and preserve local time. + axis : {{0 or 'index', 1 or 'columns'}}, default 0 + The axis to localize + level : int, str, default None + If axis ia a MultiIndex, localize a specific level. Otherwise + must be None. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + ambiguous : 'infer', bool, bool-ndarray, 'NaT', default 'raise' + When clocks moved backward due to DST, ambiguous times may arise. + For example in Central European Time (UTC+01), when going from + 03:00 DST to 02:00 non-DST, 02:30:00 local time occurs both at + 00:30:00 UTC and at 01:30:00 UTC. In such a situation, the + `ambiguous` parameter dictates how ambiguous times should be + handled. + + - 'infer' will attempt to infer fall dst-transition hours based on + order + - bool (or bool-ndarray) where True signifies a DST time, False designates + a non-DST time (note that this flag is only applicable for + ambiguous times) + - 'NaT' will return NaT where there are ambiguous times + - 'raise' will raise a ValueError if there are ambiguous + times. + nonexistent : str, default 'raise' + A nonexistent time does not exist in a particular timezone + where clocks moved forward due to DST. Valid values are: + + - 'shift_forward' will shift the nonexistent time forward to the + closest existing time + - 'shift_backward' will shift the nonexistent time backward to the + closest existing time + - 'NaT' will return NaT where there are nonexistent times + - timedelta objects will shift nonexistent times by the timedelta + - 'raise' will raise a ValueError if there are + nonexistent times. + + Returns + ------- + Series/DataFrame + Same type as the input, with time zone naive or aware index, depending on + ``tz``. + + Raises + ------ + TypeError + If the TimeSeries is tz-aware and tz is not None. + + See Also + -------- + Series.dt.tz_localize: Localize the values in a time zone naive Series. + Timestamp.tz_localize: Localize the Timestamp to a timezone. + + Examples + -------- + Localize local times: + + >>> s = pd.Series( + ... [1], + ... index=pd.DatetimeIndex(["2018-09-15 01:30:00"]), + ... ) + >>> s.tz_localize("CET") + 2018-09-15 01:30:00+02:00 1 + dtype: int64 + + Pass None to convert to tz-naive index and preserve local time: + + >>> s = pd.Series([1], index=pd.DatetimeIndex(["2018-09-15 01:30:00+02:00"])) + >>> s.tz_localize(None) + 2018-09-15 01:30:00 1 + dtype: int64 + + Be careful with DST changes. When there is sequential data, pandas + can infer the DST time: + + >>> s = pd.Series( + ... range(7), + ... index=pd.DatetimeIndex( + ... [ + ... "2018-10-28 01:30:00", + ... "2018-10-28 02:00:00", + ... "2018-10-28 02:30:00", + ... "2018-10-28 02:00:00", + ... "2018-10-28 02:30:00", + ... "2018-10-28 03:00:00", + ... "2018-10-28 03:30:00", + ... ] + ... ), + ... ) + >>> s.tz_localize("CET", ambiguous="infer") + 2018-10-28 01:30:00+02:00 0 + 2018-10-28 02:00:00+02:00 1 + 2018-10-28 02:30:00+02:00 2 + 2018-10-28 02:00:00+01:00 3 + 2018-10-28 02:30:00+01:00 4 + 2018-10-28 03:00:00+01:00 5 + 2018-10-28 03:30:00+01:00 6 + dtype: int64 + + In some cases, inferring the DST is impossible. In such cases, you can + pass an ndarray to the ambiguous parameter to set the DST explicitly + + >>> s = pd.Series( + ... range(3), + ... index=pd.DatetimeIndex( + ... [ + ... "2018-10-28 01:20:00", + ... "2018-10-28 02:36:00", + ... "2018-10-28 03:46:00", + ... ] + ... ), + ... ) + >>> s.tz_localize("CET", ambiguous=np.array([True, True, False])) + 2018-10-28 01:20:00+02:00 0 + 2018-10-28 02:36:00+02:00 1 + 2018-10-28 03:46:00+01:00 2 + dtype: int64 + + If the DST transition causes nonexistent times, you can shift these + dates forward or backward with a timedelta object or `'shift_forward'` + or `'shift_backward'`. + + >>> dti = pd.DatetimeIndex( + ... ["2015-03-29 02:30:00", "2015-03-29 03:30:00"], dtype="M8[ns]" + ... ) + >>> s = pd.Series(range(2), index=dti) + >>> s.tz_localize("Europe/Warsaw", nonexistent="shift_forward") + 2015-03-29 03:00:00+02:00 0 + 2015-03-29 03:30:00+02:00 1 + dtype: int64 + >>> s.tz_localize("Europe/Warsaw", nonexistent="shift_backward") + 2015-03-29 01:59:59.999999999+01:00 0 + 2015-03-29 03:30:00+02:00 1 + dtype: int64 + >>> s.tz_localize("Europe/Warsaw", nonexistent=pd.Timedelta("1h")) + 2015-03-29 03:30:00+02:00 0 + 2015-03-29 03:30:00+02:00 1 + dtype: int64 + """ + self._check_copy_deprecation(copy) + nonexistent_options = ("raise", "NaT", "shift_forward", "shift_backward") + if nonexistent not in nonexistent_options and not isinstance( + nonexistent, dt.timedelta + ): + raise ValueError( + "The nonexistent argument must be one of 'raise', " + "'NaT', 'shift_forward', 'shift_backward' or " + "a timedelta object" + ) + + axis = self._get_axis_number(axis) + ax = self._get_axis(axis) + + def _tz_localize(ax, tz, ambiguous, nonexistent): + if not hasattr(ax, "tz_localize"): + if len(ax) > 0: + ax_name = self._get_axis_name(axis) + raise TypeError( + f"{ax_name} is not a valid DatetimeIndex or PeriodIndex" + ) + ax = DatetimeIndex([], tz=tz) + else: + ax = ax.tz_localize(tz, ambiguous=ambiguous, nonexistent=nonexistent) + return ax + + # if a level is given it must be a MultiIndex level or + # equivalent to the axis name + if isinstance(ax, MultiIndex): + level = ax._get_level_number(level) + new_level = _tz_localize(ax.levels[level], tz, ambiguous, nonexistent) + ax = ax.set_levels(new_level, level=level) + else: + if level not in (None, 0, ax.name): + raise ValueError(f"The level {level} is not valid") + ax = _tz_localize(ax, tz, ambiguous, nonexistent) + + result = self.copy(deep=False) + result = result.set_axis(ax, axis=axis) + return result.__finalize__(self, method="tz_localize") + + # ---------------------------------------------------------------------- + # Numeric Methods + + @final + def describe( + self, + percentiles=None, + include=None, + exclude=None, + ) -> Self: + """ + Generate descriptive statistics. + + Descriptive statistics include those that summarize the central + tendency, dispersion and shape of a + dataset's distribution, excluding ``NaN`` values. + + Analyzes both numeric and object series, as well + as ``DataFrame`` column sets of mixed data types. The output + will vary depending on what is provided. Refer to the notes + below for more detail. + + Parameters + ---------- + percentiles : list-like of numbers, optional + The percentiles to include in the output. All should + fall between 0 and 1. The default, ``None``, will automatically + return the 25th, 50th, and 75th percentiles. + include : 'all', list-like of dtypes or None (default), optional + A white list of data types to include in the result. Ignored + for ``Series``. Here are the options: + + - 'all' : All columns of the input will be included in the output. + - A list-like of dtypes : Limits the results to the + provided data types. + To limit the result to numeric types submit + ``numpy.number``. To limit it instead to object columns submit + the ``numpy.object`` data type. Strings + can also be used in the style of + ``select_dtypes`` (e.g. ``df.describe(include=['O'])``). To + select pandas categorical columns, use ``'category'`` + - None (default) : The result will include all numeric columns. + exclude : list-like of dtypes or None (default), optional, + A black list of data types to omit from the result. Ignored + for ``Series``. Here are the options: + + - A list-like of dtypes : Excludes the provided data types + from the result. To exclude numeric types submit + ``numpy.number``. To exclude object columns submit the data + type ``numpy.object``. Strings can also be used in the style of + ``select_dtypes`` (e.g. ``df.describe(exclude=['O'])``). To + exclude pandas categorical columns, use ``'category'`` + - None (default) : The result will exclude nothing. + + Returns + ------- + Series or DataFrame + Summary statistics of the Series or Dataframe provided. + + See Also + -------- + DataFrame.count: Count number of non-NA/null observations. + DataFrame.max: Maximum of the values in the object. + DataFrame.min: Minimum of the values in the object. + DataFrame.mean: Mean of the values. + DataFrame.std: Standard deviation of the observations. + DataFrame.select_dtypes: Subset of a DataFrame including/excluding + columns based on their dtype. + + Notes + ----- + For numeric data, the result's index will include ``count``, + ``mean``, ``std``, ``min``, ``max`` as well as lower, ``50`` and + upper percentiles. By default the lower percentile is ``25`` and the + upper percentile is ``75``. The ``50`` percentile is the + same as the median. + + For object data (e.g. strings), the result's index + will include ``count``, ``unique``, ``top``, and ``freq``. The ``top`` + is the most common value. The ``freq`` is the most common value's + frequency. + + If multiple object values have the highest count, then the + ``count`` and ``top`` results will be arbitrarily chosen from + among those with the highest count. + + For mixed data types provided via a ``DataFrame``, the default is to + return only an analysis of numeric columns. If the DataFrame consists + only of object and categorical data without any numeric columns, the + default is to return an analysis of both the object and categorical + columns. If ``include='all'`` is provided as an option, the result + will include a union of attributes of each type. + + The `include` and `exclude` parameters can be used to limit + which columns in a ``DataFrame`` are analyzed for the output. + The parameters are ignored when analyzing a ``Series``. + + Examples + -------- + Describing a numeric ``Series``. + + >>> s = pd.Series([1, 2, 3]) + >>> s.describe() + count 3.0 + mean 2.0 + std 1.0 + min 1.0 + 25% 1.5 + 50% 2.0 + 75% 2.5 + max 3.0 + dtype: float64 + + Describing a categorical ``Series``. + + >>> s = pd.Series(["a", "a", "b", "c"]) + >>> s.describe() + count 4 + unique 3 + top a + freq 2 + dtype: object + + Describing a timestamp ``Series``. + + >>> s = pd.Series( + ... [ + ... np.datetime64("2000-01-01"), + ... np.datetime64("2010-01-01"), + ... np.datetime64("2010-01-01"), + ... ] + ... ) + >>> s.describe() + count 3 + mean 2006-09-01 08:00:00 + min 2000-01-01 00:00:00 + 25% 2004-12-31 12:00:00 + 50% 2010-01-01 00:00:00 + 75% 2010-01-01 00:00:00 + max 2010-01-01 00:00:00 + dtype: object + + Describing a ``DataFrame``. By default only numeric fields + are returned. + + >>> df = pd.DataFrame( + ... { + ... "categorical": pd.Categorical(["d", "e", "f"]), + ... "numeric": [1, 2, 3], + ... "object": ["a", "b", "c"], + ... } + ... ) + >>> df.describe() + numeric + count 3.0 + mean 2.0 + std 1.0 + min 1.0 + 25% 1.5 + 50% 2.0 + 75% 2.5 + max 3.0 + + Describing all columns of a ``DataFrame`` regardless of data type. + + >>> df.describe(include="all") # doctest: +SKIP + categorical numeric object + count 3 3.0 3 + unique 3 NaN 3 + top f NaN a + freq 1 NaN 1 + mean NaN 2.0 NaN + std NaN 1.0 NaN + min NaN 1.0 NaN + 25% NaN 1.5 NaN + 50% NaN 2.0 NaN + 75% NaN 2.5 NaN + max NaN 3.0 NaN + + Describing a column from a ``DataFrame`` by accessing it as + an attribute. + + >>> df.numeric.describe() + count 3.0 + mean 2.0 + std 1.0 + min 1.0 + 25% 1.5 + 50% 2.0 + 75% 2.5 + max 3.0 + Name: numeric, dtype: float64 + + Including only numeric columns in a ``DataFrame`` description. + + >>> df.describe(include=[np.number]) + numeric + count 3.0 + mean 2.0 + std 1.0 + min 1.0 + 25% 1.5 + 50% 2.0 + 75% 2.5 + max 3.0 + + Including only string columns in a ``DataFrame`` description. + + >>> df.describe(include=[object]) # doctest: +SKIP + object + count 3 + unique 3 + top a + freq 1 + + Including only categorical columns from a ``DataFrame`` description. + + >>> df.describe(include=["category"]) + categorical + count 3 + unique 3 + top d + freq 1 + + Excluding numeric columns from a ``DataFrame`` description. + + >>> df.describe(exclude=[np.number]) # doctest: +SKIP + categorical object + count 3 3 + unique 3 3 + top f a + freq 1 1 + + Excluding object columns from a ``DataFrame`` description. + + >>> df.describe(exclude=[object]) # doctest: +SKIP + categorical numeric + count 3 3.0 + unique 3 NaN + top f NaN + freq 1 NaN + mean NaN 2.0 + std NaN 1.0 + min NaN 1.0 + 25% NaN 1.5 + 50% NaN 2.0 + 75% NaN 2.5 + max NaN 3.0 + """ + return describe_ndframe( + obj=self, + include=include, + exclude=exclude, + percentiles=percentiles, + ).__finalize__(self, method="describe") + + @final + def pct_change( + self, + periods: int = 1, + fill_method: None = None, + freq=None, + **kwargs, + ) -> Self: + """ + Fractional change between the current and a prior element. + + Computes the fractional change from the immediately previous row by + default. This is useful in comparing the fraction of change in a time + series of elements. + + .. note:: + + Despite the name of this method, it calculates fractional change + (also known as per unit change or relative change) and not + percentage change. If you need the percentage change, multiply + these values by 100. + + Parameters + ---------- + periods : int, default 1 + Periods to shift for forming percent change. + fill_method : None + Must be None. This argument will be removed in a future version of pandas. + freq : DateOffset, timedelta, or str, optional + Increment to use from time series API (e.g. 'ME' or BDay()). + **kwargs + Additional keyword arguments are passed into + `DataFrame.shift` or `Series.shift`. + + Returns + ------- + Series or DataFrame + The same type as the calling object. + + See Also + -------- + Series.diff : Compute the difference of two elements in a Series. + DataFrame.diff : Compute the difference of two elements in a DataFrame. + Series.shift : Shift the index by some number of periods. + DataFrame.shift : Shift the index by some number of periods. + + Examples + -------- + **Series** + + >>> s = pd.Series([90, 91, 85]) + >>> s + 0 90 + 1 91 + 2 85 + dtype: int64 + + >>> s.pct_change() + 0 NaN + 1 0.011111 + 2 -0.065934 + dtype: float64 + + >>> s.pct_change(periods=2) + 0 NaN + 1 NaN + 2 -0.055556 + dtype: float64 + + See the percentage change in a Series where filling NAs with last + valid observation forward to next valid. + + >>> s = pd.Series([90, 91, None, 85]) + >>> s + 0 90.0 + 1 91.0 + 2 NaN + 3 85.0 + dtype: float64 + + >>> s.ffill().pct_change() + 0 NaN + 1 0.011111 + 2 0.000000 + 3 -0.065934 + dtype: float64 + + **DataFrame** + + Percentage change in French franc, Deutsche Mark, and Italian lira from + 1980-01-01 to 1980-03-01. + + >>> df = pd.DataFrame( + ... { + ... "FR": [4.0405, 4.0963, 4.3149], + ... "GR": [1.7246, 1.7482, 1.8519], + ... "IT": [804.74, 810.01, 860.13], + ... }, + ... index=["1980-01-01", "1980-02-01", "1980-03-01"], + ... ) + >>> df + FR GR IT + 1980-01-01 4.0405 1.7246 804.74 + 1980-02-01 4.0963 1.7482 810.01 + 1980-03-01 4.3149 1.8519 860.13 + + >>> df.pct_change() + FR GR IT + 1980-01-01 NaN NaN NaN + 1980-02-01 0.013810 0.013684 0.006549 + 1980-03-01 0.053365 0.059318 0.061876 + + Percentage of change in GOOG and APPL stock volume. Shows computing + the percentage change between columns. + + >>> df = pd.DataFrame( + ... { + ... "2016": [1769950, 30586265], + ... "2015": [1500923, 40912316], + ... "2014": [1371819, 41403351], + ... }, + ... index=["GOOG", "APPL"], + ... ) + >>> df + 2016 2015 2014 + GOOG 1769950 1500923 1371819 + APPL 30586265 40912316 41403351 + + >>> df.pct_change(axis="columns", periods=-1) + 2016 2015 2014 + GOOG 0.179241 0.094112 NaN + APPL -0.252395 -0.011860 NaN + """ + # GH#53491 + if fill_method is not None: + raise ValueError(f"fill_method must be None; got {fill_method=}.") + + axis = self._get_axis_number(kwargs.pop("axis", "index")) + shifted = self.shift(periods=periods, freq=freq, axis=axis, **kwargs) + # Unsupported left operand type for / ("Self") + rs = self / shifted - 1 # type: ignore[operator] + if freq is not None: + # Shift method is implemented differently when freq is not None + # We want to restore the original index + rs = rs.loc[~rs.index.duplicated()] + rs = rs.reindex_like(self) + return rs.__finalize__(self, method="pct_change") + + @final + def _logical_func( + self, + name: str, + func, + axis: Axis | None = 0, + bool_only: bool = False, + skipna: bool = True, + **kwargs, + ) -> Series | bool: + nv.validate_logical_func((), kwargs, fname=name) + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + + if self.ndim > 1 and axis is None: + # Reduce along one dimension then the other, to simplify DataFrame._reduce + res = self._logical_func( + name, func, axis=0, bool_only=bool_only, skipna=skipna, **kwargs + ) + # error: Item "bool" of "Series | bool" has no attribute "_logical_func" + return res._logical_func( # type: ignore[union-attr] + name, func, skipna=skipna, **kwargs + ) + elif axis is None: + axis = 0 + + if ( + self.ndim > 1 + and axis == 1 + and len(self._mgr.blocks) > 1 + # TODO(EA2D): special-case not needed + and all(block.values.ndim == 2 for block in self._mgr.blocks) + and not kwargs + ): + # Fastpath avoiding potentially expensive transpose + obj = self + if bool_only: + obj = self._get_bool_data() + return obj._reduce_axis1(name, func, skipna=skipna) + + return self._reduce( + func, + name=name, + axis=axis, + skipna=skipna, + numeric_only=bool_only, + filter_type="bool", + ) + + def any( + self, + *, + axis: Axis | None = 0, + bool_only: bool = False, + skipna: bool = True, + **kwargs, + ) -> Series | bool: + return self._logical_func( + "any", nanops.nanany, axis, bool_only, skipna, **kwargs + ) + + def all( + self, + *, + axis: Axis = 0, + bool_only: bool = False, + skipna: bool = True, + **kwargs, + ) -> Series | bool: + return self._logical_func( + "all", nanops.nanall, axis, bool_only, skipna, **kwargs + ) + + @final + def _accum_func( + self, + name: str, + func, + axis: Axis | None = None, + skipna: bool = True, + *args, + **kwargs, + ): + skipna = nv.validate_cum_func_with_skipna(skipna, args, kwargs, name) + if axis is None: + axis = 0 + else: + axis = self._get_axis_number(axis) + + if axis == 1: + return self.T._accum_func( + name, + func, + axis=0, + skipna=skipna, + *args, # noqa: B026 + **kwargs, + ).T + + def block_accum_func(blk_values): + values = blk_values.T if hasattr(blk_values, "T") else blk_values + + result: np.ndarray | ExtensionArray + if isinstance(values, ExtensionArray): + result = values._accumulate(name, skipna=skipna, **kwargs) + else: + result = nanops.na_accum_func(values, func, skipna=skipna) + + result = result.T if hasattr(result, "T") else result + return result + + result = self._mgr.apply(block_accum_func) + + return self._constructor_from_mgr(result, axes=result.axes).__finalize__( + self, method=name + ) + + def cummax(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Self: + return self._accum_func( + "cummax", np.maximum.accumulate, axis, skipna, *args, **kwargs + ) + + def cummin(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Self: + return self._accum_func( + "cummin", np.minimum.accumulate, axis, skipna, *args, **kwargs + ) + + def cumsum(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Self: + return self._accum_func("cumsum", np.cumsum, axis, skipna, *args, **kwargs) + + def cumprod(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Self: + return self._accum_func("cumprod", np.cumprod, axis, skipna, *args, **kwargs) + + @final + def _stat_function_ddof( + self, + name: str, + func, + axis: Axis | None = 0, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, + ) -> Series | float: + nv.validate_stat_ddof_func((), kwargs, fname=name) + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + + return self._reduce( + func, name, axis=axis, numeric_only=numeric_only, skipna=skipna, ddof=ddof + ) + + def sem( + self, + *, + axis: Axis | None = 0, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, + ) -> Series | float: + return self._stat_function_ddof( + "sem", nanops.nansem, axis, skipna, ddof, numeric_only, **kwargs + ) + + def var( + self, + *, + axis: Axis | None = 0, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, + ) -> Series | float: + return self._stat_function_ddof( + "var", nanops.nanvar, axis, skipna, ddof, numeric_only, **kwargs + ) + + def std( + self, + *, + axis: Axis | None = 0, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, + ) -> Series | float: + return self._stat_function_ddof( + "std", nanops.nanstd, axis, skipna, ddof, numeric_only, **kwargs + ) + + @final + def _stat_function( + self, + name: str, + func, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ): + assert name in ["median", "mean", "min", "max", "kurt", "skew"], name + nv.validate_func(name, (), kwargs) + + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + + return self._reduce( + func, name=name, axis=axis, skipna=skipna, numeric_only=numeric_only + ) + + def min( + self, + *, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ): + return self._stat_function( + "min", + nanops.nanmin, + axis, + skipna, + numeric_only, + **kwargs, + ) + + def max( + self, + *, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ): + return self._stat_function( + "max", + nanops.nanmax, + axis, + skipna, + numeric_only, + **kwargs, + ) + + def mean( + self, + *, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ) -> Series | float: + return self._stat_function( + "mean", nanops.nanmean, axis, skipna, numeric_only, **kwargs + ) + + def median( + self, + *, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ) -> Series | float: + return self._stat_function( + "median", nanops.nanmedian, axis, skipna, numeric_only, **kwargs + ) + + def skew( + self, + *, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ) -> Series | float: + return self._stat_function( + "skew", nanops.nanskew, axis, skipna, numeric_only, **kwargs + ) + + def kurt( + self, + *, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ) -> Series | float: + return self._stat_function( + "kurt", nanops.nankurt, axis, skipna, numeric_only, **kwargs + ) + + kurtosis = kurt + + @final + def _min_count_stat_function( + self, + name: str, + func, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + min_count: int = 0, + **kwargs, + ): + assert name in ["sum", "prod"], name + nv.validate_func(name, (), kwargs) + + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + + return self._reduce( + func, + name=name, + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + min_count=min_count, + ) + + def sum( + self, + *, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + min_count: int = 0, + **kwargs, + ): + return self._min_count_stat_function( + "sum", nanops.nansum, axis, skipna, numeric_only, min_count, **kwargs + ) + + def prod( + self, + *, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + min_count: int = 0, + **kwargs, + ): + return self._min_count_stat_function( + "prod", + nanops.nanprod, + axis, + skipna, + numeric_only, + min_count, + **kwargs, + ) + + product = prod + + @final + def rolling( + self, + window: int | dt.timedelta | str | BaseOffset | BaseIndexer, + min_periods: int | None = None, + center: bool = False, + win_type: str | None = None, + on: str | None = None, + closed: IntervalClosedType | None = None, + step: int | None = None, + method: str = "single", + ) -> Window | Rolling: + """ + Provide rolling window calculations. + + Parameters + ---------- + window : int, timedelta, str, offset, or BaseIndexer subclass + Interval of the moving window. + + If an integer, the delta between the start and end of each window. + The number of points in the window depends on the ``closed`` argument. + + If a timedelta, str, or offset, the time period of each window. Each + window will be a variable sized based on the observations included in + the time-period. This is only valid for datetimelike indexes. + To learn more about the offsets & frequency strings, please see + :ref:`this link`. + + If a BaseIndexer subclass, the window boundaries + based on the defined ``get_window_bounds`` method. Additional rolling + keyword arguments, namely ``min_periods``, ``center``, ``closed`` and + ``step`` will be passed to ``get_window_bounds``. + + min_periods : int, default None + Minimum number of observations in window required to have a value; + otherwise, result is ``np.nan``. + + For a window that is specified by an offset, ``min_periods`` will default + to 1. + + For a window that is specified by an integer, ``min_periods`` will default + to the size of the window. + + center : bool, default False + If False, set the window labels as the right edge of the window index. + + If True, set the window labels as the center of the window index. + + win_type : str, default None + If ``None``, all points are evenly weighted. + + If a string, it must be a valid `scipy.signal window function + `__. + + Certain Scipy window types require additional parameters to be passed + in the aggregation function. The additional parameters must match + the keywords specified in the Scipy window type method signature. + + on : str, optional + For a DataFrame, a column label or Index level on which + to calculate the rolling window, rather than the DataFrame's index. + + Provided integer column is ignored and excluded from result since + an integer index is not used to calculate the rolling window. + + closed : str, default None + Determines the inclusivity of points in the window + + If ``'right'``, uses the window (first, last] meaning the last point + is included in the calculations. + + If ``'left'``, uses the window [first, last) meaning the first point + is included in the calculations. + + If ``'both'``, uses the window [first, last] meaning all points in + the window are included in the calculations. + + If ``'neither'``, uses the window (first, last) meaning the first + and last points in the window are excluded from calculations. + + () and [] are referencing open and closed set + notation respetively. + + Default ``None`` (``'right'``). + + step : int, default None + Evaluate the window at every ``step`` result, equivalent to slicing as + ``[::step]``. ``window`` must be an integer. Using a step argument other + than None or 1 will produce a result with a different shape than the input. + + method : str {'single', 'table'}, default 'single' + + Execute the rolling operation per single column or row (``'single'``) + or over the entire object (``'table'``). + + This argument is only implemented when specifying ``engine='numba'`` + in the method call. + + Returns + ------- + pandas.api.typing.Window or pandas.api.typing.Rolling + An instance of Window is returned if ``win_type`` is passed. Otherwise, + an instance of Rolling is returned. + + See Also + -------- + expanding : Provides expanding transformations. + ewm : Provides exponential weighted functions. + + Notes + ----- + See :ref:`Windowing Operations ` for further usage details + and examples. + + Examples + -------- + >>> df = pd.DataFrame({"B": [0, 1, 2, np.nan, 4]}) + >>> df + B + 0 0.0 + 1 1.0 + 2 2.0 + 3 NaN + 4 4.0 + + **window** + + Rolling sum with a window length of 2 observations. + + >>> df.rolling(2).sum() + B + 0 NaN + 1 1.0 + 2 3.0 + 3 NaN + 4 NaN + + Rolling sum with a window span of 2 seconds. + + >>> df_time = pd.DataFrame( + ... {"B": [0, 1, 2, np.nan, 4]}, + ... index=[ + ... pd.Timestamp("20130101 09:00:00"), + ... pd.Timestamp("20130101 09:00:02"), + ... pd.Timestamp("20130101 09:00:03"), + ... pd.Timestamp("20130101 09:00:05"), + ... pd.Timestamp("20130101 09:00:06"), + ... ], + ... ) + + >>> df_time + B + 2013-01-01 09:00:00 0.0 + 2013-01-01 09:00:02 1.0 + 2013-01-01 09:00:03 2.0 + 2013-01-01 09:00:05 NaN + 2013-01-01 09:00:06 4.0 + + >>> df_time.rolling("2s").sum() + B + 2013-01-01 09:00:00 0.0 + 2013-01-01 09:00:02 1.0 + 2013-01-01 09:00:03 3.0 + 2013-01-01 09:00:05 NaN + 2013-01-01 09:00:06 4.0 + + Rolling sum with forward looking windows with 2 observations. + + >>> indexer = pd.api.indexers.FixedForwardWindowIndexer(window_size=2) + >>> df.rolling(window=indexer, min_periods=1).sum() + B + 0 1.0 + 1 3.0 + 2 2.0 + 3 4.0 + 4 4.0 + + **min_periods** + + Rolling sum with a window length of 2 observations, but only needs a minimum + of 1 observation to calculate a value. + + >>> df.rolling(2, min_periods=1).sum() + B + 0 0.0 + 1 1.0 + 2 3.0 + 3 2.0 + 4 4.0 + + **center** + + Rolling sum with the result assigned to the center of the window index. + + >>> df.rolling(3, min_periods=1, center=True).sum() + B + 0 1.0 + 1 3.0 + 2 3.0 + 3 6.0 + 4 4.0 + + >>> df.rolling(3, min_periods=1, center=False).sum() + B + 0 0.0 + 1 1.0 + 2 3.0 + 3 3.0 + 4 6.0 + + **step** + + Rolling sum with a window length of 2 observations, minimum of 1 observation to + calculate a value, and a step of 2. + + >>> df.rolling(2, min_periods=1, step=2).sum() + B + 0 0.0 + 2 3.0 + 4 4.0 + + **win_type** + + Rolling sum with a window length of 2, using the Scipy ``'gaussian'`` + window type. ``std`` is required in the aggregation function. + + >>> df.rolling(2, win_type="gaussian").sum(std=3) + B + 0 NaN + 1 0.986207 + 2 2.958621 + 3 NaN + 4 NaN + + **on** + + Rolling sum with a window length of 2 days. + + >>> df = pd.DataFrame( + ... { + ... "A": [ + ... pd.to_datetime("2020-01-01"), + ... pd.to_datetime("2020-01-01"), + ... pd.to_datetime("2020-01-02"), + ... ], + ... "B": [1, 2, 3], + ... }, + ... index=pd.date_range("2020", periods=3), + ... ) + + >>> df + A B + 2020-01-01 2020-01-01 1 + 2020-01-02 2020-01-01 2 + 2020-01-03 2020-01-02 3 + + >>> df.rolling("2D", on="A").sum() + A B + 2020-01-01 2020-01-01 1.0 + 2020-01-02 2020-01-01 3.0 + 2020-01-03 2020-01-02 6.0 + """ + if win_type is not None: + return Window( + self, + window=window, + min_periods=min_periods, + center=center, + win_type=win_type, + on=on, + closed=closed, + step=step, + method=method, + ) + + return Rolling( + self, + window=window, + min_periods=min_periods, + center=center, + win_type=win_type, + on=on, + closed=closed, + step=step, + method=method, + ) + + @final + def expanding( + self, + min_periods: int = 1, + method: Literal["single", "table"] = "single", + ) -> Expanding: + """ + Provide expanding window calculations. + + An expanding window yields the value of an aggregation statistic with all + the data available up to that point in time. + + Parameters + ---------- + min_periods : int, default 1 + Minimum number of observations in window required to have a value; + otherwise, result is ``np.nan``. + + method : str {'single', 'table'}, default 'single' + Execute the rolling operation per single column or row (``'single'``) + or over the entire object (``'table'``). + + This argument is only implemented when specifying ``engine='numba'`` + in the method call. + + Returns + ------- + pandas.api.typing.Expanding + An instance of Expanding for further expanding window calculations, + e.g. using the ``sum`` method. + + See Also + -------- + rolling : Provides rolling window calculations. + ewm : Provides exponential weighted functions. + + Notes + ----- + See :ref:`Windowing Operations ` for further usage details + and examples. + + Examples + -------- + >>> df = pd.DataFrame({"B": [0, 1, 2, np.nan, 4]}) + >>> df + B + 0 0.0 + 1 1.0 + 2 2.0 + 3 NaN + 4 4.0 + + **min_periods** + + Expanding sum with 1 vs 3 observations needed to calculate a value. + + >>> df.expanding(1).sum() + B + 0 0.0 + 1 1.0 + 2 3.0 + 3 3.0 + 4 7.0 + >>> df.expanding(3).sum() + B + 0 NaN + 1 NaN + 2 3.0 + 3 3.0 + 4 7.0 + """ + return Expanding(self, min_periods=min_periods, method=method) + + @final + @doc(ExponentialMovingWindow) + def ewm( + self, + com: float | None = None, + span: float | None = None, + halflife: float | TimedeltaConvertibleTypes | None = None, + alpha: float | None = None, + min_periods: int | None = 0, + adjust: bool = True, + ignore_na: bool = False, + times: np.ndarray | DataFrame | Series | None = None, + method: Literal["single", "table"] = "single", + ) -> ExponentialMovingWindow: + return ExponentialMovingWindow( + self, + com=com, + span=span, + halflife=halflife, + alpha=alpha, + min_periods=min_periods, + adjust=adjust, + ignore_na=ignore_na, + times=times, + method=method, + ) + + # ---------------------------------------------------------------------- + # Arithmetic Methods + + @final + def _inplace_method(self, other, op) -> Self: + """ + Wrap arithmetic method to operate inplace. + """ + result = op(self, other) + + # this makes sure that we are aligned like the input + # we are updating inplace + self._update_inplace(result.reindex_like(self)) + return self + + @final + def __iadd__(self, other) -> Self: + # error: Unsupported left operand type for + ("Type[NDFrame]") + return self._inplace_method(other, type(self).__add__) # type: ignore[operator] + + @final + def __isub__(self, other) -> Self: + # error: Unsupported left operand type for - ("Type[NDFrame]") + return self._inplace_method(other, type(self).__sub__) # type: ignore[operator] + + @final + def __imul__(self, other) -> Self: + # error: Unsupported left operand type for * ("Type[NDFrame]") + return self._inplace_method(other, type(self).__mul__) # type: ignore[operator] + + @final + def __itruediv__(self, other) -> Self: + # error: Unsupported left operand type for / ("Type[NDFrame]") + return self._inplace_method( + other, + type(self).__truediv__, # type: ignore[operator] + ) + + @final + def __ifloordiv__(self, other) -> Self: + # error: Unsupported left operand type for // ("Type[NDFrame]") + return self._inplace_method( + other, + type(self).__floordiv__, # type: ignore[operator] + ) + + @final + def __imod__(self, other) -> Self: + # error: Unsupported left operand type for % ("Type[NDFrame]") + return self._inplace_method(other, type(self).__mod__) # type: ignore[operator] + + @final + def __ipow__(self, other) -> Self: + # error: Unsupported left operand type for ** ("Type[NDFrame]") + return self._inplace_method(other, type(self).__pow__) # type: ignore[operator] + + @final + def __iand__(self, other) -> Self: + # error: Unsupported left operand type for & ("Type[NDFrame]") + return self._inplace_method(other, type(self).__and__) # type: ignore[operator] + + @final + def __ior__(self, other) -> Self: + return self._inplace_method(other, type(self).__or__) + + @final + def __ixor__(self, other) -> Self: + # error: Unsupported left operand type for ^ ("Type[NDFrame]") + return self._inplace_method(other, type(self).__xor__) # type: ignore[operator] + + # ---------------------------------------------------------------------- + # Misc methods + + @final + def _find_valid_index(self, *, how: str) -> Hashable: + """ + Retrieves the index of the first valid value. + + Parameters + ---------- + how : {'first', 'last'} + Use this parameter to change between the first or last valid index. + + Returns + ------- + idx_first_valid : type of index + """ + is_valid = self.notna().values + idxpos = find_valid_index(how=how, is_valid=is_valid) + if idxpos is None: + return None + return self.index[idxpos] + + @final + def first_valid_index(self) -> Hashable: + """ + Return index for first non-missing value or None, if no value is found. + + See the :ref:`User Guide ` for more information + on which values are considered missing. + + Returns + ------- + type of index + Index of first non-missing value. + + See Also + -------- + DataFrame.last_valid_index : Return index for last non-NA value or None, if + no non-NA value is found. + Series.last_valid_index : Return index for last non-NA value or None, if no + non-NA value is found. + DataFrame.isna : Detect missing values. + + Examples + -------- + For Series: + + >>> s = pd.Series([None, 3, 4]) + >>> s.first_valid_index() + 1 + >>> s.last_valid_index() + 2 + + >>> s = pd.Series([None, None]) + >>> print(s.first_valid_index()) + None + >>> print(s.last_valid_index()) + None + + If all elements in Series are NA/null, returns None. + + >>> s = pd.Series() + >>> print(s.first_valid_index()) + None + >>> print(s.last_valid_index()) + None + + If Series is empty, returns None. + + For DataFrame: + + >>> df = pd.DataFrame({"A": [None, None, 2], "B": [None, 3, 4]}) + >>> df + A B + 0 NaN NaN + 1 NaN 3.0 + 2 2.0 4.0 + >>> df.first_valid_index() + 1 + >>> df.last_valid_index() + 2 + + >>> df = pd.DataFrame({"A": [None, None, None], "B": [None, None, None]}) + >>> df + A B + 0 None None + 1 None None + 2 None None + >>> print(df.first_valid_index()) + None + >>> print(df.last_valid_index()) + None + + If all elements in DataFrame are NA/null, returns None. + + >>> df = pd.DataFrame() + >>> df + Empty DataFrame + Columns: [] + Index: [] + >>> print(df.first_valid_index()) + None + >>> print(df.last_valid_index()) + None + + If DataFrame is empty, returns None. + """ + return self._find_valid_index(how="first") + + @final + def last_valid_index(self) -> Hashable: + """ + Return index for last non-missing value or None, if no value is found. + + See the :ref:`User Guide ` for more information + on which values are considered missing. + + Returns + ------- + type of index + Index of last non-missing value. + + See Also + -------- + DataFrame.first_valid_index : Return index for first non-NA value or None, if + no non-NA value is found. + Series.first_valid_index : Return index for first non-NA value or None, if no + non-NA value is found. + DataFrame.isna : Detect missing values. + + Examples + -------- + For Series: + + >>> s = pd.Series([None, 3, 4]) + >>> s.first_valid_index() + 1 + >>> s.last_valid_index() + 2 + + >>> s = pd.Series([None, None]) + >>> print(s.first_valid_index()) + None + >>> print(s.last_valid_index()) + None + + If all elements in Series are NA/null, returns None. + + >>> s = pd.Series() + >>> print(s.first_valid_index()) + None + >>> print(s.last_valid_index()) + None + + If Series is empty, returns None. + + For DataFrame: + + >>> df = pd.DataFrame({"A": [None, None, 2], "B": [None, 3, 4]}) + >>> df + A B + 0 NaN NaN + 1 NaN 3.0 + 2 2.0 4.0 + >>> df.first_valid_index() + 1 + >>> df.last_valid_index() + 2 + + >>> df = pd.DataFrame({"A": [None, None, None], "B": [None, None, None]}) + >>> df + A B + 0 None None + 1 None None + 2 None None + >>> print(df.first_valid_index()) + None + >>> print(df.last_valid_index()) + None + + If all elements in DataFrame are NA/null, returns None. + + >>> df = pd.DataFrame() + >>> df + Empty DataFrame + Columns: [] + Index: [] + >>> print(df.first_valid_index()) + None + >>> print(df.last_valid_index()) + None + + If DataFrame is empty, returns None. + """ + return self._find_valid_index(how="last") + + +_num_doc = """ +{desc} + +Parameters +---------- +axis : {axis_descr} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + For DataFrames, specifying ``axis=None`` will apply the aggregation + across both axes. + + .. versionadded:: 2.0.0 + +skipna : bool, default True + Exclude NA/null values when computing the result. +numeric_only : bool, default False + Include only float, int, boolean columns. + +{min_count}\ +**kwargs + Additional keyword arguments to be passed to the function. + +Returns +------- +{name1} or scalar\ + + Value containing the calculation referenced in the description.\ +{see_also}\ +{examples} +""" + +_sum_prod_doc = """ +{desc} + +Parameters +---------- +axis : {axis_descr} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + .. warning:: + + The behavior of DataFrame.{name} with ``axis=None`` is deprecated, + in a future version this will reduce over both axes and return a scalar + To retain the old behavior, pass axis=0 (or do not pass axis). + + .. versionadded:: 2.0.0 + +skipna : bool, default True + Exclude NA/null values when computing the result. +numeric_only : bool, default False + Include only float, int, boolean columns. Not implemented for Series. + +{min_count}\ +**kwargs + Additional keyword arguments to be passed to the function. + +Returns +------- +{name1} or scalar\ + + Value containing the calculation referenced in the description.\ +{see_also}\ +{examples} +""" + +_num_ddof_doc = """ +{desc} + +Parameters +---------- +axis : {axis_descr} + For `Series` this parameter is unused and defaults to 0. + + .. warning:: + + The behavior of DataFrame.{name} with ``axis=None`` is deprecated, + in a future version this will reduce over both axes and return a scalar + To retain the old behavior, pass axis=0 (or do not pass axis). + +skipna : bool, default True + Exclude NA/null values. If an entire row/column is NA, the result + will be NA. +ddof : int, default 1 + Delta Degrees of Freedom. The divisor used in calculations is N - ddof, + where N represents the number of elements. +numeric_only : bool, default False + Include only float, int, boolean columns. Not implemented for Series. +**kwargs : + Additional keywords have no effect but might be accepted + for compatibility with NumPy. + +Returns +------- +{name1} or {name2} (if level specified) + {return_desc} + +See Also +-------- +{see_also}\ +{notes}\ +{examples} +""" + +_sem_see_also = """\ +scipy.stats.sem : Compute standard error of the mean. +{name2}.std : Return sample standard deviation over requested axis. +{name2}.var : Return unbiased variance over requested axis. +{name2}.mean : Return the mean of the values over the requested axis. +{name2}.median : Return the median of the values over the requested axis. +{name2}.mode : Return the mode(s) of the Series.""" + +_sem_return_desc = """\ +Unbiased standard error of the mean over requested axis.""" + +_std_see_also = """\ +numpy.std : Compute the standard deviation along the specified axis. +{name2}.var : Return unbiased variance over requested axis. +{name2}.sem : Return unbiased standard error of the mean over requested axis. +{name2}.mean : Return the mean of the values over the requested axis. +{name2}.median : Return the median of the values over the requested axis. +{name2}.mode : Return the mode(s) of the Series.""" + +_std_return_desc = """\ +Standard deviation over requested axis.""" + +_std_notes = """ + +Notes +----- +To have the same behaviour as `numpy.std`, use `ddof=0` (instead of the +default `ddof=1`)""" + +_std_examples = """ + +Examples +-------- +>>> df = pd.DataFrame({'person_id': [0, 1, 2, 3], +... 'age': [21, 25, 62, 43], +... 'height': [1.61, 1.87, 1.49, 2.01]} +... ).set_index('person_id') +>>> df + age height +person_id +0 21 1.61 +1 25 1.87 +2 62 1.49 +3 43 2.01 + +The standard deviation of the columns can be found as follows: + +>>> df.std() +age 18.786076 +height 0.237417 +dtype: float64 + +Alternatively, `ddof=0` can be set to normalize by N instead of N-1: + +>>> df.std(ddof=0) +age 16.269219 +height 0.205609 +dtype: float64""" + +_var_examples = """ + +Examples +-------- +>>> df = pd.DataFrame({'person_id': [0, 1, 2, 3], +... 'age': [21, 25, 62, 43], +... 'height': [1.61, 1.87, 1.49, 2.01]} +... ).set_index('person_id') +>>> df + age height +person_id +0 21 1.61 +1 25 1.87 +2 62 1.49 +3 43 2.01 + +>>> df.var() +age 352.916667 +height 0.056367 +dtype: float64 + +Alternatively, ``ddof=0`` can be set to normalize by N instead of N-1: + +>>> df.var(ddof=0) +age 264.687500 +height 0.042275 +dtype: float64""" + +_bool_doc = """ +{desc} + +Parameters +---------- +axis : {{0 or 'index', 1 or 'columns', None}}, default 0 + Indicate which axis or axes should be reduced. For `Series` this parameter + is unused and defaults to 0. + + * 0 / 'index' : reduce the index, return a Series whose index is the + original column labels. + * 1 / 'columns' : reduce the columns, return a Series whose index is the + original index. + * None : reduce all axes, return a scalar. + +bool_only : bool, default False + Include only boolean columns. Not implemented for Series. +skipna : bool, default True + Exclude NA/null values. If the entire row/column is NA and skipna is + True, then the result will be {empty_value}, as for an empty row/column. + If skipna is False, then NA are treated as True, because these are not + equal to zero. +**kwargs : any, default None + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + +Returns +------- +{name2} or {name1} + If axis=None, then a scalar boolean is returned. + Otherwise a Series is returned with index matching the index argument. + +{see_also} +{examples}""" + +_all_desc = """\ +Return whether all elements are True, potentially over an axis. + +Returns True unless there at least one element within a series or +along a Dataframe axis that is False or equivalent (e.g. zero or +empty).""" + +_all_examples = """\ +Examples +-------- +**Series** + +>>> pd.Series([True, True]).all() +True +>>> pd.Series([True, False]).all() +False +>>> pd.Series([], dtype="float64").all() +True +>>> pd.Series([np.nan]).all() +True +>>> pd.Series([np.nan]).all(skipna=False) +True + +**DataFrames** + +Create a DataFrame from a dictionary. + +>>> df = pd.DataFrame({'col1': [True, True], 'col2': [True, False]}) +>>> df + col1 col2 +0 True True +1 True False + +Default behaviour checks if values in each column all return True. + +>>> df.all() +col1 True +col2 False +dtype: bool + +Specify ``axis='columns'`` to check if values in each row all return True. + +>>> df.all(axis='columns') +0 True +1 False +dtype: bool + +Or ``axis=None`` for whether every value is True. + +>>> df.all(axis=None) +False +""" + +_all_see_also = """\ +See Also +-------- +Series.all : Return True if all elements are True. +DataFrame.any : Return True if one (or more) elements are True. +""" + +_cnum_pd_doc = """ +Return cumulative {desc} over a DataFrame or Series axis. + +Returns a DataFrame or Series of the same size containing the cumulative +{desc}. + +Parameters +---------- +axis : {{0 or 'index', 1 or 'columns'}}, default 0 + The index or the name of the axis. 0 is equivalent to None or 'index'. + For `Series` this parameter is unused and defaults to 0. +skipna : bool, default True + Exclude NA/null values. If an entire row/column is NA, the result + will be NA. +numeric_only : bool, default False + Include only float, int, boolean columns. +*args, **kwargs + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + +Returns +------- +{name1} or {name2} + Return cumulative {desc} of {name1} or {name2}. + +See Also +-------- +core.window.expanding.Expanding.{accum_func_name} : Similar functionality + but ignores ``NaN`` values. +{name2}.{accum_func_name} : Return the {desc} over + {name2} axis. +{name2}.cummax : Return cumulative maximum over {name2} axis. +{name2}.cummin : Return cumulative minimum over {name2} axis. +{name2}.cumsum : Return cumulative sum over {name2} axis. +{name2}.cumprod : Return cumulative product over {name2} axis. + +{examples}""" + +_cnum_series_doc = """ +Return cumulative {desc} over a DataFrame or Series axis. + +Returns a DataFrame or Series of the same size containing the cumulative +{desc}. + +Parameters +---------- +axis : {{0 or 'index', 1 or 'columns'}}, default 0 + The index or the name of the axis. 0 is equivalent to None or 'index'. + For `Series` this parameter is unused and defaults to 0. +skipna : bool, default True + Exclude NA/null values. If an entire row/column is NA, the result + will be NA. +*args, **kwargs + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + +Returns +------- +{name1} or {name2} + Return cumulative {desc} of {name1} or {name2}. + +See Also +-------- +core.window.expanding.Expanding.{accum_func_name} : Similar functionality + but ignores ``NaN`` values. +{name2}.{accum_func_name} : Return the {desc} over + {name2} axis. +{name2}.cummax : Return cumulative maximum over {name2} axis. +{name2}.cummin : Return cumulative minimum over {name2} axis. +{name2}.cumsum : Return cumulative sum over {name2} axis. +{name2}.cumprod : Return cumulative product over {name2} axis. + +{examples}""" + +_cummin_examples = """\ +Examples +-------- +**Series** + +>>> s = pd.Series([2, np.nan, 5, -1, 0]) +>>> s +0 2.0 +1 NaN +2 5.0 +3 -1.0 +4 0.0 +dtype: float64 + +By default, NA values are ignored. + +>>> s.cummin() +0 2.0 +1 NaN +2 2.0 +3 -1.0 +4 -1.0 +dtype: float64 + +To include NA values in the operation, use ``skipna=False`` + +>>> s.cummin(skipna=False) +0 2.0 +1 NaN +2 NaN +3 NaN +4 NaN +dtype: float64 + +**DataFrame** + +>>> df = pd.DataFrame([[2.0, 1.0], +... [3.0, np.nan], +... [1.0, 0.0]], +... columns=list('AB')) +>>> df + A B +0 2.0 1.0 +1 3.0 NaN +2 1.0 0.0 + +By default, iterates over rows and finds the minimum +in each column. This is equivalent to ``axis=None`` or ``axis='index'``. + +>>> df.cummin() + A B +0 2.0 1.0 +1 2.0 NaN +2 1.0 0.0 + +To iterate over columns and find the minimum in each row, +use ``axis=1`` + +>>> df.cummin(axis=1) + A B +0 2.0 1.0 +1 3.0 NaN +2 1.0 0.0 +""" + +_cumsum_examples = """\ +Examples +-------- +**Series** + +>>> s = pd.Series([2, np.nan, 5, -1, 0]) +>>> s +0 2.0 +1 NaN +2 5.0 +3 -1.0 +4 0.0 +dtype: float64 + +By default, NA values are ignored. + +>>> s.cumsum() +0 2.0 +1 NaN +2 7.0 +3 6.0 +4 6.0 +dtype: float64 + +To include NA values in the operation, use ``skipna=False`` + +>>> s.cumsum(skipna=False) +0 2.0 +1 NaN +2 NaN +3 NaN +4 NaN +dtype: float64 + +**DataFrame** + +>>> df = pd.DataFrame([[2.0, 1.0], +... [3.0, np.nan], +... [1.0, 0.0]], +... columns=list('AB')) +>>> df + A B +0 2.0 1.0 +1 3.0 NaN +2 1.0 0.0 + +By default, iterates over rows and finds the sum +in each column. This is equivalent to ``axis=None`` or ``axis='index'``. + +>>> df.cumsum() + A B +0 2.0 1.0 +1 5.0 NaN +2 6.0 1.0 + +To iterate over columns and find the sum in each row, +use ``axis=1`` + +>>> df.cumsum(axis=1) + A B +0 2.0 3.0 +1 3.0 NaN +2 1.0 1.0 +""" + +_cumprod_examples = """\ +Examples +-------- +**Series** + +>>> s = pd.Series([2, np.nan, 5, -1, 0]) +>>> s +0 2.0 +1 NaN +2 5.0 +3 -1.0 +4 0.0 +dtype: float64 + +By default, NA values are ignored. + +>>> s.cumprod() +0 2.0 +1 NaN +2 10.0 +3 -10.0 +4 -0.0 +dtype: float64 + +To include NA values in the operation, use ``skipna=False`` + +>>> s.cumprod(skipna=False) +0 2.0 +1 NaN +2 NaN +3 NaN +4 NaN +dtype: float64 + +**DataFrame** + +>>> df = pd.DataFrame([[2.0, 1.0], +... [3.0, np.nan], +... [1.0, 0.0]], +... columns=list('AB')) +>>> df + A B +0 2.0 1.0 +1 3.0 NaN +2 1.0 0.0 + +By default, iterates over rows and finds the product +in each column. This is equivalent to ``axis=None`` or ``axis='index'``. + +>>> df.cumprod() + A B +0 2.0 1.0 +1 6.0 NaN +2 6.0 0.0 + +To iterate over columns and find the product in each row, +use ``axis=1`` + +>>> df.cumprod(axis=1) + A B +0 2.0 2.0 +1 3.0 NaN +2 1.0 0.0 +""" + +_cummax_examples = """\ +Examples +-------- +**Series** + +>>> s = pd.Series([2, np.nan, 5, -1, 0]) +>>> s +0 2.0 +1 NaN +2 5.0 +3 -1.0 +4 0.0 +dtype: float64 + +By default, NA values are ignored. + +>>> s.cummax() +0 2.0 +1 NaN +2 5.0 +3 5.0 +4 5.0 +dtype: float64 + +To include NA values in the operation, use ``skipna=False`` + +>>> s.cummax(skipna=False) +0 2.0 +1 NaN +2 NaN +3 NaN +4 NaN +dtype: float64 + +**DataFrame** + +>>> df = pd.DataFrame([[2.0, 1.0], +... [3.0, np.nan], +... [1.0, 0.0]], +... columns=list('AB')) +>>> df + A B +0 2.0 1.0 +1 3.0 NaN +2 1.0 0.0 + +By default, iterates over rows and finds the maximum +in each column. This is equivalent to ``axis=None`` or ``axis='index'``. + +>>> df.cummax() + A B +0 2.0 1.0 +1 3.0 NaN +2 3.0 1.0 + +To iterate over columns and find the maximum in each row, +use ``axis=1`` + +>>> df.cummax(axis=1) + A B +0 2.0 2.0 +1 3.0 NaN +2 1.0 1.0 +""" + +_any_see_also = """\ +See Also +-------- +numpy.any : Numpy version of this method. +Series.any : Return whether any element is True. +Series.all : Return whether all elements are True. +DataFrame.any : Return whether any element is True over requested axis. +DataFrame.all : Return whether all elements are True over requested axis. +""" + +_any_desc = """\ +Return whether any element is True, potentially over an axis. + +Returns False unless there is at least one element within a series or +along a Dataframe axis that is True or equivalent (e.g. non-zero or +non-empty).""" + +_any_examples = """\ +Examples +-------- +**Series** + +For Series input, the output is a scalar indicating whether any element +is True. + +>>> pd.Series([False, False]).any() +False +>>> pd.Series([True, False]).any() +True +>>> pd.Series([], dtype="float64").any() +False +>>> pd.Series([np.nan]).any() +False +>>> pd.Series([np.nan]).any(skipna=False) +True + +**DataFrame** + +Whether each column contains at least one True element (the default). + +>>> df = pd.DataFrame({"A": [1, 2], "B": [0, 2], "C": [0, 0]}) +>>> df + A B C +0 1 0 0 +1 2 2 0 + +>>> df.any() +A True +B True +C False +dtype: bool + +Aggregating over the columns. + +>>> df = pd.DataFrame({"A": [True, False], "B": [1, 2]}) +>>> df + A B +0 True 1 +1 False 2 + +>>> df.any(axis='columns') +0 True +1 True +dtype: bool + +>>> df = pd.DataFrame({"A": [True, False], "B": [1, 0]}) +>>> df + A B +0 True 1 +1 False 0 + +>>> df.any(axis='columns') +0 True +1 False +dtype: bool + +Aggregating over the entire DataFrame with ``axis=None``. + +>>> df.any(axis=None) +True + +`any` for an empty DataFrame is an empty Series. + +>>> pd.DataFrame([]).any() +Series([], dtype: bool) +""" + +_shared_docs["stat_func_example"] = """ + +Examples +-------- +>>> idx = pd.MultiIndex.from_arrays([ +... ['warm', 'warm', 'cold', 'cold'], +... ['dog', 'falcon', 'fish', 'spider']], +... names=['blooded', 'animal']) +>>> s = pd.Series([4, 2, 0, 8], name='legs', index=idx) +>>> s +blooded animal +warm dog 4 + falcon 2 +cold fish 0 + spider 8 +Name: legs, dtype: int64 + +>>> s.{stat_func}() +{default_output}""" + +_sum_examples = _shared_docs["stat_func_example"].format( + stat_func="sum", verb="Sum", default_output=14, level_output_0=6, level_output_1=8 +) + +_sum_examples += """ + +By default, the sum of an empty or all-NA Series is ``0``. + +>>> pd.Series([], dtype="float64").sum() # min_count=0 is the default +0.0 + +This can be controlled with the ``min_count`` parameter. For example, if +you'd like the sum of an empty series to be NaN, pass ``min_count=1``. + +>>> pd.Series([], dtype="float64").sum(min_count=1) +nan + +Thanks to the ``skipna`` parameter, ``min_count`` handles all-NA and +empty series identically. + +>>> pd.Series([np.nan]).sum() +0.0 + +>>> pd.Series([np.nan]).sum(min_count=1) +nan""" + +_max_examples: str = _shared_docs["stat_func_example"].format( + stat_func="max", verb="Max", default_output=8, level_output_0=4, level_output_1=8 +) + +_min_examples: str = _shared_docs["stat_func_example"].format( + stat_func="min", verb="Min", default_output=0, level_output_0=2, level_output_1=0 +) + +_skew_see_also = """ + +See Also +-------- +Series.skew : Return unbiased skew over requested axis. +Series.var : Return unbiased variance over requested axis. +Series.std : Return unbiased standard deviation over requested axis.""" + +_stat_func_see_also = """ + +See Also +-------- +Series.sum : Return the sum. +Series.min : Return the minimum. +Series.max : Return the maximum. +Series.idxmin : Return the index of the minimum. +Series.idxmax : Return the index of the maximum. +DataFrame.sum : Return the sum over the requested axis. +DataFrame.min : Return the minimum over the requested axis. +DataFrame.max : Return the maximum over the requested axis. +DataFrame.idxmin : Return the index of the minimum over the requested axis. +DataFrame.idxmax : Return the index of the maximum over the requested axis.""" + +_prod_examples = """ + +Examples +-------- +By default, the product of an empty or all-NA Series is ``1`` + +>>> pd.Series([], dtype="float64").prod() +1.0 + +This can be controlled with the ``min_count`` parameter + +>>> pd.Series([], dtype="float64").prod(min_count=1) +nan + +Thanks to the ``skipna`` parameter, ``min_count`` handles all-NA and +empty series identically. + +>>> pd.Series([np.nan]).prod() +1.0 + +>>> pd.Series([np.nan]).prod(min_count=1) +nan""" + +_min_count_stub = """\ +min_count : int, default 0 + The required number of valid values to perform the operation. If fewer than + ``min_count`` non-NA values are present the result will be NA. +""" + + +def make_doc(name: str, ndim: int) -> str: + """ + Generate the docstring for a Series/DataFrame reduction. + """ + if ndim == 1: + name1 = "scalar" + name2 = "Series" + axis_descr = "{index (0)}" + else: + name1 = "Series" + name2 = "DataFrame" + axis_descr = "{index (0), columns (1)}" + + if name == "any": + base_doc = _bool_doc + desc = _any_desc + see_also = _any_see_also + examples = _any_examples + kwargs = {"empty_value": "False"} + elif name == "all": + base_doc = _bool_doc + desc = _all_desc + see_also = _all_see_also + examples = _all_examples + kwargs = {"empty_value": "True"} + elif name == "min": + base_doc = _num_doc + desc = ( + "Return the minimum of the values over the requested axis.\n\n" + "If you want the *index* of the minimum, use ``idxmin``. This is " + "the equivalent of the ``numpy.ndarray`` method ``argmin``." + ) + see_also = _stat_func_see_also + examples = _min_examples + kwargs = {"min_count": ""} + elif name == "max": + base_doc = _num_doc + desc = ( + "Return the maximum of the values over the requested axis.\n\n" + "If you want the *index* of the maximum, use ``idxmax``. This is " + "the equivalent of the ``numpy.ndarray`` method ``argmax``." + ) + see_also = _stat_func_see_also + examples = _max_examples + kwargs = {"min_count": ""} + + elif name == "sum": + base_doc = _sum_prod_doc + desc = ( + "Return the sum of the values over the requested axis.\n\n" + "This is equivalent to the method ``numpy.sum``." + ) + see_also = _stat_func_see_also + examples = _sum_examples + kwargs = {"min_count": _min_count_stub} + + elif name == "prod": + base_doc = _sum_prod_doc + desc = "Return the product of the values over the requested axis." + see_also = _stat_func_see_also + examples = _prod_examples + kwargs = {"min_count": _min_count_stub} + + elif name == "median": + base_doc = _num_doc + desc = "Return the median of the values over the requested axis." + see_also = _stat_func_see_also + examples = """ + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.median() + 2.0 + + With a DataFrame + + >>> df = pd.DataFrame({'a': [1, 2], 'b': [2, 3]}, index=['tiger', 'zebra']) + >>> df + a b + tiger 1 2 + zebra 2 3 + >>> df.median() + a 1.5 + b 2.5 + dtype: float64 + + Using axis=1 + + >>> df.median(axis=1) + tiger 1.5 + zebra 2.5 + dtype: float64 + + In this case, `numeric_only` should be set to `True` + to avoid getting an error. + + >>> df = pd.DataFrame({'a': [1, 2], 'b': ['T', 'Z']}, + ... index=['tiger', 'zebra']) + >>> df.median(numeric_only=True) + a 1.5 + dtype: float64""" + kwargs = {"min_count": ""} + + elif name == "mean": + base_doc = _num_doc + desc = "Return the mean of the values over the requested axis." + see_also = _stat_func_see_also + examples = """ + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.mean() + 2.0 + + With a DataFrame + + >>> df = pd.DataFrame({'a': [1, 2], 'b': [2, 3]}, index=['tiger', 'zebra']) + >>> df + a b + tiger 1 2 + zebra 2 3 + >>> df.mean() + a 1.5 + b 2.5 + dtype: float64 + + Using axis=1 + + >>> df.mean(axis=1) + tiger 1.5 + zebra 2.5 + dtype: float64 + + In this case, `numeric_only` should be set to `True` to avoid + getting an error. + + >>> df = pd.DataFrame({'a': [1, 2], 'b': ['T', 'Z']}, + ... index=['tiger', 'zebra']) + >>> df.mean(numeric_only=True) + a 1.5 + dtype: float64""" + kwargs = {"min_count": ""} + + elif name == "var": + base_doc = _num_ddof_doc + desc = ( + "Return unbiased variance over requested axis.\n\nNormalized by " + "N-1 by default. This can be changed using the ddof argument." + ) + examples = _var_examples + see_also = "" + kwargs = {"notes": ""} + + elif name == "std": + base_doc = _num_ddof_doc + desc = ( + "Return sample standard deviation over requested axis." + "\n\nNormalized by N-1 by default. This can be changed using the " + "ddof argument." + ) + examples = _std_examples + see_also = _std_see_also.format(name2=name2) + kwargs = {"notes": "", "return_desc": _std_return_desc} + + elif name == "sem": + base_doc = _num_ddof_doc + desc = ( + "Return unbiased standard error of the mean over requested " + "axis.\n\nNormalized by N-1 by default. This can be changed " + "using the ddof argument" + ) + examples = """ + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> round(s.sem(), 6) + 0.57735 + + With a DataFrame + + >>> df = pd.DataFrame({'a': [1, 2], 'b': [2, 3]}, index=['tiger', 'zebra']) + >>> df + a b + tiger 1 2 + zebra 2 3 + >>> df.sem() + a 0.5 + b 0.5 + dtype: float64 + + Using axis=1 + + >>> df.sem(axis=1) + tiger 0.5 + zebra 0.5 + dtype: float64 + + In this case, `numeric_only` should be set to `True` + to avoid getting an error. + + >>> df = pd.DataFrame({'a': [1, 2], 'b': ['T', 'Z']}, + ... index=['tiger', 'zebra']) + >>> df.sem(numeric_only=True) + a 0.5 + dtype: float64""" + see_also = _sem_see_also.format(name2=name2) + kwargs = {"notes": "", "return_desc": _sem_return_desc} + + elif name == "skew": + base_doc = _num_doc + desc = "Return unbiased skew over requested axis.\n\nNormalized by N-1." + see_also = _skew_see_also + examples = """ + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.skew() + 0.0 + + With a DataFrame + + >>> df = pd.DataFrame({'a': [1, 2, 3], 'b': [2, 3, 4], 'c': [1, 3, 5]}, + ... index=['tiger', 'zebra', 'cow']) + >>> df + a b c + tiger 1 2 1 + zebra 2 3 3 + cow 3 4 5 + >>> df.skew() + a 0.0 + b 0.0 + c 0.0 + dtype: float64 + + Using axis=1 + + >>> df.skew(axis=1) + tiger 1.732051 + zebra -1.732051 + cow 0.000000 + dtype: float64 + + In this case, `numeric_only` should be set to `True` to avoid + getting an error. + + >>> df = pd.DataFrame({'a': [1, 2, 3], 'b': ['T', 'Z', 'X']}, + ... index=['tiger', 'zebra', 'cow']) + >>> df.skew(numeric_only=True) + a 0.0 + dtype: float64""" + kwargs = {"min_count": ""} + + elif name == "kurt": + base_doc = _num_doc + desc = ( + "Return unbiased kurtosis over requested axis.\n\n" + "Kurtosis obtained using Fisher's definition of\n" + "kurtosis (kurtosis of normal == 0.0). Normalized " + "by N-1." + ) + see_also = "" + examples = """ + + Examples + -------- + >>> s = pd.Series([1, 2, 2, 3], index=['cat', 'dog', 'dog', 'mouse']) + >>> s + cat 1 + dog 2 + dog 2 + mouse 3 + dtype: int64 + >>> s.kurt() + 1.5 + + With a DataFrame + + >>> df = pd.DataFrame({'a': [1, 2, 2, 3], 'b': [3, 4, 4, 4]}, + ... index=['cat', 'dog', 'dog', 'mouse']) + >>> df + a b + cat 1 3 + dog 2 4 + dog 2 4 + mouse 3 4 + >>> df.kurt() + a 1.5 + b 4.0 + dtype: float64 + + With axis=None + + >>> df.kurt(axis=None) + -0.9886927196984727 + + Using axis=1 + + >>> df = pd.DataFrame({'a': [1, 2], 'b': [3, 4], 'c': [3, 4], 'd': [1, 2]}, + ... index=['cat', 'dog']) + >>> df.kurt(axis=1) + cat -6.0 + dog -6.0 + dtype: float64""" + kwargs = {"min_count": ""} + + elif name == "cumsum": + if ndim == 1: + base_doc = _cnum_series_doc + else: + base_doc = _cnum_pd_doc + + desc = "sum" + see_also = "" + examples = _cumsum_examples + kwargs = {"accum_func_name": "sum"} + + elif name == "cumprod": + if ndim == 1: + base_doc = _cnum_series_doc + else: + base_doc = _cnum_pd_doc + + desc = "product" + see_also = "" + examples = _cumprod_examples + kwargs = {"accum_func_name": "prod"} + + elif name == "cummin": + if ndim == 1: + base_doc = _cnum_series_doc + else: + base_doc = _cnum_pd_doc + + desc = "minimum" + see_also = "" + examples = _cummin_examples + kwargs = {"accum_func_name": "min"} + + elif name == "cummax": + if ndim == 1: + base_doc = _cnum_series_doc + else: + base_doc = _cnum_pd_doc + + desc = "maximum" + see_also = "" + examples = _cummax_examples + kwargs = {"accum_func_name": "max"} + + else: + raise NotImplementedError + + docstr = base_doc.format( + desc=desc, + name=name, + name1=name1, + name2=name2, + axis_descr=axis_descr, + see_also=see_also, + examples=examples, + **kwargs, + ) + return docstr diff --git a/pandas/core/indexing.py b/pandas/core/indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a92cea53615e763c35198e5361bd78ae2f8fb4 --- /dev/null +++ b/pandas/core/indexing.py @@ -0,0 +1,2796 @@ +from __future__ import annotations + +from contextlib import suppress +import sys +from typing import ( + TYPE_CHECKING, + Any, + Self, + cast, + final, +) +import warnings + +import numpy as np + +from pandas._libs.indexing import NDFrameIndexerBase +from pandas._libs.lib import item_from_zerodim +from pandas.compat import CHAINED_WARNING_DISABLED +from pandas.compat._constants import REF_COUNT_IDX +from pandas.errors import ( + AbstractMethodError, + ChainedAssignmentError, + IndexingError, + InvalidIndexError, + LossySetitemError, +) +from pandas.errors.cow import _chained_assignment_msg +from pandas.util._decorators import ( + doc, +) + +from pandas.core.dtypes.cast import ( + can_hold_element, + maybe_promote, +) +from pandas.core.dtypes.common import ( + is_array_like, + is_bool_dtype, + is_hashable, + is_integer, + is_iterator, + is_list_like, + is_numeric_dtype, + is_object_dtype, + is_scalar, + is_sequence, +) +from pandas.core.dtypes.concat import concat_compat +from pandas.core.dtypes.dtypes import ExtensionDtype +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCSeries, +) +from pandas.core.dtypes.missing import ( + construct_1d_array_from_inferred_fill_value, + infer_fill_value, + is_valid_na_for_dtype, + isna, + na_value_for_dtype, +) + +from pandas.core import algorithms as algos +import pandas.core.common as com +from pandas.core.construction import ( + array as pd_array, + extract_array, +) +from pandas.core.indexers import ( + check_array_indexer, + is_list_like_indexer, + is_scalar_indexer, + length_of_indexer, +) +from pandas.core.indexes.api import ( + Index, + MultiIndex, +) + +if TYPE_CHECKING: + from collections.abc import ( + Hashable, + Sequence, + ) + + from pandas._typing import ( + Axis, + AxisInt, + T, + npt, + ) + + from pandas import ( + DataFrame, + Series, + ) + +# "null slice" +_NS = slice(None, None) +_one_ellipsis_message = "indexer may only contain one '...' entry" + + +# the public IndexSlicerMaker +class _IndexSlice: + """ + Create an object to more easily perform multi-index slicing. + + See Also + -------- + MultiIndex.remove_unused_levels : New MultiIndex with no unused levels. + + Notes + ----- + See :ref:`Defined Levels ` + for further info on slicing a MultiIndex. + + Examples + -------- + >>> midx = pd.MultiIndex.from_product([["A0", "A1"], ["B0", "B1", "B2", "B3"]]) + >>> columns = ["foo", "bar"] + >>> dfmi = pd.DataFrame( + ... np.arange(16).reshape((len(midx), len(columns))), + ... index=midx, + ... columns=columns, + ... ) + + Using the default slice command: + + >>> dfmi.loc[(slice(None), slice("B0", "B1")), :] + foo bar + A0 B0 0 1 + B1 2 3 + A1 B0 8 9 + B1 10 11 + + Using the IndexSlice class for a more intuitive command: + + >>> idx = pd.IndexSlice + >>> dfmi.loc[idx[:, "B0":"B1"], :] + foo bar + A0 B0 0 1 + B1 2 3 + A1 B0 8 9 + B1 10 11 + """ + + def __getitem__(self, arg): + return arg + + +IndexSlice = _IndexSlice() +IndexSlice.__module__ = "pandas" + + +class IndexingMixin: + """ + Mixin for adding .loc/.iloc/.at/.iat to Dataframes and Series. + """ + + @property + def iloc(self) -> _iLocIndexer: + """ + Purely integer-location based indexing for selection by position. + + .. versionchanged:: 3.0 + + Callables which return a tuple are deprecated as input. + + ``.iloc[]`` is primarily integer position based (from ``0`` to + ``length-1`` of the axis), but may also be used with a boolean + array. + + Allowed inputs are: + + - An integer, e.g. ``5``. + - A list or array of integers, e.g. ``[4, 3, 0]``. + - A slice object with ints, e.g. ``1:7``. + - A boolean array. + - A ``callable`` function with one argument (the calling Series or + DataFrame) and that returns valid output for indexing (one of the above). + This is useful in method chains, when you don't have a reference to the + calling object, but would like to base your selection on + some value. + - A tuple of row and column indexes. The tuple elements consist of one of the + above inputs, e.g. ``(0, 1)``. + + ``.iloc`` will raise ``IndexError`` if a requested indexer is + out-of-bounds, except *slice* indexers which allow out-of-bounds + indexing (this conforms with python/numpy *slice* semantics). + + See more at :ref:`Selection by Position `. + + See Also + -------- + DataFrame.iat : Fast integer location scalar accessor. + DataFrame.loc : Purely label-location based indexer for selection by label. + Series.iloc : Purely integer-location based indexing for + selection by position. + + Examples + -------- + >>> mydict = [ + ... {"a": 1, "b": 2, "c": 3, "d": 4}, + ... {"a": 100, "b": 200, "c": 300, "d": 400}, + ... {"a": 1000, "b": 2000, "c": 3000, "d": 4000}, + ... ] + >>> df = pd.DataFrame(mydict) + >>> df + a b c d + 0 1 2 3 4 + 1 100 200 300 400 + 2 1000 2000 3000 4000 + + **Indexing just the rows** + + With a scalar integer. + + >>> type(df.iloc[0]) + + >>> df.iloc[0] + a 1 + b 2 + c 3 + d 4 + Name: 0, dtype: int64 + + With a list of integers. + + >>> df.iloc[[0]] + a b c d + 0 1 2 3 4 + >>> type(df.iloc[[0]]) + + + >>> df.iloc[[0, 1]] + a b c d + 0 1 2 3 4 + 1 100 200 300 400 + + With a `slice` object. + + >>> df.iloc[:3] + a b c d + 0 1 2 3 4 + 1 100 200 300 400 + 2 1000 2000 3000 4000 + + With a boolean mask the same length as the index. + + >>> df.iloc[[True, False, True]] + a b c d + 0 1 2 3 4 + 2 1000 2000 3000 4000 + + With a callable, useful in method chains. The `x` passed + to the ``lambda`` is the DataFrame being sliced. This selects + the rows whose index label even. + + >>> df.iloc[lambda x: x.index % 2 == 0] + a b c d + 0 1 2 3 4 + 2 1000 2000 3000 4000 + + **Indexing both axes** + + You can mix the indexer types for the index and columns. Use ``:`` to + select the entire axis. + + With scalar integers. + + >>> df.iloc[0, 1] + np.int64(2) + + With lists of integers. + + >>> df.iloc[[0, 2], [1, 3]] + b d + 0 2 4 + 2 2000 4000 + + With `slice` objects. + + >>> df.iloc[1:3, 0:3] + a b c + 1 100 200 300 + 2 1000 2000 3000 + + With a boolean array whose length matches the columns. + + >>> df.iloc[:, [True, False, True, False]] + a c + 0 1 3 + 1 100 300 + 2 1000 3000 + + With a callable function that expects the Series or DataFrame. + + >>> df.iloc[:, lambda df: [0, 2]] + a c + 0 1 3 + 1 100 300 + 2 1000 3000 + """ + return _iLocIndexer("iloc", self) + + @property + def loc(self) -> _LocIndexer: + """ + Access a group of rows and columns by label(s) or a boolean array. + + ``.loc[]`` is primarily label based, but may also be used with a + boolean array. + + Allowed inputs are: + + - A single label, e.g. ``5`` or ``'a'``, (note that ``5`` is + interpreted as a *label* of the index, and **never** as an + integer position along the index). + - A list or array of labels, e.g. ``['a', 'b', 'c']``. + - A slice object with labels, e.g. ``'a':'f'``. + + .. warning:: Note that contrary to usual python slices, **both** the + start and the stop are included + + - A boolean array of the same length as the axis being sliced, + e.g. ``[True, False, True]``. + - An alignable boolean Series. The index of the key will be aligned before + masking. + - An alignable Index. The Index of the returned selection will be the input. + - A ``callable`` function with one argument (the calling Series or + DataFrame) and that returns valid output for indexing (one of the above) + + See more at :ref:`Selection by Label `. + + Raises + ------ + KeyError + If any items are not found. + IndexingError + If an indexed key is passed and its index is unalignable to the frame index. + + See Also + -------- + DataFrame.at : Access a single value for a row/column label pair. + DataFrame.iloc : Access group of rows and columns by integer position(s). + DataFrame.xs : Returns a cross-section (row(s) or column(s)) from the + Series/DataFrame. + Series.loc : Access group of values using labels. + + Examples + -------- + **Getting values** + + >>> df = pd.DataFrame( + ... [[1, 2], [4, 5], [7, 8]], + ... index=["cobra", "viper", "sidewinder"], + ... columns=["max_speed", "shield"], + ... ) + >>> df + max_speed shield + cobra 1 2 + viper 4 5 + sidewinder 7 8 + + Single label. Note this returns the row as a Series. + + >>> df.loc["viper"] + max_speed 4 + shield 5 + Name: viper, dtype: int64 + + List of labels. Note using ``[[]]`` returns a DataFrame. + + >>> df.loc[["viper", "sidewinder"]] + max_speed shield + viper 4 5 + sidewinder 7 8 + + Single label for row and column + + >>> df.loc["cobra", "shield"] + np.int64(2) + + Slice with labels for row and single label for column. As mentioned + above, note that both the start and stop of the slice are included. + + >>> df.loc["cobra":"viper", "max_speed"] + cobra 1 + viper 4 + Name: max_speed, dtype: int64 + + Boolean list with the same length as the row axis + + >>> df.loc[[False, False, True]] + max_speed shield + sidewinder 7 8 + + Alignable boolean Series: + + >>> df.loc[ + ... pd.Series([False, True, False], index=["viper", "sidewinder", "cobra"]) + ... ] + max_speed shield + sidewinder 7 8 + + Index (same behavior as ``df.reindex``) + + >>> df.loc[pd.Index(["cobra", "viper"], name="foo")] + max_speed shield + foo + cobra 1 2 + viper 4 5 + + Conditional that returns a boolean Series + + >>> df.loc[df["shield"] > 6] + max_speed shield + sidewinder 7 8 + + Conditional that returns a boolean Series with column labels specified + + >>> df.loc[df["shield"] > 6, ["max_speed"]] + max_speed + sidewinder 7 + + Multiple conditional using ``&`` that returns a boolean Series + + >>> df.loc[(df["max_speed"] > 1) & (df["shield"] < 8)] + max_speed shield + viper 4 5 + + Multiple conditional using ``|`` that returns a boolean Series + + >>> df.loc[(df["max_speed"] > 4) | (df["shield"] < 5)] + max_speed shield + cobra 1 2 + sidewinder 7 8 + + Please ensure that each condition is wrapped in parentheses ``()``. + See the :ref:`user guide` + for more details and explanations of Boolean indexing. + + .. note:: + If you find yourself using 3 or more conditionals in ``.loc[]``, + consider using :ref:`advanced indexing`. + + See below for using ``.loc[]`` on MultiIndex DataFrames. + + Callable that returns a boolean Series + + >>> df.loc[lambda df: df["shield"] == 8] + max_speed shield + sidewinder 7 8 + + **Setting values** + + Set value for all items matching the list of labels + + >>> df.loc[["viper", "sidewinder"], ["shield"]] = 50 + >>> df + max_speed shield + cobra 1 2 + viper 4 50 + sidewinder 7 50 + + Set value for an entire row + + >>> df.loc["cobra"] = 10 + >>> df + max_speed shield + cobra 10 10 + viper 4 50 + sidewinder 7 50 + + Set value for an entire column + + >>> df.loc[:, "max_speed"] = 30 + >>> df + max_speed shield + cobra 30 10 + viper 30 50 + sidewinder 30 50 + + Set value for rows matching callable condition + + >>> df.loc[df["shield"] > 35] = 0 + >>> df + max_speed shield + cobra 30 10 + viper 0 0 + sidewinder 0 0 + + Add value matching location + + >>> df.loc["viper", "shield"] += 5 + >>> df + max_speed shield + cobra 30 10 + viper 0 5 + sidewinder 0 0 + + Setting using a ``Series`` or a ``DataFrame`` sets the values matching the + index labels, not the index positions. + + >>> shuffled_df = df.loc[["viper", "cobra", "sidewinder"]] + >>> df.loc[:] += shuffled_df + >>> df + max_speed shield + cobra 60 20 + viper 0 10 + sidewinder 0 0 + + **Getting values on a DataFrame with an index that has integer labels** + + Another example using integers for the index + + >>> df = pd.DataFrame( + ... [[1, 2], [4, 5], [7, 8]], + ... index=[7, 8, 9], + ... columns=["max_speed", "shield"], + ... ) + >>> df + max_speed shield + 7 1 2 + 8 4 5 + 9 7 8 + + Slice with integer labels for rows. As mentioned above, note that both + the start and stop of the slice are included. + + >>> df.loc[7:9] + max_speed shield + 7 1 2 + 8 4 5 + 9 7 8 + + **Getting values with a MultiIndex** + + A number of examples using a DataFrame with a MultiIndex + + >>> tuples = [ + ... ("cobra", "mark i"), + ... ("cobra", "mark ii"), + ... ("sidewinder", "mark i"), + ... ("sidewinder", "mark ii"), + ... ("viper", "mark ii"), + ... ("viper", "mark iii"), + ... ] + >>> index = pd.MultiIndex.from_tuples(tuples) + >>> values = [[12, 2], [0, 4], [10, 20], [1, 4], [7, 1], [16, 36]] + >>> df = pd.DataFrame(values, columns=["max_speed", "shield"], index=index) + >>> df + max_speed shield + cobra mark i 12 2 + mark ii 0 4 + sidewinder mark i 10 20 + mark ii 1 4 + viper mark ii 7 1 + mark iii 16 36 + + Single label. Note this returns a DataFrame with a single index. + + >>> df.loc["cobra"] + max_speed shield + mark i 12 2 + mark ii 0 4 + + Single index tuple. Note this returns a Series. + + >>> df.loc[("cobra", "mark ii")] + max_speed 0 + shield 4 + Name: (cobra, mark ii), dtype: int64 + + Single label for row and column. Similar to passing in a tuple, this + returns a Series. + + >>> df.loc["cobra", "mark i"] + max_speed 12 + shield 2 + Name: (cobra, mark i), dtype: int64 + + Single tuple. Note using ``[[]]`` returns a DataFrame. + + >>> df.loc[[("cobra", "mark ii")]] + max_speed shield + cobra mark ii 0 4 + + Single tuple for the index with a single label for the column + + >>> df.loc[("cobra", "mark i"), "shield"] + np.int64(2) + + Slice from index tuple to single label + + >>> df.loc[("cobra", "mark i") : "viper"] + max_speed shield + cobra mark i 12 2 + mark ii 0 4 + sidewinder mark i 10 20 + mark ii 1 4 + viper mark ii 7 1 + mark iii 16 36 + + Slice from index tuple to index tuple + + >>> df.loc[("cobra", "mark i") : ("viper", "mark ii")] + max_speed shield + cobra mark i 12 2 + mark ii 0 4 + sidewinder mark i 10 20 + mark ii 1 4 + viper mark ii 7 1 + + Please see the :ref:`user guide` + for more details and explanations of advanced indexing. + + **Assignment with Series** + + When assigning a Series to .loc[row_indexer, col_indexer], pandas aligns + the Series by index labels, not by order or position. + + Series assignment with .loc and index alignment: + + >>> df = pd.DataFrame({"A": [1, 2, 3]}, index=[0, 1, 2]) + >>> s = pd.Series([10, 20], index=[1, 0]) # Note reversed order + >>> df.loc[:, "B"] = s # Aligns by index, not order + >>> df + A B + 0 1 20.0 + 1 2 10.0 + 2 3 NaN + """ + return _LocIndexer("loc", self) + + @property + def at(self) -> _AtIndexer: + """ + Access a single value for a row/column label pair. + + Similar to ``loc``, in that both provide label-based lookups. Use + ``at`` if you only need to get or set a single value in a DataFrame + or Series. + + Raises + ------ + KeyError + If getting a value and 'label' does not exist in a DataFrame or Series. + + ValueError + If row/column label pair is not a tuple or if any label + from the pair is not a scalar for DataFrame. + If label is list-like (*excluding* NamedTuple) for Series. + + See Also + -------- + DataFrame.at : Access a single value for a row/column pair by label. + DataFrame.iat : Access a single value for a row/column pair by integer + position. + DataFrame.loc : Access a group of rows and columns by label(s). + DataFrame.iloc : Access a group of rows and columns by integer + position(s). + Series.at : Access a single value by label. + Series.iat : Access a single value by integer position. + Series.loc : Access a group of rows by label(s). + Series.iloc : Access a group of rows by integer position(s). + + Notes + ----- + See :ref:`Fast scalar value getting and setting ` + for more details. + + Examples + -------- + >>> df = pd.DataFrame( + ... [[0, 2, 3], [0, 4, 1], [10, 20, 30]], + ... index=[4, 5, 6], + ... columns=["A", "B", "C"], + ... ) + >>> df + A B C + 4 0 2 3 + 5 0 4 1 + 6 10 20 30 + + Get value at specified row/column pair + + >>> df.at[4, "B"] + np.int64(2) + + Set value at specified row/column pair + + >>> df.at[4, "B"] = 10 + >>> df.at[4, "B"] + np.int64(10) + + Get value within a Series + + >>> df.loc[5].at["B"] + np.int64(4) + """ + return _AtIndexer("at", self) + + @property + def iat(self) -> _iAtIndexer: + """ + Access a single value for a row/column pair by integer position. + + Similar to ``iloc``, in that both provide integer-based lookups. Use + ``iat`` if you only need to get or set a single value in a DataFrame + or Series. + + Raises + ------ + IndexError + When integer position is out of bounds. + + See Also + -------- + DataFrame.at : Access a single value for a row/column label pair. + DataFrame.loc : Access a group of rows and columns by label(s). + DataFrame.iloc : Access a group of rows and columns by integer position(s). + + Examples + -------- + >>> df = pd.DataFrame( + ... [[0, 2, 3], [0, 4, 1], [10, 20, 30]], columns=["A", "B", "C"] + ... ) + >>> df + A B C + 0 0 2 3 + 1 0 4 1 + 2 10 20 30 + + Get value at specified row/column pair + + >>> df.iat[1, 2] + np.int64(1) + + Set value at specified row/column pair + + >>> df.iat[1, 2] = 10 + >>> df.iat[1, 2] + np.int64(10) + + Get value within a series + + >>> df.loc[0].iat[1] + np.int64(2) + """ + return _iAtIndexer("iat", self) + + +class _LocationIndexer(NDFrameIndexerBase): + _valid_types: str + axis: AxisInt | None = None + + # sub-classes need to set _takeable + _takeable: bool + + @final + def __call__(self, axis: Axis | None = None) -> Self: + # we need to return a copy of ourselves + new_self = type(self)(self.name, self.obj) + + if axis is not None: + axis_int_none = self.obj._get_axis_number(axis) + else: + axis_int_none = axis + new_self.axis = axis_int_none + return new_self + + def _get_setitem_indexer(self, key): + """ + Convert a potentially-label-based key into a positional indexer. + """ + if self.name == "loc": + # always holds here bc iloc overrides _get_setitem_indexer + self._ensure_listlike_indexer(key, axis=self.axis) + + if isinstance(key, tuple): + for x in key: + check_dict_or_set_indexers(x) + + if self.axis is not None: + key = _tupleize_axis_indexer(self.ndim, self.axis, key) + + ax = self.obj._get_axis(0) + + if ( + isinstance(ax, MultiIndex) + and self.name != "iloc" + and is_hashable(key, allow_slice=False) + ): + with suppress(KeyError, InvalidIndexError): + # TypeError e.g. passed a bool + return ax.get_loc(key) + + if isinstance(key, tuple): + with suppress(IndexingError): + # suppress "Too many indexers" + return self._convert_tuple(key) + + if isinstance(key, range): + # GH#45479 test_loc_setitem_range_key + key = list(key) + + return self._convert_to_indexer(key, axis=0) + + @final + def _maybe_mask_setitem_value(self, indexer, value): + """ + If we have obj.iloc[mask] = series_or_frame and series_or_frame has the + same length as obj, we treat this as obj.iloc[mask] = series_or_frame[mask], + similar to Series.__setitem__. + + Note this is only for loc, not iloc. + """ + + if ( + isinstance(indexer, tuple) + and len(indexer) == 2 + and isinstance(value, (ABCSeries, ABCDataFrame)) + ): + pi, icols = indexer + ndim = value.ndim + if com.is_bool_indexer(pi) and len(value) == len(pi): + newkey = pi.nonzero()[0] + + if is_scalar_indexer(icols, self.ndim - 1) and ndim == 1: + # e.g. test_loc_setitem_boolean_mask_allfalse + if len(newkey) == 0: + value = value.iloc[:0] + else: + # test_loc_setitem_ndframe_values_alignment + value = self.obj.iloc._align_series(indexer, value) + indexer = (newkey, icols) + + elif ( + isinstance(icols, np.ndarray) + and icols.dtype.kind == "i" + and len(icols) == 1 + ): + if ndim == 1: + # We implicitly broadcast, though numpy does not, see + # github.com/pandas-dev/pandas/pull/45501#discussion_r789071825 + # test_loc_setitem_ndframe_values_alignment + value = self.obj.iloc._align_series(indexer, value) + indexer = (newkey, icols) + + elif ndim == 2 and value.shape[1] == 1: + if len(newkey) == 0: + value = value.iloc[:0] + else: + # test_loc_setitem_ndframe_values_alignment + value = self.obj.iloc._align_frame(indexer, value) + indexer = (newkey, icols) + elif com.is_bool_indexer(indexer): + indexer = indexer.nonzero()[0] + + return indexer, value + + @final + def _ensure_listlike_indexer(self, key, axis=None, value=None) -> None: + """ + Ensure that a list-like of column labels are all present by adding them if + they do not already exist. + + Parameters + ---------- + key : list-like of column labels + Target labels. + axis : key axis if known + """ + column_axis = 1 + + # column only exists in 2-dimensional DataFrame + if self.ndim != 2: + return + + if isinstance(key, tuple) and len(key) > 1: + # key may be a tuple if we are .loc + # if length of key is > 1 set key to column part + # unless axis is already specified, then go with that + if axis is None: + axis = column_axis + key = key[axis] + + if ( + axis == column_axis + and not isinstance(self.obj.columns, MultiIndex) + and is_list_like_indexer(key) + and not com.is_bool_indexer(key) + and all(is_hashable(k) for k in key) + ): + # GH#38148 + keys = self.obj.columns.union(key, sort=False) + diff = Index(key, copy=False).difference(self.obj.columns, sort=False) + + if len(diff): + # e.g. if we are doing df.loc[:, ["A", "B"]] = 7 and "B" + # is a new column, add the new columns with dtype=np.void + # so that later when we go through setitem_single_column + # we will use isetitem. Without this, the reindex_axis + # below would create float64 columns in this example, which + # would successfully hold 7, so we would end up with the wrong + # dtype. + indexer = np.arange(len(keys), dtype=np.intp) + indexer[len(self.obj.columns) :] = -1 + new_mgr = self.obj._mgr.reindex_indexer( + keys, indexer=indexer, axis=0, only_slice=True, use_na_proxy=True + ) + self.obj._mgr = new_mgr + return + + self.obj._mgr = self.obj._mgr.reindex_axis(keys, axis=0, only_slice=True) + + @final + def __setitem__(self, key, value) -> None: + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount(self.obj) <= REF_COUNT_IDX: + warnings.warn( + _chained_assignment_msg, ChainedAssignmentError, stacklevel=2 + ) + + check_dict_or_set_indexers(key) + if isinstance(key, tuple): + key = (list(x) if is_iterator(x) else x for x in key) + key = tuple(com.apply_if_callable(x, self.obj) for x in key) + else: + maybe_callable = com.apply_if_callable(key, self.obj) + key = self._raise_callable_usage(key, maybe_callable) + indexer = self._get_setitem_indexer(key) + self._has_valid_setitem_indexer(key) + + iloc: _iLocIndexer = ( + cast("_iLocIndexer", self) if self.name == "iloc" else self.obj.iloc + ) + iloc._setitem_with_indexer(indexer, value, self.name) + + def _validate_key(self, key, axis: AxisInt) -> None: + """ + Ensure that key is valid for current indexer. + + Parameters + ---------- + key : scalar, slice or list-like + Key requested. + axis : int + Dimension on which the indexing is being made. + + Raises + ------ + TypeError + If the key (or some element of it) has wrong type. + IndexError + If the key (or some element of it) is out of bounds. + KeyError + If the key was not found. + """ + raise AbstractMethodError(self) + + @final + def _expand_ellipsis(self, tup: tuple) -> tuple: + """ + If a tuple key includes an Ellipsis, replace it with an appropriate + number of null slices. + """ + if any(x is Ellipsis for x in tup): + if tup.count(Ellipsis) > 1: + raise IndexingError(_one_ellipsis_message) + + if len(tup) == self.ndim: + # It is unambiguous what axis this Ellipsis is indexing, + # treat as a single null slice. + i = tup.index(Ellipsis) + # FIXME: this assumes only one Ellipsis + new_key = (*tup[:i], _NS, *tup[i + 1 :]) + return new_key + + # TODO: other cases? only one test gets here, and that is covered + # by _validate_key_length + return tup + + @final + def _validate_tuple_indexer(self, key: tuple) -> tuple: + """ + Check the key for valid keys across my indexer. + """ + key = self._validate_key_length(key) + key = self._expand_ellipsis(key) + for i, k in enumerate(key): + try: + self._validate_key(k, i) + except ValueError as err: + raise ValueError( + f"Location based indexing can only have [{self._valid_types}] types" + ) from err + return key + + @final + def _is_nested_tuple_indexer(self, tup: tuple) -> bool: + """ + Returns + ------- + bool + """ + if any(isinstance(ax, MultiIndex) for ax in self.obj.axes): + return any(is_nested_tuple(tup, ax) for ax in self.obj.axes) + return False + + @final + def _convert_tuple(self, key: tuple) -> tuple: + # Note: we assume _tupleize_axis_indexer has been called, if necessary. + self._validate_key_length(key) + keyidx = [self._convert_to_indexer(k, axis=i) for i, k in enumerate(key)] + return tuple(keyidx) + + @final + def _validate_key_length(self, key: tuple) -> tuple: + if len(key) > self.ndim: + if key[0] is Ellipsis: + # e.g. Series.iloc[..., 3] reduces to just Series.iloc[3] + key = key[1:] + if Ellipsis in key: + raise IndexingError(_one_ellipsis_message) + return self._validate_key_length(key) + raise IndexingError("Too many indexers") + return key + + @final + def _getitem_tuple_same_dim(self, tup: tuple): + """ + Index with indexers that should return an object of the same dimension + as self.obj. + + This is only called after a failed call to _getitem_lowerdim. + """ + retval = self.obj + # Selecting columns before rows is significantly faster + start_val = (self.ndim - len(tup)) + 1 + for i, key in enumerate(reversed(tup)): + i = self.ndim - i - start_val + if com.is_null_slice(key): + continue + + retval = getattr(retval, self.name)._getitem_axis(key, axis=i) + # We should never have retval.ndim < self.ndim, as that should + # be handled by the _getitem_lowerdim call above. + assert retval.ndim == self.ndim + + if retval is self.obj: + # if all axes were a null slice (`df.loc[:, :]`), ensure we still + # return a new object (https://github.com/pandas-dev/pandas/pull/49469) + retval = retval.copy(deep=False) + + return retval + + @final + def _getitem_lowerdim(self, tup: tuple): + # we can directly get the axis result since the axis is specified + if self.axis is not None: + axis = self.obj._get_axis_number(self.axis) + return self._getitem_axis(tup, axis=axis) + + # we may have a nested tuples indexer here + if self._is_nested_tuple_indexer(tup): + return self._getitem_nested_tuple(tup) + + # we maybe be using a tuple to represent multiple dimensions here + ax0 = self.obj._get_axis(0) + # ...but iloc should handle the tuple as simple integer-location + # instead of checking it as multiindex representation (GH 13797) + if ( + isinstance(ax0, MultiIndex) + and self.name != "iloc" + and not any(isinstance(x, slice) for x in tup) + ): + # Note: in all extant test cases, replacing the slice condition with + # `all(is_hashable(x) or com.is_null_slice(x) for x in tup)` + # is equivalent. + # (see the other place where we call _handle_lowerdim_multi_index_axis0) + with suppress(IndexingError): + return cast(_LocIndexer, self)._handle_lowerdim_multi_index_axis0(tup) + + tup = self._validate_key_length(tup) + + # Reverse tuple so that we are indexing along columns before rows + # and avoid unintended dtype inference. # GH60600 + for i, key in zip(range(len(tup) - 1, -1, -1), reversed(tup), strict=True): + if is_label_like(key) or is_list_like(key): + # We don't need to check for tuples here because those are + # caught by the _is_nested_tuple_indexer check above. + section = self._getitem_axis(key, axis=i) + + # We should never have a scalar section here, because + # _getitem_lowerdim is only called after a check for + # is_scalar_access, which that would be. + if section.ndim == self.ndim: + # we're in the middle of slicing through a MultiIndex + # revise the key wrt to `section` by inserting an _NS + new_key = (*tup[:i], _NS, *tup[i + 1 :]) + + else: + # Note: the section.ndim == self.ndim check above + # rules out having DataFrame here, so we dont need to worry + # about transposing. + new_key = tup[:i] + tup[i + 1 :] + + if len(new_key) == 1: + new_key = new_key[0] + + # Slices should return views, but calling iloc/loc with a null + # slice returns a new object. + if com.is_null_slice(new_key): + return section + # This is an elided recursive call to iloc/loc + return getattr(section, self.name)[new_key] + + raise IndexingError("not applicable") + + @final + def _getitem_nested_tuple(self, tup: tuple): + # we have a nested tuple so have at least 1 multi-index level + # we should be able to match up the dimensionality here + + for key in tup: + check_dict_or_set_indexers(key) + + # we have too many indexers for our dim, but have at least 1 + # multi-index dimension, try to see if we have something like + # a tuple passed to a series with a multi-index + if len(tup) > self.ndim: + if self.name != "loc": + # This should never be reached, but let's be explicit about it + raise ValueError("Too many indices") # pragma: no cover + if all( + is_hashable(x, allow_slice=False) or com.is_null_slice(x) for x in tup + ): + # GH#10521 Series should reduce MultiIndex dimensions instead of + # DataFrame, IndexingError is not raised when slice(None,None,None) + # with one row. + with suppress(IndexingError): + return cast(_LocIndexer, self)._handle_lowerdim_multi_index_axis0( + tup + ) + elif isinstance(self.obj, ABCSeries) and any( + isinstance(k, tuple) for k in tup + ): + # GH#35349 Raise if tuple in tuple for series + # Do this after the all-hashable-or-null-slice check so that + # we are only getting non-hashable tuples, in particular ones + # that themselves contain a slice entry + # See test_loc_series_getitem_too_many_dimensions + raise IndexingError("Too many indexers") + + # this is a series with a multi-index specified a tuple of + # selectors + axis = self.axis or 0 + return self._getitem_axis(tup, axis=axis) + + # handle the multi-axis by taking sections and reducing + # this is iterative + obj = self.obj + # GH#41369 Loop in reverse order ensures indexing along columns before rows + # which selects only necessary blocks which avoids dtype conversion if possible + axis = len(tup) - 1 + for key in reversed(tup): + if com.is_null_slice(key): + axis -= 1 + continue + + obj = getattr(obj, self.name)._getitem_axis(key, axis=axis) + axis -= 1 + + # if we have a scalar, we are done + if is_scalar(obj) or not hasattr(obj, "ndim"): + break + + return obj + + def _convert_to_indexer(self, key, axis: AxisInt): + raise AbstractMethodError(self) + + def _raise_callable_usage(self, key: Any, maybe_callable: T) -> T: + # GH53533 + if self.name == "iloc" and callable(key) and isinstance(maybe_callable, tuple): + raise ValueError( + "Returning a tuple from a callable with iloc is not allowed.", + ) + return maybe_callable + + @final + def __getitem__(self, key): + check_dict_or_set_indexers(key) + if type(key) is tuple: + key = (list(x) if is_iterator(x) else x for x in key) + key = tuple(com.apply_if_callable(x, self.obj) for x in key) + if self._is_scalar_access(key): + return self.obj._get_value(*key, takeable=self._takeable) + return self._getitem_tuple(key) + else: + # we by definition only have the 0th axis + axis = self.axis or 0 + + maybe_callable = com.apply_if_callable(key, self.obj) + maybe_callable = self._raise_callable_usage(key, maybe_callable) + return self._getitem_axis(maybe_callable, axis=axis) + + def _is_scalar_access(self, key: tuple): + raise NotImplementedError + + def _getitem_tuple(self, tup: tuple): + raise AbstractMethodError(self) + + def _getitem_axis(self, key, axis: AxisInt): + raise NotImplementedError + + def _has_valid_setitem_indexer(self, indexer) -> bool: + raise AbstractMethodError(self) + + @final + def _getbool_axis(self, key, axis: AxisInt): + # caller is responsible for ensuring non-None axis + labels = self.obj._get_axis(axis) + key = check_bool_indexer(labels, key) + inds = key.nonzero()[0] + return self.obj.take(inds, axis=axis) + + +@doc(IndexingMixin.loc) +class _LocIndexer(_LocationIndexer): + _takeable: bool = False + _valid_types = ( + "labels (MUST BE IN THE INDEX), slices of labels (BOTH " + "endpoints included! Can be slices of integers if the " + "index is integers), listlike of labels, boolean" + ) + + # ------------------------------------------------------------------- + # Key Checks + + @doc(_LocationIndexer._validate_key) + def _validate_key(self, key, axis: Axis) -> None: + # valid for a collection of labels (we check their presence later) + # slice of labels (where start-end in labels) + # slice of integers (only if in the labels) + # boolean not in slice and with boolean index + ax = self.obj._get_axis(axis) + if isinstance(key, bool) and not ( + is_bool_dtype(ax.dtype) + or ax.dtype.name == "boolean" + or ( + isinstance(ax, MultiIndex) + and is_bool_dtype(ax.get_level_values(0).dtype) + ) + ): + raise KeyError( + f"{key}: boolean label can not be used without a boolean index" + ) + + if isinstance(key, slice) and ( + isinstance(key.start, bool) or isinstance(key.stop, bool) + ): + raise TypeError(f"{key}: boolean values can not be used in a slice") + + def _has_valid_setitem_indexer(self, indexer) -> bool: + return True + + def _is_scalar_access(self, key: tuple) -> bool: + """ + Returns + ------- + bool + """ + # this is a shortcut accessor to both .loc and .iloc + # that provide the equivalent access of .at and .iat + # a) avoid getting things via sections and (to minimize dtype changes) + # b) provide a performant path + if len(key) != self.ndim: + return False + + for i, k in enumerate(key): + if not is_scalar(k): + return False + + ax = self.obj.axes[i] + if isinstance(ax, MultiIndex): + return False + + if isinstance(k, str) and ax._supports_partial_string_indexing: + # partial string indexing, df.loc['2000', 'A'] + # should not be considered scalar + return False + + if not ax._index_as_unique: + return False + + return True + + # ------------------------------------------------------------------- + # MultiIndex Handling + + def _multi_take_opportunity(self, tup: tuple) -> bool: + """ + Check whether there is the possibility to use ``_multi_take``. + + Currently the limit is that all axes being indexed, must be indexed with + list-likes. + + Parameters + ---------- + tup : tuple + Tuple of indexers, one per axis. + + Returns + ------- + bool + Whether the current indexing, + can be passed through `_multi_take`. + """ + if not all(is_list_like_indexer(x) for x in tup): + return False + + # just too complicated + return not any(com.is_bool_indexer(x) for x in tup) + + def _multi_take(self, tup: tuple): + """ + Create the indexers for the passed tuple of keys, and + executes the take operation. This allows the take operation to be + executed all at once, rather than once for each dimension. + Improving efficiency. + + Parameters + ---------- + tup : tuple + Tuple of indexers, one per axis. + + Returns + ------- + values: same type as the object being indexed + """ + # GH 836 + d = { + axis: self._get_listlike_indexer(key, axis) + for (key, axis) in zip(tup, self.obj._AXIS_ORDERS, strict=True) + } + return self.obj._reindex_with_indexers(d, allow_dups=True) + + # ------------------------------------------------------------------- + + def _getitem_iterable(self, key, axis: AxisInt): + """ + Index current object with an iterable collection of keys. + + Parameters + ---------- + key : iterable + Targeted labels. + axis : int + Dimension on which the indexing is being made. + + Raises + ------ + KeyError + If no key was found. Will change in the future to raise if not all + keys were found. + + Returns + ------- + scalar, DataFrame, or Series: indexed value(s). + """ + # we assume that not com.is_bool_indexer(key), as that is + # handled before we get here. + self._validate_key(key, axis) + + # A collection of keys + keyarr, indexer = self._get_listlike_indexer(key, axis) + return self.obj._reindex_with_indexers( + {axis: [keyarr, indexer]}, allow_dups=True + ) + + def _getitem_tuple(self, tup: tuple): + with suppress(IndexingError): + tup = self._expand_ellipsis(tup) + return self._getitem_lowerdim(tup) + + # no multi-index, so validate all of the indexers + tup = self._validate_tuple_indexer(tup) + + # ugly hack for GH #836 + if self._multi_take_opportunity(tup): + return self._multi_take(tup) + + return self._getitem_tuple_same_dim(tup) + + def _get_label(self, label, axis: AxisInt): + # GH#5567 this will fail if the label is not present in the axis. + return self.obj.xs(label, axis=axis) + + def _handle_lowerdim_multi_index_axis0(self, tup: tuple): + # we have an axis0 multi-index, handle or raise + axis = self.axis or 0 + try: + # fast path for series or for tup devoid of slices + return self._get_label(tup, axis=axis) + + except KeyError as ek: + # raise KeyError if number of indexers match + # else IndexingError will be raised + if self.ndim < len(tup) <= self.obj.index.nlevels: + raise ek + raise IndexingError("No label returned") from ek + + def _getitem_axis(self, key, axis: AxisInt): + key = item_from_zerodim(key) + if is_iterator(key): + key = list(key) + if key is Ellipsis: + key = slice(None) + + labels = self.obj._get_axis(axis) + + if isinstance(key, tuple) and isinstance(labels, MultiIndex): + key = tuple(key) + + if isinstance(key, slice): + self._validate_key(key, axis) + return self._get_slice_axis(key, axis=axis) + elif com.is_bool_indexer(key): + return self._getbool_axis(key, axis=axis) + elif is_list_like_indexer(key): + # an iterable multi-selection + if not (isinstance(key, tuple) and isinstance(labels, MultiIndex)): + if hasattr(key, "ndim") and key.ndim > 1: + raise ValueError("Cannot index with multidimensional key") + + return self._getitem_iterable(key, axis=axis) + + # nested tuple slicing + if is_nested_tuple(key, labels): + locs = labels.get_locs(key) + indexer: list[slice | npt.NDArray[np.intp]] = [slice(None)] * self.ndim + indexer[axis] = locs + return self.obj.iloc[tuple(indexer)] + + # fall thru to straight lookup + self._validate_key(key, axis) + return self._get_label(key, axis=axis) + + def _get_slice_axis(self, slice_obj: slice, axis: AxisInt): + """ + This is pretty simple as we just have to deal with labels. + """ + # caller is responsible for ensuring non-None axis + obj = self.obj + if not need_slice(slice_obj): + return obj.copy(deep=False) + + labels = obj._get_axis(axis) + indexer = labels.slice_indexer(slice_obj.start, slice_obj.stop, slice_obj.step) + + if isinstance(indexer, slice): + return self.obj._slice(indexer, axis=axis) + else: + # DatetimeIndex overrides Index.slice_indexer and may + # return a DatetimeIndex instead of a slice object. + return self.obj.take(indexer, axis=axis) + + def _convert_to_indexer(self, key, axis: AxisInt): + """ + Convert indexing key into something we can use to do actual fancy + indexing on an ndarray. + + Examples + ix[:5] -> slice(0, 5) + ix[[1,2,3]] -> [1,2,3] + ix[['foo', 'bar', 'baz']] -> [i, j, k] (indices of foo, bar, baz) + + Going by Zen of Python? + 'In the face of ambiguity, refuse the temptation to guess.' + raise AmbiguousIndexError with integer labels? + - No, prefer label-based indexing + """ + labels = self.obj._get_axis(axis) + + if isinstance(key, slice): + return labels._convert_slice_indexer(key, kind="loc") + + if ( + isinstance(key, tuple) + and not isinstance(labels, MultiIndex) + and self.ndim < 2 + and len(key) > 1 + ): + raise IndexingError("Too many indexers") + + # Slices are not valid keys passed in by the user, + # even though they are hashable in Python 3.12 + contains_slice = False + if isinstance(key, tuple): + contains_slice = any(isinstance(v, slice) for v in key) + + if is_scalar(key) or ( + isinstance(labels, MultiIndex) and is_hashable(key) and not contains_slice + ): + # Otherwise get_loc will raise InvalidIndexError + + # if we are a label return me + try: + return labels.get_loc(key) + except LookupError: + if isinstance(key, tuple) and isinstance(labels, MultiIndex): + if len(key) == labels.nlevels: + return {"key": key} + raise + except InvalidIndexError: + # GH35015, using datetime as column indices raises exception + if not isinstance(labels, MultiIndex): + raise + except ValueError: + if not is_integer(key): + raise + return {"key": key} + + if is_nested_tuple(key, labels): + if self.ndim == 1 and any(isinstance(k, tuple) for k in key): + # GH#35349 Raise if tuple in tuple for series + raise IndexingError("Too many indexers") + return labels.get_locs(key) + + elif is_list_like_indexer(key): + if is_iterator(key): + key = list(key) + + if com.is_bool_indexer(key): + key = check_bool_indexer(labels, key) + return key + else: + return self._get_listlike_indexer(key, axis)[1] + else: + try: + return labels.get_loc(key) + except LookupError: + # allow a not found key only if we are a setter + if not is_list_like_indexer(key): + return {"key": key} + raise + + def _get_listlike_indexer(self, key, axis: AxisInt): + """ + Transform a list-like of keys into a new index and an indexer. + + Parameters + ---------- + key : list-like + Targeted labels. + axis: int + Dimension on which the indexing is being made. + + Raises + ------ + KeyError + If at least one key was requested but none was found. + + Returns + ------- + keyarr: Index + New index (coinciding with 'key' if the axis is unique). + values : array-like + Indexer for the return object, -1 denotes keys not found. + """ + ax = self.obj._get_axis(axis) + axis_name = self.obj._get_axis_name(axis) + + keyarr, indexer = ax._get_indexer_strict(key, axis_name) + + return keyarr, indexer + + +@doc(IndexingMixin.iloc) +class _iLocIndexer(_LocationIndexer): + _valid_types = ( + "integer, integer slice (START point is INCLUDED, END " + "point is EXCLUDED), listlike of integers, boolean array" + ) + _takeable = True + + # ------------------------------------------------------------------- + # Key Checks + + def _validate_key(self, key, axis: AxisInt) -> None: + if com.is_bool_indexer(key): + if hasattr(key, "index") and isinstance(key.index, Index): + if key.index.inferred_type == "integer": + return + raise ValueError( + "iLocation based boolean indexing cannot use an indexable as a mask" + ) + return + + if isinstance(key, slice): + return + elif is_integer(key): + self._validate_integer(key, axis) + elif isinstance(key, tuple): + # a tuple should already have been caught by this point + # so don't treat a tuple as a valid indexer + raise IndexingError("Too many indexers") + elif is_list_like_indexer(key): + if isinstance(key, ABCSeries): + arr = key._values + elif is_array_like(key): + arr = key + else: + arr = np.array(key) + len_axis = len(self.obj._get_axis(axis)) + + # check that the key has a numeric dtype + if not is_numeric_dtype(arr.dtype): + raise IndexError(f".iloc requires numeric indexers, got {arr}") + + if len(arr): + if isinstance(arr.dtype, ExtensionDtype): + arr_max = arr._reduce("max") + arr_min = arr._reduce("min") + else: + arr_max = np.max(arr) + arr_min = np.min(arr) + + # check that the key does not exceed the maximum size + if arr_max >= len_axis or arr_min < -len_axis: + raise IndexError("positional indexers are out-of-bounds") + else: + raise ValueError(f"Can only index by location with a [{self._valid_types}]") + + def _has_valid_setitem_indexer(self, indexer) -> bool: + """ + Validate that a positional indexer cannot enlarge its target + will raise if needed, does not modify the indexer externally. + + Returns + ------- + bool + """ + if isinstance(indexer, dict): + raise IndexError("iloc cannot enlarge its target object") + + if isinstance(indexer, ABCDataFrame): + raise TypeError( + "DataFrame indexer for .iloc is not supported. " + "Consider using .loc with a DataFrame indexer for automatic alignment.", + ) + + if not isinstance(indexer, tuple): + indexer = _tuplify(self.ndim, indexer) + + for ax, i in zip(self.obj.axes, indexer, strict=False): + if isinstance(i, slice): + # should check the stop slice? + pass + elif is_list_like_indexer(i): + # should check the elements? + pass + elif is_integer(i): + if i >= len(ax): + raise IndexError("iloc cannot enlarge its target object") + elif isinstance(i, dict): + raise IndexError("iloc cannot enlarge its target object") + + return True + + def _is_scalar_access(self, key: tuple) -> bool: + """ + Returns + ------- + bool + """ + # this is a shortcut accessor to both .loc and .iloc + # that provide the equivalent access of .at and .iat + # a) avoid getting things via sections and (to minimize dtype changes) + # b) provide a performant path + if len(key) != self.ndim: + return False + + return all(is_integer(k) for k in key) + + def _validate_integer(self, key: int | np.integer, axis: AxisInt) -> None: + """ + Check that 'key' is a valid position in the desired axis. + + Parameters + ---------- + key : int + Requested position. + axis : int + Desired axis. + + Raises + ------ + IndexError + If 'key' is not a valid position in axis 'axis'. + """ + len_axis = len(self.obj._get_axis(axis)) + if key >= len_axis or key < -len_axis: + raise IndexError("single positional indexer is out-of-bounds") + + # ------------------------------------------------------------------- + + def _getitem_tuple(self, tup: tuple): + tup = self._validate_tuple_indexer(tup) + with suppress(IndexingError): + return self._getitem_lowerdim(tup) + + return self._getitem_tuple_same_dim(tup) + + def _get_list_axis(self, key, axis: AxisInt): + """ + Return Series values by list or array of integers. + + Parameters + ---------- + key : list-like positional indexer + axis : int + + Returns + ------- + Series object + + Notes + ----- + `axis` can only be zero. + """ + try: + return self.obj.take(key, axis=axis) + except IndexError as err: + # re-raise with different error message, e.g. test_getitem_ndarray_3d + raise IndexError("positional indexers are out-of-bounds") from err + + def _getitem_axis(self, key, axis: AxisInt): + if key is Ellipsis: + key = slice(None) + elif isinstance(key, ABCDataFrame): + raise IndexError( + "DataFrame indexer is not allowed for .iloc\n" + "Consider using .loc for automatic alignment." + ) + + if isinstance(key, slice): + return self._get_slice_axis(key, axis=axis) + + if is_iterator(key): + key = list(key) + + if isinstance(key, list): + key = np.asarray(key) + + if com.is_bool_indexer(key): + self._validate_key(key, axis) + return self._getbool_axis(key, axis=axis) + + # a list of integers + elif is_list_like_indexer(key): + return self._get_list_axis(key, axis=axis) + + # a single integer + else: + key = item_from_zerodim(key) + if not is_integer(key): + raise TypeError("Cannot index by location index with a non-integer key") + + # validate the location + self._validate_integer(key, axis) + + return self.obj._ixs(key, axis=axis) + + def _get_slice_axis(self, slice_obj: slice, axis: AxisInt): + # caller is responsible for ensuring non-None axis + obj = self.obj + + if not need_slice(slice_obj): + return obj.copy(deep=False) + + labels = obj._get_axis(axis) + labels._validate_positional_slice(slice_obj) + return self.obj._slice(slice_obj, axis=axis) + + def _convert_to_indexer(self, key: T, axis: AxisInt) -> T: + """ + Much simpler as we only have to deal with our valid types. + """ + return key + + def _get_setitem_indexer(self, key): + # GH#32257 Fall through to let numpy do validation + if is_iterator(key): + key = list(key) + + if self.axis is not None: + key = _tupleize_axis_indexer(self.ndim, self.axis, key) + + return key + + # ------------------------------------------------------------------- + + def _decide_split_path(self, indexer, value) -> bool: + """ + Decide whether we will take a block-by-block path. + """ + take_split_path = not self.obj._mgr.is_single_block + + if not take_split_path and isinstance(value, ABCDataFrame): + # Avoid cast of values + take_split_path = not value._mgr.is_single_block + + # if there is only one block/type, still have to take split path + # unless the block is one-dimensional or it can hold the value + if not take_split_path and len(self.obj._mgr.blocks) and self.ndim > 1: + # in case of dict, keys are indices + val = list(value.values()) if isinstance(value, dict) else value + arr = self.obj._mgr.blocks[0].values + take_split_path = not can_hold_element( + arr, extract_array(val, extract_numpy=True) + ) + + # if we have any multi-indexes that have non-trivial slices + # (not null slices) then we must take the split path, xref + # GH 10360, GH 27841 + if isinstance(indexer, tuple) and len(indexer) == len(self.obj.axes): + for i, ax in zip(indexer, self.obj.axes, strict=True): + if isinstance(ax, MultiIndex) and not ( + is_integer(i) or com.is_null_slice(i) + ): + take_split_path = True + break + + return take_split_path + + def _setitem_new_column(self, indexer, key, value, name: str) -> None: + """ + _setitem_with_indexer cases that can go through DataFrame.__setitem__. + """ + # add the new item, and set the value + # must have all defined axes if we have a scalar + # or a list-like on the non-info axes if we have a + # list-like + if not len(self.obj): + if not is_list_like_indexer(value): + raise ValueError( + "cannot set a frame with no defined index and a scalar" + ) + self.obj[key] = value + return + + # add a new item with the dtype setup + if com.is_null_slice(indexer[0]): + # We are setting an entire column + self.obj[key] = value + return + elif is_array_like(value): + # GH#42099 + arr = extract_array(value, extract_numpy=True) + taker = -1 * np.ones(len(self.obj), dtype=np.intp) + empty_value = algos.take_nd(arr, taker) + if not isinstance(value, ABCSeries): + # if not Series (in which case we need to align), + # we can short-circuit + if isinstance(arr, np.ndarray) and arr.ndim == 1 and len(arr) == 1: + # NumPy 1.25 deprecation: https://github.com/numpy/numpy/pull/10615 + arr = arr[0, ...] + empty_value[indexer[0]] = arr + self.obj[key] = empty_value + return + + self.obj[key] = empty_value + elif not is_list_like(value): + self.obj[key] = construct_1d_array_from_inferred_fill_value( + value, len(self.obj) + ) + else: + # FIXME: GH#42099#issuecomment-864326014 + self.obj[key] = infer_fill_value(value) + + new_indexer = convert_from_missing_indexer_tuple(indexer, self.obj.axes) + self._setitem_with_indexer(new_indexer, value, name) + + return + + def _setitem_with_indexer(self, indexer, value, name: str = "iloc") -> None: + """ + _setitem_with_indexer is for setting values on a Series/DataFrame + using positional indexers. + + If the relevant keys are not present, the Series/DataFrame may be + expanded. + """ + info_axis = self.obj._info_axis_number + take_split_path = self._decide_split_path(indexer, value) + + if isinstance(indexer, tuple): + nindexer = [] + for i, idx in enumerate(indexer): + idx, missing = convert_missing_indexer(idx) + if missing: + # reindex the axis to the new value + # and set inplace + key = idx + + # if this is the items axes, then take the main missing + # path first + # this correctly sets the dtype + # essentially this separates out the block that is needed + # to possibly be modified + if self.ndim > 1 and i == info_axis: + self._setitem_new_column(indexer, key, value, name=name) + return + + # reindex the axis + index = self.obj._get_axis(i) + labels = index.insert(len(index), key) + + # We are expanding the Series/DataFrame values to match + # the length of the new index `labels`. GH#40096 ensure + # this is valid even if the index has duplicates. + taker = np.arange(len(index) + 1, dtype=np.intp) + taker[-1] = -1 + reindexers = {i: (labels, taker)} + new_obj = self.obj._reindex_with_indexers( + reindexers, allow_dups=True + ) + self.obj._mgr = new_obj._mgr + + nindexer.append(labels.get_loc(key)) + + else: + nindexer.append(idx) + + indexer = tuple(nindexer) + else: + indexer, missing = convert_missing_indexer(indexer) + + if missing: + self._setitem_with_indexer_missing(indexer, value) + return + + if name == "loc": + # must come after setting of missing + indexer, value = self._maybe_mask_setitem_value(indexer, value) + + # align and set the values + if take_split_path: + # We have to operate column-wise + self._setitem_with_indexer_split_path(indexer, value, name) + else: + self._setitem_single_block(indexer, value, name) + + def _setitem_with_indexer_split_path(self, indexer, value, name: str): + """ + Setitem column-wise. + """ + # Above we only set take_split_path to True for 2D cases + assert self.ndim == 2 + + if not isinstance(indexer, tuple): + indexer = _tuplify(self.ndim, indexer) + if len(indexer) > self.ndim: + raise IndexError("too many indices for array") + if isinstance(indexer[0], np.ndarray) and indexer[0].ndim > 2: + raise ValueError(r"Cannot set values with ndim > 2") + + if (isinstance(value, ABCSeries) and name != "iloc") or isinstance(value, dict): + from pandas import Series + + value = self._align_series(indexer, Series(value)) + + # Ensure we have something we can iterate over + info_axis = indexer[1] + ilocs = self._ensure_iterable_column_indexer(info_axis) + + pi = indexer[0] + lplane_indexer = length_of_indexer(pi, self.obj.index) + # lplane_indexer gives the expected length of obj[indexer[0]] + + # we need an iterable, with an ndim of at least 1 + # eg. don't pass through np.array(0) + if is_list_like_indexer(value) and getattr(value, "ndim", 1) > 0: + if isinstance(value, ABCDataFrame): + self._setitem_with_indexer_frame_value(indexer, value, name) + + elif np.ndim(value) == 2: + # TODO: avoid np.ndim call in case it isn't an ndarray, since + # that will construct an ndarray, which will be wasteful + self._setitem_with_indexer_2d_value(indexer, value) + + elif len(ilocs) == 1 and lplane_indexer == len(value) and not is_scalar(pi): + # We are setting multiple rows in a single column. + self._setitem_single_column(ilocs[0], value, pi) + + elif len(ilocs) == 1 and 0 != lplane_indexer != len(value): + # We are trying to set N values into M entries of a single + # column, which is invalid for N != M + # Exclude zero-len for e.g. boolean masking that is all-false + + if len(value) == 1 and not is_integer(info_axis): + # This is a case like df.iloc[:3, [1]] = [0] + # where we treat as df.iloc[:3, 1] = 0 + return self._setitem_with_indexer((pi, info_axis[0]), value[0]) + + raise ValueError( + "Must have equal len keys and value when setting with an iterable" + ) + + elif lplane_indexer == 0 and len(value) == len(self.obj.index): + # We get here in one case via .loc with an all-False mask + pass + + elif self._is_scalar_access(indexer) and is_object_dtype( + self.obj.dtypes._values[ilocs[0]] + ): + # We are setting nested data, only possible for object dtype data + self._setitem_single_column(indexer[1], value, pi) + + elif len(ilocs) == len(value): + # We are setting multiple columns in a single row. + for loc, v in zip(ilocs, value, strict=True): + self._setitem_single_column(loc, v, pi) + + elif len(ilocs) == 1 and com.is_null_slice(pi) and len(self.obj) == 0: + # This is a setitem-with-expansion, see + # test_loc_setitem_empty_append_expands_rows_mixed_dtype + # e.g. df = DataFrame(columns=["x", "y"]) + # df["x"] = df["x"].astype(np.int64) + # df.loc[:, "x"] = [1, 2, 3] + self._setitem_single_column(ilocs[0], value, pi) + + else: + raise ValueError( + "Must have equal len keys and value when setting with an iterable" + ) + + else: + # scalar value + for loc in ilocs: + self._setitem_single_column(loc, value, pi) + + def _setitem_with_indexer_2d_value(self, indexer, value) -> None: + # We get here with np.ndim(value) == 2, excluding DataFrame, + # which goes through _setitem_with_indexer_frame_value + pi = indexer[0] + + ilocs = self._ensure_iterable_column_indexer(indexer[1]) + + if not is_array_like(value): + # cast lists to array + value = np.array(value, dtype=object) + if len(ilocs) != value.shape[1]: + raise ValueError( + "Must have equal len keys and value when setting with an ndarray" + ) + + for i, loc in enumerate(ilocs): + value_col = value[:, i] + if is_object_dtype(value_col.dtype): + # casting to list so that we do type inference in setitem_single_column + value_col = value_col.tolist() + self._setitem_single_column(loc, value_col, pi) + + def _setitem_with_indexer_frame_value( + self, indexer, value: DataFrame, name: str + ) -> None: + ilocs = self._ensure_iterable_column_indexer(indexer[1]) + + sub_indexer = list(indexer) + pi = indexer[0] + + multiindex_indexer = isinstance(self.obj.columns, MultiIndex) + + unique_cols = value.columns.is_unique + + # We do not want to align the value in case of iloc GH#37728 + if name == "iloc": + for i, loc in enumerate(ilocs): + val = value.iloc[:, i] + self._setitem_single_column(loc, val, pi) + + elif not unique_cols and value.columns.equals(self.obj.columns): + # We assume we are already aligned, see + # test_iloc_setitem_frame_duplicate_columns_multiple_blocks + for loc in ilocs: + item = self.obj.columns[loc] + if item in value: + sub_indexer[1] = item + val = self._align_series( + tuple(sub_indexer), + value.iloc[:, loc], + multiindex_indexer, + ) + else: + val = np.nan + + self._setitem_single_column(loc, val, pi) + + elif not unique_cols: + raise ValueError("Setting with non-unique columns is not allowed.") + + else: + for loc in ilocs: + item = self.obj.columns[loc] + if item in value: + sub_indexer[1] = item + val = self._align_series( + tuple(sub_indexer), + value[item], + multiindex_indexer, + using_cow=True, + ) + else: + val = np.nan + + self._setitem_single_column(loc, val, pi) + + def _setitem_single_column(self, loc: int, value, plane_indexer) -> None: + """ + + Parameters + ---------- + loc : int + Indexer for column position + plane_indexer : int, slice, listlike[int] + The indexer we use for setitem along axis=0. + """ + pi = plane_indexer + + is_full_setter = com.is_null_slice(pi) or com.is_full_slice(pi, len(self.obj)) + + is_null_setter = com.is_empty_slice(pi) or (is_array_like(pi) and len(pi) == 0) + + if is_null_setter: + # no-op, don't cast dtype later + return + + elif is_full_setter: + try: + self.obj._mgr.column_setitem( + loc, plane_indexer, value, inplace_only=True + ) + except (ValueError, TypeError, LossySetitemError) as exc: + # If we're setting an entire column and we can't do it inplace, + # then we can use value's dtype (or inferred dtype) + # instead of object + dtype = self.obj.dtypes.iloc[loc] + if dtype not in (np.void, object) and not self.obj.empty: + # - Exclude np.void, as that is a special case for expansion. + # We want to raise for + # df = pd.DataFrame({'a': [1, 2]}) + # df.loc[:, 'a'] = .3 + # but not for + # df = pd.DataFrame({'a': [1, 2]}) + # df.loc[:, 'b'] = .3 + # - Exclude `object`, as then no upcasting happens. + # - Exclude empty initial object with enlargement, + # as then there's nothing to be inconsistent with. + raise TypeError( + f"Invalid value '{value}' for dtype '{dtype}'" + ) from exc + self.obj.isetitem(loc, value) + else: + # set value into the column (first attempting to operate inplace, then + # falling back to casting if necessary) + dtype = self.obj.dtypes.iloc[loc] + if dtype == np.void: + # This means we're expanding, with multiple columns, e.g. + # df = pd.DataFrame({'A': [1,2,3], 'B': [4,5,6]}) + # df.loc[df.index <= 2, ['F', 'G']] = (1, 'abc') + # Columns F and G will initially be set to np.void. + # Here, we replace those temporary `np.void` columns with + # columns of the appropriate dtype, based on `value`. + self.obj.iloc[:, loc] = construct_1d_array_from_inferred_fill_value( + value, len(self.obj) + ) + self.obj._mgr.column_setitem(loc, plane_indexer, value) + + def _setitem_single_block(self, indexer, value, name: str) -> None: + """ + _setitem_with_indexer for the case when we have a single Block. + """ + from pandas import Series + + if (isinstance(value, ABCSeries) and name != "iloc") or isinstance(value, dict): + # TODO(EA): ExtensionBlock.setitem this causes issues with + # setting for extensionarrays that store dicts. Need to decide + # if it's worth supporting that. + value = self._align_series(indexer, Series(value)) + + info_axis = self.obj._info_axis_number + item_labels = self.obj._get_axis(info_axis) + if isinstance(indexer, tuple): + # if we are setting on the info axis ONLY + # set using those methods to avoid block-splitting + # logic here + if ( + self.ndim == len(indexer) == 2 + and is_integer(indexer[1]) + and com.is_null_slice(indexer[0]) + ): + col = item_labels[indexer[info_axis]] + if len(item_labels.get_indexer_for([col])) == 1: + # e.g. test_loc_setitem_empty_append_expands_rows + loc = item_labels.get_loc(col) + self._setitem_single_column(loc, value, indexer[0]) + return + + indexer = maybe_convert_ix(*indexer) # e.g. test_setitem_frame_align + + if isinstance(value, ABCDataFrame) and name != "iloc": + value = self._align_frame(indexer, value)._values + + # actually do the set + self.obj._mgr = self.obj._mgr.setitem(indexer=indexer, value=value) + + def _setitem_with_indexer_missing(self, indexer, value): + """ + Insert new row(s) or column(s) into the Series or DataFrame. + """ + from pandas import Series + + # reindex the axis to the new value + # and set inplace + if self.ndim == 1: + index = self.obj.index + new_index = index.insert(len(index), indexer) + + # we have a coerced indexer, e.g. a float + # that matches in an int64 Index, so + # we will not create a duplicate index, rather + # index to that element + # e.g. 0.0 -> 0 + # GH#12246 + if index.is_unique: + # pass new_index[-1:] instead if [new_index[-1]] + # so that we retain dtype + new_indexer = index.get_indexer(new_index[-1:]) + if (new_indexer != -1).any(): + # We get only here with loc, so can hard code + return self._setitem_with_indexer(new_indexer, value, "loc") + + # this preserves dtype of the value and of the object + if not is_scalar(value): + new_dtype = None + + elif is_valid_na_for_dtype(value, self.obj.dtype): + if not is_object_dtype(self.obj.dtype): + # Every NA value is suitable for object, no conversion needed + value = na_value_for_dtype(self.obj.dtype, compat=False) + + new_dtype = maybe_promote(self.obj.dtype, value)[0] + + elif isna(value): + new_dtype = None + elif not self.obj.empty and not is_object_dtype(self.obj.dtype): + # We should not cast, if we have object dtype because we can + # set timedeltas into object series + curr_dtype = self.obj.dtype + curr_dtype = getattr(curr_dtype, "numpy_dtype", curr_dtype) + new_dtype = maybe_promote(curr_dtype, value)[0] + else: + new_dtype = None + + new_values = Series([value], dtype=new_dtype)._values + + if len(self.obj._values): + # GH#22717 handle casting compatibility that np.concatenate + # does incorrectly + new_values = concat_compat([self.obj._values, new_values]) + self.obj._mgr = self.obj._constructor( + new_values, index=new_index, name=self.obj.name + )._mgr + + elif self.ndim == 2: + if not len(self.obj.columns): + # no columns and scalar + raise ValueError("cannot set a frame with no defined columns") + + has_dtype = hasattr(value, "dtype") + if isinstance(value, ABCSeries): + # append a Series + value = value.reindex(index=self.obj.columns) + value.name = indexer + elif isinstance(value, dict): + value = Series( + value, index=self.obj.columns, name=indexer, dtype=object + ) + else: + # a list-list + if is_list_like_indexer(value): + # must have conforming columns + if len(value) != len(self.obj.columns): + raise ValueError("cannot set a row with mismatched columns") + + value = Series(value, index=self.obj.columns, name=indexer) + + if not len(self.obj): + # We will ignore the existing dtypes instead of using + # internals.concat logic + df = value.to_frame().T + + idx = self.obj.index + if isinstance(idx, MultiIndex): + name = idx.names + else: + name = idx.name + + df.index = Index([indexer], name=name) + if not has_dtype: + # i.e. if we already had a Series or ndarray, keep that + # dtype. But if we had a list or dict, then do inference + df = df.infer_objects() + self.obj._mgr = df._mgr + else: + self.obj._mgr = self.obj._append_internal(value)._mgr + + def _ensure_iterable_column_indexer(self, column_indexer): + """ + Ensure that our column indexer is something that can be iterated over. + """ + ilocs: Sequence[int | np.integer] | np.ndarray | range + if is_integer(column_indexer): + ilocs = [column_indexer] + elif isinstance(column_indexer, slice): + ilocs = range(len(self.obj.columns))[column_indexer] + elif ( + isinstance(column_indexer, np.ndarray) and column_indexer.dtype.kind == "b" + ): + ilocs = np.arange(len(column_indexer))[column_indexer] + else: + ilocs = column_indexer + return ilocs + + def _align_series( + self, + indexer, + ser: Series, + multiindex_indexer: bool = False, + using_cow: bool = False, + ): + """ + Parameters + ---------- + indexer : tuple, slice, scalar + Indexer used to get the locations that will be set to `ser`. + ser : pd.Series + Values to assign to the locations specified by `indexer`. + multiindex_indexer : bool, optional + Defaults to False. Should be set to True if `indexer` was from + a `pd.MultiIndex`, to avoid unnecessary broadcasting. + + Returns + ------- + `np.array` of `ser` broadcast to the appropriate shape for assignment + to the locations selected by `indexer` + """ + if isinstance(indexer, (slice, np.ndarray, list, Index)): + indexer = (indexer,) + + if isinstance(indexer, tuple): + # flatten np.ndarray indexers + if ( + len(indexer) == 2 + and isinstance(indexer[1], np.ndarray) + and indexer[1].dtype == np.bool_ + ): + indexer = (indexer[0], np.where(indexer[1])[0]) + + def ravel(i): + return i.ravel() if isinstance(i, np.ndarray) else i + + indexer = tuple(map(ravel, indexer)) + aligners = [not com.is_null_slice(idx) for idx in indexer] + sum_aligners = sum(aligners) + single_aligner = sum_aligners == 1 + is_frame = self.ndim == 2 + obj = self.obj + + # are we a single alignable value on a non-primary + # dim (e.g. panel: 1,2, or frame: 0) ? + # hence need to align to a single axis dimension + # rather that find all valid dims + + # frame + if is_frame: + single_aligner = single_aligner and aligners[0] + + # we have a frame, with multiple indexers on both axes; and a + # series, so need to broadcast (see GH5206) + if all(is_sequence(_) or isinstance(_, slice) for _ in indexer): + ser_values = ser.reindex(obj.axes[0][indexer[0]])._values + + # single indexer + if len(indexer) > 1 and not multiindex_indexer: + if isinstance(indexer[1], slice): + len_indexer = len(obj.axes[1][indexer[1]]) + else: + len_indexer = len(indexer[1]) + ser_values = ( + np.tile(ser_values, len_indexer).reshape(len_indexer, -1).T + ) + + return ser_values + + for i, idx in enumerate(indexer): + ax = obj.axes[i] + + # multiple aligners (or null slices) + if is_sequence(idx) or isinstance(idx, slice): + if single_aligner and com.is_null_slice(idx): + continue + new_ix = ax[idx] + if not is_list_like_indexer(new_ix): + new_ix = Index([new_ix]) + else: + new_ix = Index(new_ix) + if not len(new_ix) or ser.index.equals(new_ix): + if using_cow: + return ser + return ser._values.copy() + + return ser.reindex(new_ix)._values + + # 2 dims + elif single_aligner: + # reindex along index + ax = self.obj.axes[1] + if ser.index.equals(ax) or not len(ax): + return ser._values.copy() + return ser.reindex(ax)._values + + elif is_integer(indexer) and self.ndim == 1: + if is_object_dtype(self.obj.dtype): + return ser + ax = self.obj._get_axis(0) + + if ser.index.equals(ax): + return ser._values.copy() + + return ser.reindex(ax)._values[indexer] + + elif is_integer(indexer): + ax = self.obj._get_axis(1) + + if ser.index.equals(ax): + return ser._values.copy() + + return ser.reindex(ax)._values + + raise ValueError("Incompatible indexer with Series") + + def _align_frame(self, indexer, df: DataFrame) -> DataFrame: + is_frame = self.ndim == 2 + + if isinstance(indexer, tuple): + idx, cols = None, None + sindexers = [] + for i, ix in enumerate(indexer): + ax = self.obj.axes[i] + if is_sequence(ix) or isinstance(ix, slice): + if isinstance(ix, np.ndarray): + ix = ix.reshape(-1) + if idx is None: + idx = ax[ix] + elif cols is None: + cols = ax[ix] + else: + break + else: + sindexers.append(i) + + if idx is not None and cols is not None: + if df.index.equals(idx) and df.columns.equals(cols): + val = df.copy() + else: + val = df.reindex(idx, columns=cols) + return val + + elif (isinstance(indexer, slice) or is_list_like_indexer(indexer)) and is_frame: + ax = self.obj.index[indexer] + if df.index.equals(ax): + val = df.copy() + else: + # we have a multi-index and are trying to align + # with a particular, level GH3738 + if ( + isinstance(ax, MultiIndex) + and isinstance(df.index, MultiIndex) + and ax.nlevels != df.index.nlevels + ): + raise TypeError( + "cannot align on a multi-index with out " + "specifying the join levels" + ) + + val = df.reindex(index=ax) + return val + + raise ValueError("Incompatible indexer with DataFrame") + + +class _ScalarAccessIndexer(NDFrameIndexerBase): + """ + Access scalars quickly. + """ + + # sub-classes need to set _takeable + _takeable: bool + + def _convert_key(self, key): + raise AbstractMethodError(self) + + def __getitem__(self, key): + if not isinstance(key, tuple): + # we could have a convertible item here (e.g. Timestamp) + if not is_list_like_indexer(key): + key = (key,) + else: + raise ValueError("Invalid call for scalar access (getting)!") + + key = self._convert_key(key) + return self.obj._get_value(*key, takeable=self._takeable) + + def __setitem__(self, key, value) -> None: + if isinstance(key, tuple): + key = tuple(com.apply_if_callable(x, self.obj) for x in key) + else: + # scalar callable may return tuple + key = com.apply_if_callable(key, self.obj) + + if not isinstance(key, tuple): + key = _tuplify(self.ndim, key) + key = list(self._convert_key(key)) + if len(key) != self.ndim: + raise ValueError("Not enough indexers for scalar access (setting)!") + + self.obj._set_value(*key, value=value, takeable=self._takeable) + + +@doc(IndexingMixin.at) +class _AtIndexer(_ScalarAccessIndexer): + _takeable = False + + def _convert_key(self, key): + """ + Require they keys to be the same type as the index. (so we don't + fallback) + """ + # GH 26989 + # For series, unpacking key needs to result in the label. + # This is already the case for len(key) == 1; e.g. (1,) + if self.ndim == 1 and len(key) > 1: + key = (key,) + + return key + + @property + def _axes_are_unique(self) -> bool: + # Only relevant for self.ndim == 2 + assert self.ndim == 2 + return self.obj.index.is_unique and self.obj.columns.is_unique + + def __getitem__(self, key): + if self.ndim == 2 and not self._axes_are_unique: + # GH#33041 fall back to .loc + if not isinstance(key, tuple) or not all(is_scalar(x) for x in key): + raise ValueError("Invalid call for scalar access (getting)!") + return self.obj.loc[key] + + return super().__getitem__(key) + + def __setitem__(self, key, value) -> None: + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount(self.obj) <= REF_COUNT_IDX: + warnings.warn( + _chained_assignment_msg, ChainedAssignmentError, stacklevel=2 + ) + + if self.ndim == 2 and not self._axes_are_unique: + # GH#33041 fall back to .loc + if not isinstance(key, tuple) or not all(is_scalar(x) for x in key): + raise ValueError("Invalid call for scalar access (setting)!") + + self.obj.loc[key] = value + return + + return super().__setitem__(key, value) + + +@doc(IndexingMixin.iat) +class _iAtIndexer(_ScalarAccessIndexer): + _takeable = True + + def _convert_key(self, key): + """ + Require integer args. (and convert to label arguments) + """ + for i in key: + if not is_integer(i): + raise ValueError("iAt based indexing can only have integer indexers") + return key + + def __setitem__(self, key, value) -> None: + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount(self.obj) <= REF_COUNT_IDX: + warnings.warn( + _chained_assignment_msg, ChainedAssignmentError, stacklevel=2 + ) + + return super().__setitem__(key, value) + + +def _tuplify(ndim: int, loc: Hashable) -> tuple[Hashable | slice, ...]: + """ + Given an indexer for the first dimension, create an equivalent tuple + for indexing over all dimensions. + + Parameters + ---------- + ndim : int + loc : object + + Returns + ------- + tuple + """ + _tup: list[Hashable | slice] + _tup = [slice(None, None) for _ in range(ndim)] + _tup[0] = loc + return tuple(_tup) + + +def _tupleize_axis_indexer(ndim: int, axis: AxisInt, key) -> tuple: + """ + If we have an axis, adapt the given key to be axis-independent. + """ + new_key = [slice(None)] * ndim + new_key[axis] = key + return tuple(new_key) + + +def check_bool_indexer(index: Index, key) -> np.ndarray: + """ + Check if key is a valid boolean indexer for an object with such index and + perform reindexing or conversion if needed. + + This function assumes that is_bool_indexer(key) == True. + + Parameters + ---------- + index : Index + Index of the object on which the indexing is done. + key : list-like + Boolean indexer to check. + + Returns + ------- + np.array + Resulting key. + + Raises + ------ + IndexError + If the key does not have the same length as index. + IndexingError + If the index of the key is unalignable to index. + """ + result = key + if isinstance(key, ABCSeries) and not key.index.equals(index): + indexer = result.index.get_indexer_for(index) + if -1 in indexer: + raise IndexingError( + "Unalignable boolean Series provided as " + "indexer (index of the boolean Series and of " + "the indexed object do not match)." + ) + + result = result.take(indexer) + + # fall through for boolean + if not isinstance(result.dtype, ExtensionDtype): + return result.astype(bool)._values + + if is_object_dtype(key): + # key might be object-dtype bool, check_array_indexer needs bool array + result = np.asarray(result, dtype=bool) + elif not is_array_like(result): + # GH 33924 + # key may contain nan elements, check_array_indexer needs bool array + result = pd_array(result, dtype=bool) + return check_array_indexer(index, result) + + +def convert_missing_indexer(indexer): + """ + Reverse convert a missing indexer, which is a dict + return the scalar indexer and a boolean indicating if we converted + """ + if isinstance(indexer, dict): + # a missing key (but not a tuple indexer) + indexer = indexer["key"] + + if isinstance(indexer, bool): + raise KeyError("cannot use a single bool to index into setitem") + return indexer, True + + return indexer, False + + +def convert_from_missing_indexer_tuple(indexer: tuple, axes: list[Index]) -> tuple: + """ + Create a filtered indexer that doesn't have any missing indexers. + """ + + def get_indexer(_i, _idx): + return axes[_i].get_loc(_idx["key"]) if isinstance(_idx, dict) else _idx + + return tuple(get_indexer(_i, _idx) for _i, _idx in enumerate(indexer)) + + +def maybe_convert_ix(*args): + """ + We likely want to take the cross-product. + """ + for arg in args: + if not isinstance(arg, (np.ndarray, list, ABCSeries, Index)): + return args + return np.ix_(*args) + + +def is_nested_tuple(tup, labels) -> bool: + """ + Returns + ------- + bool + """ + # check for a compatible nested tuple and multiindexes among the axes + if not isinstance(tup, tuple): + return False + + for k in tup: + if is_list_like(k) or isinstance(k, slice): + return isinstance(labels, MultiIndex) + + return False + + +def is_label_like(key) -> bool: + """ + Returns + ------- + bool + """ + # select a label or row + return ( + not isinstance(key, slice) + and not is_list_like_indexer(key) + and key is not Ellipsis + ) + + +def need_slice(obj: slice) -> bool: + """ + Returns + ------- + bool + """ + return ( + obj.start is not None + or obj.stop is not None + or (obj.step is not None and obj.step != 1) + ) + + +def check_dict_or_set_indexers(key) -> None: + """ + Check if the indexer is or contains a dict or set, which is no longer allowed. + """ + if isinstance(key, set) or ( + isinstance(key, tuple) and any(isinstance(x, set) for x in key) + ): + raise TypeError( + "Passing a set as an indexer is not supported. Use a list instead." + ) + + if isinstance(key, dict) or ( + isinstance(key, tuple) and any(isinstance(x, dict) for x in key) + ): + raise TypeError( + "Passing a dict as an indexer is not supported. Use a list instead." + ) diff --git a/pandas/core/missing.py b/pandas/core/missing.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d789e9cb263205e639fcf7713e641c4d36dad5 --- /dev/null +++ b/pandas/core/missing.py @@ -0,0 +1,1103 @@ +""" +Routines for filling missing data. +""" + +from __future__ import annotations + +from functools import wraps +from typing import ( + TYPE_CHECKING, + Any, + Literal, + cast, + overload, +) + +import numpy as np + +from pandas._config import is_nan_na + +from pandas._libs import ( + NaT, + algos, + lib, +) +from pandas._typing import ( + ArrayLike, + AxisInt, + F, + ReindexMethod, + npt, +) +from pandas.compat._optional import import_optional_dependency + +from pandas.core.dtypes.cast import infer_dtype_from +from pandas.core.dtypes.common import ( + is_array_like, + is_bool_dtype, + is_numeric_dtype, + is_object_dtype, + needs_i8_conversion, +) +from pandas.core.dtypes.dtypes import ( + ArrowDtype, + BaseMaskedDtype, + DatetimeTZDtype, +) +from pandas.core.dtypes.missing import ( + is_valid_na_for_dtype, + isna, + na_value_for_dtype, +) + +if TYPE_CHECKING: + from collections.abc import Callable + from typing import TypeAlias + + from pandas import Index + + _CubicBC: TypeAlias = Literal["not-a-knot", "clamped", "natural", "periodic"] + + +def check_value_size(value, mask: npt.NDArray[np.bool_], length: int): + """ + Validate the size of the values passed to ExtensionArray.fillna. + """ + if is_array_like(value): + if len(value) != length: + raise ValueError( + f"Length of 'value' does not match. Got ({len(value)}) " + f" expected {length}" + ) + value = value[mask] + + return value + + +def mask_missing(arr: ArrayLike, value) -> npt.NDArray[np.bool_]: + """ + Return a masking array of same size/shape as arr + with entries equaling value set to True. + + Parameters + ---------- + arr : ArrayLike + value : scalar-like + Caller has ensured `not is_list_like(value)` and that it can be held + by `arr`. + + Returns + ------- + np.ndarray[bool] + """ + dtype, value = infer_dtype_from(value) + + if ( + isinstance(arr.dtype, (BaseMaskedDtype, ArrowDtype)) + and lib.is_float(value) + and np.isnan(value) + and not is_nan_na() + ): + # TODO: this should be done in an EA method? + if arr.dtype.kind == "f": + # GH#55127 + if isinstance(arr.dtype, BaseMaskedDtype): + # error: "ExtensionArray" has no attribute "_data" [attr-defined] + mask = np.isnan(arr._data) & ~arr.isna() # type: ignore[attr-defined,operator] + return mask + else: + # error: "ExtensionArray" has no attribute "_pa_array" [attr-defined] + import pyarrow.compute as pc + + mask = pc.is_nan(arr._pa_array).fill_null(False).to_numpy() # type: ignore[attr-defined] + return mask + + elif arr.dtype.kind in "iu": + # GH#51237 + mask = np.zeros(arr.shape, dtype=bool) + return mask + + if isna(value): + return isna(arr) + + # GH 21977 + mask = np.zeros(arr.shape, dtype=bool) + if ( + is_numeric_dtype(arr.dtype) + and not is_bool_dtype(arr.dtype) + and lib.is_bool(value) + ): + # e.g. test_replace_ea_float_with_bool, see GH#62048 + pass + elif ( + is_bool_dtype(arr.dtype) and is_numeric_dtype(dtype) and not lib.is_bool(value) + ): + # e.g. test_replace_ea_float_with_bool, see GH#62048 + pass + elif is_numeric_dtype(arr.dtype) and isinstance(value, str): + # GH#29553 prevent numpy deprecation warnings + pass + elif is_object_dtype(arr.dtype): + # pre-compute mask to avoid comparison to NA + # e.g. test_replace_na_in_obj_column + arr_mask = ~isna(arr) + mask[arr_mask] = arr[arr_mask] == value + else: + new_mask = arr == value + + if not isinstance(new_mask, np.ndarray): + # usually BooleanArray + new_mask = new_mask.to_numpy(dtype=bool, na_value=False) + mask = new_mask + + return mask + + +@overload +def clean_fill_method( + method: Literal["ffill", "pad", "bfill", "backfill"], + *, + allow_nearest: Literal[False] = ..., +) -> Literal["pad", "backfill"]: ... + + +@overload +def clean_fill_method( + method: Literal["ffill", "pad", "bfill", "backfill", "nearest"], + *, + allow_nearest: Literal[True], +) -> Literal["pad", "backfill", "nearest"]: ... + + +def clean_fill_method( + method: Literal["ffill", "pad", "bfill", "backfill", "nearest"], + *, + allow_nearest: bool = False, +) -> Literal["pad", "backfill", "nearest"]: + if isinstance(method, str): + # error: Incompatible types in assignment (expression has type "str", variable + # has type "Literal['ffill', 'pad', 'bfill', 'backfill', 'nearest']") + method = method.lower() # type: ignore[assignment] + if method == "ffill": + method = "pad" + elif method == "bfill": + method = "backfill" + + valid_methods = ["pad", "backfill"] + expecting = "pad (ffill) or backfill (bfill)" + if allow_nearest: + valid_methods.append("nearest") + expecting = "pad (ffill), backfill (bfill) or nearest" + if method not in valid_methods: + raise ValueError(f"Invalid fill method. Expecting {expecting}. Got {method}") + return method + + +# interpolation methods that dispatch to np.interp + +NP_METHODS = ["linear", "time", "index", "values"] + +# interpolation methods that dispatch to _interpolate_scipy_wrapper + +SP_METHODS = [ + "nearest", + "zero", + "slinear", + "quadratic", + "cubic", + "barycentric", + "krogh", + "spline", + "polynomial", + "from_derivatives", + "piecewise_polynomial", + "pchip", + "akima", + "cubicspline", +] + + +def clean_interp_method(method: str, index: Index, **kwargs) -> str: + order = kwargs.get("order") + + if method in ("spline", "polynomial") and order is None: + raise ValueError("You must specify the order of the spline or polynomial.") + + valid = NP_METHODS + SP_METHODS + if method not in valid: + raise ValueError(f"method must be one of {valid}. Got '{method}' instead.") + + if method in ("krogh", "piecewise_polynomial", "pchip"): + if not index.is_monotonic_increasing: + raise ValueError( + f"{method} interpolation requires that the index be monotonic." + ) + + return method + + +def find_valid_index(how: str, is_valid: npt.NDArray[np.bool_]) -> int | None: + """ + Retrieves the positional index of the first valid value. + + Parameters + ---------- + how : {'first', 'last'} + Use this parameter to change between the first or last valid index. + is_valid: np.ndarray + Mask to find na_values. + + Returns + ------- + int or None + """ + assert how in ["first", "last"] + + if len(is_valid) == 0: # early stop + return None + + if is_valid.ndim == 2: + # reduce axis 1 + is_valid = is_valid.any(axis=1) # type: ignore[assignment] + + if how == "first": + idxpos = is_valid[::].argmax() + + elif how == "last": + idxpos = len(is_valid) - 1 - is_valid[::-1].argmax() + + chk_notna = is_valid[idxpos] + + if not chk_notna: + return None + # Incompatible return value type (got "signedinteger[Any]", + # expected "Optional[int]") + return idxpos # type: ignore[return-value] + + +def validate_limit_direction( + limit_direction: str, +) -> Literal["forward", "backward", "both"]: + valid_limit_directions = ["forward", "backward", "both"] + limit_direction = limit_direction.lower() + if limit_direction not in valid_limit_directions: + raise ValueError( + "Invalid limit_direction: expecting one of " + f"{valid_limit_directions}, got '{limit_direction}'." + ) + # error: Incompatible return value type (got "str", expected + # "Literal['forward', 'backward', 'both']") + return limit_direction # type: ignore[return-value] + + +def validate_limit_area(limit_area: str | None) -> Literal["inside", "outside"] | None: + if limit_area is not None: + valid_limit_areas = ["inside", "outside"] + limit_area = limit_area.lower() + if limit_area not in valid_limit_areas: + raise ValueError( + f"Invalid limit_area: expecting one of {valid_limit_areas}, got " + f"{limit_area}." + ) + # error: Incompatible return value type (got "Optional[str]", expected + # "Optional[Literal['inside', 'outside']]") + return limit_area # type: ignore[return-value] + + +def infer_limit_direction( + limit_direction: Literal["backward", "forward", "both"] | None, method: str +) -> Literal["backward", "forward", "both"]: + # Set `limit_direction` depending on `method` + if limit_direction is None: + if method in ("backfill", "bfill"): + limit_direction = "backward" + else: + limit_direction = "forward" + else: + if method in ("pad", "ffill") and limit_direction != "forward": + raise ValueError( + f"`limit_direction` must be 'forward' for method `{method}`" + ) + if method in ("backfill", "bfill") and limit_direction != "backward": + raise ValueError( + f"`limit_direction` must be 'backward' for method `{method}`" + ) + return limit_direction + + +def get_interp_index(method, index: Index) -> Index: + # create/use the index + if method == "linear": + # prior default + from pandas import RangeIndex + + index = RangeIndex(len(index)) + else: + methods = {"index", "values", "nearest", "time"} + is_numeric_or_datetime = ( + is_numeric_dtype(index.dtype) + or isinstance(index.dtype, DatetimeTZDtype) + or lib.is_np_dtype(index.dtype, "mM") + ) + valid = NP_METHODS + SP_METHODS + if method in valid: + if method not in methods and not is_numeric_or_datetime: + raise ValueError( + "Index column must be numeric or datetime type when " + f"using {method} method other than linear. " + "Try setting a numeric or datetime index column before " + "interpolating." + ) + else: + raise ValueError(f"Can not interpolate with method={method}.") + + if isna(index).any(): + raise NotImplementedError( + "Interpolation with NaNs in the index " + "has not been implemented. Try filling " + "those NaNs before interpolating." + ) + return index + + +def interpolate_2d_inplace( + data: np.ndarray, # floating dtype + index: Index, + axis: AxisInt, + method: str = "linear", + limit: int | None = None, + limit_direction: str = "forward", + limit_area: str | None = None, + fill_value: Any | None = None, + mask=None, + **kwargs, +) -> None: + """ + Column-wise application of _interpolate_1d. + + Notes + ----- + Alters 'data' in-place. + + The signature does differ from _interpolate_1d because it only + includes what is needed for Block.interpolate. + """ + # validate the interp method + clean_interp_method(method, index, **kwargs) + + if is_valid_na_for_dtype(fill_value, data.dtype): + fill_value = na_value_for_dtype(data.dtype, compat=False) + + if method == "time": + if not needs_i8_conversion(index.dtype): + raise ValueError( + "time-weighted interpolation only works " + "on Series or DataFrames with a " + "DatetimeIndex" + ) + method = "values" + + limit_direction = validate_limit_direction(limit_direction) + limit_area_validated = validate_limit_area(limit_area) + + # default limit is unlimited GH #16282 + limit = algos.validate_limit(nobs=None, limit=limit) + + indices = _index_to_interp_indices(index, method) + + def func(yvalues: np.ndarray) -> None: + # process 1-d slices in the axis direction + + _interpolate_1d( + indices=indices, + yvalues=yvalues, + method=method, + limit=limit, + limit_direction=limit_direction, + limit_area=limit_area_validated, + fill_value=fill_value, + bounds_error=False, + mask=mask, + **kwargs, + ) + + np.apply_along_axis(func, axis, data) + + +def _index_to_interp_indices(index: Index, method: str) -> np.ndarray: + """ + Convert Index to ndarray of indices to pass to NumPy/SciPy. + """ + xarr = index._values + if needs_i8_conversion(xarr.dtype): + # GH#1646 for dt64tz + xarr = xarr.view("i8") + + if method == "linear": + inds = xarr + inds = cast(np.ndarray, inds) + else: + inds = np.asarray(xarr) + + if method in ("values", "index"): + if inds.dtype == np.object_: + inds = lib.maybe_convert_objects(inds) + + return inds + + +def _interpolate_1d( + indices: np.ndarray, + yvalues: np.ndarray, + method: str = "linear", + limit: int | None = None, + limit_direction: str = "forward", + limit_area: Literal["inside", "outside"] | None = None, + fill_value: Any | None = None, + bounds_error: bool = False, + order: int | None = None, + mask=None, + **kwargs, +) -> None: + """ + Logic for the 1-d interpolation. The input + indices and yvalues will each be 1-d arrays of the same length. + + Bounds_error is currently hardcoded to False since non-scipy ones don't + take it as an argument. + + Notes + ----- + Fills 'yvalues' in-place. + """ + if mask is not None: + invalid = mask + else: + invalid = isna(yvalues) + valid = ~invalid + + if not valid.any(): + return + + if valid.all(): + return + + # These index pointers to invalid values... i.e. {0, 1, etc... + all_nans = np.flatnonzero(invalid) + + first_valid_index = find_valid_index(how="first", is_valid=valid) + if first_valid_index is None: # no nan found in start + first_valid_index = 0 + start_nans = np.arange(first_valid_index) + + last_valid_index = find_valid_index(how="last", is_valid=valid) + if last_valid_index is None: # no nan found in end + last_valid_index = len(yvalues) + end_nans = np.arange(1 + last_valid_index, len(valid)) + + # preserve_nans contains indices of invalid values, + # but in this case, it is the final set of indices that need to be + # preserved as NaN after the interpolation. + + # For example if limit_direction='forward' then preserve_nans will + # contain indices of NaNs at the beginning of the series, and NaNs that + # are more than 'limit' away from the prior non-NaN. + + # set preserve_nans based on direction using _interp_limit + if limit_direction == "forward": + preserve_nans = np.union1d(start_nans, _interp_limit(invalid, limit, 0)) + elif limit_direction == "backward": + preserve_nans = np.union1d(end_nans, _interp_limit(invalid, 0, limit)) + else: + # both directions... just use _interp_limit + preserve_nans = np.unique(_interp_limit(invalid, limit, limit)) + + # if limit_area is set, add either mid or outside indices + # to preserve_nans GH #16284 + if limit_area == "inside": + # preserve NaNs on the outside + preserve_nans = np.union1d(preserve_nans, start_nans) + preserve_nans = np.union1d(preserve_nans, end_nans) + elif limit_area == "outside": + # preserve NaNs on the inside + mid_nans = np.setdiff1d(all_nans, start_nans, assume_unique=True) + mid_nans = np.setdiff1d(mid_nans, end_nans, assume_unique=True) + preserve_nans = np.union1d(preserve_nans, mid_nans) + + is_datetimelike = yvalues.dtype.kind in "mM" + + if is_datetimelike: + yvalues = yvalues.view("i8") + + if method in NP_METHODS: + # np.interp requires sorted X values, #21037 + + indexer = np.argsort(indices[valid]) + yvalues[invalid] = np.interp( + indices[invalid], indices[valid][indexer], yvalues[valid][indexer] + ) + else: + yvalues[invalid] = _interpolate_scipy_wrapper( + indices[valid], + yvalues[valid], + indices[invalid], + method=method, + fill_value=fill_value, + bounds_error=bounds_error, + order=order, + **kwargs, + ) + + if mask is not None: + mask[:] = False + mask[preserve_nans] = True + elif is_datetimelike: + yvalues[preserve_nans] = NaT.value + else: + yvalues[preserve_nans] = np.nan + return + + +def _interpolate_scipy_wrapper( + x: np.ndarray, + y: np.ndarray, + new_x: np.ndarray, + method: str, + fill_value=None, + bounds_error: bool = False, + order=None, + **kwargs, +): + """ + Passed off to scipy.interpolate.interp1d. method is scipy's kind. + Returns an array interpolated at new_x. Add any new methods to + the list in _clean_interp_method. + """ + extra = f"{method} interpolation requires SciPy." + import_optional_dependency("scipy", extra=extra) + from scipy import interpolate + + new_x = np.asarray(new_x) + + # ignores some kwargs that could be passed along. + alt_methods: dict[str, Callable[..., np.ndarray]] = { + "barycentric": interpolate.barycentric_interpolate, + "krogh": interpolate.krogh_interpolate, + "from_derivatives": _from_derivatives, + "piecewise_polynomial": _from_derivatives, + "cubicspline": _cubicspline_interpolate, + "akima": _akima_interpolate, + "pchip": interpolate.pchip_interpolate, + } + + interp1d_methods = [ + "nearest", + "zero", + "slinear", + "quadratic", + "cubic", + "polynomial", + ] + terp: Callable[..., np.ndarray] | None + if method in interp1d_methods: + if method == "polynomial": + kind = order + else: + kind = method + terp = interpolate.interp1d( + x, y, kind=kind, fill_value=fill_value, bounds_error=bounds_error + ) + new_y = terp(new_x) + elif method == "spline": + # GH #10633, #24014 + if isna(order) or (order <= 0): + raise ValueError( + f"order needs to be specified and greater than 0; got order: {order}" + ) + terp = interpolate.UnivariateSpline(x, y, k=order, **kwargs) + new_y = terp(new_x) + else: + # GH 7295: need to be able to write for some reason + # in some circumstances: check all three + if not x.flags.writeable: + x = x.copy() + if not y.flags.writeable: + y = y.copy() + if not new_x.flags.writeable: + new_x = new_x.copy() + terp = alt_methods.get(method, None) + if terp is None: + raise ValueError(f"Can not interpolate with method={method}.") + + # Make sure downcast is not in kwargs for alt methods + kwargs.pop("downcast", None) + new_y = terp(x, y, new_x, **kwargs) + return new_y + + +def _from_derivatives( + xi: np.ndarray, + yi: np.ndarray, + x: np.ndarray, + order=None, + der: int | list[int] | None = 0, + extrapolate: bool = False, +): + """ + Convenience function for interpolate.BPoly.from_derivatives. + + Construct a piecewise polynomial in the Bernstein basis, compatible + with the specified values and derivatives at breakpoints. + + Parameters + ---------- + xi : array-like + sorted 1D array of x-coordinates + yi : array-like or list of array-likes + yi[i][j] is the j-th derivative known at xi[i] + order: None or int or array-like of ints. Default: None. + Specifies the degree of local polynomials. If not None, some + derivatives are ignored. + der : int or list + How many derivatives to extract; None for all potentially nonzero + derivatives (that is a number equal to the number of points), or a + list of derivatives to extract. This number includes the function + value as 0th derivative. + extrapolate : bool, optional + Whether to extrapolate to ouf-of-bounds points based on first and last + intervals, or to return NaNs. Default: True. + + See Also + -------- + scipy.interpolate.BPoly.from_derivatives + + Returns + ------- + y : scalar or array-like + The result, of length R or length M or M by R. + """ + from scipy import interpolate + + # return the method for compat with scipy version & backwards compat + method = interpolate.BPoly.from_derivatives + m = method(xi, yi.reshape(-1, 1), orders=order, extrapolate=extrapolate) + + return m(x) + + +def _akima_interpolate( + xi: np.ndarray, + yi: np.ndarray, + x: np.ndarray, + der: int = 0, + axis: AxisInt = 0, +): + """ + Convenience function for akima interpolation. + xi and yi are arrays of values used to approximate some function f, + with ``yi = f(xi)``. + + See `Akima1DInterpolator` for details. + + Parameters + ---------- + xi : np.ndarray + A sorted list of x-coordinates, of length N. + yi : np.ndarray + A 1-D array of real values. `yi`'s length along the interpolation + axis must be equal to the length of `xi`. If N-D array, use axis + parameter to select correct axis. + x : np.ndarray + Of length M. + der : int, optional + How many derivatives to extract. This number includes the function + value as 0th derivative. + axis : int, optional + Axis in the yi array corresponding to the x-coordinate values. + + See Also + -------- + scipy.interpolate.Akima1DInterpolator + + Returns + ------- + y : scalar or array-like + The result, of length R or length M or M by R, + + """ + from scipy import interpolate + + P = interpolate.Akima1DInterpolator(xi, yi, axis=axis) + + return P(x, nu=der) + + +def _cubicspline_interpolate( + xi: np.ndarray, + yi: np.ndarray, + x: np.ndarray, + axis: AxisInt = 0, + bc_type: _CubicBC | tuple[Any, Any] = "not-a-knot", + extrapolate: Literal["periodic"] | bool | None = None, +) -> np.ndarray: + """ + Convenience function for cubic spline data interpolator. + + See `scipy.interpolate.CubicSpline` for details. + + Parameters + ---------- + xi : np.ndarray, shape (n,) + 1-d array containing values of the independent variable. + Values must be real, finite and in strictly increasing order. + yi : np.ndarray + Array containing values of the dependent variable. It can have + arbitrary number of dimensions, but the length along ``axis`` + (see below) must match the length of ``x``. Values must be finite. + x : np.ndarray, shape (m,) + axis : int, optional + Axis along which `y` is assumed to be varying. Meaning that for + ``x[i]`` the corresponding values are ``np.take(y, i, axis=axis)``. + Default is 0. + bc_type : string or 2-tuple, optional + Boundary condition type. Two additional equations, given by the + boundary conditions, are required to determine all coefficients of + polynomials on each segment [2]_. + If `bc_type` is a string, then the specified condition will be applied + at both ends of a spline. Available conditions are: + * 'not-a-knot' (default): The first and second segment at a curve end + are the same polynomial. It is a good default when there is no + information on boundary conditions. + * 'periodic': The interpolated functions is assumed to be periodic + of period ``x[-1] - x[0]``. The first and last value of `y` must be + identical: ``y[0] == y[-1]``. This boundary condition will result in + ``y'[0] == y'[-1]`` and ``y''[0] == y''[-1]``. + * 'clamped': The first derivative at curves ends are zero. Assuming + a 1D `y`, ``bc_type=((1, 0.0), (1, 0.0))`` is the same condition. + * 'natural': The second derivative at curve ends are zero. Assuming + a 1D `y`, ``bc_type=((2, 0.0), (2, 0.0))`` is the same condition. + If `bc_type` is a 2-tuple, the first and the second value will be + applied at the curve start and end respectively. The tuple values can + be one of the previously mentioned strings (except 'periodic') or a + tuple `(order, deriv_values)` allowing to specify arbitrary + derivatives at curve ends: + * `order`: the derivative order, 1 or 2. + * `deriv_value`: array-like containing derivative values, shape must + be the same as `y`, excluding ``axis`` dimension. For example, if + `y` is 1D, then `deriv_value` must be a scalar. If `y` is 3D with + the shape (n0, n1, n2) and axis=2, then `deriv_value` must be 2D + and have the shape (n0, n1). + extrapolate : {bool, 'periodic', None}, optional + If bool, determines whether to extrapolate to out-of-bounds points + based on first and last intervals, or to return NaNs. If 'periodic', + periodic extrapolation is used. If None (default), ``extrapolate`` is + set to 'periodic' for ``bc_type='periodic'`` and to True otherwise. + + See Also + -------- + scipy.interpolate.CubicHermiteSpline + + Returns + ------- + y : scalar or array-like + The result, of shape (m,) + + References + ---------- + .. [1] `Cubic Spline Interpolation + `_ + on Wikiversity. + .. [2] Carl de Boor, "A Practical Guide to Splines", Springer-Verlag, 1978. + """ + from scipy import interpolate + + P = interpolate.CubicSpline( + xi, yi, axis=axis, bc_type=bc_type, extrapolate=extrapolate + ) + + return P(x) + + +def pad_or_backfill_inplace( + values: np.ndarray, + method: Literal["pad", "backfill"] = "pad", + axis: AxisInt = 0, + limit: int | None = None, + limit_area: Literal["inside", "outside"] | None = None, +) -> None: + """ + Perform an actual interpolation of values, values will be make 2-d if + needed fills inplace, returns the result. + + Parameters + ---------- + values: np.ndarray + Input array. + method: str, default "pad" + Interpolation method. Could be "bfill" or "pad" + axis: 0 or 1 + Interpolation axis + limit: int, optional + Index limit on interpolation. + limit_area: str, optional + Limit area for interpolation. Can be "inside" or "outside" + + Notes + ----- + Modifies values in-place. + """ + transf = (lambda x: x) if axis == 0 else (lambda x: x.T) + + # reshape a 1 dim if needed + if values.ndim == 1: + if axis != 0: # pragma: no cover + raise AssertionError("cannot interpolate on an ndim == 1 with axis != 0") + values = values.reshape((1, *values.shape)) + + method = clean_fill_method(method) + tvalues = transf(values) + + func = get_fill_func(method, ndim=2) + # _pad_2d and _backfill_2d both modify tvalues inplace + func(tvalues, limit=limit, limit_area=limit_area) + + +def _fillna_prep( + values, mask: npt.NDArray[np.bool_] | None = None +) -> npt.NDArray[np.bool_]: + # boilerplate for _pad_1d, _backfill_1d, _pad_2d, _backfill_2d + + if mask is None: + mask = isna(values) + + return mask + + +def _datetimelike_compat(func: F) -> F: + """ + Wrapper to handle datetime64 and timedelta64 dtypes. + """ + + @wraps(func) + def new_func( + values, + limit: int | None = None, + limit_area: Literal["inside", "outside"] | None = None, + mask=None, + ): + if needs_i8_conversion(values.dtype): + if mask is None: + # This needs to occur before casting to int64 + mask = isna(values) + + result, mask = func( + values.view("i8"), limit=limit, limit_area=limit_area, mask=mask + ) + return result.view(values.dtype), mask + + return func(values, limit=limit, limit_area=limit_area, mask=mask) + + return cast(F, new_func) + + +@_datetimelike_compat +def _pad_1d( + values: np.ndarray, + limit: int | None = None, + limit_area: Literal["inside", "outside"] | None = None, + mask: npt.NDArray[np.bool_] | None = None, +) -> tuple[np.ndarray, npt.NDArray[np.bool_]]: + mask = _fillna_prep(values, mask) + if limit_area is not None and not mask.all(): + _fill_limit_area_1d(mask, limit_area) + algos.pad_inplace(values, mask, limit=limit) + return values, mask + + +@_datetimelike_compat +def _backfill_1d( + values: np.ndarray, + limit: int | None = None, + limit_area: Literal["inside", "outside"] | None = None, + mask: npt.NDArray[np.bool_] | None = None, +) -> tuple[np.ndarray, npt.NDArray[np.bool_]]: + mask = _fillna_prep(values, mask) + if limit_area is not None and not mask.all(): + _fill_limit_area_1d(mask, limit_area) + algos.backfill_inplace(values, mask, limit=limit) + return values, mask + + +@_datetimelike_compat +def _pad_2d( + values: np.ndarray, + limit: int | None = None, + limit_area: Literal["inside", "outside"] | None = None, + mask: npt.NDArray[np.bool_] | None = None, +) -> tuple[np.ndarray, npt.NDArray[np.bool_]]: + mask = _fillna_prep(values, mask) + if limit_area is not None: + _fill_limit_area_2d(mask, limit_area) + + if values.size: + algos.pad_2d_inplace(values, mask, limit=limit) + return values, mask + + +@_datetimelike_compat +def _backfill_2d( + values, + limit: int | None = None, + limit_area: Literal["inside", "outside"] | None = None, + mask: npt.NDArray[np.bool_] | None = None, +): + mask = _fillna_prep(values, mask) + if limit_area is not None: + _fill_limit_area_2d(mask, limit_area) + + if values.size: + algos.backfill_2d_inplace(values, mask, limit=limit) + else: + # for test coverage + pass + return values, mask + + +def _fill_limit_area_1d( + mask: npt.NDArray[np.bool_], limit_area: Literal["outside", "inside"] +) -> None: + """Prepare 1d mask for ffill/bfill with limit_area. + + Caller is responsible for checking at least one value of mask is False. + When called, mask will no longer faithfully represent when + the corresponding are NA or not. + + Parameters + ---------- + mask : np.ndarray[bool, ndim=1] + Mask representing NA values when filling. + limit_area : { "outside", "inside" } + Whether to limit filling to outside or inside the outer most non-NA value. + """ + neg_mask = ~mask + first = neg_mask.argmax() + last = len(neg_mask) - neg_mask[::-1].argmax() - 1 + if limit_area == "inside": + mask[:first] = False + mask[last + 1 :] = False + elif limit_area == "outside": + mask[first + 1 : last] = False + + +def _fill_limit_area_2d( + mask: npt.NDArray[np.bool_], limit_area: Literal["outside", "inside"] +) -> None: + """Prepare 2d mask for ffill/bfill with limit_area. + + When called, mask will no longer faithfully represent when + the corresponding are NA or not. + + Parameters + ---------- + mask : np.ndarray[bool, ndim=1] + Mask representing NA values when filling. + limit_area : { "outside", "inside" } + Whether to limit filling to outside or inside the outer most non-NA value. + """ + neg_mask = ~mask.T + if limit_area == "outside": + # Identify inside + la_mask = ( + np.maximum.accumulate(neg_mask, axis=0) + & np.maximum.accumulate(neg_mask[::-1], axis=0)[::-1] + ) + else: + # Identify outside + la_mask = ( + ~np.maximum.accumulate(neg_mask, axis=0) + | ~np.maximum.accumulate(neg_mask[::-1], axis=0)[::-1] + ) + mask[la_mask.T] = False + + +_fill_methods = {"pad": _pad_1d, "backfill": _backfill_1d} + + +def get_fill_func(method, ndim: int = 1): + method = clean_fill_method(method) + if ndim == 1: + return _fill_methods[method] + return {"pad": _pad_2d, "backfill": _backfill_2d}[method] + + +def clean_reindex_fill_method(method) -> ReindexMethod | None: + if method is None: + return None + return clean_fill_method(method, allow_nearest=True) + + +def _interp_limit( + invalid: npt.NDArray[np.bool_], fw_limit: int | None, bw_limit: int | None +) -> np.ndarray: + """ + Get indexers of values that won't be filled + because they exceed the limits. + + Parameters + ---------- + invalid : np.ndarray[bool] + fw_limit : int or None + forward limit to index + bw_limit : int or None + backward limit to index + + Returns + ------- + set of indexers + + Notes + ----- + This is equivalent to the more readable, but slower + + .. code-block:: python + + def _interp_limit(invalid, fw_limit, bw_limit): + for x in np.where(invalid)[0]: + if invalid[max(0, x - fw_limit) : x + bw_limit + 1].all(): + yield x + """ + # handle forward first; the backward direction is the same except + # 1. operate on the reversed array + # 2. subtract the returned indices from N - 1 + N = len(invalid) + f_idx = np.array([], dtype=np.int64) + b_idx = np.array([], dtype=np.int64) + assume_unique = True + + def inner(invalid, limit: int): + limit = min(limit, N) + windowed = np.lib.stride_tricks.sliding_window_view(invalid, limit + 1).all(1) + idx = np.union1d( + np.where(windowed)[0] + limit, + np.where((~invalid[: limit + 1]).cumsum() == 0)[0], + ) + return idx + + if fw_limit is not None: + if fw_limit == 0: + f_idx = np.where(invalid)[0] + assume_unique = False + else: + f_idx = inner(invalid, fw_limit) + + if bw_limit is not None: + if bw_limit == 0: + # then we don't even need to care about backwards + # just use forwards + return f_idx + else: + b_idx = N - 1 - inner(invalid[::-1], bw_limit) + if fw_limit == 0: + return b_idx + + return np.intersect1d(f_idx, b_idx, assume_unique=assume_unique) diff --git a/pandas/core/nanops.py b/pandas/core/nanops.py new file mode 100644 index 0000000000000000000000000000000000000000..9743ce10758441e54b62cfe91d6fce6479f4be66 --- /dev/null +++ b/pandas/core/nanops.py @@ -0,0 +1,1777 @@ +from __future__ import annotations + +import functools +import itertools +from typing import ( + TYPE_CHECKING, + Any, + cast, +) +import warnings + +import numpy as np + +from pandas._config import get_option + +from pandas._libs import ( + NaT, + NaTType, + iNaT, + lib, +) +from pandas._typing import ( + ArrayLike, + AxisInt, + CorrelationMethod, + Dtype, + DtypeObj, + F, + Scalar, + Shape, + npt, +) +from pandas.compat._optional import import_optional_dependency + +from pandas.core.dtypes.common import ( + is_complex, + is_float, + is_float_dtype, + is_integer, + is_numeric_dtype, + is_object_dtype, + needs_i8_conversion, + pandas_dtype, +) +from pandas.core.dtypes.missing import ( + isna, + na_value_for_dtype, + notna, +) + +if TYPE_CHECKING: + from collections.abc import Callable + +bn = import_optional_dependency("bottleneck", errors="warn") +_BOTTLENECK_INSTALLED = bn is not None +_USE_BOTTLENECK = False + + +def set_use_bottleneck(v: bool = True) -> None: + # set/unset to use bottleneck + global _USE_BOTTLENECK + if _BOTTLENECK_INSTALLED: + _USE_BOTTLENECK = v + + +set_use_bottleneck(get_option("compute.use_bottleneck")) + + +class disallow: + def __init__(self, *dtypes: Dtype) -> None: + super().__init__() + self.dtypes = tuple(pandas_dtype(dtype).type for dtype in dtypes) + + def check(self, obj) -> bool: + return hasattr(obj, "dtype") and issubclass(obj.dtype.type, self.dtypes) + + def __call__(self, f: F) -> F: + @functools.wraps(f) + def _f(*args, **kwargs): + obj_iter = itertools.chain(args, kwargs.values()) + if any(self.check(obj) for obj in obj_iter): + f_name = f.__name__.replace("nan", "") + raise TypeError( + f"reduction operation '{f_name}' not allowed for this dtype" + ) + try: + return f(*args, **kwargs) + except ValueError as e: + # we want to transform an object array + # ValueError message to the more typical TypeError + # e.g. this is normally a disallowed function on + # object arrays that contain strings + if is_object_dtype(args[0]): + raise TypeError(e) from e + raise + + return cast(F, _f) + + +class bottleneck_switch: + def __init__(self, name=None, **kwargs) -> None: + self.name = name + self.kwargs = kwargs + + def __call__(self, alt: F) -> F: + bn_name = self.name or alt.__name__ + + try: + bn_func = getattr(bn, bn_name) + except (AttributeError, NameError): # pragma: no cover + bn_func = None + + @functools.wraps(alt) + def f( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + **kwds, + ): + if len(self.kwargs) > 0: + for k, v in self.kwargs.items(): + if k not in kwds: + kwds[k] = v + + if values.size == 0 and kwds.get("min_count") is None: + # We are empty, returning NA for our type + # Only applies for the default `min_count` of None + # since that affects how empty arrays are handled. + # TODO(GH-18976) update all the nanops methods to + # correctly handle empty inputs and remove this check. + # It *may* just be `var` + return _na_for_min_count(values, axis) + + if _USE_BOTTLENECK and skipna and _bn_ok_dtype(values.dtype, bn_name): + if kwds.get("mask", None) is None: + # `mask` is not recognised by bottleneck, would raise + # TypeError if called + kwds.pop("mask", None) + result = bn_func(values, axis=axis, **kwds) + + # prefer to treat inf/-inf as NA, but must compute the func + # twice :( + if _has_infs(result): + result = alt(values, axis=axis, skipna=skipna, **kwds) + else: + result = alt(values, axis=axis, skipna=skipna, **kwds) + else: + result = alt(values, axis=axis, skipna=skipna, **kwds) + + return result + + return cast(F, f) + + +def _bn_ok_dtype(dtype: DtypeObj, name: str) -> bool: + # Bottleneck chokes on datetime64, PeriodDtype (or and EA) + if dtype != object and not needs_i8_conversion(dtype): + # GH 42878 + # Bottleneck uses naive summation leading to O(n) loss of precision + # unlike numpy which implements pairwise summation, which has O(log(n)) loss + # crossref: https://github.com/pydata/bottleneck/issues/379 + + # GH 15507 + # bottleneck does not properly upcast during the sum + # so can overflow + + # GH 9422 + # further we also want to preserve NaN when all elements + # are NaN, unlike bottleneck/numpy which consider this + # to be 0 + return name not in ["nansum", "nanprod", "nanmean"] + return False + + +def _has_infs(result) -> bool: + if isinstance(result, np.ndarray): + if result.dtype in ("f8", "f4"): + # Note: outside of a nanops-specific test, we always have + # result.ndim == 1, so there is no risk of this ravel making a copy. + return lib.has_infs(result.ravel("K")) + try: + return np.isinf(result).any() + except (TypeError, NotImplementedError): + # if it doesn't support infs, then it can't have infs + return False + + +def _get_fill_value( + dtype: DtypeObj, fill_value: Scalar | None = None, fill_value_typ=None +): + """return the correct fill value for the dtype of the values""" + if fill_value is not None: + return fill_value + if _na_ok_dtype(dtype): + if fill_value_typ is None: + return np.nan + elif fill_value_typ == "+inf": + return np.inf + else: + return -np.inf + elif fill_value_typ == "+inf": + # need the max int here + return lib.i8max + else: + return iNaT + + +def _maybe_get_mask( + values: np.ndarray, skipna: bool, mask: npt.NDArray[np.bool_] | None +) -> npt.NDArray[np.bool_] | None: + """ + Compute a mask if and only if necessary. + + This function will compute a mask iff it is necessary. Otherwise, + return the provided mask (potentially None) when a mask does not need to be + computed. + + A mask is never necessary if the values array is of boolean or integer + dtypes, as these are incapable of storing NaNs. If passing a NaN-capable + dtype that is interpretable as either boolean or integer data (eg, + timedelta64), a mask must be provided. + + If the skipna parameter is False, a new mask will not be computed. + + The mask is computed using isna() by default. Setting invert=True selects + notna() as the masking function. + + Parameters + ---------- + values : ndarray + input array to potentially compute mask for + skipna : bool + boolean for whether NaNs should be skipped + mask : Optional[ndarray] + nan-mask if known + + Returns + ------- + Optional[np.ndarray[bool]] + """ + if mask is None: + if values.dtype.kind in "biu": + # Boolean data cannot contain nulls, so signal via mask being None + return None + + if skipna or values.dtype.kind in "mM": + mask = isna(values) + + return mask + + +def _get_values( + values: np.ndarray, + skipna: bool, + fill_value: Any = None, + fill_value_typ: str | None = None, + mask: npt.NDArray[np.bool_] | None = None, +) -> tuple[np.ndarray, npt.NDArray[np.bool_] | None]: + """ + Utility to get the values view, mask, dtype, dtype_max, and fill_value. + + If both mask and fill_value/fill_value_typ are not None and skipna is True, + the values array will be copied. + + For input arrays of boolean or integer dtypes, copies will only occur if a + precomputed mask, a fill_value/fill_value_typ, and skipna=True are + provided. + + Parameters + ---------- + values : ndarray + input array to potentially compute mask for + skipna : bool + boolean for whether NaNs should be skipped + fill_value : Any + value to fill NaNs with + fill_value_typ : str + Set to '+inf' or '-inf' to handle dtype-specific infinities + mask : Optional[np.ndarray[bool]] + nan-mask if known + + Returns + ------- + values : ndarray + Potential copy of input value array + mask : Optional[ndarray[bool]] + Mask for values, if deemed necessary to compute + """ + # In _get_values is only called from within nanops, and in all cases + # with scalar fill_value. This guarantee is important for the + # np.where call below + + mask = _maybe_get_mask(values, skipna, mask) + + dtype = values.dtype + + datetimelike = False + if values.dtype.kind in "mM": + # changing timedelta64/datetime64 to int64 needs to happen after + # finding `mask` above + values = np.asarray(values.view("i8")) + datetimelike = True + + if skipna and (mask is not None): + # get our fill value (in case we need to provide an alternative + # dtype for it) + fill_value = _get_fill_value( + dtype, fill_value=fill_value, fill_value_typ=fill_value_typ + ) + + if fill_value is not None: + if mask.any(): + if datetimelike or _na_ok_dtype(dtype): + values = values.copy() + np.putmask(values, mask, fill_value) + else: + # np.where will promote if needed + values = np.where(~mask, values, fill_value) + + return values, mask + + +def _get_dtype_max(dtype: np.dtype) -> np.dtype: + # return a platform independent precision dtype + dtype_max = dtype + if dtype.kind in "bi": + dtype_max = np.dtype(np.int64) + elif dtype.kind == "u": + dtype_max = np.dtype(np.uint64) + elif dtype.kind == "f": + dtype_max = np.dtype(np.float64) + return dtype_max + + +def _na_ok_dtype(dtype: DtypeObj) -> bool: + if needs_i8_conversion(dtype): + return False + return not issubclass(dtype.type, np.integer) + + +def _wrap_results(result, dtype: np.dtype, fill_value=None): + """wrap our results if needed""" + if result is NaT: + pass + + elif dtype.kind == "M": + if fill_value is None: + # GH#24293 + fill_value = iNaT + if not isinstance(result, np.ndarray): + assert not isna(fill_value), "Expected non-null fill_value" + if result == fill_value: + result = np.nan + + if isna(result): + result = np.datetime64("NaT", "ns").astype(dtype) + else: + result = np.int64(result).view(dtype) + # retain original unit + result = result.astype(dtype, copy=False) + else: + # If we have float dtype, taking a view will give the wrong result + result = result.astype(dtype) + elif dtype.kind == "m": + if not isinstance(result, np.ndarray): + if result == fill_value or np.isnan(result): + result = np.timedelta64("NaT").astype(dtype) + + elif np.fabs(result) > lib.i8max: + # raise if we have a timedelta64[ns] which is too large + raise ValueError("overflow in timedelta operation") + else: + # return a timedelta64 with the original unit + result = np.int64(result).astype(dtype, copy=False) + + else: + result = result.astype("m8[ns]").view(dtype) + + return result + + +def _datetimelike_compat(func: F) -> F: + """ + If we have datetime64 or timedelta64 values, ensure we have a correct + mask before calling the wrapped function, then cast back afterwards. + """ + + @functools.wraps(func) + def new_func( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + mask: npt.NDArray[np.bool_] | None = None, + **kwargs, + ): + orig_values = values + + datetimelike = values.dtype.kind in "mM" + if datetimelike and mask is None: + mask = isna(values) + + result = func(values, axis=axis, skipna=skipna, mask=mask, **kwargs) + + if datetimelike: + result = _wrap_results(result, orig_values.dtype, fill_value=iNaT) + if not skipna: + assert mask is not None # checked above + result = _mask_datetimelike_result(result, axis, mask, orig_values) + + return result + + return cast(F, new_func) + + +def _na_for_min_count(values: np.ndarray, axis: AxisInt | None) -> Scalar | np.ndarray: + """ + Return the missing value for `values`. + + Parameters + ---------- + values : ndarray + axis : int or None + axis for the reduction, required if values.ndim > 1. + + Returns + ------- + result : scalar or ndarray + For 1-D values, returns a scalar of the correct missing type. + For 2-D values, returns a 1-D array where each element is missing. + """ + # we either return np.nan or pd.NaT + if values.dtype.kind in "iufcb": + values = values.astype("float64") + fill_value = na_value_for_dtype(values.dtype) + + if values.ndim == 1: + return fill_value + elif axis is None: + return fill_value + else: + result_shape = values.shape[:axis] + values.shape[axis + 1 :] + + return np.full(result_shape, fill_value, dtype=values.dtype) + + +def maybe_operate_rowwise(func: F) -> F: + """ + NumPy operations on C-contiguous ndarrays with axis=1 can be + very slow if axis 1 >> axis 0. + Operate row-by-row and concatenate the results. + """ + + @functools.wraps(func) + def newfunc(values: np.ndarray, *, axis: AxisInt | None = None, **kwargs): + if ( + axis == 1 + and values.ndim == 2 + and values.flags["C_CONTIGUOUS"] + # only takes this path for wide arrays (long dataframes), for threshold see + # https://github.com/pandas-dev/pandas/pull/43311#issuecomment-974891737 + and (values.shape[1] / 1000) > values.shape[0] + and values.dtype not in (object, bool) + ): + arrs = list(values) + if kwargs.get("mask") is not None: + mask = kwargs.pop("mask") + results = [ + func(arrs[i], mask=mask[i], **kwargs) for i in range(len(arrs)) + ] + else: + results = [func(x, **kwargs) for x in arrs] + return np.array(results) + + return func(values, axis=axis, **kwargs) + + return cast(F, newfunc) + + +def nanany( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + mask: npt.NDArray[np.bool_] | None = None, +) -> bool: + """ + Check if any elements along an axis evaluate to True. + + Parameters + ---------- + values : ndarray + axis : int, optional + skipna : bool, default True + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + result : bool + + Examples + -------- + >>> from pandas.core import nanops + >>> s = pd.Series([1, 2]) + >>> nanops.nanany(s.values) + np.True_ + + >>> from pandas.core import nanops + >>> s = pd.Series([np.nan]) + >>> nanops.nanany(s.values) + np.False_ + """ + if values.dtype.kind in "iub" and mask is None: + # GH#26032 fastpath + # error: Incompatible return value type (got "Union[bool_, ndarray]", + # expected "bool") + return values.any(axis) # type: ignore[return-value] + + if values.dtype.kind == "M": + # GH#34479 + raise TypeError("datetime64 type does not support operation 'any'") + + values, _ = _get_values(values, skipna, fill_value=False, mask=mask) + + # For object type, any won't necessarily return + # boolean values (numpy/numpy#4352) + if values.dtype == object: + values = values.astype(bool) + + # error: Incompatible return value type (got "Union[bool_, ndarray]", expected + # "bool") + return values.any(axis) # type: ignore[return-value] + + +def nanall( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + mask: npt.NDArray[np.bool_] | None = None, +) -> bool: + """ + Check if all elements along an axis evaluate to True. + + Parameters + ---------- + values : ndarray + axis : int, optional + skipna : bool, default True + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + result : bool + + Examples + -------- + >>> from pandas.core import nanops + >>> s = pd.Series([1, 2, np.nan]) + >>> nanops.nanall(s.values) + np.True_ + + >>> from pandas.core import nanops + >>> s = pd.Series([1, 0]) + >>> nanops.nanall(s.values) + np.False_ + """ + if values.dtype.kind in "iub" and mask is None: + # GH#26032 fastpath + # error: Incompatible return value type (got "Union[bool_, ndarray]", + # expected "bool") + return values.all(axis) # type: ignore[return-value] + + if values.dtype.kind == "M": + # GH#34479 + raise TypeError("datetime64 type does not support operation 'all'") + + values, _ = _get_values(values, skipna, fill_value=True, mask=mask) + + # For object type, all won't necessarily return + # boolean values (numpy/numpy#4352) + if values.dtype == object: + values = values.astype(bool) + + # error: Incompatible return value type (got "Union[bool_, ndarray]", expected + # "bool") + return values.all(axis) # type: ignore[return-value] + + +@disallow("M8") +@_datetimelike_compat +@maybe_operate_rowwise +def nansum( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + min_count: int = 0, + mask: npt.NDArray[np.bool_] | None = None, +) -> npt.NDArray[np.floating] | float | NaTType: + """ + Sum the elements along an axis ignoring NaNs + + Parameters + ---------- + values : ndarray[dtype] + axis : int, optional + skipna : bool, default True + min_count: int, default 0 + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + result : dtype + + Examples + -------- + >>> from pandas.core import nanops + >>> s = pd.Series([1, 2, np.nan]) + >>> nanops.nansum(s.values) + np.float64(3.0) + """ + dtype = values.dtype + values, mask = _get_values(values, skipna, fill_value=0, mask=mask) + dtype_sum = _get_dtype_max(dtype) + if dtype.kind == "f": + dtype_sum = dtype + elif dtype.kind == "m": + dtype_sum = np.dtype(np.float64) + + the_sum = values.sum(axis, dtype=dtype_sum) + the_sum = _maybe_null_out(the_sum, axis, mask, values.shape, min_count=min_count) + + return the_sum + + +def _mask_datetimelike_result( + result: np.ndarray | np.datetime64 | np.timedelta64, + axis: AxisInt | None, + mask: npt.NDArray[np.bool_], + orig_values: np.ndarray, +) -> np.ndarray | np.datetime64 | np.timedelta64 | NaTType: + if isinstance(result, np.ndarray): + # we need to apply the mask + result = result.astype("i8").view(orig_values.dtype) + axis_mask = mask.any(axis=axis) + result[axis_mask] = iNaT + elif mask.any(): + return np.int64(iNaT).view(orig_values.dtype) + return result + + +@bottleneck_switch() +@_datetimelike_compat +def nanmean( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + mask: npt.NDArray[np.bool_] | None = None, +) -> float: + """ + Compute the mean of the element along an axis ignoring NaNs + + Parameters + ---------- + values : ndarray + axis : int, optional + skipna : bool, default True + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + float + Unless input is a float array, in which case use the same + precision as the input array. + + Examples + -------- + >>> from pandas.core import nanops + >>> s = pd.Series([1, 2, np.nan]) + >>> nanops.nanmean(s.values) + np.float64(1.5) + """ + if values.dtype == object and len(values) > 1_000 and mask is None: + # GH#54754 if we are going to fail, try to fail-fast + nanmean(values[:1000], axis=axis, skipna=skipna) + + dtype = values.dtype + values, mask = _get_values(values, skipna, fill_value=0, mask=mask) + dtype_sum = _get_dtype_max(dtype) + dtype_count = np.dtype(np.float64) + + # not using needs_i8_conversion because that includes period + if dtype.kind in "mM": + dtype_sum = np.dtype(np.float64) + elif dtype.kind in "iu": + dtype_sum = np.dtype(np.float64) + elif dtype.kind == "f": + dtype_sum = dtype + dtype_count = dtype + + count = _get_counts(values.shape, mask, axis, dtype=dtype_count) + the_sum = values.sum(axis, dtype=dtype_sum) + the_sum = _ensure_numeric(the_sum) + + if axis is not None and getattr(the_sum, "ndim", False): + count = cast(np.ndarray, count) + with np.errstate(all="ignore"): + # suppress division by zero warnings + the_mean = the_sum / count + ct_mask = count == 0 + if ct_mask.any(): + the_mean[ct_mask] = np.nan + else: + the_mean = the_sum / count if count > 0 else np.nan + + return the_mean + + +@bottleneck_switch() +def nanmedian( + values: np.ndarray, *, axis: AxisInt | None = None, skipna: bool = True, mask=None +) -> float | np.ndarray: + """ + Parameters + ---------- + values : ndarray + axis : int, optional + skipna : bool, default True + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + result : float | ndarray + Unless input is a float array, in which case use the same + precision as the input array. + + Examples + -------- + >>> from pandas.core import nanops + >>> s = pd.Series([1, np.nan, 2, 2]) + >>> nanops.nanmedian(s.values) + 2.0 + + >>> s = pd.Series([np.nan, np.nan, np.nan]) + >>> nanops.nanmedian(s.values) + nan + """ + # for floats without mask, the data already uses NaN as missing value + # indicator, and `mask` will be calculated from that below -> in those + # cases we never need to set NaN to the masked values + using_nan_sentinel = values.dtype.kind == "f" and mask is None + + def get_median(x: np.ndarray, _mask=None): + if _mask is None: + _mask = notna(x) + else: + _mask = ~_mask + if not skipna and not _mask.all(): + return np.nan + with warnings.catch_warnings(): + # Suppress RuntimeWarning about All-NaN slice + warnings.filterwarnings( + "ignore", "All-NaN slice encountered", RuntimeWarning + ) + warnings.filterwarnings("ignore", "Mean of empty slice", RuntimeWarning) + res = np.nanmedian(x[_mask]) + return res + + dtype = values.dtype + values, mask = _get_values(values, skipna, mask=mask, fill_value=None) + if values.dtype.kind != "f": + if values.dtype == object: + # GH#34671 avoid casting strings to numeric + inferred = lib.infer_dtype(values) + if inferred in ["string", "mixed"]: + raise TypeError(f"Cannot convert {values} to numeric") + try: + values = values.astype("f8") + except ValueError as err: + # e.g. "could not convert string to float: 'a'" + raise TypeError(str(err)) from err + if not using_nan_sentinel and mask is not None: + if not values.flags.writeable: + values = values.copy() + values[mask] = np.nan + + notempty = values.size + + res: float | np.ndarray + + # an array from a frame + if values.ndim > 1 and axis is not None: + # there's a non-empty array to apply over otherwise numpy raises + if notempty: + if not skipna: + res = np.apply_along_axis(get_median, axis, values) + + else: + # fastpath for the skipna case + with warnings.catch_warnings(): + # Suppress RuntimeWarning about All-NaN slice + warnings.filterwarnings( + "ignore", "All-NaN slice encountered", RuntimeWarning + ) + if (values.shape[1] == 1 and axis == 0) or ( + values.shape[0] == 1 and axis == 1 + ): + # GH52788: fastpath when squeezable, nanmedian for 2D array slow + res = np.nanmedian(np.squeeze(values), keepdims=True) + else: + res = np.nanmedian(values, axis=axis) + + else: + # must return the correct shape, but median is not defined for the + # empty set so return nans of shape "everything but the passed axis" + # since "axis" is where the reduction would occur if we had a nonempty + # array + res = _get_empty_reduction_result(values.shape, axis) + + else: + # otherwise return a scalar value + res = get_median(values, mask) if notempty else np.nan + return _wrap_results(res, dtype) + + +def _get_empty_reduction_result( + shape: Shape, + axis: AxisInt, +) -> np.ndarray: + """ + The result from a reduction on an empty ndarray. + + Parameters + ---------- + shape : Tuple[int, ...] + axis : int + + Returns + ------- + np.ndarray + """ + shp = np.array(shape) + dims = np.arange(len(shape)) + ret = np.empty(shp[dims != axis], dtype=np.float64) + ret.fill(np.nan) + return ret + + +def _get_counts_nanvar( + values_shape: Shape, + mask: npt.NDArray[np.bool_] | None, + axis: AxisInt | None, + ddof: int, + dtype: np.dtype = np.dtype(np.float64), +) -> tuple[float | np.ndarray, float | np.ndarray]: + """ + Get the count of non-null values along an axis, accounting + for degrees of freedom. + + Parameters + ---------- + values_shape : Tuple[int, ...] + shape tuple from values ndarray, used if mask is None + mask : Optional[ndarray[bool]] + locations in values that should be considered missing + axis : Optional[int] + axis to count along + ddof : int + degrees of freedom + dtype : type, optional + type to use for count + + Returns + ------- + count : int, np.nan or np.ndarray + d : int, np.nan or np.ndarray + """ + count = _get_counts(values_shape, mask, axis, dtype=dtype) + d = count - dtype.type(ddof) + + # always return NaN, never inf + if is_float(count): + if count <= ddof: + # error: Incompatible types in assignment (expression has type + # "float", variable has type "Union[floating[Any], ndarray[Any, + # dtype[floating[Any]]]]") + count = np.nan # type: ignore[assignment] + d = np.nan + else: + # count is not narrowed by is_float check + count = cast(np.ndarray, count) + mask = count <= ddof + if mask.any(): + np.putmask(d, mask, np.nan) + np.putmask(count, mask, np.nan) + return count, d + + +@bottleneck_switch(ddof=1) +def nanstd( + values, + *, + axis: AxisInt | None = None, + skipna: bool = True, + ddof: int = 1, + mask=None, +): + """ + Compute the standard deviation along given axis while ignoring NaNs + + Parameters + ---------- + values : ndarray + axis : int, optional + skipna : bool, default True + ddof : int, default 1 + Delta Degrees of Freedom. The divisor used in calculations is N - ddof, + where N represents the number of elements. + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + result : float + Unless input is a float array, in which case use the same + precision as the input array. + + Examples + -------- + >>> from pandas.core import nanops + >>> s = pd.Series([1, np.nan, 2, 3]) + >>> nanops.nanstd(s.values) + 1.0 + """ + if values.dtype.kind == "M": + unit = np.datetime_data(values.dtype)[0] + values = values.view(f"m8[{unit}]") + + orig_dtype = values.dtype + values, mask = _get_values(values, skipna, mask=mask) + + result = np.sqrt(nanvar(values, axis=axis, skipna=skipna, ddof=ddof, mask=mask)) + return _wrap_results(result, orig_dtype) + + +@disallow("M8", "m8") +@bottleneck_switch(ddof=1) +def nanvar( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + ddof: int = 1, + mask=None, +): + """ + Compute the variance along given axis while ignoring NaNs + + Parameters + ---------- + values : ndarray + axis : int, optional + skipna : bool, default True + ddof : int, default 1 + Delta Degrees of Freedom. The divisor used in calculations is N - ddof, + where N represents the number of elements. + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + result : float + Unless input is a float array, in which case use the same + precision as the input array. + + Examples + -------- + >>> from pandas.core import nanops + >>> s = pd.Series([1, np.nan, 2, 3]) + >>> nanops.nanvar(s.values) + 1.0 + """ + dtype = values.dtype + mask = _maybe_get_mask(values, skipna, mask) + if dtype.kind in "iu": + values = values.astype("f8") + if mask is not None: + values[mask] = np.nan + + if values.dtype.kind == "f": + count, d = _get_counts_nanvar(values.shape, mask, axis, ddof, values.dtype) + else: + count, d = _get_counts_nanvar(values.shape, mask, axis, ddof) + + if skipna and mask is not None: + values = values.copy() + np.putmask(values, mask, 0) + + # xref GH10242 + # Compute variance via two-pass algorithm, which is stable against + # cancellation errors and relatively accurate for small numbers of + # observations. + # + # See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + avg = _ensure_numeric(values.sum(axis=axis, dtype=np.float64)) / count + if axis is not None: + avg = np.expand_dims(avg, axis) + if values.dtype.kind == "c": + # Need to use absolute value for complex numbers. + sqr = _ensure_numeric(abs(avg - values) ** 2) + else: + sqr = _ensure_numeric((avg - values) ** 2) + if mask is not None: + np.putmask(sqr, mask, 0) + result = sqr.sum(axis=axis, dtype=np.float64) / d + + # Return variance as np.float64 (the datatype used in the accumulator), + # unless we were dealing with a float array, in which case use the same + # precision as the original values array. + if dtype.kind == "f": + result = result.astype(dtype, copy=False) + return result + + +@disallow("M8", "m8") +def nansem( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + ddof: int = 1, + mask: npt.NDArray[np.bool_] | None = None, +) -> float: + """ + Compute the standard error in the mean along given axis while ignoring NaNs + + Parameters + ---------- + values : ndarray + axis : int, optional + skipna : bool, default True + ddof : int, default 1 + Delta Degrees of Freedom. The divisor used in calculations is N - ddof, + where N represents the number of elements. + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + result : float64 + Unless input is a float array, in which case use the same + precision as the input array. + + Examples + -------- + >>> from pandas.core import nanops + >>> s = pd.Series([1, np.nan, 2, 3]) + >>> nanops.nansem(s.values) + np.float64(0.5773502691896258) + """ + # This checks if non-numeric-like data is passed with numeric_only=False + # and raises a TypeError otherwise + nanvar(values, axis=axis, skipna=skipna, ddof=ddof, mask=mask) + + mask = _maybe_get_mask(values, skipna, mask) + if values.dtype.kind != "f": + values = values.astype("f8") + + if not skipna and mask is not None and mask.any(): + return np.nan + + count, _ = _get_counts_nanvar(values.shape, mask, axis, ddof, values.dtype) + var = nanvar(values, axis=axis, skipna=skipna, ddof=ddof, mask=mask) + + return np.sqrt(var) / np.sqrt(count) + + +def _nanminmax(meth, fill_value_typ): + @bottleneck_switch(name=f"nan{meth}") + @_datetimelike_compat + def reduction( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + mask: npt.NDArray[np.bool_] | None = None, + ): + if values.size == 0: + return _na_for_min_count(values, axis) + + dtype = values.dtype + values, mask = _get_values( + values, skipna, fill_value_typ=fill_value_typ, mask=mask + ) + result = getattr(values, meth)(axis) + result = _maybe_null_out( + result, axis, mask, values.shape, datetimelike=dtype.kind in "mM" + ) + return result + + return reduction + + +nanmin = _nanminmax("min", fill_value_typ="+inf") +nanmax = _nanminmax("max", fill_value_typ="-inf") + + +def nanargmax( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + mask: npt.NDArray[np.bool_] | None = None, +) -> int | np.ndarray: + """ + Parameters + ---------- + values : ndarray + axis : int, optional + skipna : bool, default True + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + result : int or ndarray[int] + The index/indices of max value in specified axis or -1 in the NA case + + Examples + -------- + >>> from pandas.core import nanops + >>> arr = np.array([1, 2, 3, np.nan, 4]) + >>> nanops.nanargmax(arr) + np.int64(4) + + >>> arr = np.array(range(12), dtype=np.float64).reshape(4, 3) + >>> arr[2:, 2] = np.nan + >>> arr + array([[ 0., 1., 2.], + [ 3., 4., 5.], + [ 6., 7., nan], + [ 9., 10., nan]]) + >>> nanops.nanargmax(arr, axis=1) + array([2, 2, 1, 1]) + """ + values, mask = _get_values(values, True, fill_value_typ="-inf", mask=mask) + result = values.argmax(axis) + # error: Argument 1 to "_maybe_arg_null_out" has incompatible type "Any | + # signedinteger[Any]"; expected "ndarray[Any, Any]" + result = _maybe_arg_null_out(result, axis, mask, skipna) # type: ignore[arg-type] + return result + + +def nanargmin( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + mask: npt.NDArray[np.bool_] | None = None, +) -> int | np.ndarray: + """ + Parameters + ---------- + values : ndarray + axis : int, optional + skipna : bool, default True + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + result : int or ndarray[int] + The index/indices of min value in specified axis or -1 in the NA case + + Examples + -------- + >>> from pandas.core import nanops + >>> arr = np.array([1, 2, 3, np.nan, 4]) + >>> nanops.nanargmin(arr) + np.int64(0) + + >>> arr = np.array(range(12), dtype=np.float64).reshape(4, 3) + >>> arr[2:, 0] = np.nan + >>> arr + array([[ 0., 1., 2.], + [ 3., 4., 5.], + [nan, 7., 8.], + [nan, 10., 11.]]) + >>> nanops.nanargmin(arr, axis=1) + array([0, 0, 1, 1]) + """ + values, mask = _get_values(values, True, fill_value_typ="+inf", mask=mask) + result = values.argmin(axis) + # error: Argument 1 to "_maybe_arg_null_out" has incompatible type "Any | + # signedinteger[Any]"; expected "ndarray[Any, Any]" + result = _maybe_arg_null_out(result, axis, mask, skipna) # type: ignore[arg-type] + return result + + +@disallow("M8", "m8") +@maybe_operate_rowwise +def nanskew( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + mask: npt.NDArray[np.bool_] | None = None, +) -> float: + """ + Compute the sample skewness. + + The statistic computed here is the adjusted Fisher-Pearson standardized + moment coefficient G1. The algorithm computes this coefficient directly + from the second and third central moment. + + Parameters + ---------- + values : ndarray + axis : int, optional + skipna : bool, default True + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + result : float64 + Unless input is a float array, in which case use the same + precision as the input array. + + Examples + -------- + >>> from pandas.core import nanops + >>> s = pd.Series([1, np.nan, 1, 2]) + >>> nanops.nanskew(s.values) + np.float64(1.7320508075688787) + """ + mask = _maybe_get_mask(values, skipna, mask) + if values.dtype.kind != "f": + values = values.astype("f8") + count = _get_counts(values.shape, mask, axis) + else: + count = _get_counts(values.shape, mask, axis, dtype=values.dtype) + + if skipna and mask is not None: + values = values.copy() + np.putmask(values, mask, 0) + elif not skipna and mask is not None and mask.any(): + return np.nan + + with np.errstate(invalid="ignore", divide="ignore"): + mean = values.sum(axis, dtype=np.float64) / count + if axis is not None: + mean = np.expand_dims(mean, axis) + + adjusted = values - mean + if skipna and mask is not None: + np.putmask(adjusted, mask, 0) + adjusted2 = adjusted**2 + adjusted3 = adjusted2 * adjusted + m2 = adjusted2.sum(axis, dtype=np.float64) + m3 = adjusted3.sum(axis, dtype=np.float64) + + # floating point error. See comment in [nankurt] + max_abs = np.abs(values).max(axis, initial=0.0) + eps = np.finfo(m2.dtype).eps + constant_tolerance2 = ((eps * max_abs) ** 2) * count + constant_tolerance3 = ((eps * max_abs) ** 3) * count + m2 = _zero_out_fperr(m2, constant_tolerance2) + m3 = _zero_out_fperr(m3, constant_tolerance3) + + with np.errstate(invalid="ignore", divide="ignore"): + result = (count * (count - 1) ** 0.5 / (count - 2)) * (m3 / m2**1.5) + + dtype = values.dtype + if dtype.kind == "f": + result = result.astype(dtype, copy=False) + + if isinstance(result, np.ndarray): + result = np.where(m2 == 0, 0, result) + result[count < 3] = np.nan + else: + result = dtype.type(0) if m2 == 0 else result + if count < 3: + return np.nan + + return result + + +@disallow("M8", "m8") +@maybe_operate_rowwise +def nankurt( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + mask: npt.NDArray[np.bool_] | None = None, +) -> float: + """ + Compute the sample excess kurtosis + + The statistic computed here is the adjusted Fisher-Pearson standardized + moment coefficient G2, computed directly from the second and fourth + central moment. + + Parameters + ---------- + values : ndarray + axis : int, optional + skipna : bool, default True + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + result : float64 + Unless input is a float array, in which case use the same + precision as the input array. + + Examples + -------- + >>> from pandas.core import nanops + >>> s = pd.Series([1, np.nan, 1, 3, 2]) + >>> nanops.nankurt(s.values) + np.float64(-1.2892561983471076) + """ + mask = _maybe_get_mask(values, skipna, mask) + if values.dtype.kind != "f": + values = values.astype("f8") + count = _get_counts(values.shape, mask, axis) + else: + count = _get_counts(values.shape, mask, axis, dtype=values.dtype) + + if skipna and mask is not None: + values = values.copy() + np.putmask(values, mask, 0) + elif not skipna and mask is not None and mask.any(): + return np.nan + + with np.errstate(invalid="ignore", divide="ignore"): + mean = values.sum(axis, dtype=np.float64) / count + if axis is not None: + mean = np.expand_dims(mean, axis) + + adjusted = values - mean + if skipna and mask is not None: + np.putmask(adjusted, mask, 0) + adjusted2 = adjusted**2 + adjusted4 = adjusted2**2 + m2 = adjusted2.sum(axis, dtype=np.float64) + m4 = adjusted4.sum(axis, dtype=np.float64) + + # Several floating point errors may occur during the summation due to rounding. + # This computation is similar to the one in Scipy + # https://github.com/scipy/scipy/blob/04d6d9c460b1fed83f2919ecec3d743cfa2e8317/scipy/stats/_stats_py.py#L1429 + # With a few modifications, like using the maximum value instead of the averages + # and some adaptations because they use the average and we use the sum for `m2`. + # We need to estimate an upper bound to the error to consider the data constant. + # Let's call: + # x: true value in data + # y: floating point representation + # e: relative approximation error + # n: number of observations in array + # + # We have that: + # |x - y|/|x| <= e (See https://en.wikipedia.org/wiki/Machine_epsilon) + # (|x - y|/|x|)² <= e² + # Σ (|x - y|/|x|)² <= ne² + # + # Let's say that the fperr upper bound for m2 is constrained by the summation. + # |m2 - y|/|m2| <= ne² + # |m2 - y| <= n|m2|e² + # + # We will use max (x²) to estimate |m2| + max_abs = np.abs(values).max(axis, initial=0.0) + eps = np.finfo(m2.dtype).eps + constant_tolerance2 = ((eps * max_abs) ** 2) * count + constant_tolerance4 = ((eps * max_abs) ** 4) * count + m2 = _zero_out_fperr(m2, constant_tolerance2) + m4 = _zero_out_fperr(m4, constant_tolerance4) + + with np.errstate(invalid="ignore", divide="ignore"): + adj = 3 * (count - 1) ** 2 / ((count - 2) * (count - 3)) + numerator = count * (count + 1) * (count - 1) * m4 + denominator = (count - 2) * (count - 3) * m2**2 + + if not isinstance(denominator, np.ndarray): + # if ``denom`` is a scalar, check these corner cases first before + # doing division + if count < 4: + return np.nan + if denominator == 0: + return values.dtype.type(0) + + with np.errstate(invalid="ignore", divide="ignore"): + result = numerator / denominator - adj + + dtype = values.dtype + if dtype.kind == "f": + result = result.astype(dtype, copy=False) + + if isinstance(result, np.ndarray): + result = np.where(denominator == 0, 0, result) + result[count < 4] = np.nan + + return result + + +@disallow("M8", "m8") +@maybe_operate_rowwise +def nanprod( + values: np.ndarray, + *, + axis: AxisInt | None = None, + skipna: bool = True, + min_count: int = 0, + mask: npt.NDArray[np.bool_] | None = None, +) -> float: + """ + Parameters + ---------- + values : ndarray[dtype] + axis : int, optional + skipna : bool, default True + min_count: int, default 0 + mask : ndarray[bool], optional + nan-mask if known + + Returns + ------- + Dtype + The product of all elements on a given axis. ( NaNs are treated as 1) + + Examples + -------- + >>> from pandas.core import nanops + >>> s = pd.Series([1, 2, 3, np.nan]) + >>> nanops.nanprod(s.values) + np.float64(6.0) + """ + mask = _maybe_get_mask(values, skipna, mask) + + if skipna and mask is not None: + values = values.copy() + values[mask] = 1 + result = values.prod(axis) + # error: Incompatible return value type (got "Union[ndarray, float]", expected + # "float") + return _maybe_null_out( # type: ignore[return-value] + result, axis, mask, values.shape, min_count=min_count + ) + + +def _maybe_arg_null_out( + result: np.ndarray, + axis: AxisInt | None, + mask: npt.NDArray[np.bool_] | None, + skipna: bool, +) -> np.ndarray | int: + # helper function for nanargmin/nanargmax + if mask is None: + return result + + if axis is None or not getattr(result, "ndim", False): + if skipna and mask.all(): + raise ValueError("Encountered all NA values") + elif not skipna and mask.any(): + raise ValueError("Encountered an NA value with skipna=False") + elif skipna and mask.all(axis).any(): + raise ValueError("Encountered all NA values") + elif not skipna and mask.any(axis).any(): + raise ValueError("Encountered an NA value with skipna=False") + return result + + +def _get_counts( + values_shape: Shape, + mask: npt.NDArray[np.bool_] | None, + axis: AxisInt | None, + dtype: np.dtype[np.floating] = np.dtype(np.float64), +) -> np.floating | npt.NDArray[np.floating]: + """ + Get the count of non-null values along an axis + + Parameters + ---------- + values_shape : tuple of int + shape tuple from values ndarray, used if mask is None + mask : Optional[ndarray[bool]] + locations in values that should be considered missing + axis : Optional[int] + axis to count along + dtype : type, optional + type to use for count + + Returns + ------- + count : scalar or array + """ + if axis is None: + if mask is not None: + n = mask.size - mask.sum() + else: + n = np.prod(values_shape) + return dtype.type(n) + + if mask is not None: + count = mask.shape[axis] - mask.sum(axis) + else: + count = values_shape[axis] + + if is_integer(count): + return dtype.type(count) + return count.astype(dtype, copy=False) + + +def _maybe_null_out( + result: np.ndarray | float | NaTType, + axis: AxisInt | None, + mask: npt.NDArray[np.bool_] | None, + shape: tuple[int, ...], + min_count: int = 1, + datetimelike: bool = False, +) -> np.ndarray | float | NaTType: + """ + Returns + ------- + Dtype + The product of all elements on a given axis. ( NaNs are treated as 1) + """ + if mask is None and min_count == 0: + # nothing to check; short-circuit + return result + + if axis is not None and isinstance(result, np.ndarray): + if mask is not None: + null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0 + else: + # we have no nulls, kept mask=None in _maybe_get_mask + below_count = shape[axis] - min_count < 0 + new_shape = shape[:axis] + shape[axis + 1 :] + null_mask = np.broadcast_to(below_count, new_shape) + + if np.any(null_mask): + if datetimelike: + # GH#60646 For datetimelike, no need to cast to float + result[null_mask] = iNaT + elif is_numeric_dtype(result): + if np.iscomplexobj(result): + result = result.astype("c16") + elif not is_float_dtype(result): + result = result.astype("f8", copy=False) + result[null_mask] = np.nan + else: + # GH12941, use None to auto cast null + result[null_mask] = None + elif result is not NaT: + if check_below_min_count(shape, mask, min_count): + result_dtype = getattr(result, "dtype", None) + if is_float_dtype(result_dtype): + # error: Item "None" of "Optional[Any]" has no attribute "type" + result = result_dtype.type("nan") # type: ignore[union-attr] + else: + result = np.nan + + return result + + +def check_below_min_count( + shape: tuple[int, ...], mask: npt.NDArray[np.bool_] | None, min_count: int +) -> bool: + """ + Check for the `min_count` keyword. Returns True if below `min_count` (when + missing value should be returned from the reduction). + + Parameters + ---------- + shape : tuple + The shape of the values (`values.shape`). + mask : ndarray[bool] or None + Boolean numpy array (typically of same shape as `shape`) or None. + min_count : int + Keyword passed through from sum/prod call. + + Returns + ------- + bool + """ + if min_count > 0: + if mask is None: + # no missing values, only check size + non_nulls = np.prod(shape) + else: + non_nulls = mask.size - mask.sum() + if non_nulls < min_count: + return True + return False + + +def _zero_out_fperr(arg, tol: float | np.ndarray): + # #18044 reference this behavior to fix rolling skew/kurt issue + if isinstance(arg, np.ndarray): + return np.where(np.abs(arg) < tol, 0, arg) + else: + return arg.dtype.type(0) if np.abs(arg) < tol else arg + + +@disallow("M8", "m8") +def nancorr( + a: np.ndarray, + b: np.ndarray, + *, + method: CorrelationMethod = "pearson", + min_periods: int | None = None, +) -> float: + """ + a, b: ndarrays + """ + if len(a) != len(b): + raise AssertionError("Operands to nancorr must have same size") + + if min_periods is None: + min_periods = 1 + + valid = notna(a) & notna(b) + if not valid.all(): + a = a[valid] + b = b[valid] + + if len(a) < min_periods: + return np.nan + + a = _ensure_numeric(a) + b = _ensure_numeric(b) + + f = get_corr_func(method) + return f(a, b) + + +def get_corr_func( + method: CorrelationMethod, +) -> Callable[[np.ndarray, np.ndarray], float]: + if method == "kendall": + from scipy.stats import kendalltau + + def func(a, b): + return kendalltau(a, b)[0] + + return func + elif method == "spearman": + from scipy.stats import spearmanr + + def func(a, b): + return spearmanr(a, b)[0] + + return func + elif method == "pearson": + + def func(a, b): + return np.corrcoef(a, b)[0, 1] + + return func + elif callable(method): + return method + + raise ValueError( + f"Unknown method '{method}', expected one of " + "'kendall', 'spearman', 'pearson', or callable" + ) + + +@disallow("M8", "m8") +def nancov( + a: np.ndarray, + b: np.ndarray, + *, + min_periods: int | None = None, + ddof: int | None = 1, +) -> float: + if len(a) != len(b): + raise AssertionError("Operands to nancov must have same size") + + if min_periods is None: + min_periods = 1 + + valid = notna(a) & notna(b) + if not valid.all(): + a = a[valid] + b = b[valid] + + if len(a) < min_periods: + return np.nan + + a = _ensure_numeric(a) + b = _ensure_numeric(b) + + return np.cov(a, b, ddof=ddof)[0, 1] + + +def _ensure_numeric(x): + if isinstance(x, np.ndarray): + if x.dtype.kind in "biu": + x = x.astype(np.float64) + elif x.dtype == object: + inferred = lib.infer_dtype(x) + if inferred in ["string", "mixed"]: + # GH#44008, GH#36703 avoid casting e.g. strings to numeric + raise TypeError(f"Could not convert {x} to numeric") + try: + x = x.astype(np.complex128) + except (TypeError, ValueError): + try: + x = x.astype(np.float64) + except ValueError as err: + # GH#29941 we get here with object arrays containing strs + raise TypeError(f"Could not convert {x} to numeric") from err + else: + if not np.any(np.imag(x)): + x = x.real + elif not (is_float(x) or is_integer(x) or is_complex(x)): + if isinstance(x, str): + # GH#44008, GH#36703 avoid casting e.g. strings to numeric + raise TypeError(f"Could not convert string '{x}' to numeric") + try: + x = float(x) + except (TypeError, ValueError): + # e.g. "1+1j" or "foo" + try: + x = complex(x) + except ValueError as err: + # e.g. "foo" + raise TypeError(f"Could not convert {x} to numeric") from err + return x + + +def na_accum_func(values: ArrayLike, accum_func, *, skipna: bool) -> ArrayLike: + """ + Cumulative function with skipna support. + + Parameters + ---------- + values : np.ndarray or ExtensionArray + accum_func : {np.cumprod, np.maximum.accumulate, np.cumsum, np.minimum.accumulate} + skipna : bool + + Returns + ------- + np.ndarray or ExtensionArray + """ + mask_a, mask_b = { + np.cumprod: (1.0, np.nan), + np.maximum.accumulate: (-np.inf, np.nan), + np.cumsum: (0.0, np.nan), + np.minimum.accumulate: (np.inf, np.nan), + }[accum_func] + + # This should go through ea interface + assert values.dtype.kind not in "mM" + + # We will be applying this function to block values + if skipna and not issubclass(values.dtype.type, (np.integer, np.bool_)): + vals = values.copy() + mask = isna(vals) + vals[mask] = mask_a + result = accum_func(vals, axis=0) + result[mask] = mask_b + else: + result = accum_func(values, axis=0) + + return result diff --git a/pandas/core/resample.py b/pandas/core/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..beb60faf23b55d21a78a0a35b671a48bf28d93b5 --- /dev/null +++ b/pandas/core/resample.py @@ -0,0 +1,3150 @@ +from __future__ import annotations + +import copy +from typing import ( + TYPE_CHECKING, + Concatenate, + Literal, + Self, + cast, + final, + no_type_check, + overload, +) +import warnings + +import numpy as np + +from pandas._libs import lib +from pandas._libs.tslibs import ( + BaseOffset, + IncompatibleFrequency, + NaT, + Period, + Timedelta, + Timestamp, + to_offset, +) +from pandas._typing import NDFrameT +from pandas.errors import ( + AbstractMethodError, + Pandas4Warning, +) +from pandas.util._decorators import set_module +from pandas.util._exceptions import find_stack_level + +from pandas.core.dtypes.dtypes import ( + ArrowDtype, + PeriodDtype, +) +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCSeries, +) + +import pandas.core.algorithms as algos +from pandas.core.apply import ResamplerWindowApply +from pandas.core.arrays import ArrowExtensionArray +from pandas.core.base import ( + PandasObject, + SelectionMixin, +) +from pandas.core.generic import ( + NDFrame, +) +from pandas.core.groupby.groupby import ( + BaseGroupBy, + GroupBy, + get_groupby, +) +from pandas.core.groupby.grouper import Grouper +from pandas.core.groupby.ops import BinGrouper +from pandas.core.indexes.api import MultiIndex +from pandas.core.indexes.base import Index +from pandas.core.indexes.datetimes import ( + DatetimeIndex, + date_range, +) +from pandas.core.indexes.period import ( + PeriodIndex, + period_range, +) +from pandas.core.indexes.timedeltas import ( + TimedeltaIndex, + timedelta_range, +) +from pandas.core.reshape.concat import concat + +from pandas.tseries.frequencies import ( + is_subperiod, + is_superperiod, +) +from pandas.tseries.offsets import ( + Day, + Tick, +) + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Hashable, + ) + + from pandas._typing import ( + Any, + AnyArrayLike, + Axis, + FreqIndexT, + Frequency, + IndexLabel, + InterpolateOptions, + P, + T, + TimedeltaConvertibleTypes, + TimeGrouperOrigin, + TimestampConvertibleTypes, + TimeUnit, + npt, + ) + + from pandas import ( + DataFrame, + Series, + ) + from pandas.core.generic import NDFrame + +_shared_docs_kwargs: dict[str, str] = {} + + +@set_module("pandas.api.typing") +class Resampler(BaseGroupBy, PandasObject): + """ + Class for resampling datetimelike data, a groupby-like operation. + See aggregate, transform, and apply functions on this object. + + It's easiest to use obj.resample(...) to use Resampler. + + Parameters + ---------- + obj : Series or DataFrame + groupby : TimeGrouper + + Returns + ------- + a Resampler of the appropriate type + + Notes + ----- + After resampling, see aggregate, apply, and transform functions. + """ + + _grouper: BinGrouper + _timegrouper: TimeGrouper + binner: DatetimeIndex | TimedeltaIndex | PeriodIndex # depends on subclass + exclusions: frozenset[Hashable] = frozenset() # for SelectionMixin compat + _internal_names_set = set({"obj", "ax", "_indexer"}) + + # to the groupby descriptor + _attributes = [ + "freq", + "closed", + "label", + "convention", + "origin", + "offset", + ] + + def __init__( + self, + obj: NDFrame, + timegrouper: TimeGrouper, + *, + gpr_index: Index, + group_keys: bool = False, + selection=None, + include_groups: bool = False, + ) -> None: + if include_groups: + raise ValueError("include_groups=True is no longer allowed.") + self._timegrouper = timegrouper + self.keys = None + self.sort = True + self.group_keys = group_keys + self.as_index = True + + self.obj, self.ax, self._indexer = self._timegrouper._set_grouper( + self._convert_obj(obj), sort=True, gpr_index=gpr_index + ) + self.binner, self._grouper = self._get_binner() + self._selection = selection + if self._timegrouper.key is not None: + self.exclusions = frozenset([self._timegrouper.key]) + else: + self.exclusions = frozenset() + + @final + def __str__(self) -> str: + """ + Provide a nice str repr of our rolling object. + """ + attrs = ( + f"{k}={getattr(self._timegrouper, k)}" + for k in self._attributes + if getattr(self._timegrouper, k, None) is not None + ) + return f"{type(self).__name__} [{', '.join(attrs)}]" + + @final + def __getattr__(self, attr: str): + if attr in self._internal_names_set: + return object.__getattribute__(self, attr) + if attr in self._attributes: + return getattr(self._timegrouper, attr) + if attr in self.obj: + return self[attr] + + return object.__getattribute__(self, attr) + + @final + @property + def _from_selection(self) -> bool: + """ + Is the resampling from a DataFrame column or MultiIndex level. + """ + # upsampling and PeriodIndex resampling do not work + # with selection, this state used to catch and raise an error + return self._timegrouper is not None and ( + self._timegrouper.key is not None or self._timegrouper.level is not None + ) + + def _convert_obj(self, obj: NDFrameT) -> NDFrameT: + """ + Provide any conversions for the object in order to correctly handle. + + Parameters + ---------- + obj : Series or DataFrame + + Returns + ------- + Series or DataFrame + """ + return obj._consolidate() + + def _get_binner_for_time(self): + raise AbstractMethodError(self) + + @final + def _get_binner(self): + """ + Create the BinGrouper, assume that self.set_grouper(obj) + has already been called. + """ + binner, bins, binlabels = self._get_binner_for_time() + assert len(bins) == len(binlabels) + if self._timegrouper._arrow_dtype is not None: + binlabels = binlabels.astype(self._timegrouper._arrow_dtype) + bin_grouper = BinGrouper(bins, binlabels, indexer=self._indexer) + return binner, bin_grouper + + @overload + def pipe( + self, + func: Callable[Concatenate[Self, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: ... + + @overload + def pipe( + self, + func: tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, + ) -> T: ... + + @final + def pipe( + self, + func: Callable[Concatenate[Self, P], T] | tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, + ) -> T: + """ + Apply a ``func`` with arguments to this Resampler object and return its result. + + Use `.pipe` when you want to improve readability by chaining together + functions that expect Series, DataFrames, GroupBy or Resampler objects. + Instead of writing + + >>> h = lambda x, arg2, arg3: x + 1 - arg2 * arg3 + >>> g = lambda x, arg1: x * 5 / arg1 + >>> f = lambda x: x**4 + >>> df = pd.DataFrame([["a", 4], ["b", 5]], columns=["group", "value"]) + >>> h(g(f(df.groupby("group")), arg1=1), arg2=2, arg3=3) # doctest: +SKIP + + You can write + + >>> ( + ... df.groupby("group").pipe(f).pipe(g, arg1=1).pipe(h, arg2=2, arg3=3) + ... ) # doctest: +SKIP + + which is much more readable. + + Parameters + ---------- + func : callable or tuple of (callable, str) + Function to apply to this Resampler object or, alternatively, + a `(callable, data_keyword)` tuple where `data_keyword` is a + string indicating the keyword of `callable` that expects the + Resampler object. + *args : iterable, optional + Positional arguments passed into `func`. + **kwargs : dict, optional + A dictionary of keyword arguments passed into `func`. + + Returns + ------- + any + The result of applying ``func`` to the Resampler object. + + See Also + -------- + Series.pipe : Apply a function with arguments to a series. + DataFrame.pipe: Apply a function with arguments to a dataframe. + apply : Apply function to each group instead of to the + full Resampler object. + + Notes + ----- + See more `here + `_ + + Examples + -------- + >>> df = pd.DataFrame( + ... {"A": [1, 2, 3, 4]}, index=pd.date_range("2012-08-02", periods=4) + ... ) + >>> df + A + 2012-08-02 1 + 2012-08-03 2 + 2012-08-04 3 + 2012-08-05 4 + + To get the difference between each 2-day period's maximum and minimum + value in one pass, you can do + + >>> df.resample("2D").pipe(lambda x: x.max() - x.min()) + A + 2012-08-02 1 + 2012-08-04 1 + """ + return super().pipe(func, *args, **kwargs) + + @final + def aggregate(self, func=None, *args, **kwargs): + """ + Aggregate using one or more operations over the specified axis. + + Parameters + ---------- + func : function, str, list or dict + Function to use for aggregating the data. If a function, must either + work when passed a DataFrame or when passed to DataFrame.apply. + + Accepted combinations are: + + - function + - string function name + - list of functions and/or function names, e.g. ``[np.sum, 'mean']`` + - dict of axis labels -> functions, function names or list of such. + *args + Positional arguments to pass to `func`. + **kwargs + Keyword arguments to pass to `func`. + + Returns + ------- + scalar, Series or DataFrame + + The return can be: + + * scalar : when Series.agg is called with single function + * Series : when DataFrame.agg is called with a single function + * DataFrame : when DataFrame.agg is called with several functions + + See Also + -------- + DataFrame.groupby.aggregate : Aggregate using callable, string, dict, + or list of string/callables. + DataFrame.resample.transform : Transforms the Series on each group + based on the given function. + DataFrame.aggregate: Aggregate using one or more + operations over the specified axis. + + Notes + ----- + The aggregation operations are always performed over an axis, either the + index (default) or the column axis. This behavior is different from + `numpy` aggregation functions (`mean`, `median`, `prod`, `sum`, `std`, + `var`), where the default is to compute the aggregation of the flattened + array, e.g., ``numpy.mean(arr_2d)`` as opposed to + ``numpy.mean(arr_2d, axis=0)``. + + `agg` is an alias for `aggregate`. Use the alias. + + Functions that mutate the passed object can produce unexpected + behavior or errors and are not supported. See :ref:`gotchas.udf-mutation` + for more details. + + A passed user-defined-function will be passed a Series for evaluation. + + If ``func`` defines an index relabeling, ``axis`` must be ``0`` or ``index``. + + Examples + -------- + >>> s = pd.Series( + ... [1, 2, 3, 4, 5], index=pd.date_range("20130101", periods=5, freq="s") + ... ) + >>> s + 2013-01-01 00:00:00 1 + 2013-01-01 00:00:01 2 + 2013-01-01 00:00:02 3 + 2013-01-01 00:00:03 4 + 2013-01-01 00:00:04 5 + Freq: s, dtype: int64 + + >>> r = s.resample("2s") + + >>> r.agg("sum") + 2013-01-01 00:00:00 3 + 2013-01-01 00:00:02 7 + 2013-01-01 00:00:04 5 + Freq: 2s, dtype: int64 + + >>> r.agg(["sum", "mean", "max"]) + sum mean max + 2013-01-01 00:00:00 3 1.5 2 + 2013-01-01 00:00:02 7 3.5 4 + 2013-01-01 00:00:04 5 5.0 5 + + >>> r.agg({"result": lambda x: x.mean() / x.std(), "total": "sum"}) + result total + 2013-01-01 00:00:00 2.121320 3 + 2013-01-01 00:00:02 4.949747 7 + 2013-01-01 00:00:04 NaN 5 + + >>> r.agg(average="mean", total="sum") + average total + 2013-01-01 00:00:00 1.5 3 + 2013-01-01 00:00:02 3.5 7 + 2013-01-01 00:00:04 5.0 5 + """ + result = ResamplerWindowApply(self, func, args=args, kwargs=kwargs).agg() + if result is None: + how = func + result = self._groupby_and_aggregate(how, *args, **kwargs) + + return result + + agg = aggregate + apply = aggregate + + @final + def transform(self, arg, *args, **kwargs): + """ + Call function producing a like-indexed Series on each group. + + Return a Series with the transformed values. + + Parameters + ---------- + arg : function + To apply to each group. Should return a Series with the same index. + *args, **kwargs + Additional arguments and keywords. + + Returns + ------- + Series + A Series with the transformed values, maintaining the same index as + the original object. + + See Also + -------- + core.resample.Resampler.apply : Apply a function along each group. + core.resample.Resampler.aggregate : Aggregate using one or more operations + over the specified axis. + + Examples + -------- + >>> s = pd.Series([1, 2], index=pd.date_range("20180101", periods=2, freq="1h")) + >>> s + 2018-01-01 00:00:00 1 + 2018-01-01 01:00:00 2 + Freq: h, dtype: int64 + + >>> resampled = s.resample("15min") + >>> resampled.transform(lambda x: (x - x.mean()) / x.std()) + 2018-01-01 00:00:00 NaN + 2018-01-01 01:00:00 NaN + Freq: h, dtype: float64 + """ + return self._selected_obj.groupby(self._timegrouper).transform( + arg, *args, **kwargs + ) + + def _downsample(self, how, **kwargs): + raise AbstractMethodError(self) + + def _upsample(self, f, limit: int | None = None, fill_value=None): + raise AbstractMethodError(self) + + def _gotitem(self, key, ndim: int, subset=None): + """ + Sub-classes to define. Return a sliced object. + + Parameters + ---------- + key : string / list of selections + ndim : {1, 2} + requested ndim of result + subset : object, default None + subset to act on + """ + grouper = self._grouper + if subset is None: + subset = self.obj + if key is not None: + subset = subset[key] + else: + # reached via Apply.agg_dict_like with selection=None and ndim=1 + assert subset.ndim == 1 + if ndim == 1: + assert subset.ndim == 1 + + grouped = get_groupby( + subset, by=None, grouper=grouper, group_keys=self.group_keys + ) + return grouped + + def _groupby_and_aggregate(self, how, *args, **kwargs): + """ + Re-evaluate the obj with a groupby aggregation. + """ + grouper = self._grouper + + # Excludes `on` column when provided + obj = self._obj_with_exclusions + + grouped = get_groupby(obj, by=None, grouper=grouper, group_keys=self.group_keys) + + try: + if callable(how): + # TODO: test_resample_apply_with_additional_args fails if we go + # through the non-lambda path, not clear that it should. + func = lambda x: how(x, *args, **kwargs) + result = grouped.aggregate(func) + else: + result = grouped.aggregate(how, *args, **kwargs) + except (AttributeError, KeyError): + # we have a non-reducing function; try to evaluate + # alternatively we want to evaluate only a column of the input + + # test_apply_to_one_column_of_df the function being applied references + # a DataFrame column, but aggregate_item_by_item operates column-wise + # on Series, raising AttributeError or KeyError + # (depending on whether the column lookup uses getattr/__getitem__) + result = grouped.apply(how, *args, **kwargs) + + except ValueError as err: + if "Must produce aggregated value" in str(err): + # raised in _aggregate_named + # see test_apply_without_aggregation, test_apply_with_mutated_index + pass + else: + raise + + # we have a non-reducing function + # try to evaluate + result = grouped.apply(how, *args, **kwargs) + + return self._wrap_result(result) + + @final + def _get_resampler_for_grouping( + self, + groupby: GroupBy, + key, + ): + """ + Return the correct class for resampling with groupby. + """ + return self._resampler_for_grouping( + groupby=groupby, + key=key, + parent=self, + ) + + def _wrap_result(self, result): + """ + Potentially wrap any results. + """ + if isinstance(result, ABCSeries) and self._selection is not None: + result.name = self._selection + + if isinstance(result, ABCSeries) and result.empty: + # When index is all NaT, result is empty but index is not + obj = self.obj + result.index = _asfreq_compat(obj.index[:0], freq=self.freq) + result.name = getattr(obj, "name", None) + + if self._timegrouper._arrow_dtype is not None: + result.index = result.index.astype(self._timegrouper._arrow_dtype) + result.index.name = self.obj.index.name + + return result + + @final + def ffill(self, limit: int | None = None): + """ + Forward fill the values. + + This method fills missing values by propagating the last valid + observation forward, up to the next valid observation. It is commonly + used in time series analysis when resampling data to a higher frequency + (upsampling) and filling gaps in the resampled output. + + Parameters + ---------- + limit : int, optional + Limit of how many values to fill. + + Returns + ------- + Series + The resampled data with missing values filled forward. + + See Also + -------- + Series.fillna: Fill NA/NaN values using the specified method. + DataFrame.fillna: Fill NA/NaN values using the specified method. + + Examples + -------- + Here we only create a ``Series``. + + >>> ser = pd.Series( + ... [1, 2, 3, 4], + ... index=pd.DatetimeIndex( + ... ["2023-01-01", "2023-01-15", "2023-02-01", "2023-02-15"] + ... ), + ... ) + >>> ser + 2023-01-01 1 + 2023-01-15 2 + 2023-02-01 3 + 2023-02-15 4 + dtype: int64 + + Example for ``ffill`` with downsampling (we have fewer dates after resampling): + + >>> ser.resample("MS").ffill() + 2023-01-01 1 + 2023-02-01 3 + Freq: MS, dtype: int64 + + Example for ``ffill`` with upsampling (fill the new dates with + the previous value): + + >>> ser.resample("W").ffill() + 2023-01-01 1 + 2023-01-08 1 + 2023-01-15 2 + 2023-01-22 2 + 2023-01-29 2 + 2023-02-05 3 + 2023-02-12 3 + 2023-02-19 4 + Freq: W-SUN, dtype: int64 + + With upsampling and limiting (only fill the first new date with the + previous value): + + >>> ser.resample("W").ffill(limit=1) + 2023-01-01 1.0 + 2023-01-08 1.0 + 2023-01-15 2.0 + 2023-01-22 2.0 + 2023-01-29 NaN + 2023-02-05 3.0 + 2023-02-12 NaN + 2023-02-19 4.0 + Freq: W-SUN, dtype: float64 + """ + return self._upsample("ffill", limit=limit) + + @final + def nearest(self, limit: int | None = None): + """ + Resample by using the nearest value. + + When resampling data, missing values may appear (e.g., when the + resampling frequency is higher than the original frequency). + The `nearest` method will replace ``NaN`` values that appeared in + the resampled data with the value from the nearest member of the + sequence, based on the index value. + Missing values that existed in the original data will not be modified. + If `limit` is given, fill only this many values in each direction for + each of the original values. + + Parameters + ---------- + limit : int, optional + Limit of how many values to fill. + + Returns + ------- + Series or DataFrame + An upsampled Series or DataFrame with ``NaN`` values filled with + their nearest value. + + See Also + -------- + bfill : Backward fill the new missing values in the resampled data. + ffill : Forward fill ``NaN`` values. + + Examples + -------- + >>> s = pd.Series([1, 2], index=pd.date_range("20180101", periods=2, freq="1h")) + >>> s + 2018-01-01 00:00:00 1 + 2018-01-01 01:00:00 2 + Freq: h, dtype: int64 + + >>> s.resample("15min").nearest() + 2018-01-01 00:00:00 1 + 2018-01-01 00:15:00 1 + 2018-01-01 00:30:00 2 + 2018-01-01 00:45:00 2 + 2018-01-01 01:00:00 2 + Freq: 15min, dtype: int64 + + Limit the number of upsampled values imputed by the nearest: + + >>> s.resample("15min").nearest(limit=1) + 2018-01-01 00:00:00 1.0 + 2018-01-01 00:15:00 1.0 + 2018-01-01 00:30:00 NaN + 2018-01-01 00:45:00 2.0 + 2018-01-01 01:00:00 2.0 + Freq: 15min, dtype: float64 + """ + return self._upsample("nearest", limit=limit) + + @final + def bfill(self, limit: int | None = None): + """ + Backward fill the new missing values in the resampled data. + + In statistics, imputation is the process of replacing missing data with + substituted values [1]_. When resampling data, missing values may + appear (e.g., when the resampling frequency is higher than the original + frequency). The backward fill will replace NaN values that appeared in + the resampled data with the next value in the original sequence. + Missing values that existed in the original data will not be modified. + + Parameters + ---------- + limit : int, optional + Limit of how many values to fill. + + Returns + ------- + Series, DataFrame + An upsampled Series or DataFrame with backward filled NaN values. + + See Also + -------- + nearest : Fill NaN values with nearest neighbor starting from center. + ffill : Forward fill NaN values. + Series.fillna : Fill NaN values in the Series using the + specified method, which can be 'backfill'. + DataFrame.fillna : Fill NaN values in the DataFrame using the + specified method, which can be 'backfill'. + + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Imputation_%28statistics%29 + + Examples + -------- + Resampling a Series: + + >>> s = pd.Series( + ... [1, 2, 3], index=pd.date_range("20180101", periods=3, freq="h") + ... ) + >>> s + 2018-01-01 00:00:00 1 + 2018-01-01 01:00:00 2 + 2018-01-01 02:00:00 3 + Freq: h, dtype: int64 + + >>> s.resample("30min").bfill() + 2018-01-01 00:00:00 1 + 2018-01-01 00:30:00 2 + 2018-01-01 01:00:00 2 + 2018-01-01 01:30:00 3 + 2018-01-01 02:00:00 3 + Freq: 30min, dtype: int64 + + >>> s.resample("15min").bfill(limit=2) + 2018-01-01 00:00:00 1.0 + 2018-01-01 00:15:00 NaN + 2018-01-01 00:30:00 2.0 + 2018-01-01 00:45:00 2.0 + 2018-01-01 01:00:00 2.0 + 2018-01-01 01:15:00 NaN + 2018-01-01 01:30:00 3.0 + 2018-01-01 01:45:00 3.0 + 2018-01-01 02:00:00 3.0 + Freq: 15min, dtype: float64 + + Resampling a DataFrame that has missing values: + + >>> df = pd.DataFrame( + ... {"a": [2, np.nan, 6], "b": [1, 3, 5]}, + ... index=pd.date_range("20180101", periods=3, freq="h"), + ... ) + >>> df + a b + 2018-01-01 00:00:00 2.0 1 + 2018-01-01 01:00:00 NaN 3 + 2018-01-01 02:00:00 6.0 5 + + >>> df.resample("30min").bfill() + a b + 2018-01-01 00:00:00 2.0 1 + 2018-01-01 00:30:00 NaN 3 + 2018-01-01 01:00:00 NaN 3 + 2018-01-01 01:30:00 6.0 5 + 2018-01-01 02:00:00 6.0 5 + + >>> df.resample("15min").bfill(limit=2) + a b + 2018-01-01 00:00:00 2.0 1.0 + 2018-01-01 00:15:00 NaN NaN + 2018-01-01 00:30:00 NaN 3.0 + 2018-01-01 00:45:00 NaN 3.0 + 2018-01-01 01:00:00 NaN 3.0 + 2018-01-01 01:15:00 NaN NaN + 2018-01-01 01:30:00 6.0 5.0 + 2018-01-01 01:45:00 6.0 5.0 + 2018-01-01 02:00:00 6.0 5.0 + """ + return self._upsample("bfill", limit=limit) + + @final + def interpolate( + self, + method: InterpolateOptions = "linear", + *, + axis: Axis = 0, + limit: int | None = None, + limit_direction: Literal["forward", "backward", "both"] = "forward", + limit_area=None, + **kwargs, + ): + """ + Interpolate values between target timestamps according to different methods. + + The original index is first reindexed to target timestamps + (see :meth:`core.resample.Resampler.asfreq`), + then the interpolation of ``NaN`` values via :meth:`DataFrame.interpolate` + happens. + + Parameters + ---------- + method : str, default 'linear' + Interpolation technique to use. One of: + + * 'linear': Ignore the index and treat the values as equally + spaced. This is the only method supported on MultiIndexes. + * 'time': Works on daily and higher resolution data to interpolate + given length of interval. + * 'index', 'values': use the actual numerical values of the index. + * 'pad': Fill in NaNs using existing values. + * 'nearest', 'zero', 'slinear', 'quadratic', 'cubic', + 'barycentric', 'polynomial': Passed to + `scipy.interpolate.interp1d`, whereas 'spline' is passed to + `scipy.interpolate.UnivariateSpline`. These methods use the numerical + values of the index. Both 'polynomial' and 'spline' require that + you also specify an `order` (int), e.g. + ``df.interpolate(method='polynomial', order=5)``. Note that, + `slinear` method in Pandas refers to the Scipy first order `spline` + instead of Pandas first order `spline`. + * 'krogh', 'piecewise_polynomial', 'spline', 'pchip', 'akima', + 'cubicspline': Wrappers around the SciPy interpolation methods of + similar names. See `Notes`. + * 'from_derivatives': Refers to + `scipy.interpolate.BPoly.from_derivatives`. + + axis : {{0 or 'index', 1 or 'columns', None}}, default None + Axis to interpolate along. For `Series` this parameter is unused + and defaults to 0. + limit : int, optional + Maximum number of consecutive NaNs to fill. Must be greater than + 0. + limit_direction : {{'forward', 'backward', 'both'}}, Optional + Consecutive NaNs will be filled in this direction. + + limit_area : {{`None`, 'inside', 'outside'}}, default None + If limit is specified, consecutive NaNs will be filled with this + restriction. + + * ``None``: No fill restriction. + * 'inside': Only fill NaNs surrounded by valid values + (interpolate). + * 'outside': Only fill NaNs outside valid values (extrapolate). + + **kwargs : optional + Keyword arguments to pass on to the interpolating function. + + Returns + ------- + DataFrame or Series + Interpolated values at the specified freq. + + See Also + -------- + core.resample.Resampler.asfreq: Return the values at the new freq, + essentially a reindex. + DataFrame.interpolate: Fill NaN values using an interpolation method. + DataFrame.bfill : Backward fill NaN values in the resampled data. + DataFrame.ffill : Forward fill NaN values. + + Notes + ----- + For high-frequent or non-equidistant time-series with timestamps + the reindexing followed by interpolation may lead to information loss + as shown in the last example. + + Examples + -------- + + >>> start = "2023-03-01T07:00:00" + >>> timesteps = pd.date_range(start, periods=5, freq="s") + >>> series = pd.Series(data=[1, -1, 2, 1, 3], index=timesteps) + >>> series + 2023-03-01 07:00:00 1 + 2023-03-01 07:00:01 -1 + 2023-03-01 07:00:02 2 + 2023-03-01 07:00:03 1 + 2023-03-01 07:00:04 3 + Freq: s, dtype: int64 + + Downsample the dataframe to 0.5Hz by providing the period time of 2s. + + >>> series.resample("2s").interpolate("linear") + 2023-03-01 07:00:00 1 + 2023-03-01 07:00:02 2 + 2023-03-01 07:00:04 3 + Freq: 2s, dtype: int64 + + Upsample the dataframe to 2Hz by providing the period time of 500ms. + + >>> series.resample("500ms").interpolate("linear") + 2023-03-01 07:00:00.000 1.0 + 2023-03-01 07:00:00.500 0.0 + 2023-03-01 07:00:01.000 -1.0 + 2023-03-01 07:00:01.500 0.5 + 2023-03-01 07:00:02.000 2.0 + 2023-03-01 07:00:02.500 1.5 + 2023-03-01 07:00:03.000 1.0 + 2023-03-01 07:00:03.500 2.0 + 2023-03-01 07:00:04.000 3.0 + Freq: 500ms, dtype: float64 + + Internal reindexing with ``asfreq()`` prior to interpolation leads to + an interpolated timeseries on the basis of the reindexed timestamps + (anchors). It is assured that all available datapoints from original + series become anchors, so it also works for resampling-cases that lead + to non-aligned timestamps, as in the following example: + + >>> series.resample("400ms").interpolate("linear") + 2023-03-01 07:00:00.000 1.000000 + 2023-03-01 07:00:00.400 0.333333 + 2023-03-01 07:00:00.800 -0.333333 + 2023-03-01 07:00:01.200 0.000000 + 2023-03-01 07:00:01.600 1.000000 + 2023-03-01 07:00:02.000 2.000000 + 2023-03-01 07:00:02.400 1.666667 + 2023-03-01 07:00:02.800 1.333333 + 2023-03-01 07:00:03.200 1.666667 + 2023-03-01 07:00:03.600 2.333333 + 2023-03-01 07:00:04.000 3.000000 + Freq: 400ms, dtype: float64 + + Note that the series correctly decreases between two anchors + ``07:00:00`` and ``07:00:02``. + """ + if "inplace" in kwargs: + # GH#58690 + warnings.warn( + f"The 'inplace' keyword in {type(self).__name__}.interpolate " + "is deprecated and will be removed in a future version. " + "resample(...).interpolate is never inplace.", + Pandas4Warning, + stacklevel=find_stack_level(), + ) + inplace = kwargs.pop("inplace") + if inplace: + raise ValueError("Cannot interpolate inplace on a resampled object.") + + result = self._upsample("asfreq") + + # If the original data has timestamps which are not aligned with the + # target timestamps, we need to add those points back to the data frame + # that is supposed to be interpolated. This does not work with + # PeriodIndex, so we skip this case. GH#21351 + obj = self._selected_obj + is_period_index = isinstance(obj.index, PeriodIndex) + + # Skip this step for PeriodIndex + if not is_period_index: + final_index = result.index + if isinstance(final_index, MultiIndex): + raise NotImplementedError( + "Direct interpolation of MultiIndex data frames is not " + "supported. If you tried to resample and interpolate on a " + "grouped data frame, please use:\n" + "`df.groupby(...).apply(lambda x: x.resample(...)." + "interpolate(...))`" + "\ninstead, as resampling and interpolation has to be " + "performed for each group independently." + ) + + missing_data_points_index = obj.index.difference(final_index) + if len(missing_data_points_index) > 0: + result = concat( + [result, obj.loc[missing_data_points_index]] + ).sort_index() + + result_interpolated = result.interpolate( + method=method, + axis=axis, + limit=limit, + inplace=False, + limit_direction=limit_direction, + limit_area=limit_area, + **kwargs, + ) + + # No further steps if the original data has a PeriodIndex + if is_period_index: + return result_interpolated + + # Make sure that original data points which do not align with the + # resampled index are removed + result_interpolated = result_interpolated.loc[final_index] + + # Make sure frequency indexes are preserved + result_interpolated.index = final_index + return result_interpolated + + @final + def asfreq(self, fill_value=None): + """ + Return the values at the new freq, essentially a reindex. + + Parameters + ---------- + fill_value : scalar, optional + Value to use for missing values, applied during upsampling (note + this does not fill NaNs that already were present). + + Returns + ------- + DataFrame or Series + Values at the specified freq. + + See Also + -------- + Series.asfreq: Convert TimeSeries to specified frequency. + DataFrame.asfreq: Convert TimeSeries to specified frequency. + + Examples + -------- + + >>> ser = pd.Series( + ... [1, 2, 3, 4], + ... index=pd.DatetimeIndex( + ... ["2023-01-01", "2023-01-31", "2023-02-01", "2023-02-28"] + ... ), + ... ) + >>> ser + 2023-01-01 1 + 2023-01-31 2 + 2023-02-01 3 + 2023-02-28 4 + dtype: int64 + >>> ser.resample("MS").asfreq() + 2023-01-01 1 + 2023-02-01 3 + Freq: MS, dtype: int64 + """ + return self._upsample("asfreq", fill_value=fill_value) + + @final + def sum( + self, + numeric_only: bool = False, + min_count: int = 0, + ): + """ + Compute sum of group values. + + This method provides a simple way to compute the sum of values within each + resampled group, particularly useful for aggregating time-based data into + daily, monthly, or yearly sums. + + Parameters + ---------- + numeric_only : bool, default False + Include only float, int, boolean columns. + + .. versionchanged:: 2.0.0 + + numeric_only no longer accepts ``None``. + + min_count : int, default 0 + The required number of valid values to perform the operation. If fewer + than ``min_count`` non-NA values are present the result will be NA. + + Returns + ------- + Series or DataFrame + Computed sum of values within each group. + + See Also + -------- + core.resample.Resampler.mean : Compute mean of groups, excluding missing values. + core.resample.Resampler.count : Compute count of group, excluding missing + values. + DataFrame.resample : Resample time-series data. + Series.sum : Return the sum of the values over the requested axis. + + Examples + -------- + >>> ser = pd.Series( + ... [1, 2, 3, 4], + ... index=pd.DatetimeIndex( + ... ["2023-01-01", "2023-01-15", "2023-02-01", "2023-02-15"] + ... ), + ... ) + >>> ser + 2023-01-01 1 + 2023-01-15 2 + 2023-02-01 3 + 2023-02-15 4 + dtype: int64 + >>> ser.resample("MS").sum() + 2023-01-01 3 + 2023-02-01 7 + Freq: MS, dtype: int64 + """ + return self._downsample("sum", numeric_only=numeric_only, min_count=min_count) + + @final + def prod( + self, + numeric_only: bool = False, + min_count: int = 0, + ): + """ + Compute prod of group values. + + Parameters + ---------- + numeric_only : bool, default False + Include only float, int, boolean columns. + + .. versionchanged:: 2.0.0 + + numeric_only no longer accepts ``None``. + + min_count : int, default 0 + The required number of valid values to perform the operation. If fewer + than ``min_count`` non-NA values are present the result will be NA. + + Returns + ------- + Series or DataFrame + Computed prod of values within each group. + + See Also + -------- + core.resample.Resampler.sum : Compute sum of groups, excluding missing values. + core.resample.Resampler.mean : Compute mean of groups, excluding missing values. + core.resample.Resampler.median : Compute median of groups, excluding missing + values. + + Examples + -------- + >>> ser = pd.Series( + ... [1, 2, 3, 4], + ... index=pd.DatetimeIndex( + ... ["2023-01-01", "2023-01-15", "2023-02-01", "2023-02-15"] + ... ), + ... ) + >>> ser + 2023-01-01 1 + 2023-01-15 2 + 2023-02-01 3 + 2023-02-15 4 + dtype: int64 + >>> ser.resample("MS").prod() + 2023-01-01 2 + 2023-02-01 12 + Freq: MS, dtype: int64 + """ + return self._downsample("prod", numeric_only=numeric_only, min_count=min_count) + + @final + def min( + self, + numeric_only: bool = False, + min_count: int = 0, + ): + """ + Compute min value of group. + + Parameters + ---------- + numeric_only : bool, default False + Include only float, int, boolean columns. + + .. versionchanged:: 2.0.0 + + numeric_only no longer accepts ``None``. + + min_count : int, default 0 + The required number of valid values to perform the operation. If fewer + than ``min_count`` non-NA values are present the result will be NA. + + Returns + ------- + Series or DataFrame + Compute the minimum value in the given Series or DataFrame. + + See Also + -------- + core.resample.Resampler.max : Compute max value of group. + core.resample.Resampler.mean : Compute mean of groups, excluding missing values. + core.resample.Resampler.median : Compute median of groups, excluding missing + values. + + Examples + -------- + >>> ser = pd.Series( + ... [1, 2, 3, 4], + ... index=pd.DatetimeIndex( + ... ["2023-01-01", "2023-01-15", "2023-02-01", "2023-02-15"] + ... ), + ... ) + >>> ser + 2023-01-01 1 + 2023-01-15 2 + 2023-02-01 3 + 2023-02-15 4 + dtype: int64 + >>> ser.resample("MS").min() + 2023-01-01 1 + 2023-02-01 3 + Freq: MS, dtype: int64 + """ + return self._downsample("min", numeric_only=numeric_only, min_count=min_count) + + @final + def max( + self, + numeric_only: bool = False, + min_count: int = 0, + ): + """ + Compute max value of group. + + Parameters + ---------- + numeric_only : bool, default False + Include only float, int, boolean columns. + + .. versionchanged:: 2.0.0 + + numeric_only no longer accepts ``None``. + + min_count : int, default 0 + The required number of valid values to perform the operation. If fewer + than ``min_count`` non-NA values are present the result will be NA. + + Returns + ------- + Series or DataFrame + Computes the maximum value in the given Series or Dataframe. + + See Also + -------- + core.resample.Resampler.min : Compute min value of group. + core.resample.Resampler.mean : Compute mean of groups, excluding missing values. + core.resample.Resampler.median : Compute median of groups, excluding missing + values. + + Examples + -------- + >>> ser = pd.Series( + ... [1, 2, 3, 4], + ... index=pd.DatetimeIndex( + ... ["2023-01-01", "2023-01-15", "2023-02-01", "2023-02-15"] + ... ), + ... ) + >>> ser + 2023-01-01 1 + 2023-01-15 2 + 2023-02-01 3 + 2023-02-15 4 + dtype: int64 + >>> ser.resample("MS").max() + 2023-01-01 2 + 2023-02-01 4 + Freq: MS, dtype: int64 + """ + return self._downsample("max", numeric_only=numeric_only, min_count=min_count) + + @final + def first( + self, + numeric_only: bool = False, + min_count: int = 0, + skipna: bool = True, + ): + """ + Compute the first non-null entry of each column. + + Parameters + ---------- + numeric_only : bool, default False + Include only float, int, boolean columns. + min_count : int, default 0 + The required number of valid values to perform the operation. If fewer + than ``min_count`` non-NA values are present the result will be NA. + skipna : bool, default True + Exclude NA/null values. If an entire group is NA, the result will be NA. + + Returns + ------- + Series or DataFrame + First values within each group. + + See Also + -------- + core.resample.Resampler.last : Compute the last non-null value in each group. + core.resample.Resampler.mean : Compute mean of groups, excluding missing values. + + Examples + -------- + >>> s = pd.Series( + ... [1, 2, 3, 4], + ... index=pd.DatetimeIndex( + ... ["2023-01-01", "2023-01-15", "2023-02-01", "2023-02-15"] + ... ), + ... ) + >>> s + 2023-01-01 1 + 2023-01-15 2 + 2023-02-01 3 + 2023-02-15 4 + dtype: int64 + >>> s.resample("MS").first() + 2023-01-01 1 + 2023-02-01 3 + Freq: MS, dtype: int64 + """ + return self._downsample( + "first", numeric_only=numeric_only, min_count=min_count, skipna=skipna + ) + + @final + def last( + self, + numeric_only: bool = False, + min_count: int = 0, + skipna: bool = True, + ): + """ + Compute the last non-null entry of each column. + + Parameters + ---------- + numeric_only : bool, default False + Include only float, int, boolean columns. + min_count : int, default 0 + The required number of valid values to perform the operation. If fewer + than ``min_count`` non-NA values are present the result will be NA. + skipna : bool, default True + Exclude NA/null values. If an entire group is NA, the result will be NA. + + Returns + ------- + Series or DataFrame + Last of values within each group. + + See Also + -------- + core.resample.Resampler.first : Compute the first non-null value in each group. + core.resample.Resampler.mean : Compute mean of groups, excluding missing values. + + Examples + -------- + >>> s = pd.Series( + ... [1, 2, 3, 4], + ... index=pd.DatetimeIndex( + ... ["2023-01-01", "2023-01-15", "2023-02-01", "2023-02-15"] + ... ), + ... ) + >>> s + 2023-01-01 1 + 2023-01-15 2 + 2023-02-01 3 + 2023-02-15 4 + dtype: int64 + >>> s.resample("MS").last() + 2023-01-01 2 + 2023-02-01 4 + Freq: MS, dtype: int64 + """ + return self._downsample( + "last", numeric_only=numeric_only, min_count=min_count, skipna=skipna + ) + + @final + def median(self, numeric_only: bool = False): + """ + Compute median of groups, excluding missing values. + + For multiple groupings, the result index will be a MultiIndex + + Parameters + ---------- + numeric_only : bool, default False + Include only float, int, boolean columns. + + .. versionchanged:: 2.0.0 + + numeric_only no longer accepts ``None`` and defaults to False. + + Returns + ------- + Series or DataFrame + Median of values within each group. + + See Also + -------- + Series.groupby : Apply a function groupby to a Series. + DataFrame.groupby : Apply a function groupby to each row or column of a + DataFrame. + + Examples + -------- + + >>> ser = pd.Series( + ... [1, 2, 3, 3, 4, 5], + ... index=pd.DatetimeIndex( + ... [ + ... "2023-01-01", + ... "2023-01-10", + ... "2023-01-15", + ... "2023-02-01", + ... "2023-02-10", + ... "2023-02-15", + ... ] + ... ), + ... ) + >>> ser.resample("MS").median() + 2023-01-01 2.0 + 2023-02-01 4.0 + Freq: MS, dtype: float64 + """ + return self._downsample("median", numeric_only=numeric_only) + + @final + def mean( + self, + numeric_only: bool = False, + ): + """ + Compute mean of groups, excluding missing values. + + Parameters + ---------- + numeric_only : bool, default False + Include only `float`, `int` or `boolean` data. + + .. versionchanged:: 2.0.0 + + numeric_only now defaults to ``False``. + + Returns + ------- + DataFrame or Series + Mean of values within each group. + + See Also + -------- + core.resample.Resampler.median : Compute median of groups, excluding missing + values. + core.resample.Resampler.sum : Compute sum of groups, excluding missing values. + core.resample.Resampler.std : Compute standard deviation of groups, excluding + missing values. + core.resample.Resampler.var : Compute variance of groups, excluding missing + values. + + Examples + -------- + + >>> ser = pd.Series( + ... [1, 2, 3, 4], + ... index=pd.DatetimeIndex( + ... ["2023-01-01", "2023-01-15", "2023-02-01", "2023-02-15"] + ... ), + ... ) + >>> ser + 2023-01-01 1 + 2023-01-15 2 + 2023-02-01 3 + 2023-02-15 4 + dtype: int64 + >>> ser.resample("MS").mean() + 2023-01-01 1.5 + 2023-02-01 3.5 + Freq: MS, dtype: float64 + """ + return self._downsample("mean", numeric_only=numeric_only) + + @final + def std( + self, + ddof: int = 1, + numeric_only: bool = False, + ): + """ + Compute standard deviation of groups, excluding missing values. + + Parameters + ---------- + ddof : int, default 1 + Degrees of freedom. + numeric_only : bool, default False + Include only `float`, `int` or `boolean` data. + + .. versionchanged:: 2.0.0 + + numeric_only now defaults to ``False``. + + Returns + ------- + DataFrame or Series + Standard deviation of values within each group. + + See Also + -------- + core.resample.Resampler.mean : Compute mean of groups, excluding missing values. + core.resample.Resampler.median : Compute median of groups, excluding missing + values. + core.resample.Resampler.var : Compute variance of groups, excluding missing + values. + + Examples + -------- + + >>> ser = pd.Series( + ... [1, 3, 2, 4, 3, 8], + ... index=pd.DatetimeIndex( + ... [ + ... "2023-01-01", + ... "2023-01-10", + ... "2023-01-15", + ... "2023-02-01", + ... "2023-02-10", + ... "2023-02-15", + ... ] + ... ), + ... ) + >>> ser.resample("MS").std() + 2023-01-01 1.000000 + 2023-02-01 2.645751 + Freq: MS, dtype: float64 + """ + return self._downsample("std", ddof=ddof, numeric_only=numeric_only) + + @final + def var( + self, + ddof: int = 1, + numeric_only: bool = False, + ): + """ + Compute variance of groups, excluding missing values. + + Parameters + ---------- + ddof : int, default 1 + Degrees of freedom. + + numeric_only : bool, default False + Include only `float`, `int` or `boolean` data. + + .. versionchanged:: 2.0.0 + + numeric_only now defaults to ``False``. + + Returns + ------- + DataFrame or Series + Variance of values within each group. + + See Also + -------- + core.resample.Resampler.std : Compute standard deviation of groups, excluding + missing values. + core.resample.Resampler.mean : Compute mean of groups, excluding missing values. + core.resample.Resampler.median : Compute median of groups, excluding missing + values. + + Examples + -------- + + >>> ser = pd.Series( + ... [1, 3, 2, 4, 3, 8], + ... index=pd.DatetimeIndex( + ... [ + ... "2023-01-01", + ... "2023-01-10", + ... "2023-01-15", + ... "2023-02-01", + ... "2023-02-10", + ... "2023-02-15", + ... ] + ... ), + ... ) + >>> ser.resample("MS").var() + 2023-01-01 1.0 + 2023-02-01 7.0 + Freq: MS, dtype: float64 + + >>> ser.resample("MS").var(ddof=0) + 2023-01-01 0.666667 + 2023-02-01 4.666667 + Freq: MS, dtype: float64 + """ + return self._downsample("var", ddof=ddof, numeric_only=numeric_only) + + @final + def sem( + self, + ddof: int = 1, + numeric_only: bool = False, + ): + """ + Compute standard error of the mean of groups, excluding missing values. + + For multiple groupings, the result index will be a MultiIndex. + + Parameters + ---------- + ddof : int, default 1 + Degrees of freedom. + + numeric_only : bool, default False + Include only `float`, `int` or `boolean` data. + + .. versionchanged:: 2.0.0 + + numeric_only now defaults to ``False``. + + Returns + ------- + Series or DataFrame + Standard error of the mean of values within each group. + + See Also + -------- + DataFrame.sem : Return unbiased standard error of the mean over requested axis. + Series.sem : Return unbiased standard error of the mean over requested axis. + + Examples + -------- + + >>> ser = pd.Series( + ... [1, 3, 2, 4, 3, 8], + ... index=pd.DatetimeIndex( + ... [ + ... "2023-01-01", + ... "2023-01-10", + ... "2023-01-15", + ... "2023-02-01", + ... "2023-02-10", + ... "2023-02-15", + ... ] + ... ), + ... ) + >>> ser.resample("MS").sem() + 2023-01-01 0.577350 + 2023-02-01 1.527525 + Freq: MS, dtype: float64 + """ + return self._downsample("sem", ddof=ddof, numeric_only=numeric_only) + + @final + def ohlc(self): + """ + Compute open, high, low and close values of a group, excluding missing values. + + Returns + ------- + DataFrame + Open, high, low and close values within each group. + + See Also + -------- + DataFrame.agg : Aggregate using one or more operations over the specified axis. + DataFrame.resample : Resample time-series data. + DataFrame.groupby : Group DataFrame using a mapper or by a Series of columns. + + Examples + -------- + >>> ser = pd.Series( + ... [1, 3, 2, 4, 3, 5], + ... index=pd.DatetimeIndex( + ... [ + ... "2023-01-01", + ... "2023-01-10", + ... "2023-01-15", + ... "2023-02-01", + ... "2023-02-10", + ... "2023-02-15", + ... ] + ... ), + ... ) + >>> ser.resample("MS").ohlc() + open high low close + 2023-01-01 1 3 1 2 + 2023-02-01 4 5 3 5 + """ + ax = self.ax + obj = self._obj_with_exclusions + if len(ax) == 0: + # GH#42902 + obj = obj.copy() + obj.index = _asfreq_compat(obj.index, self.freq) + if obj.ndim == 1: + obj = obj.to_frame() + obj = obj.reindex(["open", "high", "low", "close"], axis=1) + else: + mi = MultiIndex.from_product( + [obj.columns, ["open", "high", "low", "close"]] + ) + obj = obj.reindex(mi, axis=1) + return obj + + return self._downsample("ohlc") + + @final + def nunique(self): + """ + Return number of unique elements in the group. + + Returns + ------- + Series + Number of unique values within each group. + + See Also + -------- + core.groupby.SeriesGroupBy.nunique : Method nunique for SeriesGroupBy. + + Examples + -------- + >>> ser = pd.Series( + ... [1, 2, 3, 3], + ... index=pd.DatetimeIndex( + ... ["2023-01-01", "2023-01-15", "2023-02-01", "2023-02-15"] + ... ), + ... ) + >>> ser + 2023-01-01 1 + 2023-01-15 2 + 2023-02-01 3 + 2023-02-15 3 + dtype: int64 + >>> ser.resample("MS").nunique() + 2023-01-01 2 + 2023-02-01 1 + Freq: MS, dtype: int64 + """ + return self._downsample("nunique") + + @final + def size(self): + """ + Compute group sizes. + + Returns + ------- + Series + Number of rows in each group. + + See Also + -------- + Series.groupby : Apply a function groupby to a Series. + DataFrame.groupby : Apply a function groupby to each row + or column of a DataFrame. + + Examples + -------- + >>> ser = pd.Series( + ... [1, 2, 3], + ... index=pd.DatetimeIndex(["2023-01-01", "2023-01-15", "2023-02-01"]), + ... ) + >>> ser + 2023-01-01 1 + 2023-01-15 2 + 2023-02-01 3 + dtype: int64 + >>> ser.resample("MS").size() + 2023-01-01 2 + 2023-02-01 1 + Freq: MS, dtype: int64 + """ + result = self._downsample("size") + + # If the result is a non-empty DataFrame we stack to get a Series + # GH 46826 + if isinstance(result, ABCDataFrame) and not result.empty: + result = result.stack() + + if not len(self.ax): + from pandas import Series + + if self._selected_obj.ndim == 1: + name = self._selected_obj.name + else: + name = None + result = Series([], index=result.index, dtype="int64", name=name) + return result + + @final + def count(self): + """ + Compute count of group, excluding missing values. + + Returns + ------- + Series or DataFrame + Count of values within each group. + + See Also + -------- + Series.groupby : Apply a function groupby to a Series. + DataFrame.groupby : Apply a function groupby to each row + or column of a DataFrame. + + Examples + -------- + >>> ser = pd.Series( + ... [1, 2, 3, 4], + ... index=pd.DatetimeIndex( + ... ["2023-01-01", "2023-01-15", "2023-02-01", "2023-02-15"] + ... ), + ... ) + >>> ser + 2023-01-01 1 + 2023-01-15 2 + 2023-02-01 3 + 2023-02-15 4 + dtype: int64 + >>> ser.resample("MS").count() + 2023-01-01 2 + 2023-02-01 2 + Freq: MS, dtype: int64 + """ + result = self._downsample("count") + if not len(self.ax): + if self._selected_obj.ndim == 1: + result = type(self._selected_obj)( + [], index=result.index, dtype="int64", name=self._selected_obj.name + ) + else: + from pandas import DataFrame + + result = DataFrame( + [], index=result.index, columns=result.columns, dtype="int64" + ) + + return result + + @final + def quantile(self, q: float | list[float] | AnyArrayLike = 0.5, **kwargs): + """ + Return value at the given quantile. + + Computes the quantile of values within each resampled group. + + Parameters + ---------- + q : float or array-like, default 0.5 (50% quantile) + Value between 0 <= q <= 1, the quantile(s) to compute. + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + DataFrame or Series + Quantile of values within each group. + + See Also + -------- + Series.quantile + Return a series, where the index is q and the values are the quantiles. + DataFrame.quantile + Return a DataFrame, where the columns are the columns of self, + and the values are the quantiles. + DataFrameGroupBy.quantile + Return a DataFrame, where the columns are groupby columns, + and the values are its quantiles. + + Examples + -------- + + >>> ser = pd.Series( + ... [1, 3, 2, 4, 3, 8], + ... index=pd.DatetimeIndex( + ... [ + ... "2023-01-01", + ... "2023-01-10", + ... "2023-01-15", + ... "2023-02-01", + ... "2023-02-10", + ... "2023-02-15", + ... ] + ... ), + ... ) + >>> ser.resample("MS").quantile() + 2023-01-01 2.0 + 2023-02-01 4.0 + Freq: MS, dtype: float64 + + >>> ser.resample("MS").quantile(0.25) + 2023-01-01 1.5 + 2023-02-01 3.5 + Freq: MS, dtype: float64 + """ + return self._downsample("quantile", q=q, **kwargs) + + +class _GroupByMixin(PandasObject, SelectionMixin): + """ + Provide the groupby facilities. + """ + + _attributes: list[str] # in practice the same as Resampler._attributes + _selection: IndexLabel | None = None + _groupby: GroupBy + _timegrouper: TimeGrouper + + def __init__( + self, + *, + parent: Resampler, + groupby: GroupBy, + key=None, + selection: IndexLabel | None = None, + ) -> None: + # reached via ._gotitem and _get_resampler_for_grouping + + assert isinstance(groupby, GroupBy), type(groupby) + + # parent is always a Resampler, sometimes a _GroupByMixin + assert isinstance(parent, Resampler), type(parent) + + # initialize our GroupByMixin object with + # the resampler attributes + for attr in self._attributes: + setattr(self, attr, getattr(parent, attr)) + self._selection = selection + + self.binner = parent.binner + self.key = key + + self._groupby = groupby + self._timegrouper = copy.copy(parent._timegrouper) + + self.ax = parent.ax + self.obj = parent.obj + + @no_type_check + def _apply(self, f, *args, **kwargs): + """ + Dispatch to _upsample; we are stripping all of the _upsample kwargs and + performing the original function call on the grouped object. + """ + + def func(x): + x = self._resampler_cls(x, timegrouper=self._timegrouper, gpr_index=self.ax) + + if isinstance(f, str): + return getattr(x, f)(**kwargs) + + return x.apply(f, *args, **kwargs) + + result = self._groupby.apply(func) + + # GH 47705 + if ( + isinstance(result, ABCDataFrame) + and len(result) == 0 + and not isinstance(result.index, PeriodIndex) + ): + result = result.set_index( + _asfreq_compat(self.obj.index[:0], freq=self.freq), append=True + ) + + return self._wrap_result(result) + + _upsample = _apply + _downsample = _apply + _groupby_and_aggregate = _apply + + @final + def _gotitem(self, key, ndim, subset=None): + """ + Sub-classes to define. Return a sliced object. + + Parameters + ---------- + key : string / list of selections + ndim : {1, 2} + requested ndim of result + subset : object, default None + subset to act on + """ + # create a new object to prevent aliasing + if subset is None: + subset = self.obj + if key is not None: + subset = subset[key] + else: + # reached via Apply.agg_dict_like with selection=None, ndim=1 + assert subset.ndim == 1 + + # Try to select from a DataFrame, falling back to a Series + try: + if isinstance(key, list) and self.key not in key and self.key is not None: + key.append(self.key) + groupby = self._groupby[key] + except IndexError: + groupby = self._groupby + + selection = self._infer_selection(key, subset) + + new_rs = type(self)( + groupby=groupby, + parent=cast(Resampler, self), + selection=selection, + ) + return new_rs + + +class DatetimeIndexResampler(Resampler): + ax: DatetimeIndex + + @property + def _resampler_for_grouping(self) -> type[DatetimeIndexResamplerGroupby]: + return DatetimeIndexResamplerGroupby + + def _get_binner_for_time(self): + # this is how we are actually creating the bins + return self._timegrouper._get_time_bins(self.ax) + + def _downsample(self, how, **kwargs): + """ + Downsample the cython defined function. + + Parameters + ---------- + how : string / cython mapped function + **kwargs : kw args passed to how function + """ + ax = self.ax + + # Excludes `on` column when provided + obj = self._obj_with_exclusions + + if not len(ax): + # reset to the new freq + obj = obj.copy() + obj.index = obj.index._with_freq(self.freq) + assert obj.index.freq == self.freq, (obj.index.freq, self.freq) + return obj + + # we are downsampling + # we want to call the actual grouper method here + result = obj.groupby(self._grouper).aggregate(how, **kwargs) + return self._wrap_result(result) + + def _adjust_binner_for_upsample(self, binner): + """ + Adjust our binner when upsampling. + + The range of a new index should not be outside specified range + """ + if self.closed == "right": + binner = binner[1:] + else: + binner = binner[:-1] + return binner + + def _upsample(self, method, limit: int | None = None, fill_value=None): + """ + Parameters + ---------- + method : string {'backfill', 'bfill', 'pad', + 'ffill', 'asfreq'} method for upsampling + limit : int, default None + Maximum size gap to fill when reindexing + fill_value : scalar, default None + Value to use for missing values + """ + if self._from_selection: + raise ValueError( + "Upsampling from level= or on= selection " + "is not supported, use .set_index(...) " + "to explicitly set index to datetime-like" + ) + + ax = self.ax + obj = self._selected_obj + binner = self.binner + res_index = self._adjust_binner_for_upsample(binner) + + # if index exactly matches target grid (same freq & alignment), use fast path + if ( + limit is None + and to_offset(ax.inferred_freq) == self.freq + and len(obj) == len(res_index) + and obj.index.equals(res_index) + ): + result = obj.copy() + result.index = res_index + else: + if method == "asfreq": + method = None + result = obj.reindex( + res_index, method=method, limit=limit, fill_value=fill_value + ) + + return self._wrap_result(result) + + def _wrap_result(self, result): + result = super()._wrap_result(result) + + # we may have a different kind that we were asked originally + # convert if needed + if isinstance(self.ax, PeriodIndex) and not isinstance( + result.index, PeriodIndex + ): + if isinstance(result.index, MultiIndex): + # GH 24103 - e.g. groupby resample + if not isinstance(result.index.levels[-1], PeriodIndex): + new_level = result.index.levels[-1].to_period(self.freq) + result.index = result.index.set_levels(new_level, level=-1) + else: + result.index = result.index.to_period(self.freq) + return result + + +@set_module("pandas.api.typing") +# error: Definition of "ax" in base class "_GroupByMixin" is incompatible +# with definition in base class "DatetimeIndexResampler" +class DatetimeIndexResamplerGroupby( # type: ignore[misc] + _GroupByMixin, DatetimeIndexResampler +): + """ + Provides a resample of a groupby implementation + """ + + @property + def _resampler_cls(self): + return DatetimeIndexResampler + + +class PeriodIndexResampler(DatetimeIndexResampler): + # error: Incompatible types in assignment (expression has type "PeriodIndex", base + # class "DatetimeIndexResampler" defined the type as "DatetimeIndex") + ax: PeriodIndex # type: ignore[assignment] + + @property + def _resampler_for_grouping(self): + return PeriodIndexResamplerGroupby + + def _get_binner_for_time(self): + return self._timegrouper._get_period_bins(self.ax) + + def _convert_obj(self, obj: NDFrameT) -> NDFrameT: + obj = super()._convert_obj(obj) + + if self._from_selection: + # see GH 14008, GH 12871 + msg = ( + "Resampling from level= or on= selection " + "with a PeriodIndex is not currently supported, " + "use .set_index(...) to explicitly set index" + ) + raise NotImplementedError(msg) + + return obj + + def _downsample(self, how, **kwargs): + """ + Downsample the cython defined function. + + Parameters + ---------- + how : string / cython mapped function + **kwargs : kw args passed to how function + """ + ax = self.ax + + if is_subperiod(ax.freq, self.freq): + # Downsampling + return self._groupby_and_aggregate(how, **kwargs) + elif is_superperiod(ax.freq, self.freq): + if how == "ohlc": + # GH #13083 + # upsampling to subperiods is handled as an asfreq, which works + # for pure aggregating/reducing methods + # OHLC reduces along the time dimension, but creates multiple + # values for each period -> handle by _groupby_and_aggregate() + return self._groupby_and_aggregate(how) + return self.asfreq() + elif ax.freq == self.freq: + return self.asfreq() + + raise IncompatibleFrequency( + f"Frequency {ax.freq} cannot be resampled to {self.freq}, " + "as they are not sub or super periods" + ) + + def _upsample(self, method, limit: int | None = None, fill_value=None): + """ + Parameters + ---------- + method : {'backfill', 'bfill', 'pad', 'ffill'} + Method for upsampling. + limit : int, default None + Maximum size gap to fill when reindexing. + fill_value : scalar, default None + Value to use for missing values. + """ + ax = self.ax + obj = self.obj + new_index = self.binner + + # Start vs. end of period + memb = ax.asfreq(self.freq, how=self.convention) + + # Get the fill indexer + if method == "asfreq": + method = None + indexer = memb.get_indexer(new_index, method=method, limit=limit) + new_obj = _take_new_index( + obj, + indexer, + new_index, + ) + return self._wrap_result(new_obj) + + +@set_module("pandas.api.typing") +# error: Definition of "ax" in base class "_GroupByMixin" is incompatible with +# definition in base class "PeriodIndexResampler" +class PeriodIndexResamplerGroupby( # type: ignore[misc] + _GroupByMixin, PeriodIndexResampler +): + """ + Provides a resample of a groupby implementation. + """ + + @property + def _resampler_cls(self): + return PeriodIndexResampler + + +class TimedeltaIndexResampler(DatetimeIndexResampler): + # error: Incompatible types in assignment (expression has type "TimedeltaIndex", + # base class "DatetimeIndexResampler" defined the type as "DatetimeIndex") + ax: TimedeltaIndex # type: ignore[assignment] + + @property + def _resampler_for_grouping(self): + return TimedeltaIndexResamplerGroupby + + def _get_binner_for_time(self): + return self._timegrouper._get_time_delta_bins(self.ax) + + def _adjust_binner_for_upsample(self, binner): + """ + Adjust our binner when upsampling. + + The range of a new index is allowed to be greater than original range + so we don't need to change the length of a binner, GH 13022 + """ + return binner + + +@set_module("pandas.api.typing") +# error: Definition of "ax" in base class "_GroupByMixin" is incompatible with +# definition in base class "DatetimeIndexResampler" +class TimedeltaIndexResamplerGroupby( # type: ignore[misc] + _GroupByMixin, TimedeltaIndexResampler +): + """ + Provides a resample of a groupby implementation. + """ + + @property + def _resampler_cls(self): + return TimedeltaIndexResampler + + +def get_resampler(obj: Series | DataFrame, **kwds) -> Resampler: + """ + Create a TimeGrouper and return our resampler. + """ + tg = TimeGrouper(obj, **kwds) # type: ignore[arg-type] + return tg._get_resampler(obj) + + +get_resampler.__doc__ = Resampler.__doc__ + + +def get_resampler_for_grouping( + groupby: GroupBy, + rule, + how=None, + fill_method=None, + limit: int | None = None, + on=None, + **kwargs, +) -> Resampler: + """ + Return our appropriate resampler when grouping as well. + """ + # .resample uses 'on' similar to how .groupby uses 'key' + tg = TimeGrouper(freq=rule, key=on, **kwargs) + resampler = tg._get_resampler(groupby.obj) + return resampler._get_resampler_for_grouping(groupby=groupby, key=tg.key) + + +@set_module("pandas.api.typing") +class TimeGrouper(Grouper): + """ + Custom groupby class for time-interval grouping. + + Parameters + ---------- + freq : pandas date offset or offset alias for identifying bin edges + closed : closed end of interval; 'left' or 'right' + label : interval boundary to use for labeling; 'left' or 'right' + convention : {'start', 'end', 'e', 's'} + If axis is PeriodIndex + """ + + _attributes = ( + *Grouper._attributes, + "closed", + "label", + "how", + "convention", + "origin", + "offset", + ) + + origin: TimeGrouperOrigin + + def __init__( + self, + obj: Grouper | None = None, + freq: Frequency = "Min", + key: str | None = None, + closed: Literal["left", "right"] | None = None, + label: Literal["left", "right"] | None = None, + how: str = "mean", + fill_method=None, + limit: int | None = None, + convention: Literal["start", "end", "e", "s"] | None = None, + origin: ( + Literal["epoch", "start", "start_day", "end", "end_day"] + | TimestampConvertibleTypes + ) = "start_day", + offset: TimedeltaConvertibleTypes | None = None, + group_keys: bool = False, + **kwargs, + ) -> None: + # Check for correctness of the keyword arguments which would + # otherwise silently use the default if misspelled + if label not in {None, "left", "right"}: + raise ValueError(f"Unsupported value {label} for `label`") + if closed not in {None, "left", "right"}: + raise ValueError(f"Unsupported value {closed} for `closed`") + if convention not in {None, "start", "end", "e", "s"}: + raise ValueError(f"Unsupported value {convention} for `convention`") + + if (key is None and obj is not None and isinstance(obj.index, PeriodIndex)) or ( # type: ignore[attr-defined] + key is not None + and obj is not None + and getattr(obj[key], "dtype", None) == "period" # type: ignore[index] + ): + freq = to_offset(freq, is_period=True) + else: + freq = to_offset(freq) + + if not isinstance(freq, Tick): + if offset is not None: + warnings.warn( + "The 'offset' keyword does not take effect when resampling " + "with a 'freq' that is not Tick-like (h, m, s, ms, us, ns)", + RuntimeWarning, + stacklevel=find_stack_level(), + ) + if origin != "start_day": + warnings.warn( + "The 'origin' keyword does not take effect when resampling " + "with a 'freq' that is not Tick-like (h, m, s, ms, us, ns)", + RuntimeWarning, + stacklevel=find_stack_level(), + ) + + end_types = {"ME", "YE", "QE", "BME", "BYE", "BQE", "W"} + rule = freq.rule_code + if rule in end_types or ("-" in rule and rule[: rule.find("-")] in end_types): + if closed is None: + closed = "right" + if label is None: + label = "right" + # The backward resample sets ``closed`` to ``'right'`` by default + # since the last value should be considered as the edge point for + # the last bin. When origin in "end" or "end_day", the value for a + # specific ``Timestamp`` index stands for the resample result from + # the current ``Timestamp`` minus ``freq`` to the current + # ``Timestamp`` with a right close. + elif origin in ["end", "end_day"]: + if closed is None: + closed = "right" + if label is None: + label = "right" + else: + if closed is None: + closed = "left" + if label is None: + label = "left" + + self.closed = closed + self.label = label + self.convention = convention if convention is not None else "e" + self.how = how + self.fill_method = fill_method + self.limit = limit + self.group_keys = group_keys + self._arrow_dtype: ArrowDtype | None = None + + if origin in ("epoch", "start", "start_day", "end", "end_day"): + # error: Incompatible types in assignment (expression has type "Union[Union[ + # Timestamp, datetime, datetime64, signedinteger[_64Bit], float, str], + # Literal['epoch', 'start', 'start_day', 'end', 'end_day']]", variable has + # type "Union[Timestamp, Literal['epoch', 'start', 'start_day', 'end', + # 'end_day']]") + self.origin = origin # type: ignore[assignment] + else: + try: + self.origin = Timestamp(origin) + except (ValueError, TypeError) as err: + raise ValueError( + "'origin' should be equal to 'epoch', 'start', 'start_day', " + "'end', 'end_day' or " + f"should be a Timestamp convertible type. Got '{origin}' instead." + ) from err + + try: + self.offset = Timedelta(offset) if offset is not None else None + except (ValueError, TypeError) as err: + raise ValueError( + "'offset' should be a Timedelta convertible type. " + f"Got '{offset}' instead." + ) from err + + # always sort time groupers + kwargs["sort"] = True + + super().__init__(freq=freq, key=key, **kwargs) + + def _get_resampler(self, obj: NDFrame) -> Resampler: + """ + Return my resampler or raise if we have an invalid axis. + + Parameters + ---------- + obj : Series or DataFrame + + Returns + ------- + Resampler + + Raises + ------ + TypeError if incompatible axis + + """ + _, ax, _ = self._set_grouper(obj, gpr_index=None) + if isinstance(ax, DatetimeIndex): + return DatetimeIndexResampler( + obj, + timegrouper=self, + group_keys=self.group_keys, + gpr_index=ax, + ) + elif isinstance(ax, PeriodIndex): + return PeriodIndexResampler( + obj, + timegrouper=self, + group_keys=self.group_keys, + gpr_index=ax, + ) + elif isinstance(ax, TimedeltaIndex): + return TimedeltaIndexResampler( + obj, + timegrouper=self, + group_keys=self.group_keys, + gpr_index=ax, + ) + + raise TypeError( + "Only valid with DatetimeIndex, " + "TimedeltaIndex or PeriodIndex, " + f"but got an instance of '{type(ax).__name__}'" + ) + + def _get_grouper( + self, obj: NDFrameT, validate: bool = True, observed: bool = True + ) -> tuple[BinGrouper, NDFrameT]: + """ + Parameters + ---------- + obj : Series or DataFrame + Object being grouped. + validate : bool, default True + Unused. Only for compatibility with ``Grouper._get_grouper``. + observed : bool, default True + Unused. Only for compatibility with ``Grouper._get_grouper``. + + Returns + ------- + A tuple of grouper, obj (possibly sorted) + """ + # create the resampler and return our binner + r = self._get_resampler(obj) + return r._grouper, cast(NDFrameT, r.obj) + + def _get_time_bins(self, ax: DatetimeIndex): + if not isinstance(ax, DatetimeIndex): + raise TypeError( + "axis must be a DatetimeIndex, but got " + f"an instance of {type(ax).__name__}" + ) + + if len(ax) == 0: + binner = labels = DatetimeIndex( + data=[], freq=self.freq, name=ax.name, dtype=ax.dtype + ) + return binner, [], labels + + first, last = _get_timestamp_range_edges( + ax.min(), + ax.max(), + self.freq, + unit=ax.unit, + closed=self.closed, + origin=self.origin, + offset=self.offset, + ) + # GH #12037 + # use first/last directly instead of call replace() on them + # because replace() will swallow the nanosecond part + # thus last bin maybe slightly before the end if the end contains + # nanosecond part and lead to `Values falls after last bin` error + # GH 25758: If DST lands at midnight (e.g. 'America/Havana'), user feedback + # has noted that ambiguous=True provides the most sensible result + binner = labels = date_range( + freq=self.freq, + start=first, + end=last, + tz=ax.tz, + name=ax.name, + ambiguous=True, + nonexistent="shift_forward", + unit=ax.unit, + ) + + ax_values = ax.asi8 + binner, bin_edges = self._adjust_bin_edges(binner, ax_values) + + # general version, knowing nothing about relative frequencies + bins = lib.generate_bins_dt64( + ax_values, bin_edges, self.closed, hasnans=ax.hasnans + ) + + if self.closed == "right": + labels = binner + if self.label == "right": + labels = labels[1:] + elif self.label == "right": + labels = labels[1:] + + if ax.hasnans: + binner = binner.insert(0, NaT) + labels = labels.insert(0, NaT) + + # if we end up with more labels than bins + # adjust the labels + # GH4076 + if len(bins) < len(labels): + labels = labels[: len(bins)] + + return binner, bins, labels + + def _adjust_bin_edges( + self, binner: DatetimeIndex, ax_values: npt.NDArray[np.int64] + ) -> tuple[DatetimeIndex, npt.NDArray[np.int64]]: + # Some hacks for > daily data, see #1471, #1458, #1483 + + if self.freq.name in ("BME", "ME", "W") or self.freq.name.split("-")[0] in ( + "BQE", + "BYE", + "QE", + "YE", + "W", + ): + # If the right end-point is on the last day of the month, roll forwards + # until the last moment of that day. Note that we only do this for offsets + # which correspond to the end of a super-daily period - "month start", for + # example, is excluded. + if self.closed == "right": + # GH 21459, GH 9119: Adjust the bins relative to the wall time + edges_dti = binner.tz_localize(None) + edges_dti = ( + edges_dti + + Timedelta(days=1).as_unit(edges_dti.unit) + - Timedelta(1, unit=edges_dti.unit).as_unit(edges_dti.unit) + ) + bin_edges = edges_dti.tz_localize(binner.tz).asi8 + else: + bin_edges = binner.asi8 + + # intraday values on last day + if bin_edges[-2] > ax_values.max(): + bin_edges = bin_edges[:-1] + binner = binner[:-1] + else: + bin_edges = binner.asi8 + return binner, bin_edges + + def _get_time_delta_bins(self, ax: TimedeltaIndex): + if not isinstance(ax, TimedeltaIndex): + raise TypeError( + "axis must be a TimedeltaIndex, but got " + f"an instance of {type(ax).__name__}" + ) + + if not isinstance(self.freq, (Tick, Day)): + # GH#51896 + raise ValueError( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + f"e.g. '24h' or '3D', not {self.freq}" + ) + + if not len(ax): + binner = labels = TimedeltaIndex(data=[], freq=self.freq, name=ax.name) + return binner, [], labels + + start, end = ax.min(), ax.max() + + if self.closed == "right": + end += self.freq + + labels = binner = timedelta_range( + start=start, end=end, freq=self.freq, name=ax.name + ) + + end_stamps = labels + if self.closed == "left": + end_stamps += self.freq + + bins = ax.searchsorted(end_stamps, side=self.closed) + + if self.offset: + # GH 10530 & 31809 + labels += self.offset + + return binner, bins, labels + + def _get_time_period_bins(self, ax: DatetimeIndex): + if not isinstance(ax, DatetimeIndex): + raise TypeError( + "axis must be a DatetimeIndex, but got " + f"an instance of {type(ax).__name__}" + ) + + freq = self.freq + + if len(ax) == 0: + binner = labels = PeriodIndex( + data=[], freq=freq, name=ax.name, dtype=ax.dtype + ) + return binner, [], labels + + labels = binner = period_range(start=ax[0], end=ax[-1], freq=freq, name=ax.name) + + end_stamps = (labels + freq).asfreq(freq, "s").to_timestamp() + if ax.tz: + end_stamps = end_stamps.tz_localize(ax.tz) + bins = ax.searchsorted(end_stamps, side="left") + + return binner, bins, labels + + def _get_period_bins(self, ax: PeriodIndex): + if not isinstance(ax, PeriodIndex): + raise TypeError( + "axis must be a PeriodIndex, but got " + f"an instance of {type(ax).__name__}" + ) + + memb = ax.asfreq(self.freq, how=self.convention) + + # NaT handling as in pandas._lib.lib.generate_bins_dt64() + nat_count = 0 + if memb.hasnans: + # error: Incompatible types in assignment (expression has type + # "bool_", variable has type "int") [assignment] + nat_count = np.sum(memb._isnan) # type: ignore[assignment] + memb = memb[~memb._isnan] + + if not len(memb): + # index contains no valid (non-NaT) values + bins = np.array([], dtype=np.int64) + binner = labels = PeriodIndex(data=[], freq=self.freq, name=ax.name) + if len(ax) > 0: + # index is all NaT + binner, bins, labels = _insert_nat_bin(binner, bins, labels, len(ax)) + return binner, bins, labels + + freq_mult = self.freq.n + + start = ax.min().asfreq(self.freq, how=self.convention) + end = ax.max().asfreq(self.freq, how="end") + bin_shift = 0 + + if isinstance(self.freq, Tick): + # GH 23882 & 31809: get adjusted bin edge labels with 'origin' + # and 'origin' support. This call only makes sense if the freq is a + # Tick since offset and origin are only used in those cases. + # Not doing this check could create an extra empty bin. + p_start, end = _get_period_range_edges( + start, + end, + self.freq, + closed=self.closed, + origin=self.origin, + offset=self.offset, + ) + + # Get offset for bin edge (not label edge) adjustment + start_offset = Period(start, self.freq) - Period(p_start, self.freq) + # error: Item "Period" of "Union[Period, Any]" has no attribute "n" + bin_shift = start_offset.n % freq_mult # type: ignore[union-attr] + start = p_start + + labels = binner = period_range( + start=start, end=end, freq=self.freq, name=ax.name + ) + + i8 = memb.asi8 + + # when upsampling to subperiods, we need to generate enough bins + expected_bins_count = len(binner) * freq_mult + i8_extend = expected_bins_count - (i8[-1] - i8[0]) + rng = np.arange(i8[0], i8[-1] + i8_extend, freq_mult) + rng += freq_mult + # adjust bin edge indexes to account for base + rng -= bin_shift + + # Wrap in PeriodArray for PeriodArray.searchsorted + prng = type(memb._data)(rng, dtype=memb.dtype) + bins = memb.searchsorted(prng, side="left") + + if nat_count > 0: + binner, bins, labels = _insert_nat_bin(binner, bins, labels, nat_count) + + return binner, bins, labels + + def _set_grouper( + self, obj: NDFrameT, sort: bool = False, *, gpr_index: Index | None = None + ) -> tuple[NDFrameT, Index, npt.NDArray[np.intp] | None]: + obj, ax, indexer = super()._set_grouper(obj, sort, gpr_index=gpr_index) + if isinstance(ax.dtype, ArrowDtype) and ax.dtype.kind in "Mm": + self._arrow_dtype = ax.dtype + ax = Index( + cast(ArrowExtensionArray, ax.array)._maybe_convert_datelike_array() + ) + return obj, ax, indexer + + +@overload +def _take_new_index( + obj: DataFrame, indexer: npt.NDArray[np.intp], new_index: Index +) -> DataFrame: ... + + +@overload +def _take_new_index( + obj: Series, indexer: npt.NDArray[np.intp], new_index: Index +) -> Series: ... + + +def _take_new_index( + obj: DataFrame | Series, + indexer: npt.NDArray[np.intp], + new_index: Index, +) -> DataFrame | Series: + if isinstance(obj, ABCSeries): + new_values = algos.take_nd(obj._values, indexer) + return obj._constructor(new_values, index=new_index, name=obj.name) + elif isinstance(obj, ABCDataFrame): + new_mgr = obj._mgr.reindex_indexer(new_axis=new_index, indexer=indexer, axis=1) + return obj._constructor_from_mgr(new_mgr, axes=new_mgr.axes) + else: + raise ValueError("'obj' should be either a Series or a DataFrame") + + +def _get_timestamp_range_edges( + first: Timestamp, + last: Timestamp, + freq: BaseOffset, + unit: TimeUnit, + closed: Literal["right", "left"] = "left", + origin: TimeGrouperOrigin = "start_day", + offset: Timedelta | None = None, +) -> tuple[Timestamp, Timestamp]: + """ + Adjust the `first` Timestamp to the preceding Timestamp that resides on + the provided offset. Adjust the `last` Timestamp to the following + Timestamp that resides on the provided offset. Input Timestamps that + already reside on the offset will be adjusted depending on the type of + offset and the `closed` parameter. + + Parameters + ---------- + first : pd.Timestamp + The beginning Timestamp of the range to be adjusted. + last : pd.Timestamp + The ending Timestamp of the range to be adjusted. + freq : pd.DateOffset + The dateoffset to which the Timestamps will be adjusted. + closed : {'right', 'left'}, default "left" + Which side of bin interval is closed. + origin : {'epoch', 'start', 'start_day'} or Timestamp, default 'start_day' + The timestamp on which to adjust the grouping. The timezone of origin must + match the timezone of the index. + If a timestamp is not used, these values are also supported: + + - 'epoch': `origin` is 1970-01-01 + - 'start': `origin` is the first value of the timeseries + - 'start_day': `origin` is the first day at midnight of the timeseries + offset : pd.Timedelta, default is None + An offset timedelta added to the origin. + + Returns + ------- + A tuple of length 2, containing the adjusted pd.Timestamp objects. + """ + if isinstance(freq, Tick): + index_tz = first.tz + if isinstance(origin, Timestamp) and (origin.tz is None) != (index_tz is None): + raise ValueError("The origin must have the same timezone as the index.") + if origin == "epoch": + # set the epoch based on the timezone to have similar bins results when + # resampling on the same kind of indexes on different timezones + origin = Timestamp("1970-01-01", tz=index_tz) + + first, last = _adjust_dates_anchored( + first, + last, + freq, + closed=closed, + origin=origin, + offset=offset, + unit=unit, + ) + else: + first = first.normalize() + last = last.normalize() + + if closed == "left": + first = Timestamp(freq.rollback(first)) + else: + first = Timestamp(first - freq) + + last = Timestamp(last + freq) + + return first, last + + +def _get_period_range_edges( + first: Period, + last: Period, + freq: BaseOffset, + closed: Literal["right", "left"] = "left", + origin: TimeGrouperOrigin = "start_day", + offset: Timedelta | None = None, +) -> tuple[Period, Period]: + """ + Adjust the provided `first` and `last` Periods to the respective Period of + the given offset that encompasses them. + + Parameters + ---------- + first : pd.Period + The beginning Period of the range to be adjusted. + last : pd.Period + The ending Period of the range to be adjusted. + freq : pd.DateOffset + The freq to which the Periods will be adjusted. + closed : {'right', 'left'}, default "left" + Which side of bin interval is closed. + origin : {'epoch', 'start', 'start_day'}, Timestamp, default 'start_day' + The timestamp on which to adjust the grouping. The timezone of origin must + match the timezone of the index. + + If a timestamp is not used, these values are also supported: + + - 'epoch': `origin` is 1970-01-01 + - 'start': `origin` is the first value of the timeseries + - 'start_day': `origin` is the first day at midnight of the timeseries + offset : pd.Timedelta, default is None + An offset timedelta added to the origin. + + Returns + ------- + A tuple of length 2, containing the adjusted pd.Period objects. + """ + if not all(isinstance(obj, Period) for obj in [first, last]): + raise TypeError("'first' and 'last' must be instances of type Period") + + # GH 23882 + first_ts = first.to_timestamp() + last_ts = last.to_timestamp() + adjust_first = not freq.is_on_offset(first_ts) + adjust_last = freq.is_on_offset(last_ts) + + first_ts, last_ts = _get_timestamp_range_edges( + first_ts, last_ts, freq, unit="ns", closed=closed, origin=origin, offset=offset + ) + + first = (first_ts + int(adjust_first) * freq).to_period(freq) + last = (last_ts - int(adjust_last) * freq).to_period(freq) + return first, last + + +def _insert_nat_bin( + binner: PeriodIndex, bins: np.ndarray, labels: PeriodIndex, nat_count: int +) -> tuple[PeriodIndex, np.ndarray, PeriodIndex]: + # NaT handling as in pandas._lib.lib.generate_bins_dt64() + # shift bins by the number of NaT + assert nat_count > 0 + bins += nat_count + bins = np.insert(bins, 0, nat_count) + + # Incompatible types in assignment (expression has type "Index", variable + # has type "PeriodIndex") + binner = binner.insert(0, NaT) # type: ignore[assignment] + # Incompatible types in assignment (expression has type "Index", variable + # has type "PeriodIndex") + labels = labels.insert(0, NaT) # type: ignore[assignment] + return binner, bins, labels + + +def _adjust_dates_anchored( + first: Timestamp, + last: Timestamp, + freq: Tick, + closed: Literal["right", "left"] = "right", + origin: TimeGrouperOrigin = "start_day", + offset: Timedelta | None = None, + unit: TimeUnit = "ns", +) -> tuple[Timestamp, Timestamp]: + # First and last offsets should be calculated from the start day to fix an + # error cause by resampling across multiple days when a one day period is + # not a multiple of the frequency. See GH 8683 + # To handle frequencies that are not multiple or divisible by a day we let + # the possibility to define a fixed origin timestamp. See GH 31809 + first = first.as_unit(unit) + last = last.as_unit(unit) + if offset is not None: + offset = offset.as_unit(unit) + + freq_value = Timedelta(freq).as_unit(unit)._value + + origin_timestamp = 0 # origin == "epoch" + if origin == "start_day": + origin_timestamp = first.normalize()._value + elif origin == "start": + origin_timestamp = first._value + elif isinstance(origin, Timestamp): + origin_timestamp = origin.as_unit(unit)._value + elif origin in ["end", "end_day"]: + origin_last = last if origin == "end" else last.ceil("D") + sub_freq_times = (origin_last._value - first._value) // freq_value + if closed == "left": + sub_freq_times += 1 + first = origin_last - sub_freq_times * freq + origin_timestamp = first._value + origin_timestamp += offset._value if offset else 0 + + # GH 10117 & GH 19375. If first and last contain timezone information, + # Perform the calculation in UTC in order to avoid localizing on an + # Ambiguous or Nonexistent time. + first_tzinfo = first.tzinfo + last_tzinfo = last.tzinfo + if first_tzinfo is not None: + first = first.tz_convert("UTC") + if last_tzinfo is not None: + last = last.tz_convert("UTC") + + foffset = (first._value - origin_timestamp) % freq_value + loffset = (last._value - origin_timestamp) % freq_value + + if closed == "right": + if foffset > 0: + # roll back + fresult_int = first._value - foffset + else: + fresult_int = first._value - freq_value + + if loffset > 0: + # roll forward + lresult_int = last._value + (freq_value - loffset) + else: + # already the end of the road + lresult_int = last._value + else: # closed == 'left' + if foffset > 0: + fresult_int = first._value - foffset + else: + # start of the road + fresult_int = first._value + + if loffset > 0: + # roll forward + lresult_int = last._value + (freq_value - loffset) + else: + lresult_int = last._value + freq_value + fresult = Timestamp(fresult_int, unit=unit) + lresult = Timestamp(lresult_int, unit=unit) + if first_tzinfo is not None: + fresult = fresult.tz_localize("UTC").tz_convert(first_tzinfo) + if last_tzinfo is not None: + lresult = lresult.tz_localize("UTC").tz_convert(last_tzinfo) + return fresult, lresult + + +def asfreq( + obj: NDFrameT, + freq, + method=None, + how=None, + normalize: bool = False, + fill_value=None, +) -> NDFrameT: + """ + Utility frequency conversion method for Series/DataFrame. + + See :meth:`pandas.NDFrame.asfreq` for full documentation. + """ + if isinstance(obj.index, PeriodIndex): + if method is not None: + raise NotImplementedError("'method' argument is not supported") + + if how is None: + how = "E" + + if isinstance(freq, BaseOffset): + if hasattr(freq, "_period_dtype_code"): + freq = PeriodDtype(freq)._freqstr + + new_obj = obj.copy() + new_obj.index = obj.index.asfreq(freq, how=how) + + elif len(obj.index) == 0: + new_obj = obj.copy() + + new_obj.index = _asfreq_compat(obj.index, freq) + else: + unit: TimeUnit = "ns" + if isinstance(obj.index, DatetimeIndex): + # TODO: should we disallow non-DatetimeIndex? + unit = obj.index.unit + dti = date_range(obj.index.min(), obj.index.max(), freq=freq, unit=unit) + dti.name = obj.index.name + new_obj = obj.reindex(dti, method=method, fill_value=fill_value) + if normalize: + new_obj.index = new_obj.index.normalize() + + return new_obj + + +def _asfreq_compat(index: FreqIndexT, freq) -> FreqIndexT: + """ + Helper to mimic asfreq on (empty) DatetimeIndex and TimedeltaIndex. + + Parameters + ---------- + index : PeriodIndex, DatetimeIndex, or TimedeltaIndex + freq : DateOffset + + Returns + ------- + same type as index + """ + if len(index) != 0: + # This should never be reached, always checked by the caller + raise ValueError( + "Can only set arbitrary freq for empty DatetimeIndex or TimedeltaIndex" + ) + if isinstance(index, PeriodIndex): + new_index = index.asfreq(freq=freq) + elif isinstance(index, DatetimeIndex): + new_index = DatetimeIndex([], dtype=index.dtype, freq=freq, name=index.name) + elif isinstance(index, TimedeltaIndex): + new_index = TimedeltaIndex([], dtype=index.dtype, freq=freq, name=index.name) + else: # pragma: no cover + raise TypeError(type(index)) + return new_index diff --git a/pandas/core/roperator.py b/pandas/core/roperator.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea4bea41cdeaac7b0520cafc08656b1dbe5519d --- /dev/null +++ b/pandas/core/roperator.py @@ -0,0 +1,63 @@ +""" +Reversed Operations not available in the stdlib operator module. +Defining these instead of using lambdas allows us to reference them by name. +""" + +from __future__ import annotations + +import operator + + +def radd(left, right): + return right + left + + +def rsub(left, right): + return right - left + + +def rmul(left, right): + return right * left + + +def rdiv(left, right): + return right / left + + +def rtruediv(left, right): + return right / left + + +def rfloordiv(left, right): + return right // left + + +def rmod(left, right): + # check if right is a string as % is the string + # formatting operation; this is a TypeError + # otherwise perform the op + if isinstance(right, str): + typ = type(left).__name__ + raise TypeError(f"{typ} cannot perform the operation mod") + + return right % left + + +def rdivmod(left, right): + return divmod(right, left) + + +def rpow(left, right): + return right**left + + +def rand_(left, right): + return operator.and_(right, left) + + +def ror_(left, right): + return operator.or_(right, left) + + +def rxor(left, right): + return operator.xor(right, left) diff --git a/pandas/core/sample.py b/pandas/core/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..4f476540cf406af438306a124779dd1d233f14fb --- /dev/null +++ b/pandas/core/sample.py @@ -0,0 +1,163 @@ +""" +Module containing utilities for NDFrame.sample() and .GroupBy.sample() +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from pandas._libs import lib + +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCSeries, +) + +if TYPE_CHECKING: + from pandas._typing import AxisInt + + from pandas.core.generic import NDFrame + + +def preprocess_weights(obj: NDFrame, weights, axis: AxisInt) -> np.ndarray: + """ + Process and validate the `weights` argument to `NDFrame.sample` and + `.GroupBy.sample`. + + Returns `weights` as an ndarray[np.float64], validated except for normalizing + weights (because that must be done groupwise in groupby sampling). + """ + # If a series, align with frame + if isinstance(weights, ABCSeries): + weights = weights.reindex(obj.axes[axis]) + + # Strings acceptable if a dataframe and axis = 0 + if isinstance(weights, str): + if isinstance(obj, ABCDataFrame): + if axis == 0: + try: + weights = obj[weights] + except KeyError as err: + raise KeyError( + "String passed to weights not a valid column" + ) from err + else: + raise ValueError( + "Strings can only be passed to " + "weights when sampling from rows on " + "a DataFrame" + ) + else: + raise ValueError( + "Strings cannot be passed as weights when sampling from a Series." + ) + + if isinstance(obj, ABCSeries): + func = obj._constructor + else: + func = obj._constructor_sliced + + weights = func(weights, dtype="float64")._values + + if len(weights) != obj.shape[axis]: + raise ValueError("Weights and axis to be sampled must be of same length") + + if lib.has_infs(weights): + raise ValueError("weight vector may not include `inf` values") + + if (weights < 0).any(): + raise ValueError("weight vector many not include negative values") + + missing = np.isnan(weights) + if missing.any(): + # Don't modify weights in place + weights = weights.copy() + weights[missing] = 0 + return weights + + +def process_sampling_size( + n: int | None, frac: float | None, replace: bool +) -> int | None: + """ + Process and validate the `n` and `frac` arguments to `NDFrame.sample` and + `.GroupBy.sample`. + + Returns None if `frac` should be used (variable sampling sizes), otherwise returns + the constant sampling size. + """ + # If no frac or n, default to n=1. + if n is None and frac is None: + n = 1 + elif n is not None and frac is not None: + raise ValueError("Please enter a value for `frac` OR `n`, not both") + elif n is not None: + if n < 0: + raise ValueError( + "A negative number of rows requested. Please provide `n` >= 0." + ) + if n % 1 != 0: + raise ValueError("Only integers accepted as `n` values") + else: + assert frac is not None # for mypy + if frac > 1 and not replace: + raise ValueError( + "Replace has to be set to `True` when " + "upsampling the population `frac` > 1." + ) + if frac < 0: + raise ValueError( + "A negative number of rows requested. Please provide `frac` >= 0." + ) + + return n + + +def sample( + obj_len: int, + size: int, + replace: bool, + weights: np.ndarray | None, + random_state: np.random.RandomState | np.random.Generator, +) -> np.ndarray: + """ + Randomly sample `size` indices in `np.arange(obj_len)`. + + Parameters + ---------- + obj_len : int + The length of the indices being considered + size : int + The number of values to choose + replace : bool + Allow or disallow sampling of the same row more than once. + weights : np.ndarray[np.float64] or None + If None, equal probability weighting, otherwise weights according + to the vector normalized + random_state: np.random.RandomState or np.random.Generator + State used for the random sampling + + Returns + ------- + np.ndarray[np.intp] + """ + if weights is not None: + weight_sum = weights.sum() + if weight_sum != 0: + weights = weights / weight_sum + else: + raise ValueError("Invalid weights: weights sum to zero") + + assert weights is not None # for mypy + if not replace and size * weights.max() > 1: + raise ValueError( + "Weighted sampling cannot be achieved with replace=False. Either " + "set replace=True or use smaller weights. See the docstring of " + "sample for details." + ) + + return random_state.choice(obj_len, size=size, replace=replace, p=weights).astype( + np.intp, copy=False + ) diff --git a/pandas/core/series.py b/pandas/core/series.py new file mode 100644 index 0000000000000000000000000000000000000000..d54cbbdc67bd611650f6aa8738016b98b8dd8c8e --- /dev/null +++ b/pandas/core/series.py @@ -0,0 +1,8771 @@ +""" +Data structure for 1-dimensional cross-sectional and time series data +""" + +from __future__ import annotations + +from collections.abc import ( + Callable, + Hashable, + Iterable, + Mapping, + Sequence, +) +import functools +import operator +import sys +from textwrap import dedent +from typing import ( + IO, + TYPE_CHECKING, + Any, + Literal, + Self, + cast, + overload, +) +import warnings + +import numpy as np + +from pandas._libs import ( + lib, + properties, + reshape, +) +from pandas._libs.lib import is_range_indexer +from pandas.compat import CHAINED_WARNING_DISABLED +from pandas.compat._constants import ( + REF_COUNT, + REF_COUNT_METHOD, +) +from pandas.compat._optional import import_optional_dependency +from pandas.compat.numpy import function as nv +from pandas.errors import ( + ChainedAssignmentError, + InvalidIndexError, + Pandas4Warning, +) +from pandas.errors.cow import ( + _chained_assignment_method_update_msg, + _chained_assignment_msg, +) +from pandas.util._decorators import ( + Appender, + deprecate_nonkeyword_arguments, + doc, + set_module, +) +from pandas.util._exceptions import ( + find_stack_level, +) +from pandas.util._validators import ( + validate_ascending, + validate_bool_kwarg, + validate_percentile, +) + +from pandas.core.dtypes.astype import astype_is_view +from pandas.core.dtypes.cast import ( + LossySetitemError, + construct_1d_arraylike_from_scalar, + find_common_type, + infer_dtype_from, + maybe_box_native, + maybe_unbox_numpy_scalar, +) +from pandas.core.dtypes.common import ( + is_dict_like, + is_float, + is_integer, + is_iterator, + is_list_like, + is_object_dtype, + is_scalar, + pandas_dtype, + validate_all_hashable, +) +from pandas.core.dtypes.dtypes import ( + ExtensionDtype, +) +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCSeries, +) +from pandas.core.dtypes.inference import is_hashable +from pandas.core.dtypes.missing import ( + isna, + na_value_for_dtype, + notna, + remove_na_arraylike, +) + +from pandas.core import ( + algorithms, + base, + common as com, + nanops, + ops, + roperator, +) +from pandas.core.accessor import Accessor +from pandas.core.apply import SeriesApply +from pandas.core.arrays import ExtensionArray +from pandas.core.arrays.arrow import ( + ListAccessor, + StructAccessor, +) +from pandas.core.arrays.categorical import CategoricalAccessor +from pandas.core.arrays.sparse import SparseAccessor +from pandas.core.construction import ( + array as pd_array, + extract_array, + sanitize_array, +) +from pandas.core.generic import NDFrame +from pandas.core.indexers import ( + disallow_ndim_indexing, + unpack_1tuple, +) +from pandas.core.indexes.accessors import CombinedDatetimelikeProperties +from pandas.core.indexes.api import ( + DatetimeIndex, + Index, + MultiIndex, + PeriodIndex, + default_index, + ensure_index, + maybe_sequence_to_range, +) +import pandas.core.indexes.base as ibase +from pandas.core.indexes.multi import maybe_droplevels +from pandas.core.indexing import ( + check_bool_indexer, + check_dict_or_set_indexers, +) +from pandas.core.internals import SingleBlockManager +from pandas.core.methods import selectn +from pandas.core.shared_docs import _shared_docs +from pandas.core.sorting import ( + ensure_key_mapped, + nargsort, +) +from pandas.core.strings.accessor import StringMethods +from pandas.core.tools.datetimes import to_datetime + +import pandas.io.formats.format as fmt +from pandas.io.formats.info import ( + SeriesInfo, +) +import pandas.plotting + +if TYPE_CHECKING: + from pandas._libs.internals import BlockValuesRefs + from pandas._typing import ( + AggFuncType, + AnyAll, + AnyArrayLike, + ArrayLike, + ArrowArrayExportable, + ArrowStreamExportable, + Axis, + AxisInt, + CorrelationMethod, + DropKeep, + Dtype, + DtypeObj, + FilePath, + Frequency, + IgnoreRaise, + IndexKeyFunc, + IndexLabel, + Level, + ListLike, + MutableMappingT, + NaPosition, + NumpySorter, + NumpyValueArrayLike, + QuantileInterpolation, + ReindexMethod, + Renamer, + Scalar, + SortKind, + StorageOptions, + Suffixes, + ValueKeyFunc, + WriteBuffer, + npt, + ) + + from pandas.core.frame import DataFrame + from pandas.core.groupby.generic import SeriesGroupBy + +__all__ = ["Series"] + +_shared_doc_kwargs = { + "axes": "index", + "klass": "Series", + "axes_single_arg": "{0 or 'index'}", + "axis": """axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame.""", + "inplace": """inplace : bool, default False + If True, performs operation inplace and returns None.""", + "unique": "np.ndarray", + "duplicated": "Series", + "optional_by": "", + "optional_reindex": """ +index : array-like, optional + New labels for the index. Preferably an Index object to avoid + duplicating data. +axis : int or str, optional + Unused.""", +} + +# ---------------------------------------------------------------------- +# Series class + + +# error: Cannot override final attribute "ndim" (previously declared in base +# class "NDFrame") +# error: Cannot override final attribute "size" (previously declared in base +# class "NDFrame") +# definition in base class "NDFrame" +@set_module("pandas") +class Series(base.IndexOpsMixin, NDFrame): # type: ignore[misc] + """ + One-dimensional ndarray with axis labels (including time series). + + Labels need not be unique but must be a hashable type. The object + supports both integer- and label-based indexing and provides a host of + methods for performing operations involving the index. Statistical + methods from ndarray have been overridden to automatically exclude + missing data (currently represented as NaN). + + Operations between Series (+, -, /, \\*, \\*\\*) align values based on their + associated index values-- they need not be the same length. The result + index will be the sorted union of the two indexes. + + Parameters + ---------- + data : array-like, Iterable, dict, or scalar value + Contains data stored in Series. If data is a dict, argument order is + maintained. Unordered sets are not supported. + index : array-like or Index (1d) + Values must be hashable and have the same length as `data`. + Non-unique index values are allowed. Will default to + RangeIndex (0, 1, 2, ..., n) if not provided. If data is dict-like + and index is None, then the keys in the data are used as the index. If the + index is not None, the resulting Series is reindexed with the index values. + dtype : str, numpy.dtype, or ExtensionDtype, optional + Data type for the output Series. If not specified, this will be + inferred from `data`. + See the :ref:`user guide ` for more usages. + name : Hashable, default None + The name to give to the Series. + copy : bool, default None + Whether to copy input data, only relevant for array, Series, and Index + inputs (for other input, e.g. a list, a new array is created anyway). + Defaults to True for array input and False for Index/Series. + Even when False for Index/Series, a shallow copy of the data is made. + Set to False to avoid copying array input at your own risk (if you + know the input data won't be modified elsewhere). + Set to True to force copying Series/Index input up front. + + See Also + -------- + DataFrame : Two-dimensional, size-mutable, potentially heterogeneous tabular data. + Index : Immutable sequence used for indexing and alignment. + + Notes + ----- + Please reference the :ref:`User Guide ` for more information. + + Examples + -------- + Constructing Series from a dictionary with an Index specified + + >>> d = {"a": 1, "b": 2, "c": 3} + >>> ser = pd.Series(data=d, index=["a", "b", "c"]) + >>> ser + a 1 + b 2 + c 3 + dtype: int64 + + The keys of the dictionary match with the Index values, hence the Index + values have no effect. + + >>> d = {"a": 1, "b": 2, "c": 3} + >>> ser = pd.Series(data=d, index=["x", "y", "z"]) + >>> ser + x NaN + y NaN + z NaN + dtype: float64 + + Note that the Index is first built with the keys from the dictionary. + After this the Series is reindexed with the given Index values, hence we + get all NaN as a result. + + Constructing Series from a list with `copy=False`. + + >>> r = [1, 2] + >>> ser = pd.Series(r, copy=False) + >>> ser.iloc[0] = 999 + >>> r + [1, 2] + >>> ser + 0 999 + 1 2 + dtype: int64 + + Due to input data type the Series has a `copy` of + the original data even though `copy=False`, so + the data is unchanged. + + Constructing Series from a 1d ndarray with `copy=False`. + + >>> r = np.array([1, 2]) + >>> ser = pd.Series(r, copy=False) + >>> ser.iloc[0] = 999 + >>> r + array([999, 2]) + >>> ser + 0 999 + 1 2 + dtype: int64 + + Due to input data type the Series has a `view` on + the original data, so + the data is changed as well. + """ + + _typ = "series" + _HANDLED_TYPES = (Index, ExtensionArray, np.ndarray) + + _name: Hashable + _metadata: list[str] = ["_name"] + _internal_names_set = {"index", "name"} | NDFrame._internal_names_set + _accessors = {"dt", "cat", "str", "sparse"} + _hidden_attrs = ( + base.IndexOpsMixin._hidden_attrs | NDFrame._hidden_attrs | frozenset([]) + ) + + # similar to __array_priority__, positions Series after DataFrame + # but before Index and ExtensionArray. Should NOT be overridden by subclasses. + __pandas_priority__ = 3000 + + # Override cache_readonly bc Series is mutable + hasnans = property( + # error: "Callable[[IndexOpsMixin], bool]" has no attribute "fget" + base.IndexOpsMixin.hasnans.fget, # type: ignore[attr-defined] + doc=base.IndexOpsMixin.hasnans.__doc__, + ) + _mgr: SingleBlockManager + + # ---------------------------------------------------------------------- + # Constructors + + def __init__( + self, + data=None, + index=None, + dtype: Dtype | None = None, + name=None, + copy: bool | None = None, + ) -> None: + allow_mgr = False + if ( + isinstance(data, SingleBlockManager) + and index is None + and dtype is None + and (copy is False or copy is None) + ): + if not allow_mgr: + # GH#52419 + warnings.warn( + f"Passing a {type(data).__name__} to {type(self).__name__} " + "is deprecated and will raise in a future version. " + "Use public APIs instead.", + Pandas4Warning, + stacklevel=2, + ) + data = data.copy(deep=False) + # GH#33357 called with just the SingleBlockManager + NDFrame.__init__(self, data) + self.name = name + return + + if isinstance(data, (ExtensionArray, np.ndarray)): + if copy is not False: + if dtype is None or astype_is_view(data.dtype, pandas_dtype(dtype)): + data = data.copy() + copy = False + if copy is None: + copy = False + + if isinstance(data, SingleBlockManager) and not copy: + data = data.copy(deep=False) + + if not allow_mgr: + warnings.warn( + f"Passing a {type(data).__name__} to {type(self).__name__} " + "is deprecated and will raise in a future version. " + "Use public APIs instead.", + Pandas4Warning, + stacklevel=2, + ) + allow_mgr = True + + name = ibase.maybe_extract_name(name, data, type(self)) + + if index is not None: + index = ensure_index(index) + + if dtype is not None: + dtype = self._validate_dtype(dtype) + + if data is None: + index = index if index is not None else default_index(0) + if len(index) or dtype is not None: + data = na_value_for_dtype(pandas_dtype(dtype), compat=False) + else: + data = [] + + if isinstance(data, MultiIndex): + raise NotImplementedError( + "initializing a Series from a MultiIndex is not supported" + ) + + refs = None + if isinstance(data, Index): + if dtype is not None: + data = data.astype(dtype) + if not copy: + refs = data._references + + elif isinstance(data, np.ndarray): + if len(data.dtype): + # GH#13296 we are dealing with a compound dtype, which + # should be treated as 2D + raise ValueError( + "Cannot construct a Series from an ndarray with " + "compound dtype. Use DataFrame instead." + ) + elif isinstance(data, Series): + if index is None: + index = data.index + data = data._mgr.copy(deep=False) + else: + data = data.reindex(index) + data = data._mgr + if data._has_no_reference(0): + copy = False + elif isinstance(data, Mapping): + data, index = self._init_dict(data, index, dtype) + dtype = None + copy = False + elif isinstance(data, SingleBlockManager): + if index is None: + index = data.index + elif not data.index.equals(index) or copy: + # GH#19275 SingleBlockManager input should only be called + # internally + raise AssertionError( + "Cannot pass both SingleBlockManager " + "`data` argument and a different " + "`index` argument. `copy` must be False." + ) + + if not allow_mgr: + warnings.warn( + f"Passing a {type(data).__name__} to {type(self).__name__} " + "is deprecated and will raise in a future version. " + "Use public APIs instead.", + Pandas4Warning, + stacklevel=2, + ) + allow_mgr = True + + elif isinstance(data, ExtensionArray): + pass + else: + data = com.maybe_iterable_to_list(data) + if is_list_like(data) and not len(data) and dtype is None: + # GH 29405: Pre-2.0, this defaulted to float. + dtype = np.dtype(object) + + if index is None: + if not is_list_like(data): + data = [data] + index = default_index(len(data)) + elif is_list_like(data): + com.require_length_match(data, index) + + # create/copy the manager + if isinstance(data, SingleBlockManager): + if dtype is not None: + if not astype_is_view(data.dtype, pandas_dtype(dtype)): + copy = False + data = data.astype(dtype=dtype) + if copy: + data = data.copy(deep=True) + else: + data = sanitize_array(data, index, dtype, copy) + data = SingleBlockManager.from_array(data, index, refs=refs) + + NDFrame.__init__(self, data) + self.name = name + self._set_axis(0, index) + + def _init_dict( + self, data: Mapping, index: Index | None = None, dtype: DtypeObj | None = None + ): + """ + Derive the "_mgr" and "index" attributes of a new Series from a + dictionary input. + + Parameters + ---------- + data : dict or dict-like + Data used to populate the new Series. + index : Index or None, default None + Index for the new Series: if None, use dict keys. + dtype : np.dtype, ExtensionDtype, or None, default None + The dtype for the new Series: if None, infer from data. + + Returns + ------- + _data : BlockManager for the new Series + index : index for the new Series + """ + # Looking for NaN in dict doesn't work ({np.nan : 1}[float('nan')] + # raises KeyError), so we iterate the entire dict, and align + if data: + # GH:34717, issue was using zip to extract key and values from data. + # using generators in effects the performance. + # Below is the new way of extracting the keys and values + + keys = maybe_sequence_to_range(tuple(data.keys())) + values = list(data.values()) # Generating list of values- faster way + elif index is not None: + # fastpath for Series(data=None). Just use broadcasting a scalar + # instead of reindexing. + if len(index) or dtype is not None: + values = na_value_for_dtype(pandas_dtype(dtype), compat=False) + else: + values = [] + keys = index + else: + keys, values = default_index(0), [] + + # Input is now list-like, so rely on "standard" construction: + s = Series(values, index=keys, dtype=dtype) + + # Now we just make sure the order is respected, if any + if data and index is not None: + s = s.reindex(index) + return s._mgr, s.index + + # ---------------------------------------------------------------------- + + def __arrow_c_stream__(self, requested_schema=None): + """ + Export the pandas Series as an Arrow C stream PyCapsule. + + This relies on pyarrow to convert the pandas Series to the Arrow + format (and follows the default behavior of ``pyarrow.Array.from_pandas`` + in its handling of the index, i.e. to ignore it). + This conversion is not necessarily zero-copy. + + Parameters + ---------- + requested_schema : PyCapsule, default None + The schema to which the dataframe should be casted, passed as a + PyCapsule containing a C ArrowSchema representation of the + requested schema. + + Returns + ------- + PyCapsule + """ + pa = import_optional_dependency("pyarrow", min_version="16.0.0") + type = ( + pa.DataType._import_from_c_capsule(requested_schema) + if requested_schema is not None + else None + ) + ca = pa.array(self, type=type) + if not isinstance(ca, pa.ChunkedArray): + ca = pa.chunked_array([ca]) + return ca.__arrow_c_stream__() + + # ---------------------------------------------------------------------- + + @property + def _constructor(self) -> type[Series]: + return Series + + def _constructor_from_mgr(self, mgr, axes): + ser = Series._from_mgr(mgr, axes=axes) + ser._name = None # caller is responsible for setting real name + + if type(self) is Series: + # This would also work `if self._constructor is Series`, but + # this check is slightly faster, benefiting the most-common case. + return ser + + # We assume that the subclass __init__ knows how to handle a + # pd.Series object. + return self._constructor(ser) + + @property + def _constructor_expanddim(self) -> Callable[..., DataFrame]: + """ + Used when a manipulation result has one higher dimension as the + original, such as Series.to_frame() + """ + from pandas.core.frame import DataFrame + + return DataFrame + + def _constructor_expanddim_from_mgr(self, mgr, axes): + from pandas.core.frame import DataFrame + + df = DataFrame._from_mgr(mgr, axes=mgr.axes) + + if type(self) is Series: + # This would also work `if self._constructor_expanddim is DataFrame`, + # but this check is slightly faster, benefiting the most-common case. + return df + + # We assume that the subclass __init__ knows how to handle a + # pd.DataFrame object. + return self._constructor_expanddim(df) + + # types + @property + def _can_hold_na(self) -> bool: + return self._mgr._can_hold_na + + # ndarray compatibility + @property + def dtype(self) -> DtypeObj: + """ + Return the dtype object of the underlying data. + + See Also + -------- + Series.dtypes : Return the dtype object of the underlying data. + Series.astype : Cast a pandas object to a specified dtype dtype. + Series.convert_dtypes : Convert columns to the best possible dtypes using dtypes + supporting pd.NA. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.dtype + dtype('int64') + """ + return self._mgr.dtype + + @property + def dtypes(self) -> DtypeObj: + """ + Return the dtype object of the underlying data. + + See Also + -------- + DataFrame.dtypes : Return the dtypes in the DataFrame. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.dtypes + dtype('int64') + """ + # DataFrame compatibility + return self.dtype + + @property + def name(self) -> Hashable: + """ + Return the name of the Series. + + The name of a Series becomes its index or column name if it is used + to form a DataFrame. It is also used whenever displaying the Series + using the interpreter. + + Returns + ------- + label (hashable object) + The name of the Series, also the column name if part of a DataFrame. + + See Also + -------- + Series.rename : Sets the Series name when given a scalar input. + Index.name : Corresponding Index property. + + Examples + -------- + The Series name can be set initially when calling the constructor. + + >>> s = pd.Series([1, 2, 3], dtype=np.int64, name="Numbers") + >>> s + 0 1 + 1 2 + 2 3 + Name: Numbers, dtype: int64 + >>> s.name = "Integers" + >>> s + 0 1 + 1 2 + 2 3 + Name: Integers, dtype: int64 + + The name of a Series within a DataFrame is its column name. + + >>> df = pd.DataFrame( + ... [[1, 2], [3, 4], [5, 6]], columns=["Odd Numbers", "Even Numbers"] + ... ) + >>> df + Odd Numbers Even Numbers + 0 1 2 + 1 3 4 + 2 5 6 + >>> df["Even Numbers"].name + 'Even Numbers' + """ + return self._name + + @name.setter + def name(self, value: Hashable) -> None: + validate_all_hashable(value, error_name=f"{type(self).__name__}.name") + object.__setattr__(self, "_name", value) + + @property + def values(self): + """ + Return Series as ndarray or ndarray-like depending on the dtype. + + .. warning:: + + We recommend using :attr:`Series.array` or + :meth:`Series.to_numpy`, depending on whether you need + a reference to the underlying data or a NumPy array. + + Returns + ------- + numpy.ndarray or ndarray-like + + See Also + -------- + Series.array : Reference to the underlying data. + Series.to_numpy : A NumPy array representing the underlying data. + + Examples + -------- + >>> pd.Series([1, 2, 3]).values + array([1, 2, 3]) + + >>> pd.Series(list("aabc")).values + + ['a', 'a', 'b', 'c'] + Length: 4, dtype: str + + >>> pd.Series(list("aabc")).astype("category").values + ['a', 'a', 'b', 'c'] + Categories (3, str): ['a', 'b', 'c'] + + Timezone aware datetime data is converted to UTC: + + >>> pd.Series(pd.date_range("20130101", periods=3, tz="US/Eastern")).values + array(['2013-01-01T05:00:00.000000', + '2013-01-02T05:00:00.000000', + '2013-01-03T05:00:00.000000'], dtype='datetime64[us]') + """ + return self._mgr.external_values() + + @property + def _values(self): + """ + Return the internal repr of this data (defined by Block.interval_values). + This are the values as stored in the Block (ndarray or ExtensionArray + depending on the Block class), with datetime64[ns] and timedelta64[ns] + wrapped in ExtensionArrays to match Index._values behavior. + + Differs from the public ``.values`` for certain data types, because of + historical backwards compatibility of the public attribute (e.g. period + returns object ndarray and datetimetz a datetime64[ns] ndarray for + ``.values`` while it returns an ExtensionArray for ``._values`` in those + cases). + + Differs from ``.array`` in that this still returns the numpy array if + the Block is backed by a numpy array (except for datetime64 and + timedelta64 dtypes), while ``.array`` ensures to always return an + ExtensionArray. + + Overview: + + dtype | values | _values | array | + ----------- | ------------- | ------------- | --------------------- | + Numeric | ndarray | ndarray | NumpyExtensionArray | + Category | Categorical | Categorical | Categorical | + dt64[ns] | ndarray[M8ns] | DatetimeArray | DatetimeArray | + dt64[ns tz] | ndarray[M8ns] | DatetimeArray | DatetimeArray | + td64[ns] | ndarray[m8ns] | TimedeltaArray| TimedeltaArray | + Period | ndarray[obj] | PeriodArray | PeriodArray | + Nullable | EA | EA | EA | + + """ + return self._mgr.internal_values() + + @property + def _references(self) -> BlockValuesRefs: + return self._mgr._block.refs + + @Appender(base.IndexOpsMixin.array.__doc__) # type: ignore[prop-decorator] + @property + def array(self) -> ExtensionArray: + arr = self._mgr.array_values() + # TODO decide on read-only https://github.com/pandas-dev/pandas/issues/63099 + # arr = arr.view() + # arr._readonly = True + return arr + + def __len__(self) -> int: + """ + Return the length of the Series. + """ + return len(self._mgr) + + # ---------------------------------------------------------------------- + # NDArray Compat + def __array__( + self, dtype: npt.DTypeLike | None = None, copy: bool | None = None + ) -> np.ndarray: + """ + Return the values as a NumPy array. + + Users should not call this directly. Rather, it is invoked by + :func:`numpy.array` and :func:`numpy.asarray`. + + Parameters + ---------- + dtype : str or numpy.dtype, optional + The dtype to use for the resulting NumPy array. By default, + the dtype is inferred from the data. + + copy : bool or None, optional + See :func:`numpy.asarray`. + + Returns + ------- + numpy.ndarray + The values in the series converted to a :class:`numpy.ndarray` + with the specified `dtype`. + + See Also + -------- + array : Create a new array from data. + Series.array : Zero-copy view to the array backing the Series. + Series.to_numpy : Series method for similar behavior. + + Examples + -------- + >>> ser = pd.Series([1, 2, 3]) + >>> np.asarray(ser) + array([1, 2, 3]) + + For timezone-aware data, the timezones may be retained with + ``dtype='object'`` + + >>> tzser = pd.Series(pd.date_range("2000", periods=2, tz="CET")) + >>> np.asarray(tzser, dtype="object") + array([Timestamp('2000-01-01 00:00:00+0100', tz='CET'), + Timestamp('2000-01-02 00:00:00+0100', tz='CET')], + dtype=object) + + Or the values may be localized to UTC and the tzinfo discarded with + ``dtype='datetime64[ns]'`` + + >>> np.asarray(tzser, dtype="datetime64[ns]") # doctest: +ELLIPSIS + array(['1999-12-31T23:00:00.000000000', ...], + dtype='datetime64[ns]') + """ + values = self._values + if copy is None: + # Note: branch avoids `copy=None` for NumPy 1.x support + arr = np.asarray(values, dtype=dtype) + else: + arr = np.array(values, dtype=dtype, copy=copy) + + if copy is True: + return arr + if copy is False or astype_is_view(values.dtype, arr.dtype): + arr = arr.view() + arr.flags.writeable = False + return arr + + # ---------------------------------------------------------------------- + + # indexers + @property + def axes(self) -> list[Index]: + """ + Return a list of the row axis labels. + """ + return [self.index] + + # ---------------------------------------------------------------------- + # Indexing Methods + + def _ixs(self, i: int, axis: AxisInt = 0) -> Any: + """ + Return the i-th value or values in the Series by location. + + Parameters + ---------- + i : int + + Returns + ------- + scalar + """ + return self._values[i] + + def _slice(self, slobj: slice, axis: AxisInt = 0) -> Series: + # axis kwarg is retained for compat with NDFrame method + # _slice is *always* positional + mgr = self._mgr.get_slice(slobj, axis=axis) + out = self._constructor_from_mgr(mgr, axes=mgr.axes) + out._name = self._name + return out.__finalize__(self) + + def __getitem__(self, key): + check_dict_or_set_indexers(key) + key = com.apply_if_callable(key, self) + + if key is Ellipsis: + return self.copy(deep=False) + + key_is_scalar = is_scalar(key) + if isinstance(key, (list, tuple)): + key = unpack_1tuple(key) + + elif key_is_scalar: + # Note: GH#50617 in 3.0 we changed int key to always be treated as + # a label, matching DataFrame behavior. + return self._get_value(key) + + # Convert generator to list before going through hashable part + # (We will iterate through the generator there to check for slices) + if is_iterator(key): + key = list(key) + + if is_hashable(key, allow_slice=False): + # Otherwise index.get_value will raise InvalidIndexError + try: + # For labels that don't resolve as scalars like tuples and frozensets + result = self._get_value(key) + + return result + + except (KeyError, TypeError, InvalidIndexError): + # InvalidIndexError for e.g. generator + # see test_series_getitem_corner_generator + if isinstance(key, tuple) and isinstance(self.index, MultiIndex): + # We still have the corner case where a tuple is a key + # in the first level of our MultiIndex + return self._get_values_tuple(key) + + if isinstance(key, slice): + # Do slice check before somewhat-costly is_bool_indexer + return self._getitem_slice(key) + + if com.is_bool_indexer(key): + key = check_bool_indexer(self.index, key) + key = np.asarray(key, dtype=bool) + return self._get_rows_with_mask(key) + + return self._get_with(key) + + def _get_with(self, key): + # other: fancy integer or otherwise + if isinstance(key, ABCDataFrame): + raise TypeError( + "Indexing a Series with DataFrame is not " + "supported, use the appropriate DataFrame column" + ) + elif isinstance(key, tuple): + return self._get_values_tuple(key) + + return self.loc[key] + + def _get_values_tuple(self, key: tuple): + # mpl hackaround + if com.any_none(*key): + # mpl compat if we look up e.g. ser[:, np.newaxis]; + # see tests.series.timeseries.test_mpl_compat_hack + # the asarray is needed to avoid returning a 2D DatetimeArray + result = np.asarray(self._values[key]) + disallow_ndim_indexing(result) + return result + + if not isinstance(self.index, MultiIndex): + raise KeyError("key of type tuple not found and not a MultiIndex") + + # If key is contained, would have returned by now + indexer, new_index = self.index.get_loc_level(key) + new_ser = self._constructor(self._values[indexer], index=new_index, copy=False) + if isinstance(indexer, slice): + new_ser._mgr.add_references(self._mgr) + return new_ser.__finalize__(self) + + def _get_rows_with_mask(self, indexer: npt.NDArray[np.bool_]) -> Series: + new_mgr = self._mgr.get_rows_with_mask(indexer) + return self._constructor_from_mgr(new_mgr, axes=new_mgr.axes).__finalize__(self) + + def _get_value(self, label, takeable: bool = False): + """ + Quickly retrieve single value at passed index label. + + Parameters + ---------- + label : object + takeable : interpret the index as indexers, default False + + Returns + ------- + scalar value + """ + if takeable: + return self._values[label] + + # Similar to Index.get_value, but we do not fall back to positional + loc = self.index.get_loc(label) + + if is_integer(loc): + return self._values[loc] + + if isinstance(self.index, MultiIndex): + mi = self.index + new_values = self._values[loc] + if len(new_values) == 1 and mi.nlevels == 1: + # If more than one level left, we can not return a scalar + return new_values[0] + + new_index = mi[loc] + new_index = maybe_droplevels(new_index, label) + new_ser = self._constructor( + new_values, index=new_index, name=self.name, copy=False + ) + if isinstance(loc, slice): + new_ser._mgr.add_references(self._mgr) + return new_ser.__finalize__(self) + + else: + return self.iloc[loc] + + def __setitem__(self, key, value) -> None: + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount(self) <= REF_COUNT and not com.is_local_in_caller_frame( + self + ): + warnings.warn( + _chained_assignment_msg, ChainedAssignmentError, stacklevel=2 + ) + + check_dict_or_set_indexers(key) + key = com.apply_if_callable(key, self) + + if key is Ellipsis: + key = slice(None) + + if isinstance(key, slice): + indexer = self.index._convert_slice_indexer(key, kind="getitem") + return self._set_values(indexer, value) + + try: + self._set_with_engine(key, value) + except KeyError: + # We have a scalar (or for MultiIndex or object-dtype, scalar-like) + # key that is not present in self.index. + # GH#12862 adding a new key to the Series + self.loc[key] = value + + except (TypeError, ValueError, LossySetitemError): + # The key was OK, but we cannot set the value losslessly + indexer = self.index.get_loc(key) + self._set_values(indexer, value) + + except InvalidIndexError as err: + if isinstance(key, tuple) and not isinstance(self.index, MultiIndex): + # cases with MultiIndex don't get here bc they raise KeyError + # e.g. test_basic_getitem_setitem_corner + raise KeyError( + "key of type tuple not found and not a MultiIndex" + ) from err + + if com.is_bool_indexer(key): + key = check_bool_indexer(self.index, key) + key = np.asarray(key, dtype=bool) + + if ( + is_list_like(value) + and len(value) != len(self) + and not isinstance(value, Series) + and not is_object_dtype(self.dtype) + ): + # Series will be reindexed to have matching length inside + # _where call below + # GH#44265 + indexer = key.nonzero()[0] + self._set_values(indexer, value) + return + + # otherwise with listlike other we interpret series[mask] = other + # as series[mask] = other[mask] + try: + self._where(~key, value, inplace=True) + except InvalidIndexError: + # test_where_dups + self.iloc[key] = value + return + + else: + self._set_with(key, value) + + def _set_with_engine(self, key, value) -> None: + loc = self.index.get_loc(key) + + # this is equivalent to self._values[key] = value + self._mgr.setitem_inplace(loc, value) + + def _set_with(self, key, value) -> None: + # We got here via exception-handling off of InvalidIndexError, so + # key should always be listlike at this point. + assert not isinstance(key, tuple) + + if is_iterator(key): + # Without this, the call to infer_dtype will consume the generator + key = list(key) + + self._set_labels(key, value) + + def _set_labels(self, key, value) -> None: + key = com.asarray_tuplesafe(key) + indexer: np.ndarray = self.index.get_indexer(key) + mask = indexer == -1 + if mask.any(): + raise KeyError(f"{key[mask]} not in index") + self._set_values(indexer, value) + + def _set_values(self, key, value) -> None: + if isinstance(key, (Index, Series)): + key = key._values + + self._mgr = self._mgr.setitem(indexer=key, value=value) + + def _set_value(self, label, value, takeable: bool = False) -> None: + """ + Quickly set single value at passed label. + + If label is not contained, a new object is created with the label + placed at the end of the result index. + + Parameters + ---------- + label : object + Partial indexing with MultiIndex not allowed. + value : object + Scalar value. + takeable : interpret the index as indexers, default False + """ + if not takeable: + try: + loc = self.index.get_loc(label) + except KeyError: + # set using a non-recursive method + self.loc[label] = value + return + else: + loc = label + + self._set_values(loc, value) + + # ---------------------------------------------------------------------- + # Unsorted + + def repeat(self, repeats: int | Sequence[int], axis: None = None) -> Series: + """ + Repeat elements of a Series. + + Returns a new Series where each element of the current Series + is repeated consecutively a given number of times. + + Parameters + ---------- + repeats : int or array of ints + The number of repetitions for each element. This should be a + non-negative integer. Repeating 0 times will return an empty + Series. + axis : None + Unused. Parameter needed for compatibility with DataFrame. + + Returns + ------- + Series + Newly created Series with repeated elements. + + See Also + -------- + Index.repeat : Equivalent function for Index. + numpy.repeat : Similar method for :class:`numpy.ndarray`. + + Examples + -------- + >>> s = pd.Series(["a", "b", "c"]) + >>> s + 0 a + 1 b + 2 c + dtype: str + >>> s.repeat(2) + 0 a + 0 a + 1 b + 1 b + 2 c + 2 c + dtype: str + >>> s.repeat([1, 2, 3]) + 0 a + 1 b + 1 b + 2 c + 2 c + 2 c + dtype: str + """ + nv.validate_repeat((), {"axis": axis}) + new_index = self.index.repeat(repeats) + new_values = self._values.repeat(repeats) + return self._constructor(new_values, index=new_index, copy=False).__finalize__( + self, method="repeat" + ) + + @overload + def reset_index( + self, + level: IndexLabel = ..., + *, + drop: Literal[False] = ..., + name: Level = ..., + inplace: Literal[False] = ..., + allow_duplicates: bool = ..., + ) -> DataFrame: ... + + @overload + def reset_index( + self, + level: IndexLabel = ..., + *, + drop: Literal[True], + name: Level = ..., + inplace: Literal[False] = ..., + allow_duplicates: bool = ..., + ) -> Series: ... + + @overload + def reset_index( + self, + level: IndexLabel = ..., + *, + drop: bool = ..., + name: Level = ..., + inplace: Literal[True], + allow_duplicates: bool = ..., + ) -> None: ... + + def reset_index( + self, + level: IndexLabel | None = None, + *, + drop: bool = False, + name: Level = lib.no_default, + inplace: bool = False, + allow_duplicates: bool = False, + ) -> DataFrame | Series | None: + """ + Generate a new DataFrame or Series with the index reset. + + This is useful when the index needs to be treated as a column, or + when the index is meaningless and needs to be reset to the default + before another operation. + + Parameters + ---------- + level : int, str, tuple, or list, default optional + For a Series with a MultiIndex, only remove the specified levels + from the index. Removes all levels by default. + drop : bool, default False + Just reset the index, without inserting it as a column in + the new DataFrame. + name : object, optional + The name to use for the column containing the original Series + values. Uses ``self.name`` by default. This argument is ignored + when `drop` is True. + inplace : bool, default False + Modify the Series in place (do not create a new object). + allow_duplicates : bool, default False + Allow duplicate column labels to be created. + + Returns + ------- + Series or DataFrame or None + When `drop` is False (the default), a DataFrame is returned. + The newly created columns will come first in the DataFrame, + followed by the original Series values. + When `drop` is True, a `Series` is returned. + In either case, if ``inplace=True``, no value is returned. + + See Also + -------- + DataFrame.reset_index: Analogous function for DataFrame. + + Examples + -------- + >>> s = pd.Series( + ... [1, 2, 3, 4], + ... name="foo", + ... index=pd.Index(["a", "b", "c", "d"], name="idx"), + ... ) + + Generate a DataFrame with default index. + + >>> s.reset_index() + idx foo + 0 a 1 + 1 b 2 + 2 c 3 + 3 d 4 + + To specify the name of the new column use `name`. + + >>> s.reset_index(name="values") + idx values + 0 a 1 + 1 b 2 + 2 c 3 + 3 d 4 + + To generate a new Series with the default set `drop` to True. + + >>> s.reset_index(drop=True) + 0 1 + 1 2 + 2 3 + 3 4 + Name: foo, dtype: int64 + + The `level` parameter is interesting for Series with a multi-level + index. + + >>> arrays = [ + ... np.array(["bar", "bar", "baz", "baz"]), + ... np.array(["one", "two", "one", "two"]), + ... ] + >>> s2 = pd.Series( + ... range(4), + ... name="foo", + ... index=pd.MultiIndex.from_arrays(arrays, names=["a", "b"]), + ... ) + + To remove a specific level from the Index, use `level`. + + >>> s2.reset_index(level="a") + a foo + b + one bar 0 + two bar 1 + one baz 2 + two baz 3 + + If `level` is not set, all levels are removed from the Index. + + >>> s2.reset_index() + a b foo + 0 bar one 0 + 1 bar two 1 + 2 baz one 2 + 3 baz two 3 + """ + inplace = validate_bool_kwarg(inplace, "inplace") + if drop: + new_index = default_index(len(self)) + if level is not None: + level_list: Sequence[Hashable] + if not isinstance(level, (tuple, list)): + level_list = [level] + else: + level_list = level + level_list = [self.index._get_level_number(lev) for lev in level_list] + if len(level_list) < self.index.nlevels: + new_index = self.index.droplevel(level_list) + + if inplace: + self.index = new_index + else: + new_ser = self.copy(deep=False) + new_ser.index = new_index + return new_ser.__finalize__(self, method="reset_index") + elif inplace: + raise TypeError( + "Cannot reset_index inplace on a Series to create a DataFrame" + ) + else: + if name is lib.no_default: + # For backwards compatibility, keep columns as [0] instead of + # [None] when self.name is None + if self.name is None: + name = 0 + else: + name = self.name + + df = self.to_frame(name) + return df.reset_index( + level=level, drop=drop, allow_duplicates=allow_duplicates + ) + return None + + # ---------------------------------------------------------------------- + # Rendering Methods + + def __repr__(self) -> str: + """ + Return a string representation for a particular Series. + """ + repr_params = fmt.get_series_repr_params() + return self.to_string(**repr_params) + + @overload + def to_string( + self, + buf: None = ..., + *, + na_rep: str = ..., + float_format: str | None = ..., + header: bool = ..., + index: bool = ..., + length: bool = ..., + dtype=..., + name=..., + max_rows: int | None = ..., + min_rows: int | None = ..., + ) -> str: ... + + @overload + def to_string( + self, + buf: FilePath | WriteBuffer[str], + *, + na_rep: str = ..., + float_format: str | None = ..., + header: bool = ..., + index: bool = ..., + length: bool = ..., + dtype=..., + name=..., + max_rows: int | None = ..., + min_rows: int | None = ..., + ) -> None: ... + + @deprecate_nonkeyword_arguments( + Pandas4Warning, allowed_args=["self", "buf"], name="to_string" + ) + def to_string( + self, + buf: FilePath | WriteBuffer[str] | None = None, + na_rep: str = "NaN", + float_format: str | None = None, + header: bool = True, + index: bool = True, + length: bool = False, + dtype: bool = False, + name: bool = False, + max_rows: int | None = None, + min_rows: int | None = None, + ) -> str | None: + """ + Render a string representation of the Series. + + Parameters + ---------- + buf : StringIO-like, optional + Buffer to write to. + na_rep : str, optional + String representation of NaN to use, default 'NaN'. + float_format : one-parameter function, optional + Formatter function to apply to columns' elements if they are + floats, default None. + header : bool, default True + Add the Series header (index name). + index : bool, optional + Add index (row) labels, default True. + length : bool, default False + Add the Series length. + dtype : bool, default False + Add the Series dtype. + name : bool, default False + Add the Series name if not None. + max_rows : int, optional + Maximum number of rows to show before truncating. If None, show + all. + min_rows : int, optional + The number of rows to display in a truncated repr (when number + of rows is above `max_rows`). + + Returns + ------- + str or None + String representation of Series if ``buf=None``, otherwise None. + + See Also + -------- + Series.to_dict : Convert Series to dict object. + Series.to_frame : Convert Series to DataFrame object. + Series.to_markdown : Print Series in Markdown-friendly format. + Series.to_timestamp : Cast to DatetimeIndex of Timestamps. + + Examples + -------- + >>> ser = pd.Series([1, 2, 3]).to_string() + >>> ser + '0 1\\n1 2\\n2 3' + """ + formatter = fmt.SeriesFormatter( + self, + name=name, + length=length, + header=header, + index=index, + dtype=dtype, + na_rep=na_rep, + float_format=float_format, + min_rows=min_rows, + max_rows=max_rows, + ) + result = formatter.to_string() + + # catch contract violations + if not isinstance(result, str): + raise AssertionError( + "result must be of type str, type " + f"of result is {type(result).__name__!r}" + ) + + if buf is None: + return result + elif hasattr(buf, "write"): + buf.write(result) + else: + with open(buf, "w", encoding="utf-8") as f: + f.write(result) + return None + + @overload + def to_markdown( + self, + buf: None = ..., + *, + mode: str = ..., + index: bool = ..., + storage_options: StorageOptions | None = ..., + **kwargs, + ) -> str: ... + + @overload + def to_markdown( + self, + buf: IO[str], + *, + mode: str = ..., + index: bool = ..., + storage_options: StorageOptions | None = ..., + **kwargs, + ) -> None: ... + + @overload + def to_markdown( + self, + buf: IO[str] | None, + *, + mode: str = ..., + index: bool = ..., + storage_options: StorageOptions | None = ..., + **kwargs, + ) -> str | None: ... + + @deprecate_nonkeyword_arguments( + Pandas4Warning, allowed_args=["self", "buf"], name="to_markdown" + ) + def to_markdown( + self, + buf: IO[str] | None = None, + mode: str = "wt", + index: bool = True, + storage_options: StorageOptions | None = None, + **kwargs, + ) -> str | None: + """ + Print Series in Markdown-friendly format. + + Parameters + ---------- + buf : str, Path or StringIO-like, optional, default None + Buffer to write to. If None, the output is returned as a string. + mode : str, optional + Mode in which file is opened, "wt" by default. + index : bool, optional, default True + Add index (row) labels. + + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + **kwargs + These parameters will be passed to `tabulate \ + `_. + + Returns + ------- + str + Series in Markdown-friendly format. + + See Also + -------- + Series.to_frame : Rrite a text representation of object to the system clipboard. + Series.to_latex : Render Series to LaTeX-formatted table. + + Notes + ----- + Requires the `tabulate `_ package. + + Examples + -------- + >>> s = pd.Series(["elk", "pig", "dog", "quetzal"], name="animal") + >>> print(s.to_markdown()) + | | animal | + |---:|:---------| + | 0 | elk | + | 1 | pig | + | 2 | dog | + | 3 | quetzal | + + Output markdown with a tabulate option. + + >>> print(s.to_markdown(tablefmt="grid")) + +----+----------+ + | | animal | + +====+==========+ + | 0 | elk | + +----+----------+ + | 1 | pig | + +----+----------+ + | 2 | dog | + +----+----------+ + | 3 | quetzal | + +----+----------+ + """ + return self.to_frame().to_markdown( + buf, mode=mode, index=index, storage_options=storage_options, **kwargs + ) + + # ---------------------------------------------------------------------- + + def items(self) -> Iterable[tuple[Hashable, Any]]: + """ + Lazily iterate over (index, value) tuples. + + This method returns an iterable tuple (index, value). This is + convenient if you want to create a lazy iterator. + + Returns + ------- + iterable + Iterable of tuples containing the (index, value) pairs from a + Series. + + See Also + -------- + DataFrame.items : Iterate over (column name, Series) pairs. + DataFrame.iterrows : Iterate over DataFrame rows as (index, Series) pairs. + + Examples + -------- + >>> s = pd.Series(["A", "B", "C"]) + >>> for index, value in s.items(): + ... print(f"Index : {index}, Value : {value}") + Index : 0, Value : A + Index : 1, Value : B + Index : 2, Value : C + """ + return zip(iter(self.index), iter(self), strict=True) + + # ---------------------------------------------------------------------- + # Misc public methods + + def keys(self) -> Index: + """ + Return alias for index. + + Returns + ------- + Index + Index of the Series. + + See Also + -------- + Series.index : The index (axis labels) of the Series. + + Examples + -------- + >>> s = pd.Series([1, 2, 3], index=[0, 1, 2]) + >>> s.keys() + Index([0, 1, 2], dtype='int64') + """ + return self.index + + @overload + def to_dict( + self, *, into: type[MutableMappingT] | MutableMappingT + ) -> MutableMappingT: ... + + @overload + def to_dict(self, *, into: type[dict] = ...) -> dict: ... + + # error: Incompatible default for argument "into" (default has type "type[ + # dict[Any, Any]]", argument has type "type[MutableMappingT] | MutableMappingT") + def to_dict( + self, + *, + into: type[MutableMappingT] | MutableMappingT = dict, # type: ignore[assignment] + ) -> MutableMappingT: + """ + Convert Series to {label -> value} dict or dict-like object. + + Parameters + ---------- + into : class, default dict + The collections.abc.MutableMapping subclass to use as the return + object. Can be the actual class or an empty instance of the mapping + type you want. If you want a collections.defaultdict, you must + pass it initialized. + + Returns + ------- + collections.abc.MutableMapping + Key-value representation of Series. + + See Also + -------- + Series.to_list: Converts Series to a list of the values. + Series.to_numpy: Converts Series to NumPy ndarray. + Series.array: ExtensionArray of the data backing this Series. + + Examples + -------- + >>> s = pd.Series([1, 2, 3, 4]) + >>> s.to_dict() + {0: 1, 1: 2, 2: 3, 3: 4} + >>> from collections import OrderedDict, defaultdict + >>> s.to_dict(into=OrderedDict) + OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> dd = defaultdict(list) + >>> s.to_dict(into=dd) + defaultdict(, {0: 1, 1: 2, 2: 3, 3: 4}) + """ + # GH16122 + into_c = com.standardize_mapping(into) + + if is_object_dtype(self.dtype) or isinstance(self.dtype, ExtensionDtype): + return into_c((k, maybe_box_native(v)) for k, v in self.items()) + else: + # Not an object dtype => all types will be the same so let the default + # indexer return native python type + return into_c(self.items()) + + def to_frame(self, name: Hashable = lib.no_default) -> DataFrame: + """ + Convert Series to DataFrame. + + Parameters + ---------- + name : object, optional + The passed name should substitute for the series name (if it has + one). + + Returns + ------- + DataFrame + DataFrame representation of Series. + + See Also + -------- + Series.to_dict : Convert Series to dict object. + + Examples + -------- + >>> s = pd.Series(["a", "b", "c"], name="vals") + >>> s.to_frame() + vals + 0 a + 1 b + 2 c + """ + columns: Index + if name is lib.no_default: + name = self.name + if name is None: + # default to [0], same as we would get with DataFrame(self) + columns = default_index(1) + else: + columns = Index([name]) + else: + columns = Index([name]) + + mgr = self._mgr.to_2d_mgr(columns) + df = self._constructor_expanddim_from_mgr(mgr, axes=mgr.axes) + return df.__finalize__(self, method="to_frame") + + @classmethod + def from_arrow(cls, data: ArrowArrayExportable | ArrowStreamExportable) -> Series: + """ + Construct a Series from an array-like Arrow object. + + This function accepts any Arrow-compatible array-like object implementing + the `Arrow PyCapsule Protocol`_ (i.e. having an ``__arrow_c_array__`` + or ``__arrow_c_stream__`` method). + + This function currently relies on ``pyarrow`` to convert the object + in Arrow format to pandas. + + .. _Arrow PyCapsule Protocol: https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html + + .. versionadded:: 3.0 + + Parameters + ---------- + data : pyarrow.Array or Arrow-compatible object + Any array-like object implementing the Arrow PyCapsule Protocol + (i.e. has an ``__arrow_c_array__`` or ``__arrow_c_stream__`` + method). + + Returns + ------- + Series + + See Also + -------- + DataFrame.from_arrow : Construct a DataFrame from an Arrow object. + + Examples + -------- + >>> import pyarrow as pa + >>> arrow_array = pa.array([1, 2, 3]) + >>> pd.Series.from_arrow(arrow_array) + 0 1 + 1 2 + 2 3 + dtype: int64 + """ + pa = import_optional_dependency("pyarrow", min_version="14.0.0") + if not isinstance(data, (pa.Array, pa.ChunkedArray)): + if not ( + hasattr(data, "__arrow_c_array__") + or hasattr(data, "__arrow_c_stream__") + ): + # explicitly test this, because otherwise we would accept variour other + # input types through the pa.chunked_array(..) call + raise TypeError( + "Expected an Arrow-compatible array-like object (i.e. having an " + "'_arrow_c_array__' or '__arrow_c_stream__' method), got " + f"'{type(data).__name__}' instead." + ) + # using chunked_array() as it works for both arrays and streams + pa_array = pa.chunked_array(data) + else: + pa_array = data + + ser = pa_array.to_pandas() + return ser + + def _set_name(self, name, inplace: bool = False) -> Series: + """ + Set the Series name. + + Parameters + ---------- + name : str + inplace : bool + Whether to modify `self` directly or return a copy. + """ + inplace = validate_bool_kwarg(inplace, "inplace") + ser = self if inplace else self.copy(deep=False) + ser.name = name + return ser + + @Appender( + dedent( + """ + Examples + -------- + >>> ser = pd.Series([390., 350., 30., 20.], + ... index=['Falcon', 'Falcon', 'Parrot', 'Parrot'], + ... name="Max Speed") + >>> ser + Falcon 390.0 + Falcon 350.0 + Parrot 30.0 + Parrot 20.0 + Name: Max Speed, dtype: float64 + + We can pass a list of values to group the Series data by custom labels: + + >>> ser.groupby(["a", "b", "a", "b"]).mean() + a 210.0 + b 185.0 + Name: Max Speed, dtype: float64 + + Grouping by numeric labels yields similar results: + + >>> ser.groupby([0, 1, 0, 1]).mean() + 0 210.0 + 1 185.0 + Name: Max Speed, dtype: float64 + + We can group by a level of the index: + + >>> ser.groupby(level=0).mean() + Falcon 370.0 + Parrot 25.0 + Name: Max Speed, dtype: float64 + + We can group by a condition applied to the Series values: + + >>> ser.groupby(ser > 100).mean() + Max Speed + False 25.0 + True 370.0 + Name: Max Speed, dtype: float64 + + **Grouping by Indexes** + + We can groupby different levels of a hierarchical index + using the `level` parameter: + + >>> arrays = [['Falcon', 'Falcon', 'Parrot', 'Parrot'], + ... ['Captive', 'Wild', 'Captive', 'Wild']] + >>> index = pd.MultiIndex.from_arrays(arrays, names=('Animal', 'Type')) + >>> ser = pd.Series([390., 350., 30., 20.], index=index, name="Max Speed") + >>> ser + Animal Type + Falcon Captive 390.0 + Wild 350.0 + Parrot Captive 30.0 + Wild 20.0 + Name: Max Speed, dtype: float64 + + >>> ser.groupby(level=0).mean() + Animal + Falcon 370.0 + Parrot 25.0 + Name: Max Speed, dtype: float64 + + We can also group by the 'Type' level of the hierarchical index + to get the mean speed for each type: + + >>> ser.groupby(level="Type").mean() + Type + Captive 210.0 + Wild 185.0 + Name: Max Speed, dtype: float64 + + We can also choose to include `NA` in group keys or not by defining + `dropna` parameter, the default setting is `True`. + + >>> ser = pd.Series([1, 2, 3, 3], index=["a", 'a', 'b', np.nan]) + >>> ser.groupby(level=0).sum() + a 3 + b 3 + dtype: int64 + + To include `NA` values in the group keys, set `dropna=False`: + + >>> ser.groupby(level=0, dropna=False).sum() + a 3 + b 3 + NaN 3 + dtype: int64 + + We can also group by a custom list with NaN values to handle + missing group labels: + + >>> arrays = ['Falcon', 'Falcon', 'Parrot', 'Parrot'] + >>> ser = pd.Series([390., 350., 30., 20.], index=arrays, name="Max Speed") + >>> ser.groupby(["a", "b", "a", np.nan]).mean() + a 210.0 + b 350.0 + Name: Max Speed, dtype: float64 + + >>> ser.groupby(["a", "b", "a", np.nan], dropna=False).mean() + a 210.0 + b 350.0 + NaN 20.0 + Name: Max Speed, dtype: float64 + """ + ) + ) + @Appender(_shared_docs["groupby"] % _shared_doc_kwargs) + @deprecate_nonkeyword_arguments( + Pandas4Warning, allowed_args=["self", "by", "level"], name="groupby" + ) + def groupby( + self, + by=None, + level: IndexLabel | None = None, + as_index: bool = True, + sort: bool = True, + group_keys: bool = True, + observed: bool = True, + dropna: bool = True, + ) -> SeriesGroupBy: + from pandas.core.groupby.generic import SeriesGroupBy + + if level is None and by is None: + raise TypeError("You have to supply one of 'by' and 'level'") + if not as_index: + raise TypeError("as_index=False only valid with DataFrame") + + return SeriesGroupBy( + obj=self, + keys=by, + level=level, + as_index=as_index, + sort=sort, + group_keys=group_keys, + observed=observed, + dropna=dropna, + ) + + # ---------------------------------------------------------------------- + # Statistics, overridden ndarray methods + + # TODO: integrate bottleneck + def count(self) -> int: + """ + Return number of non-NA/null observations in the Series. + + Returns + ------- + int + Number of non-null values in the Series. + + See Also + -------- + DataFrame.count : Count non-NA cells for each column or row. + + Examples + -------- + >>> s = pd.Series([0.0, 1.0, np.nan]) + >>> s.count() + 2 + """ + return maybe_unbox_numpy_scalar(notna(self._values).sum().astype("int64")) + + def mode(self, dropna: bool = True) -> Series: + """ + Return the mode(s) of the Series. + + The mode is the value that appears most often. There can be multiple modes. + + Always returns Series even if only one value is returned. + + Parameters + ---------- + dropna : bool, default True + Don't consider counts of NaN/NaT. + + Returns + ------- + Series + Modes of the Series in sorted order. + + See Also + -------- + numpy.mode : Equivalent numpy function for computing median. + Series.sum : Sum of the values. + Series.median : Median of the values. + Series.std : Standard deviation of the values. + Series.var : Variance of the values. + Series.min : Minimum value. + Series.max : Maximum value. + + Examples + -------- + >>> s = pd.Series([2, 4, 2, 2, 4, None]) + >>> s.mode() + 0 2.0 + dtype: float64 + + More than one mode: + + >>> s = pd.Series([2, 4, 8, 2, 4, None]) + >>> s.mode() + 0 2.0 + 1 4.0 + dtype: float64 + + With and without considering null value: + + >>> s = pd.Series([2, 4, None, None, 4, None]) + >>> s.mode(dropna=False) + 0 NaN + dtype: float64 + >>> s = pd.Series([2, 4, None, None, 4, None]) + >>> s.mode() + 0 4.0 + dtype: float64 + """ + # TODO: Add option for bins like value_counts() + values = self._values + if isinstance(values, np.ndarray): + res_values, _ = algorithms.mode(values, dropna=dropna) + else: + res_values = values._mode(dropna=dropna) + + # Ensure index is type stable (should always use int index) + return self._constructor( + res_values, + index=range(len(res_values)), + name=self.name, + copy=False, + dtype=self.dtype, + ).__finalize__(self, method="mode") + + def unique(self) -> ArrayLike: + """ + Return unique values of Series object. + + Uniques are returned in order of appearance. Hash table-based unique, + therefore does NOT sort. + + Returns + ------- + ndarray or ExtensionArray + The unique values returned as a NumPy array. See Notes. + + See Also + -------- + Series.drop_duplicates : Return Series with duplicate values removed. + unique : Top-level unique method for any 1-d array-like object. + Index.unique : Return Index with unique values from an Index object. + + Notes + ----- + Returns the unique values as a NumPy array. In case of an + extension-array backed Series, a new + :class:`~api.extensions.ExtensionArray` of that type with just + the unique values is returned. This includes + + * Categorical + * Period + * Datetime with Timezone + * Datetime without Timezone + * Timedelta + * Interval + * Sparse + * IntegerNA + + See Examples section. + + Examples + -------- + >>> pd.Series([2, 1, 3, 3], name="A").unique() + array([2, 1, 3]) + + >>> pd.Series([pd.Timestamp("2016-01-01") for _ in range(3)]).unique() + + ['2016-01-01 00:00:00'] + Length: 1, dtype: datetime64[us] + + >>> pd.Series( + ... [pd.Timestamp("2016-01-01", tz="US/Eastern") for _ in range(3)] + ... ).unique() + + ['2016-01-01 00:00:00-05:00'] + Length: 1, dtype: datetime64[us, US/Eastern] + + A Categorical will return categories in the order of + appearance and with the same dtype. + + >>> pd.Series(pd.Categorical(list("baabc"))).unique() + ['b', 'a', 'c'] + Categories (3, str): ['a', 'b', 'c'] + >>> pd.Series( + ... pd.Categorical(list("baabc"), categories=list("abc"), ordered=True) + ... ).unique() + ['b', 'a', 'c'] + Categories (3, str): ['a' < 'b' < 'c'] + """ + return super().unique() + + @overload + def drop_duplicates( + self, + *, + keep: DropKeep = ..., + inplace: Literal[False] = ..., + ignore_index: bool = ..., + ) -> Series: ... + + @overload + def drop_duplicates( + self, *, keep: DropKeep = ..., inplace: Literal[True], ignore_index: bool = ... + ) -> None: ... + + @overload + def drop_duplicates( + self, *, keep: DropKeep = ..., inplace: bool = ..., ignore_index: bool = ... + ) -> Series | None: ... + + def drop_duplicates( + self, + *, + keep: DropKeep = "first", + inplace: bool = False, + ignore_index: bool = False, + ) -> Series | None: + """ + Return Series with duplicate values removed. + + Parameters + ---------- + keep : {'first', 'last', ``False``}, default 'first' + Method to handle dropping duplicates: + + - 'first' : Drop duplicates except for the first occurrence. + - 'last' : Drop duplicates except for the last occurrence. + - ``False`` : Drop all duplicates. + + inplace : bool, default ``False`` + If ``True``, performs operation inplace and returns None. + + ignore_index : bool, default ``False`` + If ``True``, the resulting axis will be labeled 0, 1, …, n - 1. + + .. versionadded:: 2.0.0 + + Returns + ------- + Series or None + Series with duplicates dropped or None if ``inplace=True``. + + See Also + -------- + Index.drop_duplicates : Equivalent method on Index. + DataFrame.drop_duplicates : Equivalent method on DataFrame. + Series.duplicated : Related method on Series, indicating duplicate + Series values. + Series.unique : Return unique values as an array. + + Examples + -------- + Generate a Series with duplicated entries. + + >>> s = pd.Series( + ... ["llama", "cow", "llama", "beetle", "llama", "hippo"], name="animal" + ... ) + >>> s + 0 llama + 1 cow + 2 llama + 3 beetle + 4 llama + 5 hippo + Name: animal, dtype: str + + With the 'keep' parameter, the selection behavior of duplicated values + can be changed. The value 'first' keeps the first occurrence for each + set of duplicated entries. The default value of keep is 'first'. + + >>> s.drop_duplicates() + 0 llama + 1 cow + 3 beetle + 5 hippo + Name: animal, dtype: str + + The value 'last' for parameter 'keep' keeps the last occurrence for + each set of duplicated entries. + + >>> s.drop_duplicates(keep="last") + 1 cow + 3 beetle + 4 llama + 5 hippo + Name: animal, dtype: str + + The value ``False`` for parameter 'keep' discards all sets of + duplicated entries. + + >>> s.drop_duplicates(keep=False) + 1 cow + 3 beetle + 5 hippo + Name: animal, dtype: str + """ + inplace = validate_bool_kwarg(inplace, "inplace") + result = super().drop_duplicates(keep=keep) + + if ignore_index: + result.index = default_index(len(result)) + + if inplace: + self._update_inplace(result) + return None + else: + return result + + def duplicated(self, keep: DropKeep = "first") -> Series: + """ + Indicate duplicate Series values. + + Duplicated values are indicated as ``True`` values in the resulting + Series. Either all duplicates, all except the first or all except the + last occurrence of duplicates can be indicated. + + Parameters + ---------- + keep : {'first', 'last', False}, default 'first' + Method to handle dropping duplicates: + + - 'first' : Mark duplicates as ``True`` except for the first + occurrence. + - 'last' : Mark duplicates as ``True`` except for the last + occurrence. + - ``False`` : Mark all duplicates as ``True``. + + Returns + ------- + Series[bool] + Series indicating whether each value has occurred in the + preceding values. + + See Also + -------- + Index.duplicated : Equivalent method on pandas.Index. + DataFrame.duplicated : Equivalent method on pandas.DataFrame. + Series.drop_duplicates : Remove duplicate values from Series. + + Examples + -------- + By default, for each set of duplicated values, the first occurrence is + set on False and all others on True: + + >>> animals = pd.Series(["llama", "cow", "llama", "beetle", "llama"]) + >>> animals.duplicated() + 0 False + 1 False + 2 True + 3 False + 4 True + dtype: bool + + which is equivalent to + + >>> animals.duplicated(keep="first") + 0 False + 1 False + 2 True + 3 False + 4 True + dtype: bool + + By using 'last', the last occurrence of each set of duplicated values + is set on False and all others on True: + + >>> animals.duplicated(keep="last") + 0 True + 1 False + 2 True + 3 False + 4 False + dtype: bool + + By setting keep on ``False``, all duplicates are True: + + >>> animals.duplicated(keep=False) + 0 True + 1 False + 2 True + 3 False + 4 True + dtype: bool + """ + res = self._duplicated(keep=keep) + result = self._constructor(res, index=self.index, copy=False) + return result.__finalize__(self, method="duplicated") + + def idxmin(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Hashable: + """ + Return the row label of the minimum value. + + If multiple values equal the minimum, the first row label with that + value is returned. + + Parameters + ---------- + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + skipna : bool, default True + Exclude NA/null values. If the entire Series is NA, or if ``skipna=False`` + and there is an NA value, this method will raise a ``ValueError``. + *args, **kwargs + Additional arguments and keywords have no effect but might be + accepted for compatibility with NumPy. + + Returns + ------- + Index + Label of the minimum value. + + Raises + ------ + ValueError + If the Series is empty. + + See Also + -------- + numpy.argmin : Return indices of the minimum values + along the given axis. + DataFrame.idxmin : Return index of first occurrence of minimum + over requested axis. + Series.idxmax : Return index *label* of the first occurrence + of maximum of values. + + Notes + ----- + This method is the Series version of ``ndarray.argmin``. This method + returns the label of the minimum, while ``ndarray.argmin`` returns + the position. To get the position, use ``series.values.argmin()``. + + Examples + -------- + >>> s = pd.Series(data=[1, None, 4, 1], index=["A", "B", "C", "D"]) + >>> s + A 1.0 + B NaN + C 4.0 + D 1.0 + dtype: float64 + + >>> s.idxmin() + 'A' + """ + axis = self._get_axis_number(axis) + iloc = self.argmin(axis, skipna, *args, **kwargs) + return self.index[iloc] + + def idxmax(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Hashable: + """ + Return the row label of the maximum value. + + If multiple values equal the maximum, the first row label with that + value is returned. + + Parameters + ---------- + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + skipna : bool, default True + Exclude NA/null values. If the entire Series is NA, or if ``skipna=False`` + and there is an NA value, this method will raise a ``ValueError``. + *args, **kwargs + Additional arguments and keywords have no effect but might be + accepted for compatibility with NumPy. + + Returns + ------- + Index + Label of the maximum value. + + Raises + ------ + ValueError + If the Series is empty. + + See Also + -------- + numpy.argmax : Return indices of the maximum values + along the given axis. + DataFrame.idxmax : Return index of first occurrence of maximum + over requested axis. + Series.idxmin : Return index *label* of the first occurrence + of minimum of values. + + Notes + ----- + This method is the Series version of ``ndarray.argmax``. This method + returns the label of the maximum, while ``ndarray.argmax`` returns + the position. To get the position, use ``series.values.argmax()``. + + Examples + -------- + >>> s = pd.Series(data=[1, None, 4, 3, 4], index=["A", "B", "C", "D", "E"]) + >>> s + A 1.0 + B NaN + C 4.0 + D 3.0 + E 4.0 + dtype: float64 + + >>> s.idxmax() + 'C' + """ + axis = self._get_axis_number(axis) + iloc = self.argmax(axis, skipna, *args, **kwargs) + return self.index[iloc] + + def round(self, decimals: int = 0, *args, **kwargs) -> Series: + """ + Round each value in a Series to the given number of decimals. + + Parameters + ---------- + decimals : int, default 0 + Number of decimal places to round to. If decimals is negative, + it specifies the number of positions to the left of the decimal point. + *args, **kwargs + Additional arguments and keywords have no effect but might be + accepted for compatibility with NumPy. + + Returns + ------- + Series + Rounded values of the Series. + + See Also + -------- + numpy.around : Round values of an np.array. + DataFrame.round : Round values of a DataFrame. + Series.dt.round : Round values of data to the specified freq. + + Notes + ----- + For values exactly halfway between rounded decimal values, pandas rounds + to the nearest even value (e.g. -0.5 and 0.5 round to 0.0, 1.5 and 2.5 + round to 2.0, etc.). + + Examples + -------- + >>> s = pd.Series([-0.5, 0.1, 2.5, 1.3, 2.7]) + >>> s.round() + 0 -0.0 + 1 0.0 + 2 2.0 + 3 1.0 + 4 3.0 + dtype: float64 + """ + + nv.validate_round(args, kwargs) + + if len(self) == 0: + return self.copy() + + if is_object_dtype(self.dtype): + values = self._values + result = lib.map_infer(values, lambda x: round(x, decimals), convert=False) + return self._constructor(result, index=self.index, copy=False).__finalize__( + self, method="round" + ) + new_mgr = self._mgr.round(decimals=decimals) + return self._constructor_from_mgr(new_mgr, axes=new_mgr.axes).__finalize__( + self, method="round" + ) + + @overload + def quantile( + self, q: float = ..., interpolation: QuantileInterpolation = ... + ) -> float: ... + + @overload + def quantile( + self, + q: Sequence[float] | AnyArrayLike, + interpolation: QuantileInterpolation = ..., + ) -> Series: ... + + @overload + def quantile( + self, + q: float | Sequence[float] | AnyArrayLike = ..., + interpolation: QuantileInterpolation = ..., + ) -> float | Series: ... + + def quantile( + self, + q: float | Sequence[float] | AnyArrayLike = 0.5, + interpolation: QuantileInterpolation = "linear", + ) -> float | Series: + """ + Return value at the given quantile. + + Parameters + ---------- + q : float or array-like, default 0.5 (50% quantile) + The quantile(s) to compute, which can lie in range: 0 <= q <= 1. + interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} + This optional parameter specifies the interpolation method to use, + when the desired quantile lies between two data points `i` and `j`: + + * linear: `i + (j - i) * (x-i)/(j-i)`, where `(x-i)/(j-i)` is + the fractional part of the index surrounded by `i > j`. + * lower: `i`. + * higher: `j`. + * nearest: `i` or `j` whichever is nearest. + * midpoint: (`i` + `j`) / 2. + + Returns + ------- + float or Series + If ``q`` is an array, a Series will be returned where the + index is ``q`` and the values are the quantiles, otherwise + a float will be returned. + + See Also + -------- + core.window.Rolling.quantile : Calculate the rolling quantile. + numpy.percentile : Returns the q-th percentile(s) of the array elements. + + Examples + -------- + >>> s = pd.Series([1, 2, 3, 4]) + >>> s.quantile(0.5) + 2.5 + >>> s.quantile([0.25, 0.5, 0.75]) + 0.25 1.75 + 0.50 2.50 + 0.75 3.25 + dtype: float64 + """ + validate_percentile(q) + + # We dispatch to DataFrame so that core.internals only has to worry + # about 2D cases. + df = self.to_frame() + + result = df.quantile(q=q, interpolation=interpolation, numeric_only=False) + if result.ndim == 2: + result = result.iloc[:, 0] + + if is_list_like(q): + result.name = self.name + idx = Index(q, dtype=np.float64) + return self._constructor(result, index=idx, name=self.name) + else: + # scalar + return maybe_unbox_numpy_scalar(result.iloc[0]) + + def corr( + self, + other: Series, + method: CorrelationMethod = "pearson", + min_periods: int | None = None, + ) -> float: + """ + Compute correlation with `other` Series, excluding missing values. + + The two `Series` objects are not required to be the same length and will be + aligned internally before the correlation function is applied. + + Parameters + ---------- + other : Series + Series with which to compute the correlation. + method : {'pearson', 'kendall', 'spearman'} or callable + Method used to compute correlation: + + - pearson : Standard correlation coefficient + - kendall : Kendall Tau correlation coefficient + - spearman : Spearman rank correlation + - callable: Callable with input two 1d ndarrays and returning a float. + + .. warning:: + Note that the returned matrix from corr will have 1 along the + diagonals and will be symmetric regardless of the callable's + behavior. + min_periods : int, optional + Minimum number of observations needed to have a valid result. + + Returns + ------- + float + Correlation with other. + + See Also + -------- + DataFrame.corr : Compute pairwise correlation between columns. + DataFrame.corrwith : Compute pairwise correlation with another + DataFrame or Series. + + Notes + ----- + Pearson, Kendall and Spearman correlation are currently computed using pairwise complete observations. + + * `Pearson correlation coefficient `_ + * `Kendall rank correlation coefficient `_ + * `Spearman's rank correlation coefficient `_ + + Automatic data alignment: as with all pandas operations, automatic data alignment is performed for this method. + ``corr()`` automatically considers values with matching indices. + + Examples + -------- + >>> def histogram_intersection(a, b): + ... v = np.minimum(a, b).sum().round(decimals=1) + ... return v + >>> s1 = pd.Series([0.2, 0.0, 0.6, 0.2]) + >>> s2 = pd.Series([0.3, 0.6, 0.0, 0.1]) + >>> s1.corr(s2, method=histogram_intersection) + 0.3 + + Pandas auto-aligns the values with matching indices + + >>> s1 = pd.Series([1, 2, 3], index=[0, 1, 2]) + >>> s2 = pd.Series([1, 2, 3], index=[2, 1, 0]) + >>> s1.corr(s2) + -1.0 + + If the input is a constant array, the correlation is not defined in this case, + and ``np.nan`` is returned. + + >>> s1 = pd.Series([0.45, 0.45]) + >>> s1.corr(s1) + nan + """ # noqa: E501 + this, other = self.align(other, join="inner") + if len(this) == 0: + return np.nan + + this_values = this.to_numpy(dtype=float, na_value=np.nan, copy=False) + other_values = other.to_numpy(dtype=float, na_value=np.nan, copy=False) + + if method in ["pearson", "spearman", "kendall"] or callable(method): + result = nanops.nancorr( + this_values, other_values, method=method, min_periods=min_periods + ) + result = maybe_unbox_numpy_scalar(result) + return result + + raise ValueError( + "method must be either 'pearson', " + "'spearman', 'kendall', or a callable, " + f"'{method}' was supplied" + ) + + def cov( + self, + other: Series, + min_periods: int | None = None, + ddof: int | None = 1, + ) -> float: + """ + Compute covariance with Series, excluding missing values. + + The two `Series` objects are not required to be the same length and + will be aligned internally before the covariance is calculated. + + Parameters + ---------- + other : Series + Series with which to compute the covariance. + min_periods : int, optional + Minimum number of observations needed to have a valid result. + ddof : int, default 1 + Delta degrees of freedom. The divisor used in calculations + is ``N - ddof``, where ``N`` represents the number of elements. + + Returns + ------- + float + Covariance between Series and other normalized by N-1 + (unbiased estimator). + + See Also + -------- + DataFrame.cov : Compute pairwise covariance of columns. + + Examples + -------- + >>> s1 = pd.Series([0.90010907, 0.13484424, 0.62036035]) + >>> s2 = pd.Series([0.12528585, 0.26962463, 0.51111198]) + >>> s1.cov(s2) + -0.01685762652715874 + """ + this, other = self.align(other, join="inner") + if len(this) == 0: + return np.nan + this_values = this.to_numpy(dtype=float, na_value=np.nan, copy=False) + other_values = other.to_numpy(dtype=float, na_value=np.nan, copy=False) + result = nanops.nancov( + this_values, other_values, min_periods=min_periods, ddof=ddof + ) + result = maybe_unbox_numpy_scalar(result) + return result + + def diff(self, periods: int = 1) -> Series: + """ + First discrete difference of Series elements. + + Calculates the difference of a Series element compared with another + element in the Series (default is element in previous row). + + Parameters + ---------- + periods : int, default 1 + Periods to shift for calculating difference, accepts negative + values. + + Returns + ------- + Series + First differences of the Series. + + See Also + -------- + Series.pct_change: Percent change over given number of periods. + Series.shift: Shift index by desired number of periods with an + optional time freq. + DataFrame.diff: First discrete difference of object. + + Notes + ----- + For boolean dtypes, this uses :meth:`operator.xor` rather than + :meth:`operator.sub`. + The result is calculated according to current dtype in Series, + however dtype of the result is always float64. + + Examples + -------- + + Difference with previous row + + >>> s = pd.Series([1, 1, 2, 3, 5, 8]) + >>> s.diff() + 0 NaN + 1 0.0 + 2 1.0 + 3 1.0 + 4 2.0 + 5 3.0 + dtype: float64 + + Difference with 3rd previous row + + >>> s.diff(periods=3) + 0 NaN + 1 NaN + 2 NaN + 3 2.0 + 4 4.0 + 5 6.0 + dtype: float64 + + Difference with following row + + >>> s.diff(periods=-1) + 0 0.0 + 1 -1.0 + 2 -1.0 + 3 -2.0 + 4 -3.0 + 5 NaN + dtype: float64 + + Overflow in input dtype + + >>> s = pd.Series([1, 0], dtype=np.uint8) + >>> s.diff() + 0 NaN + 1 255.0 + dtype: float64 + """ + if not lib.is_integer(periods): + if not (is_float(periods) and periods.is_integer()): + raise ValueError("periods must be an integer") + result = algorithms.diff(self._values, periods) + return self._constructor( + result, index=self.index.view(), copy=False + ).__finalize__(self, method="diff") + + def autocorr(self, lag: int = 1) -> float: + """ + Compute the lag-N autocorrelation. + + This method computes the Pearson correlation between + the Series and its shifted self. + + Parameters + ---------- + lag : int, default 1 + Number of lags to apply before performing autocorrelation. + + Returns + ------- + float + The Pearson correlation between self and self.shift(lag). + + See Also + -------- + Series.corr : Compute the correlation between two Series. + Series.shift : Shift index by desired number of periods. + DataFrame.corr : Compute pairwise correlation of columns. + DataFrame.corrwith : Compute pairwise correlation between rows or + columns of two DataFrame objects. + + Notes + ----- + If the Pearson correlation is not well defined return 'NaN'. + + Examples + -------- + >>> s = pd.Series([0.25, 0.5, 0.2, -0.05]) + >>> s.autocorr() # doctest: +ELLIPSIS + 0.10355... + >>> s.autocorr(lag=2) # doctest: +ELLIPSIS + -0.99999... + + If the Pearson correlation is not well defined, then 'NaN' is returned. + + >>> s = pd.Series([1, 0, 0, 0]) + >>> s.autocorr() + nan + """ + return self.corr(cast(Series, self.shift(lag))) + + def dot(self, other: AnyArrayLike | DataFrame) -> Series | np.ndarray: + """ + Compute the dot product between the Series and the columns of other. + + This method computes the dot product between the Series and another + one, or the Series and each columns of a DataFrame, or the Series and + each columns of an array. + + It can also be called using `self @ other`. + + Parameters + ---------- + other : Series, DataFrame or array-like + The other object to compute the dot product with its columns. + + Returns + ------- + scalar, Series or numpy.ndarray + Return the dot product of the Series and other if other is a + Series, the Series of the dot product of Series and each rows of + other if other is a DataFrame or a numpy.ndarray between the Series + and each columns of the numpy array. + + See Also + -------- + DataFrame.dot: Compute the matrix product with the DataFrame. + Series.mul: Multiplication of series and other, element-wise. + + Notes + ----- + The Series and other has to share the same index if other is a Series + or a DataFrame. + + Examples + -------- + >>> s = pd.Series([0, 1, 2, 3]) + >>> other = pd.Series([-1, 2, -3, 4]) + >>> s.dot(other) + 8 + >>> s @ other + 8 + >>> df = pd.DataFrame([[0, 1], [-2, 3], [4, -5], [6, 7]]) + >>> s.dot(df) + 0 24 + 1 14 + dtype: int64 + >>> arr = np.array([[0, 1], [-2, 3], [4, -5], [6, 7]]) + >>> s.dot(arr) + array([24, 14]) + """ + if isinstance(other, (Series, ABCDataFrame)): + common = self.index.union(other.index) + if len(common) > len(self.index) or len(common) > len(other.index): + raise ValueError("matrices are not aligned") + + left = self.reindex(index=common) + right = other.reindex(index=common) + lvals = left.values + rvals = right.values + else: + lvals = self.values + rvals = np.asarray(other) + if lvals.shape[0] != rvals.shape[0]: + raise Exception( + f"Dot product shape mismatch, {lvals.shape} vs {rvals.shape}" + ) + + if isinstance(other, ABCDataFrame): + common_type = find_common_type([self.dtypes, *list(other.dtypes)]) + return self._constructor( + np.dot(lvals, rvals), index=other.columns, copy=False, dtype=common_type + ).__finalize__(self, method="dot") + elif isinstance(other, Series): + result = np.dot(lvals, rvals) + elif isinstance(rvals, np.ndarray): + result = np.dot(lvals, rvals) + else: # pragma: no cover + raise TypeError(f"unsupported type: {type(other)}") + return maybe_unbox_numpy_scalar(result) + + def __matmul__(self, other): + """ + Matrix multiplication using binary `@` operator. + """ + return self.dot(other) + + def __rmatmul__(self, other): + """ + Matrix multiplication using binary `@` operator. + """ + return self.dot(np.transpose(other)) + + # Signature of "searchsorted" incompatible with supertype "IndexOpsMixin" + def searchsorted( # type: ignore[override] + self, + value: NumpyValueArrayLike | ExtensionArray, + side: Literal["left", "right"] = "left", + sorter: NumpySorter | None = None, + ) -> npt.NDArray[np.intp] | np.intp: + """ + Find indices where elements should be inserted to maintain order. + + Find the indices into a sorted Series `self` such that, if the + corresponding elements in `value` were inserted before the indices, + the order of `self` would be preserved. + + .. note:: + The Series *must* be monotonically sorted, otherwise + wrong locations will likely be returned. Pandas does *not* + check this for you. + + Parameters + ---------- + value : array-like or scalar + Values to insert into `self`. + side : {'left', 'right'}, optional + If 'left', the index of the first suitable location found is given. + If 'right', return the last such index. If there is no suitable + index, return either 0 or N (where N is the length of `self`). + sorter : 1-D array-like, optional + Optional array of integer indices that sort `self` into ascending + order. They are typically the result of ``np.argsort``. + + Returns + ------- + int or array of int + A scalar or array of insertion points with the + same shape as `value`. + + See Also + -------- + sort_values : Sort by the values along either axis. + numpy.searchsorted : Similar method from NumPy. + + Notes + ----- + Binary search is used to find the required insertion points. + + Examples + -------- + >>> ser = pd.Series([1, 2, 3]) + >>> ser + 0 1 + 1 2 + 2 3 + dtype: int64 + >>> ser.searchsorted(4) + np.int64(3) + >>> ser.searchsorted([0, 4]) + array([0, 3]) + >>> ser.searchsorted([1, 3], side="left") + array([0, 2]) + >>> ser.searchsorted([1, 3], side="right") + array([1, 3]) + >>> ser = pd.Series(pd.to_datetime(["3/11/2000", "3/12/2000", "3/13/2000"])) + >>> ser + 0 2000-03-11 + 1 2000-03-12 + 2 2000-03-13 + dtype: datetime64[us] + >>> ser.searchsorted("3/14/2000") + np.int64(3) + >>> ser = pd.Categorical( + ... ["apple", "bread", "bread", "cheese", "milk"], ordered=True + ... ) + >>> ser + ['apple', 'bread', 'bread', 'cheese', 'milk'] + Categories (4, str): ['apple' < 'bread' < 'cheese' < 'milk'] + >>> ser.searchsorted("bread") + np.int64(1) + >>> ser.searchsorted(["bread"], side="right") + array([3]) + + If the values are not monotonically sorted, wrong locations + may be returned: + + >>> ser = pd.Series([2, 1, 3]) + >>> ser + 0 2 + 1 1 + 2 3 + dtype: int64 + >>> ser.searchsorted(1) # doctest: +SKIP + 0 # wrong result, correct would be 1 + """ + return base.IndexOpsMixin.searchsorted(self, value, side=side, sorter=sorter) + + # ------------------------------------------------------------------- + # Combination + + def _append_internal(self, to_append: Series, ignore_index: bool = False) -> Series: + from pandas.core.reshape.concat import concat + + return concat([self, to_append], ignore_index=ignore_index) + + def compare( + self, + other: Series, + align_axis: Axis = 1, + keep_shape: bool = False, + keep_equal: bool = False, + result_names: Suffixes = ("self", "other"), + ) -> DataFrame | Series: + """ + Compare to another Series and show the differences. + + Parameters + ---------- + other : Series + Object to compare with. + + align_axis : {{0 or 'index', 1 or 'columns'}}, default 1 + Determine which axis to align the comparison on. + + * 0, or 'index' : Resulting differences are stacked vertically + with rows drawn alternately from self and other. + * 1, or 'columns' : Resulting differences are aligned horizontally + with columns drawn alternately from self and other. + + keep_shape : bool, default False + If true, all rows and columns are kept. + Otherwise, only the ones with different values are kept. + + keep_equal : bool, default False + If true, the result keeps values that are equal. + Otherwise, equal values are shown as NaNs. + + result_names : tuple, default ('self', 'other') + Set the dataframes names in the comparison. + + Returns + ------- + Series or DataFrame + If axis is 0 or 'index' the result will be a Series. + The resulting index will be a MultiIndex with 'self' and 'other' + stacked alternately at the inner level. + + If axis is 1 or 'columns' the result will be a DataFrame. + It will have two columns namely 'self' and 'other'. + + See Also + -------- + DataFrame.compare : Compare with another DataFrame and show differences. + + Notes + ----- + Matching NaNs will not appear as a difference. + + Examples + -------- + >>> s1 = pd.Series(["a", "b", "c", "d", "e"]) + >>> s2 = pd.Series(["a", "a", "c", "b", "e"]) + + Align the differences on columns + + >>> s1.compare(s2) + self other + 1 b a + 3 d b + + Stack the differences on indices + + >>> s1.compare(s2, align_axis=0) + 1 self b + other a + 3 self d + other b + dtype: str + + Keep all original rows + + >>> s1.compare(s2, keep_shape=True) + self other + 0 NaN NaN + 1 b a + 2 NaN NaN + 3 d b + 4 NaN NaN + + Keep all original rows and also all original values + + >>> s1.compare(s2, keep_shape=True, keep_equal=True) + self other + 0 a a + 1 b a + 2 c c + 3 d b + 4 e e + """ + + return super().compare( + other=other, + align_axis=align_axis, + keep_shape=keep_shape, + keep_equal=keep_equal, + result_names=result_names, + ) + + def combine( + self, + other: Series | Hashable, + func: Callable[[Hashable, Hashable], Hashable], + fill_value: Hashable | None = None, + ) -> Series: + """ + Combine the Series with a Series or scalar according to `func`. + + Combine the Series and `other` using `func` to perform elementwise + selection for combined Series. + `fill_value` is assumed when value is not present at some index + from one of the two Series being combined. + + Parameters + ---------- + other : Series or scalar + The value(s) to be combined with the `Series`. + func : function + Function that takes two scalars as inputs and returns an element. + fill_value : scalar, optional + The value to assume when an index is missing from + one Series or the other. The default specifies to use the + appropriate NaN value for the underlying dtype of the Series. + + Returns + ------- + Series + The result of combining the Series with the other object. + + See Also + -------- + Series.combine_first : Combine Series values, choosing the calling + Series' values first. + + Examples + -------- + Consider 2 Datasets ``s1`` and ``s2`` containing + highest clocked speeds of different birds. + + >>> s1 = pd.Series({"falcon": 330.0, "eagle": 160.0}) + >>> s1 + falcon 330.0 + eagle 160.0 + dtype: float64 + >>> s2 = pd.Series({"falcon": 345.0, "eagle": 200.0, "duck": 30.0}) + >>> s2 + falcon 345.0 + eagle 200.0 + duck 30.0 + dtype: float64 + + Now, to combine the two datasets and view the highest speeds + of the birds across the two datasets + + >>> s1.combine(s2, max) + duck NaN + eagle 200.0 + falcon 345.0 + dtype: float64 + + In the previous example, the resulting value for duck is missing, + because the maximum of a NaN and a float is a NaN. + So, in the example, we set ``fill_value=0``, + so the maximum value returned will be the value from some dataset. + + >>> s1.combine(s2, max, fill_value=0) + duck 30.0 + eagle 200.0 + falcon 345.0 + dtype: float64 + """ + if fill_value is None: + fill_value = na_value_for_dtype(self.dtype, compat=False) + + if isinstance(other, Series): + # If other is a Series, result is based on union of Series, + # so do this element by element + new_index = self.index.union(other.index) + new_name = ops.get_op_result_name(self, other) + new_values = np.empty(len(new_index), dtype=object) + with np.errstate(all="ignore"): + for i, idx in enumerate(new_index): + lv = self.get(idx, fill_value) + rv = other.get(idx, fill_value) + new_values[i] = func(lv, rv) + else: + # Assume that other is a scalar, so apply the function for + # each element in the Series + new_index = self.index + new_values = np.empty(len(new_index), dtype=object) + with np.errstate(all="ignore"): + new_values[:] = [func(lv, other) for lv in self._values] + new_name = self.name + + res_values = self.array._cast_pointwise_result(new_values) + return self._constructor( + res_values, + dtype=res_values.dtype, + index=new_index, + name=new_name, + copy=False, + ) + + def combine_first(self, other) -> Series: + """ + Update null elements with value in the same location in 'other'. + + Combine two Series objects by filling null values in one Series with + non-null values from the other Series. Result index will be the union + of the two indexes. + + Parameters + ---------- + other : Series + The value(s) to be used for filling null values. + + Returns + ------- + Series + The result of combining the provided Series with the other object. + + See Also + -------- + Series.combine : Perform element-wise operation on two Series + using a given function. + + Examples + -------- + >>> s1 = pd.Series([1, np.nan]) + >>> s2 = pd.Series([3, 4, 5]) + >>> s1.combine_first(s2) + 0 1.0 + 1 4.0 + 2 5.0 + dtype: float64 + + Null values still persist if the location of that null value + does not exist in `other` + + >>> s1 = pd.Series({"falcon": np.nan, "eagle": 160.0}) + >>> s2 = pd.Series({"eagle": 200.0, "duck": 30.0}) + >>> s1.combine_first(s2) + duck 30.0 + eagle 160.0 + falcon NaN + dtype: float64 + """ + from pandas.core.reshape.concat import concat + + if self.dtype == other.dtype: + if self.index.equals(other.index): + return self.mask(self.isna(), other) + + new_index = self.index.union(other.index) + + this = self + # identify the index subset to keep for each series + keep_other = other.index.difference(this.index[notna(this)]) + keep_this = this.index.difference(keep_other) + + this = this.reindex(keep_this) + other = other.reindex(keep_other) + + if this.dtype.kind == "M" and other.dtype.kind != "M": + # TODO: try to match resos? + other = to_datetime(other) + warnings.warn( + # GH#62931 + "Silently casting non-datetime 'other' to datetime in " + "Series.combine_first is deprecated and will be removed " + "in a future version. Explicitly cast before calling " + "combine_first instead.", + Pandas4Warning, + stacklevel=find_stack_level(), + ) + + combined = concat([this, other]) + combined = combined.reindex(new_index) + return combined.__finalize__(self, method="combine_first") + + def update(self, other: Series | Sequence | Mapping) -> None: + """ + Modify Series in place using values from passed Series. + + Uses non-NA values from passed Series to make updates. Aligns + on index. + + Parameters + ---------- + other : Series, or object coercible into Series + Other Series that provides values to update the current Series. + + See Also + -------- + Series.combine : Perform element-wise operation on two Series + using a given function. + Series.transform: Modify a Series using a function. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.update(pd.Series([4, 5, 6])) + >>> s + 0 4 + 1 5 + 2 6 + dtype: int64 + + >>> s = pd.Series(["a", "b", "c"]) + >>> s.update(pd.Series(["d", "e"], index=[0, 2])) + >>> s + 0 d + 1 b + 2 e + dtype: str + + >>> s = pd.Series([1, 2, 3]) + >>> s.update(pd.Series([4, 5, 6, 7, 8])) + >>> s + 0 4 + 1 5 + 2 6 + dtype: int64 + + If ``other`` contains NaNs the corresponding values are not updated + in the original Series. + + >>> s = pd.Series([1, 2, 3]) + >>> s.update(pd.Series([4, np.nan, 6])) + >>> s + 0 4 + 1 2 + 2 6 + dtype: int64 + + ``other`` can also be a non-Series object type + that is coercible into a Series + + >>> s = pd.Series([1, 2, 3]) + >>> s.update([4, np.nan, 6]) + >>> s + 0 4 + 1 2 + 2 6 + dtype: int64 + + >>> s = pd.Series([1, 2, 3]) + >>> s.update({1: 9}) + >>> s + 0 1 + 1 9 + 2 3 + dtype: int64 + """ + if not CHAINED_WARNING_DISABLED: + if sys.getrefcount( + self + ) <= REF_COUNT_METHOD and not com.is_local_in_caller_frame(self): + warnings.warn( + _chained_assignment_method_update_msg, + ChainedAssignmentError, + stacklevel=2, + ) + + if not isinstance(other, Series): + other = Series(other) + + other = other.reindex_like(self) + mask = notna(other) + + self._mgr = self._mgr.putmask(mask=mask, new=other) + + # ---------------------------------------------------------------------- + # Reindexing, sorting + + @overload + def sort_values( + self, + *, + axis: Axis = ..., + ascending: bool | Sequence[bool] = ..., + inplace: Literal[False] = ..., + kind: SortKind = ..., + na_position: NaPosition = ..., + ignore_index: bool = ..., + key: ValueKeyFunc = ..., + ) -> Series: ... + + @overload + def sort_values( + self, + *, + axis: Axis = ..., + ascending: bool | Sequence[bool] = ..., + inplace: Literal[True], + kind: SortKind = ..., + na_position: NaPosition = ..., + ignore_index: bool = ..., + key: ValueKeyFunc = ..., + ) -> None: ... + + @overload + def sort_values( + self, + *, + axis: Axis = ..., + ascending: bool | Sequence[bool] = ..., + inplace: bool = ..., + kind: SortKind = ..., + na_position: NaPosition = ..., + ignore_index: bool = ..., + key: ValueKeyFunc = ..., + ) -> Series | None: ... + + def sort_values( + self, + *, + axis: Axis = 0, + ascending: bool | Sequence[bool] = True, + inplace: bool = False, + kind: SortKind = "quicksort", + na_position: NaPosition = "last", + ignore_index: bool = False, + key: ValueKeyFunc | None = None, + ) -> Series | None: + """ + Sort by the values. + + Sort a Series in ascending or descending order by some + criterion. + + Parameters + ---------- + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + ascending : bool or list of bools, default True + If True, sort values in ascending order, otherwise descending. + inplace : bool, default False + If True, perform operation in-place. + kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, default 'quicksort' + Choice of sorting algorithm. See also :func:`numpy.sort` for more + information. 'mergesort' and 'stable' are the only stable algorithms. + na_position : {'first' or 'last'}, default 'last' + Argument 'first' puts NaNs at the beginning, 'last' puts NaNs at + the end. + ignore_index : bool, default False + If True, the resulting axis will be labeled 0, 1, …, n - 1. + key : callable, optional + If not None, apply the key function to the series values + before sorting. This is similar to the `key` argument in the + builtin :meth:`sorted` function, with the notable difference that + this `key` function should be *vectorized*. It should expect a + ``Series`` and return an array-like. + + Returns + ------- + Series or None + Series ordered by values or None if ``inplace=True``. + + See Also + -------- + Series.sort_index : Sort by the Series indices. + DataFrame.sort_values : Sort DataFrame by the values along either axis. + DataFrame.sort_index : Sort DataFrame by indices. + + Examples + -------- + >>> s = pd.Series([np.nan, 1, 3, 10, 5]) + >>> s + 0 NaN + 1 1.0 + 2 3.0 + 3 10.0 + 4 5.0 + dtype: float64 + + Sort values ascending order (default behavior) + + >>> s.sort_values(ascending=True) + 1 1.0 + 2 3.0 + 4 5.0 + 3 10.0 + 0 NaN + dtype: float64 + + Sort values descending order + + >>> s.sort_values(ascending=False) + 3 10.0 + 4 5.0 + 2 3.0 + 1 1.0 + 0 NaN + dtype: float64 + + Sort values putting NAs first + + >>> s.sort_values(na_position="first") + 0 NaN + 1 1.0 + 2 3.0 + 4 5.0 + 3 10.0 + dtype: float64 + + Sort a series of strings + + >>> s = pd.Series(["z", "b", "d", "a", "c"]) + >>> s + 0 z + 1 b + 2 d + 3 a + 4 c + dtype: str + + >>> s.sort_values() + 3 a + 1 b + 4 c + 2 d + 0 z + dtype: str + + Sort using a key function. Your `key` function will be + given the ``Series`` of values and should return an array-like. + + >>> s = pd.Series(["a", "B", "c", "D", "e"]) + >>> s.sort_values() + 1 B + 3 D + 0 a + 2 c + 4 e + dtype: str + >>> s.sort_values(key=lambda x: x.str.lower()) + 0 a + 1 B + 2 c + 3 D + 4 e + dtype: str + + NumPy ufuncs work well here. For example, we can + sort by the ``sin`` of the value + + >>> s = pd.Series([-4, -2, 0, 2, 4]) + >>> s.sort_values(key=np.sin) + 1 -2 + 4 4 + 2 0 + 0 -4 + 3 2 + dtype: int64 + + More complicated user-defined functions can be used, + as long as they expect a Series and return an array-like + + >>> s.sort_values(key=lambda x: (np.tan(x.cumsum()))) + 0 -4 + 3 2 + 4 4 + 1 -2 + 2 0 + dtype: int64 + """ + inplace = validate_bool_kwarg(inplace, "inplace") + # Validate the axis parameter + self._get_axis_number(axis) + + if is_list_like(ascending): + ascending = cast(Sequence[bool], ascending) + if len(ascending) != 1: + raise ValueError( + f"Length of ascending ({len(ascending)}) must be 1 for Series" + ) + ascending = ascending[0] + + ascending = validate_ascending(ascending) + + if na_position not in ["first", "last"]: + raise ValueError(f"invalid na_position: {na_position}") + + # GH 35922. Make sorting stable by leveraging nargsort + if key: + values_to_sort = cast(Series, ensure_key_mapped(self, key))._values + else: + values_to_sort = self._values + sorted_index = nargsort(values_to_sort, kind, bool(ascending), na_position) + + if is_range_indexer(sorted_index, len(sorted_index)): + if inplace: + return self._update_inplace(self) + return self.copy(deep=False) + + result = self._constructor( + self._values[sorted_index], index=self.index[sorted_index], copy=False + ) + + if ignore_index: + result.index = default_index(len(sorted_index)) + + if not inplace: + return result.__finalize__(self, method="sort_values") + self._update_inplace(result) + return None + + @overload + def sort_index( + self, + *, + axis: Axis = ..., + level: IndexLabel = ..., + ascending: bool | Sequence[bool] = ..., + inplace: Literal[True], + kind: SortKind = ..., + na_position: NaPosition = ..., + sort_remaining: bool = ..., + ignore_index: bool = ..., + key: IndexKeyFunc = ..., + ) -> None: ... + + @overload + def sort_index( + self, + *, + axis: Axis = ..., + level: IndexLabel = ..., + ascending: bool | Sequence[bool] = ..., + inplace: Literal[False] = ..., + kind: SortKind = ..., + na_position: NaPosition = ..., + sort_remaining: bool = ..., + ignore_index: bool = ..., + key: IndexKeyFunc = ..., + ) -> Series: ... + + @overload + def sort_index( + self, + *, + axis: Axis = ..., + level: IndexLabel = ..., + ascending: bool | Sequence[bool] = ..., + inplace: bool = ..., + kind: SortKind = ..., + na_position: NaPosition = ..., + sort_remaining: bool = ..., + ignore_index: bool = ..., + key: IndexKeyFunc = ..., + ) -> Series | None: ... + + def sort_index( + self, + *, + axis: Axis = 0, + level: IndexLabel | None = None, + ascending: bool | Sequence[bool] = True, + inplace: bool = False, + kind: SortKind = "quicksort", + na_position: NaPosition = "last", + sort_remaining: bool = True, + ignore_index: bool = False, + key: IndexKeyFunc | None = None, + ) -> Series | None: + """ + Sort Series by index labels. + + Returns a new Series sorted by label if `inplace` argument is + ``False``, otherwise updates the original series and returns None. + + Parameters + ---------- + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + level : int, optional + If not None, sort on values in specified index level(s). + ascending : bool or list-like of bools, default True + Sort ascending vs. descending. When the index is a MultiIndex the + sort direction can be controlled for each level individually. + inplace : bool, default False + If True, perform operation in-place. + kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, default 'quicksort' + Choice of sorting algorithm. See also :func:`numpy.sort` for more + information. 'mergesort' and 'stable' are the only stable algorithms. For + DataFrames, this option is only applied when sorting on a single + column or label. + na_position : {'first', 'last'}, default 'last' + If 'first' puts NaNs at the beginning, 'last' puts NaNs at the end. + Not implemented for MultiIndex. + sort_remaining : bool, default True + If True and sorting by level and index is multilevel, sort by other + levels too (in order) after sorting by specified level. + ignore_index : bool, default False + If True, the resulting axis will be labeled 0, 1, …, n - 1. + key : callable, optional + If not None, apply the key function to the index values + before sorting. This is similar to the `key` argument in the + builtin :meth:`sorted` function, with the notable difference that + this `key` function should be *vectorized*. It should expect an + ``Index`` and return an ``Index`` of the same shape. + + Returns + ------- + Series or None + The original Series sorted by the labels or None if ``inplace=True``. + + See Also + -------- + DataFrame.sort_index: Sort DataFrame by the index. + DataFrame.sort_values: Sort DataFrame by the value. + Series.sort_values : Sort Series by the value. + + Examples + -------- + >>> s = pd.Series(["a", "b", "c", "d"], index=[3, 2, 1, 4]) + >>> s.sort_index() + 1 c + 2 b + 3 a + 4 d + dtype: str + + Sort Descending + + >>> s.sort_index(ascending=False) + 4 d + 3 a + 2 b + 1 c + dtype: str + + By default NaNs are put at the end, but use `na_position` to place + them at the beginning + + >>> s = pd.Series(["a", "b", "c", "d"], index=[3, 2, 1, np.nan]) + >>> s.sort_index(na_position="first") + NaN d + 1.0 c + 2.0 b + 3.0 a + dtype: str + + Specify index level to sort + + >>> arrays = [ + ... np.array(["qux", "qux", "foo", "foo", "baz", "baz", "bar", "bar"]), + ... np.array(["two", "one", "two", "one", "two", "one", "two", "one"]), + ... ] + >>> s = pd.Series([1, 2, 3, 4, 5, 6, 7, 8], index=arrays) + >>> s.sort_index(level=1) + bar one 8 + baz one 6 + foo one 4 + qux one 2 + bar two 7 + baz two 5 + foo two 3 + qux two 1 + dtype: int64 + + Does not sort by remaining levels when sorting by levels + + >>> s.sort_index(level=1, sort_remaining=False) + qux one 2 + foo one 4 + baz one 6 + bar one 8 + qux two 1 + foo two 3 + baz two 5 + bar two 7 + dtype: int64 + + Apply a key function before sorting + + >>> s = pd.Series([1, 2, 3, 4], index=["A", "b", "C", "d"]) + >>> s.sort_index(key=lambda x: x.str.lower()) + A 1 + b 2 + C 3 + d 4 + dtype: int64 + """ + + return super().sort_index( + axis=axis, + level=level, + ascending=ascending, + inplace=inplace, + kind=kind, + na_position=na_position, + sort_remaining=sort_remaining, + ignore_index=ignore_index, + key=key, + ) + + def argsort( + self, + axis: Axis = 0, + kind: SortKind = "quicksort", + order: None = None, + stable: None = None, + ) -> Series: + """ + Return the integer indices that would sort the Series values. + + Override ndarray.argsort. Argsorts the value, omitting NA/null values, + and places the result in the same locations as the non-NA values. + + Parameters + ---------- + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + kind : {'mergesort', 'quicksort', 'heapsort', 'stable'}, default 'quicksort' + Choice of sorting algorithm. See :func:`numpy.sort` for more + information. 'mergesort' and 'stable' are the only stable algorithms. + order : None + Has no effect but is accepted for compatibility with numpy. + stable : None + Has no effect but is accepted for compatibility with numpy. + + Returns + ------- + Series[np.intp] + Positions of values within the sort order with -1 indicating + nan values. + + See Also + -------- + numpy.ndarray.argsort : Returns the indices that would sort this array. + + Examples + -------- + >>> s = pd.Series([3, 2, 1]) + >>> s.argsort() + 0 2 + 1 1 + 2 0 + dtype: int64 + """ + if axis != -1: + # GH#54257 We allow -1 here so that np.argsort(series) works + self._get_axis_number(axis) + + result = self.array.argsort(kind=kind) + + res = self._constructor( + result, index=self.index, name=self.name, dtype=np.intp, copy=False + ) + return res.__finalize__(self, method="argsort") + + def nlargest( + self, n: int = 5, keep: Literal["first", "last", "all"] = "first" + ) -> Series: + """ + Return the largest `n` elements. + + Parameters + ---------- + n : int, default 5 + Return this many descending sorted values. + keep : {'first', 'last', 'all'}, default 'first' + When there are duplicate values that cannot all fit in a + Series of `n` elements: + + - ``first`` : return the first `n` occurrences in order + of appearance. + - ``last`` : return the last `n` occurrences in reverse + order of appearance. + - ``all`` : keep all occurrences. This can result in a Series of + size larger than `n`. + + Returns + ------- + Series + The `n` largest values in the Series, sorted in decreasing order. + + See Also + -------- + Series.nsmallest: Get the `n` smallest elements. + Series.sort_values: Sort Series by values. + Series.head: Return the first `n` rows. + + Notes + ----- + Faster than ``.sort_values(ascending=False).head(n)`` for small `n` + relative to the size of the ``Series`` object. + + Examples + -------- + >>> countries_population = { + ... "Italy": 59000000, + ... "France": 65000000, + ... "Malta": 434000, + ... "Maldives": 434000, + ... "Brunei": 434000, + ... "Iceland": 337000, + ... "Nauru": 11300, + ... "Tuvalu": 11300, + ... "Anguilla": 11300, + ... "Montserrat": 5200, + ... } + >>> s = pd.Series(countries_population) + >>> s + Italy 59000000 + France 65000000 + Malta 434000 + Maldives 434000 + Brunei 434000 + Iceland 337000 + Nauru 11300 + Tuvalu 11300 + Anguilla 11300 + Montserrat 5200 + dtype: int64 + + The `n` largest elements where ``n=5`` by default. + + >>> s.nlargest() + France 65000000 + Italy 59000000 + Malta 434000 + Maldives 434000 + Brunei 434000 + dtype: int64 + + The `n` largest elements where ``n=3``. Default `keep` value is 'first' + so Malta will be kept. + + >>> s.nlargest(3) + France 65000000 + Italy 59000000 + Malta 434000 + dtype: int64 + + The `n` largest elements where ``n=3`` and keeping the last duplicates. + Brunei will be kept since it is the last with value 434000 based on + the index order. + + >>> s.nlargest(3, keep="last") + France 65000000 + Italy 59000000 + Brunei 434000 + dtype: int64 + + The `n` largest elements where ``n=3`` with all duplicates kept. Note + that the returned Series has five elements due to the three duplicates. + + >>> s.nlargest(3, keep="all") + France 65000000 + Italy 59000000 + Malta 434000 + Maldives 434000 + Brunei 434000 + dtype: int64 + """ + return selectn.SelectNSeries(self, n=n, keep=keep).nlargest() + + def nsmallest( + self, n: int = 5, keep: Literal["first", "last", "all"] = "first" + ) -> Series: + """ + Return the smallest `n` elements. + + Parameters + ---------- + n : int, default 5 + Return this many ascending sorted values. + keep : {'first', 'last', 'all'}, default 'first' + When there are duplicate values that cannot all fit in a + Series of `n` elements: + + - ``first`` : return the first `n` occurrences in order + of appearance. + - ``last`` : return the last `n` occurrences in reverse + order of appearance. + - ``all`` : keep all occurrences. This can result in a Series of + size larger than `n`. + + Returns + ------- + Series + The `n` smallest values in the Series, sorted in increasing order. + + See Also + -------- + Series.nlargest: Get the `n` largest elements. + Series.sort_values: Sort Series by values. + Series.head: Return the first `n` rows. + + Notes + ----- + Faster than ``.sort_values().head(n)`` for small `n` relative to + the size of the ``Series`` object. + + Examples + -------- + >>> countries_population = { + ... "Italy": 59000000, + ... "France": 65000000, + ... "Brunei": 434000, + ... "Malta": 434000, + ... "Maldives": 434000, + ... "Iceland": 337000, + ... "Nauru": 11300, + ... "Tuvalu": 11300, + ... "Anguilla": 11300, + ... "Montserrat": 5200, + ... } + >>> s = pd.Series(countries_population) + >>> s + Italy 59000000 + France 65000000 + Brunei 434000 + Malta 434000 + Maldives 434000 + Iceland 337000 + Nauru 11300 + Tuvalu 11300 + Anguilla 11300 + Montserrat 5200 + dtype: int64 + + The `n` smallest elements where ``n=5`` by default. + + >>> s.nsmallest() + Montserrat 5200 + Nauru 11300 + Tuvalu 11300 + Anguilla 11300 + Iceland 337000 + dtype: int64 + + The `n` smallest elements where ``n=3``. Default `keep` value is + 'first' so Nauru and Tuvalu will be kept. + + >>> s.nsmallest(3) + Montserrat 5200 + Nauru 11300 + Tuvalu 11300 + dtype: int64 + + The `n` smallest elements where ``n=3`` and keeping the last + duplicates. Anguilla and Tuvalu will be kept since they are the last + with value 11300 based on the index order. + + >>> s.nsmallest(3, keep="last") + Montserrat 5200 + Anguilla 11300 + Tuvalu 11300 + dtype: int64 + + The `n` smallest elements where ``n=3`` with all duplicates kept. Note + that the returned Series has four elements due to the three duplicates. + + >>> s.nsmallest(3, keep="all") + Montserrat 5200 + Nauru 11300 + Tuvalu 11300 + Anguilla 11300 + dtype: int64 + """ + return selectn.SelectNSeries(self, n=n, keep=keep).nsmallest() + + def swaplevel( + self, i: Level = -2, j: Level = -1, copy: bool | lib.NoDefault = lib.no_default + ) -> Series: + """ + Swap levels i and j in a :class:`MultiIndex`. + + Default is to swap the two innermost levels of the index. + + Parameters + ---------- + i, j : int or str + Levels of the indices to be swapped. Can pass level name as string. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + Returns + ------- + Series + Series with levels swapped in MultiIndex. + + See Also + -------- + DataFrame.swaplevel : Swap levels i and j in a :class:`DataFrame`. + Series.reorder_levels : Rearrange index levels using input order. + MultiIndex.swaplevel : Swap levels i and j in a :class:`MultiIndex`. + + Examples + -------- + >>> s = pd.Series( + ... ["A", "B", "A", "C"], + ... index=[ + ... ["Final exam", "Final exam", "Coursework", "Coursework"], + ... ["History", "Geography", "History", "Geography"], + ... ["January", "February", "March", "April"], + ... ], + ... ) + >>> s + Final exam History January A + Geography February B + Coursework History March A + Geography April C + dtype: str + + In the following example, we will swap the levels of the indices. + Here, we will swap the levels column-wise, but levels can be swapped row-wise + in a similar manner. Note that column-wise is the default behavior. + By not supplying any arguments for i and j, we swap the last and second to + last indices. + + >>> s.swaplevel() + Final exam January History A + February Geography B + Coursework March History A + April Geography C + dtype: str + + By supplying one argument, we can choose which index to swap the last + index with. We can for example swap the first index with the last one as + follows. + + >>> s.swaplevel(0) + January History Final exam A + February Geography Final exam B + March History Coursework A + April Geography Coursework C + dtype: str + + We can also define explicitly which indices we want to swap by supplying values + for both i and j. Here, we for example swap the first and second indices. + + >>> s.swaplevel(0, 1) + History Final exam January A + Geography Final exam February B + History Coursework March A + Geography Coursework April C + dtype: str + """ + self._check_copy_deprecation(copy) + assert isinstance(self.index, MultiIndex) + result = self.copy(deep=False) + result.index = self.index.swaplevel(i, j) + return result + + def reorder_levels(self, order: Sequence[Level]) -> Series: + """ + Rearrange index levels using input order. + + May not drop or duplicate levels. + + Parameters + ---------- + order : list of int representing new level order + Reference level by number or key. + + Returns + ------- + Series + Type of caller with index as MultiIndex (new object). + + See Also + -------- + DataFrame.reorder_levels : Rearrange index or column levels using + input ``order``. + + Examples + -------- + >>> arrays = [ + ... np.array(["dog", "dog", "cat", "cat", "bird", "bird"]), + ... np.array(["white", "black", "white", "black", "white", "black"]), + ... ] + >>> s = pd.Series([1, 2, 3, 3, 5, 2], index=arrays) + >>> s + dog white 1 + black 2 + cat white 3 + black 3 + bird white 5 + black 2 + dtype: int64 + >>> s.reorder_levels([1, 0]) + white dog 1 + black dog 2 + white cat 3 + black cat 3 + white bird 5 + black bird 2 + dtype: int64 + """ + if not isinstance(self.index, MultiIndex): # pragma: no cover + raise Exception("Can only reorder levels on a hierarchical axis.") + + result = self.copy(deep=False) + assert isinstance(result.index, MultiIndex) + result.index = result.index.reorder_levels(order) + return result + + def explode(self, ignore_index: bool = False) -> Series: + """ + Transform each element of a list-like to a row. + + Parameters + ---------- + ignore_index : bool, default False + If True, the resulting index will be labeled 0, 1, …, n - 1. + + Returns + ------- + Series + Exploded lists to rows; index will be duplicated for these rows. + + See Also + -------- + Series.str.split : Split string values on specified separator. + Series.unstack : Unstack, a.k.a. pivot, Series with MultiIndex + to produce DataFrame. + DataFrame.melt : Unpivot a DataFrame from wide format to long format. + DataFrame.explode : Explode a DataFrame from list-like + columns to long format. + + Notes + ----- + This routine will explode list-likes including lists, tuples, sets, + Series, and np.ndarray. The result dtype of the subset rows will + be object. Scalars will be returned unchanged, and empty list-likes will + result in an np.nan for that row. In addition, the ordering of elements in + the output will be non-deterministic when exploding sets. + + Reference :ref:`the user guide ` for more examples. + + Examples + -------- + >>> s = pd.Series([[1, 2, 3], "foo", [], [3, 4]]) + >>> s + 0 [1, 2, 3] + 1 foo + 2 [] + 3 [3, 4] + dtype: object + + >>> s.explode() + 0 1 + 0 2 + 0 3 + 1 foo + 2 NaN + 3 3 + 3 4 + dtype: object + """ + if isinstance(self.dtype, ExtensionDtype): + values, counts = self._values._explode() + elif len(self) and is_object_dtype(self.dtype): + values, counts = reshape.explode(np.asarray(self._values)) + else: + result = self.copy() + return result.reset_index(drop=True) if ignore_index else result + + if ignore_index: + index: Index = default_index(len(values)) + else: + index = self.index.repeat(counts) + + return self._constructor(values, index=index, name=self.name, copy=False) + + def unstack( + self, + level: IndexLabel = -1, + fill_value: Hashable | None = None, + sort: bool = True, + ) -> DataFrame: + """ + Unstack, also known as pivot, Series with MultiIndex to produce DataFrame. + + Parameters + ---------- + level : int, str, or list of these, default last level + Level(s) to unstack, can pass level name. + fill_value : scalar value, default None + Value to use when replacing NaN values. + sort : bool, default True + Sort the level(s) in the resulting MultiIndex columns. + + Returns + ------- + DataFrame + Unstacked Series. + + See Also + -------- + DataFrame.unstack : Pivot the MultiIndex of a DataFrame. + + Notes + ----- + Reference :ref:`the user guide ` for more examples. + + Examples + -------- + >>> s = pd.Series( + ... [1, 2, 3, 4], + ... index=pd.MultiIndex.from_product([["one", "two"], ["a", "b"]]), + ... ) + >>> s + one a 1 + b 2 + two a 3 + b 4 + dtype: int64 + + >>> s.unstack(level=-1) + a b + one 1 2 + two 3 4 + + >>> s.unstack(level=0) + one two + a 1 3 + b 2 4 + """ + from pandas.core.reshape.reshape import unstack + + return unstack(self, level, fill_value, sort) + + # ---------------------------------------------------------------------- + # function application + + def map( + self, + func: Callable | Mapping | Series | None = None, + na_action: Literal["ignore"] | None = None, + engine: Callable | None = None, + **kwargs, + ) -> Series: + """ + Map values of Series according to an input mapping or function. + + Used for substituting each value in a Series with another value, + that may be derived from a function, a ``dict`` or + a :class:`Series`. + + Parameters + ---------- + func : function, collections.abc.Mapping subclass or Series + Function or mapping correspondence. + na_action : {None, 'ignore'}, default None + If 'ignore', propagate NaN values, without passing them to the + mapping correspondence. + engine : decorator, optional + Choose the execution engine to use to run the function. Only used for + functions. If ``map`` is called with a mapping or ``Series``, an + exception will be raised. If ``engine`` is not provided the function will + be executed by the regular Python interpreter. + + Options include JIT compilers such as Numba, Bodo or Blosc2, which in some + cases can speed up the execution. To use an executor you can provide the + decorators ``numba.jit``, ``numba.njit``, ``bodo.jit`` or ``blosc2.jit``. + You can also provide the decorator with parameters, like + ``numba.jit(nogit=True)``. + + Not all functions can be executed with all execution engines. In general, + JIT compilers will require type stability in the function (no variable + should change data type during the execution). And not all pandas and + NumPy APIs are supported. Check the engine documentation for limitations. + + .. versionadded:: 3.0.0 + + **kwargs + Additional keyword arguments to pass as keywords arguments to + `arg`. + + .. versionadded:: 3.0.0 + + Returns + ------- + Series + Same index as caller. + + See Also + -------- + Series.apply : For applying more complex functions on a Series. + Series.replace: Replace values given in `to_replace` with `value`. + DataFrame.apply : Apply a function row-/column-wise. + DataFrame.map : Apply a function elementwise on a whole DataFrame. + + Notes + ----- + When ``arg`` is a dictionary, values in Series that are not in the + dictionary (as keys) are converted to ``NaN``. However, if the + dictionary is a ``dict`` subclass that defines ``__missing__`` (i.e. + provides a method for default values), then this default is used + rather than ``NaN``. + + Examples + -------- + >>> s = pd.Series(["cat", "dog", np.nan, "rabbit"]) + >>> s + 0 cat + 1 dog + 2 NaN + 3 rabbit + dtype: str + + ``map`` accepts a ``dict`` or a ``Series``. Values that are not found + in the ``dict`` are converted to ``NaN``, unless the dict has a default + value (e.g. ``defaultdict``): + + >>> s.map({"cat": "kitten", "dog": "puppy"}) + 0 kitten + 1 puppy + 2 NaN + 3 NaN + dtype: str + + It also accepts a function: + + >>> s.map("I am a {}".format) + 0 I am a cat + 1 I am a dog + 2 I am a nan + 3 I am a rabbit + dtype: str + + To avoid applying the function to missing values (and keep them as + ``NaN``) ``na_action='ignore'`` can be used: + + >>> s.map("I am a {}".format, na_action="ignore") + 0 I am a cat + 1 I am a dog + 2 NaN + 3 I am a rabbit + dtype: str + + For categorical data, the function is only applied to the categories: + + >>> s = pd.Series(list("cabaa")) + >>> s.map(print) + c + a + b + a + a + 0 None + 1 None + 2 None + 3 None + 4 None + dtype: object + + >>> s_cat = s.astype("category") + >>> s_cat.map(print) # function called once per unique category + a + b + c + 0 None + 1 None + 2 None + 3 None + 4 None + dtype: object + """ + if func is None: + if "arg" in kwargs: + # `.map(arg=my_func)` + func = kwargs.pop("arg") + # https://github.com/pandas-dev/pandas/pull/61264 + warnings.warn( + "The parameter `arg` has been renamed to `func`, and it " + "will stop being supported in a future version of pandas.", + Pandas4Warning, + stacklevel=find_stack_level(), + ) + else: + raise ValueError("The `func` parameter is required") + + if engine is not None: + if not callable(func): + raise ValueError( + "The engine argument can only be specified when func is a function" + ) + if not hasattr(engine, "__pandas_udf__"): + raise ValueError(f"Not a valid engine: {engine!r}") + result = engine.__pandas_udf__.map( # type: ignore[attr-defined] + data=self, + func=func, + args=(), + kwargs=kwargs, + decorator=engine, + skip_na=na_action == "ignore", + ) + if not isinstance(result, Series): + result = Series(result, index=self.index, name=self.name) + return result.__finalize__(self, method="map") + + if callable(func): + func = functools.partial(func, **kwargs) + new_values = self._map_values(func, na_action=na_action) + return self._constructor(new_values, index=self.index, copy=False).__finalize__( + self, method="map" + ) + + def _gotitem(self, key, ndim, subset=None) -> Self: + """ + Sub-classes to define. Return a sliced object. + + Parameters + ---------- + key : string / list of selections + ndim : {1, 2} + Requested ndim of result. + subset : object, default None + Subset to act on. + """ + return self + + _agg_see_also_doc = dedent( + """ + See Also + -------- + Series.apply : Invoke function on a Series. + Series.transform : Transform function producing a Series with like indexes. + """ + ) + + _agg_examples_doc = dedent( + """ + Examples + -------- + >>> s = pd.Series([1, 2, 3, 4]) + >>> s + 0 1 + 1 2 + 2 3 + 3 4 + dtype: int64 + + >>> s.agg('min') + 1 + + >>> s.agg(['min', 'max']) + min 1 + max 4 + dtype: int64 + """ + ) + + def aggregate(self, func=None, axis: Axis = 0, *args, **kwargs): + """ + Aggregate using one or more operations over the specified axis. + + Parameters + ---------- + func : function, str, list or dict + Function to use for aggregating the data. If a function, must either + work when passed a Series or when passed to Series.apply. + + Accepted combinations are: + + - function + - string function name + - list of functions and/or function names, e.g. ``[np.sum, 'mean']`` + - dict of axis labels -> functions, function names or list of such. + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + *args + Positional arguments to pass to `func`. + **kwargs + Keyword arguments to pass to `func`. + + Returns + ------- + scalar, Series or DataFrame + The return can be: + + * scalar : when Series.agg is called with single function + * Series : when DataFrame.agg is called with a single function + * DataFrame : when DataFrame.agg is called with several functions + + See Also + -------- + Series.apply : Invoke function on a Series. + Series.transform : Transform function producing a Series with like indexes. + + Notes + ----- + The aggregation operations are always performed over an axis, either the + index (default) or the column axis. This behavior is different from + `numpy` aggregation functions (`mean`, `median`, `prod`, `sum`, `std`, + `var`), where the default is to compute the aggregation of the flattened + array, e.g., ``numpy.mean(arr_2d)`` as opposed to + ``numpy.mean(arr_2d, axis=0)``. + + `agg` is an alias for `aggregate`. Use the alias. + + Functions that mutate the passed object can produce unexpected + behavior or errors and are not supported. See :ref:`gotchas.udf-mutation` + for more details. + + A passed user-defined-function will be passed a Series for evaluation. + + If ``func`` defines an index relabeling, ``axis`` must be ``0`` or ``index``. + + Examples + -------- + >>> s = pd.Series([1, 2, 3, 4]) + >>> s + 0 1 + 1 2 + 2 3 + 3 4 + dtype: int64 + + >>> s.agg("min") + 1 + + >>> s.agg(["min", "max"]) + min 1 + max 4 + dtype: int64 + """ + + # Validate the axis parameter + self._get_axis_number(axis) + + # if func is None, will switch to user-provided "named aggregation" kwargs + if func is None: + func = dict(kwargs.items()) + + op = SeriesApply(self, func, args=args, kwargs=kwargs) + result = op.agg() + return result + + agg = aggregate + + def transform( + self, func: AggFuncType, axis: Axis = 0, *args, **kwargs + ) -> DataFrame | Series: + """ + Call ``func`` on self producing a Series with the same axis shape as self. + + Parameters + ---------- + func : function, str, list-like or dict-like + Function to use for transforming the data. If a function, must either + work when passed a Series or when passed to Series.apply. If func + is both list-like and dict-like, dict-like behavior takes precedence. + + Accepted combinations are: + + - function + - string function name + - list-like of functions and/or function names, e.g. ``[np.exp, 'sqrt']`` + - dict-like of axis labels -> functions, function names or list-like of such + + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + + *args + Positional arguments to pass to `func`. + **kwargs + Keyword arguments to pass to `func`. + + Returns + ------- + Series + A Series that must have the same length as self. + + Raises + ------ + ValueError : If the returned Series has a different length than self. + + See Also + -------- + Series.agg : Only perform aggregating type operations. + Series.apply : Invoke function on a Series. + + Notes + ----- + Functions that mutate the passed object can produce unexpected + behavior or errors and are not supported. See :ref:`gotchas.udf-mutation` + for more details. + + Examples + -------- + >>> df = pd.DataFrame({"A": range(3), "B": range(1, 4)}) + >>> df + A B + 0 0 1 + 1 1 2 + 2 2 3 + >>> df.transform(lambda x: x + 1) + A B + 0 1 2 + 1 2 3 + 2 3 4 + + Even though the resulting Series must have the same length as the + input Series, it is possible to provide several input functions: + + >>> s = pd.Series(range(3)) + >>> s + 0 0 + 1 1 + 2 2 + dtype: int64 + >>> s.transform([np.sqrt, np.exp]) + sqrt exp + 0 0.000000 1.000000 + 1 1.000000 2.718282 + 2 1.414214 7.389056 + + You can call transform on a GroupBy object: + + >>> df = pd.DataFrame( + ... { + ... "Date": [ + ... "2015-05-08", + ... "2015-05-07", + ... "2015-05-06", + ... "2015-05-05", + ... "2015-05-08", + ... "2015-05-07", + ... "2015-05-06", + ... "2015-05-05", + ... ], + ... "Data": [5, 8, 6, 1, 50, 100, 60, 120], + ... } + ... ) + >>> df + Date Data + 0 2015-05-08 5 + 1 2015-05-07 8 + 2 2015-05-06 6 + 3 2015-05-05 1 + 4 2015-05-08 50 + 5 2015-05-07 100 + 6 2015-05-06 60 + 7 2015-05-05 120 + >>> df.groupby("Date")["Data"].transform("sum") + 0 55 + 1 108 + 2 66 + 3 121 + 4 55 + 5 108 + 6 66 + 7 121 + Name: Data, dtype: int64 + + >>> df = pd.DataFrame( + ... { + ... "c": [1, 1, 1, 2, 2, 2, 2], + ... "type": ["m", "n", "o", "m", "m", "n", "n"], + ... } + ... ) + >>> df + c type + 0 1 m + 1 1 n + 2 1 o + 3 2 m + 4 2 m + 5 2 n + 6 2 n + >>> df["size"] = df.groupby("c")["type"].transform(len) + >>> df + c type size + 0 1 m 3 + 1 1 n 3 + 2 1 o 3 + 3 2 m 4 + 4 2 m 4 + 5 2 n 4 + 6 2 n 4 + """ + # Validate axis argument + self._get_axis_number(axis) + ser = self.copy(deep=False) + result = SeriesApply(ser, func=func, args=args, kwargs=kwargs).transform() + return result + + def apply( + self, + func: AggFuncType, + args: tuple[Any, ...] = (), + *, + by_row: Literal[False, "compat"] = "compat", + **kwargs, + ) -> DataFrame | Series: + """ + Invoke function on values of Series. + + Can be ufunc (a NumPy function that applies to the entire Series) + or a Python function that only works on single values. + + Parameters + ---------- + func : function + Python function or NumPy ufunc to apply. + args : tuple + Positional arguments passed to func after the series value. + by_row : False or "compat", default "compat" + If ``"compat"`` and func is a callable, func will be passed each element of + the Series, like ``Series.map``. If func is a list or dict of + callables, will first try to translate each func into pandas methods. If + that doesn't work, will try call to apply again with ``by_row="compat"`` + and if that fails, will call apply again with ``by_row=False`` + (backward compatible). + If False, the func will be passed the whole Series at once. + + ``by_row`` has no effect when ``func`` is a string. + + .. versionadded:: 2.1.0 + **kwargs + Additional keyword arguments passed to func. + + Returns + ------- + Series or DataFrame + If func returns a Series object the result will be a DataFrame. + + See Also + -------- + Series.map: For element-wise operations. + Series.agg: Only perform aggregating type operations. + Series.transform: Only perform transforming type operations. + + Notes + ----- + Functions that mutate the passed object can produce unexpected + behavior or errors and are not supported. See :ref:`gotchas.udf-mutation` + for more details. + + Examples + -------- + Create a series with typical summer temperatures for each city. + + >>> s = pd.Series([20, 21, 12], index=["London", "New York", "Helsinki"]) + >>> s + London 20 + New York 21 + Helsinki 12 + dtype: int64 + + Square the values by defining a function and passing it as an + argument to ``apply()``. + + >>> def square(x): + ... return x**2 + >>> s.apply(square) + London 400 + New York 441 + Helsinki 144 + dtype: int64 + + Square the values by passing an anonymous function as an + argument to ``apply()``. + + >>> s.apply(lambda x: x**2) + London 400 + New York 441 + Helsinki 144 + dtype: int64 + + Define a custom function that needs additional positional + arguments and pass these additional arguments using the + ``args`` keyword. + + >>> def subtract_custom_value(x, custom_value): + ... return x - custom_value + + >>> s.apply(subtract_custom_value, args=(5,)) + London 15 + New York 16 + Helsinki 7 + dtype: int64 + + Define a custom function that takes keyword arguments + and pass these arguments to ``apply``. + + >>> def add_custom_values(x, **kwargs): + ... for month in kwargs: + ... x += kwargs[month] + ... return x + + >>> s.apply(add_custom_values, june=30, july=20, august=25) + London 95 + New York 96 + Helsinki 87 + dtype: int64 + + Use a function from the Numpy library. + + >>> s.apply(np.log) + London 2.995732 + New York 3.044522 + Helsinki 2.484907 + dtype: float64 + """ + return SeriesApply( + self, + func, + by_row=by_row, + args=args, + kwargs=kwargs, + ).apply() + + def _reindex_indexer( + self, + new_index: Index | None, + indexer: npt.NDArray[np.intp] | None, + ) -> Series: + # Note: new_index is None iff indexer is None + # if not None, indexer is np.intp + if indexer is None and ( + new_index is None or new_index.names == self.index.names + ): + return self.copy(deep=False) + + new_values = algorithms.take_nd( + self._values, indexer, allow_fill=True, fill_value=None + ) + return self._constructor(new_values, index=new_index, copy=False) + + def _needs_reindex_multi(self, axes, method, level) -> bool: + """ + Check if we do need a multi reindex; this is for compat with + higher dims. + """ + return False + + @overload + def rename( + self, + index: Renamer | Hashable | None = ..., + *, + axis: Axis | None = ..., + copy: bool | lib.NoDefault = ..., + inplace: Literal[True], + level: Level | None = ..., + errors: IgnoreRaise = ..., + ) -> Series | None: ... + + @overload + def rename( + self, + index: Renamer | Hashable | None = ..., + *, + axis: Axis | None = ..., + copy: bool | lib.NoDefault = ..., + inplace: Literal[False] = ..., + level: Level | None = ..., + errors: IgnoreRaise = ..., + ) -> Series: ... + + def rename( + self, + index: Renamer | Hashable | None = None, + *, + axis: Axis | None = None, + copy: bool | lib.NoDefault = lib.no_default, + inplace: bool = False, + level: Level | None = None, + errors: IgnoreRaise = "ignore", + ) -> Series | None: + """ + Alter Series index labels or name. + + Function / dict values must be unique (1-to-1). Labels not contained in + a dict / Series will be left as-is. Extra labels listed don't throw an + error. + + Alternatively, change ``Series.name`` with a scalar value. + + See the :ref:`user guide ` for more. + + Parameters + ---------- + index : scalar, hashable sequence, dict-like or function optional + Functions or dict-like are transformations to apply to + the index. + Scalar or hashable sequence-like will alter the ``Series.name`` + attribute. + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + inplace : bool, default False + Whether to return a new Series. If True the value of copy is ignored. + level : int or level name, default None + In case of MultiIndex, only rename labels in the specified level. + errors : {'ignore', 'raise'}, default 'ignore' + If 'raise', raise `KeyError` when a `dict-like mapper` or + `index` contains labels that are not present in the index being transformed. + If 'ignore', existing keys will be renamed and extra keys will be ignored. + + Returns + ------- + Series + A shallow copy with index labels or name altered, or the same object + if ``inplace=True`` and index is not a dict or callable else None. + + See Also + -------- + DataFrame.rename : Corresponding DataFrame method. + Series.rename_axis : Set the name of the axis. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s + 0 1 + 1 2 + 2 3 + dtype: int64 + >>> s.rename("my_name") # scalar, changes Series.name + 0 1 + 1 2 + 2 3 + Name: my_name, dtype: int64 + >>> s.rename(lambda x: x**2) # function, changes labels + 0 1 + 1 2 + 4 3 + dtype: int64 + >>> s.rename({1: 3, 2: 5}) # mapping, changes labels + 0 1 + 3 2 + 5 3 + dtype: int64 + """ + self._check_copy_deprecation(copy) + if axis is not None: + # Make sure we raise if an invalid 'axis' is passed. + axis = self._get_axis_number(axis) + + if callable(index) or is_dict_like(index): + # error: Argument 1 to "_rename" of "NDFrame" has incompatible + # type "Union[Union[Mapping[Any, Hashable], Callable[[Any], + # Hashable]], Hashable, None]"; expected "Union[Mapping[Any, + # Hashable], Callable[[Any], Hashable], None]" + return super()._rename( + index, # type: ignore[arg-type] + inplace=inplace, + level=level, + errors=errors, + ) + else: + return self._set_name(index, inplace=inplace) + + def set_axis( + self, + labels, + *, + axis: Axis = 0, + copy: bool | lib.NoDefault = lib.no_default, + ) -> Series: + """ + Assign desired index to given axis. + + .. deprecated:: 3.0.0 + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + Indexes for row labels can be changed by assigning a list-like or Index. + + Parameters + ---------- + labels : list-like or Index + The values for the new index. + axis : {0 or 'index'}, default 0 + The axis to update. The value 0 identifies the rows. For `Series` + this parameter is unused and defaults to 0. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + Returns + ------- + Series + A shallow copy of the object with axis altered to the given index. + + See Also + -------- + Series.rename_axis : Alter the name of the index. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s + 0 1 + 1 2 + 2 3 + dtype: int64 + >>> s.set_axis(["a", "b", "c"], axis=0) + a 1 + b 2 + c 3 + dtype: int64 + """ + + return super().set_axis(labels, axis=axis, copy=copy) + + # error: Cannot determine type of 'reindex' + + def reindex( # type: ignore[override] + self, + index=None, + *, + axis: Axis | None = None, + method: ReindexMethod | None = None, + copy: bool | lib.NoDefault = lib.no_default, + level: Level | None = None, + fill_value: Scalar | None = None, + limit: int | None = None, + tolerance=None, + ) -> Series: + """ + Conform Series to new index with optional filling logic. + + Places NA/NaN in locations having no value in the previous index. A new object + is produced unless the new index is equivalent to the current one and + ``copy=False``. + + Parameters + ---------- + index : scalar, list-like, dict-like or function, optional + A scalar, list-like, dict-like or functions transformations to + apply to that axis' values. + axis : {0 or 'index'}, default 0 + The axis to rename. For `Series` this parameter is unused and defaults to 0. + method : {{None, 'backfill'/'bfill', 'pad'/'ffill', 'nearest'}} + Method to use for filling holes in reindexed DataFrame. + Please note: this is only applicable to DataFrames/Series with a + monotonically increasing/decreasing index. + + * None (default): don't fill gaps + * pad / ffill: Propagate last valid observation forward to next + valid. + * backfill / bfill: Use next valid observation to fill gap. + * nearest: Use nearest valid observations to fill gap. + + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + level : int or name + Broadcast across a level, matching Index values on the + passed MultiIndex level. + fill_value : scalar, default np.nan + Value to use for missing values. Defaults to NaN, but can be any + "compatible" value. + limit : int, default None + Maximum number of consecutive elements to forward or backward fill. + tolerance : optional + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations most + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + + Tolerance may be a scalar value, which applies the same tolerance + to all values, or list-like, which applies variable tolerance per + element. List-like includes list, tuple, array, Series, and must be + the same size as the index and its dtype must exactly match the + index's type. + + Returns + ------- + Series + Series with changed index. + + See Also + -------- + DataFrame.set_index : Set row labels. + DataFrame.reset_index : Remove row labels or move them to new columns. + DataFrame.reindex_like : Change to same indices as other DataFrame. + + Examples + -------- + ``DataFrame.reindex`` supports two calling conventions + + * ``(index=index_labels, columns=column_labels, ...)`` + * ``(labels, axis={{'index', 'columns'}}, ...)`` + + We *highly* recommend using keyword arguments to clarify your + intent. + + Create a DataFrame with some fictional data. + + >>> index = ["Firefox", "Chrome", "Safari", "IE10", "Konqueror"] + >>> columns = ["http_status", "response_time"] + >>> df = pd.DataFrame( + ... [[200, 0.04], [200, 0.02], [404, 0.07], [404, 0.08], [301, 1.0]], + ... columns=columns, + ... index=index, + ... ) + >>> df + http_status response_time + Firefox 200 0.04 + Chrome 200 0.02 + Safari 404 0.07 + IE10 404 0.08 + Konqueror 301 1.00 + + Create a new index and reindex the DataFrame. By default + values in the new index that do not have corresponding + records in the DataFrame are assigned ``NaN``. + + >>> new_index = ["Safari", "Iceweasel", "Comodo Dragon", "IE10", "Chrome"] + >>> df.reindex(new_index) + http_status response_time + Safari 404.0 0.07 + Iceweasel NaN NaN + Comodo Dragon NaN NaN + IE10 404.0 0.08 + Chrome 200.0 0.02 + + We can fill in the missing values by passing a value to + the keyword ``fill_value``. Because the index is not monotonically + increasing or decreasing, we cannot use arguments to the keyword + ``method`` to fill the ``NaN`` values. + + >>> df.reindex(new_index, fill_value=0) + http_status response_time + Safari 404 0.07 + Iceweasel 0 0.00 + Comodo Dragon 0 0.00 + IE10 404 0.08 + Chrome 200 0.02 + + >>> df.reindex(new_index, fill_value="missing") + http_status response_time + Safari 404 0.07 + Iceweasel missing missing + Comodo Dragon missing missing + IE10 404 0.08 + Chrome 200 0.02 + + We can also reindex the columns. + + >>> df.reindex(columns=["http_status", "user_agent"]) + http_status user_agent + Firefox 200 NaN + Chrome 200 NaN + Safari 404 NaN + IE10 404 NaN + Konqueror 301 NaN + + Or we can use "axis-style" keyword arguments + + >>> df.reindex(["http_status", "user_agent"], axis="columns") + http_status user_agent + Firefox 200 NaN + Chrome 200 NaN + Safari 404 NaN + IE10 404 NaN + Konqueror 301 NaN + + To further illustrate the filling functionality in + ``reindex``, we will create a DataFrame with a + monotonically increasing index (for example, a sequence + of dates). + + >>> date_index = pd.date_range("1/1/2010", periods=6, freq="D") + >>> df2 = pd.DataFrame( + ... {"prices": [100, 101, np.nan, 100, 89, 88]}, index=date_index + ... ) + >>> df2 + prices + 2010-01-01 100.0 + 2010-01-02 101.0 + 2010-01-03 NaN + 2010-01-04 100.0 + 2010-01-05 89.0 + 2010-01-06 88.0 + + Suppose we decide to expand the DataFrame to cover a wider + date range. + + >>> date_index2 = pd.date_range("12/29/2009", periods=10, freq="D") + >>> df2.reindex(date_index2) + prices + 2009-12-29 NaN + 2009-12-30 NaN + 2009-12-31 NaN + 2010-01-01 100.0 + 2010-01-02 101.0 + 2010-01-03 NaN + 2010-01-04 100.0 + 2010-01-05 89.0 + 2010-01-06 88.0 + 2010-01-07 NaN + + The index entries that did not have a value in the original data frame + (for example, '2009-12-29') are by default filled with ``NaN``. + If desired, we can fill in the missing values using one of several + options. + + For example, to back-propagate the last valid value to fill the ``NaN`` + values, pass ``bfill`` as an argument to the ``method`` keyword. + + >>> df2.reindex(date_index2, method="bfill") + prices + 2009-12-29 100.0 + 2009-12-30 100.0 + 2009-12-31 100.0 + 2010-01-01 100.0 + 2010-01-02 101.0 + 2010-01-03 NaN + 2010-01-04 100.0 + 2010-01-05 89.0 + 2010-01-06 88.0 + 2010-01-07 NaN + + Please note that the ``NaN`` value present in the original DataFrame + (at index value 2010-01-03) will not be filled by any of the + value propagation schemes. This is because filling while reindexing + does not look at DataFrame values, but only compares the original and + desired indexes. If you do want to fill in the ``NaN`` values present + in the original DataFrame, use the ``fillna()`` method. + + See the :ref:`user guide ` for more. + """ + return super().reindex( + index=index, + method=method, + level=level, + fill_value=fill_value, + limit=limit, + tolerance=tolerance, + copy=copy, + ) + + @overload # type: ignore[override] + def rename_axis( + self, + mapper: IndexLabel | lib.NoDefault = ..., + *, + index=..., + axis: Axis = ..., + copy: bool | lib.NoDefault = ..., + inplace: Literal[True], + ) -> None: ... + + @overload + def rename_axis( + self, + mapper: IndexLabel | lib.NoDefault = ..., + *, + index=..., + axis: Axis = ..., + copy: bool | lib.NoDefault = ..., + inplace: Literal[False] = ..., + ) -> Self: ... + + @overload + def rename_axis( + self, + mapper: IndexLabel | lib.NoDefault = ..., + *, + index=..., + axis: Axis = ..., + copy: bool | lib.NoDefault = ..., + inplace: bool = ..., + ) -> Self | None: ... + + def rename_axis( + self, + mapper: IndexLabel | lib.NoDefault = lib.no_default, + *, + index=lib.no_default, + axis: Axis = 0, + copy: bool | lib.NoDefault = lib.no_default, + inplace: bool = False, + ) -> Self | None: + """ + Set the name of the axis for the index. + + Parameters + ---------- + mapper : scalar, list-like, optional + Value to set the axis name attribute. + + Use either ``mapper`` and ``axis`` to + specify the axis to target with ``mapper``, or ``index``. + + index : scalar, list-like, dict-like or function, optional + A scalar, list-like, dict-like or functions transformations to + apply to that axis' values. + axis : {0 or 'index'}, default 0 + The axis to rename. For `Series` this parameter is unused and defaults to 0. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + inplace : bool, default False + Modifies the object directly, instead of creating a new Series + or DataFrame. + + Returns + ------- + Series, or None + The same type as the caller or None if ``inplace=True``. + + See Also + -------- + Series.rename : Alter Series index labels or name. + DataFrame.rename : Alter DataFrame index labels or name. + Index.rename : Set new names on index. + + Examples + -------- + + >>> s = pd.Series(["dog", "cat", "monkey"]) + >>> s + 0 dog + 1 cat + 2 monkey + dtype: str + >>> s.rename_axis("animal") + animal + 0 dog + 1 cat + 2 monkey + dtype: str + """ + return super().rename_axis( + mapper=mapper, + index=index, + axis=axis, + inplace=inplace, + copy=copy, + ) + + @overload + def drop( + self, + labels: IndexLabel | ListLike = ..., + *, + axis: Axis = ..., + index: IndexLabel | ListLike = ..., + columns: IndexLabel | ListLike = ..., + level: Level | None = ..., + inplace: Literal[True], + errors: IgnoreRaise = ..., + ) -> None: ... + + @overload + def drop( + self, + labels: IndexLabel | ListLike = ..., + *, + axis: Axis = ..., + index: IndexLabel | ListLike = ..., + columns: IndexLabel | ListLike = ..., + level: Level | None = ..., + inplace: Literal[False] = ..., + errors: IgnoreRaise = ..., + ) -> Series: ... + + @overload + def drop( + self, + labels: IndexLabel | ListLike = ..., + *, + axis: Axis = ..., + index: IndexLabel | ListLike = ..., + columns: IndexLabel | ListLike = ..., + level: Level | None = ..., + inplace: bool = ..., + errors: IgnoreRaise = ..., + ) -> Series | None: ... + + def drop( + self, + labels: IndexLabel | ListLike = None, + *, + axis: Axis = 0, + index: IndexLabel | ListLike = None, + columns: IndexLabel | ListLike = None, + level: Level | None = None, + inplace: bool = False, + errors: IgnoreRaise = "raise", + ) -> Series | None: + """ + Return Series with specified index labels removed. + + Remove elements of a Series based on specifying the index labels. + When using a multi-index, labels on different levels can be removed + by specifying the level. + + Parameters + ---------- + labels : single label or list-like + Index labels to drop. + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + index : single label or list-like + Redundant for application on Series, but 'index' can be used instead + of 'labels'. + columns : single label or list-like + No change is made to the Series; use 'index' or 'labels' instead. + level : int or level name, optional + For MultiIndex, level for which the labels will be removed. + inplace : bool, default False + If True, do operation inplace and return None. + errors : {'ignore', 'raise'}, default 'raise' + If 'ignore', suppress error and only existing labels are dropped. + + Returns + ------- + Series or None + Series with specified index labels removed or None if ``inplace=True``. + + Raises + ------ + KeyError + If none of the labels are found in the index. + + See Also + -------- + Series.reindex : Return only specified index labels of Series. + Series.dropna : Return series without null values. + Series.drop_duplicates : Return Series with duplicate values removed. + DataFrame.drop : Drop specified labels from rows or columns. + + Examples + -------- + >>> s = pd.Series(data=np.arange(3), index=["A", "B", "C"]) + >>> s + A 0 + B 1 + C 2 + dtype: int64 + + Drop labels B and C + + >>> s.drop(labels=["B", "C"]) + A 0 + dtype: int64 + + Drop 2nd level label in MultiIndex Series + + >>> midx = pd.MultiIndex( + ... levels=[["llama", "cow", "falcon"], ["speed", "weight", "length"]], + ... codes=[[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]], + ... ) + >>> s = pd.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3], index=midx) + >>> s + llama speed 45.0 + weight 200.0 + length 1.2 + cow speed 30.0 + weight 250.0 + length 1.5 + falcon speed 320.0 + weight 1.0 + length 0.3 + dtype: float64 + + >>> s.drop(labels="weight", level=1) + llama speed 45.0 + length 1.2 + cow speed 30.0 + length 1.5 + falcon speed 320.0 + length 0.3 + dtype: float64 + """ + return super().drop( + labels=labels, + axis=axis, + index=index, + columns=columns, + level=level, + inplace=inplace, + errors=errors, + ) + + def pop(self, item: Hashable) -> Any: + """ + Return item and drops from series. Raise KeyError if not found. + + Parameters + ---------- + item : label + Index of the element that needs to be removed. + + Returns + ------- + scalar + Value that is popped from series. + + See Also + -------- + Series.drop: Drop specified values from Series. + Series.drop_duplicates: Return Series with duplicate values removed. + + Examples + -------- + >>> ser = pd.Series([1, 2, 3]) + + >>> ser.pop(0) + 1 + + >>> ser + 1 2 + 2 3 + dtype: int64 + """ + return maybe_unbox_numpy_scalar(super().pop(item=item)) + + def info( + self, + verbose: bool | None = None, + buf: IO[str] | None = None, + max_cols: int | None = None, + memory_usage: bool | str | None = None, + show_counts: bool = True, + ) -> None: + """ + Print a concise summary of a Series. + + This method prints information about a Series including + the index dtype, non-NA values and memory usage. + + Parameters + ---------- + verbose : bool, optional + Whether to print the full summary. By default, the setting in + ``pandas.options.display.max_info_columns`` is followed. + buf : writable buffer, defaults to sys.stdout + Where to send the output. By default, the output is printed to + sys.stdout. Pass a writable buffer if you need to further process + the output. + max_cols : int, optional + Unused, exists only for compatibility with DataFrame.info. + memory_usage : bool, str, optional + Specifies whether total memory usage of the Series + elements (including the index) should be displayed. By default, + this follows the ``pandas.options.display.memory_usage`` setting. + + True always show memory usage. False never shows memory usage. + A value of 'deep' is equivalent to "True with deep introspection". + Memory usage is shown in human-readable units (base-2 + representation). Without deep introspection a memory estimation is + made based in column dtype and number of rows assuming values + consume the same memory amount for corresponding dtypes. With deep + memory introspection, a real memory usage calculation is performed + at the cost of computational resources. See the + :ref:`Frequently Asked Questions ` for more + details. + show_counts : bool, optional + Whether to show the non-null counts. By default, this is shown + only if the DataFrame is smaller than + ``pandas.options.display.max_info_rows`` and + ``pandas.options.display.max_info_columns``. A value of True always + shows the counts, and False never shows the counts. + + Returns + ------- + None + This method prints a summary of a Series and returns None. + + See Also + -------- + Series.describe: Generate descriptive statistics of Series. + Series.memory_usage: Memory usage of Series. + + Examples + -------- + >>> int_values = [1, 2, 3, 4, 5] + >>> text_values = ["alpha", "beta", "gamma", "delta", "epsilon"] + >>> s = pd.Series(text_values, index=int_values) + >>> s.info() + + Index: 5 entries, 1 to 5 + Series name: None + Non-Null Count Dtype + -------------- ----- + 5 non-null str + dtypes: str(1) + memory usage: 106.0 bytes + + Prints a summary excluding information about its values: + + >>> s.info(verbose=False) + + Index: 5 entries, 1 to 5 + dtypes: str(1) + memory usage: 106.0 bytes + + Pipe output of Series.info to buffer instead of sys.stdout, get + buffer content and writes to a text file: + + >>> import io + >>> buffer = io.StringIO() + >>> s.info(buf=buffer) + >>> s = buffer.getvalue() + >>> with open("df_info.txt", "w", encoding="utf-8") as f: # doctest: +SKIP + ... f.write(s) + 260 + + The `memory_usage` parameter allows deep introspection mode, specially + useful for big Series and fine-tune memory optimization: + + >>> random_strings_array = np.random.choice(["a", "b", "c"], 10**6) + >>> s = pd.Series(np.random.choice(["a", "b", "c"], 10**6)) + >>> s.info() + + RangeIndex: 1000000 entries, 0 to 999999 + Series name: None + Non-Null Count Dtype + -------------- ----- + 1000000 non-null str + dtypes: str(1) + memory usage: 8.6 MB + + >>> s.info(memory_usage="deep") + + RangeIndex: 1000000 entries, 0 to 999999 + Series name: None + Non-Null Count Dtype + -------------- ----- + 1000000 non-null str + dtypes: str(1) + memory usage: 8.6 MB + """ + return SeriesInfo(self, memory_usage).render( + buf=buf, + max_cols=max_cols, + verbose=verbose, + show_counts=show_counts, + ) + + def memory_usage(self, index: bool = True, deep: bool = False) -> int: + """ + Return the memory usage of the Series. + + The memory usage can optionally include the contribution of + the index and of elements of `object` dtype. + + Parameters + ---------- + index : bool, default True + Specifies whether to include the memory usage of the Series index. + deep : bool, default False + If True, introspect the data deeply by interrogating + `object` dtypes for system-level memory consumption, and include + it in the returned value. + + Returns + ------- + int + Bytes of memory consumed. + + See Also + -------- + numpy.ndarray.nbytes : Total bytes consumed by the elements of the + array. + DataFrame.memory_usage : Bytes consumed by a DataFrame. + + Examples + -------- + >>> s = pd.Series(range(3)) + >>> s.memory_usage() + 156 + + Not including the index gives the size of the rest of the data, which + is necessarily smaller: + + >>> s.memory_usage(index=False) + 24 + + The memory footprint of `object` values is ignored by default: + + >>> s = pd.Series(["a", "b"]) + >>> s.values + + ['a', 'b'] + Length: 2, dtype: str + >>> s.memory_usage() + 150 + >>> s.memory_usage(deep=True) + 150 + """ + v = self._memory_usage(deep=deep) + if index: + v += self.index.memory_usage(deep=deep) + return v + + def isin(self, values) -> Series: + """ + Whether elements in Series are contained in `values`. + + Return a boolean Series showing whether each element in the Series + matches an element in the passed sequence of `values` exactly. + + Parameters + ---------- + values : set or list-like + The sequence of values to test. Passing in a single string will + raise a ``TypeError``. Instead, turn a single string into a + list of one element. + + Returns + ------- + Series + Series of booleans indicating if each element is in values. + + Raises + ------ + TypeError + * If `values` is a string + + See Also + -------- + DataFrame.isin : Equivalent method on DataFrame. + + Examples + -------- + >>> s = pd.Series( + ... ["llama", "cow", "llama", "beetle", "llama", "hippo"], name="animal" + ... ) + >>> s.isin(["cow", "llama"]) + 0 True + 1 True + 2 True + 3 False + 4 True + 5 False + Name: animal, dtype: bool + + To invert the boolean values, use the ``~`` operator: + + >>> ~s.isin(["cow", "llama"]) + 0 False + 1 False + 2 False + 3 True + 4 False + 5 True + Name: animal, dtype: bool + + Passing a single string as ``s.isin('llama')`` will raise an error. Use + a list of one element instead: + + >>> s.isin(["llama"]) + 0 True + 1 False + 2 True + 3 False + 4 True + 5 False + Name: animal, dtype: bool + + Strings and integers are distinct and are therefore not comparable: + + >>> pd.Series([1]).isin(["1"]) + 0 False + dtype: bool + >>> pd.Series([1.1]).isin(["1.1"]) + 0 False + dtype: bool + """ + result = algorithms.isin(self._values, values) + return self._constructor(result, index=self.index, copy=False).__finalize__( + self, method="isin" + ) + + def between( + self, + left, + right, + inclusive: Literal["both", "neither", "left", "right"] = "both", + ) -> Series: + """ + Return boolean Series equivalent to left <= series <= right. + + This function returns a boolean vector containing `True` wherever the + corresponding Series element is between the boundary values `left` and + `right`. NA values are treated as `False`. + + Parameters + ---------- + left : scalar or list-like + Left boundary. + right : scalar or list-like + Right boundary. + inclusive : {"both", "neither", "left", "right"} + Include boundaries. Whether to set each bound as closed or open. + + Returns + ------- + Series + Series representing whether each element is between left and + right (inclusive). + + See Also + -------- + Series.gt : Greater than of series and other. + Series.lt : Less than of series and other. + + Notes + ----- + This function is equivalent to ``(left <= ser) & (ser <= right)`` + + Examples + -------- + >>> s = pd.Series([2, 0, 4, 8, np.nan]) + + Boundary values are included by default: + + >>> s.between(1, 4) + 0 True + 1 False + 2 True + 3 False + 4 False + dtype: bool + + With `inclusive` set to ``"neither"`` boundary values are excluded: + + >>> s.between(1, 4, inclusive="neither") + 0 True + 1 False + 2 False + 3 False + 4 False + dtype: bool + + `left` and `right` can be any scalar value: + + >>> s = pd.Series(["Alice", "Bob", "Carol", "Eve"]) + >>> s.between("Anna", "Daniel") + 0 False + 1 True + 2 True + 3 False + dtype: bool + """ + if inclusive == "both": + lmask = self >= left + rmask = self <= right + elif inclusive == "left": + lmask = self >= left + rmask = self < right + elif inclusive == "right": + lmask = self > left + rmask = self <= right + elif inclusive == "neither": + lmask = self > left + rmask = self < right + else: + raise ValueError( + "Inclusive has to be either string of 'both'," + "'left', 'right', or 'neither'." + ) + + return lmask & rmask + + def case_when( + self, + caselist: list[ + tuple[ + ArrayLike | Callable[[Series], Series | np.ndarray | Sequence[bool]], + ArrayLike | Scalar | Callable[[Series], Series | np.ndarray], + ], + ], + ) -> Series: + """ + Replace values where the conditions are True. + + .. versionadded:: 2.2.0 + + Parameters + ---------- + caselist : A list of tuples of conditions and expected replacements + Takes the form: ``(condition0, replacement0)``, + ``(condition1, replacement1)``, ... . + ``condition`` should be a 1-D boolean array-like object + or a callable. If ``condition`` is a callable, + it is computed on the Series + and should return a boolean Series or array. + The callable must not change the input Series + (though pandas doesn`t check it). ``replacement`` should be a + 1-D array-like object, a scalar or a callable. + If ``replacement`` is a callable, it is computed on the Series + and should return a scalar or Series. The callable + must not change the input Series + (though pandas doesn`t check it). + + Returns + ------- + Series + A new Series with values replaced based on the provided conditions. + + See Also + -------- + Series.mask : Replace values where the condition is True. + + Examples + -------- + >>> c = pd.Series([6, 7, 8, 9], name="c") + >>> a = pd.Series([0, 0, 1, 2]) + >>> b = pd.Series([0, 3, 4, 5]) + + >>> c.case_when( + ... caselist=[ + ... (a.gt(0), a), # condition, replacement + ... (b.gt(0), b), + ... ] + ... ) + 0 6 + 1 3 + 2 1 + 3 2 + Name: c, dtype: int64 + """ + if not isinstance(caselist, list): + raise TypeError( + f"The caselist argument should be a list; instead got {type(caselist)}" + ) + + if not caselist: + raise ValueError( + "provide at least one boolean condition, " + "with a corresponding replacement." + ) + + for num, entry in enumerate(caselist): + if not isinstance(entry, tuple): + raise TypeError( + f"Argument {num} must be a tuple; instead got {type(entry)}." + ) + if len(entry) != 2: + raise ValueError( + f"Argument {num} must have length 2; " + "a condition and replacement; " + f"instead got length {len(entry)}." + ) + caselist = [ + ( + com.apply_if_callable(condition, self), + com.apply_if_callable(replacement, self), + ) + for condition, replacement in caselist + ] + default = self.copy(deep=False) + conditions, replacements = zip(*caselist, strict=True) + common_dtypes = [infer_dtype_from(arg)[0] for arg in [*replacements, default]] + if len(set(common_dtypes)) > 1: + common_dtype = find_common_type(common_dtypes) + updated_replacements = [] + for condition, replacement in zip(conditions, replacements, strict=True): + if is_scalar(replacement): + replacement = construct_1d_arraylike_from_scalar( + value=replacement, length=len(condition), dtype=common_dtype + ) + elif isinstance(replacement, ABCSeries): + replacement = replacement.astype(common_dtype) + else: + replacement = pd_array(replacement, dtype=common_dtype) + updated_replacements.append(replacement) + replacements = updated_replacements + default = default.astype(common_dtype) + + counter = range(len(conditions) - 1, -1, -1) + for position, condition, replacement in zip( + counter, reversed(conditions), reversed(replacements), strict=True + ): + try: + default = default.mask( + condition, other=replacement, axis=0, inplace=False, level=None + ) + except Exception as error: + raise ValueError( + f"Failed to apply condition{position} and replacement{position}." + ) from error + return default + + # error: Cannot determine type of 'isna' + def isna(self) -> Series: + """ + Detect missing values. + + Return a boolean same-sized Series indicating if the values are NA. + NA values, such as None or :attr:`numpy.NaN`, get mapped to True + values. + Everything else gets mapped to False values. Characters such as empty + strings ``''`` or :attr:`numpy.inf` are not considered NA values. + + Returns + ------- + Series + Mask of bool values for each element in Series that + indicates whether an element is an NA value. + + See Also + -------- + DataFrame.isna : Detect missing values. + DataFrame.isnull : Alias of isna. + Series.notna : Boolean inverse of isna. + DataFrame.notna : Boolean inverse of isna. + Series.notnull : Alias of notna. + DataFrame.notnull : Alias of notna. + Series.dropna : Omit axes labels with missing values. + DataFrame.dropna : Omit axes labels with missing values. + isna : Top-level isna. + + Examples + -------- + Show which entries in a Series are NA. + + >>> ser = pd.Series([5, 6, np.nan]) + >>> ser + 0 5.0 + 1 6.0 + 2 NaN + dtype: float64 + >>> ser.isna() + 0 False + 1 False + 2 True + dtype: bool + """ + return NDFrame.isna(self) + + # error: Cannot determine type of 'isna' + @doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) + def isnull(self) -> Series: + """ + Series.isnull is an alias for Series.isna. + """ + return super().isnull() + + # error: Cannot determine type of 'notna' + def notna(self) -> Series: + """ + Detect existing (non-missing) values. + + Return a boolean same-sized Series indicating if the values are not NA. + Non-missing values get mapped to True. Characters such as empty + strings ``''`` or :attr:`numpy.inf` are not considered NA values. + NA values, such as None or :attr:`numpy.NaN`, get mapped to False + values. + + Returns + ------- + Series + Mask of bool values for each element in Series that + indicates whether an element is not an NA value. + + See Also + -------- + Series.isna : Detect missing values. + DataFrame.isna : Detect missing values. + Series.isnull : Alias of isna. + DataFrame.isnull : Alias of isna. + DataFrame.notna : Boolean inverse of isna. + DataFrame.notnull : Alias of notna. + Series.dropna : Omit axes labels with missing values. + DataFrame.dropna : Omit axes labels with missing values. + notna : Top-level notna. + + Examples + -------- + Show which entries in a Series are not NA. + + >>> ser = pd.Series([5, 6, np.nan]) + >>> ser + 0 5.0 + 1 6.0 + 2 NaN + dtype: float64 + >>> ser.notna() + 0 True + 1 True + 2 False + dtype: bool + """ + return super().notna() + + # error: Cannot determine type of 'notna' + @doc(NDFrame.notna, klass=_shared_doc_kwargs["klass"]) + def notnull(self) -> Series: + """ + Series.notnull is an alias for Series.notna. + """ + return super().notnull() + + @overload + def dropna( + self, + *, + axis: Axis = ..., + inplace: Literal[False] = ..., + how: AnyAll | None = ..., + ignore_index: bool = ..., + ) -> Series: ... + + @overload + def dropna( + self, + *, + axis: Axis = ..., + inplace: Literal[True], + how: AnyAll | None = ..., + ignore_index: bool = ..., + ) -> None: ... + + def dropna( + self, + *, + axis: Axis = 0, + inplace: bool = False, + how: AnyAll | None = None, + ignore_index: bool = False, + ) -> Series | None: + """ + Return a new Series with missing values removed. + + See the :ref:`User Guide ` for more on which values are + considered missing, and how to work with missing data. + + Parameters + ---------- + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + inplace : bool, default False + If True, do operation inplace and return None. + how : str, optional + Not in use. Kept for compatibility. + ignore_index : bool, default ``False`` + If ``True``, the resulting axis will be labeled 0, 1, …, n - 1. + + .. versionadded:: 2.0.0 + + Returns + ------- + Series or None + Series with NA entries dropped from it or None if ``inplace=True``. + + See Also + -------- + Series.isna: Indicate missing values. + Series.notna : Indicate existing (non-missing) values. + Series.fillna : Replace missing values. + DataFrame.dropna : Drop rows or columns which contain NA values. + Index.dropna : Drop missing indices. + + Examples + -------- + >>> ser = pd.Series([1.0, 2.0, np.nan]) + >>> ser + 0 1.0 + 1 2.0 + 2 NaN + dtype: float64 + + Drop NA values from a Series. + + >>> ser.dropna() + 0 1.0 + 1 2.0 + dtype: float64 + + Empty strings are not considered NA values. ``None`` is considered an + NA value. + + >>> ser = pd.Series([np.nan, 2, pd.NaT, "", None, "I stay"]) + >>> ser + 0 NaN + 1 2 + 2 NaT + 3 + 4 None + 5 I stay + dtype: object + >>> ser.dropna() + 1 2 + 3 + 5 I stay + dtype: object + """ + inplace = validate_bool_kwarg(inplace, "inplace") + ignore_index = validate_bool_kwarg(ignore_index, "ignore_index") + # Validate the axis parameter + self._get_axis_number(axis or 0) + + if self._can_hold_na: + result = remove_na_arraylike(self) + elif not inplace: + result = self.copy(deep=False) + else: + result = self + + if ignore_index: + result.index = default_index(len(result)) + + if inplace: + return self._update_inplace(result) + else: + return result + + # ---------------------------------------------------------------------- + # Time series-oriented methods + + def to_timestamp( + self, + freq: Frequency | None = None, + how: Literal["s", "e", "start", "end"] = "start", + copy: bool | lib.NoDefault = lib.no_default, + ) -> Series: + """ + Cast to DatetimeIndex of Timestamps, at *beginning* of period. + + This can be changed to the *end* of the period, by specifying `how="e"`. + + Parameters + ---------- + freq : str, default frequency of PeriodIndex + Desired frequency. + how : {'s', 'e', 'start', 'end'} + Convention for converting period to timestamp; start of period + vs. end. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + Returns + ------- + Series with DatetimeIndex + Series with the PeriodIndex cast to DatetimeIndex. + + See Also + -------- + Series.to_period: Inverse method to cast DatetimeIndex to PeriodIndex. + DataFrame.to_timestamp: Equivalent method for DataFrame. + + Examples + -------- + >>> idx = pd.PeriodIndex(["2023", "2024", "2025"], freq="Y") + >>> s1 = pd.Series([1, 2, 3], index=idx) + >>> s1 + 2023 1 + 2024 2 + 2025 3 + Freq: Y-DEC, dtype: int64 + + The resulting frequency of the Timestamps is `YearBegin` + + >>> s1 = s1.to_timestamp() + >>> s1 + 2023-01-01 1 + 2024-01-01 2 + 2025-01-01 3 + Freq: YS-JAN, dtype: int64 + + Using `freq` which is the offset that the Timestamps will have + + >>> s2 = pd.Series([1, 2, 3], index=idx) + >>> s2 = s2.to_timestamp(freq="M") + >>> s2 + 2023-01-31 1 + 2024-01-31 2 + 2025-01-31 3 + Freq: YE-JAN, dtype: int64 + """ + self._check_copy_deprecation(copy) + if not isinstance(self.index, PeriodIndex): + raise TypeError(f"unsupported Type {type(self.index).__name__}") + + new_obj = self.copy(deep=False) + new_index = self.index.to_timestamp(freq=freq, how=how) + setattr(new_obj, "index", new_index) + return new_obj + + def to_period( + self, + freq: str | None = None, + copy: bool | lib.NoDefault = lib.no_default, + ) -> Series: + """ + Convert Series from DatetimeIndex to PeriodIndex. + + Parameters + ---------- + freq : str, default None + Frequency associated with the PeriodIndex. + copy : bool, default False + This keyword is now ignored; changing its value will have no + impact on the method. + + .. deprecated:: 3.0.0 + + This keyword is ignored and will be removed in pandas 4.0. Since + pandas 3.0, this method always returns a new object using a lazy + copy mechanism that defers copies until necessary + (Copy-on-Write). See the `user guide on Copy-on-Write + `__ + for more details. + + Returns + ------- + Series + Series with index converted to PeriodIndex. + + See Also + -------- + DataFrame.to_period: Equivalent method for DataFrame. + Series.dt.to_period: Convert DateTime column values. + + Examples + -------- + >>> idx = pd.DatetimeIndex(["2023", "2024", "2025"]) + >>> s = pd.Series([1, 2, 3], index=idx) + >>> s = s.to_period() + >>> s + 2023 1 + 2024 2 + 2025 3 + Freq: Y-DEC, dtype: int64 + + Viewing the index + + >>> s.index + PeriodIndex(['2023', '2024', '2025'], dtype='period[Y-DEC]') + """ + self._check_copy_deprecation(copy) + if not isinstance(self.index, DatetimeIndex): + raise TypeError(f"unsupported Type {type(self.index).__name__}") + + new_obj = self.copy(deep=False) + new_index = self.index.to_period(freq=freq) + setattr(new_obj, "index", new_index) + return new_obj + + # ---------------------------------------------------------------------- + # Add index + _AXIS_ORDERS: list[Literal["index", "columns"]] = ["index"] + _AXIS_LEN = len(_AXIS_ORDERS) + _info_axis_number: Literal[0] = 0 + _info_axis_name: Literal["index"] = "index" + + index = properties.AxisProperty( + axis=0, + doc=""" + The index (axis labels) of the Series. + + The index of a Series is used to label and identify each element of the + underlying data. The index can be thought of as an immutable ordered set + (technically a multi-set, as it may contain duplicate labels), and is + used to index and align data in pandas. + + Returns + ------- + Index + The index labels of the Series. + + See Also + -------- + Series.reindex : Conform Series to new index. + Index : The base pandas index type. + + Notes + ----- + For more information on pandas indexing, see the `indexing user guide + `__. + + Examples + -------- + To create a Series with a custom index and view the index labels: + + >>> cities = ['Kolkata', 'Chicago', 'Toronto', 'Lisbon'] + >>> populations = [14.85, 2.71, 2.93, 0.51] + >>> city_series = pd.Series(populations, index=cities) + >>> city_series.index + Index(['Kolkata', 'Chicago', 'Toronto', 'Lisbon'], dtype='object') + + To change the index labels of an existing Series: + + >>> city_series.index = ['KOL', 'CHI', 'TOR', 'LIS'] + >>> city_series.index + Index(['KOL', 'CHI', 'TOR', 'LIS'], dtype='object') + """, + ) + + # ---------------------------------------------------------------------- + # Accessor Methods + # ---------------------------------------------------------------------- + str = Accessor("str", StringMethods) + dt = Accessor("dt", CombinedDatetimelikeProperties) + cat = Accessor("cat", CategoricalAccessor) + plot = Accessor("plot", pandas.plotting.PlotAccessor) + sparse = Accessor("sparse", SparseAccessor) + struct = Accessor("struct", StructAccessor) + list = Accessor("list", ListAccessor) + + # ---------------------------------------------------------------------- + # Add plotting methods to Series + hist = pandas.plotting.hist_series + + # ---------------------------------------------------------------------- + # Template-Based Arithmetic/Comparison Methods + + def _cmp_method(self, other, op): + res_name = ops.get_op_result_name(self, other) + + if isinstance(other, Series) and not self._indexed_same(other): + raise ValueError("Can only compare identically-labeled Series objects") + + lvalues = self._values + rvalues = extract_array(other, extract_numpy=True, extract_range=True) + + res_values = ops.comparison_op(lvalues, rvalues, op) + + return self._construct_result(res_values, name=res_name, other=other) + + def _logical_method(self, other, op): + res_name = ops.get_op_result_name(self, other) + self, other = self._align_for_op(other, align_asobject=True) + + lvalues = self._values + rvalues = extract_array(other, extract_numpy=True, extract_range=True) + + res_values = ops.logical_op(lvalues, rvalues, op) + return self._construct_result(res_values, name=res_name, other=other) + + def _arith_method(self, other, op): + self, other = self._align_for_op(other) + return base.IndexOpsMixin._arith_method(self, other, op) + + def _align_for_op(self, right, align_asobject: bool = False): + """align lhs and rhs Series""" + # TODO: Different from DataFrame._align_for_op, list, tuple and ndarray + # are not coerced here + # because Series has inconsistencies described in GH#13637 + left = self + + if isinstance(right, Series): + # avoid repeated alignment + if not left.index.equals(right.index): + if align_asobject: + if left.dtype not in (object, np.bool_) or right.dtype not in ( + object, + np.bool_, + ): + pass + # GH#52538 no longer cast in these cases + else: + # to keep original value's dtype for bool ops + left = left.astype(object) + right = right.astype(object) + + left, right = left.align(right) + + return left, right + + def _binop(self, other: Series, func, level=None, fill_value=None) -> Series: + """ + Perform generic binary operation with optional fill value. + + Parameters + ---------- + other : Series + func : binary operator + fill_value : float or object + Value to substitute for NA/null values. If both Series are NA in a + location, the result will be NA regardless of the passed fill value. + level : int or level name, default None + Broadcast across a level, matching Index values on the + passed MultiIndex level. + + Returns + ------- + Series + """ + this = self + + if not self.index.equals(other.index): + this, other = self.align(other, level=level, join="outer") + + this_vals, other_vals = ops.fill_binop(this._values, other._values, fill_value) + + with np.errstate(all="ignore"): + result = func(this_vals, other_vals) + + name = ops.get_op_result_name(self, other) + + out = this._construct_result(result, name, other) + return cast(Series, out) + + def _construct_result( + self, + result: ArrayLike | tuple[ArrayLike, ArrayLike], + name: Hashable, + other: AnyArrayLike | DataFrame, + ) -> Series | tuple[Series, Series]: + """ + Construct an appropriately-labelled Series from the result of an op. + + Parameters + ---------- + result : ndarray or ExtensionArray + name : Label + other : Series, DataFrame or array-like + + Returns + ------- + Series + In the case of __divmod__ or __rdivmod__, a 2-tuple of Series. + """ + if isinstance(result, tuple): + # produced by divmod or rdivmod + + res1 = self._construct_result(result[0], name=name, other=other) + res2 = self._construct_result(result[1], name=name, other=other) + + # GH#33427 assertions to keep mypy happy + assert isinstance(res1, Series) + assert isinstance(res2, Series) + return (res1, res2) + + # TODO: result should always be ArrayLike, but this fails for some + # JSONArray tests + dtype = getattr(result, "dtype", None) + out = self._constructor(result, index=self.index, dtype=dtype, copy=False) + out = out.__finalize__(self) + out = out.__finalize__(other) + + # Set the result's name after __finalize__ is called because __finalize__ + # would set it back to self.name + out.name = name + return out + + def _flex_method(self, other, op, *, level=None, fill_value=None, axis: Axis = 0): + if axis is not None: + self._get_axis_number(axis) + + res_name = ops.get_op_result_name(self, other) + + if isinstance(other, Series): + return self._binop(other, op, level=level, fill_value=fill_value) + elif isinstance(other, (np.ndarray, list, tuple, ExtensionArray)): + if len(other) != len(self): + raise ValueError("Lengths must be equal") + other = self._constructor(other, self.index, copy=False) + result = self._binop(other, op, level=level, fill_value=fill_value) + result._name = res_name + return result + elif isinstance(other, ABCDataFrame): + # GH#46179 + raise TypeError( + f"Series.{op.__name__.strip('_')} does not support a DataFrame " + f"`other`. Use df.{op.__name__.strip('_')}(ser) instead." + ) + else: + if fill_value is not None: + if isna(other): + return op(self, fill_value) + self = self.fillna(fill_value) + + return op(self, other) + + def eq( + self, + other, + level: Level | None = None, + fill_value: float | None = None, + axis: Axis = 0, + ) -> Series: + """ + Return Equal to of series and other, element-wise (binary operator `eq`). + + Equivalent to ``series == other``, but with support to substitute a fill_value + for missing data in either one of the inputs. + + Parameters + ---------- + other : object + When a Series is provided, will align on indexes. For all other types, + will behave the same as ``==`` but with possibly different results due + to the other arguments. + level : int or name + Broadcast across a level, matching Index values on the + passed MultiIndex level. + fill_value : None or float value, default None (NaN) + Fill existing missing (NaN) values, and any new element needed for + successful Series alignment, with this value before computation. + If data in both corresponding Series locations is missing + the result of filling (at that location) will be missing. + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + + Returns + ------- + Series + The result of the operation. + + See Also + -------- + Series.ge : Return elementwise Greater than or equal to of series and other. + Series.le : Return elementwise Less than or equal to of series and other. + Series.gt : Return elementwise Greater than of series and other. + Series.lt : Return elementwise Less than of series and other. + + Examples + -------- + >>> a = pd.Series([1, 1, 1, np.nan], index=["a", "b", "c", "d"]) + >>> a + a 1.0 + b 1.0 + c 1.0 + d NaN + dtype: float64 + >>> b = pd.Series([1, np.nan, 1, np.nan], index=["a", "b", "d", "e"]) + >>> b + a 1.0 + b NaN + d 1.0 + e NaN + dtype: float64 + >>> a.eq(b, fill_value=0) + a True + b False + c False + d False + e False + dtype: bool + """ + return self._flex_method( + other, operator.eq, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("ne", "series")) + def ne(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, operator.ne, level=level, fill_value=fill_value, axis=axis + ) + + def le(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + """ + Return Less than or equal to of series and other, \ + element-wise (binary operator `le`). + + Equivalent to ``series <= other``, but with support to substitute a + fill_value for missing data in either one of the inputs. + + Parameters + ---------- + other : object + When a Series is provided, will align on indexes. For all other types, + will behave the same as ``==`` but with possibly different results due + to the other arguments. + level : int or name + Broadcast across a level, matching Index values on the + passed MultiIndex level. + fill_value : None or float value, default None (NaN) + Fill existing missing (NaN) values, and any new element needed for + successful Series alignment, with this value before computation. + If data in both corresponding Series locations is missing + the result of filling (at that location) will be missing. + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + + Returns + ------- + Series + The result of the operation. + + See Also + -------- + Series.ge : Return elementwise Greater than or equal to of series and other. + Series.lt : Return elementwise Less than of series and other. + Series.gt : Return elementwise Greater than of series and other. + Series.eq : Return elementwise equal to of series and other. + + Examples + -------- + >>> a = pd.Series([1, 1, 1, np.nan, 1], index=['a', 'b', 'c', 'd', 'e']) + >>> a + a 1.0 + b 1.0 + c 1.0 + d NaN + e 1.0 + dtype: float64 + >>> b = pd.Series([0, 1, 2, np.nan, 1], index=['a', 'b', 'c', 'd', 'f']) + >>> b + a 0.0 + b 1.0 + c 2.0 + d NaN + f 1.0 + dtype: float64 + >>> a.le(b, fill_value=0) + a False + b True + c True + d False + e False + f True + dtype: bool + """ + return self._flex_method( + other, operator.le, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("lt", "series")) + def lt(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, operator.lt, level=level, fill_value=fill_value, axis=axis + ) + + def ge(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + """ + Return Greater than or equal to of series and other, \ + element-wise (binary operator `ge`). + + Equivalent to ``series >= other``, but with support to substitute a + fill_value for missing data in either one of the inputs. + + Parameters + ---------- + other : object + When a Series is provided, will align on indexes. For all other types, + will behave the same as ``==`` but with possibly different results due + to the other arguments. + level : int or name + Broadcast across a level, matching Index values on the + passed MultiIndex level. + fill_value : None or float value, default None (NaN) + Fill existing missing (NaN) values, and any new element needed for + successful Series alignment, with this value before computation. + If data in both corresponding Series locations is missing + the result of filling (at that location) will be missing. + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + + Returns + ------- + Series + The result of the operation. + + See Also + -------- + Series.gt : Greater than comparison, element-wise. + Series.le : Less than or equal to comparison, element-wise. + Series.lt : Less than comparison, element-wise. + Series.eq : Equal to comparison, element-wise. + Series.ne : Not equal to comparison, element-wise. + + Examples + -------- + >>> a = pd.Series([1, 1, 1, np.nan, 1], index=["a", "b", "c", "d", "e"]) + >>> a + a 1.0 + b 1.0 + c 1.0 + d NaN + e 1.0 + dtype: float64 + >>> b = pd.Series([0, 1, 2, np.nan, 1], index=["a", "b", "c", "d", "f"]) + >>> b + a 0.0 + b 1.0 + c 2.0 + d NaN + f 1.0 + dtype: float64 + >>> a.ge(b, fill_value=0) + a True + b True + c False + d False + e True + f False + dtype: bool + """ + return self._flex_method( + other, operator.ge, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("gt", "series")) + def gt(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, operator.gt, level=level, fill_value=fill_value, axis=axis + ) + + def add(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + """ + Return Addition of series and other, element-wise (binary operator `add`). + + Equivalent to ``series + other``, but with support to substitute a fill_value + for missing data in either one of the inputs. + + Parameters + ---------- + other : Series or scalar value + With which to compute the addition. + level : int or name + Broadcast across a level, matching Index values on the + passed MultiIndex level. + fill_value : None or float value, default None (NaN) + Fill existing missing (NaN) values, and any new element needed for + successful Series alignment, with this value before computation. + If data in both corresponding Series locations is missing + the result of filling (at that location) will be missing. + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + + Returns + ------- + Series + The result of the operation. + + See Also + -------- + Series.radd : Reverse of the Addition operator, see + `Python documentation + `_ + for more details. + + Examples + -------- + >>> a = pd.Series([1, 1, 1, np.nan], index=["a", "b", "c", "d"]) + >>> a + a 1.0 + b 1.0 + c 1.0 + d NaN + dtype: float64 + >>> b = pd.Series([1, np.nan, 1, np.nan], index=["a", "b", "d", "e"]) + >>> b + a 1.0 + b NaN + d 1.0 + e NaN + dtype: float64 + >>> a.add(b, fill_value=0) + a 2.0 + b 1.0 + c 1.0 + d 1.0 + e NaN + dtype: float64 + """ + return self._flex_method( + other, operator.add, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("radd", "series")) + def radd(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, roperator.radd, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("sub", "series")) + def sub(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, operator.sub, level=level, fill_value=fill_value, axis=axis + ) + + subtract = sub + + @Appender(ops.make_flex_doc("rsub", "series")) + def rsub(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, roperator.rsub, level=level, fill_value=fill_value, axis=axis + ) + + def mul( + self, + other, + level: Level | None = None, + fill_value: float | None = None, + axis: Axis = 0, + ) -> Series: + """ + Return Multiplication of series and other, element-wise (binary operator `mul`). + + Equivalent to ``series * other``, but with support to substitute + a fill_value for missing data in either one of the inputs. + + Parameters + ---------- + other : Series or scalar value + With which to compute the multiplication. + level : int or name + Broadcast across a level, matching Index values on the + passed MultiIndex level. + fill_value : None or float value, default None (NaN) + Fill existing missing (NaN) values, and any new element needed for + successful Series alignment, with this value before computation. + If data in both corresponding Series locations is missing + the result of filling (at that location) will be missing. + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + + Returns + ------- + Series + The result of the operation. + + See Also + -------- + Series.rmul : Reverse of the Multiplication operator, see + `Python documentation + `_ + for more details. + + Examples + -------- + >>> a = pd.Series([1, 1, 1, np.nan], index=["a", "b", "c", "d"]) + >>> a + a 1.0 + b 1.0 + c 1.0 + d NaN + dtype: float64 + >>> b = pd.Series([1, np.nan, 1, np.nan], index=["a", "b", "d", "e"]) + >>> b + a 1.0 + b NaN + d 1.0 + e NaN + dtype: float64 + >>> a.multiply(b, fill_value=0) + a 1.0 + b 0.0 + c 0.0 + d 0.0 + e NaN + dtype: float64 + >>> a.mul(5, fill_value=0) + a 5.0 + b 5.0 + c 5.0 + d 0.0 + dtype: float64 + """ + return self._flex_method( + other, operator.mul, level=level, fill_value=fill_value, axis=axis + ) + + multiply = mul + + @Appender(ops.make_flex_doc("rmul", "series")) + def rmul(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, roperator.rmul, level=level, fill_value=fill_value, axis=axis + ) + + def truediv(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + """ + Return Floating division of series and other, \ + element-wise (binary operator `truediv`). + + Equivalent to ``series / other``, but with support to substitute a + fill_value for missing data in either one of the inputs. + + Parameters + ---------- + other : Series or scalar value + Series with which to compute division. + level : int or name + Broadcast across a level, matching Index values on the + passed MultiIndex level. + fill_value : None or float value, default None (NaN) + Fill existing missing (NaN) values, and any new element needed for + successful Series alignment, with this value before computation. + If data in both corresponding Series locations is missing + the result of filling (at that location) will be missing. + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + + Returns + ------- + Series + The result of the operation. + + See Also + -------- + Series.rtruediv : Reverse of the Floating division operator, see + `Python documentation + `_ + for more details. + + Examples + -------- + >>> a = pd.Series([1, 1, 1, np.nan], index=["a", "b", "c", "d"]) + >>> a + a 1.0 + b 1.0 + c 1.0 + d NaN + dtype: float64 + >>> b = pd.Series([1, np.nan, 1, np.nan], index=["a", "b", "d", "e"]) + >>> b + a 1.0 + b NaN + d 1.0 + e NaN + dtype: float64 + >>> a.divide(b, fill_value=0) + a 1.0 + b inf + c inf + d 0.0 + e NaN + dtype: float64 + """ + return self._flex_method( + other, operator.truediv, level=level, fill_value=fill_value, axis=axis + ) + + div = truediv + divide = truediv + + @Appender(ops.make_flex_doc("rtruediv", "series")) + def rtruediv(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, roperator.rtruediv, level=level, fill_value=fill_value, axis=axis + ) + + rdiv = rtruediv + + @Appender(ops.make_flex_doc("floordiv", "series")) + def floordiv(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, operator.floordiv, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("rfloordiv", "series")) + def rfloordiv(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, roperator.rfloordiv, level=level, fill_value=fill_value, axis=axis + ) + + def mod(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + """ + Return Modulo of series and other, element-wise (binary operator `mod`). + + Equivalent to ``series % other``, but with support to substitute a + fill_value for missing data in either one of the inputs. + + Parameters + ---------- + other : Series or scalar value + Series with which to compute modulo. + level : int or name + Broadcast across a level, matching Index values on the + passed MultiIndex level. + fill_value : None or float value, default None (NaN) + Fill existing missing (NaN) values, and any new element needed for + successful Series alignment, with this value before computation. + If data in both corresponding Series locations is missing + the result of filling (at that location) will be missing. + axis : {0 or 'index'} + Unused. Parameter needed for compatibility with DataFrame. + + Returns + ------- + Series + The result of the operation. + + See Also + -------- + Series.rmod : Reverse of the Modulo operator, see + `Python documentation + `_ + for more details. + + Examples + -------- + >>> a = pd.Series([1, 1, 1, np.nan], index=["a", "b", "c", "d"]) + >>> a + a 1.0 + b 1.0 + c 1.0 + d NaN + dtype: float64 + >>> b = pd.Series([1, np.nan, 1, np.nan], index=["a", "b", "d", "e"]) + >>> b + a 1.0 + b NaN + d 1.0 + e NaN + dtype: float64 + >>> a.mod(b, fill_value=0) + a 0.0 + b NaN + c NaN + d 0.0 + e NaN + dtype: float64 + """ + return self._flex_method( + other, operator.mod, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("rmod", "series")) + def rmod(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, roperator.rmod, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("pow", "series")) + def pow(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, operator.pow, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("rpow", "series")) + def rpow(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, roperator.rpow, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("divmod", "series")) + def divmod(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, divmod, level=level, fill_value=fill_value, axis=axis + ) + + @Appender(ops.make_flex_doc("rdivmod", "series")) + def rdivmod(self, other, level=None, fill_value=None, axis: Axis = 0) -> Series: + return self._flex_method( + other, roperator.rdivmod, level=level, fill_value=fill_value, axis=axis + ) + + # ---------------------------------------------------------------------- + # Reductions + + def _reduce( + self, + op, + # error: Variable "pandas.core.series.Series.str" is not valid as a type + name: str, # type: ignore[valid-type] + *, + axis: Axis = 0, + skipna: bool = True, + numeric_only: bool = False, + filter_type=None, + **kwds, + ): + """ + Perform a reduction operation. + + If we have an ndarray as a value, then simply perform the operation, + otherwise delegate to the object. + """ + delegate = self._values + + if axis is not None: + self._get_axis_number(axis) + + if isinstance(delegate, ExtensionArray): + # dispatch to ExtensionArray interface + result = delegate._reduce(name, skipna=skipna, **kwds) + + else: + # dispatch to numpy arrays + if numeric_only and self.dtype.kind not in "iufcb": + # i.e. not is_numeric_dtype(self.dtype) + kwd_name = "numeric_only" + if name in ["any", "all"]: + kwd_name = "bool_only" + # GH#47500 - change to TypeError to match other methods + raise TypeError( + f"Series.{name} does not allow {kwd_name}={numeric_only} " + "with non-numeric dtypes." + ) + result = op(delegate, skipna=skipna, **kwds) + + result = maybe_unbox_numpy_scalar(result) + return result + + # error: Signature of "any" incompatible with supertype "NDFrame" + def any( # type: ignore[override] + self, + *, + axis: Axis = 0, + bool_only: bool = False, + skipna: bool = True, + **kwargs, + ) -> bool: + """ + Return whether any element is True, potentially over an axis. + + Returns False unless there is at least one element within a series or + along a Dataframe axis that is True or equivalent (e.g. non-zero or + non-empty). + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns', None}, default 0 + Indicate which axis or axes should be reduced. For `Series` this parameter + is unused and defaults to 0. + + * 0 / 'index' : reduce the index, return a Series whose index is the + original column labels. + * 1 / 'columns' : reduce the columns, return a Series whose index is the + original index. + * None : reduce all axes, return a scalar. + + bool_only : bool, default False + Include only boolean columns. Not implemented for Series. + skipna : bool, default True + Exclude NA/null values. If the entire row/column is NA and skipna is + True, then the result will be False, as for an empty row/column. + If skipna is False, then NA are treated as True, because these are not + equal to zero. + **kwargs : any, default None + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + + Returns + ------- + Series or scalar + If axis=None, then a scalar boolean is returned. + Otherwise a Series is returned with index matching the index argument. + + See Also + -------- + numpy.any : Numpy version of this method. + Series.any : Return whether any element is True. + Series.all : Return whether all elements are True. + DataFrame.any : Return whether any element is True over requested axis. + DataFrame.all : Return whether all elements are True over requested axis. + + Examples + -------- + **Series** + + For Series input, the output is a scalar indicating whether any element + is True. + + >>> pd.Series([False, False]).any() + False + >>> pd.Series([True, False]).any() + True + >>> pd.Series([], dtype="float64").any() + False + >>> pd.Series([np.nan]).any() + False + >>> pd.Series([np.nan]).any(skipna=False) + True + + **DataFrame** + + Whether each column contains at least one True element (the default). + + >>> df = pd.DataFrame({"A": [1, 2], "B": [0, 2], "C": [0, 0]}) + >>> df + A B C + 0 1 0 0 + 1 2 2 0 + + >>> df.any() + A True + B True + C False + dtype: bool + + Aggregating over the columns. + + >>> df = pd.DataFrame({"A": [True, False], "B": [1, 2]}) + >>> df + A B + 0 True 1 + 1 False 2 + + >>> df.any(axis="columns") + 0 True + 1 True + dtype: bool + + >>> df = pd.DataFrame({"A": [True, False], "B": [1, 0]}) + >>> df + A B + 0 True 1 + 1 False 0 + + >>> df.any(axis="columns") + 0 True + 1 False + dtype: bool + + Aggregating over the entire DataFrame with ``axis=None``. + + >>> df.any(axis=None) + True + + `any` for an empty DataFrame is an empty Series. + + >>> pd.DataFrame([]).any() + Series([], dtype: bool) + """ + nv.validate_logical_func((), kwargs, fname="any") + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + return self._reduce( + nanops.nanany, + name="any", + axis=axis, + numeric_only=bool_only, + skipna=skipna, + filter_type="bool", + ) + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="all") + def all( + self, + axis: Axis = 0, + bool_only: bool = False, + skipna: bool = True, + **kwargs, + ) -> bool: + """ + Return whether all elements are True, potentially over an axis. + + Returns True unless there at least one element within a series or + along a Dataframe axis that is False or equivalent (e.g. zero or + empty). + + Parameters + ---------- + axis : {0 or 'index', 1 or 'columns', None}, default 0 + Indicate which axis or axes should be reduced. For `Series` this parameter + is unused and defaults to 0. + + * 0 / 'index' : reduce the index, return a Series whose index is the + original column labels. + * 1 / 'columns' : reduce the columns, return a Series whose index is the + original index. + * None : reduce all axes, return a scalar. + + bool_only : bool, default False + Include only boolean columns. Not implemented for Series. + skipna : bool, default True + Exclude NA/null values. If the entire row/column is NA and skipna is + True, then the result will be True, as for an empty row/column. + If skipna is False, then NA are treated as True, because these are not + equal to zero. + **kwargs : any, default None + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + + Returns + ------- + Series or scalar + If axis=None, then a scalar boolean is returned. + Otherwise a Series is returned with index matching the index argument. + + See Also + -------- + Series.all : Return True if all elements are True. + DataFrame.any : Return True if one (or more) elements are True. + + Examples + -------- + **Series** + + >>> pd.Series([True, True]).all() + True + >>> pd.Series([True, False]).all() + False + >>> pd.Series([], dtype="float64").all() + True + >>> pd.Series([np.nan]).all() + True + >>> pd.Series([np.nan]).all(skipna=False) + True + + **DataFrames** + + Create a DataFrame from a dictionary. + + >>> df = pd.DataFrame({"col1": [True, True], "col2": [True, False]}) + >>> df + col1 col2 + 0 True True + 1 True False + + Default behaviour checks if values in each column all return True. + + >>> df.all() + col1 True + col2 False + dtype: bool + + Specify ``axis='columns'`` to check if values in each row all return True. + + >>> df.all(axis="columns") + 0 True + 1 False + dtype: bool + + Or ``axis=None`` for whether every value is True. + + >>> df.all(axis=None) + False + """ + nv.validate_logical_func((), kwargs, fname="all") + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + return self._reduce( + nanops.nanall, + name="all", + axis=axis, + numeric_only=bool_only, + skipna=skipna, + filter_type="bool", + ) + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="min") + def min( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ): + """ + Return the minimum of the values over the requested axis. + + If you want the *index* of the minimum, use ``idxmin``. + This is the equivalent of the ``numpy.ndarray`` method ``argmin``. + + Parameters + ---------- + axis : {index (0)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + For DataFrames, specifying ``axis=None`` will apply the aggregation + across both axes. + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + scalar or Series (if level specified) + The minimum of the values in the Series. + + See Also + -------- + numpy.min : Equivalent numpy function for arrays. + Series.min : Return the minimum. + Series.max : Return the maximum. + Series.idxmin : Return the index of the minimum. + Series.idxmax : Return the index of the maximum. + DataFrame.min : Return the minimum over the requested axis. + DataFrame.max : Return the maximum over the requested axis. + DataFrame.idxmin : Return the index of the minimum over the requested axis. + DataFrame.idxmax : Return the index of the maximum over the requested axis. + + Examples + -------- + >>> idx = pd.MultiIndex.from_arrays( + ... [["warm", "warm", "cold", "cold"], ["dog", "falcon", "fish", "spider"]], + ... names=["blooded", "animal"], + ... ) + >>> s = pd.Series([4, 2, 0, 8], name="legs", index=idx) + >>> s + blooded animal + warm dog 4 + falcon 2 + cold fish 0 + spider 8 + Name: legs, dtype: int64 + + >>> s.min() + 0 + """ + return NDFrame.min( + self, axis=axis, skipna=skipna, numeric_only=numeric_only, **kwargs + ) + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="max") + def max( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ): + """ + Return the maximum of the values over the requested axis. + + If you want the *index* of the maximum, use ``idxmax``. + This is the equivalent of the ``numpy.ndarray`` method ``argmax``. + + Parameters + ---------- + axis : {index (0)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + For DataFrames, specifying ``axis=None`` will apply the aggregation + across both axes. + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + scalar or Series (if level specified) + The maximum of the values in the Series. + + See Also + -------- + numpy.max : Equivalent numpy function for arrays. + Series.min : Return the minimum. + Series.max : Return the maximum. + Series.idxmin : Return the index of the minimum. + Series.idxmax : Return the index of the maximum. + DataFrame.min : Return the minimum over the requested axis. + DataFrame.max : Return the maximum over the requested axis. + DataFrame.idxmin : Return the index of the minimum over the requested axis. + DataFrame.idxmax : Return the index of the maximum over the requested axis. + + Examples + -------- + >>> idx = pd.MultiIndex.from_arrays( + ... [["warm", "warm", "cold", "cold"], ["dog", "falcon", "fish", "spider"]], + ... names=["blooded", "animal"], + ... ) + >>> s = pd.Series([4, 2, 0, 8], name="legs", index=idx) + >>> s + blooded animal + warm dog 4 + falcon 2 + cold fish 0 + spider 8 + Name: legs, dtype: int64 + + >>> s.max() + 8 + """ + return NDFrame.max( + self, axis=axis, skipna=skipna, numeric_only=numeric_only, **kwargs + ) + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="sum") + def sum( + self, + axis: Axis | None = None, + skipna: bool = True, + numeric_only: bool = False, + min_count: int = 0, + **kwargs, + ): + """ + Return the sum of the values over the requested axis. + + This is equivalent to the method ``numpy.sum``. + + Parameters + ---------- + axis : {index (0)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + .. warning:: + + The behavior of DataFrame.sum with ``axis=None`` is deprecated, + in a future version this will reduce over both axes and return a scalar + To retain the old behavior, pass axis=0 (or do not pass axis). + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. Not implemented for Series. + + min_count : int, default 0 + The required number of valid values to perform the operation. If fewer than + ``min_count`` non-NA values are present the result will be NA. + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + scalar or Series (if level specified) + Sum of the values for the requested axis. + + See Also + -------- + numpy.sum : Equivalent numpy function for computing sum. + Series.mean : Mean of the values. + Series.median : Median of the values. + Series.std : Standard deviation of the values. + Series.var : Variance of the values. + Series.min : Minimum value. + Series.max : Maximum value. + + Examples + -------- + >>> idx = pd.MultiIndex.from_arrays( + ... [["warm", "warm", "cold", "cold"], ["dog", "falcon", "fish", "spider"]], + ... names=["blooded", "animal"], + ... ) + >>> s = pd.Series([4, 2, 0, 8], name="legs", index=idx) + >>> s + blooded animal + warm dog 4 + falcon 2 + cold fish 0 + spider 8 + Name: legs, dtype: int64 + + >>> s.sum() + 14 + + By default, the sum of an empty or all-NA Series is ``0``. + + >>> pd.Series([], dtype="float64").sum() # min_count=0 is the default + 0.0 + + This can be controlled with the ``min_count`` parameter. For example, if + you'd like the sum of an empty series to be NaN, pass ``min_count=1``. + + >>> pd.Series([], dtype="float64").sum(min_count=1) + nan + + Thanks to the ``skipna`` parameter, ``min_count`` handles all-NA and + empty series identically. + + >>> pd.Series([np.nan]).sum() + 0.0 + + >>> pd.Series([np.nan]).sum(min_count=1) + nan + """ + return NDFrame.sum( + self, + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + min_count=min_count, + **kwargs, + ) + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="prod") + def prod( + self, + axis: Axis | None = None, + skipna: bool = True, + numeric_only: bool = False, + min_count: int = 0, + **kwargs, + ): + """ + Return the product of the values over the requested axis. + + By default, missing values are skipped. To include them in the calculation, + set ``skipna`` parameter to False. + + Parameters + ---------- + axis : {index (0)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + .. warning:: + The behavior of DataFrame.prod with ``axis=None`` is deprecated, + in a future version this will reduce over both axes and return a scalar + To retain the old behavior, pass axis=0 (or do not pass axis). + + .. versionadded:: 2.0.0 + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. Not implemented for Series. + min_count : int, default 0 + The required number of valid values to perform the operation. If fewer than + ``min_count`` non-NA values are present the result will be NA. + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + scalar + Value containing the calculation referenced in the description. + + See Also + -------- + Series.sum : Return the sum. + Series.min : Return the minimum. + Series.max : Return the maximum. + Series.idxmin : Return the index of the minimum. + Series.idxmax : Return the index of the maximum. + + DataFrame.sum : Return the sum over the requested axis. + DataFrame.min : Return the minimum over the requested axis. + DataFrame.max : Return the maximum over the requested axis. + DataFrame.idxmin : Return the index of the minimum over the requested axis. + DataFrame.idxmax : Return the index of the maximum over the requested axis. + + Examples + -------- + By default, the product of an empty or all-NA Series is ``1`` + + >>> pd.Series([], dtype="float64").prod() + 1.0 + + This can be controlled with the ``min_count`` parameter + + >>> pd.Series([], dtype="float64").prod(min_count=1) + nan + + Thanks to the ``skipna`` parameter, ``min_count`` handles all-NA and + empty series identically. + + >>> pd.Series([np.nan]).prod() + 1.0 + >>> pd.Series([np.nan]).prod(min_count=1) + nan + """ + return NDFrame.prod( + self, + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + min_count=min_count, + **kwargs, + ) + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="mean") + def mean( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ) -> Any: + """ + Return the mean of the values over the requested axis. + + Parameters + ---------- + axis : {index (0)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + For DataFrames, specifying ``axis=None`` will apply the aggregation + across both axes. + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + scalar or Series (if level specified) + Mean of the values for the requested axis. + + See Also + -------- + numpy.median : Equivalent numpy function for computing median. + Series.sum : Sum of the values. + Series.median : Median of the values. + Series.std : Standard deviation of the values. + Series.var : Variance of the values. + Series.min : Minimum value. + Series.max : Maximum value. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.mean() + 2.0 + """ + return NDFrame.mean( + self, axis=axis, skipna=skipna, numeric_only=numeric_only, **kwargs + ) + + @deprecate_nonkeyword_arguments( + Pandas4Warning, allowed_args=["self"], name="median" + ) + def median( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ) -> Any: + """ + Return the median of the values over the requested axis. + + Parameters + ---------- + axis : {index (0)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + For DataFrames, specifying ``axis=None`` will apply the aggregation + across both axes. + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + scalar or Series (if level specified) + Median of the values for the requested axis. + + See Also + -------- + numpy.median : Equivalent numpy function for computing median. + Series.sum : Sum of the values. + Series.median : Median of the values. + Series.std : Standard deviation of the values. + Series.var : Variance of the values. + Series.min : Minimum value. + Series.max : Maximum value. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.median() + 2.0 + + With a DataFrame + + >>> df = pd.DataFrame({"a": [1, 2], "b": [2, 3]}, index=["tiger", "zebra"]) + >>> df + a b + tiger 1 2 + zebra 2 3 + >>> df.median() + a 1.5 + b 2.5 + dtype: float64 + + Using axis=1 + + >>> df.median(axis=1) + tiger 1.5 + zebra 2.5 + dtype: float64 + + In this case, `numeric_only` should be set to `True` + to avoid getting an error. + + >>> df = pd.DataFrame({"a": [1, 2], "b": ["T", "Z"]}, index=["tiger", "zebra"]) + >>> df.median(numeric_only=True) + a 1.5 + dtype: float64 + """ + return NDFrame.median( + self, axis=axis, skipna=skipna, numeric_only=numeric_only, **kwargs + ) + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="sem") + def sem( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, + ): + """ + Return unbiased standard error of the mean over requested axis. + + Normalized by N-1 by default. This can be changed using the ddof argument + + Parameters + ---------- + axis : {index (0)} + This parameter is unused and defaults to 0. + skipna : bool, default True + Exclude NA/null values. If an entire row/column is NA, the result + will be NA. + ddof : int, default 1 + Delta Degrees of Freedom. The divisor used in calculations is N - ddof, + where N represents the number of elements. + numeric_only : bool, default False + Include only float, int, boolean columns. Not implemented for Series. + **kwargs : + Additional keywords have no effect but might be accepted + for compatibility with NumPy. + + Returns + ------- + scalar or Series (if level specified) + Unbiased standard error of the mean over requested axis. + + See Also + -------- + scipy.stats.sem : Compute standard error of the mean. + Series.std : Return sample standard deviation over requested axis. + Series.var : Return unbiased variance over requested axis. + Series.mean : Return the mean of the values over the requested axis. + Series.median : Return the median of the values over the requested axis. + Series.mode : Return the mode(s) of the Series. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> round(s.sem(), 6) + 0.57735 + """ + return NDFrame.sem( + self, + axis=axis, + skipna=skipna, + ddof=ddof, + numeric_only=numeric_only, + **kwargs, + ) + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="var") + def var( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, + ): + """ + Return unbiased variance over requested axis. + + Normalized by N-1 by default. This can be changed using the ddof argument. + + Parameters + ---------- + axis : {index (0)} + For `Series` this parameter is unused and defaults to 0. + + .. warning:: + + The behavior of DataFrame.var with ``axis=None`` is deprecated, + in a future version this will reduce over both axes and return a scalar + To retain the old behavior, pass axis=0 (or do not pass axis). + + skipna : bool, default True + Exclude NA/null values. If an entire row/column is NA, the result + will be NA. + ddof : int, default 1 + Delta Degrees of Freedom. The divisor used in calculations is N - ddof, + where N represents the number of elements. + numeric_only : bool, default False + Include only float, int, boolean columns. Not implemented for Series. + **kwargs : + Additional keywords passed. + + Returns + ------- + scalar or Series (if level specified) + Unbiased variance over requested axis. + + See Also + -------- + numpy.var : Equivalent function in NumPy. + Series.std : Returns the standard deviation of the Series. + DataFrame.var : Returns the variance of the DataFrame. + DataFrame.std : Return standard deviation of the values over + the requested axis. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "person_id": [0, 1, 2, 3], + ... "age": [21, 25, 62, 43], + ... "height": [1.61, 1.87, 1.49, 2.01], + ... } + ... ).set_index("person_id") + >>> df + age height + person_id + 0 21 1.61 + 1 25 1.87 + 2 62 1.49 + 3 43 2.01 + + >>> df.var() + age 352.916667 + height 0.056367 + dtype: float64 + + Alternatively, ``ddof=0`` can be set to normalize by N instead of N-1: + + >>> df.var(ddof=0) + age 264.687500 + height 0.042275 + dtype: float64 + """ + return NDFrame.var( + self, + axis=axis, + skipna=skipna, + ddof=ddof, + numeric_only=numeric_only, + **kwargs, + ) + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="std") + def std( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, + ): + """ + Return sample standard deviation. + + Normalized by N-1 by default. This can be changed using the ddof argument. + + Parameters + ---------- + axis : {index (0)} + This parameter is unused and defaults to 0. + skipna : bool, default True + Exclude NA/null values. If Series is NA, the result + will be NA. + ddof : int, default 1 + Delta Degrees of Freedom. The divisor used in calculations is N - ddof, + where N represents the number of elements. + numeric_only : bool, default False + Not implemented for Series. + **kwargs : + Additional keywords have no effect but might be accepted + for compatibility with NumPy. + + Returns + ------- + scalar + Standard deviation over all values in the Series. + + See Also + -------- + numpy.std : Compute the standard deviation along the specified axis. + Series.var : Return unbiased variance over requested axis. + Series.sem : Return unbiased standard error of the mean over requested axis. + Series.mean : Return the mean of the values over the requested axis. + Series.median : Return the median of the values over the requested axis. + Series.mode : Return the mode(s) of the Series. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.std() + 1.0 + + Alternatively, ``ddof=0`` can be set to normalize by $N$ instead of $N-1$: + + >>> s.std(ddof=0) + 0.816496580927726 + """ + return NDFrame.std( + self, + axis=axis, + skipna=skipna, + ddof=ddof, + numeric_only=numeric_only, + **kwargs, + ) + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="skew") + def skew( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ): + """ + Return unbiased skew over requested axis. + + Normalized by N-1. + + Parameters + ---------- + axis : {index (0)} + This parameter is unused and defaults to 0. + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Unused. + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + scalar + Unbiased skew of the Series. + + See Also + -------- + + Series.var : Return unbiased variance over requested axis. + Series.std : Return unbiased standard deviation over requested axis. + + Examples + -------- + >>> s = pd.Series([1, 2, 3]) + >>> s.skew() + 0.0 + """ + return NDFrame.skew( + self, axis=axis, skipna=skipna, numeric_only=numeric_only, **kwargs + ) + + @deprecate_nonkeyword_arguments(Pandas4Warning, allowed_args=["self"], name="kurt") + def kurt( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, + ): + """ + Return unbiased kurtosis over requested axis. + + Kurtosis obtained using Fisher's definition of + kurtosis (kurtosis of normal == 0.0). Normalized by N-1. + + Parameters + ---------- + axis : {index (0)} + Axis for the function to be applied on. + For `Series` this parameter is unused and defaults to 0. + + For DataFrames, specifying ``axis=None`` will apply the aggregation + across both axes. + + .. versionadded:: 2.0.0 + + skipna : bool, default True + Exclude NA/null values when computing the result. + numeric_only : bool, default False + Include only float, int, boolean columns. + + **kwargs + Additional keyword arguments to be passed to the function. + + Returns + ------- + scalar + Unbiased kurtosis. + + See Also + -------- + Series.skew : Return unbiased skew over requested axis. + Series.var : Return unbiased variance over requested axis. + Series.std : Return unbiased standard deviation over requested axis. + + Examples + -------- + >>> s = pd.Series([1, 2, 2, 3], index=["cat", "dog", "dog", "mouse"]) + >>> s + cat 1 + dog 2 + dog 2 + mouse 3 + dtype: int64 + >>> s.kurt() + 1.5 + """ + return NDFrame.kurt( + self, axis=axis, skipna=skipna, numeric_only=numeric_only, **kwargs + ) + + kurtosis = kurt + product = prod + + def cummin(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Self: + """ + Return cumulative minimum over a Series. + + Returns a Series of the same size containing the cumulative + minimum. + + Parameters + ---------- + axis : {0 or 'index'}, default 0 + This parameter is unused and defaults to 0. + skipna : bool, default True + If the entire series is NA, the result will be NA. + *args, **kwargs + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + + Returns + ------- + Series + Return cumulative minimum of the Series. + + See Also + -------- + core.window.expanding.Expanding.min : Similar functionality + but ignores ``NaN`` values. + Series.min : Return the minimum value of the Series. + Series.cummax : Return cumulative maximum. + Series.cumsum : Return cumulative sum. + Series.cumprod : Return cumulative product. + + Examples + -------- + >>> s = pd.Series([2, np.nan, 5, -1, 0]) + >>> s + 0 2.0 + 1 NaN + 2 5.0 + 3 -1.0 + 4 0.0 + dtype: float64 + + By default, NA values are ignored. + + >>> s.cummin() + 0 2.0 + 1 NaN + 2 2.0 + 3 -1.0 + 4 -1.0 + dtype: float64 + + To include NA values in the operation, use ``skipna=False`` + + >>> s.cummin(skipna=False) + 0 2.0 + 1 NaN + 2 NaN + 3 NaN + 4 NaN + dtype: float64 + """ + return NDFrame.cummin(self, axis, skipna, *args, **kwargs) + + def cummax(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Self: + """ + Return cumulative maximum over a Series. + + Returns a Series of the same size containing the cumulative + maximum. + + Parameters + ---------- + axis : {0 or 'index'}, default 0 + This parameter is unused and defaults to 0. + skipna : bool, default True + Exclude NA/null values. If the series is NA, the result is NA. + *args, **kwargs + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + + Returns + ------- + Series + Return cumulative maximum of Series. + + See Also + -------- + core.window.expanding.Expanding.max : Similar functionality + but ignores ``NaN`` values. + Series.max : Return the maximum over a Series. + Series.cummin : Return cumulative minimum. + Series.cumsum : Return cumulative sum. + Series.cumprod : Return cumulative product. + + Examples + -------- + >>> s = pd.Series([2, np.nan, 5, -1, 0]) + >>> s + 0 2.0 + 1 NaN + 2 5.0 + 3 -1.0 + 4 0.0 + dtype: float64 + + By default, NA values are ignored. + + >>> s.cummax() + 0 2.0 + 1 NaN + 2 5.0 + 3 5.0 + 4 5.0 + dtype: float64 + + To include NA values in the operation, use ``skipna=False`` + + >>> s.cummax(skipna=False) + 0 2.0 + 1 NaN + 2 NaN + 3 NaN + 4 NaN + dtype: float64 + """ + return NDFrame.cummax(self, axis, skipna, *args, **kwargs) + + def cumsum(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Self: + """ + Return cumulative sum over a Series. + + Returns a Series of the same size containing the cumulative sum. + + Parameters + ---------- + axis : {0 or 'index'}, default 0 + This parameter is unused and defaults to 0. + skipna : bool, default True + Exclude NA/null values. If entire series is NA, the result will be NA. + *args, **kwargs + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + + Returns + ------- + Series + Return cumulative sum of Series. + + See Also + -------- + core.window.expanding.Expanding.sum : Similar functionality + but ignores ``NaN`` values. + Series.sum : Return the sum over Series. + Series.cummax : Return cumulative maximum. + Series.cummin : Return cumulative minimum. + Series.cumprod : Return cumulative product. + + Examples + -------- + >>> s = pd.Series([2, np.nan, 5, -1, 0]) + >>> s + 0 2.0 + 1 NaN + 2 5.0 + 3 -1.0 + 4 0.0 + dtype: float64 + + By default, NA values are ignored. + + >>> s.cumsum() + 0 2.0 + 1 NaN + 2 7.0 + 3 6.0 + 4 6.0 + dtype: float64 + + To include NA values in the operation, use ``skipna=False`` + + >>> s.cumsum(skipna=False) + 0 2.0 + 1 NaN + 2 NaN + 3 NaN + 4 NaN + dtype: float64 + """ + return NDFrame.cumsum(self, axis, skipna, *args, **kwargs) + + def cumprod(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Self: + """ + Return cumulative product over a Series. + + Returns a Series of the same size containing the cumulative + product. + + Parameters + ---------- + axis : {0 or 'index'}, default 0 + This parameter is unused and defaults to 0. + skipna : bool, default True + Exclude NA/null values. If entire Series is NA, the result will be NA. + *args, **kwargs + Additional keywords have no effect but might be accepted for + compatibility with NumPy. + + Returns + ------- + Series + Return cumulative product of Series. + + See Also + -------- + core.window.expanding.Expanding.prod : Similar functionality + but ignores ``NaN`` values. + Series.prod : Return the product over Series. + Series.cummax : Return cumulative maximum. + Series.cummin : Return cumulative minimum. + Series.cumsum : Return cumulative sum. + + Examples + -------- + >>> s = pd.Series([2, np.nan, 5, -1, 0]) + >>> s + 0 2.0 + 1 NaN + 2 5.0 + 3 -1.0 + 4 0.0 + dtype: float64 + + By default, NA values are ignored. + + >>> s.cumprod() + 0 2.0 + 1 NaN + 2 10.0 + 3 -10.0 + 4 -0.0 + dtype: float64 + + To include NA values in the operation, use ``skipna=False`` + + >>> s.cumprod(skipna=False) + 0 2.0 + 1 NaN + 2 NaN + 3 NaN + 4 NaN + dtype: float64 + """ + return NDFrame.cumprod(self, axis, skipna, *args, **kwargs) diff --git a/pandas/core/shared_docs.py b/pandas/core/shared_docs.py new file mode 100644 index 0000000000000000000000000000000000000000..3f91443b8cda655d07f9cf2dc1aedb8bcce873da --- /dev/null +++ b/pandas/core/shared_docs.py @@ -0,0 +1,639 @@ +from __future__ import annotations + +_shared_docs: dict[str, str] = {} + +_shared_docs["aggregate"] = """ +Aggregate using one or more operations over the specified axis. + +Parameters +---------- +func : function, str, list or dict + Function to use for aggregating the data. If a function, must either + work when passed a {klass} or when passed to {klass}.apply. + + Accepted combinations are: + + - function + - string function name + - list of functions and/or function names, e.g. ``[np.sum, 'mean']`` + - dict of axis labels -> functions, function names or list of such. +{axis} +*args + Positional arguments to pass to `func`. +**kwargs + Keyword arguments to pass to `func`. + +Returns +------- +scalar, Series or DataFrame + + The return can be: + + * scalar : when Series.agg is called with single function + * Series : when DataFrame.agg is called with a single function + * DataFrame : when DataFrame.agg is called with several functions +{see_also} +Notes +----- +The aggregation operations are always performed over an axis, either the +index (default) or the column axis. This behavior is different from +`numpy` aggregation functions (`mean`, `median`, `prod`, `sum`, `std`, +`var`), where the default is to compute the aggregation of the flattened +array, e.g., ``numpy.mean(arr_2d)`` as opposed to +``numpy.mean(arr_2d, axis=0)``. + +`agg` is an alias for `aggregate`. Use the alias. + +Functions that mutate the passed object can produce unexpected +behavior or errors and are not supported. See :ref:`gotchas.udf-mutation` +for more details. + +A passed user-defined-function will be passed a Series for evaluation. + +If ``func`` defines an index relabeling, ``axis`` must be ``0`` or ``index``. +{examples}""" + +_shared_docs["compare"] = """ +Compare to another {klass} and show the differences. + +Parameters +---------- +other : {klass} + Object to compare with. + +align_axis : {{0 or 'index', 1 or 'columns'}}, default 1 + Determine which axis to align the comparison on. + + * 0, or 'index' : Resulting differences are stacked vertically + with rows drawn alternately from self and other. + * 1, or 'columns' : Resulting differences are aligned horizontally + with columns drawn alternately from self and other. + +keep_shape : bool, default False + If true, all rows and columns are kept. + Otherwise, only the ones with different values are kept. + +keep_equal : bool, default False + If true, the result keeps values that are equal. + Otherwise, equal values are shown as NaNs. + +result_names : tuple, default ('self', 'other') + Set the dataframes names in the comparison. +""" + +_shared_docs["groupby"] = """ +Group %(klass)s using a mapper or by a Series of columns. + +A groupby operation involves some combination of splitting the +object, applying a function, and combining the results. This can be +used to group large amounts of data and compute operations on these +groups. + +Parameters +---------- +by : mapping, function, label, pd.Grouper or list of such + Used to determine the groups for the groupby. + If ``by`` is a function, it's called on each value of the object's + index. If a dict or Series is passed, the Series or dict VALUES + will be used to determine the groups (the Series' values are first + aligned; see ``.align()`` method). If a list or ndarray of length + equal to the selected axis is passed (see the `groupby user guide + `_), + the values are used as-is to determine the groups. A label or list + of labels may be passed to group by the columns in ``self``. + Notice that a tuple is interpreted as a (single) key. +level : int, level name, or sequence of such, default None + If the axis is a MultiIndex (hierarchical), group by a particular + level or levels. Do not specify both ``by`` and ``level``. +as_index : bool, default True + Return object with group labels as the + index. Only relevant for DataFrame input. as_index=False is + effectively "SQL-style" grouped output. This argument has no effect + on filtrations (see the `filtrations in the user guide + `_), + such as ``head()``, ``tail()``, ``nth()`` and in transformations + (see the `transformations in the user guide + `_). +sort : bool, default True + Sort group keys. Get better performance by turning this off. + Note this does not influence the order of observations within each + group. Groupby preserves the order of rows within each group. If False, + the groups will appear in the same order as they did in the original DataFrame. + This argument has no effect on filtrations (see the `filtrations in the user guide + `_), + such as ``head()``, ``tail()``, ``nth()`` and in transformations + (see the `transformations in the user guide + `_). + + .. versionchanged:: 2.0.0 + + Specifying ``sort=False`` with an ordered categorical grouper will no + longer sort the values. + +group_keys : bool, default True + When calling apply and the ``by`` argument produces a like-indexed + (i.e. :ref:`a transform `) result, add group keys to + index to identify pieces. By default group keys are not included + when the result's index (and column) labels match the inputs, and + are included otherwise. + + .. versionchanged:: 2.0.0 + + ``group_keys`` now defaults to ``True``. + +observed : bool, default True + This only applies if any of the groupers are Categoricals. + If True: only show observed values for categorical groupers. + If False: show all values for categorical groupers. + + .. versionchanged:: 3.0.0 + + The default value is now ``True``. + +dropna : bool, default True + If True, and if group keys contain NA values, NA values together + with row/column will be dropped. + If False, NA values will also be treated as the key in groups. + +Returns +------- +pandas.api.typing.%(klass)sGroupBy + Returns a groupby object that contains information about the groups. + +See Also +-------- +resample : Convenience method for frequency conversion and resampling + of time series. + +Notes +----- +See the `user guide +`__ for more +detailed usage and examples, including splitting an object into groups, +iterating through groups, selecting a group, aggregation, and more. + +The implementation of groupby is hash-based, meaning in particular that +objects that compare as equal will be considered to be in the same group. +An exception to this is that pandas has special handling of NA values: +any NA values will be collapsed to a single group, regardless of how +they compare. See the user guide linked above for more details. +""" + +_shared_docs["transform"] = """ +Call ``func`` on self producing a {klass} with the same axis shape as self. + +Parameters +---------- +func : function, str, list-like or dict-like + Function to use for transforming the data. If a function, must either + work when passed a {klass} or when passed to {klass}.apply. If func + is both list-like and dict-like, dict-like behavior takes precedence. + + Accepted combinations are: + + - function + - string function name + - list-like of functions and/or function names, e.g. ``[np.exp, 'sqrt']`` + - dict-like of axis labels -> functions, function names or list-like of such. +{axis} +*args + Positional arguments to pass to `func`. +**kwargs + Keyword arguments to pass to `func`. + +Returns +------- +{klass} + A {klass} that must have the same length as self. + +Raises +------ +ValueError : If the returned {klass} has a different length than self. + +See Also +-------- +{klass}.agg : Only perform aggregating type operations. +{klass}.apply : Invoke function on a {klass}. + +Notes +----- +Functions that mutate the passed object can produce unexpected +behavior or errors and are not supported. See :ref:`gotchas.udf-mutation` +for more details. + +Examples +-------- +>>> df = pd.DataFrame({{'A': range(3), 'B': range(1, 4)}}) +>>> df + A B +0 0 1 +1 1 2 +2 2 3 +>>> df.transform(lambda x: x + 1) + A B +0 1 2 +1 2 3 +2 3 4 + +Even though the resulting {klass} must have the same length as the +input {klass}, it is possible to provide several input functions: + +>>> s = pd.Series(range(3)) +>>> s +0 0 +1 1 +2 2 +dtype: int64 +>>> s.transform([np.sqrt, np.exp]) + sqrt exp +0 0.000000 1.000000 +1 1.000000 2.718282 +2 1.414214 7.389056 + +You can call transform on a GroupBy object: + +>>> df = pd.DataFrame({{ +... "Date": [ +... "2015-05-08", "2015-05-07", "2015-05-06", "2015-05-05", +... "2015-05-08", "2015-05-07", "2015-05-06", "2015-05-05"], +... "Data": [5, 8, 6, 1, 50, 100, 60, 120], +... }}) +>>> df + Date Data +0 2015-05-08 5 +1 2015-05-07 8 +2 2015-05-06 6 +3 2015-05-05 1 +4 2015-05-08 50 +5 2015-05-07 100 +6 2015-05-06 60 +7 2015-05-05 120 +>>> df.groupby('Date')['Data'].transform('sum') +0 55 +1 108 +2 66 +3 121 +4 55 +5 108 +6 66 +7 121 +Name: Data, dtype: int64 + +>>> df = pd.DataFrame({{ +... "c": [1, 1, 1, 2, 2, 2, 2], +... "type": ["m", "n", "o", "m", "m", "n", "n"] +... }}) +>>> df + c type +0 1 m +1 1 n +2 1 o +3 2 m +4 2 m +5 2 n +6 2 n +>>> df['size'] = df.groupby('c')['type'].transform(len) +>>> df + c type size +0 1 m 3 +1 1 n 3 +2 1 o 3 +3 2 m 4 +4 2 m 4 +5 2 n 4 +6 2 n 4 +""" + +_shared_docs["storage_options"] = """storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_.""" + +_shared_docs["compression_options"] = """compression : str or dict, default 'infer' + For on-the-fly compression of the output data. If 'infer' and '%s' is + path-like, then detect compression from the following extensions: '.gz', + '.bz2', '.zip', '.xz', '.zst', '.tar', '.tar.gz', '.tar.xz' or '.tar.bz2' + (otherwise no compression). + Set to ``None`` for no compression. + Can also be a dict with key ``'method'`` set + to one of {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} and + other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdCompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for faster compression and to create + a reproducible gzip archive: + ``compression={'method': 'gzip', 'compresslevel': 1, 'mtime': 1}``. + + .. versionadded:: 1.5.0 + Added support for `.tar` files.""" + +_shared_docs["decompression_options"] = """compression : str or dict, default 'infer' + For on-the-fly decompression of on-disk data. If 'infer' and '%s' is + path-like, then detect compression from the following extensions: '.gz', + '.bz2', '.zip', '.xz', '.zst', '.tar', '.tar.gz', '.tar.xz' or '.tar.bz2' + (otherwise no compression). + If using 'zip' or 'tar', the ZIP file must contain only one data file to be read in. + Set to ``None`` for no decompression. + Can also be a dict with key ``'method'`` set + to one of {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} and + other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdDecompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for Zstandard decompression using a + custom compression dictionary: + ``compression={'method': 'zstd', 'dict_data': my_compression_dict}``. + + .. versionadded:: 1.5.0 + Added support for `.tar` files.""" + +_shared_docs["replace"] = """ + Replace values given in `to_replace` with `value`. + + Values of the {klass} are replaced with other values dynamically. + This differs from updating with ``.loc`` or ``.iloc``, which require + you to specify a location to update with some value. + + Parameters + ---------- + to_replace : str, regex, list, dict, Series, int, float, or None + How to find the values that will be replaced. + + * numeric, str or regex: + + - numeric: numeric values equal to `to_replace` will be + replaced with `value` + - str: string exactly matching `to_replace` will be replaced + with `value` + - regex: regexes matching `to_replace` will be replaced with + `value` + + * list of str, regex, or numeric: + + - First, if `to_replace` and `value` are both lists, they + **must** be the same length. + - Second, if ``regex=True`` then all of the strings in **both** + lists will be interpreted as regexes otherwise they will match + directly. This doesn't matter much for `value` since there + are only a few possible substitution regexes you can use. + - str, regex and numeric rules apply as above. + + * dict: + + - Dicts can be used to specify different replacement values + for different existing values. For example, + ``{{'a': 'b', 'y': 'z'}}`` replaces the value 'a' with 'b' and + 'y' with 'z'. To use a dict in this way, the optional `value` + parameter should not be given. + - For a DataFrame a dict can specify that different values + should be replaced in different columns. For example, + ``{{'a': 1, 'b': 'z'}}`` looks for the value 1 in column 'a' + and the value 'z' in column 'b' and replaces these values + with whatever is specified in `value`. The `value` parameter + should not be ``None`` in this case. You can treat this as a + special case of passing two lists except that you are + specifying the column to search in. + - For a DataFrame nested dictionaries, e.g., + ``{{'a': {{'b': np.nan}}}}``, are read as follows: look in column + 'a' for the value 'b' and replace it with NaN. The optional `value` + parameter should not be specified to use a nested dict in this + way. You can nest regular expressions as well. Note that + column names (the top-level dictionary keys in a nested + dictionary) **cannot** be regular expressions. + + * None: + + - This means that the `regex` argument must be a string, + compiled regular expression, or list, dict, ndarray or + Series of such elements. If `value` is also ``None`` then + this **must** be a nested dictionary or Series. + + See the examples section for examples of each of these. + value : scalar, dict, list, str, regex, default None + Value to replace any values matching `to_replace` with. + For a DataFrame a dict of values can be used to specify which + value to use for each column (columns not in the dict will not be + filled). Regular expressions, strings and lists or dicts of such + objects are also allowed. + {inplace} + regex : bool or same types as `to_replace`, default False + Whether to interpret `to_replace` and/or `value` as regular + expressions. Alternatively, this could be a regular expression or a + list, dict, or array of regular expressions in which case + `to_replace` must be ``None``. + + Returns + ------- + {klass} + Object after replacement. + + Raises + ------ + AssertionError + * If `regex` is not a ``bool`` and `to_replace` is not + ``None``. + + TypeError + * If `to_replace` is not a scalar, array-like, ``dict``, or ``None`` + * If `to_replace` is a ``dict`` and `value` is not a ``list``, + ``dict``, ``ndarray``, or ``Series`` + * If `to_replace` is ``None`` and `regex` is not compilable + into a regular expression or is a list, dict, ndarray, or + Series. + * When replacing multiple ``bool`` or ``datetime64`` objects and + the arguments to `to_replace` does not match the type of the + value being replaced + + ValueError + * If a ``list`` or an ``ndarray`` is passed to `to_replace` and + `value` but they are not the same length. + + See Also + -------- + Series.fillna : Fill NA values. + DataFrame.fillna : Fill NA values. + Series.where : Replace values based on boolean condition. + DataFrame.where : Replace values based on boolean condition. + DataFrame.map: Apply a function to a Dataframe elementwise. + Series.map: Map values of Series according to an input mapping or function. + Series.str.replace : Simple string replacement. + + Notes + ----- + * Regex substitution is performed under the hood with ``re.sub``. The + rules for substitution for ``re.sub`` are the same. + * Regular expressions will only substitute on strings, meaning you + cannot provide, for example, a regular expression matching floating + point numbers and expect the columns in your frame that have a + numeric dtype to be matched. However, if those floating point + numbers *are* strings, then you can do this. + * This method has *a lot* of options. You are encouraged to experiment + and play with this method to gain intuition about how it works. + * When dict is used as the `to_replace` value, it is like + key(s) in the dict are the to_replace part and + value(s) in the dict are the value parameter. + + Examples + -------- + + **Scalar `to_replace` and `value`** + + >>> s = pd.Series([1, 2, 3, 4, 5]) + >>> s.replace(1, 5) + 0 5 + 1 2 + 2 3 + 3 4 + 4 5 + dtype: int64 + + >>> df = pd.DataFrame({{'A': [0, 1, 2, 3, 4], + ... 'B': [5, 6, 7, 8, 9], + ... 'C': ['a', 'b', 'c', 'd', 'e']}}) + >>> df.replace(0, 5) + A B C + 0 5 5 a + 1 1 6 b + 2 2 7 c + 3 3 8 d + 4 4 9 e + + **List-like `to_replace`** + + >>> df.replace([0, 1, 2, 3], 4) + A B C + 0 4 5 a + 1 4 6 b + 2 4 7 c + 3 4 8 d + 4 4 9 e + + >>> df.replace([0, 1, 2, 3], [4, 3, 2, 1]) + A B C + 0 4 5 a + 1 3 6 b + 2 2 7 c + 3 1 8 d + 4 4 9 e + + **dict-like `to_replace`** + + >>> df.replace({{0: 10, 1: 100}}) + A B C + 0 10 5 a + 1 100 6 b + 2 2 7 c + 3 3 8 d + 4 4 9 e + + >>> df.replace({{'A': 0, 'B': 5}}, 100) + A B C + 0 100 100 a + 1 1 6 b + 2 2 7 c + 3 3 8 d + 4 4 9 e + + >>> df.replace({{'A': {{0: 100, 4: 400}}}}) + A B C + 0 100 5 a + 1 1 6 b + 2 2 7 c + 3 3 8 d + 4 400 9 e + + **Regular expression `to_replace`** + + >>> df = pd.DataFrame({{'A': ['bat', 'foo', 'bait'], + ... 'B': ['abc', 'bar', 'xyz']}}) + >>> df.replace(to_replace=r'^ba.$', value='new', regex=True) + A B + 0 new abc + 1 foo new + 2 bait xyz + + >>> df.replace({{'A': r'^ba.$'}}, {{'A': 'new'}}, regex=True) + A B + 0 new abc + 1 foo bar + 2 bait xyz + + >>> df.replace(regex=r'^ba.$', value='new') + A B + 0 new abc + 1 foo new + 2 bait xyz + + >>> df.replace(regex={{r'^ba.$': 'new', 'foo': 'xyz'}}) + A B + 0 new abc + 1 xyz new + 2 bait xyz + + >>> df.replace(regex=[r'^ba.$', 'foo'], value='new') + A B + 0 new abc + 1 new new + 2 bait xyz + + Compare the behavior of ``s.replace({{'a': None}})`` and + ``s.replace('a', None)`` to understand the peculiarities + of the `to_replace` parameter: + + >>> s = pd.Series([10, 'a', 'a', 'b', 'a']) + + When one uses a dict as the `to_replace` value, it is like the + value(s) in the dict are equal to the `value` parameter. + ``s.replace({{'a': None}})`` is equivalent to + ``s.replace(to_replace={{'a': None}}, value=None)``: + + >>> s.replace({{'a': None}}) + 0 10 + 1 None + 2 None + 3 b + 4 None + dtype: object + + If ``None`` is explicitly passed for ``value``, it will be respected: + + >>> s.replace('a', None) + 0 10 + 1 None + 2 None + 3 b + 4 None + dtype: object + + When ``regex=True``, ``value`` is not ``None`` and `to_replace` is a string, + the replacement will be applied in all columns of the DataFrame. + + >>> df = pd.DataFrame({{'A': [0, 1, 2, 3, 4], + ... 'B': ['a', 'b', 'c', 'd', 'e'], + ... 'C': ['f', 'g', 'h', 'i', 'j']}}) + + >>> df.replace(to_replace='^[a-g]', value='e', regex=True) + A B C + 0 0 e e + 1 1 e e + 2 2 e h + 3 3 e i + 4 4 e j + + If ``value`` is not ``None`` and `to_replace` is a dictionary, the dictionary + keys will be the DataFrame columns that the replacement will be applied. + + >>> df.replace(to_replace={{'B': '^[a-c]', 'C': '^[h-j]'}}, value='e', regex=True) + A B C + 0 0 e f + 1 1 e g + 2 2 e e + 3 3 d e + 4 4 e e +""" diff --git a/pandas/core/sorting.py b/pandas/core/sorting.py new file mode 100644 index 0000000000000000000000000000000000000000..fecc2ca9e2e0ba13e3d98a2c29c7183a9b0e5fa5 --- /dev/null +++ b/pandas/core/sorting.py @@ -0,0 +1,736 @@ +"""miscellaneous sorting / groupby utilities""" + +from __future__ import annotations + +import itertools +from typing import ( + TYPE_CHECKING, + cast, +) + +import numpy as np + +from pandas._libs import ( + algos, + hashtable, + lib, +) +from pandas._libs.hashtable import unique_label_indices + +from pandas.core.dtypes.common import ( + ensure_int64, + ensure_platform_int, +) +from pandas.core.dtypes.generic import ( + ABCMultiIndex, + ABCRangeIndex, +) +from pandas.core.dtypes.missing import isna + +from pandas.core.construction import extract_array + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Hashable, + Sequence, + ) + + from pandas._typing import ( + ArrayLike, + AxisInt, + IndexKeyFunc, + Level, + NaPosition, + Shape, + SortKind, + npt, + ) + + from pandas import ( + MultiIndex, + Series, + ) + from pandas.core.arrays import ExtensionArray + from pandas.core.indexes.base import Index + + +def get_indexer_indexer( + target: Index, + level: Level | list[Level] | None, + ascending: list[bool] | bool, + kind: SortKind, + na_position: NaPosition, + sort_remaining: bool, + key: IndexKeyFunc, +) -> npt.NDArray[np.intp] | None: + """ + Helper method that return the indexer according to input parameters for + the sort_index method of DataFrame and Series. + + Parameters + ---------- + target : Index + level : int or level name or list of ints or list of level names + ascending : bool or list of bools, default True + kind : {'quicksort', 'mergesort', 'heapsort', 'stable'} + na_position : {'first', 'last'} + sort_remaining : bool + key : callable, optional + + Returns + ------- + Optional[ndarray[intp]] + The indexer for the new index. + """ + + # error: Incompatible types in assignment (expression has type + # "Union[ExtensionArray, ndarray[Any, Any], Index, Series]", variable has + # type "Index") + target = ensure_key_mapped(target, key, levels=level) # type: ignore[assignment] + target = target._sort_levels_monotonic() + + if level is not None: + _, indexer = target.sortlevel( + level, + ascending=ascending, + sort_remaining=sort_remaining, + na_position=na_position, + ) + elif (np.all(ascending) and target.is_monotonic_increasing) or ( + not np.any(ascending) and target.is_monotonic_decreasing + ): + # Check monotonic-ness before sort an index (GH 11080) + return None + elif isinstance(target, ABCMultiIndex): + codes = [lev.codes for lev in target._get_codes_for_sorting()] + indexer = lexsort_indexer( + codes, orders=ascending, na_position=na_position, codes_given=True + ) + else: + # ascending can only be a Sequence for MultiIndex + indexer = nargsort( + target, + kind=kind, + ascending=cast(bool, ascending), + na_position=na_position, + ) + return indexer + + +def get_group_index( + labels, shape: Shape, sort: bool, xnull: bool +) -> npt.NDArray[np.int64]: + """ + For the particular label_list, gets the offsets into the hypothetical list + representing the totally ordered cartesian product of all possible label + combinations, *as long as* this space fits within int64 bounds; + otherwise, though group indices identify unique combinations of + labels, they cannot be deconstructed. + - If `sort`, rank of returned ids preserve lexical ranks of labels. + i.e. returned id's can be used to do lexical sort on labels; + - If `xnull` nulls (-1 labels) are passed through. + + Parameters + ---------- + labels : sequence of arrays + Integers identifying levels at each location + shape : tuple[int, ...] + Number of unique levels at each location + sort : bool + If the ranks of returned ids should match lexical ranks of labels + xnull : bool + If true nulls are excluded. i.e. -1 values in the labels are + passed through. + + Returns + ------- + An array of type int64 where two elements are equal if their corresponding + labels are equal at all location. + + Notes + ----- + The length of `labels` and `shape` must be identical. + """ + + def _int64_cut_off(shape) -> int: + acc = 1 + for i, mul in enumerate(shape): + acc *= int(mul) + if not acc < lib.i8max: + return i + return len(shape) + + def maybe_lift(lab, size: int) -> tuple[np.ndarray, int]: + # promote nan values (assigned -1 label in lab array) + # so that all output values are non-negative + return (lab + 1, size + 1) if (lab == -1).any() else (lab, size) + + labels = [ensure_int64(x) for x in labels] + lshape = list(shape) + if not xnull: + for i, (lab, size) in enumerate(zip(labels, shape, strict=True)): + labels[i], lshape[i] = maybe_lift(lab, size) + + # Iteratively process all the labels in chunks sized so less + # than lib.i8max unique int ids will be required for each chunk + while True: + # how many levels can be done without overflow: + nlev = _int64_cut_off(lshape) + + # compute flat ids for the first `nlev` levels + stride = np.prod(lshape[1:nlev], dtype="i8") + out = stride * labels[0].astype("i8", subok=False, copy=False) + + for i in range(1, nlev): + if lshape[i] == 0: + stride = np.int64(0) + else: + stride //= lshape[i] + out += labels[i] * stride + + if xnull: # exclude nulls + mask = labels[0] == -1 + for lab in labels[1:nlev]: + mask |= lab == -1 + out[mask] = -1 + + if nlev == len(lshape): # all levels done! + break + + # compress what has been done so far in order to avoid overflow + # to retain lexical ranks, obs_ids should be sorted + comp_ids, obs_ids = compress_group_index(out, sort=sort) + + labels = [comp_ids, *labels[nlev:]] + lshape = [len(obs_ids), *lshape[nlev:]] + + return out + + +def get_compressed_ids( + labels, sizes: Shape +) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.int64]]: + """ + Group_index is offsets into cartesian product of all possible labels. This + space can be huge, so this function compresses it, by computing offsets + (comp_ids) into the list of unique labels (obs_group_ids). + + Parameters + ---------- + labels : list of label arrays + sizes : tuple[int] of size of the levels + + Returns + ------- + np.ndarray[np.intp] + comp_ids + np.ndarray[np.int64] + obs_group_ids + """ + ids = get_group_index(labels, sizes, sort=True, xnull=False) + return compress_group_index(ids, sort=True) + + +def is_int64_overflow_possible(shape: Shape) -> bool: + the_prod = 1 + for x in shape: + the_prod *= int(x) + + return the_prod >= lib.i8max + + +def _decons_group_index( + comp_labels: npt.NDArray[np.intp], shape: Shape +) -> list[npt.NDArray[np.intp]]: + # reconstruct labels + if is_int64_overflow_possible(shape): + # at some point group indices are factorized, + # and may not be deconstructed here! wrong path! + raise ValueError("cannot deconstruct factorized group indices!") + + label_list = [] + factor = 1 + y = np.array(0) + x = comp_labels + for i in reversed(range(len(shape))): + labels = (x - y) % (factor * shape[i]) // factor + np.putmask(labels, comp_labels < 0, -1) + label_list.append(labels) + y = labels * factor + factor *= shape[i] + return label_list[::-1] + + +def decons_obs_group_ids( + comp_ids: npt.NDArray[np.intp], + obs_ids: npt.NDArray[np.intp], + shape: Shape, + labels: Sequence[npt.NDArray[np.signedinteger]], + xnull: bool, +) -> list[npt.NDArray[np.intp]]: + """ + Reconstruct labels from observed group ids. + + Parameters + ---------- + comp_ids : np.ndarray[np.intp] + obs_ids: np.ndarray[np.intp] + shape : tuple[int] + labels : Sequence[np.ndarray[np.signedinteger]] + xnull : bool + If nulls are excluded; i.e. -1 labels are passed through. + """ + if not xnull: + lift = np.fromiter(((a == -1).any() for a in labels), dtype=np.intp) + arr_shape = np.asarray(shape, dtype=np.intp) + lift + shape = tuple(arr_shape) + + if not is_int64_overflow_possible(shape): + # obs ids are deconstructable! take the fast route! + out = _decons_group_index(obs_ids, shape) + return ( + out + if xnull or not lift.any() + else [x - y for x, y in zip(out, lift, strict=True)] + ) + + indexer = unique_label_indices(comp_ids) + return [lab[indexer].astype(np.intp, subok=False, copy=True) for lab in labels] + + +def lexsort_indexer( + keys: Sequence[ArrayLike | Index | Series], + orders=None, + na_position: str = "last", + key: Callable | None = None, + codes_given: bool = False, +) -> npt.NDArray[np.intp]: + """ + Performs lexical sorting on a set of keys + + Parameters + ---------- + keys : Sequence[ArrayLike | Index | Series] + Sequence of arrays to be sorted by the indexer + Sequence[Series] is only if key is not None. + orders : bool or list of booleans, optional + Determines the sorting order for each element in keys. If a list, + it must be the same length as keys. This determines whether the + corresponding element in keys should be sorted in ascending + (True) or descending (False) order. if bool, applied to all + elements as above. if None, defaults to True. + na_position : {'first', 'last'}, default 'last' + Determines placement of NA elements in the sorted list ("last" or "first") + key : Callable, optional + Callable key function applied to every element in keys before sorting + codes_given: bool, False + Avoid categorical materialization if codes are already provided. + + Returns + ------- + np.ndarray[np.intp] + """ + from pandas.core.arrays import Categorical + + if na_position not in ["last", "first"]: + raise ValueError(f"invalid na_position: {na_position}") + + if isinstance(orders, bool): + orders = itertools.repeat(orders, len(keys)) + elif orders is None: + orders = itertools.repeat(True, len(keys)) + else: + orders = reversed(orders) + + labels = [] + + for k, order in zip(reversed(keys), orders, strict=True): + k = ensure_key_mapped(k, key) + if codes_given: + codes = cast(np.ndarray, k) + n = codes.max() + 1 if len(codes) else 0 + else: + cat = Categorical(k, ordered=True) + codes = cat.codes + n = len(cat.categories) + + mask = codes == -1 + + if na_position == "last" and mask.any(): + codes = np.where(mask, n, codes) + + # not order means descending + if not order: + codes = np.where(mask, codes, n - codes - 1) + + labels.append(codes) + + return np.lexsort(labels) + + +def nargsort( + items: ArrayLike | Index | Series, + kind: SortKind = "quicksort", + ascending: bool = True, + na_position: str = "last", + key: Callable | None = None, + mask: npt.NDArray[np.bool_] | None = None, +) -> npt.NDArray[np.intp]: + """ + Intended to be a drop-in replacement for np.argsort which handles NaNs. + + Adds ascending, na_position, and key parameters. + + (GH #6399, #5231, #27237) + + Parameters + ---------- + items : np.ndarray, ExtensionArray, Index, or Series + kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}, default 'quicksort' + ascending : bool, default True + na_position : {'first', 'last'}, default 'last' + key : Optional[Callable], default None + mask : Optional[np.ndarray[bool]], default None + Passed when called by ExtensionArray.argsort. + + Returns + ------- + np.ndarray[np.intp] + """ + + if key is not None: + # see TestDataFrameSortKey, TestRangeIndex::test_sort_values_key + items = ensure_key_mapped(items, key) + return nargsort( + items, + kind=kind, + ascending=ascending, + na_position=na_position, + key=None, + mask=mask, + ) + + if isinstance(items, ABCRangeIndex): + return items.argsort(ascending=ascending) + elif not isinstance(items, ABCMultiIndex): + items = extract_array(items) + else: + raise TypeError( + "nargsort does not support MultiIndex. Use index.sort_values instead." + ) + + if mask is None: + mask = np.asarray(isna(items)) + + if not isinstance(items, np.ndarray): + # i.e. ExtensionArray + return items.argsort( + ascending=ascending, + kind=kind, + na_position=na_position, + ) + + idx = np.arange(len(items)) + non_nans = items[~mask] + non_nan_idx = idx[~mask] + + nan_idx = np.nonzero(mask)[0] + if not ascending: + non_nans = non_nans[::-1] + non_nan_idx = non_nan_idx[::-1] + indexer = non_nan_idx[non_nans.argsort(kind=kind)] + if not ascending: + indexer = indexer[::-1] + # Finally, place the NaNs at the end or the beginning according to + # na_position + if na_position == "last": + indexer = np.concatenate([indexer, nan_idx]) + elif na_position == "first": + indexer = np.concatenate([nan_idx, indexer]) + else: + raise ValueError(f"invalid na_position: {na_position}") + return ensure_platform_int(indexer) + + +def nargminmax(values: ExtensionArray, method: str, axis: AxisInt = 0): + """ + Implementation of np.argmin/argmax but for ExtensionArray and which + handles missing values. + + Parameters + ---------- + values : ExtensionArray + method : {"argmax", "argmin"} + axis : int, default 0 + + Returns + ------- + int + """ + assert method in {"argmax", "argmin"} + func = np.argmax if method == "argmax" else np.argmin + + mask = np.asarray(isna(values)) + arr_values = values._values_for_argsort() + + if arr_values.ndim > 1: + if mask.any(): + if axis == 1: + zipped = zip(arr_values, mask, strict=True) + else: + zipped = zip(arr_values.T, mask.T, strict=True) + return np.array([_nanargminmax(v, m, func) for v, m in zipped]) + return func(arr_values, axis=axis) + + return _nanargminmax(arr_values, mask, func) + + +def _nanargminmax(values: np.ndarray, mask: npt.NDArray[np.bool_], func) -> int: + """ + See nanargminmax.__doc__. + """ + idx = np.arange(values.shape[0]) + non_nans = values[~mask] + non_nan_idx = idx[~mask] + + return non_nan_idx[func(non_nans)] + + +def _ensure_key_mapped_multiindex( + index: MultiIndex, key: Callable, level=None +) -> MultiIndex: + """ + Returns a new MultiIndex in which key has been applied + to all levels specified in level (or all levels if level + is None). Used for key sorting for MultiIndex. + + Parameters + ---------- + index : MultiIndex + Index to which to apply the key function on the + specified levels. + key : Callable + Function that takes an Index and returns an Index of + the same shape. This key is applied to each level + separately. The name of the level can be used to + distinguish different levels for application. + level : list-like, int or str, default None + Level or list of levels to apply the key function to. + If None, key function is applied to all levels. Other + levels are left unchanged. + + Returns + ------- + labels : MultiIndex + Resulting MultiIndex with modified levels. + """ + + if level is not None: + if isinstance(level, (str, int)): + level_iter = [level] + else: + level_iter = level + + sort_levels: range | set = {index._get_level_number(lev) for lev in level_iter} + else: + sort_levels = range(index.nlevels) + + mapped = [ + ( + ensure_key_mapped(index._get_level_values(level), key) + if level in sort_levels + else index._get_level_values(level) + ) + for level in range(index.nlevels) + ] + + return type(index).from_arrays(mapped) + + +def ensure_key_mapped( + values: ArrayLike | Index | Series, key: Callable | None, levels=None +) -> ArrayLike | Index | Series: + """ + Applies a callable key function to the values function and checks + that the resulting value has the same shape. Can be called on Index + subclasses, Series, DataFrames, or ndarrays. + + Parameters + ---------- + values : Series, DataFrame, Index subclass, or ndarray + key : Optional[Callable], key to be called on the values array + levels : Optional[List], if values is a MultiIndex, list of levels to + apply the key to. + """ + from pandas.core.indexes.api import Index + + if not key: + return values + + if isinstance(values, ABCMultiIndex): + return _ensure_key_mapped_multiindex(values, key, level=levels) + + result = key(values.copy()) + if len(result) != len(values): + raise ValueError( + "User-provided `key` function must not change the shape of the array." + ) + + try: + if isinstance( + values, Index + ): # convert to a new Index subclass, not necessarily the same + result = Index(result, tupleize_cols=False) + else: + # try to revert to original type otherwise + type_of_values = type(values) + # error: Too many arguments for "ExtensionArray" + result = type_of_values(result) # type: ignore[call-arg] + except TypeError as err: + raise TypeError( + f"User-provided `key` function returned an invalid type {type(result)} \ + which could not be converted to {type(values)}." + ) from err + + return result + + +def get_indexer_dict( + label_list: list[np.ndarray], keys: list[Index] +) -> dict[Hashable, npt.NDArray[np.intp]]: + """ + Returns + ------- + dict: + Labels mapped to indexers. + """ + shape = tuple(len(x) for x in keys) + + group_index = get_group_index(label_list, shape, sort=True, xnull=True) + if np.all(group_index == -1): + # Short-circuit, lib.indices_fast will return the same + return {} + ngroups = ( + ((group_index.size and group_index.max()) + 1) + if is_int64_overflow_possible(shape) + else np.prod(shape, dtype="i8") + ) + + sorter = get_group_index_sorter(group_index, ngroups) + + sorted_labels = [lab.take(sorter) for lab in label_list] + group_index = group_index.take(sorter) + + return lib.indices_fast(sorter, group_index, keys, sorted_labels) + + +# ---------------------------------------------------------------------- +# sorting levels...cleverly? + + +def get_group_index_sorter( + group_index: npt.NDArray[np.intp], ngroups: int | None = None +) -> npt.NDArray[np.intp]: + """ + algos.groupsort_indexer implements `counting sort` and it is at least + O(ngroups), where + ngroups = prod(shape) + shape = map(len, keys) + that is, linear in the number of combinations (cartesian product) of unique + values of groupby keys. This can be huge when doing multi-key groupby. + np.argsort(kind='mergesort') is O(count x log(count)) where count is the + length of the data-frame; + Both algorithms are `stable` sort and that is necessary for correctness of + groupby operations. e.g. consider: + df.groupby(key)[col].transform('first') + + Parameters + ---------- + group_index : np.ndarray[np.intp] + signed integer dtype + ngroups : int or None, default None + + Returns + ------- + np.ndarray[np.intp] + """ + if ngroups is None: + ngroups = 1 + group_index.max() + count = len(group_index) + alpha = 0.0 # taking complexities literally; there may be + beta = 1.0 # some room for fine-tuning these parameters + do_groupsort = count > 0 and ((alpha + beta * ngroups) < (count * np.log(count))) + if do_groupsort: + sorter, _ = algos.groupsort_indexer( + ensure_platform_int(group_index), + ngroups, + ) + # sorter _should_ already be intp, but mypy is not yet able to verify + else: + sorter = group_index.argsort(kind="mergesort") + return ensure_platform_int(sorter) + + +def compress_group_index( + group_index: npt.NDArray[np.int64], sort: bool = True +) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]: + """ + Group_index is offsets into cartesian product of all possible labels. This + space can be huge, so this function compresses it, by computing offsets + (comp_ids) into the list of unique labels (obs_group_ids). + """ + if len(group_index) and np.all(group_index[1:] >= group_index[:-1]): + # GH 53806: fast path for sorted group_index + unique_mask = np.concatenate( + [group_index[:1] > -1, group_index[1:] != group_index[:-1]] + ) + comp_ids = unique_mask.cumsum() + comp_ids -= 1 + obs_group_ids = group_index[unique_mask] + else: + size_hint = len(group_index) + table = hashtable.Int64HashTable(size_hint) + + group_index = ensure_int64(group_index) + + # note, group labels come out ascending (ie, 1,2,3 etc) + comp_ids, obs_group_ids = table.get_labels_groupby(group_index) + + if sort and len(obs_group_ids) > 0: + obs_group_ids, comp_ids = _reorder_by_uniques(obs_group_ids, comp_ids) + + return ensure_int64(comp_ids), ensure_int64(obs_group_ids) + + +def _reorder_by_uniques( + uniques: npt.NDArray[np.int64], labels: npt.NDArray[np.intp] +) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.intp]]: + """ + Parameters + ---------- + uniques : np.ndarray[np.int64] + labels : np.ndarray[np.intp] + + Returns + ------- + np.ndarray[np.int64] + np.ndarray[np.intp] + """ + # sorter is index where elements ought to go + sorter = uniques.argsort() + + # reverse_indexer is where elements came from + reverse_indexer = np.empty(len(sorter), dtype=np.intp) + reverse_indexer.put(sorter, np.arange(len(sorter))) + + mask = labels < 0 + + # move labels to right locations (ie, unsort ascending labels) + labels = reverse_indexer.take(labels) + np.putmask(labels, mask, -1) + + # sort observed ids + uniques = uniques.take(sorter) + + return uniques, labels diff --git a/pandas/errors/__init__.py b/pandas/errors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18edb61c8a3f3e2202b6823c86091b20a5f30b6b --- /dev/null +++ b/pandas/errors/__init__.py @@ -0,0 +1,1087 @@ +""" +Expose public exceptions & warnings +""" + +from __future__ import annotations + +import abc +import ctypes + +from pandas._config.config import OptionError + +from pandas._libs.tslibs import ( + IncompatibleFrequency, + OutOfBoundsDatetime, + OutOfBoundsTimedelta, +) + +from pandas.util.version import InvalidVersion + + +class IntCastingNaNError(ValueError): + """ + Exception raised when converting (``astype``) an array with NaN to an integer type. + + This error occurs when attempting to cast a data structure containing non-finite + values (such as NaN or infinity) to an integer data type. Integer types do not + support non-finite values, so such conversions are explicitly disallowed to + prevent silent data corruption or unexpected behavior. + + See Also + -------- + DataFrame.astype : Method to cast a pandas DataFrame object to a specified dtype. + Series.astype : Method to cast a pandas Series object to a specified dtype. + + Examples + -------- + >>> pd.DataFrame(np.array([[1, np.nan], [2, 3]]), dtype="i8") + Traceback (most recent call last): + IntCastingNaNError: Cannot convert non-finite values (NA or inf) to integer + """ + + +class NullFrequencyError(ValueError): + """ + Exception raised when a ``freq`` cannot be null. + + Particularly ``DatetimeIndex.shift``, ``TimedeltaIndex.shift``, + ``PeriodIndex.shift``. + + See Also + -------- + Index.shift : Shift values of Index. + Series.shift : Shift values of Series. + + Examples + -------- + >>> df = pd.DatetimeIndex(["2011-01-01 10:00", "2011-01-01"], freq=None) + >>> df.shift(2) + Traceback (most recent call last): + NullFrequencyError: Cannot shift with no freq + """ + + +class PerformanceWarning(Warning): + """ + Warning raised when there is a possible performance impact. + + See Also + -------- + DataFrame.set_index : Set the DataFrame index using existing columns. + DataFrame.loc : Access a group of rows and columns by label(s) \ + or a boolean array. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"jim": [0, 0, 1, 1], "joe": ["x", "x", "z", "y"], "jolie": [1, 2, 3, 4]} + ... ) + >>> df = df.set_index(["jim", "joe"]) + >>> df + jolie + jim joe + 0 x 1 + x 2 + 1 z 3 + y 4 + >>> df.loc[(1, "z")] # doctest: +SKIP + # PerformanceWarning: indexing past lexsort depth may impact performance. + df.loc[(1, 'z')] + jolie + jim joe + 1 z 3 + """ + + +class PandasChangeWarning(Warning): + """ + Warning raised for any upcoming change. + + See Also + -------- + errors.PandasPendingDeprecationWarning : Class for deprecations that will raise a + PendingDeprecationWarning. + errors.PandasDeprecationWarning : Class for deprecations that will raise a + DeprecationWarning. + errors.PandasFutureWarning : Class for deprecations that will raise a FutureWarning. + + Examples + -------- + >>> pd.errors.PandasChangeWarning + + """ + + @classmethod + @abc.abstractmethod + def version(cls) -> str: + """Version where change will be enforced.""" + + +class PandasPendingDeprecationWarning(PandasChangeWarning, PendingDeprecationWarning): + """ + Warning raised for an upcoming change that is a PendingDeprecationWarning. + + See Also + -------- + errors.PandasChangeWarning: Class for deprecations that will raise any warning. + errors.PandasDeprecationWarning : Class for deprecations that will raise a + DeprecationWarning. + errors.PandasFutureWarning : Class for deprecations that will raise a FutureWarning. + + Examples + -------- + >>> pd.errors.PandasPendingDeprecationWarning + + """ + + +class PandasDeprecationWarning(PandasChangeWarning, DeprecationWarning): + """ + Warning raised for an upcoming change that is a DeprecationWarning. + + See Also + -------- + errors.PandasChangeWarning: Class for deprecations that will raise any warning. + errors.PandasPendingDeprecationWarning : Class for deprecations that will raise a + PendingDeprecationWarning. + errors.PandasFutureWarning : Class for deprecations that will raise a FutureWarning. + + Examples + -------- + >>> pd.errors.PandasDeprecationWarning + + """ + + +class PandasFutureWarning(PandasChangeWarning, FutureWarning): + """ + Warning raised for an upcoming change that is a FutureWarning. + + See Also + -------- + errors.PandasChangeWarning: Class for deprecations that will raise any warning. + errors.PandasPendingDeprecationWarning : Class for deprecations that will raise a + PendingDeprecationWarning. + errors.PandasDeprecationWarning : Class for deprecations that will raise a + DeprecationWarning. + + Examples + -------- + >>> pd.errors.PandasFutureWarning + + """ + + +class Pandas4Warning(PandasDeprecationWarning): + """ + Warning raised for an upcoming change that will be enforced in pandas 4.0. + + See Also + -------- + errors.PandasChangeWarning: Class for deprecations that will raise any warning. + errors.PandasPendingDeprecationWarning : Class for deprecations that will raise a + PendingDeprecationWarning. + errors.PandasDeprecationWarning : Class for deprecations that will raise a + DeprecationWarning. + errors.PandasFutureWarning : Class for deprecations that will raise a FutureWarning. + + Examples + -------- + >>> pd.errors.Pandas4Warning + + """ + + @classmethod + def version(cls) -> str: + """Version where change will be enforced.""" + return "4.0" + + +class Pandas5Warning(PandasPendingDeprecationWarning): + """ + Warning raised for an upcoming change that will be enforced in pandas 5.0. + + See Also + -------- + errors.PandasChangeWarning: Class for deprecations that will raise any warning. + errors.PandasPendingDeprecationWarning : Class for deprecations that will raise a + PendingDeprecationWarning. + errors.PandasDeprecationWarning : Class for deprecations that will raise a + DeprecationWarning. + errors.PandasFutureWarning : Class for deprecations that will raise a FutureWarning. + + Examples + -------- + >>> pd.errors.Pandas5Warning + + """ + + @classmethod + def version(cls) -> str: + """Version where change will be enforced.""" + return "5.0" + + +_CurrentDeprecationWarning = Pandas4Warning + + +class UnsupportedFunctionCall(ValueError): + """ + Exception raised when attempting to call a unsupported numpy function. + + For example, ``np.cumsum(groupby_object)``. + + See Also + -------- + DataFrame.groupby : Group DataFrame using a mapper or by a Series of columns. + Series.groupby : Group Series using a mapper or by a Series of columns. + core.groupby.GroupBy.cumsum : Compute cumulative sum for each group. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"A": [0, 0, 1, 1], "B": ["x", "x", "z", "y"], "C": [1, 2, 3, 4]} + ... ) + >>> np.cumsum(df.groupby(["A"])) + Traceback (most recent call last): + UnsupportedFunctionCall: numpy operations are not valid with groupby. + Use .groupby(...).cumsum() instead + """ + + +class UnsortedIndexError(KeyError): + """ + Error raised when slicing a MultiIndex which has not been lexsorted. + + Subclass of `KeyError`. + + See Also + -------- + DataFrame.sort_index : Sort a DataFrame by its index. + DataFrame.set_index : Set the DataFrame index using existing columns. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "cat": [0, 0, 1, 1], + ... "color": ["white", "white", "brown", "black"], + ... "lives": [4, 4, 3, 7], + ... }, + ... ) + >>> df = df.set_index(["cat", "color"]) + >>> df + lives + cat color + 0 white 4 + white 4 + 1 brown 3 + black 7 + >>> df.loc[(0, "black") : (1, "white")] + Traceback (most recent call last): + UnsortedIndexError: 'Key length (2) was greater + than MultiIndex lexsort depth (1)' + """ + + +class ParserError(ValueError): + """ + Exception that is raised by an error encountered in parsing file contents. + + This is a generic error raised for errors encountered when functions like + `read_csv` or `read_html` are parsing contents of a file. + + See Also + -------- + read_csv : Read CSV (comma-separated) file into a DataFrame. + read_html : Read HTML table into a DataFrame. + + Examples + -------- + >>> data = '''a,b,c + ... cat,foo,bar + ... dog,foo,"baz''' + >>> from io import StringIO + >>> pd.read_csv(StringIO(data), skipfooter=1, engine="python") + Traceback (most recent call last): + ParserError: ',' expected after '"'. Error could possibly be due + to parsing errors in the skipped footer rows + """ + + +class DtypeWarning(Warning): + """ + Warning raised when reading different dtypes in a column from a file. + + Raised for a dtype incompatibility. This can happen whenever `read_csv` + or `read_table` encounter non-uniform dtypes in a column(s) of a given + CSV file. + + See Also + -------- + read_csv : Read CSV (comma-separated) file into a DataFrame. + read_table : Read general delimited file into a DataFrame. + + Notes + ----- + This warning is issued when dealing with larger files because the dtype + checking happens per chunk read. + + Despite the warning, the CSV file is read with mixed types in a single + column which will be an object type. See the examples below to better + understand this issue. + + Examples + -------- + This example creates and reads a large CSV file with a column that contains + `int` and `str`. + + >>> df = pd.DataFrame( + ... { + ... "a": (["1"] * 100000 + ["X"] * 100000 + ["1"] * 100000), + ... "b": ["b"] * 300000, + ... } + ... ) # doctest: +SKIP + >>> df.to_csv("test.csv", index=False) # doctest: +SKIP + >>> df2 = pd.read_csv("test.csv") # doctest: +SKIP + ... # DtypeWarning: Columns (0: a) have mixed types + + Important to notice that ``df2`` will contain both `str` and `int` for the + same input, '1'. + + >>> df2.iloc[262140, 0] # doctest: +SKIP + '1' + >>> type(df2.iloc[262140, 0]) # doctest: +SKIP + + >>> df2.iloc[262150, 0] # doctest: +SKIP + 1 + >>> type(df2.iloc[262150, 0]) # doctest: +SKIP + + + One way to solve this issue is using the `dtype` parameter in the + `read_csv` and `read_table` functions to explicit the conversion: + + >>> df2 = pd.read_csv("test.csv", sep=",", dtype={"a": str}) # doctest: +SKIP + + No warning was issued. + """ + + +class EmptyDataError(ValueError): + """ + Exception raised in ``pd.read_csv`` when empty data or header is encountered. + + This error is typically encountered when attempting to read an empty file or + an invalid file where no data or headers are present. + + See Also + -------- + read_csv : Read a comma-separated values (CSV) file into DataFrame. + errors.ParserError : Exception that is raised by an error encountered in parsing + file contents. + errors.DtypeWarning : Warning raised when reading different dtypes in a column + from a file. + + Examples + -------- + >>> from io import StringIO + >>> empty = StringIO() + >>> pd.read_csv(empty) + Traceback (most recent call last): + EmptyDataError: No columns to parse from file + """ + + +class ParserWarning(Warning): + """ + Warning raised when reading a file that doesn't use the default 'c' parser. + + Raised by `pd.read_csv` and `pd.read_table` when it is necessary to change + parsers, generally from the default 'c' parser to 'python'. + + It happens due to a lack of support or functionality for parsing a + particular attribute of a CSV file with the requested engine. + + Currently, 'c' unsupported options include the following parameters: + + 1. `sep` other than a single character (e.g. regex separators) + 2. `skipfooter` higher than 0 + + The warning can be avoided by adding `engine='python'` as a parameter in + `pd.read_csv` and `pd.read_table` methods. + + See Also + -------- + pd.read_csv : Read CSV (comma-separated) file into DataFrame. + pd.read_table : Read general delimited file into DataFrame. + + Examples + -------- + Using a `sep` in `pd.read_csv` other than a single character: + + >>> import io + >>> csv = '''a;b;c + ... 1;1,8 + ... 1;2,1''' + >>> df = pd.read_csv(io.StringIO(csv), sep="[;,]") # doctest: +SKIP + ... # ParserWarning: Falling back to the 'python' engine... + + Adding `engine='python'` to `pd.read_csv` removes the Warning: + + >>> df = pd.read_csv(io.StringIO(csv), sep="[;,]", engine="python") + """ + + +class MergeError(ValueError): + """ + Exception raised when merging data. + + Subclass of ``ValueError``. + + See Also + -------- + DataFrame.join : For joining DataFrames on their indexes. + merge : For merging two DataFrames on a common set of keys. + + Examples + -------- + >>> left = pd.DataFrame( + ... {"a": ["a", "b", "b", "d"], "b": ["cat", "dog", "weasel", "horse"]}, + ... index=range(4), + ... ) + >>> right = pd.DataFrame( + ... {"a": ["a", "b", "c", "d"], "c": ["meow", "bark", "chirp", "nay"]}, + ... index=range(4), + ... ).set_index("a") + >>> left.join( + ... right, + ... on="a", + ... validate="one_to_one", + ... ) + Traceback (most recent call last): + MergeError: Merge keys are not unique in left dataset; not a one-to-one merge + """ + + +class AbstractMethodError(NotImplementedError): + """ + Raise this error instead of NotImplementedError for abstract methods. + + The `AbstractMethodError` is designed for use in classes that follow an abstract + base class pattern. By raising this error in the method, it ensures that a subclass + must implement the method to provide specific functionality. This is useful in a + framework or library where certain methods must be implemented by the user to + ensure correct behavior. + + Parameters + ---------- + class_instance : object + The instance of the class where the abstract method is being called. + methodtype : str, default "method" + A string indicating the type of method that is abstract. + Must be one of {"method", "classmethod", "staticmethod", "property"}. + + See Also + -------- + api.extensions.ExtensionArray + An example of a pandas extension mechanism that requires implementing + specific abstract methods. + NotImplementedError + A built-in exception that can also be used for abstract methods but lacks + the specificity of `AbstractMethodError` in indicating the need for subclass + implementation. + + Examples + -------- + >>> class Foo: + ... @classmethod + ... def classmethod(cls): + ... raise pd.errors.AbstractMethodError(cls, methodtype="classmethod") + ... + ... def method(self): + ... raise pd.errors.AbstractMethodError(self) + >>> test = Foo.classmethod() + Traceback (most recent call last): + AbstractMethodError: This classmethod must be defined in the concrete class Foo + + >>> test2 = Foo().method() + Traceback (most recent call last): + AbstractMethodError: This classmethod must be defined in the concrete class Foo + """ + + def __init__(self, class_instance, methodtype: str = "method") -> None: + types = {"method", "classmethod", "staticmethod", "property"} + if methodtype not in types: + raise ValueError( + f"methodtype must be one of {types}, got {methodtype} instead." + ) + self.methodtype = methodtype + self.class_instance = class_instance + + def __str__(self) -> str: + if self.methodtype == "classmethod": + name = self.class_instance.__name__ + else: + name = type(self.class_instance).__name__ + return f"This {self.methodtype} must be defined in the concrete class {name}" + + +class NumbaUtilError(Exception): + """ + Error raised for unsupported Numba engine routines. + + See Also + -------- + DataFrame.groupby : Group DataFrame using a mapper or by a Series of columns. + Series.groupby : Group Series using a mapper or by a Series of columns. + DataFrame.agg : Aggregate using one or more operations over the specified axis. + Series.agg : Aggregate using one or more operations over the specified axis. + + Examples + -------- + >>> df = pd.DataFrame( + ... {"key": ["a", "a", "b", "b"], "data": [1, 2, 3, 4]}, columns=["key", "data"] + ... ) + >>> def incorrect_function(x): + ... return sum(x) * 2.7 + >>> df.groupby("key").agg(incorrect_function, engine="numba") + Traceback (most recent call last): + NumbaUtilError: The first 2 arguments to incorrect_function + must be ['values', 'index'] + """ + + +class DuplicateLabelError(ValueError): + """ + Error raised when an operation would introduce duplicate labels. + + This error is typically encountered when performing operations on objects + with `allows_duplicate_labels=False` and the operation would result in + duplicate labels in the index. Duplicate labels can lead to ambiguities + in indexing and reduce data integrity. + + See Also + -------- + Series.set_flags : Return a new ``Series`` object with updated flags. + DataFrame.set_flags : Return a new ``DataFrame`` object with updated flags. + Series.reindex : Conform ``Series`` object to new index with optional filling logic. + DataFrame.reindex : Conform ``DataFrame`` object to new index with optional filling + logic. + + Examples + -------- + >>> s = pd.Series([0, 1, 2], index=["a", "b", "c"]).set_flags( + ... allows_duplicate_labels=False + ... ) + >>> s.reindex(["a", "a", "b"]) + Traceback (most recent call last): + ... + DuplicateLabelError: Index has duplicates. + positions + label + a [0, 1] + """ + + +class InvalidIndexError(Exception): + """ + Exception raised when attempting to use an invalid index key. + + This exception is triggered when a user attempts to access or manipulate + data in a pandas DataFrame or Series using an index key that is not valid + for the given object. This may occur in cases such as using a malformed + slice, a mismatched key for a ``MultiIndex``, or attempting to access an index + element that does not exist. + + See Also + -------- + MultiIndex : A multi-level, or hierarchical, index object for pandas objects. + + Examples + -------- + >>> idx = pd.MultiIndex.from_product([["x", "y"], [0, 1]]) + >>> df = pd.DataFrame([[1, 1, 2, 2], [3, 3, 4, 4]], columns=idx) + >>> df + x y + 0 1 0 1 + 0 1 1 2 2 + 1 3 3 4 4 + >>> df[:, 0] + Traceback (most recent call last): + InvalidIndexError: (slice(None, None, None), 0) + """ + + +class DataError(Exception): + """ + Exception raised when performing an operation on non-numerical data. + + For example, calling ``ohlc`` on a non-numerical column or a function + on a rolling window. + + See Also + -------- + Series.rolling : Provide rolling window calculations on Series object. + DataFrame.rolling : Provide rolling window calculations on DataFrame object. + + Examples + -------- + >>> ser = pd.Series(["a", "b", "c"]) + >>> ser.rolling(2).sum() + Traceback (most recent call last): + DataError: No numeric types to aggregate + """ + + +class SpecificationError(Exception): + """ + Exception raised by ``agg`` when the functions are ill-specified. + + The exception raised in two scenarios. + + The first way is calling ``agg`` on a + Dataframe or Series using a nested renamer (dict-of-dict). + + The second way is calling ``agg`` on a Dataframe with duplicated functions + names without assigning column name. + + See Also + -------- + DataFrame.agg : Aggregate using one or more operations over the specified axis. + Series.agg : Aggregate using one or more operations over the specified axis. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 1, 1, 2, 2], "B": range(5), "C": range(5)}) + >>> df.groupby("A").B.agg({"foo": "count"}) # doctest: +SKIP + ... # SpecificationError: nested renamer is not supported + + >>> df.groupby("A").agg({"B": {"foo": ["sum", "max"]}}) # doctest: +SKIP + ... # SpecificationError: nested renamer is not supported + + >>> df.groupby("A").agg(["min", "min"]) # doctest: +SKIP + ... # SpecificationError: nested renamer is not supported + """ + + +class ChainedAssignmentError(Warning): + """ + Warning raised when trying to set using chained assignment. + + With Copy-on-Write now always enabled, chained assignment can + never work. In such a situation, we are always setting into a temporary + object that is the result of an indexing operation (getitem), which under + Copy-on-Write always behaves as a copy. Thus, assigning through a chain + can never update the original Series or DataFrame. + + For more information on Copy-on-Write, + see :ref:`the user guide`. + + See Also + -------- + DataFrame.loc : Access a group of rows and columns by label(s) or a boolean array. + DataFrame.iloc : Purely integer-location based indexing for selection by position. + Series.loc : Access a group of rows by label(s) or a boolean array. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 1, 1, 2, 2]}, columns=["A"]) + >>> df["A"][0:3] = 10 # doctest: +SKIP + ... # ChainedAssignmentError: ... + """ + + +class NumExprClobberingError(NameError): + """ + Exception raised when trying to use a built-in numexpr name as a variable name. + + ``eval`` or ``query`` will throw the error if the engine is set + to 'numexpr'. 'numexpr' is the default engine value for these methods if the + numexpr package is installed. + + See Also + -------- + eval : Evaluate a Python expression as a string using various backends. + DataFrame.query : Query the columns of a DataFrame with a boolean expression. + + Examples + -------- + >>> df = pd.DataFrame({"abs": [1, 1, 1]}) + >>> df.query("abs > 2") # doctest: +SKIP + ... # NumExprClobberingError: Variables in expression "(abs) > (2)" overlap... + >>> sin, a = 1, 2 + >>> pd.eval("sin + a", engine="numexpr") # doctest: +SKIP + ... # NumExprClobberingError: Variables in expression "(sin) + (a)" overlap... + """ + + +class UndefinedVariableError(NameError): + """ + Exception raised by ``query`` or ``eval`` when using an undefined variable name. + + It will also specify whether the undefined variable is local or not. + + Parameters + ---------- + name : str + The name of the undefined variable. + is_local : bool or None, optional + Indicates whether the undefined variable is considered a local variable. + If ``True``, the error message specifies it as a local variable. + If ``False`` or ``None``, the variable is treated as a non-local name. + + See Also + -------- + DataFrame.query : Query the columns of a DataFrame with a boolean expression. + DataFrame.eval : Evaluate a string describing operations on DataFrame columns. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 1, 1]}) + >>> df.query("A > x") # doctest: +SKIP + ... # UndefinedVariableError: name 'x' is not defined + >>> df.query("A > @y") # doctest: +SKIP + ... # UndefinedVariableError: local variable 'y' is not defined + >>> pd.eval("x + 1") # doctest: +SKIP + ... # UndefinedVariableError: name 'x' is not defined + """ + + def __init__(self, name: str, is_local: bool | None = None) -> None: + base_msg = f"{name!r} is not defined" + if is_local: + msg = f"local variable {base_msg}" + else: + msg = f"name {base_msg}" + super().__init__(msg) + + +class IndexingError(Exception): + """ + Exception is raised when trying to index and there is a mismatch in dimensions. + + Raised by properties like :attr:`.pandas.DataFrame.iloc` when + an indexer is out of bounds or :attr:`.pandas.DataFrame.loc` when its index is + unalignable to the frame index. + + See Also + -------- + DataFrame.iloc : Purely integer-location based indexing for \ + selection by position. + DataFrame.loc : Access a group of rows and columns by label(s) \ + or a boolean array. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 1, 1]}) + >>> df.loc[..., ..., "A"] # doctest: +SKIP + ... # IndexingError: indexer may only contain one '...' entry + >>> df = pd.DataFrame({"A": [1, 1, 1]}) + >>> df.loc[1, ..., ...] # doctest: +SKIP + ... # IndexingError: Too many indexers + >>> df[pd.Series([True], dtype=bool)] # doctest: +SKIP + ... # IndexingError: Unalignable boolean Series provided as indexer... + >>> s = pd.Series(range(2), index=pd.MultiIndex.from_product([["a", "b"], ["c"]])) + >>> s.loc["a", "c", "d"] # doctest: +SKIP + ... # IndexingError: Too many indexers + """ + + +class PyperclipException(RuntimeError): + """ + Exception raised when clipboard functionality is unsupported. + + Raised by ``to_clipboard()`` and ``read_clipboard()``. + """ + + +class PyperclipWindowsException(PyperclipException): + """ + Exception raised when clipboard functionality is unsupported by Windows. + + Access to the clipboard handle would be denied due to some other + window process is accessing it. + """ + + def __init__(self, message: str) -> None: + # attr only exists on Windows, so typing fails on other platforms + message += f" ({ctypes.WinError()})" # type: ignore[attr-defined] + super().__init__(message) + + +class CSSWarning(UserWarning): + """ + Warning is raised when converting css styling fails. + + This can be due to the styling not having an equivalent value or because the + styling isn't properly formatted. + + See Also + -------- + DataFrame.style : Returns a Styler object for applying CSS-like styles. + io.formats.style.Styler : Helps style a DataFrame or Series according to the + data with HTML and CSS. + io.formats.style.Styler.to_excel : Export styled DataFrame to Excel. + io.formats.style.Styler.to_html : Export styled DataFrame to HTML. + + Examples + -------- + >>> df = pd.DataFrame({"A": [1, 1, 1]}) + >>> df.style.map(lambda x: "background-color: blueGreenRed;").to_excel( + ... "styled.xlsx" + ... ) # doctest: +SKIP + CSSWarning: Unhandled color format: 'blueGreenRed' + >>> df.style.map(lambda x: "border: 1px solid red red;").to_excel( + ... "styled.xlsx" + ... ) # doctest: +SKIP + CSSWarning: Unhandled color format: 'blueGreenRed' + """ + + +class PossibleDataLossError(Exception): + """ + Exception raised when trying to open an HDFStore file when already opened. + + This error is triggered when there is a potential risk of data loss due to + conflicting operations on an HDFStore file. It serves to prevent unintended + overwrites or data corruption by enforcing exclusive access to the file. + + See Also + -------- + HDFStore : Dict-like IO interface for storing pandas objects in PyTables. + HDFStore.open : Open an HDFStore file in the specified mode. + + Examples + -------- + >>> store = pd.HDFStore("my-store", "a") # doctest: +SKIP + >>> store.open("w") # doctest: +SKIP + """ + + +class ClosedFileError(Exception): + """ + Exception is raised when trying to perform an operation on a closed HDFStore file. + + ``ClosedFileError`` is specific to operations on ``HDFStore`` objects. Once an + HDFStore is closed, its resources are no longer available, and any further attempt + to access data or perform file operations will raise this exception. + + See Also + -------- + HDFStore.close : Closes the PyTables file handle. + HDFStore.open : Opens the file in the specified mode. + HDFStore.is_open : Returns a boolean indicating whether the file is open. + + Examples + -------- + >>> store = pd.HDFStore("my-store", "a") # doctest: +SKIP + >>> store.close() # doctest: +SKIP + >>> store.keys() # doctest: +SKIP + ... # ClosedFileError: my-store file is not open! + """ + + +class IncompatibilityWarning(Warning): + """ + Warning raised when trying to use where criteria on an incompatible HDF5 file. + """ + + +class AttributeConflictWarning(Warning): + """ + Warning raised when index attributes conflict when using HDFStore. + + Occurs when attempting to append an index with a different + name than the existing index on an HDFStore or attempting to append an index with a + different frequency than the existing index on an HDFStore. + + See Also + -------- + HDFStore : Dict-like IO interface for storing pandas objects in PyTables. + DataFrame.to_hdf : Write the contained data to an HDF5 file using HDFStore. + read_hdf : Read from an HDF5 file into a DataFrame. + + Examples + -------- + >>> idx1 = pd.Index(["a", "b"], name="name1") + >>> df1 = pd.DataFrame([[1, 2], [3, 4]], index=idx1) + >>> df1.to_hdf("file", "data", "w", append=True) # doctest: +SKIP + >>> idx2 = pd.Index(["c", "d"], name="name2") + >>> df2 = pd.DataFrame([[5, 6], [7, 8]], index=idx2) + >>> df2.to_hdf("file", "data", "a", append=True) # doctest: +SKIP + AttributeConflictWarning: the [index_name] attribute of the existing index is + [name1] which conflicts with the new [name2]... + """ + + +class DatabaseError(OSError): + """ + Error is raised when executing SQL with bad syntax or SQL that throws an error. + + Raised by :func:`.pandas.read_sql` when a bad SQL statement is passed in. + + See Also + -------- + read_sql : Read SQL query or database table into a DataFrame. + + Examples + -------- + >>> from sqlite3 import connect + >>> conn = connect(":memory:") + >>> pd.read_sql("select * test", conn) # doctest: +SKIP + """ + + +class PossiblePrecisionLoss(Warning): + """ + Warning raised by to_stata on a column with a value outside or equal to int64. + + When the column value is outside or equal to the int64 value the column is + converted to a float64 dtype. + + See Also + -------- + DataFrame.to_stata : Export DataFrame object to Stata dta format. + + Examples + -------- + >>> df = pd.DataFrame({"s": pd.Series([1, 2**53], dtype=np.int64)}) + >>> df.to_stata("test") # doctest: +SKIP + """ + + +class ValueLabelTypeMismatch(Warning): + """ + Warning raised by to_stata on a category column that contains non-string values. + + When exporting data to Stata format using the `to_stata` method, category columns + must have string values as labels. If a category column contains non-string values + (e.g., integers, floats, or other types), this warning is raised to indicate that + the Stata file may not correctly represent the data. + + See Also + -------- + DataFrame.to_stata : Export DataFrame object to Stata dta format. + Series.cat : Accessor for categorical properties of the Series values. + + Examples + -------- + >>> df = pd.DataFrame({"categories": pd.Series(["a", 2], dtype="category")}) + >>> df.to_stata("test") # doctest: +SKIP + """ + + +class InvalidColumnName(Warning): + """ + Warning raised by to_stata the column contains a non-valid stata name. + + Because the column name is an invalid Stata variable, the name needs to be + converted. + + See Also + -------- + DataFrame.to_stata : Export DataFrame object to Stata dta format. + + Examples + -------- + >>> df = pd.DataFrame({"0categories": pd.Series([2, 2])}) + >>> df.to_stata("test") # doctest: +SKIP + """ + + +class CategoricalConversionWarning(Warning): + """ + Warning is raised when reading a partial labeled Stata file using an iterator. + + This warning helps ensure data integrity and alerts users to potential issues + during the incremental reading of Stata files with labeled data, allowing for + additional checks and adjustments as necessary. + + See Also + -------- + read_stata : Read a Stata file into a DataFrame. + Categorical : Represents a categorical variable in pandas. + + Examples + -------- + >>> from pandas.io.stata import StataReader + >>> with StataReader("dta_file", chunksize=2) as reader: # doctest: +SKIP + ... for i, block in enumerate(reader): + ... print(i, block) + ... # CategoricalConversionWarning: One or more series with value labels... + """ + + +class LossySetitemError(Exception): + """ + Raised when trying to do a __setitem__ on an np.ndarray that is not lossless. + + Notes + ----- + This is an internal error. + """ + + +class NoBufferPresent(Exception): + """ + Exception is raised in _get_data_buffer to signal that there is no requested buffer. + """ + + +class InvalidComparison(Exception): + """ + Exception is raised by _validate_comparison_value to indicate an invalid comparison. + + Notes + ----- + This is an internal error. + """ + + +__all__ = [ + "AbstractMethodError", + "AttributeConflictWarning", + "CSSWarning", + "CategoricalConversionWarning", + "ChainedAssignmentError", + "ClosedFileError", + "DataError", + "DatabaseError", + "DtypeWarning", + "DuplicateLabelError", + "EmptyDataError", + "IncompatibilityWarning", + "IncompatibleFrequency", + "IndexingError", + "IntCastingNaNError", + "InvalidColumnName", + "InvalidComparison", + "InvalidIndexError", + "InvalidVersion", + "LossySetitemError", + "MergeError", + "NoBufferPresent", + "NullFrequencyError", + "NumExprClobberingError", + "NumbaUtilError", + "OptionError", + "OutOfBoundsDatetime", + "OutOfBoundsTimedelta", + "Pandas4Warning", + "Pandas5Warning", + "PandasChangeWarning", + "PandasDeprecationWarning", + "PandasFutureWarning", + "PandasPendingDeprecationWarning", + "ParserError", + "ParserWarning", + "PerformanceWarning", + "PossibleDataLossError", + "PossiblePrecisionLoss", + "PyperclipException", + "PyperclipWindowsException", + "SpecificationError", + "UndefinedVariableError", + "UnsortedIndexError", + "UnsupportedFunctionCall", + "ValueLabelTypeMismatch", +] diff --git a/pandas/errors/cow.py b/pandas/errors/cow.py new file mode 100644 index 0000000000000000000000000000000000000000..8516c33b9d9dcc85e3aeb1bd74068e7a298c9c68 --- /dev/null +++ b/pandas/errors/cow.py @@ -0,0 +1,43 @@ +_chained_assignment_msg = ( + "A value is being set on a copy of a DataFrame or Series " + "through chained assignment.\n" + "Such chained assignment never works to update the original DataFrame or " + "Series, because the intermediate object on which we are setting values " + "always behaves as a copy (due to Copy-on-Write).\n\n" + "Try using '.loc[row_indexer, col_indexer] = value' instead, to perform " + "the assignment in a single step.\n\n" + "See the documentation for a more detailed explanation: " + "https://pandas.pydata.org/pandas-docs/stable/user_guide/" + "copy_on_write.html#chained-assignment" +) + + +_chained_assignment_method_msg = ( + "A value is being set on a copy of a DataFrame or Series " + "through chained assignment using an inplace method.\n" + "Such inplace method never works to update the original DataFrame or Series, " + "because the intermediate object on which we are setting values always " + "behaves as a copy (due to Copy-on-Write).\n\n" + "For example, when doing 'df[col].method(value, inplace=True)', try " + "using 'df.method({col: value}, inplace=True)' instead, to perform " + "the operation inplace on the original object, or try to avoid an inplace " + "operation using 'df[col] = df[col].method(value)'.\n\n" + "See the documentation for a more detailed explanation: " + "https://pandas.pydata.org/pandas-docs/stable/user_guide/" + "copy_on_write.html" +) + + +_chained_assignment_method_update_msg = ( + "A value is being set on a copy of a DataFrame or Series " + "through chained assignment using an inplace method.\n" + "Such inplace method never works to update the original DataFrame or Series, " + "because the intermediate object on which we are setting values always " + "behaves as a copy (due to Copy-on-Write).\n\n" + "For example, when doing 'df[col].update(other)', try " + "using 'df.update({col: other})' instead, to perform " + "the operation inplace on the original object.\n\n" + "See the documentation for a more detailed explanation: " + "https://pandas.pydata.org/pandas-docs/stable/user_guide/" + "copy_on_write.html" +) diff --git a/pandas/io/__init__.py b/pandas/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1c7e531debb1426624186453b622cfccd11d44ef --- /dev/null +++ b/pandas/io/__init__.py @@ -0,0 +1,13 @@ +# ruff: noqa: TC004 +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + # import modules that have public classes/functions + from pandas.io import ( + formats, + json, + stata, + ) + + # mark only those modules as public + __all__ = ["formats", "json", "stata"] diff --git a/pandas/io/_util.py b/pandas/io/_util.py new file mode 100644 index 0000000000000000000000000000000000000000..da9ac3913cbbd45b80c47987dd5b3c523da8e8b4 --- /dev/null +++ b/pandas/io/_util.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Literal, +) + +import numpy as np + +from pandas._config import using_string_dtype + +from pandas._libs import lib +from pandas.compat import ( + pa_version_under18p0, + pa_version_under19p0, +) +from pandas.compat._optional import import_optional_dependency + +from pandas.core.dtypes.common import pandas_dtype + +import pandas as pd + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Hashable, + Sequence, + ) + + import pyarrow + + from pandas._typing import ( + DtypeArg, + DtypeBackend, + ) + + +def _arrow_dtype_mapping() -> dict: + pa = import_optional_dependency("pyarrow") + return { + pa.int8(): pd.Int8Dtype(), + pa.int16(): pd.Int16Dtype(), + pa.int32(): pd.Int32Dtype(), + pa.int64(): pd.Int64Dtype(), + pa.uint8(): pd.UInt8Dtype(), + pa.uint16(): pd.UInt16Dtype(), + pa.uint32(): pd.UInt32Dtype(), + pa.uint64(): pd.UInt64Dtype(), + pa.bool_(): pd.BooleanDtype(), + pa.string(): pd.StringDtype(), + pa.float32(): pd.Float32Dtype(), + pa.float64(): pd.Float64Dtype(), + pa.string(): pd.StringDtype(), + pa.large_string(): pd.StringDtype(), + } + + +def _arrow_string_types_mapper() -> Callable: + pa = import_optional_dependency("pyarrow") + + mapping = { + pa.string(): pd.StringDtype(na_value=np.nan), + pa.large_string(): pd.StringDtype(na_value=np.nan), + } + if not pa_version_under18p0: + mapping[pa.string_view()] = pd.StringDtype(na_value=np.nan) + + return mapping.get + + +def arrow_table_to_pandas( + table: pyarrow.Table, + dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault = lib.no_default, + null_to_int64: bool = False, + to_pandas_kwargs: dict | None = None, + dtype: DtypeArg | None = None, + names: Sequence[Hashable] | None = None, +) -> pd.DataFrame: + pa = import_optional_dependency("pyarrow") + + to_pandas_kwargs = {} if to_pandas_kwargs is None else to_pandas_kwargs + + types_mapper: type[pd.ArrowDtype] | None | Callable + if dtype_backend == "numpy_nullable": + mapping = _arrow_dtype_mapping() + if null_to_int64: + # Modify the default mapping to also map null to Int64 + # (to match other engines - only for CSV parser) + mapping[pa.null()] = pd.Int64Dtype() + types_mapper = mapping.get + elif dtype_backend == "pyarrow": + types_mapper = pd.ArrowDtype + elif using_string_dtype(): + if pa_version_under19p0: + types_mapper = _arrow_string_types_mapper() + elif dtype is not None: + # GH#56136 Avoid lossy conversion to float64 + # We'll convert to numpy below if + types_mapper = { + pa.int8(): pd.Int8Dtype(), + pa.int16(): pd.Int16Dtype(), + pa.int32(): pd.Int32Dtype(), + pa.int64(): pd.Int64Dtype(), + }.get + else: + types_mapper = None + elif dtype_backend is lib.no_default or dtype_backend == "numpy": + if dtype is not None: + # GH#56136 Avoid lossy conversion to float64 + # We'll convert to numpy below if + types_mapper = { + pa.int8(): pd.Int8Dtype(), + pa.int16(): pd.Int16Dtype(), + pa.int32(): pd.Int32Dtype(), + pa.int64(): pd.Int64Dtype(), + }.get + else: + types_mapper = None + else: + raise NotImplementedError + + df = table.to_pandas(types_mapper=types_mapper, **to_pandas_kwargs) + return _post_convert_dtypes(df, dtype_backend, dtype, names) + + +def _post_convert_dtypes( + df: pd.DataFrame, + dtype_backend: DtypeBackend | Literal["numpy"] | lib.NoDefault, + dtype: DtypeArg | None, + names: Sequence[Hashable] | None, +) -> pd.DataFrame: + if dtype is not None and ( + dtype_backend is lib.no_default or dtype_backend == "numpy" + ): + # GH#56136 apply any user-provided dtype, and convert any IntegerDtype + # columns the user didn't explicitly ask for. + if isinstance(dtype, dict): + if names is not None: + df.columns = names + + cmp_dtypes = { + pd.Int8Dtype(), + pd.Int16Dtype(), + pd.Int32Dtype(), + pd.Int64Dtype(), + } + for col in df.columns: + if col not in dtype and df[col].dtype in cmp_dtypes: + # Any key that the user didn't explicitly specify + # that got converted to IntegerDtype now gets converted + # to numpy dtype. + dtype[col] = df[col].dtype.numpy_dtype + + # Ignore non-existent columns from dtype mapping + # like other parsers do + dtype = { + key: pandas_dtype(dtype[key]) for key in dtype if key in df.columns + } + + else: + dtype = pandas_dtype(dtype) + + try: + df = df.astype(dtype) + except TypeError as err: + # GH#44901 reraise to keep api consistent + raise ValueError(str(err)) from err + + if ( + not using_string_dtype() + and dtype != "str" + and (dtype_backend is lib.no_default or dtype_backend == "numpy") + ): + # Convert any StringDtype columns back to object dtype (pyarrow always + # uses string dtype even when the infer_string option is False) + for col, dtype in zip(df.columns, df.dtypes, strict=True): + if isinstance(dtype, pd.StringDtype) and dtype.na_value is np.nan: + df[col] = df[col].astype("object").fillna(None) + if isinstance(dtype, pd.CategoricalDtype): + cat_dtype = dtype.categories.dtype + if ( + isinstance(cat_dtype, pd.StringDtype) + and cat_dtype.na_value is np.nan + ): + cat_dtype = pd.CategoricalDtype( + categories=dtype.categories.astype("object"), + ordered=dtype.ordered, + ) + df[col] = df[col].astype(cat_dtype) + + return df diff --git a/pandas/io/api.py b/pandas/io/api.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9f38999f787cdc96f3162934bfdcba251ccf5d --- /dev/null +++ b/pandas/io/api.py @@ -0,0 +1,65 @@ +""" +Data I/O API +""" + +from pandas.io.clipboards import read_clipboard +from pandas.io.excel import ( + ExcelFile, + ExcelWriter, + read_excel, +) +from pandas.io.feather_format import read_feather +from pandas.io.html import read_html +from pandas.io.iceberg import read_iceberg +from pandas.io.json import read_json +from pandas.io.orc import read_orc +from pandas.io.parquet import read_parquet +from pandas.io.parsers import ( + read_csv, + read_fwf, + read_table, +) +from pandas.io.pickle import ( + read_pickle, + to_pickle, +) +from pandas.io.pytables import ( + HDFStore, + read_hdf, +) +from pandas.io.sas import read_sas +from pandas.io.spss import read_spss +from pandas.io.sql import ( + read_sql, + read_sql_query, + read_sql_table, +) +from pandas.io.stata import read_stata +from pandas.io.xml import read_xml + +__all__ = [ + "ExcelFile", + "ExcelWriter", + "HDFStore", + "read_clipboard", + "read_csv", + "read_excel", + "read_feather", + "read_fwf", + "read_hdf", + "read_html", + "read_iceberg", + "read_json", + "read_orc", + "read_parquet", + "read_pickle", + "read_sas", + "read_spss", + "read_sql", + "read_sql_query", + "read_sql_table", + "read_stata", + "read_table", + "read_xml", + "to_pickle", +] diff --git a/pandas/io/clipboards.py b/pandas/io/clipboards.py new file mode 100644 index 0000000000000000000000000000000000000000..9a562481f0e98726a67acedb3b7f48183d676057 --- /dev/null +++ b/pandas/io/clipboards.py @@ -0,0 +1,200 @@ +"""io on the clipboard""" + +from __future__ import annotations + +from io import StringIO +from typing import TYPE_CHECKING +import warnings + +from pandas._libs import lib +from pandas.util._decorators import set_module +from pandas.util._exceptions import find_stack_level +from pandas.util._validators import check_dtype_backend + +from pandas.core.dtypes.generic import ABCDataFrame + +from pandas import ( + get_option, + option_context, +) + +if TYPE_CHECKING: + from pandas._typing import DtypeBackend + + +@set_module("pandas") +def read_clipboard( + sep: str = r"\s+", + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, + **kwargs, +): # pragma: no cover + r""" + Read text from clipboard and pass to :func:`~pandas.read_csv`. + + Parses clipboard contents similar to how CSV files are parsed + using :func:`~pandas.read_csv`. + + Parameters + ---------- + sep : str, default '\\s+' + A string or regex delimiter. The default of ``'\\s+'`` denotes + one or more whitespace characters. + + dtype_backend : {'numpy_nullable', 'pyarrow'} + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). If not specified, the default behavior + is to not use nullable data types. If specified, the behavior + is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + * ``"pyarrow"``: returns pyarrow-backed nullable + :class:`ArrowDtype` :class:`DataFrame` + + .. versionadded:: 2.0 + + **kwargs + See :func:`~pandas.read_csv` for the full argument list. + + Returns + ------- + DataFrame + A parsed :class:`~pandas.DataFrame` object. + + See Also + -------- + DataFrame.to_clipboard : Copy object to the system clipboard. + read_csv : Read a comma-separated values (csv) file into DataFrame. + read_fwf : Read a table of fixed-width formatted lines into DataFrame. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=["A", "B", "C"]) + >>> df.to_clipboard() # doctest: +SKIP + >>> pd.read_clipboard() # doctest: +SKIP + A B C + 0 1 2 3 + 1 4 5 6 + """ + encoding = kwargs.pop("encoding", "utf-8") + + # only utf-8 is valid for passed value because that's what clipboard + # supports + if encoding is not None and encoding.lower().replace("-", "") != "utf8": + raise NotImplementedError("reading from clipboard only supports utf-8 encoding") + + check_dtype_backend(dtype_backend) + + from pandas.io.clipboard import clipboard_get + from pandas.io.parsers import read_csv + + text = clipboard_get() + + # Try to decode (if needed, as "text" might already be a string here). + try: + text = text.decode(kwargs.get("encoding") or get_option("display.encoding")) + except AttributeError: + pass + + # Excel copies into clipboard with \t separation + # inspect no more then the 10 first lines, if they + # all contain an equal number (>0) of tabs, infer + # that this came from excel and set 'sep' accordingly + lines = text[:10000].split("\n")[:-1][:10] + + # Need to remove leading white space, since read_csv + # accepts: + # a b + # 0 1 2 + # 1 3 4 + + counts = {x.lstrip(" ").count("\t") for x in lines} + if len(lines) > 1 and len(counts) == 1 and counts.pop() != 0: + sep = "\t" + # check the number of leading tabs in the first line + # to account for index columns + index_length = len(lines[0]) - len(lines[0].lstrip(" \t")) + if index_length != 0: + kwargs.setdefault("index_col", list(range(index_length))) + + elif not isinstance(sep, str): + raise ValueError(f"{sep=} must be a string") + + # Regex separator currently only works with python engine. + # Default to python if separator is multi-character (regex) + if len(sep) > 1 and kwargs.get("engine") is None: + kwargs["engine"] = "python" + elif len(sep) > 1 and kwargs.get("engine") == "c": + warnings.warn( + "read_clipboard with regex separator does not work properly with c engine.", + stacklevel=find_stack_level(), + ) + + return read_csv(StringIO(text), sep=sep, dtype_backend=dtype_backend, **kwargs) + + +def to_clipboard( + obj, excel: bool | None = True, sep: str | None = None, **kwargs +) -> None: # pragma: no cover + """ + Attempt to write text representation of object to the system clipboard + The clipboard can be then pasted into Excel for example. + + Parameters + ---------- + obj : the object to write to the clipboard + excel : bool, defaults to True + if True, use the provided separator, writing in a csv + format for allowing easy pasting into excel. + if False, write a string representation of the object + to the clipboard + sep : optional, defaults to tab + other keywords are passed to to_csv + + Notes + ----- + Requirements for your platform + - Linux: xclip, or xsel (with PyQt4 modules) + - Windows: + - OS X: + """ + encoding = kwargs.pop("encoding", "utf-8") + + # testing if an invalid encoding is passed to clipboard + if encoding is not None and encoding.lower().replace("-", "") != "utf8": + raise ValueError("clipboard only supports utf-8 encoding") + + from pandas.io.clipboard import clipboard_set + + if excel is None: + excel = True + + if excel: + try: + if sep is None: + sep = "\t" + buf = StringIO() + + # clipboard_set (pyperclip) expects unicode + obj.to_csv(buf, sep=sep, encoding="utf-8", **kwargs) + text = buf.getvalue() + + clipboard_set(text) + return + except TypeError: + warnings.warn( + "to_clipboard in excel mode requires a single character separator.", + stacklevel=find_stack_level(), + ) + elif sep is not None: + warnings.warn( + "to_clipboard with excel=False ignores the sep argument.", + stacklevel=find_stack_level(), + ) + + if isinstance(obj, ABCDataFrame): + # str(df) has various unhelpful defaults, like truncation + with option_context("display.max_colwidth", None): + objstr = obj.to_string(**kwargs) + else: + objstr = str(obj) + clipboard_set(objstr) diff --git a/pandas/io/common.py b/pandas/io/common.py new file mode 100644 index 0000000000000000000000000000000000000000..04f4f9f604786e8e8b3eab852c51e212cf50b9c8 --- /dev/null +++ b/pandas/io/common.py @@ -0,0 +1,1327 @@ +"""Common I/O API utilities""" + +from __future__ import annotations + +from abc import ( + ABC, + abstractmethod, +) +import codecs +from collections import defaultdict +from collections.abc import ( + Hashable, + Mapping, + Sequence, +) +import dataclasses +import functools +import gzip +from io import ( + BufferedIOBase, + BytesIO, + RawIOBase, + StringIO, + TextIOBase, + TextIOWrapper, +) +import mmap +import os +from pathlib import Path +import re +import tarfile +from typing import ( + IO, + TYPE_CHECKING, + Any, + AnyStr, + DefaultDict, + Generic, + Literal, + TypeVar, + cast, + overload, +) +from urllib.parse import ( + urljoin, + urlparse as parse_url, + uses_netloc, + uses_params, + uses_relative, +) +import warnings +import zipfile + +from pandas._typing import ( + BaseBuffer, + ReadCsvBuffer, +) +from pandas.compat._optional import import_optional_dependency +from pandas.util._exceptions import find_stack_level + +from pandas.core.dtypes.common import ( + is_bool, + is_file_like, + is_integer, + is_list_like, +) +from pandas.core.dtypes.generic import ABCMultiIndex + +_VALID_URLS = set(uses_relative + uses_netloc + uses_params) +_VALID_URLS.discard("") +_FSSPEC_URL_PATTERN = re.compile(r"^[A-Za-z][A-Za-z0-9+\-+.]*(::[A-Za-z0-9+\-+.]+)*://") + +BaseBufferT = TypeVar("BaseBufferT", bound=BaseBuffer) + + +if TYPE_CHECKING: + from types import TracebackType + + from pandas._typing import ( + CompressionDict, + CompressionOptions, + FilePath, + ReadBuffer, + StorageOptions, + WriteBuffer, + ) + + from pandas import MultiIndex + + +@dataclasses.dataclass +class IOArgs: + """ + Return value of io/common.py:_get_filepath_or_buffer. + """ + + filepath_or_buffer: str | BaseBuffer + encoding: str + mode: str + compression: CompressionDict + should_close: bool = False + + +@dataclasses.dataclass +class IOHandles(Generic[AnyStr]): + """ + Return value of io/common.py:get_handle + + Can be used as a context manager. + + This is used to easily close created buffers and to handle corner cases when + TextIOWrapper is inserted. + + handle: The file handle to be used. + created_handles: All file handles that are created by get_handle + is_wrapped: Whether a TextIOWrapper needs to be detached. + """ + + # handle might not implement the IO-interface + handle: IO[AnyStr] + compression: CompressionDict + created_handles: list[IO[bytes] | IO[str]] = dataclasses.field(default_factory=list) + is_wrapped: bool = False + + def close(self) -> None: + """ + Close all created buffers. + + Note: If a TextIOWrapper was inserted, it is flushed and detached to + avoid closing the potentially user-created buffer. + """ + if self.is_wrapped: + assert isinstance(self.handle, TextIOWrapper) + self.handle.flush() + self.handle.detach() + self.created_handles.remove(self.handle) + for handle in self.created_handles: + handle.close() + self.created_handles = [] + self.is_wrapped = False + + def __enter__(self) -> IOHandles[AnyStr]: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + + +def is_url(url: object) -> bool: + """ + Check to see if a URL has a valid protocol. + + Parameters + ---------- + url : str or unicode + + Returns + ------- + isurl : bool + If `url` has a valid protocol return True otherwise False. + """ + if not isinstance(url, str): + return False + return parse_url(url).scheme in _VALID_URLS + + +@overload +def _expand_user(filepath_or_buffer: str) -> str: ... + + +@overload +def _expand_user(filepath_or_buffer: BaseBufferT) -> BaseBufferT: ... + + +def _expand_user(filepath_or_buffer: str | BaseBufferT) -> str | BaseBufferT: + """ + Return the argument with an initial component of ~ or ~user + replaced by that user's home directory. + + Parameters + ---------- + filepath_or_buffer : object to be converted if possible + + Returns + ------- + expanded_filepath_or_buffer : an expanded filepath or the + input if not expandable + """ + if isinstance(filepath_or_buffer, str): + return os.path.expanduser(filepath_or_buffer) + return filepath_or_buffer + + +def validate_header_arg(header: object) -> None: + if header is None: + return + if is_integer(header): + header = cast(int, header) + if header < 0: + # GH 27779 + raise ValueError( + "Passing negative integer to header is invalid. " + "For no header, use header=None instead" + ) + return + if is_list_like(header, allow_sets=False): + header = cast(Sequence, header) + if not all(map(is_integer, header)): + raise ValueError("header must be integer or list of integers") + if any(i < 0 for i in header): + raise ValueError("cannot specify multi-index header with negative integers") + return + if is_bool(header): + raise TypeError( + "Passing a bool to header is invalid. Use header=None for no header or " + "header=int or list-like of ints to specify " + "the row(s) making up the column names" + ) + # GH 16338 + raise ValueError("header must be integer or list of integers") + + +@overload +def stringify_path( + filepath_or_buffer: FilePath, convert_file_like: bool = ... +) -> str: ... + + +@overload +def stringify_path( + filepath_or_buffer: BaseBufferT, convert_file_like: bool = ... +) -> BaseBufferT: ... + + +def stringify_path( + filepath_or_buffer: FilePath | BaseBufferT, + convert_file_like: bool = False, +) -> str | BaseBufferT: + """ + Attempt to convert a path-like object to a string. + + Parameters + ---------- + filepath_or_buffer : object to be converted + + Returns + ------- + str_filepath_or_buffer : maybe a string version of the object + + Notes + ----- + Objects supporting the fspath protocol are coerced + according to its __fspath__ method. + + Any other object is passed through unchanged, which includes bytes, + strings, buffers, or anything else that's not even path-like. + """ + if not convert_file_like and is_file_like(filepath_or_buffer): + # GH 38125: some fsspec objects implement os.PathLike but have already opened a + # file. This prevents opening the file a second time. infer_compression calls + # this function with convert_file_like=True to infer the compression. + return cast(BaseBufferT, filepath_or_buffer) + + if isinstance(filepath_or_buffer, os.PathLike): + filepath_or_buffer = filepath_or_buffer.__fspath__() + return _expand_user(filepath_or_buffer) + + +def urlopen(*args: Any, **kwargs: Any) -> Any: + """ + Lazy-import wrapper for stdlib urlopen, as that imports a big chunk of + the stdlib. + """ + import urllib.request + + return urllib.request.urlopen(*args, **kwargs) # noqa: TID251 + + +def is_fsspec_url(url: FilePath | BaseBuffer) -> bool: + """ + Returns true if the given URL looks like + something fsspec can handle + """ + return ( + isinstance(url, str) + and bool(_FSSPEC_URL_PATTERN.match(url)) + and not url.startswith(("http://", "https://")) + ) + + +def _get_filepath_or_buffer( + filepath_or_buffer: FilePath | BaseBuffer, + encoding: str = "utf-8", + compression: CompressionOptions | None = None, + mode: str = "r", + storage_options: StorageOptions | None = None, +) -> IOArgs: + """ + If the filepath_or_buffer is a url, translate and return the buffer. + Otherwise passthrough. + + Parameters + ---------- + filepath_or_buffer : a url, filepath (str or pathlib.Path), + or buffer + + compression : str or dict, default 'infer' + For on-the-fly compression of the output data. If 'infer' and + 'filepath_or_buffer' is path-like, then detect compression from the + following extensions: '.gz', + '.bz2', '.zip', '.xz', '.zst', '.tar', '.tar.gz', '.tar.xz' or '.tar.bz2' + (otherwise no compression). + Set to ``None`` for no compression. + Can also be a dict with key ``'method'`` set + to one of {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} + and other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdCompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for faster compression and to + create a reproducible gzip archive: + ``compression={'method': 'gzip', 'compresslevel': 1, 'mtime': 1}``. + + encoding : the encoding to use to decode bytes, default is 'utf-8' + mode : str, optional + + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + + Returns the dataclass IOArgs. + """ + filepath_or_buffer = stringify_path(filepath_or_buffer) + + # handle compression dict + compression_method, compression = get_compression_method(compression) + compression_method = infer_compression(filepath_or_buffer, compression_method) + + # GH21227 internal compression is not used for non-binary handles. + if compression_method and hasattr(filepath_or_buffer, "write") and "b" not in mode: + warnings.warn( + "compression has no effect when passing a non-binary object as input.", + RuntimeWarning, + stacklevel=find_stack_level(), + ) + compression_method = None + + compression = dict(compression, method=compression_method) + + # bz2 and xz do not write the byte order mark for utf-16 and utf-32 + # print a warning when writing such files + if ( + "w" in mode + and compression_method in ["bz2", "xz"] + and encoding in ["utf-16", "utf-32"] + ): + warnings.warn( + f"{compression} will not write the byte order mark for {encoding}", + UnicodeWarning, + stacklevel=find_stack_level(), + ) + + if "a" in mode and compression_method in ["zip", "tar"]: + # GH56778 + warnings.warn( + "zip and tar do not support mode 'a' properly. " + "This combination will result in multiple files with same name " + "being added to the archive.", + RuntimeWarning, + stacklevel=find_stack_level(), + ) + + # Use binary mode when converting path-like objects to file-like objects (fsspec) + # except when text mode is explicitly requested. The original mode is returned if + # fsspec is not used. + fsspec_mode = mode + if "t" not in fsspec_mode and "b" not in fsspec_mode: + fsspec_mode += "b" + + if isinstance(filepath_or_buffer, str) and is_url(filepath_or_buffer): + # TODO: fsspec can also handle HTTP via requests, but leaving this + # unchanged. using fsspec appears to break the ability to infer if the + # server responded with gzipped data + storage_options = storage_options or {} + + # waiting until now for importing to match intended lazy logic of + # urlopen function defined elsewhere in this module + import urllib.request + + # assuming storage_options is to be interpreted as headers + req_info = urllib.request.Request(filepath_or_buffer, headers=storage_options) + with urlopen(req_info) as req: + content_encoding = req.headers.get("Content-Encoding", None) + if content_encoding == "gzip": + # Override compression based on Content-Encoding header + compression = {"method": "gzip"} + reader = BytesIO(req.read()) + return IOArgs( + filepath_or_buffer=reader, + encoding=encoding, + compression=compression, + should_close=True, + mode=fsspec_mode, + ) + + if is_fsspec_url(filepath_or_buffer): + assert isinstance( + filepath_or_buffer, str + ) # just to appease mypy for this branch + # two special-case s3-like protocols; these have special meaning in Hadoop, + # but are equivalent to just "s3" from fsspec's point of view + # cc #11071 + if filepath_or_buffer.startswith("s3a://"): + filepath_or_buffer = filepath_or_buffer.replace("s3a://", "s3://") + if filepath_or_buffer.startswith("s3n://"): + filepath_or_buffer = filepath_or_buffer.replace("s3n://", "s3://") + fsspec = import_optional_dependency("fsspec") + + # If botocore is installed we fallback to reading with anon=True + # to allow reads from public buckets + err_types_to_retry_with_anon: list[Any] = [] + try: + import_optional_dependency("botocore") + from botocore.exceptions import ( + ClientError, + NoCredentialsError, + ) + + err_types_to_retry_with_anon = [ + ClientError, + NoCredentialsError, + PermissionError, + ] + except ImportError: + pass + + try: + file_obj = fsspec.open( + filepath_or_buffer, mode=fsspec_mode, **(storage_options or {}) + ).open() + # GH 34626 Reads from Public Buckets without Credentials needs anon=True + except tuple(err_types_to_retry_with_anon): + if storage_options is None: + storage_options = {"anon": True} + else: + # don't mutate user input. + storage_options = dict(storage_options) + storage_options["anon"] = True + file_obj = fsspec.open( + filepath_or_buffer, mode=fsspec_mode, **(storage_options or {}) + ).open() + + return IOArgs( + filepath_or_buffer=file_obj, + encoding=encoding, + compression=compression, + should_close=True, + mode=fsspec_mode, + ) + elif storage_options: + raise ValueError( + "storage_options passed with file object or non-fsspec file path" + ) + + if isinstance(filepath_or_buffer, (str, bytes, mmap.mmap)): + return IOArgs( + filepath_or_buffer=_expand_user(filepath_or_buffer), + encoding=encoding, + compression=compression, + should_close=False, + mode=mode, + ) + + # is_file_like requires (read | write) & __iter__ but __iter__ is only + # needed for read_csv(engine=python) + if not ( + hasattr(filepath_or_buffer, "read") or hasattr(filepath_or_buffer, "write") + ): + msg = f"Invalid file path or buffer object type: {type(filepath_or_buffer)}" + raise ValueError(msg) + + return IOArgs( + filepath_or_buffer=filepath_or_buffer, + encoding=encoding, + compression=compression, + should_close=False, + mode=mode, + ) + + +def file_path_to_url(path: str) -> str: + """ + converts an absolute native path to a FILE URL. + + Parameters + ---------- + path : a path in native format + + Returns + ------- + a valid FILE URL + """ + # lazify expensive import (~30ms) + from urllib.request import pathname2url + + return urljoin("file:", pathname2url(path)) + + +extension_to_compression = { + ".tar": "tar", + ".tar.gz": "tar", + ".tar.bz2": "tar", + ".tar.xz": "tar", + ".gz": "gzip", + ".bz2": "bz2", + ".zip": "zip", + ".xz": "xz", + ".zst": "zstd", +} +_supported_compressions = set(extension_to_compression.values()) + + +def get_compression_method( + compression: CompressionOptions, +) -> tuple[str | None, CompressionDict]: + """ + Simplifies a compression argument to a compression method string and + a mapping containing additional arguments. + + Parameters + ---------- + compression : str or mapping + If string, specifies the compression method. If mapping, value at key + 'method' specifies compression method. + + Returns + ------- + tuple of ({compression method}, Optional[str] + {compression arguments}, Dict[str, Any]) + + Raises + ------ + ValueError on mapping missing 'method' key + """ + compression_method: str | None + if isinstance(compression, Mapping): + compression_args = dict(compression) + try: + compression_method = compression_args.pop("method") + except KeyError as err: + raise ValueError("If mapping, compression must have key 'method'") from err + else: + compression_args = {} + compression_method = compression + return compression_method, compression_args + + +def infer_compression( + filepath_or_buffer: FilePath | BaseBuffer, compression: str | None +) -> str | None: + """ + Get the compression method for filepath_or_buffer. If compression='infer', + the inferred compression method is returned. Otherwise, the input + compression method is returned unchanged, unless it's invalid, in which + case an error is raised. + + Parameters + ---------- + filepath_or_buffer : str or file handle + File path or object. + + compression : str or dict, default 'infer' + For on-the-fly compression of the output data. If 'infer' and + 'filepath_or_buffer' is path-like, then detect compression from the + following extensions: '.gz', + '.bz2', '.zip', '.xz', '.zst', '.tar', '.tar.gz', '.tar.xz' or '.tar.bz2' + (otherwise no compression). + Set to ``None`` for no compression. + Can also be a dict with key ``'method'`` set + to one of {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} + and other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdCompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for faster compression and to + create a reproducible gzip archive: + ``compression={'method': 'gzip', 'compresslevel': 1, 'mtime': 1}``. + + Returns + ------- + string or None + + Raises + ------ + ValueError on invalid compression specified. + """ + if compression is None: + return None + + # Infer compression + if compression == "infer": + # Convert all path types (e.g. pathlib.Path) to strings + if isinstance(filepath_or_buffer, str) and "::" in filepath_or_buffer: + # chained URLs contain :: + filepath_or_buffer = filepath_or_buffer.split("::")[0] + filepath_or_buffer = stringify_path(filepath_or_buffer, convert_file_like=True) + if not isinstance(filepath_or_buffer, str): + # Cannot infer compression of a buffer, assume no compression + return None + + # Infer compression from the filename/URL extension + for extension, compression in extension_to_compression.items(): + if filepath_or_buffer.lower().endswith(extension): + return compression + return None + + # Compression has been specified. Check that it's valid + if compression in _supported_compressions: + return compression + + valid = ["infer", None, *sorted(_supported_compressions)] + msg = ( + f"Unrecognized compression type: {compression}\n" + f"Valid compression types are {valid}" + ) + raise ValueError(msg) + + +def check_parent_directory(path: Path | str) -> None: + """ + Check if parent directory of a file exists, raise OSError if it does not + + Parameters + ---------- + path: Path or str + Path to check parent directory of + """ + parent = Path(path).parent + if not parent.is_dir(): + raise OSError(rf"Cannot save file into a non-existent directory: '{parent}'") + + +@overload +def get_handle( + path_or_buf: FilePath | BaseBuffer, + mode: str, + *, + encoding: str | None = ..., + compression: CompressionOptions = ..., + memory_map: bool = ..., + is_text: Literal[False], + errors: str | None = ..., + storage_options: StorageOptions = ..., +) -> IOHandles[bytes]: ... + + +@overload +def get_handle( + path_or_buf: FilePath | BaseBuffer, + mode: str, + *, + encoding: str | None = ..., + compression: CompressionOptions = ..., + memory_map: bool = ..., + is_text: Literal[True] = ..., + errors: str | None = ..., + storage_options: StorageOptions = ..., +) -> IOHandles[str]: ... + + +@overload +def get_handle( + path_or_buf: FilePath | BaseBuffer, + mode: str, + *, + encoding: str | None = ..., + compression: CompressionOptions = ..., + memory_map: bool = ..., + is_text: bool = ..., + errors: str | None = ..., + storage_options: StorageOptions = ..., +) -> IOHandles[str] | IOHandles[bytes]: ... + + +def get_handle( + path_or_buf: FilePath | BaseBuffer, + mode: str, + *, + encoding: str | None = None, + compression: CompressionOptions | None = None, + memory_map: bool = False, + is_text: bool = True, + errors: str | None = None, + storage_options: StorageOptions | None = None, +) -> IOHandles[str] | IOHandles[bytes]: + """ + Get file handle for given path/buffer and mode. + + Parameters + ---------- + path_or_buf : str or file handle + File path or object. + mode : str + Mode to open path_or_buf with. + encoding : str or None + Encoding to use. + compression : str or dict, default 'infer' + For on-the-fly compression of the output data. If 'infer' and 'path_or_buf' + is path-like, then detect compression from the following extensions: '.gz', + '.bz2', '.zip', '.xz', '.zst', '.tar', '.tar.gz', '.tar.xz' or '.tar.bz2' + (otherwise no compression). + Set to ``None`` for no compression. + Can also be a dict with key ``'method'`` set + to one of {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} + and other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdCompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for faster compression and to + create a reproducible gzip archive: + ``compression={'method': 'gzip', 'compresslevel': 1, 'mtime': 1}``. + + May be a dict with key 'method' as compression mode + and other keys as compression options if compression + mode is 'zip'. + + Passing compression options as keys in dict is + supported for compression modes 'gzip', 'bz2', 'zstd' and 'zip'. + + memory_map : bool, default False + See parsers._parser_params for more information. Only used by read_csv. + is_text : bool, default True + Whether the type of the content passed to the file/buffer is string or + bytes. This is not the same as `"b" not in mode`. If a string content is + passed to a binary file/buffer, a wrapper is inserted. + errors : str, default 'strict' + Specifies how encoding and decoding errors are to be handled. + See the errors argument for :func:`open` for a full list + of options. + storage_options: StorageOptions = None + Passed to _get_filepath_or_buffer + + Returns the dataclass IOHandles + """ + # Windows does not default to utf-8. Set to utf-8 for a consistent behavior + encoding = encoding or "utf-8" + + errors = errors or "strict" + + # read_csv does not know whether the buffer is opened in binary/text mode + if _is_binary_mode(path_or_buf, mode) and "b" not in mode: + mode += "b" + + # validate encoding and errors + codecs.lookup(encoding) + if isinstance(errors, str): + codecs.lookup_error(errors) + + # open URLs + ioargs = _get_filepath_or_buffer( + path_or_buf, + encoding=encoding, + compression=compression, + mode=mode, + storage_options=storage_options, + ) + + handle = ioargs.filepath_or_buffer + handles: list[BaseBuffer] + + # memory mapping needs to be the first step + # only used for read_csv + handle, memory_map, handles = _maybe_memory_map(handle, memory_map) + + is_path = isinstance(handle, str) + compression_args = dict(ioargs.compression) + compression = compression_args.pop("method") + + # Only for write methods + if "r" not in mode and is_path: + check_parent_directory(str(handle)) + + if compression: + if compression != "zstd": + # compression libraries do not like an explicit text-mode + ioargs.mode = ioargs.mode.replace("t", "") + elif compression == "zstd" and "b" not in ioargs.mode: + # python-zstandard defaults to text mode, but we always expect + # compression libraries to use binary mode. + ioargs.mode += "b" + + # GZ Compression + if compression == "gzip": + if isinstance(handle, str): + # error: Incompatible types in assignment (expression has type + # "GzipFile", variable has type "Union[str, BaseBuffer]") + handle = gzip.GzipFile( # type: ignore[assignment] + filename=handle, + mode=ioargs.mode, + **compression_args, + ) + else: + handle = gzip.GzipFile( + # No overload variant of "GzipFile" matches argument types + # "Union[str, BaseBuffer]", "str", "Dict[str, Any]" + fileobj=handle, # type: ignore[call-overload] + mode=ioargs.mode, + **compression_args, + ) + + # BZ Compression + elif compression == "bz2": + import bz2 + + # Overload of "BZ2File" to handle pickle protocol 5 + # "Union[str, BaseBuffer]", "str", "Dict[str, Any]" + handle = bz2.BZ2File( # type: ignore[call-overload] + handle, + mode=ioargs.mode, + **compression_args, + ) + + # ZIP Compression + elif compression == "zip": + # error: Argument 1 to "_BytesZipFile" has incompatible type + # "Union[str, BaseBuffer]"; expected "Union[Union[str, PathLike[str]], + # ReadBuffer[bytes], WriteBuffer[bytes]]" + handle = _BytesZipFile( + handle, # type: ignore[arg-type] + ioargs.mode, + **compression_args, + ) + if handle.buffer.mode == "r": + handles.append(handle) + zip_names = handle.buffer.namelist() + if len(zip_names) == 1: + handle = handle.buffer.open(zip_names.pop()) + elif not zip_names: + raise ValueError(f"Zero files found in ZIP file {path_or_buf}") + else: + raise ValueError( + "Multiple files found in ZIP file. " + f"Only one file per ZIP: {zip_names}" + ) + + # TAR Encoding + elif compression == "tar": + compression_args.setdefault("mode", ioargs.mode) + if isinstance(handle, str): + handle = _BytesTarFile(name=handle, **compression_args) + else: + # error: Argument "fileobj" to "_BytesTarFile" has incompatible + # type "BaseBuffer"; expected "Union[ReadBuffer[bytes], + # WriteBuffer[bytes], None]" + handle = _BytesTarFile( + fileobj=handle, # type: ignore[arg-type] + **compression_args, + ) + assert isinstance(handle, _BytesTarFile) + if "r" in handle.buffer.mode: + handles.append(handle) + files = handle.buffer.getnames() + if len(files) == 1: + file = handle.buffer.extractfile(files[0]) + assert file is not None + handle = file + elif not files: + raise ValueError(f"Zero files found in TAR archive {path_or_buf}") + else: + raise ValueError( + "Multiple files found in TAR archive. " + f"Only one file per TAR archive: {files}" + ) + + # XZ Compression + elif compression == "xz": + # error: Argument 1 to "LZMAFile" has incompatible type "Union[str, + # BaseBuffer]"; expected "Optional[Union[Union[str, bytes, PathLike[str], + # PathLike[bytes]], IO[bytes]], None]" + import lzma + + handle = lzma.LZMAFile( + handle, # type: ignore[arg-type] + ioargs.mode, + **compression_args, + ) + + # Zstd Compression + elif compression == "zstd": + zstd = import_optional_dependency("zstandard") + if "r" in ioargs.mode: + open_args = {"dctx": zstd.ZstdDecompressor(**compression_args)} + else: + open_args = {"cctx": zstd.ZstdCompressor(**compression_args)} + handle = zstd.open( + handle, + mode=ioargs.mode, + **open_args, + ) + + # Unrecognized Compression + else: + msg = f"Unrecognized compression type: {compression}" + raise ValueError(msg) + + assert not isinstance(handle, str) + handles.append(handle) + + elif isinstance(handle, str): + # Check whether the filename is to be opened in binary mode. + # Binary mode does not support 'encoding' and 'newline'. + if ioargs.encoding and "b" not in ioargs.mode: + # Encoding + handle = open( + handle, + ioargs.mode, + encoding=ioargs.encoding, + errors=errors, + newline="", + ) + else: + # Binary mode + handle = open(handle, ioargs.mode) + handles.append(handle) + + # Convert BytesIO or file objects passed with an encoding + is_wrapped = False + if not is_text and ioargs.mode == "rb" and isinstance(handle, TextIOBase): + # not added to handles as it does not open/buffer resources + handle = _BytesIOWrapper( + handle, + encoding=ioargs.encoding, + ) + elif is_text and ( + compression or memory_map or _is_binary_mode(handle, ioargs.mode) + ): + if ( + not hasattr(handle, "readable") + or not hasattr(handle, "writable") + or not hasattr(handle, "seekable") + ): + handle = _IOWrapper(handle) + # error: Value of type variable "_BufferT_co" of "TextIOWrapper" cannot + # be "_IOWrapper | BaseBuffer" [type-var] + handle = TextIOWrapper( + handle, # type: ignore[type-var] + encoding=ioargs.encoding, + errors=errors, + newline="", + ) + handles.append(handle) + # only marked as wrapped when the caller provided a handle + is_wrapped = not ( + isinstance(ioargs.filepath_or_buffer, str) or ioargs.should_close + ) + + if "r" in ioargs.mode and not hasattr(handle, "read"): + raise TypeError( + "Expected file path name or file-like object, " + f"got {type(ioargs.filepath_or_buffer)} type" + ) + + handles.reverse() # close the most recently added buffer first + if ioargs.should_close: + assert not isinstance(ioargs.filepath_or_buffer, str) + handles.append(ioargs.filepath_or_buffer) + + return IOHandles( + # error: Argument "handle" to "IOHandles" has incompatible type + # "Union[TextIOWrapper, GzipFile, BaseBuffer, typing.IO[bytes], + # typing.IO[Any]]"; expected "pandas._typing.IO[Any]" + handle=handle, # type: ignore[arg-type] + # error: Argument "created_handles" to "IOHandles" has incompatible type + # "List[BaseBuffer]"; expected "List[Union[IO[bytes], IO[str]]]" + created_handles=handles, # type: ignore[arg-type] + is_wrapped=is_wrapped, + compression=ioargs.compression, + ) + + +class _BufferedWriter(BytesIO, ABC): + """ + Some objects do not support multiple .write() calls (TarFile and ZipFile). + This wrapper writes to the underlying buffer on close. + """ + + buffer = BytesIO() + + @abstractmethod + def write_to_buffer(self) -> None: ... + + def close(self) -> None: + if self.closed: + # already closed + return + if self.getbuffer().nbytes: + # write to buffer + self.seek(0) + with self.buffer: + self.write_to_buffer() + else: + self.buffer.close() + super().close() + + +class _BytesTarFile(_BufferedWriter): + def __init__( + self, + name: str | None = None, + mode: Literal["r", "a", "w", "x"] = "r", + fileobj: ReadBuffer[bytes] | WriteBuffer[bytes] | None = None, + archive_name: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__() + self.archive_name = archive_name + self.name = name + # error: No overload variant of "open" of "TarFile" matches argument + # types "str | None", "str", "ReadBuffer[bytes] | WriteBuffer[bytes] | None", + # "dict[str, Any]" + # error: Incompatible types in assignment (expression has type "TarFile", + # base class "_BufferedWriter" defined the type as "BytesIO") + self.buffer: tarfile.TarFile = tarfile.TarFile.open( # type: ignore[call-overload, assignment] + name=name, + mode=self.extend_mode(mode), + fileobj=fileobj, + **kwargs, + ) + + def extend_mode(self, mode: str) -> str: + mode = mode.replace("b", "") + if mode != "w": + return mode + if self.name is not None: + suffix = Path(self.name).suffix + if suffix in (".gz", ".xz", ".bz2"): + mode = f"{mode}:{suffix[1:]}" + return mode + + def infer_filename(self) -> str | None: + """ + If an explicit archive_name is not given, we still want the file inside the zip + file not to be named something.tar, because that causes confusion (GH39465). + """ + if self.name is None: + return None + + filename = Path(self.name) + if filename.suffix == ".tar": + return filename.with_suffix("").name + elif filename.suffix in (".tar.gz", ".tar.bz2", ".tar.xz"): + return filename.with_suffix("").with_suffix("").name + return filename.name + + def write_to_buffer(self) -> None: + # TarFile needs a non-empty string + archive_name = self.archive_name or self.infer_filename() or "tar" + tarinfo = tarfile.TarInfo(name=archive_name) + tarinfo.size = len(self.getvalue()) + self.buffer.addfile(tarinfo, self) + + +class _BytesZipFile(_BufferedWriter): + def __init__( + self, + file: FilePath | ReadBuffer[bytes] | WriteBuffer[bytes], + mode: str, + archive_name: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__() + mode = mode.replace("b", "") + self.archive_name = archive_name + + kwargs.setdefault("compression", zipfile.ZIP_DEFLATED) + # error: No overload variant of "ZipFile" matches argument types + # "str | PathLike[str] | ReadBuffer[bytes] | WriteBuffer[bytes]", + # "str", "dict[str, Any]" + # error: Incompatible types in assignment (expression has type "ZipFile", + # base class "_BufferedWriter" defined the type as "BytesIO") + self.buffer: zipfile.ZipFile = zipfile.ZipFile( # type: ignore[call-overload, assignment] + file, mode, **kwargs + ) + + def infer_filename(self) -> str | None: + """ + If an explicit archive_name is not given, we still want the file inside the zip + file not to be named something.zip, because that causes confusion (GH39465). + """ + if isinstance(self.buffer.filename, (os.PathLike, str)): + filename = Path(self.buffer.filename) + if filename.suffix == ".zip": + return filename.with_suffix("").name + return filename.name + return None + + def write_to_buffer(self) -> None: + # ZipFile needs a non-empty string + archive_name = self.archive_name or self.infer_filename() or "zip" + self.buffer.writestr(archive_name, self.getvalue()) + + +class _IOWrapper: + # TextIOWrapper is overly strict: it request that the buffer has seekable, readable, + # and writable. If we have a read-only buffer, we shouldn't need writable and vice + # versa. Some buffers, are seek/read/writ-able but they do not have the "-able" + # methods, e.g., tempfile.SpooledTemporaryFile. + # If a buffer does not have the above "-able" methods, we simple assume they are + # seek/read/writ-able. + def __init__(self, buffer: BaseBuffer) -> None: + self.buffer = buffer + + def __getattr__(self, name: str) -> Any: + return getattr(self.buffer, name) + + def readable(self) -> bool: + if hasattr(self.buffer, "readable"): + return self.buffer.readable() + return True + + def seekable(self) -> bool: + if hasattr(self.buffer, "seekable"): + return self.buffer.seekable() + return True + + def writable(self) -> bool: + if hasattr(self.buffer, "writable"): + return self.buffer.writable() + return True + + +class _BytesIOWrapper: + # Wrapper that wraps a StringIO buffer and reads bytes from it + # Created for compat with pyarrow read_csv + def __init__(self, buffer: StringIO | TextIOBase, encoding: str = "utf-8") -> None: + self.buffer = buffer + self.encoding = encoding + # Because a character can be represented by more than 1 byte, + # it is possible that reading will produce more bytes than n + # We store the extra bytes in this overflow variable, and append the + # overflow to the front of the bytestring the next time reading is performed + self.overflow = b"" + + def __getattr__(self, attr: str) -> Any: + return getattr(self.buffer, attr) + + def read(self, n: int | None = -1) -> bytes: + assert self.buffer is not None + bytestring = self.buffer.read(n).encode(self.encoding) + # When n=-1/n greater than remaining bytes: Read entire file/rest of file + combined_bytestring = self.overflow + bytestring + if n is None or n < 0 or n >= len(combined_bytestring): + self.overflow = b"" + return combined_bytestring + else: + to_return = combined_bytestring[:n] + self.overflow = combined_bytestring[n:] + return to_return + + +def _maybe_memory_map( + handle: str | BaseBuffer, memory_map: bool +) -> tuple[str | BaseBuffer, bool, list[BaseBuffer]]: + """Try to memory map file/buffer.""" + handles: list[BaseBuffer] = [] + memory_map &= hasattr(handle, "fileno") or isinstance(handle, str) + if not memory_map: + return handle, memory_map, handles + + # mmap used by only read_csv + handle = cast(ReadCsvBuffer, handle) + + # need to open the file first + if isinstance(handle, str): + handle = open(handle, "rb") + handles.append(handle) + + try: + # open mmap and adds *-able + # error: Argument 1 to "_IOWrapper" has incompatible type "mmap"; + # expected "BaseBuffer" + wrapped = _IOWrapper( + mmap.mmap( + handle.fileno(), + 0, + access=mmap.ACCESS_READ, # type: ignore[arg-type] + ) + ) + finally: + for handle in reversed(handles): + # error: "BaseBuffer" has no attribute "close" + handle.close() # type: ignore[attr-defined] + + return wrapped, memory_map, [wrapped] + + +def file_exists(filepath_or_buffer: FilePath | BaseBuffer) -> bool: + """Test whether file exists.""" + exists = False + filepath_or_buffer = stringify_path(filepath_or_buffer) + if not isinstance(filepath_or_buffer, str): + return exists + try: + exists = os.path.exists(filepath_or_buffer) + # gh-5874: if the filepath is too long will raise here + except (TypeError, ValueError): + pass + return exists + + +def _is_binary_mode(handle: FilePath | BaseBuffer, mode: str) -> bool: + """Whether the handle is opened in binary mode""" + # specified by user + if "t" in mode or "b" in mode: + return "b" in mode + + # exceptions + text_classes = ( + # classes that expect string but have 'b' in mode + codecs.StreamWriter, + codecs.StreamReader, + codecs.StreamReaderWriter, + ) + if issubclass(type(handle), text_classes): + return False + + return isinstance(handle, _get_binary_io_classes()) or "b" in getattr( + handle, "mode", mode + ) + + +@functools.lru_cache +def _get_binary_io_classes() -> tuple[type, ...]: + """IO classes that that expect bytes""" + binary_classes: tuple[type, ...] = (BufferedIOBase, RawIOBase) + + # python-zstandard doesn't use any of the builtin base classes; instead we + # have to use the `zstd.ZstdDecompressionReader` class for isinstance checks. + # Unfortunately `zstd.ZstdDecompressionReader` isn't exposed by python-zstandard + # so we have to get it from a `zstd.ZstdDecompressor` instance. + # See also https://github.com/indygreg/python-zstandard/pull/165. + zstd = import_optional_dependency("zstandard", errors="ignore") + if zstd is not None: + with zstd.ZstdDecompressor().stream_reader(b"") as reader: + binary_classes += (type(reader),) + + return binary_classes + + +def is_potential_multi_index( + columns: Sequence[Hashable] | MultiIndex, + index_col: bool | Sequence[int] | None = None, +) -> bool: + """ + Check whether or not the `columns` parameter + could be converted into a MultiIndex. + + Parameters + ---------- + columns : array-like + Object which may or may not be convertible into a MultiIndex + index_col : None, bool or list, optional + Column or columns to use as the (possibly hierarchical) index + + Returns + ------- + bool : Whether or not columns could become a MultiIndex + """ + if index_col is None or isinstance(index_col, bool): + index_columns = set() + else: + index_columns = set(index_col) + + return bool( + len(columns) + and not isinstance(columns, ABCMultiIndex) + and all(isinstance(c, tuple) for c in columns if c not in index_columns) + ) + + +def dedup_names( + names: Sequence[Hashable], is_potential_multiindex: bool +) -> Sequence[Hashable]: + """ + Rename column names if duplicates exist. + + Currently the renaming is done by appending a period and an autonumeric, + but a custom pattern may be supported in the future. + + Examples + -------- + >>> dedup_names(["x", "y", "x", "x"], is_potential_multiindex=False) + ['x', 'y', 'x.1', 'x.2'] + """ + names = list(names) # so we can index + counts: DefaultDict[Hashable, int] = defaultdict(int) + + for i, col in enumerate(names): + cur_count = counts[col] + + while cur_count > 0: + counts[col] = cur_count + 1 + + if is_potential_multiindex: + # for mypy + assert isinstance(col, tuple) + col = (*col[:-1], f"{col[-1]}.{cur_count}") + else: + col = f"{col}.{cur_count}" + cur_count = counts[col] + + names[i] = col + counts[col] = cur_count + 1 + + return names diff --git a/pandas/io/feather_format.py b/pandas/io/feather_format.py new file mode 100644 index 0000000000000000000000000000000000000000..750df6143aa56fe429bf4240bc813b5f4df4b5ac --- /dev/null +++ b/pandas/io/feather_format.py @@ -0,0 +1,181 @@ +"""feather-format compat""" + +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, +) +import warnings + +import numpy as np + +from pandas._config import using_string_dtype + +from pandas._libs import lib +from pandas.compat._optional import import_optional_dependency +from pandas.errors import Pandas4Warning +from pandas.util._decorators import set_module +from pandas.util._validators import check_dtype_backend + +from pandas.core.api import DataFrame +from pandas.core.arrays.string_ import StringDtype + +from pandas.io._util import arrow_table_to_pandas +from pandas.io.common import get_handle + +if TYPE_CHECKING: + from collections.abc import ( + Hashable, + Sequence, + ) + + from pandas._typing import ( + DtypeBackend, + FilePath, + ReadBuffer, + StorageOptions, + WriteBuffer, + ) + + +def to_feather( + df: DataFrame, + path: FilePath | WriteBuffer[bytes], + storage_options: StorageOptions | None = None, + **kwargs: Any, +) -> None: + """ + Write a DataFrame to the binary Feather format. + + Parameters + ---------- + df : DataFrame + path : str, path object, or file-like object + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + **kwargs : + Additional keywords passed to `pyarrow.feather.write_feather`. + + """ + import_optional_dependency("pyarrow") + from pyarrow import feather + + if not isinstance(df, DataFrame): + raise ValueError("feather only support IO with DataFrames") + + with get_handle( + path, "wb", storage_options=storage_options, is_text=False + ) as handles: + feather.write_feather(df, handles.handle, **kwargs) + + +@set_module("pandas") +def read_feather( + path: FilePath | ReadBuffer[bytes], + columns: Sequence[Hashable] | None = None, + use_threads: bool = True, + storage_options: StorageOptions | None = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, +) -> DataFrame: + """ + Load a feather-format object from the file path. + + Feather is particularly useful for scenarios that require efficient + serialization and deserialization of tabular data. It supports + schema preservation, making it a reliable choice for use cases + such as sharing data between Python and R, or persisting intermediate + results during data processing pipelines. This method provides additional + flexibility with options for selective column reading, thread parallelism, + and choosing the backend for data types. + + Parameters + ---------- + path : str, path object, or file-like object + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``read()`` function. The string could be a URL. + Valid URL schemes include http, ftp, s3, gs and file. For file URLs, a host is + expected. A local file could be: ``file://localhost/path/to/table.feather``. + columns : sequence, default None + If not provided, all columns are read. + use_threads : bool, default True + Whether to parallelize reading using multiple threads. + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + dtype_backend : {{'numpy_nullable', 'pyarrow'}} + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). If not specified, the default behavior + is to not use nullable data types. If specified, the behavior + is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame`. + * ``"pyarrow"``: returns pyarrow-backed nullable + :class:`ArrowDtype` :class:`DataFrame` + + .. versionadded:: 2.0 + + Returns + ------- + type of object stored in file + DataFrame object stored in the file. + + See Also + -------- + read_csv : Read a comma-separated values (csv) file into a pandas DataFrame. + read_excel : Read an Excel file into a pandas DataFrame. + read_spss : Read an SPSS file into a pandas DataFrame. + read_orc : Load an ORC object into a pandas DataFrame. + read_sas : Read SAS file into a pandas DataFrame. + + Examples + -------- + >>> df = pd.read_feather("path/to/file.feather") # doctest: +SKIP + """ + import_optional_dependency("pyarrow") + from pyarrow import feather + + # import utils to register the pyarrow extension types + import pandas.core.arrays.arrow.extension_types # pyright: ignore[reportUnusedImport] # noqa: F401 + + check_dtype_backend(dtype_backend) + + with get_handle( + path, "rb", storage_options=storage_options, is_text=False + ) as handles: + if dtype_backend is lib.no_default and not using_string_dtype(): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + "make_block is deprecated", + Pandas4Warning, + ) + + df = feather.read_feather( + handles.handle, columns=columns, use_threads=bool(use_threads) + ) + # Convert any StringDtype columns to object dtype (pyarrow always + # uses string dtype even when the infer_string option is False) + for col, dtype in zip(df.columns, df.dtypes, strict=True): + if isinstance(dtype, StringDtype) and dtype.na_value is np.nan: + df[col] = df[col].astype("object") + return df + + pa_table = feather.read_table( + handles.handle, columns=columns, use_threads=bool(use_threads) + ) + return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend) diff --git a/pandas/io/html.py b/pandas/io/html.py new file mode 100644 index 0000000000000000000000000000000000000000..3ceba63dea7690573ae66575748074ff126a12af --- /dev/null +++ b/pandas/io/html.py @@ -0,0 +1,1245 @@ +""" +:mod:`pandas.io.html` is a module containing functionality for dealing with +HTML IO. + +""" + +from __future__ import annotations + +from collections import abc +import errno +import numbers +import os +import re +from re import Pattern +from typing import ( + TYPE_CHECKING, + Literal, + cast, +) + +from pandas._libs import lib +from pandas.compat._optional import import_optional_dependency +from pandas.errors import ( + AbstractMethodError, + EmptyDataError, +) +from pandas.util._decorators import set_module +from pandas.util._validators import check_dtype_backend + +from pandas.core.dtypes.common import is_list_like + +from pandas import isna +from pandas.core.indexes.base import Index +from pandas.core.indexes.multi import MultiIndex +from pandas.core.series import Series + +from pandas.io.common import ( + get_handle, + is_url, + stringify_path, + validate_header_arg, +) +from pandas.io.formats.printing import pprint_thing +from pandas.io.parsers import TextParser + +if TYPE_CHECKING: + from collections.abc import ( + Iterable, + Sequence, + ) + + from pandas._typing import ( + BaseBuffer, + DtypeBackend, + FilePath, + HTMLFlavors, + ReadBuffer, + StorageOptions, + ) + + from pandas import DataFrame + +############# +# READ HTML # +############# +_RE_WHITESPACE = re.compile(r"[\r\n]+|\s{2,}") + + +def _remove_whitespace(s: str, regex: Pattern = _RE_WHITESPACE) -> str: + """ + Replace extra whitespace inside of a string with a single space. + + Parameters + ---------- + s : str or unicode + The string from which to remove extra whitespace. + regex : re.Pattern + The regular expression to use to remove extra whitespace. + + Returns + ------- + subd : str or unicode + `s` with all extra whitespace replaced with a single space. + """ + return regex.sub(" ", s.strip()) + + +def _get_skiprows(skiprows: int | Sequence[int] | slice | None) -> int | Sequence[int]: + """ + Get an iterator given an integer, slice or container. + + Parameters + ---------- + skiprows : int, slice, container + The iterator to use to skip rows; can also be a slice. + + Raises + ------ + TypeError + * If `skiprows` is not a slice, integer, or Container + + Returns + ------- + it : iterable + A proper iterator to use to skip rows of a DataFrame. + """ + if isinstance(skiprows, slice): + start, step = skiprows.start or 0, skiprows.step or 1 + return list(range(start, skiprows.stop, step)) + elif isinstance(skiprows, numbers.Integral) or is_list_like(skiprows): + return cast("int | Sequence[int]", skiprows) + elif skiprows is None: + return 0 + raise TypeError(f"{type(skiprows).__name__} is not a valid type for skipping rows") + + +def _read( + obj: FilePath | BaseBuffer, + encoding: str | None, + storage_options: StorageOptions | None, +) -> str | bytes: + """ + Try to read from a url, file or string. + + Parameters + ---------- + obj : str, unicode, path object, or file-like object + + Returns + ------- + raw_text : str + """ + try: + with get_handle( + obj, "r", encoding=encoding, storage_options=storage_options + ) as handles: + return handles.handle.read() + except OSError as err: + if not is_url(obj): + raise FileNotFoundError( + f"[Errno {errno.ENOENT}] {os.strerror(errno.ENOENT)}: {obj}" + ) from err + raise + + +class _HtmlFrameParser: + """ + Base class for parsers that parse HTML into DataFrames. + + Parameters + ---------- + io : str or file-like + This can be either a string path, a valid URL using the HTTP, + FTP, or FILE protocols or a file-like object. + + match : str or regex + The text to match in the document. + + attrs : dict + List of HTML element attributes to match. + + encoding : str + Encoding to be used by parser + + displayed_only : bool + Whether or not items with "display:none" should be ignored + + extract_links : {None, "all", "header", "body", "footer"} + Table elements in the specified section(s) with tags will have their + href extracted. + + Attributes + ---------- + io : str or file-like + raw HTML, URL, or file-like object + + match : regex + The text to match in the raw HTML + + attrs : dict-like + A dictionary of valid table attributes to use to search for table + elements. + + encoding : str + Encoding to be used by parser + + displayed_only : bool + Whether or not items with "display:none" should be ignored + + extract_links : {None, "all", "header", "body", "footer"} + Table elements in the specified section(s) with tags will have their + href extracted. + + Notes + ----- + To subclass this class effectively you must override the following methods: + * :func:`_build_doc` + * :func:`_attr_getter` + * :func:`_href_getter` + * :func:`_text_getter` + * :func:`_parse_td` + * :func:`_parse_thead_tr` + * :func:`_parse_tbody_tr` + * :func:`_parse_tfoot_tr` + * :func:`_parse_tables` + * :func:`_equals_tag` + See each method's respective documentation for details on their + functionality. + """ + + def __init__( + self, + io: FilePath | ReadBuffer[str] | ReadBuffer[bytes], + match: str | Pattern, + attrs: dict[str, str] | None, + encoding: str, + displayed_only: bool, + extract_links: Literal["header", "footer", "body", "all"] | None, + storage_options: StorageOptions = None, + ) -> None: + self.io = io + self.match = match + self.attrs = attrs + self.encoding = encoding + self.displayed_only = displayed_only + self.extract_links = extract_links + self.storage_options = storage_options + + def parse_tables(self): + """ + Parse and return all tables from the DOM. + + Returns + ------- + list of parsed (header, body, footer) tuples from tables. + """ + tables = self._parse_tables(self._build_doc(), self.match, self.attrs) + return (self._parse_thead_tbody_tfoot(table) for table in tables) + + def _attr_getter(self, obj, attr): + """ + Return the attribute value of an individual DOM node. + + Parameters + ---------- + obj : node-like + A DOM node. + + attr : str or unicode + The attribute, such as "colspan" + + Returns + ------- + str or unicode + The attribute value. + """ + # Both lxml and BeautifulSoup have the same implementation: + return obj.get(attr) + + def _href_getter(self, obj) -> str | None: + """ + Return an href if the DOM node contains a child or None. + + Parameters + ---------- + obj : node-like + A DOM node. + + Returns + ------- + href : str or unicode + The href from the child of the DOM node. + """ + raise AbstractMethodError(self) + + def _text_getter(self, obj): + """ + Return the text of an individual DOM node. + + Parameters + ---------- + obj : node-like + A DOM node. + + Returns + ------- + text : str or unicode + The text from an individual DOM node. + """ + raise AbstractMethodError(self) + + def _parse_td(self, obj): + """ + Return the td elements from a row element. + + Parameters + ---------- + obj : node-like + A DOM node. + + Returns + ------- + list of node-like + These are the elements of each row, i.e., the columns. + """ + raise AbstractMethodError(self) + + def _parse_thead_tr(self, table): + """ + Return the list of thead row elements from the parsed table element. + + Parameters + ---------- + table : a table element that contains zero or more thead elements. + + Returns + ------- + list of node-like + These are the row elements of a table. + """ + raise AbstractMethodError(self) + + def _parse_tbody_tr(self, table): + """ + Return the list of tbody row elements from the parsed table element. + + HTML5 table bodies consist of either 0 or more elements (which + only contain elements) or 0 or more elements. This method + checks for both structures. + + Parameters + ---------- + table : a table element that contains row elements. + + Returns + ------- + list of node-like + These are the row elements of a table. + """ + raise AbstractMethodError(self) + + def _parse_tfoot_tr(self, table): + """ + Return the list of tfoot row elements from the parsed table element. + + Parameters + ---------- + table : a table element that contains row elements. + + Returns + ------- + list of node-like + These are the row elements of a table. + """ + raise AbstractMethodError(self) + + def _parse_tables(self, document, match, attrs): + """ + Return all tables from the parsed DOM. + + Parameters + ---------- + document : the DOM from which to parse the table element. + + match : str or regular expression + The text to search for in the DOM tree. + + attrs : dict + A dictionary of table attributes that can be used to disambiguate + multiple tables on a page. + + Raises + ------ + ValueError : `match` does not match any text in the document. + + Returns + ------- + list of node-like + HTML
elements to be parsed into raw data. + """ + raise AbstractMethodError(self) + + def _equals_tag(self, obj, tag) -> bool: + """ + Return whether an individual DOM node matches a tag + + Parameters + ---------- + obj : node-like + A DOM node. + + tag : str + Tag name to be checked for equality. + + Returns + ------- + boolean + Whether `obj`'s tag name is `tag` + """ + raise AbstractMethodError(self) + + def _build_doc(self): + """ + Return a tree-like object that can be used to iterate over the DOM. + + Returns + ------- + node-like + The DOM from which to parse the table element. + """ + raise AbstractMethodError(self) + + def _parse_thead_tbody_tfoot(self, table_html): + """ + Given a table, return parsed header, body, and foot. + + Parameters + ---------- + table_html : node-like + + Returns + ------- + tuple of (header, body, footer), each a list of list-of-text rows. + + Notes + ----- + Header and body are lists-of-lists. Top level list is a list of + rows. Each row is a list of str text. + + Logic: Use , , elements to identify + header, body, and footer, otherwise: + - Put all rows into body + - Move rows from top of body to header only if + all elements inside row are . Move the top all- or + while body_rows and row_is_all_th(body_rows[0]): + header_rows.append(body_rows.pop(0)) + + header, rem = self._expand_colspan_rowspan(header_rows, section="header") + body, rem = self._expand_colspan_rowspan( + body_rows, + section="body", + remainder=rem, + overflow=len(footer_rows) > 0, + ) + footer, _ = self._expand_colspan_rowspan( + footer_rows, section="footer", remainder=rem, overflow=False + ) + + return header, body, footer + + def _expand_colspan_rowspan( + self, + rows, + section: Literal["header", "footer", "body"], + remainder: list[tuple[int, str | tuple, int]] | None = None, + overflow: bool = True, + ) -> tuple[list[list], list[tuple[int, str | tuple, int]]]: + """ + Given a list of s, return a list of text rows. + + Parameters + ---------- + rows : list of node-like + List of s + section : the section that the rows belong to (header, body or footer). + remainder: list[tuple[int, str | tuple, int]] | None + Any remainder from the expansion of previous section + overflow: bool + If true, return any partial rows as 'remainder'. If not, use up any + partial rows. True by default. + + Returns + ------- + list of list + Each returned row is a list of str text, or tuple (text, link) + if extract_links is not None. + remainder + Remaining partial rows if any. If overflow is False, an empty list + is returned. + + Notes + ----- + Any cell with ``rowspan`` or ``colspan`` will have its contents copied + to subsequent cells. + """ + all_texts = [] # list of rows, each a list of str + text: str | tuple + remainder = remainder if remainder is not None else [] + + for tr in rows: + texts = [] # the output for this row + next_remainder = [] + + index = 0 + tds = self._parse_td(tr) + for td in tds: + # Append texts from previous rows with rowspan>1 that come + # before this or (see _parse_thead_tr). + return row.xpath("./td|./th") + + def _parse_tables(self, document, match, kwargs): + pattern = match.pattern + + # 1. check all descendants for the given pattern and only search tables + # GH 49929 + xpath_expr = f"//table[.//text()[re:test(., {pattern!r})]]" + + # if any table attributes were given build an xpath expression to + # search for them + if kwargs: + xpath_expr += _build_xpath_expr(kwargs) + + tables = document.xpath(xpath_expr, namespaces=_re_namespace) + + tables = self._handle_hidden_tables(tables, "attrib") + if self.displayed_only: + for table in tables: + # lxml utilizes XPATH 1.0 which does not have regex + # support. As a result, we find all elements with a style + # attribute and iterate them to check for display:none + for elem in table.xpath(".//style"): + elem.drop_tree() + for elem in table.xpath(".//*[@style]"): + if "display:none" in elem.attrib.get("style", "").replace(" ", ""): + elem.drop_tree() + if not tables: + raise ValueError(f"No tables found matching regex {pattern!r}") + return tables + + def _equals_tag(self, obj, tag) -> bool: + return obj.tag == tag + + def _build_doc(self): + """ + Raises + ------ + ValueError + * If a URL that lxml cannot parse is passed. + + Exception + * Any other ``Exception`` thrown. For example, trying to parse a + URL that is syntactically correct on a machine with no internet + connection will fail. + + See Also + -------- + pandas.io.html._HtmlFrameParser._build_doc + """ + from lxml.etree import XMLSyntaxError + from lxml.html import ( + HTMLParser, + parse, + ) + + parser = HTMLParser(recover=True, encoding=self.encoding) + + if is_url(self.io): + with get_handle(self.io, "r", storage_options=self.storage_options) as f: + r = parse(f.handle, parser=parser) + else: + # try to parse the input in the simplest way + try: + r = parse(self.io, parser=parser) + except OSError as err: + raise FileNotFoundError( + f"[Errno {errno.ENOENT}] {os.strerror(errno.ENOENT)}: {self.io}" + ) from err + try: + r = r.getroot() + except AttributeError: + pass + else: + if not hasattr(r, "text_content"): + raise XMLSyntaxError("no text parsed from document", 0, 0, 0) + + for br in r.xpath("*//br"): + br.tail = "\n" + (br.tail or "") + + return r + + def _parse_thead_tr(self, table): + rows = [] + + for thead in table.xpath(".//thead"): + rows.extend(thead.xpath("./tr")) + + # HACK: lxml does not clean up the clearly-erroneous + # . (Missing ). Add + # the and _pretend_ it's a ; _parse_td() will find its + # children as though it's a . + # + # Better solution would be to use html5lib. + elements_at_root = thead.xpath("./td|./th") + if elements_at_root: + rows.append(thead) + + return rows + + def _parse_tbody_tr(self, table): + from_tbody = table.xpath(".//tbody//tr") + from_root = table.xpath("./tr") + # HTML spec: at most one of these lists has content + return from_tbody + from_root + + def _parse_tfoot_tr(self, table): + return table.xpath(".//tfoot//tr") + + +def _expand_elements(body) -> None: + data = [len(elem) for elem in body] + lens = Series(data) + lens_max = lens.max() + not_max = lens[lens != lens_max] + + empty = [""] + for ind, length in not_max.items(): + body[ind] += empty * (lens_max - length) + + +def _data_to_frame(**kwargs): + head, body, foot = kwargs.pop("data") + header = kwargs.pop("header") + kwargs["skiprows"] = _get_skiprows(kwargs["skiprows"]) + if head: + body = head + body + + # Infer header when there is a or top ") + + result1 = flavor_read_html(StringIO(data1))[0] + result2 = flavor_read_html(StringIO(data2))[0] + + tm.assert_frame_equal(result1, expected1) + tm.assert_frame_equal(result2, expected2) + + def test_parse_header_of_non_string_column(self, flavor_read_html): + # GH5048: if header is specified explicitly, an int column should be + # parsed as int while its header is parsed as str + result = flavor_read_html( + StringIO( + """ +
+ - Move rows from bottom of body to footer only if + all elements inside row are + """ + header_rows = self._parse_thead_tr(table_html) + body_rows = self._parse_tbody_tr(table_html) + footer_rows = self._parse_tfoot_tr(table_html) + + def row_is_all_th(row): + return all(self._equals_tag(t, "th") for t in self._parse_td(row)) + + if not header_rows: + # The table has no
rows from + # body_rows to header_rows. (This is a common case because many + # tables in the wild have no
+ while remainder and remainder[0][0] <= index: + prev_i, prev_text, prev_rowspan = remainder.pop(0) + texts.append(prev_text) + if prev_rowspan > 1: + next_remainder.append((prev_i, prev_text, prev_rowspan - 1)) + index += 1 + + # Append the text from this , colspan times + text = _remove_whitespace(self._text_getter(td)) + if self.extract_links in ("all", section): + href = self._href_getter(td) + text = (text, href) + rowspan = int(self._attr_getter(td, "rowspan") or 1) + colspan = int(self._attr_getter(td, "colspan") or 1) + + for _ in range(colspan): + texts.append(text) + if rowspan > 1: + next_remainder.append((index, text, rowspan - 1)) + index += 1 + + # Append texts from previous rows at the final position + for prev_i, prev_text, prev_rowspan in remainder: + texts.append(prev_text) + if prev_rowspan > 1: + next_remainder.append((prev_i, prev_text, prev_rowspan - 1)) + + all_texts.append(texts) + remainder = next_remainder + + if not overflow: + # Append rows that only appear because the previous row had non-1 + # rowspan + while remainder: + next_remainder = [] + texts = [] + for prev_i, prev_text, prev_rowspan in remainder: + texts.append(prev_text) + if prev_rowspan > 1: + next_remainder.append((prev_i, prev_text, prev_rowspan - 1)) + all_texts.append(texts) + remainder = next_remainder + + return all_texts, remainder + + def _handle_hidden_tables(self, tbl_list, attr_name: str): + """ + Return list of tables, potentially removing hidden elements + + Parameters + ---------- + tbl_list : list of node-like + Type of list elements will vary depending upon parser used + attr_name : str + Name of the accessor for retrieving HTML attributes + + Returns + ------- + list of node-like + Return type matches `tbl_list` + """ + if not self.displayed_only: + return tbl_list + + return [ + x + for x in tbl_list + if "display:none" + not in getattr(x, attr_name).get("style", "").replace(" ", "") + ] + + +class _BeautifulSoupHtml5LibFrameParser(_HtmlFrameParser): + """ + HTML to DataFrame parser that uses BeautifulSoup under the hood. + + See Also + -------- + pandas.io.html._HtmlFrameParser + pandas.io.html._LxmlFrameParser + + Notes + ----- + Documentation strings for this class are in the base class + :class:`pandas.io.html._HtmlFrameParser`. + """ + + def _parse_tables(self, document, match, attrs): + element_name = "table" + tables = document.find_all(element_name, attrs=attrs) + if not tables: + raise ValueError("No tables found") + + result = [] + unique_tables = set() + tables = self._handle_hidden_tables(tables, "attrs") + + for table in tables: + if self.displayed_only: + for elem in table.find_all("style"): + elem.decompose() + + for elem in table.find_all(style=re.compile(r"display:\s*none")): + elem.decompose() + + if table not in unique_tables and table.find(string=match) is not None: + result.append(table) + unique_tables.add(table) + if not result: + raise ValueError(f"No tables found matching pattern {match.pattern!r}") + return result + + def _href_getter(self, obj) -> str | None: + a = obj.find("a", href=True) + return None if not a else a["href"] + + def _text_getter(self, obj): + return obj.text + + def _equals_tag(self, obj, tag) -> bool: + return obj.name == tag + + def _parse_td(self, row): + return row.find_all(("td", "th"), recursive=False) + + def _parse_thead_tr(self, table): + return table.select("thead tr") + + def _parse_tbody_tr(self, table): + from_tbody = table.select("tbody tr") + from_root = table.find_all("tr", recursive=False) + # HTML spec: at most one of these lists has content + return from_tbody + from_root + + def _parse_tfoot_tr(self, table): + return table.select("tfoot tr") + + def _setup_build_doc(self): + raw_text = _read(self.io, self.encoding, self.storage_options) + if not raw_text: + raise ValueError(f"No text parsed from document: {self.io}") + return raw_text + + def _build_doc(self): + from bs4 import BeautifulSoup + + bdoc = self._setup_build_doc() + if isinstance(bdoc, bytes) and self.encoding is not None: + udoc = bdoc.decode(self.encoding) + from_encoding = None + else: + udoc = bdoc + from_encoding = self.encoding + + soup = BeautifulSoup(udoc, features="html5lib", from_encoding=from_encoding) + + for br in soup.find_all("br"): + br.replace_with("\n" + br.text) + + return soup + + +def _build_xpath_expr(attrs) -> str: + """ + Build an xpath expression to simulate bs4's ability to pass in kwargs to + search for attributes when using the lxml parser. + + Parameters + ---------- + attrs : dict + A dict of HTML attributes. These are NOT checked for validity. + + Returns + ------- + expr : unicode + An XPath expression that checks for the given HTML attributes. + """ + # give class attribute as class_ because class is a python keyword + if "class_" in attrs: + attrs["class"] = attrs.pop("class_") + + s = " and ".join([f"@{k}={v!r}" for k, v in attrs.items()]) + return f"[{s}]" + + +_re_namespace = {"re": "http://exslt.org/regular-expressions"} + + +class _LxmlFrameParser(_HtmlFrameParser): + """ + HTML to DataFrame parser that uses lxml under the hood. + + Warning + ------- + This parser can only handle HTTP, FTP, and FILE urls. + + See Also + -------- + _HtmlFrameParser + _BeautifulSoupLxmlFrameParser + + Notes + ----- + Documentation strings for this class are in the base class + :class:`_HtmlFrameParser`. + """ + + def _href_getter(self, obj) -> str | None: + href = obj.xpath(".//a/@href") + return None if not href else href[0] + + def _text_getter(self, obj): + return obj.text_content() + + def _parse_td(self, row): + # Look for direct children only: the "row" element here may be a + #
foobar
-only rows + if header is None: + if len(head) == 1: + header = 0 + else: + # ignore all-empty-text rows + header = [i for i, row in enumerate(head) if any(text for text in row)] + + if foot: + body += foot + + # fill out elements of body that are "ragged" + _expand_elements(body) + with TextParser(body, header=header, **kwargs) as tp: + return tp.read() + + +_valid_parsers = { + "lxml": _LxmlFrameParser, + None: _LxmlFrameParser, + "html5lib": _BeautifulSoupHtml5LibFrameParser, + "bs4": _BeautifulSoupHtml5LibFrameParser, +} + + +def _parser_dispatch(flavor: HTMLFlavors | None) -> type[_HtmlFrameParser]: + """ + Choose the parser based on the input flavor. + + Parameters + ---------- + flavor : {{"lxml", "html5lib", "bs4"}} or None + The type of parser to use. This must be a valid backend. + + Returns + ------- + cls : _HtmlFrameParser subclass + The parser class based on the requested input flavor. + + Raises + ------ + ValueError + * If `flavor` is not a valid backend. + ImportError + * If you do not have the requested `flavor` + """ + valid_parsers = list(_valid_parsers.keys()) + if flavor not in valid_parsers: + raise ValueError( + f"{flavor!r} is not a valid flavor, valid flavors are {valid_parsers}" + ) + + if flavor in ("bs4", "html5lib"): + import_optional_dependency("html5lib") + import_optional_dependency("bs4") + else: + import_optional_dependency("lxml.etree") + return _valid_parsers[flavor] + + +def _print_as_set(s) -> str: + arg = ", ".join([pprint_thing(el) for el in s]) + return f"{{{arg}}}" + + +def _validate_flavor(flavor): + if flavor is None: + flavor = "lxml", "bs4" + elif isinstance(flavor, str): + flavor = (flavor,) + elif isinstance(flavor, abc.Iterable): + if not all(isinstance(flav, str) for flav in flavor): + raise TypeError( + f"Object of type {type(flavor).__name__!r} " + f"is not an iterable of strings" + ) + else: + msg = repr(flavor) if isinstance(flavor, str) else str(flavor) + msg += " is not a valid flavor" + raise ValueError(msg) + + flavor = tuple(flavor) + valid_flavors = set(_valid_parsers) + flavor_set = set(flavor) + + if not flavor_set & valid_flavors: + raise ValueError( + f"{_print_as_set(flavor_set)} is not a valid set of flavors, valid " + f"flavors are {_print_as_set(valid_flavors)}" + ) + return flavor + + +def _parse( + flavor, + io, + match, + attrs, + encoding, + displayed_only, + extract_links, + storage_options, + **kwargs, +): + flavor = _validate_flavor(flavor) + compiled_match = re.compile(match) # you can pass a compiled regex here + + retained = None + for flav in flavor: + parser = _parser_dispatch(flav) + p = parser( + io, + compiled_match, + attrs, + encoding, + displayed_only, + extract_links, + storage_options, + ) + + try: + tables = p.parse_tables() + except ValueError as caught: + # if `io` is an io-like object, check if it's seekable + # and try to rewind it before trying the next parser + if hasattr(io, "seekable") and io.seekable(): + io.seek(0) + elif hasattr(io, "seekable") and not io.seekable(): + # if we couldn't rewind it, let the user know + raise ValueError( + f"The flavor {flav} failed to parse your input. " + "Since you passed a non-rewindable file " + "object, we can't rewind it to try " + "another parser. Try read_html() with a different flavor." + ) from caught + + retained = caught + else: + break + else: + assert retained is not None # for mypy + raise retained + + ret = [] + for table in tables: + try: + df = _data_to_frame(data=table, **kwargs) + # Cast MultiIndex header to an Index of tuples when extracting header + # links and replace nan with None (therefore can't use mi.to_flat_index()). + # This maintains consistency of selection (e.g. df.columns.str[1]) + if extract_links in ("all", "header") and isinstance( + df.columns, MultiIndex + ): + df.columns = Index( + ((col[0], None if isna(col[1]) else col[1]) for col in df.columns), + tupleize_cols=False, + ) + + ret.append(df) + except EmptyDataError: # empty table + continue + return ret + + +@set_module("pandas") +def read_html( + io: FilePath | ReadBuffer[str], + *, + match: str | Pattern = ".+", + flavor: HTMLFlavors | Sequence[HTMLFlavors] | None = None, + header: int | Sequence[int] | None = None, + index_col: int | Sequence[int] | None = None, + skiprows: int | Sequence[int] | slice | None = None, + attrs: dict[str, str] | None = None, + parse_dates: bool = False, + thousands: str | None = ",", + encoding: str | None = None, + decimal: str = ".", + converters: dict | None = None, + na_values: Iterable[object] | None = None, + keep_default_na: bool = True, + displayed_only: bool = True, + extract_links: Literal["header", "footer", "body", "all"] | None = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, + storage_options: StorageOptions = None, +) -> list[DataFrame]: + r""" + Read HTML tables into a ``list`` of ``DataFrame`` objects. + + Parameters + ---------- + io : str, path object, or file-like object + String path, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a string ``read()`` function. + The string can represent a URL. Note that + lxml only accepts the http, ftp and file url protocols. If you have a + URL that starts with ``'https'`` you might try removing the ``'s'``. + + match : str or compiled regular expression, optional + The set of tables containing text matching this regex or string will be + returned. Unless the HTML is extremely simple you will probably need to + pass a non-empty string here. Defaults to '.+' (match any non-empty + string). The default value will return all tables contained on a page. + This value is converted to a regular expression so that there is + consistent behavior between Beautiful Soup and lxml. + + flavor : {{"lxml", "html5lib", "bs4"}} or list-like, optional + The parsing engine (or list of parsing engines) to use. 'bs4' and + 'html5lib' are synonymous with each other, they are both there for + backwards compatibility. The default of ``None`` tries to use ``lxml`` + to parse and if that fails it falls back on ``bs4`` + ``html5lib``. + + header : int or list-like, optional + The row (or list of rows for a :class:`~pandas.MultiIndex`) to use to + make the columns headers. + + index_col : int or list-like, optional + The column (or list of columns) to use to create the index. + + skiprows : int, list-like or slice, optional + Number of rows to skip after parsing the column integer. 0-based. If a + sequence of integers or a slice is given, will skip the rows indexed by + that sequence. Note that a single element sequence means 'skip the nth + row' whereas an integer means 'skip n rows'. + + attrs : dict, optional + This is a dictionary of attributes that you can pass to use to identify + the table in the HTML. These are not checked for validity before being + passed to lxml or Beautiful Soup. However, these attributes must be + valid HTML table attributes to work correctly. For example, :: + + attrs = {{"id": "table"}} + + is a valid attribute dictionary because the 'id' HTML tag attribute is + a valid HTML attribute for *any* HTML tag as per `this document + `__. :: + + attrs = {{"asdf": "table"}} + + is *not* a valid attribute dictionary because 'asdf' is not a valid + HTML attribute even if it is a valid XML attribute. Valid HTML 4.01 + table attributes can be found `here + `__. A + working draft of the HTML 5 spec can be found `here + `__. It contains the + latest information on table attributes for the modern web. + + parse_dates : bool, optional + See :func:`~read_csv` for more details. + + thousands : str, optional + Separator to use to parse thousands. Defaults to ``','``. + + encoding : str, optional + The encoding used to decode the web page. Defaults to ``None``.``None`` + preserves the previous encoding behavior, which depends on the + underlying parser library (e.g., the parser library will try to use + the encoding provided by the document). + + decimal : str, default '.' + Character to recognize as decimal point (e.g. use ',' for European + data). + + converters : dict, default None + Dict of functions for converting values in certain columns. Keys can + either be integers or column labels, values are functions that take one + input argument, the cell (not column) content, and return the + transformed content. + + na_values : iterable, default None + Custom NA values. + + keep_default_na : bool, default True + If na_values are specified and keep_default_na is False the default NaN + values are overridden, otherwise they're appended to. + + displayed_only : bool, default True + Whether elements with "display: none" should be parsed. + + extract_links : {{None, "all", "header", "body", "footer"}} + Table elements in the specified section(s) with tags will have their + href extracted. + + dtype_backend : {{'numpy_nullable', 'pyarrow'}} + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). If not specified, the default behavior + is to not use nullable data types. If specified, the behavior + is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + * ``"pyarrow"``: returns pyarrow-backed nullable + :class:`ArrowDtype` :class:`DataFrame` + + .. versionadded:: 2.0 + + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + .. versionadded:: 2.1.0 + + Returns + ------- + dfs + A list of DataFrames. + + See Also + -------- + read_csv : Read a comma-separated values (csv) file into DataFrame. + + Notes + ----- + Before using this function you should read the :ref:`gotchas about the + HTML parsing libraries `. + + Expect to do some cleanup after you call this function. For example, you + might need to manually assign column names if the column names are + converted to NaN when you pass the `header=0` argument. We try to assume as + little as possible about the structure of the table and push the + idiosyncrasies of the HTML contained in the table to the user. + + This function searches for ```` elements and only for ```` + and ```` or ```` argument, it is used to construct + the header, otherwise the function attempts to find the header within + the body (by putting rows with only `` within on malformed HTML. + """ + result = flavor_read_html( + StringIO( + """
`` rows and ```` elements within each ``
`` + element in the table. ```` stands for "table data". This function + attempts to properly handle ``colspan`` and ``rowspan`` attributes. + If the function has a ``
`` elements into the header). + + Similar to :func:`~read_csv` the `header` argument is applied + **after** `skiprows` is applied. + + This function will *always* return a list of :class:`DataFrame` *or* + it will fail, i.e., it will *not* return an empty list, save for some + rare cases. + It might return an empty list in case of inputs with single row and + ```` containing only whitespaces. + + Examples + -------- + See the :ref:`read_html documentation in the IO section of the docs + ` for some examples of reading in HTML tables. + """ + # Type check here. We don't want to parse only to fail because of an + # invalid value of an integer skiprows. + if isinstance(skiprows, numbers.Integral) and skiprows < 0: + raise ValueError( + "cannot skip rows starting from the end of the " + "data (you passed a negative value)" + ) + if extract_links not in [None, "header", "footer", "body", "all"]: + raise ValueError( + "`extract_links` must be one of " + '{None, "header", "footer", "body", "all"}, got ' + f'"{extract_links}"' + ) + + validate_header_arg(header) + check_dtype_backend(dtype_backend) + + io = stringify_path(io) + + return _parse( + flavor=flavor, + io=io, + match=match, + header=header, + index_col=index_col, + skiprows=skiprows, + parse_dates=parse_dates, + thousands=thousands, + attrs=attrs, + encoding=encoding, + decimal=decimal, + converters=converters, + na_values=na_values, + keep_default_na=keep_default_na, + displayed_only=displayed_only, + extract_links=extract_links, + dtype_backend=dtype_backend, + storage_options=storage_options, + ) diff --git a/pandas/io/iceberg.py b/pandas/io/iceberg.py new file mode 100644 index 0000000000000000000000000000000000000000..f4361b000524e0701fd97c2f3632b422f29265bc --- /dev/null +++ b/pandas/io/iceberg.py @@ -0,0 +1,155 @@ +from typing import ( + Any, +) + +from pandas.compat._optional import import_optional_dependency +from pandas.util._decorators import set_module + +from pandas import DataFrame + + +@set_module("pandas") +def read_iceberg( + table_identifier: str, + catalog_name: str | None = None, + *, + catalog_properties: dict[str, Any] | None = None, + columns: list[str] | None = None, + row_filter: str | None = None, + case_sensitive: bool = True, + snapshot_id: int | None = None, + limit: int | None = None, + scan_properties: dict[str, Any] | None = None, +) -> DataFrame: + """ + Read an Apache Iceberg table into a pandas DataFrame. + + .. versionadded:: 3.0.0 + + .. warning:: + + read_iceberg is experimental and may change without warning. + + Parameters + ---------- + table_identifier : str + Table identifier. + catalog_name : str, optional + The name of the catalog. + catalog_properties : dict of {str: str}, optional + The properties that are used next to the catalog configuration. + columns : list of str, optional + A list of strings representing the column names to return in the output + dataframe. + row_filter : str, optional + A string that describes the desired rows. + case_sensitive : bool, default True + If True column matching is case sensitive. + snapshot_id : int, optional + Snapshot ID to time travel to. By default the table will be scanned as of the + current snapshot ID. + limit : int, optional + An integer representing the number of rows to return in the scan result. + By default all matching rows will be fetched. + scan_properties : dict of {str: obj}, optional + Additional Table properties as a dictionary of string key value pairs to use + for this scan. + + Returns + ------- + DataFrame + DataFrame based on the Iceberg table. + + See Also + -------- + read_parquet : Read a Parquet file. + + Examples + -------- + >>> df = pd.read_iceberg( + ... table_identifier="my_table", + ... catalog_name="my_catalog", + ... catalog_properties={"s3.secret-access-key": "my-secret"}, + ... row_filter="trip_distance >= 10.0", + ... columns=["VendorID", "tpep_pickup_datetime"], + ... ) # doctest: +SKIP + """ + pyiceberg_catalog = import_optional_dependency("pyiceberg.catalog") + pyiceberg_expressions = import_optional_dependency("pyiceberg.expressions") + if catalog_properties is None: + catalog_properties = {} + catalog = pyiceberg_catalog.load_catalog(catalog_name, **catalog_properties) + table = catalog.load_table(table_identifier) + if row_filter is None: + row_filter = pyiceberg_expressions.AlwaysTrue() + if columns is None: + selected_fields = ("*",) + else: + selected_fields = tuple(columns) # type: ignore[assignment] + if scan_properties is None: + scan_properties = {} + result = table.scan( + row_filter=row_filter, + selected_fields=selected_fields, + case_sensitive=case_sensitive, + snapshot_id=snapshot_id, + options=scan_properties, + limit=limit, + ) + return result.to_pandas() + + +def to_iceberg( + df: DataFrame, + table_identifier: str, + catalog_name: str | None = None, + *, + catalog_properties: dict[str, Any] | None = None, + location: str | None = None, + append: bool = False, + snapshot_properties: dict[str, str] | None = None, +) -> None: + """ + Write a DataFrame to an Apache Iceberg table. + + .. versionadded:: 3.0.0 + + Parameters + ---------- + table_identifier : str + Table identifier. + catalog_name : str, optional + The name of the catalog. + catalog_properties : dict of {str: str}, optional + The properties that are used next to the catalog configuration. + location : str, optional + Location for the table. + append : bool, default False + If ``True``, append data to the table, instead of replacing the content. + snapshot_properties : dict of {str: str}, optional + Custom properties to be added to the snapshot summary + + See Also + -------- + read_iceberg : Read an Apache Iceberg table. + DataFrame.to_parquet : Write a DataFrame in Parquet format. + """ + pa = import_optional_dependency("pyarrow") + pyiceberg_catalog = import_optional_dependency("pyiceberg.catalog") + if catalog_properties is None: + catalog_properties = {} + catalog = pyiceberg_catalog.load_catalog(catalog_name, **catalog_properties) + arrow_table = pa.Table.from_pandas(df) + table = catalog.create_table_if_not_exists( + identifier=table_identifier, + schema=arrow_table.schema, + location=location, + # we could add `partition_spec`, `sort_order` and `properties` in the + # future, but it may not be trivial without exposing PyIceberg objects + ) + if snapshot_properties is None: + snapshot_properties = {} + if append: + table.append(arrow_table, snapshot_properties=snapshot_properties) + else: + table.overwrite(arrow_table, snapshot_properties=snapshot_properties) diff --git a/pandas/io/orc.py b/pandas/io/orc.py new file mode 100644 index 0000000000000000000000000000000000000000..8851532508c7e1d55df8b7da01008c95f78b09d6 --- /dev/null +++ b/pandas/io/orc.py @@ -0,0 +1,243 @@ +"""orc compat""" + +from __future__ import annotations + +import io +from typing import ( + TYPE_CHECKING, + Any, + Literal, +) + +from pandas._libs import lib +from pandas.compat._optional import import_optional_dependency +from pandas.util._decorators import set_module +from pandas.util._validators import check_dtype_backend + +from pandas.core.indexes.api import default_index + +from pandas.io._util import arrow_table_to_pandas +from pandas.io.common import ( + get_handle, + is_fsspec_url, +) + +if TYPE_CHECKING: + import fsspec + import pyarrow.fs + + from pandas._typing import ( + DtypeBackend, + FilePath, + ReadBuffer, + WriteBuffer, + ) + + from pandas.core.frame import DataFrame + + +@set_module("pandas") +def read_orc( + path: FilePath | ReadBuffer[bytes], + columns: list[str] | None = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, + filesystem: pyarrow.fs.FileSystem | fsspec.spec.AbstractFileSystem | None = None, + **kwargs: Any, +) -> DataFrame: + """ + Load an ORC object from the file path, returning a DataFrame. + + This method reads an ORC (Optimized Row Columnar) file into a pandas + DataFrame using the `pyarrow.orc` library. ORC is a columnar storage format + that provides efficient compression and fast retrieval for analytical workloads. + It allows reading specific columns, handling different filesystem + types (such as local storage, cloud storage via fsspec, or pyarrow filesystem), + and supports different data type backends, including `numpy_nullable` and `pyarrow`. + + Parameters + ---------- + path : str, path object, or file-like object + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``read()`` function. The string could be a URL. + Valid URL schemes include http, ftp, s3, and file. For file URLs, a host is + expected. A local file could be: + ``file://localhost/path/to/table.orc``. + columns : list, default None + If not None, only these columns will be read from the file. + Output always follows the ordering of the file and not the columns list. + This mirrors the original behaviour of + :external+pyarrow:py:meth:`pyarrow.orc.ORCFile.read`. + dtype_backend : {'numpy_nullable', 'pyarrow'} + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). If not specified, the default behavior + is to not use nullable data types. If specified, the behavior + is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + * ``"pyarrow"``: returns pyarrow-backed nullable + :class:`ArrowDtype` :class:`DataFrame` + + .. versionadded:: 2.0 + + filesystem : fsspec or pyarrow filesystem, default None + Filesystem object to use when reading the orc file. + + .. versionadded:: 2.1.0 + + **kwargs + Any additional kwargs are passed to pyarrow. + + Returns + ------- + DataFrame + DataFrame based on the ORC file. + + See Also + -------- + read_csv : Read a comma-separated values (csv) file into a pandas DataFrame. + read_excel : Read an Excel file into a pandas DataFrame. + read_spss : Read an SPSS file into a pandas DataFrame. + read_sas : Load a SAS file into a pandas DataFrame. + read_feather : Load a feather-format object into a pandas DataFrame. + + Notes + ----- + Before using this function you should read the :ref:`user guide about ORC ` + and :ref:`install optional dependencies `. + + If ``path`` is a URI scheme pointing to a local or remote file (e.g. "s3://"), + a ``pyarrow.fs`` filesystem will be attempted to read the file. You can also pass a + pyarrow or fsspec filesystem object into the filesystem keyword to override this + behavior. + + Examples + -------- + >>> result = pd.read_orc("example_pa.orc") # doctest: +SKIP + """ + # we require a newer version of pyarrow than we support for orc + + orc = import_optional_dependency("pyarrow.orc") + + check_dtype_backend(dtype_backend) + + with get_handle(path, "rb", is_text=False) as handles: + source = handles.handle + if is_fsspec_url(path) and filesystem is None: + pa = import_optional_dependency("pyarrow") + pa_fs = import_optional_dependency("pyarrow.fs") + try: + filesystem, source = pa_fs.FileSystem.from_uri(path) + except (TypeError, pa.ArrowInvalid): + pass + + pa_table = orc.read_table( + source=source, columns=columns, filesystem=filesystem, **kwargs + ) + return arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend) + + +def to_orc( + df: DataFrame, + path: FilePath | WriteBuffer[bytes] | None = None, + *, + engine: Literal["pyarrow"] = "pyarrow", + index: bool | None = None, + engine_kwargs: dict[str, Any] | None = None, +) -> bytes | None: + """ + Write a DataFrame to the ORC format. + + Parameters + ---------- + df : DataFrame + The dataframe to be written to ORC. Raises NotImplementedError + if dtype of one or more columns is category, unsigned integers, + intervals, periods or sparse. + path : str, file-like object or None, default None + If a string, it will be used as Root Directory path + when writing a partitioned dataset. By file-like object, + we refer to objects with a write() method, such as a file handle + (e.g. via builtin open function). If path is None, + a bytes object is returned. + engine : str, default 'pyarrow' + ORC library to use. + index : bool, optional + If ``True``, include the dataframe's index(es) in the file output. If + ``False``, they will not be written to the file. + If ``None``, similar to ``infer`` the dataframe's index(es) + will be saved. However, instead of being saved as values, + the RangeIndex will be stored as a range in the metadata so it + doesn't require much space and is faster. Other indexes will + be included as columns in the file output. + engine_kwargs : dict[str, Any] or None, default None + Additional keyword arguments passed to :func:`pyarrow.orc.write_table`. + + Returns + ------- + bytes if no path argument is provided else None + + Raises + ------ + NotImplementedError + Dtype of one or more columns is category, unsigned integers, interval, + period or sparse. + ValueError + engine is not pyarrow. + + Notes + ----- + * Before using this function you should read the + :ref:`user guide about ORC ` and + :ref:`install optional dependencies `. + * This function requires `pyarrow `_ + library. + * For supported dtypes please refer to `supported ORC features in Arrow + `__. + * Currently timezones in datetime columns are not preserved when a + dataframe is converted into ORC files. + """ + if index is None: + index = df.index.names[0] is not None + if engine_kwargs is None: + engine_kwargs = {} + + # validate index + # -------------- + + # validate that we have only a default index + # raise on anything else as we don't serialize the index + + if not df.index.equals(default_index(len(df))): + raise ValueError( + "orc does not support serializing a non-default index for the index; " + "you can .reset_index() to make the index into column(s)" + ) + + if df.index.name is not None: + raise ValueError("orc does not serialize index meta-data on a default index") + + if engine != "pyarrow": + raise ValueError("engine must be 'pyarrow'") + pa = import_optional_dependency("pyarrow") + orc = import_optional_dependency("pyarrow.orc") + + was_none = path is None + if was_none: + path = io.BytesIO() + assert path is not None # For mypy + with get_handle(path, "wb", is_text=False) as handles: + try: + orc.write_table( + pa.Table.from_pandas(df, preserve_index=index), + handles.handle, + **engine_kwargs, + ) + except (TypeError, pa.ArrowNotImplementedError) as e: + raise NotImplementedError( + "The dtype of one or more columns is not supported yet." + ) from e + + if was_none: + assert isinstance(path, io.BytesIO) # For mypy + return path.getvalue() + return None diff --git a/pandas/io/parquet.py b/pandas/io/parquet.py new file mode 100644 index 0000000000000000000000000000000000000000..218002ebb3f6a0bef4c994b152c8ff51c0852440 --- /dev/null +++ b/pandas/io/parquet.py @@ -0,0 +1,680 @@ +"""parquet compat""" + +from __future__ import annotations + +import io +import json +import os +from typing import ( + TYPE_CHECKING, + Any, + Literal, +) +from warnings import ( + catch_warnings, + filterwarnings, +) + +from pandas._libs import lib +from pandas.compat._optional import import_optional_dependency +from pandas.errors import ( + AbstractMethodError, + Pandas4Warning, +) +from pandas.util._decorators import set_module +from pandas.util._validators import check_dtype_backend + +from pandas import ( + DataFrame, + get_option, +) + +from pandas.io._util import arrow_table_to_pandas +from pandas.io.common import ( + IOHandles, + get_handle, + is_fsspec_url, + is_url, + stringify_path, +) + +if TYPE_CHECKING: + from pandas._typing import ( + DtypeBackend, + FilePath, + ParquetCompressionOptions, + ReadBuffer, + StorageOptions, + WriteBuffer, + ) + + +def get_engine(engine: str) -> BaseImpl: + """return our implementation""" + if engine == "auto": + engine = get_option("io.parquet.engine") + + if engine == "auto": + # try engines in this order + engine_classes = [PyArrowImpl, FastParquetImpl] + + error_msgs = "" + for engine_class in engine_classes: + try: + return engine_class() + except ImportError as err: + error_msgs += "\n - " + str(err) + + raise ImportError( + "Unable to find a usable engine; " + "tried using: 'pyarrow', 'fastparquet'.\n" + "A suitable version of " + "pyarrow or fastparquet is required for parquet " + "support.\n" + "Trying to import the above resulted in these errors:" + f"{error_msgs}" + ) + + if engine == "pyarrow": + return PyArrowImpl() + elif engine == "fastparquet": + return FastParquetImpl() + + raise ValueError("engine must be one of 'pyarrow', 'fastparquet'") + + +def _get_path_or_handle( + path: FilePath | ReadBuffer[bytes] | WriteBuffer[bytes], + fs: Any, + storage_options: StorageOptions | None = None, + mode: str = "rb", + is_dir: bool = False, +) -> tuple[ + FilePath | ReadBuffer[bytes] | WriteBuffer[bytes], IOHandles[bytes] | None, Any +]: + """File handling for PyArrow.""" + path_or_handle = stringify_path(path) + if fs is not None: + pa_fs = import_optional_dependency("pyarrow.fs", errors="ignore") + fsspec = import_optional_dependency("fsspec", errors="ignore") + if pa_fs is not None and isinstance(fs, pa_fs.FileSystem): + if storage_options: + raise NotImplementedError( + "storage_options not supported with a pyarrow FileSystem." + ) + elif fsspec is not None and isinstance(fs, fsspec.spec.AbstractFileSystem): + pass + else: + raise ValueError( + f"filesystem must be a pyarrow or fsspec FileSystem, " + f"not a {type(fs).__name__}" + ) + if is_fsspec_url(path_or_handle) and fs is None: + if storage_options is None: + pa = import_optional_dependency("pyarrow") + pa_fs = import_optional_dependency("pyarrow.fs") + + try: + fs, path_or_handle = pa_fs.FileSystem.from_uri(path) + except (TypeError, pa.ArrowInvalid): + pass + if fs is None: + fsspec = import_optional_dependency("fsspec") + fs, path_or_handle = fsspec.core.url_to_fs( + path_or_handle, **(storage_options or {}) + ) + elif storage_options and (not is_url(path_or_handle) or mode != "rb"): + # can't write to a remote url + # without making use of fsspec at the moment + raise ValueError("storage_options passed with buffer, or non-supported URL") + + handles = None + if ( + not fs + and not is_dir + and isinstance(path_or_handle, str) + and not os.path.isdir(path_or_handle) + ): + # use get_handle only when we are very certain that it is not a directory + # fsspec resources can also point to directories + # this branch is used for example when reading from non-fsspec URLs + handles = get_handle( + path_or_handle, mode, is_text=False, storage_options=storage_options + ) + fs = None + path_or_handle = handles.handle + return path_or_handle, handles, fs + + +class BaseImpl: + @staticmethod + def validate_dataframe(df: DataFrame) -> None: + if not isinstance(df, DataFrame): + raise ValueError("to_parquet only supports IO with DataFrames") + + def write(self, df: DataFrame, path, compression, **kwargs) -> None: + raise AbstractMethodError(self) + + def read(self, path, columns=None, **kwargs) -> DataFrame: + raise AbstractMethodError(self) + + +class PyArrowImpl(BaseImpl): + def __init__(self) -> None: + import_optional_dependency( + "pyarrow", extra="pyarrow is required for parquet support." + ) + import pyarrow.parquet + + # import utils to register the pyarrow extension types + import pandas.core.arrays.arrow.extension_types # pyright: ignore[reportUnusedImport] # noqa: F401 + + self.api = pyarrow + + def write( + self, + df: DataFrame, + path: FilePath | WriteBuffer[bytes], + compression: ParquetCompressionOptions = "snappy", + index: bool | None = None, + storage_options: StorageOptions | None = None, + partition_cols: list[str] | None = None, + filesystem=None, + **kwargs, + ) -> None: + self.validate_dataframe(df) + + from_pandas_kwargs: dict[str, Any] = {"schema": kwargs.pop("schema", None)} + if index is not None: + from_pandas_kwargs["preserve_index"] = index + + table = self.api.Table.from_pandas(df, **from_pandas_kwargs) + + if df.attrs: + df_metadata = {"PANDAS_ATTRS": json.dumps(df.attrs)} + existing_metadata = table.schema.metadata + merged_metadata = {**existing_metadata, **df_metadata} + table = table.replace_schema_metadata(merged_metadata) + + path_or_handle, handles, filesystem = _get_path_or_handle( + path, + filesystem, + storage_options=storage_options, + mode="wb", + is_dir=partition_cols is not None, + ) + if ( + isinstance(path_or_handle, io.BufferedWriter) + and hasattr(path_or_handle, "name") + and isinstance(path_or_handle.name, (str, bytes)) + ): + if isinstance(path_or_handle.name, bytes): + path_or_handle = path_or_handle.name.decode() + else: + path_or_handle = path_or_handle.name + + try: + if partition_cols is not None: + # writes to multiple files under the given path + self.api.parquet.write_to_dataset( + table, + path_or_handle, + compression=compression, + partition_cols=partition_cols, + filesystem=filesystem, + **kwargs, + ) + else: + # write to single output file + self.api.parquet.write_table( + table, + path_or_handle, + compression=compression, + filesystem=filesystem, + **kwargs, + ) + finally: + if handles is not None: + handles.close() + + def read( + self, + path, + columns=None, + filters=None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, + storage_options: StorageOptions | None = None, + filesystem=None, + to_pandas_kwargs: dict[str, Any] | None = None, + **kwargs, + ) -> DataFrame: + kwargs["use_pandas_metadata"] = True + + path_or_handle, handles, filesystem = _get_path_or_handle( + path, + filesystem, + storage_options=storage_options, + mode="rb", + ) + try: + pa_table = self.api.parquet.read_table( + path_or_handle, + columns=columns, + filesystem=filesystem, + filters=filters, + **kwargs, + ) + with catch_warnings(): + filterwarnings( + "ignore", + "make_block is deprecated", + Pandas4Warning, + ) + result = arrow_table_to_pandas( + pa_table, + dtype_backend=dtype_backend, + to_pandas_kwargs=to_pandas_kwargs, + ) + + if pa_table.schema.metadata: + if b"PANDAS_ATTRS" in pa_table.schema.metadata: + df_metadata = pa_table.schema.metadata[b"PANDAS_ATTRS"] + result.attrs = json.loads(df_metadata) + return result + finally: + if handles is not None: + handles.close() + + +class FastParquetImpl(BaseImpl): + def __init__(self) -> None: + # since pandas is a dependency of fastparquet + # we need to import on first use + fastparquet = import_optional_dependency( + "fastparquet", extra="fastparquet is required for parquet support." + ) + self.api = fastparquet + + def write( + self, + df: DataFrame, + path, + compression: Literal["snappy", "gzip", "brotli"] | None = "snappy", + index=None, + partition_cols=None, + storage_options: StorageOptions | None = None, + filesystem=None, + **kwargs, + ) -> None: + self.validate_dataframe(df) + + if "partition_on" in kwargs and partition_cols is not None: + raise ValueError( + "Cannot use both partition_on and " + "partition_cols. Use partition_cols for partitioning data" + ) + if "partition_on" in kwargs: + partition_cols = kwargs.pop("partition_on") + + if partition_cols is not None: + kwargs["file_scheme"] = "hive" + + if filesystem is not None: + raise NotImplementedError( + "filesystem is not implemented for the fastparquet engine." + ) + + # cannot use get_handle as write() does not accept file buffers + path = stringify_path(path) + if is_fsspec_url(path): + fsspec = import_optional_dependency("fsspec") + + # if filesystem is provided by fsspec, file must be opened in 'wb' mode. + kwargs["open_with"] = lambda path, _: fsspec.open( + path, "wb", **(storage_options or {}) + ).open() + elif storage_options: + raise ValueError( + "storage_options passed with file object or non-fsspec file path" + ) + + with catch_warnings(record=True): + self.api.write( + path, + df, + compression=compression, + write_index=index, + partition_on=partition_cols, + **kwargs, + ) + + def read( + self, + path, + columns=None, + filters=None, + storage_options: StorageOptions | None = None, + filesystem=None, + to_pandas_kwargs: dict | None = None, + **kwargs, + ) -> DataFrame: + parquet_kwargs: dict[str, Any] = {} + dtype_backend = kwargs.pop("dtype_backend", lib.no_default) + # We are disabling nullable dtypes for fastparquet pending discussion + parquet_kwargs["pandas_nulls"] = False + if dtype_backend is not lib.no_default: + raise ValueError( + "The 'dtype_backend' argument is not supported for the " + "fastparquet engine" + ) + if filesystem is not None: + raise NotImplementedError( + "filesystem is not implemented for the fastparquet engine." + ) + if to_pandas_kwargs is not None: + raise NotImplementedError( + "to_pandas_kwargs is not implemented for the fastparquet engine." + ) + path = stringify_path(path) + handles = None + if is_fsspec_url(path): + fsspec = import_optional_dependency("fsspec") + + parquet_kwargs["fs"] = fsspec.open(path, "rb", **(storage_options or {})).fs + elif isinstance(path, str) and not os.path.isdir(path): + # use get_handle only when we are very certain that it is not a directory + # fsspec resources can also point to directories + # this branch is used for example when reading from non-fsspec URLs + handles = get_handle( + path, "rb", is_text=False, storage_options=storage_options + ) + path = handles.handle + + try: + parquet_file = self.api.ParquetFile(path, **parquet_kwargs) + with catch_warnings(): + filterwarnings( + "ignore", + "make_block is deprecated", + Pandas4Warning, + ) + return parquet_file.to_pandas( + columns=columns, filters=filters, **kwargs + ) + finally: + if handles is not None: + handles.close() + + +def to_parquet( + df: DataFrame, + path: FilePath | WriteBuffer[bytes] | None = None, + engine: str = "auto", + compression: ParquetCompressionOptions = "snappy", + index: bool | None = None, + storage_options: StorageOptions | None = None, + partition_cols: list[str] | None = None, + filesystem: Any = None, + **kwargs, +) -> bytes | None: + """ + Write a DataFrame to the parquet format. + + Parameters + ---------- + df : DataFrame + path : str, path object, file-like object, or None, default None + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``write()`` function. If None, the result + is returned as bytes. If a string, it will be used as Root Directory + path when writing a partitioned dataset. The engine fastparquet does + not accept file-like objects. + engine : {{'auto', 'pyarrow', 'fastparquet'}}, default 'auto' + Parquet library to use. If 'auto', then the option + ``io.parquet.engine`` is used. The default ``io.parquet.engine`` + behavior is to try 'pyarrow', falling back to 'fastparquet' if + 'pyarrow' is unavailable. + + When using the ``'pyarrow'`` engine and no storage options are provided + and a filesystem is implemented by both ``pyarrow.fs`` and ``fsspec`` + (e.g. "s3://"), then the ``pyarrow.fs`` filesystem is attempted first. + Use the filesystem keyword with an instantiated fsspec filesystem + if you wish to use its implementation. + compression : {{'snappy', 'gzip', 'brotli', 'lz4', 'zstd', None}}, + default 'snappy'. Name of the compression to use. Use ``None`` + for no compression. + index : bool, default None + If ``True``, include the dataframe's index(es) in the file output. If + ``False``, they will not be written to the file. + If ``None``, similar to ``True`` the dataframe's index(es) + will be saved. However, instead of being saved as values, + the RangeIndex will be stored as a range in the metadata so it + doesn't require much space and is faster. Other indexes will + be included as columns in the file output. + partition_cols : str or list, optional, default None + Column names by which to partition the dataset. + Columns are partitioned in the order they are given. + Must be None if path is not a string. + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value + pairs are forwarded to ``urllib.request.Request`` as header options. + For other URLs (e.g. starting with "s3://", and "gcs://") the + key-value pairs are forwarded to ``fsspec.open``. Please see ``fsspec`` + and ``urllib`` for more details, and for more examples on storage + options refer `here `_. + filesystem : fsspec or pyarrow filesystem, default None + Filesystem object to use when reading the parquet file. Only implemented + for ``engine="pyarrow"``. + + .. versionadded:: 2.1.0 + + **kwargs + Additional keyword arguments passed to the engine: + + * For ``engine="pyarrow"``: passed to :func:`pyarrow.parquet.write_table` + or :func:`pyarrow.parquet.write_to_dataset` (when using partition_cols) + * For ``engine="fastparquet"``: passed to :func:`fastparquet.write` + + Returns + ------- + bytes if no path argument is provided else None + """ + if isinstance(partition_cols, str): + partition_cols = [partition_cols] + impl = get_engine(engine) + + path_or_buf: FilePath | WriteBuffer[bytes] = io.BytesIO() if path is None else path + + impl.write( + df, + path_or_buf, + compression=compression, + index=index, + partition_cols=partition_cols, + storage_options=storage_options, + filesystem=filesystem, + **kwargs, + ) + + if path is None: + assert isinstance(path_or_buf, io.BytesIO) + return path_or_buf.getvalue() + else: + return None + + +@set_module("pandas") +def read_parquet( + path: FilePath | ReadBuffer[bytes], + engine: str = "auto", + columns: list[str] | None = None, + storage_options: StorageOptions | None = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, + filesystem: Any = None, + filters: list[tuple] | list[list[tuple]] | None = None, + to_pandas_kwargs: dict | None = None, + **kwargs, +) -> DataFrame: + """ + Load a parquet object from the file path, returning a DataFrame. + + The function automatically handles reading the data from a parquet file + and creates a DataFrame with the appropriate structure. + + Parameters + ---------- + path : str, path object or file-like object + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``read()`` function. + The string could be a URL. Valid URL schemes include http, ftp, s3, + gs, and file. For file URLs, a host is expected. A local file could be: + ``file://localhost/path/to/table.parquet``. + A file URL can also be a path to a directory that contains multiple + partitioned parquet files. Both pyarrow and fastparquet support + paths to directories as well as file URLs. A directory path could be: + ``file://localhost/path/to/tables`` or ``s3://bucket/partition_dir``. + engine : {{'auto', 'pyarrow', 'fastparquet'}}, default 'auto' + Parquet library to use. If 'auto', then the option + ``io.parquet.engine`` is used. The default ``io.parquet.engine`` + behavior is to try 'pyarrow', falling back to 'fastparquet' if + 'pyarrow' is unavailable. + + When using the ``'pyarrow'`` engine and no storage options are provided + and a filesystem is implemented by both ``pyarrow.fs`` and ``fsspec`` + (e.g. "s3://"), then the ``pyarrow.fs`` filesystem is attempted first. + Use the filesystem keyword with an instantiated fsspec filesystem + if you wish to use its implementation. + columns : list, default=None + If not None, only these columns will be read from the file. + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value + pairs are forwarded to ``urllib.request.Request`` as header options. + For other URLs (e.g. starting with "s3://", and "gcs://") the + key-value pairs are forwarded to ``fsspec.open``. Please see ``fsspec`` + and ``urllib`` for more details, and for more examples on storage + options refer `here `_. + dtype_backend : {{'numpy_nullable', 'pyarrow'}} + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). If not specified, the default behavior + is to not use nullable data types. If specified, the behavior + is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + * ``"pyarrow"``: returns pyarrow-backed nullable + :class:`ArrowDtype` :class:`DataFrame` + + .. versionadded:: 2.0 + + filesystem : fsspec or pyarrow filesystem, default None + Filesystem object to use when reading the parquet file. Only implemented + for ``engine="pyarrow"``. + + .. versionadded:: 2.1.0 + + filters : List[Tuple] or List[List[Tuple]], default None + To filter out data. + Filter syntax: [[(column, op, val), ...],...] + where op is [==, =, >, >=, <, <=, !=, in, not in] + The innermost tuples are transposed into a set of filters applied + through an `AND` operation. + The outer list combines these sets of filters through an `OR` + operation. + A single list of tuples can also be used, meaning that no `OR` + operation between set of filters is to be conducted. + + Using this argument will NOT result in row-wise filtering of the final + partitions unless ``engine="pyarrow"`` is also specified. For + other engines, filtering is only performed at the partition level, that is, + to prevent the loading of some row-groups and/or files. + + .. versionadded:: 2.1.0 + + to_pandas_kwargs : dict | None, default None + Keyword arguments to pass through to :func:`pyarrow.Table.to_pandas` + when ``engine="pyarrow"``. + + .. versionadded:: 3.0.0 + + **kwargs + Additional keyword arguments passed to the engine: + + * For ``engine="pyarrow"``: passed to :func:`pyarrow.parquet.read_table` + * For ``engine="fastparquet"``: passed to + :meth:`fastparquet.ParquetFile.to_pandas` + + Returns + ------- + DataFrame + DataFrame based on parquet file. + + See Also + -------- + DataFrame.to_parquet : Create a parquet object that serializes a DataFrame. + + Examples + -------- + >>> original_df = pd.DataFrame({"foo": range(5), "bar": range(5, 10)}) + >>> original_df + foo bar + 0 0 5 + 1 1 6 + 2 2 7 + 3 3 8 + 4 4 9 + >>> df_parquet_bytes = original_df.to_parquet() + >>> from io import BytesIO + >>> restored_df = pd.read_parquet(BytesIO(df_parquet_bytes)) + >>> restored_df + foo bar + 0 0 5 + 1 1 6 + 2 2 7 + 3 3 8 + 4 4 9 + >>> restored_df.equals(original_df) + True + >>> restored_bar = pd.read_parquet(BytesIO(df_parquet_bytes), columns=["bar"]) + >>> restored_bar + bar + 0 5 + 1 6 + 2 7 + 3 8 + 4 9 + >>> restored_bar.equals(original_df[["bar"]]) + True + + The function uses `kwargs` that are passed directly to the engine. + In the following example, we use the `filters` argument of the pyarrow + engine to filter the rows of the DataFrame. + + Since `pyarrow` is the default engine, we can omit the `engine` argument. + Note that the `filters` argument is implemented by the `pyarrow` engine, + which can benefit from multithreading and also potentially be more + economical in terms of memory. + + >>> sel = [("foo", ">", 2)] + >>> restored_part = pd.read_parquet(BytesIO(df_parquet_bytes), filters=sel) + >>> restored_part + foo bar + 0 3 8 + 1 4 9 + """ + + impl = get_engine(engine) + check_dtype_backend(dtype_backend) + + return impl.read( + path, + columns=columns, + filters=filters, + storage_options=storage_options, + dtype_backend=dtype_backend, + filesystem=filesystem, + to_pandas_kwargs=to_pandas_kwargs, + **kwargs, + ) diff --git a/pandas/io/pickle.py b/pandas/io/pickle.py new file mode 100644 index 0000000000000000000000000000000000000000..0f2b380bc70bf273b2b4cd400281b4b7d2d5fc15 --- /dev/null +++ b/pandas/io/pickle.py @@ -0,0 +1,239 @@ +"""pickle compat""" + +from __future__ import annotations + +import pickle +from typing import ( + TYPE_CHECKING, + Any, +) +import warnings + +from pandas.compat import pickle_compat +from pandas.util._decorators import set_module + +from pandas.io.common import get_handle + +if TYPE_CHECKING: + from pandas._typing import ( + CompressionOptions, + FilePath, + ReadPickleBuffer, + StorageOptions, + WriteBuffer, + ) + + from pandas import ( + DataFrame, + Series, + ) + + +@set_module("pandas") +def to_pickle( + obj: Any, + filepath_or_buffer: FilePath | WriteBuffer[bytes], + compression: CompressionOptions = "infer", + protocol: int = pickle.HIGHEST_PROTOCOL, + storage_options: StorageOptions | None = None, +) -> None: + """ + Pickle (serialize) object to file. + + Parameters + ---------- + obj : any object + Any python object. + filepath_or_buffer : str, path object, or file-like object + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``write()`` function. + Also accepts URL. URL has to be of S3 or GCS. + compression : str or dict, default 'infer' + For on-the-fly compression of the output data. If 'infer' and + 'filepath_or_buffer' is path-like, then detect compression from the + following extensions: '.gz', '.bz2', '.zip', '.xz', '.zst', '.tar', + '.tar.gz', '.tar.xz' or '.tar.bz2' (otherwise no compression). + Set to ``None`` for no compression. + Can also be a dict with key ``'method'`` set + to one of {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, + ``'tar'``} and other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdCompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for faster compression + and to create a reproducible gzip archive: + ``compression={'method': 'gzip', 'compresslevel': 1, 'mtime': 1}``. + protocol : int + Int which indicates which protocol should be used by the pickler, + default HIGHEST_PROTOCOL (see [1], paragraph 12.1.2). The possible + values for this parameter depend on the version of Python. For Python + 2.x, possible values are 0, 1, 2. For Python>=3.0, 3 is a valid value. + For Python >= 3.4, 4 is a valid value. A negative value for the + protocol parameter is equivalent to setting its value to + HIGHEST_PROTOCOL. + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + .. [1] https://docs.python.org/3/library/pickle.html + + See Also + -------- + read_pickle : Load pickled pandas object (or any object) from file. + DataFrame.to_hdf : Write DataFrame to an HDF5 file. + DataFrame.to_sql : Write DataFrame to a SQL database. + DataFrame.to_parquet : Write a DataFrame to the binary parquet format. + + Examples + -------- + >>> original_df = pd.DataFrame( + ... {{"foo": range(5), "bar": range(5, 10)}} + ... ) # doctest: +SKIP + >>> original_df # doctest: +SKIP + foo bar + 0 0 5 + 1 1 6 + 2 2 7 + 3 3 8 + 4 4 9 + >>> pd.to_pickle(original_df, "./dummy.pkl") # doctest: +SKIP + + >>> unpickled_df = pd.read_pickle("./dummy.pkl") # doctest: +SKIP + >>> unpickled_df # doctest: +SKIP + foo bar + 0 0 5 + 1 1 6 + 2 2 7 + 3 3 8 + 4 4 9 + """ + if protocol < 0: + protocol = pickle.HIGHEST_PROTOCOL + + with get_handle( + filepath_or_buffer, + "wb", + compression=compression, + is_text=False, + storage_options=storage_options, + ) as handles: + # letting pickle write directly to the buffer is more memory-efficient + pickle.dump(obj, handles.handle, protocol=protocol) + + +@set_module("pandas") +def read_pickle( + filepath_or_buffer: FilePath | ReadPickleBuffer, + compression: CompressionOptions = "infer", + storage_options: StorageOptions | None = None, +) -> DataFrame | Series: + """ + Load pickled pandas object (or any object) from file and return unpickled object. + + .. warning:: + + Loading pickled data received from untrusted sources can be + unsafe. See `here `__. + + Parameters + ---------- + filepath_or_buffer : str, path object, or file-like object + String, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a binary ``readlines()`` function. + Also accepts URL. URL is not limited to S3 and GCS. + compression : str or dict, default 'infer' + For on-the-fly decompression of on-disk data. If 'infer' and + 'filepath_or_buffer' is path-like, then detect compression from the + following extensions: '.gz', '.bz2', '.zip', '.xz', '.zst', '.tar', + '.tar.gz', '.tar.xz' or '.tar.bz2' (otherwise no compression). + If using 'zip' or 'tar', the ZIP file must contain only one data file + to be read in. + Set to ``None`` for no decompression. + Can also be a dict with key ``'method'`` set + to one of {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, + ``'tar'``} and other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdDecompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for Zstandard decompression + using a custom compression dictionary: + ``compression={'method': 'zstd', 'dict_data': my_compression_dict}``. + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + Returns + ------- + object + The unpickled pandas object (or any object) that was stored in file. + + See Also + -------- + DataFrame.to_pickle : Pickle (serialize) DataFrame object to file. + Series.to_pickle : Pickle (serialize) Series object to file. + read_hdf : Read HDF5 file into a DataFrame. + read_sql : Read SQL query or database table into a DataFrame. + read_parquet : Load a parquet object, returning a DataFrame. + + Notes + ----- + read_pickle is only guaranteed to be backwards compatible to pandas 1.0 + provided the object was serialized with to_pickle. + + Examples + -------- + >>> original_df = pd.DataFrame( + ... {{"foo": range(5), "bar": range(5, 10)}} + ... ) # doctest: +SKIP + >>> original_df # doctest: +SKIP + foo bar + 0 0 5 + 1 1 6 + 2 2 7 + 3 3 8 + 4 4 9 + >>> pd.to_pickle(original_df, "./dummy.pkl") # doctest: +SKIP + + >>> unpickled_df = pd.read_pickle("./dummy.pkl") # doctest: +SKIP + >>> unpickled_df # doctest: +SKIP + foo bar + 0 0 5 + 1 1 6 + 2 2 7 + 3 3 8 + 4 4 9 + """ + # TypeError for Cython complaints about object.__new__ vs Tick.__new__ + excs_to_catch = (AttributeError, ImportError, ModuleNotFoundError, TypeError) + with get_handle( + filepath_or_buffer, + "rb", + compression=compression, + is_text=False, + storage_options=storage_options, + ) as handles: + # 1) try standard library Pickle + # 2) try pickle_compat (older pandas version) to handle subclass changes + try: + with warnings.catch_warnings(record=True): + # We want to silence any warnings about, e.g. moved modules. + warnings.simplefilter("ignore", Warning) + return pickle.load(handles.handle) + except excs_to_catch: + # e.g. + # "No module named 'pandas.core.sparse.series'" + # "Can't get attribute '_nat_unpickle' on str: + # set the encoding if we need + if encoding is None: + encoding = _default_encoding + + return encoding + + +def _ensure_str(name): + """ + Ensure that an index / column name is a str (python 3); otherwise they + may be np.string dtype. Non-string dtypes are passed through unchanged. + + https://github.com/pandas-dev/pandas/issues/13492 + """ + if isinstance(name, str): + name = str(name) + return name + + +Term: TypeAlias = PyTablesExpr + + +def _ensure_term(where, scope_level: int): + """ + Ensure that the where is a Term or a list of Term. + + This makes sure that we are capturing the scope of variables that are + passed create the terms here with a frame_level=2 (we are 2 levels down) + """ + # only consider list/tuple here as an ndarray is automatically a coordinate + # list + level = scope_level + 1 + if isinstance(where, (list, tuple)): + where = [ + Term(term, scope_level=level + 1) if maybe_expression(term) else term + for term in where + if term is not None + ] + elif maybe_expression(where): + where = Term(where, scope_level=level) + return where if where is None or len(where) else None + + +incompatibility_doc: Final = """ +where criteria is being ignored as this version [%s] is too old (or +not-defined), read the file in and write it out to a new file to upgrade (with +the copy_to method) +""" + +attribute_conflict_doc: Final = """ +the [%s] attribute of the existing index is [%s] which conflicts with the new +[%s], resetting the attribute to None +""" + +performance_doc: Final = """ +your performance may suffer as PyTables will pickle object types that it cannot +map directly to c-types [inferred_type->%s,key->%s] [items->%s] +""" + +# formats +_FORMAT_MAP = {"f": "fixed", "fixed": "fixed", "t": "table", "table": "table"} + +# axes map +_AXES_MAP = {DataFrame: [0]} + +# register our configuration options +dropna_doc: Final = """ +: boolean + drop ALL nan rows when appending to a table +""" +format_doc: Final = """ +: format + default format writing format, if None, then + put will default to 'fixed' and append will default to 'table' +""" + +with config.config_prefix("io.hdf"): + config.register_option("dropna_table", False, dropna_doc, validator=config.is_bool) + config.register_option( + "default_format", + None, + format_doc, + validator=config.is_one_of_factory(["fixed", "table", None]), + ) + +# oh the troubles to reduce import time +_table_mod: ModuleType | None = None +_table_file_open_policy_is_strict = False + + +def _tables(): + global _table_mod + global _table_file_open_policy_is_strict + if _table_mod is None: + import tables + + _table_mod = tables + + # set the file open policy + # return the file open policy; this changes as of pytables 3.1 + # depending on the HDF5 version + with suppress(AttributeError): + _table_file_open_policy_is_strict = ( + tables.file._FILE_OPEN_POLICY == "strict" + ) + + return _table_mod + + +# interface to/from ### + + +def to_hdf( + path_or_buf: FilePath | HDFStore, + key: str, + value: DataFrame | Series, + mode: str = "a", + complevel: int | None = None, + complib: str | None = None, + append: bool = False, + format: str | None = None, + index: bool = True, + min_itemsize: int | dict[str, int] | None = None, + nan_rep=None, + dropna: bool | None = None, + data_columns: Literal[True] | list[str] | None = None, + errors: str = "strict", + encoding: str = "UTF-8", +) -> None: + """store this object, close it if we opened it""" + if append: + f = lambda store: store.append( + key, + value, + format=format, + index=index, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + dropna=dropna, + data_columns=data_columns, + errors=errors, + encoding=encoding, + ) + else: + # NB: dropna is not passed to `put` + f = lambda store: store.put( + key, + value, + format=format, + index=index, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + data_columns=data_columns, + errors=errors, + encoding=encoding, + dropna=dropna, + ) + + if isinstance(path_or_buf, HDFStore): + f(path_or_buf) + else: + path_or_buf = stringify_path(path_or_buf) + with HDFStore( + path_or_buf, mode=mode, complevel=complevel, complib=complib + ) as store: + f(store) + + +@set_module("pandas") +def read_hdf( + path_or_buf: FilePath | HDFStore, + key=None, + mode: str = "r", + errors: str = "strict", + where: str | list | None = None, + start: int | None = None, + stop: int | None = None, + columns: list[str] | None = None, + iterator: bool = False, + chunksize: int | None = None, + **kwargs, +): + """ + Read from the store, close it if we opened it. + + Retrieve pandas object stored in file, optionally based on where + criteria. + + .. warning:: + + Pandas uses PyTables for reading and writing HDF5 files, which allows + serializing object-dtype data with pickle when using the "fixed" format. + Loading pickled data received from untrusted sources can be unsafe. + + See: https://docs.python.org/3/library/pickle.html for more. + + Parameters + ---------- + path_or_buf : str, path object, pandas.HDFStore + Any valid string path is acceptable. Only supports the local file system, + remote URLs and file-like objects are not supported. + + If you want to pass in a path object, pandas accepts any + ``os.PathLike``. + + Alternatively, pandas accepts an open :class:`pandas.HDFStore` object. + + key : object, optional + The group identifier in the store. Can be omitted if the HDF file + contains a single pandas object. + mode : {'r', 'r+', 'a'}, default 'r' + Mode to use when opening the file. Ignored if path_or_buf is a + :class:`pandas.HDFStore`. Default is 'r'. + errors : str, default 'strict' + Specifies how encoding and decoding errors are to be handled. + See the errors argument for :func:`open` for a full list + of options. + where : list, optional + A list of Term (or convertible) objects. + start : int, optional + Row number to start selection. + stop : int, optional + Row number to stop selection. + columns : list, optional + A list of columns names to return. + iterator : bool, optional + Return an iterator object. + chunksize : int, optional + Number of rows to include in an iteration when using an iterator. + **kwargs + Additional keyword arguments passed to HDFStore. + + Returns + ------- + object + The selected object. Return type depends on the object stored. + + See Also + -------- + DataFrame.to_hdf : Write an HDF file from a DataFrame. + HDFStore : Low-level access to HDF files. + + Notes + ----- + When ``errors="surrogatepass"``, ``pd.options.future.infer_string`` is true, + and PyArrow is installed, if a UTF-16 surrogate is encountered when decoding + to UTF-8, the resulting dtype will be + ``pd.StringDtype(storage="python", na_value=np.nan)``. + + Examples + -------- + >>> df = pd.DataFrame([[1, 1.0, "a"]], columns=["x", "y", "z"]) # doctest: +SKIP + >>> df.to_hdf("./store.h5", "data") # doctest: +SKIP + >>> reread = pd.read_hdf("./store.h5") # doctest: +SKIP + """ + if mode not in ["r", "r+", "a"]: + raise ValueError( + f"mode {mode} is not allowed while performing a read. " + f"Allowed modes are r, r+ and a." + ) + # grab the scope + if where is not None: + where = _ensure_term(where, scope_level=1) + + if isinstance(path_or_buf, HDFStore): + if not path_or_buf.is_open: + raise OSError("The HDFStore must be open for reading.") + + store = path_or_buf + auto_close = False + else: + path_or_buf = stringify_path(path_or_buf) + if not isinstance(path_or_buf, str): + raise NotImplementedError( + "Support for generic buffers has not been implemented." + ) + try: + exists = os.path.exists(path_or_buf) + + # if filepath is too long + except (TypeError, ValueError): + exists = False + + if not exists: + raise FileNotFoundError(f"File {path_or_buf} does not exist") + + store = HDFStore(path_or_buf, mode=mode, errors=errors, **kwargs) + # can't auto open/close if we are using an iterator + # so delegate to the iterator + auto_close = True + + try: + if key is None: + groups = store.groups() + if len(groups) == 0: + raise ValueError( + "Dataset(s) incompatible with Pandas data types, " + "not table, or no datasets found in HDF5 file." + ) + candidate_only_group = groups[0] + + # For the HDF file to have only one dataset, all other groups + # should then be metadata groups for that candidate group. (This + # assumes that the groups() method enumerates parent groups + # before their children.) + for group_to_check in groups[1:]: + if not _is_metadata_of(group_to_check, candidate_only_group): + raise ValueError( + "key must be provided when HDF5 " + "file contains multiple datasets." + ) + key = candidate_only_group._v_pathname + return store.select( + key, + where=where, + start=start, + stop=stop, + columns=columns, + iterator=iterator, + chunksize=chunksize, + auto_close=auto_close, + ) + except (ValueError, TypeError, LookupError): + if not isinstance(path_or_buf, HDFStore): + # if there is an error, close the store if we opened it. + with suppress(AttributeError): + store.close() + + raise + + +def _is_metadata_of(group: Node, parent_group: Node) -> bool: + """Check if a given group is a metadata group for a given parent_group.""" + if group._v_depth <= parent_group._v_depth: + return False + + current = group + while current._v_depth > 1: + parent = current._v_parent + if parent == parent_group and current._v_name == "meta": + return True + current = current._v_parent + return False + + +@set_module("pandas") +class HDFStore: + """ + Dict-like IO interface for storing pandas objects in PyTables. + + Either Fixed or Table format. + + .. warning:: + + Pandas uses PyTables for reading and writing HDF5 files, which allows + serializing object-dtype data with pickle when using the "fixed" format. + Loading pickled data received from untrusted sources can be unsafe. + + See: https://docs.python.org/3/library/pickle.html for more. + + Parameters + ---------- + path : str + File path to HDF5 file. + mode : {'a', 'w', 'r', 'r+'}, default 'a' + + ``'r'`` + Read-only; no data can be modified. + ``'w'`` + Write; a new file is created (an existing file with the same + name would be deleted). + ``'a'`` + Append; an existing file is opened for reading and writing, + and if the file does not exist it is created. + ``'r+'`` + It is similar to ``'a'``, but the file must already exist. + complevel : int, 0-9, default None + Specifies a compression level for data. + A value of 0 or None disables compression. + complib : {'zlib', 'lzo', 'bzip2', 'blosc'}, default 'zlib' + Specifies the compression library to be used. + These additional compressors for Blosc are supported + (default if no compressor specified: 'blosc:blosclz'): + {'blosc:blosclz', 'blosc:lz4', 'blosc:lz4hc', 'blosc:snappy', + 'blosc:zlib', 'blosc:zstd'}. + Specifying a compression library which is not available issues + a ValueError. + fletcher32 : bool, default False + If applying compression use the fletcher32 checksum. + **kwargs + These parameters will be passed to the PyTables open_file method. + + Examples + -------- + >>> bar = pd.DataFrame(np.random.randn(10, 4)) + >>> store = pd.HDFStore("test.h5") + >>> store["foo"] = bar # write to HDF5 + >>> bar = store["foo"] # retrieve + >>> store.close() + + **Create or load HDF5 file in-memory** + + When passing the `driver` option to the PyTables open_file method through + **kwargs, the HDF5 file is loaded or created in-memory and will only be + written when closed: + + >>> bar = pd.DataFrame(np.random.randn(10, 4)) + >>> store = pd.HDFStore("test.h5", driver="H5FD_CORE") + >>> store["foo"] = bar + >>> store.close() # only now, data is written to disk + """ + + _handle: File | None + _mode: str + + def __init__( + self, + path, + mode: str = "a", + complevel: int | None = None, + complib=None, + fletcher32: bool = False, + **kwargs, + ) -> None: + if "format" in kwargs: + raise ValueError("format is not a defined argument for HDFStore") + + tables = import_optional_dependency("tables") + + if complib is not None and complib not in tables.filters.all_complibs: + raise ValueError( + f"complib only supports {tables.filters.all_complibs} compression." + ) + + if complib is None and complevel is not None: + complib = tables.filters.default_complib + + self._path = stringify_path(path) + if mode is None: + mode = "a" + self._mode = mode + self._handle = None + self._complevel = complevel if complevel else 0 + self._complib = complib + self._fletcher32 = fletcher32 + self._filters = None + self.open(mode=mode, **kwargs) + + def __fspath__(self) -> str: + return self._path + + @property + def root(self): + """return the root node""" + self._check_if_open() + assert self._handle is not None # for mypy + return self._handle.root + + @property + def filename(self) -> str: + return self._path + + def __getitem__(self, key: str): + return self.get(key) + + def __setitem__(self, key: str, value) -> None: + self.put(key, value) + + def __delitem__(self, key: str) -> int | None: + return self.remove(key) + + def __getattr__(self, name: str): + """allow attribute access to get stores""" + try: + return self.get(name) + except (KeyError, ClosedFileError): + pass + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def __contains__(self, key: str) -> bool: + """ + check for existence of this key + can match the exact pathname or the pathnm w/o the leading '/' + """ + node = self.get_node(key) + if node is not None: + name = node._v_pathname + if key in (name, name[1:]): + return True + return False + + def __len__(self) -> int: + return len(self.groups()) + + def __repr__(self) -> str: + pstr = pprint_thing(self._path) + return f"{type(self)}\nFile path: {pstr}\n" + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + + def keys(self, include: str = "pandas") -> list[str]: + """ + Return a list of keys corresponding to objects stored in HDFStore. + + Parameters + ---------- + + include : str, default 'pandas' + When kind equals 'pandas' return pandas objects. + When kind equals 'native' return native HDF5 Table objects. + + Returns + ------- + list + List of ABSOLUTE path-names (e.g. have the leading '/'). + + Raises + ------ + raises ValueError if kind has an illegal value + + See Also + -------- + HDFStore.info : Prints detailed information on the store. + HDFStore.get_node : Returns the node with the key. + HDFStore.get_storer : Returns the storer object for a key. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + >>> store = pd.HDFStore("store.h5", "w") # doctest: +SKIP + >>> store.put("data", df) # doctest: +SKIP + >>> store.get("data") # doctest: +SKIP + >>> print(store.keys()) # doctest: +SKIP + ['/data1', '/data2'] + >>> store.close() # doctest: +SKIP + """ + if include == "pandas": + return [n._v_pathname for n in self.groups()] + + elif include == "native": + assert self._handle is not None # mypy + return [ + n._v_pathname for n in self._handle.walk_nodes("/", classname="Table") + ] + raise ValueError( + f"`include` should be either 'pandas' or 'native' but is '{include}'" + ) + + def __iter__(self) -> Iterator[str]: + return iter(self.keys()) + + def items(self) -> Iterator[tuple[str, list]]: + """ + iterate on key->group + """ + for g in self.groups(): + yield g._v_pathname, g + + def open(self, mode: str = "a", **kwargs) -> None: + """ + Open the file in the specified mode + + Parameters + ---------- + mode : {'a', 'w', 'r', 'r+'}, default 'a' + See HDFStore docstring or tables.open_file for info about modes + **kwargs + These parameters will be passed to the PyTables open_file method. + """ + tables = _tables() + + if self._mode != mode: + # if we are changing a write mode to read, ok + if self._mode in ["a", "w"] and mode in ["r", "r+"]: + pass + elif mode in ["w"]: + # this would truncate, raise here + if self.is_open: + raise PossibleDataLossError( + f"Re-opening the file [{self._path}] with mode [{self._mode}] " + "will delete the current file!" + ) + + self._mode = mode + + # close and reopen the handle + if self.is_open: + self.close() + + if self._complevel and self._complevel > 0: + self._filters = _tables().Filters( + self._complevel, self._complib, fletcher32=self._fletcher32 + ) + + if _table_file_open_policy_is_strict and self.is_open: + msg = ( + "Cannot open HDF5 file, which is already opened, " + "even in read-only mode." + ) + raise ValueError(msg) + + self._handle = tables.open_file(self._path, self._mode, **kwargs) + + def close(self) -> None: + """ + Close the PyTables file handle + """ + if self._handle is not None: + self._handle.close() + self._handle = None + + @property + def is_open(self) -> bool: + """ + return a boolean indicating whether the file is open + """ + if self._handle is None: + return False + return bool(self._handle.isopen) + + def flush(self, fsync: bool = False) -> None: + """ + Force all buffered modifications to be written to disk. + + Parameters + ---------- + fsync : bool (default False) + call ``os.fsync()`` on the file handle to force writing to disk. + + Notes + ----- + Without ``fsync=True``, flushing may not guarantee that the OS writes + to disk. With fsync, the operation will block until the OS claims the + file has been written; however, other caching layers may still + interfere. + """ + if self._handle is not None: + self._handle.flush() + if fsync: + with suppress(OSError): + os.fsync(self._handle.fileno()) + + def get(self, key: str): + """ + Retrieve pandas object stored in file. + + Parameters + ---------- + key : str + Object to retrieve from file. Raises KeyError if not found. + + Returns + ------- + object + Same type as object stored in file. + + See Also + -------- + HDFStore.get_node : Returns the node with the key. + HDFStore.get_storer : Returns the storer object for a key. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + >>> store = pd.HDFStore("store.h5", "w") # doctest: +SKIP + >>> store.put("data", df) # doctest: +SKIP + >>> store.get("data") # doctest: +SKIP + >>> store.close() # doctest: +SKIP + """ + with patch_pickle(): + # GH#31167 Without this patch, pickle doesn't know how to unpickle + # old DateOffset objects now that they are cdef classes. + group = self.get_node(key) + if group is None: + raise KeyError(f"No object named {key} in the file") + return self._read_group(group) + + def select( + self, + key: str, + where=None, + start=None, + stop=None, + columns=None, + iterator: bool = False, + chunksize: int | None = None, + auto_close: bool = False, + ): + """ + Retrieve pandas object stored in file, optionally based on where criteria. + + .. warning:: + + Pandas uses PyTables for reading and writing HDF5 files, which allows + serializing object-dtype data with pickle when using the "fixed" format. + Loading pickled data received from untrusted sources can be unsafe. + + See: https://docs.python.org/3/library/pickle.html for more. + + Parameters + ---------- + key : str + Object being retrieved from file. + where : list or None + List of Term (or convertible) objects, optional. + start : int or None + Row number to start selection. + stop : int, default None + Row number to stop selection. + columns : list or None + A list of columns that if not None, will limit the return columns. + iterator : bool or False + Returns an iterator. + chunksize : int or None + Number or rows to include in iteration, return an iterator. + auto_close : bool or False + Should automatically close the store when finished. + + Returns + ------- + object + Retrieved object from file. + + See Also + -------- + HDFStore.select_as_coordinates : Returns the selection as an index. + HDFStore.select_column : Returns a single column from the table. + HDFStore.select_as_multiple : Retrieves pandas objects from multiple tables. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + >>> store = pd.HDFStore("store.h5", "w") # doctest: +SKIP + >>> store.put("data", df) # doctest: +SKIP + >>> store.get("data") # doctest: +SKIP + >>> print(store.keys()) # doctest: +SKIP + ['/data1', '/data2'] + >>> store.select("/data1") # doctest: +SKIP + A B + 0 1 2 + 1 3 4 + >>> store.select("/data1", where="columns == A") # doctest: +SKIP + A + 0 1 + 1 3 + >>> store.close() # doctest: +SKIP + """ + group = self.get_node(key) + if group is None: + raise KeyError(f"No object named {key} in the file") + + # create the storer and axes + where = _ensure_term(where, scope_level=1) + s = self._create_storer(group) + s.infer_axes() + + # function to call on iteration + def func(_start, _stop, _where): + return s.read(start=_start, stop=_stop, where=_where, columns=columns) + + # create the iterator + it = TableIterator( + self, + s, + func, + where=where, + nrows=s.nrows, + start=start, + stop=stop, + iterator=iterator, + chunksize=chunksize, + auto_close=auto_close, + ) + + return it.get_result() + + def select_as_coordinates( + self, + key: str, + where=None, + start: int | None = None, + stop: int | None = None, + ): + """ + return the selection as an Index + + .. warning:: + + Pandas uses PyTables for reading and writing HDF5 files, which allows + serializing object-dtype data with pickle when using the "fixed" format. + Loading pickled data received from untrusted sources can be unsafe. + + See: https://docs.python.org/3/library/pickle.html for more. + + + Parameters + ---------- + key : str + where : list of Term (or convertible) objects, optional + start : integer (defaults to None), row number to start selection + stop : integer (defaults to None), row number to stop selection + """ + where = _ensure_term(where, scope_level=1) + tbl = self.get_storer(key) + if not isinstance(tbl, Table): + raise TypeError("can only read_coordinates with a table") + return tbl.read_coordinates(where=where, start=start, stop=stop) + + def select_column( + self, + key: str, + column: str, + start: int | None = None, + stop: int | None = None, + ): + """ + return a single column from the table. This is generally only useful to + select an indexable + + .. warning:: + + Pandas uses PyTables for reading and writing HDF5 files, which allows + serializing object-dtype data with pickle when using the "fixed" format. + Loading pickled data received from untrusted sources can be unsafe. + + See: https://docs.python.org/3/library/pickle.html for more. + + Parameters + ---------- + key : str + column : str + The column of interest. + start : int or None, default None + stop : int or None, default None + + Raises + ------ + raises KeyError if the column is not found (or key is not a valid + store) + raises ValueError if the column can not be extracted individually (it + is part of a data block) + + """ + tbl = self.get_storer(key) + if not isinstance(tbl, Table): + raise TypeError("can only read_column with a table") + return tbl.read_column(column=column, start=start, stop=stop) + + def select_as_multiple( + self, + keys, + where=None, + selector=None, + columns=None, + start=None, + stop=None, + iterator: bool = False, + chunksize: int | None = None, + auto_close: bool = False, + ): + """ + Retrieve pandas objects from multiple tables. + + .. warning:: + + Pandas uses PyTables for reading and writing HDF5 files, which allows + serializing object-dtype data with pickle when using the "fixed" format. + Loading pickled data received from untrusted sources can be unsafe. + + See: https://docs.python.org/3/library/pickle.html for more. + + Parameters + ---------- + keys : a list of the tables + selector : the table to apply the where criteria (defaults to keys[0] + if not supplied) + columns : the columns I want back + start : integer (defaults to None), row number to start selection + stop : integer (defaults to None), row number to stop selection + iterator : bool, return an iterator, default False + chunksize : nrows to include in iteration, return an iterator + auto_close : bool, default False + Should automatically close the store when finished. + + Raises + ------ + raises KeyError if keys or selector is not found or keys is empty + raises TypeError if keys is not a list or tuple + raises ValueError if the tables are not ALL THE SAME DIMENSIONS + """ + # default to single select + where = _ensure_term(where, scope_level=1) + if isinstance(keys, (list, tuple)) and len(keys) == 1: + keys = keys[0] + if isinstance(keys, str): + return self.select( + key=keys, + where=where, + columns=columns, + start=start, + stop=stop, + iterator=iterator, + chunksize=chunksize, + auto_close=auto_close, + ) + + if not isinstance(keys, (list, tuple)): + raise TypeError("keys must be a list/tuple") + + if not len(keys): + raise ValueError("keys must have a non-zero length") + + if selector is None: + selector = keys[0] + + # collect the tables + tbls = [self.get_storer(k) for k in keys] + s = self.get_storer(selector) + + # validate rows + nrows = None + for t, k in itertools.chain([(s, selector)], zip(tbls, keys, strict=True)): + if t is None: + raise KeyError(f"Invalid table [{k}]") + if not t.is_table: + raise TypeError( + f"object [{t.pathname}] is not a table, and cannot be used in all " + "select as multiple" + ) + + if nrows is None: + nrows = t.nrows + elif t.nrows != nrows: + raise ValueError("all tables must have exactly the same nrows!") + + # The isinstance checks here are redundant with the check above, + # but necessary for mypy; see GH#29757 + _tbls = [x for x in tbls if isinstance(x, Table)] + + # axis is the concentration axes + axis = {t.non_index_axes[0][0] for t in _tbls}.pop() + + def func(_start, _stop, _where): + # retrieve the objs, _where is always passed as a set of + # coordinates here + objs = [ + t.read(where=_where, columns=columns, start=_start, stop=_stop) + for t in tbls + ] + + # concat and return + return concat(objs, axis=axis, verify_integrity=False)._consolidate() + + # create the iterator + it = TableIterator( + self, + s, + func, + where=where, + nrows=nrows, + start=start, + stop=stop, + iterator=iterator, + chunksize=chunksize, + auto_close=auto_close, + ) + + return it.get_result(coordinates=True) + + def put( + self, + key: str, + value: DataFrame | Series, + format=None, + index: bool = True, + append: bool = False, + complib=None, + complevel: int | None = None, + min_itemsize: int | dict[str, int] | None = None, + nan_rep=None, + data_columns: Literal[True] | list[str] | None = None, + encoding=None, + errors: str = "strict", + track_times: bool = True, + dropna: bool = False, + ) -> None: + """ + Store object in HDFStore. + + This method writes a pandas DataFrame or Series into an HDF5 file using + either the fixed or table format. The `table` format allows additional + operations like incremental appends and queries but may have performance + trade-offs. The `fixed` format provides faster read/write operations but + does not support appends or queries. + + Parameters + ---------- + key : str + Key of object to store in file. + value : {Series, DataFrame} + Value of object to store in file. + format : 'fixed(f)|table(t)', default is 'fixed' + Format to use when storing object in HDFStore. Value can be one of: + + ``'fixed'`` + Fixed format. Fast writing/reading. Not-appendable, nor searchable. + ``'table'`` + Table format. Write as a PyTables Table structure which may perform + worse but allow more flexible operations like searching / selecting + subsets of the data. + index : bool, default True + Write DataFrame index as a column. + append : bool, default False + This will force Table format, append the input data to the existing. + complib : default None + This parameter is currently not accepted. + complevel : int, 0-9, default None + Specifies a compression level for data. + A value of 0 or None disables compression. + min_itemsize : int, dict, or None + Dict of columns that specify minimum str sizes. + nan_rep : str + Str to use as str nan representation. + data_columns : list of columns or True, default None + List of columns to create as data columns, or True to use all columns. + See `here + `__. + encoding : str, default None + Provide an encoding for strings. + errors : str, default 'strict' + The error handling scheme to use for encoding errors. + The default is 'strict' meaning that encoding errors raise a + UnicodeEncodeError. Other possible values are 'ignore', 'replace' and + 'xmlcharrefreplace' as well as any other name registered with + codecs.register_error that can handle UnicodeEncodeErrors. + track_times : bool, default True + Parameter is propagated to 'create_table' method of 'PyTables'. + If set to False it enables to have the same h5 files (same hashes) + independent on creation time. + dropna : bool, default False, optional + Remove missing values. + + See Also + -------- + HDFStore.info : Prints detailed information on the store. + HDFStore.get_storer : Returns the storer object for a key. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + >>> store = pd.HDFStore("store.h5", "w") # doctest: +SKIP + >>> store.put("data", df) # doctest: +SKIP + """ + if format is None: + format = get_option("io.hdf.default_format") or "fixed" + format = self._validate_format(format) + self._write_to_group( + key, + value, + format=format, + index=index, + append=append, + complib=complib, + complevel=complevel, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + data_columns=data_columns, + encoding=encoding, + errors=errors, + track_times=track_times, + dropna=dropna, + ) + + def remove(self, key: str, where=None, start=None, stop=None) -> int | None: + """ + Remove pandas object partially by specifying the where condition + + Parameters + ---------- + key : str + Node to remove or delete rows from + where : list of Term (or convertible) objects, optional + start : integer (defaults to None), row number to start selection + stop : integer (defaults to None), row number to stop selection + + Returns + ------- + number of rows removed (or None if not a Table) + + Raises + ------ + raises KeyError if key is not a valid store + + """ + where = _ensure_term(where, scope_level=1) + try: + s = self.get_storer(key) + except KeyError: + # the key is not a valid store, re-raising KeyError + raise + except AssertionError: + # surface any assertion errors for e.g. debugging + raise + except Exception as err: + # In tests we get here with ClosedFileError, TypeError, and + # _table_mod.NoSuchNodeError. TODO: Catch only these? + + if where is not None: + raise ValueError( + "trying to remove a node with a non-None where clause!" + ) from err + + # we are actually trying to remove a node (with children) + node = self.get_node(key) + if node is not None: + node._f_remove(recursive=True) + return None + + # remove the node + if com.all_none(where, start, stop): + s.group._f_remove(recursive=True) + return None + + # delete from the table + if not s.is_table: + raise ValueError("can only remove with where on objects written as tables") + return s.delete(where=where, start=start, stop=stop) + + def append( + self, + key: str, + value: DataFrame | Series, + format=None, + axes=None, + index: bool | list[str] = True, + append: bool = True, + complib=None, + complevel: int | None = None, + columns=None, + min_itemsize: int | dict[str, int] | None = None, + nan_rep=None, + chunksize: int | None = None, + expectedrows=None, + dropna: bool | None = None, + data_columns: Literal[True] | list[str] | None = None, + encoding=None, + errors: str = "strict", + ) -> None: + """ + Append to Table in file. + + Node must already exist and be Table format. + + Parameters + ---------- + key : str + Key of object to append. + value : {Series, DataFrame} + Value of object to append. + format : 'table' is the default + Format to use when storing object in HDFStore. Value can be one of: + + ``'table'`` + Table format. Write as a PyTables Table structure which may perform + worse but allow more flexible operations like searching / selecting + subsets of the data. + axes : default None + This parameter is currently not accepted. + index : bool, default True + Write DataFrame index as a column. + append : bool, default True + Append the input data to the existing. + complib : default None + This parameter is currently not accepted. + complevel : int, 0-9, default None + Specifies a compression level for data. + A value of 0 or None disables compression. + columns : default None + This parameter is currently not accepted, try data_columns. + min_itemsize : int, dict, or None + Dict of columns that specify minimum str sizes. + nan_rep : str + Str to use as str nan representation. + chunksize : int or None + Size to chunk the writing. + expectedrows : int + Expected TOTAL row size of this table. + dropna : bool, default False, optional + Do not write an ALL nan row to the store settable + by the option 'io.hdf.dropna_table'. + data_columns : list of columns, or True, default None + List of columns to create as indexed data columns for on-disk + queries, or True to use all columns. By default only the axes + of the object are indexed. See `here + `__. + encoding : default None + Provide an encoding for str. + errors : str, default 'strict' + The error handling scheme to use for encoding errors. + The default is 'strict' meaning that encoding errors raise a + UnicodeEncodeError. Other possible values are 'ignore', 'replace' and + 'xmlcharrefreplace' as well as any other name registered with + codecs.register_error that can handle UnicodeEncodeErrors. + + See Also + -------- + HDFStore.append_to_multiple : Append to multiple tables. + + Notes + ----- + Does *not* check if data being appended overlaps with existing + data in the table, so be careful + + Examples + -------- + >>> df1 = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + >>> store = pd.HDFStore("store.h5", "w") # doctest: +SKIP + >>> store.put("data", df1, format="table") # doctest: +SKIP + >>> df2 = pd.DataFrame([[5, 6], [7, 8]], columns=["A", "B"]) + >>> store.append("data", df2) # doctest: +SKIP + >>> store.close() # doctest: +SKIP + A B + 0 1 2 + 1 3 4 + 0 5 6 + 1 7 8 + """ + if columns is not None: + raise TypeError( + "columns is not a supported keyword in append, try data_columns" + ) + + if dropna is None: + dropna = get_option("io.hdf.dropna_table") + if format is None: + format = get_option("io.hdf.default_format") or "table" + format = self._validate_format(format) + self._write_to_group( + key, + value, + format=format, + axes=axes, + index=index, + append=append, + complib=complib, + complevel=complevel, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + chunksize=chunksize, + expectedrows=expectedrows, + dropna=dropna, + data_columns=data_columns, + encoding=encoding, + errors=errors, + ) + + def append_to_multiple( + self, + d: dict, + value, + selector, + data_columns=None, + axes=None, + dropna: bool = False, + **kwargs, + ) -> None: + """ + Append to multiple tables + + Parameters + ---------- + d : a dict of table_name to table_columns, None is acceptable as the + values of one node (this will get all the remaining columns) + value : a pandas object + selector : a string that designates the indexable table; all of its + columns will be designed as data_columns, unless data_columns is + passed, in which case these are used + data_columns : list of columns to create as data columns, or True to + use all columns + dropna : if evaluates to True, drop rows from all tables if any single + row in each table has all NaN. Default False. + + Notes + ----- + axes parameter is currently not accepted + + """ + if axes is not None: + raise TypeError( + "axes is currently not accepted as a parameter to append_to_multiple; " + "you can create the tables independently instead" + ) + + if not isinstance(d, dict): + raise ValueError( + "append_to_multiple must have a dictionary specified as the " + "way to split the value" + ) + + if selector not in d: + raise ValueError( + "append_to_multiple requires a selector that is in passed dict" + ) + + # figure out the splitting axis (the non_index_axis) + axis = next(iter(set(range(value.ndim)) - set(_AXES_MAP[type(value)]))) + + # figure out how to split the value + remain_key = None + remain_values: list = [] + for k, v in d.items(): + if v is None: + if remain_key is not None: + raise ValueError( + "append_to_multiple can only have one value in d that is None" + ) + remain_key = k + else: + remain_values.extend(v) + if remain_key is not None: + ordered = value.axes[axis] + ordd = ordered.difference(Index(remain_values)) + ordd = sorted(ordered.get_indexer(ordd)) + d[remain_key] = ordered.take(ordd) + + # data_columns + if data_columns is None: + data_columns = d[selector] + + # ensure rows are synchronized across the tables + if dropna: + idxs = (value[cols].dropna(how="all").index for cols in d.values()) + valid_index = next(idxs) + for index in idxs: + valid_index = valid_index.intersection(index) + value = value.loc[valid_index] + + min_itemsize = kwargs.pop("min_itemsize", None) + + # append + for k, v in d.items(): + dc = data_columns if k == selector else None + + # compute the val + val = value.reindex(v, axis=axis) + + filtered = ( + {key: value for (key, value) in min_itemsize.items() if key in v} + if min_itemsize is not None + else None + ) + self.append(k, val, data_columns=dc, min_itemsize=filtered, **kwargs) + + def create_table_index( + self, + key: str, + columns=None, + optlevel: int | None = None, + kind: str | None = None, + ) -> None: + """ + Create a pytables index on the table. + + Parameters + ---------- + key : str + columns : None, bool, or listlike[str] + Indicate which columns to create an index on. + + * False : Do not create any indexes. + * True : Create indexes on all columns. + * None : Create indexes on all columns. + * listlike : Create indexes on the given columns. + + optlevel : int or None, default None + Optimization level, if None, pytables defaults to 6. + kind : str or None, default None + Kind of index, if None, pytables defaults to "medium". + + Raises + ------ + TypeError: raises if the node is not a table + """ + # version requirements + _tables() + s = self.get_storer(key) + if s is None: + return + + if not isinstance(s, Table): + raise TypeError("cannot create table index on a Fixed format store") + s.create_index(columns=columns, optlevel=optlevel, kind=kind) + + def groups(self) -> list: + """ + Return a list of all the top-level nodes. + + Each node returned is not a pandas storage object. + + Returns + ------- + list + List of objects. + + See Also + -------- + HDFStore.get_node : Returns the node with the key. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + >>> store = pd.HDFStore("store.h5", "w") # doctest: +SKIP + >>> store.put("data", df) # doctest: +SKIP + >>> print(store.groups()) # doctest: +SKIP + >>> store.close() # doctest: +SKIP + [/data (Group) '' + children := ['axis0' (Array), 'axis1' (Array), 'block0_values' (Array), + 'block0_items' (Array)]] + """ + _tables() + self._check_if_open() + assert self._handle is not None # for mypy + assert _table_mod is not None # for mypy + return [ + g + for g in self._handle.walk_groups() + if ( + not isinstance(g, _table_mod.link.Link) + and ( + getattr(g._v_attrs, "pandas_type", None) + or getattr(g, "table", None) + or (isinstance(g, _table_mod.table.Table) and g._v_name != "table") + ) + ) + ] + + def walk(self, where: str = "/") -> Iterator[tuple[str, list[str], list[str]]]: + """ + Walk the pytables group hierarchy for pandas objects. + + This generator will yield the group path, subgroups and pandas object + names for each group. + + Any non-pandas PyTables objects that are not a group will be ignored. + + The `where` group itself is listed first (preorder), then each of its + child groups (following an alphanumerical order) is also traversed, + following the same procedure. + + Parameters + ---------- + where : str, default "/" + Group where to start walking. + + Yields + ------ + path : str + Full path to a group (without trailing '/'). + groups : list + Names (strings) of the groups contained in `path`. + leaves : list + Names (strings) of the pandas objects contained in `path`. + + See Also + -------- + HDFStore.info : Prints detailed information on the store. + + Examples + -------- + >>> df1 = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + >>> store = pd.HDFStore("store.h5", "w") # doctest: +SKIP + >>> store.put("data", df1, format="table") # doctest: +SKIP + >>> df2 = pd.DataFrame([[5, 6], [7, 8]], columns=["A", "B"]) + >>> store.append("data", df2) # doctest: +SKIP + >>> store.close() # doctest: +SKIP + >>> for group in store.walk(): # doctest: +SKIP + ... print(group) # doctest: +SKIP + >>> store.close() # doctest: +SKIP + """ + _tables() + self._check_if_open() + assert self._handle is not None # for mypy + assert _table_mod is not None # for mypy + + for g in self._handle.walk_groups(where): + if getattr(g._v_attrs, "pandas_type", None) is not None: + continue + + groups = [] + leaves = [] + for child in g._v_children.values(): + pandas_type = getattr(child._v_attrs, "pandas_type", None) + if pandas_type is None: + if isinstance(child, _table_mod.group.Group): + groups.append(child._v_name) + else: + leaves.append(child._v_name) + + yield (g._v_pathname.rstrip("/"), groups, leaves) + + def get_node(self, key: str) -> Node | None: + """return the node with the key or None if it does not exist""" + self._check_if_open() + if not key.startswith("/"): + key = "/" + key + + assert self._handle is not None + assert _table_mod is not None # for mypy + try: + node = self._handle.get_node(self.root, key) + except _table_mod.exceptions.NoSuchNodeError: + return None + + assert isinstance(node, _table_mod.Node), type(node) + return node + + def get_storer(self, key: str) -> GenericFixed | Table: + """return the storer object for a key, raise if not in the file""" + group = self.get_node(key) + if group is None: + raise KeyError(f"No object named {key} in the file") + + s = self._create_storer(group) + s.infer_axes() + return s + + def copy( + self, + file, + mode: str = "w", + propindexes: bool = True, + keys=None, + complib=None, + complevel: int | None = None, + fletcher32: bool = False, + overwrite: bool = True, + ) -> HDFStore: + """ + Copy the existing store to a new file, updating in place. + + Parameters + ---------- + propindexes : bool, default True + Restore indexes in copied file. + keys : list, optional + List of keys to include in the copy (defaults to all). + overwrite : bool, default True + Whether to overwrite (remove and replace) existing nodes in the new store. + mode, complib, complevel, fletcher32 same as in HDFStore.__init__ + + Returns + ------- + open file handle of the new store + """ + new_store = HDFStore( + file, mode=mode, complib=complib, complevel=complevel, fletcher32=fletcher32 + ) + if keys is None: + keys = list(self.keys()) + if not isinstance(keys, (tuple, list)): + keys = [keys] + for k in keys: + s = self.get_storer(k) + if s is not None: + if k in new_store: + if overwrite: + new_store.remove(k) + + data = self.select(k) + if isinstance(s, Table): + index: bool | list[str] = False + if propindexes: + index = [a.name for a in s.axes if a.is_indexed] + new_store.append( + k, + data, + index=index, + data_columns=getattr(s, "data_columns", None), + encoding=s.encoding, + ) + else: + new_store.put(k, data, encoding=s.encoding) + + return new_store + + def info(self) -> str: + """ + Print detailed information on the store. + + Returns + ------- + str + A String containing the python pandas class name, filepath to the HDF5 + file and all the object keys along with their respective dataframe shapes. + + See Also + -------- + HDFStore.get_storer : Returns the storer object for a key. + + Examples + -------- + >>> df1 = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + >>> df2 = pd.DataFrame([[5, 6], [7, 8]], columns=["C", "D"]) + >>> store = pd.HDFStore("store.h5", "w") # doctest: +SKIP + >>> store.put("data1", df1) # doctest: +SKIP + >>> store.put("data2", df2) # doctest: +SKIP + >>> print(store.info()) # doctest: +SKIP + >>> store.close() # doctest: +SKIP + + File path: store.h5 + /data1 frame (shape->[2,2]) + /data2 frame (shape->[2,2]) + """ + path = pprint_thing(self._path) + output = f"{type(self)}\nFile path: {path}\n" + + if self.is_open: + lkeys = sorted(self.keys()) + if lkeys: + keys = [] + values = [] + + for k in lkeys: + try: + s = self.get_storer(k) + if s is not None: + keys.append(pprint_thing(s.pathname or k)) + values.append(pprint_thing(s or "invalid_HDFStore node")) + except AssertionError: + # surface any assertion errors for e.g. debugging + raise + except Exception as detail: + keys.append(k) + dstr = pprint_thing(detail) + values.append(f"[invalid_HDFStore node: {dstr}]") + + output += adjoin(12, keys, values) + else: + output += "Empty" + else: + output += "File is CLOSED" + + return output + + # ------------------------------------------------------------------------ + # private methods + + def _check_if_open(self) -> None: + if not self.is_open: + raise ClosedFileError(f"{self._path} file is not open!") + + def _validate_format(self, format: str) -> str: + """validate / deprecate formats""" + # validate + try: + format = _FORMAT_MAP[format.lower()] + except KeyError as err: + raise TypeError(f"invalid HDFStore format specified [{format}]") from err + + return format + + def _create_storer( + self, + group, + format=None, + value: DataFrame | Series | None = None, + encoding: str = "UTF-8", + errors: str = "strict", + ) -> GenericFixed | Table: + """return a suitable class to operate""" + cls: type[GenericFixed | Table] + + if value is not None and not isinstance(value, (Series, DataFrame)): + raise TypeError("value must be None, Series, or DataFrame") + + pt = getattr(group._v_attrs, "pandas_type", None) + tt = getattr(group._v_attrs, "table_type", None) + + # infer the pt from the passed value + if pt is None: + if value is None: + _tables() + assert _table_mod is not None # for mypy + if getattr(group, "table", None) or isinstance( + group, _table_mod.table.Table + ): + pt = "frame_table" + tt = "generic_table" + else: + raise TypeError( + "cannot create a storer if the object is not existing " + "nor a value are passed" + ) + else: + if isinstance(value, Series): + pt = "series" + else: + pt = "frame" + + # we are actually a table + if format == "table": + pt += "_table" + + # a storer node + if "table" not in pt: + _STORER_MAP = {"series": SeriesFixed, "frame": FrameFixed} + try: + cls = _STORER_MAP[pt] + except KeyError as err: + raise TypeError( + f"cannot properly create the storer for: [_STORER_MAP] [group->" + f"{group},value->{type(value)},format->{format}" + ) from err + return cls(self, group, encoding=encoding, errors=errors) + + # existing node (and must be a table) + if tt is None: + # if we are a writer, determine the tt + if value is not None: + if pt == "series_table": + index = getattr(value, "index", None) + if index is not None: + if index.nlevels == 1: + tt = "appendable_series" + elif index.nlevels > 1: + tt = "appendable_multiseries" + elif pt == "frame_table": + index = getattr(value, "index", None) + if index is not None: + if index.nlevels == 1: + tt = "appendable_frame" + elif index.nlevels > 1: + tt = "appendable_multiframe" + + _TABLE_MAP = { + "generic_table": GenericTable, + "appendable_series": AppendableSeriesTable, + "appendable_multiseries": AppendableMultiSeriesTable, + "appendable_frame": AppendableFrameTable, + "appendable_multiframe": AppendableMultiFrameTable, + "worm": WORMTable, + } + try: + cls = _TABLE_MAP[tt] # type: ignore[index] + except KeyError as err: + raise TypeError( + f"cannot properly create the storer for: [_TABLE_MAP] [group->" + f"{group},value->{type(value)},format->{format}" + ) from err + + return cls(self, group, encoding=encoding, errors=errors) + + def _write_to_group( + self, + key: str, + value: DataFrame | Series, + format, + axes=None, + index: bool | list[str] = True, + append: bool = False, + complib=None, + complevel: int | None = None, + fletcher32=None, + min_itemsize: int | dict[str, int] | None = None, + chunksize: int | None = None, + expectedrows=None, + dropna: bool = False, + nan_rep=None, + data_columns=None, + encoding=None, + errors: str = "strict", + track_times: bool = True, + ) -> None: + # we don't want to store a table node at all if our object is 0-len + # as there are not dtypes + if getattr(value, "empty", None) and (format == "table" or append): + return + + group = self._identify_group(key, append) + + s = self._create_storer(group, format, value, encoding=encoding, errors=errors) + if append: + # raise if we are trying to append to a Fixed format, + # or a table that exists (and we are putting) + if not s.is_table or (s.is_table and format == "fixed" and s.is_exists): + raise ValueError("Can only append to Tables") + if not s.is_exists: + s.set_object_info() + else: + s.set_object_info() + + if not s.is_table and complib: + raise ValueError("Compression not supported on Fixed format stores") + + # write the object + s.write( + obj=value, + axes=axes, + append=append, + complib=complib, + complevel=complevel, + fletcher32=fletcher32, + min_itemsize=min_itemsize, + chunksize=chunksize, + expectedrows=expectedrows, + dropna=dropna, + nan_rep=nan_rep, + data_columns=data_columns, + track_times=track_times, + ) + + if isinstance(s, Table) and index: + s.create_index(columns=index) + + def _read_group(self, group: Node): + s = self._create_storer(group) + s.infer_axes() + return s.read() + + def _identify_group(self, key: str, append: bool) -> Node: + """Identify HDF5 group based on key, delete/create group if needed.""" + group = self.get_node(key) + + # we make this assertion for mypy; the get_node call will already + # have raised if this is incorrect + assert self._handle is not None + + # remove the node if we are not appending + if group is not None and not append: + self._handle.remove_node(group, recursive=True) + group = None + + if group is None: + group = self._create_nodes_and_group(key) + + return group + + def _create_nodes_and_group(self, key: str) -> Node: + """Create nodes from key and return group name.""" + # assertion for mypy + assert self._handle is not None + + paths = key.split("/") + # recursively create the groups + path = "/" + for p in paths: + if not len(p): + continue + new_path = path + if not path.endswith("/"): + new_path += "/" + new_path += p + group = self.get_node(new_path) + if group is None: + group = self._handle.create_group(path, p) + path = new_path + return group + + +class TableIterator: + """ + Define the iteration interface on a table + + Parameters + ---------- + store : HDFStore + s : the referred storer + func : the function to execute the query + where : the where of the query + nrows : the rows to iterate on + start : the passed start value (default is None) + stop : the passed stop value (default is None) + iterator : bool, default False + Whether to use the default iterator. + chunksize : the passed chunking value (default is 100000) + auto_close : bool, default False + Whether to automatically close the store at the end of iteration. + """ + + chunksize: int | None + store: HDFStore + s: GenericFixed | Table + + def __init__( + self, + store: HDFStore, + s: GenericFixed | Table, + func, + where, + nrows, + start=None, + stop=None, + iterator: bool = False, + chunksize: int | None = None, + auto_close: bool = False, + ) -> None: + self.store = store + self.s = s + self.func = func + self.where = where + + # set start/stop if they are not set if we are a table + if self.s.is_table: + if nrows is None: + nrows = 0 + if start is None: + start = 0 + if stop is None: + stop = nrows + stop = min(nrows, stop) + + self.nrows = nrows + self.start = start + self.stop = stop + + self.coordinates = None + if iterator or chunksize is not None: + if chunksize is None: + chunksize = 100000 + self.chunksize = int(chunksize) + else: + self.chunksize = None + + self.auto_close = auto_close + + def __iter__(self) -> Iterator: + # iterate + current = self.start + if self.coordinates is None: + raise ValueError("Cannot iterate until get_result is called.") + while current < self.stop: + stop = min(current + self.chunksize, self.stop) + value = self.func(None, None, self.coordinates[current:stop]) + current = stop + if value is None or not len(value): + continue + + yield value + + self.close() + + def close(self) -> None: + if self.auto_close: + self.store.close() + + def get_result(self, coordinates: bool = False): + # return the actual iterator + if self.chunksize is not None: + if not isinstance(self.s, Table): + raise TypeError("can only use an iterator or chunksize on a table") + + self.coordinates = self.s.read_coordinates(where=self.where) + + return self + + # if specified read via coordinates (necessary for multiple selections + if coordinates: + if not isinstance(self.s, Table): + raise TypeError("can only read_coordinates on a table") + where = self.s.read_coordinates( + where=self.where, start=self.start, stop=self.stop + ) + else: + where = self.where + + # directly return the result + results = self.func(self.start, self.stop, where) + self.close() + return results + + +class IndexCol: + """ + an index column description class + + Parameters + ---------- + axis : axis which I reference + values : the ndarray like converted values + kind : a string description of this type + typ : the pytables type + pos : the position in the pytables + + """ + + is_an_indexable: bool = True + is_data_indexable: bool = True + _info_fields = ["freq", "tz", "index_name"] + + def __init__( + self, + name: str, + values=None, + kind=None, + typ=None, + cname: str | None = None, + axis=None, + pos=None, + freq=None, + tz=None, + index_name=None, + ordered=None, + table=None, + meta=None, + metadata=None, + ) -> None: + if not isinstance(name, str): + raise ValueError("`name` must be a str.") + + self.values = values + self.kind = kind + self.typ = typ + self.name = name + self.cname = cname or name + self.axis = axis + self.pos = pos + self.freq = freq + self.tz = tz + self.index_name = index_name + self.ordered = ordered + self.table = table + self.meta = meta + self.metadata = metadata + + if pos is not None: + self.set_pos(pos) + + # These are ensured as long as the passed arguments match the + # constructor annotations. + assert isinstance(self.name, str) + assert isinstance(self.cname, str) + + @property + def itemsize(self) -> int: + # Assumes self.typ has already been initialized + return self.typ.itemsize + + @property + def kind_attr(self) -> str: + return f"{self.name}_kind" + + def set_pos(self, pos: int) -> None: + """set the position of this column in the Table""" + self.pos = pos + if pos is not None and self.typ is not None: + self.typ._v_pos = pos + + def __repr__(self) -> str: + temp = tuple( + map(pprint_thing, (self.name, self.cname, self.axis, self.pos, self.kind)) + ) + return ",".join( + [ + f"{key}->{value}" + for key, value in zip( + ["name", "cname", "axis", "pos", "kind"], temp, strict=True + ) + ] + ) + + def __eq__(self, other: object) -> bool: + """compare 2 col items""" + return all( + getattr(self, a, None) == getattr(other, a, None) + for a in ["name", "cname", "axis", "pos"] + ) + + def __ne__(self, other) -> bool: + return not self.__eq__(other) + + @property + def is_indexed(self) -> bool: + """return whether I am an indexed column""" + if not hasattr(self.table, "cols"): + # e.g. if infer hasn't been called yet, self.table will be None. + return False + return getattr(self.table.cols, self.cname).is_indexed + + def convert( + self, values: np.ndarray, nan_rep, encoding: str, errors: str + ) -> tuple[np.ndarray, np.ndarray] | tuple[Index, Index]: + """ + Convert the data from this selection to the appropriate pandas type. + """ + assert isinstance(values, np.ndarray), type(values) + + # values is a recarray + if values.dtype.fields is not None: + # Copy, otherwise values will be a view + # preventing the original recarry from being free'ed + values = values[self.cname].copy() + + val_kind = self.kind + values = _maybe_convert(values, val_kind, encoding, errors) + kwargs = {} + kwargs["name"] = self.index_name + + if self.freq is not None: + kwargs["freq"] = self.freq + + factory: type[Index | DatetimeIndex] = Index + if lib.is_np_dtype(values.dtype, "M") or isinstance( + values.dtype, DatetimeTZDtype + ): + factory = DatetimeIndex + elif values.dtype == "i8" and "freq" in kwargs: + # PeriodIndex data is stored as i8 + # error: Incompatible types in assignment (expression has type + # "Callable[[Any, KwArg(Any)], PeriodIndex]", variable has type + # "Union[Type[Index], Type[DatetimeIndex]]") + factory = lambda x, **kwds: PeriodIndex.from_ordinals( # type: ignore[assignment] + x, freq=kwds.get("freq", None) + )._rename(kwds["name"]) + + # making an Index instance could throw a number of different errors + try: + new_pd_index = factory(values, **kwargs) + except UnicodeEncodeError as err: + if ( + errors == "surrogatepass" + and using_string_dtype() + and str(err).endswith("surrogates not allowed") + and HAS_PYARROW + ): + new_pd_index = factory( + values, + dtype=StringDtype(storage="python", na_value=np.nan), + **kwargs, + ) + else: + raise + except ValueError: + # if the output freq is different that what we recorded, + # it should be None (see also 'doc example part 2') + if "freq" in kwargs: + kwargs["freq"] = None + new_pd_index = factory(values, **kwargs) + + final_pd_index: Index + if self.tz is not None and isinstance(new_pd_index, DatetimeIndex): + final_pd_index = new_pd_index.tz_localize("UTC").tz_convert(self.tz) + else: + final_pd_index = new_pd_index + return final_pd_index, final_pd_index + + def take_data(self): + """return the values""" + return self.values + + @property + def attrs(self): + return self.table._v_attrs + + @property + def description(self): + return self.table.description + + @property + def col(self): + """return my current col description""" + return getattr(self.description, self.cname, None) + + @property + def cvalues(self): + """return my cython values""" + return self.values + + def __iter__(self) -> Iterator: + return iter(self.values) + + def maybe_set_size(self, min_itemsize=None) -> None: + """ + maybe set a string col itemsize: + min_itemsize can be an integer or a dict with this columns name + with an integer size + """ + if self.kind == "string": + if isinstance(min_itemsize, dict): + min_itemsize = min_itemsize.get(self.name) + + if min_itemsize is not None and self.typ.itemsize < min_itemsize: + self.typ = _tables().StringCol(itemsize=min_itemsize, pos=self.pos) + + def validate_names(self) -> None: + pass + + def validate_and_set(self, handler: AppendableTable, append: bool) -> None: + self.table = handler.table + self.validate_col() + self.validate_attr(append) + self.validate_metadata(handler) + self.write_metadata(handler) + self.set_attr() + + def validate_col(self, itemsize=None): + """validate this column: return the compared against itemsize""" + # validate this column for string truncation (or reset to the max size) + if self.kind == "string": + c = self.col + if c is not None: + if itemsize is None: + itemsize = self.itemsize + if c.itemsize < itemsize: + raise ValueError( + f"Trying to store a string with len [{itemsize}] in " + f"[{self.cname}] column but\nthis column has a limit of " + f"[{c.itemsize}]!\nConsider using min_itemsize to " + "preset the sizes on these columns" + ) + return c.itemsize + + return None + + def validate_attr(self, append: bool) -> None: + # check for backwards incompatibility + if append: + existing_kind = getattr(self.attrs, self.kind_attr, None) + if existing_kind is not None and existing_kind != self.kind: + raise TypeError( + f"incompatible kind in col [{existing_kind} - {self.kind}]" + ) + + def update_info(self, info) -> None: + """ + set/update the info for this indexable with the key/value + if there is a conflict raise/warn as needed + """ + for key in self._info_fields: + value = getattr(self, key, None) + idx = info.setdefault(self.name, {}) + + existing_value = idx.get(key) + if key in idx and value is not None and existing_value != value: + # frequency/name just warn + if key in ["freq", "index_name"]: + ws = attribute_conflict_doc % (key, existing_value, value) + warnings.warn( + ws, AttributeConflictWarning, stacklevel=find_stack_level() + ) + + # reset + idx[key] = None + setattr(self, key, None) + + else: + raise ValueError( + f"invalid info for [{self.name}] for [{key}], " + f"existing_value [{existing_value}] conflicts with " + f"new value [{value}]" + ) + elif value is not None or existing_value is not None: + idx[key] = value + + def set_info(self, info) -> None: + """set my state from the passed info""" + idx = info.get(self.name) + if idx is not None: + self.__dict__.update(idx) + + def set_attr(self) -> None: + """set the kind for this column""" + setattr(self.attrs, self.kind_attr, self.kind) + + def validate_metadata(self, handler: AppendableTable) -> None: + """validate that kind=category does not change the categories""" + if self.meta == "category": + new_metadata = self.metadata + cur_metadata = handler.read_metadata(self.cname) + if ( + new_metadata is not None + and cur_metadata is not None + and not array_equivalent( + new_metadata, cur_metadata, strict_nan=True, dtype_equal=True + ) + ): + raise ValueError( + "cannot append a categorical with " + "different categories to the existing" + ) + + def write_metadata(self, handler: AppendableTable) -> None: + """set the meta data""" + if self.metadata is not None: + handler.write_metadata(self.cname, self.metadata) + + +class GenericIndexCol(IndexCol): + """an index which is not represented in the data of the table""" + + @property + def is_indexed(self) -> bool: + return False + + def convert( + self, values: np.ndarray, nan_rep, encoding: str, errors: str + ) -> tuple[Index, Index]: + """ + Convert the data from this selection to the appropriate pandas type. + + Parameters + ---------- + values : np.ndarray + nan_rep : str + encoding : str + errors : str + """ + assert isinstance(values, np.ndarray), type(values) + + index = RangeIndex(len(values)) + return index, index + + def set_attr(self) -> None: + pass + + +class DataCol(IndexCol): + """ + a data holding column, by definition this is not indexable + + Parameters + ---------- + data : the actual data + cname : the column name in the table to hold the data (typically + values) + meta : a string description of the metadata + metadata : the actual metadata + """ + + is_an_indexable = False + is_data_indexable = False + _info_fields = ["tz", "ordered"] + + def __init__( + self, + name: str, + values=None, + kind=None, + typ=None, + cname: str | None = None, + pos=None, + tz=None, + ordered=None, + table=None, + meta=None, + metadata=None, + dtype: DtypeArg | None = None, + data=None, + ) -> None: + super().__init__( + name=name, + values=values, + kind=kind, + typ=typ, + pos=pos, + cname=cname, + tz=tz, + ordered=ordered, + table=table, + meta=meta, + metadata=metadata, + ) + self.dtype = dtype + self.data = data + + @property + def dtype_attr(self) -> str: + return f"{self.name}_dtype" + + @property + def meta_attr(self) -> str: + return f"{self.name}_meta" + + def __repr__(self) -> str: + temp = tuple( + map( + pprint_thing, (self.name, self.cname, self.dtype, self.kind, self.shape) + ) + ) + return ",".join( + [ + f"{key}->{value}" + for key, value in zip( + ["name", "cname", "dtype", "kind", "shape"], temp, strict=True + ) + ] + ) + + def __eq__(self, other: object) -> bool: + """compare 2 col items""" + return all( + getattr(self, a, None) == getattr(other, a, None) + for a in ["name", "cname", "dtype", "pos"] + ) + + def set_data(self, data: ArrayLike) -> None: + assert data is not None + assert self.dtype is None + + data, dtype_name = _get_data_and_dtype_name(data) + + self.data = data + self.dtype = dtype_name + self.kind = _dtype_to_kind(dtype_name) + + def take_data(self): + """return the data""" + return self.data + + @classmethod + def _get_atom(cls, values: ArrayLike) -> Col: + """ + Get an appropriately typed and shaped pytables.Col object for values. + """ + dtype = values.dtype + # error: Item "ExtensionDtype" of "Union[ExtensionDtype, dtype[Any]]" has no + # attribute "itemsize" + itemsize = dtype.itemsize # type: ignore[union-attr] + + shape = values.shape + if values.ndim == 1: + # EA, use block shape pretending it is 2D + # TODO(EA2D): not necessary with 2D EAs + shape = (1, values.size) + + if isinstance(values, Categorical): + codes = values.codes + atom = cls.get_atom_data(shape, kind=codes.dtype.name) + elif lib.is_np_dtype(dtype, "M") or isinstance(dtype, DatetimeTZDtype): + atom = cls.get_atom_datetime64(shape) + elif lib.is_np_dtype(dtype, "m"): + atom = cls.get_atom_timedelta64(shape) + elif is_complex_dtype(dtype): + atom = _tables().ComplexCol(itemsize=itemsize, shape=shape[0]) + elif is_string_dtype(dtype): + atom = cls.get_atom_string(shape, itemsize) + else: + atom = cls.get_atom_data(shape, kind=dtype.name) + + return atom + + @classmethod + def get_atom_string(cls, shape, itemsize): + return _tables().StringCol(itemsize=itemsize, shape=shape[0]) + + @classmethod + def get_atom_coltype(cls, kind: str) -> type[Col]: + """return the PyTables column class for this column""" + if kind.startswith("uint"): + k4 = kind[4:] + col_name = f"UInt{k4}Col" + elif kind.startswith("period"): + # we store as integer + col_name = "Int64Col" + else: + kcap = kind.capitalize() + col_name = f"{kcap}Col" + + return getattr(_tables(), col_name) + + @classmethod + def get_atom_data(cls, shape, kind: str) -> Col: + return cls.get_atom_coltype(kind=kind)(shape=shape[0]) + + @classmethod + def get_atom_datetime64(cls, shape): + return _tables().Int64Col(shape=shape[0]) + + @classmethod + def get_atom_timedelta64(cls, shape): + return _tables().Int64Col(shape=shape[0]) + + @property + def shape(self): + return getattr(self.data, "shape", None) + + @property + def cvalues(self): + """return my cython values""" + return self.data + + def validate_attr(self, append) -> None: + """validate that we have the same order as the existing & same dtype""" + if append: + existing_fields = getattr(self.attrs, self.kind_attr, None) + if existing_fields is not None and existing_fields != list(self.values): + raise ValueError("appended items do not match existing items in table!") + + existing_dtype = getattr(self.attrs, self.dtype_attr, None) + if existing_dtype is not None and existing_dtype != self.dtype: + raise ValueError( + "appended items dtype do not match existing items dtype in table!" + ) + + def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): + """ + Convert the data from this selection to the appropriate pandas type. + + Parameters + ---------- + values : np.ndarray + nan_rep : + encoding : str + errors : str + + Returns + ------- + index : listlike to become an Index + data : ndarraylike to become a column + """ + assert isinstance(values, np.ndarray), type(values) + + # values is a recarray + if values.dtype.fields is not None: + values = values[self.cname] + + assert self.typ is not None + if self.dtype is None: + # Note: in tests we never have timedelta64 or datetime64, + # so the _get_data_and_dtype_name may be unnecessary + converted, dtype_name = _get_data_and_dtype_name(values) + kind = _dtype_to_kind(dtype_name) + else: + converted = values + dtype_name = self.dtype + kind = self.kind + + assert isinstance(converted, np.ndarray) # for mypy + + # use the meta if needed + meta = self.meta + metadata = self.metadata + ordered = self.ordered + tz = self.tz + + assert dtype_name is not None + # convert to the correct dtype + dtype = dtype_name + + # reverse converts + if dtype.startswith("datetime64"): + # recreate with tz if indicated + if dtype == "datetime64": + dtype = "datetime64[ns]" + converted = _set_tz(converted, tz, dtype) + + elif dtype.startswith("timedelta64"): + if dtype == "timedelta64": + # from before we started storing timedelta64 unit + converted = np.asarray(converted, dtype="m8[ns]") + else: + converted = np.asarray(converted, dtype=dtype) + elif dtype == "date": + try: + converted = np.asarray( + [date.fromordinal(v) for v in converted], dtype=object + ) + except ValueError: + converted = np.asarray( + [date.fromtimestamp(v) for v in converted], dtype=object + ) + + elif meta == "category": + # we have a categorical + categories = metadata + codes = converted.ravel() + + # if we have stored a NaN in the categories + # then strip it; in theory we could have BOTH + # -1s in the codes and nulls :< + if categories is None: + # Handle case of NaN-only categorical columns in which case + # the categories are an empty array; when this is stored, + # pytables cannot write a zero-len array, so on readback + # the categories would be None and `read_hdf()` would fail. + categories = Index([], dtype=np.float64) + else: + mask = isna(categories) + if mask.any(): + categories = categories[~mask] + codes[codes != -1] -= mask.astype(int).cumsum()._values + + converted = Categorical.from_codes( + codes, categories=categories, ordered=ordered, validate=False + ) + + else: + try: + converted = converted.astype(dtype, copy=False) + except TypeError: + converted = converted.astype("O", copy=False) + + # convert nans / decode + if kind == "string": + converted = _unconvert_string_array( + converted, nan_rep=nan_rep, encoding=encoding, errors=errors + ) + + return self.values, converted + + def set_attr(self) -> None: + """set the data for this column""" + setattr(self.attrs, self.kind_attr, self.values) + setattr(self.attrs, self.meta_attr, self.meta) + assert self.dtype is not None + setattr(self.attrs, self.dtype_attr, self.dtype) + + +class DataIndexableCol(DataCol): + """represent a data column that can be indexed""" + + is_data_indexable = True + + def validate_names(self) -> None: + if not is_string_dtype(Index(self.values).dtype): + # TODO: should the message here be more specifically non-str? + raise ValueError("cannot have non-object label DataIndexableCol") + + @classmethod + def get_atom_string(cls, shape, itemsize): + return _tables().StringCol(itemsize=itemsize) + + @classmethod + def get_atom_data(cls, shape, kind: str) -> Col: + return cls.get_atom_coltype(kind=kind)() + + @classmethod + def get_atom_datetime64(cls, shape): + return _tables().Int64Col() + + @classmethod + def get_atom_timedelta64(cls, shape): + return _tables().Int64Col() + + +class GenericDataIndexableCol(DataIndexableCol): + """represent a generic pytables data column""" + + +class Fixed: + """ + represent an object in my store + facilitate read/write of various types of objects + this is an abstract base class + + Parameters + ---------- + parent : HDFStore + group : Node + The group node where the table resides. + """ + + pandas_kind: str + format_type: str = "fixed" # GH#30962 needed by dask + obj_type: type[DataFrame | Series] + ndim: int + parent: HDFStore + is_table: bool = False + + def __init__( + self, + parent: HDFStore, + group: Node, + encoding: str | None = "UTF-8", + errors: str = "strict", + ) -> None: + assert isinstance(parent, HDFStore), type(parent) + assert _table_mod is not None # needed for mypy + assert isinstance(group, _table_mod.Node), type(group) + self.parent = parent + self.group = group + self.encoding = _ensure_encoding(encoding) + self.errors = errors + + @property + def is_old_version(self) -> bool: + return self.version[0] <= 0 and self.version[1] <= 10 and self.version[2] < 1 + + @property + def version(self) -> tuple[int, int, int]: + """compute and set our version""" + version = getattr(self.group._v_attrs, "pandas_version", None) + if isinstance(version, str): + version_tup = tuple(int(x) for x in version.split(".")) + if len(version_tup) == 2: + version_tup = (*version_tup, 0) + assert len(version_tup) == 3 # needed for mypy + return version_tup + else: + return (0, 0, 0) + + @property + def pandas_type(self): + return getattr(self.group._v_attrs, "pandas_type", None) + + def __repr__(self) -> str: + """return a pretty representation of myself""" + self.infer_axes() + s = self.shape + if s is not None: + if isinstance(s, (list, tuple)): + jshape = ",".join([pprint_thing(x) for x in s]) + s = f"[{jshape}]" + return f"{self.pandas_type:12.12} (shape->{s})" + return self.pandas_type + + def set_object_info(self) -> None: + """set my pandas type & version""" + self.attrs.pandas_type = str(self.pandas_kind) + self.attrs.pandas_version = str(_version) + + def copy(self) -> Fixed: + new_self = copy.copy(self) + return new_self + + @property + def shape(self): + return self.nrows + + @property + def pathname(self): + return self.group._v_pathname + + @property + def _handle(self): + return self.parent._handle + + @property + def _filters(self): + return self.parent._filters + + @property + def _complevel(self) -> int: + return self.parent._complevel + + @property + def _fletcher32(self) -> bool: + return self.parent._fletcher32 + + @property + def attrs(self): + return self.group._v_attrs + + def set_attrs(self) -> None: + """set our object attributes""" + + def get_attrs(self) -> None: + """get our object attributes""" + + @property + def storable(self): + """return my storable""" + return self.group + + @property + def is_exists(self) -> bool: + return False + + @property + def nrows(self): + return getattr(self.storable, "nrows", None) + + def validate(self, other) -> Literal[True] | None: + """validate against an existing storable""" + if other is None: + return None + return True + + def validate_version(self, where=None) -> None: + """are we trying to operate on an old version?""" + + def infer_axes(self) -> bool: + """ + infer the axes of my storer + return a boolean indicating if we have a valid storer or not + """ + s = self.storable + if s is None: + return False + self.get_attrs() + return True + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ) -> Series | DataFrame: + raise NotImplementedError( + "cannot read on an abstract storer: subclasses should implement" + ) + + def write(self, obj, **kwargs) -> None: + raise NotImplementedError( + "cannot write on an abstract storer: subclasses should implement" + ) + + def delete( + self, where=None, start: int | None = None, stop: int | None = None + ) -> int | None: + """ + support fully deleting the node in its entirety (only) - where + specification must be None + """ + if com.all_none(where, start, stop): + self._handle.remove_node(self.group, recursive=True) + return None + + raise TypeError("cannot delete on an abstract storer") + + +class GenericFixed(Fixed): + """a generified fixed version""" + + _index_type_map = {DatetimeIndex: "datetime", PeriodIndex: "period"} + _reverse_index_map = {v: k for k, v in _index_type_map.items()} + attributes: list[str] = [] + + # indexer helpers + def _class_to_alias(self, cls) -> str: + return self._index_type_map.get(cls, "") + + def _alias_to_class(self, alias): + if isinstance(alias, type): # pragma: no cover + # compat: for a short period of time master stored types + return alias + return self._reverse_index_map.get(alias, Index) + + def _get_index_factory(self, attrs): + index_class = self._alias_to_class(getattr(attrs, "index_class", "")) + + factory: Callable + + kwargs = {} + if index_class == DatetimeIndex: + + def f(values, freq=None, tz=None): + # data are already in UTC, localize and convert if tz present + dta = DatetimeArray._simple_new( + values.values, dtype=values.dtype, freq=freq + ) + result = DatetimeIndex._simple_new(dta, name=None) + if tz is not None: + result = result.tz_localize("UTC").tz_convert(tz) + return result + + factory = f + elif index_class == PeriodIndex: + + def f(values, freq=None, tz=None): + dtype = PeriodDtype(freq) + parr = PeriodArray._simple_new(values, dtype=dtype) + return PeriodIndex._simple_new(parr, name=None) + + factory = f + else: + factory = index_class + kwargs["copy"] = False + + if "freq" in attrs: + kwargs["freq"] = attrs["freq"] + if index_class is Index: + # DTI/PI would be gotten by _alias_to_class + factory = TimedeltaIndex + + if "tz" in attrs: + kwargs["tz"] = attrs["tz"] + assert index_class is DatetimeIndex # just checking + + return factory, kwargs + + def validate_read(self, columns, where) -> None: + """ + raise if any keywords are passed which are not-None + """ + if columns is not None: + raise TypeError( + "cannot pass a column specification when reading " + "a Fixed format store. this store must be selected in its entirety" + ) + if where is not None: + raise TypeError( + "cannot pass a where specification when reading " + "from a Fixed format store. this store must be selected in its entirety" + ) + + @property + def is_exists(self) -> bool: + return True + + def set_attrs(self) -> None: + """set our object attributes""" + self.attrs.encoding = self.encoding + self.attrs.errors = self.errors + + def get_attrs(self) -> None: + """retrieve our attributes""" + self.encoding = _ensure_encoding(getattr(self.attrs, "encoding", None)) + self.errors = getattr(self.attrs, "errors", "strict") + for n in self.attributes: + setattr(self, n, getattr(self.attrs, n, None)) + + def write(self, obj, **kwargs) -> None: + self.set_attrs() + + def read_array(self, key: str, start: int | None = None, stop: int | None = None): + """read an array for the specified node (off of group""" + import tables + + node = getattr(self.group, key) + attrs = node._v_attrs + + transposed = getattr(attrs, "transposed", False) + + if isinstance(node, tables.VLArray): + ret = node[0][start:stop] + dtype = getattr(attrs, "value_type", None) + if dtype is not None: + ret = pd_array(ret, dtype=dtype) + else: + dtype = getattr(attrs, "value_type", None) + shape = getattr(attrs, "shape", None) + + if shape is not None: + # length 0 axis + ret = np.empty(shape, dtype=dtype) + else: + ret = node[start:stop] + + if dtype and dtype.startswith("datetime64"): + # reconstruct a timezone if indicated + if dtype == "datetime64": + dtype = "datetime64[ns]" + tz = getattr(attrs, "tz", None) + ret = _set_tz(ret, tz, dtype) + + elif dtype and dtype.startswith("timedelta64"): + if dtype == "timedelta64": + # This was written back before we started writing + # timedelta64 units + ret = np.asarray(ret, dtype="m8[ns]") + else: + ret = np.asarray(ret, dtype=dtype) + + if transposed: + return ret.T + else: + return ret + + def read_index( + self, key: str, start: int | None = None, stop: int | None = None + ) -> Index: + variety = getattr(self.attrs, f"{key}_variety") + + if variety == "multi": + return self.read_multi_index(key, start=start, stop=stop) + elif variety == "regular": + node = getattr(self.group, key) + index = self.read_index_node(node, start=start, stop=stop) + return index + else: # pragma: no cover + raise TypeError(f"unrecognized index variety: {variety}") + + def write_index(self, key: str, index: Index) -> None: + if isinstance(index, MultiIndex): + setattr(self.attrs, f"{key}_variety", "multi") + self.write_multi_index(key, index) + else: + setattr(self.attrs, f"{key}_variety", "regular") + converted = _convert_index("index", index, self.encoding, self.errors) + + self.write_array(key, converted.values) + + node = getattr(self.group, key) + node._v_attrs.kind = converted.kind + node._v_attrs.name = index.name + + if isinstance(index, (DatetimeIndex, PeriodIndex)): + node._v_attrs.index_class = self._class_to_alias(type(index)) + + if isinstance(index, (DatetimeIndex, PeriodIndex, TimedeltaIndex)): + node._v_attrs.freq = index.freq + + if isinstance(index, DatetimeIndex) and index.tz is not None: + node._v_attrs.tz = _get_tz(index.tz) + + def write_multi_index(self, key: str, index: MultiIndex) -> None: + setattr(self.attrs, f"{key}_nlevels", index.nlevels) + + for i, (lev, level_codes, name) in enumerate( + zip(index.levels, index.codes, index.names, strict=True) + ): + # write the level + if isinstance(lev.dtype, ExtensionDtype) and not isinstance( + lev.dtype, StringDtype + ): + raise NotImplementedError( + "Saving a MultiIndex with an extension dtype is not supported." + ) + level_key = f"{key}_level{i}" + conv_level = _convert_index(level_key, lev, self.encoding, self.errors) + self.write_array(level_key, conv_level.values) + node = getattr(self.group, level_key) + node._v_attrs.kind = conv_level.kind + node._v_attrs.name = name + + # write the name + setattr(node._v_attrs, f"{key}_name{name}", name) + + # write the labels + label_key = f"{key}_label{i}" + self.write_array(label_key, level_codes) + + def read_multi_index( + self, key: str, start: int | None = None, stop: int | None = None + ) -> MultiIndex: + nlevels = getattr(self.attrs, f"{key}_nlevels") + + levels = [] + codes = [] + names: list[Hashable] = [] + for i in range(nlevels): + level_key = f"{key}_level{i}" + node = getattr(self.group, level_key) + lev = self.read_index_node(node, start=start, stop=stop) + levels.append(lev) + names.append(lev.name) + + label_key = f"{key}_label{i}" + level_codes = self.read_array(label_key, start=start, stop=stop) + codes.append(level_codes) + + return MultiIndex( + levels=levels, codes=codes, names=names, verify_integrity=True + ) + + def read_index_node( + self, node: Node, start: int | None = None, stop: int | None = None + ) -> Index: + data = node[start:stop] + # If the index was an empty array write_array_empty() will + # have written a sentinel. Here we replace it with the original. + if "shape" in node._v_attrs and np.prod(node._v_attrs.shape) == 0: + data = np.empty(node._v_attrs.shape, dtype=node._v_attrs.value_type) + kind = node._v_attrs.kind + name = None + + if "name" in node._v_attrs: + name = _ensure_str(node._v_attrs.name) + + attrs = node._v_attrs + factory, kwargs = self._get_index_factory(attrs) + + if kind in ("date", "object"): + index = factory( + _unconvert_index( + data, kind, encoding=self.encoding, errors=self.errors + ), + dtype=object, + **kwargs, + ) + else: + try: + index = factory( + _unconvert_index( + data, kind, encoding=self.encoding, errors=self.errors + ), + **kwargs, + ) + except UnicodeEncodeError as err: + if ( + self.errors == "surrogatepass" + and using_string_dtype() + and str(err).endswith("surrogates not allowed") + and HAS_PYARROW + ): + index = factory( + _unconvert_index( + data, kind, encoding=self.encoding, errors=self.errors + ), + dtype=StringDtype(storage="python", na_value=np.nan), + **kwargs, + ) + else: + raise + + index.name = name + + return index + + def write_array_empty(self, key: str, value: ArrayLike) -> None: + """write a 0-len array""" + # ugly hack for length 0 axes + arr = np.empty((1,) * value.ndim) + self._handle.create_array(self.group, key, arr) + node = getattr(self.group, key) + node._v_attrs.value_type = str(value.dtype) + node._v_attrs.shape = value.shape + + def write_array( + self, key: str, obj: AnyArrayLike, items: Index | None = None + ) -> None: + # TODO: we only have a few tests that get here, the only EA + # that gets passed is DatetimeArray, and we never have + # both self._filters and EA + + value = extract_array(obj, extract_numpy=True) + + if key in self.group: + self._handle.remove_node(self.group, key) + + # Transform needed to interface with pytables row/col notation + empty_array = value.size == 0 + transposed = False + + if isinstance(value.dtype, CategoricalDtype): + raise NotImplementedError( + "Cannot store a category dtype in an HDF5 dataset that uses format=" + '"fixed". Use format="table".' + ) + if not empty_array: + if hasattr(value, "T"): + # ExtensionArrays (1d) may not have transpose. + value = value.T + transposed = True + + atom = None + if self._filters is not None: + with suppress(ValueError): + # get the atom for this datatype + atom = _tables().Atom.from_dtype(value.dtype) + + if atom is not None: + # We only get here if self._filters is non-None and + # the Atom.from_dtype call succeeded + + # create an empty chunked array and fill it from value + if not empty_array: + ca = self._handle.create_carray( + self.group, key, atom, value.shape, filters=self._filters + ) + ca[:] = value + + else: + self.write_array_empty(key, value) + + elif value.dtype.type == np.object_: + # infer the type, warn if we have a non-string type here (for + # performance) + inferred_type = lib.infer_dtype(value, skipna=False) + if empty_array: + pass + elif inferred_type == "string": + pass + elif get_option("performance_warnings"): + ws = performance_doc % (inferred_type, key, items) + warnings.warn(ws, PerformanceWarning, stacklevel=find_stack_level()) + + vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom()) + vlarr.append(value) + + elif lib.is_np_dtype(value.dtype, "M"): + self._handle.create_array(self.group, key, value.view("i8")) + getattr(self.group, key)._v_attrs.value_type = str(value.dtype) + elif isinstance(value.dtype, DatetimeTZDtype): + # store as UTC + # with a zone + + # error: "ExtensionArray" has no attribute "asi8" + self._handle.create_array( + self.group, + key, + value.asi8, # type: ignore[attr-defined] + ) + + node = getattr(self.group, key) + # error: "ExtensionArray" has no attribute "tz" + node._v_attrs.tz = _get_tz(value.tz) # type: ignore[attr-defined] + node._v_attrs.value_type = f"datetime64[{value.dtype.unit}]" + elif lib.is_np_dtype(value.dtype, "m"): + self._handle.create_array(self.group, key, value.view("i8")) + getattr(self.group, key)._v_attrs.value_type = str(value.dtype) + elif isinstance(value, BaseStringArray): + vlarr = self._handle.create_vlarray(self.group, key, _tables().ObjectAtom()) + vlarr.append(value.to_numpy()) + node = getattr(self.group, key) + node._v_attrs.value_type = str(value.dtype) + elif empty_array: + self.write_array_empty(key, value) + else: + self._handle.create_array(self.group, key, value) + + getattr(self.group, key)._v_attrs.transposed = transposed + + +class SeriesFixed(GenericFixed): + pandas_kind = "series" + attributes = ["name"] + + name: Hashable + + @property + def shape(self) -> tuple[int] | None: + try: + return (len(self.group.values),) + except (TypeError, AttributeError): + return None + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ) -> Series: + self.validate_read(columns, where) + index = self.read_index("index", start=start, stop=stop) + values = self.read_array("values", start=start, stop=stop) + try: + result = Series(values, index=index, name=self.name, copy=False) + except UnicodeEncodeError as err: + if ( + self.errors == "surrogatepass" + and using_string_dtype() + and str(err).endswith("surrogates not allowed") + and HAS_PYARROW + ): + result = Series( + values, + index=index, + name=self.name, + copy=False, + dtype=StringDtype(storage="python", na_value=np.nan), + ) + else: + raise + return result + + def write(self, obj, **kwargs) -> None: + super().write(obj, **kwargs) + self.write_index("index", obj.index) + self.write_array("values", obj) + self.attrs.name = obj.name + + +class BlockManagerFixed(GenericFixed): + attributes = ["ndim", "nblocks"] + + nblocks: int + + @property + def shape(self) -> list[int] | None: + try: + ndim = self.ndim + + # items + items = 0 + for i in range(self.nblocks): + node = getattr(self.group, f"block{i}_items") + shape = getattr(node, "shape", None) + if shape is not None: + items += shape[0] + + # data shape + node = self.group.block0_values + shape = getattr(node, "shape", None) + if shape is not None: + shape = list(shape[0 : (ndim - 1)]) + else: + shape = [] + + shape.append(items) + + return shape + except AttributeError: + return None + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ) -> DataFrame: + # start, stop applied to rows, so 0th axis only + self.validate_read(columns, where) + select_axis = self.obj_type()._get_block_manager_axis(0) + + axes = [] + for i in range(self.ndim): + _start, _stop = (start, stop) if i == select_axis else (None, None) + ax = self.read_index(f"axis{i}", start=_start, stop=_stop) + axes.append(ax) + + items = axes[0] + dfs = [] + + for i in range(self.nblocks): + blk_items = self.read_index(f"block{i}_items") + values = self.read_array(f"block{i}_values", start=_start, stop=_stop) + + columns = items[items.get_indexer(blk_items)] + df = DataFrame(values.T, columns=columns, index=axes[1], copy=False) + if ( + using_string_dtype() + and isinstance(values, np.ndarray) + and is_string_array(values, skipna=True) + ): + df = df.astype(StringDtype(na_value=np.nan)) + dfs.append(df) + + if len(dfs) > 0: + out = concat(dfs, axis=1).copy() + return out.reindex(columns=items) + + return DataFrame(columns=axes[0], index=axes[1]) + + def write(self, obj, **kwargs) -> None: + super().write(obj, **kwargs) + + data = obj._mgr + if not data.is_consolidated(): + data = data.consolidate() + + self.attrs.ndim = data.ndim + for i, ax in enumerate(data.axes): + if i == 0 and (not ax.is_unique): + raise ValueError("Columns index has to be unique for fixed format") + self.write_index(f"axis{i}", ax) + + # Supporting mixed-type DataFrame objects...nontrivial + self.attrs.nblocks = len(data.blocks) + for i, blk in enumerate(data.blocks): + # I have no idea why, but writing values before items fixed #2299 + blk_items = data.items.take(blk.mgr_locs) + self.write_array(f"block{i}_values", blk.values, items=blk_items) + self.write_index(f"block{i}_items", blk_items) + + +class FrameFixed(BlockManagerFixed): + pandas_kind = "frame" + obj_type = DataFrame + + +class Table(Fixed): + """ + represent a table: + facilitate read/write of various types of tables + + Attrs in Table Node + ------------------- + These are attributes that are store in the main table node, they are + necessary to recreate these tables when read back in. + + index_axes : a list of tuples of the (original indexing axis and + index column) + non_index_axes: a list of tuples of the (original index axis and + columns on a non-indexing axis) + values_axes : a list of the columns which comprise the data of this + table + data_columns : a list of the columns that we are allowing indexing + (these become single columns in values_axes) + nan_rep : the string to use for nan representations for string + objects + levels : the names of levels + metadata : the names of the metadata columns + """ + + pandas_kind = "wide_table" + format_type: str = "table" # GH#30962 needed by dask + table_type: str + levels: int | list[Hashable] = 1 + is_table = True + + metadata: list + + def __init__( + self, + parent: HDFStore, + group: Node, + encoding: str | None = None, + errors: str = "strict", + index_axes: list[IndexCol] | None = None, + non_index_axes: list[tuple[AxisInt, Any]] | None = None, + values_axes: list[DataCol] | None = None, + data_columns: list | None = None, + info: dict | None = None, + nan_rep=None, + ) -> None: + super().__init__(parent, group, encoding=encoding, errors=errors) + self.index_axes = index_axes or [] + self.non_index_axes = non_index_axes or [] + self.values_axes = values_axes or [] + self.data_columns = data_columns or [] + self.info = info or {} + self.nan_rep = nan_rep + + @property + def table_type_short(self) -> str: + return self.table_type.split("_")[0] + + def __repr__(self) -> str: + """return a pretty representation of myself""" + self.infer_axes() + jdc = ",".join(self.data_columns) if len(self.data_columns) else "" + dc = f",dc->[{jdc}]" + + ver = "" + if self.is_old_version: + jver = ".".join([str(x) for x in self.version]) + ver = f"[{jver}]" + + jindex_axes = ",".join([a.name for a in self.index_axes]) + return ( + f"{self.pandas_type:12.12}{ver} " + f"(typ->{self.table_type_short},nrows->{self.nrows}," + f"ncols->{self.ncols},indexers->[{jindex_axes}]{dc})" + ) + + def __getitem__(self, c: str): + """return the axis for c""" + for a in self.axes: + if c == a.name: + return a + return None + + def validate(self, other) -> None: + """validate against an existing table""" + if other is None: + return + + if other.table_type != self.table_type: + raise TypeError( + "incompatible table_type with existing " + f"[{other.table_type} - {self.table_type}]" + ) + + for c in ["index_axes", "non_index_axes", "values_axes"]: + sv = getattr(self, c, None) + ov = getattr(other, c, None) + if sv != ov: + # show the error for the specific axes + # Argument 1 to "enumerate" has incompatible type + # "Optional[Any]"; expected "Iterable[Any]" [arg-type] + for i, sax in enumerate(sv): # type: ignore[arg-type] + # Value of type "Optional[Any]" is not indexable [index] + oax = ov[i] # type: ignore[index] + if sax != oax: + if c == "values_axes" and sax.kind != oax.kind: + raise ValueError( + f"Cannot serialize the column [{oax.values[0]}] " + f"because its data contents are not [{sax.kind}] " + f"but [{oax.kind}] object dtype" + ) + raise ValueError( + f"invalid combination of [{c}] on appending data " + f"[{sax}] vs current table [{oax}]" + ) + + # should never get here + raise Exception( + f"invalid combination of [{c}] on appending data [{sv}] vs " + f"current table [{ov}]" + ) + + @property + def is_multi_index(self) -> bool: + """the levels attribute is 1 or a list in the case of a multi-index""" + return isinstance(self.levels, list) + + def validate_multiindex( + self, obj: DataFrame | Series + ) -> tuple[DataFrame, list[Hashable]]: + """ + validate that we can store the multi-index; reset and return the + new object + """ + levels = com.fill_missing_names(obj.index.names) + try: + reset_obj = obj.reset_index() + except ValueError as err: + raise ValueError( + "duplicate names/columns in the multi-index when storing as a table" + ) from err + assert isinstance(reset_obj, DataFrame) # for mypy + return reset_obj, levels + + @property + def nrows_expected(self) -> int: + """based on our axes, compute the expected nrows""" + return np.prod([i.cvalues.shape[0] for i in self.index_axes]) + + @property + def is_exists(self) -> bool: + """has this table been created""" + return "table" in self.group + + @property + def storable(self): + return getattr(self.group, "table", None) + + @property + def table(self): + """return the table group (this is my storable)""" + return self.storable + + @property + def dtype(self): + return self.table.dtype + + @property + def description(self): + return self.table.description + + @property + def axes(self) -> itertools.chain[IndexCol]: + return itertools.chain(self.index_axes, self.values_axes) + + @property + def ncols(self) -> int: + """the number of total columns in the values axes""" + return sum(len(a.values) for a in self.values_axes) + + @property + def is_transposed(self) -> bool: + return False + + @property + def data_orientation(self) -> tuple[int, ...]: + """return a tuple of my permutated axes, non_indexable at the front""" + return tuple( + itertools.chain( + [int(a[0]) for a in self.non_index_axes], + [int(a.axis) for a in self.index_axes], + ) + ) + + def queryables(self) -> dict[str, Any]: + """return a dict of the kinds allowable columns for this object""" + # mypy doesn't recognize DataFrame._AXIS_NAMES, so we re-write it here + axis_names = {0: "index", 1: "columns"} + + # compute the values_axes queryables + d1 = [(a.cname, a) for a in self.index_axes] + d2 = [(axis_names[axis], None) for axis, values in self.non_index_axes] + d3 = [ + (v.cname, v) for v in self.values_axes if v.name in set(self.data_columns) + ] + + return dict(d1 + d2 + d3) + + def index_cols(self) -> list[tuple[Any, Any]]: + """return a list of my index cols""" + # Note: each `i.cname` below is assured to be a str. + return [(i.axis, i.cname) for i in self.index_axes] + + def values_cols(self) -> list[str]: + """return a list of my values cols""" + return [i.cname for i in self.values_axes] + + def _get_metadata_path(self, key: str) -> str: + """return the metadata pathname for this key""" + group = self.group._v_pathname + return f"{group}/meta/{key}/meta" + + def write_metadata(self, key: str, values: np.ndarray) -> None: + """ + Write out a metadata array to the key as a fixed-format Series. + + Parameters + ---------- + key : str + values : ndarray + """ + self.parent.put( + self._get_metadata_path(key), + Series(values, copy=False), + format="table", + encoding=self.encoding, + errors=self.errors, + nan_rep=self.nan_rep, + ) + + def read_metadata(self, key: str): + """return the meta data array for this key""" + if getattr(getattr(self.group, "meta", None), key, None) is not None: + return self.parent.select(self._get_metadata_path(key)) + return None + + def set_attrs(self) -> None: + """set our table type & indexables""" + self.attrs.table_type = str(self.table_type) + self.attrs.index_cols = self.index_cols() + self.attrs.values_cols = self.values_cols() + self.attrs.non_index_axes = self.non_index_axes + self.attrs.data_columns = self.data_columns + self.attrs.nan_rep = self.nan_rep + self.attrs.encoding = self.encoding + self.attrs.errors = self.errors + self.attrs.levels = self.levels + self.attrs.info = self.info + + def get_attrs(self) -> None: + """retrieve our attributes""" + self.non_index_axes = getattr(self.attrs, "non_index_axes", None) or [] + self.data_columns = getattr(self.attrs, "data_columns", None) or [] + self.info = getattr(self.attrs, "info", None) or {} + self.nan_rep = getattr(self.attrs, "nan_rep", None) + self.encoding = _ensure_encoding(getattr(self.attrs, "encoding", None)) + self.errors = getattr(self.attrs, "errors", "strict") + self.levels: list[Hashable] = getattr(self.attrs, "levels", None) or [] + self.index_axes = [a for a in self.indexables if a.is_an_indexable] + self.values_axes = [a for a in self.indexables if not a.is_an_indexable] + + def validate_version(self, where=None) -> None: + """are we trying to operate on an old version?""" + if where is not None: + if self.is_old_version: + ws = incompatibility_doc % ".".join([str(x) for x in self.version]) + warnings.warn( + ws, + IncompatibilityWarning, + stacklevel=find_stack_level(), + ) + + def validate_min_itemsize(self, min_itemsize) -> None: + """ + validate the min_itemsize doesn't contain items that are not in the + axes this needs data_columns to be defined + """ + if min_itemsize is None: + return + if not isinstance(min_itemsize, dict): + return + + q = self.queryables() + for k in min_itemsize: + # ok, apply generally + if k == "values": + continue + if k not in q: + raise ValueError( + f"min_itemsize has the key [{k}] which is not an axis or " + "data_column" + ) + + @cache_readonly + def indexables(self): + """create/cache the indexables if they don't exist""" + _indexables = [] + + desc = self.description + table_attrs = self.table.attrs + + # Note: each of the `name` kwargs below are str, ensured + # by the definition in index_cols. + # index columns + for i, (axis, name) in enumerate(self.attrs.index_cols): + atom = getattr(desc, name) + md = self.read_metadata(name) + meta = "category" if md is not None else None + + kind_attr = f"{name}_kind" + kind = getattr(table_attrs, kind_attr, None) + + index_col = IndexCol( + name=name, + axis=axis, + pos=i, + kind=kind, + typ=atom, + table=self.table, + meta=meta, + metadata=md, + ) + _indexables.append(index_col) + + # values columns + dc = set(self.data_columns) + base_pos = len(_indexables) + + def f(i, c: str) -> DataCol: + assert isinstance(c, str) + klass = DataCol + if c in dc: + klass = DataIndexableCol + + atom = getattr(desc, c) + adj_name = _maybe_adjust_name(c, self.version) + + # TODO: why kind_attr here? + values = getattr(table_attrs, f"{adj_name}_kind", None) + dtype = getattr(table_attrs, f"{adj_name}_dtype", None) + # Argument 1 to "_dtype_to_kind" has incompatible type + # "Optional[Any]"; expected "str" [arg-type] + kind = _dtype_to_kind(dtype) # type: ignore[arg-type] + + md = self.read_metadata(c) + # TODO: figure out why these two versions of `meta` dont always match. + # meta = "category" if md is not None else None + meta = getattr(table_attrs, f"{adj_name}_meta", None) + + obj = klass( + name=adj_name, + cname=c, + values=values, + kind=kind, + pos=base_pos + i, + typ=atom, + table=self.table, + meta=meta, + metadata=md, + dtype=dtype, + ) + return obj + + # Note: the definition of `values_cols` ensures that each + # `c` below is a str. + _indexables.extend([f(i, c) for i, c in enumerate(self.attrs.values_cols)]) + + return _indexables + + def create_index( + self, columns=None, optlevel=None, kind: str | None = None + ) -> None: + """ + Create a pytables index on the specified columns. + + Parameters + ---------- + columns : None, bool, or listlike[str] + Indicate which columns to create an index on. + + * False : Do not create any indexes. + * True : Create indexes on all columns. + * None : Create indexes on all columns. + * listlike : Create indexes on the given columns. + + optlevel : int or None, default None + Optimization level, if None, pytables defaults to 6. + kind : str or None, default None + Kind of index, if None, pytables defaults to "medium". + + Raises + ------ + TypeError if trying to create an index on a complex-type column. + + Notes + ----- + Cannot index Time64Col or ComplexCol. + Pytables must be >= 3.0. + """ + if not self.infer_axes(): + return + if columns is False: + return + + # index all indexables and data_columns + if columns is None or columns is True: + columns = [a.cname for a in self.axes if a.is_data_indexable] + if not isinstance(columns, (tuple, list)): + columns = [columns] + + kw = {} + if optlevel is not None: + kw["optlevel"] = optlevel + if kind is not None: + kw["kind"] = kind + + table = self.table + for c in columns: + v = getattr(table.cols, c, None) + if v is not None: + # remove the index if the kind/optlevel have changed + if v.is_indexed: + index = v.index + cur_optlevel = index.optlevel + cur_kind = index.kind + + if kind is not None and cur_kind != kind: + v.remove_index() + else: + kw["kind"] = cur_kind + + if optlevel is not None and cur_optlevel != optlevel: + v.remove_index() + else: + kw["optlevel"] = cur_optlevel + + # create the index + if not v.is_indexed: + if v.type.startswith("complex"): + raise TypeError( + "Columns containing complex values can be stored but " + "cannot be indexed when using table format. Either use " + "fixed format, set index=False, or do not include " + "the columns containing complex values to " + "data_columns when initializing the table." + ) + v.create_index(**kw) + elif c in self.non_index_axes[0][1]: + # GH 28156 + raise AttributeError( + f"column {c} is not a data_column.\n" + f"In order to read column {c} you must reload the dataframe \n" + f"into HDFStore and include {c} with the data_columns argument." + ) + + def _read_axes( + self, where, start: int | None = None, stop: int | None = None + ) -> list[tuple[np.ndarray, np.ndarray] | tuple[Index, Index]]: + """ + Create the axes sniffed from the table. + + Parameters + ---------- + where : ??? + start : int or None, default None + stop : int or None, default None + + Returns + ------- + List[Tuple[index_values, column_values]] + """ + # create the selection + selection = Selection(self, where=where, start=start, stop=stop) + values = selection.select() + + results = [] + # convert the data + for a in self.axes: + a.set_info(self.info) + res = a.convert( + values, + nan_rep=self.nan_rep, + encoding=self.encoding, + errors=self.errors, + ) + results.append(res) + + return results + + @classmethod + def get_object(cls, obj, transposed: bool): + """return the data for this obj""" + return obj + + def validate_data_columns(self, data_columns, min_itemsize, non_index_axes) -> list: + """ + take the input data_columns and min_itemize and create a data + columns spec + """ + if not len(non_index_axes): + return [] + + axis, axis_labels = non_index_axes[0] + info = self.info.get(axis, {}) + if info.get("type") == "MultiIndex" and data_columns: + raise ValueError( + f"cannot use a multi-index on axis [{axis}] with " + f"data_columns {data_columns}" + ) + + # evaluate the passed data_columns, True == use all columns + # take only valid axis labels + if data_columns is True: + data_columns = list(axis_labels) + elif data_columns is None: + data_columns = [] + + # if min_itemsize is a dict, add the keys (exclude 'values') + if isinstance(min_itemsize, dict): + existing_data_columns = set(data_columns) + data_columns = list(data_columns) # ensure we do not modify + data_columns.extend( + [ + k + for k in min_itemsize.keys() + if k != "values" and k not in existing_data_columns + ] + ) + + # return valid columns in the order of our axis + return [c for c in data_columns if c in axis_labels] + + def _create_axes( + self, + axes, + obj: DataFrame, + validate: bool = True, + nan_rep=None, + data_columns=None, + min_itemsize=None, + ): + """ + Create and return the axes. + + Parameters + ---------- + axes: list or None + The names or numbers of the axes to create. + obj : DataFrame + The object to create axes on. + validate: bool, default True + Whether to validate the obj against an existing object already written. + nan_rep : + A value to use for string column nan_rep. + data_columns : List[str], True, or None, default None + Specify the columns that we want to create to allow indexing on. + + * True : Use all available columns. + * None : Use no columns. + * List[str] : Use the specified columns. + + min_itemsize: Dict[str, int] or None, default None + The min itemsize for a column in bytes. + """ + if not isinstance(obj, DataFrame): + group = self.group._v_name + raise TypeError( + f"cannot properly create the storer for: [group->{group}," + f"value->{type(obj)}]" + ) + + # set the default axes if needed + if axes is None: + axes = [0] + + # map axes to numbers + axes = [obj._get_axis_number(a) for a in axes] + + # do we have an existing table (if so, use its axes & data_columns) + if self.infer_axes(): + table_exists = True + axes = [a.axis for a in self.index_axes] + data_columns = list(self.data_columns) + nan_rep = self.nan_rep + # TODO: do we always have validate=True here? + else: + table_exists = False + + new_info = self.info + + assert self.ndim == 2 # with next check, we must have len(axes) == 1 + # currently support on ndim-1 axes + if len(axes) != self.ndim - 1: + raise ValueError( + "currently only support ndim-1 indexers in an AppendableTable" + ) + + # create according to the new data + new_non_index_axes: list = [] + + # nan_representation + if nan_rep is None: + nan_rep = "nan" + + # We construct the non-index-axis first, since that alters new_info + idx = next(x for x in [0, 1] if x not in axes) + + a = obj.axes[idx] + # we might be able to change the axes on the appending data if necessary + append_axis = list(a) + if table_exists: + indexer = len(new_non_index_axes) # i.e. 0 + exist_axis = self.non_index_axes[indexer][1] + if not array_equivalent( + np.array(append_axis), + np.array(exist_axis), + strict_nan=True, + dtype_equal=True, + ): + # ahah! -> reindex + if array_equivalent( + np.array(sorted(append_axis)), + np.array(sorted(exist_axis)), + strict_nan=True, + dtype_equal=True, + ): + append_axis = exist_axis + + # the non_index_axes info + info = new_info.setdefault(idx, {}) + info["names"] = list(a.names) + info["type"] = type(a).__name__ + + new_non_index_axes.append((idx, append_axis)) + + # Now we can construct our new index axis + idx = axes[0] + a = obj.axes[idx] + axis_name = obj._get_axis_name(idx) + new_index = _convert_index(axis_name, a, self.encoding, self.errors) + new_index.axis = idx + + # Because we are always 2D, there is only one new_index, so + # we know it will have pos=0 + new_index.set_pos(0) + new_index.update_info(new_info) + new_index.maybe_set_size(min_itemsize) # check for column conflicts + + new_index_axes = [new_index] + j = len(new_index_axes) # i.e. 1 + assert j == 1 + + # reindex by our non_index_axes & compute data_columns + assert len(new_non_index_axes) == 1 + for a in new_non_index_axes: + obj = _reindex_axis(obj, a[0], a[1]) + + transposed = new_index.axis == 1 + + # figure out data_columns and get out blocks + data_columns = self.validate_data_columns( + data_columns, min_itemsize, new_non_index_axes + ) + + frame = self.get_object(obj, transposed)._consolidate() + + blocks, blk_items = self._get_blocks_and_items( + frame, table_exists, new_non_index_axes, self.values_axes, data_columns + ) + + # add my values + vaxes = [] + for i, (blk, b_items) in enumerate(zip(blocks, blk_items, strict=True)): + # shape of the data column are the indexable axes + klass = DataCol + name = None + + # we have a data_column + if data_columns and len(b_items) == 1 and b_items[0] in data_columns: + klass = DataIndexableCol + name = b_items[0] + if not (name is None or isinstance(name, str)): + # TODO: should the message here be more specifically non-str? + raise ValueError("cannot have non-object label DataIndexableCol") + + # make sure that we match up the existing columns + # if we have an existing table + existing_col: DataCol | None + + if table_exists and validate: + try: + existing_col = self.values_axes[i] + except (IndexError, KeyError) as err: + raise ValueError( + f"Incompatible appended table [{blocks}]" + f"with existing table [{self.values_axes}]" + ) from err + else: + existing_col = None + + new_name = name or f"values_block_{i}" + data_converted = _maybe_convert_for_string_atom( + new_name, + blk.values, + existing_col=existing_col, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + encoding=self.encoding, + errors=self.errors, + columns=b_items, + ) + adj_name = _maybe_adjust_name(new_name, self.version) + + typ = klass._get_atom(data_converted) + kind = _dtype_to_kind(data_converted.dtype.name) + tz = None + if getattr(data_converted, "tz", None) is not None: + tz = _get_tz(data_converted.tz) + + meta = metadata = ordered = None + if isinstance(data_converted.dtype, CategoricalDtype): + ordered = data_converted.ordered + meta = "category" + metadata = np.asarray(data_converted.categories).ravel() + elif isinstance(blk.dtype, StringDtype): + meta = str(blk.dtype) + + data, dtype_name = _get_data_and_dtype_name(data_converted) + + col = klass( + name=adj_name, + cname=new_name, + values=list(b_items), + typ=typ, + pos=j, + kind=kind, + tz=tz, + ordered=ordered, + meta=meta, + metadata=metadata, + dtype=dtype_name, + data=data, + ) + col.update_info(new_info) + + vaxes.append(col) + + j += 1 + + dcs = [col.name for col in vaxes if col.is_data_indexable] + + new_table = type(self)( + parent=self.parent, + group=self.group, + encoding=self.encoding, + errors=self.errors, + index_axes=new_index_axes, + non_index_axes=new_non_index_axes, + values_axes=vaxes, + data_columns=dcs, + info=new_info, + nan_rep=nan_rep, + ) + if hasattr(self, "levels"): + # TODO: get this into constructor, only for appropriate subclass + new_table.levels = self.levels + + new_table.validate_min_itemsize(min_itemsize) + + if validate and table_exists: + new_table.validate(self) + + return new_table + + @staticmethod + def _get_blocks_and_items( + frame: DataFrame, + table_exists: bool, + new_non_index_axes, + values_axes, + data_columns, + ): + # Helper to clarify non-state-altering parts of _create_axes + def get_blk_items(mgr): + return [mgr.items.take(blk.mgr_locs) for blk in mgr.blocks] + + mgr = frame._mgr + blocks: list[Block] = list(mgr.blocks) + blk_items: list[Index] = get_blk_items(mgr) + + if len(data_columns): + # TODO: prove that we only get here with axis == 1? + # It is the case in all extant tests, but NOT the case + # outside this `if len(data_columns)` check. + + axis, axis_labels = new_non_index_axes[0] + new_labels = Index(axis_labels).difference(Index(data_columns)) + mgr = frame.reindex(new_labels, axis=axis)._mgr + + blocks = list(mgr.blocks) + blk_items = get_blk_items(mgr) + for c in data_columns: + # This reindex would raise ValueError if we had a duplicate + # index, so we can infer that (as long as axis==1) we + # get a single column back, so a single block. + mgr = frame.reindex([c], axis=axis)._mgr + blocks.extend(mgr.blocks) + blk_items.extend(get_blk_items(mgr)) + + # reorder the blocks in the same order as the existing table if we can + if table_exists: + by_items = { + tuple(b_items.tolist()): (b, b_items) + for b, b_items in zip(blocks, blk_items, strict=True) + } + new_blocks: list[Block] = [] + new_blk_items = [] + for ea in values_axes: + items = tuple(ea.values) + try: + b, b_items = by_items.pop(items) + new_blocks.append(b) + new_blk_items.append(b_items) + except (IndexError, KeyError) as err: + jitems = ",".join([pprint_thing(item) for item in items]) + raise ValueError( + f"cannot match existing table structure for [{jitems}] " + "on appending data" + ) from err + blocks = new_blocks + blk_items = new_blk_items + + return blocks, blk_items + + def process_axes(self, obj, selection: Selection, columns=None) -> DataFrame: + """process axes filters""" + # make a copy to avoid side effects + if columns is not None: + columns = list(columns) + + # make sure to include levels if we have them + if columns is not None and self.is_multi_index: + assert isinstance(self.levels, list) # assured by is_multi_index + for n in self.levels: + if n not in columns: + columns.insert(0, n) + + # reorder by any non_index_axes & limit to the select columns + for axis, labels in self.non_index_axes: + obj = _reindex_axis(obj, axis, labels, columns) + + def process_filter(field, filt, op): + for axis_name in obj._AXIS_ORDERS: + axis_number = obj._get_axis_number(axis_name) + axis_values = obj._get_axis(axis_name) + assert axis_number is not None + + # see if the field is the name of an axis + if field == axis_name: + # if we have a multi-index, then need to include + # the levels + if self.is_multi_index: + filt = filt.union(Index(self.levels)) + + takers = op(axis_values, filt) + return obj.loc(axis=axis_number)[takers] + + # this might be the name of a file IN an axis + elif field in axis_values: + # we need to filter on this dimension + values = ensure_index(getattr(obj, field).values) + filt = ensure_index(filt) + + # hack until we support reversed dim flags + if isinstance(obj, DataFrame): + axis_number = 1 - axis_number + + takers = op(values, filt) + return obj.loc(axis=axis_number)[takers] + + raise ValueError(f"cannot find the field [{field}] for filtering!") + + # apply the selection filters (but keep in the same order) + if selection.filter is not None: + for field, op, filt in selection.filter.format(): + obj = process_filter(field, filt, op) + + return obj + + def create_description( + self, + complib, + complevel: int | None, + fletcher32: bool, + expectedrows: int | None, + ) -> dict[str, Any]: + """create the description of the table from the axes & values""" + # provided expected rows if its passed + if expectedrows is None: + expectedrows = max(self.nrows_expected, 10000) + + d = {"name": "table", "expectedrows": expectedrows} + + # description from the axes & values + d["description"] = {a.cname: a.typ for a in self.axes} + + if complib: + if complevel is None: + complevel = self._complevel or 9 + filters = _tables().Filters( + complevel=complevel, + complib=complib, + fletcher32=fletcher32 or self._fletcher32, + ) + d["filters"] = filters + elif self._filters is not None: + d["filters"] = self._filters + + return d + + def read_coordinates( + self, where=None, start: int | None = None, stop: int | None = None + ): + """ + select coordinates (row numbers) from a table; return the + coordinates object + """ + # validate the version + self.validate_version(where) + + # infer the data kind + if not self.infer_axes(): + return False + + # create the selection + selection = Selection(self, where=where, start=start, stop=stop) + coords = selection.select_coords() + if selection.filter is not None: + for field, op, filt in selection.filter.format(): + data = self.read_column( + field, start=coords.min(), stop=coords.max() + 1 + ) + coords = coords[op(data.iloc[coords - coords.min()], filt).values] + + return Index(coords, copy=False) + + def read_column( + self, + column: str, + where=None, + start: int | None = None, + stop: int | None = None, + ): + """ + return a single column from the table, generally only indexables + are interesting + """ + # validate the version + self.validate_version() + + # infer the data kind + if not self.infer_axes(): + return False + + if where is not None: + raise TypeError("read_column does not currently accept a where clause") + + # find the axes + for a in self.axes: + if column == a.name: + if not a.is_data_indexable: + raise ValueError( + f"column [{column}] can not be extracted individually; " + "it is not data indexable" + ) + + # column must be an indexable or a data column + c = getattr(self.table.cols, column) + a.set_info(self.info) + col_values = a.convert( + c[start:stop], + nan_rep=self.nan_rep, + encoding=self.encoding, + errors=self.errors, + ) + cvs = col_values[1] + dtype = getattr(self.table.attrs, f"{column}_meta", None) + return Series(cvs, name=column, copy=False, dtype=dtype) + + raise KeyError(f"column [{column}] not found in the table") + + +class WORMTable(Table): + """ + a write-once read-many table: this format DOES NOT ALLOW appending to a + table. writing is a one-time operation the data are stored in a format + that allows for searching the data on disk + """ + + table_type = "worm" + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ): + """ + read the indices and the indexing array, calculate offset rows and return + """ + raise NotImplementedError("WORMTable needs to implement read") + + def write(self, obj, **kwargs) -> None: + """ + write in a format that we can search later on (but cannot append + to): write out the indices and the values using _write_array + (e.g. a CArray) create an indexing table so that we can search + """ + raise NotImplementedError("WORMTable needs to implement write") + + +class AppendableTable(Table): + """support the new appendable table formats""" + + table_type = "appendable" + + # error: Signature of "write" incompatible with supertype "Fixed" + def write( # type: ignore[override] + self, + obj, + axes=None, + append: bool = False, + complib=None, + complevel=None, + fletcher32=None, + min_itemsize=None, + chunksize: int | None = None, + expectedrows=None, + dropna: bool = False, + nan_rep=None, + data_columns=None, + track_times: bool = True, + ) -> None: + if not append and self.is_exists: + self._handle.remove_node(self.group, "table") + + # create the axes + table = self._create_axes( + axes=axes, + obj=obj, + validate=append, + min_itemsize=min_itemsize, + nan_rep=nan_rep, + data_columns=data_columns, + ) + + for a in table.axes: + a.validate_names() + + if not table.is_exists: + # create the table + options = table.create_description( + complib=complib, + complevel=complevel, + fletcher32=fletcher32, + expectedrows=expectedrows, + ) + + # set the table attributes + table.set_attrs() + + options["track_times"] = track_times + + # create the table + table._handle.create_table(table.group, **options) + + # update my info + table.attrs.info = table.info + + # validate the axes and set the kinds + for a in table.axes: + a.validate_and_set(table, append) + + # add the rows + table.write_data(chunksize, dropna=dropna) + + def write_data(self, chunksize: int | None, dropna: bool = False) -> None: + """ + we form the data into a 2-d including indexes,values,mask write chunk-by-chunk + """ + names = self.dtype.names + nrows = self.nrows_expected + + # if dropna==True, then drop ALL nan rows + masks = [] + if dropna: + for a in self.values_axes: + # figure the mask: only do if we can successfully process this + # column, otherwise ignore the mask + mask = isna(a.data).all(axis=0) + if isinstance(mask, np.ndarray): + masks.append(mask.astype("u1", copy=False)) + + # consolidate masks + if masks: + mask = masks[0] + for m in masks[1:]: + mask = mask & m + mask = mask.ravel() + else: + mask = None + + # broadcast the indexes if needed + indexes = [a.cvalues for a in self.index_axes] + nindexes = len(indexes) + assert nindexes == 1, nindexes # ensures we dont need to broadcast + + # transpose the values so first dimension is last + # reshape the values if needed + values = [a.take_data() for a in self.values_axes] + values = [v.transpose(np.roll(np.arange(v.ndim), v.ndim - 1)) for v in values] + bvalues = [] + for i, v in enumerate(values): + new_shape = (nrows, *self.dtype[names[nindexes + i]].shape) + bvalues.append(v.reshape(new_shape)) + + # write the chunks + if chunksize is None: + chunksize = 100000 + + rows = np.empty(min(chunksize, nrows), dtype=self.dtype) + chunks = nrows // chunksize + 1 + for i in range(chunks): + start_i = i * chunksize + end_i = min((i + 1) * chunksize, nrows) + if start_i >= end_i: + break + + self.write_data_chunk( + rows, + indexes=[a[start_i:end_i] for a in indexes], + mask=mask[start_i:end_i] if mask is not None else None, + values=[v[start_i:end_i] for v in bvalues], + ) + + def write_data_chunk( + self, + rows: np.ndarray, + indexes: list[np.ndarray], + mask: npt.NDArray[np.bool_] | None, + values: list[np.ndarray], + ) -> None: + """ + Parameters + ---------- + rows : an empty memory space where we are putting the chunk + indexes : an array of the indexes + mask : an array of the masks + values : an array of the values + """ + # 0 len + for v in values: + if not np.prod(v.shape): + return + + nrows = indexes[0].shape[0] + if nrows != len(rows): + rows = np.empty(nrows, dtype=self.dtype) + names = self.dtype.names + nindexes = len(indexes) + + # indexes + for i, idx in enumerate(indexes): + rows[names[i]] = idx + + # values + for i, v in enumerate(values): + rows[names[i + nindexes]] = v + + # mask + if mask is not None: + m = ~mask.ravel().astype(bool, copy=False) + if not m.all(): + rows = rows[m] + + if len(rows): + self.table.append(rows) + self.table.flush() + + def delete( + self, where=None, start: int | None = None, stop: int | None = None + ) -> int | None: + # delete all rows (and return the nrows) + if where is None or not len(where): + if start is None and stop is None: + nrows = self.nrows + self._handle.remove_node(self.group, recursive=True) + else: + # pytables<3.0 would remove a single row with stop=None + if stop is None: + stop = self.nrows + nrows = self.table.remove_rows(start=start, stop=stop) + self.table.flush() + return nrows + + # infer the data kind + if not self.infer_axes(): + return None + + # create the selection + table = self.table + selection = Selection(self, where, start=start, stop=stop) + values = selection.select_coords() + + # delete the rows in reverse order + sorted_series = Series(values, copy=False).sort_values() + ln = len(sorted_series) + + if ln: + # construct groups of consecutive rows + diff = sorted_series.diff() + groups = list(diff[diff > 1].index) + + # 1 group + if not groups: + groups = [0] + + # final element + if groups[-1] != ln: + groups.append(ln) + + # initial element + if groups[0] != 0: + groups.insert(0, 0) + + # we must remove in reverse order! + pg = groups.pop() + for g in reversed(groups): + rows = sorted_series.take(range(g, pg)) + table.remove_rows( + start=rows[rows.index[0]], stop=rows[rows.index[-1]] + 1 + ) + pg = g + + self.table.flush() + + # return the number of rows removed + return ln + + +class AppendableFrameTable(AppendableTable): + """support the new appendable table formats""" + + pandas_kind = "frame_table" + table_type = "appendable_frame" + ndim = 2 + obj_type: type[DataFrame | Series] = DataFrame + + @property + def is_transposed(self) -> bool: + return self.index_axes[0].axis == 1 + + @classmethod + def get_object(cls, obj, transposed: bool): + """these are written transposed""" + if transposed: + obj = obj.T + return obj + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ): + # validate the version + self.validate_version(where) + + # infer the data kind + if not self.infer_axes(): + return None + + result = self._read_axes(where=where, start=start, stop=stop) + + info = ( + self.info.get(self.non_index_axes[0][0], {}) + if len(self.non_index_axes) + else {} + ) + + inds = [i for i, ax in enumerate(self.axes) if ax is self.index_axes[0]] + assert len(inds) == 1 + ind = inds[0] + + index = result[ind][0] + + frames = [] + for i, a in enumerate(self.axes): + if a not in self.values_axes: + continue + index_vals, cvalues = result[i] + + # we could have a multi-index constructor here + # ensure_index doesn't recognized our list-of-tuples here + if info.get("type") != "MultiIndex": + cols = Index(index_vals) + else: + cols = MultiIndex.from_tuples(index_vals) + + names = info.get("names") + if names is not None: + cols.set_names(names, inplace=True) + + if self.is_transposed: + values = cvalues + index_ = cols + cols_ = Index(index, name=getattr(index, "name", None)) + else: + values = cvalues.T + index_ = Index(index, name=getattr(index, "name", None)) + cols_ = cols + + # if we have a DataIndexableCol, its shape will only be 1 dim + if values.ndim == 1 and isinstance(values, np.ndarray): + values = values.reshape((1, values.shape[0])) + + if isinstance(values, (np.ndarray, DatetimeArray)): + try: + df = DataFrame(values.T, columns=cols_, index=index_, copy=False) + except UnicodeEncodeError as err: + if ( + self.errors == "surrogatepass" + and using_string_dtype() + and str(err).endswith("surrogates not allowed") + and HAS_PYARROW + ): + df = DataFrame( + values.T, + columns=cols_, + index=index_, + copy=False, + dtype=StringDtype(storage="python", na_value=np.nan), + ) + else: + raise + elif isinstance(values, Index): + df = DataFrame(values, columns=cols_, index=index_) + else: + # Categorical + df = DataFrame._from_arrays([values], columns=cols_, index=index_) + if not (using_string_dtype() and values.dtype.kind == "O"): + assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype) + + # If str / string dtype is stored in meta, use that. + for column in cols_: + dtype = getattr(self.table.attrs, f"{column}_meta", None) + if dtype in ["str", "string"]: + df[column] = df[column].astype(dtype) + frames.append(df) + + if len(frames) == 1: + df = frames[0] + else: + df = concat(frames, axis=1) + + selection = Selection(self, where=where, start=start, stop=stop) + # apply the selection filters & axis orderings + df = self.process_axes(df, selection=selection, columns=columns) + return df + + +class AppendableSeriesTable(AppendableFrameTable): + """support the new appendable table formats""" + + pandas_kind = "series_table" + table_type = "appendable_series" + ndim = 2 + obj_type = Series + + @property + def is_transposed(self) -> bool: + return False + + @classmethod + def get_object(cls, obj, transposed: bool): + return obj + + # error: Signature of "write" incompatible with supertype "Fixed" + def write(self, obj, data_columns=None, **kwargs) -> None: # type: ignore[override] + """we are going to write this as a frame table""" + if not isinstance(obj, DataFrame): + name = obj.name or "values" + obj = obj.to_frame(name) + super().write(obj=obj, data_columns=obj.columns.tolist(), **kwargs) + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ) -> Series: + is_multi_index = self.is_multi_index + if columns is not None and is_multi_index: + assert isinstance(self.levels, list) # needed for mypy + for n in self.levels: + if n not in columns: + columns.insert(0, n) + s = super().read(where=where, columns=columns, start=start, stop=stop) + if is_multi_index: + s.set_index(self.levels, inplace=True) + + s = s.iloc[:, 0] + + # remove the default name + if s.name == "values": + s.name = None + return s + + +class AppendableMultiSeriesTable(AppendableSeriesTable): + """support the new appendable table formats""" + + pandas_kind = "series_table" + table_type = "appendable_multiseries" + + # error: Signature of "write" incompatible with supertype "Fixed" + def write(self, obj, **kwargs) -> None: # type: ignore[override] + """we are going to write this as a frame table""" + name = obj.name or "values" + newobj, self.levels = self.validate_multiindex(obj) + assert isinstance(self.levels, list) # for mypy + cols = list(self.levels) + cols.append(name) + newobj.columns = Index(cols) + super().write(obj=newobj, **kwargs) + + +class GenericTable(AppendableFrameTable): + """a table that read/writes the generic pytables table format""" + + pandas_kind = "frame_table" + table_type = "generic_table" + ndim = 2 + obj_type = DataFrame + levels: list[Hashable] + + @property + def pandas_type(self) -> str: + return self.pandas_kind + + @property + def storable(self): + return getattr(self.group, "table", None) or self.group + + def get_attrs(self) -> None: + """retrieve our attributes""" + self.non_index_axes = [] + self.nan_rep = None + self.levels = [] + + self.index_axes = [a for a in self.indexables if a.is_an_indexable] + self.values_axes = [a for a in self.indexables if not a.is_an_indexable] + self.data_columns = [a.name for a in self.values_axes] + + @cache_readonly + def indexables(self): + """create the indexables from the table description""" + d = self.description + + # TODO: can we get a typ for this? AFAICT it is the only place + # where we aren't passing one + # the index columns is just a simple index + md = self.read_metadata("index") + meta = "category" if md is not None else None + index_col = GenericIndexCol( + name="index", axis=0, table=self.table, meta=meta, metadata=md + ) + + _indexables: list[GenericIndexCol | GenericDataIndexableCol] = [index_col] + + for i, n in enumerate(d._v_names): + assert isinstance(n, str) + + atom = getattr(d, n) + md = self.read_metadata(n) + meta = "category" if md is not None else None + dc = GenericDataIndexableCol( + name=n, + pos=i, + values=[n], + typ=atom, + table=self.table, + meta=meta, + metadata=md, + ) + _indexables.append(dc) + + return _indexables + + # error: Signature of "write" incompatible with supertype "AppendableTable" + def write(self, **kwargs) -> None: # type: ignore[override] + raise NotImplementedError("cannot write on a generic table") + + +class AppendableMultiFrameTable(AppendableFrameTable): + """a frame with a multi-index""" + + table_type = "appendable_multiframe" + obj_type = DataFrame + ndim = 2 + _re_levels = re.compile(r"^level_\d+$") + + @property + def table_type_short(self) -> str: + return "appendable_multi" + + # error: Signature of "write" incompatible with supertype "Fixed" + def write(self, obj, data_columns=None, **kwargs) -> None: # type: ignore[override] + if data_columns is None: + data_columns = [] + elif data_columns is True: + data_columns = obj.columns.tolist() + obj, self.levels = self.validate_multiindex(obj) + assert isinstance(self.levels, list) # for mypy + for n in self.levels: + if n not in data_columns: + data_columns.insert(0, n) + super().write(obj=obj, data_columns=data_columns, **kwargs) + + def read( + self, + where=None, + columns=None, + start: int | None = None, + stop: int | None = None, + ) -> DataFrame: + df = super().read(where=where, columns=columns, start=start, stop=stop) + df = df.set_index(self.levels) + + # remove names for 'level_%d' + df.index = df.index.set_names( + [None if self._re_levels.search(name) else name for name in df.index.names] + ) + + return df + + +def _reindex_axis( + obj: DataFrame, axis: AxisInt, labels: Index, other=None +) -> DataFrame: + ax = obj._get_axis(axis) + labels = ensure_index(labels) + + # try not to reindex even if other is provided + # if it equals our current index + if other is not None: + other = ensure_index(other) + if (other is None or labels.equals(other)) and labels.equals(ax): + return obj + + labels = ensure_index(labels.unique()) + if other is not None: + labels = ensure_index(other.unique()).intersection(labels, sort=False) + if not labels.equals(ax): + slicer: list[slice | Index] = [slice(None, None)] * obj.ndim + slicer[axis] = labels + obj = obj.loc[tuple(slicer)] + return obj + + +# tz to/from coercion + + +def _get_tz(tz: tzinfo) -> str | tzinfo: + """for a tz-aware type, return an encoded zone""" + zone = timezones.get_timezone(tz) + return zone + + +def _set_tz( + values: npt.NDArray[np.int64], tz: str | tzinfo | None, datetime64_dtype: str +) -> DatetimeArray: + """ + Coerce the values to a DatetimeArray with appropriate tz. + + Parameters + ---------- + values : ndarray[int64] + tz : str, tzinfo, or None + datetime64_dtype : str, e.g. "datetime64[ns]", "datetime64[25s]" + """ + assert values.dtype == "i8", values.dtype + # Argument "tz" to "tz_to_dtype" has incompatible type "str | tzinfo | None"; + # expected "tzinfo" + unit, _ = np.datetime_data(datetime64_dtype) # parsing dtype: unit, count + unit = cast("TimeUnit", unit) + # error: Argument "tz" to "tz_to_dtype" has incompatible type + # "str | tzinfo | None"; expected "tzinfo" + dtype = tz_to_dtype(tz=tz, unit=unit) # type: ignore[arg-type] + dta = DatetimeArray._from_sequence(values, dtype=dtype) + return dta + + +def _convert_index(name: str, index: Index, encoding: str, errors: str) -> IndexCol: + assert isinstance(name, str) + + index_name = index.name + # error: Argument 1 to "_get_data_and_dtype_name" has incompatible type "Index"; + # expected "Union[ExtensionArray, ndarray]" + converted, dtype_name = _get_data_and_dtype_name(index) # type: ignore[arg-type] + kind = _dtype_to_kind(dtype_name) + atom = DataIndexableCol._get_atom(converted) + + if ( + lib.is_np_dtype(index.dtype, "iu") + or needs_i8_conversion(index.dtype) + or is_bool_dtype(index.dtype) + ): + # Includes Index, RangeIndex, DatetimeIndex, TimedeltaIndex, PeriodIndex, + # in which case "kind" is "integer", "integer", "datetime64", + # "timedelta64", and "integer", respectively. + return IndexCol( + name, + values=converted, + kind=kind, + typ=atom, + freq=getattr(index, "freq", None), + tz=getattr(index, "tz", None), + index_name=index_name, + ) + + if isinstance(index, MultiIndex): + raise TypeError("MultiIndex not supported here!") + + inferred_type = lib.infer_dtype(index, skipna=False) + # we won't get inferred_type of "datetime64" or "timedelta64" as these + # would go through the DatetimeIndex/TimedeltaIndex paths above + + values = np.asarray(index) + + if inferred_type == "date": + converted = np.asarray([v.toordinal() for v in values], dtype=np.int32) + return IndexCol( + name, converted, "date", _tables().Time32Col(), index_name=index_name + ) + elif inferred_type == "string": + converted = _convert_string_array(values, encoding, errors) + itemsize = converted.dtype.itemsize + return IndexCol( + name, + converted, + "string", + _tables().StringCol(itemsize), + index_name=index_name, + ) + + elif inferred_type in ["integer", "floating"]: + return IndexCol( + name, values=converted, kind=kind, typ=atom, index_name=index_name + ) + else: + assert isinstance(converted, np.ndarray) and converted.dtype == object + assert kind == "object", kind + atom = _tables().ObjectAtom() + return IndexCol(name, converted, kind, atom, index_name=index_name) + + +def _unconvert_index(data, kind: str, encoding: str, errors: str) -> np.ndarray | Index: + index: Index | np.ndarray + + if kind.startswith("datetime64"): + if kind == "datetime64": + # created before we stored resolution information + index = DatetimeIndex(data, copy=False) + else: + index = DatetimeIndex(data.view(kind), copy=False) + elif kind.startswith("timedelta64"): + if kind == "timedelta64": + # created before we stored resolution information + index = TimedeltaIndex(data, copy=False) + else: + index = TimedeltaIndex(data.view(kind), copy=False) + elif kind == "date": + try: + index = np.asarray([date.fromordinal(v) for v in data], dtype=object) + except ValueError: + index = np.asarray([date.fromtimestamp(v) for v in data], dtype=object) + elif kind in ("integer", "float", "bool"): + index = np.asarray(data) + elif kind in ("string"): + index = _unconvert_string_array( + data, nan_rep=None, encoding=encoding, errors=errors + ) + elif kind == "object": + index = np.asarray(data[0]) + else: # pragma: no cover + raise ValueError(f"unrecognized index type {kind}") + return index + + +def _maybe_convert_for_string_atom( + name: str, + bvalues: ArrayLike, + existing_col, + min_itemsize, + nan_rep, + encoding, + errors, + columns: list[str], +): + if isinstance(bvalues.dtype, StringDtype): + bvalues = bvalues.to_numpy() + if bvalues.dtype != object: + return bvalues + + bvalues = cast(np.ndarray, bvalues) + + dtype_name = bvalues.dtype.name + inferred_type = lib.infer_dtype(bvalues, skipna=False) + + if inferred_type == "date": + raise TypeError("[date] is not implemented as a table column") + if inferred_type == "datetime": + # after GH#8260 + # this only would be hit for a multi-timezone dtype which is an error + raise TypeError( + "too many timezones in this block, create separate data columns" + ) + + if not (inferred_type == "string" or dtype_name == "object"): + return bvalues + + mask = isna(bvalues) + data = bvalues.copy() + data[mask] = nan_rep + + if existing_col and mask.any() and len(nan_rep) > existing_col.itemsize: + raise ValueError("NaN representation is too large for existing column size") + + # see if we have a valid string type + inferred_type = lib.infer_dtype(data, skipna=False) + if inferred_type != "string": + # we cannot serialize this data, so report an exception on a column + # by column basis + + # expected behaviour: + # search block for a non-string object column by column + for i in range(data.shape[0]): + col = data[i] + inferred_type = lib.infer_dtype(col, skipna=False) + if inferred_type != "string": + error_column_label = columns[i] if len(columns) > i else f"No.{i}" + raise TypeError( + f"Cannot serialize the column [{error_column_label}]\n" + f"because its data contents are not [string] but " + f"[{inferred_type}] object dtype" + ) + + # itemsize is the maximum length of a string (along any dimension) + + data_converted = _convert_string_array(data, encoding, errors).reshape(data.shape) + itemsize = data_converted.itemsize + + # specified min_itemsize? + if isinstance(min_itemsize, dict): + min_itemsize = int(min_itemsize.get(name) or min_itemsize.get("values") or 0) + itemsize = max(min_itemsize or 0, itemsize) + + # check for column in the values conflicts + if existing_col is not None: + eci = existing_col.validate_col(itemsize) + if eci is not None and eci > itemsize: + itemsize = eci + + data_converted = data_converted.astype(f"|S{itemsize}", copy=False) + return data_converted + + +def _convert_string_array(data: np.ndarray, encoding: str, errors: str) -> np.ndarray: + """ + Take a string-like that is object dtype and coerce to a fixed size string type. + + Parameters + ---------- + data : np.ndarray[object] + encoding : str + errors : str + Handler for encoding errors. + + Returns + ------- + np.ndarray[fixed-length-string] + """ + # encode if needed + if len(data): + data = ( + Series(data.ravel(), copy=False, dtype="object") + .str.encode(encoding, errors) + ._values.reshape(data.shape) + ) + + # create the sized dtype + ensured = ensure_object(data.ravel()) + itemsize = max(1, libwriters.max_len_string_array(ensured)) + + data = np.asarray(data, dtype=f"S{itemsize}") + return data + + +def _unconvert_string_array( + data: np.ndarray, nan_rep, encoding: str, errors: str +) -> np.ndarray: + """ + Inverse of _convert_string_array. + + Parameters + ---------- + data : np.ndarray[fixed-length-string] + nan_rep : the storage repr of NaN + encoding : str + errors : str + Handler for encoding errors. + + Returns + ------- + np.ndarray[object] + Decoded data. + """ + shape = data.shape + data = np.asarray(data.ravel(), dtype=object) + + if len(data): + itemsize = libwriters.max_len_string_array(ensure_object(data)) + dtype = f"U{itemsize}" + + if isinstance(data[0], bytes): + ser = Series(data, copy=False).str.decode( + encoding, errors=errors, dtype="object" + ) + data = ser.to_numpy() + data.flags.writeable = True + else: + data = data.astype(dtype, copy=False).astype(object, copy=False) + + if nan_rep is None: + nan_rep = "nan" + + libwriters.string_array_replace_from_nan_rep(data, nan_rep) + return data.reshape(shape) + + +def _maybe_convert(values: np.ndarray, val_kind: str, encoding: str, errors: str): + assert isinstance(val_kind, str), type(val_kind) + if _need_convert(val_kind): + conv = _get_converter(val_kind, encoding, errors) + values = conv(values) + return values + + +def _get_converter(kind: str, encoding: str, errors: str): + if kind == "datetime64": + return lambda x: np.asarray(x, dtype="M8[ns]") + elif "datetime64" in kind: + return lambda x: np.asarray(x, dtype=kind) + elif kind == "string": + return lambda x: _unconvert_string_array( + x, nan_rep=None, encoding=encoding, errors=errors + ) + else: # pragma: no cover + raise ValueError(f"invalid kind {kind}") + + +def _need_convert(kind: str) -> bool: + if kind in ("datetime64", "string") or "datetime64" in kind: + return True + return False + + +def _maybe_adjust_name(name: str, version: Sequence[int]) -> str: + """ + Prior to 0.10.1, we named values blocks like: values_block_0 and the + name values_0, adjust the given name if necessary. + + Parameters + ---------- + name : str + version : Tuple[int, int, int] + + Returns + ------- + str + """ + if isinstance(version, str) or len(version) < 3: + raise ValueError("Version is incorrect, expected sequence of 3 integers.") + + if version[0] == 0 and version[1] <= 10 and version[2] == 0: + m = re.search(r"values_block_(\d+)", name) + if m: + grp = m.groups()[0] + name = f"values_{grp}" + return name + + +def _dtype_to_kind(dtype_str: str) -> str: + """ + Find the "kind" string describing the given dtype name. + """ + if dtype_str.startswith(("string", "bytes")): + kind = "string" + elif dtype_str.startswith("float"): + kind = "float" + elif dtype_str.startswith("complex"): + kind = "complex" + elif dtype_str.startswith(("int", "uint")): + kind = "integer" + elif dtype_str.startswith("datetime64"): + kind = dtype_str + elif dtype_str.startswith("timedelta"): + kind = dtype_str + elif dtype_str.startswith("bool"): + kind = "bool" + elif dtype_str.startswith("category"): + kind = "category" + elif dtype_str.startswith("period"): + # We store the `freq` attr so we can restore from integers + kind = "integer" + elif dtype_str == "object": + kind = "object" + elif dtype_str == "str": + kind = "str" + else: + raise ValueError(f"cannot interpret dtype of [{dtype_str}]") + + return kind + + +def _get_data_and_dtype_name(data: ArrayLike): + """ + Convert the passed data into a storable form and a dtype string. + """ + if isinstance(data, Categorical): + data = data.codes + + if isinstance(data.dtype, DatetimeTZDtype): + # For datetime64tz we need to drop the TZ in tests TODO: why? + dtype_name = f"datetime64[{data.dtype.unit}]" + else: + dtype_name = data.dtype.name + + if data.dtype.kind in "mM": + data = np.asarray(data.view("i8")) + # TODO: we used to reshape for the dt64tz case, but no longer + # doing that doesn't seem to break anything. why? + + elif isinstance(data, PeriodIndex): + data = data.asi8 + + data = np.asarray(data) + return data, dtype_name + + +class Selection: + """ + Carries out a selection operation on a tables.Table object. + + Parameters + ---------- + table : a Table object + where : list of Terms (or convertible to) + start, stop: indices to start and/or stop selection + + """ + + def __init__( + self, + table: Table, + where=None, + start: int | None = None, + stop: int | None = None, + ) -> None: + self.table = table + self.where = where + self.start = start + self.stop = stop + self.condition = None + self.filter = None + self.terms = None + self.coordinates = None + + if is_list_like(where): + # see if we have a passed coordinate like + with suppress(ValueError): + inferred = lib.infer_dtype(where, skipna=False) + if inferred in ("integer", "boolean"): + where = np.asarray(where) + if where.dtype == np.bool_: + start, stop = self.start, self.stop + if start is None: + start = 0 + if stop is None: + stop = self.table.nrows + self.coordinates = np.arange(start, stop)[where] + elif issubclass(where.dtype.type, np.integer): + if (self.start is not None and (where < self.start).any()) or ( + self.stop is not None and (where >= self.stop).any() + ): + raise ValueError( + "where must have index locations >= start and < stop" + ) + self.coordinates = where + + if self.coordinates is None: + self.terms = self.generate(where) + + # create the numexpr & the filter + if self.terms is not None: + self.condition, self.filter = self.terms.evaluate() + + @overload + def generate(self, where: dict | list | tuple | str) -> PyTablesExpr: ... + + @overload + def generate(self, where: None) -> None: ... + + def generate(self, where: dict | list | tuple | str | None) -> PyTablesExpr | None: + """where can be a : dict,list,tuple,string""" + if where is None: + return None + + q = self.table.queryables() + try: + return PyTablesExpr(where, queryables=q, encoding=self.table.encoding) + except NameError as err: + # raise a nice message, suggesting that the user should use + # data_columns + qkeys = ",".join(q.keys()) + msg = dedent( + f"""\ + The passed where expression: {where} + contains an invalid variable reference + all of the variable references must be a reference to + an axis (e.g. 'index' or 'columns'), or a data_column + The currently defined references are: {qkeys} + """ + ) + raise ValueError(msg) from err + + def select(self): + """ + generate the selection + """ + if self.condition is not None: + return self.table.table.read_where( + self.condition.format(), start=self.start, stop=self.stop + ) + elif self.coordinates is not None: + return self.table.table.read_coordinates(self.coordinates) + return self.table.table.read(start=self.start, stop=self.stop) + + def select_coords(self): + """ + generate the selection + """ + start, stop = self.start, self.stop + nrows = self.table.nrows + if start is None: + start = 0 + elif start < 0: + start += nrows + if stop is None: + stop = nrows + elif stop < 0: + stop += nrows + + if self.condition is not None: + return self.table.table.get_where_list( + self.condition.format(), start=start, stop=stop, sort=True + ) + elif self.coordinates is not None: + return self.coordinates + + return np.arange(start, stop) diff --git a/pandas/io/spss.py b/pandas/io/spss.py new file mode 100644 index 0000000000000000000000000000000000000000..522c7206a2ae55322232f2be1031cbfd30d1fdfd --- /dev/null +++ b/pandas/io/spss.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, +) + +from pandas._libs import lib +from pandas.compat._optional import import_optional_dependency +from pandas.util._decorators import set_module +from pandas.util._validators import check_dtype_backend + +from pandas.core.dtypes.inference import is_list_like + +from pandas.io.common import stringify_path + +if TYPE_CHECKING: + from collections.abc import Sequence + from pathlib import Path + + from pandas._typing import DtypeBackend + + from pandas import DataFrame + + +@set_module("pandas") +def read_spss( + path: str | Path, + usecols: Sequence[str] | None = None, + convert_categoricals: bool = True, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, + **kwargs: Any, +) -> DataFrame: + """ + Load an SPSS file from the file path, returning a DataFrame. + + Parameters + ---------- + path : str or Path + File path. + usecols : list-like, optional + Return a subset of the columns. If None, return all columns. + convert_categoricals : bool, default is True + Convert categorical columns into pd.Categorical. + dtype_backend : {'numpy_nullable', 'pyarrow'} + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). If not specified, the default behavior + is to not use nullable data types. If specified, the behavior + is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + * ``"pyarrow"``: returns pyarrow-backed + nullable :class:`ArrowDtype` :class:`DataFrame` + + .. versionadded:: 2.0 + **kwargs + Additional keyword arguments that can be passed to :func:`pyreadstat.read_sav`. + + .. versionadded:: 3.0 + + Returns + ------- + DataFrame + DataFrame based on the SPSS file. + + See Also + -------- + read_csv : Read a comma-separated values (csv) file into a pandas DataFrame. + read_excel : Read an Excel file into a pandas DataFrame. + read_sas : Read an SAS file into a pandas DataFrame. + read_orc : Load an ORC object into a pandas DataFrame. + read_feather : Load a feather-format object into a pandas DataFrame. + + Examples + -------- + >>> df = pd.read_spss("spss_data.sav") # doctest: +SKIP + """ + pyreadstat = import_optional_dependency("pyreadstat") + check_dtype_backend(dtype_backend) + + if usecols is not None: + if not is_list_like(usecols): + raise TypeError("usecols must be list-like.") + usecols = list(usecols) # pyreadstat requires a list + + df, metadata = pyreadstat.read_sav( + stringify_path(path), + usecols=usecols, + apply_value_formats=convert_categoricals, + **kwargs, + ) + df.attrs = metadata.__dict__ + if dtype_backend is not lib.no_default: + df = df.convert_dtypes(dtype_backend=dtype_backend) + return df diff --git a/pandas/io/sql.py b/pandas/io/sql.py new file mode 100644 index 0000000000000000000000000000000000000000..52adbd42c4479804cfbbb30bf3e769f5f3645106 --- /dev/null +++ b/pandas/io/sql.py @@ -0,0 +1,2960 @@ +""" +Collection of query wrappers / abstractions to both facilitate data +retrieval and to reduce dependency on DB-specific API. +""" + +from __future__ import annotations + +from abc import ( + ABC, + abstractmethod, +) +from contextlib import ( + ExitStack, + contextmanager, +) +from datetime import ( + date, + datetime, + time, +) +from functools import partial +import re +from typing import ( + TYPE_CHECKING, + Any, + Literal, + Self, + cast, + overload, +) +import warnings + +import numpy as np + +from pandas._config import using_string_dtype + +from pandas._libs import lib +from pandas.compat._optional import ( + VERSIONS, + import_optional_dependency, +) +from pandas.errors import ( + AbstractMethodError, + DatabaseError, +) +from pandas.util._decorators import set_module +from pandas.util._exceptions import find_stack_level +from pandas.util._validators import check_dtype_backend + +from pandas.core.dtypes.common import ( + is_dict_like, + is_list_like, + is_object_dtype, + is_string_dtype, +) +from pandas.core.dtypes.dtypes import DatetimeTZDtype +from pandas.core.dtypes.missing import isna + +from pandas import get_option +from pandas.core.api import ( + DataFrame, + Series, +) +from pandas.core.arrays import ArrowExtensionArray +from pandas.core.arrays.string_ import StringDtype +from pandas.core.base import PandasObject +import pandas.core.common as com +from pandas.core.common import maybe_make_list +from pandas.core.internals.construction import convert_object_array +from pandas.core.tools.datetimes import to_datetime + +from pandas.io._util import arrow_table_to_pandas + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Generator, + Iterator, + Mapping, + ) + + from sqlalchemy import Table + from sqlalchemy.sql.expression import ( + Delete, + Select, + TextClause, + ) + + from pandas._typing import ( + DtypeArg, + DtypeBackend, + IndexLabel, + ) + + from pandas import Index + +# ----------------------------------------------------------------------------- +# -- Helper functions + + +def _process_parse_dates_argument(parse_dates): + """Process parse_dates argument for read_sql functions""" + # handle non-list entries for parse_dates gracefully + if parse_dates is True or parse_dates is None or parse_dates is False: + parse_dates = [] + + elif not hasattr(parse_dates, "__iter__"): + parse_dates = [parse_dates] + return parse_dates + + +def _handle_date_column( + col, utc: bool = False, format: str | dict[str, Any] | None = None +): + if isinstance(format, dict): + # GH35185 Allow custom error values in parse_dates argument of + # read_sql like functions. + # Format can take on custom to_datetime argument values such as + # {"errors": "coerce"} or {"dayfirst": True} + return to_datetime(col, **format) + else: + # Allow passing of formatting string for integers + # GH17855 + if format is None and ( + issubclass(col.dtype.type, np.floating) + or issubclass(col.dtype.type, np.integer) + ): + format = "s" + if format in ["D", "d", "h", "m", "s", "ms", "us", "ns"]: + return to_datetime(col, errors="coerce", unit=format, utc=utc) + elif isinstance(col.dtype, DatetimeTZDtype): + # coerce to UTC timezone + # GH11216 + return to_datetime(col, utc=True) + else: + return to_datetime(col, errors="coerce", format=format, utc=utc) + + +def _parse_date_columns(data_frame: DataFrame, parse_dates) -> DataFrame: + """ + Force non-datetime columns to be read as such. + Supports both string formatted and integer timestamp columns. + """ + parse_dates = _process_parse_dates_argument(parse_dates) + + # we want to coerce datetime64_tz dtypes for now to UTC + # we could in theory do a 'nice' conversion from a FixedOffset tz + # GH11216 + for i, (col_name, df_col) in enumerate(data_frame.items()): + if isinstance(df_col.dtype, DatetimeTZDtype) or col_name in parse_dates: + try: + fmt = parse_dates[col_name] + except (KeyError, TypeError): + fmt = None + data_frame.isetitem(i, _handle_date_column(df_col, format=fmt)) + + return data_frame + + +def _convert_arrays_to_dataframe( + data, + columns, + coerce_float: bool = True, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", +) -> DataFrame: + content = lib.to_object_array_tuples(data) + idx_len = content.shape[0] + arrays = convert_object_array( + list(content.T), + dtype=None, + coerce_float=coerce_float, + dtype_backend=dtype_backend, + ) + if dtype_backend == "pyarrow": + pa = import_optional_dependency("pyarrow") + + result_arrays = [] + for arr in arrays: + pa_array = pa.array(arr, from_pandas=True) + if arr.dtype == "string": + # TODO: Arrow still infers strings arrays as regular strings instead + # of large_string, which is what we preserver everywhere else for + # dtype_backend="pyarrow". We may want to reconsider this + pa_array = pa_array.cast(pa.string()) + result_arrays.append(ArrowExtensionArray(pa_array)) + arrays = result_arrays # type: ignore[assignment] + if arrays: + return DataFrame._from_arrays( + arrays, columns=columns, index=range(idx_len), verify_integrity=False + ) + else: + return DataFrame(columns=columns) + + +def _wrap_result( + data, + columns, + index_col=None, + coerce_float: bool = True, + parse_dates=None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", +) -> DataFrame: + """Wrap result set of a SQLAlchemy query in a DataFrame.""" + frame = _convert_arrays_to_dataframe(data, columns, coerce_float, dtype_backend) + + if dtype: + frame = frame.astype(dtype) + + frame = _parse_date_columns(frame, parse_dates) + + if index_col is not None: + frame = frame.set_index(index_col) + + return frame + + +def _wrap_result_adbc( + df: DataFrame, + *, + index_col=None, + parse_dates=None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", +) -> DataFrame: + """Wrap result set of a SQLAlchemy query in a DataFrame.""" + if dtype: + df = df.astype(dtype) + + df = _parse_date_columns(df, parse_dates) + + if index_col is not None: + df = df.set_index(index_col) + + return df + + +# ----------------------------------------------------------------------------- +# -- Read and write to DataFrames + + +@overload +def read_sql_table( # pyright: ignore[reportOverlappingOverload] + table_name: str, + con, + schema=..., + index_col: str | list[str] | None = ..., + coerce_float=..., + parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ..., + columns: list[str] | None = ..., + chunksize: None = ..., + dtype_backend: DtypeBackend | lib.NoDefault = ..., +) -> DataFrame: ... + + +@overload +def read_sql_table( + table_name: str, + con, + schema=..., + index_col: str | list[str] | None = ..., + coerce_float=..., + parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ..., + columns: list[str] | None = ..., + chunksize: int = ..., + dtype_backend: DtypeBackend | lib.NoDefault = ..., +) -> Iterator[DataFrame]: ... + + +@set_module("pandas") +def read_sql_table( + table_name: str, + con, + schema: str | None = None, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = None, + columns: list[str] | None = None, + chunksize: int | None = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, +) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL database table into a DataFrame. + + Given a table name and a SQLAlchemy connectable, returns a DataFrame. + This function does not support DBAPI connections. + + Parameters + ---------- + table_name : str + Name of SQL table in database. + con : SQLAlchemy connectable or str + A database URI could be provided as str. + SQLite DBAPI connection mode not supported. + schema : str, default None + Name of SQL schema in database to query (if database flavor + supports this). Uses default schema if None (default). + index_col : str or list of str, optional, default: None + Column(s) to set as index(MultiIndex). + coerce_float : bool, default True + Attempts to convert values of non-string, non-numeric objects (like + decimal.Decimal) to floating point. Can result in loss of Precision. + parse_dates : list or dict, default None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg dict}``, where the arg dict corresponds + to the keyword arguments of :func:`pandas.to_datetime` + Especially useful with databases without native Datetime support, + such as SQLite. + columns : list, default None + List of column names to select from SQL table. + chunksize : int, default None + If specified, returns an iterator where `chunksize` is the number of + rows to include in each chunk. + dtype_backend : {'numpy_nullable', 'pyarrow'} + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). If not specified, the default behavior + is to not use nullable data types. If specified, the behavior + is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + * ``"pyarrow"``: returns pyarrow-backed nullable + :class:`ArrowDtype` :class:`DataFrame` + + .. versionadded:: 2.0 + + Returns + ------- + DataFrame or Iterator[DataFrame] + A SQL table is returned as two-dimensional data structure with labeled + axes. + + See Also + -------- + read_sql_query : Read SQL query into a DataFrame. + read_sql : Read SQL query or database table into a DataFrame. + + Notes + ----- + Any datetime values with time zone information will be converted to UTC. + + Examples + -------- + >>> pd.read_sql_table("table_name", "postgres:///db_name") # doctest:+SKIP + """ + + check_dtype_backend(dtype_backend) + if dtype_backend is lib.no_default: + dtype_backend = "numpy" # type: ignore[assignment] + assert dtype_backend is not lib.no_default + + with pandasSQL_builder(con, schema=schema, need_transaction=True) as pandas_sql: + if not pandas_sql.has_table(table_name): + raise ValueError(f"Table {table_name} not found") + + table = pandas_sql.read_table( + table_name, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + columns=columns, + chunksize=chunksize, + dtype_backend=dtype_backend, + ) + + if table is not None: + return table + else: + raise ValueError(f"Table {table_name} not found", con) + + +@overload +def read_sql_query( # pyright: ignore[reportOverlappingOverload] + sql, + con, + index_col: str | list[str] | None = ..., + coerce_float=..., + params: list[Any] | Mapping[str, Any] | None = ..., + parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ..., + chunksize: None = ..., + dtype: DtypeArg | None = ..., + dtype_backend: DtypeBackend | lib.NoDefault = ..., +) -> DataFrame: ... + + +@overload +def read_sql_query( + sql, + con, + index_col: str | list[str] | None = ..., + coerce_float=..., + params: list[Any] | Mapping[str, Any] | None = ..., + parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ..., + chunksize: int = ..., + dtype: DtypeArg | None = ..., + dtype_backend: DtypeBackend | lib.NoDefault = ..., +) -> Iterator[DataFrame]: ... + + +@set_module("pandas") +def read_sql_query( + sql, + con, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + params: list[Any] | Mapping[str, Any] | None = None, + parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, +) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL query into a DataFrame. + + Returns a DataFrame corresponding to the result set of the query + string. Optionally provide an `index_col` parameter to use one of the + columns as the index, otherwise default integer index will be used. + + Parameters + ---------- + sql : str SQL query or SQLAlchemy Selectable (select or text object) + SQL query to be executed. + con : SQLAlchemy connectable, str, or sqlite3 connection + Using SQLAlchemy makes it possible to use any DB supported by that + library. If a DBAPI2 object, only sqlite3 is supported. + index_col : str or list of str, optional, default: None + Column(s) to set as index(MultiIndex). + coerce_float : bool, default True + Attempts to convert values of non-string, non-numeric objects (like + decimal.Decimal) to floating point. Useful for SQL result sets. + params : list, tuple or mapping, optional, default: None + List of parameters to pass to execute method. The syntax used + to pass parameters is database driver dependent. Check your + database driver documentation for which of the five syntax styles, + described in PEP 249's paramstyle, is supported. + Eg. for psycopg2, uses %(name)s so use params={'name' : 'value'}. + parse_dates : list or dict, default: None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times, or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg dict}``, where the arg dict corresponds + to the keyword arguments of :func:`pandas.to_datetime` + Especially useful with databases without native Datetime support, + such as SQLite. + chunksize : int, default None + If specified, return an iterator where `chunksize` is the number of + rows to include in each chunk. + dtype : Type name or dict of columns + Data type for data or columns. E.g. np.float64 or + {'a': np.float64, 'b': np.int32, 'c': 'Int64'}. + dtype_backend : {'numpy_nullable', 'pyarrow'} + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). If not specified, the default behavior + is to not use nullable data types. If specified, the behavior + is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + * ``"pyarrow"``: returns pyarrow-backed nullable + :class:`ArrowDtype` :class:`DataFrame` + + .. versionadded:: 2.0 + + Returns + ------- + DataFrame or Iterator[DataFrame] + Returns a DataFrame object that contains the result set of the + executed SQL query, in relation to the specified database connection. + + See Also + -------- + read_sql_table : Read SQL database table into a DataFrame. + read_sql : Read SQL query or database table into a DataFrame. + + Notes + ----- + Any datetime values with time zone information parsed via the `parse_dates` + parameter will be converted to UTC. + + Examples + -------- + >>> from sqlalchemy import create_engine # doctest: +SKIP + >>> engine = create_engine("sqlite:///database.db") # doctest: +SKIP + >>> sql_query = "SELECT int_column FROM test_data" # doctest: +SKIP + >>> with engine.connect() as conn, conn.begin(): # doctest: +SKIP + ... data = pd.read_sql_query(sql_query, conn) # doctest: +SKIP + """ + + check_dtype_backend(dtype_backend) + if dtype_backend is lib.no_default: + dtype_backend = "numpy" # type: ignore[assignment] + assert dtype_backend is not lib.no_default + + with pandasSQL_builder(con) as pandas_sql: + return pandas_sql.read_query( + sql, + index_col=index_col, + params=params, + coerce_float=coerce_float, + parse_dates=parse_dates, + chunksize=chunksize, + dtype=dtype, + dtype_backend=dtype_backend, + ) + + +@overload +def read_sql( # pyright: ignore[reportOverlappingOverload] + sql, + con, + index_col: str | list[str] | None = ..., + coerce_float=..., + params=..., + parse_dates=..., + columns: list[str] = ..., + chunksize: None = ..., + dtype_backend: DtypeBackend | lib.NoDefault = ..., + dtype: DtypeArg | None = None, +) -> DataFrame: ... + + +@overload +def read_sql( + sql, + con, + index_col: str | list[str] | None = ..., + coerce_float=..., + params=..., + parse_dates=..., + columns: list[str] = ..., + chunksize: int = ..., + dtype_backend: DtypeBackend | lib.NoDefault = ..., + dtype: DtypeArg | None = None, +) -> Iterator[DataFrame]: ... + + +@set_module("pandas") +def read_sql( + sql, + con, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + params=None, + parse_dates=None, + columns: list[str] | None = None, + chunksize: int | None = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, + dtype: DtypeArg | None = None, +) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL query or database table into a DataFrame. + + This function is a convenience wrapper around ``read_sql_table`` and + ``read_sql_query`` (for backward compatibility). It will delegate + to the specific function depending on the provided input. A SQL query + will be routed to ``read_sql_query``, while a database table name will + be routed to ``read_sql_table``. Note that the delegated function might + have more specific notes about their functionality not listed here. + + Parameters + ---------- + sql : str or SQLAlchemy Selectable (select or text object) + SQL query to be executed or a table name. + con : ADBC Connection, SQLAlchemy connectable, str, or sqlite3 connection + ADBC provides high performance I/O with native type support, where available. + Using SQLAlchemy makes it possible to use any DB supported by that + library. If a DBAPI2 object, only sqlite3 is supported. The user is responsible + for engine disposal and connection closure for the ADBC connection and + SQLAlchemy connectable; str connections are closed automatically. See + `here `_. + index_col : str or list of str, optional, default: None + Column(s) to set as index(MultiIndex). + coerce_float : bool, default True + Attempts to convert values of non-string, non-numeric objects (like + decimal.Decimal) to floating point, useful for SQL result sets. + params : list, tuple or dict, optional, default: None + List of parameters to pass to execute method. The syntax used + to pass parameters is database driver dependent. Check your + database driver documentation for which of the five syntax styles, + described in PEP 249's paramstyle, is supported. + Eg. for psycopg2, uses %(name)s so use params={'name' : 'value'}. + parse_dates : list or dict, default: None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times, or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg dict}``, where the arg dict corresponds + to the keyword arguments of :func:`pandas.to_datetime` + Especially useful with databases without native Datetime support, + such as SQLite. + columns : list, default: None + List of column names to select from SQL table (only used when reading + a table). + chunksize : int, default None + If specified, return an iterator where `chunksize` is the + number of rows to include in each chunk. + dtype_backend : {'numpy_nullable', 'pyarrow'} + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). If not specified, the default behavior + is to not use nullable data types. If specified, the behavior + is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + * ``"pyarrow"``: returns pyarrow-backed nullable + :class:`ArrowDtype` :class:`DataFrame` + + .. versionadded:: 2.0 + dtype : Type name or dict of columns + Data type for data or columns. E.g. np.float64 or + {'a': np.float64, 'b': np.int32, 'c': 'Int64'}. + The argument is ignored if a table is passed instead of a query. + + .. versionadded:: 2.0.0 + + Returns + ------- + DataFrame or Iterator[DataFrame] + Returns a DataFrame object that contains the result set of the + executed SQL query or an SQL Table based on the provided input, + in relation to the specified database connection. + + See Also + -------- + read_sql_table : Read SQL database table into a DataFrame. + read_sql_query : Read SQL query into a DataFrame. + + Notes + ----- + ``pandas`` does not attempt to sanitize SQL statements; + instead it simply forwards the statement you are executing + to the underlying driver, which may or may not sanitize from there. + Please refer to the underlying driver documentation for any details. + Generally, be wary when accepting statements from arbitrary sources. + + Examples + -------- + Read data from SQL via either a SQL query or a SQL tablename. + When using a SQLite database only SQL queries are accepted, + providing only the SQL tablename will result in an error. + + >>> from sqlite3 import connect + >>> conn = connect(":memory:") + >>> df = pd.DataFrame( + ... data=[[0, "10/11/12"], [1, "12/11/10"]], + ... columns=["int_column", "date_column"], + ... ) + >>> df.to_sql(name="test_data", con=conn) + 2 + + >>> pd.read_sql("SELECT int_column, date_column FROM test_data", conn) + int_column date_column + 0 0 10/11/12 + 1 1 12/11/10 + + >>> pd.read_sql("test_data", "postgres:///db_name") # doctest:+SKIP + + For parameterized query, using ``params`` is recommended over string interpolation. + + >>> from sqlalchemy import text + >>> sql = text( + ... "SELECT int_column, date_column FROM test_data WHERE int_column=:int_val" + ... ) + >>> pd.read_sql(sql, conn, params={"int_val": 1}) # doctest:+SKIP + int_column date_column + 0 1 12/11/10 + + Apply date parsing to columns through the ``parse_dates`` argument + The ``parse_dates`` argument calls ``pd.to_datetime`` on the provided columns. + Custom argument values for applying ``pd.to_datetime`` on a column are specified + via a dictionary format: + + >>> pd.read_sql( + ... "SELECT int_column, date_column FROM test_data", + ... conn, + ... parse_dates={"date_column": {"format": "%d/%m/%y"}}, + ... ) + int_column date_column + 0 0 2012-11-10 + 1 1 2010-11-12 + + .. versionadded:: 2.2.0 + + pandas now supports reading via ADBC drivers + + >>> from adbc_driver_postgresql import dbapi # doctest:+SKIP + >>> with dbapi.connect("postgres:///db_name") as conn: # doctest:+SKIP + ... pd.read_sql("SELECT int_column FROM test_data", conn) + int_column + 0 0 + 1 1 + """ + + check_dtype_backend(dtype_backend) + if dtype_backend is lib.no_default: + dtype_backend = "numpy" # type: ignore[assignment] + assert dtype_backend is not lib.no_default + + with pandasSQL_builder(con) as pandas_sql: + if isinstance(pandas_sql, SQLiteDatabase): + return pandas_sql.read_query( + sql, + index_col=index_col, + params=params, + coerce_float=coerce_float, + parse_dates=parse_dates, + chunksize=chunksize, + dtype_backend=dtype_backend, + dtype=dtype, + ) + + try: + _is_table_name = pandas_sql.has_table(sql) + except Exception: + # using generic exception to catch errors from sql drivers (GH24988) + _is_table_name = False + + if _is_table_name: + return pandas_sql.read_table( + sql, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + columns=columns, + chunksize=chunksize, + dtype_backend=dtype_backend, + ) + else: + return pandas_sql.read_query( + sql, + index_col=index_col, + params=params, + coerce_float=coerce_float, + parse_dates=parse_dates, + chunksize=chunksize, + dtype_backend=dtype_backend, + dtype=dtype, + ) + + +def to_sql( + frame, + name: str, + con, + schema: str | None = None, + if_exists: Literal["fail", "replace", "append", "delete_rows"] = "fail", + index: bool = True, + index_label: IndexLabel | None = None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + method: Literal["multi"] | Callable | None = None, + engine: str = "auto", + **engine_kwargs, +) -> int | None: + """ + Write records stored in a DataFrame to a SQL database. + + .. warning:: + The pandas library does not attempt to sanitize inputs provided via a to_sql call. + Please refer to the documentation for the underlying database driver to see if it + will properly prevent injection, or alternatively be advised of a security risk when + executing arbitrary commands in a to_sql call. + + Parameters + ---------- + frame : DataFrame, Series + name : str + Name of SQL table. + con : ADBC Connection, SQLAlchemy connectable, str, or sqlite3 connection + or sqlite3 DBAPI2 connection + ADBC provides high performance I/O with native type support, where available. + Using SQLAlchemy makes it possible to use any DB supported by that + library. + If a DBAPI2 object, only sqlite3 is supported. + schema : str, optional + Name of SQL schema in database to write to (if database flavor + supports this). If None, use default schema (default). + if_exists : {'fail', 'replace', 'append', 'delete_rows'}, default 'fail' + - fail: If table exists, do nothing. + - replace: If table exists, drop it, recreate it, and insert data. + - append: If table exists, insert data. Create if does not exist. + - delete_rows: If a table exists, delete all records and insert data. + index : bool, default True + Write DataFrame index as a column. + index_label : str or sequence, optional + Column label for index column(s). If None is given (default) and + `index` is True, then the index names are used. + A sequence should be given if the DataFrame uses MultiIndex. + chunksize : int, optional + Specify the number of rows in each batch to be written at a time. + By default, all rows will be written at once. + dtype : dict or scalar, optional + Specifying the datatype for columns. If a dictionary is used, the + keys should be the column names and the values should be the + SQLAlchemy types or strings for the sqlite3 fallback mode. If a + scalar is provided, it will be applied to all columns. + method : {None, 'multi', callable}, optional + Controls the SQL insertion clause used: + + - None : Uses standard SQL ``INSERT`` clause (one per row). + - ``'multi'``: Pass multiple values in a single ``INSERT`` clause. + - callable with signature ``(pd_table, conn, keys, data_iter) -> int | None``. + + Details and a sample callable implementation can be found in the + section :ref:`insert method `. + engine : {'auto', 'sqlalchemy'}, default 'auto' + SQL engine library to use. If 'auto', then the option + ``io.sql.engine`` is used. The default ``io.sql.engine`` + behavior is 'sqlalchemy' + + **engine_kwargs + Any additional kwargs are passed to the engine. + + Returns + ------- + None or int + Number of rows affected by to_sql. None is returned if the callable + passed into ``method`` does not return an integer number of rows. + + Notes + ----- + The returned rows affected is the sum of the ``rowcount`` attribute of ``sqlite3.Cursor`` + or SQLAlchemy connectable. If using ADBC the returned rows are the result + of ``Cursor.adbc_ingest``. The returned value may not reflect the exact number of written + rows as stipulated in the + `sqlite3 `__ or + `SQLAlchemy `__ + """ # noqa: E501 + if if_exists not in ("fail", "replace", "append", "delete_rows"): + raise ValueError(f"'{if_exists}' is not valid for if_exists") + + if isinstance(frame, Series): + frame = frame.to_frame() + elif not isinstance(frame, DataFrame): + raise NotImplementedError( + "'frame' argument should be either a Series or a DataFrame" + ) + + with pandasSQL_builder(con, schema=schema, need_transaction=True) as pandas_sql: + return pandas_sql.to_sql( + frame, + name, + if_exists=if_exists, + index=index, + index_label=index_label, + schema=schema, + chunksize=chunksize, + dtype=dtype, + method=method, + engine=engine, + **engine_kwargs, + ) + + +def has_table(table_name: str, con, schema: str | None = None) -> bool: + """ + Check if DataBase has named table. + + Parameters + ---------- + table_name: string + Name of SQL table. + con: ADBC Connection, SQLAlchemy connectable, str, or sqlite3 connection + ADBC provides high performance I/O with native type support, where available. + Using SQLAlchemy makes it possible to use any DB supported by that + library. + If a DBAPI2 object, only sqlite3 is supported. + schema : string, default None + Name of SQL schema in database to write to (if database flavor supports + this). If None, use default schema (default). + + Returns + ------- + boolean + """ + with pandasSQL_builder(con, schema=schema) as pandas_sql: + return pandas_sql.has_table(table_name) + + +table_exists = has_table + + +def pandasSQL_builder( + con, + schema: str | None = None, + need_transaction: bool = False, +) -> PandasSQL: + """ + Convenience function to return the correct PandasSQL subclass based on the + provided parameters. Also creates a sqlalchemy connection and transaction + if necessary. + """ + import sqlite3 + + if isinstance(con, sqlite3.Connection) or con is None: + return SQLiteDatabase(con) + + sqlalchemy = import_optional_dependency("sqlalchemy", errors="ignore") + + if isinstance(con, str) and sqlalchemy is None: + raise ImportError( + f"Using URI string without version '{VERSIONS['sqlalchemy']}' or newer " + "of 'sqlalchemy' installed." + ) + + if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)): + return SQLDatabase(con, schema, need_transaction) + + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + if adbc and isinstance(con, adbc.Connection): + return ADBCDatabase(con) + + warnings.warn( + "pandas only supports SQLAlchemy connectable (engine/connection) or " + "database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 " + "objects are not tested. Please consider using SQLAlchemy.", + UserWarning, + stacklevel=find_stack_level(), + ) + return SQLiteDatabase(con) + + +class SQLTable(PandasObject): + """ + For mapping Pandas tables to SQL tables. + Uses fact that table is reflected by SQLAlchemy to + do better type conversions. + Also holds various flags needed to avoid having to + pass them between functions all the time. + """ + + # TODO: support for multiIndex + + def __init__( + self, + name: str, + pandas_sql_engine, + frame=None, + index: bool | str | list[str] | None = True, + if_exists: Literal["fail", "replace", "append", "delete_rows"] = "fail", + prefix: str = "pandas", + index_label=None, + schema=None, + keys=None, + dtype: DtypeArg | None = None, + ) -> None: + self.name = name + self.pd_sql = pandas_sql_engine + self.prefix = prefix + self.frame = frame + self.index = self._index_name(index, index_label) + self.schema = schema + self.if_exists = if_exists + self.keys = keys + self.dtype = dtype + + if frame is not None: + # We want to initialize based on a dataframe + self.table = self._create_table_setup() + else: + # no data provided, read-only mode + self.table = self.pd_sql.get_table(self.name, self.schema) + + if self.table is None: + raise ValueError(f"Could not init table '{name}'") + + if not len(self.name): + raise ValueError("Empty table name specified") + + def exists(self): + return self.pd_sql.has_table(self.name, self.schema) + + def sql_schema(self) -> str: + from sqlalchemy.schema import CreateTable + + return str(CreateTable(self.table).compile(self.pd_sql.con)) + + def _execute_create(self) -> None: + # Inserting table into database, add to MetaData object + self.table = self.table.to_metadata(self.pd_sql.meta) + with self.pd_sql.run_transaction(): + self.table.create(bind=self.pd_sql.con) + + def create(self) -> None: + if self.exists(): + if self.if_exists == "fail": + raise ValueError(f"Table '{self.name}' already exists.") + elif self.if_exists == "replace": + self.pd_sql.drop_table(self.name, self.schema) + self._execute_create() + elif self.if_exists == "append": + pass + elif self.if_exists == "delete_rows": + self.pd_sql.delete_rows(self.name, self.schema) + else: + raise ValueError(f"'{self.if_exists}' is not valid for if_exists") + else: + self._execute_create() + + def _execute_insert(self, conn, keys: list[str], data_iter) -> int: + """ + Execute SQL statement inserting data + + Parameters + ---------- + conn : sqlalchemy.engine.Engine or sqlalchemy.engine.Connection + keys : list of str + Column names + data_iter : generator of list + Each item contains a list of values to be inserted + """ + data = [dict(zip(keys, row, strict=True)) for row in data_iter] + result = self.pd_sql.execute(self.table.insert(), data) + return result.rowcount + + def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int: + """ + Alternative to _execute_insert for DBs support multi-value INSERT. + + Note: multi-value insert is usually faster for analytics DBs + and tables containing a few columns + but performance degrades quickly with increase of columns. + + """ + + from sqlalchemy import insert + + data = [dict(zip(keys, row, strict=True)) for row in data_iter] + stmt = insert(self.table).values(data) + result = self.pd_sql.execute(stmt) + return result.rowcount + + def insert_data(self) -> tuple[list[str], list[np.ndarray]]: + if self.index is not None: + temp = self.frame.copy(deep=False) + temp.index.names = self.index + try: + temp.reset_index(inplace=True) + except ValueError as err: + raise ValueError(f"duplicate name in index/columns: {err}") from err + else: + temp = self.frame + + column_names = list(map(str, temp.columns)) + ncols = len(column_names) + # this just pre-allocates the list: None's will be replaced with ndarrays + # error: List item 0 has incompatible type "None"; expected "ndarray" + data_list: list[np.ndarray] = [None] * ncols # type: ignore[list-item] + + for i, (_, ser) in enumerate(temp.items()): + if ser.dtype.kind == "M": + if isinstance(ser._values, ArrowExtensionArray): + import pyarrow as pa + + if pa.types.is_date(ser.dtype.pyarrow_dtype): + # GH#53854 to_pydatetime not supported for pyarrow date dtypes + d = ser._values.to_numpy(dtype=object) + else: + d = ser.dt.to_pydatetime()._values + else: + d = ser._values.to_pydatetime() + elif ser.dtype.kind == "m": + vals = ser._values + if isinstance(vals, ArrowExtensionArray): + vals = vals.to_numpy(dtype=np.dtype("m8[ns]")) + # store as integers, see GH#6921, GH#7076 + d = vals.view("i8").astype(object) + else: + d = ser._values.astype(object) + + assert isinstance(d, np.ndarray), type(d) + + if ser._can_hold_na: + # Note: this will miss timedeltas since they are converted to int + mask = isna(d) + d[mask] = None + + data_list[i] = d + + return column_names, data_list + + def insert( + self, + chunksize: int | None = None, + method: Literal["multi"] | Callable | None = None, + ) -> int | None: + # set insert method + if method is None: + exec_insert = self._execute_insert + elif method == "multi": + exec_insert = self._execute_insert_multi + elif callable(method): + exec_insert = partial(method, self) + else: + raise ValueError(f"Invalid parameter `method`: {method}") + + keys, data_list = self.insert_data() + + nrows = len(self.frame) + + if nrows == 0: + return 0 + + if chunksize is None: + chunksize = nrows + elif chunksize == 0: + raise ValueError("chunksize argument should be non-zero") + + chunks = (nrows // chunksize) + 1 + total_inserted = None + with self.pd_sql.run_transaction() as conn: + for i in range(chunks): + start_i = i * chunksize + end_i = min((i + 1) * chunksize, nrows) + if start_i >= end_i: + break + + chunk_iter = zip( + *(arr[start_i:end_i] for arr in data_list), strict=True + ) + num_inserted = exec_insert(conn, keys, chunk_iter) + # GH 46891 + if num_inserted is not None: + if total_inserted is None: + total_inserted = num_inserted + else: + total_inserted += num_inserted + return total_inserted + + def _query_iterator( + self, + result, + exit_stack: ExitStack, + chunksize: int | None, + columns, + coerce_float: bool = True, + parse_dates=None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> Generator[DataFrame]: + """Return generator through chunked result set.""" + has_read_data = False + with exit_stack: + while True: + data = result.fetchmany(chunksize) + if not data: + if not has_read_data: + yield DataFrame.from_records( + [], columns=columns, coerce_float=coerce_float + ) + break + + has_read_data = True + self.frame = _convert_arrays_to_dataframe( + data, columns, coerce_float, dtype_backend + ) + + self._harmonize_columns( + parse_dates=parse_dates, dtype_backend=dtype_backend + ) + + if self.index is not None: + self.frame.set_index(self.index, inplace=True) + + yield self.frame + + def read( + self, + exit_stack: ExitStack, + coerce_float: bool = True, + parse_dates=None, + columns=None, + chunksize: int | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + from sqlalchemy import select + + if columns is not None and len(columns) > 0: + cols = [self.table.c[n] for n in columns] + if self.index is not None: + for idx in self.index[::-1]: + cols.insert(0, self.table.c[idx]) + sql_select = select(*cols) + else: + sql_select = select(self.table) + result = self.pd_sql.execute(sql_select) + column_names = result.keys() + + if chunksize is not None: + return self._query_iterator( + result, + exit_stack, + chunksize, + column_names, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype_backend=dtype_backend, + ) + else: + data = result.fetchall() + self.frame = _convert_arrays_to_dataframe( + data, column_names, coerce_float, dtype_backend + ) + + self._harmonize_columns( + parse_dates=parse_dates, dtype_backend=dtype_backend + ) + + if self.index is not None: + self.frame.set_index(self.index, inplace=True) + + return self.frame + + def _index_name(self, index, index_label): + # for writing: index=True to include index in sql table + if index is True: + nlevels = self.frame.index.nlevels + # if index_label is specified, set this as index name(s) + if index_label is not None: + if not isinstance(index_label, list): + index_label = [index_label] + if len(index_label) != nlevels: + raise ValueError( + "Length of 'index_label' should match number of " + f"levels, which is {nlevels}" + ) + return index_label + # return the used column labels for the index columns + if ( + nlevels == 1 + and "index" not in self.frame.columns + and self.frame.index.name is None + ): + return ["index"] + else: + return com.fill_missing_names(self.frame.index.names) + + # for reading: index=(list of) string to specify column to set as index + elif isinstance(index, str): + return [index] + elif isinstance(index, list): + return index + else: + return None + + def _get_column_names_and_types(self, dtype_mapper): + column_names_and_types = [] + if self.index is not None: + for i, idx_label in enumerate(self.index): + idx_type = dtype_mapper(self.frame.index._get_level_values(i)) + column_names_and_types.append((str(idx_label), idx_type, True)) + + column_names_and_types += [ + (str(self.frame.columns[i]), dtype_mapper(self.frame.iloc[:, i]), False) + for i in range(len(self.frame.columns)) + ] + + return column_names_and_types + + def _create_table_setup(self): + from sqlalchemy import ( + Column, + PrimaryKeyConstraint, + Table, + ) + from sqlalchemy.schema import MetaData + + column_names_and_types = self._get_column_names_and_types(self._sqlalchemy_type) + + columns: list[Any] = [ + Column(name, typ, index=is_index) + for name, typ, is_index in column_names_and_types + ] + + if self.keys is not None: + if not is_list_like(self.keys): + keys = [self.keys] + else: + keys = self.keys + pkc = PrimaryKeyConstraint(*keys, name=self.name + "_pk") + columns.append(pkc) + + schema = self.schema or self.pd_sql.meta.schema + + # At this point, attach to new metadata, only attach to self.meta + # once table is created. + meta = MetaData() + return Table(self.name, meta, *columns, schema=schema) + + def _harmonize_columns( + self, + parse_dates=None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> None: + """ + Make the DataFrame's column types align with the SQL table + column types. + Need to work around limited NA value support. Floats are always + fine, ints must always be floats if there are Null values. + Booleans are hard because converting bool column with None replaces + all Nones with false. Therefore only convert bool if there are no + NA values. + Datetimes should already be converted to np.datetime64 if supported, + but here we also force conversion if required. + """ + parse_dates = _process_parse_dates_argument(parse_dates) + + for sql_col in self.table.columns: + col_name = sql_col.name + try: + df_col = self.frame[col_name] + + # Handle date parsing upfront; don't try to convert columns + # twice + if col_name in parse_dates: + try: + fmt = parse_dates[col_name] + except TypeError: + fmt = None + self.frame[col_name] = _handle_date_column(df_col, format=fmt) + continue + + # the type the dataframe column should have + col_type = self._get_dtype(sql_col.type) + + if ( + col_type is datetime + or col_type is date + or col_type is DatetimeTZDtype + ): + # Convert tz-aware Datetime SQL columns to UTC + utc = col_type is DatetimeTZDtype + self.frame[col_name] = _handle_date_column(df_col, utc=utc) + elif dtype_backend == "numpy" and col_type is float: + # floats support NA, can always convert! + self.frame[col_name] = df_col.astype(col_type) + elif ( + using_string_dtype() + and is_string_dtype(col_type) + and is_object_dtype(self.frame[col_name]) + ): + self.frame[col_name] = df_col.astype(col_type) + elif dtype_backend == "numpy" and len(df_col) == df_col.count(): + # No NA values, can convert ints and bools + if col_type is np.dtype("int64") or col_type is bool: + self.frame[col_name] = df_col.astype(col_type) + except KeyError: + pass # this column not in results + + def _sqlalchemy_type(self, col: Index | Series): + dtype: DtypeArg = self.dtype or {} + if is_dict_like(dtype): + dtype = cast(dict, dtype) + if col.name in dtype: + return dtype[col.name] + + # Infer type of column, while ignoring missing values. + # Needed for inserting typed data containing NULLs, GH 8778. + col_type = lib.infer_dtype(col, skipna=True) + + from sqlalchemy.types import ( + TIMESTAMP, + BigInteger, + Boolean, + Date, + DateTime, + Float, + Integer, + SmallInteger, + Text, + Time, + ) + + if col_type in ("datetime64", "datetime"): + # GH 9086: TIMESTAMP is the suggested type if the column contains + # timezone information + try: + # error: Item "Index" of "Union[Index, Series]" has no attribute "dt" + if col.dt.tz is not None: # type: ignore[union-attr] + return TIMESTAMP(timezone=True) + except AttributeError: + # The column is actually a DatetimeIndex + # GH 26761 or an Index with date-like data e.g. 9999-01-01 + if getattr(col, "tz", None) is not None: + return TIMESTAMP(timezone=True) + return DateTime + if col_type == "timedelta64": + warnings.warn( + "the 'timedelta' type is not supported, and will be " + "written as integer values (ns frequency) to the database.", + UserWarning, + stacklevel=find_stack_level(), + ) + return BigInteger + elif col_type == "floating": + if col.dtype == "float32": + return Float(precision=23) + else: + return Float(precision=53) + elif col_type == "integer": + # GH35076 Map pandas integer to optimal SQLAlchemy integer type + if col.dtype.name.lower() in ("int8", "uint8", "int16"): + return SmallInteger + elif col.dtype.name.lower() in ("uint16", "int32"): + return Integer + elif col.dtype.name.lower() == "uint64": + raise ValueError("Unsigned 64 bit integer datatype is not supported") + else: + return BigInteger + elif col_type == "boolean": + return Boolean + elif col_type == "date": + return Date + elif col_type == "time": + return Time + elif col_type == "complex": + raise ValueError("Complex datatypes not supported") + + return Text + + def _get_dtype(self, sqltype): + from sqlalchemy.types import ( + TIMESTAMP, + Boolean, + Date, + DateTime, + Float, + Integer, + String, + ) + + if isinstance(sqltype, Float): + return float + elif isinstance(sqltype, Integer): + # TODO: Refine integer size. + return np.dtype("int64") + elif isinstance(sqltype, TIMESTAMP): + # we have a timezone capable type + if not sqltype.timezone: + return datetime + return DatetimeTZDtype + elif isinstance(sqltype, DateTime): + # Caution: np.datetime64 is also a subclass of np.number. + return datetime + elif isinstance(sqltype, Date): + return date + elif isinstance(sqltype, Boolean): + return bool + elif isinstance(sqltype, String): + if using_string_dtype(): + return StringDtype(na_value=np.nan) + + return object + + +class PandasSQL(PandasObject, ABC): + """ + Subclasses Should define read_query and to_sql. + """ + + def __enter__(self) -> Self: + return self + + def __exit__(self, *args) -> None: + pass + + def read_table( + self, + table_name: str, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates=None, + columns=None, + schema: str | None = None, + chunksize: int | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + raise NotImplementedError + + @abstractmethod + def read_query( + self, + sql: str, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates=None, + params=None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + pass + + @abstractmethod + def to_sql( + self, + frame, + name: str, + if_exists: Literal["fail", "replace", "append", "delete_rows"] = "fail", + index: bool = True, + index_label=None, + schema=None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + method: Literal["multi"] | Callable | None = None, + engine: str = "auto", + **engine_kwargs, + ) -> int | None: + pass + + @abstractmethod + def execute(self, sql: str | Select | TextClause, params=None): + pass + + @abstractmethod + def has_table(self, name: str, schema: str | None = None) -> bool: + pass + + @abstractmethod + def _create_sql_schema( + self, + frame: DataFrame, + table_name: str, + keys: list[str] | None = None, + dtype: DtypeArg | None = None, + schema: str | None = None, + ) -> str: + pass + + +class BaseEngine: + def insert_records( + self, + table: SQLTable, + con, + frame, + name: str, + index: bool | str | list[str] | None = True, + schema=None, + chunksize: int | None = None, + method=None, + **engine_kwargs, + ) -> int | None: + """ + Inserts data into already-prepared table + """ + raise AbstractMethodError(self) + + +class SQLAlchemyEngine(BaseEngine): + def __init__(self) -> None: + import_optional_dependency( + "sqlalchemy", extra="sqlalchemy is required for SQL support." + ) + + def insert_records( + self, + table: SQLTable, + con, + frame, + name: str, + index: bool | str | list[str] | None = True, + schema=None, + chunksize: int | None = None, + method=None, + **engine_kwargs, + ) -> int | None: + from sqlalchemy import exc + + try: + return table.insert(chunksize=chunksize, method=method) + except exc.StatementError as err: + # GH34431 + # https://stackoverflow.com/a/67358288/6067848 + msg = r"""(\(1054, "Unknown column 'inf(e0)?' in 'field list'"\))(?# + )|inf can not be used with MySQL""" + err_text = str(err.orig) + if re.search(msg, err_text): + raise ValueError("inf cannot be used with MySQL") from err + raise err + + +def get_engine(engine: str) -> BaseEngine: + """return our implementation""" + if engine == "auto": + engine = get_option("io.sql.engine") + + if engine == "auto": + # try engines in this order + engine_classes = [SQLAlchemyEngine] + + error_msgs = "" + for engine_class in engine_classes: + try: + return engine_class() + except ImportError as err: + error_msgs += "\n - " + str(err) + + raise ImportError( + "Unable to find a usable engine; " + "tried using: 'sqlalchemy'.\n" + "A suitable version of " + "sqlalchemy is required for sql I/O " + "support.\n" + "Trying to import the above resulted in these errors:" + f"{error_msgs}" + ) + + if engine == "sqlalchemy": + return SQLAlchemyEngine() + + raise ValueError("engine must be one of 'auto', 'sqlalchemy'") + + +class SQLDatabase(PandasSQL): + """ + This class enables conversion between DataFrame and SQL databases + using SQLAlchemy to handle DataBase abstraction. + + Parameters + ---------- + con : SQLAlchemy Connectable or URI string. + Connectable to connect with the database. Using SQLAlchemy makes it + possible to use any DB supported by that library. + schema : string, default None + Name of SQL schema in database to write to (if database flavor + supports this). If None, use default schema (default). + need_transaction : bool, default False + If True, SQLDatabase will create a transaction. + + """ + + def __init__( + self, con, schema: str | None = None, need_transaction: bool = False + ) -> None: + from sqlalchemy import create_engine + from sqlalchemy.engine import Engine + from sqlalchemy.schema import MetaData + + # self.exit_stack cleans up the Engine and Connection and commits the + # transaction if any of those objects was created below. + # Cleanup happens either in self.__exit__ or at the end of the iterator + # returned by read_sql when chunksize is not None. + self.exit_stack = ExitStack() + if isinstance(con, str): + con = create_engine(con) + self.exit_stack.callback(con.dispose) + if isinstance(con, Engine): + con = self.exit_stack.enter_context(con.connect()) + if need_transaction and not con.in_transaction(): + self.exit_stack.enter_context(con.begin()) + self.con = con + self.meta = MetaData(schema=schema) + self.returns_generator = False + + def __exit__(self, *args) -> None: + if not self.returns_generator: + self.exit_stack.close() + + @contextmanager + def run_transaction(self): + if not self.con.in_transaction(): + with self.con.begin(): + yield self.con + else: + yield self.con + + def execute(self, sql: str | Select | TextClause | Delete, params=None): + """Simple passthrough to SQLAlchemy connectable""" + from sqlalchemy.exc import SQLAlchemyError + + args = [] if params is None else [params] + if isinstance(sql, str): + execute_function = self.con.exec_driver_sql + else: + execute_function = self.con.execute + + try: + return execute_function(sql, *args) + except SQLAlchemyError as exc: + raise DatabaseError(f"Execution failed on sql '{sql}': {exc}") from exc + + def read_table( + self, + table_name: str, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates=None, + columns=None, + schema: str | None = None, + chunksize: int | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL database table into a DataFrame. + + Parameters + ---------- + table_name : str + Name of SQL table in database. + index_col : string, optional, default: None + Column to set as index. + coerce_float : bool, default True + Attempts to convert values of non-string, non-numeric objects + (like decimal.Decimal) to floating point. This can result in + loss of precision. + parse_dates : list or dict, default: None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times, or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg}``, where the arg corresponds + to the keyword arguments of :func:`pandas.to_datetime`. + Especially useful with databases without native Datetime support, + such as SQLite. + columns : list, default: None + List of column names to select from SQL table. + schema : string, default None + Name of SQL schema in database to query (if database flavor + supports this). If specified, this overwrites the default + schema of the SQL database object. + chunksize : int, default None + If specified, return an iterator where `chunksize` is the number + of rows to include in each chunk. + dtype_backend : {'numpy_nullable', 'pyarrow'} + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). If not specified, the default behavior + is to not use nullable data types. If specified, the behavior + is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + * ``"pyarrow"``: returns pyarrow-backed nullable + :class:`ArrowDtype` :class:`DataFrame` + + .. versionadded:: 2.0 + + Returns + ------- + DataFrame + + See Also + -------- + pandas.read_sql_table + SQLDatabase.read_query + + """ + self.meta.reflect(bind=self.con, only=[table_name], views=True) + table = SQLTable(table_name, self, index=index_col, schema=schema) + if chunksize is not None: + self.returns_generator = True + return table.read( + self.exit_stack, + coerce_float=coerce_float, + parse_dates=parse_dates, + columns=columns, + chunksize=chunksize, + dtype_backend=dtype_backend, + ) + + @staticmethod + def _query_iterator( + result, + exit_stack: ExitStack, + chunksize: int, + columns, + index_col=None, + coerce_float: bool = True, + parse_dates=None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> Generator[DataFrame]: + """Return generator through chunked result set""" + has_read_data = False + with exit_stack: + while True: + data = result.fetchmany(chunksize) + if not data: + if not has_read_data: + yield _wrap_result( + [], + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + break + + has_read_data = True + yield _wrap_result( + data, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + + def read_query( + self, + sql: str, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates=None, + params=None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL query into a DataFrame. + + Parameters + ---------- + sql : str + SQL query to be executed. + index_col : string, optional, default: None + Column name to use as index for the returned DataFrame object. + coerce_float : bool, default True + Attempt to convert values of non-string, non-numeric objects (like + decimal.Decimal) to floating point, useful for SQL result sets. + params : list, tuple or dict, optional, default: None + List of parameters to pass to execute method. The syntax used + to pass parameters is database driver dependent. Check your + database driver documentation for which of the five syntax styles, + described in PEP 249's paramstyle, is supported. + Eg. for psycopg2, uses %(name)s so use params={'name' : 'value'} + parse_dates : list or dict, default: None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times, or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg dict}``, where the arg dict + corresponds to the keyword arguments of + :func:`pandas.to_datetime` Especially useful with databases + without native Datetime support, such as SQLite. + chunksize : int, default None + If specified, return an iterator where `chunksize` is the number + of rows to include in each chunk. + dtype : Type name or dict of columns + Data type for data or columns. E.g. np.float64 or + {'a': np.float64, 'b': np.int32, 'c': 'Int64'} + + Returns + ------- + DataFrame + + See Also + -------- + read_sql_table : Read SQL database table into a DataFrame. + read_sql + + """ + result = self.execute(sql, params) + columns = result.keys() + + if chunksize is not None: + self.returns_generator = True + return self._query_iterator( + result, + self.exit_stack, + chunksize, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + else: + data = result.fetchall() + frame = _wrap_result( + data, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + return frame + + read_sql = read_query + + def prep_table( + self, + frame, + name: str, + if_exists: Literal["fail", "replace", "append", "delete_rows"] = "fail", + index: bool | str | list[str] | None = True, + index_label=None, + schema=None, + dtype: DtypeArg | None = None, + ) -> SQLTable: + """ + Prepares table in the database for data insertion. Creates it if needed, etc. + """ + if dtype: + if not is_dict_like(dtype): + # error: Value expression in dictionary comprehension has incompatible + # type "Union[ExtensionDtype, str, dtype[Any], Type[object], + # Dict[Hashable, Union[ExtensionDtype, Union[str, dtype[Any]], + # Type[str], Type[float], Type[int], Type[complex], Type[bool], + # Type[object]]]]"; expected type "Union[ExtensionDtype, str, + # dtype[Any], Type[object]]" + dtype = dict.fromkeys(frame, dtype) # type: ignore[arg-type] + else: + dtype = cast(dict, dtype) + + from sqlalchemy.types import TypeEngine + + for col, my_type in dtype.items(): + if isinstance(my_type, type) and issubclass(my_type, TypeEngine): + pass + elif isinstance(my_type, TypeEngine): + pass + else: + raise ValueError(f"The type of {col} is not a SQLAlchemy type") + + table = SQLTable( + name, + self, + frame=frame, + index=index, + if_exists=if_exists, + index_label=index_label, + schema=schema, + dtype=dtype, + ) + table.create() + return table + + def check_case_sensitive( + self, + name: str, + schema: str | None, + ) -> None: + """ + Checks table name for issues with case-sensitivity. + Method is called after data is inserted. + """ + if not name.isdigit() and not name.islower(): + # check for potentially case sensitivity issues (GH7815) + # Only check when name is not a number and name is not lower case + from sqlalchemy import inspect as sqlalchemy_inspect + + insp = sqlalchemy_inspect(self.con) + table_names = insp.get_table_names(schema=schema or self.meta.schema) + if name not in table_names: + msg = ( + f"The provided table name '{name}' is not found exactly as " + "such in the database after writing the table, possibly " + "due to case sensitivity issues. Consider using lower " + "case table names." + ) + warnings.warn( + msg, + UserWarning, + stacklevel=find_stack_level(), + ) + + def to_sql( + self, + frame, + name: str, + if_exists: Literal["fail", "replace", "append", "delete_rows"] = "fail", + index: bool = True, + index_label=None, + schema: str | None = None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + method: Literal["multi"] | Callable | None = None, + engine: str = "auto", + **engine_kwargs, + ) -> int | None: + """ + Write records stored in a DataFrame to a SQL database. + + Parameters + ---------- + frame : DataFrame + name : string + Name of SQL table. + if_exists : {'fail', 'replace', 'append', 'delete_rows'}, default 'fail' + - fail: If table exists, do nothing. + - replace: If table exists, drop it, recreate it, and insert data. + - append: If table exists, insert data. Create if does not exist. + - delete_rows: If a table exists, delete all records and insert data. + index : boolean, default True + Write DataFrame index as a column. + index_label : string or sequence, default None + Column label for index column(s). If None is given (default) and + `index` is True, then the index names are used. + A sequence should be given if the DataFrame uses MultiIndex. + schema : string, default None + Name of SQL schema in database to write to (if database flavor + supports this). If specified, this overwrites the default + schema of the SQLDatabase object. + chunksize : int, default None + If not None, then rows will be written in batches of this size at a + time. If None, all rows will be written at once. + dtype : single type or dict of column name to SQL type, default None + Optional specifying the datatype for columns. The SQL type should + be a SQLAlchemy type. If all columns are of the same type, one + single value can be used. + method : {None', 'multi', callable}, default None + Controls the SQL insertion clause used: + + * None : Uses standard SQL ``INSERT`` clause (one per row). + * 'multi': Pass multiple values in a single ``INSERT`` clause. + * callable with signature ``(pd_table, conn, keys, data_iter)``. + + Details and a sample callable implementation can be found in the + section :ref:`insert method `. + engine : {'auto', 'sqlalchemy'}, default 'auto' + SQL engine library to use. If 'auto', then the option + ``io.sql.engine`` is used. The default ``io.sql.engine`` + behavior is 'sqlalchemy' + + **engine_kwargs + Any additional kwargs are passed to the engine. + """ + sql_engine = get_engine(engine) + + table = self.prep_table( + frame=frame, + name=name, + if_exists=if_exists, + index=index, + index_label=index_label, + schema=schema, + dtype=dtype, + ) + + total_inserted = sql_engine.insert_records( + table=table, + con=self.con, + frame=frame, + name=name, + index=index, + schema=schema, + chunksize=chunksize, + method=method, + **engine_kwargs, + ) + + self.check_case_sensitive(name=name, schema=schema) + return total_inserted + + @property + def tables(self): + return self.meta.tables + + def has_table(self, name: str, schema: str | None = None) -> bool: + from sqlalchemy import inspect as sqlalchemy_inspect + + insp = sqlalchemy_inspect(self.con) + return insp.has_table(name, schema or self.meta.schema) + + def get_table(self, table_name: str, schema: str | None = None) -> Table: + from sqlalchemy import ( + Numeric, + Table, + ) + + schema = schema or self.meta.schema + tbl = Table(table_name, self.meta, autoload_with=self.con, schema=schema) + for column in tbl.columns: + if isinstance(column.type, Numeric): + column.type.asdecimal = False + return tbl + + def drop_table(self, table_name: str, schema: str | None = None) -> None: + schema = schema or self.meta.schema + if self.has_table(table_name, schema): + self.meta.reflect( + bind=self.con, only=[table_name], schema=schema, views=True + ) + with self.run_transaction(): + self.get_table(table_name, schema).drop(bind=self.con) + self.meta.clear() + + def delete_rows(self, table_name: str, schema: str | None = None) -> None: + schema = schema or self.meta.schema + if self.has_table(table_name, schema): + self.meta.reflect( + bind=self.con, only=[table_name], schema=schema, views=True + ) + table = self.get_table(table_name, schema) + self.execute(table.delete()).close() + self.meta.clear() + + def _create_sql_schema( + self, + frame: DataFrame, + table_name: str, + keys: list[str] | None = None, + dtype: DtypeArg | None = None, + schema: str | None = None, + ) -> str: + table = SQLTable( + table_name, + self, + frame=frame, + index=False, + keys=keys, + dtype=dtype, + schema=schema, + ) + return str(table.sql_schema()) + + +# ---- SQL without SQLAlchemy --- + + +class ADBCDatabase(PandasSQL): + """ + This class enables conversion between DataFrame and SQL databases + using ADBC to handle DataBase abstraction. + + Parameters + ---------- + con : adbc_driver_manager.dbapi.Connection + """ + + def __init__(self, con) -> None: + self.con = con + + @contextmanager + def run_transaction(self): + with self.con.cursor() as cur: + try: + yield cur + except Exception: + self.con.rollback() + raise + self.con.commit() + + def execute(self, sql: str | Select | TextClause, params=None): + from adbc_driver_manager import Error + + if not isinstance(sql, str): + raise TypeError("Query must be a string unless using sqlalchemy.") + args = [] if params is None else [params] + cur = self.con.cursor() + try: + cur.execute(sql, *args) + return cur + except Error as exc: + try: + self.con.rollback() + except Error as inner_exc: # pragma: no cover + ex = DatabaseError( + f"Execution failed on sql: {sql}\n{exc}\nunable to rollback" + ) + raise ex from inner_exc + + ex = DatabaseError(f"Execution failed on sql '{sql}': {exc}") + raise ex from exc + + def read_table( + self, + table_name: str, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates=None, + columns=None, + schema: str | None = None, + chunksize: int | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL database table into a DataFrame. + + Parameters + ---------- + table_name : str + Name of SQL table in database. + coerce_float : bool, default True + Raises NotImplementedError + parse_dates : list or dict, default: None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times, or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg}``, where the arg corresponds + to the keyword arguments of :func:`pandas.to_datetime`. + Especially useful with databases without native Datetime support, + such as SQLite. + columns : list, default: None + List of column names to select from SQL table. + schema : string, default None + Name of SQL schema in database to query (if database flavor + supports this). If specified, this overwrites the default + schema of the SQL database object. + chunksize : int, default None + Raises NotImplementedError + dtype_backend : {'numpy_nullable', 'pyarrow'} + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). If not specified, the default behavior + is to not use nullable data types. If specified, the behavior + is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + * ``"pyarrow"``: returns pyarrow-backed nullable + :class:`ArrowDtype` :class:`DataFrame` + + .. versionadded:: 2.0 + + Returns + ------- + DataFrame + + See Also + -------- + pandas.read_sql_table + SQLDatabase.read_query + + """ + if coerce_float is not True: + raise NotImplementedError( + "'coerce_float' is not implemented for ADBC drivers" + ) + if chunksize: + raise NotImplementedError("'chunksize' is not implemented for ADBC drivers") + + if columns: + if index_col: + index_select = maybe_make_list(index_col) + else: + index_select = [] + to_select = index_select + columns + select_list = ", ".join(f'"{x}"' for x in to_select) + else: + select_list = "*" + if schema: + stmt = f"SELECT {select_list} FROM {schema}.{table_name}" + else: + stmt = f"SELECT {select_list} FROM {table_name}" + + with self.execute(stmt) as cur: + pa_table = cur.fetch_arrow_table() + df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend) + + return _wrap_result_adbc( + df, + index_col=index_col, + parse_dates=parse_dates, + ) + + def read_query( + self, + sql: str, + index_col: str | list[str] | None = None, + coerce_float: bool = True, + parse_dates=None, + params=None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + """ + Read SQL query into a DataFrame. + + Parameters + ---------- + sql : str + SQL query to be executed. + index_col : string, optional, default: None + Column name to use as index for the returned DataFrame object. + coerce_float : bool, default True + Raises NotImplementedError + params : list, tuple or dict, optional, default: None + Raises NotImplementedError + parse_dates : list or dict, default: None + - List of column names to parse as dates. + - Dict of ``{column_name: format string}`` where format string is + strftime compatible in case of parsing string times, or is one of + (D, s, ns, ms, us) in case of parsing integer timestamps. + - Dict of ``{column_name: arg dict}``, where the arg dict + corresponds to the keyword arguments of + :func:`pandas.to_datetime` Especially useful with databases + without native Datetime support, such as SQLite. + chunksize : int, default None + Raises NotImplementedError + dtype : Type name or dict of columns + Data type for data or columns. E.g. np.float64 or + {'a': np.float64, 'b': np.int32, 'c': 'Int64'} + + Returns + ------- + DataFrame + + See Also + -------- + read_sql_table : Read SQL database table into a DataFrame. + read_sql + + """ + if coerce_float is not True: + raise NotImplementedError( + "'coerce_float' is not implemented for ADBC drivers" + ) + if params: + raise NotImplementedError("'params' is not implemented for ADBC drivers") + if chunksize: + raise NotImplementedError("'chunksize' is not implemented for ADBC drivers") + + with self.execute(sql) as cur: + pa_table = cur.fetch_arrow_table() + df = arrow_table_to_pandas(pa_table, dtype_backend=dtype_backend) + + return _wrap_result_adbc( + df, + index_col=index_col, + parse_dates=parse_dates, + dtype=dtype, + ) + + read_sql = read_query + + def to_sql( + self, + frame, + name: str, + if_exists: Literal["fail", "replace", "append", "delete_rows"] = "fail", + index: bool = True, + index_label=None, + schema: str | None = None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + method: Literal["multi"] | Callable | None = None, + engine: str = "auto", + **engine_kwargs, + ) -> int | None: + """ + Write records stored in a DataFrame to a SQL database. + + Parameters + ---------- + frame : DataFrame + name : string + Name of SQL table. + if_exists : {'fail', 'replace', 'append'}, default 'fail' + - fail: If table exists, do nothing. + - replace: If table exists, drop it, recreate it, and insert data. + - append: If table exists, insert data. Create if does not exist. + - delete_rows: If a table exists, delete all records and insert data. + index : boolean, default True + Write DataFrame index as a column. + index_label : string or sequence, default None + Raises NotImplementedError + schema : string, default None + Name of SQL schema in database to write to (if database flavor + supports this). If specified, this overwrites the default + schema of the SQLDatabase object. + chunksize : int, default None + Raises NotImplementedError + dtype : single type or dict of column name to SQL type, default None + Raises NotImplementedError + method : {None', 'multi', callable}, default None + Raises NotImplementedError + engine : {'auto', 'sqlalchemy'}, default 'auto' + Raises NotImplementedError if not set to 'auto' + """ + pa = import_optional_dependency("pyarrow") + from adbc_driver_manager import Error + + if index_label: + raise NotImplementedError( + "'index_label' is not implemented for ADBC drivers" + ) + if chunksize: + raise NotImplementedError("'chunksize' is not implemented for ADBC drivers") + if dtype: + raise NotImplementedError("'dtype' is not implemented for ADBC drivers") + if method: + raise NotImplementedError("'method' is not implemented for ADBC drivers") + if engine != "auto": + raise NotImplementedError( + "engine != 'auto' not implemented for ADBC drivers" + ) + + if schema: + table_name = f"{schema}.{name}" + else: + table_name = name + + # pandas if_exists="append" will still create the + # table if it does not exist; ADBC is more explicit with append/create + # as applicable modes, so the semantics get blurred across + # the libraries + mode = "create" + if self.has_table(name, schema): + if if_exists == "fail": + raise ValueError(f"Table '{table_name}' already exists.") + elif if_exists == "replace": + sql_statement = f"DROP TABLE {table_name}" + self.execute(sql_statement).close() + elif if_exists == "append": + mode = "append" + elif if_exists == "delete_rows": + mode = "append" + self.delete_rows(name, schema) + + try: + tbl = pa.Table.from_pandas(frame, preserve_index=index) + except pa.ArrowNotImplementedError as exc: + raise ValueError("datatypes not supported") from exc + + with self.con.cursor() as cur: + try: + total_inserted = cur.adbc_ingest( + table_name=name, data=tbl, mode=mode, db_schema_name=schema + ) + except Error as exc: + raise DatabaseError( + f"Failed to insert records on table={name} with {mode=}" + ) from exc + + self.con.commit() + return total_inserted + + def has_table(self, name: str, schema: str | None = None) -> bool: + meta = self.con.adbc_get_objects( + db_schema_filter=schema, table_name_filter=name + ).read_all() + + for catalog_schema in meta["catalog_db_schemas"].to_pylist(): + if not catalog_schema: + continue + for schema_record in catalog_schema: + if not schema_record: + continue + + for table_record in schema_record["db_schema_tables"]: + if table_record["table_name"] == name: + return True + + return False + + def delete_rows(self, name: str, schema: str | None = None) -> None: + table_name = f"{schema}.{name}" if schema else name + if self.has_table(name, schema): + self.execute(f"DELETE FROM {table_name}").close() + + def _create_sql_schema( + self, + frame: DataFrame, + table_name: str, + keys: list[str] | None = None, + dtype: DtypeArg | None = None, + schema: str | None = None, + ) -> str: + raise NotImplementedError("not implemented for adbc") + + +# sqlite-specific sql strings and handler class +# dictionary used for readability purposes +_SQL_TYPES = { + "string": "TEXT", + "floating": "REAL", + "integer": "INTEGER", + "datetime": "TIMESTAMP", + "date": "DATE", + "time": "TIME", + "boolean": "INTEGER", +} + + +def _get_unicode_name(name: object) -> str: + try: + uname = str(name).encode("utf-8", "strict").decode("utf-8") + except UnicodeError as err: + raise ValueError(f"Cannot convert identifier to UTF-8: '{name}'") from err + return uname + + +def _get_valid_sqlite_name(name: object) -> str: + # See https://stackoverflow.com/questions/6514274/how-do-you-escape-strings\ + # -for-sqlite-table-column-names-in-python + # Ensure the string can be encoded as UTF-8. + # Ensure the string does not include any NUL characters. + # Replace all " with "". + # Wrap the entire thing in double quotes. + + uname = _get_unicode_name(name) + if not len(uname): + raise ValueError("Empty table or column name specified") + + nul_index = uname.find("\x00") + if nul_index >= 0: + raise ValueError("SQLite identifier cannot contain NULs") + return '"' + uname.replace('"', '""') + '"' + + +class SQLiteTable(SQLTable): + """ + Patch the SQLTable for fallback support. + Instead of a table variable just use the Create Table statement. + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self._register_date_adapters() + + def _register_date_adapters(self) -> None: + # GH 8341 + # register an adapter callable for datetime.time object + import sqlite3 + + # this will transform time(12,34,56,789) into '12:34:56.000789' + # (this is what sqlalchemy does) + def _adapt_time(t) -> str: + # This is faster than strftime + return f"{t.hour:02d}:{t.minute:02d}:{t.second:02d}.{t.microsecond:06d}" + + # Also register adapters for date/datetime and co + # xref https://docs.python.org/3.12/library/sqlite3.html#adapter-and-converter-recipes + # Python 3.12+ doesn't auto-register adapters for us anymore + + adapt_date_iso = lambda val: val.isoformat() + adapt_datetime_iso = lambda val: val.isoformat(" ") + + sqlite3.register_adapter(time, _adapt_time) + + sqlite3.register_adapter(date, adapt_date_iso) + sqlite3.register_adapter(datetime, adapt_datetime_iso) + + convert_date = lambda val: date.fromisoformat(val.decode()) + convert_timestamp = lambda val: datetime.fromisoformat(val.decode()) + + sqlite3.register_converter("date", convert_date) + sqlite3.register_converter("timestamp", convert_timestamp) + + def sql_schema(self) -> str: + return str(";\n".join(self.table)) + + def _execute_create(self) -> None: + with self.pd_sql.run_transaction() as cur: + for stmt in self.table: + cur.execute(stmt) + + def insert_statement(self, *, num_rows: int) -> str: + names = list(map(str, self.frame.columns)) + wld = "?" # wildcard char + escape = _get_valid_sqlite_name + + if self.index is not None: + for idx in self.index[::-1]: + names.insert(0, idx) + + bracketed_names = [escape(column) for column in names] + col_names = ",".join(bracketed_names) + + row_wildcards = ",".join([wld] * len(names)) + wildcards = ",".join([f"({row_wildcards})" for _ in range(num_rows)]) + insert_statement = ( + f"INSERT INTO {escape(self.name)} ({col_names}) VALUES {wildcards}" + ) + return insert_statement + + def _execute_insert(self, conn, keys, data_iter) -> int: + from sqlite3 import Error + + data_list = list(data_iter) + try: + conn.executemany(self.insert_statement(num_rows=1), data_list) + except Error as exc: + raise DatabaseError("Execution failed") from exc + return conn.rowcount + + def _execute_insert_multi(self, conn, keys, data_iter) -> int: + data_list = list(data_iter) + flattened_data = [x for row in data_list for x in row] + conn.execute(self.insert_statement(num_rows=len(data_list)), flattened_data) + return conn.rowcount + + def _create_table_setup(self): + """ + Return a list of SQL statements that creates a table reflecting the + structure of a DataFrame. The first entry will be a CREATE TABLE + statement while the rest will be CREATE INDEX statements. + """ + column_names_and_types = self._get_column_names_and_types(self._sql_type_name) + escape = _get_valid_sqlite_name + + create_tbl_stmts = [ + escape(cname) + " " + ctype for cname, ctype, _ in column_names_and_types + ] + + if self.keys is not None and len(self.keys): + if not is_list_like(self.keys): + keys = [self.keys] + else: + keys = self.keys + cnames_br = ", ".join([escape(c) for c in keys]) + create_tbl_stmts.append( + f"CONSTRAINT {self.name}_pk PRIMARY KEY ({cnames_br})" + ) + if self.schema: + schema_name = self.schema + "." + else: + schema_name = "" + create_stmts = [ + "CREATE TABLE " + + schema_name + + escape(self.name) + + " (\n" + + ",\n ".join(create_tbl_stmts) + + "\n)" + ] + + ix_cols = [cname for cname, _, is_index in column_names_and_types if is_index] + if ix_cols: + cnames = "_".join(ix_cols) + cnames_br = ",".join([escape(c) for c in ix_cols]) + create_stmts.append( + "CREATE INDEX " + + escape("ix_" + self.name + "_" + cnames) + + "ON " + + escape(self.name) + + " (" + + cnames_br + + ")" + ) + + return create_stmts + + def _sql_type_name(self, col): + dtype: DtypeArg = self.dtype or {} + if is_dict_like(dtype): + dtype = cast(dict, dtype) + if col.name in dtype: + return dtype[col.name] + + # Infer type of column, while ignoring missing values. + # Needed for inserting typed data containing NULLs, GH 8778. + col_type = lib.infer_dtype(col, skipna=True) + + if col_type == "timedelta64": + warnings.warn( + "the 'timedelta' type is not supported, and will be " + "written as integer values (ns frequency) to the database.", + UserWarning, + stacklevel=find_stack_level(), + ) + col_type = "integer" + + elif col_type == "datetime64": + col_type = "datetime" + + elif col_type == "empty": + col_type = "string" + + elif col_type == "complex": + raise ValueError("Complex datatypes not supported") + + if col_type not in _SQL_TYPES: + col_type = "string" + + return _SQL_TYPES[col_type] + + +class SQLiteDatabase(PandasSQL): + """ + Version of SQLDatabase to support SQLite connections (fallback without + SQLAlchemy). This should only be used internally. + + Parameters + ---------- + con : sqlite connection object + + """ + + def __init__(self, con) -> None: + self.con = con + + @contextmanager + def run_transaction(self): + cur = self.con.cursor() + try: + yield cur + self.con.commit() + except Exception: + self.con.rollback() + raise + finally: + cur.close() + + def execute(self, sql: str | Select | TextClause, params=None): + from sqlite3 import Error + + if not isinstance(sql, str): + raise TypeError("Query must be a string unless using sqlalchemy.") + args = [] if params is None else [params] + cur = self.con.cursor() + try: + cur.execute(sql, *args) + return cur + except Error as exc: + try: + self.con.rollback() + except Error as inner_exc: # pragma: no cover + ex = DatabaseError( + f"Execution failed on sql: {sql}\n{exc}\nunable to rollback" + ) + raise ex from inner_exc + + ex = DatabaseError(f"Execution failed on sql '{sql}': {exc}") + raise ex from exc + + @staticmethod + def _query_iterator( + cursor, + chunksize: int, + columns, + index_col=None, + coerce_float: bool = True, + parse_dates=None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> Generator[DataFrame]: + """Return generator through chunked result set""" + has_read_data = False + while True: + data = cursor.fetchmany(chunksize) + if type(data) == tuple: + data = list(data) + if not data: + cursor.close() + if not has_read_data: + result = DataFrame.from_records( + [], columns=columns, coerce_float=coerce_float + ) + if dtype: + result = result.astype(dtype) + yield result + break + + has_read_data = True + yield _wrap_result( + data, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + + def read_query( + self, + sql, + index_col=None, + coerce_float: bool = True, + parse_dates=None, + params=None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + dtype_backend: DtypeBackend | Literal["numpy"] = "numpy", + ) -> DataFrame | Iterator[DataFrame]: + cursor = self.execute(sql, params) + columns = [col_desc[0] for col_desc in cursor.description] + + if chunksize is not None: + return self._query_iterator( + cursor, + chunksize, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + else: + data = self._fetchall_as_list(cursor) + cursor.close() + + frame = _wrap_result( + data, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + dtype_backend=dtype_backend, + ) + return frame + + def _fetchall_as_list(self, cur): + result = cur.fetchall() + if not isinstance(result, list): + result = list(result) + return result + + def to_sql( + self, + frame, + name: str, + if_exists: str = "fail", + index: bool = True, + index_label=None, + schema=None, + chunksize: int | None = None, + dtype: DtypeArg | None = None, + method: Literal["multi"] | Callable | None = None, + engine: str = "auto", + **engine_kwargs, + ) -> int | None: + """ + Write records stored in a DataFrame to a SQL database. + + Parameters + ---------- + frame: DataFrame + name: string + Name of SQL table. + if_exists: {'fail', 'replace', 'append', 'delete_rows'}, default 'fail' + fail: If table exists, do nothing. + replace: If table exists, drop it, recreate it, and insert data. + append: If table exists, insert data. Create if it does not exist. + delete_rows: If a table exists, delete all records and insert data. + index : bool, default True + Write DataFrame index as a column + index_label : string or sequence, default None + Column label for index column(s). If None is given (default) and + `index` is True, then the index names are used. + A sequence should be given if the DataFrame uses MultiIndex. + schema : string, default None + Ignored parameter included for compatibility with SQLAlchemy + version of ``to_sql``. + chunksize : int, default None + If not None, then rows will be written in batches of this + size at a time. If None, all rows will be written at once. + dtype : single type or dict of column name to SQL type, default None + Optional specifying the datatype for columns. The SQL type should + be a string. If all columns are of the same type, one single value + can be used. + method : {None, 'multi', callable}, default None + Controls the SQL insertion clause used: + + * None : Uses standard SQL ``INSERT`` clause (one per row). + * 'multi': Pass multiple values in a single ``INSERT`` clause. + * callable with signature ``(pd_table, conn, keys, data_iter)``. + + Details and a sample callable implementation can be found in the + section :ref:`insert method `. + """ + if dtype: + if not is_dict_like(dtype): + # error: Value expression in dictionary comprehension has incompatible + # type "Union[ExtensionDtype, str, dtype[Any], Type[object], + # Dict[Hashable, Union[ExtensionDtype, Union[str, dtype[Any]], + # Type[str], Type[float], Type[int], Type[complex], Type[bool], + # Type[object]]]]"; expected type "Union[ExtensionDtype, str, + # dtype[Any], Type[object]]" + dtype = dict.fromkeys(frame, dtype) # type: ignore[arg-type] + else: + dtype = cast(dict, dtype) + + for col, my_type in dtype.items(): + if not isinstance(my_type, str): + raise ValueError(f"{col} ({my_type}) not a string") + + table = SQLiteTable( + name, + self, + frame=frame, + index=index, + if_exists=if_exists, + index_label=index_label, + dtype=dtype, + ) + table.create() + return table.insert(chunksize, method) + + def has_table(self, name: str, schema: str | None = None) -> bool: + wld = "?" + query = f""" + SELECT + name + FROM + sqlite_master + WHERE + type IN ('table', 'view') + AND name={wld}; + """ + + return len(self.execute(query, [name]).fetchall()) > 0 + + def get_table(self, table_name: str, schema: str | None = None) -> None: + return None # not supported in fallback mode + + def drop_table(self, name: str, schema: str | None = None) -> None: + drop_sql = f"DROP TABLE {_get_valid_sqlite_name(name)}" + self.execute(drop_sql).close() + + def delete_rows(self, name: str, schema: str | None = None) -> None: + delete_sql = f"DELETE FROM {_get_valid_sqlite_name(name)}" + if self.has_table(name, schema): + self.execute(delete_sql).close() + + def _create_sql_schema( + self, + frame, + table_name: str, + keys=None, + dtype: DtypeArg | None = None, + schema: str | None = None, + ) -> str: + table = SQLiteTable( + table_name, + self, + frame=frame, + index=False, + keys=keys, + dtype=dtype, + schema=schema, + ) + return str(table.sql_schema()) + + +def get_schema( + frame, + name: str, + keys=None, + con=None, + dtype: DtypeArg | None = None, + schema: str | None = None, +) -> str: + """ + Get the SQL db table schema for the given frame. + + Parameters + ---------- + frame : DataFrame + name : str + name of SQL table + keys : string or sequence, default: None + columns to use a primary key + con: ADBC Connection, SQLAlchemy connectable, sqlite3 connection, default: None + ADBC provides high performance I/O with native type support, where available. + Using SQLAlchemy makes it possible to use any DB supported by that + library + If a DBAPI2 object, only sqlite3 is supported. + dtype : dict of column name to SQL type, default None + Optional specifying the datatype for columns. The SQL type should + be a SQLAlchemy type, or a string for sqlite3 fallback connection. + schema: str, default: None + Optional specifying the schema to be used in creating the table. + """ + with pandasSQL_builder(con=con) as pandas_sql: + return pandas_sql._create_sql_schema( + frame, name, keys=keys, dtype=dtype, schema=schema + ) diff --git a/pandas/io/stata.py b/pandas/io/stata.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2dc9fd831b75f69c85dcac06004c6cdee1af8a --- /dev/null +++ b/pandas/io/stata.py @@ -0,0 +1,3925 @@ +""" +Module contains tools for processing Stata files into DataFrames + +The StataReader below was originally written by Joe Presbrey as part of PyDTA. +It has been extended and improved by Skipper Seabold from the Statsmodels +project who also developed the StataWriter and was finally added to pandas in +a once again improved version. + +You can find more information on http://presbrey.mit.edu/PyDTA and +https://www.statsmodels.org/devel/ +""" + +from __future__ import annotations + +from collections import abc +from datetime import ( + datetime, + timedelta, +) +from io import BytesIO +import os +import struct +import sys +from typing import ( + IO, + TYPE_CHECKING, + AnyStr, + Final, + Self, + cast, +) +import warnings + +import numpy as np + +from pandas._libs import lib +from pandas._libs.lib import infer_dtype +from pandas._libs.writers import max_len_string_array +from pandas.errors import ( + CategoricalConversionWarning, + InvalidColumnName, + Pandas4Warning, + PossiblePrecisionLoss, + ValueLabelTypeMismatch, +) +from pandas.util._decorators import ( + set_module, +) +from pandas.util._exceptions import find_stack_level + +from pandas.core.dtypes.base import ExtensionDtype +from pandas.core.dtypes.common import ( + ensure_object, + is_numeric_dtype, + is_string_dtype, +) +from pandas.core.dtypes.dtypes import CategoricalDtype + +from pandas import ( + Categorical, + DatetimeIndex, + NaT, + Timestamp, + isna, + to_datetime, +) +from pandas.core.frame import DataFrame +from pandas.core.indexes.base import Index +from pandas.core.indexes.range import RangeIndex +from pandas.core.series import Series +from pandas.core.shared_docs import _shared_docs + +from pandas.io.common import get_handle + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Hashable, + Sequence, + ) + from types import TracebackType + from typing import Literal + + from pandas._typing import ( + CompressionOptions, + FilePath, + ReadBuffer, + StorageOptions, + WriteBuffer, + ) + +_version_error = ( + "Version of given Stata file is {version}. pandas supports importing " + "versions 102, 103, 104, 105, 108, 110 (Stata 7), 111 (Stata 7SE), " + "113 (Stata 8/9), 114 (Stata 10/11), 115 (Stata 12), 117 (Stata 13), " + "118 (Stata 14/15/16), and 119 (Stata 15/16, over 32,767 variables)." +) + +_statafile_processing_params1 = """\ +convert_dates : bool, default True + Convert date variables to DataFrame time values. +convert_categoricals : bool, default True + Read value labels and convert columns to Categorical/Factor variables.""" + +_statafile_processing_params2 = """\ +index_col : str, optional + Column to set as index. +convert_missing : bool, default False + Flag indicating whether to convert missing values to their Stata + representations. If False, missing values are replaced with nan. + If True, columns containing missing values are returned with + object data types and missing values are represented by + StataMissingValue objects. +preserve_dtypes : bool, default True + Preserve Stata datatypes. If False, numeric data are upcast to pandas + default types for foreign data (float64 or int64). +columns : list or None + Columns to retain. Columns will be returned in the given order. None + returns all columns. +order_categoricals : bool, default True + Flag indicating whether converted categorical data are ordered.""" + +_chunksize_params = """\ +chunksize : int, default None + Return StataReader object for iterations, returns chunks with + given number of lines.""" + +_reader_notes = """\ +Notes +----- +Categorical variables read through an iterator may not have the same +categories and dtype. This occurs when a variable stored in a DTA +file is associated to an incomplete set of value labels that only +label a strict subset of the values.""" + +_stata_reader_doc = f"""\ +Class for reading Stata dta files. + +Parameters +---------- +path_or_buf : path (string), buffer or path object + string, pathlib.Path or object + implementing a binary read() functions. +{_statafile_processing_params1} +{_statafile_processing_params2} +{_chunksize_params} +{_shared_docs["decompression_options"]} +{_shared_docs["storage_options"]} + +{_reader_notes} +""" + + +_date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"] + + +stata_epoch: Final = datetime(1960, 1, 1) +unix_epoch: Final = datetime(1970, 1, 1) + + +def _stata_elapsed_date_to_datetime_vec(dates: Series, fmt: str) -> Series: + """ + Convert from SIF to datetime. https://www.stata.com/help.cgi?datetime + + Parameters + ---------- + dates : Series + The Stata Internal Format date to convert to datetime according to fmt + fmt : str + The format to convert to. Can be, tc, td, tw, tm, tq, th, ty + Returns + + Returns + ------- + converted : Series + The converted dates + + Examples + -------- + >>> dates = pd.Series([52]) + >>> _stata_elapsed_date_to_datetime_vec(dates, "%tw") + 0 1961-01-01 + dtype: datetime64[s] + + Notes + ----- + datetime/c - tc + milliseconds since 01jan1960 00:00:00.000, assuming 86,400 s/day + datetime/C - tC - NOT IMPLEMENTED + milliseconds since 01jan1960 00:00:00.000, adjusted for leap seconds + date - td + days since 01jan1960 (01jan1960 = 0) + weekly date - tw + weeks since 1960w1 + This assumes 52 weeks in a year, then adds 7 * remainder of the weeks. + The datetime value is the start of the week in terms of days in the + year, not ISO calendar weeks. + monthly date - tm + months since 1960m1 + quarterly date - tq + quarters since 1960q1 + half-yearly date - th + half-years since 1960h1 yearly + date - ty + years since 0000 + """ + + if fmt.startswith(("%tc", "tc")): + # Delta ms relative to base + td = np.timedelta64(stata_epoch - unix_epoch, "ms") + res = np.array(dates._values, dtype="M8[ms]") + td + return Series(res, index=dates.index) + + elif fmt.startswith(("%td", "td", "%d", "d")): + # Delta days relative to base + td = np.timedelta64(stata_epoch - unix_epoch, "D") + res = np.array(dates._values, dtype="M8[D]") + td + return Series(res, index=dates.index) + + elif fmt.startswith(("%tm", "tm")): + # Delta months relative to base + ordinals = dates + (stata_epoch.year - unix_epoch.year) * 12 + res = np.array(ordinals, dtype="M8[M]").astype("M8[s]") + return Series(res, index=dates.index) + + elif fmt.startswith(("%tq", "tq")): + # Delta quarters relative to base + ordinals = dates + (stata_epoch.year - unix_epoch.year) * 4 + res = np.array(ordinals, dtype="M8[3M]").astype("M8[s]") + return Series(res, index=dates.index) + + elif fmt.startswith(("%th", "th")): + # Delta half-years relative to base + ordinals = dates + (stata_epoch.year - unix_epoch.year) * 2 + res = np.array(ordinals, dtype="M8[6M]").astype("M8[s]") + return Series(res, index=dates.index) + + elif fmt.startswith(("%ty", "ty")): + # Years -- not delta + ordinals = dates - 1970 + res = np.array(ordinals, dtype="M8[Y]").astype("M8[s]") + return Series(res, index=dates.index) + + bad_locs = np.isnan(dates) + has_bad_values = False + if bad_locs.any(): + has_bad_values = True + dates._values[bad_locs] = 1.0 # Replace with NaT + dates = dates.astype(np.int64) + + if fmt.startswith(("%tC", "tC")): + warnings.warn( + "Encountered %tC format. Leaving in Stata Internal Format.", + stacklevel=find_stack_level(), + ) + conv_dates = Series(dates, dtype=object) + if has_bad_values: + conv_dates[bad_locs] = NaT + return conv_dates + # does not count leap days - 7 days is a week. + # 52nd week may have more than 7 days + elif fmt.startswith(("%tw", "tw")): + year = stata_epoch.year + dates // 52 + days = (dates % 52) * 7 + per_y = (year - 1970).array.view("Period[Y]") + per_d = per_y.asfreq("D", how="S") + per_d_shifted = per_d + days._values + per_s = per_d_shifted.asfreq("s", how="S") + conv_dates_arr = per_s.view("M8[s]") + conv_dates = Series(conv_dates_arr, index=dates.index) + + else: + raise ValueError(f"Date fmt {fmt} not understood") + + if has_bad_values: # Restore NaT for bad values + conv_dates[bad_locs] = NaT + + return conv_dates + + +def _datetime_to_stata_elapsed_vec(dates: Series, fmt: str) -> Series: + """ + Convert from datetime to SIF. https://www.stata.com/help.cgi?datetime + + Parameters + ---------- + dates : Series + Series or array containing datetime or datetime64[ns] to + convert to the Stata Internal Format given by fmt + fmt : str + The format to convert to. Can be, tc, td, tw, tm, tq, th, ty + """ + index = dates.index + NS_PER_DAY = 24 * 3600 * 1000 * 1000 * 1000 + US_PER_DAY = NS_PER_DAY / 1000 + MS_PER_DAY = NS_PER_DAY / 1_000_000 + + def parse_dates_safe( + dates: Series, delta: bool = False, year: bool = False, days: bool = False + ) -> DataFrame: + d = {} + if lib.is_np_dtype(dates.dtype, "M"): + if delta: + time_delta = dates.dt.as_unit("ms") - Timestamp(stata_epoch).as_unit( + "ms" + ) + d["delta"] = time_delta._values.view(np.int64) + if days or year: + date_index = DatetimeIndex(dates) + d["year"] = date_index._data.year + d["month"] = date_index._data.month + if days: + year_start = np.asarray(dates).astype("M8[Y]").astype(dates.dtype) + diff = dates - year_start + d["days"] = np.asarray(diff).astype("m8[D]").view("int64") + + elif infer_dtype(dates, skipna=False) == "datetime": + warnings.warn( + # GH#56536 + "Converting object-dtype columns of datetimes to datetime64 when " + "writing to stata is deprecated. Call " + "`df=df.infer_objects(copy=False)` before writing to stata instead.", + Pandas4Warning, + stacklevel=find_stack_level(), + ) + if delta: + delta = dates._values - stata_epoch + + def f(x: timedelta) -> float: + return US_PER_DAY * x.days + 1_000_000 * x.seconds + x.microseconds + + v = np.vectorize(f) + d["delta"] = v(delta) // 1_000 # convert back to ms + if year: + year_month = dates.apply(lambda x: 100 * x.year + x.month) + d["year"] = year_month._values // 100 + d["month"] = year_month._values - d["year"] * 100 + if days: + + def g(x: datetime) -> int: + return (x - datetime(x.year, 1, 1)).days + + v = np.vectorize(g) + d["days"] = v(dates) + else: + raise ValueError( + "Columns containing dates must contain either " + "datetime64, datetime or null values." + ) + + return DataFrame(d, index=index) + + bad_loc = isna(dates) + index = dates.index + if bad_loc.any(): + if lib.is_np_dtype(dates.dtype, "M"): + dates._values[bad_loc] = to_datetime(stata_epoch) + else: + dates._values[bad_loc] = stata_epoch + + if fmt in ["%tc", "tc"]: + d = parse_dates_safe(dates, delta=True) + conv_dates = d.delta + elif fmt in ["%tC", "tC"]: + warnings.warn( + "Stata Internal Format tC not supported.", + stacklevel=find_stack_level(), + ) + conv_dates = dates + elif fmt in ["%td", "td"]: + d = parse_dates_safe(dates, delta=True) + conv_dates = d.delta // MS_PER_DAY + elif fmt in ["%tw", "tw"]: + d = parse_dates_safe(dates, year=True, days=True) + conv_dates = 52 * (d.year - stata_epoch.year) + d.days // 7 + elif fmt in ["%tm", "tm"]: + d = parse_dates_safe(dates, year=True) + conv_dates = 12 * (d.year - stata_epoch.year) + d.month - 1 + elif fmt in ["%tq", "tq"]: + d = parse_dates_safe(dates, year=True) + conv_dates = 4 * (d.year - stata_epoch.year) + (d.month - 1) // 3 + elif fmt in ["%th", "th"]: + d = parse_dates_safe(dates, year=True) + conv_dates = 2 * (d.year - stata_epoch.year) + (d.month > 6).astype(int) + elif fmt in ["%ty", "ty"]: + d = parse_dates_safe(dates, year=True) + conv_dates = d.year + else: + raise ValueError(f"Format {fmt} is not a known Stata date format") + + conv_dates = Series(conv_dates, dtype=np.float64, copy=False) + missing_value = struct.unpack(" DataFrame: + """ + Checks the dtypes of the columns of a pandas DataFrame for + compatibility with the data types and ranges supported by Stata, and + converts if necessary. + + Parameters + ---------- + data : DataFrame + The DataFrame to check and convert + + Notes + ----- + Numeric columns in Stata must be one of int8, int16, int32, float32 or + float64, with some additional value restrictions. int8 and int16 columns + are checked for violations of the value restrictions and upcast if needed. + int64 data is not usable in Stata, and so it is downcast to int32 whenever + the value are in the int32 range, and sidecast to float64 when larger than + this range. If the int64 values are outside of the range of those + perfectly representable as float64 values, a warning is raised. + + bool columns are cast to int8. uint columns are converted to int of the + same size if there is no loss in precision, otherwise are upcast to a + larger type. uint64 is currently not supported since it is concerted to + object in a DataFrame. + """ + ws = "" + # original, if small, if large + conversion_data: tuple[ + tuple[type, type, type], + tuple[type, type, type], + tuple[type, type, type], + tuple[type, type, type], + tuple[type, type, type], + ] = ( + (np.bool_, np.int8, np.int8), + (np.uint8, np.int8, np.int16), + (np.uint16, np.int16, np.int32), + (np.uint32, np.int32, np.int64), + (np.uint64, np.int64, np.float64), + ) + + float32_max = struct.unpack("= 2**53: + ws = precision_loss_doc.format("uint64", "float64") + + data[col] = data[col].astype(dtype) + + # Check values and upcast if necessary + + if dtype == np.int8 and not empty_df: + if data[col].max() > 100 or data[col].min() < -127: + data[col] = data[col].astype(np.int16) + elif dtype == np.int16 and not empty_df: + if data[col].max() > 32740 or data[col].min() < -32767: + data[col] = data[col].astype(np.int32) + elif dtype == np.int64: + if empty_df or ( + data[col].max() <= 2147483620 and data[col].min() >= -2147483647 + ): + data[col] = data[col].astype(np.int32) + else: + data[col] = data[col].astype(np.float64) + if data[col].max() >= 2**53 or data[col].min() <= -(2**53): + ws = precision_loss_doc.format("int64", "float64") + elif dtype in (np.float32, np.float64): + if np.isinf(data[col]).any(): + raise ValueError( + f"Column {col} contains infinity or -infinity" + "which is outside the range supported by Stata." + ) + value = data[col].max() + if dtype == np.float32 and value > float32_max: + data[col] = data[col].astype(np.float64) + elif dtype == np.float64: + if value > float64_max: + raise ValueError( + f"Column {col} has a maximum value ({value}) outside the range " + f"supported by Stata ({float64_max})" + ) + if is_nullable_int: + if orig_missing.any(): + # Replace missing by Stata sentinel value + sentinel = StataMissingValue.BASE_MISSING_VALUES[data[col].dtype.name] + data.loc[orig_missing, col] = sentinel + if ws: + warnings.warn( + ws, + PossiblePrecisionLoss, + stacklevel=find_stack_level(), + ) + + return data + + +class StataValueLabel: + """ + Parse a categorical column and prepare formatted output + + Parameters + ---------- + catarray : Series + Categorical Series to encode + encoding : {"latin-1", "utf-8"} + Encoding to use for value labels. + """ + + def __init__( + self, catarray: Series, encoding: Literal["latin-1", "utf-8"] = "latin-1" + ) -> None: + if encoding not in ("latin-1", "utf-8"): + raise ValueError("Only latin-1 and utf-8 are supported.") + self.labname = catarray.name + self._encoding = encoding + categories = catarray.cat.categories + self.value_labels = enumerate(categories) + + self._prepare_value_labels() + + def _prepare_value_labels(self) -> None: + """Encode value labels.""" + + self.text_len = 0 + self.txt: list[bytes] = [] + self.n = 0 + # Offsets (length of categories), converted to int32 + self.off = np.array([], dtype=np.int32) + # Values, converted to int32 + self.val = np.array([], dtype=np.int32) + self.len = 0 + + # Compute lengths and setup lists of offsets and labels + offsets: list[int] = [] + values: list[float] = [] + for vl in self.value_labels: + category: str | bytes = vl[1] + if not isinstance(category, str): + category = str(category) + warnings.warn( + value_label_mismatch_doc.format(self.labname), + ValueLabelTypeMismatch, + stacklevel=find_stack_level(), + ) + category = category.encode(self._encoding) + offsets.append(self.text_len) + self.text_len += len(category) + 1 # +1 for the padding + values.append(vl[0]) + self.txt.append(category) + self.n += 1 + + # Ensure int32 + self.off = np.array(offsets, dtype=np.int32) + self.val = np.array(values, dtype=np.int32) + + # Total length + self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len + + def generate_value_label(self, byteorder: str) -> bytes: + """ + Generate the binary representation of the value labels. + + Parameters + ---------- + byteorder : str + Byte order of the output + + Returns + ------- + value_label : bytes + Bytes containing the formatted value label + """ + encoding = self._encoding + bio = BytesIO() + null_byte = b"\x00" + + # len + bio.write(struct.pack(byteorder + "i", self.len)) + + # labname + labname = str(self.labname)[:32].encode(encoding) + lab_len = 32 if encoding not in ("utf-8", "utf8") else 128 + labname = _pad_bytes(labname, lab_len + 1) + bio.write(labname) + + # padding - 3 bytes + for i in range(3): + bio.write(struct.pack("c", null_byte)) + + # value_label_table + # n - int32 + bio.write(struct.pack(byteorder + "i", self.n)) + + # textlen - int32 + bio.write(struct.pack(byteorder + "i", self.text_len)) + + # off - int32 array (n elements) + for offset in self.off: + bio.write(struct.pack(byteorder + "i", offset)) + + # val - int32 array (n elements) + for value in self.val: + bio.write(struct.pack(byteorder + "i", value)) + + # txt - Text labels, null terminated + for text in self.txt: + bio.write(text + null_byte) + + return bio.getvalue() + + +class StataNonCatValueLabel(StataValueLabel): + """ + Prepare formatted version of value labels + + Parameters + ---------- + labname : str + Value label name + value_labels: Dictionary + Mapping of values to labels + encoding : {"latin-1", "utf-8"} + Encoding to use for value labels. + """ + + def __init__( + self, + labname: str, + value_labels: dict[float, str], + encoding: Literal["latin-1", "utf-8"] = "latin-1", + ) -> None: + if encoding not in ("latin-1", "utf-8"): + raise ValueError("Only latin-1 and utf-8 are supported.") + + self.labname = labname + self._encoding = encoding + self.value_labels = sorted( # type: ignore[assignment] + value_labels.items(), key=lambda x: x[0] + ) + self._prepare_value_labels() + + +class StataMissingValue: + """ + An observation's missing value. + + Parameters + ---------- + value : {int, float} + The Stata missing value code + + Notes + ----- + More information: + + Integer missing values make the code '.', '.a', ..., '.z' to the ranges + 101 ... 127 (for int8), 32741 ... 32767 (for int16) and 2147483621 ... + 2147483647 (for int32). Missing values for floating point data types are + more complex but the pattern is simple to discern from the following table. + + np.float32 missing values (float in Stata) + 0000007f . + 0008007f .a + 0010007f .b + ... + 00c0007f .x + 00c8007f .y + 00d0007f .z + + np.float64 missing values (double in Stata) + 000000000000e07f . + 000000000001e07f .a + 000000000002e07f .b + ... + 000000000018e07f .x + 000000000019e07f .y + 00000000001ae07f .z + """ + + # Construct a dictionary of missing values + MISSING_VALUES: dict[float, str] = {} + bases: Final = (101, 32741, 2147483621) + for b in bases: + # Conversion to long to avoid hash issues on 32 bit platforms #8968 + MISSING_VALUES[b] = "." + for i in range(1, 27): + MISSING_VALUES[i + b] = "." + chr(96 + i) + + float32_base: bytes = b"\x00\x00\x00\x7f" + increment_32: int = struct.unpack(" 0: + MISSING_VALUES[key] += chr(96 + i) + int_value = struct.unpack(" 0: + MISSING_VALUES[key] += chr(96 + i) + int_value = struct.unpack("q", struct.pack(" None: + self._value = value + # Conversion to int to avoid hash issues on 32 bit platforms #8968 + value = int(value) if value < 2147483648 else float(value) + self._str = self.MISSING_VALUES[value] + + @property + def string(self) -> str: + """ + The Stata representation of the missing value: '.', '.a'..'.z' + + Returns + ------- + str + The representation of the missing value. + """ + return self._str + + @property + def value(self) -> float: + """ + The binary representation of the missing value. + + Returns + ------- + {int, float} + The binary representation of the missing value. + """ + return self._value + + def __str__(self) -> str: + return self.string + + def __repr__(self) -> str: + return f"{type(self)}({self})" + + def __eq__(self, other: object) -> bool: + return ( + isinstance(other, type(self)) + and self.string == other.string + and self.value == other.value + ) + + @classmethod + def get_base_missing_value(cls, dtype: np.dtype) -> float: + if dtype.type is np.int8: + value = cls.BASE_MISSING_VALUES["int8"] + elif dtype.type is np.int16: + value = cls.BASE_MISSING_VALUES["int16"] + elif dtype.type is np.int32: + value = cls.BASE_MISSING_VALUES["int32"] + elif dtype.type is np.float32: + value = cls.BASE_MISSING_VALUES["float32"] + elif dtype.type is np.float64: + value = cls.BASE_MISSING_VALUES["float64"] + else: + raise ValueError("Unsupported dtype") + return value + + +class StataParser: + def __init__(self) -> None: + # type code. + # -------------------- + # str1 1 = 0x01 + # str2 2 = 0x02 + # ... + # str244 244 = 0xf4 + # byte 251 = 0xfb (sic) + # int 252 = 0xfc + # long 253 = 0xfd + # float 254 = 0xfe + # double 255 = 0xff + # -------------------- + # NOTE: the byte type seems to be reserved for categorical variables + # with a label, but the underlying variable is -127 to 100 + # we're going to drop the label and cast to int + self.DTYPE_MAP = dict( + [(i, np.dtype(f"S{i}")) for i in range(1, 245)] + + [ + (251, np.dtype(np.int8)), + (252, np.dtype(np.int16)), + (253, np.dtype(np.int32)), + (254, np.dtype(np.float32)), + (255, np.dtype(np.float64)), + ] + ) + self.DTYPE_MAP_XML: dict[int, np.dtype] = { + 32768: np.dtype(np.uint8), # Keys to GSO + 65526: np.dtype(np.float64), + 65527: np.dtype(np.float32), + 65528: np.dtype(np.int32), + 65529: np.dtype(np.int16), + 65530: np.dtype(np.int8), + } + self.TYPE_MAP = list(tuple(range(251)) + tuple("bhlfd")) + self.TYPE_MAP_XML = { + # Not really a Q, unclear how to handle byteswap + 32768: "Q", + 65526: "d", + 65527: "f", + 65528: "l", + 65529: "h", + 65530: "b", + } + # NOTE: technically, some of these are wrong. there are more numbers + # that can be represented. it's the 27 ABOVE and BELOW the max listed + # numeric data type in [U] 12.2.2 of the 11.2 manual + float32_min = b"\xff\xff\xff\xfe" + float32_max = b"\xff\xff\xff\x7e" + float64_min = b"\xff\xff\xff\xff\xff\xff\xef\xff" + float64_max = b"\xff\xff\xff\xff\xff\xff\xdf\x7f" + self.VALID_RANGE = { + "b": (-127, 100), + "h": (-32767, 32740), + "l": (-2147483647, 2147483620), + "f": ( + np.float32(struct.unpack(" None: + super().__init__() + + # Arguments to the reader (can be temporarily overridden in + # calls to read). + self._convert_dates = convert_dates + self._convert_categoricals = convert_categoricals + self._index_col = index_col + self._convert_missing = convert_missing + self._preserve_dtypes = preserve_dtypes + self._columns = columns + self._order_categoricals = order_categoricals + self._original_path_or_buf = path_or_buf + self._compression = compression + self._storage_options = storage_options + self._encoding = "" + self._chunksize = chunksize + self._using_iterator = False + self._entered = False + if self._chunksize is None: + self._chunksize = 1 + elif not isinstance(chunksize, int) or chunksize <= 0: + raise ValueError("chunksize must be a positive integer when set.") + + # State variables for the file + self._close_file: Callable[[], None] | None = None + self._column_selector_set = False + self._value_label_dict: dict[str, dict[int, str]] = {} + self._value_labels_read = False + self._dtype: np.dtype | None = None + self._lines_read = 0 + + self._native_byteorder = _set_endianness(sys.byteorder) + + def _ensure_open(self) -> None: + """ + Ensure the file has been opened and its header data read. + """ + if not hasattr(self, "_path_or_buf"): + self._open_file() + + def _open_file(self) -> None: + """ + Open the file (with compression options, etc.), and read header information. + """ + if not self._entered: + warnings.warn( + "StataReader is being used without using a context manager. " + "Using StataReader as a context manager is the only supported method.", + ResourceWarning, + stacklevel=find_stack_level(), + ) + handles = get_handle( + self._original_path_or_buf, + "rb", + storage_options=self._storage_options, + is_text=False, + compression=self._compression, + ) + if hasattr(handles.handle, "seekable") and handles.handle.seekable(): + # If the handle is directly seekable, use it without an extra copy. + self._path_or_buf = handles.handle + self._close_file = handles.close + else: + # Copy to memory, and ensure no encoding. + with handles: + self._path_or_buf = BytesIO(handles.handle.read()) + self._close_file = self._path_or_buf.close + + self._read_header() + self._setup_dtype() + + def __enter__(self) -> Self: + """enter context manager""" + self._entered = True + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if self._close_file: + self._close_file() + + def _set_encoding(self) -> None: + """ + Set string encoding which depends on file version + """ + if self._format_version < 118: + self._encoding = "latin-1" + else: + self._encoding = "utf-8" + + def _read_int8(self) -> int: + return struct.unpack("b", self._path_or_buf.read(1))[0] + + def _read_uint8(self) -> int: + return struct.unpack("B", self._path_or_buf.read(1))[0] + + def _read_uint16(self) -> int: + return struct.unpack(f"{self._byteorder}H", self._path_or_buf.read(2))[0] + + def _read_uint32(self) -> int: + return struct.unpack(f"{self._byteorder}I", self._path_or_buf.read(4))[0] + + def _read_uint64(self) -> int: + return struct.unpack(f"{self._byteorder}Q", self._path_or_buf.read(8))[0] + + def _read_int16(self) -> int: + return struct.unpack(f"{self._byteorder}h", self._path_or_buf.read(2))[0] + + def _read_int32(self) -> int: + return struct.unpack(f"{self._byteorder}i", self._path_or_buf.read(4))[0] + + def _read_int64(self) -> int: + return struct.unpack(f"{self._byteorder}q", self._path_or_buf.read(8))[0] + + def _read_char8(self) -> bytes: + return struct.unpack("c", self._path_or_buf.read(1))[0] + + def _read_int16_count(self, count: int) -> tuple[int, ...]: + return struct.unpack( + f"{self._byteorder}{'h' * count}", + self._path_or_buf.read(2 * count), + ) + + def _read_header(self) -> None: + first_char = self._read_char8() + if first_char == b"<": + self._read_new_header() + else: + self._read_old_header(first_char) + + def _read_new_header(self) -> None: + # The first part of the header is common to 117 - 119. + self._path_or_buf.read(27) # stata_dta>
+ self._format_version = int(self._path_or_buf.read(3)) + if self._format_version not in [117, 118, 119]: + raise ValueError(_version_error.format(version=self._format_version)) + self._set_encoding() + self._path_or_buf.read(21) # + self._byteorder = ">" if self._path_or_buf.read(3) == b"MSF" else "<" + self._path_or_buf.read(15) # + self._nvar = ( + self._read_uint16() if self._format_version <= 118 else self._read_uint32() + ) + self._path_or_buf.read(7) # + + self._nobs = self._get_nobs() + self._path_or_buf.read(11) # + self._time_stamp = self._get_time_stamp() + self._path_or_buf.read(26) #
+ self._path_or_buf.read(8) # 0x0000000000000000 + self._path_or_buf.read(8) # position of + + self._seek_vartypes = self._read_int64() + 16 + self._seek_varnames = self._read_int64() + 10 + self._seek_sortlist = self._read_int64() + 10 + self._seek_formats = self._read_int64() + 9 + self._seek_value_label_names = self._read_int64() + 19 + + # Requires version-specific treatment + self._seek_variable_labels = self._get_seek_variable_labels() + + self._path_or_buf.read(8) # + self._data_location = self._read_int64() + 6 + self._seek_strls = self._read_int64() + 7 + self._seek_value_labels = self._read_int64() + 14 + + self._typlist, self._dtyplist = self._get_dtypes(self._seek_vartypes) + + self._path_or_buf.seek(self._seek_varnames) + self._varlist = self._get_varlist() + + self._path_or_buf.seek(self._seek_sortlist) + self._srtlist = self._read_int16_count(self._nvar + 1)[:-1] + + self._path_or_buf.seek(self._seek_formats) + self._fmtlist = self._get_fmtlist() + + self._path_or_buf.seek(self._seek_value_label_names) + self._lbllist = self._get_lbllist() + + self._path_or_buf.seek(self._seek_variable_labels) + self._variable_labels = self._get_variable_labels() + + # Get data type information, works for versions 117-119. + def _get_dtypes( + self, seek_vartypes: int + ) -> tuple[list[int | str], list[str | np.dtype]]: + self._path_or_buf.seek(seek_vartypes) + typlist = [] + dtyplist = [] + for _ in range(self._nvar): + typ = self._read_uint16() + if typ <= 2045: + typlist.append(typ) + dtyplist.append(str(typ)) + else: + try: + typlist.append(self.TYPE_MAP_XML[typ]) # type: ignore[arg-type] + dtyplist.append(self.DTYPE_MAP_XML[typ]) # type: ignore[arg-type] + except KeyError as err: + raise ValueError(f"cannot convert stata types [{typ}]") from err + + return typlist, dtyplist # type: ignore[return-value] + + def _get_varlist(self) -> list[str]: + # 33 in order formats, 129 in formats 118 and 119 + b = 33 if self._format_version < 118 else 129 + return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)] + + # Returns the format list + def _get_fmtlist(self) -> list[str]: + if self._format_version >= 118: + b = 57 + elif self._format_version > 113: + b = 49 + elif self._format_version > 104: + b = 12 + else: + b = 7 + + return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)] + + # Returns the label list + def _get_lbllist(self) -> list[str]: + if self._format_version >= 118: + b = 129 + elif self._format_version > 108: + b = 33 + else: + b = 9 + return [self._decode(self._path_or_buf.read(b)) for _ in range(self._nvar)] + + def _get_variable_labels(self) -> list[str]: + if self._format_version >= 118: + vlblist = [ + self._decode(self._path_or_buf.read(321)) for _ in range(self._nvar) + ] + elif self._format_version > 105: + vlblist = [ + self._decode(self._path_or_buf.read(81)) for _ in range(self._nvar) + ] + else: + vlblist = [ + self._decode(self._path_or_buf.read(32)) for _ in range(self._nvar) + ] + return vlblist + + def _get_nobs(self) -> int: + if self._format_version >= 118: + return self._read_uint64() + elif self._format_version >= 103: + return self._read_uint32() + else: + return self._read_uint16() + + def _get_data_label(self) -> str: + if self._format_version >= 118: + strlen = self._read_uint16() + return self._decode(self._path_or_buf.read(strlen)) + elif self._format_version == 117: + strlen = self._read_int8() + return self._decode(self._path_or_buf.read(strlen)) + elif self._format_version > 105: + return self._decode(self._path_or_buf.read(81)) + else: + return self._decode(self._path_or_buf.read(32)) + + def _get_time_stamp(self) -> str: + if self._format_version >= 118: + strlen = self._read_int8() + return self._path_or_buf.read(strlen).decode("utf-8") + elif self._format_version == 117: + strlen = self._read_int8() + return self._decode(self._path_or_buf.read(strlen)) + elif self._format_version > 104: + return self._decode(self._path_or_buf.read(18)) + else: + raise ValueError + + def _get_seek_variable_labels(self) -> int: + if self._format_version == 117: + self._path_or_buf.read(8) # , throw away + # Stata 117 data files do not follow the described format. This is + # a work around that uses the previous label, 33 bytes for each + # variable, 20 for the closing tag and 17 for the opening tag + return self._seek_value_label_names + (33 * self._nvar) + 20 + 17 + elif self._format_version >= 118: + return self._read_int64() + 17 + else: + raise ValueError + + def _read_old_header(self, first_char: bytes) -> None: + self._format_version = int(first_char[0]) + if self._format_version not in [ + 102, + 103, + 104, + 105, + 108, + 110, + 111, + 113, + 114, + 115, + ]: + raise ValueError(_version_error.format(version=self._format_version)) + self._set_encoding() + # Note 102 format will have a zero in this header position, so support + # relies on little-endian being set whenever this value isn't one, + # even though for later releases strictly speaking the value should + # be either one or two to be valid + self._byteorder = ">" if self._read_int8() == 0x1 else "<" + self._filetype = self._read_int8() + self._path_or_buf.read(1) # unused + + self._nvar = self._read_uint16() + self._nobs = self._get_nobs() + + self._data_label = self._get_data_label() + + if self._format_version >= 105: + self._time_stamp = self._get_time_stamp() + + # descriptors + if self._format_version >= 111: + typlist = [int(c) for c in self._path_or_buf.read(self._nvar)] + else: + buf = self._path_or_buf.read(self._nvar) + typlistb = np.frombuffer(buf, dtype=np.uint8) + typlist = [] + for tp in typlistb: + if tp in self.OLD_TYPE_MAPPING: + typlist.append(self.OLD_TYPE_MAPPING[tp]) + else: + typlist.append(tp - 127) # bytes + + try: + self._typlist = [self.TYPE_MAP[typ] for typ in typlist] + except ValueError as err: + invalid_types = ",".join([str(x) for x in typlist]) + raise ValueError(f"cannot convert stata types [{invalid_types}]") from err + try: + self._dtyplist = [self.DTYPE_MAP[typ] for typ in typlist] + except ValueError as err: + invalid_dtypes = ",".join([str(x) for x in typlist]) + raise ValueError(f"cannot convert stata dtypes [{invalid_dtypes}]") from err + + if self._format_version > 108: + self._varlist = [ + self._decode(self._path_or_buf.read(33)) for _ in range(self._nvar) + ] + else: + self._varlist = [ + self._decode(self._path_or_buf.read(9)) for _ in range(self._nvar) + ] + self._srtlist = self._read_int16_count(self._nvar + 1)[:-1] + + self._fmtlist = self._get_fmtlist() + + self._lbllist = self._get_lbllist() + + self._variable_labels = self._get_variable_labels() + + # ignore expansion fields (Format 105 and later) + # When reading, read five bytes; the last four bytes now tell you + # the size of the next read, which you discard. You then continue + # like this until you read 5 bytes of zeros. + + if self._format_version > 104: + while True: + data_type = self._read_int8() + if self._format_version > 108: + data_len = self._read_int32() + else: + data_len = self._read_int16() + if data_type == 0: + break + self._path_or_buf.read(data_len) + + # necessary data to continue parsing + self._data_location = self._path_or_buf.tell() + + def _setup_dtype(self) -> np.dtype: + """Map between numpy and state dtypes""" + if self._dtype is not None: + return self._dtype + + dtypes = [] # Convert struct data types to numpy data type + for i, typ in enumerate(self._typlist): + if typ in self.NUMPY_TYPE_MAP: + typ = cast(str, typ) # only strs in NUMPY_TYPE_MAP + dtypes.append((f"s{i}", f"{self._byteorder}{self.NUMPY_TYPE_MAP[typ]}")) + else: + dtypes.append((f"s{i}", f"S{typ}")) + self._dtype = np.dtype(dtypes) + + return self._dtype + + def _decode(self, s: bytes) -> str: + # have bytes not strings, so must decode + s = s.partition(b"\0")[0] + try: + return s.decode(self._encoding) + except UnicodeDecodeError: + # GH 25960, fallback to handle incorrect format produced when 117 + # files are converted to 118 files in Stata + encoding = self._encoding + msg = f""" +One or more strings in the dta file could not be decoded using {encoding}, and +so the fallback encoding of latin-1 is being used. This can happen when a file +has been incorrectly encoded by Stata or some other software. You should verify +the string values returned are correct.""" + warnings.warn( + msg, + UnicodeWarning, + stacklevel=find_stack_level(), + ) + return s.decode("latin-1") + + def _read_new_value_labels(self) -> None: + """Reads value labels with variable length strings (108 and later format)""" + if self._format_version >= 117: + self._path_or_buf.seek(self._seek_value_labels) + else: + assert self._dtype is not None + offset = self._nobs * self._dtype.itemsize + self._path_or_buf.seek(self._data_location + offset) + + while True: + if self._format_version >= 117: + if self._path_or_buf.read(5) == b" + break # end of value label table + + slength = self._path_or_buf.read(4) + if not slength: + break # end of value label table (format < 117), or end-of-file + if self._format_version == 108: + labname = self._decode(self._path_or_buf.read(9)) + elif self._format_version <= 117: + labname = self._decode(self._path_or_buf.read(33)) + else: + labname = self._decode(self._path_or_buf.read(129)) + self._path_or_buf.read(3) # padding + + n = self._read_uint32() + txtlen = self._read_uint32() + off = np.frombuffer( + self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n + ) + val = np.frombuffer( + self._path_or_buf.read(4 * n), dtype=f"{self._byteorder}i4", count=n + ) + ii = np.argsort(off) + off = off[ii] + val = val[ii] + txt = self._path_or_buf.read(txtlen) + self._value_label_dict[labname] = {} + for i in range(n): + end = off[i + 1] if i < n - 1 else txtlen + self._value_label_dict[labname][val[i]] = self._decode( + txt[off[i] : end] + ) + + if self._format_version >= 117: + self._path_or_buf.read(6) # + + def _read_old_value_labels(self) -> None: + """Reads value labels with fixed-length strings (105 and earlier format)""" + assert self._dtype is not None + offset = self._nobs * self._dtype.itemsize + self._path_or_buf.seek(self._data_location + offset) + + while True: + if not self._path_or_buf.read(2): + # end-of-file may have been reached, if so stop here + break + + # otherwise back up and read again, taking byteorder into account + self._path_or_buf.seek(-2, os.SEEK_CUR) + n = self._read_uint16() + labname = self._decode(self._path_or_buf.read(9)) + self._path_or_buf.read(1) # padding + codes = np.frombuffer( + self._path_or_buf.read(2 * n), dtype=f"{self._byteorder}i2", count=n + ) + self._value_label_dict[labname] = {} + for i in range(n): + self._value_label_dict[labname][codes[i]] = self._decode( + self._path_or_buf.read(8) + ) + + def _read_value_labels(self) -> None: + self._ensure_open() + if self._value_labels_read: + # Don't read twice + return + + if self._format_version >= 108: + self._read_new_value_labels() + else: + self._read_old_value_labels() + self._value_labels_read = True + + def _read_strls(self) -> None: + self._path_or_buf.seek(self._seek_strls) + # Wrap v_o in a string to allow uint64 values as keys on 32bit OS + self.GSO = {"0": ""} + while True: + if self._path_or_buf.read(3) != b"GSO": + break + + if self._format_version == 117: + v_o = self._read_uint64() + else: + buf = self._path_or_buf.read(12) + # Only tested on little endian machine. + v_size = 2 if self._format_version == 118 else 3 + if self._byteorder == "<": + buf = buf[0:v_size] + buf[4 : (12 - v_size)] + else: + buf = buf[4 - v_size : 4] + buf[(4 + v_size) :] + v_o = struct.unpack(f"{self._byteorder}Q", buf)[0] + typ = self._read_uint8() + length = self._read_uint32() + va = self._path_or_buf.read(length) + if typ == 130: + decoded_va = va[0:-1].decode(self._encoding) + else: + # Stata says typ 129 can be binary, so use str + decoded_va = str(va) + # Wrap v_o in a string to allow uint64 values as keys on 32bit OS + self.GSO[str(v_o)] = decoded_va + + def __next__(self) -> DataFrame: + self._using_iterator = True + return self.read(nrows=self._chunksize) + + def get_chunk(self, size: int | None = None) -> DataFrame: + """ + Reads lines from Stata file and returns as dataframe + + Parameters + ---------- + size : int, defaults to None + Number of lines to read. If None, reads whole file. + + Returns + ------- + DataFrame + """ + if size is None: + size = self._chunksize + return self.read(nrows=size) + + def read( + self, + nrows: int | None = None, + convert_dates: bool | None = None, + convert_categoricals: bool | None = None, + index_col: str | None = None, + convert_missing: bool | None = None, + preserve_dtypes: bool | None = None, + columns: Sequence[str] | None = None, + order_categoricals: bool | None = None, + ) -> DataFrame: + """ + Reads observations from Stata file, converting them into a dataframe + + Parameters + ---------- + nrows : int + Number of lines to read from data file, if None read whole file. + convert_dates : bool, default True + Convert date variables to DataFrame time values. + convert_categoricals : bool, default True + Read value labels and convert columns to Categorical/Factor variables. + index_col : str, optional + Column to set as index. + convert_missing : bool, default False + Flag indicating whether to convert missing values to their Stata + representations. If False, missing values are replaced with nan. + If True, columns containing missing values are returned with + object data types and missing values are represented by + StataMissingValue objects. + preserve_dtypes : bool, default True + Preserve Stata datatypes. If False, numeric data are upcast to pandas + default types for foreign data (float64 or int64). + columns : list or None + Columns to retain. Columns will be returned in the given order. None + returns all columns. + order_categoricals : bool, default True + Flag indicating whether converted categorical data are ordered. + + Returns + ------- + DataFrame + """ + self._ensure_open() + + # Handle options + if convert_dates is None: + convert_dates = self._convert_dates + if convert_categoricals is None: + convert_categoricals = self._convert_categoricals + if convert_missing is None: + convert_missing = self._convert_missing + if preserve_dtypes is None: + preserve_dtypes = self._preserve_dtypes + if columns is None: + columns = self._columns + if order_categoricals is None: + order_categoricals = self._order_categoricals + if index_col is None: + index_col = self._index_col + if nrows is None: + nrows = self._nobs + + # Handle empty file or chunk. If reading incrementally raise + # StopIteration. If reading the whole thing return an empty + # data frame. + if (self._nobs == 0) and nrows == 0: + data = DataFrame(columns=self._varlist) + # Apply dtypes correctly + for i, col in enumerate(data.columns): + dt = self._dtyplist[i] + if isinstance(dt, np.dtype): + if dt.char != "S": + data[col] = data[col].astype(dt) + if columns is not None: + data = self._do_select_columns(data, columns) + return data + + if (self._format_version >= 117) and (not self._value_labels_read): + self._read_strls() + + # Read data + assert self._dtype is not None + dtype = self._dtype + max_read_len = (self._nobs - self._lines_read) * dtype.itemsize + read_len = nrows * dtype.itemsize + read_len = min(read_len, max_read_len) + if read_len <= 0: + # Iterator has finished, should never be here unless + # we are reading the file incrementally + if convert_categoricals: + self._read_value_labels() + raise StopIteration + offset = self._lines_read * dtype.itemsize + self._path_or_buf.seek(self._data_location + offset) + read_lines = min(nrows, self._nobs - self._lines_read) + raw_data = np.frombuffer( + self._path_or_buf.read(read_len), dtype=dtype, count=read_lines + ) + + self._lines_read += read_lines + + # if necessary, swap the byte order to native here + if self._byteorder != self._native_byteorder: + raw_data = raw_data.byteswap().view(raw_data.dtype.newbyteorder()) + + if convert_categoricals: + self._read_value_labels() + + if len(raw_data) == 0: + data = DataFrame(columns=self._varlist) + else: + data = DataFrame.from_records(raw_data) + data.columns = Index(self._varlist) + + # If index is not specified, use actual row number rather than + # restarting at 0 for each chunk. + if index_col is None: + data.index = RangeIndex( + self._lines_read - read_lines, self._lines_read + ) # set attr instead of set_index to avoid copy + + if columns is not None: + data = self._do_select_columns(data, columns) + + # Decode strings + for col, typ in zip(data, self._typlist, strict=True): + if isinstance(typ, int): + data[col] = data[col].apply(self._decode) + + data = self._insert_strls(data) + + # Convert columns (if needed) to match input type + valid_dtypes = [i for i, dtyp in enumerate(self._dtyplist) if dtyp is not None] + object_type = np.dtype(object) + for idx in valid_dtypes: + dtype = data.iloc[:, idx].dtype + if dtype not in (object_type, self._dtyplist[idx]): + data.isetitem(idx, data.iloc[:, idx].astype(dtype)) + + data = self._do_convert_missing(data, convert_missing) + + if convert_dates: + for i, fmt in enumerate(self._fmtlist): + if any(fmt.startswith(date_fmt) for date_fmt in _date_formats): + data.isetitem( + i, _stata_elapsed_date_to_datetime_vec(data.iloc[:, i], fmt) + ) + + if convert_categoricals: + data = self._do_convert_categoricals( + data, self._value_label_dict, self._lbllist, order_categoricals + ) + + if not preserve_dtypes: + retyped_data = [] + convert = False + for col in data: + dtype = data[col].dtype + if dtype in (np.dtype(np.float16), np.dtype(np.float32)): + dtype = np.dtype(np.float64) + convert = True + elif dtype in ( + np.dtype(np.int8), + np.dtype(np.int16), + np.dtype(np.int32), + ): + dtype = np.dtype(np.int64) + convert = True + retyped_data.append((col, data[col].astype(dtype))) + if convert: + data = DataFrame.from_dict(dict(retyped_data)) + + if index_col is not None: + data = data.set_index(data.pop(index_col)) + + return data + + def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFrame: + # missing code for double was different in version 105 and prior + old_missingdouble = float.fromhex("0x1.0p333") + + # Check for missing values, and replace if found + replacements = {} + for i in range(len(data.columns)): + fmt = self._typlist[i] + # recode instances of the old missing code to the currently used value + if self._format_version <= 105 and fmt == "d": + data.iloc[:, i] = data.iloc[:, i].replace( + old_missingdouble, self.MISSING_VALUES["d"] + ) + + if self._format_version <= 111: + if fmt not in self.OLD_VALID_RANGE: + continue + + fmt = cast(str, fmt) # only strs in OLD_VALID_RANGE + nmin, nmax = self.OLD_VALID_RANGE[fmt] + else: + if fmt not in self.VALID_RANGE: + continue + + fmt = cast(str, fmt) # only strs in VALID_RANGE + nmin, nmax = self.VALID_RANGE[fmt] + series = data.iloc[:, i] + + # appreciably faster to do this with ndarray instead of Series + svals = series._values + missing = (svals < nmin) | (svals > nmax) + + if not missing.any(): + continue + + if convert_missing: # Replacement follows Stata notation + missing_loc = np.nonzero(np.asarray(missing))[0] + umissing, umissing_loc = np.unique(series[missing], return_inverse=True) + replacement = Series(series, dtype=object) + for j, um in enumerate(umissing): + if self._format_version <= 111: + missing_value = StataMissingValue( + float(self.MISSING_VALUES[fmt]) + ) + else: + missing_value = StataMissingValue(um) + + loc = missing_loc[umissing_loc == j] + replacement.iloc[loc] = missing_value + else: # All replacements are identical + dtype = series.dtype + if dtype not in (np.float32, np.float64): + dtype = np.float64 + replacement = Series(series, dtype=dtype) + # Note: operating on ._values is much faster than directly + # TODO: can we fix that? + replacement._values[missing] = np.nan + replacements[i] = replacement + if replacements: + for idx, value in replacements.items(): + data.isetitem(idx, value) + return data + + def _insert_strls(self, data: DataFrame) -> DataFrame: + if not hasattr(self, "GSO") or len(self.GSO) == 0: + return data + for i, typ in enumerate(self._typlist): + if typ != "Q": + continue + # Wrap v_o in a string to allow uint64 values as keys on 32bit OS + data.isetitem(i, [self.GSO[str(k)] for k in data.iloc[:, i]]) + return data + + def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFrame: + if not self._column_selector_set: + column_set = set(columns) + if len(column_set) != len(columns): + raise ValueError("columns contains duplicate entries") + unmatched = column_set.difference(data.columns) + if unmatched: + joined = ", ".join(list(unmatched)) + raise ValueError( + "The following columns were not " + f"found in the Stata data set: {joined}" + ) + # Copy information for retained columns for later processing + dtyplist = [] + typlist = [] + fmtlist = [] + lbllist = [] + for col in columns: + i = data.columns.get_loc(col) # type: ignore[no-untyped-call] + dtyplist.append(self._dtyplist[i]) + typlist.append(self._typlist[i]) + fmtlist.append(self._fmtlist[i]) + lbllist.append(self._lbllist[i]) + + self._dtyplist = dtyplist + self._typlist = typlist + self._fmtlist = fmtlist + self._lbllist = lbllist + self._column_selector_set = True + + return data[columns] + + def _do_convert_categoricals( + self, + data: DataFrame, + value_label_dict: dict[str, dict[int, str]], + lbllist: Sequence[str], + order_categoricals: bool, + ) -> DataFrame: + """ + Converts categorical columns to Categorical type. + """ + if not value_label_dict: + return data + cat_converted_data = [] + for col, label in zip(data, lbllist, strict=True): + if label in value_label_dict: + # Explicit call with ordered=True + vl = value_label_dict[label] + keys = np.array(list(vl.keys())) + column = data[col] + key_matches = column.isin(keys) + if self._using_iterator and key_matches.all(): + initial_categories: np.ndarray | None = keys + # If all categories are in the keys and we are iterating, + # use the same keys for all chunks. If some are missing + # value labels, then we will fall back to the categories + # varying across chunks. + else: + if self._using_iterator: + # warn is using an iterator + warnings.warn( + categorical_conversion_warning, + CategoricalConversionWarning, + stacklevel=find_stack_level(), + ) + initial_categories = None + cat_data = Categorical( + column, categories=initial_categories, ordered=order_categoricals + ) + if initial_categories is None: + # If None here, then we need to match the cats in the Categorical + categories = [] + for category in cat_data.categories: + if category in vl: + categories.append(vl[category]) + else: + categories.append(category) + else: + # If all cats are matched, we can use the values + categories = list(vl.values()) + try: + # Try to catch duplicate categories + # TODO: if we get a non-copying rename_categories, use that + cat_data = cat_data.rename_categories(categories) + except ValueError as err: + vc = Series(categories, copy=False).value_counts() + repeated_cats = list(vc.index[vc > 1]) + repeats = "-" * 80 + "\n" + "\n".join(repeated_cats) + # GH 25772 + msg = f""" +Value labels for column {col} are not unique. These cannot be converted to +pandas categoricals. + +Either read the file with `convert_categoricals` set to False or use the +low level interface in `StataReader` to separately read the values and the +value_labels. + +The repeated labels are: +{repeats} +""" + raise ValueError(msg) from err + # TODO: is the next line needed above in the data(...) method? + cat_series = Series(cat_data, index=data.index, copy=False) + cat_converted_data.append((col, cat_series)) + else: + cat_converted_data.append((col, data[col])) + data = DataFrame(dict(cat_converted_data), copy=False) + return data + + @property + def data_label(self) -> str: + """ + Return data label of Stata file. + + The data label is a descriptive string associated with the dataset + stored in the Stata file. This property provides access to that + label, if one is present. + + See Also + -------- + io.stata.StataReader.variable_labels : Return a dict associating each variable + name with corresponding label. + DataFrame.to_stata : Export DataFrame object to Stata dta format. + + Examples + -------- + >>> df = pd.DataFrame([(1,)], columns=["variable"]) + >>> time_stamp = pd.Timestamp(2000, 2, 29, 14, 21) + >>> data_label = "This is a data file." + >>> path = "/My_path/filename.dta" + >>> df.to_stata( + ... path, + ... time_stamp=time_stamp, # doctest: +SKIP + ... data_label=data_label, # doctest: +SKIP + ... version=None, + ... ) # doctest: +SKIP + >>> with pd.io.stata.StataReader(path) as reader: # doctest: +SKIP + ... print(reader.data_label) # doctest: +SKIP + This is a data file. + """ + self._ensure_open() + return self._data_label + + @property + def time_stamp(self) -> str: + """ + Return time stamp of Stata file. + """ + self._ensure_open() + return self._time_stamp + + def variable_labels(self) -> dict[str, str]: + """ + Return a dict associating each variable name with corresponding label. + + This method retrieves variable labels from a Stata file. Variable labels are + mappings between variable names and their corresponding descriptive labels + in a Stata dataset. + + Returns + ------- + dict + A python dictionary. + + See Also + -------- + read_stata : Read Stata file into DataFrame. + DataFrame.to_stata : Export DataFrame object to Stata dta format. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=["col_1", "col_2"]) + >>> time_stamp = pd.Timestamp(2000, 2, 29, 14, 21) + >>> path = "/My_path/filename.dta" + >>> variable_labels = {"col_1": "This is an example"} + >>> df.to_stata( + ... path, + ... time_stamp=time_stamp, # doctest: +SKIP + ... variable_labels=variable_labels, + ... version=None, + ... ) # doctest: +SKIP + >>> with pd.io.stata.StataReader(path) as reader: # doctest: +SKIP + ... print(reader.variable_labels()) # doctest: +SKIP + {'index': '', 'col_1': 'This is an example', 'col_2': ''} + >>> pd.read_stata(path) # doctest: +SKIP + index col_1 col_2 + 0 0 1 2 + 1 1 3 4 + """ + self._ensure_open() + return dict(zip(self._varlist, self._variable_labels, strict=True)) + + def value_labels(self) -> dict[str, dict[int, str]]: + """ + Return a nested dict associating each variable name to its value and label. + + This method retrieves the value labels from a Stata file. Value labels are + mappings between the coded values and their corresponding descriptive labels + in a Stata dataset. + + Returns + ------- + dict + A python dictionary. + + See Also + -------- + read_stata : Read Stata file into DataFrame. + DataFrame.to_stata : Export DataFrame object to Stata dta format. + + Examples + -------- + >>> df = pd.DataFrame([[1, 2], [3, 4]], columns=["col_1", "col_2"]) + >>> time_stamp = pd.Timestamp(2000, 2, 29, 14, 21) + >>> path = "/My_path/filename.dta" + >>> value_labels = {"col_1": {3: "x"}} + >>> df.to_stata( + ... path, + ... time_stamp=time_stamp, # doctest: +SKIP + ... value_labels=value_labels, + ... version=None, + ... ) # doctest: +SKIP + >>> with pd.io.stata.StataReader(path) as reader: # doctest: +SKIP + ... print(reader.value_labels()) # doctest: +SKIP + {'col_1': {3: 'x'}} + >>> pd.read_stata(path) # doctest: +SKIP + index col_1 col_2 + 0 0 1 2 + 1 1 x 4 + """ + if not self._value_labels_read: + self._read_value_labels() + + return self._value_label_dict + + +@set_module("pandas") +def read_stata( + filepath_or_buffer: FilePath | ReadBuffer[bytes], + *, + convert_dates: bool = True, + convert_categoricals: bool = True, + index_col: str | None = None, + convert_missing: bool = False, + preserve_dtypes: bool = True, + columns: Sequence[str] | None = None, + order_categoricals: bool = True, + chunksize: int | None = None, + iterator: bool = False, + compression: CompressionOptions = "infer", + storage_options: StorageOptions | None = None, +) -> DataFrame | StataReader: + """ + Read Stata file into DataFrame. + + Parameters + ---------- + filepath_or_buffer : str, path object or file-like object + Any valid string path is acceptable. The string could be a URL. Valid + URL schemes include http, ftp, s3, and file. For file URLs, a host is + expected. A local file could be: ``file://localhost/path/to/table.dta``. + + If you want to pass in a path object, pandas accepts any ``os.PathLike``. + + By file-like object, we refer to objects with a ``read()`` method, + such as a file handle (e.g. via builtin ``open`` function) + or ``StringIO``. + convert_dates : bool, default True + Convert date variables to DataFrame time values. + convert_categoricals : bool, default True + Read value labels and convert columns to Categorical/Factor variables. + index_col : str, optional + Column to set as index. + convert_missing : bool, default False + Flag indicating whether to convert missing values to their Stata + representations. If False, missing values are replaced with nan. + If True, columns containing missing values are returned with + object data types and missing values are represented by + StataMissingValue objects. + preserve_dtypes : bool, default True + Preserve Stata datatypes. If False, numeric data are upcast to pandas + default types for foreign data (float64 or int64). + columns : list or None + Columns to retain. Columns will be returned in the given order. None + returns all columns. + order_categoricals : bool, default True + Flag indicating whether converted categorical data are ordered. + chunksize : int, default None + Return StataReader object for iterations, returns chunks with + given number of lines. + iterator : bool, default False + Return StataReader object. + compression : str or dict, default 'infer' + For on-the-fly decompression of on-disk data. If 'infer' and + 'filepath_or_buffer' is path-like, then detect compression from the + following extensions: '.gz', '.bz2', '.zip', '.xz', '.zst', '.tar', + '.tar.gz', '.tar.xz' or '.tar.bz2' (otherwise no compression). + If using 'zip' or 'tar', the ZIP file must contain only one + data file to be read in. Set to ``None`` for no decompression. + Can also be a dict with key ``'method'`` set to one of + {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} and + other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdDecompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for Zstandard decompression using a + custom compression dictionary: + ``compression={'method': 'zstd', 'dict_data': my_compression_dict}``. + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + Returns + ------- + DataFrame, pandas.api.typing.StataReader + If iterator or chunksize, returns StataReader, else DataFrame. + + See Also + -------- + io.stata.StataReader : Low-level reader for Stata data files. + DataFrame.to_stata: Export Stata data files. + + Notes + ----- + Categorical variables read through an iterator may not have the same + categories and dtype. This occurs when a variable stored in a DTA + file is associated to an incomplete set of value labels that only + label a strict subset of the values. + + Examples + -------- + + Creating a dummy stata for this example + + >>> df = pd.DataFrame( + ... { + ... "animal": ["falcon", "parrot", "falcon", "parrot"], + ... "speed": [350, 18, 361, 15], + ... } + ... ) # doctest: +SKIP + >>> df.to_stata("animals.dta") # doctest: +SKIP + + Read a Stata dta file: + + >>> df = pd.read_stata("animals.dta") # doctest: +SKIP + + Read a Stata dta file in 10,000 line chunks: + + >>> values = np.random.randint( + ... 0, 10, size=(20_000, 1), dtype="uint8" + ... ) # doctest: +SKIP + >>> df = pd.DataFrame(values, columns=["i"]) # doctest: +SKIP + >>> df.to_stata("filename.dta") # doctest: +SKIP + + >>> with pd.read_stata('filename.dta', chunksize=10000) as itr: # doctest: +SKIP + >>> for chunk in itr: + ... # Operate on a single chunk, e.g., chunk.mean() + ... pass # doctest: +SKIP + """ + reader = StataReader( + filepath_or_buffer, + convert_dates=convert_dates, + convert_categoricals=convert_categoricals, + index_col=index_col, + convert_missing=convert_missing, + preserve_dtypes=preserve_dtypes, + columns=columns, + order_categoricals=order_categoricals, + chunksize=chunksize, + storage_options=storage_options, + compression=compression, + ) + + if iterator or chunksize: + return reader + + with reader: + return reader.read() + + +def _set_endianness(endianness: str) -> str: + if endianness.lower() in ["<", "little"]: + return "<" + elif endianness.lower() in [">", "big"]: + return ">" + else: # pragma : no cover + raise ValueError(f"Endianness {endianness} not understood") + + +def _pad_bytes(name: AnyStr, length: int) -> AnyStr: + """ + Take a char string and pads it with null bytes until it's length chars. + """ + if isinstance(name, bytes): + return name + b"\x00" * (length - len(name)) + return name + "\x00" * (length - len(name)) + + +def _convert_datetime_to_stata_type(fmt: str) -> np.dtype: + """ + Convert from one of the stata date formats to a type in TYPE_MAP. + """ + if fmt in [ + "tc", + "%tc", + "td", + "%td", + "tw", + "%tw", + "tm", + "%tm", + "tq", + "%tq", + "th", + "%th", + "ty", + "%ty", + ]: + return np.dtype(np.float64) # Stata expects doubles for SIFs + else: + raise NotImplementedError(f"Format {fmt} not implemented") + + +def _maybe_convert_to_int_keys(convert_dates: dict, varlist: list[Hashable]) -> dict: + new_dict = {} + for key, value in convert_dates.items(): + if not value.startswith("%"): # make sure proper fmts + convert_dates[key] = "%" + value + if key in varlist: + new_dict[varlist.index(key)] = convert_dates[key] + else: + if not isinstance(key, int): + raise ValueError("convert_dates key must be a column or an integer") + new_dict[key] = convert_dates[key] + return new_dict + + +def _dtype_to_stata_type(dtype: np.dtype, column: Series) -> int: + """ + Convert dtype types to stata types. Returns the byte of the given ordinal. + See TYPE_MAP and comments for an explanation. This is also explained in + the dta spec. + 1 - 244 are strings of this length + Pandas Stata + 251 - for int8 byte + 252 - for int16 int + 253 - for int32 long + 254 - for float32 float + 255 - for double double + + If there are dates to convert, then dtype will already have the correct + type inserted. + """ + # TODO: expand to handle datetime to integer conversion + if dtype.type is np.object_: # try to coerce it to the biggest string + # not memory efficient, what else could we + # do? + itemsize = max_len_string_array(ensure_object(column._values)) + return max(itemsize, 1) + elif dtype.type is np.float64: + return 255 + elif dtype.type is np.float32: + return 254 + elif dtype.type is np.int32: + return 253 + elif dtype.type is np.int16: + return 252 + elif dtype.type is np.int8: + return 251 + else: # pragma : no cover + raise NotImplementedError(f"Data type {dtype} not supported.") + + +def _dtype_to_default_stata_fmt( + dtype: np.dtype, column: Series, dta_version: int = 114, force_strl: bool = False +) -> str: + """ + Map numpy dtype to stata's default format for this type. Not terribly + important since users can change this in Stata. Semantics are + + object -> "%DDs" where DD is the length of the string. If not a string, + raise ValueError + float64 -> "%10.0g" + float32 -> "%9.0g" + int64 -> "%9.0g" + int32 -> "%12.0g" + int16 -> "%8.0g" + int8 -> "%8.0g" + strl -> "%9s" + """ + # TODO: Refactor to combine type with format + # TODO: expand this to handle a default datetime format? + if dta_version < 117: + max_str_len = 244 + else: + max_str_len = 2045 + if force_strl: + return "%9s" + if dtype.type is np.object_: + itemsize = max_len_string_array(ensure_object(column._values)) + if itemsize > max_str_len: + if dta_version >= 117: + return "%9s" + else: + raise ValueError(excessive_string_length_error.format(column.name)) + return "%" + str(max(itemsize, 1)) + "s" + elif dtype == np.float64: + return "%10.0g" + elif dtype == np.float32: + return "%9.0g" + elif dtype == np.int32: + return "%12.0g" + elif dtype in (np.int8, np.int16): + return "%8.0g" + else: # pragma : no cover + raise NotImplementedError(f"Data type {dtype} not supported.") + + +class StataWriter(StataParser): + """ + A class for writing Stata binary dta files + + Parameters + ---------- + fname : path (string), buffer or path object + string, pathlib.Path or + object implementing a binary write() functions. If using a buffer + then the buffer will not be automatically closed after the file + is written. + data : DataFrame + Input to save + convert_dates : dict + Dictionary mapping columns containing datetime types to stata internal + format to use when writing the dates. Options are 'tc', 'td', 'tm', + 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name. + Datetime columns that do not have a conversion type specified will be + converted to 'tc'. Raises NotImplementedError if a datetime column has + timezone information + write_index : bool + Write the index to Stata dataset. + byteorder : str + Can be ">", "<", "little", or "big". default is `sys.byteorder` + time_stamp : datetime + A datetime to use as file creation date. Default is the current time + data_label : str + A label for the data set. Must be 80 characters or smaller. + variable_labels : dict + Dictionary containing columns as keys and variable labels as values. + Each label must be 80 characters or smaller. + compression : str or dict, default 'infer' + For on-the-fly compression of the output data. If 'infer' and 'fname' is + path-like, then detect compression from the following extensions: '.gz', + '.bz2', '.zip', '.xz', '.zst', '.tar', '.tar.gz', '.tar.xz' or '.tar.bz2' + (otherwise no compression). + Set to ``None`` for no compression. + Can also be a dict with key ``'method'`` set + to one of {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} + and other key-value pairs are forwarded to + ``zipfile.ZipFile``, ``gzip.GzipFile``, + ``bz2.BZ2File``, ``zstandard.ZstdCompressor``, ``lzma.LZMAFile`` or + ``tarfile.TarFile``, respectively. + As an example, the following could be passed for faster compression and to + create a reproducible gzip archive: + ``compression={'method': 'gzip', 'compresslevel': 1, 'mtime': 1}``. + storage_options : dict, optional + Extra options that make sense for a particular storage connection, e.g. + host, port, username, password, etc. For HTTP(S) URLs the key-value pairs + are forwarded to ``urllib.request.Request`` as header options. For other + URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are + forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more + details, and for more examples on storage options refer `here + `_. + + value_labels : dict of dicts + Dictionary containing columns as keys and dictionaries of column value + to labels as values. The combined length of all labels for a single + variable must be 32,000 characters or smaller. + + Returns + ------- + writer : StataWriter instance + The StataWriter instance has a write_file method, which will + write the file to the given `fname`. + + Raises + ------ + NotImplementedError + * If datetimes contain timezone information + ValueError + * Columns listed in convert_dates are neither datetime64[ns] + or datetime + * Column dtype is not representable in Stata + * Column listed in convert_dates is not in DataFrame + * Categorical label contains more than 32,000 characters + + Examples + -------- + >>> data = pd.DataFrame([[1.0, 1]], columns=["a", "b"]) + >>> writer = StataWriter("./data_file.dta", data) + >>> writer.write_file() + + Directly write a zip file + >>> compression = {"method": "zip", "archive_name": "data_file.dta"} + >>> writer = StataWriter("./data_file.zip", data, compression=compression) + >>> writer.write_file() + + Save a DataFrame with dates + >>> from datetime import datetime + >>> data = pd.DataFrame([[datetime(2000, 1, 1)]], columns=["date"]) + >>> writer = StataWriter("./date_data_file.dta", data, {"date": "tw"}) + >>> writer.write_file() + """ + + _max_string_length = 244 + _encoding: Literal["latin-1", "utf-8"] = "latin-1" + + def __init__( + self, + fname: FilePath | WriteBuffer[bytes], + data: DataFrame, + convert_dates: dict[Hashable, str] | None = None, + write_index: bool = True, + byteorder: str | None = None, + time_stamp: datetime | None = None, + data_label: str | None = None, + variable_labels: dict[Hashable, str] | None = None, + compression: CompressionOptions = "infer", + storage_options: StorageOptions | None = None, + *, + value_labels: dict[Hashable, dict[float, str]] | None = None, + ) -> None: + super().__init__() + self.data = data + self._convert_dates = {} if convert_dates is None else convert_dates + self._write_index = write_index + self._time_stamp = time_stamp + self._data_label = data_label + self._variable_labels = variable_labels + self._non_cat_value_labels = value_labels + self._value_labels: list[StataValueLabel] = [] + self._has_value_labels = np.array([], dtype=bool) + self._compression = compression + self._output_file: IO[bytes] | None = None + self._converted_names: dict[Hashable, str] = {} + # attach nobs, nvars, data, varlist, typlist + self._prepare_pandas(data) + self.storage_options = storage_options + + if byteorder is None: + byteorder = sys.byteorder + self._byteorder = _set_endianness(byteorder) + self._fname = fname + self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8} + + def _write(self, to_write: str) -> None: + """ + Helper to call encode before writing to file for Python 3 compat. + """ + self.handles.handle.write(to_write.encode(self._encoding)) + + def _write_bytes(self, value: bytes) -> None: + """ + Helper to assert file is open before writing. + """ + self.handles.handle.write(value) + + def _prepare_non_cat_value_labels( + self, data: DataFrame + ) -> list[StataNonCatValueLabel]: + """ + Check for value labels provided for non-categorical columns. Value + labels + """ + non_cat_value_labels: list[StataNonCatValueLabel] = [] + if self._non_cat_value_labels is None: + return non_cat_value_labels + + for labname, labels in self._non_cat_value_labels.items(): + if labname in self._converted_names: + colname = self._converted_names[labname] + elif labname in data.columns: + colname = str(labname) + else: + raise KeyError( + f"Can't create value labels for {labname}, it wasn't " + "found in the dataset." + ) + + if not is_numeric_dtype(data[colname].dtype): + # Labels should not be passed explicitly for categorical + # columns that will be converted to int + raise ValueError( + f"Can't create value labels for {labname}, value labels " + "can only be applied to numeric columns." + ) + svl = StataNonCatValueLabel(colname, labels, self._encoding) + non_cat_value_labels.append(svl) + return non_cat_value_labels + + def _prepare_categoricals(self, data: DataFrame) -> DataFrame: + """ + Check for categorical columns, retain categorical information for + Stata file and convert categorical data to int + """ + is_cat = [isinstance(dtype, CategoricalDtype) for dtype in data.dtypes] + if not any(is_cat): + return data + + self._has_value_labels |= np.array(is_cat) + + get_base_missing_value = StataMissingValue.get_base_missing_value + data_formatted = [] + for col, col_is_cat in zip(data, is_cat, strict=True): + if col_is_cat: + svl = StataValueLabel(data[col], encoding=self._encoding) + self._value_labels.append(svl) + dtype = data[col].cat.codes.dtype + if dtype == np.int64: + raise ValueError( + "It is not possible to export " + "int64-based categorical data to Stata." + ) + values = data[col].cat.codes._values.copy() + + # Upcast if needed so that correct missing values can be set + if values.max() >= get_base_missing_value(dtype): + if dtype == np.int8: + dtype = np.dtype(np.int16) + elif dtype == np.int16: + dtype = np.dtype(np.int32) + else: + dtype = np.dtype(np.float64) + values = np.array(values, dtype=dtype) + + # Replace missing values with Stata missing value for type + values[values == -1] = get_base_missing_value(dtype) + data_formatted.append((col, values)) + else: + data_formatted.append((col, data[col])) + return DataFrame.from_dict(dict(data_formatted)) + + def _replace_nans(self, data: DataFrame) -> DataFrame: + # return data + """ + Checks floating point data columns for nans, and replaces these with + the generic Stata for missing value (.) + """ + for c in data: + dtype = data[c].dtype + if dtype in (np.float32, np.float64): + if dtype == np.float32: + replacement = self.MISSING_VALUES["f"] + else: + replacement = self.MISSING_VALUES["d"] + data[c] = data[c].fillna(replacement) + + return data + + def _update_strl_names(self) -> None: + """No-op, forward compatibility""" + + def _validate_variable_name(self, name: str) -> str: + """ + Validate variable names for Stata export. + + Parameters + ---------- + name : str + Variable name + + Returns + ------- + str + The validated name with invalid characters replaced with + underscores. + + Notes + ----- + Stata 114 and 117 support ascii characters in a-z, A-Z, 0-9 + and _. + """ + for c in name: + if ( + (c < "A" or c > "Z") + and (c < "a" or c > "z") + and (c < "0" or c > "9") + and c != "_" + ): + name = name.replace(c, "_") + return name + + def _check_column_names(self, data: DataFrame) -> DataFrame: + """ + Checks column names to ensure that they are valid Stata column names. + This includes checks for: + * Non-string names + * Stata keywords + * Variables that start with numbers + * Variables with names that are too long + + When an illegal variable name is detected, it is converted, and if + dates are exported, the variable name is propagated to the date + conversion dictionary + """ + converted_names: dict[Hashable, str] = {} + columns = list(data.columns) + original_columns = columns[:] + + duplicate_var_id = 0 + for j, name in enumerate(columns): + orig_name = name + if not isinstance(name, str): + name = str(name) + + name = self._validate_variable_name(name) + + # Variable name must not be a reserved word + if name in self.RESERVED_WORDS: + name = "_" + name + + # Variable name may not start with a number + if "0" <= name[0] <= "9": + name = "_" + name + + name = name[: min(len(name), 32)] + + if not name == orig_name: + # check for duplicates + while columns.count(name) > 0: + # prepend ascending number to avoid duplicates + name = "_" + str(duplicate_var_id) + name + name = name[: min(len(name), 32)] + duplicate_var_id += 1 + converted_names[orig_name] = name + + columns[j] = name + + data.columns = Index(columns) + + # Check date conversion, and fix key if needed + if self._convert_dates: + for c, o in zip(columns, original_columns, strict=True): + if c != o: + self._convert_dates[c] = self._convert_dates[o] + del self._convert_dates[o] + + if converted_names: + conversion_warning = [] + for orig_name, name in converted_names.items(): + msg = f"{orig_name} -> {name}" + conversion_warning.append(msg) + + ws = invalid_name_doc.format("\n ".join(conversion_warning)) + warnings.warn( + ws, + InvalidColumnName, + stacklevel=find_stack_level(), + ) + + self._converted_names = converted_names + self._update_strl_names() + + return data + + def _set_formats_and_types(self, dtypes: Series) -> None: + self.fmtlist: list[str] = [] + self.typlist: list[int] = [] + for col, dtype in dtypes.items(): + self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, self.data[col])) + self.typlist.append(_dtype_to_stata_type(dtype, self.data[col])) + + def _prepare_pandas(self, data: DataFrame) -> None: + # NOTE: we might need a different API / class for pandas objects so + # we can set different semantics - handle this with a PR to pandas.io + + data = data.copy() + + if self._write_index: + temp = data.reset_index() + if isinstance(temp, DataFrame): + data = temp + + # Ensure column names are strings + data = self._check_column_names(data) + + # Check columns for compatibility with stata, upcast if necessary + # Raise if outside the supported range + data = _cast_to_stata_types(data) + + # Replace NaNs with Stata missing values + data = self._replace_nans(data) + + # Set all columns to initially unlabelled + self._has_value_labels = np.repeat(False, data.shape[1]) + + # Create value labels for non-categorical data + non_cat_value_labels = self._prepare_non_cat_value_labels(data) + + non_cat_columns = [svl.labname for svl in non_cat_value_labels] + has_non_cat_val_labels = data.columns.isin(non_cat_columns) + self._has_value_labels |= has_non_cat_val_labels + self._value_labels.extend(non_cat_value_labels) + + # Convert categoricals to int data, and strip labels + data = self._prepare_categoricals(data) + + self.nobs, self.nvar = data.shape + self.data = data + self.varlist = data.columns.tolist() + + dtypes = data.dtypes + + # Ensure all date columns are converted + for col in data: + if col in self._convert_dates: + continue + if lib.is_np_dtype(data[col].dtype, "M"): + self._convert_dates[col] = "tc" + + self._convert_dates = _maybe_convert_to_int_keys( + self._convert_dates, self.varlist + ) + for key in self._convert_dates: + new_type = _convert_datetime_to_stata_type(self._convert_dates[key]) + dtypes.iloc[key] = np.dtype(new_type) + + # Verify object arrays are strings and encode to bytes + self._encode_strings() + + self._set_formats_and_types(dtypes) + + # set the given format for the datetime cols + if self._convert_dates is not None: + for key in self._convert_dates: + if isinstance(key, int): + self.fmtlist[key] = self._convert_dates[key] + + def _encode_strings(self) -> None: + """ + Encode strings in dta-specific encoding + + Do not encode columns marked for date conversion or for strL + conversion. The strL converter independently handles conversion and + also accepts empty string arrays. + """ + convert_dates = self._convert_dates + # _convert_strl is not available in dta 114 + convert_strl = getattr(self, "_convert_strl", []) + for i, col in enumerate(self.data): + # Skip columns marked for date conversion or strl conversion + if i in convert_dates or col in convert_strl: + continue + column = self.data[col] + dtype = column.dtype + # TODO could also handle string dtype here specifically + if dtype.type is np.object_: + inferred_dtype = infer_dtype(column, skipna=True) + if not ((inferred_dtype == "string") or len(column) == 0): + col = column.name + raise ValueError( + f"""\ +Column `{col}` cannot be exported.\n\nOnly string-like object arrays +containing all strings or a mix of strings and None can be exported. +Object arrays containing only null values are prohibited. Other object +types cannot be exported and must first be converted to one of the +supported types.""" + ) + encoded = self.data[col].str.encode(self._encoding) + # If larger than _max_string_length do nothing + if ( + max_len_string_array(ensure_object(self.data[col]._values)) + <= self._max_string_length + ): + self.data[col] = encoded + + def write_file(self) -> None: + """ + Export DataFrame object to Stata dta format. + + This method writes the contents of a pandas DataFrame to a `.dta` file + compatible with Stata. It includes features for handling value labels, + variable types, and metadata like timestamps and data labels. The output + file can then be read and used in Stata or other compatible statistical + tools. + + See Also + -------- + read_stata : Read Stata file into DataFrame. + DataFrame.to_stata : Export DataFrame object to Stata dta format. + io.stata.StataWriter : A class for writing Stata binary dta files. + + Examples + -------- + >>> df = pd.DataFrame( + ... { + ... "fully_labelled": [1, 2, 3, 3, 1], + ... "partially_labelled": [1.0, 2.0, np.nan, 9.0, np.nan], + ... "Y": [7, 7, 9, 8, 10], + ... "Z": pd.Categorical(["j", "k", "l", "k", "j"]), + ... } + ... ) + >>> path = "/My_path/filename.dta" + >>> labels = { + ... "fully_labelled": {1: "one", 2: "two", 3: "three"}, + ... "partially_labelled": {1.0: "one", 2.0: "two"}, + ... } + >>> writer = pd.io.stata.StataWriter( + ... path, df, value_labels=labels + ... ) # doctest: +SKIP + >>> writer.write_file() # doctest: +SKIP + >>> df = pd.read_stata(path) # doctest: +SKIP + >>> df # doctest: +SKIP + index fully_labelled partially_labeled Y Z + 0 0 one one 7 j + 1 1 two two 7 k + 2 2 three NaN 9 l + 3 3 three 9.0 8 k + 4 4 one NaN 10 j + """ + with get_handle( + self._fname, + "wb", + compression=self._compression, + is_text=False, + storage_options=self.storage_options, + ) as self.handles: + if self.handles.compression["method"] is not None: + # ZipFile creates a file (with the same name) for each write call. + # Write it first into a buffer and then write the buffer to the ZipFile. + self._output_file, self.handles.handle = self.handles.handle, BytesIO() + self.handles.created_handles.append(self.handles.handle) + + try: + self._write_header( + data_label=self._data_label, time_stamp=self._time_stamp + ) + self._write_map() + self._write_variable_types() + self._write_varnames() + self._write_sortlist() + self._write_formats() + self._write_value_label_names() + self._write_variable_labels() + self._write_expansion_fields() + self._write_characteristics() + records = self._prepare_data() + self._write_data(records) + self._write_strls() + self._write_value_labels() + self._write_file_close_tag() + self._write_map() + self._close() + except Exception as exc: + self.handles.close() + if isinstance(self._fname, (str, os.PathLike)) and os.path.isfile( + self._fname + ): + try: + os.unlink(self._fname) + except OSError: + warnings.warn( + f"This save was not successful but {self._fname} could not " + "be deleted. This file is not valid.", + ResourceWarning, + stacklevel=find_stack_level(), + ) + raise exc + + def _close(self) -> None: + """ + Close the file if it was created by the writer. + + If a buffer or file-like object was passed in, for example a GzipFile, + then leave this file open for the caller to close. + """ + # write compression + if self._output_file is not None: + assert isinstance(self.handles.handle, BytesIO) + bio, self.handles.handle = self.handles.handle, self._output_file + self.handles.handle.write(bio.getvalue()) + + def _write_map(self) -> None: + """No-op, future compatibility""" + + def _write_file_close_tag(self) -> None: + """No-op, future compatibility""" + + def _write_characteristics(self) -> None: + """No-op, future compatibility""" + + def _write_strls(self) -> None: + """No-op, future compatibility""" + + def _write_expansion_fields(self) -> None: + """Write 5 zeros for expansion fields""" + self._write(_pad_bytes("", 5)) + + def _write_value_labels(self) -> None: + for vl in self._value_labels: + self._write_bytes(vl.generate_value_label(self._byteorder)) + + def _write_header( + self, + data_label: str | None = None, + time_stamp: datetime | None = None, + ) -> None: + byteorder = self._byteorder + # ds_format - just use 114 + self._write_bytes(struct.pack("b", 114)) + # byteorder + self._write((byteorder == ">" and "\x01") or "\x02") + # filetype + self._write("\x01") + # unused + self._write("\x00") + # number of vars, 2 bytes + self._write_bytes(struct.pack(byteorder + "h", self.nvar)[:2]) + # number of obs, 4 bytes + self._write_bytes(struct.pack(byteorder + "i", self.nobs)[:4]) + # data label 81 bytes, char, null terminated + if data_label is None: + self._write_bytes(self._null_terminate_bytes(_pad_bytes("", 80))) + else: + self._write_bytes( + self._null_terminate_bytes(_pad_bytes(data_label[:80], 80)) + ) + # time stamp, 18 bytes, char, null terminated + # format dd Mon yyyy hh:mm + if time_stamp is None: + time_stamp = datetime.now() + elif not isinstance(time_stamp, datetime): + raise ValueError("time_stamp should be datetime type") + # GH #13856 + # Avoid locale-specific month conversion + months = [ + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", + ] + month_lookup = {i + 1: month for i, month in enumerate(months)} + ts = ( + time_stamp.strftime("%d ") + + month_lookup[time_stamp.month] + + time_stamp.strftime(" %Y %H:%M") + ) + self._write_bytes(self._null_terminate_bytes(ts)) + + def _write_variable_types(self) -> None: + for typ in self.typlist: + self._write_bytes(struct.pack("B", typ)) + + def _write_varnames(self) -> None: + # varlist names are checked by _check_column_names + # varlist, requires null terminated + for name in self.varlist: + name = self._null_terminate_str(name) + name = _pad_bytes(name[:32], 33) + self._write(name) + + def _write_sortlist(self) -> None: + # srtlist, 2*(nvar+1), int array, encoded by byteorder + srtlist = _pad_bytes("", 2 * (self.nvar + 1)) + self._write(srtlist) + + def _write_formats(self) -> None: + # fmtlist, 49*nvar, char array + for fmt in self.fmtlist: + self._write(_pad_bytes(fmt, 49)) + + def _write_value_label_names(self) -> None: + # lbllist, 33*nvar, char array + for i in range(self.nvar): + # Use variable name when categorical + if self._has_value_labels[i]: + name = self.varlist[i] + name = self._null_terminate_str(name) + name = _pad_bytes(name[:32], 33) + self._write(name) + else: # Default is empty label + self._write(_pad_bytes("", 33)) + + def _write_variable_labels(self) -> None: + # Missing labels are 80 blank characters plus null termination + blank = _pad_bytes("", 81) + + if self._variable_labels is None: + for i in range(self.nvar): + self._write(blank) + return + + for col in self.data: + if col in self._variable_labels: + label = self._variable_labels[col] + if len(label) > 80: + raise ValueError("Variable labels must be 80 characters or fewer") + is_latin1 = all(ord(c) < 256 for c in label) + if not is_latin1: + raise ValueError( + "Variable labels must contain only characters that " + "can be encoded in Latin-1" + ) + self._write(_pad_bytes(label, 81)) + else: + self._write(blank) + + def _convert_strls(self, data: DataFrame) -> DataFrame: + """No-op, future compatibility""" + return data + + def _prepare_data(self) -> np.rec.recarray: + data = self.data + typlist = self.typlist + convert_dates = self._convert_dates + # 1. Convert dates + if self._convert_dates is not None: + for i, col in enumerate(data): + if i in convert_dates: + data[col] = _datetime_to_stata_elapsed_vec( + data[col], self.fmtlist[i] + ) + # 2. Convert strls + data = self._convert_strls(data) + + # 3. Convert bad string data to '' and pad to correct length + dtypes = {} + native_byteorder = self._byteorder == _set_endianness(sys.byteorder) + for i, col in enumerate(data): + typ = typlist[i] + if typ <= self._max_string_length: + dc = data[col].fillna("") + data[col] = dc.apply(_pad_bytes, args=(typ,)) + stype = f"S{typ}" + dtypes[col] = stype + data[col] = data[col].astype(stype) + else: + dtype = data[col].dtype + if not native_byteorder: + dtype = dtype.newbyteorder(self._byteorder) + dtypes[col] = dtype + + return data.to_records(index=False, column_dtypes=dtypes) + + def _write_data(self, records: np.rec.recarray) -> None: + self._write_bytes(records.tobytes()) + + @staticmethod + def _null_terminate_str(s: str) -> str: + s += "\x00" + return s + + def _null_terminate_bytes(self, s: str) -> bytes: + return self._null_terminate_str(s).encode(self._encoding) + + +def _dtype_to_stata_type_117(dtype: np.dtype, column: Series, force_strl: bool) -> int: + """ + Converts dtype types to stata types. Returns the byte of the given ordinal. + See TYPE_MAP and comments for an explanation. This is also explained in + the dta spec. + 1 - 2045 are strings of this length + Pandas Stata + 32768 - for object strL + 65526 - for int8 byte + 65527 - for int16 int + 65528 - for int32 long + 65529 - for float32 float + 65530 - for double double + + If there are dates to convert, then dtype will already have the correct + type inserted. + """ + # TODO: expand to handle datetime to integer conversion + if force_strl: + return 32768 + if dtype.type is np.object_: # try to coerce it to the biggest string + # not memory efficient, what else could we + # do? + itemsize = max_len_string_array(ensure_object(column._values)) + itemsize = max(itemsize, 1) + if itemsize <= 2045: + return itemsize + return 32768 + elif dtype.type is np.float64: + return 65526 + elif dtype.type is np.float32: + return 65527 + elif dtype.type is np.int32: + return 65528 + elif dtype.type is np.int16: + return 65529 + elif dtype.type is np.int8: + return 65530 + else: # pragma : no cover + raise NotImplementedError(f"Data type {dtype} not supported.") + + +def _pad_bytes_new(name: str | bytes, length: int) -> bytes: + """ + Takes a bytes instance and pads it with null bytes until it's length chars. + """ + if isinstance(name, str): + name = bytes(name, "utf-8") + return name + b"\x00" * (length - len(name)) + + +class StataStrLWriter: + """ + Converter for Stata StrLs + + Stata StrLs map 8 byte values to strings which are stored using a + dictionary-like format where strings are keyed to two values. + + Parameters + ---------- + df : DataFrame + DataFrame to convert + columns : Sequence[str] + List of columns names to convert to StrL + version : int, optional + dta version. Currently supports 117, 118 and 119 + byteorder : str, optional + Can be ">", "<", "little", or "big". default is `sys.byteorder` + + Notes + ----- + Supports creation of the StrL block of a dta file for dta versions + 117, 118 and 119. These differ in how the GSO is stored. 118 and + 119 store the GSO lookup value as a uint32 and a uint64, while 117 + uses two uint32s. 118 and 119 also encode all strings as unicode + which is required by the format. 117 uses 'latin-1' a fixed width + encoding that extends the 7-bit ascii table with an additional 128 + characters. + """ + + def __init__( + self, + df: DataFrame, + columns: Sequence[str], + version: int = 117, + byteorder: str | None = None, + ) -> None: + if version not in (117, 118, 119): + raise ValueError("Only dta versions 117, 118 and 119 supported") + self._dta_ver = version + + self.df = df + self.columns = columns + self._gso_table = {"": (0, 0)} + if byteorder is None: + byteorder = sys.byteorder + self._byteorder = _set_endianness(byteorder) + # Flag whether chosen byteorder matches the system on which we're running + self._native_byteorder = self._byteorder == _set_endianness(sys.byteorder) + + gso_v_type = "I" # uint32 + gso_o_type = "Q" # uint64 + self._encoding = "utf-8" + if version == 117: + o_size = 4 + gso_o_type = "I" # 117 used uint32 + self._encoding = "latin-1" + elif version == 118: + o_size = 6 + else: # version == 119 + o_size = 5 + if self._native_byteorder: + self._o_offet = 2 ** (8 * (8 - o_size)) + else: + self._o_offet = 2 ** (8 * o_size) + self._gso_o_type = gso_o_type + self._gso_v_type = gso_v_type + + def _convert_key(self, key: tuple[int, int]) -> int: + v, o = key + if self._native_byteorder: + return v + self._o_offet * o + else: + # v, o will be swapped when applying byteorder + return o + self._o_offet * v + + def generate_table(self) -> tuple[dict[str, tuple[int, int]], DataFrame]: + """ + Generates the GSO lookup table for the DataFrame + + Returns + ------- + gso_table : dict + Ordered dictionary using the string found as keys + and their lookup position (v,o) as values + gso_df : DataFrame + DataFrame where strl columns have been converted to + (v,o) values + + Notes + ----- + Modifies the DataFrame in-place. + + The DataFrame returned encodes the (v,o) values as uint64s. The + encoding depends on the dta version, and can be expressed as + + enc = v + o * 2 ** (o_size * 8) + + so that v is stored in the lower bits and o is in the upper + bits. o_size is + + * 117: 4 + * 118: 6 + * 119: 5 + """ + gso_table = self._gso_table + gso_df = self.df + columns = list(gso_df.columns) + selected = gso_df[self.columns] + col_index = [(col, columns.index(col)) for col in self.columns] + keys = np.empty(selected.shape, dtype=np.uint64) + for o, (idx, row) in enumerate(selected.iterrows()): + for j, (col, v) in enumerate(col_index): + val = row[col] + # Allow columns with mixed str and None or pd.NA (GH 23633) + val = "" if isna(val) else val + key = gso_table.get(val, None) + if key is None: + # Stata prefers human numbers + key = (v + 1, o + 1) + gso_table[val] = key + keys[o, j] = self._convert_key(key) + for i, col in enumerate(self.columns): + gso_df[col] = keys[:, i] + + return gso_table, gso_df + + def generate_blob(self, gso_table: dict[str, tuple[int, int]]) -> bytes: + """ + Generates the binary blob of GSOs that is written to the dta file. + + Parameters + ---------- + gso_table : dict + Ordered dictionary (str, vo) + + Returns + ------- + gso : bytes + Binary content of dta file to be placed between strl tags + + Notes + ----- + Output format depends on dta version. 117 uses two uint32s to + express v and o while 118+ uses a uint32 for v and a uint64 for o. + """ + # Format information + # Length includes null term + # 117 + # GSOvvvvooootllllxxxxxxxxxxxxxxx...x + # 3 u4 u4 u1 u4 string + null term + # + # 118, 119 + # GSOvvvvooooooootllllxxxxxxxxxxxxxxx...x + # 3 u4 u8 u1 u4 string + null term + + bio = BytesIO() + gso = bytes("GSO", "ascii") + gso_type = struct.pack(self._byteorder + "B", 130) + null = struct.pack(self._byteorder + "B", 0) + v_type = self._byteorder + self._gso_v_type + o_type = self._byteorder + self._gso_o_type + len_type = self._byteorder + "I" + for strl, vo in gso_table.items(): + if vo == (0, 0): + continue + v, o = vo + + # GSO + bio.write(gso) + + # vvvv + bio.write(struct.pack(v_type, v)) + + # oooo / oooooooo + bio.write(struct.pack(o_type, o)) + + # t + bio.write(gso_type) + + # llll + if isinstance(strl, str): + strl_convert = bytes(strl, "utf-8") + else: + strl_convert = strl + + bio.write(struct.pack(len_type, len(strl_convert) + 1)) + + # xxx...xxx + bio.write(strl_convert) + bio.write(null) + + return bio.getvalue() + + +class StataWriter117(StataWriter): + """ + A class for writing Stata binary dta files in Stata 13 format (117) + + Parameters + ---------- + fname : path (string), buffer or path object + string, pathlib.Path or + object implementing a binary write() functions. If using a buffer + then the buffer will not be automatically closed after the file + is written. + data : DataFrame + Input to save + convert_dates : dict + Dictionary mapping columns containing datetime types to stata internal + format to use when writing the dates. Options are 'tc', 'td', 'tm', + 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name. + Datetime columns that do not have a conversion type specified will be + converted to 'tc'. Raises NotImplementedError if a datetime column has + timezone information + write_index : bool + Write the index to Stata dataset. + byteorder : str + Can be ">", "<", "little", or "big". default is `sys.byteorder` + time_stamp : datetime + A datetime to use as file creation date. Default is the current time + data_label : str + A label for the data set. Must be 80 characters or smaller. + variable_labels : dict + Dictionary containing columns as keys and variable labels as values. + Each label must be 80 characters or smaller. + convert_strl : list + List of columns names to convert to Stata StrL format. Columns with + more than 2045 characters are automatically written as StrL. + Smaller columns can be converted by including the column name. Using + StrLs can reduce output file size when strings are longer than 8 + characters, and either frequently repeated or sparse. + {compression_options} + + value_labels : dict of dicts + Dictionary containing columns as keys and dictionaries of column value + to labels as values. The combined length of all labels for a single + variable must be 32,000 characters or smaller. + + Returns + ------- + writer : StataWriter117 instance + The StataWriter117 instance has a write_file method, which will + write the file to the given `fname`. + + Raises + ------ + NotImplementedError + * If datetimes contain timezone information + ValueError + * Columns listed in convert_dates are neither datetime64[ns] + or datetime + * Column dtype is not representable in Stata + * Column listed in convert_dates is not in DataFrame + * Categorical label contains more than 32,000 characters + + Examples + -------- + >>> data = pd.DataFrame([[1.0, 1, "a"]], columns=["a", "b", "c"]) + >>> writer = pd.io.stata.StataWriter117("./data_file.dta", data) + >>> writer.write_file() + + Directly write a zip file + >>> compression = {"method": "zip", "archive_name": "data_file.dta"} + >>> writer = pd.io.stata.StataWriter117( + ... "./data_file.zip", data, compression=compression + ... ) + >>> writer.write_file() + + Or with long strings stored in strl format + >>> data = pd.DataFrame( + ... [["A relatively long string"], [""], [""]], columns=["strls"] + ... ) + >>> writer = pd.io.stata.StataWriter117( + ... "./data_file_with_long_strings.dta", data, convert_strl=["strls"] + ... ) + >>> writer.write_file() + """ + + _max_string_length = 2045 + _dta_version = 117 + + def __init__( + self, + fname: FilePath | WriteBuffer[bytes], + data: DataFrame, + convert_dates: dict[Hashable, str] | None = None, + write_index: bool = True, + byteorder: str | None = None, + time_stamp: datetime | None = None, + data_label: str | None = None, + variable_labels: dict[Hashable, str] | None = None, + convert_strl: Sequence[Hashable] | None = None, + compression: CompressionOptions = "infer", + storage_options: StorageOptions | None = None, + *, + value_labels: dict[Hashable, dict[float, str]] | None = None, + ) -> None: + # Copy to new list since convert_strl might be modified later + self._convert_strl: list[Hashable] = [] + if convert_strl is not None: + self._convert_strl.extend(convert_strl) + + super().__init__( + fname, + data, + convert_dates, + write_index, + byteorder=byteorder, + time_stamp=time_stamp, + data_label=data_label, + variable_labels=variable_labels, + value_labels=value_labels, + compression=compression, + storage_options=storage_options, + ) + self._map: dict[str, int] = {} + self._strl_blob = b"" + + @staticmethod + def _tag(val: str | bytes, tag: str) -> bytes: + """Surround val with """ + if isinstance(val, str): + val = bytes(val, "utf-8") + return bytes("<" + tag + ">", "utf-8") + val + bytes("", "utf-8") + + def _update_map(self, tag: str) -> None: + """Update map location for tag with file position""" + assert self.handles.handle is not None + self._map[tag] = self.handles.handle.tell() + + def _write_header( + self, + data_label: str | None = None, + time_stamp: datetime | None = None, + ) -> None: + """Write the file header""" + byteorder = self._byteorder + self._write_bytes(bytes("", "utf-8")) + bio = BytesIO() + # ds_format - 117 + bio.write(self._tag(bytes(str(self._dta_version), "utf-8"), "release")) + # byteorder + bio.write(self._tag((byteorder == ">" and "MSF") or "LSF", "byteorder")) + # number of vars, 2 bytes in 117 and 118, 4 byte in 119 + nvar_type = "H" if self._dta_version <= 118 else "I" + bio.write(self._tag(struct.pack(byteorder + nvar_type, self.nvar), "K")) + # 117 uses 4 bytes, 118 uses 8 + nobs_size = "I" if self._dta_version == 117 else "Q" + bio.write(self._tag(struct.pack(byteorder + nobs_size, self.nobs), "N")) + # data label 81 bytes, char, null terminated + label = data_label[:80] if data_label is not None else "" + encoded_label = label.encode(self._encoding) + label_size = "B" if self._dta_version == 117 else "H" + label_len = struct.pack(byteorder + label_size, len(encoded_label)) + encoded_label = label_len + encoded_label + bio.write(self._tag(encoded_label, "label")) + # time stamp, 18 bytes, char, null terminated + # format dd Mon yyyy hh:mm + if time_stamp is None: + time_stamp = datetime.now() + elif not isinstance(time_stamp, datetime): + raise ValueError("time_stamp should be datetime type") + # Avoid locale-specific month conversion + months = [ + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", + ] + month_lookup = {i + 1: month for i, month in enumerate(months)} + ts = ( + time_stamp.strftime("%d ") + + month_lookup[time_stamp.month] + + time_stamp.strftime(" %Y %H:%M") + ) + # '\x11' added due to inspection of Stata file + stata_ts = b"\x11" + bytes(ts, "utf-8") + bio.write(self._tag(stata_ts, "timestamp")) + self._write_bytes(self._tag(bio.getvalue(), "header")) + + def _write_map(self) -> None: + """ + Called twice during file write. The first populates the values in + the map with 0s. The second call writes the final map locations when + all blocks have been written. + """ + if not self._map: + self._map = { + "stata_data": 0, + "map": self.handles.handle.tell(), + "variable_types": 0, + "varnames": 0, + "sortlist": 0, + "formats": 0, + "value_label_names": 0, + "variable_labels": 0, + "characteristics": 0, + "data": 0, + "strls": 0, + "value_labels": 0, + "stata_data_close": 0, + "end-of-file": 0, + } + # Move to start of map + self.handles.handle.seek(self._map["map"]) + bio = BytesIO() + for val in self._map.values(): + bio.write(struct.pack(self._byteorder + "Q", val)) + self._write_bytes(self._tag(bio.getvalue(), "map")) + + def _write_variable_types(self) -> None: + self._update_map("variable_types") + bio = BytesIO() + for typ in self.typlist: + bio.write(struct.pack(self._byteorder + "H", typ)) + self._write_bytes(self._tag(bio.getvalue(), "variable_types")) + + def _write_varnames(self) -> None: + self._update_map("varnames") + bio = BytesIO() + # 118 scales by 4 to accommodate utf-8 data worst case encoding + vn_len = 32 if self._dta_version == 117 else 128 + for name in self.varlist: + name = self._null_terminate_str(name) + name = _pad_bytes_new(name[:32].encode(self._encoding), vn_len + 1) + bio.write(name) + self._write_bytes(self._tag(bio.getvalue(), "varnames")) + + def _write_sortlist(self) -> None: + self._update_map("sortlist") + sort_size = 2 if self._dta_version < 119 else 4 + self._write_bytes(self._tag(b"\x00" * sort_size * (self.nvar + 1), "sortlist")) + + def _write_formats(self) -> None: + self._update_map("formats") + bio = BytesIO() + fmt_len = 49 if self._dta_version == 117 else 57 + for fmt in self.fmtlist: + bio.write(_pad_bytes_new(fmt.encode(self._encoding), fmt_len)) + self._write_bytes(self._tag(bio.getvalue(), "formats")) + + def _write_value_label_names(self) -> None: + self._update_map("value_label_names") + bio = BytesIO() + # 118 scales by 4 to accommodate utf-8 data worst case encoding + vl_len = 32 if self._dta_version == 117 else 128 + for i in range(self.nvar): + # Use variable name when categorical + name = "" # default name + if self._has_value_labels[i]: + name = self.varlist[i] + name = self._null_terminate_str(name) + encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1) + bio.write(encoded_name) + self._write_bytes(self._tag(bio.getvalue(), "value_label_names")) + + def _write_variable_labels(self) -> None: + # Missing labels are 80 blank characters plus null termination + self._update_map("variable_labels") + bio = BytesIO() + # 118 scales by 4 to accommodate utf-8 data worst case encoding + vl_len = 80 if self._dta_version == 117 else 320 + blank = _pad_bytes_new("", vl_len + 1) + + if self._variable_labels is None: + for _ in range(self.nvar): + bio.write(blank) + self._write_bytes(self._tag(bio.getvalue(), "variable_labels")) + return + + for col in self.data: + if col in self._variable_labels: + label = self._variable_labels[col] + if len(label) > 80: + raise ValueError("Variable labels must be 80 characters or fewer") + try: + encoded = label.encode(self._encoding) + except UnicodeEncodeError as err: + raise ValueError( + "Variable labels must contain only characters that " + f"can be encoded in {self._encoding}" + ) from err + + bio.write(_pad_bytes_new(encoded, vl_len + 1)) + else: + bio.write(blank) + self._write_bytes(self._tag(bio.getvalue(), "variable_labels")) + + def _write_characteristics(self) -> None: + self._update_map("characteristics") + self._write_bytes(self._tag(b"", "characteristics")) + + def _write_data(self, records: np.rec.recarray) -> None: + self._update_map("data") + self._write_bytes(b"") + self._write_bytes(records.tobytes()) + self._write_bytes(b"") + + def _write_strls(self) -> None: + self._update_map("strls") + self._write_bytes(self._tag(self._strl_blob, "strls")) + + def _write_expansion_fields(self) -> None: + """No-op in dta 117+""" + + def _write_value_labels(self) -> None: + self._update_map("value_labels") + bio = BytesIO() + for vl in self._value_labels: + lab = vl.generate_value_label(self._byteorder) + lab = self._tag(lab, "lbl") + bio.write(lab) + self._write_bytes(self._tag(bio.getvalue(), "value_labels")) + + def _write_file_close_tag(self) -> None: + self._update_map("stata_data_close") + self._write_bytes(bytes("", "utf-8")) + self._update_map("end-of-file") + + def _update_strl_names(self) -> None: + """ + Update column names for conversion to strl if they might have been + changed to comply with Stata naming rules + """ + # Update convert_strl if names changed + for orig, new in self._converted_names.items(): + if orig in self._convert_strl: + idx = self._convert_strl.index(orig) + self._convert_strl[idx] = new + + def _convert_strls(self, data: DataFrame) -> DataFrame: + """ + Convert columns to StrLs if either very large or in the + convert_strl variable + """ + convert_cols = [ + col + for i, col in enumerate(data) + if self.typlist[i] == 32768 or col in self._convert_strl + ] + + if convert_cols: + ssw = StataStrLWriter( + data, convert_cols, version=self._dta_version, byteorder=self._byteorder + ) + tab, new_data = ssw.generate_table() + data = new_data + self._strl_blob = ssw.generate_blob(tab) + return data + + def _set_formats_and_types(self, dtypes: Series) -> None: + self.typlist = [] + self.fmtlist = [] + for col, dtype in dtypes.items(): + force_strl = col in self._convert_strl + fmt = _dtype_to_default_stata_fmt( + dtype, + self.data[col], + dta_version=self._dta_version, + force_strl=force_strl, + ) + self.fmtlist.append(fmt) + self.typlist.append( + _dtype_to_stata_type_117(dtype, self.data[col], force_strl) + ) + + +class StataWriterUTF8(StataWriter117): + """ + Stata binary dta file writing in Stata 15 (118) and 16 (119) formats + + DTA 118 and 119 format files support unicode string data (both fixed + and strL) format. Unicode is also supported in value labels, variable + labels and the dataset label. Format 119 is automatically used if the + file contains more than 32,767 variables. + + Parameters + ---------- + fname : path (string), buffer or path object + string, pathlib.Path or + object implementing a binary write() functions. If using a buffer + then the buffer will not be automatically closed after the file + is written. + data : DataFrame + Input to save + convert_dates : dict, default None + Dictionary mapping columns containing datetime types to stata internal + format to use when writing the dates. Options are 'tc', 'td', 'tm', + 'tw', 'th', 'tq', 'ty'. Column can be either an integer or a name. + Datetime columns that do not have a conversion type specified will be + converted to 'tc'. Raises NotImplementedError if a datetime column has + timezone information + write_index : bool, default True + Write the index to Stata dataset. + byteorder : str, default None + Can be ">", "<", "little", or "big". default is `sys.byteorder` + time_stamp : datetime, default None + A datetime to use as file creation date. Default is the current time + data_label : str, default None + A label for the data set. Must be 80 characters or smaller. + variable_labels : dict, default None + Dictionary containing columns as keys and variable labels as values. + Each label must be 80 characters or smaller. + convert_strl : list, default None + List of columns names to convert to Stata StrL format. Columns with + more than 2045 characters are automatically written as StrL. + Smaller columns can be converted by including the column name. Using + StrLs can reduce output file size when strings are longer than 8 + characters, and either frequently repeated or sparse. + version : int, default None + The dta version to use. By default, uses the size of data to determine + the version. 118 is used if data.shape[1] <= 32767, and 119 is used + for storing larger DataFrames. + {compression_options} + + value_labels : dict of dicts + Dictionary containing columns as keys and dictionaries of column value + to labels as values. The combined length of all labels for a single + variable must be 32,000 characters or smaller. + + Returns + ------- + StataWriterUTF8 + The instance has a write_file method, which will write the file to the + given `fname`. + + Raises + ------ + NotImplementedError + * If datetimes contain timezone information + ValueError + * Columns listed in convert_dates are neither datetime64[ns] + or datetime + * Column dtype is not representable in Stata + * Column listed in convert_dates is not in DataFrame + * Categorical label contains more than 32,000 characters + + Examples + -------- + Using Unicode data and column names + + >>> from pandas.io.stata import StataWriterUTF8 + >>> data = pd.DataFrame([[1.0, 1, "ᴬ"]], columns=["a", "β", "ĉ"]) + >>> writer = StataWriterUTF8("./data_file.dta", data) + >>> writer.write_file() + + Directly write a zip file + >>> compression = {"method": "zip", "archive_name": "data_file.dta"} + >>> writer = StataWriterUTF8("./data_file.zip", data, compression=compression) + >>> writer.write_file() + + Or with long strings stored in strl format + + >>> data = pd.DataFrame( + ... [["ᴀ relatively long ŝtring"], [""], [""]], columns=["strls"] + ... ) + >>> writer = StataWriterUTF8( + ... "./data_file_with_long_strings.dta", data, convert_strl=["strls"] + ... ) + >>> writer.write_file() + """ + + _encoding: Literal["utf-8"] = "utf-8" + + def __init__( + self, + fname: FilePath | WriteBuffer[bytes], + data: DataFrame, + convert_dates: dict[Hashable, str] | None = None, + write_index: bool = True, + byteorder: str | None = None, + time_stamp: datetime | None = None, + data_label: str | None = None, + variable_labels: dict[Hashable, str] | None = None, + convert_strl: Sequence[Hashable] | None = None, + version: int | None = None, + compression: CompressionOptions = "infer", + storage_options: StorageOptions | None = None, + *, + value_labels: dict[Hashable, dict[float, str]] | None = None, + ) -> None: + if version is None: + version = 118 if data.shape[1] <= 32767 else 119 + elif version not in (118, 119): + raise ValueError("version must be either 118 or 119.") + elif version == 118 and data.shape[1] > 32767: + raise ValueError( + "You must use version 119 for data sets containing more than" + "32,767 variables" + ) + + super().__init__( + fname, + data, + convert_dates=convert_dates, + write_index=write_index, + byteorder=byteorder, + time_stamp=time_stamp, + data_label=data_label, + variable_labels=variable_labels, + value_labels=value_labels, + convert_strl=convert_strl, + compression=compression, + storage_options=storage_options, + ) + # Override version set in StataWriter117 init + self._dta_version = version + + def _validate_variable_name(self, name: str) -> str: + """ + Validate variable names for Stata export. + + Parameters + ---------- + name : str + Variable name + + Returns + ------- + str + The validated name with invalid characters replaced with + underscores. + + Notes + ----- + Stata 118+ support most unicode characters. The only limitation is in + the ascii range where the characters supported are a-z, A-Z, 0-9 and _. + """ + # High code points appear to be acceptable + for c in name: + if ( + ( + ord(c) < 128 + and (c < "A" or c > "Z") + and (c < "a" or c > "z") + and (c < "0" or c > "9") + and c != "_" + ) + or 128 <= ord(c) < 192 + or c in {"×", "÷"} # noqa: RUF001 + ): + name = name.replace(c, "_") + + return name diff --git a/pandas/io/xml.py b/pandas/io/xml.py new file mode 100644 index 0000000000000000000000000000000000000000..96a2c6cc5d126c02e91faeef24fc100dd028c1f2 --- /dev/null +++ b/pandas/io/xml.py @@ -0,0 +1,1155 @@ +""" +:mod:``pandas.io.xml`` is a module for reading XML. +""" + +from __future__ import annotations + +import io +from os import PathLike +from typing import ( + TYPE_CHECKING, + Any, +) + +from pandas._libs import lib +from pandas.compat._optional import import_optional_dependency +from pandas.errors import ( + AbstractMethodError, + ParserError, +) +from pandas.util._decorators import set_module +from pandas.util._validators import check_dtype_backend + +from pandas.core.dtypes.common import is_list_like + +from pandas.io.common import ( + get_handle, + infer_compression, + is_fsspec_url, + is_url, + stringify_path, +) +from pandas.io.parsers import TextParser + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Sequence, + ) + from xml.etree.ElementTree import Element + + from lxml import etree + + from pandas._typing import ( + CompressionOptions, + ConvertersArg, + DtypeArg, + DtypeBackend, + FilePath, + ParseDatesArg, + ReadBuffer, + StorageOptions, + XMLParsers, + ) + + from pandas import DataFrame + + +class _XMLFrameParser: + """ + Internal subclass to parse XML into DataFrames. + + Parameters + ---------- + path_or_buffer : a valid JSON ``str``, path object or file-like object + Any valid string path is acceptable. The string could be a URL. Valid + URL schemes include http, ftp, s3, and file. + + xpath : str or regex + The ``XPath`` expression to parse required set of nodes for + migration to :class:`~pandas.DataFrame`. ``etree`` supports limited ``XPath``. + + namespaces : dict + The namespaces defined in XML document (``xmlns:namespace='URI'``) + as dicts with key being namespace and value the URI. + + elems_only : bool + Parse only the child elements at the specified ``xpath``. + + attrs_only : bool + Parse only the attributes at the specified ``xpath``. + + names : list + Column names for :class:`~pandas.DataFrame` of parsed XML data. + + dtype : dict + Data type for data or columns. E.g. {{'a': np.float64, + 'b': np.int32, 'c': 'Int64'}} + + converters : dict, optional + Dict of functions for converting values in certain columns. Keys can + either be integers or column labels. + + parse_dates : bool or list of int or names or list of lists or dict + Converts either index or select columns to datetimes + + encoding : str + Encoding of xml object or document. + + stylesheet : str or file-like + URL, file, file-like object, or a raw string containing XSLT, + ``etree`` does not support XSLT but retained for consistency. + + iterparse : dict, optional + Dict with row element as key and list of descendant elements + and/or attributes as value to be retrieved in iterparsing of + XML document. + + compression : str or dict, default 'infer' + For on-the-fly decompression of on-disk data. If 'infer' and + 'path_or_buffer' is path-like, then detect compression from the + following extensions: '.gz', '.bz2', '.zip', '.xz', '.zst', '.tar', + '.tar.gz', '.tar.xz' or '.tar.bz2' (otherwise no compression). + If using 'zip' or 'tar', the ZIP file must contain only one data + file to be read in. Set to ``None`` for no decompression. + Can also be a dict with key ``'method'`` set to one of + {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} + and other key-value pairs are forwarded to ``zipfile.ZipFile``, + ``gzip.GzipFile``, ``bz2.BZ2File``, ``zstandard.ZstdDecompressor``, + ``lzma.LZMAFile`` or ``tarfile.TarFile``, respectively. + As an example, the following could be passed for Zstandard + decompression using a custom compression dictionary: + ``compression={'method': 'zstd', 'dict_data': my_compression_dict}``. + + storage_options : dict, optional + Extra options that make sense for a particular storage connection, + e.g. host, port, username, password, etc. For HTTP(S) URLs the + key-value pairs are forwarded to ``urllib.request.Request`` as header + options. For other URLs (e.g. starting with "s3://", and "gcs://") + the key-value pairs are forwarded to ``fsspec.open``. Please see + ``fsspec`` and ``urllib`` for more details, and for more examples on + storage options refer `here `_. + + See also + -------- + pandas.io.xml._EtreeFrameParser + pandas.io.xml._LxmlFrameParser + + Notes + ----- + To subclass this class effectively you must override the following methods:` + * :func:`parse_data` + * :func:`_parse_nodes` + * :func:`_iterparse_nodes` + * :func:`_parse_doc` + * :func:`_validate_names` + * :func:`_validate_path` + + + See each method's respective documentation for details on their + functionality. + """ + + def __init__( + self, + path_or_buffer: FilePath | ReadBuffer[bytes] | ReadBuffer[str], + xpath: str, + namespaces: dict[str, str] | None, + elems_only: bool, + attrs_only: bool, + names: Sequence[str] | None, + dtype: DtypeArg | None, + converters: ConvertersArg | None, + parse_dates: ParseDatesArg | None, + encoding: str | None, + stylesheet: FilePath | ReadBuffer[bytes] | ReadBuffer[str] | None, + iterparse: dict[str, list[str]] | None, + compression: CompressionOptions, + storage_options: StorageOptions, + ) -> None: + self.path_or_buffer = path_or_buffer + self.xpath = xpath + self.namespaces = namespaces + self.elems_only = elems_only + self.attrs_only = attrs_only + self.names = names + self.dtype = dtype + self.converters = converters + self.parse_dates = parse_dates + self.encoding = encoding + self.stylesheet = stylesheet + self.iterparse = iterparse + self.compression: CompressionOptions = compression + self.storage_options = storage_options + + def parse_data(self) -> list[dict[str, str | None]]: + """ + Parse xml data. + + This method will call the other internal methods to + validate ``xpath``, names, parse and return specific nodes. + """ + + raise AbstractMethodError(self) + + def _parse_nodes(self, elems: list[Any]) -> list[dict[str, str | None]]: + """ + Parse xml nodes. + + This method will parse the children and attributes of elements + in ``xpath``, conditionally for only elements, only attributes + or both while optionally renaming node names. + + Raises + ------ + ValueError + * If only elements and only attributes are specified. + + Notes + ----- + Namespace URIs will be removed from return node values. Also, + elements with missing children or attributes compared to siblings + will have optional keys filled with None values. + """ + + dicts: list[dict[str, str | None]] + + if self.elems_only and self.attrs_only: + raise ValueError("Either element or attributes can be parsed not both.") + if self.elems_only: + if self.names: + dicts = [ + { + **( + {el.tag: el.text} + if el.text and not el.text.isspace() + else {} + ), + **{ + nm: ch.text if ch.text else None + for nm, ch in zip(self.names, el.findall("*"), strict=True) + }, + } + for el in elems + ] + else: + dicts = [ + {ch.tag: ch.text if ch.text else None for ch in el.findall("*")} + for el in elems + ] + + elif self.attrs_only: + dicts = [ + {k: v if v else None for k, v in el.attrib.items()} for el in elems + ] + + elif self.names: + dicts = [ + { + **el.attrib, + **({el.tag: el.text} if el.text and not el.text.isspace() else {}), + **{ + nm: ch.text if ch.text else None + for nm, ch in zip(self.names, el.findall("*"), strict=False) + }, + } + for el in elems + ] + + else: + dicts = [ + { + **el.attrib, + **({el.tag: el.text} if el.text and not el.text.isspace() else {}), + **{ch.tag: ch.text if ch.text else None for ch in el.findall("*")}, + } + for el in elems + ] + + dicts = [ + {k.split("}")[1] if "}" in k else k: v for k, v in d.items()} for d in dicts + ] + + keys = list(dict.fromkeys([k for d in dicts for k in d.keys()])) + dicts = [{k: d[k] if k in d.keys() else None for k in keys} for d in dicts] + + if self.names: + dicts = [dict(zip(self.names, d.values(), strict=True)) for d in dicts] + + return dicts + + def _iterparse_nodes(self, iterparse: Callable) -> list[dict[str, str | None]]: + """ + Iterparse xml nodes. + + This method will read in local disk, decompressed XML files for elements + and underlying descendants using iterparse, a method to iterate through + an XML tree without holding entire XML tree in memory. + + Raises + ------ + TypeError + * If ``iterparse`` is not a dict or its dict value is not list-like. + ParserError + * If ``path_or_buffer`` is not a physical file on disk or file-like object. + * If no data is returned from selected items in ``iterparse``. + + Notes + ----- + Namespace URIs will be removed from return node values. Also, + elements with missing children or attributes in submitted list + will have optional keys filled with None values. + """ + + dicts: list[dict[str, str | None]] = [] + row: dict[str, str | None] | None = None + + if not isinstance(self.iterparse, dict): + raise TypeError( + f"{type(self.iterparse).__name__} is not a valid type for iterparse" + ) + + row_node = next(iter(self.iterparse.keys())) if self.iterparse else "" + if not is_list_like(self.iterparse[row_node]): + raise TypeError( + f"{type(self.iterparse[row_node])} is not a valid type " + "for value in iterparse" + ) + + if (not hasattr(self.path_or_buffer, "read")) and ( + not isinstance(self.path_or_buffer, (str, PathLike)) + or is_url(self.path_or_buffer) + or is_fsspec_url(self.path_or_buffer) + or ( + isinstance(self.path_or_buffer, str) + and self.path_or_buffer.startswith((" list[Any]: + """ + Validate ``xpath``. + + This method checks for syntax, evaluation, or empty nodes return. + + Raises + ------ + SyntaxError + * If xpah is not supported or issues with namespaces. + + ValueError + * If xpah does not return any nodes. + """ + + raise AbstractMethodError(self) + + def _validate_names(self) -> None: + """ + Validate names. + + This method will check if names is a list-like and aligns + with length of parse nodes. + + Raises + ------ + ValueError + * If value is not a list and less then length of nodes. + """ + raise AbstractMethodError(self) + + def _parse_doc( + self, raw_doc: FilePath | ReadBuffer[bytes] | ReadBuffer[str] + ) -> Element | etree._Element: + """ + Build tree from path_or_buffer. + + This method will parse XML object into tree + either from string/bytes or file location. + """ + raise AbstractMethodError(self) + + +class _EtreeFrameParser(_XMLFrameParser): + """ + Internal class to parse XML into DataFrames with the Python + standard library XML module: `xml.etree.ElementTree`. + """ + + def parse_data(self) -> list[dict[str, str | None]]: + from xml.etree.ElementTree import iterparse + + if self.stylesheet is not None: + raise ValueError( + "To use stylesheet, you need lxml installed and selected as parser." + ) + + if self.iterparse is None: + self.xml_doc = self._parse_doc(self.path_or_buffer) + elems = self._validate_path() + + self._validate_names() + + xml_dicts: list[dict[str, str | None]] = ( + self._parse_nodes(elems) + if self.iterparse is None + else self._iterparse_nodes(iterparse) + ) + + return xml_dicts + + def _validate_path(self) -> list[Any]: + """ + Notes + ----- + ``etree`` supports limited ``XPath``. If user attempts a more complex + expression syntax error will raise. + """ + + msg = ( + "xpath does not return any nodes or attributes. " + "Be sure to specify in `xpath` the parent nodes of " + "children and attributes to parse. " + "If document uses namespaces denoted with " + "xmlns, be sure to define namespaces and " + "use them in xpath." + ) + try: + elems = self.xml_doc.findall(self.xpath, namespaces=self.namespaces) + children = [ch for el in elems for ch in el.findall("*")] + attrs = {k: v for el in elems for k, v in el.attrib.items()} + + if elems is None: + raise ValueError(msg) + + if elems is not None: + if self.elems_only and children == []: + raise ValueError(msg) + if self.attrs_only and attrs == {}: + raise ValueError(msg) + if children == [] and attrs == {}: + raise ValueError(msg) + + except (KeyError, SyntaxError) as err: + raise SyntaxError( + "You have used an incorrect or unsupported XPath " + "expression for etree library or you used an " + "undeclared namespace prefix." + ) from err + + return elems + + def _validate_names(self) -> None: + children: list[Any] + + if self.names: + if self.iterparse: + children = self.iterparse[next(iter(self.iterparse))] + else: + parent = self.xml_doc.find(self.xpath, namespaces=self.namespaces) + children = parent.findall("*") if parent is not None else [] + + if is_list_like(self.names): + if len(self.names) < len(children): + raise ValueError( + "names does not match length of child elements in xpath." + ) + else: + raise TypeError( + f"{type(self.names).__name__} is not a valid type for names" + ) + + def _parse_doc( + self, raw_doc: FilePath | ReadBuffer[bytes] | ReadBuffer[str] + ) -> Element: + from xml.etree.ElementTree import ( + XMLParser, + parse, + ) + + handle_data = get_data_from_filepath( + filepath_or_buffer=raw_doc, + encoding=self.encoding, + compression=self.compression, + storage_options=self.storage_options, + ) + + with handle_data as xml_data: + curr_parser = XMLParser(encoding=self.encoding) + document = parse(xml_data, parser=curr_parser) + + return document.getroot() + + +class _LxmlFrameParser(_XMLFrameParser): + """ + Internal class to parse XML into :class:`~pandas.DataFrame` with third-party + full-featured XML library, ``lxml``, that supports + ``XPath`` 1.0 and XSLT 1.0. + """ + + def parse_data(self) -> list[dict[str, str | None]]: + """ + Parse xml data. + + This method will call the other internal methods to + validate ``xpath``, names, optionally parse and run XSLT, + and parse original or transformed XML and return specific nodes. + """ + from lxml.etree import iterparse + + if self.iterparse is None: + self.xml_doc = self._parse_doc(self.path_or_buffer) + + if self.stylesheet: + self.xsl_doc = self._parse_doc(self.stylesheet) + self.xml_doc = self._transform_doc() + + elems = self._validate_path() + + self._validate_names() + + xml_dicts: list[dict[str, str | None]] = ( + self._parse_nodes(elems) + if self.iterparse is None + else self._iterparse_nodes(iterparse) + ) + + return xml_dicts + + def _validate_path(self) -> list[Any]: + msg = ( + "xpath does not return any nodes or attributes. " + "Be sure to specify in `xpath` the parent nodes of " + "children and attributes to parse. " + "If document uses namespaces denoted with " + "xmlns, be sure to define namespaces and " + "use them in xpath." + ) + + elems = self.xml_doc.xpath(self.xpath, namespaces=self.namespaces) + children = [ch for el in elems for ch in el.xpath("*")] + attrs = {k: v for el in elems for k, v in el.attrib.items()} + + if elems == []: + raise ValueError(msg) + + if elems != []: + if self.elems_only and children == []: + raise ValueError(msg) + if self.attrs_only and attrs == {}: + raise ValueError(msg) + if children == [] and attrs == {}: + raise ValueError(msg) + + return elems + + def _validate_names(self) -> None: + children: list[Any] + + if self.names: + if self.iterparse: + children = self.iterparse[next(iter(self.iterparse))] + else: + children = self.xml_doc.xpath( + self.xpath + "[1]/*", namespaces=self.namespaces + ) + + if is_list_like(self.names): + if len(self.names) < len(children): + raise ValueError( + "names does not match length of child elements in xpath." + ) + else: + raise TypeError( + f"{type(self.names).__name__} is not a valid type for names" + ) + + def _parse_doc( + self, raw_doc: FilePath | ReadBuffer[bytes] | ReadBuffer[str] + ) -> etree._Element: + from lxml.etree import ( + XMLParser, + fromstring, + parse, + ) + + handle_data = get_data_from_filepath( + filepath_or_buffer=raw_doc, + encoding=self.encoding, + compression=self.compression, + storage_options=self.storage_options, + ) + + with handle_data as xml_data: + curr_parser = XMLParser(encoding=self.encoding) + + if isinstance(xml_data, io.StringIO): + if self.encoding is None: + raise TypeError( + "Can not pass encoding None when input is StringIO." + ) + + document = fromstring( + xml_data.getvalue().encode(self.encoding), parser=curr_parser + ) + else: + document = parse(xml_data, parser=curr_parser) + + return document + + def _transform_doc(self) -> etree._XSLTResultTree: + """ + Transform original tree using stylesheet. + + This method will transform original xml using XSLT script into + am ideally flatter xml document for easier parsing and migration + to Data Frame. + """ + from lxml.etree import XSLT + + transformer = XSLT(self.xsl_doc) + new_doc = transformer(self.xml_doc) + + return new_doc + + +def get_data_from_filepath( + filepath_or_buffer: FilePath | ReadBuffer[bytes] | ReadBuffer[str], + encoding: str | None, + compression: CompressionOptions, + storage_options: StorageOptions, +): + """ + Extract raw XML data. + + The method accepts two input types: + 1. filepath (string-like) + 2. file-like object (e.g. open file object, StringIO) + """ + filepath_or_buffer = stringify_path(filepath_or_buffer) + with get_handle( + filepath_or_buffer, + "r", + encoding=encoding, + compression=compression, + storage_options=storage_options, + ) as handle_obj: + return ( + preprocess_data(handle_obj.handle.read()) + if hasattr(handle_obj.handle, "read") + else handle_obj.handle + ) + + +def preprocess_data( + data: str | bytes | io.StringIO | io.BytesIO, +) -> io.StringIO | io.BytesIO: + """ + Convert extracted raw data. + + This method will return underlying data of extracted XML content. + The data either has a `read` attribute (e.g. a file object or a + StringIO/BytesIO) or is a string or bytes that is an XML document. + """ + + if isinstance(data, str): + data = io.StringIO(data) + + elif isinstance(data, bytes): + data = io.BytesIO(data) + + return data + + +def _data_to_frame(data: list[dict[str, str | None]], **kwargs) -> DataFrame: + """ + Convert parsed data to Data Frame. + + This method will bind xml dictionary data of keys and values + into named columns of Data Frame using the built-in TextParser + class that build Data Frame and infers specific dtypes. + """ + + tags = next(iter(data)) + nodes = [list(d.values()) for d in data] + + try: + with TextParser(nodes, names=tags, **kwargs) as tp: + return tp.read() + except ParserError as err: + raise ParserError( + "XML document may be too complex for import. " + "Try to flatten document and use distinct " + "element and attribute names." + ) from err + + +def _parse( + path_or_buffer: FilePath | ReadBuffer[bytes] | ReadBuffer[str], + xpath: str, + namespaces: dict[str, str] | None, + elems_only: bool, + attrs_only: bool, + names: Sequence[str] | None, + dtype: DtypeArg | None, + converters: ConvertersArg | None, + parse_dates: ParseDatesArg | None, + encoding: str | None, + parser: XMLParsers, + stylesheet: FilePath | ReadBuffer[bytes] | ReadBuffer[str] | None, + iterparse: dict[str, list[str]] | None, + compression: CompressionOptions, + storage_options: StorageOptions, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, + **kwargs, +) -> DataFrame: + """ + Call internal parsers. + + This method will conditionally call internal parsers: + LxmlFrameParser and/or EtreeParser. + + Raises + ------ + ImportError + * If lxml is not installed if selected as parser. + + ValueError + * If parser is not lxml or etree. + """ + + p: _EtreeFrameParser | _LxmlFrameParser + + if parser == "lxml": + lxml = import_optional_dependency("lxml.etree", errors="ignore") + + if lxml is not None: + p = _LxmlFrameParser( + path_or_buffer, + xpath, + namespaces, + elems_only, + attrs_only, + names, + dtype, + converters, + parse_dates, + encoding, + stylesheet, + iterparse, + compression, + storage_options, + ) + else: + raise ImportError("lxml not found, please install or use the etree parser.") + + elif parser == "etree": + p = _EtreeFrameParser( + path_or_buffer, + xpath, + namespaces, + elems_only, + attrs_only, + names, + dtype, + converters, + parse_dates, + encoding, + stylesheet, + iterparse, + compression, + storage_options, + ) + else: + raise ValueError("Values for parser can only be lxml or etree.") + + data_dicts = p.parse_data() + + return _data_to_frame( + data=data_dicts, + dtype=dtype, + converters=converters, + parse_dates=parse_dates, + dtype_backend=dtype_backend, + **kwargs, + ) + + +@set_module("pandas") +def read_xml( + path_or_buffer: FilePath | ReadBuffer[bytes] | ReadBuffer[str], + *, + xpath: str = "./*", + namespaces: dict[str, str] | None = None, + elems_only: bool = False, + attrs_only: bool = False, + names: Sequence[str] | None = None, + dtype: DtypeArg | None = None, + converters: ConvertersArg | None = None, + parse_dates: ParseDatesArg | None = None, + # encoding can not be None for lxml and StringIO input + encoding: str | None = "utf-8", + parser: XMLParsers = "lxml", + stylesheet: FilePath | ReadBuffer[bytes] | ReadBuffer[str] | None = None, + iterparse: dict[str, list[str]] | None = None, + compression: CompressionOptions = "infer", + storage_options: StorageOptions | None = None, + dtype_backend: DtypeBackend | lib.NoDefault = lib.no_default, +) -> DataFrame: + r""" + Read XML document into a :class:`~pandas.DataFrame` object. + + Parameters + ---------- + path_or_buffer : str, path object, or file-like object + String path, path object (implementing ``os.PathLike[str]``), or file-like + object implementing a ``read()`` function. The string can be a path. + The string can further be a URL. Valid URL schemes + include http, ftp, s3, and file. + + xpath : str, optional, default './\*' + The ``XPath`` to parse required set of nodes for migration to + :class:`~pandas.DataFrame`.``XPath`` should return a collection of elements + and not a single element. Note: The ``etree`` parser supports limited ``XPath`` + expressions. For more complex ``XPath``, use ``lxml`` which requires + installation. + + namespaces : dict, optional + The namespaces defined in XML document as dicts with key being + namespace prefix and value the URI. There is no need to include all + namespaces in XML, only the ones used in ``xpath`` expression. + Note: if XML document uses default namespace denoted as + `xmlns=''` without a prefix, you must assign any temporary + namespace prefix such as 'doc' to the URI in order to parse + underlying nodes and/or attributes. + + elems_only : bool, optional, default False + Parse only the child elements at the specified ``xpath``. By default, + all child elements and non-empty text nodes are returned. + + attrs_only : bool, optional, default False + Parse only the attributes at the specified ``xpath``. + By default, all attributes are returned. + + names : list-like, optional + Column names for DataFrame of parsed XML data. Use this parameter to + rename original element names and distinguish same named elements and + attributes. + + dtype : Type name or dict of column -> type, optional + Data type for data or columns. E.g. {{'a': np.float64, 'b': np.int32, + 'c': 'Int64'}} + Use `str` or `object` together with suitable `na_values` settings + to preserve and not interpret dtype. + If converters are specified, they will be applied INSTEAD + of dtype conversion. + + converters : dict, optional + Dict of functions for converting values in certain columns. Keys can either + be integers or column labels. + + parse_dates : bool or list of int or names or list of lists or dict, default False + Identifiers to parse index or columns to datetime. The behavior is as follows: + + * boolean. If True -> try parsing the index. + * list of int or names. e.g. If [1, 2, 3] -> try parsing columns 1, 2, 3 + each as a separate date column. + * list of lists. e.g. If [[1, 3]] -> combine columns 1 and 3 and parse as + a single date column. + * dict, e.g. {{'foo' : [1, 3]}} -> parse columns 1, 3 as date and call + result 'foo' + + encoding : str, optional, default 'utf-8' + Encoding of XML document. + + parser : {{'lxml','etree'}}, default 'lxml' + Parser module to use for retrieval of data. Only 'lxml' and + 'etree' are supported. With 'lxml' more complex ``XPath`` searches + and ability to use XSLT stylesheet are supported. + + stylesheet : str, path object or file-like object + A URL, file-like object, or a string path containing an XSLT script. + This stylesheet should flatten complex, deeply nested XML documents + for easier parsing. To use this feature you must have ``lxml`` module + installed and specify 'lxml' as ``parser``. The ``xpath`` must + reference nodes of transformed XML document generated after XSLT + transformation and not the original XML document. Only XSLT 1.0 + scripts and not later versions is currently supported. + + iterparse : dict, optional + The nodes or attributes to retrieve in iterparsing of XML document + as a dict with key being the name of repeating element and value being + list of elements or attribute names that are descendants of the repeated + element. Note: If this option is used, it will replace ``xpath`` parsing + and unlike ``xpath``, descendants do not need to relate to each other but can + exist any where in document under the repeating element. This memory- + efficient method should be used for very large XML files (500MB, 1GB, or 5GB+). + For example, ``{{"row_element": ["child_elem", "attr", "grandchild_elem"]}}``. + + compression : str or dict, default 'infer' + For on-the-fly decompression of on-disk data. If 'infer' and + 'path_or_buffer' is path-like, then detect compression from the + following extensions: '.gz', '.bz2', '.zip', '.xz', '.zst', '.tar', + '.tar.gz', '.tar.xz' or '.tar.bz2' (otherwise no compression). + If using 'zip' or 'tar', the ZIP file must contain only one data + file to be read in. Set to ``None`` for no decompression. + Can also be a dict with key ``'method'`` set to one of + {``'zip'``, ``'gzip'``, ``'bz2'``, ``'zstd'``, ``'xz'``, ``'tar'``} + and other key-value pairs are forwarded to ``zipfile.ZipFile``, + ``gzip.GzipFile``, ``bz2.BZ2File``, ``zstandard.ZstdDecompressor``, + ``lzma.LZMAFile`` or ``tarfile.TarFile``, respectively. + As an example, the following could be passed for Zstandard + decompression using a custom compression dictionary: + ``compression={'method': 'zstd', 'dict_data': my_compression_dict}``. + + storage_options : dict, optional + Extra options that make sense for a particular storage connection, + e.g. host, port, username, password, etc. For HTTP(S) URLs the + key-value pairs are forwarded to ``urllib.request.Request`` as header + options. For other URLs (e.g. starting with "s3://", and "gcs://") + the key-value pairs are forwarded to ``fsspec.open``. Please see + ``fsspec`` and ``urllib`` for more details, and for more examples on + storage options refer `here `_. + + dtype_backend : {{'numpy_nullable', 'pyarrow'}} + Back-end data type applied to the resultant :class:`DataFrame` + (still experimental). If not specified, the default behavior + is to not use nullable data types. If specified, the behavior + is as follows: + + * ``"numpy_nullable"``: returns nullable-dtype-backed :class:`DataFrame` + * ``"pyarrow"``: returns pyarrow-backed nullable + :class:`ArrowDtype` :class:`DataFrame` + + .. versionadded:: 2.0 + + Returns + ------- + df + A DataFrame. + + See Also + -------- + read_json : Convert a JSON string to pandas object. + read_html : Read HTML tables into a list of DataFrame objects. + + Notes + ----- + This method is best designed to import shallow XML documents in + following format which is the ideal fit for the two-dimensions of a + ``DataFrame`` (row by column). :: + + + + data + data + data + ... + + + ... + + ... + + + As a file format, XML documents can be designed any way including + layout of elements and attributes as long as it conforms to W3C + specifications. Therefore, this method is a convenience handler for + a specific flatter design and not all possible XML structures. + + However, for more complex XML documents, ``stylesheet`` allows you to + temporarily redesign original document with XSLT (a special purpose + language) for a flatter version for migration to a DataFrame. + + This function will *always* return a single :class:`DataFrame` or raise + exceptions due to issues with XML document, ``xpath``, or other + parameters. + + See the :ref:`read_xml documentation in the IO section of the docs + ` for more information in using this method to parse XML + files to DataFrames. + + Examples + -------- + >>> from io import StringIO + >>> xml = ''' + ... + ... + ... square + ... 360 + ... 4.0 + ... + ... + ... circle + ... 360 + ... + ... + ... + ... triangle + ... 180 + ... 3.0 + ... + ... ''' + + >>> df = pd.read_xml(StringIO(xml)) + >>> df + shape degrees sides + 0 square 360 4.0 + 1 circle 360 NaN + 2 triangle 180 3.0 + + >>> xml = ''' + ... + ... + ... + ... + ... ''' + + >>> df = pd.read_xml(StringIO(xml), xpath=".//row") + >>> df + shape degrees sides + 0 square 360 4.0 + 1 circle 360 NaN + 2 triangle 180 3.0 + + >>> xml = ''' + ... + ... + ... square + ... 360 + ... 4.0 + ... + ... + ... circle + ... 360 + ... + ... + ... + ... triangle + ... 180 + ... 3.0 + ... + ... ''' + + >>> df = pd.read_xml( + ... StringIO(xml), + ... xpath="//doc:row", + ... namespaces={"doc": "https://example.com"}, + ... ) + >>> df + shape degrees sides + 0 square 360 4.0 + 1 circle 360 NaN + 2 triangle 180 3.0 + + >>> xml_data = ''' + ... + ... + ... 0 + ... 1 + ... 2.5 + ... True + ... a + ... 2019-12-31 00:00:00 + ... + ... + ... 1 + ... 4.5 + ... False + ... b + ... 2019-12-31 00:00:00 + ... + ... + ... ''' + + >>> df = pd.read_xml( + ... StringIO(xml_data), dtype_backend="numpy_nullable", parse_dates=["e"] + ... ) + >>> df + index a b c d e + 0 0 1 2.5 True a 2019-12-31 + 1 1 4.5 False b 2019-12-31 + """ + check_dtype_backend(dtype_backend) + + return _parse( + path_or_buffer=path_or_buffer, + xpath=xpath, + namespaces=namespaces, + elems_only=elems_only, + attrs_only=attrs_only, + names=names, + dtype=dtype, + converters=converters, + parse_dates=parse_dates, + encoding=encoding, + parser=parser, + stylesheet=stylesheet, + iterparse=iterparse, + compression=compression, + storage_options=storage_options, + dtype_backend=dtype_backend, + ) diff --git a/pandas/plotting/__init__.py b/pandas/plotting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..837bfaf82ca272672cb6bd91adb83154e66a508e --- /dev/null +++ b/pandas/plotting/__init__.py @@ -0,0 +1,99 @@ +""" +Plotting public API. + +Authors of third-party plotting backends should implement a module with a +public ``plot(data, kind, **kwargs)``. The parameter `data` will contain +the data structure and can be a `Series` or a `DataFrame`. For example, +for ``df.plot()`` the parameter `data` will contain the DataFrame `df`. +In some cases, the data structure is transformed before being sent to +the backend (see PlotAccessor.__call__ in pandas/plotting/_core.py for +the exact transformations). + +The parameter `kind` will be one of: + +- line +- bar +- barh +- box +- hist +- kde +- area +- pie +- scatter +- hexbin + +See the pandas API reference for documentation on each kind of plot. + +Any other keyword argument is currently assumed to be backend specific, +but some parameters may be unified and added to the signature in the +future (e.g. `title` which should be useful for any backend). + +Currently, all the Matplotlib functions in pandas are accessed through +the selected backend. For example, `pandas.plotting.boxplot` (equivalent +to `DataFrame.boxplot`) is also accessed in the selected backend. This +is expected to change, and the exact API is under discussion. But with +the current version, backends are expected to implement the next functions: + +- plot (describe above, used for `Series.plot` and `DataFrame.plot`) +- hist_series and hist_frame (for `Series.hist` and `DataFrame.hist`) +- boxplot (`pandas.plotting.boxplot(df)` equivalent to `DataFrame.boxplot`) +- boxplot_frame and boxplot_frame_groupby +- register and deregister (register converters for the tick formats) +- Plots not called as `Series` and `DataFrame` methods: + - table + - andrews_curves + - autocorrelation_plot + - bootstrap_plot + - lag_plot + - parallel_coordinates + - radviz + - scatter_matrix + +Use the code in pandas/plotting/_matplotib.py and +https://github.com/pyviz/hvplot as a reference on how to write a backend. + +For the discussion about the API see +https://github.com/pandas-dev/pandas/issues/26747. +""" + +from pandas.plotting._core import ( + PlotAccessor, + boxplot, + boxplot_frame, + boxplot_frame_groupby, + hist_frame, + hist_series, +) +from pandas.plotting._misc import ( + andrews_curves, + autocorrelation_plot, + bootstrap_plot, + deregister as deregister_matplotlib_converters, + lag_plot, + parallel_coordinates, + plot_params, + radviz, + register as register_matplotlib_converters, + scatter_matrix, + table, +) + +__all__ = [ + "PlotAccessor", + "andrews_curves", + "autocorrelation_plot", + "bootstrap_plot", + "boxplot", + "boxplot_frame", + "boxplot_frame_groupby", + "deregister_matplotlib_converters", + "hist_frame", + "hist_series", + "lag_plot", + "parallel_coordinates", + "plot_params", + "radviz", + "register_matplotlib_converters", + "scatter_matrix", + "table", +] diff --git a/pandas/plotting/_core.py b/pandas/plotting/_core.py new file mode 100644 index 0000000000000000000000000000000000000000..e75bb32313b03da9aeee2cfbcd9227e10edcd780 --- /dev/null +++ b/pandas/plotting/_core.py @@ -0,0 +1,2255 @@ +from __future__ import annotations + +import importlib +from typing import ( + TYPE_CHECKING, + Literal, +) + +from pandas._config import get_option + +from pandas.util._decorators import set_module + +from pandas.core.dtypes.common import ( + is_integer, + is_list_like, +) +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCSeries, +) + +from pandas.core.base import PandasObject + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Hashable, + Sequence, + ) + import types + + from matplotlib.axes import Axes + import numpy as np + + from pandas._typing import IndexLabel + + from pandas import ( + DataFrame, + Index, + Series, + ) + from pandas.core.groupby.generic import DataFrameGroupBy + + +def holds_integer(column: Index) -> bool: + return column.dtype.kind in "iu" + + +@set_module("pandas.plotting") +def hist_series( + self: Series, + by=None, + ax=None, + grid: bool = True, + xlabelsize: int | None = None, + xrot: float | None = None, + ylabelsize: int | None = None, + yrot: float | None = None, + figsize: tuple[int, int] | None = None, + bins: int | Sequence[int] = 10, + backend: str | None = None, + legend: bool = False, + **kwargs, +): + """ + Draw histogram of the input series using matplotlib. + + Parameters + ---------- + by : object, optional + If passed, then used to form histograms for separate groups. + ax : matplotlib axis object + If not passed, uses gca(). + grid : bool, default True + Whether to show axis grid lines. + xlabelsize : int, default None + If specified changes the x-axis label size. + xrot : float, default None + Rotation of x axis labels. + ylabelsize : int, default None + If specified changes the y-axis label size. + yrot : float, default None + Rotation of y axis labels. + figsize : tuple, default None + Figure size in inches by default. + bins : int or sequence, default 10 + Number of histogram bins to be used. If an integer is given, bins + 1 + bin edges are calculated and returned. If bins is a sequence, gives + bin edges, including left edge of first bin and right edge of last + bin. In this case, bins is returned unmodified. + backend : str, default None + Backend to use instead of the backend specified in the option + ``plotting.backend``. For instance, 'matplotlib'. Alternatively, to + specify the ``plotting.backend`` for the whole session, set + ``pd.options.plotting.backend``. + legend : bool, default False + Whether to show the legend. + + **kwargs + To be passed to the actual plotting function. + + Returns + ------- + matplotlib.axes.Axes + A histogram plot. + + See Also + -------- + matplotlib.axes.Axes.hist : Plot a histogram using matplotlib. + + Examples + -------- + For Series: + + .. plot:: + :context: close-figs + + >>> lst = ["a", "a", "a", "b", "b", "b"] + >>> ser = pd.Series([1, 2, 2, 4, 6, 6], index=lst) + >>> hist = ser.hist() + + For Groupby: + + .. plot:: + :context: close-figs + + >>> lst = ["a", "a", "a", "b", "b", "b"] + >>> ser = pd.Series([1, 2, 2, 4, 6, 6], index=lst) + >>> hist = ser.groupby(level=0).hist() + """ + plot_backend = _get_plot_backend(backend) + return plot_backend.hist_series( + self, + by=by, + ax=ax, + grid=grid, + xlabelsize=xlabelsize, + xrot=xrot, + ylabelsize=ylabelsize, + yrot=yrot, + figsize=figsize, + bins=bins, + legend=legend, + **kwargs, + ) + + +@set_module("pandas.plotting") +def hist_frame( + data: DataFrame, + column: IndexLabel | None = None, + by=None, + grid: bool = True, + xlabelsize: int | None = None, + xrot: float | None = None, + ylabelsize: int | None = None, + yrot: float | None = None, + ax=None, + sharex: bool = False, + sharey: bool = False, + figsize: tuple[int, int] | None = None, + layout: tuple[int, int] | None = None, + bins: int | Sequence[int] = 10, + backend: str | None = None, + legend: bool = False, + **kwargs, +): + """ + Make a histogram of the DataFrame's columns. + + A `histogram`_ is a representation of the distribution of data. + This function calls :meth:`matplotlib.pyplot.hist`, on each series in + the DataFrame, resulting in one histogram per column. + + .. _histogram: https://en.wikipedia.org/wiki/Histogram + + Parameters + ---------- + data : DataFrame + The pandas object holding the data. + column : str or sequence, optional + If passed, will be used to limit data to a subset of columns. + by : object, optional + If passed, then used to form histograms for separate groups. + grid : bool, default True + Whether to show axis grid lines. + xlabelsize : int, default None + If specified changes the x-axis label size. + xrot : float, default None + Rotation of x axis labels. For example, a value of 90 displays the + x labels rotated 90 degrees clockwise. + ylabelsize : int, default None + If specified changes the y-axis label size. + yrot : float, default None + Rotation of y axis labels. For example, a value of 90 displays the + y labels rotated 90 degrees clockwise. + ax : Matplotlib axes object, default None + The axes to plot the histogram on. + sharex : bool, default True if ax is None else False + In case subplots=True, share x axis and set some x axis labels to + invisible; defaults to True if ax is None otherwise False if an ax + is passed in. + Note that passing in both an ax and sharex=True will alter all x axis + labels for all subplots in a figure. + sharey : bool, default False + In case subplots=True, share y axis and set some y axis labels to + invisible. + figsize : tuple, optional + The size in inches of the figure to create. Uses the value in + `matplotlib.rcParams` by default. + layout : tuple, optional + Tuple of (rows, columns) for the layout of the histograms. + bins : int or sequence, default 10 + Number of histogram bins to be used. If an integer is given, bins + 1 + bin edges are calculated and returned. If bins is a sequence, gives + bin edges, including left edge of first bin and right edge of last + bin. In this case, bins is returned unmodified. + + backend : str, default None + Backend to use instead of the backend specified in the option + ``plotting.backend``. For instance, 'matplotlib'. Alternatively, to + specify the ``plotting.backend`` for the whole session, set + ``pd.options.plotting.backend``. + + legend : bool, default False + Whether to show the legend. + + **kwargs + All other plotting keyword arguments to be passed to + :meth:`matplotlib.pyplot.hist`. + + Returns + ------- + np.ndarray + 2D NumPy Array of :class:`matplotlib.axes.Axes`. + + See Also + -------- + matplotlib.pyplot.hist : Plot a histogram using matplotlib. + + Examples + -------- + This example draws a histogram based on the length and width of + some animals, displayed in three bins + + .. plot:: + :context: close-figs + + >>> data = { + ... "length": [1.5, 0.5, 1.2, 0.9, 3], + ... "width": [0.7, 0.2, 0.15, 0.2, 1.1], + ... } + >>> index = ["pig", "rabbit", "duck", "chicken", "horse"] + >>> df = pd.DataFrame(data, index=index) + >>> hist = df.hist(bins=3) + """ + plot_backend = _get_plot_backend(backend) + return plot_backend.hist_frame( + data, + column=column, + by=by, + grid=grid, + xlabelsize=xlabelsize, + xrot=xrot, + ylabelsize=ylabelsize, + yrot=yrot, + ax=ax, + sharex=sharex, + sharey=sharey, + figsize=figsize, + layout=layout, + legend=legend, + bins=bins, + **kwargs, + ) + + +@set_module("pandas.plotting") +def boxplot( + data: DataFrame, + column: str | list[str] | None = None, + by: str | list[str] | None = None, + ax: Axes | None = None, + fontsize: float | str | None = None, + rot: int = 0, + grid: bool = True, + figsize: tuple[float, float] | None = None, + layout: tuple[int, int] | None = None, + return_type: str | None = None, + **kwargs, +): + """ + Make a box plot from DataFrame columns. + + Make a box-and-whisker plot from DataFrame columns, optionally grouped + by some other columns. A box plot is a method for graphically depicting + groups of numerical data through their quartiles. + The box extends from the Q1 to Q3 quartile values of the data, + with a line at the median (Q2). The whiskers extend from the edges + of box to show the range of the data. By default, they extend no more than + `1.5 * IQR (IQR = Q3 - Q1)` from the edges of the box, ending at the farthest + data point within that interval. Outliers are plotted as separate dots. + + For further details see + Wikipedia's entry for `boxplot `_. + + Parameters + ---------- + data : DataFrame + The data to visualize. + column : str or list of str, optional + Column name or list of names, or vector. + Can be any valid input to :meth:`pandas.DataFrame.groupby`. + by : str or array-like, optional + Column in the DataFrame to :meth:`pandas.DataFrame.groupby`. + One box-plot will be done per value of columns in `by`. + ax : object of class matplotlib.axes.Axes, optional + The matplotlib axes to be used by boxplot. + fontsize : float or str + Tick label font size in points or as a string (e.g., `large`). + rot : float, default 0 + The rotation angle of labels (in degrees) + with respect to the screen coordinate system. + grid : bool, default True + Setting this to True will show the grid. + figsize : A tuple (width, height) in inches + The size of the figure to create in matplotlib. + layout : tuple (rows, columns), optional + For example, (3, 5) will display the subplots + using 3 rows and 5 columns, starting from the top-left. + return_type : {'axes', 'dict', 'both'} or None, default 'axes' + The kind of object to return. The default is ``axes``. + + * 'axes' returns the matplotlib axes the boxplot is drawn on. + * 'dict' returns a dictionary whose values are the matplotlib + lines of the boxplot. + * 'both' returns a namedtuple with the axes and dict. + * when grouping with ``by``, a Series mapping columns to + ``return_type`` is returned. + + If ``return_type`` is `None`, a NumPy array + of axes with the same shape as ``layout`` is returned. + + **kwargs + All other plotting keyword arguments to be passed to + :func:`matplotlib.pyplot.boxplot`. + + Returns + ------- + result + See Notes. + + See Also + -------- + Series.plot.hist: Make a histogram. + matplotlib.pyplot.boxplot : Matplotlib equivalent plot. + + Notes + ----- + The return type depends on the `return_type` parameter: + + * 'axes' : object of class matplotlib.axes.Axes + * 'dict' : dict of matplotlib.lines.Line2D objects + * 'both' : a namedtuple with structure (ax, lines) + + For data grouped with ``by``, return a Series of the above or a numpy + array: + + * :class:`~pandas.Series` + * :class:`~numpy.array` (for ``return_type = None``) + + Use ``return_type='dict'`` when you want to tweak the appearance + of the lines after plotting. In this case a dict containing the Lines + making up the boxes, caps, fliers, medians, and whiskers is returned. + + Examples + -------- + + Boxplots can be created for every column in the dataframe + by ``df.boxplot()`` or indicating the columns to be used: + + .. plot:: + :context: close-figs + + >>> np.random.seed(1234) + >>> df = pd.DataFrame( + ... np.random.randn(10, 4), columns=["Col1", "Col2", "Col3", "Col4"] + ... ) + >>> boxplot = df.boxplot(column=["Col1", "Col2", "Col3"]) # doctest: +SKIP + + Boxplots of variables distributions grouped by the values of a third + variable can be created using the option ``by``. For instance: + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame(np.random.randn(10, 2), columns=["Col1", "Col2"]) + >>> df["X"] = pd.Series(["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) + >>> boxplot = df.boxplot(by="X") + + A list of strings (i.e. ``['X', 'Y']``) can be passed to boxplot + in order to group the data by combination of the variables in the x-axis: + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame(np.random.randn(10, 3), columns=["Col1", "Col2", "Col3"]) + >>> df["X"] = pd.Series(["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) + >>> df["Y"] = pd.Series(["A", "B", "A", "B", "A", "B", "A", "B", "A", "B"]) + >>> boxplot = df.boxplot(column=["Col1", "Col2"], by=["X", "Y"]) + + The layout of boxplot can be adjusted giving a tuple to ``layout``: + + .. plot:: + :context: close-figs + + >>> boxplot = df.boxplot(column=["Col1", "Col2"], by="X", layout=(2, 1)) + + Additional formatting can be done to the boxplot, like suppressing the grid + (``grid=False``), rotating the labels in the x-axis (i.e. ``rot=45``) + or changing the fontsize (i.e. ``fontsize=15``): + + .. plot:: + :context: close-figs + + >>> boxplot = df.boxplot(grid=False, rot=45, fontsize=15) # doctest: +SKIP + + The parameter ``return_type`` can be used to select the type of element + returned by `boxplot`. When ``return_type='axes'`` is selected, + the matplotlib axes on which the boxplot is drawn are returned: + + >>> boxplot = df.boxplot(column=["Col1", "Col2"], return_type="axes") + >>> type(boxplot) + + + When grouping with ``by``, a Series mapping columns to ``return_type`` + is returned: + + >>> boxplot = df.boxplot(column=["Col1", "Col2"], by="X", return_type="axes") + >>> type(boxplot) + + + If ``return_type`` is `None`, a NumPy array of axes with the same shape + as ``layout`` is returned: + + >>> boxplot = df.boxplot(column=["Col1", "Col2"], by="X", return_type=None) + >>> type(boxplot) + + """ + plot_backend = _get_plot_backend("matplotlib") + return plot_backend.boxplot( + data, + column=column, + by=by, + ax=ax, + fontsize=fontsize, + rot=rot, + grid=grid, + figsize=figsize, + layout=layout, + return_type=return_type, + **kwargs, + ) + + +@set_module("pandas.plotting") +def boxplot_frame( + self: DataFrame, + column=None, + by=None, + ax=None, + fontsize: int | None = None, + rot: int = 0, + grid: bool = True, + figsize: tuple[float, float] | None = None, + layout=None, + return_type=None, + backend=None, + **kwargs, +): + """ + Make a box plot from DataFrame columns. + + Make a box-and-whisker plot from DataFrame columns, optionally grouped + by some other columns. A box plot is a method for graphically depicting + groups of numerical data through their quartiles. + The box extends from the Q1 to Q3 quartile values of the data, + with a line at the median (Q2). The whiskers extend from the edges + of box to show the range of the data. By default, they extend no more than + `1.5 * IQR (IQR = Q3 - Q1)` from the edges of the box, ending at the farthest + data point within that interval. Outliers are plotted as separate dots. + + For further details see + Wikipedia's entry for `boxplot `_. + + Parameters + ---------- + column : str or list of str, optional + Column name or list of names, or vector. + Can be any valid input to :meth:`pandas.DataFrame.groupby`. + by : str or array-like, optional + Column in the DataFrame to :meth:`pandas.DataFrame.groupby`. + One box-plot will be done per value of columns in `by`. + ax : object of class matplotlib.axes.Axes, optional + The matplotlib axes to be used by boxplot. + fontsize : float or str + Tick label font size in points or as a string (e.g., `large`). + rot : float, default 0 + The rotation angle of labels (in degrees) + with respect to the screen coordinate system. + grid : bool, default True + Setting this to True will show the grid. + figsize : A tuple (width, height) in inches + The size of the figure to create in matplotlib. + layout : tuple (rows, columns), optional + For example, (3, 5) will display the subplots + using 3 rows and 5 columns, starting from the top-left. + return_type : {'axes', 'dict', 'both'} or None, default 'axes' + The kind of object to return. The default is ``axes``. + + * 'axes' returns the matplotlib axes the boxplot is drawn on. + * 'dict' returns a dictionary whose values are the matplotlib + lines of the boxplot. + * 'both' returns a namedtuple with the axes and dict. + * when grouping with ``by``, a Series mapping columns to + ``return_type`` is returned. + + If ``return_type`` is `None`, a NumPy array + of axes with the same shape as ``layout`` is returned. + backend : str, default None + Backend to use instead of the backend specified in the option + ``plotting.backend``. For instance, 'matplotlib'. Alternatively, to + specify the ``plotting.backend`` for the whole session, set + ``pd.options.plotting.backend``. + + **kwargs + All other plotting keyword arguments to be passed to + :func:`matplotlib.pyplot.boxplot`. + + Returns + ------- + result + See Notes. + + See Also + -------- + Series.plot.hist: Make a histogram. + matplotlib.pyplot.boxplot : Matplotlib equivalent plot. + + Notes + ----- + The return type depends on the `return_type` parameter: + + * 'axes' : object of class matplotlib.axes.Axes + * 'dict' : dict of matplotlib.lines.Line2D objects + * 'both' : a namedtuple with structure (ax, lines) + + For data grouped with ``by``, return a Series of the above or a numpy + array: + + * :class:`~pandas.Series` + * :class:`~numpy.array` (for ``return_type = None``) + + Use ``return_type='dict'`` when you want to tweak the appearance + of the lines after plotting. In this case a dict containing the Lines + making up the boxes, caps, fliers, medians, and whiskers is returned. + + Examples + -------- + + Boxplots can be created for every column in the dataframe + by ``df.boxplot()`` or indicating the columns to be used: + + .. plot:: + :context: close-figs + + >>> np.random.seed(1234) + >>> df = pd.DataFrame( + ... np.random.randn(10, 4), columns=["Col1", "Col2", "Col3", "Col4"] + ... ) + >>> boxplot = df.boxplot(column=["Col1", "Col2", "Col3"]) # doctest: +SKIP + + Boxplots of variables distributions grouped by the values of a third + variable can be created using the option ``by``. For instance: + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame(np.random.randn(10, 2), columns=["Col1", "Col2"]) + >>> df["X"] = pd.Series(["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) + >>> boxplot = df.boxplot(by="X") + + A list of strings (i.e. ``['X', 'Y']``) can be passed to boxplot + in order to group the data by combination of the variables in the x-axis: + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame(np.random.randn(10, 3), columns=["Col1", "Col2", "Col3"]) + >>> df["X"] = pd.Series(["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) + >>> df["Y"] = pd.Series(["A", "B", "A", "B", "A", "B", "A", "B", "A", "B"]) + >>> boxplot = df.boxplot(column=["Col1", "Col2"], by=["X", "Y"]) + + The layout of boxplot can be adjusted giving a tuple to ``layout``: + + .. plot:: + :context: close-figs + + >>> boxplot = df.boxplot(column=["Col1", "Col2"], by="X", layout=(2, 1)) + + Additional formatting can be done to the boxplot, like suppressing the grid + (``grid=False``), rotating the labels in the x-axis (i.e. ``rot=45``) + or changing the fontsize (i.e. ``fontsize=15``): + + .. plot:: + :context: close-figs + + >>> boxplot = df.boxplot(grid=False, rot=45, fontsize=15) # doctest: +SKIP + + The parameter ``return_type`` can be used to select the type of element + returned by `boxplot`. When ``return_type='axes'`` is selected, + the matplotlib axes on which the boxplot is drawn are returned: + + .. plot:: + :context: close-figs + + >>> boxplot = df.boxplot(column=["Col1", "Col2"], return_type="axes") + >>> type(boxplot) + + + When grouping with ``by``, a Series mapping columns to ``return_type`` + is returned: + + .. plot:: + :context: close-figs + + >>> boxplot = df.boxplot(column=["Col1", "Col2"], by="X", return_type="axes") + >>> type(boxplot) + + + If ``return_type`` is `None`, a NumPy array of axes with the same shape + as ``layout`` is returned: + + .. plot:: + :context: close-figs + + >>> boxplot = df.boxplot(column=["Col1", "Col2"], by="X", return_type=None) + >>> type(boxplot) + + """ + + plot_backend = _get_plot_backend(backend) + return plot_backend.boxplot_frame( + self, + column=column, + by=by, + ax=ax, + fontsize=fontsize, + rot=rot, + grid=grid, + figsize=figsize, + layout=layout, + return_type=return_type, + **kwargs, + ) + + +@set_module("pandas.plotting") +def boxplot_frame_groupby( + grouped: DataFrameGroupBy, + subplots: bool = True, + column=None, + fontsize: int | None = None, + rot: int = 0, + grid: bool = True, + ax=None, + figsize: tuple[float, float] | None = None, + layout=None, + sharex: bool = False, + sharey: bool = True, + backend=None, + **kwargs, +): + """ + Make box plots from DataFrameGroupBy data. + + Parameters + ---------- + grouped : DataFrameGroupBy + The grouped DataFrame object over which to create the box plots. + subplots : bool + * ``False`` - no subplots will be used + * ``True`` - create a subplot for each group. + column : column name or list of names, or vector + Can be any valid input to groupby. + fontsize : float or str + Font size for the labels. + rot : float + Rotation angle of labels (in degrees) on the x-axis. + grid : bool + Whether to show grid lines on the plot. + ax : Matplotlib axis object, default None + The axes on which to draw the plots. If None, uses the current axes. + figsize : tuple of (float, float) + The figure size in inches (width, height). + layout : tuple (optional) + The layout of the plot: (rows, columns). + sharex : bool, default False + Whether x-axes will be shared among subplots. + sharey : bool, default True + Whether y-axes will be shared among subplots. + backend : str, default None + Backend to use instead of the backend specified in the option + ``plotting.backend``. For instance, 'matplotlib'. Alternatively, to + specify the ``plotting.backend`` for the whole session, set + ``pd.options.plotting.backend``. + **kwargs + All other plotting keyword arguments to be passed to + matplotlib's boxplot function. + + Returns + ------- + dict or DataFrame.boxplot return value + If ``subplots=True``, returns a dictionary of group keys to the boxplot + return values. If ``subplots=False``, returns the boxplot return value + of a single DataFrame. + + See Also + -------- + DataFrame.boxplot : Create a box plot from a DataFrame. + Series.plot : Plot a Series. + + Examples + -------- + You can create boxplots for grouped data and show them as separate subplots: + + .. plot:: + :context: close-figs + + >>> import itertools + >>> tuples = [t for t in itertools.product(range(1000), range(4))] + >>> index = pd.MultiIndex.from_tuples(tuples, names=["lvl0", "lvl1"]) + >>> data = np.random.randn(len(index), 4) + >>> df = pd.DataFrame(data, columns=list("ABCD"), index=index) + >>> grouped = df.groupby(level="lvl1") + >>> grouped.boxplot(rot=45, fontsize=12, figsize=(8, 10)) # doctest: +SKIP + + The ``subplots=False`` option shows the boxplots in a single figure. + + .. plot:: + :context: close-figs + + >>> grouped.boxplot(subplots=False, rot=45, fontsize=12) # doctest: +SKIP + """ + plot_backend = _get_plot_backend(backend) + return plot_backend.boxplot_frame_groupby( + grouped, + subplots=subplots, + column=column, + fontsize=fontsize, + rot=rot, + grid=grid, + ax=ax, + figsize=figsize, + layout=layout, + sharex=sharex, + sharey=sharey, + **kwargs, + ) + + +@set_module("pandas.plotting") +class PlotAccessor(PandasObject): + """ + Make plots of Series or DataFrame. + + Uses the backend specified by the + option ``plotting.backend``. By default, matplotlib is used. + + Parameters + ---------- + data : Series or DataFrame + The object for which the method is called. + + Attributes + ---------- + x : label or position, default None + Only used if data is a DataFrame. + y : label, position or list of label, positions, default None + Allows plotting of one column versus another. Only used if data is a + DataFrame. + kind : str + The kind of plot to produce: + + - 'line' : line plot (default) + - 'bar' : vertical bar plot + - 'barh' : horizontal bar plot + - 'hist' : histogram + - 'box' : boxplot + - 'kde' : Kernel Density Estimation plot + - 'density' : same as 'kde' + - 'area' : area plot + - 'pie' : pie plot + - 'scatter' : scatter plot (DataFrame only) + - 'hexbin' : hexbin plot (DataFrame only) + ax : matplotlib axes object, default None + An axes of the current figure. + subplots : bool or sequence of iterables, default False + Whether to group columns into subplots: + + - ``False`` : No subplots will be used + - ``True`` : Make separate subplots for each column. + - sequence of iterables of column labels: Create a subplot for each + group of columns. For example `[('a', 'c'), ('b', 'd')]` will + create 2 subplots: one with columns 'a' and 'c', and one + with columns 'b' and 'd'. Remaining columns that aren't specified + will be plotted in additional subplots (one per column). + + sharex : bool, default True if ax is None else False + In case ``subplots=True``, share x axis and set some x axis labels + to invisible; defaults to True if ax is None otherwise False if + an ax is passed in; Be aware, that passing in both an ax and + ``sharex=True`` will alter all x axis labels for all axis in a figure. + sharey : bool, default False + In case ``subplots=True``, share y axis and set some y axis labels to invisible. + layout : tuple, optional + (rows, columns) for the layout of subplots. + figsize : a tuple (width, height) in inches + Size of a figure object. + use_index : bool, default True + Use index as ticks for x axis. + title : str or list + Title to use for the plot. If a string is passed, print the string + at the top of the figure. If a list is passed and `subplots` is + True, print each item in the list above the corresponding subplot. + grid : bool, default None (matlab style default) + Axis grid lines. + legend : bool or {'reverse'} + Place legend on axis subplots. + style : list or dict + The matplotlib line style per column. + logx : bool or 'sym', default False + Use log scaling or symlog scaling on x axis. + + logy : bool or 'sym' default False + Use log scaling or symlog scaling on y axis. + + loglog : bool or 'sym', default False + Use log scaling or symlog scaling on both x and y axes. + + xticks : sequence + Values to use for the xticks. + yticks : sequence + Values to use for the yticks. + xlim : 2-tuple/list + Set the x limits of the current axes. + ylim : 2-tuple/list + Set the y limits of the current axes. + xlabel : label, optional + Name to use for the xlabel on x-axis. Default uses index name as xlabel, or the + x-column name for planar plots. + + .. versionchanged:: 2.0.0 + + Now applicable to histograms. + + ylabel : label, optional + Name to use for the ylabel on y-axis. Default will show no ylabel, or the + y-column name for planar plots. + + .. versionchanged:: 2.0.0 + + Now applicable to histograms. + + rot : float, default None + Rotation for ticks (xticks for vertical, yticks for horizontal + plots). + fontsize : float, default None + Font size for xticks and yticks. + colormap : str or matplotlib colormap object, default None + Colormap to select colors from. If string, load colormap with that + name from matplotlib. + colorbar : bool, optional + If True, plot colorbar (only relevant for 'scatter' and 'hexbin' + plots). + position : float + Specify relative alignments for bar plot layout. + From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 + (center). + table : bool, Series or DataFrame, default False + If True, draw a table using the data in the DataFrame and the data + will be transposed to meet matplotlib's default layout. + If a Series or DataFrame is passed, use passed data to draw a + table. + yerr : DataFrame, Series, array-like, dict and str + See :ref:`Plotting with Error Bars ` for + detail. + xerr : DataFrame, Series, array-like, dict and str + Equivalent to yerr. + stacked : bool, default False in line and bar plots, and True in area plot + If True, create stacked plot. + secondary_y : bool or sequence, default False + Whether to plot on the secondary y-axis if a list/tuple, which + columns to plot on secondary y-axis. + mark_right : bool, default True + When using a secondary_y axis, automatically mark the column + labels with "(right)" in the legend. + include_bool : bool, default is False + If True, boolean values can be plotted. + backend : str, default None + Backend to use instead of the backend specified in the option + ``plotting.backend``. For instance, 'matplotlib'. Alternatively, to + specify the ``plotting.backend`` for the whole session, set + ``pd.options.plotting.backend``. + **kwargs + Options to pass to matplotlib plotting method. + + Returns + ------- + :class:`matplotlib.axes.Axes` or numpy.ndarray of them + If the backend is not the default matplotlib one, the return value + will be the object returned by the backend. + + See Also + -------- + matplotlib.pyplot.plot : Plot y versus x as lines and/or markers. + DataFrame.hist : Make a histogram. + DataFrame.boxplot : Make a box plot. + DataFrame.plot.scatter : Make a scatter plot with varying marker + point size and color. + DataFrame.plot.hexbin : Make a hexagonal binning plot of + two variables. + DataFrame.plot.kde : Make Kernel Density Estimate plot using + Gaussian kernels. + DataFrame.plot.area : Make a stacked area plot. + DataFrame.plot.bar : Make a bar plot. + DataFrame.plot.barh : Make a horizontal bar plot. + + Notes + ----- + - See matplotlib documentation online for more on this subject + - If `kind` = 'bar' or 'barh', you can specify relative alignments + for bar plot layout by `position` keyword. + From 0 (left/bottom-end) to 1 (right/top-end). Default is 0.5 + (center) + + Examples + -------- + For Series: + + .. plot:: + :context: close-figs + + >>> ser = pd.Series([1, 2, 3, 3]) + >>> plot = ser.plot(kind="hist", title="My plot") + + For DataFrame: + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame( + ... { + ... "length": [1.5, 0.5, 1.2, 0.9, 3], + ... "width": [0.7, 0.2, 0.15, 0.2, 1.1], + ... }, + ... index=["pig", "rabbit", "duck", "chicken", "horse"], + ... ) + >>> plot = df.plot(title="DataFrame Plot") + + For SeriesGroupBy: + + .. plot:: + :context: close-figs + + >>> lst = [-1, -2, -3, 1, 2, 3] + >>> ser = pd.Series([1, 2, 2, 4, 6, 6], index=lst) + >>> plot = ser.groupby(lambda x: x > 0).plot(title="SeriesGroupBy Plot") + + For DataFrameGroupBy: + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame({"col1": [1, 2, 3, 4], "col2": ["A", "B", "A", "B"]}) + >>> plot = df.groupby("col2").plot(kind="bar", title="DataFrameGroupBy Plot") + """ + + _common_kinds = ("line", "bar", "barh", "kde", "density", "area", "hist", "box") + _series_kinds = ("pie",) + _dataframe_kinds = ("scatter", "hexbin") + _kind_aliases = {"density": "kde"} + _all_kinds = _common_kinds + _series_kinds + _dataframe_kinds + + def __init__(self, data: Series | DataFrame) -> None: + self._parent = data + + @staticmethod + def _get_call_args(backend_name: str, data: Series | DataFrame, args, kwargs): + """ + This function makes calls to this accessor `__call__` method compatible + with the previous `SeriesPlotMethods.__call__` and + `DataFramePlotMethods.__call__`. Those had slightly different + signatures, since `DataFramePlotMethods` accepted `x` and `y` + parameters. + """ + if isinstance(data, ABCSeries): + arg_def = [ + ("kind", "line"), + ("ax", None), + ("figsize", None), + ("use_index", True), + ("title", None), + ("grid", None), + ("legend", False), + ("style", None), + ("logx", False), + ("logy", False), + ("loglog", False), + ("xticks", None), + ("yticks", None), + ("xlim", None), + ("ylim", None), + ("rot", None), + ("fontsize", None), + ("colormap", None), + ("table", False), + ("yerr", None), + ("xerr", None), + ("label", None), + ("secondary_y", False), + ("xlabel", None), + ("ylabel", None), + ] + elif isinstance(data, ABCDataFrame): + arg_def = [ + ("x", None), + ("y", None), + ("kind", "line"), + ("ax", None), + ("subplots", False), + ("sharex", None), + ("sharey", False), + ("layout", None), + ("figsize", None), + ("use_index", True), + ("title", None), + ("grid", None), + ("legend", True), + ("style", None), + ("logx", False), + ("logy", False), + ("loglog", False), + ("xticks", None), + ("yticks", None), + ("xlim", None), + ("ylim", None), + ("rot", None), + ("fontsize", None), + ("colormap", None), + ("table", False), + ("yerr", None), + ("xerr", None), + ("secondary_y", False), + ("xlabel", None), + ("ylabel", None), + ] + else: + raise TypeError( + f"Called plot accessor for type {type(data).__name__}, " + "expected Series or DataFrame" + ) + + if args and isinstance(data, ABCSeries): + positional_args = str(args)[1:-1] + keyword_args = ", ".join( + [ + f"{name}={value!r}" + for (name, _), value in zip(arg_def, args, strict=False) + ] + ) + msg = ( + "`Series.plot()` should not be called with positional " + "arguments, only keyword arguments. The order of " + "positional arguments will change in the future. " + f"Use `Series.plot({keyword_args})` instead of " + f"`Series.plot({positional_args})`." + ) + raise TypeError(msg) + + pos_args = { + name: value for (name, _), value in zip(arg_def, args, strict=False) + } + if backend_name == "pandas.plotting._matplotlib": + kwargs = dict(arg_def, **pos_args, **kwargs) + else: + kwargs = dict(pos_args, **kwargs) + + x = kwargs.pop("x", None) + y = kwargs.pop("y", None) + kind = kwargs.pop("kind", "line") + return x, y, kind, kwargs + + def __call__(self, *args, **kwargs): + plot_backend = _get_plot_backend(kwargs.pop("backend", None)) + + x, y, kind, kwargs = self._get_call_args( + plot_backend.__name__, self._parent, args, kwargs + ) + + kind = self._kind_aliases.get(kind, kind) + + # when using another backend, get out of the way + if plot_backend.__name__ != "pandas.plotting._matplotlib": + return plot_backend.plot(self._parent, x=x, y=y, kind=kind, **kwargs) + + if kind not in self._all_kinds: + raise ValueError( + f"{kind} is not a valid plot kind Valid plot kinds: {self._all_kinds}" + ) + + data = self._parent + + if isinstance(data, ABCSeries): + kwargs["reuse_plot"] = True + + if kind in self._dataframe_kinds: + if isinstance(data, ABCDataFrame): + return plot_backend.plot(data, x=x, y=y, kind=kind, **kwargs) + else: + raise ValueError(f"plot kind {kind} can only be used for data frames") + elif kind in self._series_kinds: + if isinstance(data, ABCDataFrame): + if y is None and kwargs.get("subplots") is False: + raise ValueError( + f"{kind} requires either y column or 'subplots=True'" + ) + if y is not None: + if is_integer(y) and not holds_integer(data.columns): + y = data.columns[y] + # converted to series actually. copy to not modify + data = data[y].copy(deep=False) + data.index.name = y + elif isinstance(data, ABCDataFrame): + data_cols = data.columns + if x is not None: + if is_integer(x) and not holds_integer(data.columns): + x = data_cols[x] + elif not isinstance(data[x], ABCSeries): + raise ValueError("x must be a label or position") + data = data.set_index(x) + if y is not None: + # check if we have y as int or list of ints + int_ylist = is_list_like(y) and all(is_integer(c) for c in y) + int_y_arg = is_integer(y) or int_ylist + if int_y_arg and not holds_integer(data.columns): + y = data_cols[y] + + label_kw = kwargs["label"] if "label" in kwargs else False + for kw in ["xerr", "yerr"]: + if kw in kwargs and ( + isinstance(kwargs[kw], str) or is_integer(kwargs[kw]) + ): + try: + kwargs[kw] = data[kwargs[kw]] + except (IndexError, KeyError, TypeError): + pass + + data = data[y] + + if isinstance(data, ABCSeries): + label_name = label_kw or y + data.name = label_name + else: + # error: Argument 1 to "len" has incompatible type "Any | bool"; + # expected "Sized" [arg-type] + match = is_list_like(label_kw) and len(label_kw) == len(y) # type: ignore[arg-type] + if label_kw and not match: + raise ValueError( + "label should be list-like and same length as y" + ) + label_name = label_kw or data.columns + data.columns = label_name + + return plot_backend.plot(data, kind=kind, **kwargs) + + __call__.__doc__ = __doc__ + + def line( + self, + x: Hashable | None = None, + y: Hashable | None = None, + color: str | Sequence[str] | dict | None = None, + **kwargs, + ) -> PlotAccessor: + """ + Plot Series or DataFrame as lines. + + This function is useful to plot lines using DataFrame's values + as coordinates. + + Parameters + ---------- + x : label or position, optional + Allows plotting of one column versus another. If not specified, + the index of the DataFrame is used. + y : label or position, optional + Allows plotting of one column versus another. If not specified, + all numerical columns are used. + color : str, array-like, or dict, optional + The color for each of the DataFrame's columns. Possible values are: + + - A single color string referred to by name, RGB or RGBA code, + for instance 'red' or '#a98d19'. + + - A sequence of color strings referred to by name, RGB or RGBA + code, which will be used for each column recursively. For + instance ['green','yellow'] each column's line will be filled in + green or yellow, alternatively. If there is only a single column to + be plotted, then only the first color from the color list will be + used. + + - A dict of the form {column name : color}, so that each column will be + colored accordingly. For example, if your columns are called `a` and + `b`, then passing {'a': 'green', 'b': 'red'} will color lines for + column `a` in green and lines for column `b` in red. + + **kwargs + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns + ------- + matplotlib.axes.Axes or np.ndarray of them + An ndarray is returned with one :class:`matplotlib.axes.Axes` + per column when ``subplots=True``. + + See Also + -------- + matplotlib.pyplot.plot : Plot y versus x as lines and/or markers. + + Examples + -------- + + .. plot:: + :context: close-figs + + >>> s = pd.Series([1, 3, 2]) + >>> s.plot.line() # doctest: +SKIP + + .. plot:: + :context: close-figs + + The following example shows the populations for some animals + over the years. + + >>> df = pd.DataFrame( + ... { + ... "pig": [20, 18, 489, 675, 1776], + ... "horse": [4, 25, 281, 600, 1900], + ... }, + ... index=[1990, 1997, 2003, 2009, 2014], + ... ) + >>> lines = df.plot.line() + + .. plot:: + :context: close-figs + + An example with subplots, so an array of axes is returned. + + >>> axes = df.plot.line(subplots=True) + >>> type(axes) + + + .. plot:: + :context: close-figs + + Let's repeat the same example, but specifying colors for + each column (in this case, for each animal). + + >>> axes = df.plot.line( + ... subplots=True, color={"pig": "pink", "horse": "#742802"} + ... ) + + .. plot:: + :context: close-figs + + The following example shows the relationship between both + populations. + + >>> lines = df.plot.line(x="pig", y="horse") + """ + if color is not None: + kwargs["color"] = color + return self(kind="line", x=x, y=y, **kwargs) + + def bar( + self, + x: Hashable | None = None, + y: Hashable | None = None, + color: str | Sequence[str] | dict | None = None, + **kwargs, + ) -> PlotAccessor: + """ + Vertical bar plot. + + A bar plot is a plot that presents categorical data with + rectangular bars with lengths proportional to the values that they + represent. A bar plot shows comparisons among discrete categories. One + axis of the plot shows the specific categories being compared, and the + other axis represents a measured value. + + Parameters + ---------- + x : label or position, optional + Allows plotting of one column versus another. If not specified, + the index of the DataFrame is used. + y : label or position, optional + Allows plotting of one column versus another. If not specified, + all numerical columns are used. + color : str, array-like, or dict, optional + The color for each of the DataFrame's columns. Possible values are: + + - A single color string referred to by name, RGB or RGBA code, + for instance 'red' or '#a98d19'. + + - A sequence of color strings referred to by name, RGB or RGBA + code, which will be used for each column recursively. For + instance ['green','yellow'] each column's bar will be filled in + green or yellow, alternatively. If there is only a single column to + be plotted, then only the first color from the color list will be + used. + + - A dict of the form {column name : color}, so that each column will be + colored accordingly. For example, if your columns are called `a` and + `b`, then passing {'a': 'green', 'b': 'red'} will color bars for + column `a` in green and bars for column `b` in red. + + **kwargs + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns + ------- + matplotlib.axes.Axes or np.ndarray of them + An ndarray is returned with one :class:`matplotlib.axes.Axes` + per column when ``subplots=True``. + + See Also + -------- + DataFrame.plot.barh : Horizontal bar plot. + DataFrame.plot : Make plots of a DataFrame. + matplotlib.pyplot.bar : Make a bar plot with matplotlib. + + Examples + -------- + Basic plot. + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame({"lab": ["A", "B", "C"], "val": [10, 30, 20]}) + >>> ax = df.plot.bar(x="lab", y="val", rot=0) + + Plot a whole dataframe to a bar plot. Each column is assigned a + distinct color, and each row is nested in a group along the + horizontal axis. + + .. plot:: + :context: close-figs + + >>> speed = [0.1, 17.5, 40, 48, 52, 69, 88] + >>> lifespan = [2, 8, 70, 1.5, 25, 12, 28] + >>> index = [ + ... "snail", + ... "pig", + ... "elephant", + ... "rabbit", + ... "giraffe", + ... "coyote", + ... "horse", + ... ] + >>> df = pd.DataFrame({"speed": speed, "lifespan": lifespan}, index=index) + >>> ax = df.plot.bar(rot=0) + + Plot stacked bar charts for the DataFrame + + .. plot:: + :context: close-figs + + >>> ax = df.plot.bar(stacked=True) + + Instead of nesting, the figure can be split by column with + ``subplots=True``. In this case, a :class:`numpy.ndarray` of + :class:`matplotlib.axes.Axes` are returned. + + .. plot:: + :context: close-figs + + >>> axes = df.plot.bar(rot=0, subplots=True) + >>> axes[1].legend(loc=2) # doctest: +SKIP + + If you don't like the default colours, you can specify how you'd + like each column to be colored. + + .. plot:: + :context: close-figs + + >>> axes = df.plot.bar( + ... rot=0, + ... subplots=True, + ... color={"speed": "red", "lifespan": "green"}, + ... ) + >>> axes[1].legend(loc=2) # doctest: +SKIP + + Plot a single column. + + .. plot:: + :context: close-figs + + >>> ax = df.plot.bar(y="speed", rot=0) + + Plot only selected categories for the DataFrame. + + .. plot:: + :context: close-figs + + >>> ax = df.plot.bar(x="lifespan", rot=0) + """ + if color is not None: + kwargs["color"] = color + return self(kind="bar", x=x, y=y, **kwargs) + + def barh( + self, + x: Hashable | None = None, + y: Hashable | None = None, + color: str | Sequence[str] | dict | None = None, + **kwargs, + ) -> PlotAccessor: + """ + Make a horizontal bar plot. + + A horizontal bar plot is a plot that presents quantitative data with + rectangular bars with lengths proportional to the values that they + represent. A bar plot shows comparisons among discrete categories. One + axis of the plot shows the specific categories being compared, and the + other axis represents a measured value. + + Parameters + ---------- + x : label or position, optional + Allows plotting of one column versus another. If not specified, + the index of the DataFrame is used. + y : label or position, optional + Allows plotting of one column versus another. If not specified, + all numerical columns are used. + color : str, array-like, or dict, optional + The color for each of the DataFrame's columns. Possible values are: + + - A single color string referred to by name, RGB or RGBA code, + for instance 'red' or '#a98d19'. + + - A sequence of color strings referred to by name, RGB or RGBA + code, which will be used for each column recursively. For + instance ['green','yellow'] each column's bar will be filled in + green or yellow, alternatively. If there is only a single column to + be plotted, then only the first color from the color list will be + used. + + - A dict of the form {column name : color}, so that each column will be + colored accordingly. For example, if your columns are called `a` and + `b`, then passing {'a': 'green', 'b': 'red'} will color bars for + column `a` in green and bars for column `b` in red. + + **kwargs + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns + ------- + matplotlib.axes.Axes or np.ndarray of them + An ndarray is returned with one :class:`matplotlib.axes.Axes` + per column when ``subplots=True``. + + See Also + -------- + DataFrame.plot.bar : Vertical bar plot. + DataFrame.plot : Make plots of DataFrame using matplotlib. + matplotlib.axes.Axes.bar : Plot a vertical bar plot using matplotlib. + + Examples + -------- + Basic example + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame({"lab": ["A", "B", "C"], "val": [10, 30, 20]}) + >>> ax = df.plot.barh(x="lab", y="val") + + Plot a whole DataFrame to a horizontal bar plot + + .. plot:: + :context: close-figs + + >>> speed = [0.1, 17.5, 40, 48, 52, 69, 88] + >>> lifespan = [2, 8, 70, 1.5, 25, 12, 28] + >>> index = [ + ... "snail", + ... "pig", + ... "elephant", + ... "rabbit", + ... "giraffe", + ... "coyote", + ... "horse", + ... ] + >>> df = pd.DataFrame({"speed": speed, "lifespan": lifespan}, index=index) + >>> ax = df.plot.barh() + + Plot stacked barh charts for the DataFrame + + .. plot:: + :context: close-figs + + >>> ax = df.plot.barh(stacked=True) + + We can specify colors for each column + + .. plot:: + :context: close-figs + + >>> ax = df.plot.barh(color={"speed": "red", "lifespan": "green"}) + + Plot a column of the DataFrame to a horizontal bar plot + + .. plot:: + :context: close-figs + + >>> speed = [0.1, 17.5, 40, 48, 52, 69, 88] + >>> lifespan = [2, 8, 70, 1.5, 25, 12, 28] + >>> index = [ + ... "snail", + ... "pig", + ... "elephant", + ... "rabbit", + ... "giraffe", + ... "coyote", + ... "horse", + ... ] + >>> df = pd.DataFrame({"speed": speed, "lifespan": lifespan}, index=index) + >>> ax = df.plot.barh(y="speed") + + Plot DataFrame versus the desired column + + .. plot:: + :context: close-figs + + >>> speed = [0.1, 17.5, 40, 48, 52, 69, 88] + >>> lifespan = [2, 8, 70, 1.5, 25, 12, 28] + >>> index = [ + ... "snail", + ... "pig", + ... "elephant", + ... "rabbit", + ... "giraffe", + ... "coyote", + ... "horse", + ... ] + >>> df = pd.DataFrame({"speed": speed, "lifespan": lifespan}, index=index) + >>> ax = df.plot.barh(x="lifespan") + """ + if color is not None: + kwargs["color"] = color + return self(kind="barh", x=x, y=y, **kwargs) + + def box(self, by: IndexLabel | None = None, **kwargs) -> PlotAccessor: + r""" + Make a box plot of the DataFrame columns. + + A box plot is a method for graphically depicting groups of numerical + data through their quartiles. + The box extends from the Q1 to Q3 quartile values of the data, + with a line at the median (Q2). The whiskers extend from the edges + of box to show the range of the data. The position of the whiskers + is set by default to 1.5*IQR (IQR = Q3 - Q1) from the edges of the + box. Outlier points are those past the end of the whiskers. + + For further details see Wikipedia's + entry for `boxplot `__. + + A consideration when using this chart is that the box and the whiskers + can overlap, which is very common when plotting small sets of data. + + Parameters + ---------- + by : str or sequence + Column in the DataFrame to group by. + + **kwargs + Additional keywords are documented in + :meth:`DataFrame.plot`. + + Returns + ------- + :class:`matplotlib.axes.Axes` or numpy.ndarray of them + The matplotlib axes containing the box plot. + + See Also + -------- + DataFrame.boxplot: Another method to draw a box plot. + Series.plot.box: Draw a box plot from a Series object. + matplotlib.pyplot.boxplot: Draw a box plot in matplotlib. + + Examples + -------- + Draw a box plot from a DataFrame with four columns of randomly + generated data. + + .. plot:: + :context: close-figs + + >>> data = np.random.randn(25, 4) + >>> df = pd.DataFrame(data, columns=list("ABCD")) + >>> ax = df.plot.box() + + You can also generate groupings if you specify the `by` parameter (which + can take a column name, or a list or tuple of column names): + + .. plot:: + :context: close-figs + + >>> age_list = [8, 10, 12, 14, 72, 74, 76, 78, 20, 25, 30, 35, 60, 85] + >>> df = pd.DataFrame({"gender": list("MMMMMMMMFFFFFF"), "age": age_list}) + >>> ax = df.plot.box(column="age", by="gender", figsize=(10, 8)) + """ + return self(kind="box", by=by, **kwargs) + + def hist( + self, by: IndexLabel | None = None, bins: int = 10, **kwargs + ) -> PlotAccessor: + """ + Draw one histogram of the DataFrame's columns. + + A histogram is a representation of the distribution of data. + This function groups the values of all given Series in the DataFrame + into bins and draws all bins in one :class:`matplotlib.axes.Axes`. + This is useful when the DataFrame's Series are in a similar scale. + + Parameters + ---------- + by : str or sequence, optional + Column in the DataFrame to group by. + bins : int, default 10 + Number of histogram bins to be used. + **kwargs + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns + ------- + :class:`matplotlib.axes.Axes` + Return a histogram plot. + + See Also + -------- + DataFrame.hist : Draw histograms per DataFrame's Series. + Series.hist : Draw a histogram with Series' data. + + Examples + -------- + When we roll a die 6000 times, we expect to get each value around 1000 + times. But when we roll two dice and sum the result, the distribution + is going to be quite different. A histogram illustrates those + distributions. + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame(np.random.randint(1, 7, 6000), columns=["one"]) + >>> df["two"] = df["one"] + np.random.randint(1, 7, 6000) + >>> ax = df.plot.hist(bins=12, alpha=0.5) + + A grouped histogram can be generated by providing the parameter `by` (which + can be a column name, or a list of column names): + + .. plot:: + :context: close-figs + + >>> age_list = [8, 10, 12, 14, 72, 74, 76, 78, 20, 25, 30, 35, 60, 85] + >>> df = pd.DataFrame({"gender": list("MMMMMMMMFFFFFF"), "age": age_list}) + >>> ax = df.plot.hist(column=["age"], by="gender", figsize=(10, 8)) + """ + return self(kind="hist", by=by, bins=bins, **kwargs) + + def kde( + self, + bw_method: Literal["scott", "silverman"] | float | Callable | None = None, + ind: np.ndarray | int | None = None, + weights: np.ndarray | None = None, + **kwargs, + ) -> PlotAccessor: + """ + Generate Kernel Density Estimate plot using Gaussian kernels. + + In statistics, `kernel density estimation`_ (KDE) is a non-parametric + way to estimate the probability density function (PDF) of a random + variable. This function uses Gaussian kernels and includes automatic + bandwidth determination. + + .. _kernel density estimation: + https://en.wikipedia.org/wiki/Kernel_density_estimation + + Parameters + ---------- + bw_method : str, scalar or callable, optional + The method used to calculate the estimator bandwidth. This can be + 'scott', 'silverman', a scalar constant or a callable. + If None (default), 'scott' is used. + See :class:`scipy.stats.gaussian_kde` for more information. + ind : NumPy array or int, optional + Evaluation points for the estimated PDF. If None (default), + 1000 equally spaced points are used. If `ind` is a NumPy array, the + KDE is evaluated at the points passed. If `ind` is an integer, + `ind` number of equally spaced points are used. + weights : NumPy array, optional + Weights of datapoints. This must be the same shape as datapoints. + If None (default), the samples are assumed to be equally weighted. + **kwargs + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns + ------- + matplotlib.axes.Axes or numpy.ndarray of them + The matplotlib axes containing the KDE plot. + + See Also + -------- + scipy.stats.gaussian_kde : Representation of a kernel-density + estimate using Gaussian kernels. This is the function used + internally to estimate the PDF. + + Examples + -------- + Given a Series of points randomly sampled from an unknown + distribution, estimate its PDF using KDE with automatic + bandwidth determination and plot the results, evaluating them at + 1000 equally spaced points (default): + + .. plot:: + :context: close-figs + + >>> s = pd.Series([1, 2, 2.5, 3, 3.5, 4, 5]) + >>> ax = s.plot.kde() + + A scalar bandwidth can be specified. Using a small bandwidth value can + lead to over-fitting, while using a large bandwidth value may result + in under-fitting: + + .. plot:: + :context: close-figs + + >>> ax = s.plot.kde(bw_method=0.3) + + .. plot:: + :context: close-figs + + >>> ax = s.plot.kde(bw_method=3) + + Finally, the `ind` parameter determines the evaluation points for the + plot of the estimated PDF: + + .. plot:: + :context: close-figs + + >>> ax = s.plot.kde(ind=[1, 2, 3, 4, 5]) + + For DataFrame, it works in the same way: + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame( + ... { + ... "x": [1, 2, 2.5, 3, 3.5, 4, 5], + ... "y": [4, 4, 4.5, 5, 5.5, 6, 6], + ... } + ... ) + >>> ax = df.plot.kde() + + A scalar bandwidth can be specified. Using a small bandwidth value can + lead to over-fitting, while using a large bandwidth value may result + in under-fitting: + + .. plot:: + :context: close-figs + + >>> ax = df.plot.kde(bw_method=0.3) + + .. plot:: + :context: close-figs + + >>> ax = df.plot.kde(bw_method=3) + + Finally, the `ind` parameter determines the evaluation points for the + plot of the estimated PDF: + + .. plot:: + :context: close-figs + + >>> ax = df.plot.kde(ind=[1, 2, 3, 4, 5, 6]) + """ + return self(kind="kde", bw_method=bw_method, ind=ind, weights=weights, **kwargs) + + density = kde + + def area( + self, + x: Hashable | None = None, + y: Hashable | None = None, + stacked: bool = True, + **kwargs, + ) -> PlotAccessor: + """ + Draw a stacked area plot. + + An area plot displays quantitative data visually. + This function wraps the matplotlib area function. + + Parameters + ---------- + x : label or position, optional + Coordinates for the X axis. By default uses the index. + y : label or position, optional + Column to plot. By default uses all columns. + stacked : bool, default True + Area plots are stacked by default. Set to False to create a + unstacked plot. + **kwargs + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns + ------- + matplotlib.axes.Axes or numpy.ndarray + Area plot, or array of area plots if subplots is True. + + See Also + -------- + DataFrame.plot : Make plots of DataFrame using matplotlib. + + Examples + -------- + Draw an area plot based on basic business metrics: + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame( + ... { + ... "sales": [3, 2, 3, 9, 10, 6], + ... "signups": [5, 5, 6, 12, 14, 13], + ... "visits": [20, 42, 28, 62, 81, 50], + ... }, + ... index=pd.date_range( + ... start="2018/01/01", end="2018/07/01", freq="ME" + ... ), + ... ) + >>> ax = df.plot.area() + + Area plots are stacked by default. To produce an unstacked plot, + pass ``stacked=False``: + + .. plot:: + :context: close-figs + + >>> ax = df.plot.area(stacked=False) + + Draw an area plot for a single column: + + .. plot:: + :context: close-figs + + >>> ax = df.plot.area(y="sales") + + Draw with a different `x`: + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame( + ... { + ... "sales": [3, 2, 3], + ... "visits": [20, 42, 28], + ... "day": [1, 2, 3], + ... } + ... ) + >>> ax = df.plot.area(x="day") + """ + return self(kind="area", x=x, y=y, stacked=stacked, **kwargs) + + def pie(self, y: IndexLabel | None = None, **kwargs) -> PlotAccessor: + """ + Generate a pie plot. + + A pie plot is a proportional representation of the numerical data in a + column. This function wraps :meth:`matplotlib.pyplot.pie` for the + specified column. If no column reference is passed and + ``subplots=True`` a pie plot is drawn for each numerical column + independently. + + Parameters + ---------- + y : int or label, optional + Label or position of the column to plot. + If not provided, ``subplots=True`` argument must be passed. + **kwargs + Keyword arguments to pass on to :meth:`DataFrame.plot`. + + Returns + ------- + matplotlib.axes.Axes or np.ndarray of them + A NumPy array is returned when `subplots` is True. + + See Also + -------- + Series.plot.pie : Generate a pie plot for a Series. + DataFrame.plot : Make plots of a DataFrame. + + Examples + -------- + In the example below we have a DataFrame with the information about + planet's mass and radius. We pass the 'mass' column to the + pie function to get a pie plot. + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame( + ... {"mass": [0.330, 4.87, 5.97], "radius": [2439.7, 6051.8, 6378.1]}, + ... index=["Mercury", "Venus", "Earth"], + ... ) + >>> plot = df.plot.pie(y="mass", figsize=(5, 5)) + + .. plot:: + :context: close-figs + + >>> plot = df.plot.pie(subplots=True, figsize=(11, 6)) + """ + if y is not None: + kwargs["y"] = y + if ( + isinstance(self._parent, ABCDataFrame) + and kwargs.get("y", None) is None + and not kwargs.get("subplots", False) + ): + raise ValueError("pie requires either y column or 'subplots=True'") + return self(kind="pie", **kwargs) + + def scatter( + self, + x: Hashable, + y: Hashable, + s: Hashable | Sequence[Hashable] | None = None, + c: Hashable | Sequence[Hashable] | None = None, + **kwargs, + ) -> PlotAccessor: + """ + Create a scatter plot with varying marker point size and color. + + The coordinates of each point are defined by two dataframe columns and + filled circles are used to represent each point. This kind of plot is + useful to see complex correlations between two variables. Points could + be for instance natural 2D coordinates like longitude and latitude in + a map or, in general, any pair of metrics that can be plotted against + each other. + + Parameters + ---------- + x : int or str + The column name or column position to be used as horizontal + coordinates for each point. + y : int or str + The column name or column position to be used as vertical + coordinates for each point. + s : str, scalar or array-like, optional + The size of each point. Possible values are: + + - A string with the name of the column to be used for marker's size. + + - A single scalar so all points have the same size. + + - A sequence of scalars, which will be used for each point's size + recursively. For instance, when passing [2,14] all points size + will be either 2 or 14, alternatively. + + c : str, int or array-like, optional + The color of each point. Possible values are: + + - A single color string referred to by name, RGB or RGBA code, + for instance 'red' or '#a98d19'. + + - A sequence of color strings referred to by name, RGB or RGBA + code, which will be used for each point's color recursively. For + instance ['green','yellow'] all points will be filled in green or + yellow, alternatively. + + - A column name or position whose values will be used to color the + marker points according to a colormap. + + **kwargs + Keyword arguments to pass on to :meth:`DataFrame.plot`. + + Returns + ------- + :class:`matplotlib.axes.Axes` or numpy.ndarray of them + The matplotlib axes containing the scatter plot. + + See Also + -------- + matplotlib.pyplot.scatter : Scatter plot using multiple input data + formats. + + Examples + -------- + Let's see how to draw a scatter plot using coordinates from the values + in a DataFrame's columns. + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame( + ... [ + ... [5.1, 3.5, 0], + ... [4.9, 3.0, 0], + ... [7.0, 3.2, 1], + ... [6.4, 3.2, 1], + ... [5.9, 3.0, 2], + ... ], + ... columns=["length", "width", "species"], + ... ) + >>> ax1 = df.plot.scatter(x="length", y="width", c="DarkBlue") + + And now with the color determined by a column as well. + + .. plot:: + :context: close-figs + + >>> ax2 = df.plot.scatter( + ... x="length", y="width", c="species", colormap="viridis" + ... ) + """ + return self(kind="scatter", x=x, y=y, s=s, c=c, **kwargs) + + def hexbin( + self, + x: Hashable, + y: Hashable, + C: Hashable | None = None, + reduce_C_function: Callable | None = None, + gridsize: int | tuple[int, int] | None = None, + **kwargs, + ) -> PlotAccessor: + """ + Generate a hexagonal binning plot. + + Generate a hexagonal binning plot of `x` versus `y`. If `C` is `None` + (the default), this is a histogram of the number of occurrences + of the observations at ``(x[i], y[i])``. + + If `C` is specified, specifies values at given coordinates + ``(x[i], y[i])``. These values are accumulated for each hexagonal + bin and then reduced according to `reduce_C_function`, + having as default the NumPy's mean function (:meth:`numpy.mean`). + (If `C` is specified, it must also be a 1-D sequence + of the same length as `x` and `y`, or a column label.) + + Parameters + ---------- + x : int or str + The column label or position for x points. + y : int or str + The column label or position for y points. + C : int or str, optional + The column label or position for the value of `(x, y)` point. + reduce_C_function : callable, default `np.mean` + Function of one argument that reduces all the values in a bin to + a single number (e.g. `np.mean`, `np.max`, `np.sum`, `np.std`). + gridsize : int or tuple of (int, int), default 100 + The number of hexagons in the x-direction. + The corresponding number of hexagons in the y-direction is + chosen in a way that the hexagons are approximately regular. + Alternatively, gridsize can be a tuple with two elements + specifying the number of hexagons in the x-direction and the + y-direction. + **kwargs + Additional keyword arguments are documented in + :meth:`DataFrame.plot`. + + Returns + ------- + matplotlib.Axes + The matplotlib ``Axes`` on which the hexbin is plotted. + + See Also + -------- + DataFrame.plot : Make plots of a DataFrame. + matplotlib.pyplot.hexbin : Hexagonal binning plot using matplotlib, + the matplotlib function that is used under the hood. + + Examples + -------- + The following examples are generated with random data from + a normal distribution. + + .. plot:: + :context: close-figs + + >>> n = 10000 + >>> df = pd.DataFrame({"x": np.random.randn(n), "y": np.random.randn(n)}) + >>> ax = df.plot.hexbin(x="x", y="y", gridsize=20) + + The next example uses `C` and `np.sum` as `reduce_C_function`. + Note that `'observations'` values ranges from 1 to 5 but the result + plot shows values up to more than 25. This is because of the + `reduce_C_function`. + + .. plot:: + :context: close-figs + + >>> n = 500 + >>> df = pd.DataFrame( + ... { + ... "coord_x": np.random.uniform(-3, 3, size=n), + ... "coord_y": np.random.uniform(30, 50, size=n), + ... "observations": np.random.randint(1, 5, size=n), + ... } + ... ) + >>> ax = df.plot.hexbin( + ... x="coord_x", + ... y="coord_y", + ... C="observations", + ... reduce_C_function=np.sum, + ... gridsize=10, + ... cmap="viridis", + ... ) + """ + if reduce_C_function is not None: + kwargs["reduce_C_function"] = reduce_C_function + if gridsize is not None: + kwargs["gridsize"] = gridsize + + return self(kind="hexbin", x=x, y=y, C=C, **kwargs) + + +_backends: dict[str, types.ModuleType] = {} + + +def _load_backend(backend: str) -> types.ModuleType: + """ + Load a pandas plotting backend. + + Parameters + ---------- + backend : str + The identifier for the backend. Either an entrypoint item registered + with importlib.metadata, "matplotlib", or a module name. + + Returns + ------- + types.ModuleType + The imported backend. + """ + from importlib.metadata import entry_points + + if backend == "matplotlib": + # Because matplotlib is an optional dependency and first-party backend, + # we need to attempt an import here to raise an ImportError if needed. + try: + module = importlib.import_module("pandas.plotting._matplotlib") + except ImportError: + raise ImportError( + "matplotlib is required for plotting when the " + 'default backend "matplotlib" is selected.' + ) from None + return module + + found_backend = False + + eps = entry_points() + key = "pandas_plotting_backends" + # entry_points lost dict API ~ PY 3.10 + # https://github.com/python/importlib_metadata/issues/298 + if hasattr(eps, "select"): + entry = eps.select(group=key) + else: + # Argument 2 to "get" of "dict" has incompatible type "Tuple[]"; + # expected "EntryPoints" [arg-type] + entry = eps.get(key, ()) # type: ignore[arg-type] + for entry_point in entry: + found_backend = entry_point.name == backend + if found_backend: + module = entry_point.load() + break + + if not found_backend: + # Fall back to unregistered, module name approach. + try: + module = importlib.import_module(backend) + found_backend = True + except ImportError: + # We re-raise later on. + pass + + if found_backend: + if hasattr(module, "plot"): + # Validate that the interface is implemented when the option is set, + # rather than at plot time. + return module + + raise ValueError( + f"Could not find plotting backend '{backend}'. Ensure that you've " + f"installed the package providing the '{backend}' entrypoint, or that " + "the package has a top-level `.plot` method." + ) + + +def _get_plot_backend(backend: str | None = None): + """ + Return the plotting backend to use (e.g. `pandas.plotting._matplotlib`). + + The plotting system of pandas uses matplotlib by default, but the idea here + is that it can also work with other third-party backends. This function + returns the module which provides a top-level `.plot` method that will + actually do the plotting. The backend is specified from a string, which + either comes from the keyword argument `backend`, or, if not specified, from + the option `pandas.options.plotting.backend`. All the rest of the code in + this file uses the backend specified there for the plotting. + + The backend is imported lazily, as matplotlib is a soft dependency, and + pandas can be used without it being installed. + + Notes + ----- + Modifies `_backends` with imported backend as a side effect. + """ + backend_str: str = backend or get_option("plotting.backend") + + if backend_str in _backends: + return _backends[backend_str] + + module = _load_backend(backend_str) + _backends[backend_str] = module + return module diff --git a/pandas/plotting/_matplotlib/__init__.py b/pandas/plotting/_matplotlib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff28868aa003326355f0e3e4b5b7914edb63121c --- /dev/null +++ b/pandas/plotting/_matplotlib/__init__.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pandas.plotting._matplotlib.boxplot import ( + BoxPlot, + boxplot, + boxplot_frame, + boxplot_frame_groupby, +) +from pandas.plotting._matplotlib.converter import ( + deregister, + register, +) +from pandas.plotting._matplotlib.core import ( + AreaPlot, + BarhPlot, + BarPlot, + HexBinPlot, + LinePlot, + PiePlot, + ScatterPlot, +) +from pandas.plotting._matplotlib.hist import ( + HistPlot, + KdePlot, + hist_frame, + hist_series, +) +from pandas.plotting._matplotlib.misc import ( + andrews_curves, + autocorrelation_plot, + bootstrap_plot, + lag_plot, + parallel_coordinates, + radviz, + scatter_matrix, +) +from pandas.plotting._matplotlib.tools import table + +if TYPE_CHECKING: + from pandas.plotting._matplotlib.core import MPLPlot + +PLOT_CLASSES: dict[str, type[MPLPlot]] = { + "line": LinePlot, + "bar": BarPlot, + "barh": BarhPlot, + "box": BoxPlot, + "hist": HistPlot, + "kde": KdePlot, + "area": AreaPlot, + "pie": PiePlot, + "scatter": ScatterPlot, + "hexbin": HexBinPlot, +} + + +def plot(data, kind, **kwargs): + # Importing pyplot at the top of the file (before the converters are + # registered) causes problems in matplotlib 2 (converters seem to not + # work) + import matplotlib.pyplot as plt + + if kwargs.pop("reuse_plot", False): + ax = kwargs.get("ax") + if ax is None and len(plt.get_fignums()) > 0: + with plt.rc_context(): + ax = plt.gca() + kwargs["ax"] = getattr(ax, "left_ax", ax) + plot_obj = PLOT_CLASSES[kind](data, **kwargs) + plot_obj.generate() + plt.draw_if_interactive() + return plot_obj.result + + +__all__ = [ + "andrews_curves", + "autocorrelation_plot", + "bootstrap_plot", + "boxplot", + "boxplot_frame", + "boxplot_frame_groupby", + "deregister", + "hist_frame", + "hist_series", + "lag_plot", + "parallel_coordinates", + "plot", + "radviz", + "register", + "scatter_matrix", + "table", +] diff --git a/pandas/plotting/_matplotlib/boxplot.py b/pandas/plotting/_matplotlib/boxplot.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb185c51478f9c892631934f0b818effc5a5c96 --- /dev/null +++ b/pandas/plotting/_matplotlib/boxplot.py @@ -0,0 +1,563 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Literal, + NamedTuple, +) +import warnings + +import matplotlib as mpl +import numpy as np + +from pandas._libs import lib +from pandas.util._decorators import cache_readonly +from pandas.util._exceptions import find_stack_level + +from pandas.core.dtypes.common import is_dict_like +from pandas.core.dtypes.generic import ABCSeries +from pandas.core.dtypes.missing import remove_na_arraylike + +import pandas as pd +import pandas.core.common as com +from pandas.util.version import Version + +from pandas.io.formats.printing import pprint_thing +from pandas.plotting._matplotlib.core import ( + LinePlot, + MPLPlot, +) +from pandas.plotting._matplotlib.groupby import create_iter_data_given_by +from pandas.plotting._matplotlib.style import get_standard_colors +from pandas.plotting._matplotlib.tools import ( + create_subplots, + flatten_axes, + maybe_adjust_figure, +) + +if TYPE_CHECKING: + from collections.abc import Collection + + from matplotlib.axes import Axes + from matplotlib.figure import Figure + from matplotlib.lines import Line2D + + from pandas._typing import MatplotlibColor + + +def _set_ticklabels(ax: Axes, labels: list[str], is_vertical: bool, **kwargs) -> None: + """Set the tick labels of a given axis. + + Due to https://github.com/matplotlib/matplotlib/pull/17266, we need to handle the + case of repeated ticks (due to `FixedLocator`) and thus we duplicate the number of + labels. + """ + ticks = ax.get_xticks() if is_vertical else ax.get_yticks() + if len(ticks) != len(labels): + i, remainder = divmod(len(ticks), len(labels)) + if Version(mpl.__version__) < Version("3.10"): + assert remainder == 0, remainder + labels *= i + if is_vertical: + ax.set_xticklabels(labels, **kwargs) + else: + ax.set_yticklabels(labels, **kwargs) + + +class BoxPlot(LinePlot): + @property + def _kind(self) -> Literal["box"]: + return "box" + + _layout_type = "horizontal" + + _valid_return_types = (None, "axes", "dict", "both") + + class BP(NamedTuple): + # namedtuple to hold results + ax: Axes + lines: dict[str, list[Line2D]] + + def __init__(self, data, return_type: str = "axes", **kwargs) -> None: + if return_type not in self._valid_return_types: + raise ValueError("return_type must be {None, 'axes', 'dict', 'both'}") + + self.return_type = return_type + # Do not call LinePlot.__init__ which may fill nan + MPLPlot.__init__(self, data, **kwargs) + + if self.subplots: + # Disable label ax sharing. Otherwise, all subplots shows last + # column label + if self.orientation == "vertical": + self.sharex = False + else: + self.sharey = False + + # error: Signature of "_plot" incompatible with supertype "MPLPlot" + @classmethod + def _plot( # type: ignore[override] + cls, ax: Axes, y: np.ndarray, column_num=None, return_type: str = "axes", **kwds + ): + ys: np.ndarray | list[np.ndarray] + if y.ndim == 2: + ys = [remove_na_arraylike(v) for v in y] + # Boxplot fails with empty arrays, so need to add a NaN + # if any cols are empty + # GH 8181 + ys = [v if v.size > 0 else np.array([np.nan]) for v in ys] + else: + ys = remove_na_arraylike(y) + bp = ax.boxplot(ys, **kwds) + + if return_type == "dict": + return bp, bp + elif return_type == "both": + return cls.BP(ax=ax, lines=bp), bp + else: + return ax, bp + + def _validate_color_args(self, color, colormap): + if color is lib.no_default: + return None + + if colormap is not None: + warnings.warn( + "'color' and 'colormap' cannot be used simultaneously. Using 'color'", + stacklevel=find_stack_level(), + ) + + if isinstance(color, dict): + valid_keys = ["boxes", "whiskers", "medians", "caps"] + for key in color: + if key not in valid_keys: + raise ValueError( + f"color dict contains invalid key '{key}'. " + f"The key must be either {valid_keys}" + ) + return color + + @cache_readonly + def _color_attrs(self): + # get standard colors for default + # use 2 colors by default, for box/whisker and median + # flier colors isn't needed here + # because it can be specified by ``sym`` kw + return get_standard_colors(num_colors=3, colormap=self.colormap, color=None) + + @cache_readonly + def _boxes_c(self): + return self._color_attrs[0] + + @cache_readonly + def _whiskers_c(self): + return self._color_attrs[0] + + @cache_readonly + def _medians_c(self): + return self._color_attrs[2] + + @cache_readonly + def _caps_c(self): + return self._color_attrs[0] + + def _get_colors( + self, + num_colors=None, + color_kwds: dict[str, MatplotlibColor] + | MatplotlibColor + | Collection[MatplotlibColor] + | None = "color", + ) -> None: + pass + + def maybe_color_bp(self, bp) -> None: + if isinstance(self.color, dict): + boxes = self.color.get("boxes", self._boxes_c) + whiskers = self.color.get("whiskers", self._whiskers_c) + medians = self.color.get("medians", self._medians_c) + caps = self.color.get("caps", self._caps_c) + else: + # Other types are forwarded to matplotlib + # If None, use default colors + boxes = self.color or self._boxes_c + whiskers = self.color or self._whiskers_c + medians = self.color or self._medians_c + caps = self.color or self._caps_c + + color_tup = (boxes, whiskers, medians, caps) + maybe_color_bp(bp, color_tup=color_tup, **self.kwds) + + def _make_plot(self, fig: Figure) -> None: + if self.subplots: + obj_axes = [] + obj_labels = [] + + # Re-create iterated data if `by` is assigned by users + data = ( + create_iter_data_given_by(self.data, self._kind) + if self.by is not None + else self.data + ) + + for i, (label, y) in enumerate(self._iter_data(data=data)): + ax = self._get_ax(i) + kwds = self.kwds.copy() + + # When by is applied, show title for subplots to know which group it is + # just like df.boxplot, and need to apply T on y to provide right input + if self.by is not None: + y = y.T + ax.set_title(pprint_thing(label)) + + # When `by` is assigned, the ticklabels will become unique grouped + # values, instead of label which is used as subtitle in this case. + # error: "Index" has no attribute "levels"; maybe "nlevels"? + levels = self.data.columns.levels # type: ignore[attr-defined] + ticklabels = [pprint_thing(col) for col in levels[0]] + else: + ticklabels = [pprint_thing(label)] + + ret, bp = self._plot( + ax, y, column_num=i, return_type=self.return_type, **kwds + ) + self.maybe_color_bp(bp) + obj_axes.append(ret) + obj_labels.append(label) + _set_ticklabels( + ax=ax, labels=ticklabels, is_vertical=self.orientation == "vertical" + ) + self._return_obj = pd.Series(obj_axes, index=obj_labels, dtype=object) + else: + y = self.data.values.T + ax = self._get_ax(0) + kwds = self.kwds.copy() + + ret, bp = self._plot( + ax, y, column_num=0, return_type=self.return_type, **kwds + ) + self.maybe_color_bp(bp) + self._return_obj = ret + + labels = [pprint_thing(left) for left in self.data.columns] + if not self.use_index: + labels = [pprint_thing(key) for key in range(len(labels))] + _set_ticklabels( + ax=ax, labels=labels, is_vertical=self.orientation == "vertical" + ) + + def _make_legend(self) -> None: + pass + + def _post_plot_logic(self, ax: Axes, data) -> None: + # GH 45465: make sure that the boxplot doesn't ignore xlabel/ylabel + if self.xlabel: + ax.set_xlabel(pprint_thing(self.xlabel)) + if self.ylabel: + ax.set_ylabel(pprint_thing(self.ylabel)) + + @property + def orientation(self) -> Literal["horizontal", "vertical"]: + if self.kwds.get("vert", True): + return "vertical" + else: + return "horizontal" + + @property + def result(self): + if self.return_type is None: + return super().result + else: + return self._return_obj + + +def maybe_color_bp(bp, color_tup, **kwds) -> None: + # GH#30346, when users specifying those arguments explicitly, our defaults + # for these four kwargs should be overridden; if not, use Pandas settings + if not kwds.get("boxprops"): + mpl.artist.setp(bp["boxes"], color=color_tup[0], alpha=1) + if not kwds.get("whiskerprops"): + mpl.artist.setp(bp["whiskers"], color=color_tup[1], alpha=1) + if not kwds.get("medianprops"): + mpl.artist.setp(bp["medians"], color=color_tup[2], alpha=1) + if not kwds.get("capprops"): + mpl.artist.setp(bp["caps"], color=color_tup[3], alpha=1) + + +def _grouped_plot_by_column( + plotf, + data, + columns=None, + by=None, + numeric_only: bool = True, + grid: bool = False, + figsize: tuple[float, float] | None = None, + ax=None, + layout=None, + return_type=None, + **kwargs, +): + grouped = data.groupby(by, observed=False) + if columns is None: + if not isinstance(by, (list, tuple)): + by = [by] + columns = data._get_numeric_data().columns.difference(by) + naxes = len(columns) + fig, axes = create_subplots( + naxes=naxes, + sharex=kwargs.pop("sharex", True), + sharey=kwargs.pop("sharey", True), + figsize=figsize, + ax=ax, + layout=layout, + ) + + # GH 45465: move the "by" label based on "vert" + xlabel, ylabel = kwargs.pop("xlabel", None), kwargs.pop("ylabel", None) + if kwargs.get("vert", True): + xlabel = xlabel or by + else: + ylabel = ylabel or by + + ax_values = [] + + for ax, col in zip(flatten_axes(axes), columns, strict=False): + gp_col = grouped[col] + keys, values = zip(*gp_col, strict=True) + re_plotf = plotf(keys, values, ax, xlabel=xlabel, ylabel=ylabel, **kwargs) + ax.set_title(col) + ax_values.append(re_plotf) + ax.grid(grid) + + result = pd.Series(ax_values, index=columns, copy=False) + + # Return axes in multiplot case, maybe revisit later # 985 + if return_type is None: + result = axes + + byline = by[0] if len(by) == 1 else by + fig.suptitle(f"Boxplot grouped by {byline}") + maybe_adjust_figure(fig, bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2) + + return result + + +def boxplot( + data, + column=None, + by=None, + ax=None, + fontsize: int | None = None, + rot: int = 0, + grid: bool = True, + figsize: tuple[float, float] | None = None, + layout=None, + return_type=None, + **kwds, +): + import matplotlib.pyplot as plt + + # validate return_type: + if return_type not in BoxPlot._valid_return_types: + raise ValueError("return_type must be {'axes', 'dict', 'both'}") + + if isinstance(data, ABCSeries): + data = data.to_frame("x") + column = "x" + + def _get_colors(): + # num_colors=3 is required as method maybe_color_bp takes the colors + # in positions 0 and 2. + # if colors not provided, use same defaults as DataFrame.plot.box + result_list = get_standard_colors(num_colors=3) + result = np.take(result_list, [0, 0, 2]) + result = np.append(result, "k") + + colors = kwds.pop("color", None) + if colors: + if is_dict_like(colors): + # replace colors in result array with user-specified colors + # taken from the colors dict parameter + # "boxes" value placed in position 0, "whiskers" in 1, etc. + valid_keys = ["boxes", "whiskers", "medians", "caps"] + key_to_index = dict(zip(valid_keys, range(4), strict=True)) + for key, value in colors.items(): + if key in valid_keys: + result[key_to_index[key]] = value + else: + raise ValueError( + f"color dict contains invalid key '{key}'. " + f"The key must be either {valid_keys}" + ) + else: + result.fill(colors) + + return result + + def plot_group(keys, values, ax: Axes, **kwds): + # GH 45465: xlabel/ylabel need to be popped out before plotting happens + xlabel, ylabel = kwds.pop("xlabel", None), kwds.pop("ylabel", None) + if xlabel: + ax.set_xlabel(pprint_thing(xlabel)) + if ylabel: + ax.set_ylabel(pprint_thing(ylabel)) + + keys = [pprint_thing(x) for x in keys] + values = [remove_na_arraylike(v) for v in values] + bp = ax.boxplot(values, **kwds) + if fontsize is not None: + ax.tick_params(axis="both", labelsize=fontsize) + + # GH 45465: x/y are flipped when "vert" changes + _set_ticklabels( + ax=ax, labels=keys, is_vertical=kwds.get("vert", True), rotation=rot + ) + maybe_color_bp(bp, color_tup=colors, **kwds) + + # Return axes in multiplot case, maybe revisit later # 985 + if return_type == "dict": + return bp + elif return_type == "both": + return BoxPlot.BP(ax=ax, lines=bp) + else: + return ax + + colors = _get_colors() + if column is None: + columns = None + elif isinstance(column, (list, tuple)): + columns = column + else: + columns = [column] + + if by is not None: + # Prefer array return type for 2-D plots to match the subplot layout + # https://github.com/pandas-dev/pandas/pull/12216#issuecomment-241175580 + result = _grouped_plot_by_column( + plot_group, + data, + columns=columns, + by=by, + grid=grid, + figsize=figsize, + ax=ax, + layout=layout, + return_type=return_type, + **kwds, + ) + else: + if return_type is None: + return_type = "axes" + if layout is not None: + raise ValueError("The 'layout' keyword is not supported when 'by' is None") + + if ax is None: + rc = {"figure.figsize": figsize} if figsize is not None else {} + with mpl.rc_context(rc): + ax = plt.gca() + data = data._get_numeric_data() + naxes = len(data.columns) + if naxes == 0: + raise ValueError( + "boxplot method requires numerical columns, nothing to plot." + ) + if columns is None: + columns = data.columns + else: + data = data[columns] + + result = plot_group(columns, data.values.T, ax, **kwds) + ax.grid(grid) + + return result + + +def boxplot_frame( + self, + column=None, + by=None, + ax=None, + fontsize: int | None = None, + rot: int = 0, + grid: bool = True, + figsize: tuple[float, float] | None = None, + layout=None, + return_type=None, + **kwds, +): + import matplotlib.pyplot as plt + + ax = boxplot( + self, + column=column, + by=by, + ax=ax, + fontsize=fontsize, + grid=grid, + rot=rot, + figsize=figsize, + layout=layout, + return_type=return_type, + **kwds, + ) + plt.draw_if_interactive() + return ax + + +def boxplot_frame_groupby( + grouped, + subplots: bool = True, + column=None, + fontsize: int | None = None, + rot: int = 0, + grid: bool = True, + ax=None, + figsize: tuple[float, float] | None = None, + layout=None, + sharex: bool = False, + sharey: bool = True, + **kwds, +): + if subplots is True: + naxes = len(grouped) + fig, axes = create_subplots( + naxes=naxes, + squeeze=False, + ax=ax, + sharex=sharex, + sharey=sharey, + figsize=figsize, + layout=layout, + ) + data = {} + for (key, group), ax in zip(grouped, flatten_axes(axes), strict=False): + d = group.boxplot( + ax=ax, column=column, fontsize=fontsize, rot=rot, grid=grid, **kwds + ) + ax.set_title(pprint_thing(key)) + data[key] = d + ret = pd.Series(data) + maybe_adjust_figure(fig, bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2) + else: + keys, frames = zip(*grouped, strict=True) + df = pd.concat(frames, keys=keys, axis=1) + + # GH 16748, DataFrameGroupby fails when subplots=False and `column` argument + # is assigned, and in this case, since `df` here becomes MI after groupby, + # so we need to couple the keys (grouped values) and column (original df + # column) together to search for subset to plot + if column is not None: + column = com.convert_to_list_like(column) + multi_key = pd.MultiIndex.from_product([keys, column]) + column = list(multi_key.values) + ret = df.boxplot( + column=column, + fontsize=fontsize, + rot=rot, + grid=grid, + ax=ax, + figsize=figsize, + layout=layout, + **kwds, + ) + return ret diff --git a/pandas/plotting/_matplotlib/converter.py b/pandas/plotting/_matplotlib/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..813bd984cf2972ad950c35aa3f6511d3cfd1aa0d --- /dev/null +++ b/pandas/plotting/_matplotlib/converter.py @@ -0,0 +1,1130 @@ +from __future__ import annotations + +import contextlib +import datetime as pydt +from datetime import ( + datetime, + tzinfo, +) +import functools +from typing import ( + TYPE_CHECKING, + Any, + cast, +) +import warnings + +import matplotlib as mpl +import matplotlib.dates as mdates +import matplotlib.units as munits +import numpy as np + +from pandas._libs import lib +from pandas._libs.tslibs import ( + Timestamp, + to_offset, +) +from pandas._libs.tslibs.dtypes import ( + FreqGroup, + periods_per_day, +) +from pandas._typing import ( + F, + npt, +) + +from pandas.core.dtypes.common import ( + is_float, + is_float_dtype, + is_integer, + is_integer_dtype, + is_nested_list_like, +) + +from pandas import ( + Index, + Series, + get_option, +) +import pandas.core.common as com +from pandas.core.indexes.datetimes import date_range +from pandas.core.indexes.period import ( + Period, + PeriodIndex, + period_range, +) +import pandas.core.tools.datetimes as tools + +if TYPE_CHECKING: + from collections.abc import Generator + + from matplotlib.axis import Axis + + from pandas._libs.tslibs.offsets import BaseOffset + from pandas._typing import TimeUnit + + +_mpl_units: dict = {} # Cache for units overwritten by us + + +def get_pairs() -> list[tuple[type, type[mdates.DateConverter]]]: + pairs = [ + (Timestamp, DatetimeConverter), + (Period, PeriodConverter), + (pydt.datetime, DatetimeConverter), + (pydt.date, DatetimeConverter), + (pydt.time, TimeConverter), + (np.datetime64, DatetimeConverter), + ] + return pairs + + +def register_pandas_matplotlib_converters(func: F) -> F: + """ + Decorator applying pandas_converters. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with pandas_converters(): + return func(*args, **kwargs) + + return cast(F, wrapper) + + +@contextlib.contextmanager +def pandas_converters() -> Generator[None]: + """ + Context manager registering pandas' converters for a plot. + + See Also + -------- + register_pandas_matplotlib_converters : Decorator that applies this. + """ + value = get_option("plotting.matplotlib.register_converters") + + if value: + # register for True or "auto" + register() + try: + yield + finally: + if value == "auto": + # only deregister for "auto" + deregister() + + +def register() -> None: + pairs = get_pairs() + for type_, cls in pairs: + # Cache previous converter if present + if type_ in munits.registry and not isinstance(munits.registry[type_], cls): + previous = munits.registry[type_] + _mpl_units[type_] = previous + # Replace with pandas converter + munits.registry[type_] = cls() + + +def deregister() -> None: + # Renamed in pandas.plotting.__init__ + for type_, cls in get_pairs(): + # We use type to catch our classes directly, no inheritance + if type(munits.registry.get(type_)) is cls: + munits.registry.pop(type_) + + # restore the old keys + for unit, formatter in _mpl_units.items(): + if type(formatter) not in {DatetimeConverter, PeriodConverter, TimeConverter}: + # make it idempotent by excluding ours. + munits.registry[unit] = formatter + + +def _to_ordinalf(tm: pydt.time) -> float: + tot_sec = tm.hour * 3600 + tm.minute * 60 + tm.second + tm.microsecond / 10**6 + return tot_sec + + +def time2num(d): + if isinstance(d, str): + parsed = Timestamp(d) + return _to_ordinalf(parsed.time()) + if isinstance(d, pydt.time): + return _to_ordinalf(d) + return d + + +class TimeConverter(munits.ConversionInterface): + @staticmethod + def convert(value, unit, axis): + valid_types = (str, pydt.time) + if isinstance(value, valid_types) or is_integer(value) or is_float(value): + return time2num(value) + if isinstance(value, Index): + return value.map(time2num) + if isinstance(value, (list, tuple, np.ndarray, Index)): + return [time2num(x) for x in value] + return value + + @staticmethod + def axisinfo(unit, axis) -> munits.AxisInfo | None: + if unit != "time": + return None + + majloc = mpl.ticker.AutoLocator() # pyright: ignore[reportAttributeAccessIssue] + majfmt = TimeFormatter(majloc) + return munits.AxisInfo(majloc=majloc, majfmt=majfmt, label="time") + + @staticmethod + def default_units(x, axis) -> str: + return "time" + + +# time formatter +class TimeFormatter(mpl.ticker.Formatter): # pyright: ignore[reportAttributeAccessIssue] + def __init__(self, locs) -> None: + self.locs = locs + + def __call__(self, x, pos: int | None = 0) -> str: + """ + Return the time of day as a formatted string. + + Parameters + ---------- + x : float + The time of day specified as seconds since 00:00 (midnight), + with up to microsecond precision. + pos + Unused + + Returns + ------- + str + A string in HH:MM:SS.mmmuuu format. Microseconds, + milliseconds and seconds are only displayed if non-zero. + """ + fmt = "%H:%M:%S.%f" + s = int(x) + msus = round((x - s) * 10**6) + ms = msus // 1000 + us = msus % 1000 + m, s = divmod(s, 60) + h, m = divmod(m, 60) + _, h = divmod(h, 24) + if us != 0: + return pydt.time(h, m, s, msus).strftime(fmt) + elif ms != 0: + return pydt.time(h, m, s, msus).strftime(fmt)[:-3] + elif s != 0: + return pydt.time(h, m, s).strftime("%H:%M:%S") + + return pydt.time(h, m).strftime("%H:%M") + + +# Period Conversion + + +class PeriodConverter(mdates.DateConverter): + @staticmethod + def convert(values, unit, axis: Axis): + # Reached via e.g. `ax.set_xlim` + + # In tests as of 2025-09-24, unit is always None except for 3 tests + # that directly call this with unit=""; + # axis is always specifically a matplotlib.axis.XAxis + + if not hasattr(axis, "freq"): + raise TypeError("Axis must have `freq` set to convert to Periods") + freq = to_offset(axis.freq, is_period=True) # pyright: ignore[reportAttributeAccessIssue] + return PeriodConverter.convert_from_freq(values, freq) + + @staticmethod + def convert_from_freq(values, freq: BaseOffset): + if is_nested_list_like(values): + values = [PeriodConverter._convert_1d(v, freq) for v in values] + else: + values = PeriodConverter._convert_1d(values, freq) + return values + + @staticmethod + def _convert_1d(values, freq: BaseOffset): + valid_types = (str, datetime, Period, pydt.date, np.datetime64) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "Period with BDay freq is deprecated", category=FutureWarning + ) + warnings.filterwarnings( + "ignore", r"PeriodDtype\[B\] is deprecated", category=FutureWarning + ) + if ( + isinstance(values, valid_types) + or is_integer(values) + or is_float(values) + ): + return _get_datevalue(values, freq) + elif isinstance(values, PeriodIndex): + return values.asfreq(freq).asi8 + elif isinstance(values, Index): + return values.map(lambda x: _get_datevalue(x, freq)) + elif lib.infer_dtype(values, skipna=False) == "period": + # https://github.com/pandas-dev/pandas/issues/24304 + # convert ndarray[period] -> PeriodIndex + return PeriodIndex(values, freq=freq).asi8 + elif isinstance(values, (list, tuple, np.ndarray)): + return [_get_datevalue(x, freq) for x in values] + return values + + +def _get_datevalue(date, freq: BaseOffset): + if isinstance(date, Period): + return date.asfreq(freq).ordinal + elif isinstance(date, (str, datetime, pydt.date, np.datetime64)): + return Period(date, freq).ordinal # pyright: ignore[reportAttributeAccessIssue] + elif is_integer(date) or is_float(date): + return date + elif date is None: + return None + raise ValueError(f"Unrecognizable date '{date}'") + + +# Datetime Conversion +class DatetimeConverter(mdates.DateConverter): + @staticmethod + def convert(values, unit, axis: Axis): + # Reached via e.g. `ax.set_xlim` + + # In tests as of 2025-09-24, unit is always None except for 3 tests + # that directly call this with unit=""; + # axis is always specifically a matplotlib.axis.XAxis + + # values might be a 1-d array, or a list-like of arrays. + if is_nested_list_like(values): + values = [DatetimeConverter._convert_1d(v, unit, axis) for v in values] + else: + values = DatetimeConverter._convert_1d(values, unit, axis) + return values + + @staticmethod + def _convert_1d(values, unit, axis): + def try_parse(values): + try: + return mdates.date2num(tools.to_datetime(values)) + except Exception: + return values + + if isinstance(values, (datetime, pydt.date, np.datetime64, pydt.time)): + return mdates.date2num(values) + elif is_integer(values) or is_float(values): + return values + elif isinstance(values, str): + return try_parse(values) + elif isinstance(values, (list, tuple, np.ndarray, Index, Series)): + if isinstance(values, Series): + # https://github.com/matplotlib/matplotlib/issues/11391 + # Series was skipped. Convert to DatetimeIndex to get asi8 + values = Index(values) + if isinstance(values, Index): + values = values.values + if not isinstance(values, np.ndarray): + values = com.asarray_tuplesafe(values) + + if is_integer_dtype(values) or is_float_dtype(values): + return values + + try: + values = tools.to_datetime(values) + except Exception: + pass + + values = mdates.date2num(values) + + return values + + @staticmethod + def axisinfo(unit: tzinfo | None, axis) -> munits.AxisInfo: + """ + Return the :class:`~matplotlib.units.AxisInfo` for *unit*. + + *unit* is a tzinfo instance or None. + The *axis* argument is required but not used. + """ + tz = unit + + majloc = PandasAutoDateLocator(tz=tz) + majfmt = PandasAutoDateFormatter(majloc, tz=tz) + datemin = pydt.date(2000, 1, 1) + datemax = pydt.date(2010, 1, 1) + + return munits.AxisInfo( + majloc=majloc, majfmt=majfmt, label="", default_limits=(datemin, datemax) + ) + + +class PandasAutoDateFormatter(mdates.AutoDateFormatter): + def __init__(self, locator, tz=None, defaultfmt: str = "%Y-%m-%d") -> None: + mdates.AutoDateFormatter.__init__(self, locator, tz, defaultfmt) + + +class PandasAutoDateLocator(mdates.AutoDateLocator): + def get_locator(self, dmin, dmax): + """Pick the best locator based on a distance.""" + tot_sec = (dmax - dmin).total_seconds() + + if abs(tot_sec) < self.minticks: + self._freq = -1 + locator = MilliSecondLocator(self.tz) + locator.set_axis(self.axis) + + # error: Item "None" of "Axis | _DummyAxis | _AxisWrapper | None" + # has no attribute "get_data_interval" + locator.axis.set_view_interval( # type: ignore[union-attr] + *self.axis.get_view_interval() # type: ignore[union-attr] + ) + locator.axis.set_data_interval( # type: ignore[union-attr] + *self.axis.get_data_interval() # type: ignore[union-attr] + ) + return locator + + return mdates.AutoDateLocator.get_locator(self, dmin, dmax) + + def _get_unit(self): + return MilliSecondLocator.get_unit_generic(self._freq) + + +class MilliSecondLocator(mdates.DateLocator): + UNIT = 1.0 / (24 * 3600 * 1000) + + def __init__(self, tz) -> None: + mdates.DateLocator.__init__(self, tz) + self._interval = 1.0 + + def _get_unit(self): + return self.get_unit_generic(-1) + + @staticmethod + def get_unit_generic(freq): + unit = mdates.RRuleLocator.get_unit_generic(freq) + if unit < 0: + return MilliSecondLocator.UNIT + return unit + + def __call__(self): + # if no data have been set, this will tank with a ValueError + try: + dmin, dmax = self.viewlim_to_dt() + except ValueError: + return [] + + # We need to cap at the endpoints of valid datetime + nmax, nmin = mdates.date2num((dmax, dmin)) + + num = (nmax - nmin) * 86400 * 1000 + max_millis_ticks = 6 + for interval in [1, 10, 50, 100, 200, 500]: + if num <= interval * (max_millis_ticks - 1): + self._interval = interval + break + # We went through the whole loop without breaking, default to 1 + self._interval = 1000.0 + + estimate = (nmax - nmin) / (self._get_unit() * self._get_interval()) + + if estimate > self.MAXTICKS * 2: + raise RuntimeError( + "MillisecondLocator estimated to generate " + f"{estimate:d} ticks from {dmin} to {dmax}: exceeds Locator.MAXTICKS" + f"* 2 ({self.MAXTICKS * 2:d}) " + ) + + interval = self._get_interval() + freq = f"{interval}ms" + tz = self.tz.tzname(None) + st = dmin.replace(tzinfo=None) + ed = dmax.replace(tzinfo=None) + all_dates = date_range(start=st, end=ed, freq=freq, tz=tz).astype(object) + + try: + if len(all_dates) > 0: + locs = self.raise_if_exceeds(mdates.date2num(all_dates)) + return locs + except Exception: # pragma: no cover + pass + + lims = mdates.date2num([dmin, dmax]) + return lims + + def _get_interval(self): + return self._interval + + def autoscale(self): + """ + Set the view limits to include the data range. + """ + # We need to cap at the endpoints of valid datetime + dmin, dmax = self.datalim_to_dt() + + vmin = mdates.date2num(dmin) + vmax = mdates.date2num(dmax) + + return self.nonsingular(vmin, vmax) + + +# Fixed frequency dynamic tick locators and formatters + +# ------------------------------------------------------------------------- +# --- Locators --- +# ------------------------------------------------------------------------- + + +def _get_default_annual_spacing(nyears) -> tuple[int, int]: + """ + Returns a default spacing between consecutive ticks for annual data. + """ + if nyears < 11: + (min_spacing, maj_spacing) = (1, 1) + elif nyears < 20: + (min_spacing, maj_spacing) = (1, 2) + elif nyears < 50: + (min_spacing, maj_spacing) = (1, 5) + elif nyears < 100: + (min_spacing, maj_spacing) = (5, 10) + elif nyears < 200: + (min_spacing, maj_spacing) = (5, 25) + elif nyears < 600: + (min_spacing, maj_spacing) = (10, 50) + else: + factor = nyears // 1000 + 1 + (min_spacing, maj_spacing) = (factor * 20, factor * 100) + return (min_spacing, maj_spacing) + + +def _period_break(dates: PeriodIndex, period: str) -> npt.NDArray[np.intp]: + """ + Returns the indices where the given period changes. + + Parameters + ---------- + dates : PeriodIndex + Array of intervals to monitor. + period : str + Name of the period to monitor. + """ + mask = _period_break_mask(dates, period) + return np.nonzero(mask)[0] + + +def _period_break_mask(dates: PeriodIndex, period: str) -> npt.NDArray[np.bool_]: + current = getattr(dates, period) + previous = getattr(dates - 1 * dates.freq, period) + return current != previous + + +def has_level_label(label_flags: npt.NDArray[np.intp], vmin: float) -> bool: + """ + Returns true if the ``label_flags`` indicate there is at least one label + for this level. + + if the minimum view limit is not an exact integer, then the first tick + label won't be shown, so we must adjust for that. + """ + if label_flags.size == 0 or ( + label_flags.size == 1 and label_flags[0] == 0 and vmin % 1 > 0.0 + ): + return False + else: + return True + + +def _get_periods_per_ymd(freq: BaseOffset) -> tuple[int, int, int]: + # error: "BaseOffset" has no attribute "_period_dtype_code" + dtype_code = freq._period_dtype_code # type: ignore[attr-defined] + freq_group = FreqGroup.from_period_dtype_code(dtype_code) + + ppd = -1 # placeholder for above-day freqs + + if dtype_code >= FreqGroup.FR_HR.value: # pyright: ignore[reportAttributeAccessIssue] + # error: "BaseOffset" has no attribute "_creso" + ppd = periods_per_day(freq._creso) # type: ignore[attr-defined] + ppm = 28 * ppd + ppy = 365 * ppd + elif freq_group == FreqGroup.FR_BUS: + ppm = 19 + ppy = 261 + elif freq_group == FreqGroup.FR_DAY: + ppm = 28 + ppy = 365 + elif freq_group == FreqGroup.FR_WK: + ppm = 3 + ppy = 52 + elif freq_group == FreqGroup.FR_MTH: + ppm = 1 + ppy = 12 + elif freq_group == FreqGroup.FR_QTR: + ppm = -1 # placerholder + ppy = 4 + elif freq_group == FreqGroup.FR_ANN: + ppm = -1 # placeholder + ppy = 1 + else: + raise NotImplementedError(f"Unsupported frequency: {dtype_code}") + + return ppd, ppm, ppy + + +@functools.cache +def _daily_finder(vmin: float, vmax: float, freq: BaseOffset) -> np.ndarray: + # error: "BaseOffset" has no attribute "_period_dtype_code" + dtype_code = freq._period_dtype_code # type: ignore[attr-defined] + + periodsperday, periodspermonth, periodsperyear = _get_periods_per_ymd(freq) + + # save this for later usage + vmin_orig = vmin + (vmin, vmax) = (int(vmin), int(vmax)) + span = vmax - vmin + 1 + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", "Period with BDay freq is deprecated", category=FutureWarning + ) + warnings.filterwarnings( + "ignore", r"PeriodDtype\[B\] is deprecated", category=FutureWarning + ) + dates_ = period_range( + start=Period(ordinal=vmin, freq=freq), + end=Period(ordinal=vmax, freq=freq), + freq=freq, + ) + + # Initialize the output + info = np.zeros( + span, dtype=[("val", np.int64), ("maj", bool), ("min", bool), ("fmt", "|S20")] + ) + info["val"][:] = dates_.asi8 + info["fmt"][:] = "" + info["maj"][[0, -1]] = True + # .. and set some shortcuts + info_maj = info["maj"] + info_min = info["min"] + info_fmt = info["fmt"] + + def first_label(label_flags): + if (label_flags[0] == 0) and (label_flags.size > 1) and ((vmin_orig % 1) > 0.0): + return label_flags[1] + else: + return label_flags[0] + + # Case 1. Less than a month + if span <= periodspermonth: + day_start = _period_break(dates_, "day") + month_start = _period_break(dates_, "month") + year_start = _period_break(dates_, "year") + + def _hour_finder(label_interval: int, force_year_start: bool) -> None: + target = dates_.hour + mask = _period_break_mask(dates_, "hour") + info_maj[day_start] = True + info_min[mask & (target % label_interval == 0)] = True + info_fmt[mask & (target % label_interval == 0)] = "%H:%M" + info_fmt[day_start] = "%H:%M\n%d-%b" + info_fmt[year_start] = "%H:%M\n%d-%b\n%Y" + if force_year_start and not has_level_label(year_start, vmin_orig): + info_fmt[first_label(day_start)] = "%H:%M\n%d-%b\n%Y" + + def _minute_finder(label_interval: int) -> None: + target = dates_.minute + hour_start = _period_break(dates_, "hour") + mask = _period_break_mask(dates_, "minute") + info_maj[hour_start] = True + info_min[mask & (target % label_interval == 0)] = True + info_fmt[mask & (target % label_interval == 0)] = "%H:%M" + info_fmt[day_start] = "%H:%M\n%d-%b" + info_fmt[year_start] = "%H:%M\n%d-%b\n%Y" + + def _second_finder(label_interval: int) -> None: + target = dates_.second + minute_start = _period_break(dates_, "minute") + mask = _period_break_mask(dates_, "second") + info_maj[minute_start] = True + info_min[mask & (target % label_interval == 0)] = True + info_fmt[mask & (target % label_interval == 0)] = "%H:%M:%S" + info_fmt[day_start] = "%H:%M:%S\n%d-%b" + info_fmt[year_start] = "%H:%M:%S\n%d-%b\n%Y" + + if span < periodsperday / 12000: + _second_finder(1) + elif span < periodsperday / 6000: + _second_finder(2) + elif span < periodsperday / 2400: + _second_finder(5) + elif span < periodsperday / 1200: + _second_finder(10) + elif span < periodsperday / 800: + _second_finder(15) + elif span < periodsperday / 400: + _second_finder(30) + elif span < periodsperday / 150: + _minute_finder(1) + elif span < periodsperday / 70: + _minute_finder(2) + elif span < periodsperday / 24: + _minute_finder(5) + elif span < periodsperday / 12: + _minute_finder(15) + elif span < periodsperday / 6: + _minute_finder(30) + elif span < periodsperday / 2.5: + _hour_finder(1, False) + elif span < periodsperday / 1.5: + _hour_finder(2, False) + elif span < periodsperday * 1.25: + _hour_finder(3, False) + elif span < periodsperday * 2.5: + _hour_finder(6, True) + elif span < periodsperday * 4: + _hour_finder(12, True) + else: + info_maj[month_start] = True + info_min[day_start] = True + info_fmt[day_start] = "%d" + info_fmt[month_start] = "%d\n%b" + info_fmt[year_start] = "%d\n%b\n%Y" + if not has_level_label(year_start, vmin_orig): + if not has_level_label(month_start, vmin_orig): + info_fmt[first_label(day_start)] = "%d\n%b\n%Y" + else: + info_fmt[first_label(month_start)] = "%d\n%b\n%Y" + + # Case 2. Less than three months + elif span <= periodsperyear // 4: + month_start = _period_break(dates_, "month") + info_maj[month_start] = True + if dtype_code < FreqGroup.FR_HR.value: # pyright: ignore[reportAttributeAccessIssue] + info["min"] = True + else: + day_start = _period_break(dates_, "day") + info["min"][day_start] = True + week_start = _period_break(dates_, "week") + year_start = _period_break(dates_, "year") + info_fmt[week_start] = "%d" + info_fmt[month_start] = "\n\n%b" + info_fmt[year_start] = "\n\n%b\n%Y" + if not has_level_label(year_start, vmin_orig): + if not has_level_label(month_start, vmin_orig): + info_fmt[first_label(week_start)] = "\n\n%b\n%Y" + else: + info_fmt[first_label(month_start)] = "\n\n%b\n%Y" + # Case 3. Less than 14 months ............... + elif span <= 1.15 * periodsperyear: + year_start = _period_break(dates_, "year") + month_start = _period_break(dates_, "month") + week_start = _period_break(dates_, "week") + info_maj[month_start] = True + info_min[week_start] = True + info_min[year_start] = False + info_min[month_start] = False + info_fmt[month_start] = "%b" + info_fmt[year_start] = "%b\n%Y" + if not has_level_label(year_start, vmin_orig): + info_fmt[first_label(month_start)] = "%b\n%Y" + # Case 4. Less than 2.5 years ............... + elif span <= 2.5 * periodsperyear: + year_start = _period_break(dates_, "year") + quarter_start = _period_break(dates_, "quarter") + month_start = _period_break(dates_, "month") + info_maj[quarter_start] = True + info_min[month_start] = True + info_fmt[quarter_start] = "%b" + info_fmt[year_start] = "%b\n%Y" + # Case 4. Less than 4 years ................. + elif span <= 4 * periodsperyear: + year_start = _period_break(dates_, "year") + month_start = _period_break(dates_, "month") + info_maj[year_start] = True + info_min[month_start] = True + info_min[year_start] = False + + month_break = dates_[month_start].month + jan_or_jul = month_start[(month_break == 1) | (month_break == 7)] + info_fmt[jan_or_jul] = "%b" + info_fmt[year_start] = "%b\n%Y" + # Case 5. Less than 11 years ................ + elif span <= 11 * periodsperyear: + year_start = _period_break(dates_, "year") + quarter_start = _period_break(dates_, "quarter") + info_maj[year_start] = True + info_min[quarter_start] = True + info_min[year_start] = False + info_fmt[year_start] = "%Y" + # Case 6. More than 12 years ................ + else: + year_start = _period_break(dates_, "year") + year_break = dates_[year_start].year + nyears = span / periodsperyear + (min_anndef, maj_anndef) = _get_default_annual_spacing(nyears) + major_idx = year_start[(year_break % maj_anndef == 0)] + info_maj[major_idx] = True + minor_idx = year_start[(year_break % min_anndef == 0)] + info_min[minor_idx] = True + info_fmt[major_idx] = "%Y" + + return info + + +@functools.cache +def _monthly_finder(vmin: float, vmax: float, freq: BaseOffset) -> np.ndarray: + _, _, periodsperyear = _get_periods_per_ymd(freq) + + vmin_orig = vmin + (vmin, vmax) = (int(vmin), int(vmax)) + span = vmax - vmin + 1 + + # Initialize the output + info = np.zeros( + span, dtype=[("val", int), ("maj", bool), ("min", bool), ("fmt", "|S8")] + ) + info["val"] = np.arange(vmin, vmax + 1) + dates_ = info["val"] + info["fmt"] = "" + year_start = (dates_ % 12 == 0).nonzero()[0] + info_maj = info["maj"] + info_fmt = info["fmt"] + + if span <= 1.15 * periodsperyear: + info_maj[year_start] = True + info["min"] = True + + info_fmt[:] = "%b" + info_fmt[year_start] = "%b\n%Y" + + if not has_level_label(year_start, vmin_orig): + if dates_.size > 1: + idx = 1 + else: + idx = 0 + info_fmt[idx] = "%b\n%Y" + + elif span <= 2.5 * periodsperyear: + quarter_start = (dates_ % 3 == 0).nonzero() + info_maj[year_start] = True + # TODO: Check the following : is it really info['fmt'] ? + # 2023-09-15 this is reached in test_finder_monthly + info["fmt"][quarter_start] = True + info["min"] = True + + info_fmt[quarter_start] = "%b" + info_fmt[year_start] = "%b\n%Y" + + elif span <= 4 * periodsperyear: + info_maj[year_start] = True + info["min"] = True + + jan_or_jul = (dates_ % 12 == 0) | (dates_ % 12 == 6) + info_fmt[jan_or_jul] = "%b" + info_fmt[year_start] = "%b\n%Y" + + elif span <= 11 * periodsperyear: + quarter_start = (dates_ % 3 == 0).nonzero() + info_maj[year_start] = True + info["min"][quarter_start] = True + + info_fmt[year_start] = "%Y" + + else: + nyears = span / periodsperyear + (min_anndef, maj_anndef) = _get_default_annual_spacing(nyears) + years = dates_[year_start] // 12 + 1 + major_idx = year_start[(years % maj_anndef == 0)] + info_maj[major_idx] = True + info["min"][year_start[(years % min_anndef == 0)]] = True + + info_fmt[major_idx] = "%Y" + + return info + + +@functools.cache +def _quarterly_finder(vmin: float, vmax: float, freq: BaseOffset) -> np.ndarray: + _, _, periodsperyear = _get_periods_per_ymd(freq) + vmin_orig = vmin + (vmin, vmax) = (int(vmin), int(vmax)) + span = vmax - vmin + 1 + + info = np.zeros( + span, dtype=[("val", int), ("maj", bool), ("min", bool), ("fmt", "|S8")] + ) + info["val"] = np.arange(vmin, vmax + 1) + info["fmt"] = "" + dates_ = info["val"] + info_maj = info["maj"] + info_fmt = info["fmt"] + year_start = (dates_ % 4 == 0).nonzero()[0] + + if span <= 3.5 * periodsperyear: + info_maj[year_start] = True + info["min"] = True + + info_fmt[:] = "Q%q" + info_fmt[year_start] = "Q%q\n%F" + if not has_level_label(year_start, vmin_orig): + if dates_.size > 1: + idx = 1 + else: + idx = 0 + info_fmt[idx] = "Q%q\n%F" + + elif span <= 11 * periodsperyear: + info_maj[year_start] = True + info["min"] = True + info_fmt[year_start] = "%F" + + else: + # https://github.com/pandas-dev/pandas/pull/47602 + years = dates_[year_start] // 4 + 1970 + nyears = span / periodsperyear + (min_anndef, maj_anndef) = _get_default_annual_spacing(nyears) + major_idx = year_start[(years % maj_anndef == 0)] + info_maj[major_idx] = True + info["min"][year_start[(years % min_anndef == 0)]] = True + info_fmt[major_idx] = "%F" + + return info + + +@functools.cache +def _annual_finder(vmin: float, vmax: float, freq: BaseOffset) -> np.ndarray: + # Note: small difference here vs other finders in adding 1 to vmax + (vmin, vmax) = (int(vmin), int(vmax + 1)) + span = vmax - vmin + 1 + + info = np.zeros( + span, dtype=[("val", int), ("maj", bool), ("min", bool), ("fmt", "|S8")] + ) + info["val"] = np.arange(vmin, vmax + 1) + info["fmt"] = "" + dates_ = info["val"] + + (min_anndef, maj_anndef) = _get_default_annual_spacing(span) + major_idx = dates_ % maj_anndef == 0 + minor_idx = dates_ % min_anndef == 0 + info["maj"][major_idx] = True + info["min"][minor_idx] = True + info["fmt"][major_idx] = "%Y" + + return info + + +def get_finder(freq: BaseOffset): + # error: "BaseOffset" has no attribute "_period_dtype_code" + dtype_code = freq._period_dtype_code # type: ignore[attr-defined] + fgroup = FreqGroup.from_period_dtype_code(dtype_code) + + if fgroup == FreqGroup.FR_ANN: + return _annual_finder + elif fgroup == FreqGroup.FR_QTR: + return _quarterly_finder + elif fgroup == FreqGroup.FR_MTH: + return _monthly_finder + elif (dtype_code >= FreqGroup.FR_BUS.value) or fgroup == FreqGroup.FR_WK: # pyright: ignore[reportAttributeAccessIssue] + return _daily_finder + else: # pragma: no cover + raise NotImplementedError(f"Unsupported frequency: {dtype_code}") + + +class TimeSeries_DateLocator(mpl.ticker.Locator): # pyright: ignore[reportAttributeAccessIssue] + """ + Locates the ticks along an axis controlled by a :class:`Series`. + + Parameters + ---------- + freq : BaseOffset + Valid frequency specifier. + minor_locator : {False, True}, optional + Whether the locator is for minor ticks (True) or not. + dynamic_mode : {True, False}, optional + Whether the locator should work in dynamic mode. + base : {int}, optional + quarter : {int}, optional + month : {int}, optional + day : {int}, optional + """ + + axis: Axis + + def __init__( + self, + freq: BaseOffset, + minor_locator: bool = False, + dynamic_mode: bool = True, + base: int = 1, + quarter: int = 1, + month: int = 1, + day: int = 1, + plot_obj=None, + ) -> None: + freq = to_offset(freq, is_period=True) + self.freq = freq + self.base = base + (self.quarter, self.month, self.day) = (quarter, month, day) + self.isminor = minor_locator + self.isdynamic = dynamic_mode + self.offset = 0 + self.plot_obj = plot_obj + self.finder = get_finder(freq) + + def _get_default_locs(self, vmin, vmax): + """Returns the default locations of ticks.""" + locator = self.finder(vmin, vmax, self.freq) + + if self.isminor: + return np.compress(locator["min"], locator["val"]) + return np.compress(locator["maj"], locator["val"]) + + def __call__(self): + """Return the locations of the ticks.""" + # axis calls Locator.set_axis inside set_m_formatter + + vi = tuple(self.axis.get_view_interval()) + vmin, vmax = vi + if vmax < vmin: + vmin, vmax = vmax, vmin + if self.isdynamic: + locs = self._get_default_locs(vmin, vmax) + else: # pragma: no cover + base = self.base + (d, m) = divmod(vmin, base) + vmin = (d + 1) * base + # error: No overload variant of "range" matches argument types "float", + # "float", "int" + locs = list(range(vmin, vmax + 1, base)) # type: ignore[call-overload] + return locs + + def autoscale(self): + """ + Sets the view limits to the nearest multiples of base that contain the + data. + """ + # requires matplotlib >= 0.98.0 + (vmin, vmax) = self.axis.get_data_interval() + + locs = self._get_default_locs(vmin, vmax) + (vmin, vmax) = locs[[0, -1]] + if vmin == vmax: + vmin -= 1 + vmax += 1 + return mpl.transforms.nonsingular(vmin, vmax) + + +# ------------------------------------------------------------------------- +# --- Formatter --- +# ------------------------------------------------------------------------- + + +class TimeSeries_DateFormatter(mpl.ticker.Formatter): # pyright: ignore[reportAttributeAccessIssue] + """ + Formats the ticks along an axis controlled by a :class:`PeriodIndex`. + + Parameters + ---------- + freq : BaseOffset + Valid frequency specifier. + minor_locator : bool, default False + Whether the current formatter should apply to minor ticks (True) or + major ticks (False). + dynamic_mode : bool, default True + Whether the formatter works in dynamic mode or not. + """ + + axis: Axis + + def __init__( + self, + freq: BaseOffset, + minor_locator: bool = False, + dynamic_mode: bool = True, + plot_obj=None, + ) -> None: + freq = to_offset(freq, is_period=True) + self.format = None + self.freq = freq + self.locs: list[Any] = [] # unused, for matplotlib compat + self.formatdict: dict[Any, Any] | None = None + self.isminor = minor_locator + self.isdynamic = dynamic_mode + self.offset = 0 + self.plot_obj = plot_obj + self.finder = get_finder(freq) + + def _set_default_format(self, vmin, vmax): + """Returns the default ticks spacing.""" + info = self.finder(vmin, vmax, self.freq) + + if self.isminor: + format = np.compress(info["min"] & np.logical_not(info["maj"]), info) + else: + format = np.compress(info["maj"], info) + self.formatdict = {x: f for (x, _, _, f) in format} + return self.formatdict + + def set_locs(self, locs) -> None: + """Sets the locations of the ticks""" + # don't actually use the locs. This is just needed to work with + # matplotlib. Force to use vmin, vmax + + self.locs = locs + + (vmin, vmax) = tuple(self.axis.get_view_interval()) + if vmax < vmin: + (vmin, vmax) = (vmax, vmin) + self._set_default_format(vmin, vmax) + + def __call__(self, x, pos: int | None = 0) -> str: + if self.formatdict is None: + return "" + else: + fmt = self.formatdict.pop(x, "") + if isinstance(fmt, np.bytes_): + fmt = fmt.decode("utf-8") + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + "Period with BDay freq is deprecated", + category=FutureWarning, + ) + period = Period(ordinal=int(x), freq=self.freq) + assert isinstance(period, Period) + return period.strftime(fmt) + + +class TimeSeries_TimedeltaFormatter(mpl.ticker.Formatter): # pyright: ignore[reportAttributeAccessIssue] + """ + Formats the ticks along an axis controlled by a :class:`TimedeltaIndex`. + """ + + def __init__(self, unit: TimeUnit = "ns"): + self.unit = unit + super().__init__() + + axis: Axis + + @staticmethod + def format_timedelta_ticks(x, pos, n_decimals: int, exp: int = 9) -> str: + """ + Convert seconds to 'D days HH:MM:SS.F' + """ + s, ns = divmod(x, 10**exp) + m, s = divmod(s, 60) + h, m = divmod(m, 60) + d, h = divmod(h, 24) + decimals = int(ns * 10 ** (n_decimals - exp)) + s = f"{int(h):02d}:{int(m):02d}:{int(s):02d}" + if n_decimals > 0: + s += f".{decimals:0{n_decimals}d}" + if d != 0: + s = f"{int(d):d} days {s}" + return s + + def __call__(self, x, pos: int | None = 0) -> str: + exp = {"ns": 9, "us": 6, "ms": 3, "s": 0}[self.unit] + (vmin, vmax) = tuple(self.axis.get_view_interval()) + n_decimals = min(int(np.ceil(np.log10(100 * 10**exp / abs(vmax - vmin)))), exp) + return self.format_timedelta_ticks(x, pos, n_decimals, exp) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py new file mode 100644 index 0000000000000000000000000000000000000000..0834501c4429db97e2ebb0a7f7eeb88416aedc0a --- /dev/null +++ b/pandas/plotting/_matplotlib/core.py @@ -0,0 +1,2207 @@ +from __future__ import annotations + +from abc import ( + ABC, + abstractmethod, +) +from collections.abc import ( + Hashable, + Iterable, + Iterator, + Sequence, +) +from typing import ( + TYPE_CHECKING, + Any, + Literal, + cast, + final, +) +import warnings + +import matplotlib as mpl +import numpy as np + +from pandas._libs import lib +from pandas.errors import AbstractMethodError +from pandas.util._decorators import cache_readonly +from pandas.util._exceptions import find_stack_level + +from pandas.core.dtypes.common import ( + is_any_real_numeric_dtype, + is_bool, + is_float, + is_float_dtype, + is_hashable, + is_integer, + is_integer_dtype, + is_iterator, + is_list_like, + is_number, + is_numeric_dtype, +) +from pandas.core.dtypes.dtypes import ( + CategoricalDtype, + ExtensionDtype, +) +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCDatetimeIndex, + ABCIndex, + ABCMultiIndex, + ABCPeriodIndex, + ABCSeries, +) +from pandas.core.dtypes.missing import isna + +import pandas.core.common as com + +from pandas.io.formats.printing import pprint_thing +from pandas.plotting._matplotlib import tools +from pandas.plotting._matplotlib.converter import ( + PeriodConverter, + register_pandas_matplotlib_converters, +) +from pandas.plotting._matplotlib.groupby import reconstruct_data_with_by +from pandas.plotting._matplotlib.misc import unpack_single_str_list +from pandas.plotting._matplotlib.style import get_standard_colors +from pandas.plotting._matplotlib.timeseries import ( + format_dateaxis, + maybe_convert_index, + prepare_ts_data, + use_dynamic_x, +) +from pandas.plotting._matplotlib.tools import ( + create_subplots, + flatten_axes, + format_date_labels, + get_all_lines, + get_xlim, + handle_shared_axes, +) + +if TYPE_CHECKING: + from matplotlib.artist import Artist + from matplotlib.axes import Axes + from matplotlib.axis import Axis + from matplotlib.figure import Figure + + from pandas._typing import ( + IndexLabel, + NDFrameT, + PlottingOrientation, + npt, + ) + + from pandas import ( + DataFrame, + Index, + Series, + ) + + +def holds_integer(column: Index) -> bool: + return column.dtype.kind in "iu" + + +def _color_in_style(style: str) -> bool: + """ + Check if there is a color letter in the style string. + """ + return not set(mpl.colors.BASE_COLORS).isdisjoint(style) + + +class MPLPlot(ABC): + """ + Base class for assembling a pandas plot using matplotlib + + Parameters + ---------- + data : + + """ + + @property + @abstractmethod + def _kind(self) -> str: + """Specify kind str. Must be overridden in child class""" + raise NotImplementedError + + _layout_type = "vertical" + _default_rot = 0 + + @property + def orientation(self) -> str | None: + return None + + data: DataFrame + + def __init__( + self, + data, + kind=None, + by: IndexLabel | None = None, + subplots: bool | Sequence[Sequence[str]] = False, + sharex: bool | None = None, + sharey: bool = False, + use_index: bool = True, + figsize: tuple[float, float] | None = None, + grid=None, + legend: bool | str = True, + rot=None, + ax=None, + fig=None, + title=None, + xlim=None, + ylim=None, + xticks=None, + yticks=None, + xlabel: Hashable | None = None, + ylabel: Hashable | None = None, + fontsize: int | None = None, + secondary_y: bool | tuple | list | np.ndarray = False, + colormap=None, + table: bool = False, + layout=None, + include_bool: bool = False, + column: IndexLabel | None = None, + *, + logx: bool | None | Literal["sym"] = False, + logy: bool | None | Literal["sym"] = False, + loglog: bool | None | Literal["sym"] = False, + mark_right: bool = True, + stacked: bool = False, + label: Hashable | None = None, + style=None, + **kwds, + ) -> None: + # if users assign an empty list or tuple, raise `ValueError` + # similar to current `df.box` and `df.hist` APIs. + if by in ([], ()): + raise ValueError("No group keys passed!") + self.by = com.maybe_make_list(by) + + # Assign the rest of columns into self.columns if by is explicitly defined + # while column is not, only need `columns` in hist/box plot when it's DF + # TODO: Might deprecate `column` argument in future PR (#28373) + if isinstance(data, ABCDataFrame): + if column: + self.columns = com.maybe_make_list(column) + elif self.by is None: + self.columns = [ + col for col in data.columns if is_numeric_dtype(data[col]) + ] + else: + self.columns = [ + col + for col in data.columns + if col not in self.by and is_numeric_dtype(data[col]) + ] + + # For `hist` plot, need to get grouped original data before `self.data` is + # updated later + if self.by is not None and self._kind == "hist": + self._grouped = data.groupby(unpack_single_str_list(self.by)) + + self.kind = kind + + self.subplots = type(self)._validate_subplots_kwarg( + subplots, data, kind=self._kind + ) + + self.sharex = type(self)._validate_sharex(sharex, ax, by) + self.sharey = sharey + self.figsize = figsize + self.layout = layout + + self.xticks = xticks + self.yticks = yticks + self.xlim = xlim + self.ylim = ylim + self.title = title + self.use_index = use_index + self.xlabel = xlabel + self.ylabel = ylabel + + self.fontsize = fontsize + + if rot is not None: + self.rot = rot + # need to know for format_date_labels since it's rotated to 30 by + # default + self._rot_set = True + else: + self._rot_set = False + self.rot = self._default_rot + + if grid is None: + grid = False if secondary_y else mpl.rcParams["axes.grid"] + + self.grid = grid + self.legend = legend + self.legend_handles: list[Artist] = [] + self.legend_labels: list[Hashable] = [] + + self.logx = type(self)._validate_log_kwd("logx", logx) + self.logy = type(self)._validate_log_kwd("logy", logy) + self.loglog = type(self)._validate_log_kwd("loglog", loglog) + self.label = label + self.style = style + self.mark_right = mark_right + self.stacked = stacked + + # ax may be an Axes object or (if self.subplots) an ndarray of + # Axes objects + self.ax = ax + # TODO: deprecate fig keyword as it is ignored, not passed in tests + # as of 2023-11-05 + + # parse errorbar input if given + xerr = kwds.pop("xerr", None) + yerr = kwds.pop("yerr", None) + nseries = self._get_nseries(data) + xerr, data = type(self)._parse_errorbars("xerr", xerr, data, nseries) + yerr, data = type(self)._parse_errorbars("yerr", yerr, data, nseries) + self.errors = {"xerr": xerr, "yerr": yerr} + self.data = data + + if not isinstance(secondary_y, (bool, tuple, list, np.ndarray, ABCIndex)): + secondary_y = [secondary_y] + self.secondary_y = secondary_y + + # ugly TypeError if user passes matplotlib's `cmap` name. + # Probably better to accept either. + if "cmap" in kwds and colormap: + raise TypeError("Only specify one of `cmap` and `colormap`.") + if "cmap" in kwds: + self.colormap = kwds.pop("cmap") + else: + self.colormap = colormap + + self.table = table + self.include_bool = include_bool + + self.kwds = kwds + + color = kwds.pop("color", lib.no_default) + self.color = self._validate_color_args(color, self.colormap) + assert "color" not in self.kwds + + self.data = self._ensure_frame(self.data) + + from pandas.plotting import plot_params + + self.x_compat = plot_params["x_compat"] + if "x_compat" in self.kwds: + self.x_compat = bool(self.kwds.pop("x_compat")) + + @final + def _is_ts_plot(self) -> bool: + # this is slightly deceptive + return not self.x_compat and self.use_index and self._use_dynamic_x() + + @final + def _use_dynamic_x(self) -> bool: + return use_dynamic_x(self._get_ax(0), self.data.index) + + @final + @staticmethod + def _validate_sharex(sharex: bool | None, ax, by) -> bool: + if sharex is None: + # if by is defined, subplots are used and sharex should be False + if ax is None and by is None: + sharex = True + else: + # if we get an axis, the users should do the visibility + # setting... + sharex = False + elif not is_bool(sharex): + raise TypeError("sharex must be a bool or None") + return bool(sharex) + + @classmethod + def _validate_log_kwd( + cls, + kwd: str, + value: bool | None | Literal["sym"], + ) -> bool | None | Literal["sym"]: + if ( + value is None + or isinstance(value, bool) + or (isinstance(value, str) and value == "sym") + ): + return value + raise ValueError( + f"keyword '{kwd}' should be bool, None, or 'sym', not '{value}'" + ) + + @final + @staticmethod + def _validate_subplots_kwarg( + subplots: bool | Sequence[Sequence[str]], data: Series | DataFrame, kind: str + ) -> bool | list[tuple[int, ...]]: + """ + Validate the subplots parameter + + - check type and content + - check for duplicate columns + - check for invalid column names + - convert column names into indices + - add missing columns in a group of their own + See comments in code below for more details. + + Parameters + ---------- + subplots : subplots parameters as passed to PlotAccessor + + Returns + ------- + validated subplots : a bool or a list of tuples of column indices. Columns + in the same tuple will be grouped together in the resulting plot. + """ + + if isinstance(subplots, bool): + return subplots + elif not isinstance(subplots, Iterable): + raise ValueError("subplots should be a bool or an iterable") + + supported_kinds = ( + "line", + "bar", + "barh", + "hist", + "kde", + "density", + "area", + "pie", + ) + if kind not in supported_kinds: + raise ValueError( + "When subplots is an iterable, kind must be " + f"one of {', '.join(supported_kinds)}. Got {kind}." + ) + + if isinstance(data, ABCSeries): + raise NotImplementedError( + "An iterable subplots for a Series is not supported." + ) + + columns = data.columns + if isinstance(columns, ABCMultiIndex): + raise NotImplementedError( + "An iterable subplots for a DataFrame with a MultiIndex column " + "is not supported." + ) + + if columns.nunique() != len(columns): + raise NotImplementedError( + "An iterable subplots for a DataFrame with non-unique column " + "labels is not supported." + ) + + # subplots is a list of tuples where each tuple is a group of + # columns to be grouped together (one ax per group). + # we consolidate the subplots list such that: + # - the tuples contain indices instead of column names + # - the columns that aren't yet in the list are added in a group + # of their own. + # For example with columns from a to g, and + # subplots = [(a, c), (b, f, e)], + # we end up with [(ai, ci), (bi, fi, ei), (di,), (gi,)] + # This way, we can handle self.subplots in a homogeneous manner + # later. + # TODO: also accept indices instead of just names? + + out = [] + seen_columns: set[Hashable] = set() + for group in subplots: + if not is_list_like(group): + raise ValueError( + "When subplots is an iterable, each entry " + "should be a list/tuple of column names." + ) + idx_locs = columns.get_indexer_for(group) + if (idx_locs == -1).any(): + bad_labels = np.extract(idx_locs == -1, group) + raise ValueError( + f"Column label(s) {list(bad_labels)} not found in the DataFrame." + ) + unique_columns = set(group) + duplicates = seen_columns.intersection(unique_columns) + if duplicates: + raise ValueError( + "Each column should be in only one subplot. " + f"Columns {duplicates} were found in multiple subplots." + ) + seen_columns = seen_columns.union(unique_columns) + out.append(tuple(idx_locs)) + + unseen_columns = columns.difference(seen_columns) + for column in unseen_columns: + idx_loc = columns.get_loc(column) + out.append((idx_loc,)) + return out + + def _validate_color_args(self, color, colormap): + if color is lib.no_default: + # It was not provided by the user + if "colors" in self.kwds and colormap is not None: + warnings.warn( + "'color' and 'colormap' cannot be used simultaneously. " + "Using 'color'", + stacklevel=find_stack_level(), + ) + return None + if self.nseries == 1 and color is not None and not is_list_like(color): + # support series.plot(color='green') + color = [color] + + if isinstance(color, tuple) and self.nseries == 1 and len(color) in (3, 4): + # support RGB and RGBA tuples in series plot + color = [color] + + if colormap is not None: + warnings.warn( + "'color' and 'colormap' cannot be used simultaneously. Using 'color'", + stacklevel=find_stack_level(), + ) + + if self.style is not None: + if isinstance(self.style, dict): + styles = [self.style[col] for col in self.columns if col in self.style] + elif is_list_like(self.style): + styles = self.style + else: + styles = [self.style] + # need only a single match + for s in styles: + if _color_in_style(s): + raise ValueError( + "Cannot pass 'style' string with a color symbol and " + "'color' keyword argument. Please use one or the " + "other or pass 'style' without a color symbol" + ) + return color + + @final + @staticmethod + def _iter_data( + data: DataFrame | dict[Hashable, Series | DataFrame], + ) -> Iterator[tuple[Hashable, np.ndarray]]: + for col, values in data.items(): + # This was originally written to use values.values before EAs + # were implemented; adding np.asarray(...) to keep consistent + # typing. + yield col, np.asarray(values.values) + + def _get_nseries(self, data: Series | DataFrame) -> int: + # When `by` is explicitly assigned, grouped data size will be defined, and + # this will determine number of subplots to have, aka `self.nseries` + if data.ndim == 1: + return 1 + elif self.by is not None and self._kind == "hist": + return len(self._grouped) + elif self.by is not None and self._kind == "box": + return len(self.columns) + else: + return data.shape[1] + + @final + @property + def nseries(self) -> int: + return self._get_nseries(self.data) + + @final + def generate(self) -> None: + self._compute_plot_data() + fig = self.fig + self._make_plot(fig) + self._add_table() + self._make_legend() + self._adorn_subplots(fig) + + for ax in self.axes: + self._post_plot_logic_common(ax) + self._post_plot_logic(ax, self.data) + + @final + @staticmethod + def _has_plotted_object(ax: Axes) -> bool: + """check whether ax has data""" + return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0 + + @final + def _maybe_right_yaxis(self, ax: Axes, axes_num: int) -> Axes: + if not self.on_right(axes_num): + # secondary axes may be passed via ax kw + return self._get_ax_layer(ax) + + if hasattr(ax, "right_ax"): + # if it has right_ax property, ``ax`` must be left axes + return ax.right_ax + elif hasattr(ax, "left_ax"): + # if it has left_ax property, ``ax`` must be right axes + return ax + else: + # otherwise, create twin axes + orig_ax, new_ax = ax, ax.twinx() + # TODO: use Matplotlib public API when available + new_ax._get_lines = orig_ax._get_lines # type: ignore[attr-defined] + # TODO #54485 + new_ax._get_patches_for_fill = ( # type: ignore[attr-defined] + orig_ax._get_patches_for_fill # type: ignore[attr-defined] + ) + # TODO #54485 + orig_ax.right_ax, new_ax.left_ax = ( # type: ignore[attr-defined] + new_ax, + orig_ax, + ) + + if not self._has_plotted_object(orig_ax): # no data on left y + orig_ax.get_yaxis().set_visible(False) + + if self.logy is True or self.loglog is True: + new_ax.set_yscale("log") + elif self.logy == "sym" or self.loglog == "sym": + new_ax.set_yscale("symlog") + return new_ax + + @final + @cache_readonly + def fig(self) -> Figure: + return self._axes_and_fig[1] + + @final + @cache_readonly + # TODO: can we annotate this as both a Sequence[Axes] and ndarray[object]? + def axes(self) -> Sequence[Axes]: + return self._axes_and_fig[0] + + @final + @cache_readonly + def _axes_and_fig(self) -> tuple[Sequence[Axes], Figure]: + import matplotlib.pyplot as plt + + if self.subplots: + naxes = ( + self.nseries if isinstance(self.subplots, bool) else len(self.subplots) + ) + fig, axes = create_subplots( + naxes=naxes, + sharex=self.sharex, + sharey=self.sharey, + figsize=self.figsize, + ax=self.ax, + layout=self.layout, + layout_type=self._layout_type, + ) + elif self.ax is None: + fig = plt.figure(figsize=self.figsize) + axes = fig.add_subplot(111) + else: + fig = self.ax.get_figure() + if self.figsize is not None: + fig.set_size_inches(self.figsize) + axes = self.ax + + axes = np.fromiter(flatten_axes(axes), dtype=object) + + if self.logx is True or self.loglog is True: + [a.set_xscale("log") for a in axes] + elif self.logx == "sym" or self.loglog == "sym": + [a.set_xscale("symlog") for a in axes] + + if self.logy is True or self.loglog is True: + [a.set_yscale("log") for a in axes] + elif self.logy == "sym" or self.loglog == "sym": + [a.set_yscale("symlog") for a in axes] + + axes_seq = cast(Sequence["Axes"], axes) + return axes_seq, fig + + @property + def result(self): + """ + Return result axes + """ + if self.subplots: + if self.layout is not None and not is_list_like(self.ax): + # error: "Sequence[Any]" has no attribute "reshape" + return self.axes.reshape(*self.layout) # type: ignore[attr-defined] + else: + return self.axes + else: + sec_true = isinstance(self.secondary_y, bool) and self.secondary_y + # error: Argument 1 to "len" has incompatible type "Union[bool, + # Tuple[Any, ...], List[Any], ndarray[Any, Any]]"; expected "Sized" + all_sec = ( + is_list_like(self.secondary_y) and len(self.secondary_y) == self.nseries # type: ignore[arg-type] + ) + if sec_true or all_sec: + # if all data is plotted on secondary, return right axes + return self._get_ax_layer(self.axes[0], primary=False) + else: + return self.axes[0] + + @final + @staticmethod + def _convert_to_ndarray(data): + # GH31357: categorical columns are processed separately + if isinstance(data.dtype, CategoricalDtype): + return data + + # GH32073: cast to float if values contain nulled integers + if (is_integer_dtype(data.dtype) or is_float_dtype(data.dtype)) and isinstance( + data.dtype, ExtensionDtype + ): + return data.to_numpy(dtype="float", na_value=np.nan) + + # GH25587: cast ExtensionArray of pandas (IntegerArray, etc.) to + # np.ndarray before plot. + if len(data) > 0: + return np.asarray(data) + + return data + + @final + def _ensure_frame(self, data) -> DataFrame: + if isinstance(data, ABCSeries): + label = self.label + if label is None and data.name is None: + label = "" + if label is None: + # We'll end up with columns of [0] instead of [None] + data = data.to_frame() + else: + data = data.to_frame(name=label) + elif self._kind in ("hist", "box"): + cols = self.columns if self.by is None else self.columns + self.by + data = data.loc[:, cols] + return data + + @final + def _compute_plot_data(self) -> None: + data = self.data + + # GH15079 reconstruct data if by is defined + if self.by is not None: + self.subplots = True + data = reconstruct_data_with_by(self.data, by=self.by, cols=self.columns) + + # GH16953, infer_objects is needed as fallback, for ``Series`` + # with ``dtype == object`` + data = data.infer_objects() + include_type = [np.number, "datetime", "datetimetz", "timedelta"] + + # GH23719, allow plotting boolean + if self.include_bool is True: + include_type.append(np.bool_) + + # GH22799, exclude datetime-like type for boxplot + exclude_type = None + if self._kind == "box": + # TODO: change after solving issue 27881 + include_type = [np.number] + exclude_type = ["timedelta"] + + # GH 18755, include object and category type for scatter plot + if self._kind == "scatter": + include_type.extend(["object", "category", "string"]) + + numeric_data = data.select_dtypes(include=include_type, exclude=exclude_type) + + is_empty = numeric_data.shape[-1] == 0 + # no non-numeric frames or series allowed + if is_empty: + raise TypeError("no numeric data to plot") + + self.data = numeric_data.apply(type(self)._convert_to_ndarray) + + def _make_plot(self, fig: Figure) -> None: + raise AbstractMethodError(self) + + @final + def _add_table(self) -> None: + if self.table is False: + return + elif self.table is True: + data = self.data.transpose() + else: + data = self.table + ax = self._get_ax(0) + tools.table(ax, data) + + @final + def _post_plot_logic_common(self, ax: Axes) -> None: + """Common post process for each axes""" + if self.orientation == "vertical" or self.orientation is None: + type(self)._apply_axis_properties( + ax.xaxis, rot=self.rot, fontsize=self.fontsize + ) + type(self)._apply_axis_properties(ax.yaxis, fontsize=self.fontsize) + + if hasattr(ax, "right_ax"): + type(self)._apply_axis_properties( + ax.right_ax.yaxis, fontsize=self.fontsize + ) + + elif self.orientation == "horizontal": + type(self)._apply_axis_properties( + ax.yaxis, rot=self.rot, fontsize=self.fontsize + ) + type(self)._apply_axis_properties(ax.xaxis, fontsize=self.fontsize) + + if hasattr(ax, "right_ax"): + type(self)._apply_axis_properties( + ax.right_ax.yaxis, fontsize=self.fontsize + ) + else: # pragma no cover + raise ValueError + + @abstractmethod + def _post_plot_logic(self, ax: Axes, data) -> None: + """Post process for each axes. Overridden in child classes""" + + @final + def _adorn_subplots(self, fig: Figure) -> None: + """Common post process unrelated to data""" + if len(self.axes) > 0: + all_axes = self._get_subplots(fig) + nrows, ncols = self._get_axes_layout(fig) + handle_shared_axes( + axarr=all_axes, + nplots=len(all_axes), + naxes=nrows * ncols, + nrows=nrows, + ncols=ncols, + sharex=self.sharex, + sharey=self.sharey, + ) + + for ax in self.axes: + ax = getattr(ax, "right_ax", ax) + if self.yticks is not None: + ax.set_yticks(self.yticks) + + if self.xticks is not None: + ax.set_xticks(self.xticks) + + if self.ylim is not None: + ax.set_ylim(self.ylim) + + if self.xlim is not None: + ax.set_xlim(self.xlim) + + # GH9093, currently Pandas does not show ylabel, so if users provide + # ylabel will set it as ylabel in the plot. + if self.ylabel is not None: + ax.set_ylabel(pprint_thing(self.ylabel)) + + ax.grid(self.grid) + + if self.title: + if self.subplots: + if is_list_like(self.title): + if not isinstance(self.subplots, bool): + if len(self.subplots) != len(self.title): + raise ValueError( + f"The number of titles ({len(self.title)}) must equal " + f"the number of subplots ({len(self.subplots)})." + ) + elif len(self.title) != self.nseries: + raise ValueError( + "The length of `title` must equal the number " + "of columns if using `title` of type `list` " + "and `subplots=True`.\n" + f"length of title = {len(self.title)}\n" + f"number of columns = {self.nseries}" + ) + + for ax, title in zip(self.axes, self.title, strict=False): + ax.set_title(title) + else: + fig.suptitle(self.title) + else: + if is_list_like(self.title): + msg = ( + "Using `title` of type `list` is not supported " + "unless `subplots=True` is passed" + ) + raise ValueError(msg) + self.axes[0].set_title(self.title) + + @final + @staticmethod + def _apply_axis_properties( + axis: Axis, rot=None, fontsize: int | None = None + ) -> None: + """ + Tick creation within matplotlib is reasonably expensive and is + internally deferred until accessed as Ticks are created/destroyed + multiple times per draw. It's therefore beneficial for us to avoid + accessing unless we will act on the Tick. + """ + if rot is not None or fontsize is not None: + # rot=0 is a valid setting, hence the explicit None check + labels = axis.get_majorticklabels() + axis.get_minorticklabels() + for label in labels: + if rot is not None: + label.set_rotation(rot) + if fontsize is not None: + label.set_fontsize(fontsize) + + @final + @property + def legend_title(self) -> str | None: + if not isinstance(self.data.columns, ABCMultiIndex): + name = self.data.columns.name + if name is not None: + name = pprint_thing(name) + return name + else: + stringified = map(pprint_thing, self.data.columns.names) + return ",".join(stringified) + + @final + def _mark_right_label(self, label: str, index: int) -> str: + """ + Append ``(right)`` to the label of a line if it's plotted on the right axis. + + Note that ``(right)`` is only appended when ``subplots=False``. + """ + if not self.subplots and self.mark_right and self.on_right(index): + label += " (right)" + return label + + @final + def _append_legend_handles_labels(self, handle: Artist, label: str) -> None: + """ + Append current handle and label to ``legend_handles`` and ``legend_labels``. + + These will be used to make the legend. + """ + self.legend_handles.append(handle) + self.legend_labels.append(label) + + def _make_legend(self) -> None: + ax, leg = self._get_ax_legend(self.axes[0]) + + handles = [] + labels = [] + title = "" + + if not self.subplots: + if leg is not None: + title = leg.get_title().get_text() + # Replace leg.legend_handles because it misses marker info + handles = leg.legend_handles + labels = [x.get_text() for x in leg.get_texts()] + + if self.legend: + if self.legend == "reverse": + handles += reversed(self.legend_handles) + labels += reversed(self.legend_labels) + else: + handles += self.legend_handles + labels += self.legend_labels + + if self.legend_title is not None: + title = self.legend_title + + if len(handles) > 0: + ax.legend(handles, labels, loc="best", title=title) + + elif self.subplots and self.legend: + for ax in self.axes: + if ax.get_visible(): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + "No artists with labels found to put in legend.", + UserWarning, + ) + ax.legend(loc="best") + + @final + @staticmethod + def _get_ax_legend(ax: Axes): + """ + Take in axes and return ax and legend under different scenarios + """ + leg = ax.get_legend() + + other_ax = cast( + "Axes", getattr(ax, "left_ax", None) or getattr(ax, "right_ax", None) + ) + other_leg = None + if other_ax is not None: + other_leg = other_ax.get_legend() + if leg is None and other_leg is not None: + leg = other_leg + ax = other_ax + return ax, leg + + _need_to_set_index = False + + @final + def _get_xticks(self): + index = self.data.index + is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time") + + # TODO: be stricter about x? + x: list[int] | np.ndarray + if self.use_index: + if isinstance(index, ABCPeriodIndex): + # test_mixed_freq_irreg_period + x = index.to_timestamp()._mpl_repr() + # TODO: why do we need to do to_timestamp() here but not other + # places where we call mpl_repr? + elif is_any_real_numeric_dtype(index.dtype): + # Matplotlib supports numeric values or datetime objects as + # xaxis values. Taking LBYL approach here, by the time + # matplotlib raises exception when using non numeric/datetime + # values for xaxis, several actions are already taken by plt. + x = index._mpl_repr() + elif isinstance(index, ABCDatetimeIndex) or is_datetype: + x = index._mpl_repr() + else: + self._need_to_set_index = True + x = list(range(len(index))) + else: + x = list(range(len(index))) + + return x + + @classmethod + @register_pandas_matplotlib_converters + def _plot( + cls, ax: Axes, x, y: np.ndarray, style=None, is_errorbar: bool = False, **kwds + ): + mask = isna(y) + if mask.any(): + y = np.ma.array(y) + y = np.ma.masked_where(mask, y) + + if isinstance(x, ABCIndex): + x = x._mpl_repr() + + if is_errorbar: + if "xerr" in kwds: + kwds["xerr"] = np.array(kwds.get("xerr")) + if "yerr" in kwds: + kwds["yerr"] = np.array(kwds.get("yerr")) + return ax.errorbar(x, y, **kwds) + else: + # prevent style kwarg from going to errorbar, where it is unsupported + args = (x, y, style) if style is not None else (x, y) + return ax.plot(*args, **kwds) + + def _get_custom_index_name(self): + """Specify whether xlabel/ylabel should be used to override index name""" + return self.xlabel + + @final + def _get_index_name(self) -> str | None: + if isinstance(self.data.index, ABCMultiIndex): + name = self.data.index.names + if com.any_not_none(*name): + name = ",".join([pprint_thing(x) for x in name]) + else: + name = None + else: + name = self.data.index.name + if name is not None: + name = pprint_thing(name) + + # GH 45145, override the default axis label if one is provided. + index_name = self._get_custom_index_name() + if index_name is not None: + name = pprint_thing(index_name) + + return name + + @final + @classmethod + def _get_ax_layer(cls, ax, primary: bool = True): + """get left (primary) or right (secondary) axes""" + if primary: + return getattr(ax, "left_ax", ax) + else: + return getattr(ax, "right_ax", ax) + + @final + def _col_idx_to_axis_idx(self, col_idx: int) -> int: + """Return the index of the axis where the column at col_idx should be plotted""" + if isinstance(self.subplots, list): + # Subplots is a list: some columns will be grouped together in the same ax + return next( + group_idx + for (group_idx, group) in enumerate(self.subplots) + if col_idx in group + ) + else: + # subplots is True: one ax per column + return col_idx + + @final + def _get_ax(self, i: int) -> Axes: + # get the twinx ax if appropriate + if self.subplots: + i = self._col_idx_to_axis_idx(i) + ax = self.axes[i] + ax = self._maybe_right_yaxis(ax, i) + # error: Unsupported target for indexed assignment ("Sequence[Any]") + self.axes[i] = ax # type: ignore[index] + else: + ax = self.axes[0] + ax = self._maybe_right_yaxis(ax, i) + + ax.get_yaxis().set_visible(True) + return ax + + @final + def on_right(self, i: int) -> bool: + if isinstance(self.secondary_y, bool): + return self.secondary_y + + if isinstance(self.secondary_y, (tuple, list, np.ndarray, ABCIndex)): + return self.data.columns[i] in self.secondary_y + + @final + def _apply_style_colors( + self, colors, kwds: dict[str, Any], col_num: int, label: str + ): + """ + Manage style and color based on column number and its label. + Returns tuple of appropriate style and kwds which "color" may be added. + """ + style = None + if self.style is not None: + if isinstance(self.style, list): + try: + style = self.style[col_num] + except IndexError: + pass + elif isinstance(self.style, dict): + style = self.style.get(label, style) + else: + style = self.style + + has_color = "color" in kwds or self.colormap is not None + nocolor_style = style is None or not _color_in_style(style) + if (has_color or self.subplots) and nocolor_style: + if isinstance(colors, dict): + kwds["color"] = colors[label] + else: + kwds["color"] = colors[col_num % len(colors)] + return style, kwds + + def _get_colors( + self, + num_colors: int | None = None, + color_kwds: str = "color", + ): + if num_colors is None: + num_colors = self.nseries + if color_kwds == "color": + color = self.color + else: + color = self.kwds.get(color_kwds) + return get_standard_colors( + num_colors=num_colors, + colormap=self.colormap, + color=color, + ) + + # TODO: tighter typing for first return? + @final + @staticmethod + def _parse_errorbars( + label: str, err, data: NDFrameT, nseries: int + ) -> tuple[Any, NDFrameT]: + """ + Look for error keyword arguments and return the actual errorbar data + or return the error DataFrame/dict + + Error bars can be specified in several ways: + Series: the user provides a pandas.Series object of the same + length as the data + ndarray: provides an np.ndarray of the same length as the data + DataFrame/dict: error values are paired with keys matching the + key in the plotted DataFrame + str: the name of the column within the plotted DataFrame + + Asymmetrical error bars are also supported, however raw error values + must be provided in this case. For an ``N`` length :class:`Series`, a + ``2xN`` array should be provided indicating lower and upper (or left + and right) errors. For an ``MxN`` :class:`DataFrame`, asymmetrical errors + should be in an ``Mx2xN`` array. + """ + if err is None: + return None, data + + def match_labels(data, e): + e = e.reindex(data.index) + return e + + # key-matched DataFrame + if isinstance(err, ABCDataFrame): + err = match_labels(data, err) + # key-matched dict + elif isinstance(err, dict): + pass + + # Series of error values + elif isinstance(err, ABCSeries): + # broadcast error series across data + err = match_labels(data, err) + err = np.atleast_2d(err) + err = np.tile(err, (nseries, 1)) + + # errors are a column in the dataframe + elif isinstance(err, str): + evalues = data[err].values + data = data[data.columns.drop(err)] + err = np.atleast_2d(evalues) + err = np.tile(err, (nseries, 1)) + + elif is_list_like(err): + if is_iterator(err): + err = np.atleast_2d(list(err)) + else: + # raw error values + err = np.atleast_2d(err) + + err_shape = err.shape + + # asymmetrical error bars + if isinstance(data, ABCSeries) and err_shape[0] == 2: + err = np.expand_dims(err, 0) + err_shape = err.shape + if err_shape[2] != len(data): + raise ValueError( + "Asymmetrical error bars should be provided " + f"with the shape (2, {len(data)})" + ) + elif isinstance(data, ABCDataFrame) and err.ndim == 3: + if ( + (err_shape[0] != nseries) + or (err_shape[1] != 2) + or (err_shape[2] != len(data)) + ): + raise ValueError( + "Asymmetrical error bars should be provided " + f"with the shape ({nseries}, 2, {len(data)})" + ) + + # broadcast errors to each data series + if len(err) == 1: + err = np.tile(err, (nseries, 1)) + + elif is_number(err): + err = np.tile( + [err], + (nseries, len(data)), + ) + + else: + msg = f"No valid {label} detected" + raise ValueError(msg) + + return err, data + + @final + def _get_errorbars( + self, label=None, index=None, xerr: bool = True, yerr: bool = True + ) -> dict[str, Any]: + errors = {} + + for kw, flag in zip(["xerr", "yerr"], [xerr, yerr], strict=True): + if flag: + err = self.errors[kw] + # user provided label-matched dataframe of errors + if isinstance(err, (ABCDataFrame, dict)): + if label is not None and label in err.keys(): + err = err[label] + else: + err = None + elif index is not None and err is not None: + err = err[index] + + if err is not None: + errors[kw] = err + return errors + + @final + def _get_subplots(self, fig: Figure) -> list[Axes]: + return [ + ax + for ax in fig.get_axes() + if (isinstance(ax, mpl.axes.Axes) and ax.get_subplotspec() is not None) + ] + + @final + def _get_axes_layout(self, fig: Figure) -> tuple[int, int]: + axes = self._get_subplots(fig) + x_set = set() + y_set = set() + for ax in axes: + # check axes coordinates to estimate layout + points = ax.get_position().get_points() + x_set.add(points[0][0]) + y_set.add(points[0][1]) + return (len(y_set), len(x_set)) + + +class PlanePlot(MPLPlot, ABC): + """ + Abstract class for plotting on plane, currently scatter and hexbin. + """ + + _layout_type = "single" + + def __init__(self, data, x, y, **kwargs) -> None: + MPLPlot.__init__(self, data, **kwargs) + if x is None or y is None: + raise ValueError(self._kind + " requires an x and y column") + if is_integer(x) and not holds_integer(self.data.columns): + x = self.data.columns[x] + if is_integer(y) and not holds_integer(self.data.columns): + y = self.data.columns[y] + + self.x = x + self.y = y + + @final + def _get_nseries(self, data: Series | DataFrame) -> int: + return 1 + + @final + def _post_plot_logic(self, ax: Axes, data) -> None: + x, y = self.x, self.y + xlabel = self.xlabel if self.xlabel is not None else pprint_thing(x) + ylabel = self.ylabel if self.ylabel is not None else pprint_thing(y) + # error: Argument 1 to "set_xlabel" of "_AxesBase" has incompatible + # type "Hashable"; expected "str" + ax.set_xlabel(xlabel) # type: ignore[arg-type] + ax.set_ylabel(ylabel) # type: ignore[arg-type] + + @final + def _plot_colorbar(self, ax: Axes, *, fig: Figure, **kwds): + # Addresses issues #10611 and #10678: + # When plotting scatterplots and hexbinplots in IPython + # inline backend the colorbar axis height tends not to + # exactly match the parent axis height. + # The difference is due to small fractional differences + # in floating points with similar representation. + # To deal with this, this method forces the colorbar + # height to take the height of the parent axes. + # For a more detailed description of the issue + # see the following link: + # https://github.com/ipython/ipython/issues/11215 + + # GH33389, if ax is used multiple times, we should always + # use the last one which contains the latest information + # about the ax + img = ax.collections[-1] + return fig.colorbar(img, ax=ax, **kwds) + + +class ScatterPlot(PlanePlot): + @property + def _kind(self) -> Literal["scatter"]: + return "scatter" + + def __init__( + self, + data, + x, + y, + s=None, + c=None, + *, + colorbar: bool | lib.NoDefault = lib.no_default, + norm=None, + **kwargs, + ) -> None: + if s is None: + # hide the matplotlib default for size, in case we want to change + # the handling of this argument later + s = 20 + elif is_hashable(s) and s in data.columns: + s = data[s] + self.s = s + + self.colorbar = colorbar + self.norm = norm + + super().__init__(data, x, y, **kwargs) + if is_integer(c) and not holds_integer(self.data.columns): + c = self.data.columns[c] + self.c = c + + @register_pandas_matplotlib_converters + def _make_plot(self, fig: Figure) -> None: + x, y, c, data = self.x, self.y, self.c, self.data + ax = self.axes[0] + + from pandas import Series + + x_data = data[x] + s = Series(index=x_data) + if use_dynamic_x(ax, s.index): + s = maybe_convert_index(ax, s) + freq, s = prepare_ts_data(s, ax, self.kwds) + x_data = s.index + + c_is_column = is_hashable(c) and c in self.data.columns + + color_by_categorical = c_is_column and isinstance( + self.data[c].dtype, CategoricalDtype + ) + + color = self.color + c_values = self._get_c_values(color, color_by_categorical, c_is_column) + norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical) + cb = self._get_colorbar(c_values, c_is_column) + + if self.legend: + label = self.label + else: + label = None + + # if a list of non-color strings is passed in as c, color points + # by uniqueness of the strings, such same strings get same color + create_colors = not self._are_valid_colors(c_values) + if create_colors: + color_mapping = self._get_color_mapping(c_values) + c_values = [color_mapping[s] for s in c_values] + + # build legend for labeling custom colors + ax.legend( + handles=[ + mpl.patches.Circle((0, 0), facecolor=c, label=s) + for s, c in color_mapping.items() + ] + ) + + scatter = ax.scatter( + x_data.values, + data[y].values, + c=c_values, + label=label, + cmap=cmap, + norm=norm, + s=self.s, + **self.kwds, + ) + + if cb: + cbar_label = c if c_is_column else "" + cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label) + if color_by_categorical: + n_cats = len(self.data[c].cat.categories) + cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats)) + cbar.ax.set_yticklabels(self.data[c].cat.categories) + + if label is not None: + self._append_legend_handles_labels( + # error: Argument 2 to "_append_legend_handles_labels" of + # "MPLPlot" has incompatible type "Hashable"; expected "str" + scatter, + label, # type: ignore[arg-type] + ) + + errors_x = self._get_errorbars(label=x, index=0, yerr=False) + errors_y = self._get_errorbars(label=y, index=0, xerr=False) + if len(errors_x) > 0 or len(errors_y) > 0: + err_kwds = dict(errors_x, **errors_y) + err_kwds["ecolor"] = scatter.get_facecolor()[0] + ax.errorbar(data[x].values, data[y].values, linestyle="none", **err_kwds) + + def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool): + c = self.c + if c is not None and color is not None: + raise TypeError("Specify exactly one of `c` and `color`") + if c is None and color is None: + c_values = mpl.rcParams["patch.facecolor"] + elif color is not None: + c_values = color + elif color_by_categorical: + c_values = self.data[c].cat.codes + elif c_is_column: + c_values = self.data[c].values + else: + c_values = c + return c_values + + def _are_valid_colors(self, c_values: Series) -> bool: + # check if c_values contains strings and if these strings are valid mpl colors. + # no need to check numerics as these (and mpl colors) will be validated for us + # in .Axes.scatter._parse_scatter_color_args(...) + unique = np.unique(c_values) + try: + if len(c_values) and all(isinstance(c, str) for c in unique): + mpl.colors.to_rgba_array(unique) + + return True + + except (TypeError, ValueError) as _: + return False + + def _get_color_mapping(self, c_values: Series) -> dict[str, np.ndarray]: + unique = np.unique(c_values) + n_colors = len(unique) + + # passing `None` here will default to :rc:`image.cmap` + cmap = mpl.colormaps.get_cmap(self.colormap) + colors = cmap(np.linspace(0, 1, n_colors)) # RGB tuples + + return dict(zip(unique, colors, strict=True)) + + def _get_norm_and_cmap(self, c_values, color_by_categorical: bool): + c = self.c + if self.colormap is not None: + cmap = mpl.colormaps.get_cmap(self.colormap) + # cmap is only used if c_values are integers, otherwise UserWarning. + # GH-53908: additionally call isinstance() because is_integer_dtype + # returns True for "b" (meaning "blue" and not int8 in this context) + elif not isinstance(c_values, str) and is_integer_dtype(c_values): + # pandas uses colormap, matplotlib uses cmap. + cmap = mpl.colormaps["Greys"] + else: + cmap = None + + if color_by_categorical and cmap is not None: + n_cats = len(self.data[c].cat.categories) + cmap = mpl.colors.ListedColormap([cmap(i) for i in range(cmap.N)]) + bounds = np.linspace(0, n_cats, n_cats + 1) + norm = mpl.colors.BoundaryNorm(bounds, cmap.N) + # TODO: warn that we are ignoring self.norm if user specified it? + # Doesn't happen in any tests 2023-11-09 + else: + norm = self.norm + return norm, cmap + + def _get_colorbar(self, c_values, c_is_column: bool) -> bool: + # plot colorbar if + # 1. colormap is assigned, and + # 2.`c` is a column containing only numeric values + plot_colorbar = self.colormap or c_is_column + cb = self.colorbar + if cb is lib.no_default: + return is_numeric_dtype(c_values) and plot_colorbar + return cb + + +class HexBinPlot(PlanePlot): + @property + def _kind(self) -> Literal["hexbin"]: + return "hexbin" + + def __init__(self, data, x, y, C=None, *, colorbar: bool = True, **kwargs) -> None: + super().__init__(data, x, y, **kwargs) + if is_integer(C) and not holds_integer(self.data.columns): + C = self.data.columns[C] + self.C = C + + self.colorbar = colorbar + + # Scatter plot allows to plot objects data + if len(self.data[self.x]._get_numeric_data()) == 0: + raise ValueError(self._kind + " requires x column to be numeric") + if len(self.data[self.y]._get_numeric_data()) == 0: + raise ValueError(self._kind + " requires y column to be numeric") + + def _make_plot(self, fig: Figure) -> None: + x, y, data, C = self.x, self.y, self.data, self.C + ax = self.axes[0] + # pandas uses colormap, matplotlib uses cmap. + cmap = self.colormap or "BuGn" + cmap = mpl.colormaps.get_cmap(cmap) + cb = self.colorbar + + if C is None: + c_values = None + else: + c_values = data[C].values + + ax.hexbin(data[x].values, data[y].values, C=c_values, cmap=cmap, **self.kwds) + if cb: + self._plot_colorbar(ax, fig=fig) + + def _make_legend(self) -> None: + pass + + +class LinePlot(MPLPlot): + _default_rot = 0 + + @property + def orientation(self) -> PlottingOrientation: + return "vertical" + + @property + def _kind(self) -> Literal["line", "area", "hist", "kde", "box"]: + return "line" + + def __init__(self, data, **kwargs) -> None: + MPLPlot.__init__(self, data, **kwargs) + if self.stacked: + self.data = self.data.fillna(value=0) + + def _make_plot(self, fig: Figure) -> None: + if self._is_ts_plot(): + data = maybe_convert_index(self._get_ax(0), self.data) + + x = data.index # dummy, not used + plotf = self._ts_plot + it = data.items() + else: + x = self._get_xticks() + # error: Incompatible types in assignment (expression has type + # "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has + # type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]") + plotf = self._plot # type: ignore[assignment] + # error: Incompatible types in assignment (expression has type + # "Iterator[tuple[Hashable, ndarray[Any, Any]]]", variable has + # type "Iterable[tuple[Hashable, Series]]") + it = self._iter_data(data=self.data) # type: ignore[assignment] + + stacking_id = self._get_stacking_id() + is_errorbar = com.any_not_none(*self.errors.values()) + + colors = self._get_colors() + for i, (label, y) in enumerate(it): + ax = self._get_ax(i) + kwds = self.kwds.copy() + if self.color is not None: + kwds["color"] = self.color + style, kwds = self._apply_style_colors( + colors, + kwds, + i, + # error: Argument 4 to "_apply_style_colors" of "MPLPlot" has + # incompatible type "Hashable"; expected "str" + label, # type: ignore[arg-type] + ) + + errors = self._get_errorbars(label=label, index=i) + kwds = dict(kwds, **errors) + + label = pprint_thing(label) + label = self._mark_right_label(label, index=i) + kwds["label"] = label + + newlines = plotf( + ax, + x, + y, + style=style, + column_num=i, + stacking_id=stacking_id, + is_errorbar=is_errorbar, + **kwds, + ) + self._append_legend_handles_labels(newlines[0], label) + + if self._is_ts_plot(): + # reset of xlim should be used for ts data + # TODO: GH28021, should find a way to change view limit on xaxis + lines = get_all_lines(ax) + left, right = get_xlim(lines) + ax.set_xlim(left, right) + + # error: Signature of "_plot" incompatible with supertype "MPLPlot" + @classmethod + def _plot( # type: ignore[override] + cls, + ax: Axes, + x, + y: np.ndarray, + style=None, + column_num=None, + stacking_id=None, + **kwds, + ): + # column_num is used to get the target column from plotf in line and + # area plots + if column_num == 0: + cls._initialize_stacker(ax, stacking_id, len(y)) + y_values = cls._get_stacked_values(ax, stacking_id, y, kwds["label"]) + lines = MPLPlot._plot(ax, x, y_values, style=style, **kwds) + cls._update_stacker(ax, stacking_id, y) + return lines + + @final + def _ts_plot(self, ax: Axes, x, data: Series, style=None, **kwds): + # accept x to be consistent with normal plot func, + # x is not passed to tsplot as it uses data.index as x coordinate + # column_num must be in kwds for stacking purpose + freq, data = prepare_ts_data(data, ax, kwds) + + # TODO #54485 + ax._plot_data.append((data, self._kind, kwds)) # type: ignore[attr-defined] + + lines = self._plot(ax, data.index, np.asarray(data.values), style=style, **kwds) + # set date formatter, locators and rescale limits + # TODO #54485 + format_dateaxis(ax, ax.freq, data.index) # type: ignore[arg-type, attr-defined] + return lines + + @final + def _get_stacking_id(self) -> int | None: + if self.stacked: + return id(self.data) + else: + return None + + @final + @classmethod + def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None: + if stacking_id is None: + return + if not hasattr(ax, "_stacker_pos_prior"): + # TODO #54485 + ax._stacker_pos_prior = {} # type: ignore[attr-defined] + if not hasattr(ax, "_stacker_neg_prior"): + # TODO #54485 + ax._stacker_neg_prior = {} # type: ignore[attr-defined] + # TODO #54485 + ax._stacker_pos_prior[stacking_id] = np.zeros(n) # type: ignore[attr-defined] + # TODO #54485 + ax._stacker_neg_prior[stacking_id] = np.zeros(n) # type: ignore[attr-defined] + + @final + @classmethod + def _get_stacked_values( + cls, ax: Axes, stacking_id: int | None, values: np.ndarray, label + ) -> np.ndarray: + if stacking_id is None: + return values + if not hasattr(ax, "_stacker_pos_prior"): + # stacker may not be initialized for subplots + cls._initialize_stacker(ax, stacking_id, len(values)) + + if (values >= 0).all(): + # TODO #54485 + return ( + ax._stacker_pos_prior[stacking_id] # type: ignore[attr-defined] + + values + ) + elif (values <= 0).all(): + # TODO #54485 + return ( + ax._stacker_neg_prior[stacking_id] # type: ignore[attr-defined] + + values + ) + + raise ValueError( + "When stacked is True, each column must be either " + "all positive or all negative. " + f"Column '{label}' contains both positive and negative values" + ) + + @final + @classmethod + def _update_stacker(cls, ax: Axes, stacking_id: int | None, values) -> None: + if stacking_id is None: + return + if (values >= 0).all(): + # TODO #54485 + ax._stacker_pos_prior[stacking_id] += values # type: ignore[attr-defined] + elif (values <= 0).all(): + # TODO #54485 + ax._stacker_neg_prior[stacking_id] += values # type: ignore[attr-defined] + + def _post_plot_logic(self, ax: Axes, data) -> None: + def get_label(i): + if is_float(i) and i.is_integer(): + i = int(i) + try: + return pprint_thing(data.index[i]) + except Exception: + return "" + + if self._need_to_set_index: + xticks = ax.get_xticks() + xticklabels = [get_label(x) for x in xticks] + # error: Argument 1 to "FixedLocator" has incompatible type "ndarray[Any, + # Any]"; expected "Sequence[float]" + ax.xaxis.set_major_locator(mpl.ticker.FixedLocator(xticks)) # type: ignore[arg-type] + ax.set_xticklabels(xticklabels) + + # If the index is an irregular time series, then by default + # we rotate the tick labels. The exception is if there are + # subplots which don't share their x-axes, in which we case + # we don't rotate the ticklabels as by default the subplots + # would be too close together. + condition = ( + not self._use_dynamic_x() + and (data.index._is_all_dates and self.use_index) + and (not self.subplots or (self.subplots and self.sharex)) + ) + + index_name = self._get_index_name() + + if condition: + # irregular TS rotated 30 deg. by default + # probably a better place to check / set this. + if not self._rot_set: + self.rot = 30 + format_date_labels(ax, rot=self.rot) + + if index_name is not None and self.use_index: + ax.set_xlabel(index_name) + + +class AreaPlot(LinePlot): + @property + def _kind(self) -> Literal["area"]: + return "area" + + def __init__(self, data, **kwargs) -> None: + kwargs.setdefault("stacked", True) + data = data.fillna(value=0) + LinePlot.__init__(self, data, **kwargs) + + if not self.stacked: + # use smaller alpha to distinguish overlap + self.kwds.setdefault("alpha", 0.5) + + if self.logy or self.loglog: + raise ValueError("Log-y scales are not supported in area plot") + + # error: Signature of "_plot" incompatible with supertype "MPLPlot" + @classmethod + def _plot( # type: ignore[override] + cls, + ax: Axes, + x, + y: np.ndarray, + style=None, + column_num=None, + stacking_id=None, + is_errorbar: bool = False, + **kwds, + ): + if column_num == 0: + cls._initialize_stacker(ax, stacking_id, len(y)) + y_values = cls._get_stacked_values(ax, stacking_id, y, kwds["label"]) + + # need to remove label, because subplots uses mpl legend as it is + line_kwds = kwds.copy() + line_kwds.pop("label") + lines = MPLPlot._plot(ax, x, y_values, style=style, **line_kwds) + + # get data from the line to get coordinates for fill_between + xdata, y_values = lines[0].get_data(orig=False) + + # unable to use ``_get_stacked_values`` here to get starting point + if stacking_id is None: + start = np.zeros(len(y)) + elif (y >= 0).all(): + # TODO #54485 + start = ax._stacker_pos_prior[stacking_id] # type: ignore[attr-defined] + elif (y <= 0).all(): + # TODO #54485 + start = ax._stacker_neg_prior[stacking_id] # type: ignore[attr-defined] + else: + start = np.zeros(len(y)) + + if "color" not in kwds: + kwds["color"] = lines[0].get_color() + + rect = ax.fill_between(xdata, start, y_values, **kwds) + cls._update_stacker(ax, stacking_id, y) + + # LinePlot expects list of artists + res = [rect] + return res + + def _post_plot_logic(self, ax: Axes, data) -> None: + LinePlot._post_plot_logic(self, ax, data) + + is_shared_y = len(list(ax.get_shared_y_axes())) > 0 + # do not override the default axis behaviour in case of shared y axes + if self.ylim is None and not is_shared_y: + if (data >= 0).all().all(): + ax.set_ylim(0, None) + elif (data <= 0).all().all(): + ax.set_ylim(None, 0) + + +class BarPlot(MPLPlot): + @property + def _kind(self) -> Literal["bar", "barh"]: + return "bar" + + _default_rot = 90 + + @property + def orientation(self) -> PlottingOrientation: + return "vertical" + + def __init__( + self, + data, + *, + align="center", + bottom=0, + left=0, + width=0.5, + position=0.5, + log=False, + **kwargs, + ) -> None: + # we have to treat a series differently than a + # 1-column DataFrame w.r.t. color handling + self._is_series = isinstance(data, ABCSeries) + self.bar_width = width + self._align = align + self._position = position + + if is_list_like(bottom): + bottom = np.array(bottom) + if is_list_like(left): + left = np.array(left) + self.bottom = bottom + self.left = left + + self.log = log + + MPLPlot.__init__(self, data, **kwargs) + + if self._is_ts_plot(): + self.tick_pos = np.array( + PeriodConverter.convert_from_freq( + self._get_xticks(), + data.index.freq, + ) + ) + else: + self.tick_pos = np.arange(len(data)) + + @cache_readonly + def ax_pos(self) -> np.ndarray: + return self.tick_pos - self.tickoffset + + @cache_readonly + def tickoffset(self): + if self.stacked or self.subplots: + return self.bar_width * self._position + elif self._align == "edge": + w = self.bar_width / self.nseries + return self.bar_width * (self._position - 0.5) + w * 0.5 + else: + return self.bar_width * self._position + + @cache_readonly + def lim_offset(self): + if self.stacked or self.subplots: + if self._align == "edge": + return self.bar_width / 2 + else: + return 0 + elif self._align == "edge": + w = self.bar_width / self.nseries + return w * 0.5 + else: + return 0 + + # error: Signature of "_plot" incompatible with supertype "MPLPlot" + @classmethod + @register_pandas_matplotlib_converters + def _plot( # type: ignore[override] + cls, + ax: Axes, + x, + y: np.ndarray, + w, + start: int | npt.NDArray[np.intp] = 0, + log: bool = False, + **kwds, + ): + return ax.bar(x, y, w, bottom=start, log=log, **kwds) + + @property + def _start_base(self): + return self.bottom + + def _make_plot(self, fig: Figure) -> None: + colors = self._get_colors() + ncolors = len(colors) + + pos_prior = neg_prior = np.zeros(len(self.data)) + K = self.nseries + + data = self.data.fillna(0) + + _stacked_subplots_ind: dict[int, int] = {} + _stacked_subplots_offsets = [] + + self.subplots: list[Any] + + if not isinstance(self.subplots, bool): + if bool(self.subplots) and self.stacked: + for i, sub_plot in enumerate(self.subplots): + if len(sub_plot) <= 1: + continue + for plot in sub_plot: + _stacked_subplots_ind[int(plot)] = i + _stacked_subplots_offsets.append([0, 0]) + + for i, (label, y) in enumerate(self._iter_data(data=data)): + ax = self._get_ax(i) + kwds = self.kwds.copy() + if self._is_series: + kwds["color"] = colors + elif isinstance(colors, dict): + kwds["color"] = colors[label] + else: + kwds["color"] = colors[i % ncolors] + + errors = self._get_errorbars(label=label, index=i) + kwds = dict(kwds, **errors) + + label = pprint_thing(label) + label = self._mark_right_label(label, index=i) + + if (("yerr" in kwds) or ("xerr" in kwds)) and (kwds.get("ecolor") is None): + kwds["ecolor"] = mpl.rcParams["xtick.color"] + + start = 0 + if self.log and (y >= 1).all(): + start = 1 + start = start + self._start_base + + kwds["align"] = self._align + + if i in _stacked_subplots_ind: + offset_index = _stacked_subplots_ind[i] + pos_prior, neg_prior = _stacked_subplots_offsets[offset_index] # type: ignore[assignment] + mask = y >= 0 + start = np.where(mask, pos_prior, neg_prior) + self._start_base + w = self.bar_width / 2 + rect = self._plot( + ax, + self.ax_pos + w, + y, + self.bar_width, + start=start, + label=label, + log=self.log, + **kwds, + ) + pos_new = pos_prior + np.where(mask, y, 0) + neg_new = neg_prior + np.where(mask, 0, y) + _stacked_subplots_offsets[offset_index] = [pos_new, neg_new] + + elif self.subplots: + w = self.bar_width / 2 + rect = self._plot( + ax, + self.ax_pos + w, + y, + self.bar_width, + start=start, + label=label, + log=self.log, + **kwds, + ) + ax.set_title(label) + elif self.stacked: + mask = y >= 0 + start = np.where(mask, pos_prior, neg_prior) + self._start_base + w = self.bar_width / 2 + rect = self._plot( + ax, + self.ax_pos + w, + y, + self.bar_width, + start=start, + label=label, + log=self.log, + **kwds, + ) + pos_prior = pos_prior + np.where(mask, y, 0) + neg_prior = neg_prior + np.where(mask, 0, y) + else: + w = self.bar_width / K + rect = self._plot( + ax, + self.ax_pos + (i + 0.5) * w, + y, + w, + start=start, + label=label, + log=self.log, + **kwds, + ) + self._append_legend_handles_labels(rect, label) + + def _post_plot_logic(self, ax: Axes, data) -> None: + if self.use_index: + str_index = [pprint_thing(key) for key in data.index] + else: + str_index = [pprint_thing(key) for key in range(data.shape[0])] + + s_edge = self.ax_pos[0] - 0.25 + self.lim_offset + e_edge = self.ax_pos[-1] + 0.25 + self.bar_width + self.lim_offset + + self._decorate_ticks(ax, self._get_index_name(), str_index, s_edge, e_edge) + + def _decorate_ticks( + self, + ax: Axes, + name: str | None, + ticklabels: list[str], + start_edge: float, + end_edge: float, + ) -> None: + ax.set_xlim((start_edge, end_edge)) + + if self.xticks is not None: + ax.set_xticks(np.array(self.xticks)) + else: + ax.set_xticks(self.tick_pos) + ax.set_xticklabels(ticklabels) + + if name is not None and self.use_index: + ax.set_xlabel(name) + + +class BarhPlot(BarPlot): + @property + def _kind(self) -> Literal["barh"]: + return "barh" + + _default_rot = 0 + + @property + def orientation(self) -> Literal["horizontal"]: + return "horizontal" + + @property + def _start_base(self): + return self.left + + # error: Signature of "_plot" incompatible with supertype "MPLPlot" + @classmethod + def _plot( # type: ignore[override] + cls, + ax: Axes, + x, + y: np.ndarray, + w, + start: int | npt.NDArray[np.intp] = 0, + log: bool = False, + **kwds, + ): + return ax.barh(x, y, w, left=start, log=log, **kwds) + + def _get_custom_index_name(self): + return self.ylabel + + def _decorate_ticks( + self, + ax: Axes, + name: str | None, + ticklabels: list[str], + start_edge: float, + end_edge: float, + ) -> None: + # horizontal bars + ax.set_ylim((start_edge, end_edge)) + ax.set_yticks(self.tick_pos) + ax.set_yticklabels(ticklabels) + if name is not None and self.use_index: + ax.set_ylabel(name) + # error: Argument 1 to "set_xlabel" of "_AxesBase" has incompatible type + # "Hashable | None"; expected "str" + ax.set_xlabel(self.xlabel) # type: ignore[arg-type] + + +class PiePlot(MPLPlot): + @property + def _kind(self) -> Literal["pie"]: + return "pie" + + _layout_type = "horizontal" + + def __init__(self, data: Series | DataFrame, kind=None, **kwargs) -> None: + data = data.fillna(value=0) + lt_zero = data < 0 + if isinstance(data, ABCDataFrame) and lt_zero.any().any(): + raise ValueError(f"{self._kind} plot doesn't allow negative values") + elif isinstance(data, ABCSeries) and lt_zero.any(): + raise ValueError(f"{self._kind} plot doesn't allow negative values") + MPLPlot.__init__(self, data, kind=kind, **kwargs) + + @classmethod + def _validate_log_kwd( + cls, + kwd: str, + value: bool | None | Literal["sym"], + ) -> bool | None | Literal["sym"]: + super()._validate_log_kwd(kwd=kwd, value=value) + if value is not False: + warnings.warn( + f"PiePlot ignores the '{kwd}' keyword", + UserWarning, + stacklevel=find_stack_level(), + ) + return False + + def _validate_color_args(self, color, colormap) -> None: + # TODO: warn if color is passed and ignored? + return None + + def _make_plot(self, fig: Figure) -> None: + colors = self._get_colors(num_colors=len(self.data), color_kwds="colors") + self.kwds.setdefault("colors", colors) + + for i, (label, y) in enumerate(self._iter_data(data=self.data)): + ax = self._get_ax(i) + + kwds = self.kwds.copy() + + def blank_labeler(label, value): + if value == 0: + return "" + else: + return label + + idx = [pprint_thing(v) for v in self.data.index] + labels = kwds.pop("labels", idx) + # labels is used for each wedge's labels + # Blank out labels for values of 0 so they don't overlap + # with nonzero wedges + if labels is not None: + blabels = [ + blank_labeler(left, value) + for left, value in zip(labels, y, strict=True) + ] + else: + blabels = None + results = ax.pie(y, labels=blabels, **kwds) + + if kwds.get("autopct", None) is not None: + # error: Need more than 2 values to unpack (3 expected) + patches, texts, autotexts = results # type: ignore[misc] + else: + # error: Too many values to unpack (2 expected, 3 provided) + patches, texts = results # type: ignore[misc] + autotexts = [] + + if self.fontsize is not None: + for t in texts + autotexts: + t.set_fontsize(self.fontsize) + + # leglabels is used for legend labels + leglabels = labels if labels is not None else idx + for _patch, _leglabel in zip(patches, leglabels, strict=True): + self._append_legend_handles_labels(_patch, _leglabel) + + def _post_plot_logic(self, ax: Axes, data) -> None: + pass diff --git a/pandas/plotting/_matplotlib/groupby.py b/pandas/plotting/_matplotlib/groupby.py new file mode 100644 index 0000000000000000000000000000000000000000..783f79710097c7e471d3f531fac3be8cd711014a --- /dev/null +++ b/pandas/plotting/_matplotlib/groupby.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from pandas.core.dtypes.missing import remove_na_arraylike + +from pandas import ( + MultiIndex, + concat, +) + +from pandas.plotting._matplotlib.misc import unpack_single_str_list + +if TYPE_CHECKING: + from collections.abc import Hashable + + from pandas._typing import IndexLabel + + from pandas import ( + DataFrame, + Series, + ) + + +def create_iter_data_given_by( + data: DataFrame, kind: str = "hist" +) -> dict[Hashable, DataFrame | Series]: + """ + Create data for iteration given `by` is assigned or not, and it is only + used in both hist and boxplot. + + If `by` is assigned, return a dictionary of DataFrames in which the key of + dictionary is the values in groups. + If `by` is not assigned, return input as is, and this preserves current + status of iter_data. + + Parameters + ---------- + data : reformatted grouped data from `_compute_plot_data` method. + kind : str, plot kind. This function is only used for `hist` and `box` plots. + + Returns + ------- + iter_data : DataFrame or Dictionary of DataFrames + + Examples + -------- + If `by` is assigned: + + >>> import numpy as np + >>> tuples = [("h1", "a"), ("h1", "b"), ("h2", "a"), ("h2", "b")] + >>> mi = pd.MultiIndex.from_tuples(tuples) + >>> value = [[1, 3, np.nan, np.nan], [3, 4, np.nan, np.nan], [np.nan, np.nan, 5, 6]] + >>> data = pd.DataFrame(value, columns=mi) + >>> create_iter_data_given_by(data) + {'h1': h1 + a b + 0 1.0 3.0 + 1 3.0 4.0 + 2 NaN NaN, 'h2': h2 + a b + 0 NaN NaN + 1 NaN NaN + 2 5.0 6.0} + """ + + # For `hist` plot, before transformation, the values in level 0 are values + # in groups and subplot titles, and later used for column subselection and + # iteration; For `box` plot, values in level 1 are column names to show, + # and are used for iteration and as subplots titles. + if kind == "hist": + level = 0 + else: + level = 1 + + # Select sub-columns based on the value of level of MI, and if `by` is + # assigned, data must be a MI DataFrame + assert isinstance(data.columns, MultiIndex) + return { + col: data.loc[:, data.columns.get_level_values(level) == col] + for col in data.columns.levels[level] + } + + +def reconstruct_data_with_by( + data: DataFrame, by: IndexLabel, cols: IndexLabel +) -> DataFrame: + """ + Internal function to group data, and reassign multiindex column names onto the + result in order to let grouped data be used in _compute_plot_data method. + + Parameters + ---------- + data : Original DataFrame to plot + by : grouped `by` parameter selected by users + cols : columns of data set (excluding columns used in `by`) + + Returns + ------- + Output is the reconstructed DataFrame with MultiIndex columns. The first level + of MI is unique values of groups, and second level of MI is the columns + selected by users. + + Examples + -------- + >>> d = {"h": ["h1", "h1", "h2"], "a": [1, 3, 5], "b": [3, 4, 6]} + >>> df = pd.DataFrame(d) + >>> reconstruct_data_with_by(df, by="h", cols=["a", "b"]) + h1 h2 + a b a b + 0 1.0 3.0 NaN NaN + 1 3.0 4.0 NaN NaN + 2 NaN NaN 5.0 6.0 + """ + by_modified = unpack_single_str_list(by) + grouped = data.groupby(by_modified) + + data_list = [] + for key, group in grouped: + # error: List item 1 has incompatible type "Union[Hashable, + # Sequence[Hashable]]"; expected "Iterable[Hashable]" + columns = MultiIndex.from_product([[key], cols]) # type: ignore[list-item] + sub_group = group[cols] + sub_group.columns = columns + data_list.append(sub_group) + + data = concat(data_list, axis=1) + return data + + +def reformat_hist_y_given_by(y: np.ndarray, by: IndexLabel | None) -> np.ndarray: + """Internal function to reformat y given `by` is applied or not for hist plot. + + If by is None, input y is 1-d with NaN removed; and if by is not None, groupby + will take place and input y is multi-dimensional array. + """ + if by is not None and len(y.shape) > 1: + return np.array([remove_na_arraylike(col) for col in y.T]).T + return remove_na_arraylike(y) diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py new file mode 100644 index 0000000000000000000000000000000000000000..029db85b315fd5d7849cd4441e26d0e50a046f01 --- /dev/null +++ b/pandas/plotting/_matplotlib/hist.py @@ -0,0 +1,574 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Literal, + final, +) + +import numpy as np + +from pandas.core.dtypes.common import ( + is_integer, + is_list_like, +) +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCIndex, +) +from pandas.core.dtypes.missing import ( + isna, + remove_na_arraylike, +) + +from pandas.io.formats.printing import pprint_thing +from pandas.plotting._matplotlib.core import ( + LinePlot, + MPLPlot, +) +from pandas.plotting._matplotlib.groupby import ( + create_iter_data_given_by, + reformat_hist_y_given_by, +) +from pandas.plotting._matplotlib.misc import unpack_single_str_list +from pandas.plotting._matplotlib.tools import ( + create_subplots, + flatten_axes, + maybe_adjust_figure, + set_ticks_props, +) + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.container import BarContainer + from matplotlib.figure import Figure + from matplotlib.patches import Polygon + + from pandas._typing import PlottingOrientation + + from pandas import ( + DataFrame, + Series, + ) + + +class HistPlot(LinePlot): + @property + def _kind(self) -> Literal["hist", "kde"]: + return "hist" + + def __init__( + self, + data, + bins: int | np.ndarray | list[np.ndarray] = 10, + bottom: int | np.ndarray = 0, + *, + range=None, + weights=None, + **kwargs, + ) -> None: + if is_list_like(bottom): + bottom = np.array(bottom) + self.bottom = bottom + + self._bin_range = range + self.weights = weights + + self.xlabel = kwargs.get("xlabel") + self.ylabel = kwargs.get("ylabel") + # Do not call LinePlot.__init__ which may fill nan + MPLPlot.__init__(self, data, **kwargs) + + self.bins = self._adjust_bins(bins) + + def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]): + if is_integer(bins): + if self.by is not None: + by_modified = unpack_single_str_list(self.by) + grouped = self.data.groupby(by_modified)[self.columns] + bins = [self._calculate_bins(group, bins) for key, group in grouped] + else: + bins = self._calculate_bins(self.data, bins) + return bins + + def _calculate_bins(self, data: Series | DataFrame, bins) -> np.ndarray: + """Calculate bins given data""" + nd_values = data.infer_objects()._get_numeric_data() + values = nd_values.values + if nd_values.ndim == 2: + values = values.reshape(-1) + values = values[~isna(values)] + + return np.histogram_bin_edges(values, bins=bins, range=self._bin_range) + + # error: Signature of "_plot" incompatible with supertype "LinePlot" + @classmethod + def _plot( # type: ignore[override] + cls, + ax: Axes, + y: np.ndarray, + style=None, + bottom: int | np.ndarray = 0, + column_num: int = 0, + stacking_id=None, + *, + bins, + **kwds, + # might return a subset from the possible return types of Axes.hist(...)[2]? + ) -> BarContainer | Polygon | list[BarContainer | Polygon]: + if column_num == 0: + cls._initialize_stacker(ax, stacking_id, len(bins) - 1) + + base = np.zeros(len(bins) - 1) + bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds["label"]) + # ignore style + n, bins, patches = ax.hist(y, bins=bins, bottom=bottom, **kwds) + cls._update_stacker(ax, stacking_id, n) + return patches + + def _make_plot(self, fig: Figure) -> None: + colors = self._get_colors() + stacking_id = self._get_stacking_id() + + # Re-create iterated data if `by` is assigned by users + data = ( + create_iter_data_given_by(self.data, self._kind) + if self.by is not None + else self.data + ) + for i, (label, y) in enumerate(self._iter_data(data=data)): + ax = self._get_ax(i) + + kwds = self.kwds.copy() + if self.color is not None: + kwds["color"] = self.color + + label = pprint_thing(label) + label = self._mark_right_label(label, index=i) + kwds["label"] = label + + style, kwds = self._apply_style_colors(colors, kwds, i, label) + if style is not None: + kwds["style"] = style + + self._make_plot_keywords(kwds, y) + + # the bins is multi-dimension array now and each plot need only 1-d and + # when by is applied, label should be columns that are grouped + if self.by is not None: + kwds["bins"] = kwds["bins"][i] + kwds["label"] = self.columns + kwds.pop("color") + + if self.weights is not None: + kwds["weights"] = type(self)._get_column_weights(self.weights, i, y) + + y = reformat_hist_y_given_by(y, self.by) + + artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds) + + # when by is applied, show title for subplots to know which group it is + if self.by is not None: + ax.set_title(pprint_thing(label)) + + # error: Value of type "Polygon" is not indexable + self._append_legend_handles_labels(artists[0], label) # type: ignore[index,arg-type] + + def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None: + """merge BoxPlot/KdePlot properties to passed kwds""" + # y is required for KdePlot + kwds["bottom"] = self.bottom + kwds["bins"] = self.bins + + @final + @staticmethod + def _get_column_weights(weights, i: int, y): + # We allow weights to be a multi-dimensional array, e.g. a (10, 2) array, + # and each sub-array (10,) will be called in each iteration. If users only + # provide 1D array, we assume the same weights is used for all iterations + if weights is not None: + if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1: + try: + weights = weights[:, i] + except IndexError as err: + raise ValueError( + "weights must have the same shape as data, " + "or be a single column" + ) from err + weights = weights[~isna(y)] + return weights + + def _post_plot_logic(self, ax: Axes, data) -> None: + if self.orientation == "horizontal": + # error: Argument 1 to "set_xlabel" of "_AxesBase" has incompatible + # type "Hashable"; expected "str" + ax.set_xlabel( + "Frequency" if self.xlabel is None else self.xlabel # type: ignore[arg-type] + ) + ax.set_ylabel(self.ylabel) # type: ignore[arg-type] + else: + ax.set_xlabel(self.xlabel) # type: ignore[arg-type] + ax.set_ylabel( + "Frequency" if self.ylabel is None else self.ylabel # type: ignore[arg-type] + ) + + @property + def orientation(self) -> PlottingOrientation: + if self.kwds.get("orientation", None) == "horizontal": + return "horizontal" + else: + return "vertical" + + +class KdePlot(HistPlot): + @property + def _kind(self) -> Literal["kde"]: + return "kde" + + @property + def orientation(self) -> Literal["vertical"]: + return "vertical" + + def __init__( + self, data, bw_method=None, ind=None, *, weights=None, **kwargs + ) -> None: + # Do not call LinePlot.__init__ which may fill nan + MPLPlot.__init__(self, data, **kwargs) + self.bw_method = bw_method + self.ind = ind + self.weights = weights + + @staticmethod + def _get_ind(y: np.ndarray, ind): + if ind is None: + # np.nanmax() and np.nanmin() ignores the missing values + sample_range = np.nanmax(y) - np.nanmin(y) + ind = np.linspace( + np.nanmin(y) - 0.5 * sample_range, + np.nanmax(y) + 0.5 * sample_range, + 1000, + ) + elif is_integer(ind): + sample_range = np.nanmax(y) - np.nanmin(y) + ind = np.linspace( + np.nanmin(y) - 0.5 * sample_range, + np.nanmax(y) + 0.5 * sample_range, + ind, + ) + return ind + + @classmethod + # error: Signature of "_plot" incompatible with supertype "MPLPlot" + def _plot( # type: ignore[override] + cls, + ax: Axes, + y: np.ndarray, + style=None, + bw_method=None, + weights=None, + ind=None, + column_num=None, + stacking_id: int | None = None, + **kwds, + ): + from scipy.stats import gaussian_kde + + y = remove_na_arraylike(y) + gkde = gaussian_kde(y, bw_method=bw_method, weights=weights) + + # gaussian_kde.evaluate(None) raises TypeError, so pyright requires this check + assert ind is not None + y = gkde.evaluate(ind) + lines = MPLPlot._plot(ax, ind, y, style=style, **kwds) + return lines + + def _make_plot_keywords(self, kwds: dict[str, Any], y: np.ndarray) -> None: + kwds["bw_method"] = self.bw_method + kwds["ind"] = type(self)._get_ind(y, ind=self.ind) + + def _post_plot_logic(self, ax: Axes, data) -> None: + ax.set_ylabel("Density") + + +def _grouped_plot( + plotf, + data: Series | DataFrame, + column=None, + by=None, + numeric_only: bool = True, + figsize: tuple[float, float] | None = None, + sharex: bool = True, + sharey: bool = True, + layout=None, + rot: float = 0, + ax=None, + **kwargs, +): + # error: Non-overlapping equality check (left operand type: "Optional[Tuple[float, + # float]]", right operand type: "Literal['default']") + if figsize == "default": # type: ignore[comparison-overlap] + # allowed to specify mpl default with 'default' + raise ValueError( + "figsize='default' is no longer supported. " + "Specify figure size by tuple instead" + ) + + grouped = data.groupby(by) + if column is not None: + grouped = grouped[column] + + naxes = len(grouped) + fig, axes = create_subplots( + naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout + ) + + for ax, (key, group) in zip(flatten_axes(axes), grouped, strict=False): + if numeric_only and isinstance(group, ABCDataFrame): + group = group._get_numeric_data() + plotf(group, ax, **kwargs) + ax.set_title(pprint_thing(key)) + + return fig, axes + + +def _grouped_hist( + data: Series | DataFrame, + column=None, + by=None, + ax=None, + bins: int = 50, + figsize: tuple[float, float] | None = None, + layout=None, + sharex: bool = False, + sharey: bool = False, + rot: float = 90, + grid: bool = True, + xlabelsize: int | None = None, + xrot=None, + ylabelsize: int | None = None, + yrot=None, + legend: bool = False, + **kwargs, +): + """ + Grouped histogram + + Parameters + ---------- + data : Series/DataFrame + column : object, optional + by : object, optional + ax : axes, optional + bins : int, default 50 + figsize : tuple, optional + layout : optional + sharex : bool, default False + sharey : bool, default False + rot : float, default 90 + grid : bool, default True + legend: : bool, default False + kwargs : dict, keyword arguments passed to matplotlib.Axes.hist + + Returns + ------- + collection of Matplotlib Axes + """ + if legend: + assert "label" not in kwargs + if data.ndim == 1: + kwargs["label"] = data.name + elif column is None: + kwargs["label"] = data.columns + else: + kwargs["label"] = column + + def plot_group(group, ax) -> None: + ax.hist(group.dropna().values, bins=bins, **kwargs) + if legend: + ax.legend() + + if xrot is None: + xrot = rot + + fig, axes = _grouped_plot( + plot_group, + data, + column=column, + by=by, + sharex=sharex, + sharey=sharey, + ax=ax, + figsize=figsize, + layout=layout, + rot=rot, + ) + + set_ticks_props( + axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot + ) + + maybe_adjust_figure( + fig, bottom=0.15, top=0.9, left=0.1, right=0.9, hspace=0.5, wspace=0.3 + ) + return axes + + +def hist_series( + self: Series, + by=None, + ax=None, + grid: bool = True, + xlabelsize: int | None = None, + xrot=None, + ylabelsize: int | None = None, + yrot=None, + figsize: tuple[float, float] | None = None, + bins: int = 10, + legend: bool = False, + **kwds, +): + import matplotlib.pyplot as plt + + if legend and "label" in kwds: + raise ValueError("Cannot use both legend and label") + + if by is None: + if kwds.get("layout", None) is not None: + raise ValueError("The 'layout' keyword is not supported when 'by' is None") + # hack until the plotting interface is a bit more unified + fig = kwds.pop( + "figure", plt.gcf() if plt.get_fignums() else plt.figure(figsize=figsize) + ) + if figsize is not None and tuple(figsize) != tuple(fig.get_size_inches()): + fig.set_size_inches(*figsize, forward=True) + if ax is None: + ax = fig.gca() + elif ax.get_figure() != fig: + raise AssertionError("passed axis not bound to passed figure") + values = self.dropna().values + if legend: + kwds["label"] = self.name + ax.hist(values, bins=bins, **kwds) + if legend: + ax.legend() + ax.grid(grid) + axes = np.array([ax]) + + set_ticks_props( + axes, + xlabelsize=xlabelsize, + xrot=xrot, + ylabelsize=ylabelsize, + yrot=yrot, + ) + + else: + if "figure" in kwds: + raise ValueError( + "Cannot pass 'figure' when using the " + "'by' argument, since a new 'Figure' instance will be created" + ) + axes = _grouped_hist( + self, + by=by, + ax=ax, + grid=grid, + figsize=figsize, + bins=bins, + xlabelsize=xlabelsize, + xrot=xrot, + ylabelsize=ylabelsize, + yrot=yrot, + legend=legend, + **kwds, + ) + + if hasattr(axes, "ndim"): + if axes.ndim == 1 and len(axes) == 1: + return axes[0] + return axes + + +def hist_frame( + data: DataFrame, + column=None, + by=None, + grid: bool = True, + xlabelsize: int | None = None, + xrot=None, + ylabelsize: int | None = None, + yrot=None, + ax=None, + sharex: bool = False, + sharey: bool = False, + figsize: tuple[float, float] | None = None, + layout=None, + bins: int = 10, + legend: bool = False, + **kwds, +): + if legend and "label" in kwds: + raise ValueError("Cannot use both legend and label") + if by is not None: + axes = _grouped_hist( + data, + column=column, + by=by, + ax=ax, + grid=grid, + figsize=figsize, + sharex=sharex, + sharey=sharey, + layout=layout, + bins=bins, + xlabelsize=xlabelsize, + xrot=xrot, + ylabelsize=ylabelsize, + yrot=yrot, + legend=legend, + **kwds, + ) + return axes + + if column is not None: + if not isinstance(column, (list, np.ndarray, ABCIndex)): + column = [column] + data = data[column] + # GH32590 + data = data.select_dtypes( + include=(np.number, "datetime64", "datetimetz"), exclude="timedelta" + ) + naxes = len(data.columns) + + if naxes == 0: + raise ValueError( + "hist method requires numerical or datetime columns, nothing to plot." + ) + + fig, axes = create_subplots( + naxes=naxes, + ax=ax, + squeeze=False, + sharex=sharex, + sharey=sharey, + figsize=figsize, + layout=layout, + ) + can_set_label = "label" not in kwds + + for ax, col in zip(flatten_axes(axes), data.columns, strict=False): + if legend and can_set_label: + kwds["label"] = col + ax.hist(data[col].dropna().values, bins=bins, **kwds) + ax.set_title(col) + ax.grid(grid) + if legend: + ax.legend() + + set_ticks_props( + axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot + ) + maybe_adjust_figure(fig, wspace=0.3, hspace=0.3) + + return axes diff --git a/pandas/plotting/_matplotlib/misc.py b/pandas/plotting/_matplotlib/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..271b8f1dc7dc9733fa47d6c132389ed5b66e5c24 --- /dev/null +++ b/pandas/plotting/_matplotlib/misc.py @@ -0,0 +1,480 @@ +from __future__ import annotations + +import random +from typing import TYPE_CHECKING + +import matplotlib as mpl +import numpy as np + +from pandas.core.dtypes.missing import notna + +from pandas.io.formats.printing import pprint_thing +from pandas.plotting._matplotlib.style import get_standard_colors +from pandas.plotting._matplotlib.tools import ( + create_subplots, + do_adjust_figure, + maybe_adjust_figure, + set_ticks_props, +) + +if TYPE_CHECKING: + from collections.abc import Hashable + + from matplotlib.axes import Axes + from matplotlib.figure import Figure + + from pandas import ( + DataFrame, + Index, + Series, + ) + + +def scatter_matrix( + frame: DataFrame, + alpha: float = 0.5, + figsize: tuple[float, float] | None = None, + ax=None, + grid: bool = False, + diagonal: str = "hist", + marker: str = ".", + density_kwds=None, + hist_kwds=None, + range_padding: float = 0.05, + **kwds, +): + df = frame._get_numeric_data() + n = df.columns.size + naxes = n * n + fig, axes = create_subplots(naxes=naxes, figsize=figsize, ax=ax, squeeze=False) + + # no gaps between subplots + maybe_adjust_figure(fig, wspace=0, hspace=0) + + mask = notna(df) + + marker = _get_marker_compat(marker) + + hist_kwds = hist_kwds or {} + density_kwds = density_kwds or {} + + # GH 14855 + kwds.setdefault("edgecolors", "none") + + boundaries_list = [] + for a in df.columns: + values = df[a].values[mask[a].values] + rmin_, rmax_ = np.min(values), np.max(values) + rdelta_ext = (rmax_ - rmin_) * range_padding / 2 + boundaries_list.append((rmin_ - rdelta_ext, rmax_ + rdelta_ext)) + + for i, a in enumerate(df.columns): + for j, b in enumerate(df.columns): + ax = axes[i, j] + + if i == j: + values = df[a].values[mask[a].values] + + # Deal with the diagonal by drawing a histogram there. + if diagonal == "hist": + ax.hist(values, **hist_kwds) + + elif diagonal in ("kde", "density"): + from scipy.stats import gaussian_kde + + y = values + gkde = gaussian_kde(y) + ind = np.linspace(y.min(), y.max(), 1000) + ax.plot(ind, gkde.evaluate(ind), **density_kwds) + + ax.set_xlim(boundaries_list[i]) + + else: + common = (mask[a] & mask[b]).values + + ax.scatter( + df[b][common], df[a][common], marker=marker, alpha=alpha, **kwds + ) + + ax.set_xlim(boundaries_list[j]) + ax.set_ylim(boundaries_list[i]) + + ax.set_xlabel(b) + ax.set_ylabel(a) + + if j != 0: + ax.yaxis.set_visible(False) + if i != n - 1: + ax.xaxis.set_visible(False) + + if len(df.columns) > 1: + lim1 = boundaries_list[0] + locs = axes[0][1].yaxis.get_majorticklocs() + locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])] + adj = (locs - lim1[0]) / (lim1[1] - lim1[0]) + + lim0 = axes[0][0].get_ylim() + adj = adj * (lim0[1] - lim0[0]) + lim0[0] + axes[0][0].yaxis.set_ticks(adj) + + if np.all(locs == locs.astype(int)): + # if all ticks are int + locs = locs.astype(int) + axes[0][0].yaxis.set_ticklabels(locs) + + set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0) + + return axes + + +def _get_marker_compat(marker): + if marker not in mpl.lines.lineMarkers: + return "o" + return marker + + +def radviz( + frame: DataFrame, + class_column, + ax: Axes | None = None, + color=None, + colormap=None, + **kwds, +) -> Axes: + import matplotlib.pyplot as plt + + def normalize(series): + a = min(series) + b = max(series) + return (series - a) / (b - a) + + n = len(frame) + classes = frame[class_column].drop_duplicates() + class_col = frame[class_column] + df = frame.drop(class_column, axis=1).apply(normalize) + + if ax is None: + ax = plt.gca() + ax.set_xlim(-1, 1) + ax.set_ylim(-1, 1) + + to_plot: dict[Hashable, list[list]] = {} + colors = get_standard_colors( + num_colors=len(classes), colormap=colormap, color_type="random", color=color + ) + + for kls in classes: + to_plot[kls] = [[], []] + + m = len(frame.columns) - 1 + s = np.array( + [(np.cos(t), np.sin(t)) for t in [2 * np.pi * (i / m) for i in range(m)]] + ) + + for i in range(n): + row = df.iloc[i].values + row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1) + y = (s * row_).sum(axis=0) / row.sum() + kls = class_col.iat[i] + to_plot[kls][0].append(y[0]) + to_plot[kls][1].append(y[1]) + + for i, kls in enumerate(classes): + ax.scatter( + to_plot[kls][0], + to_plot[kls][1], + color=colors[i], + label=pprint_thing(kls), + **kwds, + ) + ax.legend() + + ax.add_patch(mpl.patches.Circle((0.0, 0.0), radius=1.0, facecolor="none")) + + for xy, name in zip(s, df.columns, strict=True): + ax.add_patch(mpl.patches.Circle(xy, radius=0.025, facecolor="gray")) + + if xy[0] < 0.0 and xy[1] < 0.0: + ax.text( + xy[0] - 0.025, xy[1] - 0.025, name, ha="right", va="top", size="small" + ) + elif xy[0] < 0.0 <= xy[1]: + ax.text( + xy[0] - 0.025, + xy[1] + 0.025, + name, + ha="right", + va="bottom", + size="small", + ) + elif xy[1] < 0.0 <= xy[0]: + ax.text( + xy[0] + 0.025, xy[1] - 0.025, name, ha="left", va="top", size="small" + ) + elif xy[0] >= 0.0 and xy[1] >= 0.0: + ax.text( + xy[0] + 0.025, xy[1] + 0.025, name, ha="left", va="bottom", size="small" + ) + + ax.axis("equal") + return ax + + +def andrews_curves( + frame: DataFrame, + class_column, + ax: Axes | None = None, + samples: int = 200, + color=None, + colormap=None, + **kwds, +) -> Axes: + import matplotlib.pyplot as plt + + def function(amplitudes): + def f(t): + x1 = amplitudes[0] + result = x1 / np.sqrt(2.0) + + # Take the rest of the coefficients and resize them + # appropriately. Take a copy of amplitudes as otherwise numpy + # deletes the element from amplitudes itself. + coeffs = np.delete(np.copy(amplitudes), 0) + coeffs = np.resize(coeffs, (int((coeffs.size + 1) / 2), 2)) + + # Generate the harmonics and arguments for the sin and cos + # functions. + harmonics = np.arange(0, coeffs.shape[0]) + 1 + trig_args = np.outer(harmonics, t) + + result += np.sum( + coeffs[:, 0, np.newaxis] * np.sin(trig_args) + + coeffs[:, 1, np.newaxis] * np.cos(trig_args), + axis=0, + ) + return result + + return f + + n = len(frame) + class_col = frame[class_column] + classes = frame[class_column].drop_duplicates() + df = frame.drop(class_column, axis=1) + t = np.linspace(-np.pi, np.pi, samples) + used_legends: set[str] = set() + + color_values = get_standard_colors( + num_colors=len(classes), colormap=colormap, color_type="random", color=color + ) + colors = dict(zip(classes, color_values, strict=False)) + if ax is None: + ax = plt.gca() + ax.set_xlim(-np.pi, np.pi) + for i in range(n): + row = df.iloc[i].values + f = function(row) + y = f(t) + kls = class_col.iat[i] + label = pprint_thing(kls) + if label not in used_legends: + used_legends.add(label) + ax.plot(t, y, color=colors[kls], label=label, **kwds) + else: + ax.plot(t, y, color=colors[kls], **kwds) + + ax.legend(loc="upper right") + ax.grid() + return ax + + +def bootstrap_plot( + series: Series, + fig: Figure | None = None, + size: int = 50, + samples: int = 500, + **kwds, +) -> Figure: + import matplotlib.pyplot as plt + + # TODO: is the failure mentioned below still relevant? + # random.sample(ndarray, int) fails on python 3.3, sigh + data = list(series.values) + samplings = [random.sample(data, size) for _ in range(samples)] + + means = np.array([np.mean(sampling) for sampling in samplings]) + medians = np.array([np.median(sampling) for sampling in samplings]) + midranges = np.array( + [(min(sampling) + max(sampling)) * 0.5 for sampling in samplings] + ) + if fig is None: + fig = plt.figure() + x = list(range(samples)) + axes = [] + ax1 = fig.add_subplot(2, 3, 1) + ax1.set_xlabel("Sample") + axes.append(ax1) + ax1.plot(x, means, **kwds) + ax2 = fig.add_subplot(2, 3, 2) + ax2.set_xlabel("Sample") + axes.append(ax2) + ax2.plot(x, medians, **kwds) + ax3 = fig.add_subplot(2, 3, 3) + ax3.set_xlabel("Sample") + axes.append(ax3) + ax3.plot(x, midranges, **kwds) + ax4 = fig.add_subplot(2, 3, 4) + ax4.set_xlabel("Mean") + axes.append(ax4) + ax4.hist(means, **kwds) + ax5 = fig.add_subplot(2, 3, 5) + ax5.set_xlabel("Median") + axes.append(ax5) + ax5.hist(medians, **kwds) + ax6 = fig.add_subplot(2, 3, 6) + ax6.set_xlabel("Midrange") + axes.append(ax6) + ax6.hist(midranges, **kwds) + for axis in axes: + plt.setp(axis.get_xticklabels(), fontsize=8) + plt.setp(axis.get_yticklabels(), fontsize=8) + if do_adjust_figure(fig): + plt.tight_layout() + return fig + + +def parallel_coordinates( + frame: DataFrame, + class_column, + cols=None, + ax: Axes | None = None, + color=None, + use_columns: bool = False, + xticks=None, + colormap=None, + axvlines: bool = True, + axvlines_kwds=None, + sort_labels: bool = False, + **kwds, +) -> Axes: + import matplotlib.pyplot as plt + + if axvlines_kwds is None: + axvlines_kwds = {"linewidth": 1, "color": "black"} + + n = len(frame) + classes = frame[class_column].drop_duplicates() + class_col = frame[class_column] + + if cols is None: + df = frame.drop(class_column, axis=1) + else: + df = frame[cols] + + used_legends: set[str] = set() + + ncols = len(df.columns) + + # determine values to use for xticks + x: list[int] | Index + if use_columns is True: + if not np.all(np.isreal(list(df.columns))): + raise ValueError("Columns must be numeric to be used as xticks") + x = df.columns + elif xticks is not None: + if not np.all(np.isreal(xticks)): + raise ValueError("xticks specified must be numeric") + if len(xticks) != ncols: + raise ValueError("Length of xticks must match number of columns") + x = xticks + else: + x = list(range(ncols)) + + if ax is None: + ax = plt.gca() + + color_values = get_standard_colors( + num_colors=len(classes), colormap=colormap, color_type="random", color=color + ) + + if sort_labels: + classes = sorted(classes) + color_values = sorted(color_values) + colors = dict(zip(classes, color_values, strict=True)) + + for i in range(n): + y = df.iloc[i].values + kls = class_col.iat[i] + label = pprint_thing(kls) + if label not in used_legends: + used_legends.add(label) + ax.plot(x, y, color=colors[kls], label=label, **kwds) + else: + ax.plot(x, y, color=colors[kls], **kwds) + + if axvlines: + for i in x: + ax.axvline(i, **axvlines_kwds) + + ax.set_xticks(x) + ax.set_xticklabels(df.columns) + ax.set_xlim(x[0], x[-1]) + ax.legend(loc="upper right") + ax.grid() + return ax + + +def lag_plot(series: Series, lag: int = 1, ax: Axes | None = None, **kwds) -> Axes: + # workaround because `c='b'` is hardcoded in matplotlib's scatter method + import matplotlib.pyplot as plt + + kwds.setdefault("c", plt.rcParams["patch.facecolor"]) + + data = series.values + y1 = data[:-lag] + y2 = data[lag:] + if ax is None: + ax = plt.gca() + ax.set_xlabel("y(t)") + ax.set_ylabel(f"y(t + {lag})") + ax.scatter(y1, y2, **kwds) + return ax + + +def autocorrelation_plot(series: Series, ax: Axes | None = None, **kwds) -> Axes: + import matplotlib.pyplot as plt + + n = len(series) + data = np.asarray(series) + if ax is None: + ax = plt.gca() + ax.set_xlim(1, n) + ax.set_ylim(-1.0, 1.0) + mean = np.mean(data) + c0 = np.sum((data - mean) ** 2) / n + + def r(h): + return ((data[: n - h] - mean) * (data[h:] - mean)).sum() / n / c0 + + x = np.arange(n) + 1 + y = [r(loc) for loc in x] + z95 = 1.959963984540054 + z99 = 2.5758293035489004 + ax.axhline(y=z99 / np.sqrt(n), linestyle="--", color="grey") + ax.axhline(y=z95 / np.sqrt(n), color="grey") + ax.axhline(y=0.0, color="black") + ax.axhline(y=-z95 / np.sqrt(n), color="grey") + ax.axhline(y=-z99 / np.sqrt(n), linestyle="--", color="grey") + ax.set_xlabel("Lag") + ax.set_ylabel("Autocorrelation") + ax.plot(x, y, **kwds) + if "label" in kwds: + ax.legend() + ax.grid() + return ax + + +def unpack_single_str_list(keys): + # GH 42795 + if isinstance(keys, list) and len(keys) == 1: + keys = keys[0] + return keys diff --git a/pandas/plotting/_matplotlib/style.py b/pandas/plotting/_matplotlib/style.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf63c86213924927524c7018bf6dad87f2de636 --- /dev/null +++ b/pandas/plotting/_matplotlib/style.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +from collections.abc import ( + Collection, + Iterator, + Sequence, +) +import itertools +from typing import ( + TYPE_CHECKING, + cast, + overload, +) +import warnings + +import matplotlib as mpl +import matplotlib.colors +import numpy as np + +from pandas._typing import MatplotlibColor as Color +from pandas.util._exceptions import find_stack_level + +from pandas.core.dtypes.common import is_list_like + +if TYPE_CHECKING: + from matplotlib.colors import Colormap + + +@overload +def get_standard_colors( + num_colors: int, + colormap: Colormap | None = ..., + color_type: str = ..., + *, + color: dict[str, Color], +) -> dict[str, Color]: ... + + +@overload +def get_standard_colors( + num_colors: int, + colormap: Colormap | None = ..., + color_type: str = ..., + *, + color: Color | Sequence[Color] | None = ..., +) -> list[Color]: ... + + +@overload +def get_standard_colors( + num_colors: int, + colormap: Colormap | None = ..., + color_type: str = ..., + *, + color: dict[str, Color] | Color | Sequence[Color] | None = ..., +) -> dict[str, Color] | list[Color]: ... + + +def get_standard_colors( + num_colors: int, + colormap: Colormap | None = None, + color_type: str = "default", + *, + color: dict[str, Color] | Color | Sequence[Color] | None = None, +) -> dict[str, Color] | list[Color]: + """ + Get standard colors based on `colormap`, `color_type` or `color` inputs. + + Parameters + ---------- + num_colors : int + Minimum number of colors to be returned. + Ignored if `color` is a dictionary. + colormap : :py:class:`matplotlib.colors.Colormap`, optional + Matplotlib colormap. + When provided, the resulting colors will be derived from the colormap. + color_type : {"default", "random"}, optional + Type of colors to derive. Used if provided `color` and `colormap` are None. + Ignored if either `color` or `colormap` are not None. + color : dict or str or sequence, optional + Color(s) to be used for deriving sequence of colors. + Can be either be a dictionary, or a single color (single color string, + or sequence of floats representing a single color), + or a sequence of colors. + + Returns + ------- + dict or list + Standard colors. Can either be a mapping if `color` was a dictionary, + or a list of colors with a length of `num_colors` or more. + + Warns + ----- + UserWarning + If both `colormap` and `color` are provided. + Parameter `color` will override. + """ + if isinstance(color, dict): + return color + + colors = _derive_colors( + color=color, + colormap=colormap, + color_type=color_type, + num_colors=num_colors, + ) + + return list(_cycle_colors(colors, num_colors=num_colors)) + + +def _derive_colors( + *, + color: Color | Collection[Color] | None, + colormap: str | Colormap | None, + color_type: str, + num_colors: int, +) -> list[Color]: + """ + Derive colors from either `colormap`, `color_type` or `color` inputs. + + Get a list of colors either from `colormap`, or from `color`, + or from `color_type` (if both `colormap` and `color` are None). + + Parameters + ---------- + color : str or sequence, optional + Color(s) to be used for deriving sequence of colors. + Can be either be a single color (single color string, or sequence of floats + representing a single color), or a sequence of colors. + colormap : :py:class:`matplotlib.colors.Colormap`, optional + Matplotlib colormap. + When provided, the resulting colors will be derived from the colormap. + color_type : {"default", "random"}, optional + Type of colors to derive. Used if provided `color` and `colormap` are None. + Ignored if either `color` or `colormap`` are not None. + num_colors : int + Number of colors to be extracted. + + Returns + ------- + list + List of colors extracted. + + Warns + ----- + UserWarning + If both `colormap` and `color` are provided. + Parameter `color` will override. + """ + if color is None and colormap is not None: + return _get_colors_from_colormap(colormap, num_colors=num_colors) + elif color is not None: + if colormap is not None: + warnings.warn( + "'color' and 'colormap' cannot be used simultaneously. Using 'color'", + stacklevel=find_stack_level(), + ) + return _get_colors_from_color(color) + else: + return _get_colors_from_color_type(color_type, num_colors=num_colors) + + +def _cycle_colors(colors: list[Color], num_colors: int) -> Iterator[Color]: + """Cycle colors until achieving max of `num_colors` or length of `colors`. + + Extra colors will be ignored by matplotlib if there are more colors + than needed and nothing needs to be done here. + """ + max_colors = max(num_colors, len(colors)) + yield from itertools.islice(itertools.cycle(colors), max_colors) + + +def _get_colors_from_colormap( + colormap: str | Colormap, + num_colors: int, +) -> list[Color]: + """Get colors from colormap.""" + cmap = _get_cmap_instance(colormap) + return [cmap(num) for num in np.linspace(0, 1, num=num_colors)] + + +def _get_cmap_instance(colormap: str | Colormap) -> Colormap: + """Get instance of matplotlib colormap.""" + if isinstance(colormap, str): + cmap = colormap + colormap = mpl.colormaps[colormap] + if colormap is None: + raise ValueError(f"Colormap {cmap} is not recognized") + return colormap + + +def _get_colors_from_color( + color: Color | Collection[Color], +) -> list[Color]: + """Get colors from user input color.""" + if len(color) == 0: + raise ValueError(f"Invalid color argument: {color}") + + if _is_single_color(color): + color = cast(Color, color) + return [color] + + color = cast(Collection[Color], color) + return list(_gen_list_of_colors_from_iterable(color)) + + +def _is_single_color(color: Color | Collection[Color]) -> bool: + """Check if `color` is a single color, not a sequence of colors. + + Single color is of these kinds: + - Named color "red", "C0", "firebrick" + - Alias "g" + - Sequence of floats, such as (0.1, 0.2, 0.3) or (0.1, 0.2, 0.3, 0.4). + + See Also + -------- + _is_single_string_color + """ + if isinstance(color, str) and _is_single_string_color(color): + # GH #36972 + return True + + if _is_floats_color(color): + return True + + return False + + +def _gen_list_of_colors_from_iterable(color: Collection[Color]) -> Iterator[Color]: + """ + Yield colors from string of several letters or from collection of colors. + """ + for x in color: + if _is_single_color(x): + yield x + else: + raise ValueError(f"Invalid color {x}") + + +def _is_floats_color(color: Color | Collection[Color]) -> bool: + """Check if color comprises a sequence of floats representing color.""" + return bool( + is_list_like(color) + and (len(color) == 3 or len(color) == 4) + and all(isinstance(x, (int, float)) for x in color) + ) + + +def _get_colors_from_color_type(color_type: str, num_colors: int) -> list[Color]: + """Get colors from user input color type.""" + if color_type == "default": + prop_cycle = mpl.rcParams["axes.prop_cycle"] + return [ + c["color"] + for c in itertools.islice(prop_cycle, min(num_colors, len(prop_cycle))) + ] + elif color_type == "random": + return np.random.default_rng(num_colors).random((num_colors, 3)).tolist() + else: + raise ValueError("color_type must be either 'default' or 'random'") + + +def _is_single_string_color(color: Color) -> bool: + """Check if `color` is a single string color. + + Examples of single string colors: + - 'r' + - 'g' + - 'red' + - 'green' + - 'C3' + - 'firebrick' + + Parameters + ---------- + color : Color + Color string or sequence of floats. + + Returns + ------- + bool + True if `color` looks like a valid color. + False otherwise. + """ + conv = matplotlib.colors.ColorConverter() + try: + # error: Argument 1 to "to_rgba" of "ColorConverter" has incompatible type + # "str | Sequence[float]"; expected "tuple[float, float, float] | ..." + conv.to_rgba(color) # type: ignore[arg-type] + except ValueError: + return False + else: + return True diff --git a/pandas/plotting/_matplotlib/timeseries.py b/pandas/plotting/_matplotlib/timeseries.py new file mode 100644 index 0000000000000000000000000000000000000000..5023867445adb844f0d5c3e28183bbf2474e027a --- /dev/null +++ b/pandas/plotting/_matplotlib/timeseries.py @@ -0,0 +1,364 @@ +# TODO: Use the fact that axis can have units to simplify the process + +from __future__ import annotations + +import functools +from typing import ( + TYPE_CHECKING, + Any, +) +import warnings + +from pandas._libs.tslibs import ( + BaseOffset, + Period, + to_offset, +) +from pandas._libs.tslibs.dtypes import ( + OFFSET_TO_PERIOD_FREQSTR, + FreqGroup, +) + +from pandas.core.dtypes.generic import ( + ABCDatetimeIndex, + ABCPeriodIndex, + ABCTimedeltaIndex, +) + +from pandas.io.formats.printing import pprint_thing +from pandas.plotting._matplotlib.converter import ( + TimeSeries_DateFormatter, + TimeSeries_DateLocator, + TimeSeries_TimedeltaFormatter, +) +from pandas.tseries.frequencies import ( + get_period_alias, + is_subperiod, + is_superperiod, +) + +if TYPE_CHECKING: + from datetime import timedelta + + from matplotlib.axes import Axes + + from pandas._typing import NDFrameT + + from pandas import ( + DatetimeIndex, + Index, + PeriodIndex, + Series, + ) + +# --------------------------------------------------------------------- +# Plotting functions and monkey patches + + +def maybe_resample(series: Series, ax: Axes, kwargs: dict[str, Any]): + # resample against axes freq if necessary + + if "how" in kwargs: + raise ValueError( + "'how' is not a valid keyword for plotting functions. If plotting " + "multiple objects on shared axes, resample manually first." + ) + + freq, ax_freq = _get_freq(ax, series) + + if freq is None: # pragma: no cover + raise ValueError("Cannot use dynamic axis without frequency info") + + # Convert DatetimeIndex to PeriodIndex + if isinstance(series.index, ABCDatetimeIndex): + series = series.to_period(freq=freq) + + if ax_freq is not None and freq != ax_freq: + if is_superperiod(freq, ax_freq): # upsample input + series = series.copy(deep=False) + # error: "Index" has no attribute "asfreq" + series.index = series.index.asfreq( # type: ignore[attr-defined] + ax_freq, how="s" + ) + freq = ax_freq + elif _is_sup(freq, ax_freq): # one is weekly + how = "last" + series = getattr(series.resample("D"), how)().dropna() + series = getattr(series.resample(ax_freq), how)().dropna() + freq = ax_freq + elif is_subperiod(freq, ax_freq) or _is_sub(freq, ax_freq): + _upsample_others(ax, freq, kwargs) + else: # pragma: no cover + raise ValueError("Incompatible frequency conversion") + return freq, series + + +def _is_sub(f1: str, f2: str) -> bool: + return (f1.startswith("W") and is_subperiod("D", f2)) or ( + f2.startswith("W") and is_subperiod(f1, "D") + ) + + +def _is_sup(f1: str, f2: str) -> bool: + return (f1.startswith("W") and is_superperiod("D", f2)) or ( + f2.startswith("W") and is_superperiod(f1, "D") + ) + + +def _upsample_others(ax: Axes, freq: BaseOffset, kwargs: dict[str, Any]) -> None: + legend = ax.get_legend() + lines, labels = _replot_ax(ax, freq) + _replot_ax(ax, freq) + + other_ax = None + if hasattr(ax, "left_ax"): + other_ax = ax.left_ax + if hasattr(ax, "right_ax"): + other_ax = ax.right_ax + + if other_ax is not None: + rlines, rlabels = _replot_ax(other_ax, freq) + lines.extend(rlines) + labels.extend(rlabels) + + if legend is not None and kwargs.get("legend", True) and len(lines) > 0: + title: str | None = legend.get_title().get_text() + if title == "None": + title = None + ax.legend(lines, labels, loc="best", title=title) + + +def _replot_ax(ax: Axes, freq: BaseOffset): + data = getattr(ax, "_plot_data", None) + + # clear current axes and data + # TODO #54485 + ax._plot_data = [] # type: ignore[attr-defined] + ax.clear() + + decorate_axes(ax, freq) + + lines = [] + labels = [] + if data is not None: + for series, plotf, kwds in data: + series = series.copy(deep=False) + idx = series.index.asfreq(freq, how="S") + series.index = idx + # TODO #54485 + ax._plot_data.append((series, plotf, kwds)) # type: ignore[attr-defined] + + # for tsplot + if isinstance(plotf, str): + from pandas.plotting._matplotlib import PLOT_CLASSES + + plotf = PLOT_CLASSES[plotf]._plot + + lines.append(plotf(ax, series.index._mpl_repr(), series.values, **kwds)[0]) + labels.append(pprint_thing(series.name)) + + return lines, labels + + +def decorate_axes(ax: Axes, freq: BaseOffset) -> None: + """Initialize axes for time-series plotting""" + if not hasattr(ax, "_plot_data"): + # TODO #54485 + ax._plot_data = [] # type: ignore[attr-defined] + + # TODO #54485 + ax.freq = freq # type: ignore[attr-defined] + xaxis = ax.get_xaxis() + # TODO #54485 + xaxis.freq = freq # type: ignore[attr-defined] + + +def _get_ax_freq(ax: Axes): + """ + Get the freq attribute of the ax object if set. + Also checks shared axes (eg when using secondary yaxis, sharex=True + or twinx) + """ + ax_freq = getattr(ax, "freq", None) + if ax_freq is None: + # check for left/right ax in case of secondary yaxis + if hasattr(ax, "left_ax"): + ax_freq = getattr(ax.left_ax, "freq", None) + elif hasattr(ax, "right_ax"): + ax_freq = getattr(ax.right_ax, "freq", None) + if ax_freq is None: + # check if a shared ax (sharex/twinx) has already freq set + shared_axes = ax.get_shared_x_axes().get_siblings(ax) + if len(shared_axes) > 1: + for shared_ax in shared_axes: + ax_freq = getattr(shared_ax, "freq", None) + if ax_freq is not None: + break + return ax_freq + + +def _get_period_alias(freq: timedelta | BaseOffset | str) -> str | None: + if isinstance(freq, BaseOffset): + freqstr = freq.name + else: + freqstr = to_offset(freq, is_period=True).rule_code + + return get_period_alias(freqstr) + + +def _get_freq(ax: Axes, series: Series): + # get frequency from data + freq = getattr(series.index, "freq", None) + if freq is None: + freq = getattr(series.index, "inferred_freq", None) + freq = to_offset(freq, is_period=True) + + ax_freq = _get_ax_freq(ax) + + # use axes freq if no data freq + if freq is None: + freq = ax_freq + + # get the period frequency + freq = _get_period_alias(freq) + return freq, ax_freq + + +def use_dynamic_x(ax: Axes, index: Index) -> bool: + freq = _get_index_freq(index) + ax_freq = _get_ax_freq(ax) + + if freq is None: # convert irregular if axes has freq info + freq = ax_freq + # do not use tsplot if irregular was plotted first + elif (ax_freq is None) and (len(ax.get_lines()) > 0): + return False + + if freq is None: + return False + + freq_str = _get_period_alias(freq) + + if freq_str is None: + return False + + # FIXME: hack this for 0.10.1, creating more technical debt...sigh + if isinstance(index, ABCDatetimeIndex): + # error: "BaseOffset" has no attribute "_period_dtype_code" + freq_str = OFFSET_TO_PERIOD_FREQSTR.get(freq_str, freq_str) + base = to_offset(freq_str, is_period=True)._period_dtype_code # type: ignore[attr-defined] + if base <= FreqGroup.FR_DAY.value: + return index[:1].is_normalized + period = Period(index[0], freq_str) + assert isinstance(period, Period) + return period.to_timestamp().tz_localize(index.tz) == index[0] + return True + + +def _get_index_freq(index: Index) -> BaseOffset | None: + freq = getattr(index, "freq", None) + if freq is None: + freq = getattr(index, "inferred_freq", None) + freq = to_offset(freq) + return freq + + +def maybe_convert_index(ax: Axes, data: NDFrameT) -> NDFrameT: + # tsplot converts automatically, but don't want to convert index + # over and over for DataFrames + if isinstance(data.index, (ABCDatetimeIndex, ABCPeriodIndex)): + freq = _get_index_freq(data.index) + + if freq is None: + freq = _get_ax_freq(ax) + + if freq is None: + raise ValueError("Could not get frequency alias for plotting") + + freq_str = _get_period_alias(freq) + + with warnings.catch_warnings(): + # suppress Period[B] deprecation warning + # TODO: need to find an alternative to this before the deprecation + # is enforced! + warnings.filterwarnings( + "ignore", + r"PeriodDtype\[B\] is deprecated", + category=FutureWarning, + ) + + if isinstance(data.index, ABCDatetimeIndex): + data = data.tz_localize(None).to_period(freq=freq_str) + elif isinstance(data.index, ABCPeriodIndex): + data.index = data.index.asfreq(freq=freq_str, how="start") + return data + + +# Patch methods for subplot. + + +def _format_coord(freq: BaseOffset, t, y) -> str: + time_period = Period(ordinal=int(t), freq=freq) + return f"t = {time_period} y = {y:8f}" + + +def format_dateaxis( + subplot, freq: BaseOffset, index: DatetimeIndex | PeriodIndex +) -> None: + """ + Pretty-formats the date axis (x-axis). + + Major and minor ticks are automatically set for the frequency of the + current underlying series. As the dynamic mode is activated by + default, changing the limits of the x axis will intelligently change + the positions of the ticks. + """ + import matplotlib.pyplot as plt + + # handle index specific formatting + # Note: DatetimeIndex does not use this + # interface. DatetimeIndex uses matplotlib.date directly + if isinstance(index, ABCPeriodIndex): + majlocator = TimeSeries_DateLocator( + freq, dynamic_mode=True, minor_locator=False, plot_obj=subplot + ) + minlocator = TimeSeries_DateLocator( + freq, dynamic_mode=True, minor_locator=True, plot_obj=subplot + ) + subplot.xaxis.set_major_locator(majlocator) + subplot.xaxis.set_minor_locator(minlocator) + + majformatter = TimeSeries_DateFormatter( + freq, dynamic_mode=True, minor_locator=False, plot_obj=subplot + ) + minformatter = TimeSeries_DateFormatter( + freq, dynamic_mode=True, minor_locator=True, plot_obj=subplot + ) + subplot.xaxis.set_major_formatter(majformatter) + subplot.xaxis.set_minor_formatter(minformatter) + + # x and y coord info + subplot.format_coord = functools.partial(_format_coord, freq) + + elif isinstance(index, ABCTimedeltaIndex): + subplot.xaxis.set_major_formatter(TimeSeries_TimedeltaFormatter(index.unit)) + else: + raise TypeError("index type not supported") + + plt.draw_if_interactive() + + +def prepare_ts_data( + series: Series, ax: Axes, kwargs: dict[str, Any] +) -> tuple[BaseOffset | str, Series]: + freq, data = maybe_resample(series, ax, kwargs) + + # Set ax with freq info + decorate_axes(ax, freq) + # digging deeper + if hasattr(ax, "left_ax"): + decorate_axes(ax.left_ax, freq) + if hasattr(ax, "right_ax"): + decorate_axes(ax.right_ax, freq) + + return freq, data diff --git a/pandas/plotting/_matplotlib/tools.py b/pandas/plotting/_matplotlib/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..8ee75e7fe553ee7ddae5418da9d32dd768857a78 --- /dev/null +++ b/pandas/plotting/_matplotlib/tools.py @@ -0,0 +1,491 @@ +# being a bit too dynamic +from __future__ import annotations + +from math import ceil +from typing import TYPE_CHECKING +import warnings + +import matplotlib as mpl +import numpy as np + +from pandas.util._exceptions import find_stack_level + +from pandas.core.dtypes.common import is_list_like +from pandas.core.dtypes.generic import ( + ABCDataFrame, + ABCIndex, + ABCSeries, +) + +if TYPE_CHECKING: + from collections.abc import ( + Generator, + Iterable, + ) + + from matplotlib.axes import Axes + from matplotlib.axis import Axis + from matplotlib.figure import Figure + from matplotlib.lines import Line2D + from matplotlib.table import Table + + from pandas import ( + DataFrame, + Series, + ) + + +def do_adjust_figure(fig: Figure) -> bool: + """Whether fig has constrained_layout enabled.""" + if not hasattr(fig, "get_constrained_layout"): + return False + return not fig.get_constrained_layout() + + +def maybe_adjust_figure(fig: Figure, *args, **kwargs) -> None: + """Call fig.subplots_adjust unless fig has constrained_layout enabled.""" + if do_adjust_figure(fig): + fig.subplots_adjust(*args, **kwargs) + + +def format_date_labels(ax: Axes, rot) -> None: + # mini version of autofmt_xdate + for label in ax.get_xticklabels(): + label.set_horizontalalignment("right") + label.set_rotation(rot) + fig = ax.get_figure() + if fig is not None: + # should always be a Figure but can technically be None + maybe_adjust_figure(fig, bottom=0.2) # type: ignore[arg-type] + + +def table( + ax, data: DataFrame | Series, rowLabels=None, colLabels=None, **kwargs +) -> Table: + if isinstance(data, ABCSeries): + data = data.to_frame() + elif isinstance(data, ABCDataFrame): + pass + else: + raise ValueError("Input data must be DataFrame or Series") + + if rowLabels is None: + rowLabels = data.index + + if colLabels is None: + colLabels = data.columns + + cellText = data.values + + # error: Argument "cellText" to "table" has incompatible type "ndarray[Any, + # Any]"; expected "Sequence[Sequence[str]] | None" + return mpl.table.table( + ax, + cellText=cellText, # type: ignore[arg-type] + rowLabels=rowLabels, + colLabels=colLabels, + **kwargs, + ) + + +def _get_layout( + nplots: int, + layout: tuple[int, int] | None = None, + layout_type: str = "box", +) -> tuple[int, int]: + if layout is not None: + if not isinstance(layout, (tuple, list)) or len(layout) != 2: + raise ValueError("Layout must be a tuple of (rows, columns)") + + nrows, ncols = layout + + if nrows == -1 and ncols > 0: + layout = (ceil(nplots / ncols), ncols) + elif ncols == -1 and nrows > 0: + layout = (nrows, ceil(nplots / nrows)) + elif ncols <= 0 and nrows <= 0: + msg = "At least one dimension of layout must be positive" + raise ValueError(msg) + + nrows, ncols = layout + if nrows * ncols < nplots: + raise ValueError( + f"Layout of {nrows}x{ncols} must be larger than required size {nplots}" + ) + + return layout + + if layout_type == "single": + return (1, 1) + elif layout_type == "horizontal": + return (1, nplots) + elif layout_type == "vertical": + return (nplots, 1) + + layouts = {1: (1, 1), 2: (1, 2), 3: (2, 2), 4: (2, 2)} + try: + return layouts[nplots] + except KeyError: + k = 1 + while k**2 < nplots: + k += 1 + + if (k - 1) * k >= nplots: + return k, (k - 1) + else: + return k, k + + +# copied from matplotlib/pyplot.py and modified for pandas.plotting + + +def create_subplots( + naxes: int, + sharex: bool = False, + sharey: bool = False, + squeeze: bool = True, + subplot_kw=None, + ax=None, + layout=None, + layout_type: str = "box", + **fig_kw, +): + """ + Create a figure with a set of subplots already made. + + This utility wrapper makes it convenient to create common layouts of + subplots, including the enclosing figure object, in a single call. + + Parameters + ---------- + naxes : int + Number of required axes. Exceeded axes are set invisible. Default is + nrows * ncols. + + sharex : bool + If True, the X axis will be shared amongst all subplots. + + sharey : bool + If True, the Y axis will be shared amongst all subplots. + + squeeze : bool + + If True, extra dimensions are squeezed out from the returned axis object: + - if only one subplot is constructed (nrows=ncols=1), the resulting + single Axis object is returned as a scalar. + - for Nx1 or 1xN subplots, the returned object is a 1-d numpy object + array of Axis objects are returned as numpy 1-d arrays. + - for NxM subplots with N>1 and M>1 are returned as a 2d array. + + If False, no squeezing is done: the returned axis object is always + a 2-d array containing Axis instances, even if it ends up being 1x1. + + subplot_kw : dict + Dict with keywords passed to the add_subplot() call used to create each + subplots. + + ax : Matplotlib axis object, optional + + layout : tuple + Number of rows and columns of the subplot grid. + If not specified, calculated from naxes and layout_type + + layout_type : {'box', 'horizontal', 'vertical'}, default 'box' + Specify how to layout the subplot grid. + + fig_kw : Other keyword arguments to be passed to the figure() call. + Note that all keywords not recognized above will be + automatically included here. + + Returns + ------- + fig, ax : tuple + - fig is the Matplotlib Figure object + - ax can be either a single axis object or an array of axis objects if + more than one subplot was created. The dimensions of the resulting array + can be controlled with the squeeze keyword, see above. + + Examples + -------- + x = np.linspace(0, 2*np.pi, 400) + y = np.sin(x**2) + + # Just a figure and one subplot + f, ax = plt.subplots() + ax.plot(x, y) + ax.set_title('Simple plot') + + # Two subplots, unpack the output array immediately + f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) + ax1.plot(x, y) + ax1.set_title('Sharing Y axis') + ax2.scatter(x, y) + + # Four polar axes + plt.subplots(2, 2, subplot_kw=dict(polar=True)) + """ + import matplotlib.pyplot as plt + + if subplot_kw is None: + subplot_kw = {} + + if ax is None: + fig = plt.figure(**fig_kw) + else: + if is_list_like(ax): + if squeeze: + ax = np.fromiter(flatten_axes(ax), dtype=object) + if layout is not None: + warnings.warn( + "When passing multiple axes, layout keyword is ignored.", + UserWarning, + stacklevel=find_stack_level(), + ) + if sharex or sharey: + warnings.warn( + "When passing multiple axes, sharex and sharey " + "are ignored. These settings must be specified when creating axes.", + UserWarning, + stacklevel=find_stack_level(), + ) + if ax.size == naxes: + fig = ax.flat[0].get_figure() + return fig, ax + else: + raise ValueError( + f"The number of passed axes must be {naxes}, the " + "same as the output plot" + ) + + fig = ax.get_figure() + # if ax is passed and a number of subplots is 1, return ax as it is + if naxes == 1: + if squeeze: + return fig, ax + else: + return fig, np.fromiter(flatten_axes(ax), dtype=object) + else: + warnings.warn( + "To output multiple subplots, the figure containing " + "the passed axes is being cleared.", + UserWarning, + stacklevel=find_stack_level(), + ) + fig.clear() + + nrows, ncols = _get_layout(naxes, layout=layout, layout_type=layout_type) + nplots = nrows * ncols + + # Create empty object array to hold all axes. It's easiest to make it 1-d + # so we can just append subplots upon creation, and then + axarr = np.empty(nplots, dtype=object) + + # Create first subplot separately, so we can share it if requested + ax0 = fig.add_subplot(nrows, ncols, 1, **subplot_kw) + + if sharex: + subplot_kw["sharex"] = ax0 + if sharey: + subplot_kw["sharey"] = ax0 + axarr[0] = ax0 + + # Note off-by-one counting because add_subplot uses the MATLAB 1-based + # convention. + for i in range(1, nplots): + kwds = subplot_kw.copy() + # Set sharex and sharey to None for blank/dummy axes, these can + # interfere with proper axis limits on the visible axes if + # they share axes e.g. issue #7528 + if i >= naxes: + kwds["sharex"] = None + kwds["sharey"] = None + ax = fig.add_subplot(nrows, ncols, i + 1, **kwds) + axarr[i] = ax + + if naxes != nplots: + for ax in axarr[naxes:]: + ax.set_visible(False) + + handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey) + + if squeeze: + # Reshape the array to have the final desired dimension (nrow,ncol), + # though discarding unneeded dimensions that equal 1. If we only have + # one subplot, just return it instead of a 1-element array. + if nplots == 1: + axes = axarr[0] + else: + axes = axarr.reshape(nrows, ncols).squeeze() + else: + # returned axis array will be always 2-d, even if nrows=ncols=1 + axes = axarr.reshape(nrows, ncols) + + return fig, axes + + +def _remove_labels_from_axis(axis: Axis) -> None: + for t in axis.get_majorticklabels(): + t.set_visible(False) + + # set_visible will not be effective if + # minor axis has NullLocator and NullFormatter (default) + if isinstance(axis.get_minor_locator(), mpl.ticker.NullLocator): + axis.set_minor_locator(mpl.ticker.AutoLocator()) + if isinstance(axis.get_minor_formatter(), mpl.ticker.NullFormatter): + axis.set_minor_formatter(mpl.ticker.FormatStrFormatter("")) + for t in axis.get_minorticklabels(): + t.set_visible(False) + + axis.get_label().set_visible(False) + + +def _has_externally_shared_axis(ax1: Axes, compare_axis: str) -> bool: + """ + Return whether an axis is externally shared. + + Parameters + ---------- + ax1 : matplotlib.axes.Axes + Axis to query. + compare_axis : str + `"x"` or `"y"` according to whether the X-axis or Y-axis is being + compared. + + Returns + ------- + bool + `True` if the axis is externally shared. Otherwise `False`. + + Notes + ----- + If two axes with different positions are sharing an axis, they can be + referred to as *externally* sharing the common axis. + + If two axes sharing an axis also have the same position, they can be + referred to as *internally* sharing the common axis (a.k.a twinning). + + _handle_shared_axes() is only interested in axes externally sharing an + axis, regardless of whether either of the axes is also internally sharing + with a third axis. + """ + if compare_axis == "x": + axes = ax1.get_shared_x_axes() + elif compare_axis == "y": + axes = ax1.get_shared_y_axes() + else: + raise ValueError( + "_has_externally_shared_axis() needs 'x' or 'y' as a second parameter" + ) + + axes_siblings = axes.get_siblings(ax1) + + # Retain ax1 and any of its siblings which aren't in the same position as it + ax1_points = ax1.get_position().get_points() + + for ax2 in axes_siblings: + if not np.array_equal(ax1_points, ax2.get_position().get_points()): + return True + + return False + + +def handle_shared_axes( + axarr: Iterable[Axes], + nplots: int, + naxes: int, + nrows: int, + ncols: int, + sharex: bool, + sharey: bool, +) -> None: + if nplots > 1: + row_num = lambda x: x.get_subplotspec().rowspan.start + col_num = lambda x: x.get_subplotspec().colspan.start + + is_first_col = lambda x: x.get_subplotspec().is_first_col() + + if nrows > 1: + try: + # first find out the ax layout, + # so that we can correctly handle 'gaps" + layout = np.zeros((nrows + 1, ncols + 1), dtype=np.bool_) + for ax in axarr: + layout[row_num(ax), col_num(ax)] = ax.get_visible() + + for ax in axarr: + # only the last row of subplots should get x labels -> all + # other off layout handles the case that the subplot is + # the last in the column, because below is no subplot/gap. + if not layout[row_num(ax) + 1, col_num(ax)]: + continue + if sharex or _has_externally_shared_axis(ax, "x"): + _remove_labels_from_axis(ax.xaxis) + + except IndexError: + # if gridspec is used, ax.rowNum and ax.colNum may different + # from layout shape. in this case, use last_row logic + is_last_row = lambda x: x.get_subplotspec().is_last_row() + for ax in axarr: + if is_last_row(ax): + continue + if sharex or _has_externally_shared_axis(ax, "x"): + _remove_labels_from_axis(ax.xaxis) + + if ncols > 1: + for ax in axarr: + # only the first column should get y labels -> set all other to + # off as we only have labels in the first column and we always + # have a subplot there, we can skip the layout test + if is_first_col(ax): + continue + if sharey or _has_externally_shared_axis(ax, "y"): + _remove_labels_from_axis(ax.yaxis) + + +def flatten_axes(axes: Axes | Iterable[Axes]) -> Generator[Axes]: + if not is_list_like(axes): + yield axes # type: ignore[misc] + elif isinstance(axes, (np.ndarray, ABCIndex)): + yield from np.asarray(axes).reshape(-1) + else: + yield from axes # type: ignore[misc] + + +def set_ticks_props( + axes: Axes | Iterable[Axes], + xlabelsize: int | None = None, + xrot=None, + ylabelsize: int | None = None, + yrot=None, +): + for ax in flatten_axes(axes): + if xlabelsize is not None: + mpl.artist.setp(ax.get_xticklabels(), fontsize=xlabelsize) # type: ignore[arg-type] + if xrot is not None: + mpl.artist.setp(ax.get_xticklabels(), rotation=xrot) # type: ignore[arg-type] + if ylabelsize is not None: + mpl.artist.setp(ax.get_yticklabels(), fontsize=ylabelsize) # type: ignore[arg-type] + if yrot is not None: + mpl.artist.setp(ax.get_yticklabels(), rotation=yrot) # type: ignore[arg-type] + return axes + + +def get_all_lines(ax: Axes) -> list[Line2D]: + lines = ax.get_lines() + + if hasattr(ax, "right_ax"): + lines += ax.right_ax.get_lines() + + if hasattr(ax, "left_ax"): + lines += ax.left_ax.get_lines() + + return lines + + +def get_xlim(lines: Iterable[Line2D]) -> tuple[float, float]: + left, right = np.inf, -np.inf + for line in lines: + x = line.get_xdata(orig=False) + left = min(np.nanmin(x), left) + right = max(np.nanmax(x), right) + return left, right diff --git a/pandas/plotting/_misc.py b/pandas/plotting/_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..4c445c570ae33ddf1022eb953319bf8badd84ad9 --- /dev/null +++ b/pandas/plotting/_misc.py @@ -0,0 +1,780 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import ( + TYPE_CHECKING, + Any, +) + +from pandas.util._decorators import set_module + +from pandas.plotting._core import _get_plot_backend + +if TYPE_CHECKING: + from collections.abc import ( + Generator, + Mapping, + ) + + from matplotlib.axes import Axes + from matplotlib.colors import Colormap + from matplotlib.figure import Figure + from matplotlib.table import Table + import numpy as np + + from pandas import ( + DataFrame, + Series, + ) + + +@set_module("pandas.plotting") +def table(ax: Axes, data: DataFrame | Series, **kwargs) -> Table: + """ + Helper function to convert DataFrame and Series to matplotlib.table. + + This method provides an easy way to visualize tabular data within a Matplotlib + figure. It automatically extracts index and column labels from the DataFrame + or Series, unless explicitly specified. This function is particularly useful + when displaying summary tables alongside other plots or when creating static + reports. It utilizes the `matplotlib.pyplot.table` backend and allows + customization through various styling options available in Matplotlib. + + Parameters + ---------- + ax : Matplotlib axes object + The axes on which to draw the table. + data : DataFrame or Series + Data for table contents. + **kwargs + Keyword arguments to be passed to matplotlib.table.table. + If `rowLabels` or `colLabels` is not specified, data index or column + names will be used. + + Returns + ------- + matplotlib table object + The created table as a matplotlib Table object. + + See Also + -------- + DataFrame.plot : Make plots of DataFrame using matplotlib. + matplotlib.pyplot.table : Create a table from data in a Matplotlib plot. + + Examples + -------- + + .. plot:: + :context: close-figs + + >>> import matplotlib.pyplot as plt + >>> df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}) + >>> fig, ax = plt.subplots() + >>> ax.axis("off") + (np.float64(0.0), np.float64(1.0), np.float64(0.0), np.float64(1.0)) + >>> table = pd.plotting.table( + ... ax, df, loc="center", cellLoc="center", colWidths=[0.2, 0.2] + ... ) + """ + plot_backend = _get_plot_backend("matplotlib") + return plot_backend.table( + ax=ax, data=data, rowLabels=None, colLabels=None, **kwargs + ) + + +@set_module("pandas.plotting") +def register() -> None: + """ + Register pandas formatters and converters with matplotlib. + + This function modifies the global ``matplotlib.units.registry`` + dictionary. pandas adds custom converters for + + * pd.Timestamp + * pd.Period + * np.datetime64 + * datetime.datetime + * datetime.date + * datetime.time + + See Also + -------- + deregister_matplotlib_converters : Remove pandas formatters and converters. + + Examples + -------- + .. plot:: + :context: close-figs + + The following line is done automatically by pandas so + the plot can be rendered: + + >>> pd.plotting.register_matplotlib_converters() + + >>> df = pd.DataFrame( + ... {"ts": pd.period_range("2020", periods=2, freq="M"), "y": [1, 2]} + ... ) + >>> plot = df.plot.line(x="ts", y="y") + + Unsetting the register manually an error will be raised: + + >>> pd.set_option( + ... "plotting.matplotlib.register_converters", False + ... ) # doctest: +SKIP + >>> df.plot.line(x="ts", y="y") # doctest: +SKIP + Traceback (most recent call last): + TypeError: float() argument must be a string or a real number, not 'Period' + """ + plot_backend = _get_plot_backend("matplotlib") + plot_backend.register() + + +@set_module("pandas.plotting") +def deregister() -> None: + """ + Remove pandas formatters and converters. + + Removes the custom converters added by :func:`register`. This + attempts to set the state of the registry back to the state before + pandas registered its own units. Converters for pandas' own types like + Timestamp and Period are removed completely. Converters for types + pandas overwrites, like ``datetime.datetime``, are restored to their + original value. + + See Also + -------- + register_matplotlib_converters : Register pandas formatters and converters + with matplotlib. + + Examples + -------- + .. plot:: + :context: close-figs + + The following line is done automatically by pandas so + the plot can be rendered: + + >>> pd.plotting.register_matplotlib_converters() + + >>> df = pd.DataFrame( + ... {"ts": pd.period_range("2020", periods=2, freq="M"), "y": [1, 2]} + ... ) + >>> plot = df.plot.line(x="ts", y="y") + + Unsetting the register manually an error will be raised: + + >>> pd.set_option( + ... "plotting.matplotlib.register_converters", False + ... ) # doctest: +SKIP + >>> df.plot.line(x="ts", y="y") # doctest: +SKIP + Traceback (most recent call last): + TypeError: float() argument must be a string or a real number, not 'Period' + """ + plot_backend = _get_plot_backend("matplotlib") + plot_backend.deregister() + + +@set_module("pandas.plotting") +def scatter_matrix( + frame: DataFrame, + alpha: float = 0.5, + figsize: tuple[float, float] | None = None, + ax: Axes | None = None, + grid: bool = False, + diagonal: str = "hist", + marker: str = ".", + density_kwds: Mapping[str, Any] | None = None, + hist_kwds: Mapping[str, Any] | None = None, + range_padding: float = 0.05, + **kwargs, +) -> np.ndarray: + """ + Draw a matrix of scatter plots. + + Each pair of numeric columns in the DataFrame is plotted against each other, + resulting in a matrix of scatter plots. The diagonal plots can display either + histograms or Kernel Density Estimation (KDE) plots for each variable. + + Parameters + ---------- + frame : DataFrame + The data to be plotted. + alpha : float, optional + Amount of transparency applied. + figsize : (float,float), optional + A tuple (width, height) in inches. + ax : Matplotlib axis object, optional + An existing Matplotlib axis object for the plots. If None, a new axis is + created. + grid : bool, optional + Setting this to True will show the grid. + diagonal : {'hist', 'kde'} + Pick between 'kde' and 'hist' for either Kernel Density Estimation or + Histogram plot in the diagonal. + marker : str, optional + Matplotlib marker type, default '.'. + density_kwds : keywords + Keyword arguments to be passed to kernel density estimate plot. + hist_kwds : keywords + Keyword arguments to be passed to hist function. + range_padding : float, default 0.05 + Relative extension of axis range in x and y with respect to + (x_max - x_min) or (y_max - y_min). + **kwargs + Keyword arguments to be passed to scatter function. + + Returns + ------- + numpy.ndarray + A matrix of scatter plots. + + See Also + -------- + plotting.parallel_coordinates : Plots parallel coordinates for multivariate data. + plotting.andrews_curves : Generates Andrews curves for visualizing clusters of + multivariate data. + plotting.radviz : Creates a RadViz visualization. + plotting.bootstrap_plot : Visualizes uncertainty in data via bootstrap sampling. + + Examples + -------- + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame(np.random.randn(1000, 4), columns=["A", "B", "C", "D"]) + >>> pd.plotting.scatter_matrix(df, alpha=0.2) + array([[, , + , ], + [, , + , ], + [, , + , ], + [, , + , ]], + dtype=object) + """ + plot_backend = _get_plot_backend("matplotlib") + return plot_backend.scatter_matrix( + frame=frame, + alpha=alpha, + figsize=figsize, + ax=ax, + grid=grid, + diagonal=diagonal, + marker=marker, + density_kwds=density_kwds, + hist_kwds=hist_kwds, + range_padding=range_padding, + **kwargs, + ) + + +@set_module("pandas.plotting") +def radviz( + frame: DataFrame, + class_column: str, + ax: Axes | None = None, + color: list[str] | tuple[str, ...] | None = None, + colormap: Colormap | str | None = None, + **kwds, +) -> Axes: + """ + Plot a multidimensional dataset in 2D. + + Each Series in the DataFrame is represented as an evenly distributed + slice on a circle. Each data point is rendered in the circle according to + the value on each Series. Highly correlated `Series` in the `DataFrame` + are placed closer on the unit circle. + + RadViz allow to project an N-dimensional data set into a 2D space where the + influence of each dimension can be interpreted as a balance between the + influence of all dimensions. + + More info available at the `original article + `_ + describing RadViz. + + Parameters + ---------- + frame : `DataFrame` + Object holding the data. + class_column : str + Column name containing the name of the data point category. + ax : :class:`matplotlib.axes.Axes`, optional + A plot instance to which to add the information. + color : list[str] or tuple[str], optional + Assign a color to each category. Example: ['blue', 'green']. + colormap : str or :class:`matplotlib.colors.Colormap`, default None + Colormap to select colors from. If string, load colormap with that + name from matplotlib. + **kwds + Options to pass to matplotlib scatter plotting method. + + Returns + ------- + :class:`matplotlib.axes.Axes` + The Axes object from Matplotlib. + + See Also + -------- + plotting.andrews_curves : Plot clustering visualization. + + Examples + -------- + + .. plot:: + :context: close-figs + + >>> df = pd.DataFrame( + ... { + ... "SepalLength": [6.5, 7.7, 5.1, 5.8, 7.6, 5.0, 5.4, 4.6, 6.7, 4.6], + ... "SepalWidth": [3.0, 3.8, 3.8, 2.7, 3.0, 2.3, 3.0, 3.2, 3.3, 3.6], + ... "PetalLength": [5.5, 6.7, 1.9, 5.1, 6.6, 3.3, 4.5, 1.4, 5.7, 1.0], + ... "PetalWidth": [1.8, 2.2, 0.4, 1.9, 2.1, 1.0, 1.5, 0.2, 2.1, 0.2], + ... "Category": [ + ... "virginica", + ... "virginica", + ... "setosa", + ... "virginica", + ... "virginica", + ... "versicolor", + ... "versicolor", + ... "setosa", + ... "virginica", + ... "setosa", + ... ], + ... } + ... ) + >>> pd.plotting.radviz(df, "Category") # doctest: +SKIP + """ + plot_backend = _get_plot_backend("matplotlib") + return plot_backend.radviz( + frame=frame, + class_column=class_column, + ax=ax, + color=color, + colormap=colormap, + **kwds, + ) + + +@set_module("pandas.plotting") +def andrews_curves( + frame: DataFrame, + class_column: str, + ax: Axes | None = None, + samples: int = 200, + color: list[str] | tuple[str, ...] | None = None, + colormap: Colormap | str | None = None, + **kwargs, +) -> Axes: + """ + Generate a matplotlib plot for visualizing clusters of multivariate data. + + Andrews curves have the functional form: + + .. math:: + f(t) = \\frac{x_1}{\\sqrt{2}} + x_2 \\sin(t) + x_3 \\cos(t) + + x_4 \\sin(2t) + x_5 \\cos(2t) + \\cdots + + Where :math:`x` coefficients correspond to the values of each dimension + and :math:`t` is linearly spaced between :math:`-\\pi` and :math:`+\\pi`. + Each row of frame then corresponds to a single curve. + + Parameters + ---------- + frame : DataFrame + Data to be plotted, preferably normalized to (0.0, 1.0). + class_column : label + Name of the column containing class names. + ax : axes object, default None + Axes to use. + samples : int + Number of points to plot in each curve. + color : str, list[str] or tuple[str], optional + Colors to use for the different classes. Colors can be strings + or 3-element floating point RGB values. + colormap : str or matplotlib colormap object, default None + Colormap to select colors from. If a string, load colormap with that + name from matplotlib. + **kwargs + Options to pass to matplotlib plotting method. + + Returns + ------- + :class:`matplotlib.axes.Axes` + The matplotlib Axes object with the plot. + + See Also + -------- + plotting.parallel_coordinates : Plot parallel coordinates chart. + DataFrame.plot : Make plots of Series or DataFrame. + + Examples + -------- + + .. plot:: + :context: close-figs + + >>> df = pd.read_csv( + ... "https://raw.githubusercontent.com/pandas-dev/" + ... "pandas/main/pandas/tests/io/data/csv/iris.csv" + ... ) # doctest: +SKIP + >>> pd.plotting.andrews_curves(df, "Name") # doctest: +SKIP + """ + plot_backend = _get_plot_backend("matplotlib") + return plot_backend.andrews_curves( + frame=frame, + class_column=class_column, + ax=ax, + samples=samples, + color=color, + colormap=colormap, + **kwargs, + ) + + +@set_module("pandas.plotting") +def bootstrap_plot( + series: Series, + fig: Figure | None = None, + size: int = 50, + samples: int = 500, + **kwds, +) -> Figure: + """ + Bootstrap plot on mean, median and mid-range statistics. + + The bootstrap plot is used to estimate the uncertainty of a statistic + by relying on random sampling with replacement [1]_. This function will + generate bootstrapping plots for mean, median and mid-range statistics + for the given number of samples of the given size. + + .. [1] "Bootstrapping (statistics)" in \ + https://en.wikipedia.org/wiki/Bootstrapping_%28statistics%29 + + Parameters + ---------- + series : pandas.Series + Series from where to get the samplings for the bootstrapping. + fig : matplotlib.figure.Figure, default None + If given, it will use the `fig` reference for plotting instead of + creating a new one with default parameters. + size : int, default 50 + Number of data points to consider during each sampling. It must be + less than or equal to the length of the `series`. + samples : int, default 500 + Number of times the bootstrap procedure is performed. + **kwds + Options to pass to matplotlib plotting method. + + Returns + ------- + matplotlib.figure.Figure + Matplotlib figure. + + See Also + -------- + DataFrame.plot : Basic plotting for DataFrame objects. + Series.plot : Basic plotting for Series objects. + + Examples + -------- + This example draws a basic bootstrap plot for a Series. + + .. plot:: + :context: close-figs + + >>> s = pd.Series(np.random.uniform(size=100)) + >>> pd.plotting.bootstrap_plot(s) # doctest: +SKIP +
+ """ + plot_backend = _get_plot_backend("matplotlib") + return plot_backend.bootstrap_plot( + series=series, fig=fig, size=size, samples=samples, **kwds + ) + + +@set_module("pandas.plotting") +def parallel_coordinates( + frame: DataFrame, + class_column: str, + cols: list[str] | None = None, + ax: Axes | None = None, + color: list[str] | tuple[str, ...] | None = None, + use_columns: bool = False, + xticks: list | tuple | None = None, + colormap: Colormap | str | None = None, + axvlines: bool = True, + axvlines_kwds: Mapping[str, Any] | None = None, + sort_labels: bool = False, + **kwargs, +) -> Axes: + """ + Parallel coordinates plotting. + + Parameters + ---------- + frame : DataFrame + The DataFrame to be plotted. + class_column : str + Column name containing class names. + cols : list, optional + A list of column names to use. + ax : matplotlib.axis, optional + Matplotlib axis object. + color : list or tuple, optional + Colors to use for the different classes. + use_columns : bool, optional + If true, columns will be used as xticks. + xticks : list or tuple, optional + A list of values to use for xticks. + colormap : str or matplotlib colormap, default None + Colormap to use for line colors. + axvlines : bool, optional + If true, vertical lines will be added at each xtick. + axvlines_kwds : keywords, optional + Options to be passed to axvline method for vertical lines. + sort_labels : bool, default False + Sort class_column labels, useful when assigning colors. + **kwargs + Options to pass to matplotlib plotting method. + + Returns + ------- + matplotlib.axes.Axes + The matplotlib axes containing the parallel coordinates plot. + + See Also + -------- + plotting.andrews_curves : Generate a matplotlib plot for visualizing clusters + of multivariate data. + plotting.radviz : Plot a multidimensional dataset in 2D. + + Examples + -------- + + .. plot:: + :context: close-figs + + >>> df = pd.read_csv( + ... "https://raw.githubusercontent.com/pandas-dev/" + ... "pandas/main/pandas/tests/io/data/csv/iris.csv" + ... ) # doctest: +SKIP + >>> pd.plotting.parallel_coordinates( + ... df, "Name", color=("#556270", "#4ECDC4", "#C7F464") + ... ) # doctest: +SKIP + """ + plot_backend = _get_plot_backend("matplotlib") + return plot_backend.parallel_coordinates( + frame=frame, + class_column=class_column, + cols=cols, + ax=ax, + color=color, + use_columns=use_columns, + xticks=xticks, + colormap=colormap, + axvlines=axvlines, + axvlines_kwds=axvlines_kwds, + sort_labels=sort_labels, + **kwargs, + ) + + +@set_module("pandas.plotting") +def lag_plot(series: Series, lag: int = 1, ax: Axes | None = None, **kwds) -> Axes: + """ + Lag plot for time series. + + A lag plot is a scatter plot of a time series against a lag of itself. It helps + in visualizing the temporal dependence between observations by plotting the values + at time `t` on the x-axis and the values at time `t + lag` on the y-axis. + + Parameters + ---------- + series : Series + The time series to visualize. + lag : int, default 1 + Lag length of the scatter plot. + ax : Matplotlib axis object, optional + The matplotlib axis object to use. + **kwds + Matplotlib scatter method keyword arguments. + + Returns + ------- + matplotlib.axes.Axes + The matplotlib Axes object containing the lag plot. + + See Also + -------- + plotting.autocorrelation_plot : Autocorrelation plot for time series. + matplotlib.pyplot.scatter : A scatter plot of y vs. x with varying marker size + and/or color in Matplotlib. + + Examples + -------- + Lag plots are most commonly used to look for patterns in time series data. + + Given the following time series + + .. plot:: + :context: close-figs + + >>> np.random.seed(5) + >>> x = np.cumsum(np.random.normal(loc=1, scale=5, size=50)) + >>> s = pd.Series(x) + >>> s.plot() # doctest: +SKIP + + A lag plot with ``lag=1`` returns + + .. plot:: + :context: close-figs + + >>> _ = pd.plotting.lag_plot(s, lag=1) + """ + plot_backend = _get_plot_backend("matplotlib") + return plot_backend.lag_plot(series=series, lag=lag, ax=ax, **kwds) + + +@set_module("pandas.plotting") +def autocorrelation_plot(series: Series, ax: Axes | None = None, **kwargs) -> Axes: + """ + Autocorrelation plot for time series. + + This method generates an autocorrelation plot for a given time series, + which helps to identify any periodic structure or correlation within the + data across various lags. It shows the correlation of a time series with a + delayed copy of itself as a function of delay. Autocorrelation plots are useful for + checking randomness in a data set. If the data are random, the autocorrelations + should be near zero for any and all time-lag separations. If the data are not + random, then one or more of the autocorrelations will be significantly + non-zero. + + Parameters + ---------- + series : Series + The time series to visualize. + ax : Matplotlib axis object, optional + The matplotlib axis object to use. + **kwargs + Options to pass to matplotlib plotting method. + + Returns + ------- + matplotlib.axes.Axes + The matplotlib axes containing the autocorrelation plot. + + See Also + -------- + Series.autocorr : Compute the lag-N autocorrelation for a Series. + plotting.lag_plot : Lag plot for time series. + + Examples + -------- + The horizontal lines in the plot correspond to 95% and 99% confidence bands. + + The dashed line is 99% confidence band. + + .. plot:: + :context: close-figs + + >>> spacing = np.linspace(-9 * np.pi, 9 * np.pi, num=1000) + >>> s = pd.Series(0.7 * np.random.rand(1000) + 0.3 * np.sin(spacing)) + >>> pd.plotting.autocorrelation_plot(s) # doctest: +SKIP + """ + plot_backend = _get_plot_backend("matplotlib") + return plot_backend.autocorrelation_plot(series=series, ax=ax, **kwargs) + + +class _Options(dict): + """ + Stores pandas plotting options. + + Allows for parameter aliasing so you can just use parameter names that are + the same as the plot function parameters, but is stored in a canonical + format that makes it easy to breakdown into groups later. + + See Also + -------- + plotting.register_matplotlib_converters : Register pandas formatters and + converters with matplotlib. + plotting.bootstrap_plot : Bootstrap plot on mean, median and mid-range statistics. + plotting.autocorrelation_plot : Autocorrelation plot for time series. + plotting.lag_plot : Lag plot for time series. + + Examples + -------- + + .. plot:: + :context: close-figs + + >>> np.random.seed(42) + >>> df = pd.DataFrame( + ... {"A": np.random.randn(10), "B": np.random.randn(10)}, + ... index=pd.date_range("1/1/2000", freq="4MS", periods=10), + ... ) + >>> with pd.plotting.plot_params.use("x_compat", True): + ... _ = df["A"].plot(color="r") + ... _ = df["B"].plot(color="g") + """ + + # alias so the names are same as plotting method parameter names + _ALIASES = {"x_compat": "xaxis.compat"} + _DEFAULT_KEYS = ["xaxis.compat"] + + def __init__(self) -> None: + super().__setitem__("xaxis.compat", False) + + def __getitem__(self, key): + key = self._get_canonical_key(key) + if key not in self: + raise ValueError(f"{key} is not a valid pandas plotting option") + return super().__getitem__(key) + + def __setitem__(self, key, value) -> None: + key = self._get_canonical_key(key) + super().__setitem__(key, value) + + def __delitem__(self, key) -> None: + key = self._get_canonical_key(key) + if key in self._DEFAULT_KEYS: + raise ValueError(f"Cannot remove default parameter {key}") + super().__delitem__(key) + + def __contains__(self, key) -> bool: + key = self._get_canonical_key(key) + return super().__contains__(key) + + def reset(self) -> None: + """ + Reset the option store to its initial state + + Returns + ------- + None + """ + # error: Cannot access "__init__" directly + self.__init__() # type: ignore[misc] + + def _get_canonical_key(self, key: str) -> str: + return self._ALIASES.get(key, key) + + @contextmanager + def use(self, key, value) -> Generator[_Options]: + """ + Temporarily set a parameter value using the with statement. + Aliasing allowed. + """ + old_value = self[key] + try: + self[key] = value + yield self + finally: + self[key] = old_value + + +plot_params = _Options() +plot_params.__module__ = "pandas.plotting" diff --git a/pandas/tests/__init__.py b/pandas/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/api/__init__.py b/pandas/tests/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/api/test_api.py b/pandas/tests/api/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..4eefb1163352c16e4d4a263a8af1a457ad510599 --- /dev/null +++ b/pandas/tests/api/test_api.py @@ -0,0 +1,577 @@ +from __future__ import annotations + +import importlib +import inspect +import pathlib +import pkgutil + +import pytest + +import pandas as pd +from pandas import api +import pandas._testing as tm +from pandas.api import ( + executors as api_executors, + extensions as api_extensions, + indexers as api_indexers, + interchange as api_interchange, + types as api_types, + typing as api_typing, +) +from pandas.api.typing import aliases as api_aliases + + +class Base: + def check(self, namespace, expected, ignored=None): + # see which names are in the namespace, minus optional + # ignored ones + # compare vs the expected + + result = sorted( + f for f in dir(namespace) if not f.startswith("__") and f != "annotations" + ) + if ignored is not None: + result = sorted(set(result) - set(ignored)) + + expected = sorted(expected) + tm.assert_almost_equal(result, expected) + + +class TestPDApi(Base): + # these are optionally imported based on testing + # & need to be ignored + ignored = ["tests", "locale", "conftest", "_version_meson"] + + # top-level sub-packages + public_lib = [ + "api", + "arrays", + "options", + "test", + "testing", + "errors", + "plotting", + "io", + "tseries", + ] + private_lib = ["compat", "core", "pandas", "util", "_built_with_meson"] + + # misc + misc = ["IndexSlice", "NaT", "NA"] + + # top-level classes + classes = [ + "ArrowDtype", + "Categorical", + "CategoricalIndex", + "DataFrame", + "DateOffset", + "DatetimeIndex", + "ExcelFile", + "ExcelWriter", + "Flags", + "Grouper", + "HDFStore", + "Index", + "MultiIndex", + "Period", + "PeriodIndex", + "RangeIndex", + "Series", + "SparseDtype", + "StringDtype", + "Timedelta", + "TimedeltaIndex", + "Timestamp", + "Interval", + "IntervalIndex", + "CategoricalDtype", + "PeriodDtype", + "IntervalDtype", + "DatetimeTZDtype", + "BooleanDtype", + "Int8Dtype", + "Int16Dtype", + "Int32Dtype", + "Int64Dtype", + "UInt8Dtype", + "UInt16Dtype", + "UInt32Dtype", + "UInt64Dtype", + "Float32Dtype", + "Float64Dtype", + "NamedAgg", + ] + + # these are already deprecated; awaiting removal + deprecated_classes: list[str] = [] + + # external modules exposed in pandas namespace + modules: list[str] = [] + + # top-level functions + funcs = [ + "array", + "bdate_range", + "col", + "concat", + "crosstab", + "cut", + "date_range", + "interval_range", + "eval", + "factorize", + "get_dummies", + "from_dummies", + "infer_freq", + "isna", + "isnull", + "lreshape", + "melt", + "notna", + "notnull", + "offsets", + "merge", + "merge_ordered", + "merge_asof", + "period_range", + "pivot", + "pivot_table", + "qcut", + "show_versions", + "timedelta_range", + "unique", + "wide_to_long", + ] + + # top-level option funcs + funcs_option = [ + "reset_option", + "describe_option", + "get_option", + "option_context", + "set_option", + "set_eng_float_format", + ] + + # top-level read_* funcs + funcs_read = [ + "read_clipboard", + "read_csv", + "read_excel", + "read_fwf", + "read_hdf", + "read_html", + "read_xml", + "read_json", + "read_pickle", + "read_sas", + "read_sql", + "read_sql_query", + "read_sql_table", + "read_stata", + "read_table", + "read_feather", + "read_parquet", + "read_orc", + "read_spss", + "read_iceberg", + ] + + # top-level json funcs + funcs_json = ["json_normalize"] + + # top-level to_* funcs + funcs_to = ["to_datetime", "to_numeric", "to_pickle", "to_timedelta"] + + # top-level to deprecate in the future + deprecated_funcs_in_future: list[str] = [] + + # these are already deprecated; awaiting removal + deprecated_funcs: list[str] = [] + + # private modules in pandas namespace + private_modules = [ + "_config", + "_libs", + "_is_numpy_dev", + "_pandas_datetime_CAPI", + "_pandas_parser_CAPI", + "_testing", + "_typing", + ] + if not pd._built_with_meson: + private_modules.append("_version") + + def test_api(self): + checkthese = ( + self.public_lib + + self.private_lib + + self.misc + + self.modules + + self.classes + + self.funcs + + self.funcs_option + + self.funcs_read + + self.funcs_json + + self.funcs_to + + self.private_modules + ) + self.check(namespace=pd, expected=checkthese, ignored=self.ignored) + + def test_api_all(self): + expected = set( + self.public_lib + + self.misc + + self.modules + + self.classes + + self.funcs + + self.funcs_option + + self.funcs_read + + self.funcs_json + + self.funcs_to + ) - set(self.deprecated_classes) + actual = set(pd.__all__) + + extraneous = actual - expected + assert not extraneous + + missing = expected - actual + assert not missing + + def test_depr(self): + deprecated_list = ( + self.deprecated_classes + + self.deprecated_funcs + + self.deprecated_funcs_in_future + ) + for depr in deprecated_list: + with tm.assert_produces_warning(FutureWarning): + _ = getattr(pd, depr) + + +class TestApi(Base): + allowed_api_dirs = [ + "executors", + "types", + "extensions", + "indexers", + "interchange", + "typing", + "internals", + ] + allowed_typing = [ + "DataFrameGroupBy", + "DatetimeIndexResamplerGroupby", + "Expanding", + "ExpandingGroupby", + "ExponentialMovingWindow", + "ExponentialMovingWindowGroupby", + "Expression", + "FrozenList", + "JsonReader", + "NaTType", + "NAType", + "NoDefault", + "PeriodIndexResamplerGroupby", + "Resampler", + "Rolling", + "RollingGroupby", + "SeriesGroupBy", + "StataReader", + "SASReader", + "TimedeltaIndexResamplerGroupby", + "TimeGrouper", + "Window", + "aliases", + ] + allowed_api_types = [ + "is_any_real_numeric_dtype", + "is_array_like", + "is_bool", + "is_bool_dtype", + "is_categorical_dtype", + "is_complex", + "is_complex_dtype", + "is_datetime64_any_dtype", + "is_datetime64_dtype", + "is_datetime64_ns_dtype", + "is_datetime64tz_dtype", + "is_dict_like", + "is_dtype_equal", + "is_extension_array_dtype", + "is_file_like", + "is_float", + "is_float_dtype", + "is_hashable", + "is_int64_dtype", + "is_integer", + "is_integer_dtype", + "is_interval_dtype", + "is_iterator", + "is_list_like", + "is_named_tuple", + "is_number", + "is_numeric_dtype", + "is_object_dtype", + "is_period_dtype", + "is_re", + "is_re_compilable", + "is_scalar", + "is_signed_integer_dtype", + "is_sparse", + "is_string_dtype", + "is_timedelta64_dtype", + "is_timedelta64_ns_dtype", + "is_unsigned_integer_dtype", + "pandas_dtype", + "infer_dtype", + "union_categoricals", + "CategoricalDtype", + "DatetimeTZDtype", + "IntervalDtype", + "PeriodDtype", + ] + allowed_api_interchange = ["from_dataframe", "DataFrame"] + allowed_api_indexers = [ + "check_array_indexer", + "BaseIndexer", + "FixedForwardWindowIndexer", + "VariableOffsetWindowIndexer", + ] + allowed_api_extensions = [ + "no_default", + "ExtensionDtype", + "register_extension_dtype", + "register_dataframe_accessor", + "register_index_accessor", + "register_series_accessor", + "take", + "ExtensionArray", + "ExtensionScalarOpsMixin", + ] + allowed_api_executors = ["BaseExecutionEngine"] + allowed_api_aliases = [ + "AggFuncType", + "AlignJoin", + "AnyAll", + "AnyArrayLike", + "ArrayLike", + "AstypeArg", + "Axes", + "Axis", + "CSVEngine", + "ColspaceArgType", + "CompressionOptions", + "CorrelationMethod", + "DropKeep", + "Dtype", + "DtypeArg", + "DtypeBackend", + "DtypeObj", + "ExcelWriterIfSheetExists", + "ExcelWriterMergeCells", + "FilePath", + "FillnaOptions", + "FloatFormatType", + "FormattersType", + "FromDictOrient", + "HTMLFlavors", + "IgnoreRaise", + "IndexLabel", + "InterpolateOptions", + "IntervalClosedType", + "IntervalLeftRight", + "JSONEngine", + "JSONSerializable", + "JoinHow", + "JoinValidate", + "ListLike", + "MergeHow", + "MergeValidate", + "NaPosition", + "NsmallestNlargestKeep", + "OpenFileErrors", + "Ordered", + "ParquetCompressionOptions", + "QuantileInterpolation", + "ReadBuffer", + "ReadCsvBuffer", + "ReadPickleBuffer", + "ReindexMethod", + "Scalar", + "ScalarIndexer", + "SequenceIndexer", + "SequenceNotStr", + "SliceType", + "SortKind", + "StorageOptions", + "Suffixes", + "TakeIndexer", + "TimeAmbiguous", + "TimeGrouperOrigin", + "TimeNonexistent", + "TimeUnit", + "TimedeltaConvertibleTypes", + "TimestampConvertibleTypes", + "ToStataByteorder", + "ToTimestampHow", + "UpdateJoin", + "UsecolsArgType", + "WindowingRankType", + "WriteBuffer", + "WriteExcelBuffer", + "XMLParsers", + ] + + def test_api(self): + self.check(api, self.allowed_api_dirs) + + def test_api_typing(self): + self.check(api_typing, self.allowed_typing) + + def test_api_types(self): + self.check(api_types, self.allowed_api_types) + + def test_api_interchange(self): + self.check(api_interchange, self.allowed_api_interchange) + + def test_api_indexers(self): + self.check(api_indexers, self.allowed_api_indexers) + + def test_api_extensions(self): + self.check(api_extensions, self.allowed_api_extensions) + + def test_api_executors(self): + self.check(api_executors, self.allowed_api_executors) + + def test_api_typing_aliases(self): + self.check(api_aliases, self.allowed_api_aliases) + + +class TestErrors(Base): + def test_errors(self): + ignored = ["_CurrentDeprecationWarning", "abc", "ctypes", "cow"] + self.check(pd.errors, pd.errors.__all__, ignored=ignored) + + +class TestUtil(Base): + def test_util(self): + self.check( + pd.util, + ["hash_array", "hash_pandas_object"], + ignored=[ + "_decorators", + "_test_decorators", + "_exceptions", + "_validators", + "capitalize_first_letter", + "version", + "_print_versions", + "_tester", + ], + ) + + +class TestTesting(Base): + funcs = [ + "assert_frame_equal", + "assert_series_equal", + "assert_index_equal", + "assert_extension_array_equal", + ] + + def test_testing(self): + from pandas import testing + + self.check(testing, self.funcs) + + def test_util_in_top_level(self): + with pytest.raises(AttributeError, match="foo"): + pd.util.foo + + +def get_pandas_objects( + module_name: str, recurse: bool +) -> list[tuple[str, str, object]]: + """ + Get all pandas objects within a module. + + An object is determined to be part of pandas if it has a string + __module__ attribute that starts with ``"pandas"``. + + Parameters + ---------- + module_name : str + Name of the module to search. + recurse : bool + Whether to search submodules. + + Returns + ------- + List of all objects that are determined to be a part of pandas. + """ + module = importlib.import_module(module_name) + objs = [] + + for name, obj in inspect.getmembers(module): + module_dunder = getattr(obj, "__module__", None) + if isinstance(module_dunder, str) and module_dunder.startswith("pandas"): + objs.append((module_name, name, obj)) + + if not recurse: + return objs + + # __file__ can, but shouldn't, be None + assert isinstance(module.__file__, str) + paths = [pathlib.Path(module.__file__).parent] + for module_info in pkgutil.walk_packages(paths): + name = module_info.name + if name.startswith("_") or name == "internals": + continue + objs.extend( + get_pandas_objects(f"{module.__name__}.{name}", recurse=module_info.ispkg) + ) + return objs + + +@pytest.mark.slow +@pytest.mark.parametrize( + "module_name", + [ + "pandas", + "pandas.api", + "pandas.arrays", + "pandas.errors", + pytest.param("pandas.io", marks=pytest.mark.xfail(reason="Private imports")), + "pandas.plotting", + "pandas.testing", + ], +) +def test_attributes_module(module_name): + """ + Ensures that all public objects have their __module__ set to the public import path. + """ + recurse = module_name not in ["pandas", "pandas.testing"] + objs = get_pandas_objects(module_name, recurse=recurse) + failures = [ + (module_name, name, type(obj), obj.__module__) + for module_name, name, obj in objs + if not ( + obj.__module__ == module_name + # Explicit exceptions + or ("Dtype" in name and obj.__module__ == "pandas") + or (name == "Categorical" and obj.__module__ == "pandas") + ) + ] + assert len(failures) == 0, "\n".join(str(e) for e in failures) + + # Check that all objects can indeed be imported from their __module__ + failures = [] + for module_name, name, obj in objs: + module = importlib.import_module(obj.__module__) + try: + getattr(module, name) + except Exception: + failures.append((module_name, name, type(obj), obj.__module__)) + assert len(failures) == 0, "\n".join(str(e) for e in failures) diff --git a/pandas/tests/api/test_types.py b/pandas/tests/api/test_types.py new file mode 100644 index 0000000000000000000000000000000000000000..bf39370c49d76762760a98e820ad2985a8a81222 --- /dev/null +++ b/pandas/tests/api/test_types.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +import pandas._testing as tm +from pandas.api import types +from pandas.tests.api.test_api import Base + + +class TestTypes(Base): + allowed = [ + "is_any_real_numeric_dtype", + "is_bool", + "is_bool_dtype", + "is_categorical_dtype", + "is_complex", + "is_complex_dtype", + "is_datetime64_any_dtype", + "is_datetime64_dtype", + "is_datetime64_ns_dtype", + "is_datetime64tz_dtype", + "is_dtype_equal", + "is_float", + "is_float_dtype", + "is_int64_dtype", + "is_integer", + "is_integer_dtype", + "is_number", + "is_numeric_dtype", + "is_object_dtype", + "is_scalar", + "is_sparse", + "is_string_dtype", + "is_signed_integer_dtype", + "is_timedelta64_dtype", + "is_timedelta64_ns_dtype", + "is_unsigned_integer_dtype", + "is_period_dtype", + "is_interval_dtype", + "is_re", + "is_re_compilable", + "is_dict_like", + "is_iterator", + "is_file_like", + "is_list_like", + "is_hashable", + "is_array_like", + "is_named_tuple", + "pandas_dtype", + "union_categoricals", + "infer_dtype", + "is_extension_array_dtype", + ] + deprecated: list[str] = [] + dtypes = ["CategoricalDtype", "DatetimeTZDtype", "PeriodDtype", "IntervalDtype"] + + def test_types(self): + self.check(types, self.allowed + self.dtypes + self.deprecated) + + def test_deprecated_from_api_types(self): + for t in self.deprecated: + with tm.assert_produces_warning(FutureWarning): + getattr(types, t)(1) diff --git a/pandas/tests/apply/__init__.py b/pandas/tests/apply/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/apply/common.py b/pandas/tests/apply/common.py new file mode 100644 index 0000000000000000000000000000000000000000..b4d153df54059ca2a82f336e19afb4297eb218a2 --- /dev/null +++ b/pandas/tests/apply/common.py @@ -0,0 +1,7 @@ +from pandas.core.groupby.base import transformation_kernels + +# There is no Series.cumcount or DataFrame.cumcount +series_transform_kernels = [ + x for x in sorted(transformation_kernels) if x != "cumcount" +] +frame_transform_kernels = [x for x in sorted(transformation_kernels) if x != "cumcount"] diff --git a/pandas/tests/apply/conftest.py b/pandas/tests/apply/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..aecf82f5a941948da66c9dda09ec9a826a2706ca --- /dev/null +++ b/pandas/tests/apply/conftest.py @@ -0,0 +1,63 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, +) +from pandas.api.executors import BaseExecutionEngine + + +class MockExecutionEngine(BaseExecutionEngine): + """ + Execution Engine to test if the execution engine interface receives and + uses all parameters provided by the user. + + Making this engine work as the default Python engine by calling it, no extra + functionality is implemented here. + + When testing, this will be called when this engine is provided, and then the + same pandas.map and pandas.apply function will be called, but without engine, + executing the default behavior from the python engine. + """ + + def map(data, func, args, kwargs, decorator, skip_na): + kwargs_to_pass = kwargs if isinstance(data, DataFrame) else {} + return data.map(func, na_action="ignore" if skip_na else None, **kwargs_to_pass) + + def apply(data, func, args, kwargs, decorator, axis): + if isinstance(data, Series): + return data.apply(func, convert_dtype=True, args=args, by_row=False) + elif isinstance(data, DataFrame): + return data.apply( + func, + axis=axis, + raw=False, + result_type=None, + args=args, + by_row="compat", + **kwargs, + ) + else: + assert isinstance(data, np.ndarray) + + def wrap_function(func): + # https://github.com/numpy/numpy/issues/8352 + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + if isinstance(result, str): + result = np.array(result, dtype=object) + return result + + return wrapper + + return np.apply_along_axis(wrap_function(func), axis, data, *args, **kwargs) + + +class MockEngineDecorator: + __pandas_udf__ = MockExecutionEngine + + +@pytest.fixture(params=[None, MockEngineDecorator]) +def engine(request): + return request.param diff --git a/pandas/tests/apply/test_frame_apply.py b/pandas/tests/apply/test_frame_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..0c16425ac2ac73f2ea96173fe6ec97c4b8ef0cb9 --- /dev/null +++ b/pandas/tests/apply/test_frame_apply.py @@ -0,0 +1,1875 @@ +from datetime import datetime +import warnings + +import numpy as np +import pytest + +from pandas.compat import is_platform_arm + +from pandas.core.dtypes.dtypes import CategoricalDtype + +import pandas as pd +from pandas import ( + DataFrame, + MultiIndex, + Series, + Timestamp, + date_range, +) +import pandas._testing as tm +from pandas.tests.apply.conftest import MockEngineDecorator +from pandas.tests.frame.common import zip_frames +from pandas.util.version import Version + + +@pytest.fixture +def int_frame_const_col(): + """ + Fixture for DataFrame of ints which are constant per column + + Columns are ['A', 'B', 'C'], with values (per column): [1, 2, 3] + """ + df = DataFrame( + np.tile(np.arange(3, dtype="int64"), 6).reshape(6, -1) + 1, + columns=["A", "B", "C"], + ) + return df + + +@pytest.fixture( + params=[ + "python", + pytest.param("numba", marks=pytest.mark.single_cpu), + MockEngineDecorator, + ] +) +def engine(request): + if request.param == "numba": + pytest.importorskip("numba") + return request.param + + +def test_apply(float_frame, engine, request): + if engine == "numba": + mark = pytest.mark.xfail(reason="numba engine not supporting numpy ufunc yet") + request.node.add_marker(mark) + with np.errstate(all="ignore"): + # ufunc + result = np.sqrt(float_frame["A"]) + expected = float_frame.apply(np.sqrt, engine=engine)["A"] + tm.assert_series_equal(result, expected) + + # aggregator + result = float_frame.apply(np.mean, engine=engine)["A"] + expected = np.mean(float_frame["A"]) + assert result == expected + + d = float_frame.index[0] + result = float_frame.apply(np.mean, axis=1, engine=engine) + expected = np.mean(float_frame.xs(d)) + assert result[d] == expected + assert result.index is float_frame.index + + +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("raw", [True, False]) +@pytest.mark.parametrize("nopython", [True, False]) +def test_apply_args(float_frame, axis, raw, engine, nopython): + numba = pytest.importorskip("numba") + if ( + engine == "numba" + and Version(numba.__version__) == Version("0.61") + and is_platform_arm() + ): + pytest.skip(f"Segfaults on ARM platforms with numba {numba.__version__}") + engine_kwargs = {"nopython": nopython} + result = float_frame.apply( + lambda x, y: x + y, + axis, + args=(1,), + raw=raw, + engine=engine, + engine_kwargs=engine_kwargs, + ) + expected = float_frame + 1 + tm.assert_frame_equal(result, expected) + + # GH:58712 + result = float_frame.apply( + lambda x, a, b: x + a + b, + args=(1,), + b=2, + raw=raw, + engine=engine, + engine_kwargs=engine_kwargs, + ) + expected = float_frame + 3 + tm.assert_frame_equal(result, expected) + + if engine == "numba": + # py signature binding + with pytest.raises(TypeError, match="missing a required argument: 'a'"): + float_frame.apply( + lambda x, a: x + a, + b=2, + raw=raw, + engine=engine, + engine_kwargs=engine_kwargs, + ) + + # keyword-only arguments are not supported in numba + with pytest.raises( + pd.errors.NumbaUtilError, + match="numba does not support keyword-only arguments", + ): + float_frame.apply( + lambda x, a, *, b: x + a + b, + args=(1,), + b=2, + raw=raw, + engine=engine, + engine_kwargs=engine_kwargs, + ) + + with pytest.raises( + pd.errors.NumbaUtilError, + match="numba does not support keyword-only arguments", + ): + float_frame.apply( + lambda *x, b: x[0] + x[1] + b, + args=(1,), + b=2, + raw=raw, + engine=engine, + engine_kwargs=engine_kwargs, + ) + + +def test_apply_categorical_func(): + # GH 9573 + df = DataFrame({"c0": ["A", "A", "B", "B"], "c1": ["C", "C", "D", "D"]}) + result = df.apply(lambda ts: ts.astype("category")) + + assert result.shape == (4, 2) + assert isinstance(result["c0"].dtype, CategoricalDtype) + assert isinstance(result["c1"].dtype, CategoricalDtype) + + +def test_apply_axis1_with_ea(): + # GH#36785 + expected = DataFrame({"A": [Timestamp("2013-01-01", tz="UTC")]}) + result = expected.apply(lambda x: x, axis=1) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "data, dtype", + [(1, None), (1, CategoricalDtype([1])), (Timestamp("2013-01-01", tz="UTC"), None)], +) +def test_agg_axis1_duplicate_index(data, dtype): + # GH 42380 + expected = DataFrame([[data], [data]], index=["a", "a"], dtype=dtype) + result = expected.agg(lambda x: x, axis=1) + tm.assert_frame_equal(result, expected) + + +def test_apply_mixed_datetimelike(): + # mixed datetimelike + # GH 7778 + expected = DataFrame( + { + "A": date_range("20130101", periods=3), + "B": pd.to_timedelta(np.arange(3), unit="s"), + } + ) + result = expected.apply(lambda x: x, axis=1) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", [np.sqrt, np.mean]) +def test_apply_empty(func, engine): + # empty + empty_frame = DataFrame() + + result = empty_frame.apply(func, engine=engine) + assert result.empty + + +def test_apply_float_frame(float_frame, engine): + no_rows = float_frame[:0] + result = no_rows.apply(lambda x: x.mean(), engine=engine) + expected = Series(np.nan, index=float_frame.columns) + tm.assert_series_equal(result, expected) + + no_cols = float_frame.loc[:, []] + result = no_cols.apply(lambda x: x.mean(), axis=1, engine=engine) + expected = Series(np.nan, index=float_frame.index) + tm.assert_series_equal(result, expected) + + +def test_apply_empty_except_index(engine): + # GH 2476 + expected = DataFrame(index=["a"]) + result = expected.apply(lambda x: x["a"], axis=1, engine=engine) + tm.assert_frame_equal(result, expected) + + +def test_apply_with_reduce_empty(): + # reduce with an empty DataFrame + empty_frame = DataFrame() + + x = [] + result = empty_frame.apply(x.append, axis=1, result_type="expand") + tm.assert_frame_equal(result, empty_frame) + result = empty_frame.apply(x.append, axis=1, result_type="reduce") + expected = Series([], dtype=np.float64) + tm.assert_series_equal(result, expected) + + empty_with_cols = DataFrame(columns=["a", "b", "c"]) + result = empty_with_cols.apply(x.append, axis=1, result_type="expand") + tm.assert_frame_equal(result, empty_with_cols) + result = empty_with_cols.apply(x.append, axis=1, result_type="reduce") + expected = Series([], dtype=np.float64) + tm.assert_series_equal(result, expected) + + # Ensure that x.append hasn't been called + assert x == [] + + +@pytest.mark.parametrize("func", ["sum", "prod", "any", "all"]) +def test_apply_funcs_over_empty(func): + # GH 28213 + df = DataFrame(columns=["a", "b", "c"]) + + result = df.apply(getattr(np, func)) + expected = getattr(df, func)() + if func in ("sum", "prod"): + expected = expected.astype(float) + tm.assert_series_equal(result, expected) + + +def test_nunique_empty(): + # GH 28213 + df = DataFrame(columns=["a", "b", "c"]) + + result = df.nunique() + expected = Series(0, index=df.columns) + tm.assert_series_equal(result, expected) + + result = df.T.nunique() + expected = Series([], dtype=np.float64) + tm.assert_series_equal(result, expected) + + +def test_apply_standard_nonunique(): + df = DataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]], index=["a", "a", "c"]) + + result = df.apply(lambda s: s[0], axis=1) + expected = Series([1, 4, 7], ["a", "a", "c"]) + tm.assert_series_equal(result, expected) + + result = df.T.apply(lambda s: s[0], axis=0) + tm.assert_series_equal(result, expected) + + +def test_apply_broadcast_scalars(float_frame): + # scalars + result = float_frame.apply(np.mean, result_type="broadcast") + expected = DataFrame([float_frame.mean()], index=float_frame.index) + tm.assert_frame_equal(result, expected) + + +def test_apply_broadcast_scalars_axis1(float_frame): + result = float_frame.apply(np.mean, axis=1, result_type="broadcast") + m = float_frame.mean(axis=1) + expected = DataFrame(dict.fromkeys(float_frame.columns, m)) + tm.assert_frame_equal(result, expected) + + +def test_apply_broadcast_lists_columns(float_frame): + # lists + result = float_frame.apply( + lambda x: list(range(len(float_frame.columns))), + axis=1, + result_type="broadcast", + ) + m = list(range(len(float_frame.columns))) + expected = DataFrame( + [m] * len(float_frame.index), + dtype="float64", + index=float_frame.index, + columns=float_frame.columns, + ) + tm.assert_frame_equal(result, expected) + + +def test_apply_broadcast_lists_index(float_frame): + result = float_frame.apply( + lambda x: list(range(len(float_frame.index))), result_type="broadcast" + ) + m = list(range(len(float_frame.index))) + expected = DataFrame( + dict.fromkeys(float_frame.columns, m), + dtype="float64", + index=float_frame.index, + ) + tm.assert_frame_equal(result, expected) + + +def test_apply_broadcast_list_lambda_func(int_frame_const_col): + # preserve columns + df = int_frame_const_col + result = df.apply(lambda x: [1, 2, 3], axis=1, result_type="broadcast") + tm.assert_frame_equal(result, df) + + +def test_apply_broadcast_series_lambda_func(int_frame_const_col): + df = int_frame_const_col + result = df.apply( + lambda x: Series([1, 2, 3], index=list("abc")), + axis=1, + result_type="broadcast", + ) + expected = df.copy() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_apply_raw_float_frame(float_frame, axis, engine): + if engine == "numba": + pytest.skip("numba can't handle when UDF returns None.") + + def _assert_raw(x): + assert isinstance(x, np.ndarray) + assert x.ndim == 1 + + float_frame.apply(_assert_raw, axis=axis, engine=engine, raw=True) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_apply_raw_float_frame_lambda(float_frame, axis, engine): + result = float_frame.apply(np.mean, axis=axis, engine=engine, raw=True) + expected = float_frame.apply(lambda x: x.values.mean(), axis=axis) + tm.assert_series_equal(result, expected) + + +def test_apply_raw_float_frame_no_reduction(float_frame, engine): + # no reduction + result = float_frame.apply(lambda x: x * 2, engine=engine, raw=True) + expected = float_frame * 2 + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("axis", [0, 1]) +def test_apply_raw_mixed_type_frame(axis, engine): + if engine == "numba": + pytest.skip("isinstance check doesn't work with numba") + + def _assert_raw(x): + assert isinstance(x, np.ndarray) + assert x.ndim == 1 + + # Mixed dtype (GH-32423) + df = DataFrame( + { + "a": 1.0, + "b": 2, + "c": "foo", + "float32": np.array([1.0] * 10, dtype="float32"), + "int32": np.array([1] * 10, dtype="int32"), + }, + index=np.arange(10), + ) + df.apply(_assert_raw, axis=axis, engine=engine, raw=True) + + +def test_apply_axis1(float_frame): + d = float_frame.index[0] + result = float_frame.apply(np.mean, axis=1)[d] + expected = np.mean(float_frame.xs(d)) + assert result == expected + + +def test_apply_mixed_dtype_corner(): + df = DataFrame({"A": ["foo"], "B": [1.0]}) + result = df[:0].apply(np.mean, axis=1) + # the result here is actually kind of ambiguous, should it be a Series + # or a DataFrame? + expected = Series(dtype=np.float64) + tm.assert_series_equal(result, expected) + + +def test_apply_mixed_dtype_corner_indexing(): + df = DataFrame({"A": ["foo"], "B": [1.0]}) + result = df.apply(lambda x: x["A"], axis=1) + expected = Series(["foo"], index=range(1)) + tm.assert_series_equal(result, expected) + + result = df.apply(lambda x: x["B"], axis=1) + expected = Series([1.0], index=range(1)) + tm.assert_series_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +@pytest.mark.parametrize("ax", ["index", "columns"]) +@pytest.mark.parametrize( + "func", [lambda x: x, lambda x: x.mean()], ids=["identity", "mean"] +) +@pytest.mark.parametrize("raw", [True, False]) +@pytest.mark.parametrize("axis", [0, 1]) +def test_apply_empty_infer_type(ax, func, raw, axis, engine, request): + df = DataFrame(**{ax: ["a", "b", "c"]}) + + with np.errstate(all="ignore"): + test_res = func(np.array([], dtype="f8")) + is_reduction = not isinstance(test_res, np.ndarray) + + result = df.apply(func, axis=axis, engine=engine, raw=raw) + if is_reduction: + agg_axis = df._get_agg_axis(axis) + assert isinstance(result, Series) + assert result.index is agg_axis + else: + assert isinstance(result, DataFrame) + + +def test_apply_empty_infer_type_broadcast(): + no_cols = DataFrame(index=["a", "b", "c"]) + result = no_cols.apply(lambda x: x.mean(), result_type="broadcast") + assert isinstance(result, DataFrame) + + +def test_apply_with_args_kwds_add_some(float_frame): + def add_some(x, howmuch=0): + return x + howmuch + + result = float_frame.apply(add_some, howmuch=2) + expected = float_frame.apply(lambda x: x + 2) + tm.assert_frame_equal(result, expected) + + +def test_apply_with_args_kwds_agg_and_add(float_frame): + def agg_and_add(x, howmuch=0): + return x.mean() + howmuch + + result = float_frame.apply(agg_and_add, howmuch=2) + expected = float_frame.apply(lambda x: x.mean() + 2) + tm.assert_series_equal(result, expected) + + +def test_apply_with_args_kwds_subtract_and_divide(float_frame): + def subtract_and_divide(x, sub, divide=1): + return (x - sub) / divide + + result = float_frame.apply(subtract_and_divide, args=(2,), divide=2) + expected = float_frame.apply(lambda x: (x - 2.0) / 2.0) + tm.assert_frame_equal(result, expected) + + +def test_apply_yield_list(float_frame): + result = float_frame.apply(list) + tm.assert_frame_equal(result, float_frame) + + +def test_apply_reduce_Series(float_frame): + float_frame.iloc[::2, float_frame.columns.get_loc("A")] = np.nan + expected = float_frame.mean(axis=1) + result = float_frame.apply(np.mean, axis=1) + tm.assert_series_equal(result, expected) + + +def test_apply_reduce_to_dict(): + # GH 25196 37544 + data = DataFrame([[1, 2], [3, 4]], columns=["c0", "c1"], index=["i0", "i1"]) + + result = data.apply(dict, axis=0) + expected = Series([{"i0": 1, "i1": 3}, {"i0": 2, "i1": 4}], index=data.columns) + tm.assert_series_equal(result, expected) + + result = data.apply(dict, axis=1) + expected = Series([{"c0": 1, "c1": 2}, {"c0": 3, "c1": 4}], index=data.index) + tm.assert_series_equal(result, expected) + + +def test_apply_differently_indexed(): + df = DataFrame(np.random.default_rng(2).standard_normal((20, 10))) + + result = df.apply(Series.describe, axis=0) + expected = DataFrame({i: v.describe() for i, v in df.items()}, columns=df.columns) + tm.assert_frame_equal(result, expected) + + result = df.apply(Series.describe, axis=1) + expected = DataFrame({i: v.describe() for i, v in df.T.items()}, columns=df.index).T + tm.assert_frame_equal(result, expected) + + +def test_apply_bug(): + # GH 6125 + positions = DataFrame( + [ + [1, "ABC0", 50], + [1, "YUM0", 20], + [1, "DEF0", 20], + [2, "ABC1", 50], + [2, "YUM1", 20], + [2, "DEF1", 20], + ], + columns=["a", "market", "position"], + ) + + def f(r): + return r["market"] + + expected = positions.apply(f, axis=1) + + positions = DataFrame( + [ + [datetime(2013, 1, 1), "ABC0", 50], + [datetime(2013, 1, 2), "YUM0", 20], + [datetime(2013, 1, 3), "DEF0", 20], + [datetime(2013, 1, 4), "ABC1", 50], + [datetime(2013, 1, 5), "YUM1", 20], + [datetime(2013, 1, 6), "DEF1", 20], + ], + columns=["a", "market", "position"], + ) + result = positions.apply(f, axis=1) + tm.assert_series_equal(result, expected) + + +def test_apply_convert_objects(): + expected = DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + result = expected.apply(lambda x: x, axis=1) + tm.assert_frame_equal(result, expected) + + +def test_apply_attach_name(float_frame): + result = float_frame.apply(lambda x: x.name) + expected = Series(float_frame.columns, index=float_frame.columns) + tm.assert_series_equal(result, expected) + + +def test_apply_attach_name_axis1(float_frame): + result = float_frame.apply(lambda x: x.name, axis=1) + expected = Series(float_frame.index, index=float_frame.index) + tm.assert_series_equal(result, expected) + + +def test_apply_attach_name_non_reduction(float_frame): + # non-reductions + result = float_frame.apply(lambda x: np.repeat(x.name, len(x))) + expected = DataFrame( + np.tile(float_frame.columns, (len(float_frame.index), 1)), + index=float_frame.index, + columns=float_frame.columns, + ) + tm.assert_frame_equal(result, expected) + + +def test_apply_attach_name_non_reduction_axis1(float_frame): + result = float_frame.apply(lambda x: np.repeat(x.name, len(x)), axis=1) + expected = Series( + np.repeat(t[0], len(float_frame.columns)) for t in float_frame.itertuples() + ) + expected.index = float_frame.index + tm.assert_series_equal(result, expected) + + +def test_apply_multi_index(): + index = MultiIndex.from_arrays([["a", "a", "b"], ["c", "d", "d"]]) + s = DataFrame([[1, 2], [3, 4], [5, 6]], index=index, columns=["col1", "col2"]) + result = s.apply(lambda x: Series({"min": min(x), "max": max(x)}), 1) + expected = DataFrame([[1, 2], [3, 4], [5, 6]], index=index, columns=["min", "max"]) + tm.assert_frame_equal(result, expected, check_like=True) + + +@pytest.mark.parametrize( + "df, dicts", + [ + [ + DataFrame([["foo", "bar"], ["spam", "eggs"]]), + Series([{0: "foo", 1: "spam"}, {0: "bar", 1: "eggs"}]), + ], + [DataFrame([[0, 1], [2, 3]]), Series([{0: 0, 1: 2}, {0: 1, 1: 3}])], + ], +) +def test_apply_dict(df, dicts): + # GH 8735 + fn = lambda x: x.to_dict() + reduce_true = df.apply(fn, result_type="reduce") + reduce_false = df.apply(fn, result_type="expand") + reduce_none = df.apply(fn) + + tm.assert_series_equal(reduce_true, dicts) + tm.assert_frame_equal(reduce_false, df) + tm.assert_series_equal(reduce_none, dicts) + + +def test_apply_non_numpy_dtype(): + # GH 12244 + df = DataFrame({"dt": date_range("2015-01-01", periods=3, tz="Europe/Brussels")}) + result = df.apply(lambda x: x) + tm.assert_frame_equal(result, df) + + result = df.apply(lambda x: x + pd.Timedelta("1day")) + expected = DataFrame( + {"dt": date_range("2015-01-02", periods=3, tz="Europe/Brussels")} + ) + tm.assert_frame_equal(result, expected) + + +def test_apply_non_numpy_dtype_category(): + df = DataFrame({"dt": ["a", "b", "c", "a"]}, dtype="category") + result = df.apply(lambda x: x) + tm.assert_frame_equal(result, df) + + +def test_apply_dup_names_multi_agg(): + # GH 21063 + df = DataFrame([[0, 1], [2, 3]], columns=["a", "a"]) + expected = DataFrame([[0, 1]], columns=["a", "a"], index=["min"]) + result = df.agg(["min"]) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("op", ["apply", "agg"]) +def test_apply_nested_result_axis_1(op): + # GH 13820 + def apply_list(row): + return [2 * row["A"], 2 * row["C"], 2 * row["B"]] + + df = DataFrame(np.zeros((4, 4)), columns=list("ABCD")) + result = getattr(df, op)(apply_list, axis=1) + expected = Series( + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]] + ) + tm.assert_series_equal(result, expected) + + +def test_apply_noreduction_tzaware_object(): + # https://github.com/pandas-dev/pandas/issues/31505 + expected = DataFrame( + {"foo": [Timestamp("2020", tz="UTC")]}, dtype="datetime64[ns, UTC]" + ) + result = expected.apply(lambda x: x) + tm.assert_frame_equal(result, expected) + result = expected.apply(lambda x: x.copy()) + tm.assert_frame_equal(result, expected) + + +def test_apply_function_runs_once(): + # https://github.com/pandas-dev/pandas/issues/30815 + + df = DataFrame({"a": [1, 2, 3]}) + names = [] # Save row names function is applied to + + def reducing_function(row): + names.append(row.name) + + def non_reducing_function(row): + names.append(row.name) + return row + + for func in [reducing_function, non_reducing_function]: + del names[:] + + df.apply(func, axis=1) + assert names == list(df.index) + + +def test_apply_raw_function_runs_once(engine): + # https://github.com/pandas-dev/pandas/issues/34506 + if engine == "numba": + pytest.skip("appending to list outside of numba func is not supported") + + df = DataFrame({"a": [1, 2, 3]}) + values = [] # Save row values function is applied to + + def reducing_function(row): + values.extend(row) + + def non_reducing_function(row): + values.extend(row) + return row + + for func in [reducing_function, non_reducing_function]: + del values[:] + + df.apply(func, engine=engine, raw=True, axis=1) + assert values == list(df.a.to_list()) + + +def test_apply_with_byte_string(): + # GH 34529 + df = DataFrame(np.array([b"abcd", b"efgh"]), columns=["col"]) + expected = DataFrame(np.array([b"abcd", b"efgh"]), columns=["col"], dtype=object) + # After we make the apply we expect a dataframe just + # like the original but with the object datatype + result = df.apply(lambda x: x.astype("object")) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("val", ["asd", 12, None, np.nan]) +def test_apply_category_equalness(val): + # Check if categorical comparisons on apply, GH 21239 + df_values = ["asd", None, 12, "asd", "cde", np.nan] + df = DataFrame({"a": df_values}, dtype="category") + + result = df.a.apply(lambda x: x == val) + expected = Series( + [False if pd.isnull(x) else x == val for x in df_values], name="a" + ) + # False since behavior of NaN for categorical dtype has been changed (GH 59966) + tm.assert_series_equal(result, expected) + + +# the user has supplied an opaque UDF where +# they are transforming the input that requires +# us to infer the output + + +def test_infer_row_shape(): + # GH 17437 + # if row shape is changing, infer it + df = DataFrame(np.random.default_rng(2).random((10, 2))) + result = df.apply(np.fft.fft, axis=0).shape + assert result == (10, 2) + + result = df.apply(np.fft.rfft, axis=0).shape + assert result == (6, 2) + + +@pytest.mark.parametrize( + "ops, by_row, expected", + [ + ({"a": lambda x: x + 1}, "compat", DataFrame({"a": [2, 3]})), + ({"a": lambda x: x + 1}, False, DataFrame({"a": [2, 3]})), + ({"a": lambda x: x.sum()}, "compat", Series({"a": 3})), + ({"a": lambda x: x.sum()}, False, Series({"a": 3})), + ( + {"a": ["sum", np.sum, lambda x: x.sum()]}, + "compat", + DataFrame({"a": [3, 3, 3]}, index=["sum", "sum", ""]), + ), + ( + {"a": ["sum", np.sum, lambda x: x.sum()]}, + False, + DataFrame({"a": [3, 3, 3]}, index=["sum", "sum", ""]), + ), + ({"a": lambda x: 1}, "compat", DataFrame({"a": [1, 1]})), + ({"a": lambda x: 1}, False, Series({"a": 1})), + ], +) +def test_dictlike_lambda(ops, by_row, expected): + # GH53601 + df = DataFrame({"a": [1, 2]}) + result = df.apply(ops, by_row=by_row) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "ops", + [ + {"a": lambda x: x + 1}, + {"a": lambda x: x.sum()}, + {"a": ["sum", np.sum, lambda x: x.sum()]}, + {"a": lambda x: 1}, + ], +) +def test_dictlike_lambda_raises(ops): + # GH53601 + df = DataFrame({"a": [1, 2]}) + with pytest.raises(ValueError, match="by_row=True not allowed"): + df.apply(ops, by_row=True) + + +def test_with_dictlike_columns(): + # GH 17602 + df = DataFrame([[1, 2], [1, 2]], columns=["a", "b"]) + result = df.apply(lambda x: {"s": x["a"] + x["b"]}, axis=1) + expected = Series([{"s": 3} for t in df.itertuples()]) + tm.assert_series_equal(result, expected) + + df["tm"] = [ + Timestamp("2017-05-01 00:00:00"), + Timestamp("2017-05-02 00:00:00"), + ] + result = df.apply(lambda x: {"s": x["a"] + x["b"]}, axis=1) + tm.assert_series_equal(result, expected) + + # compose a series + result = (df["a"] + df["b"]).apply(lambda x: {"s": x}) + expected = Series([{"s": 3}, {"s": 3}]) + tm.assert_series_equal(result, expected) + + +def test_with_dictlike_columns_with_datetime(): + # GH 18775 + df = DataFrame() + df["author"] = ["X", "Y", "Z"] + df["publisher"] = ["BBC", "NBC", "N24"] + df["date"] = pd.to_datetime( + ["17-10-2010 07:15:30", "13-05-2011 08:20:35", "15-01-2013 09:09:09"], + dayfirst=True, + ) + result = df.apply(lambda x: {}, axis=1) + expected = Series([{}, {}, {}]) + tm.assert_series_equal(result, expected) + + +def test_with_dictlike_columns_with_infer(): + # GH 17602 + df = DataFrame([[1, 2], [1, 2]], columns=["a", "b"]) + result = df.apply(lambda x: {"s": x["a"] + x["b"]}, axis=1, result_type="expand") + expected = DataFrame({"s": [3, 3]}) + tm.assert_frame_equal(result, expected) + + df["tm"] = [ + Timestamp("2017-05-01 00:00:00"), + Timestamp("2017-05-02 00:00:00"), + ] + result = df.apply(lambda x: {"s": x["a"] + x["b"]}, axis=1, result_type="expand") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "ops, by_row, expected", + [ + ([lambda x: x + 1], "compat", DataFrame({("a", ""): [2, 3]})), + ([lambda x: x + 1], False, DataFrame({("a", ""): [2, 3]})), + ([lambda x: x.sum()], "compat", DataFrame({"a": [3]}, index=[""])), + ([lambda x: x.sum()], False, DataFrame({"a": [3]}, index=[""])), + ( + ["sum", np.sum, lambda x: x.sum()], + "compat", + DataFrame({"a": [3, 3, 3]}, index=["sum", "sum", ""]), + ), + ( + ["sum", np.sum, lambda x: x.sum()], + False, + DataFrame({"a": [3, 3, 3]}, index=["sum", "sum", ""]), + ), + ( + [lambda x: x + 1, lambda x: 3], + "compat", + DataFrame([[2, 3], [3, 3]], columns=[["a", "a"], ["", ""]]), + ), + ( + [lambda x: 2, lambda x: 3], + False, + DataFrame({"a": [2, 3]}, ["", ""]), + ), + ], +) +def test_listlike_lambda(ops, by_row, expected): + # GH53601 + df = DataFrame({"a": [1, 2]}) + result = df.apply(ops, by_row=by_row) + tm.assert_equal(result, expected) + + +def test_listlike_datetime_index_unsorted(): + # https://github.com/pandas-dev/pandas/pull/62843 + values = [datetime(2024, 1, 1), datetime(2024, 1, 2), datetime(2024, 1, 3)] + df = DataFrame({"a": [1, 2]}, index=[values[1], values[0]]) + result = df.apply([lambda x: x, lambda x: x.shift(freq="D")], by_row=False) + expected = DataFrame( + [[1.0, 2.0], [2.0, np.nan], [np.nan, 1.0]], + index=[values[1], values[0], values[2]], + columns=MultiIndex([["a"], [""]], codes=[[0, 0], [0, 0]]), + ) + tm.assert_frame_equal(result, expected) + + +def test_dictlike_datetime_index_unsorted(): + # https://github.com/pandas-dev/pandas/pull/62843 + values = [datetime(2024, 1, 1), datetime(2024, 1, 2), datetime(2024, 1, 3)] + df = DataFrame({"a": [1, 2], "b": [3, 4]}, index=[values[1], values[0]]) + result = df.apply( + {"a": lambda x: x, "b": lambda x: x.shift(freq="D")}, by_row=False + ) + expected = DataFrame( + { + "a": [1.0, 2.0, np.nan], + "b": [4.0, np.nan, 3.0], + }, + index=[values[1], values[0], values[2]], + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "ops", + [ + [lambda x: x + 1], + [lambda x: x.sum()], + ["sum", np.sum, lambda x: x.sum()], + [lambda x: x + 1, lambda x: 3], + ], +) +def test_listlike_lambda_raises(ops): + # GH53601 + df = DataFrame({"a": [1, 2]}) + with pytest.raises(ValueError, match="by_row=True not allowed"): + df.apply(ops, by_row=True) + + +def test_with_listlike_columns(): + # GH 17348 + df = DataFrame( + { + "a": Series(np.random.default_rng(2).standard_normal(4)), + "b": ["a", "list", "of", "words"], + "ts": date_range("2016-10-01", periods=4, freq="h"), + } + ) + + result = df[["a", "b"]].apply(tuple, axis=1) + expected = Series([t[1:] for t in df[["a", "b"]].itertuples()]) + tm.assert_series_equal(result, expected) + + result = df[["a", "ts"]].apply(tuple, axis=1) + expected = Series([t[1:] for t in df[["a", "ts"]].itertuples()]) + tm.assert_series_equal(result, expected) + + +def test_with_listlike_columns_returning_list(): + # GH 18919 + df = DataFrame({"x": Series([["a", "b"], ["q"]]), "y": Series([["z"], ["q", "t"]])}) + df.index = MultiIndex.from_tuples([("i0", "j0"), ("i1", "j1")]) + + result = df.apply(lambda row: [el for el in row["x"] if el in row["y"]], axis=1) + expected = Series([[], ["q"]], index=df.index) + tm.assert_series_equal(result, expected) + + +def test_infer_output_shape_columns(): + # GH 18573 + + df = DataFrame( + { + "number": [1.0, 2.0], + "string": ["foo", "bar"], + "datetime": [ + Timestamp("2017-11-29 03:30:00"), + Timestamp("2017-11-29 03:45:00"), + ], + } + ) + result = df.apply(lambda row: (row.number, row.string), axis=1) + expected = Series([(t.number, t.string) for t in df.itertuples()]) + tm.assert_series_equal(result, expected) + + +def test_infer_output_shape_listlike_columns(): + # GH 16353 + + df = DataFrame( + np.random.default_rng(2).standard_normal((6, 3)), columns=["A", "B", "C"] + ) + + result = df.apply(lambda x: [1, 2, 3], axis=1) + expected = Series([[1, 2, 3] for t in df.itertuples()]) + tm.assert_series_equal(result, expected) + + result = df.apply(lambda x: [1, 2], axis=1) + expected = Series([[1, 2] for t in df.itertuples()]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("val", [1, 2]) +def test_infer_output_shape_listlike_columns_np_func(val): + # GH 17970 + df = DataFrame({"a": [1, 2, 3]}, index=list("abc")) + + result = df.apply(lambda row: np.ones(val), axis=1) + expected = Series([np.ones(val) for t in df.itertuples()], index=df.index) + tm.assert_series_equal(result, expected) + + +def test_infer_output_shape_listlike_columns_with_timestamp(): + # GH 17892 + df = DataFrame( + { + "a": [ + Timestamp("2010-02-01"), + Timestamp("2010-02-04"), + Timestamp("2010-02-05"), + Timestamp("2010-02-06"), + ], + "b": [9, 5, 4, 3], + "c": [5, 3, 4, 2], + "d": [1, 2, 3, 4], + } + ) + + def fun(x): + return (1, 2) + + result = df.apply(fun, axis=1) + expected = Series([(1, 2) for t in df.itertuples()]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("lst", [[1, 2, 3], [1, 2]]) +def test_consistent_coerce_for_shapes(lst): + # we want column names to NOT be propagated + # just because the shape matches the input shape + df = DataFrame( + np.random.default_rng(2).standard_normal((4, 3)), columns=["A", "B", "C"] + ) + + result = df.apply(lambda x: lst, axis=1) + expected = Series([lst for t in df.itertuples()]) + tm.assert_series_equal(result, expected) + + +def test_consistent_names(int_frame_const_col): + # if a Series is returned, we should use the resulting index names + df = int_frame_const_col + + result = df.apply( + lambda x: Series([1, 2, 3], index=["test", "other", "cols"]), axis=1 + ) + expected = int_frame_const_col.rename( + columns={"A": "test", "B": "other", "C": "cols"} + ) + tm.assert_frame_equal(result, expected) + + result = df.apply(lambda x: Series([1, 2], index=["test", "other"]), axis=1) + expected = expected[["test", "other"]] + tm.assert_frame_equal(result, expected) + + +def test_result_type(int_frame_const_col): + # result_type should be consistent no matter which + # path we take in the code + df = int_frame_const_col + + result = df.apply(lambda x: [1, 2, 3], axis=1, result_type="expand") + expected = df.copy() + expected.columns = range(3) + tm.assert_frame_equal(result, expected) + + +def test_result_type_shorter_list(int_frame_const_col): + # result_type should be consistent no matter which + # path we take in the code + df = int_frame_const_col + result = df.apply(lambda x: [1, 2], axis=1, result_type="expand") + expected = df[["A", "B"]].copy() + expected.columns = range(2) + tm.assert_frame_equal(result, expected) + + +def test_result_type_broadcast(int_frame_const_col, request, engine): + # result_type should be consistent no matter which + # path we take in the code + if engine == "numba": + mark = pytest.mark.xfail(reason="numba engine doesn't support list return") + request.node.add_marker(mark) + df = int_frame_const_col + if engine is MockEngineDecorator: + with pytest.raises( + NotImplementedError, + match="result_type='broadcast' only implemented for the default engine", + ): + df.apply( + lambda x: [1, 2, 3], axis=1, result_type="broadcast", engine=engine + ) + else: + # broadcast result + result = df.apply( + lambda x: [1, 2, 3], axis=1, result_type="broadcast", engine=engine + ) + expected = df.copy() + tm.assert_frame_equal(result, expected) + + +def test_result_type_broadcast_series_func(int_frame_const_col, engine, request): + # result_type should be consistent no matter which + # path we take in the code + if engine == "numba": + mark = pytest.mark.xfail( + reason="numba Series constructor only support ndarrays not list data" + ) + request.node.add_marker(mark) + df = int_frame_const_col + columns = ["other", "col", "names"] + + if engine is MockEngineDecorator: + with pytest.raises( + NotImplementedError, + match="result_type='broadcast' only implemented for the default engine", + ): + df.apply( + lambda x: Series([1, 2, 3], index=columns), + axis=1, + result_type="broadcast", + engine=engine, + ) + else: + result = df.apply( + lambda x: Series([1, 2, 3], index=columns), + axis=1, + result_type="broadcast", + engine=engine, + ) + expected = df.copy() + tm.assert_frame_equal(result, expected) + + +def test_result_type_series_result(int_frame_const_col, engine, request): + # result_type should be consistent no matter which + # path we take in the code + if engine == "numba": + mark = pytest.mark.xfail( + reason="numba Series constructor only support ndarrays not list data" + ) + request.node.add_marker(mark) + df = int_frame_const_col + # series result + result = df.apply(lambda x: Series([1, 2, 3], index=x.index), axis=1, engine=engine) + expected = df.copy() + tm.assert_frame_equal(result, expected) + + +def test_result_type_series_result_other_index(int_frame_const_col, engine, request): + # result_type should be consistent no matter which + # path we take in the code + + if engine == "numba": + mark = pytest.mark.xfail( + reason="no support in numba Series constructor for list of columns" + ) + request.node.add_marker(mark) + df = int_frame_const_col + # series result with other index + columns = ["other", "col", "names"] + result = df.apply(lambda x: Series([1, 2, 3], index=columns), axis=1, engine=engine) + expected = df.copy() + expected.columns = columns + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "box", + [lambda x: list(x), lambda x: tuple(x), lambda x: np.array(x, dtype="int64")], + ids=["list", "tuple", "array"], +) +def test_consistency_for_boxed(box, int_frame_const_col): + # passing an array or list should not affect the output shape + df = int_frame_const_col + + result = df.apply(lambda x: box([1, 2]), axis=1) + expected = Series([box([1, 2]) for t in df.itertuples()]) + tm.assert_series_equal(result, expected) + + result = df.apply(lambda x: box([1, 2]), axis=1, result_type="expand") + expected = int_frame_const_col[["A", "B"]].rename(columns={"A": 0, "B": 1}) + tm.assert_frame_equal(result, expected) + + +def test_agg_transform(axis, float_frame): + other_axis = 1 if axis in {0, "index"} else 0 + + with np.errstate(all="ignore"): + f_abs = np.abs(float_frame) + f_sqrt = np.sqrt(float_frame) + + # ufunc + expected = f_sqrt.copy() + result = float_frame.apply(np.sqrt, axis=axis) + tm.assert_frame_equal(result, expected) + + # list-like + result = float_frame.apply([np.sqrt], axis=axis) + expected = f_sqrt.copy() + if axis in {0, "index"}: + expected.columns = MultiIndex.from_product([float_frame.columns, ["sqrt"]]) + else: + expected.index = MultiIndex.from_product([float_frame.index, ["sqrt"]]) + tm.assert_frame_equal(result, expected) + + # multiple items in list + # these are in the order as if we are applying both + # functions per series and then concatting + result = float_frame.apply([np.abs, np.sqrt], axis=axis) + expected = zip_frames([f_abs, f_sqrt], axis=other_axis) + if axis in {0, "index"}: + expected.columns = MultiIndex.from_product( + [float_frame.columns, ["absolute", "sqrt"]] + ) + else: + expected.index = MultiIndex.from_product( + [float_frame.index, ["absolute", "sqrt"]] + ) + tm.assert_frame_equal(result, expected) + + +def test_demo(): + # demonstration tests + df = DataFrame({"A": range(5), "B": 5}) + + result = df.agg(["min", "max"]) + expected = DataFrame( + {"A": [0, 4], "B": [5, 5]}, columns=["A", "B"], index=["min", "max"] + ) + tm.assert_frame_equal(result, expected) + + +def test_demo_dict_agg(): + # demonstration tests + df = DataFrame({"A": range(5), "B": 5}) + result = df.agg({"A": ["min", "max"], "B": ["sum", "max"]}) + expected = DataFrame( + {"A": [4.0, 0.0, np.nan], "B": [5.0, np.nan, 25.0]}, + columns=["A", "B"], + index=["max", "min", "sum"], + ) + tm.assert_frame_equal(result.reindex_like(expected), expected) + + +def test_agg_with_name_as_column_name(): + # GH 36212 - Column name is "name" + data = {"name": ["foo", "bar"]} + df = DataFrame(data) + + # result's name should be None + result = df.agg({"name": "count"}) + expected = Series({"name": 2}) + tm.assert_series_equal(result, expected) + + # Check if name is still preserved when aggregating series instead + result = df["name"].agg({"name": "count"}) + expected = Series({"name": 2}, name="name") + tm.assert_series_equal(result, expected) + + +def test_agg_multiple_mixed(): + # GH 20909 + mdf = DataFrame( + { + "A": [1, 2, 3], + "B": [1.0, 2.0, 3.0], + "C": ["foo", "bar", "baz"], + } + ) + expected = DataFrame( + { + "A": [1, 6], + "B": [1.0, 6.0], + "C": ["bar", "foobarbaz"], + }, + index=["min", "sum"], + ) + # sorted index + result = mdf.agg(["min", "sum"]) + tm.assert_frame_equal(result, expected) + + result = mdf[["C", "B", "A"]].agg(["sum", "min"]) + # GH40420: the result of .agg should have an index that is sorted + # according to the arguments provided to agg. + expected = expected[["C", "B", "A"]].reindex(["sum", "min"]) + tm.assert_frame_equal(result, expected) + + +def test_agg_multiple_mixed_raises(): + # GH 20909 + mdf = DataFrame( + { + "A": [1, 2, 3], + "B": [1.0, 2.0, 3.0], + "C": ["foo", "bar", "baz"], + "D": date_range("20130101", periods=3), + } + ) + + # sorted index + msg = "does not support operation" + with pytest.raises(TypeError, match=msg): + mdf.agg(["min", "sum"]) + + with pytest.raises(TypeError, match=msg): + mdf[["D", "C", "B", "A"]].agg(["sum", "min"]) + + +def test_agg_reduce(axis, float_frame): + other_axis = 1 if axis in {0, "index"} else 0 + name1, name2 = float_frame.axes[other_axis].unique()[:2].sort_values() + + # all reducers + expected = pd.concat( + [ + float_frame.mean(axis=axis), + float_frame.max(axis=axis), + float_frame.sum(axis=axis), + ], + axis=1, + ) + expected.columns = ["mean", "max", "sum"] + expected = expected.T if axis in {0, "index"} else expected + + result = float_frame.agg(["mean", "max", "sum"], axis=axis) + tm.assert_frame_equal(result, expected) + + # dict input with scalars + func = {name1: "mean", name2: "sum"} + result = float_frame.agg(func, axis=axis) + expected = Series( + [ + float_frame.loc(other_axis)[name1].mean(), + float_frame.loc(other_axis)[name2].sum(), + ], + index=[name1, name2], + ) + tm.assert_series_equal(result, expected) + + # dict input with lists + func = {name1: ["mean"], name2: ["sum"]} + result = float_frame.agg(func, axis=axis) + expected = DataFrame( + { + name1: Series([float_frame.loc(other_axis)[name1].mean()], index=["mean"]), + name2: Series([float_frame.loc(other_axis)[name2].sum()], index=["sum"]), + } + ) + expected = expected.T if axis in {1, "columns"} else expected + tm.assert_frame_equal(result, expected) + + # dict input with lists with multiple + func = {name1: ["mean", "sum"], name2: ["sum", "max"]} + result = float_frame.agg(func, axis=axis) + expected = pd.concat( + { + name1: Series( + [ + float_frame.loc(other_axis)[name1].mean(), + float_frame.loc(other_axis)[name1].sum(), + ], + index=["mean", "sum"], + ), + name2: Series( + [ + float_frame.loc(other_axis)[name2].sum(), + float_frame.loc(other_axis)[name2].max(), + ], + index=["sum", "max"], + ), + }, + axis=1, + ) + expected = expected.T if axis in {1, "columns"} else expected + tm.assert_frame_equal(result, expected) + + +def test_named_agg_reduce_axis1_raises(float_frame): + name1, name2 = float_frame.axes[0].unique()[:2].sort_values() + msg = "Named aggregation is not supported when axis=1." + for axis in [1, "columns"]: + with pytest.raises(NotImplementedError, match=msg): + float_frame.agg(row1=(name1, "sum"), row2=(name2, "max"), axis=axis) + + +def test_nuiscance_columns(): + # GH 15015 + df = DataFrame( + { + "A": [1, 2, 3], + "B": [1.0, 2.0, 3.0], + "C": ["foo", "bar", "baz"], + "D": date_range("20130101", periods=3), + } + ) + + result = df.agg("min") + expected = Series([1, 1.0, "bar", Timestamp("20130101")], index=df.columns) + tm.assert_series_equal(result, expected) + + result = df.agg(["min"]) + expected = DataFrame( + [[1, 1.0, "bar", Timestamp("20130101")]], + index=["min"], + columns=df.columns, + ) + tm.assert_frame_equal(result, expected) + + msg = "does not support operation" + with pytest.raises(TypeError, match=msg): + df.agg("sum") + + result = df[["A", "B", "C"]].agg("sum") + expected = Series([6, 6.0, "foobarbaz"], index=["A", "B", "C"]) + tm.assert_series_equal(result, expected) + + msg = "does not support operation" + with pytest.raises(TypeError, match=msg): + df.agg(["sum"]) + + +@pytest.mark.parametrize("how", ["agg", "apply"]) +def test_non_callable_aggregates(how): + # GH 16405 + # 'size' is a property of frame/series + # validate that this is working + # GH 39116 - expand to apply + df = DataFrame( + {"A": [None, 2, 3], "B": [1.0, np.nan, 3.0], "C": ["foo", None, "bar"]} + ) + + # Function aggregate + result = getattr(df, how)({"A": "count"}) + expected = Series({"A": 2}) + + tm.assert_series_equal(result, expected) + + # Non-function aggregate + result = getattr(df, how)({"A": "size"}) + expected = Series({"A": 3}) + + tm.assert_series_equal(result, expected) + + # Mix function and non-function aggs + result1 = getattr(df, how)(["count", "size"]) + result2 = getattr(df, how)( + {"A": ["count", "size"], "B": ["count", "size"], "C": ["count", "size"]} + ) + expected = DataFrame( + { + "A": {"count": 2, "size": 3}, + "B": {"count": 2, "size": 3}, + "C": {"count": 2, "size": 3}, + } + ) + + tm.assert_frame_equal(result1, result2, check_like=True) + tm.assert_frame_equal(result2, expected, check_like=True) + + # Just functional string arg is same as calling df.arg() + result = getattr(df, how)("count") + expected = df.count() + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("how", ["agg", "apply"]) +def test_size_as_str(how, axis): + # GH 39934 + df = DataFrame( + {"A": [None, 2, 3], "B": [1.0, np.nan, 3.0], "C": ["foo", None, "bar"]} + ) + # Just a string attribute arg same as calling df.arg + # on the columns + result = getattr(df, how)("size", axis=axis) + if axis in (0, "index"): + expected = Series(df.shape[0], index=df.columns) + else: + expected = Series(df.shape[1], index=df.index) + tm.assert_series_equal(result, expected) + + +def test_agg_listlike_result(): + # GH-29587 user defined function returning list-likes + df = DataFrame({"A": [2, 2, 3], "B": [1.5, np.nan, 1.5], "C": ["foo", None, "bar"]}) + + def func(group_col): + return list(group_col.dropna().unique()) + + result = df.agg(func) + expected = Series([[2, 3], [1.5], ["foo", "bar"]], index=["A", "B", "C"]) + tm.assert_series_equal(result, expected) + + result = df.agg([func]) + expected = expected.to_frame("func").T + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize( + "args, kwargs", + [ + ((1, 2, 3), {}), + ((8, 7, 15), {}), + ((1, 2), {}), + ((1,), {"b": 2}), + ((), {"a": 1, "b": 2}), + ((), {"a": 2, "b": 1}), + ((), {"a": 1, "b": 2, "c": 3}), + ], +) +def test_agg_args_kwargs(axis, args, kwargs): + def f(x, a, b, c=3): + return x.sum() + (a + b) / c + + df = DataFrame([[1, 2], [3, 4]]) + + if axis == 0: + expected = Series([5.0, 7.0]) + else: + expected = Series([4.0, 8.0]) + + result = df.agg(f, axis, *args, **kwargs) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("num_cols", [2, 3, 5]) +def test_frequency_is_original(num_cols, engine, request): + # GH 22150 + if engine == "numba": + mark = pytest.mark.xfail(reason="numba engine only supports numeric indices") + request.node.add_marker(mark) + index = pd.DatetimeIndex(["1950-06-30", "1952-10-24", "1953-05-29"]) + original = index.copy() + df = DataFrame(1, index=index, columns=range(num_cols)) + df.apply(lambda x: x, engine=engine) + assert index.freq == original.freq + + +def test_apply_datetime_tz_issue(engine, request): + # GH 29052 + + if engine == "numba": + mark = pytest.mark.xfail( + reason="numba engine doesn't support non-numeric indexes" + ) + request.node.add_marker(mark) + + timestamps = [ + Timestamp("2019-03-15 12:34:31.909000+0000", tz="UTC"), + Timestamp("2019-03-15 12:34:34.359000+0000", tz="UTC"), + Timestamp("2019-03-15 12:34:34.660000+0000", tz="UTC"), + ] + df = DataFrame(data=[0, 1, 2], index=timestamps) + result = df.apply(lambda x: x.name, axis=1, engine=engine) + expected = Series(index=timestamps, data=timestamps) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("df", [DataFrame({"A": ["a", None], "B": ["c", "d"]})]) +@pytest.mark.parametrize("method", ["min", "max", "sum"]) +def test_mixed_column_raises(df, method, using_infer_string): + # GH 16832 + if method == "sum": + msg = r'can only concatenate str \(not "int"\) to str|does not support' + else: + msg = "not supported between instances of 'str' and 'float'" + if not using_infer_string: + with pytest.raises(TypeError, match=msg): + getattr(df, method)() + else: + getattr(df, method)() + + +@pytest.mark.parametrize("col", [1, 1.0, True, "a", np.nan]) +def test_apply_dtype(col): + # GH 31466 + df = DataFrame([[1.0, col]], columns=["a", "b"]) + result = df.apply(lambda x: x.dtype) + expected = df.dtypes + + tm.assert_series_equal(result, expected) + + +def test_apply_mutating(): + # GH#35462 case where applied func pins a new BlockManager to a row + df = DataFrame({"a": range(10), "b": range(10, 20)}) + df_orig = df.copy() + + def func(row): + mgr = row._mgr + row.loc["a"] += 1 + assert row._mgr is not mgr + return row + + expected = df.copy() + expected["a"] += 1 + + result = df.apply(func, axis=1) + + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(df, df_orig) + + +def test_apply_empty_list_reduce(): + # GH#35683 get columns correct + df = DataFrame([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]], columns=["a", "b"]) + + result = df.apply(lambda x: [], result_type="reduce") + expected = Series({"a": [], "b": []}, dtype=object) + tm.assert_series_equal(result, expected) + + +def test_apply_no_suffix_index(engine, request): + # GH36189 + if engine == "numba": + mark = pytest.mark.xfail( + reason="numba engine doesn't support list-likes/dict-like callables" + ) + request.node.add_marker(mark) + pdf = DataFrame([[4, 9]] * 3, columns=["A", "B"]) + result = pdf.apply(["sum", lambda x: x.sum(), lambda x: x.sum()], engine=engine) + expected = DataFrame( + {"A": [12, 12, 12], "B": [27, 27, 27]}, index=["sum", "", ""] + ) + + tm.assert_frame_equal(result, expected) + + +def test_apply_raw_returns_string(engine): + # https://github.com/pandas-dev/pandas/issues/35940 + if engine == "numba": + pytest.skip("No object dtype support in numba") + df = DataFrame({"A": ["aa", "bbb"]}) + result = df.apply(lambda x: x[0], engine=engine, axis=1, raw=True) + expected = Series(["aa", "bbb"]) + tm.assert_series_equal(result, expected) + + +def test_aggregation_func_column_order(): + # GH40420: the result of .agg should have an index that is sorted + # according to the arguments provided to agg. + df = DataFrame( + [ + (1, 0, 0), + (2, 0, 0), + (3, 0, 0), + (4, 5, 4), + (5, 6, 6), + (6, 7, 7), + ], + columns=("att1", "att2", "att3"), + ) + + def sum_div2(s): + return s.sum() / 2 + + aggs = ["sum", sum_div2, "count", "min"] + result = df.agg(aggs) + expected = DataFrame( + { + "att1": [21.0, 10.5, 6.0, 1.0], + "att2": [18.0, 9.0, 6.0, 0.0], + "att3": [17.0, 8.5, 6.0, 0.0], + }, + index=["sum", "sum_div2", "count", "min"], + ) + tm.assert_frame_equal(result, expected) + + +def test_apply_getitem_axis_1(engine, request): + # GH 13427 + if engine == "numba": + mark = pytest.mark.xfail( + reason="numba engine not supporting duplicate index values" + ) + request.node.add_marker(mark) + df = DataFrame({"a": [0, 1, 2], "b": [1, 2, 3]}) + result = df[["a", "a"]].apply( + lambda x: x.iloc[0] + x.iloc[1], axis=1, engine=engine + ) + expected = Series([0, 2, 4]) + tm.assert_series_equal(result, expected) + + +def test_nuisance_depr_passes_through_warnings(): + # GH 43740 + # DataFrame.agg with list-likes may emit warnings for both individual + # args and for entire columns, but we only want to emit once. We + # catch and suppress the warnings for individual args, but need to make + # sure if some other warnings were raised, they get passed through to + # the user. + + def expected_warning(x): + warnings.warn("Hello, World!") + return x.sum() + + df = DataFrame({"a": [1, 2, 3]}) + with tm.assert_produces_warning(UserWarning, match="Hello, World!"): + df.agg([expected_warning]) + + +def test_apply_type(): + # GH 46719 + df = DataFrame( + {"col1": [3, "string", float], "col2": [0.25, datetime(2020, 1, 1), np.nan]}, + index=["a", "b", "c"], + ) + + # axis=0 + result = df.apply(type, axis=0) + expected = Series({"col1": Series, "col2": Series}) + tm.assert_series_equal(result, expected) + + # axis=1 + result = df.apply(type, axis=1) + expected = Series({"a": Series, "b": Series, "c": Series}) + tm.assert_series_equal(result, expected) + + +def test_apply_on_empty_dataframe(engine): + # GH 39111 + df = DataFrame({"a": [1, 2], "b": [3, 0]}) + result = df.head(0).apply(lambda x: max(x["a"], x["b"]), axis=1, engine=engine) + expected = Series([], dtype=np.float64) + tm.assert_series_equal(result, expected) + + +def test_apply_return_list(): + df = DataFrame({"a": [1, 2], "b": [2, 3]}) + result = df.apply(lambda x: [x.values]) + expected = DataFrame({"a": [[1, 2]], "b": [[2, 3]]}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "test, constant", + [ + ({"a": [1, 2, 3], "b": [1, 1, 1]}, {"a": [1, 2, 3], "b": [1]}), + ({"a": [2, 2, 2], "b": [1, 1, 1]}, {"a": [2], "b": [1]}), + ], +) +def test_unique_agg_type_is_series(test, constant): + # GH#22558 + df1 = DataFrame(test) + expected = Series(data=constant, index=["a", "b"], dtype="object") + aggregation = {"a": "unique", "b": "unique"} + + result = df1.agg(aggregation) + + tm.assert_series_equal(result, expected) + + +def test_any_apply_keyword_non_zero_axis_regression(): + # https://github.com/pandas-dev/pandas/issues/48656 + df = DataFrame({"A": [1, 2, 0], "B": [0, 2, 0], "C": [0, 0, 0]}) + expected = Series([True, True, False]) + tm.assert_series_equal(df.any(axis=1), expected) + + result = df.apply("any", axis=1) + tm.assert_series_equal(result, expected) + + result = df.apply("any", 1) + tm.assert_series_equal(result, expected) + + +def test_agg_mapping_func_deprecated(): + # GH 53325 + df = DataFrame({"x": [1, 2, 3]}) + + def foo1(x, a=1, c=0): + return x + a + c + + def foo2(x, b=2, c=0): + return x + b + c + + # single func already takes the vectorized path + result = df.agg(foo1, 0, 3, c=4) + expected = df + 7 + tm.assert_frame_equal(result, expected) + + result = df.agg([foo1, foo2], 0, 3, c=4) + expected = DataFrame( + [[8, 8], [9, 9], [10, 10]], columns=[["x", "x"], ["foo1", "foo2"]] + ) + tm.assert_frame_equal(result, expected) + + # TODO: the result below is wrong, should be fixed (GH53325) + result = df.agg({"x": foo1}, 0, 3, c=4) + expected = DataFrame([2, 3, 4], columns=["x"]) + tm.assert_frame_equal(result, expected) + + +def test_agg_std(): + df = DataFrame(np.arange(6).reshape(3, 2), columns=["A", "B"]) + + result = df.agg(np.std, ddof=1) + expected = Series({"A": 2.0, "B": 2.0}, dtype=float) + tm.assert_series_equal(result, expected) + + result = df.agg([np.std], ddof=1) + expected = DataFrame({"A": 2.0, "B": 2.0}, index=["std"]) + tm.assert_frame_equal(result, expected) + + +def test_agg_np_size(): + # GH#42203, GH#48328 + df = DataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]], columns=["A", "B", "C"]) + + result = df.agg({"A": [np.size]}) + expected = DataFrame({"A": [3]}, index=["size"]) + tm.assert_frame_equal(result, expected) + + result = df.agg({"A": np.size}) + expected = Series({"A": 3}) + tm.assert_series_equal(result, expected) + + result = df.agg({"A": [np.mean, np.size]}) + expected = DataFrame({"A": [4.0, 3.0]}, index=["mean", "size"]) + tm.assert_frame_equal(result, expected) + + +def test_agg_dist_like_and_nonunique_columns(): + # GH#51099 + df = DataFrame( + {"A": [None, 2, 3], "B": [1.0, np.nan, 3.0], "C": ["foo", None, "bar"]} + ) + df.columns = ["A", "A", "C"] + + result = df.agg({"A": "count"}) + expected = df["A"].count() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("engine_name", ["unknown", 25]) +def test_wrong_engine(engine_name): + with pytest.raises(ValueError, match="Unknown engine "): + DataFrame().apply(lambda x: x, engine=engine_name) diff --git a/pandas/tests/apply/test_frame_apply_relabeling.py b/pandas/tests/apply/test_frame_apply_relabeling.py new file mode 100644 index 0000000000000000000000000000000000000000..86918ec09aa97d7db9af0a8655e3273a53b7aad0 --- /dev/null +++ b/pandas/tests/apply/test_frame_apply_relabeling.py @@ -0,0 +1,105 @@ +import numpy as np + +import pandas as pd +import pandas._testing as tm + + +def test_agg_relabel(): + # GH 26513 + df = pd.DataFrame({"A": [1, 2, 1, 2], "B": [1, 2, 3, 4], "C": [3, 4, 5, 6]}) + + # simplest case with one column, one func + result = df.agg(foo=("B", "sum")) + expected = pd.DataFrame({"B": [10]}, index=pd.Index(["foo"])) + tm.assert_frame_equal(result, expected) + + # test on same column with different methods + result = df.agg(foo=("B", "sum"), bar=("B", "min")) + expected = pd.DataFrame({"B": [10, 1]}, index=pd.Index(["foo", "bar"])) + + tm.assert_frame_equal(result, expected) + + +def test_agg_relabel_multi_columns_multi_methods(): + # GH 26513, test on multiple columns with multiple methods + df = pd.DataFrame({"A": [1, 2, 1, 2], "B": [1, 2, 3, 4], "C": [3, 4, 5, 6]}) + result = df.agg( + foo=("A", "sum"), + bar=("B", "mean"), + cat=("A", "min"), + dat=("B", "max"), + f=("A", "max"), + g=("C", "min"), + ) + expected = pd.DataFrame( + { + "A": [6.0, np.nan, 1.0, np.nan, 2.0, np.nan], + "B": [np.nan, 2.5, np.nan, 4.0, np.nan, np.nan], + "C": [np.nan, np.nan, np.nan, np.nan, np.nan, 3.0], + }, + index=pd.Index(["foo", "bar", "cat", "dat", "f", "g"]), + ) + tm.assert_frame_equal(result, expected) + + +def test_agg_relabel_partial_functions(): + # GH 26513, test on partial, functools or more complex cases + df = pd.DataFrame({"A": [1, 2, 1, 2], "B": [1, 2, 3, 4], "C": [3, 4, 5, 6]}) + result = df.agg(foo=("A", np.mean), bar=("A", "mean"), cat=("A", min)) + expected = pd.DataFrame( + {"A": [1.5, 1.5, 1.0]}, index=pd.Index(["foo", "bar", "cat"]) + ) + tm.assert_frame_equal(result, expected) + + result = df.agg( + foo=("A", min), + bar=("B", np.min), + cat=("B", max), + dat=("C", "min"), + f=("B", np.sum), + kk=("B", lambda x: min(x)), + ) + expected = pd.DataFrame( + { + "A": [1.0, np.nan, np.nan, np.nan, np.nan, np.nan], + "B": [np.nan, 1.0, 4.0, np.nan, 10.0, 1.0], + "C": [np.nan, np.nan, np.nan, 3.0, np.nan, np.nan], + }, + index=pd.Index(["foo", "bar", "cat", "dat", "f", "kk"]), + ) + tm.assert_frame_equal(result, expected) + + +def test_agg_namedtuple(): + # GH 26513 + df = pd.DataFrame({"A": [0, 1], "B": [1, 2]}) + result = df.agg( + foo=pd.NamedAgg("B", "sum"), + bar=pd.NamedAgg("B", "min"), + cat=pd.NamedAgg(column="B", aggfunc="count"), + fft=pd.NamedAgg("B", aggfunc="max"), + ) + + expected = pd.DataFrame( + {"B": [3, 1, 2, 2]}, index=pd.Index(["foo", "bar", "cat", "fft"]) + ) + tm.assert_frame_equal(result, expected) + + result = df.agg( + foo=pd.NamedAgg("A", "min"), + bar=pd.NamedAgg(column="B", aggfunc="max"), + cat=pd.NamedAgg(column="A", aggfunc="max"), + ) + expected = pd.DataFrame( + {"A": [0.0, np.nan, 1.0], "B": [np.nan, 2.0, np.nan]}, + index=pd.Index(["foo", "bar", "cat"]), + ) + tm.assert_frame_equal(result, expected) + + +def test_reconstruct_func(): + # GH 28472, test to ensure reconstruct_func isn't moved; + # This method is used by other libraries (e.g. dask) + result = pd.core.apply.reconstruct_func("min") + expected = (False, "min", None, None) + tm.assert_equal(result, expected) diff --git a/pandas/tests/apply/test_frame_transform.py b/pandas/tests/apply/test_frame_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..558d76ae8fdc4b95d46bbe94e15822779bd7c53f --- /dev/null +++ b/pandas/tests/apply/test_frame_transform.py @@ -0,0 +1,264 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + MultiIndex, + Series, +) +import pandas._testing as tm +from pandas.tests.apply.common import frame_transform_kernels +from pandas.tests.frame.common import zip_frames + + +def unpack_obj(obj, klass, axis): + """ + Helper to ensure we have the right type of object for a test parametrized + over frame_or_series. + """ + if klass is not DataFrame: + obj = obj["A"] + if axis != 0: + pytest.skip(f"Test is only for DataFrame with axis={axis}") + return obj + + +def test_transform_ufunc(axis, float_frame, frame_or_series): + # GH 35964 + obj = unpack_obj(float_frame, frame_or_series, axis) + + with np.errstate(all="ignore"): + f_sqrt = np.sqrt(obj) + + # ufunc + result = obj.transform(np.sqrt, axis=axis) + expected = f_sqrt + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "ops, names", + [ + ([np.sqrt], ["sqrt"]), + ([np.abs, np.sqrt], ["absolute", "sqrt"]), + (np.array([np.sqrt]), ["sqrt"]), + (np.array([np.abs, np.sqrt]), ["absolute", "sqrt"]), + ], +) +def test_transform_listlike(axis, float_frame, ops, names): + # GH 35964 + other_axis = 1 if axis in {0, "index"} else 0 + with np.errstate(all="ignore"): + expected = zip_frames([op(float_frame) for op in ops], axis=other_axis) + if axis in {0, "index"}: + expected.columns = MultiIndex.from_product([float_frame.columns, names]) + else: + expected.index = MultiIndex.from_product([float_frame.index, names]) + result = float_frame.transform(ops, axis=axis) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("ops", [[], np.array([])]) +def test_transform_empty_listlike(float_frame, ops, frame_or_series): + obj = unpack_obj(float_frame, frame_or_series, 0) + + with pytest.raises(ValueError, match="No transform functions were provided"): + obj.transform(ops) + + +def test_transform_listlike_func_with_args(): + # GH 50624 + df = DataFrame({"x": [1, 2, 3]}) + + def foo1(x, a=1, c=0): + return x + a + c + + def foo2(x, b=2, c=0): + return x + b + c + + msg = r"foo1\(\) got an unexpected keyword argument 'b'" + with pytest.raises(TypeError, match=msg): + df.transform([foo1, foo2], 0, 3, b=3, c=4) + + result = df.transform([foo1, foo2], 0, 3, c=4) + expected = DataFrame( + [[8, 8], [9, 9], [10, 10]], + columns=MultiIndex.from_tuples([("x", "foo1"), ("x", "foo2")]), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("box", [dict, Series]) +def test_transform_dictlike(axis, float_frame, box): + # GH 35964 + if axis in (0, "index"): + e = float_frame.columns[0] + expected = float_frame[[e]].transform(np.abs) + else: + e = float_frame.index[0] + expected = float_frame.iloc[[0]].transform(np.abs) + result = float_frame.transform(box({e: np.abs}), axis=axis) + tm.assert_frame_equal(result, expected) + + +def test_transform_dictlike_mixed(): + # GH 40018 - mix of lists and non-lists in values of a dictionary + df = DataFrame({"a": [1, 2], "b": [1, 4], "c": [1, 4]}) + result = df.transform({"b": ["sqrt", "abs"], "c": "sqrt"}) + expected = DataFrame( + [[1.0, 1, 1.0], [2.0, 4, 2.0]], + columns=MultiIndex([("b", "c"), ("sqrt", "abs")], [(0, 0, 1), (0, 1, 0)]), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "ops", + [ + {}, + {"A": []}, + {"A": [], "B": "cumsum"}, + {"A": "cumsum", "B": []}, + {"A": [], "B": ["cumsum"]}, + {"A": ["cumsum"], "B": []}, + ], +) +def test_transform_empty_dictlike(float_frame, ops, frame_or_series): + obj = unpack_obj(float_frame, frame_or_series, 0) + + with pytest.raises(ValueError, match="No transform functions were provided"): + obj.transform(ops) + + +@pytest.mark.parametrize("use_apply", [True, False]) +def test_transform_udf(axis, float_frame, use_apply, frame_or_series): + # GH 35964 + obj = unpack_obj(float_frame, frame_or_series, axis) + + # transform uses UDF either via apply or passing the entire DataFrame + def func(x): + # transform is using apply iff x is not a DataFrame + if use_apply == isinstance(x, frame_or_series): + # Force transform to fallback + raise ValueError + return x + 1 + + result = obj.transform(func, axis=axis) + expected = obj + 1 + tm.assert_equal(result, expected) + + +wont_fail = ["ffill", "bfill", "fillna", "pad", "backfill", "shift"] +frame_kernels_raise = [x for x in frame_transform_kernels if x not in wont_fail] + + +@pytest.mark.parametrize("op", [*frame_kernels_raise, lambda x: x + 1]) +def test_transform_bad_dtype(op, frame_or_series, request): + # GH 35964 + if op == "ngroup": + request.applymarker( + pytest.mark.xfail(raises=ValueError, reason="ngroup not valid for NDFrame") + ) + + obj = DataFrame({"A": 3 * [object]}) # DataFrame that will fail on most transforms + obj = tm.get_obj(obj, frame_or_series) + error = TypeError + msg = "|".join( + [ + "not supported between instances of 'type' and 'type'", + "unsupported operand type", + ] + ) + + with pytest.raises(error, match=msg): + obj.transform(op) + with pytest.raises(error, match=msg): + obj.transform([op]) + with pytest.raises(error, match=msg): + obj.transform({"A": op}) + with pytest.raises(error, match=msg): + obj.transform({"A": [op]}) + + +@pytest.mark.parametrize("op", frame_kernels_raise) +def test_transform_failure_typeerror(request, op): + # GH 35964 + + if op == "ngroup": + request.applymarker( + pytest.mark.xfail(raises=ValueError, reason="ngroup not valid for NDFrame") + ) + + # Using object makes most transform kernels fail + df = DataFrame({"A": 3 * [object], "B": [1, 2, 3]}) + error = TypeError + msg = "|".join( + [ + "not supported between instances of 'type' and 'type'", + "unsupported operand type", + ] + ) + + with pytest.raises(error, match=msg): + df.transform([op]) + + with pytest.raises(error, match=msg): + df.transform({"A": op, "B": op}) + + with pytest.raises(error, match=msg): + df.transform({"A": [op], "B": [op]}) + + with pytest.raises(error, match=msg): + df.transform({"A": [op, "shift"], "B": [op]}) + + +def test_transform_failure_valueerror(): + # GH 40211 + def op(x): + if np.sum(np.sum(x)) < 10: + raise ValueError + return x + + df = DataFrame({"A": [1, 2, 3], "B": [400, 500, 600]}) + msg = "Transform function failed" + + with pytest.raises(ValueError, match=msg): + df.transform([op]) + + with pytest.raises(ValueError, match=msg): + df.transform({"A": op, "B": op}) + + with pytest.raises(ValueError, match=msg): + df.transform({"A": [op], "B": [op]}) + + with pytest.raises(ValueError, match=msg): + df.transform({"A": [op, "shift"], "B": [op]}) + + +@pytest.mark.parametrize("use_apply", [True, False]) +def test_transform_passes_args(use_apply, frame_or_series): + # GH 35964 + # transform uses UDF either via apply or passing the entire DataFrame + expected_args = [1, 2] + expected_kwargs = {"c": 3} + + def f(x, a, b, c): + # transform is using apply iff x is not a DataFrame + if use_apply == isinstance(x, frame_or_series): + # Force transform to fallback + raise ValueError + assert [a, b] == expected_args + assert c == expected_kwargs["c"] + return x + + frame_or_series([1]).transform(f, 0, *expected_args, **expected_kwargs) + + +def test_transform_empty_dataframe(): + # https://github.com/pandas-dev/pandas/issues/39636 + df = DataFrame([], columns=["col1", "col2"]) + result = df.transform(lambda x: x + 10) + tm.assert_frame_equal(result, df) + + result = df["col1"].transform(lambda x: x + 10) + tm.assert_series_equal(result, df["col1"]) diff --git a/pandas/tests/apply/test_invalid_arg.py b/pandas/tests/apply/test_invalid_arg.py new file mode 100644 index 0000000000000000000000000000000000000000..0503bf9166ec7b6c06edf95293cc286140787d60 --- /dev/null +++ b/pandas/tests/apply/test_invalid_arg.py @@ -0,0 +1,375 @@ +# Tests specifically aimed at detecting bad arguments. +# This file is organized by reason for exception. +# 1. always invalid argument values +# 2. missing column(s) +# 3. incompatible ops/dtype/args/kwargs +# 4. invalid result shape/type +# If your test does not fit into one of these categories, add to this list. + +from itertools import chain +import re + +import numpy as np +import pytest + +from pandas.errors import SpecificationError + +from pandas import ( + DataFrame, + Series, + date_range, +) +import pandas._testing as tm + + +@pytest.mark.parametrize("result_type", ["foo", 1]) +def test_result_type_error(result_type): + # allowed result_type + df = DataFrame( + np.tile(np.arange(3, dtype="int64"), 6).reshape(6, -1) + 1, + columns=["A", "B", "C"], + ) + + msg = ( + "invalid value for result_type, must be one of " + "{None, 'reduce', 'broadcast', 'expand'}" + ) + with pytest.raises(ValueError, match=msg): + df.apply(lambda x: [1, 2, 3], axis=1, result_type=result_type) + + +def test_apply_invalid_axis_value(): + df = DataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]], index=["a", "a", "c"]) + msg = "No axis named 2 for object type DataFrame" + with pytest.raises(ValueError, match=msg): + df.apply(lambda x: x, 2) + + +def test_agg_raises(): + # GH 26513 + df = DataFrame({"A": [0, 1], "B": [1, 2]}) + msg = "Must provide" + + with pytest.raises(TypeError, match=msg): + df.agg() + + +def test_map_with_invalid_na_action_raises(): + # https://github.com/pandas-dev/pandas/issues/32815 + s = Series([1, 2, 3]) + msg = "na_action must either be 'ignore' or None" + with pytest.raises(ValueError, match=msg): + s.map(lambda x: x, na_action="____") + + +@pytest.mark.parametrize("input_na_action", ["____", True]) +def test_map_arg_is_dict_with_invalid_na_action_raises(input_na_action): + # https://github.com/pandas-dev/pandas/issues/46588 + s = Series([1, 2, 3]) + msg = f"na_action must either be 'ignore' or None, {input_na_action} was passed" + with pytest.raises(ValueError, match=msg): + s.map({1: 2}, na_action=input_na_action) + + +@pytest.mark.parametrize("method", ["apply", "agg", "transform"]) +@pytest.mark.parametrize("func", [{"A": {"B": "sum"}}, {"A": {"B": ["sum"]}}]) +def test_nested_renamer(frame_or_series, method, func): + # GH 35964 + obj = frame_or_series({"A": [1]}) + match = "nested renamer is not supported" + with pytest.raises(SpecificationError, match=match): + getattr(obj, method)(func) + + +@pytest.mark.parametrize( + "renamer", + [{"foo": ["min", "max"]}, {"foo": ["min", "max"], "bar": ["sum", "mean"]}], +) +def test_series_nested_renamer(renamer): + s = Series(range(6), dtype="int64", name="series") + msg = "nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + s.agg(renamer) + + +def test_apply_dict_depr(): + tsdf = DataFrame( + np.random.default_rng(2).standard_normal((10, 3)), + columns=["A", "B", "C"], + index=date_range("1/1/2000", periods=10), + ) + msg = "nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + tsdf.A.agg({"foo": ["sum", "mean"]}) + + +@pytest.mark.parametrize("method", ["agg", "transform"]) +def test_dict_nested_renaming_depr(method): + df = DataFrame({"A": range(5), "B": 5}) + + # nested renaming + msg = r"nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + getattr(df, method)({"A": {"foo": "min"}, "B": {"bar": "max"}}) + + +@pytest.mark.parametrize("method", ["apply", "agg", "transform"]) +@pytest.mark.parametrize("func", [{"B": "sum"}, {"B": ["sum"]}]) +def test_missing_column(method, func): + # GH 40004 + obj = DataFrame({"A": [1]}) + msg = r"Label\(s\) \['B'\] do not exist" + with pytest.raises(KeyError, match=msg): + getattr(obj, method)(func) + + +def test_transform_mixed_column_name_dtypes(): + # GH39025 + df = DataFrame({"a": ["1"]}) + msg = r"Label\(s\) \[1, 'b'\] do not exist" + with pytest.raises(KeyError, match=msg): + df.transform({"a": int, 1: str, "b": int}) + + +@pytest.mark.parametrize( + "how, args", [("pct_change", ()), ("nsmallest", (1, ["a", "b"])), ("tail", 1)] +) +def test_apply_str_axis_1_raises(how, args): + # GH 39211 - some ops don't support axis=1 + df = DataFrame({"a": [1, 2], "b": [3, 4]}) + msg = f"Operation {how} does not support axis=1" + with pytest.raises(ValueError, match=msg): + df.apply(how, axis=1, args=args) + + +def test_transform_axis_1_raises(): + # GH 35964 + msg = "No axis named 1 for object type Series" + with pytest.raises(ValueError, match=msg): + Series([1]).transform("sum", axis=1) + + +def test_apply_modify_traceback(): + data = DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + data.loc[4, "C"] = np.nan + + def transform(row): + if row["C"].startswith("shin") and row["A"] == "foo": + row["D"] = 7 + return row + + msg = "'float' object has no attribute 'startswith'" + with pytest.raises(AttributeError, match=msg): + data.apply(transform, axis=1) + + +@pytest.mark.parametrize( + "df, func, expected", + tm.get_cython_table_params( + DataFrame([["a", "b"], ["b", "a"]]), [["cumprod", TypeError]] + ), +) +def test_agg_cython_table_raises_frame(df, func, expected, axis, using_infer_string): + # GH 21224 + if using_infer_string: + expected = (expected, NotImplementedError) + + msg = ( + "can't multiply sequence by non-int of type 'str'" + "|cannot perform cumprod with type str" # NotImplementedError python backend + "|operation 'cumprod' not supported for dtype 'str'" # TypeError pyarrow + ) + warn = None if isinstance(func, str) else FutureWarning + with pytest.raises(expected, match=msg): + with tm.assert_produces_warning(warn, match="using DataFrame.cumprod"): + df.agg(func, axis=axis) + + +@pytest.mark.parametrize( + "series, func, expected", + chain( + tm.get_cython_table_params( + Series("a b c".split()), + [ + ("mean", TypeError), # mean raises TypeError + ("prod", TypeError), + ("std", TypeError), + ("var", TypeError), + ("median", TypeError), + ("cumprod", TypeError), + ], + ) + ), +) +def test_agg_cython_table_raises_series(series, func, expected, using_infer_string): + # GH21224 + msg = r"[Cc]ould not convert|can't multiply sequence by non-int of type" + if func == "median" or func is np.nanmedian or func is np.median: + msg = r"Cannot convert \['a' 'b' 'c'\] to numeric" + + if using_infer_string and func == "cumprod": + expected = (expected, NotImplementedError) + + msg = ( + msg + "|does not support|has no kernel|Cannot perform|cannot perform|operation" + ) + warn = None if isinstance(func, str) else FutureWarning + + with pytest.raises(expected, match=msg): + # e.g. Series('a b'.split()).cumprod() will raise + with tm.assert_produces_warning(warn, match="is currently using Series.*"): + series.agg(func) + + +def test_agg_none_to_type(): + # GH 40543 + df = DataFrame({"a": [None]}) + msg = re.escape("int() argument must be a string") + with pytest.raises(TypeError, match=msg): + df.agg({"a": lambda x: int(x.iloc[0])}) + + +def test_transform_none_to_type(): + # GH#34377 + df = DataFrame({"a": [None]}) + msg = "argument must be a" + with pytest.raises(TypeError, match=msg): + df.transform({"a": lambda x: int(x.iloc[0])}) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: np.array([1, 2]).reshape(-1, 2), + lambda x: [1, 2], + lambda x: Series([1, 2]), + ], +) +def test_apply_broadcast_error(func): + df = DataFrame( + np.tile(np.arange(3, dtype="int64"), 6).reshape(6, -1) + 1, + columns=["A", "B", "C"], + ) + + # > 1 ndim + msg = "too many dims to broadcast|cannot broadcast result" + with pytest.raises(ValueError, match=msg): + df.apply(func, axis=1, result_type="broadcast") + + +def test_transform_and_agg_err_agg(axis, float_frame): + # cannot both transform and agg + msg = "cannot combine transform and aggregation operations" + with pytest.raises(ValueError, match=msg): + with np.errstate(all="ignore"): + float_frame.agg(["max", "sqrt"], axis=axis) + + +@pytest.mark.filterwarnings("ignore::FutureWarning") # GH53325 +@pytest.mark.parametrize( + "func, msg", + [ + (["sqrt", "max"], "cannot combine transform and aggregation"), + ( + {"foo": np.sqrt, "bar": "sum"}, + "cannot perform both aggregation and transformation", + ), + ], +) +def test_transform_and_agg_err_series(string_series, func, msg): + # we are trying to transform with an aggregator + with pytest.raises(ValueError, match=msg): + with np.errstate(all="ignore"): + string_series.agg(func) + + +@pytest.mark.parametrize("func", [["max", "min"], ["max", "sqrt"]]) +def test_transform_wont_agg_frame(axis, float_frame, func): + # GH 35964 + # cannot both transform and agg + msg = "Function did not transform" + with pytest.raises(ValueError, match=msg): + float_frame.transform(func, axis=axis) + + +@pytest.mark.parametrize("func", [["min", "max"], ["sqrt", "max"]]) +def test_transform_wont_agg_series(string_series, func): + # GH 35964 + # we are trying to transform with an aggregator + msg = "Function did not transform" + + with pytest.raises(ValueError, match=msg): + string_series.transform(func) + + +@pytest.mark.parametrize( + "op_wrapper", [lambda x: x, lambda x: [x], lambda x: {"A": x}, lambda x: {"A": [x]}] +) +def test_transform_reducer_raises(all_reductions, frame_or_series, op_wrapper): + # GH 35964 + op = op_wrapper(all_reductions) + + obj = DataFrame({"A": [1, 2, 3]}) + obj = tm.get_obj(obj, frame_or_series) + + msg = "Function did not transform" + with pytest.raises(ValueError, match=msg): + obj.transform(op) + + +def test_transform_missing_labels_raises(): + # GH 58474 + df = DataFrame({"foo": [2, 4, 6], "bar": [1, 2, 3]}, index=["A", "B", "C"]) + msg = r"Label\(s\) \['A', 'B'\] do not exist" + with pytest.raises(KeyError, match=msg): + df.transform({"A": lambda x: x + 2, "B": lambda x: x * 2}, axis=0) + + msg = r"Label\(s\) \['bar', 'foo'\] do not exist" + with pytest.raises(KeyError, match=msg): + df.transform({"foo": lambda x: x + 2, "bar": lambda x: x * 2}, axis=1) diff --git a/pandas/tests/apply/test_numba.py b/pandas/tests/apply/test_numba.py new file mode 100644 index 0000000000000000000000000000000000000000..75bc3f5b74b9deff5587a6c0b0a3c25a266f9a1e --- /dev/null +++ b/pandas/tests/apply/test_numba.py @@ -0,0 +1,129 @@ +import numpy as np +import pytest + +from pandas.compat import is_platform_arm +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Index, +) +import pandas._testing as tm +from pandas.util.version import Version + +pytestmark = [td.skip_if_no("numba"), pytest.mark.single_cpu, pytest.mark.skipif()] + +numba = pytest.importorskip("numba") +pytestmark.append( + pytest.mark.skipif( + Version(numba.__version__) == Version("0.61") and is_platform_arm(), + reason=f"Segfaults on ARM platforms with numba {numba.__version__}", + ) +) + + +@pytest.fixture(params=[0, 1]) +def apply_axis(request): + return request.param + + +def test_numba_vs_python_noop(float_frame, apply_axis): + func = lambda x: x + result = float_frame.apply(func, engine="numba", axis=apply_axis) + expected = float_frame.apply(func, engine="python", axis=apply_axis) + tm.assert_frame_equal(result, expected) + + +def test_numba_vs_python_string_index(): + # GH#56189 + df = DataFrame( + 1, + index=Index(["a", "b"], dtype=pd.StringDtype(na_value=np.nan)), + columns=Index(["x", "y"], dtype=pd.StringDtype(na_value=np.nan)), + ) + func = lambda x: x + result = df.apply(func, engine="numba", axis=0) + expected = df.apply(func, engine="python", axis=0) + tm.assert_frame_equal( + result, expected, check_column_type=False, check_index_type=False + ) + + +def test_numba_vs_python_indexing(): + frame = DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7.0, 8.0, 9.0]}, + index=Index(["A", "B", "C"]), + ) + row_func = lambda x: x["c"] + result = frame.apply(row_func, engine="numba", axis=1) + expected = frame.apply(row_func, engine="python", axis=1) + tm.assert_series_equal(result, expected) + + col_func = lambda x: x["A"] + result = frame.apply(col_func, engine="numba", axis=0) + expected = frame.apply(col_func, engine="python", axis=0) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "reduction", + [lambda x: x.mean(), lambda x: x.min(), lambda x: x.max(), lambda x: x.sum()], +) +def test_numba_vs_python_reductions(reduction, apply_axis): + df = DataFrame(np.ones((4, 4), dtype=np.float64)) + result = df.apply(reduction, engine="numba", axis=apply_axis) + expected = df.apply(reduction, engine="python", axis=apply_axis) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("colnames", [[1, 2, 3], [1.0, 2.0, 3.0]]) +def test_numba_numeric_colnames(colnames): + # Check that numeric column names lower properly and can be indexed on + df = DataFrame( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int64), columns=colnames + ) + first_col = colnames[0] + f = lambda x: x[first_col] # Get the first column + result = df.apply(f, engine="numba", axis=1) + expected = df.apply(f, engine="python", axis=1) + tm.assert_series_equal(result, expected) + + +def test_numba_parallel_unsupported(float_frame): + f = lambda x: x + with pytest.raises( + NotImplementedError, + match="Parallel apply is not supported when raw=False and engine='numba'", + ): + float_frame.apply(f, engine="numba", engine_kwargs={"parallel": True}) + + +def test_numba_nonunique_unsupported(apply_axis): + f = lambda x: x + df = DataFrame({"a": [1, 2]}, index=Index(["a", "a"])) + with pytest.raises( + NotImplementedError, + match="The index/columns must be unique when raw=False and engine='numba'", + ): + df.apply(f, engine="numba", axis=apply_axis) + + +def test_numba_unsupported_dtypes(apply_axis): + pytest.importorskip("pyarrow") + f = lambda x: x + df = DataFrame({"a": [1, 2], "b": ["a", "b"], "c": [4, 5]}) + df["c"] = df["c"].astype("double[pyarrow]") + + with pytest.raises( + ValueError, + match="Column b must have a numeric dtype. Found 'object|str' instead", + ): + df.apply(f, engine="numba", axis=apply_axis) + + with pytest.raises( + ValueError, + match="Column c is backed by an extension array, " + "which is not supported by the numba engine.", + ): + df["c"].to_frame().apply(f, engine="numba", axis=apply_axis) diff --git a/pandas/tests/apply/test_series_apply.py b/pandas/tests/apply/test_series_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..cea6fb793c0c7b4687bbacc6b57b5e13dd7a2aee --- /dev/null +++ b/pandas/tests/apply/test_series_apply.py @@ -0,0 +1,669 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + concat, + date_range, + timedelta_range, +) +import pandas._testing as tm +from pandas.tests.apply.common import series_transform_kernels + + +@pytest.fixture(params=[False, "compat"]) +def by_row(request): + return request.param + + +def test_series_map_box_timedelta(by_row): + # GH#11349 + ser = Series(timedelta_range("1 day 1 s", periods=3, freq="h")) + + def f(x): + return x.total_seconds() if by_row else x.dt.total_seconds() + + result = ser.apply(f, by_row=by_row) + + expected = ser.map(lambda x: x.total_seconds()) + tm.assert_series_equal(result, expected) + + expected = Series([86401.0, 90001.0, 93601.0]) + tm.assert_series_equal(result, expected) + + +def test_apply(datetime_series, by_row): + result = datetime_series.apply(np.sqrt, by_row=by_row) + with np.errstate(all="ignore"): + expected = np.sqrt(datetime_series) + tm.assert_series_equal(result, expected) + + # element-wise apply (ufunc) + result = datetime_series.apply(np.exp, by_row=by_row) + expected = np.exp(datetime_series) + tm.assert_series_equal(result, expected) + + # empty series + s = Series(dtype=object, name="foo", index=Index([], name="bar")) + rs = s.apply(lambda x: x, by_row=by_row) + tm.assert_series_equal(s, rs) + + # check all metadata (GH 9322) + assert s is not rs + assert s.index is rs.index + assert s.dtype == rs.dtype + assert s.name == rs.name + + # index but no data + s = Series(index=[1, 2, 3], dtype=np.float64) + rs = s.apply(lambda x: x, by_row=by_row) + tm.assert_series_equal(s, rs) + + +def test_apply_map_same_length_inference_bug(): + s = Series([1, 2]) + + def f(x): + return (x, x + 1) + + result = s.apply(f, by_row="compat") + expected = s.map(f) + tm.assert_series_equal(result, expected) + + +def test_apply_args(): + s = Series(["foo,bar"]) + + result = s.apply(str.split, args=(",",)) + assert result[0] == ["foo", "bar"] + assert isinstance(result[0], list) + + +@pytest.mark.parametrize( + "args, kwargs, increment", + [((), {}, 0), ((), {"a": 1}, 1), ((2, 3), {}, 32), ((1,), {"c": 2}, 201)], +) +def test_agg_args(args, kwargs, increment): + # GH 43357 + def f(x, a=0, b=0, c=0): + return x + a + 10 * b + 100 * c + + s = Series([1, 2]) + result = s.agg(f, 0, *args, **kwargs) + expected = s + increment + tm.assert_series_equal(result, expected) + + +def test_agg_mapping_func_deprecated(): + # GH 53325 + s = Series([1, 2, 3]) + + def foo1(x, a=1, c=0): + return x + a + c + + def foo2(x, b=2, c=0): + return x + b + c + + s.agg(foo1, 0, 3, c=4) + s.agg([foo1, foo2], 0, 3, c=4) + s.agg({"a": foo1, "b": foo2}, 0, 3, c=4) + + +def test_series_apply_map_box_timestamps(by_row): + # GH#2689, GH#2627 + ser = Series(date_range("1/1/2000", periods=10)) + + def func(x): + return (x.hour, x.day, x.month) + + if not by_row: + msg = "Series' object has no attribute 'hour'" + with pytest.raises(AttributeError, match=msg): + ser.apply(func, by_row=by_row) + return + + result = ser.apply(func, by_row=by_row) + expected = ser.map(func) + tm.assert_series_equal(result, expected) + + +def test_apply_box_dt64(): + # ufunc will not be boxed. Same test cases as the test_map_box + vals = [pd.Timestamp("2011-01-01"), pd.Timestamp("2011-01-02")] + ser = Series(vals, dtype="M8[ns]") + assert ser.dtype == "datetime64[ns]" + # boxed value must be Timestamp instance + res = ser.apply(lambda x: f"{type(x).__name__}_{x.day}_{x.tz}", by_row="compat") + exp = Series(["Timestamp_1_None", "Timestamp_2_None"]) + tm.assert_series_equal(res, exp) + + +def test_apply_box_dt64tz(): + vals = [ + pd.Timestamp("2011-01-01", tz="US/Eastern"), + pd.Timestamp("2011-01-02", tz="US/Eastern"), + ] + ser = Series(vals, dtype="M8[ns, US/Eastern]") + assert ser.dtype == "datetime64[ns, US/Eastern]" + res = ser.apply(lambda x: f"{type(x).__name__}_{x.day}_{x.tz}", by_row="compat") + exp = Series(["Timestamp_1_US/Eastern", "Timestamp_2_US/Eastern"]) + tm.assert_series_equal(res, exp) + + +def test_apply_box_td64(): + # timedelta + vals = [pd.Timedelta("1 days"), pd.Timedelta("2 days")] + ser = Series(vals) + assert ser.dtype == "timedelta64[us]" + res = ser.apply(lambda x: f"{type(x).__name__}_{x.days}", by_row="compat") + exp = Series(["Timedelta_1", "Timedelta_2"]) + tm.assert_series_equal(res, exp) + + +def test_apply_box_period(): + # period + vals = [pd.Period("2011-01-01", freq="M"), pd.Period("2011-01-02", freq="M")] + ser = Series(vals) + assert ser.dtype == "Period[M]" + res = ser.apply(lambda x: f"{type(x).__name__}_{x.freqstr}", by_row="compat") + exp = Series(["Period_M", "Period_M"]) + tm.assert_series_equal(res, exp) + + +def test_apply_datetimetz(by_row): + values = date_range("2011-01-01", "2011-01-02", freq="h").tz_localize("Asia/Tokyo") + s = Series(values, name="XX") + + result = s.apply(lambda x: x + pd.offsets.Day(), by_row=by_row) + exp_values = date_range("2011-01-02", "2011-01-03", freq="h").tz_localize( + "Asia/Tokyo" + ) + exp = Series(exp_values, name="XX") + tm.assert_series_equal(result, exp) + + result = s.apply(lambda x: x.hour if by_row else x.dt.hour, by_row=by_row) + exp = Series([*list(range(24)), 0], name="XX", dtype="int64" if by_row else "int32") + tm.assert_series_equal(result, exp) + + # not vectorized + def f(x): + return str(x.tz) if by_row else str(x.dt.tz) + + result = s.apply(f, by_row=by_row) + if by_row: + exp = Series(["Asia/Tokyo"] * 25, name="XX") + tm.assert_series_equal(result, exp) + else: + assert result == "Asia/Tokyo" + + +def test_apply_categorical(by_row, using_infer_string): + values = pd.Categorical(list("ABBABCD"), categories=list("DCBA"), ordered=True) + ser = Series(values, name="XX", index=list("abcdefg")) + + if not by_row: + msg = "Series' object has no attribute 'lower" + with pytest.raises(AttributeError, match=msg): + ser.apply(lambda x: x.lower(), by_row=by_row) + assert ser.apply(lambda x: "A", by_row=by_row) == "A" + return + + result = ser.apply(lambda x: x.lower(), by_row=by_row) + + # should be categorical dtype when the number of categories are + # the same + values = pd.Categorical(list("abbabcd"), categories=list("dcba"), ordered=True) + exp = Series(values, name="XX", index=list("abcdefg")) + tm.assert_series_equal(result, exp) + tm.assert_categorical_equal(result.values, exp.values) + + result = ser.apply(lambda x: "A") + exp = Series(["A"] * 7, name="XX", index=list("abcdefg")) + tm.assert_series_equal(result, exp) + assert result.dtype == object if not using_infer_string else "str" + + +@pytest.mark.parametrize("series", [["1-1", "1-1", np.nan], ["1-1", "1-2", np.nan]]) +def test_apply_categorical_with_nan_values(series, by_row): + # GH 20714 bug fixed in: GH 24275 + s = Series(series, dtype="category") + if not by_row: + msg = "'Series' object has no attribute 'split'" + with pytest.raises(AttributeError, match=msg): + s.apply(lambda x: x.split("-")[0], by_row=by_row) + return + # NaN for cat dtype fixed in (GH 59966) + result = s.apply(lambda x: x.split("-")[0] if pd.notna(x) else False, by_row=by_row) + result = result.astype(object) + expected = Series(["1", "1", False], dtype="category") + expected = expected.astype(object) + tm.assert_series_equal(result, expected) + + +def test_apply_empty_integer_series_with_datetime_index(by_row): + # GH 21245 + s = Series([], index=date_range(start="2018-01-01", periods=0), dtype=int) + result = s.apply(lambda x: x, by_row=by_row) + tm.assert_series_equal(result, s) + + +def test_apply_dataframe_iloc(): + uintDF = DataFrame(np.uint64([1, 2, 3, 4, 5]), columns=["Numbers"]) + indexDF = DataFrame([2, 3, 2, 1, 2], columns=["Indices"]) + + def retrieve(targetRow, targetDF): + val = targetDF["Numbers"].iloc[targetRow] + return val + + result = indexDF["Indices"].apply(retrieve, args=(uintDF,)) + expected = Series([3, 4, 3, 2, 3], name="Indices", dtype="uint64") + tm.assert_series_equal(result, expected) + + +def test_transform(string_series, by_row): + # transforming functions + + with np.errstate(all="ignore"): + f_sqrt = np.sqrt(string_series) + f_abs = np.abs(string_series) + + # ufunc + result = string_series.apply(np.sqrt, by_row=by_row) + expected = f_sqrt.copy() + tm.assert_series_equal(result, expected) + + # list-like + result = string_series.apply([np.sqrt], by_row=by_row) + expected = f_sqrt.to_frame().copy() + expected.columns = ["sqrt"] + tm.assert_frame_equal(result, expected) + + result = string_series.apply(["sqrt"], by_row=by_row) + tm.assert_frame_equal(result, expected) + + # multiple items in list + # these are in the order as if we are applying both functions per + # series and then concatting + expected = concat([f_sqrt, f_abs], axis=1) + expected.columns = ["sqrt", "absolute"] + result = string_series.apply([np.sqrt, np.abs], by_row=by_row) + tm.assert_frame_equal(result, expected) + + # dict, provide renaming + expected = concat([f_sqrt, f_abs], axis=1) + expected.columns = ["foo", "bar"] + expected = expected.unstack().rename("series") + + result = string_series.apply({"foo": np.sqrt, "bar": np.abs}, by_row=by_row) + tm.assert_series_equal(result.reindex_like(expected), expected) + + +@pytest.mark.parametrize("op", series_transform_kernels) +def test_transform_partial_failure(op, request): + # GH 35964 + if op in ("ffill", "bfill", "shift"): + request.applymarker( + pytest.mark.xfail(reason=f"{op} is successful on any dtype") + ) + + # Using object makes most transform kernels fail + ser = Series(3 * [object]) + + if op in ("fillna", "ngroup"): + error = ValueError + msg = "Transform function failed" + else: + error = TypeError + msg = "|".join( + [ + "not supported between instances of 'type' and 'type'", + "unsupported operand type", + ] + ) + + with pytest.raises(error, match=msg): + ser.transform([op, "shift"]) + + with pytest.raises(error, match=msg): + ser.transform({"A": op, "B": "shift"}) + + with pytest.raises(error, match=msg): + ser.transform({"A": [op], "B": ["shift"]}) + + with pytest.raises(error, match=msg): + ser.transform({"A": [op, "shift"], "B": [op]}) + + +def test_transform_partial_failure_valueerror(): + # GH 40211 + def noop(x): + return x + + def raising_op(_): + raise ValueError + + ser = Series(3 * [object]) + msg = "Transform function failed" + + with pytest.raises(ValueError, match=msg): + ser.transform([noop, raising_op]) + + with pytest.raises(ValueError, match=msg): + ser.transform({"A": raising_op, "B": noop}) + + with pytest.raises(ValueError, match=msg): + ser.transform({"A": [raising_op], "B": [noop]}) + + with pytest.raises(ValueError, match=msg): + ser.transform({"A": [noop, raising_op], "B": [noop]}) + + +def test_demo(): + # demonstration tests + s = Series(range(6), dtype="int64", name="series") + + result = s.agg(["min", "max"]) + expected = Series([0, 5], index=["min", "max"], name="series") + tm.assert_series_equal(result, expected) + + result = s.agg({"foo": "min"}) + expected = Series([0], index=["foo"], name="series") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", [str, lambda x: str(x)]) +def test_apply_map_evaluate_lambdas_the_same(string_series, func, by_row, engine): + # test that we are evaluating row-by-row first if by_row="compat" + # else vectorized evaluation + result = string_series.apply(func, by_row=by_row) + + if by_row: + expected = string_series.map(func, engine=engine) + tm.assert_series_equal(result, expected) + else: + assert result == str(string_series) + + +def test_agg_evaluate_lambdas(string_series): + # GH53325 + result = string_series.agg(lambda x: type(x)) + assert result is Series + + result = string_series.agg(type) + assert result is Series + + +@pytest.mark.parametrize("op_name", ["agg", "apply"]) +def test_with_nested_series(datetime_series, op_name): + # GH 2316 & GH52123 + # .agg with a reducer and a transform, what to do + result = getattr(datetime_series, op_name)( + lambda x: Series([x, x**2], index=["x", "x^2"]) + ) + if op_name == "apply": + expected = DataFrame({"x": datetime_series, "x^2": datetime_series**2}) + tm.assert_frame_equal(result, expected) + else: + expected = Series([datetime_series, datetime_series**2], index=["x", "x^2"]) + tm.assert_series_equal(result, expected) + + +def test_replicate_describe(string_series): + # this also tests a result set that is all scalars + expected = string_series.describe() + result = string_series.apply( + { + "count": "count", + "mean": "mean", + "std": "std", + "min": "min", + "25%": lambda x: x.quantile(0.25), + "50%": "median", + "75%": lambda x: x.quantile(0.75), + "max": "max", + }, + ) + tm.assert_series_equal(result, expected) + + +def test_reduce(string_series): + # reductions with named functions + result = string_series.agg(["sum", "mean"]) + expected = Series( + [string_series.sum(), string_series.mean()], + ["sum", "mean"], + name=string_series.name, + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "how, kwds", + [("agg", {}), ("apply", {"by_row": "compat"}), ("apply", {"by_row": False})], +) +def test_non_callable_aggregates(how, kwds): + # test agg using non-callable series attributes + # GH 39116 - expand to apply + s = Series([1, 2, None]) + + # Calling agg w/ just a string arg same as calling s.arg + result = getattr(s, how)("size", **kwds) + expected = s.size + assert result == expected + + # test when mixed w/ callable reducers + result = getattr(s, how)(["size", "count", "mean"], **kwds) + expected = Series({"size": 3.0, "count": 2.0, "mean": 1.5}) + tm.assert_series_equal(result, expected) + + result = getattr(s, how)({"size": "size", "count": "count", "mean": "mean"}, **kwds) + tm.assert_series_equal(result, expected) + + +def test_series_apply_no_suffix_index(by_row): + # GH36189 + s = Series([4] * 3) + result = s.apply(["sum", lambda x: x.sum(), lambda x: x.sum()], by_row=by_row) + expected = Series([12, 12, 12], index=["sum", "", ""]) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dti,exp", + [ + ( + Series([1, 2], index=pd.DatetimeIndex([0, 31536000000])), + DataFrame(np.repeat([[1, 2]], 2, axis=0), dtype="int64"), + ), + ( + Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ), + DataFrame(np.repeat([[1, 2]], 10, axis=0), dtype="int64"), + ), + ], +) +@pytest.mark.parametrize("aware", [True, False]) +def test_apply_series_on_date_time_index_aware_series(dti, exp, aware): + # GH 25959 + # Calling apply on a localized time series should not cause an error + if aware: + index = dti.tz_localize("UTC").index + else: + index = dti.index + result = Series(index).apply(lambda x: Series([1, 2])) + tm.assert_frame_equal(result, exp) + + +@pytest.mark.parametrize( + "by_row, expected", [("compat", Series(np.ones(10), dtype="int64")), (False, 1)] +) +def test_apply_scalar_on_date_time_index_aware_series(by_row, expected): + # GH 25959 + # Calling apply on a localized time series should not cause an error + series = Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10, tz="UTC"), + ) + result = Series(series.index).apply(lambda x: 1, by_row=by_row) + tm.assert_equal(result, expected) + + +def test_apply_to_timedelta(by_row): + list_of_valid_strings = ["00:00:01", "00:00:02"] + a = pd.to_timedelta(list_of_valid_strings) + b = Series(list_of_valid_strings).apply(pd.to_timedelta, by_row=by_row) + tm.assert_series_equal(Series(a), b) + + list_of_strings = ["00:00:01", np.nan, pd.NaT, pd.NaT] + + a = pd.to_timedelta(list_of_strings) + ser = Series(list_of_strings) + b = ser.apply(pd.to_timedelta, by_row=by_row) + tm.assert_series_equal(Series(a), b) + + +@pytest.mark.parametrize( + "ops, names", + [ + ([np.sum], ["sum"]), + ([np.sum, np.mean], ["sum", "mean"]), + (np.array([np.sum]), ["sum"]), + (np.array([np.sum, np.mean]), ["sum", "mean"]), + ], +) +@pytest.mark.parametrize( + "how, kwargs", + [["agg", {}], ["apply", {"by_row": "compat"}], ["apply", {"by_row": False}]], +) +def test_apply_listlike_reducer(string_series, ops, names, how, kwargs): + # GH 39140 + expected = Series( + {name: op(string_series) for name, op in zip(names, ops, strict=True)} + ) + expected.name = "series" + result = getattr(string_series, how)(ops, **kwargs) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "ops", + [ + {"A": np.sum}, + {"A": np.sum, "B": np.mean}, + Series({"A": np.sum}), + Series({"A": np.sum, "B": np.mean}), + ], +) +@pytest.mark.parametrize( + "how, kwargs", + [["agg", {}], ["apply", {"by_row": "compat"}], ["apply", {"by_row": False}]], +) +def test_apply_dictlike_reducer(string_series, ops, how, kwargs, by_row): + # GH 39140 + expected = Series({name: op(string_series) for name, op in ops.items()}) + expected.name = string_series.name + result = getattr(string_series, how)(ops, **kwargs) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "ops, names", + [ + ([np.sqrt], ["sqrt"]), + ([np.abs, np.sqrt], ["absolute", "sqrt"]), + (np.array([np.sqrt]), ["sqrt"]), + (np.array([np.abs, np.sqrt]), ["absolute", "sqrt"]), + ], +) +def test_apply_listlike_transformer(string_series, ops, names, by_row): + # GH 39140 + with np.errstate(all="ignore"): + expected = concat([op(string_series) for op in ops], axis=1) + expected.columns = names + result = string_series.apply(ops, by_row=by_row) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "ops, expected", + [ + ([lambda x: x], DataFrame({"": [1, 2, 3]})), + ([lambda x: x.sum()], Series([6], index=[""])), + ], +) +def test_apply_listlike_lambda(ops, expected, by_row): + # GH53400 + ser = Series([1, 2, 3]) + result = ser.apply(ops, by_row=by_row) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "ops", + [ + {"A": np.sqrt}, + {"A": np.sqrt, "B": np.exp}, + Series({"A": np.sqrt}), + Series({"A": np.sqrt, "B": np.exp}), + ], +) +def test_apply_dictlike_transformer(string_series, ops, by_row): + # GH 39140 + with np.errstate(all="ignore"): + expected = concat({name: op(string_series) for name, op in ops.items()}) + expected.name = string_series.name + result = string_series.apply(ops, by_row=by_row) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "ops, expected", + [ + ( + {"a": lambda x: x}, + Series([1, 2, 3], index=MultiIndex.from_arrays([["a"] * 3, range(3)])), + ), + ({"a": lambda x: x.sum()}, Series([6], index=["a"])), + ], +) +def test_apply_dictlike_lambda(ops, by_row, expected): + # GH53400 + ser = Series([1, 2, 3]) + result = ser.apply(ops, by_row=by_row) + tm.assert_equal(result, expected) + + +def test_apply_retains_column_name(by_row): + # GH 16380 + df = DataFrame({"x": range(3)}, Index(range(3), name="x")) + result = df.x.apply(lambda x: Series(range(x + 1), Index(range(x + 1), name="y"))) + expected = DataFrame( + [[0.0, np.nan, np.nan], [0.0, 1.0, np.nan], [0.0, 1.0, 2.0]], + columns=Index(range(3), name="y"), + index=Index(range(3), name="x"), + ) + tm.assert_frame_equal(result, expected) + + +def test_apply_type(): + # GH 46719 + s = Series([3, "string", float], index=["a", "b", "c"]) + result = s.apply(type) + expected = Series([int, str, type], index=["a", "b", "c"]) + tm.assert_series_equal(result, expected) + + +def test_series_apply_unpack_nested_data(): + # GH#55189 + ser = Series([[1, 2, 3], [4, 5, 6, 7]]) + result = ser.apply(lambda x: Series(x)) + expected = DataFrame({0: [1.0, 4.0], 1: [2.0, 5.0], 2: [3.0, 6.0], 3: [np.nan, 7]}) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/apply/test_series_apply_relabeling.py b/pandas/tests/apply/test_series_apply_relabeling.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a285e6eb38cc26da155755108ef2c814229384 --- /dev/null +++ b/pandas/tests/apply/test_series_apply_relabeling.py @@ -0,0 +1,33 @@ +import pandas as pd +import pandas._testing as tm + + +def test_relabel_no_duplicated_method(): + # this is to test there is no duplicated method used in agg + df = pd.DataFrame({"A": [1, 2, 1, 2], "B": [1, 2, 3, 4]}) + + result = df["A"].agg(foo="sum") + expected = df["A"].agg({"foo": "sum"}) + tm.assert_series_equal(result, expected) + + result = df["B"].agg(foo="min", bar="max") + expected = df["B"].agg({"foo": "min", "bar": "max"}) + tm.assert_series_equal(result, expected) + + result = df["B"].agg(foo=sum, bar=min, cat="max") + expected = df["B"].agg({"foo": sum, "bar": min, "cat": "max"}) + tm.assert_series_equal(result, expected) + + +def test_relabel_duplicated_method(): + # this is to test with nested renaming, duplicated method can be used + # if they are assigned with different new names + df = pd.DataFrame({"A": [1, 2, 1, 2], "B": [1, 2, 3, 4]}) + + result = df["A"].agg(foo="sum", bar="sum") + expected = pd.Series([6, 6], index=["foo", "bar"], name="A") + tm.assert_series_equal(result, expected) + + result = df["B"].agg(foo=min, bar="min") + expected = pd.Series([1, 1], index=["foo", "bar"], name="B") + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/apply/test_series_transform.py b/pandas/tests/apply/test_series_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..82592c4711ece5a7f4b6d421d743e1adbd78c345 --- /dev/null +++ b/pandas/tests/apply/test_series_transform.py @@ -0,0 +1,84 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + MultiIndex, + Series, + concat, +) +import pandas._testing as tm + + +@pytest.mark.parametrize( + "args, kwargs, increment", + [((), {}, 0), ((), {"a": 1}, 1), ((2, 3), {}, 32), ((1,), {"c": 2}, 201)], +) +def test_agg_args(args, kwargs, increment): + # GH 43357 + def f(x, a=0, b=0, c=0): + return x + a + 10 * b + 100 * c + + s = Series([1, 2]) + result = s.transform(f, 0, *args, **kwargs) + expected = s + increment + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "ops, names", + [ + ([np.sqrt], ["sqrt"]), + ([np.abs, np.sqrt], ["absolute", "sqrt"]), + (np.array([np.sqrt]), ["sqrt"]), + (np.array([np.abs, np.sqrt]), ["absolute", "sqrt"]), + ], +) +def test_transform_listlike(string_series, ops, names): + # GH 35964 + with np.errstate(all="ignore"): + expected = concat([op(string_series) for op in ops], axis=1) + expected.columns = names + result = string_series.transform(ops) + tm.assert_frame_equal(result, expected) + + +def test_transform_listlike_func_with_args(): + # GH 50624 + + s = Series([1, 2, 3]) + + def foo1(x, a=1, c=0): + return x + a + c + + def foo2(x, b=2, c=0): + return x + b + c + + msg = r"foo1\(\) got an unexpected keyword argument 'b'" + with pytest.raises(TypeError, match=msg): + s.transform([foo1, foo2], 0, 3, b=3, c=4) + + result = s.transform([foo1, foo2], 0, 3, c=4) + expected = DataFrame({"foo1": [8, 9, 10], "foo2": [8, 9, 10]}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("box", [dict, Series]) +def test_transform_dictlike(string_series, box): + # GH 35964 + with np.errstate(all="ignore"): + expected = concat([np.sqrt(string_series), np.abs(string_series)], axis=1) + expected.columns = ["foo", "bar"] + result = string_series.transform(box({"foo": np.sqrt, "bar": np.abs})) + tm.assert_frame_equal(result, expected) + + +def test_transform_dictlike_mixed(): + # GH 40018 - mix of lists and non-lists in values of a dictionary + df = Series([1, 4]) + result = df.transform({"b": ["sqrt", "abs"], "c": "sqrt"}) + expected = DataFrame( + [[1.0, 1, 1.0], [2.0, 4, 2.0]], + columns=MultiIndex([("b", "c"), ("sqrt", "abs")], [(0, 0, 1), (0, 1, 0)]), + ) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/apply/test_str.py b/pandas/tests/apply/test_str.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a9492630b13a8ac03e976a699f8e58752887f2 --- /dev/null +++ b/pandas/tests/apply/test_str.py @@ -0,0 +1,307 @@ +from itertools import chain +import operator + +import numpy as np +import pytest + +from pandas.compat import ( + WASM, +) + +from pandas.core.dtypes.common import is_number + +from pandas import ( + DataFrame, + Series, +) +import pandas._testing as tm +from pandas.tests.apply.common import ( + frame_transform_kernels, + series_transform_kernels, +) + + +@pytest.mark.parametrize("func", ["sum", "mean", "min", "max", "std"]) +@pytest.mark.parametrize( + "kwds", + [ + pytest.param({}, id="no_kwds"), + pytest.param({"axis": 1}, id="on_axis"), + pytest.param({"numeric_only": True}, id="func_kwds"), + pytest.param({"axis": 1, "numeric_only": True}, id="axis_and_func_kwds"), + ], +) +@pytest.mark.parametrize("how", ["agg", "apply"]) +def test_apply_with_string_funcs(float_frame, func, kwds, how): + result = getattr(float_frame, how)(func, **kwds) + expected = getattr(float_frame, func)(**kwds) + tm.assert_series_equal(result, expected) + + +def test_with_string_args(datetime_series, all_numeric_reductions): + result = datetime_series.apply(all_numeric_reductions) + expected = getattr(datetime_series, all_numeric_reductions)() + assert result == expected + + +@pytest.mark.parametrize("op", ["mean", "median", "std", "var"]) +@pytest.mark.parametrize("how", ["agg", "apply"]) +def test_apply_np_reducer(op, how): + # GH 39116 + float_frame = DataFrame({"a": [1, 2], "b": [3, 4]}) + result = getattr(float_frame, how)(op) + # pandas ddof defaults to 1, numpy to 0 + kwargs = {"ddof": 1} if op in ("std", "var") else {} + expected = Series( + getattr(np, op)(float_frame, axis=0, **kwargs), index=float_frame.columns + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.skipif(WASM, reason="No fp exception support in wasm") +@pytest.mark.parametrize( + "op", ["abs", "ceil", "cos", "cumsum", "exp", "log", "sqrt", "square"] +) +@pytest.mark.parametrize("how", ["transform", "apply"]) +def test_apply_np_transformer(float_frame, op, how): + # GH 39116 + + # float_frame will _usually_ have negative values, which will + # trigger the warning here, but let's put one in just to be sure + float_frame.iloc[0, 0] = -1.0 + warn = None + if op in ["log", "sqrt"]: + warn = RuntimeWarning + + with tm.assert_produces_warning(warn, check_stacklevel=False): + # float_frame fixture is defined in conftest.py, so we don't check the + # stacklevel as otherwise the test would fail. + result = getattr(float_frame, how)(op) + expected = getattr(np, op)(float_frame) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "series, func, expected", + chain( + tm.get_cython_table_params( + Series(dtype=np.float64), + [ + ("sum", 0), + ("max", np.nan), + ("min", np.nan), + ("all", True), + ("any", False), + ("mean", np.nan), + ("prod", 1), + ("std", np.nan), + ("var", np.nan), + ("median", np.nan), + ], + ), + tm.get_cython_table_params( + Series([np.nan, 1, 2, 3]), + [ + ("sum", 6), + ("max", 3), + ("min", 1), + ("all", True), + ("any", True), + ("mean", 2), + ("prod", 6), + ("std", 1), + ("var", 1), + ("median", 2), + ], + ), + tm.get_cython_table_params( + Series("a b c".split()), + [ + ("sum", "abc"), + ("max", "c"), + ("min", "a"), + ("all", True), + ("any", True), + ], + ), + ), +) +def test_agg_cython_table_series(series, func, expected): + # GH21224 + # test reducing functions in + # pandas.core.base.SelectionMixin._cython_table + warn = None if isinstance(func, str) else FutureWarning + with tm.assert_produces_warning(warn, match="is currently using Series.*"): + result = series.agg(func) + if is_number(expected): + assert np.isclose(result, expected, equal_nan=True) + else: + assert result == expected + + +@pytest.mark.parametrize( + "series, func, expected", + chain( + tm.get_cython_table_params( + Series(dtype=np.float64), + [ + ("cumprod", Series([], dtype=np.float64)), + ("cumsum", Series([], dtype=np.float64)), + ], + ), + tm.get_cython_table_params( + Series([np.nan, 1, 2, 3]), + [ + ("cumprod", Series([np.nan, 1, 2, 6])), + ("cumsum", Series([np.nan, 1, 3, 6])), + ], + ), + tm.get_cython_table_params( + Series("a b c".split()), [("cumsum", Series(["a", "ab", "abc"]))] + ), + ), +) +def test_agg_cython_table_transform_series(series, func, expected): + # GH21224 + # test transforming functions in + # pandas.core.base.SelectionMixin._cython_table (cumprod, cumsum) + warn = None if isinstance(func, str) else FutureWarning + with tm.assert_produces_warning(warn, match="is currently using Series.*"): + result = series.agg(func) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "df, func, expected", + chain( + tm.get_cython_table_params( + DataFrame(), + [ + ("sum", Series(dtype="float64")), + ("max", Series(dtype="float64")), + ("min", Series(dtype="float64")), + ("all", Series(dtype=bool)), + ("any", Series(dtype=bool)), + ("mean", Series(dtype="float64")), + ("prod", Series(dtype="float64")), + ("std", Series(dtype="float64")), + ("var", Series(dtype="float64")), + ("median", Series(dtype="float64")), + ], + ), + tm.get_cython_table_params( + DataFrame([[np.nan, 1], [1, 2]]), + [ + ("sum", Series([1.0, 3])), + ("max", Series([1.0, 2])), + ("min", Series([1.0, 1])), + ("all", Series([True, True])), + ("any", Series([True, True])), + ("mean", Series([1, 1.5])), + ("prod", Series([1.0, 2])), + ("std", Series([np.nan, 0.707107])), + ("var", Series([np.nan, 0.5])), + ("median", Series([1, 1.5])), + ], + ), + ), +) +def test_agg_cython_table_frame(df, func, expected, axis): + # GH 21224 + # test reducing functions in + # pandas.core.base.SelectionMixin._cython_table + warn = None if isinstance(func, str) else FutureWarning + with tm.assert_produces_warning(warn, match="is currently using DataFrame.*"): + # GH#53425 + result = df.agg(func, axis=axis) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "df, func, expected", + chain( + tm.get_cython_table_params( + DataFrame(), [("cumprod", DataFrame()), ("cumsum", DataFrame())] + ), + tm.get_cython_table_params( + DataFrame([[np.nan, 1], [1, 2]]), + [ + ("cumprod", DataFrame([[np.nan, 1], [1, 2]])), + ("cumsum", DataFrame([[np.nan, 1], [1, 3]])), + ], + ), + ), +) +def test_agg_cython_table_transform_frame(df, func, expected, axis): + # GH 21224 + # test transforming functions in + # pandas.core.base.SelectionMixin._cython_table (cumprod, cumsum) + if axis in ("columns", 1): + # operating blockwise doesn't let us preserve dtypes + expected = expected.astype("float64") + + warn = None if isinstance(func, str) else FutureWarning + with tm.assert_produces_warning(warn, match="is currently using DataFrame.*"): + # GH#53425 + result = df.agg(func, axis=axis) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("op", series_transform_kernels) +def test_transform_groupby_kernel_series(request, string_series, op): + # GH 35964 + if op == "ngroup": + request.applymarker( + pytest.mark.xfail(raises=ValueError, reason="ngroup not valid for NDFrame") + ) + args = [0.0] if op == "fillna" else [] + ones = np.ones(string_series.shape[0]) + + warn = FutureWarning if op == "fillna" else None + msg = "SeriesGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=msg): + expected = string_series.groupby(ones).transform(op, *args) + result = string_series.transform(op, 0, *args) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("op", frame_transform_kernels) +def test_transform_groupby_kernel_frame(request, float_frame, op): + if op == "ngroup": + request.applymarker( + pytest.mark.xfail(raises=ValueError, reason="ngroup not valid for NDFrame") + ) + + # GH 35964 + + args = [0.0] if op == "fillna" else [] + ones = np.ones(float_frame.shape[0]) + gb = float_frame.groupby(ones) + + warn = FutureWarning if op == "fillna" else None + op_msg = "DataFrameGroupBy.fillna is deprecated" + with tm.assert_produces_warning(warn, match=op_msg): + expected = gb.transform(op, *args) + + result = float_frame.transform(op, 0, *args) + tm.assert_frame_equal(result, expected) + + # same thing, but ensuring we have multiple blocks + assert "E" not in float_frame.columns + float_frame["E"] = float_frame["A"].copy() + assert len(float_frame._mgr.blocks) > 1 + + ones = np.ones(float_frame.shape[0]) + gb2 = float_frame.groupby(ones) + expected2 = gb2.transform(op, *args) + result2 = float_frame.transform(op, 0, *args) + tm.assert_frame_equal(result2, expected2) + + +@pytest.mark.parametrize("method", ["abs", "shift", "pct_change", "cumsum", "rank"]) +def test_transform_method_name(method): + # GH 19760 + df = DataFrame({"A": [-1, 2]}) + result = df.transform(method) + expected = operator.methodcaller(method)(df) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/arithmetic/__init__.py b/pandas/tests/arithmetic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/arithmetic/common.py b/pandas/tests/arithmetic/common.py new file mode 100644 index 0000000000000000000000000000000000000000..7ea9d2b0ee23ad14168a4366e332e1d49d3c0c85 --- /dev/null +++ b/pandas/tests/arithmetic/common.py @@ -0,0 +1,158 @@ +""" +Assertion helpers for arithmetic tests. +""" + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + Series, + array, +) +import pandas._testing as tm +from pandas.core.arrays import ( + BooleanArray, + NumpyExtensionArray, +) + + +def assert_cannot_add(left, right, msg="cannot add"): + """ + Helper function to assert that two objects cannot be added. + + Parameters + ---------- + left : object + The first operand. + right : object + The second operand. + msg : str, default "cannot add" + The error message expected in the TypeError. + """ + with pytest.raises(TypeError, match=msg): + left + right + with pytest.raises(TypeError, match=msg): + right + left + + +def assert_invalid_addsub_type(left, right, msg=None): + """ + Helper function to assert that two objects can + neither be added nor subtracted. + + Parameters + ---------- + left : object + The first operand. + right : object + The second operand. + msg : str or None, default None + The error message expected in the TypeError. + """ + with pytest.raises(TypeError, match=msg): + left + right + with pytest.raises(TypeError, match=msg): + right + left + with pytest.raises(TypeError, match=msg): + left - right + with pytest.raises(TypeError, match=msg): + right - left + + +def get_upcast_box(left, right, is_cmp: bool = False): + """ + Get the box to use for 'expected' in an arithmetic or comparison operation. + + Parameters + left : Any + right : Any + is_cmp : bool, default False + Whether the operation is a comparison method. + """ + + if isinstance(left, DataFrame) or isinstance(right, DataFrame): + return DataFrame + if isinstance(left, Series) or isinstance(right, Series): + if is_cmp and isinstance(left, Index): + # Index does not defer for comparisons + return np.array + return Series + if isinstance(left, Index) or isinstance(right, Index): + if is_cmp: + return np.array + return Index + return tm.to_array + + +def assert_invalid_comparison(left, right, box): + """ + Assert that comparison operations with mismatched types behave correctly. + + Parameters + ---------- + left : np.ndarray, ExtensionArray, Index, or Series + right : object + box : {pd.DataFrame, pd.Series, pd.Index, pd.array, tm.to_array} + """ + # Not for tznaive-tzaware comparison + + # Note: not quite the same as how we do this for tm.box_expected + xbox = box if box not in [Index, array] else np.array + + def xbox2(x): + # Eventually we'd like this to be tighter, but for now we'll + # just exclude NumpyExtensionArray[bool] + if isinstance(x, NumpyExtensionArray): + return x._ndarray + if isinstance(x, BooleanArray): + # NB: we are assuming no pd.NAs for now + return x.astype(bool) + return x + + result = xbox2(left == right) + expected = xbox(np.zeros(result.shape, dtype=np.bool_)) + + tm.assert_equal(result, expected) + + result = xbox2(right == left) + tm.assert_equal(result, xbox(expected)) + + result = xbox2(left != right) + tm.assert_equal(result, ~expected) + + result = xbox2(right != left) + tm.assert_equal(result, xbox(~expected)) + + msg = "|".join( + [ + "Invalid comparison between", + "Cannot compare type", + "not supported between", + "invalid type promotion", + ( + # GH#36706 npdev 1.20.0 2020-09-28 + r"The DTypes and " + r" do not have a common DType. " + "For example they cannot be stored in a single array unless the " + "dtype is `object`." + ), + ] + ) + with pytest.raises(TypeError, match=msg): + left < right + with pytest.raises(TypeError, match=msg): + left <= right + with pytest.raises(TypeError, match=msg): + left > right + with pytest.raises(TypeError, match=msg): + left >= right + with pytest.raises(TypeError, match=msg): + right < left + with pytest.raises(TypeError, match=msg): + right <= left + with pytest.raises(TypeError, match=msg): + right > left + with pytest.raises(TypeError, match=msg): + right >= left diff --git a/pandas/tests/arithmetic/conftest.py b/pandas/tests/arithmetic/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..44838394f2183dd7f18c4a960ff60a9f5e3f29cc --- /dev/null +++ b/pandas/tests/arithmetic/conftest.py @@ -0,0 +1,139 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import Index + + +@pytest.fixture(params=[1, np.array(1, dtype=np.int64)]) +def one(request): + """ + Several variants of integer value 1. The zero-dim integer array + behaves like an integer. + + This fixture can be used to check that datetimelike indexes handle + addition and subtraction of integers and zero-dimensional arrays + of integers. + + Examples + -------- + dti = pd.date_range('2016-01-01', periods=2, freq='h') + dti + DatetimeIndex(['2016-01-01 00:00:00', '2016-01-01 01:00:00'], + dtype='datetime64[ns]', freq='h') + dti + one + DatetimeIndex(['2016-01-01 01:00:00', '2016-01-01 02:00:00'], + dtype='datetime64[ns]', freq='h') + """ + return request.param + + +zeros = [ + box_cls([0] * 5, dtype=dtype) + for box_cls in [Index, np.array, pd.array] + for dtype in [np.int64, np.uint64, np.float64] +] +zeros.extend([box_cls([-0.0] * 5, dtype=np.float64) for box_cls in [Index, np.array]]) +zeros.extend([np.array(0, dtype=dtype) for dtype in [np.int64, np.uint64, np.float64]]) +zeros.extend([np.array(-0.0, dtype=np.float64)]) +zeros.extend([0, 0.0, -0.0]) + + +@pytest.fixture(params=zeros) +def zero(request): + """ + Several types of scalar zeros and length 5 vectors of zeros. + + This fixture can be used to check that numeric-dtype indexes handle + division by any zero numeric-dtype. + + Uses vector of length 5 for broadcasting with `numeric_idx` fixture, + which creates numeric-dtype vectors also of length 5. + + Examples + -------- + arr = RangeIndex(5) + arr / zeros + Index([nan, inf, inf, inf, inf], dtype='float64') + """ + return request.param + + +# ------------------------------------------------------------------ +# Scalar Fixtures + + +@pytest.fixture( + params=[ + pd.Timedelta("10m7s").to_pytimedelta(), + pd.Timedelta("10m7s"), + pd.Timedelta("10m7s").to_timedelta64(), + ], + ids=lambda x: type(x).__name__, +) +def scalar_td(request): + """ + Several variants of Timedelta scalars representing 10 minutes and 7 seconds. + """ + return request.param + + +@pytest.fixture( + params=[ + pd.offsets.Day(3), + pd.offsets.Hour(72), + pd.Timedelta(days=3).to_pytimedelta(), + pd.Timedelta("72:00:00"), + np.timedelta64(3, "D"), + np.timedelta64(72, "h"), + ], + ids=lambda x: type(x).__name__, +) +def three_days(request): + """ + Several timedelta-like and DateOffset objects that each represent + a 3-day timedelta + """ + return request.param + + +@pytest.fixture( + params=[ + pd.offsets.Hour(2), + pd.offsets.Minute(120), + pd.Timedelta(hours=2).to_pytimedelta(), + pd.Timedelta(seconds=2 * 3600), + np.timedelta64(2, "h"), + np.timedelta64(120, "m"), + ], + ids=lambda x: type(x).__name__, +) +def two_hours(request): + """ + Several timedelta-like and DateOffset objects that each represent + a 2-hour timedelta + """ + return request.param + + +_common_mismatch = [ + pd.offsets.YearBegin(2), + pd.offsets.MonthBegin(1), + pd.offsets.Minute(), +] + + +@pytest.fixture( + params=[ + np.timedelta64(4, "h"), + pd.Timedelta(hours=23).to_pytimedelta(), + pd.Timedelta("23:00:00"), + *_common_mismatch, + ] +) +def not_daily(request): + """ + Several timedelta-like and DateOffset instances that are _not_ + compatible with Daily frequencies. + """ + return request.param diff --git a/pandas/tests/arithmetic/test_array_ops.py b/pandas/tests/arithmetic/test_array_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..f026092e0d1133ea861f9cdb2699924e3618298e --- /dev/null +++ b/pandas/tests/arithmetic/test_array_ops.py @@ -0,0 +1,78 @@ +import operator + +import numpy as np +import pytest + +from pandas.core.dtypes.missing import isna + +import pandas._testing as tm +from pandas.core.ops.array_ops import ( + comparison_op, + na_logical_op, +) + + +def test_na_logical_op_2d(): + left = np.arange(8).reshape(4, 2) + right = left.astype(object) + right[0, 0] = np.nan + + # Check that we fall back to the vec_binop branch + with pytest.raises(TypeError, match="unsupported operand type"): + operator.or_(left, right) + + result = na_logical_op(left, right, operator.or_) + expected = right + tm.assert_numpy_array_equal(result, expected) + + +def test_object_comparison_2d(): + left = np.arange(9).reshape(3, 3).astype(object) + right = left.T + + result = comparison_op(left, right, operator.eq) + expected = np.eye(3).astype(bool) + tm.assert_numpy_array_equal(result, expected) + + # Ensure that cython doesn't raise on non-writeable arg, which + # we can get from np.broadcast_to + right.flags.writeable = False + result = comparison_op(left, right, operator.ne) + tm.assert_numpy_array_equal(result, ~expected) + + +@pytest.mark.parametrize("rvalues", [1, [1, 1, 1], np.nan, None]) +@pytest.mark.parametrize( + "op", [operator.eq, operator.ne, operator.lt, operator.le, operator.gt, operator.ge] +) +def test_comparison_for_subclasses(rvalues, op): + # GH#63205 Ensure subclasses of ndarray are correctly handled in comparison_op + # Define a custom ndarray subclass + class TestArray(np.ndarray): + def __new__(cls, input_array): + return np.asarray(input_array).view(cls) + + def __array_finalize__(self, obj) -> None: + self._is_test_array = True + + def expected_with_na_handling(lvalues, rvalues, op): + # Similar to comparison_op, handle zerodim arrays with na value separately + if (rvalues.ndim == 0) and isna(rvalues.item()): + # numpy does not like comparisons vs None + if op is operator.ne: + return np.ones(lvalues.shape, dtype=bool) + else: + return np.zeros(lvalues.shape, dtype=bool) + return op(lvalues, rvalues) + + # Define test data + lvalues = [1, 2, 3] + + # Test with both ndarray and TestArray + result = comparison_op(np.array(lvalues), np.array(rvalues), op) + expected = expected_with_na_handling(np.array(lvalues), np.array(rvalues), op) + tm.assert_numpy_array_equal(result, expected) + + result = comparison_op(TestArray(lvalues), TestArray(rvalues), op) + expected = expected_with_na_handling(TestArray(lvalues), TestArray(rvalues), op) + tm.assert_numpy_array_equal(result, expected) diff --git a/pandas/tests/arithmetic/test_bool.py b/pandas/tests/arithmetic/test_bool.py new file mode 100644 index 0000000000000000000000000000000000000000..3723b7042a3ce77bcf21d34c77fff01ed31eceb4 --- /dev/null +++ b/pandas/tests/arithmetic/test_bool.py @@ -0,0 +1,28 @@ +import pytest + +from pandas import ( + DataFrame, + Series, +) +import pandas._testing as tm + + +def test_divmod_bool_raises(box_with_array): + # GH#46043 // raises, so divmod should too + ser = Series([True, False]) + obj = tm.box_expected(ser, box_with_array) + + msg = "operator 'floordiv' not implemented for bool dtypes" + with pytest.raises(NotImplementedError, match=msg): + obj // obj + + if box_with_array is DataFrame: + msg = "operator 'floordiv' not implemented for bool dtypes" + else: + msg = "operator 'divmod' not implemented for bool dtypes" + with pytest.raises(NotImplementedError, match=msg): + divmod(obj, obj) + + # go through __rdivmod__ + with pytest.raises(NotImplementedError, match=msg): + divmod(True, obj) diff --git a/pandas/tests/arithmetic/test_categorical.py b/pandas/tests/arithmetic/test_categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..d6f3a13ce670596a12ca10b9e8d02d69d63c96fb --- /dev/null +++ b/pandas/tests/arithmetic/test_categorical.py @@ -0,0 +1,25 @@ +import numpy as np + +from pandas import ( + Categorical, + Series, +) +import pandas._testing as tm + + +class TestCategoricalComparisons: + def test_categorical_nan_equality(self): + cat = Series(Categorical(["a", "b", "c", np.nan])) + expected = Series([True, True, True, False]) + result = cat == cat + tm.assert_series_equal(result, expected) + + def test_categorical_tuple_equality(self): + # GH 18050 + ser = Series([(0, 0), (0, 1), (0, 0), (1, 0), (1, 1)]) + expected = Series([True, False, True, False, False]) + result = ser == (0, 0) + tm.assert_series_equal(result, expected) + + result = ser.astype("category") == (0, 0) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/arithmetic/test_datetime64.py b/pandas/tests/arithmetic/test_datetime64.py new file mode 100644 index 0000000000000000000000000000000000000000..05d0a9c0626af83b3eb43a683a94a0229376ff88 --- /dev/null +++ b/pandas/tests/arithmetic/test_datetime64.py @@ -0,0 +1,2500 @@ +# Arithmetic tests for DataFrame/Series/Index/Array classes that should +# behave identically. +# Specifically for datetime64 and datetime64tz dtypes +from datetime import ( + datetime, + time, + timedelta, + timezone, +) +from itertools import ( + product, +) +import operator + +import numpy as np +import pytest + +from pandas._libs.tslibs.conversion import localize_pydatetime +from pandas._libs.tslibs.offsets import shift_months + +import pandas as pd +from pandas import ( + DateOffset, + DatetimeIndex, + NaT, + Period, + Series, + Timedelta, + TimedeltaIndex, + Timestamp, + date_range, +) +import pandas._testing as tm +from pandas.core import roperator +from pandas.tests.arithmetic.common import ( + assert_cannot_add, + assert_invalid_addsub_type, + assert_invalid_comparison, + get_upcast_box, +) + +# ------------------------------------------------------------------ +# Comparisons + + +class TestDatetime64ArrayLikeComparisons: + # Comparison tests for datetime64 vectors fully parametrized over + # DataFrame/Series/DatetimeIndex/DatetimeArray. Ideally all comparison + # tests will eventually end up here. + + def test_compare_zerodim(self, tz_naive_fixture, box_with_array): + # Test comparison with zero-dimensional array is unboxed + tz = tz_naive_fixture + box = box_with_array + dti = date_range("20130101", periods=3, tz=tz) + + other = np.array(dti.to_numpy()[0]) + + dtarr = tm.box_expected(dti, box) + xbox = get_upcast_box(dtarr, other, True) + result = dtarr <= other + expected = np.array([True, False, False]) + expected = tm.box_expected(expected, xbox) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "other", + [ + "foo", + -1, + 99, + 4.0, + object(), + timedelta(days=2), + # GH#19800, GH#19301 datetime.date comparison raises to + # match DatetimeIndex/Timestamp. This also matches the behavior + # of stdlib datetime.datetime + datetime(2001, 1, 1).date(), + # GH#19301 None and NaN are *not* cast to NaT for comparisons + None, + np.nan, + ], + ) + def test_dt64arr_cmp_scalar_invalid(self, other, tz_naive_fixture, box_with_array): + # GH#22074, GH#15966 + tz = tz_naive_fixture + + rng = date_range("1/1/2000", periods=10, tz=tz) + dtarr = tm.box_expected(rng, box_with_array) + assert_invalid_comparison(dtarr, other, box_with_array) + + @pytest.mark.parametrize( + "other", + [ + # GH#4968 invalid date/int comparisons + list(range(10)), + np.arange(10), + np.arange(10).astype(np.float32), + np.arange(10).astype(object), + pd.timedelta_range("1ns", periods=10).array, + np.array(pd.timedelta_range("1ns", periods=10)), + list(pd.timedelta_range("1ns", periods=10)), + pd.timedelta_range("1 Day", periods=10).astype(object), + pd.period_range("1971-01-01", freq="D", periods=10).array, + pd.period_range("1971-01-01", freq="D", periods=10).astype(object), + ], + ) + def test_dt64arr_cmp_arraylike_invalid( + self, other, tz_naive_fixture, box_with_array + ): + tz = tz_naive_fixture + + dta = date_range("1970-01-01", freq="ns", periods=10, tz=tz)._data + obj = tm.box_expected(dta, box_with_array) + assert_invalid_comparison(obj, other, box_with_array) + + def test_dt64arr_cmp_mixed_invalid(self, tz_naive_fixture): + tz = tz_naive_fixture + + dta = date_range("1970-01-01", freq="h", periods=5, tz=tz)._data + + other = np.array([0, 1, 2, dta[3], Timedelta(days=1)]) + result = dta == other + expected = np.array([False, False, False, True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = dta != other + tm.assert_numpy_array_equal(result, ~expected) + + msg = "Invalid comparison between|Cannot compare type|not supported between" + with pytest.raises(TypeError, match=msg): + dta < other + with pytest.raises(TypeError, match=msg): + dta > other + with pytest.raises(TypeError, match=msg): + dta <= other + with pytest.raises(TypeError, match=msg): + dta >= other + + def test_dt64arr_nat_comparison(self, tz_naive_fixture, box_with_array): + # GH#22242, GH#22163 DataFrame considered NaT == ts incorrectly + tz = tz_naive_fixture + box = box_with_array + + ts = Timestamp("2021-01-01", tz=tz) + ser = Series([ts, NaT]) + + obj = tm.box_expected(ser, box) + xbox = get_upcast_box(obj, ts, True) + + expected = Series([True, False], dtype=np.bool_) + expected = tm.box_expected(expected, xbox) + + result = obj == ts + tm.assert_equal(result, expected) + + +class TestDatetime64SeriesComparison: + # TODO: moved from tests.series.test_operators; needs cleanup + + @pytest.mark.parametrize( + "pair", + [ + ( + [Timestamp("2011-01-01"), NaT, Timestamp("2011-01-03")], + [NaT, NaT, Timestamp("2011-01-03")], + ), + ( + [Timedelta("1 days"), NaT, Timedelta("3 days")], + [NaT, NaT, Timedelta("3 days")], + ), + ( + [Period("2011-01", freq="M"), NaT, Period("2011-03", freq="M")], + [NaT, NaT, Period("2011-03", freq="M")], + ), + ], + ) + @pytest.mark.parametrize("reverse", [True, False]) + @pytest.mark.parametrize("dtype", [None, object]) + @pytest.mark.parametrize( + "op, expected", + [ + (operator.eq, [False, False, True]), + (operator.ne, [True, True, False]), + (operator.lt, [False, False, False]), + (operator.gt, [False, False, False]), + (operator.ge, [False, False, True]), + (operator.le, [False, False, True]), + ], + ) + def test_nat_comparisons( + self, + dtype, + index_or_series, + reverse, + pair, + op, + expected, + ): + box = index_or_series + lhs, rhs = pair + if reverse: + # add lhs / rhs switched data + lhs, rhs = rhs, lhs + + left = Series(lhs, dtype=dtype) + right = box(rhs, dtype=dtype) + + result = op(left, right) + + tm.assert_series_equal(result, Series(expected)) + + @pytest.mark.parametrize( + "data", + [ + [Timestamp("2011-01-01"), NaT, Timestamp("2011-01-03")], + [Timedelta("1 days"), NaT, Timedelta("3 days")], + [Period("2011-01", freq="M"), NaT, Period("2011-03", freq="M")], + ], + ) + @pytest.mark.parametrize("dtype", [None, object]) + def test_nat_comparisons_scalar(self, dtype, data, box_with_array): + box = box_with_array + + left = Series(data, dtype=dtype) + left = tm.box_expected(left, box) + xbox = get_upcast_box(left, NaT, True) + + expected = [False, False, False] + expected = tm.box_expected(expected, xbox) + if box is pd.array and dtype is object: + expected = pd.array(expected, dtype="bool") + + tm.assert_equal(left == NaT, expected) + tm.assert_equal(NaT == left, expected) + + expected = [True, True, True] + expected = tm.box_expected(expected, xbox) + if box is pd.array and dtype is object: + expected = pd.array(expected, dtype="bool") + tm.assert_equal(left != NaT, expected) + tm.assert_equal(NaT != left, expected) + + expected = [False, False, False] + expected = tm.box_expected(expected, xbox) + if box is pd.array and dtype is object: + expected = pd.array(expected, dtype="bool") + tm.assert_equal(left < NaT, expected) + tm.assert_equal(NaT > left, expected) + tm.assert_equal(left <= NaT, expected) + tm.assert_equal(NaT >= left, expected) + + tm.assert_equal(left > NaT, expected) + tm.assert_equal(NaT < left, expected) + tm.assert_equal(left >= NaT, expected) + tm.assert_equal(NaT <= left, expected) + + @pytest.mark.parametrize("val", [datetime(2000, 1, 4), datetime(2000, 1, 5)]) + def test_series_comparison_scalars(self, val): + series = Series(date_range("1/1/2000", periods=10)) + + result = series > val + expected = Series([x > val for x in series]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "left,right", [("lt", "gt"), ("le", "ge"), ("eq", "eq"), ("ne", "ne")] + ) + def test_timestamp_compare_series(self, left, right): + # see gh-4982 + # Make sure we can compare Timestamps on the right AND left hand side. + ser = Series(date_range("20010101", periods=10), name="dates") + s_nat = ser.copy(deep=True) + + ser[0] = Timestamp("nat") + ser[3] = Timestamp("nat") + + left_f = getattr(operator, left) + right_f = getattr(operator, right) + + # No NaT + expected = left_f(ser, Timestamp("20010109")) + result = right_f(Timestamp("20010109"), ser) + tm.assert_series_equal(result, expected) + + # NaT + expected = left_f(ser, Timestamp("nat")) + result = right_f(Timestamp("nat"), ser) + tm.assert_series_equal(result, expected) + + # Compare to Timestamp with series containing NaT + expected = left_f(s_nat, Timestamp("20010109")) + result = right_f(Timestamp("20010109"), s_nat) + tm.assert_series_equal(result, expected) + + # Compare to NaT with series containing NaT + expected = left_f(s_nat, NaT) + result = right_f(NaT, s_nat) + tm.assert_series_equal(result, expected) + + def test_dt64arr_timestamp_equality(self, box_with_array): + # GH#11034 + box = box_with_array + + ser = Series([Timestamp("2000-01-29 01:59:00"), Timestamp("2000-01-30"), NaT]) + ser = tm.box_expected(ser, box) + xbox = get_upcast_box(ser, ser, True) + + result = ser != ser + expected = tm.box_expected([False, False, True], xbox) + tm.assert_equal(result, expected) + + if box is pd.DataFrame: + # alignment for frame vs series comparisons deprecated + # in GH#46795 enforced 2.0 + with pytest.raises(ValueError, match="not aligned"): + ser != ser[0] + + else: + result = ser != ser[0] + expected = tm.box_expected([False, True, True], xbox) + tm.assert_equal(result, expected) + + if box is pd.DataFrame: + # alignment for frame vs series comparisons deprecated + # in GH#46795 enforced 2.0 + with pytest.raises(ValueError, match="not aligned"): + ser != ser[2] + else: + result = ser != ser[2] + expected = tm.box_expected([True, True, True], xbox) + tm.assert_equal(result, expected) + + result = ser == ser + expected = tm.box_expected([True, True, False], xbox) + tm.assert_equal(result, expected) + + if box is pd.DataFrame: + # alignment for frame vs series comparisons deprecated + # in GH#46795 enforced 2.0 + with pytest.raises(ValueError, match="not aligned"): + ser == ser[0] + else: + result = ser == ser[0] + expected = tm.box_expected([True, False, False], xbox) + tm.assert_equal(result, expected) + + if box is pd.DataFrame: + # alignment for frame vs series comparisons deprecated + # in GH#46795 enforced 2.0 + with pytest.raises(ValueError, match="not aligned"): + ser == ser[2] + else: + result = ser == ser[2] + expected = tm.box_expected([False, False, False], xbox) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "datetimelike", + [ + Timestamp("20130101"), + datetime(2013, 1, 1), + np.datetime64("2013-01-01T00:00", "ns"), + ], + ) + @pytest.mark.parametrize( + "op,expected", + [ + (operator.lt, [True, False, False, False]), + (operator.le, [True, True, False, False]), + (operator.eq, [False, True, False, False]), + (operator.gt, [False, False, False, True]), + ], + ) + def test_dt64_compare_datetime_scalar(self, datetimelike, op, expected): + # GH#17965, test for ability to compare datetime64[ns] columns + # to datetimelike + ser = Series( + [ + Timestamp("20120101"), + Timestamp("20130101"), + np.nan, + Timestamp("20130103"), + ], + name="A", + ) + result = op(ser, datetimelike) + expected = Series(expected, name="A") + tm.assert_series_equal(result, expected) + + def test_ts_series_numpy_maximum(self): + # GH#50864, test numpy.maximum does not fail + # given a TimeStamp and Series(with dtype datetime64) comparison + ts = Timestamp("2024-07-01") + ts_series = Series( + ["2024-06-01", "2024-07-01", "2024-08-01"], + dtype="datetime64[us]", + ) + + expected = Series( + ["2024-07-01", "2024-07-01", "2024-08-01"], + dtype="datetime64[us]", + ) + + tm.assert_series_equal(expected, np.maximum(ts, ts_series)) + + +class TestDatetimeIndexComparisons: + # TODO: moved from tests.indexes.test_base; parametrize and de-duplicate + def test_comparators(self, comparison_op): + index = date_range("2020-01-01", periods=10) + element = index[len(index) // 2] + element = Timestamp(element).to_datetime64() + + arr = np.array(index) + arr_result = comparison_op(arr, element) + index_result = comparison_op(index, element) + + assert isinstance(index_result, np.ndarray) + tm.assert_numpy_array_equal(arr_result, index_result) + + @pytest.mark.parametrize( + "other", + [datetime(2016, 1, 1), Timestamp("2016-01-01"), np.datetime64("2016-01-01")], + ) + def test_dti_cmp_datetimelike(self, other, tz_naive_fixture): + tz = tz_naive_fixture + dti = date_range("2016-01-01", periods=2, tz=tz) + if tz is not None: + if isinstance(other, np.datetime64): + pytest.skip(f"{type(other).__name__} is not tz aware") + other = localize_pydatetime(other, dti.tzinfo) + + result = dti == other + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = dti > other + expected = np.array([False, True]) + tm.assert_numpy_array_equal(result, expected) + + result = dti >= other + expected = np.array([True, True]) + tm.assert_numpy_array_equal(result, expected) + + result = dti < other + expected = np.array([False, False]) + tm.assert_numpy_array_equal(result, expected) + + result = dti <= other + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("dtype", [None, object]) + def test_dti_cmp_nat(self, dtype, box_with_array): + left = DatetimeIndex([Timestamp("2011-01-01"), NaT, Timestamp("2011-01-03")]) + right = DatetimeIndex([NaT, NaT, Timestamp("2011-01-03")]) + + left = tm.box_expected(left, box_with_array) + right = tm.box_expected(right, box_with_array) + xbox = get_upcast_box(left, right, True) + + lhs, rhs = left, right + if dtype is object: + lhs, rhs = left.astype(object), right.astype(object) + + result = rhs == lhs + expected = np.array([False, False, True]) + expected = tm.box_expected(expected, xbox) + tm.assert_equal(result, expected) + + result = lhs != rhs + expected = np.array([True, True, False]) + expected = tm.box_expected(expected, xbox) + tm.assert_equal(result, expected) + + expected = np.array([False, False, False]) + expected = tm.box_expected(expected, xbox) + tm.assert_equal(lhs == NaT, expected) + tm.assert_equal(NaT == rhs, expected) + + expected = np.array([True, True, True]) + expected = tm.box_expected(expected, xbox) + tm.assert_equal(lhs != NaT, expected) + tm.assert_equal(NaT != lhs, expected) + + expected = np.array([False, False, False]) + expected = tm.box_expected(expected, xbox) + tm.assert_equal(lhs < NaT, expected) + tm.assert_equal(NaT > lhs, expected) + + def test_dti_cmp_nat_behaves_like_float_cmp_nan(self): + fidx1 = pd.Index([1.0, np.nan, 3.0, np.nan, 5.0, 7.0]) + fidx2 = pd.Index([2.0, 3.0, np.nan, np.nan, 6.0, 7.0]) + + didx1 = DatetimeIndex( + ["2014-01-01", NaT, "2014-03-01", NaT, "2014-05-01", "2014-07-01"] + ) + didx2 = DatetimeIndex( + ["2014-02-01", "2014-03-01", NaT, NaT, "2014-06-01", "2014-07-01"] + ) + darr = np.array( + [ + np.datetime64("2014-02-01 00:00"), + np.datetime64("2014-03-01 00:00"), + np.datetime64("nat"), + np.datetime64("nat"), + np.datetime64("2014-06-01 00:00"), + np.datetime64("2014-07-01 00:00"), + ] + ) + + cases = [(fidx1, fidx2), (didx1, didx2), (didx1, darr)] + + # Check pd.NaT is handles as the same as np.nan + with tm.assert_produces_warning(None): + for idx1, idx2 in cases: + result = idx1 < idx2 + expected = np.array([True, False, False, False, True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = idx2 > idx1 + expected = np.array([True, False, False, False, True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = idx1 <= idx2 + expected = np.array([True, False, False, False, True, True]) + tm.assert_numpy_array_equal(result, expected) + + result = idx2 >= idx1 + expected = np.array([True, False, False, False, True, True]) + tm.assert_numpy_array_equal(result, expected) + + result = idx1 == idx2 + expected = np.array([False, False, False, False, False, True]) + tm.assert_numpy_array_equal(result, expected) + + result = idx1 != idx2 + expected = np.array([True, True, True, True, True, False]) + tm.assert_numpy_array_equal(result, expected) + + with tm.assert_produces_warning(None): + for idx1, val in [(fidx1, np.nan), (didx1, NaT)]: + result = idx1 < val + expected = np.array([False, False, False, False, False, False]) + tm.assert_numpy_array_equal(result, expected) + result = idx1 > val + tm.assert_numpy_array_equal(result, expected) + + result = idx1 <= val + tm.assert_numpy_array_equal(result, expected) + result = idx1 >= val + tm.assert_numpy_array_equal(result, expected) + + result = idx1 == val + tm.assert_numpy_array_equal(result, expected) + + result = idx1 != val + expected = np.array([True, True, True, True, True, True]) + tm.assert_numpy_array_equal(result, expected) + + # Check pd.NaT is handles as the same as np.nan + with tm.assert_produces_warning(None): + for idx1, val in [(fidx1, 3), (didx1, datetime(2014, 3, 1))]: + result = idx1 < val + expected = np.array([True, False, False, False, False, False]) + tm.assert_numpy_array_equal(result, expected) + result = idx1 > val + expected = np.array([False, False, False, False, True, True]) + tm.assert_numpy_array_equal(result, expected) + + result = idx1 <= val + expected = np.array([True, False, True, False, False, False]) + tm.assert_numpy_array_equal(result, expected) + result = idx1 >= val + expected = np.array([False, False, True, False, True, True]) + tm.assert_numpy_array_equal(result, expected) + + result = idx1 == val + expected = np.array([False, False, True, False, False, False]) + tm.assert_numpy_array_equal(result, expected) + + result = idx1 != val + expected = np.array([True, True, False, True, True, True]) + tm.assert_numpy_array_equal(result, expected) + + def test_comparison_tzawareness_compat(self, comparison_op, box_with_array): + # GH#18162 + op = comparison_op + box = box_with_array + + dr = date_range("2016-01-01", periods=6, unit="ns") + dz = dr.tz_localize("US/Pacific") + + dr = tm.box_expected(dr, box) + dz = tm.box_expected(dz, box) + + if box is pd.DataFrame: + tolist = lambda x: x.astype(object).values.tolist()[0] + else: + tolist = list + + if op not in [operator.eq, operator.ne]: + msg = ( + r"Invalid comparison between dtype=datetime64\[ns.*\] " + "and (Timestamp|DatetimeArray|list|ndarray)" + ) + with pytest.raises(TypeError, match=msg): + op(dr, dz) + + with pytest.raises(TypeError, match=msg): + op(dr, tolist(dz)) + with pytest.raises(TypeError, match=msg): + op(dr, np.array(tolist(dz), dtype=object)) + with pytest.raises(TypeError, match=msg): + op(dz, dr) + + with pytest.raises(TypeError, match=msg): + op(dz, tolist(dr)) + with pytest.raises(TypeError, match=msg): + op(dz, np.array(tolist(dr), dtype=object)) + + # The aware==aware and naive==naive comparisons should *not* raise + assert np.all(dr == dr) + assert np.all(dr == tolist(dr)) + assert np.all(tolist(dr) == dr) + assert np.all(np.array(tolist(dr), dtype=object) == dr) + assert np.all(dr == np.array(tolist(dr), dtype=object)) + + assert np.all(dz == dz) + assert np.all(dz == tolist(dz)) + assert np.all(tolist(dz) == dz) + assert np.all(np.array(tolist(dz), dtype=object) == dz) + assert np.all(dz == np.array(tolist(dz), dtype=object)) + + def test_comparison_tzawareness_compat_scalars(self, comparison_op, box_with_array): + # GH#18162 + op = comparison_op + + dr = date_range("2016-01-01", periods=6, unit="ns") + dz = dr.tz_localize("US/Pacific") + + dr = tm.box_expected(dr, box_with_array) + dz = tm.box_expected(dz, box_with_array) + + # Check comparisons against scalar Timestamps + ts = Timestamp("2000-03-14 01:59") + ts_tz = Timestamp("2000-03-14 01:59", tz="Europe/Amsterdam") + + assert np.all(dr > ts) + msg = r"Invalid comparison between dtype=datetime64\[ns.*\] and Timestamp" + if op not in [operator.eq, operator.ne]: + with pytest.raises(TypeError, match=msg): + op(dr, ts_tz) + + assert np.all(dz > ts_tz) + if op not in [operator.eq, operator.ne]: + with pytest.raises(TypeError, match=msg): + op(dz, ts) + + if op not in [operator.eq, operator.ne]: + # GH#12601: Check comparison against Timestamps and DatetimeIndex + with pytest.raises(TypeError, match=msg): + op(ts, dz) + + @pytest.mark.parametrize( + "other", + [datetime(2016, 1, 1), Timestamp("2016-01-01"), np.datetime64("2016-01-01")], + ) + # Bug in NumPy? https://github.com/numpy/numpy/issues/13841 + # Raising in __eq__ will fallback to NumPy, which warns, fails, + # then re-raises the original exception. So we just need to ignore. + @pytest.mark.filterwarnings("ignore:elementwise comp:DeprecationWarning") + def test_scalar_comparison_tzawareness( + self, comparison_op, other, tz_aware_fixture, box_with_array + ): + op = comparison_op + tz = tz_aware_fixture + dti = date_range("2016-01-01", periods=2, tz=tz, unit="ns") + + dtarr = tm.box_expected(dti, box_with_array) + xbox = get_upcast_box(dtarr, other, True) + if op in [operator.eq, operator.ne]: + exbool = op is operator.ne + expected = np.array([exbool, exbool], dtype=bool) + expected = tm.box_expected(expected, xbox) + + result = op(dtarr, other) + tm.assert_equal(result, expected) + + result = op(other, dtarr) + tm.assert_equal(result, expected) + else: + msg = ( + r"Invalid comparison between dtype=datetime64\[ns, .*\] " + f"and {type(other).__name__}" + ) + with pytest.raises(TypeError, match=msg): + op(dtarr, other) + with pytest.raises(TypeError, match=msg): + op(other, dtarr) + + def test_nat_comparison_tzawareness(self, comparison_op): + # GH#19276 + # tzaware DatetimeIndex should not raise when compared to NaT + op = comparison_op + + dti = DatetimeIndex( + ["2014-01-01", NaT, "2014-03-01", NaT, "2014-05-01", "2014-07-01"] + ) + expected = np.array([op == operator.ne] * len(dti)) + result = op(dti, NaT) + tm.assert_numpy_array_equal(result, expected) + + result = op(dti.tz_localize("US/Pacific"), NaT) + tm.assert_numpy_array_equal(result, expected) + + def test_dti_cmp_str(self, tz_naive_fixture): + # GH#22074 + # regardless of tz, we expect these comparisons are valid + tz = tz_naive_fixture + rng = date_range("1/1/2000", periods=10, tz=tz) + other = "1/1/2000" + + result = rng == other + expected = np.array([True] + [False] * 9) + tm.assert_numpy_array_equal(result, expected) + + result = rng != other + expected = np.array([False] + [True] * 9) + tm.assert_numpy_array_equal(result, expected) + + result = rng < other + expected = np.array([False] * 10) + tm.assert_numpy_array_equal(result, expected) + + result = rng <= other + expected = np.array([True] + [False] * 9) + tm.assert_numpy_array_equal(result, expected) + + result = rng > other + expected = np.array([False] + [True] * 9) + tm.assert_numpy_array_equal(result, expected) + + result = rng >= other + expected = np.array([True] * 10) + tm.assert_numpy_array_equal(result, expected) + + def test_dti_cmp_list(self): + rng = date_range("1/1/2000", periods=10) + + result = rng == list(rng) + expected = rng == rng + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "other", + [ + pd.timedelta_range("1D", periods=10), + pd.timedelta_range("1D", periods=10).to_series(), + pd.timedelta_range("1D", periods=10).asi8.view("m8[ns]"), + ], + ids=lambda x: type(x).__name__, + ) + def test_dti_cmp_tdi_tzawareness(self, other): + # GH#22074 + # reversion test that we _don't_ call _assert_tzawareness_compat + # when comparing against TimedeltaIndex + dti = date_range("2000-01-01", periods=10, tz="Asia/Tokyo") + + result = dti == other + expected = np.array([False] * 10) + if isinstance(other, Series): + tm.assert_series_equal(result, Series(expected, index=other.index)) + else: + tm.assert_numpy_array_equal(result, expected) + + result = dti != other + expected = np.array([True] * 10) + if isinstance(other, Series): + tm.assert_series_equal(result, Series(expected, index=other.index)) + else: + tm.assert_numpy_array_equal(result, expected) + + msg = "Invalid comparison between" + with pytest.raises(TypeError, match=msg): + dti < other + with pytest.raises(TypeError, match=msg): + dti <= other + with pytest.raises(TypeError, match=msg): + dti > other + with pytest.raises(TypeError, match=msg): + dti >= other + + def test_dti_cmp_object_dtype(self): + # GH#22074 + dti = date_range("2000-01-01", periods=10, tz="Asia/Tokyo") + + other = dti.astype("O") + + result = dti == other + expected = np.array([True] * 10) + tm.assert_numpy_array_equal(result, expected) + + other = dti.tz_localize(None) + result = dti != other + tm.assert_numpy_array_equal(result, expected) + + other = np.array(list(dti[:5]) + [Timedelta(days=1)] * 5) + result = dti == other + expected = np.array([True] * 5 + [False] * 5) + tm.assert_numpy_array_equal(result, expected) + msg = ">=' not supported between instances of 'Timestamp' and 'Timedelta'" + with pytest.raises(TypeError, match=msg): + dti >= other + + +# ------------------------------------------------------------------ +# Arithmetic + + +class TestDatetime64Arithmetic: + # This class is intended for "finished" tests that are fully parametrized + # over DataFrame/Series/Index/DatetimeArray + + # ------------------------------------------------------------- + # Addition/Subtraction of timedelta-like + + @pytest.mark.arm_slow + def test_dt64arr_add_timedeltalike_scalar( + self, tz_naive_fixture, two_hours, box_with_array + ): + # GH#22005, GH#22163 check DataFrame doesn't raise TypeError + tz = tz_naive_fixture + + rng = date_range("2000-01-01", "2000-02-01", tz=tz, unit="ns") + expected = date_range("2000-01-01 02:00", "2000-02-01 02:00", tz=tz, unit="ns") + if tz is not None: + expected = expected._with_freq(None) + + rng = tm.box_expected(rng, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = rng + two_hours + tm.assert_equal(result, expected) + + result = two_hours + rng + tm.assert_equal(result, expected) + + rng += two_hours + tm.assert_equal(rng, expected) + + def test_dt64arr_sub_timedeltalike_scalar( + self, tz_naive_fixture, two_hours, box_with_array + ): + tz = tz_naive_fixture + + rng = date_range("2000-01-01", "2000-02-01", tz=tz, unit="ns") + expected = date_range("1999-12-31 22:00", "2000-01-31 22:00", tz=tz, unit="ns") + if tz is not None: + expected = expected._with_freq(None) + + rng = tm.box_expected(rng, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = rng - two_hours + tm.assert_equal(result, expected) + + rng -= two_hours + tm.assert_equal(rng, expected) + + def test_dt64_array_sub_dt_with_different_timezone(self, box_with_array): + t1 = date_range("20130101", periods=3).tz_localize("US/Eastern") + t1 = tm.box_expected(t1, box_with_array) + t2 = Timestamp("20130101").tz_localize("CET") + tnaive = Timestamp(20130101) + + result = t1 - t2 + expected = TimedeltaIndex( + ["0 days 06:00:00", "1 days 06:00:00", "2 days 06:00:00"] + ) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(result, expected) + + result = t2 - t1 + expected = TimedeltaIndex( + ["-1 days +18:00:00", "-2 days +18:00:00", "-3 days +18:00:00"] + ) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(result, expected) + + msg = "Cannot subtract tz-naive and tz-aware datetime-like objects" + with pytest.raises(TypeError, match=msg): + t1 - tnaive + + with pytest.raises(TypeError, match=msg): + tnaive - t1 + + def test_dt64_array_sub_dt64_array_with_different_timezone(self, box_with_array): + t1 = date_range("20130101", periods=3).tz_localize("US/Eastern") + t1 = tm.box_expected(t1, box_with_array) + t2 = date_range("20130101", periods=3).tz_localize("CET") + t2 = tm.box_expected(t2, box_with_array) + tnaive = date_range("20130101", periods=3) + + result = t1 - t2 + expected = TimedeltaIndex( + ["0 days 06:00:00", "0 days 06:00:00", "0 days 06:00:00"] + ) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(result, expected) + + result = t2 - t1 + expected = TimedeltaIndex( + ["-1 days +18:00:00", "-1 days +18:00:00", "-1 days +18:00:00"] + ) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(result, expected) + + msg = "Cannot subtract tz-naive and tz-aware datetime-like objects" + with pytest.raises(TypeError, match=msg): + t1 - tnaive + + with pytest.raises(TypeError, match=msg): + tnaive - t1 + + def test_dt64arr_add_sub_td64_nat(self, box_with_array, tz_naive_fixture): + # GH#23320 special handling for timedelta64("NaT") + tz = tz_naive_fixture + + dti = date_range("1994-04-01", periods=9, tz=tz, freq="QS", unit="ns") + other = np.timedelta64("NaT") + expected = DatetimeIndex(["NaT"] * 9, tz=tz).as_unit("ns") + + obj = tm.box_expected(dti, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = obj + other + tm.assert_equal(result, expected) + result = other + obj + tm.assert_equal(result, expected) + result = obj - other + tm.assert_equal(result, expected) + msg = "cannot subtract" + with pytest.raises(TypeError, match=msg): + other - obj + + def test_dt64arr_add_sub_td64ndarray(self, tz_naive_fixture, box_with_array): + tz = tz_naive_fixture + dti = date_range("2016-01-01", periods=3, tz=tz) + tdi = TimedeltaIndex(["-1 Day", "-1 Day", "-1 Day"]) + tdarr = tdi.values + + expected = date_range("2015-12-31", "2016-01-02", periods=3, tz=tz) + + dtarr = tm.box_expected(dti, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = dtarr + tdarr + tm.assert_equal(result, expected) + result = tdarr + dtarr + tm.assert_equal(result, expected) + + expected = date_range("2016-01-02", "2016-01-04", periods=3, tz=tz) + expected = tm.box_expected(expected, box_with_array) + + result = dtarr - tdarr + tm.assert_equal(result, expected) + msg = "|".join( + [ + "cannot subtract DatetimeArray from ndarray", + "cannot subtract a datelike from a TimedeltaArray", + "cannot subtract DatetimeArray from Timedelta", + ] + ) + with pytest.raises(TypeError, match=msg): + tdarr - dtarr + + # ----------------------------------------------------------------- + # Subtraction of datetime-like scalars + + @pytest.mark.parametrize( + "ts", + [ + Timestamp("2013-01-01"), + Timestamp("2013-01-01").to_pydatetime(), + Timestamp("2013-01-01").to_datetime64(), + # GH#7996, GH#22163 ensure non-nano datetime64 is converted to nano + # for DataFrame operation + np.datetime64("2013-01-01", "D"), + ], + ) + def test_dt64arr_sub_dtscalar(self, box_with_array, ts): + # GH#8554, GH#22163 DataFrame op should _not_ return dt64 dtype + idx = date_range("2013-01-01", periods=3)._with_freq(None) + idx = tm.box_expected(idx, box_with_array) + + expected = TimedeltaIndex(["0 Days", "1 Day", "2 Days"]) + expected = tm.box_expected(expected, box_with_array) + + result = idx - ts + tm.assert_equal(result, expected) + + result = ts - idx + tm.assert_equal(result, -expected) + tm.assert_equal(result, -expected) + + def test_dt64arr_sub_timestamp_tzaware(self, box_with_array): + ser = date_range("2014-03-17", periods=2, freq="D", tz="US/Eastern", unit="ns") + ser = ser._with_freq(None) + ts = ser[0] + + ser = tm.box_expected(ser, box_with_array) + + delta_series = Series( + [np.timedelta64(0, "D"), np.timedelta64(1, "D")], dtype="m8[ns]" + ) + expected = tm.box_expected(delta_series, box_with_array) + + tm.assert_equal(ser - ts, expected) + tm.assert_equal(ts - ser, -expected) + + def test_dt64arr_sub_NaT(self, box_with_array, unit): + # GH#18808 + dti = DatetimeIndex([NaT, Timestamp("19900315")]).as_unit(unit) + ser = tm.box_expected(dti, box_with_array) + + result = ser - NaT + expected = Series([NaT, NaT], dtype=f"timedelta64[{unit}]") + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(result, expected) + + dti_tz = dti.tz_localize("Asia/Tokyo") + ser_tz = tm.box_expected(dti_tz, box_with_array) + + result = ser_tz - NaT + expected = Series([NaT, NaT], dtype=f"timedelta64[{unit}]") + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(result, expected) + + # ------------------------------------------------------------- + # Subtraction of datetime-like array-like + + def test_dt64arr_sub_dt64object_array( + self, performance_warning, box_with_array, tz_naive_fixture + ): + dti = date_range("2016-01-01", periods=3, tz=tz_naive_fixture) + expected = dti - dti + + obj = tm.box_expected(dti, box_with_array) + expected = tm.box_expected(expected, box_with_array).astype(object) + + with tm.assert_produces_warning(performance_warning): + result = obj - obj.astype(object) + tm.assert_equal(result, expected) + + def test_dt64arr_naive_sub_dt64ndarray(self, box_with_array): + dti = date_range("2016-01-01", periods=3, tz=None) + dt64vals = dti.values + + dtarr = tm.box_expected(dti, box_with_array) + + expected = dtarr - dtarr + result = dtarr - dt64vals + tm.assert_equal(result, expected) + result = dt64vals - dtarr + tm.assert_equal(result, expected) + + def test_dt64arr_aware_sub_dt64ndarray_raises( + self, tz_aware_fixture, box_with_array + ): + tz = tz_aware_fixture + dti = date_range("2016-01-01", periods=3, tz=tz) + dt64vals = dti.values + + dtarr = tm.box_expected(dti, box_with_array) + msg = "Cannot subtract tz-naive and tz-aware datetime" + with pytest.raises(TypeError, match=msg): + dtarr - dt64vals + with pytest.raises(TypeError, match=msg): + dt64vals - dtarr + + # ------------------------------------------------------------- + # Addition of datetime-like others (invalid) + + def test_dt64arr_add_dtlike_raises(self, tz_naive_fixture, box_with_array): + # GH#22163 ensure DataFrame doesn't cast Timestamp to i8 + # GH#9631 + tz = tz_naive_fixture + + dti = date_range("2016-01-01", periods=3, tz=tz) + if tz is None: + dti2 = dti.tz_localize("US/Eastern") + else: + dti2 = dti.tz_localize(None) + dtarr = tm.box_expected(dti, box_with_array) + + assert_cannot_add(dtarr, dti.values) + assert_cannot_add(dtarr, dti) + assert_cannot_add(dtarr, dtarr) + assert_cannot_add(dtarr, dti[0]) + assert_cannot_add(dtarr, dti[0].to_pydatetime()) + assert_cannot_add(dtarr, dti[0].to_datetime64()) + assert_cannot_add(dtarr, dti2[0]) + assert_cannot_add(dtarr, dti2[0].to_pydatetime()) + assert_cannot_add(dtarr, np.datetime64("2011-01-01", "D")) + + # ------------------------------------------------------------- + # Other Invalid Addition/Subtraction + + # Note: freq here includes both Tick and non-Tick offsets; this is + # relevant because historically integer-addition was allowed if we had + # a freq. + @pytest.mark.parametrize("freq", ["h", "D", "W", "2ME", "MS", "QE", "B", None]) + @pytest.mark.parametrize("dtype", [None, "uint8"]) + def test_dt64arr_addsub_intlike( + self, dtype, index_or_series_or_array, freq, tz_naive_fixture + ): + # GH#19959, GH#19123, GH#19012 + # GH#55860 use index_or_series_or_array instead of box_with_array + # bc DataFrame alignment makes it inapplicable + tz = tz_naive_fixture + + if freq is None: + dti = DatetimeIndex(["NaT", "2017-04-05 06:07:08"], tz=tz) + else: + dti = date_range("2016-01-01", periods=2, freq=freq, tz=tz) + + obj = index_or_series_or_array(dti) + other = np.array([4, -1]) + if dtype is not None: + other = other.astype(dtype) + + msg = "|".join( + [ + "Addition/subtraction of integers", + "cannot subtract DatetimeArray from", + # IntegerArray + "can only perform ops with numeric values", + "unsupported operand type.*Categorical", + r"unsupported operand type\(s\) for -: 'int' and 'Timestamp'", + ] + ) + assert_invalid_addsub_type(obj, 1, msg) + assert_invalid_addsub_type(obj, np.int64(2), msg) + assert_invalid_addsub_type(obj, np.array(3, dtype=np.int64), msg) + assert_invalid_addsub_type(obj, other, msg) + assert_invalid_addsub_type(obj, np.array(other), msg) + assert_invalid_addsub_type(obj, pd.array(other), msg) + assert_invalid_addsub_type(obj, pd.Categorical(other), msg) + assert_invalid_addsub_type(obj, pd.Index(other), msg) + assert_invalid_addsub_type(obj, Series(other), msg) + + @pytest.mark.parametrize( + "other", + [ + 3.14, + np.array([2.0, 3.0]), + # GH#13078 datetime +/- Period is invalid + Period("2011-01-01", freq="D"), + # https://github.com/pandas-dev/pandas/issues/10329 + time(1, 2, 3), + ], + ) + @pytest.mark.parametrize("dti_freq", [None, "D"]) + def test_dt64arr_add_sub_invalid(self, dti_freq, other, box_with_array): + dti = DatetimeIndex(["2011-01-01", "2011-01-02"], freq=dti_freq) + dtarr = tm.box_expected(dti, box_with_array) + msg = "|".join( + [ + "unsupported operand type", + "cannot (add|subtract)", + "cannot use operands with types", + "ufunc '?(add|subtract)'? cannot use operands with types", + "Concatenation operation is not implemented for NumPy arrays", + ] + ) + assert_invalid_addsub_type(dtarr, other, msg) + + @pytest.mark.parametrize("pi_freq", ["D", "W", "Q", "h"]) + @pytest.mark.parametrize("dti_freq", [None, "D"]) + def test_dt64arr_add_sub_parr( + self, dti_freq, pi_freq, box_with_array, box_with_array2 + ): + # GH#20049 subtracting PeriodIndex should raise TypeError + dti = DatetimeIndex(["2011-01-01", "2011-01-02"], freq=dti_freq) + pi = dti.to_period(pi_freq) + + dtarr = tm.box_expected(dti, box_with_array) + parr = tm.box_expected(pi, box_with_array2) + msg = "|".join( + [ + "cannot (add|subtract)", + "unsupported operand", + "descriptor.*requires", + "ufunc.*cannot use operands", + ] + ) + assert_invalid_addsub_type(dtarr, parr, msg) + + @pytest.mark.filterwarnings("ignore::pandas.errors.PerformanceWarning") + def test_dt64arr_addsub_time_objects_raises(self, box_with_array, tz_naive_fixture): + # https://github.com/pandas-dev/pandas/issues/10329 + + tz = tz_naive_fixture + + obj1 = date_range("2012-01-01", periods=3, tz=tz) + obj2 = [time(i, i, i) for i in range(3)] + + obj1 = tm.box_expected(obj1, box_with_array) + obj2 = tm.box_expected(obj2, box_with_array) + + msg = "|".join( + [ + "unsupported operand", + "cannot subtract DatetimeArray from ndarray", + ] + ) + # pandas.errors.PerformanceWarning: Non-vectorized DateOffset being + # applied to Series or DatetimeIndex + # we aren't testing that here, so ignore. + assert_invalid_addsub_type(obj1, obj2, msg=msg) + + # ------------------------------------------------------------- + # Other invalid operations + + @pytest.mark.parametrize( + "dt64_series", + [ + Series([Timestamp("19900315"), Timestamp("19900315")]), + Series([NaT, Timestamp("19900315")]), + Series([NaT, NaT], dtype="datetime64[ns]"), + ], + ) + @pytest.mark.parametrize("one", [1, 1.0, np.array(1)]) + def test_dt64_mul_div_numeric_invalid(self, one, dt64_series, box_with_array): + obj = tm.box_expected(dt64_series, box_with_array) + + msg = "cannot perform .* with this index type" + + # multiplication + with pytest.raises(TypeError, match=msg): + obj * one + with pytest.raises(TypeError, match=msg): + one * obj + + # division + with pytest.raises(TypeError, match=msg): + obj / one + with pytest.raises(TypeError, match=msg): + one / obj + + +class TestDatetime64DateOffsetArithmetic: + # ------------------------------------------------------------- + # Tick DateOffsets + + # TODO: parametrize over timezone? + def test_dt64arr_series_add_tick_DateOffset(self, box_with_array, unit): + # GH#4532 + # operate with pd.offsets + ser = Series( + [Timestamp("20130101 9:01"), Timestamp("20130101 9:02")] + ).dt.as_unit(unit) + expected = Series( + [Timestamp("20130101 9:01:05"), Timestamp("20130101 9:02:05")] + ).dt.as_unit(unit) + + ser = tm.box_expected(ser, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = ser + pd.offsets.Second(5) + tm.assert_equal(result, expected) + + result2 = pd.offsets.Second(5) + ser + tm.assert_equal(result2, expected) + + def test_dt64arr_series_sub_tick_DateOffset(self, box_with_array): + # GH#4532 + # operate with pd.offsets + ser = Series([Timestamp("20130101 9:01"), Timestamp("20130101 9:02")]) + expected = Series( + [Timestamp("20130101 9:00:55"), Timestamp("20130101 9:01:55")] + ) + + ser = tm.box_expected(ser, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = ser - pd.offsets.Second(5) + tm.assert_equal(result, expected) + + result2 = -pd.offsets.Second(5) + ser + tm.assert_equal(result2, expected) + msg = "cannot subtract DatetimeArray from Second" + with pytest.raises(TypeError, match=msg): + pd.offsets.Second(5) - ser + + @pytest.mark.parametrize( + "cls_name", ["Day", "Hour", "Minute", "Second", "Milli", "Micro", "Nano"] + ) + def test_dt64arr_add_sub_tick_DateOffset_smoke(self, cls_name, box_with_array): + # GH#4532 + # smoke tests for valid DateOffsets + ser = Series([Timestamp("20130101 9:01"), Timestamp("20130101 9:02")]) + ser = tm.box_expected(ser, box_with_array) + + offset_cls = getattr(pd.offsets, cls_name) + ser + offset_cls(5) + offset_cls(5) + ser + ser - offset_cls(5) + + def test_dti_add_tick_tzaware(self, tz_aware_fixture, box_with_array): + # GH#21610, GH#22163 ensure DataFrame doesn't return object-dtype + tz = tz_aware_fixture + if tz == "US/Pacific": + dates = date_range("2012-11-01", periods=3, tz=tz, unit="ns") + offset = dates + pd.offsets.Hour(5) + assert dates[0] + pd.offsets.Hour(5) == offset[0] + + dates = date_range("2010-11-01 00:00", periods=3, tz=tz, freq="h", unit="ns") + expected = DatetimeIndex( + ["2010-11-01 05:00", "2010-11-01 06:00", "2010-11-01 07:00"], + freq="h", + tz=tz, + ).as_unit("ns") + + dates = tm.box_expected(dates, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + for scalar in [pd.offsets.Hour(5), np.timedelta64(5, "h"), timedelta(hours=5)]: + offset = dates + scalar + tm.assert_equal(offset, expected) + offset = scalar + dates + tm.assert_equal(offset, expected) + + roundtrip = offset - scalar + tm.assert_equal(roundtrip, dates) + + msg = "cannot subtract DatetimeArray from" + with pytest.raises(TypeError, match=msg): + scalar - dates + + # ------------------------------------------------------------- + # RelativeDelta DateOffsets + + def test_dt64arr_add_sub_relativedelta_offsets(self, box_with_array, unit): + # GH#10699 + vec = DatetimeIndex( + [ + Timestamp("2000-01-05 00:15:00"), + Timestamp("2000-01-31 00:23:00"), + Timestamp("2000-01-01"), + Timestamp("2000-03-31"), + Timestamp("2000-02-29"), + Timestamp("2000-12-31"), + Timestamp("2000-05-15"), + Timestamp("2001-06-15"), + ] + ).as_unit(unit) + vec = tm.box_expected(vec, box_with_array) + vec_items = vec.iloc[0] if box_with_array is pd.DataFrame else vec + + # DateOffset relativedelta fastpath + relative_kwargs = [ + ("years", 2), + ("months", 5), + ("days", 3), + ("hours", 5), + ("minutes", 10), + ("seconds", 2), + ("microseconds", 5), + ] + for i, (offset_unit, value) in enumerate(relative_kwargs): + off = DateOffset(**{offset_unit: value}) + + exp_unit = unit + if offset_unit == "microseconds" and unit != "ns": + exp_unit = "us" + + # TODO(GH#55564): as_unit will be unnecessary + expected = DatetimeIndex([x + off for x in vec_items]).as_unit(exp_unit) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(expected, vec + off) + + expected = DatetimeIndex([x - off for x in vec_items]).as_unit(exp_unit) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(expected, vec - off) + + off = DateOffset(**dict(relative_kwargs[: i + 1])) + + expected = DatetimeIndex([x + off for x in vec_items]).as_unit(exp_unit) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(expected, vec + off) + + expected = DatetimeIndex([x - off for x in vec_items]).as_unit(exp_unit) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(expected, vec - off) + msg = "cannot subtract DatetimeArray from" + with pytest.raises(TypeError, match=msg): + off - vec + + # ------------------------------------------------------------- + # Non-Tick, Non-RelativeDelta DateOffsets + + # TODO: redundant with test_dt64arr_add_sub_DateOffset? that includes + # tz-aware cases which this does not + @pytest.mark.filterwarnings("ignore::pandas.errors.PerformanceWarning") + @pytest.mark.parametrize( + "cls_and_kwargs", + [ + "YearBegin", + ("YearBegin", {"month": 5}), + "YearEnd", + ("YearEnd", {"month": 5}), + "MonthBegin", + "MonthEnd", + "SemiMonthEnd", + "SemiMonthBegin", + "Week", + ("Week", {"weekday": 3}), + ("Week", {"weekday": 6}), + "BusinessDay", + "BDay", + "QuarterEnd", + "QuarterBegin", + "CustomBusinessDay", + "CDay", + "CBMonthEnd", + "CBMonthBegin", + "BMonthBegin", + "BMonthEnd", + "BusinessHour", + "BYearBegin", + "BYearEnd", + "BQuarterBegin", + ("LastWeekOfMonth", {"weekday": 2}), + ( + "FY5253Quarter", + { + "qtr_with_extra_week": 1, + "startingMonth": 1, + "weekday": 2, + "variation": "nearest", + }, + ), + ("FY5253", {"weekday": 0, "startingMonth": 2, "variation": "nearest"}), + ("WeekOfMonth", {"weekday": 2, "week": 2}), + "Easter", + ("DateOffset", {"day": 4}), + ("DateOffset", {"month": 5}), + ], + ) + @pytest.mark.parametrize("normalize", [True, False]) + @pytest.mark.parametrize("n", [0, 5]) + @pytest.mark.parametrize("tz", [None, "US/Central"]) + def test_dt64arr_add_sub_DateOffsets( + self, box_with_array, n, normalize, cls_and_kwargs, unit, tz + ): + # GH#10699 + # assert vectorized operation matches pointwise operations + + if isinstance(cls_and_kwargs, tuple): + # If cls_name param is a tuple, then 2nd entry is kwargs for + # the offset constructor + cls_name, kwargs = cls_and_kwargs + else: + cls_name = cls_and_kwargs + kwargs = {} + + if n == 0 and cls_name in [ + "WeekOfMonth", + "LastWeekOfMonth", + "FY5253Quarter", + "FY5253", + ]: + # passing n = 0 is invalid for these offset classes + return + + vec = ( + DatetimeIndex( + [ + Timestamp("2000-01-05 00:15:00"), + Timestamp("2000-01-31 00:23:00"), + Timestamp("2000-01-01"), + Timestamp("2000-03-31"), + Timestamp("2000-02-29"), + Timestamp("2000-12-31"), + Timestamp("2000-05-15"), + Timestamp("2001-06-15"), + ] + ) + .as_unit(unit) + .tz_localize(tz) + ) + vec = tm.box_expected(vec, box_with_array) + vec_items = vec.iloc[0] if box_with_array is pd.DataFrame else vec + + offset_cls = getattr(pd.offsets, cls_name) + offset = offset_cls(n, normalize=normalize, **kwargs) + + # TODO(GH#55564): as_unit will be unnecessary + expected = DatetimeIndex([x + offset for x in vec_items]).as_unit(unit) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(expected, vec + offset) + tm.assert_equal(expected, offset + vec) + + expected = DatetimeIndex([x - offset for x in vec_items]).as_unit(unit) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(expected, vec - offset) + + expected = DatetimeIndex([offset + x for x in vec_items]).as_unit(unit) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(expected, offset + vec) + msg = "cannot subtract DatetimeArray from" + with pytest.raises(TypeError, match=msg): + offset - vec + + @pytest.mark.parametrize( + "other", + [ + [pd.offsets.MonthEnd(), pd.offsets.Day(n=2)], + [pd.offsets.DateOffset(years=1), pd.offsets.MonthEnd()], + # matching offsets + [pd.offsets.DateOffset(years=1), pd.offsets.DateOffset(years=1)], + ], + ) + @pytest.mark.parametrize("op", [operator.add, roperator.radd, operator.sub]) + def test_dt64arr_add_sub_offset_array( + self, performance_warning, tz_naive_fixture, box_with_array, op, other + ): + # GH#18849 + # GH#10699 array of offsets + + tz = tz_naive_fixture + dti = date_range("2017-01-01", periods=2, tz=tz) + dtarr = tm.box_expected(dti, box_with_array) + other = np.array(other) + expected = DatetimeIndex([op(dti[n], other[n]) for n in range(len(dti))]) + expected = tm.box_expected(expected, box_with_array).astype(object) + + with tm.assert_produces_warning(performance_warning): + res = op(dtarr, other) + tm.assert_equal(res, expected) + + # Same thing but boxing other + other = tm.box_expected(other, box_with_array) + if box_with_array is pd.array and op is roperator.radd: + # We expect a NumpyExtensionArray, not ndarray[object] here + expected = pd.array(expected, dtype=object) + with tm.assert_produces_warning(performance_warning): + res = op(dtarr, other) + tm.assert_equal(res, expected) + + @pytest.mark.parametrize( + "op, offset, exp, exp_freq", + [ + ( + "__add__", + DateOffset(months=3, days=10), + [ + Timestamp("2014-04-11"), + Timestamp("2015-04-11"), + Timestamp("2016-04-11"), + Timestamp("2017-04-11"), + ], + None, + ), + ( + "__add__", + DateOffset(months=3), + [ + Timestamp("2014-04-01"), + Timestamp("2015-04-01"), + Timestamp("2016-04-01"), + Timestamp("2017-04-01"), + ], + "YS-APR", + ), + ( + "__sub__", + DateOffset(months=3, days=10), + [ + Timestamp("2013-09-21"), + Timestamp("2014-09-21"), + Timestamp("2015-09-21"), + Timestamp("2016-09-21"), + ], + None, + ), + ( + "__sub__", + DateOffset(months=3), + [ + Timestamp("2013-10-01"), + Timestamp("2014-10-01"), + Timestamp("2015-10-01"), + Timestamp("2016-10-01"), + ], + "YS-OCT", + ), + ], + ) + def test_dti_add_sub_nonzero_mth_offset( + self, op, offset, exp, exp_freq, tz_aware_fixture, box_with_array + ): + # GH 26258 + tz = tz_aware_fixture + date = date_range( + start="01 Jan 2014", end="01 Jan 2017", freq="YS", tz=tz, unit="ns" + ) + date = tm.box_expected(date, box_with_array, False) + mth = getattr(date, op) + result = mth(offset) + + expected = DatetimeIndex(exp, tz=tz).as_unit("ns") + expected = tm.box_expected(expected, box_with_array, False) + tm.assert_equal(result, expected) + + def test_dt64arr_series_add_DateOffset_with_milli(self): + # GH 57529 + dti = DatetimeIndex( + [ + "2000-01-01 00:00:00.012345678", + "2000-01-31 00:00:00.012345678", + "2000-02-29 00:00:00.012345678", + ], + dtype="datetime64[ns]", + ) + result = dti + DateOffset(milliseconds=4) + expected = DatetimeIndex( + [ + "2000-01-01 00:00:00.016345678", + "2000-01-31 00:00:00.016345678", + "2000-02-29 00:00:00.016345678", + ], + dtype="datetime64[ns]", + ) + tm.assert_index_equal(result, expected) + + result = dti + DateOffset(days=1, milliseconds=4) + expected = DatetimeIndex( + [ + "2000-01-02 00:00:00.016345678", + "2000-02-01 00:00:00.016345678", + "2000-03-01 00:00:00.016345678", + ], + dtype="datetime64[ns]", + ) + tm.assert_index_equal(result, expected) + + +class TestDatetime64OverflowHandling: + # TODO: box + de-duplicate + + def test_dt64_overflow_masking(self, box_with_array): + # GH#25317 + left = Series([Timestamp("1969-12-31")], dtype="M8[ns]") + right = Series([NaT]) + + left = tm.box_expected(left, box_with_array) + right = tm.box_expected(right, box_with_array) + + expected = TimedeltaIndex([NaT], dtype="m8[ns]") + expected = tm.box_expected(expected, box_with_array) + + result = left - right + tm.assert_equal(result, expected) + + def test_dt64_series_arith_overflow(self): + # GH#12534, fixed by GH#19024 + dt = Timestamp("1700-01-31") + td = Timedelta("20000 Days") + dti = date_range("1949-09-30", freq="100YE", periods=4, unit="ns") + ser = Series(dti) + msg = "Overflow in int64 addition" + with pytest.raises(OverflowError, match=msg): + ser - dt + with pytest.raises(OverflowError, match=msg): + dt - ser + with pytest.raises(OverflowError, match=msg): + ser + td + with pytest.raises(OverflowError, match=msg): + td + ser + + ser.iloc[-1] = NaT + expected = Series( + ["2004-10-03", "2104-10-04", "2204-10-04", "NaT"], dtype="datetime64[ns]" + ) + res = ser + td + tm.assert_series_equal(res, expected) + res = td + ser + tm.assert_series_equal(res, expected) + + ser.iloc[1:] = NaT + expected = Series(["91279 Days", "NaT", "NaT", "NaT"], dtype="timedelta64[ns]") + res = ser - dt + tm.assert_series_equal(res, expected) + res = dt - ser + tm.assert_series_equal(res, -expected) + + def test_datetimeindex_sub_timestamp_overflow(self): + dtimax = pd.to_datetime(["2021-12-28 17:19", Timestamp.max]).as_unit("ns") + dtimin = pd.to_datetime(["2021-12-28 17:19", Timestamp.min]).as_unit("ns") + + tsneg = Timestamp("1950-01-01").as_unit("ns") + ts_neg_variants = [ + tsneg, + tsneg.to_pydatetime(), + tsneg.to_datetime64().astype("datetime64[ns]"), + tsneg.to_datetime64().astype("datetime64[D]"), + ] + + tspos = Timestamp("1980-01-01").as_unit("ns") + ts_pos_variants = [ + tspos, + tspos.to_pydatetime(), + tspos.to_datetime64().astype("datetime64[ns]"), + tspos.to_datetime64().astype("datetime64[D]"), + ] + msg = "Overflow in int64 addition" + for variant in ts_neg_variants: + with pytest.raises(OverflowError, match=msg): + dtimax - variant + + expected = Timestamp.max._value - tspos._value + for variant in ts_pos_variants: + res = dtimax - variant + assert res[1]._value == expected + + expected = Timestamp.min._value - tsneg._value + for variant in ts_neg_variants: + res = dtimin - variant + assert res[1]._value == expected + + for variant in ts_pos_variants: + with pytest.raises(OverflowError, match=msg): + dtimin - variant + + def test_datetimeindex_sub_datetimeindex_overflow(self): + # GH#22492, GH#22508 + dtimax = pd.to_datetime(["2021-12-28 17:19", Timestamp.max]).as_unit("ns") + dtimin = pd.to_datetime(["2021-12-28 17:19", Timestamp.min]).as_unit("ns") + + ts_neg = pd.to_datetime(["1950-01-01", "1950-01-01"]).as_unit("ns") + ts_pos = pd.to_datetime(["1980-01-01", "1980-01-01"]).as_unit("ns") + + # General tests + expected = Timestamp.max._value - ts_pos[1]._value + result = dtimax - ts_pos + assert result[1]._value == expected + + expected = Timestamp.min._value - ts_neg[1]._value + result = dtimin - ts_neg + assert result[1]._value == expected + msg = "Overflow in int64 addition" + with pytest.raises(OverflowError, match=msg): + dtimax - ts_neg + + with pytest.raises(OverflowError, match=msg): + dtimin - ts_pos + + # Edge cases + tmin = pd.to_datetime([Timestamp.min]) + t1 = tmin + Timedelta.max + Timedelta("1us") + with pytest.raises(OverflowError, match=msg): + t1 - tmin + + tmax = pd.to_datetime([Timestamp.max]) + t2 = tmax + Timedelta.min - Timedelta("1us") + with pytest.raises(OverflowError, match=msg): + tmax - t2 + + +class TestTimestampSeriesArithmetic: + def test_empty_series_add_sub(self, box_with_array): + # GH#13844 + a = Series(dtype="M8[ns]") + b = Series(dtype="m8[ns]") + a = box_with_array(a) + b = box_with_array(b) + tm.assert_equal(a, a + b) + tm.assert_equal(a, a - b) + tm.assert_equal(a, b + a) + msg = "cannot subtract" + with pytest.raises(TypeError, match=msg): + b - a + + def test_operators_datetimelike(self): + # ## timedelta64 ### + td1 = Series([timedelta(minutes=5, seconds=3)] * 3) + td1.iloc[2] = np.nan + + # ## datetime64 ### + dt1 = Series( + [ + Timestamp("20111230"), + Timestamp("20120101"), + Timestamp("20120103"), + ] + ) + dt1.iloc[2] = np.nan + dt2 = Series( + [ + Timestamp("20111231"), + Timestamp("20120102"), + Timestamp("20120104"), + ] + ) + dt1 - dt2 + dt2 - dt1 + + # datetime64 with timetimedelta + dt1 + td1 + td1 + dt1 + dt1 - td1 + + # timetimedelta with datetime64 + td1 + dt1 + dt1 + td1 + + def test_dt64ser_sub_datetime_dtype(self, unit): + ts = Timestamp(datetime(1993, 1, 7, 13, 30, 00)) + dt = datetime(1993, 6, 22, 13, 30) + ser = Series([ts], dtype=f"M8[{unit}]") + result = ser - dt + + # the expected unit is the max of `unit` and the unit imputed to `dt`, + # which is "us" + exp_unit = tm.get_finest_unit(unit, "us") + assert result.dtype == f"timedelta64[{exp_unit}]" + + # ------------------------------------------------------------- + # TODO: This next block of tests came from tests.series.test_operators, + # needs to be de-duplicated and parametrized over `box` classes + + @pytest.mark.parametrize( + "left, right, op_fail", + [ + [ + [Timestamp("20111230"), Timestamp("20120101"), NaT], + [Timestamp("20111231"), Timestamp("20120102"), Timestamp("20120104")], + ["__sub__", "__rsub__"], + ], + [ + [Timestamp("20111230"), Timestamp("20120101"), NaT], + [timedelta(minutes=5, seconds=3), timedelta(minutes=5, seconds=3), NaT], + ["__add__", "__radd__", "__sub__"], + ], + [ + [ + Timestamp("20111230", tz="US/Eastern"), + Timestamp("20111230", tz="US/Eastern"), + NaT, + ], + [timedelta(minutes=5, seconds=3), NaT, timedelta(minutes=5, seconds=3)], + ["__add__", "__radd__", "__sub__"], + ], + ], + ) + def test_operators_datetimelike_invalid( + self, left, right, op_fail, all_arithmetic_operators + ): + # these are all TypeError ops + op_str = all_arithmetic_operators + arg1 = Series(left) + arg2 = Series(right) + # check that we are getting a TypeError + # with 'operate' (from core/ops.py) for the ops that are not + # defined + op = getattr(arg1, op_str, None) + # Previously, _validate_for_numeric_binop in core/indexes/base.py + # did this for us. + if op_str not in op_fail: + with pytest.raises( + TypeError, match="operate|[cC]annot|unsupported operand" + ): + op(arg2) + else: + # Smoke test + op(arg2) + + def test_sub_single_tz(self, unit): + # GH#12290 + s1 = Series([Timestamp("2016-02-10", tz="America/Sao_Paulo")]).dt.as_unit(unit) + s2 = Series([Timestamp("2016-02-08", tz="America/Sao_Paulo")]).dt.as_unit(unit) + result = s1 - s2 + expected = Series([Timedelta("2days")]).dt.as_unit(unit) + tm.assert_series_equal(result, expected) + result = s2 - s1 + expected = Series([Timedelta("-2days")]).dt.as_unit(unit) + tm.assert_series_equal(result, expected) + + def test_dt64tz_series_sub_dtitz(self): + # GH#19071 subtracting tzaware DatetimeIndex from tzaware Series + # (with same tz) raises, fixed by #19024 + dti = date_range("1999-09-30", periods=10, tz="US/Pacific") + ser = Series(dti) + expected = Series(TimedeltaIndex(["0days"] * 10)) + + res = dti - ser + tm.assert_series_equal(res, expected) + res = ser - dti + tm.assert_series_equal(res, expected) + + def test_sub_datetime_compat(self, unit): + # see GH#14088 + ser = Series([datetime(2016, 8, 23, 12, tzinfo=timezone.utc), NaT]).dt.as_unit( + unit + ) + dt = datetime(2016, 8, 22, 12, tzinfo=timezone.utc) + # The datetime object has "us" so we upcast lower units + exp_unit = tm.get_finest_unit(unit, "us") + exp = Series([Timedelta("1 days"), NaT]).dt.as_unit(exp_unit) + result = ser - dt + tm.assert_series_equal(result, exp) + result2 = ser - Timestamp(dt) + tm.assert_series_equal(result2, exp) + + def test_dt64_series_add_mixed_tick_DateOffset(self): + # GH#4532 + # operate with pd.offsets + s = Series([Timestamp("20130101 9:01"), Timestamp("20130101 9:02")]) + + result = s + pd.offsets.Milli(5) + result2 = pd.offsets.Milli(5) + s + expected = Series( + [Timestamp("20130101 9:01:00.005"), Timestamp("20130101 9:02:00.005")] + ) + tm.assert_series_equal(result, expected) + tm.assert_series_equal(result2, expected) + + result = s + pd.offsets.Minute(5) + pd.offsets.Milli(5) + expected = Series( + [Timestamp("20130101 9:06:00.005"), Timestamp("20130101 9:07:00.005")] + ) + tm.assert_series_equal(result, expected) + + def test_datetime64_ops_nat(self, unit): + # GH#11349 + datetime_series = Series([NaT, Timestamp("19900315")]).dt.as_unit(unit) + nat_series_dtype_timestamp = Series([NaT, NaT], dtype=f"datetime64[{unit}]") + single_nat_dtype_datetime = Series([NaT], dtype=f"datetime64[{unit}]") + + # subtraction + tm.assert_series_equal(-NaT + datetime_series, nat_series_dtype_timestamp) + msg = "bad operand type for unary -: 'DatetimeArray'" + with pytest.raises(TypeError, match=msg): + -single_nat_dtype_datetime + datetime_series + + tm.assert_series_equal( + -NaT + nat_series_dtype_timestamp, nat_series_dtype_timestamp + ) + with pytest.raises(TypeError, match=msg): + -single_nat_dtype_datetime + nat_series_dtype_timestamp + + # addition + tm.assert_series_equal( + nat_series_dtype_timestamp + NaT, nat_series_dtype_timestamp + ) + tm.assert_series_equal( + NaT + nat_series_dtype_timestamp, nat_series_dtype_timestamp + ) + + tm.assert_series_equal( + nat_series_dtype_timestamp + NaT, nat_series_dtype_timestamp + ) + tm.assert_series_equal( + NaT + nat_series_dtype_timestamp, nat_series_dtype_timestamp + ) + + # ------------------------------------------------------------- + # Timezone-Centric Tests + + def test_operators_datetimelike_with_timezones(self): + tz = "US/Eastern" + dt1 = Series(date_range("2000-01-01 09:00:00", periods=5, tz=tz), name="foo") + dt2 = dt1.copy() + dt2.iloc[2] = np.nan + + td1 = Series(pd.timedelta_range("1 days 1 min", periods=5, freq="h")) + td2 = td1.copy() + td2.iloc[1] = np.nan + assert td2._values.freq is None + + result = dt1 + td1[0] + exp = (dt1.dt.tz_localize(None) + td1[0]).dt.tz_localize(tz) + tm.assert_series_equal(result, exp) + + result = dt2 + td2[0] + exp = (dt2.dt.tz_localize(None) + td2[0]).dt.tz_localize(tz) + tm.assert_series_equal(result, exp) + + # odd numpy behavior with scalar timedeltas + result = td1[0] + dt1 + exp = (dt1.dt.tz_localize(None) + td1[0]).dt.tz_localize(tz) + tm.assert_series_equal(result, exp) + + result = td2[0] + dt2 + exp = (dt2.dt.tz_localize(None) + td2[0]).dt.tz_localize(tz) + tm.assert_series_equal(result, exp) + + result = dt1 - td1[0] + exp = (dt1.dt.tz_localize(None) - td1[0]).dt.tz_localize(tz) + tm.assert_series_equal(result, exp) + msg = "cannot subtract DatetimeArray from" + with pytest.raises(TypeError, match=msg): + td1[0] - dt1 + + result = dt2 - td2[0] + exp = (dt2.dt.tz_localize(None) - td2[0]).dt.tz_localize(tz) + tm.assert_series_equal(result, exp) + with pytest.raises(TypeError, match=msg): + td2[0] - dt2 + + result = dt1 + td1 + exp = (dt1.dt.tz_localize(None) + td1).dt.tz_localize(tz) + tm.assert_series_equal(result, exp) + + result = dt2 + td2 + exp = (dt2.dt.tz_localize(None) + td2).dt.tz_localize(tz) + tm.assert_series_equal(result, exp) + + result = dt1 - td1 + exp = (dt1.dt.tz_localize(None) - td1).dt.tz_localize(tz) + tm.assert_series_equal(result, exp) + + result = dt2 - td2 + exp = (dt2.dt.tz_localize(None) - td2).dt.tz_localize(tz) + tm.assert_series_equal(result, exp) + msg = "cannot (add|subtract)" + with pytest.raises(TypeError, match=msg): + td1 - dt1 + with pytest.raises(TypeError, match=msg): + td2 - dt2 + + +class TestDatetimeIndexArithmetic: + # ------------------------------------------------------------- + # Binary operations DatetimeIndex and TimedeltaIndex/array + + def test_dti_add_tdi(self, tz_naive_fixture): + # GH#17558 + tz = tz_naive_fixture + dti = DatetimeIndex([Timestamp("2017-01-01", tz=tz)] * 10) + tdi = pd.timedelta_range("0 days", periods=10) + expected = date_range("2017-01-01", periods=10, tz=tz) + expected = expected._with_freq(None) + + # add with TimedeltaIndex + result = dti + tdi + tm.assert_index_equal(result, expected) + + result = tdi + dti + tm.assert_index_equal(result, expected) + + # add with timedelta64 array + result = dti + tdi.values + tm.assert_index_equal(result, expected) + + result = tdi.values + dti + tm.assert_index_equal(result, expected) + + def test_dti_iadd_tdi(self, tz_naive_fixture): + # GH#17558 + tz = tz_naive_fixture + dti = DatetimeIndex([Timestamp("2017-01-01", tz=tz)] * 10) + tdi = pd.timedelta_range("0 days", periods=10) + expected = date_range("2017-01-01", periods=10, tz=tz) + expected = expected._with_freq(None) + + # iadd with TimedeltaIndex + result = DatetimeIndex([Timestamp("2017-01-01", tz=tz)] * 10) + result += tdi + tm.assert_index_equal(result, expected) + + result = pd.timedelta_range("0 days", periods=10) + result += dti + tm.assert_index_equal(result, expected) + + # iadd with timedelta64 array + result = DatetimeIndex([Timestamp("2017-01-01", tz=tz)] * 10) + result += tdi.values + tm.assert_index_equal(result, expected) + + result = pd.timedelta_range("0 days", periods=10) + result += dti + tm.assert_index_equal(result, expected) + + def test_dti_sub_tdi(self, tz_naive_fixture): + # GH#17558 + tz = tz_naive_fixture + dti = DatetimeIndex([Timestamp("2017-01-01", tz=tz)] * 10) + tdi = pd.timedelta_range("0 days", periods=10) + expected = date_range("2017-01-01", periods=10, tz=tz, freq="-1D") + expected = expected._with_freq(None) + + # sub with TimedeltaIndex + result = dti - tdi + tm.assert_index_equal(result, expected) + + msg = "cannot subtract .*TimedeltaArray" + with pytest.raises(TypeError, match=msg): + tdi - dti + + # sub with timedelta64 array + result = dti - tdi.values + tm.assert_index_equal(result, expected) + + msg = "cannot subtract a datelike from a TimedeltaArray" + with pytest.raises(TypeError, match=msg): + tdi.values - dti + + def test_dti_isub_tdi(self, tz_naive_fixture, unit): + # GH#17558 + tz = tz_naive_fixture + dti = DatetimeIndex([Timestamp("2017-01-01", tz=tz)] * 10).as_unit(unit) + tdi = pd.timedelta_range("0 days", periods=10, unit=unit) + expected = date_range("2017-01-01", periods=10, tz=tz, freq="-1D", unit=unit) + expected = expected._with_freq(None) + + # isub with TimedeltaIndex + result = DatetimeIndex([Timestamp("2017-01-01", tz=tz)] * 10).as_unit(unit) + result -= tdi + tm.assert_index_equal(result, expected) + + # DTA.__isub__ GH#43904 + dta = dti._data.copy() + dta -= tdi + tm.assert_datetime_array_equal(dta, expected._data) + + out = dti._data.copy() + np.subtract(out, tdi, out=out) + tm.assert_datetime_array_equal(out, expected._data) + + msg = "cannot subtract a datelike from a TimedeltaArray" + with pytest.raises(TypeError, match=msg): + tdi -= dti + + # isub with timedelta64 array + result = DatetimeIndex([Timestamp("2017-01-01", tz=tz)] * 10).as_unit(unit) + result -= tdi.values + tm.assert_index_equal(result, expected) + + with pytest.raises(TypeError, match=msg): + tdi.values -= dti + + with pytest.raises(TypeError, match=msg): + tdi._values -= dti + + # ------------------------------------------------------------- + # Binary Operations DatetimeIndex and datetime-like + # TODO: A couple other tests belong in this section. Move them in + # A PR where there isn't already a giant diff. + + # ------------------------------------------------------------- + + def test_dta_add_sub_index(self, tz_naive_fixture): + # Check that DatetimeArray defers to Index classes + dti = date_range("20130101", periods=3, tz=tz_naive_fixture) + dta = dti.array + result = dta - dti + expected = dti - dti + tm.assert_index_equal(result, expected) + + tdi = result + result = dta + tdi + expected = dti + tdi + tm.assert_index_equal(result, expected) + + result = dta - tdi + expected = dti - tdi + tm.assert_index_equal(result, expected) + + def test_sub_dti_dti(self, unit): + # previously performed setop (deprecated in 0.16.0), now changed to + # return subtraction -> TimeDeltaIndex (GH ...) + + dti = date_range("20130101", periods=3, unit=unit) + dti_tz = date_range("20130101", periods=3, unit=unit).tz_localize("US/Eastern") + expected = TimedeltaIndex([0, 0, 0]).as_unit(unit) + + result = dti - dti + tm.assert_index_equal(result, expected) + + result = dti_tz - dti_tz + tm.assert_index_equal(result, expected) + msg = "Cannot subtract tz-naive and tz-aware datetime-like objects" + with pytest.raises(TypeError, match=msg): + dti_tz - dti + + with pytest.raises(TypeError, match=msg): + dti - dti_tz + + # isub + dti -= dti + tm.assert_index_equal(dti, expected) + + # different length raises ValueError + dti1 = date_range("20130101", periods=3, unit=unit) + dti2 = date_range("20130101", periods=4, unit=unit) + msg = "cannot add indices of unequal length" + with pytest.raises(ValueError, match=msg): + dti1 - dti2 + + # NaN propagation + dti1 = DatetimeIndex(["2012-01-01", np.nan, "2012-01-03"]).as_unit(unit) + dti2 = DatetimeIndex(["2012-01-02", "2012-01-03", np.nan]).as_unit(unit) + expected = TimedeltaIndex(["1 days", np.nan, np.nan]).as_unit(unit) + result = dti2 - dti1 + tm.assert_index_equal(result, expected) + + # ------------------------------------------------------------------- + # TODO: Most of this block is moved from series or frame tests, needs + # cleanup, box-parametrization, and de-duplication + + @pytest.mark.parametrize("op", [operator.add, operator.sub]) + def test_timedelta64_equal_timedelta_supported_ops(self, op, box_with_array): + ser = Series( + [ + Timestamp("20130301"), + Timestamp("20130228 23:00:00"), + Timestamp("20130228 22:00:00"), + Timestamp("20130228 21:00:00"), + ] + ) + obj = box_with_array(ser) + + intervals = ["D", "h", "m", "s", "us"] + + def timedelta64(*args): + # see casting notes in NumPy gh-12927 + return np.sum(list(map(np.timedelta64, args, intervals))) + + for d, h, m, s, us in product(*([range(2)] * 5)): + nptd = timedelta64(d, h, m, s, us) + pytd = timedelta(days=d, hours=h, minutes=m, seconds=s, microseconds=us) + lhs = op(obj, nptd) + rhs = op(obj, pytd) + + tm.assert_equal(lhs, rhs) + + def test_ops_nat_mixed_datetime64_timedelta64(self): + # GH#11349 + timedelta_series = Series([NaT, Timedelta("1s")]) + datetime_series = Series([NaT, Timestamp("19900315")]) + nat_series_dtype_timedelta = Series([NaT, NaT], dtype="timedelta64[ns]") + nat_series_dtype_timestamp = Series([NaT, NaT], dtype="datetime64[ns]") + single_nat_dtype_datetime = Series([NaT], dtype="datetime64[ns]") + single_nat_dtype_timedelta = Series([NaT], dtype="timedelta64[ns]") + + # subtraction + tm.assert_series_equal( + datetime_series - single_nat_dtype_datetime, nat_series_dtype_timedelta + ) + + tm.assert_series_equal( + datetime_series - single_nat_dtype_timedelta, nat_series_dtype_timestamp + ) + tm.assert_series_equal( + -single_nat_dtype_timedelta + datetime_series, nat_series_dtype_timestamp + ) + + # without a Series wrapping the NaT, it is ambiguous + # whether it is a datetime64 or timedelta64 + # defaults to interpreting it as timedelta64 + tm.assert_series_equal( + nat_series_dtype_timestamp - single_nat_dtype_datetime, + nat_series_dtype_timedelta, + ) + + tm.assert_series_equal( + nat_series_dtype_timestamp - single_nat_dtype_timedelta, + nat_series_dtype_timestamp, + ) + tm.assert_series_equal( + -single_nat_dtype_timedelta + nat_series_dtype_timestamp, + nat_series_dtype_timestamp, + ) + msg = "cannot subtract a datelike" + with pytest.raises(TypeError, match=msg): + timedelta_series - single_nat_dtype_datetime + + # addition + tm.assert_series_equal( + nat_series_dtype_timestamp + single_nat_dtype_timedelta, + nat_series_dtype_timestamp, + ) + tm.assert_series_equal( + single_nat_dtype_timedelta + nat_series_dtype_timestamp, + nat_series_dtype_timestamp, + ) + + tm.assert_series_equal( + nat_series_dtype_timestamp + single_nat_dtype_timedelta, + nat_series_dtype_timestamp, + ) + tm.assert_series_equal( + single_nat_dtype_timedelta + nat_series_dtype_timestamp, + nat_series_dtype_timestamp, + ) + + tm.assert_series_equal( + nat_series_dtype_timedelta + single_nat_dtype_datetime, + nat_series_dtype_timestamp, + ) + tm.assert_series_equal( + single_nat_dtype_datetime + nat_series_dtype_timedelta, + nat_series_dtype_timestamp, + ) + + def test_ufunc_coercions(self, unit): + idx = date_range("2011-01-01", periods=3, freq="2D", name="x", unit=unit) + + delta = np.timedelta64(1, "D") + exp = date_range("2011-01-02", periods=3, freq="2D", name="x", unit=unit) + for result in [idx + delta, np.add(idx, delta)]: + assert isinstance(result, DatetimeIndex) + tm.assert_index_equal(result, exp) + assert result.freq == "2D" + + exp = date_range("2010-12-31", periods=3, freq="2D", name="x", unit=unit) + + for result in [idx - delta, np.subtract(idx, delta)]: + assert isinstance(result, DatetimeIndex) + tm.assert_index_equal(result, exp) + assert result.freq == "2D" + + # When adding/subtracting an ndarray (which has no .freq), the result + # does not infer freq + idx = idx._with_freq(None) + delta = np.array( + [np.timedelta64(1, "D"), np.timedelta64(2, "D"), np.timedelta64(3, "D")] + ) + exp = DatetimeIndex( + ["2011-01-02", "2011-01-05", "2011-01-08"], name="x" + ).as_unit(unit) + + for result in [idx + delta, np.add(idx, delta)]: + tm.assert_index_equal(result, exp) + assert result.freq == exp.freq + + exp = DatetimeIndex( + ["2010-12-31", "2011-01-01", "2011-01-02"], name="x" + ).as_unit(unit) + for result in [idx - delta, np.subtract(idx, delta)]: + assert isinstance(result, DatetimeIndex) + tm.assert_index_equal(result, exp) + assert result.freq == exp.freq + + def test_dti_add_series(self, tz_naive_fixture, names): + # GH#13905 + tz = tz_naive_fixture + index = DatetimeIndex( + ["2016-06-28 05:30", "2016-06-28 05:31"], tz=tz, name=names[0] + ).as_unit("ns") + ser = Series([Timedelta(seconds=5)] * 2, index=index, name=names[1]) + expected = Series(index + Timedelta(seconds=5), index=index, name=names[2]) + + # passing name arg isn't enough when names[2] is None + expected.name = names[2] + assert expected.dtype == index.dtype + result = ser + index + tm.assert_series_equal(result, expected) + result2 = index + ser + tm.assert_series_equal(result2, expected) + + expected = index + Timedelta(seconds=5) + result3 = ser.values + index + tm.assert_index_equal(result3, expected) + result4 = index + ser.values + tm.assert_index_equal(result4, expected) + + @pytest.mark.parametrize("op", [operator.add, roperator.radd, operator.sub]) + def test_dti_addsub_offset_arraylike( + self, performance_warning, tz_naive_fixture, names, op, index_or_series + ): + # GH#18849, GH#19744 + other_box = index_or_series + + tz = tz_naive_fixture + dti = date_range("2017-01-01", periods=2, tz=tz, name=names[0]) + other = other_box([pd.offsets.MonthEnd(), pd.offsets.Day(n=2)], name=names[1]) + + xbox = get_upcast_box(dti, other) + + with tm.assert_produces_warning(performance_warning): + res = op(dti, other) + + expected = DatetimeIndex( + [op(dti[n], other[n]) for n in range(len(dti))], name=names[2], freq="infer" + ) + expected = tm.box_expected(expected, xbox).astype(object) + tm.assert_equal(res, expected) + + @pytest.mark.parametrize("other_box", [pd.Index, np.array]) + def test_dti_addsub_object_arraylike( + self, performance_warning, tz_naive_fixture, box_with_array, other_box + ): + tz = tz_naive_fixture + + dti = date_range("2017-01-01", periods=2, tz=tz) + dtarr = tm.box_expected(dti, box_with_array) + other = other_box([pd.offsets.MonthEnd(), Timedelta(days=4)]) + xbox = get_upcast_box(dtarr, other) + + expected = DatetimeIndex(["2017-01-31", "2017-01-06"], tz=tz_naive_fixture) + expected = tm.box_expected(expected, xbox).astype(object) + + with tm.assert_produces_warning(performance_warning): + result = dtarr + other + tm.assert_equal(result, expected) + + expected = DatetimeIndex(["2016-12-31", "2016-12-29"], tz=tz_naive_fixture) + expected = tm.box_expected(expected, xbox).astype(object) + + with tm.assert_produces_warning(performance_warning): + result = dtarr - other + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("years", [-1, 0, 1]) +@pytest.mark.parametrize("months", [-2, 0, 2]) +def test_shift_months(years, months, unit): + dti = DatetimeIndex( + [ + Timestamp("2000-01-05 00:15:00"), + Timestamp("2000-01-31 00:23:00"), + Timestamp("2000-01-01"), + Timestamp("2000-02-29"), + Timestamp("2000-12-31"), + ] + ).as_unit(unit) + shifted = shift_months(dti.asi8, years * 12 + months, reso=dti._data._creso) + shifted_dt64 = shifted.view(f"M8[{dti.unit}]") + actual = DatetimeIndex(shifted_dt64) + + raw = [x + pd.offsets.DateOffset(years=years, months=months) for x in dti] + expected = DatetimeIndex(raw).as_unit(dti.unit) + tm.assert_index_equal(actual, expected) + + +def test_dt64arr_addsub_object_dtype_2d(performance_warning): + # block-wise DataFrame operations will require operating on 2D + # DatetimeArray/TimedeltaArray, so check that specifically. + dti = date_range("1994-02-13", freq="2W", periods=4) + dta = dti._data.reshape((4, 1)) + + other = np.array([[pd.offsets.Day(n)] for n in range(4)]) + assert other.shape == dta.shape + + with tm.assert_produces_warning(performance_warning): + result = dta + other + with tm.assert_produces_warning(performance_warning): + expected = (dta[:, 0] + other[:, 0]).reshape(-1, 1) + + tm.assert_numpy_array_equal(result, expected) + + with tm.assert_produces_warning(performance_warning): + # Case where we expect to get a TimedeltaArray back + result2 = dta - dta.astype(object) + + assert result2.shape == (4, 1) + assert all(td._value == 0 for td in result2.ravel()) + + +def test_non_nano_dt64_addsub_np_nat_scalars(): + # GH 52295 + ser = Series([1233242342344, 232432434324, 332434242344], dtype="datetime64[ms]") + result = ser - np.datetime64("nat", "ms") + expected = Series([NaT] * 3, dtype="timedelta64[ms]") + tm.assert_series_equal(result, expected) + + result = ser + np.timedelta64("nat", "ms") + expected = Series([NaT] * 3, dtype="datetime64[ms]") + tm.assert_series_equal(result, expected) + + +def test_non_nano_dt64_addsub_np_nat_scalars_unitless(): + # GH 52295 + # TODO: Can we default to the ser unit? + ser = Series([1233242342344, 232432434324, 332434242344], dtype="datetime64[ms]") + result = ser - np.datetime64("nat") + expected = Series([NaT] * 3, dtype="timedelta64[ms]") + tm.assert_series_equal(result, expected) + + result = ser + np.timedelta64("nat") + expected = Series([NaT] * 3, dtype="datetime64[ms]") + tm.assert_series_equal(result, expected) + + +def test_non_nano_dt64_addsub_np_nat_scalars_unsupported_unit(): + # GH 52295 + ser = Series([12332, 23243, 33243], dtype="datetime64[s]") + result = ser - np.datetime64("nat", "D") + expected = Series([NaT] * 3, dtype="timedelta64[s]") + tm.assert_series_equal(result, expected) + + result = ser + np.timedelta64("nat", "D") + expected = Series([NaT] * 3, dtype="datetime64[s]") + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/arithmetic/test_interval.py b/pandas/tests/arithmetic/test_interval.py new file mode 100644 index 0000000000000000000000000000000000000000..e2353da0dad568875ba150570b52262aa81d0ea4 --- /dev/null +++ b/pandas/tests/arithmetic/test_interval.py @@ -0,0 +1,308 @@ +import operator + +import numpy as np +import pytest + +from pandas.core.dtypes.common import is_list_like + +import pandas as pd +from pandas import ( + Categorical, + Index, + Interval, + IntervalIndex, + Period, + Series, + Timedelta, + Timestamp, + date_range, + period_range, + timedelta_range, +) +import pandas._testing as tm +from pandas.core.arrays import ( + BooleanArray, + IntervalArray, +) +from pandas.tests.arithmetic.common import get_upcast_box + + +@pytest.fixture( + params=[ + (Index([0, 2, 4, 4]), Index([1, 3, 5, 8])), + (Index([0.0, 1.0, 2.0, np.nan]), Index([1.0, 2.0, 3.0, np.nan])), + ( + timedelta_range("0 days", periods=3).insert(3, pd.NaT), + timedelta_range("1 day", periods=3).insert(3, pd.NaT), + ), + ( + date_range("20170101", periods=3).insert(3, pd.NaT), + date_range("20170102", periods=3).insert(3, pd.NaT), + ), + ( + date_range("20170101", periods=3, tz="US/Eastern").insert(3, pd.NaT), + date_range("20170102", periods=3, tz="US/Eastern").insert(3, pd.NaT), + ), + ], + ids=lambda x: str(x[0].dtype), +) +def left_right_dtypes(request): + """ + Fixture for building an IntervalArray from various dtypes + """ + return request.param + + +@pytest.fixture +def interval_array(left_right_dtypes): + """ + Fixture to generate an IntervalArray of various dtypes containing NA if possible + """ + left, right = left_right_dtypes + return IntervalArray.from_arrays(left, right) + + +def create_categorical_intervals(left, right, closed="right"): + return Categorical(IntervalIndex.from_arrays(left, right, closed)) + + +def create_series_intervals(left, right, closed="right"): + return Series(IntervalArray.from_arrays(left, right, closed)) + + +def create_series_categorical_intervals(left, right, closed="right"): + return Series(Categorical(IntervalIndex.from_arrays(left, right, closed))) + + +class TestComparison: + @pytest.fixture(params=[operator.eq, operator.ne]) + def op(self, request): + return request.param + + @pytest.fixture( + params=[ + IntervalArray.from_arrays, + IntervalIndex.from_arrays, + create_categorical_intervals, + create_series_intervals, + create_series_categorical_intervals, + ], + ids=[ + "IntervalArray", + "IntervalIndex", + "Categorical[Interval]", + "Series[Interval]", + "Series[Categorical[Interval]]", + ], + ) + def interval_constructor(self, request): + """ + Fixture for all pandas native interval constructors. + To be used as the LHS of IntervalArray comparisons. + """ + return request.param + + def elementwise_comparison(self, op, interval_array, other): + """ + Helper that performs elementwise comparisons between `array` and `other` + """ + other = other if is_list_like(other) else [other] * len(interval_array) + expected = np.array( + [op(x, y) for x, y in zip(interval_array, other, strict=True)] + ) + if isinstance(other, Series): + return Series(expected, index=other.index) + return expected + + def test_compare_scalar_interval(self, op, interval_array): + # matches first interval + other = interval_array[0] + result = op(interval_array, other) + expected = self.elementwise_comparison(op, interval_array, other) + tm.assert_numpy_array_equal(result, expected) + + # matches on a single endpoint but not both + other = Interval(interval_array.left[0], interval_array.right[1]) + result = op(interval_array, other) + expected = self.elementwise_comparison(op, interval_array, other) + tm.assert_numpy_array_equal(result, expected) + + def test_compare_scalar_interval_mixed_closed(self, op, closed, other_closed): + interval_array = IntervalArray.from_arrays(range(2), range(1, 3), closed=closed) + other = Interval(0, 1, closed=other_closed) + + result = op(interval_array, other) + expected = self.elementwise_comparison(op, interval_array, other) + tm.assert_numpy_array_equal(result, expected) + + def test_compare_scalar_na(self, op, interval_array, nulls_fixture, box_with_array): + box = box_with_array + obj = tm.box_expected(interval_array, box) + result = op(obj, nulls_fixture) + + if nulls_fixture is pd.NA: + # GH#31882 + exp = np.ones(interval_array.shape, dtype=bool) + expected = BooleanArray(exp, exp) + else: + expected = self.elementwise_comparison(op, interval_array, nulls_fixture) + + if not (box is Index and nulls_fixture is pd.NA): + # don't cast expected from BooleanArray to ndarray[object] + xbox = get_upcast_box(obj, nulls_fixture, True) + expected = tm.box_expected(expected, xbox) + + tm.assert_equal(result, expected) + + rev = op(nulls_fixture, obj) + tm.assert_equal(rev, expected) + + @pytest.mark.parametrize( + "other", + [ + 0, + 1.0, + True, + "foo", + Timestamp("2017-01-01"), + Timestamp("2017-01-01", tz="US/Eastern"), + Timedelta("0 days"), + Period("2017-01-01", "D"), + ], + ) + def test_compare_scalar_other(self, op, interval_array, other): + result = op(interval_array, other) + expected = self.elementwise_comparison(op, interval_array, other) + tm.assert_numpy_array_equal(result, expected) + + def test_compare_list_like_interval(self, op, interval_array, interval_constructor): + # same endpoints + other = interval_constructor(interval_array.left, interval_array.right) + result = op(interval_array, other) + expected = self.elementwise_comparison(op, interval_array, other) + tm.assert_equal(result, expected) + + # different endpoints + other = interval_constructor( + interval_array.left[::-1], interval_array.right[::-1] + ) + result = op(interval_array, other) + expected = self.elementwise_comparison(op, interval_array, other) + tm.assert_equal(result, expected) + + # all nan endpoints + other = interval_constructor([np.nan] * 4, [np.nan] * 4) + result = op(interval_array, other) + expected = self.elementwise_comparison(op, interval_array, other) + tm.assert_equal(result, expected) + + def test_compare_list_like_interval_mixed_closed( + self, op, interval_constructor, closed, other_closed + ): + interval_array = IntervalArray.from_arrays(range(2), range(1, 3), closed=closed) + other = interval_constructor(range(2), range(1, 3), closed=other_closed) + + result = op(interval_array, other) + expected = self.elementwise_comparison(op, interval_array, other) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "other", + [ + ( + Interval(0, 1), + Interval(Timedelta("1 day"), Timedelta("2 days")), + Interval(4, 5, "both"), + Interval(10, 20, "neither"), + ), + (0, 1.5, Timestamp("20170103"), np.nan), + ( + Timestamp("20170102", tz="US/Eastern"), + Timedelta("2 days"), + "baz", + pd.NaT, + ), + ], + ) + def test_compare_list_like_object(self, op, interval_array, other): + result = op(interval_array, other) + expected = self.elementwise_comparison(op, interval_array, other) + tm.assert_numpy_array_equal(result, expected) + + def test_compare_list_like_nan(self, op, interval_array, nulls_fixture): + other = [nulls_fixture] * 4 + result = op(interval_array, other) + expected = self.elementwise_comparison(op, interval_array, other) + + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "other", + [ + np.arange(4, dtype="int64"), + np.arange(4, dtype="float64"), + date_range("2017-01-01", periods=4), + date_range("2017-01-01", periods=4, tz="US/Eastern"), + timedelta_range("0 days", periods=4), + period_range("2017-01-01", periods=4, freq="D"), + Categorical(list("abab")), + Categorical(date_range("2017-01-01", periods=4)), + pd.array(list("abcd")), + pd.array(["foo", 3.14, None, object()], dtype=object), + ], + ids=lambda x: str(x.dtype), + ) + def test_compare_list_like_other(self, op, interval_array, other): + result = op(interval_array, other) + expected = self.elementwise_comparison(op, interval_array, other) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("length", [1, 3, 5]) + @pytest.mark.parametrize("other_constructor", [IntervalArray, list]) + def test_compare_length_mismatch_errors(self, op, other_constructor, length): + interval_array = IntervalArray.from_arrays(range(4), range(1, 5)) + other = other_constructor([Interval(0, 1)] * length) + with pytest.raises(ValueError, match="Lengths must match to compare"): + op(interval_array, other) + + @pytest.mark.parametrize( + "constructor, expected_type, assert_func", + [ + (IntervalIndex, np.array, tm.assert_numpy_array_equal), + (Series, Series, tm.assert_series_equal), + ], + ) + def test_index_series_compat(self, op, constructor, expected_type, assert_func): + # IntervalIndex/Series that rely on IntervalArray for comparisons + breaks = range(4) + index = constructor(IntervalIndex.from_breaks(breaks)) + + # scalar comparisons + other = index[0] + result = op(index, other) + expected = expected_type(self.elementwise_comparison(op, index, other)) + assert_func(result, expected) + + other = breaks[0] + result = op(index, other) + expected = expected_type(self.elementwise_comparison(op, index, other)) + assert_func(result, expected) + + # list-like comparisons + other = IntervalArray.from_breaks(breaks) + result = op(index, other) + expected = expected_type(self.elementwise_comparison(op, index, other)) + assert_func(result, expected) + + other = [index[0], breaks[0], "foo"] + result = op(index, other) + expected = expected_type(self.elementwise_comparison(op, index, other)) + assert_func(result, expected) + + @pytest.mark.parametrize("scalars", ["a", False, 1, 1.0, None]) + def test_comparison_operations(self, scalars): + # GH #28981 + expected = Series([False, False]) + s = Series([Interval(0, 1), Interval(1, 2)], dtype="interval") + result = s == scalars + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/arithmetic/test_numeric.py b/pandas/tests/arithmetic/test_numeric.py new file mode 100644 index 0000000000000000000000000000000000000000..5878246126d617e96dbb11b2d437e78fe87a3aec --- /dev/null +++ b/pandas/tests/arithmetic/test_numeric.py @@ -0,0 +1,1585 @@ +# Arithmetic tests for DataFrame/Series/Index/Array classes that should +# behave identically. +# Specifically for numeric dtypes +from __future__ import annotations + +from collections import abc +from datetime import timedelta +from decimal import Decimal +import operator + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + Index, + RangeIndex, + Series, + Timedelta, + TimedeltaIndex, + array, + date_range, +) +import pandas._testing as tm +from pandas.core import ops +from pandas.core.computation import expressions as expr +from pandas.tests.arithmetic.common import ( + assert_invalid_addsub_type, + assert_invalid_comparison, +) + + +@pytest.fixture(autouse=True, params=[0, 1000000], ids=["numexpr", "python"]) +def switch_numexpr_min_elements(request, monkeypatch): + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", request.param) + yield request.param + + +@pytest.fixture( + params=[ + # TODO: add more dtypes here + Index(np.arange(5, dtype="float64")), + Index(np.arange(5, dtype="int64")), + Index(np.arange(5, dtype="uint64")), + RangeIndex(5), + ], + ids=lambda x: type(x).__name__, +) +def numeric_idx(request): + """ + Several types of numeric-dtypes Index objects + """ + return request.param + + +def adjust_negative_zero(zero, expected): + """ + Helper to adjust the expected result if we are dividing by -0.0 + as opposed to 0.0 + """ + if np.signbit(np.array(zero)).any(): + # All entries in the `zero` fixture should be either + # all-negative or no-negative. + assert np.signbit(np.array(zero)).all() + + expected *= -1 + + return expected + + +def compare_op(series, other, op): + left = np.abs(series) if op in (ops.rpow, operator.pow) else series + right = np.abs(other) if op in (ops.rpow, operator.pow) else other + + cython_or_numpy = op(left, right) + python = left.combine(right, op) + if isinstance(other, Series) and not other.index.equals(series.index): + python.index = python.index._with_freq(None) + tm.assert_series_equal(cython_or_numpy, python) + + +# TODO: remove this kludge once mypy stops giving false positives here +# List comprehension has incompatible type List[PandasObject]; expected List[RangeIndex] +# See GH#29725 +_ldtypes = ["i1", "i2", "i4", "i8", "u1", "u2", "u4", "u8", "f2", "f4", "f8"] +lefts: list[Index | Series] = [RangeIndex(10, 40, 10)] +lefts.extend([Series([10, 20, 30], dtype=dtype) for dtype in _ldtypes]) +lefts.extend([Index([10, 20, 30], dtype=dtype) for dtype in _ldtypes if dtype != "f2"]) + +# ------------------------------------------------------------------ +# Comparisons + + +class TestNumericComparisons: + def test_operator_series_comparison_zerorank(self): + # GH#13006 + result = np.float64(0) > Series([1, 2, 3]) + expected = 0.0 > Series([1, 2, 3]) + tm.assert_series_equal(result, expected) + result = Series([1, 2, 3]) < np.float64(0) + expected = Series([1, 2, 3]) < 0.0 + tm.assert_series_equal(result, expected) + result = np.array([0, 1, 2])[0] > Series([0, 1, 2]) + expected = 0.0 > Series([1, 2, 3]) + tm.assert_series_equal(result, expected) + + def test_df_numeric_cmp_dt64_raises(self, box_with_array, fixed_now_ts): + # GH#8932, GH#22163 + ts = fixed_now_ts + obj = np.array(range(5)) + obj = tm.box_expected(obj, box_with_array) + + assert_invalid_comparison(obj, ts, box_with_array) + + def test_compare_invalid(self): + # GH#8058 + # ops testing + a = Series(np.random.default_rng(2).standard_normal(5), name=0) + b = Series(np.random.default_rng(2).standard_normal(5)) + b.name = pd.Timestamp("2000-01-01") + tm.assert_series_equal(a / b, 1 / (b / a)) + + def test_numeric_cmp_string_numexpr_path(self, box_with_array, monkeypatch): + # GH#36377, GH#35700 + box = box_with_array + xbox = box if box is not Index else np.ndarray + + obj = Series(np.random.default_rng(2).standard_normal(51)) + obj = tm.box_expected(obj, box, transpose=False) + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", 50) + result = obj == "a" + + expected = Series(np.zeros(51, dtype=bool)) + expected = tm.box_expected(expected, xbox, transpose=False) + tm.assert_equal(result, expected) + + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", 50) + result = obj != "a" + tm.assert_equal(result, ~expected) + + msg = "Invalid comparison between dtype=float64 and str" + with pytest.raises(TypeError, match=msg): + obj < "a" + + +# ------------------------------------------------------------------ +# Numeric dtypes Arithmetic with Datetime/Timedelta Scalar + + +class TestNumericArraylikeArithmeticWithDatetimeLike: + def test_mul_timedelta_list(self, box_with_array): + # GH#62524 + box = box_with_array + left = np.array([3, 4]) + left = tm.box_expected(left, box) + + right = [Timedelta(days=1), Timedelta(days=2)] + + result = left * right + + expected = TimedeltaIndex([Timedelta(days=3), Timedelta(days=8)]) + expected = tm.box_expected(expected, box) + tm.assert_equal(result, expected) + + result2 = right * left + tm.assert_equal(result2, expected) + + @pytest.mark.parametrize("box_cls", [np.array, Index, Series]) + @pytest.mark.parametrize( + "left", lefts, ids=lambda x: type(x).__name__ + str(x.dtype) + ) + def test_mul_td64arr(self, left, box_cls): + # GH#22390 + right = np.array([1, 2, 3], dtype="m8[s]") + right = box_cls(right) + + expected = TimedeltaIndex(["10s", "40s", "90s"], dtype=right.dtype) + + if isinstance(left, Series) or box_cls is Series: + expected = Series(expected) + assert expected.dtype == right.dtype + + result = left * right + tm.assert_equal(result, expected) + + result = right * left + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("box_cls", [np.array, Index, Series]) + @pytest.mark.parametrize( + "left", lefts, ids=lambda x: type(x).__name__ + str(x.dtype) + ) + def test_div_td64arr(self, left, box_cls): + # GH#22390 + right = np.array([10, 40, 90], dtype="m8[s]") + right = box_cls(right) + + expected = TimedeltaIndex(["1s", "2s", "3s"], dtype=right.dtype) + if isinstance(left, Series) or box_cls is Series: + expected = Series(expected) + assert expected.dtype == right.dtype + + result = right / left + tm.assert_equal(result, expected) + + result = right // left + tm.assert_equal(result, expected) + + # (true_) needed for min-versions build 2022-12-26 + msg = "ufunc '(true_)?divide' cannot use operands with types" + with pytest.raises(TypeError, match=msg): + left / right + + msg = "ufunc 'floor_divide' cannot use operands with types" + with pytest.raises(TypeError, match=msg): + left // right + + # TODO: also test Tick objects; + # see test_numeric_arr_rdiv_tdscalar for note on these failing + @pytest.mark.parametrize( + "scalar_td", + [ + Timedelta(days=1).as_unit("ns"), + Timedelta(days=1).as_unit("ns").to_timedelta64(), + Timedelta(days=1).to_pytimedelta(), + Timedelta(days=1).to_timedelta64().astype("timedelta64[s]"), + Timedelta(days=1).to_timedelta64().astype("timedelta64[ms]"), + ], + ids=lambda x: type(x).__name__, + ) + def test_numeric_arr_mul_tdscalar(self, scalar_td, numeric_idx, box_with_array): + # GH#19333 + box = box_with_array + index = numeric_idx + expected = TimedeltaIndex( + [Timedelta(days=n) for n in range(len(index))], dtype="m8[ns]" + ) + if isinstance(scalar_td, np.timedelta64): + dtype = scalar_td.dtype + expected = expected.astype(dtype) + elif type(scalar_td) is timedelta: + expected = expected.astype("m8[us]") + + index = tm.box_expected(index, box) + expected = tm.box_expected(expected, box) + + result = index * scalar_td + tm.assert_equal(result, expected) + + commute = scalar_td * index + tm.assert_equal(commute, expected) + + @pytest.mark.parametrize( + "scalar_td", + [ + Timedelta(days=1).as_unit("ns"), + Timedelta(days=1).as_unit("ns").to_timedelta64(), + Timedelta(days=1).as_unit("ns").to_pytimedelta(), + ], + ids=lambda x: type(x).__name__, + ) + @pytest.mark.parametrize("dtype", [np.int64, np.float64]) + def test_numeric_arr_mul_tdscalar_numexpr_path( + self, dtype, scalar_td, box_with_array + ): + # GH#44772 for the float64 case + box = box_with_array + + arr_i8 = np.arange(2 * 10**4).astype(np.int64, copy=False) + arr = arr_i8.astype(dtype, copy=False) + obj = tm.box_expected(arr, box, transpose=False) + + expected = arr_i8.view("timedelta64[D]").astype("timedelta64[ns]") + if type(scalar_td) is timedelta: + expected = expected.astype("timedelta64[us]") + + expected = tm.box_expected(expected, box, transpose=False) + + result = obj * scalar_td + tm.assert_equal(result, expected) + + result = scalar_td * obj + tm.assert_equal(result, expected) + + def test_numeric_arr_rdiv_tdscalar(self, three_days, numeric_idx, box_with_array): + box = box_with_array + + index = numeric_idx[1:3] + + expected = TimedeltaIndex(["3 Days", "36 Hours"]) + if isinstance(three_days, np.timedelta64): + dtype = three_days.dtype + if dtype < np.dtype("m8[s]"): + # i.e. resolution is lower -> use lowest supported resolution + dtype = np.dtype("m8[s]") + expected = expected.astype(dtype) + elif type(three_days) is timedelta or ( + isinstance(three_days, Timedelta) and three_days.unit == "us" + ): + expected = expected.astype("m8[us]") + elif isinstance( + three_days, + (pd.offsets.Day, pd.offsets.Hour, pd.offsets.Minute, pd.offsets.Second), + ): + # closest reso is Second + expected = expected.astype("m8[s]") + + index = tm.box_expected(index, box) + expected = tm.box_expected(expected, box) + + if isinstance(three_days, pd.offsets.Day): + # GH#41943 Day is no longer timedelta-like + msg = "unsupported operand type" + with pytest.raises(TypeError, match=msg): + three_days / index + else: + result = three_days / index + tm.assert_equal(result, expected) + msg = "cannot use operands with types dtype" + + with pytest.raises(TypeError, match=msg): + index / three_days + + @pytest.mark.parametrize( + "other", + [ + Timedelta(hours=31), + Timedelta(hours=31).to_pytimedelta(), + Timedelta(hours=31).to_timedelta64(), + Timedelta(hours=31).to_timedelta64().astype("m8[h]"), + np.timedelta64("NaT"), + np.timedelta64("NaT", "D"), + pd.offsets.Minute(3), + pd.offsets.Second(0), + # GH#28080 numeric+datetimelike should raise; Timestamp used + # to raise NullFrequencyError but that behavior was removed in 1.0 + pd.Timestamp("2021-01-01", tz="Asia/Tokyo"), + pd.Timestamp("2021-01-01"), + pd.Timestamp("2021-01-01").to_pydatetime(), + pd.Timestamp("2021-01-01", tz="UTC").to_pydatetime(), + pd.Timestamp("2021-01-01").to_datetime64(), + np.datetime64("NaT", "ns"), + pd.NaT, + ], + ids=repr, + ) + def test_add_sub_datetimedeltalike_invalid( + self, numeric_idx, other, box_with_array + ): + box = box_with_array + + left = tm.box_expected(numeric_idx, box) + msg = "|".join( + [ + "unsupported operand type", + "Addition/subtraction of integers and integer-arrays", + "Instead of adding/subtracting", + "cannot use operands with types dtype", + "Concatenation operation is not implemented for NumPy arrays", + "Cannot (add|subtract) NaT (to|from) ndarray", + # pd.array vs np.datetime64 case + r"operand type\(s\) all returned NotImplemented from __array_ufunc__", + "can only perform ops with numeric values", + "cannot subtract DatetimeArray from ndarray", + # pd.Timedelta(1) + Index([0, 1, 2]) + "Cannot add or subtract Timedelta from integers", + ] + ) + assert_invalid_addsub_type(left, other, msg) + + +# ------------------------------------------------------------------ +# Arithmetic + + +class TestDivisionByZero: + def test_div_zero(self, zero, numeric_idx): + idx = numeric_idx + + expected = Index([np.nan, np.inf, np.inf, np.inf, np.inf], dtype=np.float64) + # We only adjust for Index, because Series does not yet apply + # the adjustment correctly. + expected2 = adjust_negative_zero(zero, expected) + + result = idx / zero + tm.assert_index_equal(result, expected2) + ser_compat = Series(idx).astype("i8") / np.array(zero).astype("i8") + tm.assert_series_equal(ser_compat, Series(expected)) + + def test_floordiv_zero(self, zero, numeric_idx): + idx = numeric_idx + + expected = Index([np.nan, np.inf, np.inf, np.inf, np.inf], dtype=np.float64) + # We only adjust for Index, because Series does not yet apply + # the adjustment correctly. + expected2 = adjust_negative_zero(zero, expected) + + result = idx // zero + tm.assert_index_equal(result, expected2) + ser_compat = Series(idx).astype("i8") // np.array(zero).astype("i8") + tm.assert_series_equal(ser_compat, Series(expected)) + + def test_mod_zero(self, zero, numeric_idx): + idx = numeric_idx + + expected = Index([np.nan, np.nan, np.nan, np.nan, np.nan], dtype=np.float64) + result = idx % zero + tm.assert_index_equal(result, expected) + ser_compat = Series(idx).astype("i8") % np.array(zero).astype("i8") + tm.assert_series_equal(ser_compat, Series(result)) + + def test_divmod_zero(self, zero, numeric_idx): + idx = numeric_idx + + exleft = Index([np.nan, np.inf, np.inf, np.inf, np.inf], dtype=np.float64) + exright = Index([np.nan, np.nan, np.nan, np.nan, np.nan], dtype=np.float64) + exleft = adjust_negative_zero(zero, exleft) + + result = divmod(idx, zero) + tm.assert_index_equal(result[0], exleft) + tm.assert_index_equal(result[1], exright) + + @pytest.mark.parametrize("op", [operator.truediv, operator.floordiv]) + def test_div_negative_zero(self, zero, numeric_idx, op): + # Check that -1 / -0.0 returns np.inf, not -np.inf + if numeric_idx.dtype == np.uint64: + pytest.skip(f"Div by negative 0 not relevant for {numeric_idx.dtype}") + idx = numeric_idx - 3 + + expected = Index([-np.inf, -np.inf, -np.inf, np.nan, np.inf], dtype=np.float64) + expected = adjust_negative_zero(zero, expected) + + result = op(idx, zero) + tm.assert_index_equal(result, expected) + + # ------------------------------------------------------------------ + + @pytest.mark.parametrize("dtype1", [np.int64, np.float64, np.uint64]) + def test_ser_div_ser( + self, + switch_numexpr_min_elements, + dtype1, + any_real_numpy_dtype, + ): + # no longer do integer div for any ops, but deal with the 0's + dtype2 = any_real_numpy_dtype + + first = Series([3, 4, 5, 8], name="first").astype(dtype1) + second = Series([0, 0, 0, 3], name="second").astype(dtype2) + + with np.errstate(all="ignore"): + expected = Series( + first.values.astype(np.float64) / second.values, + dtype="float64", + name=None, + ) + expected.iloc[0:3] = np.inf + if first.dtype == "int64" and second.dtype == "float32": + # when using numexpr, the casting rules are slightly different + # and int64/float32 combo results in float32 instead of float64 + if expr.USE_NUMEXPR and switch_numexpr_min_elements == 0: + expected = expected.astype("float32") + + result = first / second + tm.assert_series_equal(result, expected) + assert not result.equals(second / first) + + @pytest.mark.parametrize("dtype1", [np.int64, np.float64, np.uint64]) + def test_ser_divmod_zero(self, dtype1, any_real_numpy_dtype): + # GH#26987 + dtype2 = any_real_numpy_dtype + left = Series([1, 1]).astype(dtype1) + right = Series([0, 2]).astype(dtype2) + + # GH#27321 pandas convention is to set 1 // 0 to np.inf, as opposed + # to numpy which sets to np.nan; patch `expected[0]` below + expected = left // right, left % right + expected = list(expected) + expected[0] = expected[0].astype(np.float64) + expected[0][0] = np.inf + result = divmod(left, right) + + tm.assert_series_equal(result[0], expected[0]) + tm.assert_series_equal(result[1], expected[1]) + + # rdivmod case + result = divmod(left.values, right) + tm.assert_series_equal(result[0], expected[0]) + tm.assert_series_equal(result[1], expected[1]) + + def test_ser_divmod_inf(self): + left = Series([np.inf, 1.0]) + right = Series([np.inf, 2.0]) + + expected = left // right, left % right + result = divmod(left, right) + + tm.assert_series_equal(result[0], expected[0]) + tm.assert_series_equal(result[1], expected[1]) + + # rdivmod case + result = divmod(left.values, right) + tm.assert_series_equal(result[0], expected[0]) + tm.assert_series_equal(result[1], expected[1]) + + def test_rdiv_zero_compat(self): + # GH#8674 + zero_array = np.array([0] * 5) + data = np.random.default_rng(2).standard_normal(5) + expected = Series([0.0] * 5) + + result = zero_array / Series(data) + tm.assert_series_equal(result, expected) + + result = Series(zero_array) / data + tm.assert_series_equal(result, expected) + + result = Series(zero_array) / Series(data) + tm.assert_series_equal(result, expected) + + def test_div_zero_inf_signs(self): + # GH#9144, inf signing + ser = Series([-1, 0, 1], name="first") + expected = Series([-np.inf, np.nan, np.inf], name="first") + + result = ser / 0 + tm.assert_series_equal(result, expected) + + def test_rdiv_zero(self): + # GH#9144 + ser = Series([-1, 0, 1], name="first") + expected = Series([0.0, np.nan, 0.0], name="first") + + result = 0 / ser + tm.assert_series_equal(result, expected) + + def test_floordiv_div(self): + # GH#9144 + ser = Series([-1, 0, 1], name="first") + + result = ser // 0 + expected = Series([-np.inf, np.nan, np.inf], name="first") + tm.assert_series_equal(result, expected) + + def test_df_div_zero_df(self): + # integer div, but deal with the 0's (GH#9144) + df = pd.DataFrame({"first": [3, 4, 5, 8], "second": [0, 0, 0, 3]}) + result = df / df + + first = Series([1.0, 1.0, 1.0, 1.0]) + second = Series([np.nan, np.nan, np.nan, 1]) + expected = pd.DataFrame({"first": first, "second": second}) + tm.assert_frame_equal(result, expected) + + def test_df_div_zero_array(self): + # integer div, but deal with the 0's (GH#9144) + df = pd.DataFrame({"first": [3, 4, 5, 8], "second": [0, 0, 0, 3]}) + + first = Series([1.0, 1.0, 1.0, 1.0]) + second = Series([np.nan, np.nan, np.nan, 1]) + expected = pd.DataFrame({"first": first, "second": second}) + + with np.errstate(all="ignore"): + arr = df.values.astype("float") / df.values + result = pd.DataFrame(arr, index=df.index, columns=df.columns) + tm.assert_frame_equal(result, expected) + + def test_df_div_zero_int(self): + # integer div, but deal with the 0's (GH#9144) + df = pd.DataFrame({"first": [3, 4, 5, 8], "second": [0, 0, 0, 3]}) + + result = df / 0 + expected = pd.DataFrame(np.inf, index=df.index, columns=df.columns) + expected.iloc[0:3, 1] = np.nan + tm.assert_frame_equal(result, expected) + + # numpy has a slightly different (wrong) treatment + with np.errstate(all="ignore"): + arr = df.values.astype("float64") / 0 + result2 = pd.DataFrame(arr, index=df.index, columns=df.columns) + tm.assert_frame_equal(result2, expected) + + def test_df_div_zero_series_does_not_commute(self): + # integer div, but deal with the 0's (GH#9144) + df = pd.DataFrame(np.random.default_rng(2).standard_normal((10, 5))) + ser = df[0] + res = ser / df + res2 = df / ser + assert not res.fillna(0).equals(res2.fillna(0)) + + # ------------------------------------------------------------------ + # Mod By Zero + + def test_df_mod_zero_df(self): + # GH#3590, modulo as ints + df = pd.DataFrame({"first": [3, 4, 5, 8], "second": [0, 0, 0, 3]}) + # this is technically wrong, as the integer portion is coerced to float + first = Series([0, 0, 0, 0]) + first = first.astype("float64") + second = Series([np.nan, np.nan, np.nan, 0]) + expected = pd.DataFrame({"first": first, "second": second}) + result = df % df + tm.assert_frame_equal(result, expected) + + # GH#38939 If we dont pass copy=False, df is consolidated and + # result["first"] is float64 instead of int64 + df = pd.DataFrame({"first": [3, 4, 5, 8], "second": [0, 0, 0, 3]}, copy=False) + first = Series([0, 0, 0, 0], dtype="int64") + second = Series([np.nan, np.nan, np.nan, 0]) + expected = pd.DataFrame({"first": first, "second": second}) + result = df % df + tm.assert_frame_equal(result, expected) + + def test_df_mod_zero_array(self): + # GH#3590, modulo as ints + df = pd.DataFrame({"first": [3, 4, 5, 8], "second": [0, 0, 0, 3]}) + + # this is technically wrong, as the integer portion is coerced to float + # ### + first = Series([0, 0, 0, 0], dtype="float64") + second = Series([np.nan, np.nan, np.nan, 0]) + expected = pd.DataFrame({"first": first, "second": second}) + + # numpy has a slightly different (wrong) treatment + with np.errstate(all="ignore"): + arr = df.values % df.values + result2 = pd.DataFrame(arr, index=df.index, columns=df.columns, dtype="float64") + result2.iloc[0:3, 1] = np.nan + tm.assert_frame_equal(result2, expected) + + def test_df_mod_zero_int(self): + # GH#3590, modulo as ints + df = pd.DataFrame({"first": [3, 4, 5, 8], "second": [0, 0, 0, 3]}) + + result = df % 0 + expected = pd.DataFrame(np.nan, index=df.index, columns=df.columns) + tm.assert_frame_equal(result, expected) + + # numpy has a slightly different (wrong) treatment + with np.errstate(all="ignore"): + arr = df.values.astype("float64") % 0 + result2 = pd.DataFrame(arr, index=df.index, columns=df.columns) + tm.assert_frame_equal(result2, expected) + + def test_df_mod_zero_series_does_not_commute(self): + # GH#3590, modulo as ints + # not commutative with series + df = pd.DataFrame(np.random.default_rng(2).standard_normal((10, 5))) + ser = df[0] + res = ser % df + res2 = df % ser + assert not res.fillna(0).equals(res2.fillna(0)) + + +class TestMultiplicationDivision: + # __mul__, __rmul__, __div__, __rdiv__, __floordiv__, __rfloordiv__ + # for non-timestamp/timedelta/period dtypes + + def test_divide_decimal(self, box_with_array): + # resolves issue GH#9787 + box = box_with_array + ser = Series([Decimal(10)]) + expected = Series([Decimal(5)]) + + ser = tm.box_expected(ser, box) + expected = tm.box_expected(expected, box) + + result = ser / Decimal(2) + + tm.assert_equal(result, expected) + + result = ser // Decimal(2) + tm.assert_equal(result, expected) + + def test_div_equiv_binop(self): + # Test Series.div as well as Series.__div__ + # float/integer issue + # GH#7785 + first = Series([1, 0], name="first") + second = Series([-0.01, -0.02], name="second") + expected = Series([-0.01, -np.inf]) + + result = second.div(first) + tm.assert_series_equal(result, expected, check_names=False) + + result = second / first + tm.assert_series_equal(result, expected) + + def test_div_int(self, numeric_idx): + idx = numeric_idx + result = idx / 1 + expected = idx.astype("float64") + tm.assert_index_equal(result, expected) + + result = idx / 2 + expected = Index(idx.values / 2) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("op", [operator.mul, ops.rmul, operator.floordiv]) + def test_mul_int_identity(self, op, numeric_idx, box_with_array): + idx = numeric_idx + idx = tm.box_expected(idx, box_with_array) + + result = op(idx, 1) + tm.assert_equal(result, idx) + + def test_mul_int_array(self, numeric_idx): + idx = numeric_idx + didx = idx * idx + + result = idx * np.array(5, dtype="int64") + tm.assert_index_equal(result, idx * 5) + + arr_dtype = "uint64" if idx.dtype == np.uint64 else "int64" + result = idx * np.arange(5, dtype=arr_dtype) + tm.assert_index_equal(result, didx) + + def test_mul_int_series(self, numeric_idx): + idx = numeric_idx + didx = idx * idx + + arr_dtype = "uint64" if idx.dtype == np.uint64 else "int64" + result = idx * Series(np.arange(5, dtype=arr_dtype)) + tm.assert_series_equal(result, Series(didx)) + + def test_mul_float_series(self, numeric_idx): + idx = numeric_idx + rng5 = np.arange(5, dtype="float64") + + result = idx * Series(rng5 + 0.1) + expected = Series(rng5 * (rng5 + 0.1)) + tm.assert_series_equal(result, expected) + + def test_mul_index(self, numeric_idx): + idx = numeric_idx + + result = idx * idx + tm.assert_index_equal(result, idx**2) + + def test_mul_datelike_raises(self, numeric_idx): + idx = numeric_idx + msg = "cannot perform __rmul__ with this index type" + with pytest.raises(TypeError, match=msg): + idx * date_range("20130101", periods=5) + + def test_mul_size_mismatch_raises(self, numeric_idx): + idx = numeric_idx + msg = "operands could not be broadcast together" + with pytest.raises(ValueError, match=msg): + idx * idx[0:3] + with pytest.raises(ValueError, match=msg): + idx * np.array([1, 2]) + + @pytest.mark.parametrize("op", [operator.pow, ops.rpow]) + def test_pow_float(self, op, numeric_idx, box_with_array): + # test power calculations both ways, GH#14973 + box = box_with_array + idx = numeric_idx + expected = Index(op(idx.values, 2.0)) + + idx = tm.box_expected(idx, box) + expected = tm.box_expected(expected, box) + + result = op(idx, 2.0) + tm.assert_equal(result, expected) + + def test_modulo(self, numeric_idx, box_with_array): + # GH#9244 + box = box_with_array + idx = numeric_idx + expected = Index(idx.values % 2) + + idx = tm.box_expected(idx, box) + expected = tm.box_expected(expected, box) + + result = idx % 2 + tm.assert_equal(result, expected) + + def test_divmod_scalar(self, numeric_idx): + idx = numeric_idx + + result = divmod(idx, 2) + with np.errstate(all="ignore"): + div, mod = divmod(idx.values, 2) + + expected = Index(div), Index(mod) + for r, e in zip(result, expected, strict=True): + tm.assert_index_equal(r, e) + + def test_divmod_ndarray(self, numeric_idx): + idx = numeric_idx + other = np.ones(idx.values.shape, dtype=idx.values.dtype) * 2 + + result = divmod(idx, other) + with np.errstate(all="ignore"): + div, mod = divmod(idx.values, other) + + expected = Index(div), Index(mod) + for r, e in zip(result, expected, strict=True): + tm.assert_index_equal(r, e) + + def test_divmod_series(self, numeric_idx): + idx = numeric_idx + other = np.ones(idx.values.shape, dtype=idx.values.dtype) * 2 + + result = divmod(idx, Series(other)) + with np.errstate(all="ignore"): + div, mod = divmod(idx.values, other) + + expected = Series(div), Series(mod) + for r, e in zip(result, expected, strict=True): + tm.assert_series_equal(r, e) + + @pytest.mark.parametrize("other", [np.nan, 7, -23, 2.718, -3.14, np.inf]) + def test_ops_np_scalar(self, other): + vals = np.random.default_rng(2).standard_normal((5, 3)) + f = lambda x: pd.DataFrame( + x, index=list("ABCDE"), columns=["jim", "joe", "jolie"] + ) + + df = f(vals) + + tm.assert_frame_equal(df / np.array(other), f(vals / other)) + tm.assert_frame_equal(np.array(other) * df, f(vals * other)) + tm.assert_frame_equal(df + np.array(other), f(vals + other)) + tm.assert_frame_equal(np.array(other) - df, f(other - vals)) + + # TODO: This came from series.test.test_operators, needs cleanup + def test_operators_frame(self): + # rpow does not work with DataFrame + ts = Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + ts.name = "ts" + + df = pd.DataFrame({"A": ts}) + + tm.assert_series_equal(ts + ts, ts + df["A"], check_names=False) + tm.assert_series_equal(ts**ts, ts ** df["A"], check_names=False) + tm.assert_series_equal(ts < ts, ts < df["A"], check_names=False) + tm.assert_series_equal(ts / ts, ts / df["A"], check_names=False) + + # TODO: this came from tests.series.test_analytics, needs cleanup and + # de-duplication with test_modulo above + def test_modulo2(self): + with np.errstate(all="ignore"): + # GH#3590, modulo as ints + p = pd.DataFrame({"first": [3, 4, 5, 8], "second": [0, 0, 0, 3]}) + result = p["first"] % p["second"] + expected = Series(p["first"].values % p["second"].values, dtype="float64") + expected.iloc[0:3] = np.nan + tm.assert_series_equal(result, expected) + + result = p["first"] % 0 + expected = Series(np.nan, index=p.index, name="first") + tm.assert_series_equal(result, expected) + + p = p.astype("float64") + result = p["first"] % p["second"] + expected = Series(p["first"].values % p["second"].values) + tm.assert_series_equal(result, expected) + + p = p.astype("float64") + result = p["first"] % p["second"] + result2 = p["second"] % p["first"] + assert not result.equals(result2) + + def test_modulo_zero_int(self): + # GH#9144 + with np.errstate(all="ignore"): + s = Series([0, 1]) + + result = s % 0 + expected = Series([np.nan, np.nan]) + tm.assert_series_equal(result, expected) + + result = 0 % s + expected = Series([np.nan, 0.0]) + tm.assert_series_equal(result, expected) + + def test_non_1d_ea_raises_notimplementederror(self): + # GH#61866 + ea_array = array([1, 2, 3, 4, 5], dtype="Int64").reshape(5, 1) + np_array = np.array([1, 2, 3, 4, 5], dtype=np.int64).reshape(5, 1) + + msg = "can only perform ops with 1-d structures" + + with pytest.raises(NotImplementedError, match=msg): + ea_array * np_array + + with pytest.raises(NotImplementedError, match=msg): + np_array * ea_array + + +class TestAdditionSubtraction: + # __add__, __sub__, __radd__, __rsub__, __iadd__, __isub__ + # for non-timestamp/timedelta/period dtypes + + @pytest.mark.parametrize( + "first, second, expected", + [ + ( + Series([1, 2, 3], index=list("ABC"), name="x"), + Series([2, 2, 2], index=list("ABD"), name="x"), + Series([3.0, 4.0, np.nan, np.nan], index=list("ABCD"), name="x"), + ), + ( + Series([1, 2, 3], index=list("ABC"), name="x"), + Series([2, 2, 2, 2], index=list("ABCD"), name="x"), + Series([3, 4, 5, np.nan], index=list("ABCD"), name="x"), + ), + ], + ) + def test_add_series(self, first, second, expected): + # GH#1134 + tm.assert_series_equal(first + second, expected) + tm.assert_series_equal(second + first, expected) + + @pytest.mark.parametrize( + "first, second, expected", + [ + ( + pd.DataFrame({"x": [1, 2, 3]}, index=list("ABC")), + pd.DataFrame({"x": [2, 2, 2]}, index=list("ABD")), + pd.DataFrame({"x": [3.0, 4.0, np.nan, np.nan]}, index=list("ABCD")), + ), + ( + pd.DataFrame({"x": [1, 2, 3]}, index=list("ABC")), + pd.DataFrame({"x": [2, 2, 2, 2]}, index=list("ABCD")), + pd.DataFrame({"x": [3, 4, 5, np.nan]}, index=list("ABCD")), + ), + ], + ) + def test_add_frames(self, first, second, expected): + # GH#1134 + tm.assert_frame_equal(first + second, expected) + tm.assert_frame_equal(second + first, expected) + + # TODO: This came from series.test.test_operators, needs cleanup + def test_series_frame_radd_bug(self, fixed_now_ts): + # GH#353 + vals = Series([str(i) for i in range(5)]) + result = "foo_" + vals + expected = vals.map(lambda x: "foo_" + x) + tm.assert_series_equal(result, expected) + + frame = pd.DataFrame({"vals": vals}) + result = "foo_" + frame + expected = pd.DataFrame({"vals": vals.map(lambda x: "foo_" + x)}) + tm.assert_frame_equal(result, expected) + + ts = Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + + # really raise this time + fix_now = fixed_now_ts.to_pydatetime() + msg = "|".join( + [ + "unsupported operand type", + # wrong error message, see https://github.com/numpy/numpy/issues/18832 + "Concatenation operation", + ] + ) + with pytest.raises(TypeError, match=msg): + fix_now + ts + + with pytest.raises(TypeError, match=msg): + ts + fix_now + + # TODO: This came from series.test.test_operators, needs cleanup + def test_datetime64_with_index(self): + # arithmetic integer ops with an index + ser = Series(np.random.default_rng(2).standard_normal(5)) + expected = ser - ser.index.to_series() + result = ser - ser.index + tm.assert_series_equal(result, expected) + + # GH#4629 + # arithmetic datetime64 ops with an index + ser = Series( + date_range("20130101", periods=5), + index=date_range("20130101", periods=5), + ) + expected = ser - ser.index.to_series() + result = ser - ser.index + tm.assert_series_equal(result, expected) + + msg = "cannot subtract PeriodArray from DatetimeArray" + with pytest.raises(TypeError, match=msg): + # GH#18850 + result = ser - ser.index.to_period() + + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), + index=date_range("20130101", periods=5), + ) + df["date"] = pd.Timestamp("20130102") + df["expected"] = df["date"] - df.index.to_series() + df["result"] = df["date"] - df.index + tm.assert_series_equal(df["result"], df["expected"], check_names=False) + + # TODO: taken from tests.frame.test_operators, needs cleanup + def test_frame_operators(self, float_frame): + frame = float_frame + + garbage = np.random.default_rng(2).random(4) + colSeries = Series(garbage, index=np.array(frame.columns)) + + idSum = frame + frame + seriesSum = frame + colSeries + + for col, series in idSum.items(): + for idx, val in series.items(): + origVal = frame[col][idx] * 2 + if not np.isnan(val): + assert val == origVal + else: + assert np.isnan(origVal) + + for col, series in seriesSum.items(): + for idx, val in series.items(): + origVal = frame[col][idx] + colSeries[col] + if not np.isnan(val): + assert val == origVal + else: + assert np.isnan(origVal) + + def test_frame_operators_col_align(self, float_frame): + frame2 = pd.DataFrame(float_frame, columns=["D", "C", "B", "A"]) + added = frame2 + frame2 + expected = frame2 * 2 + tm.assert_frame_equal(added, expected) + + def test_frame_operators_none_to_nan(self): + df = pd.DataFrame({"a": ["a", None, "b"]}) + tm.assert_frame_equal(df + df, pd.DataFrame({"a": ["aa", np.nan, "bb"]})) + + @pytest.mark.parametrize("dtype", ("float", "int64")) + def test_frame_operators_empty_like(self, dtype): + # Test for issue #10181 + frames = [ + pd.DataFrame(dtype=dtype), + pd.DataFrame(columns=["A"], dtype=dtype), + pd.DataFrame(index=[0], dtype=dtype), + ] + for df in frames: + assert (df + df).equals(df) + tm.assert_frame_equal(df + df, df) + + @pytest.mark.parametrize( + "func", + [lambda x: x * 2, lambda x: x[::2], lambda x: 5], + ids=["multiply", "slice", "constant"], + ) + def test_series_operators_arithmetic(self, all_arithmetic_functions, func): + op = all_arithmetic_functions + series = Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + other = func(series) + compare_op(series, other, op) + + @pytest.mark.parametrize( + "func", [lambda x: x + 1, lambda x: 5], ids=["add", "constant"] + ) + def test_series_operators_compare(self, comparison_op, func): + op = comparison_op + series = Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + other = func(series) + compare_op(series, other, op) + + @pytest.mark.parametrize( + "func", + [lambda x: x * 2, lambda x: x[::2], lambda x: 5], + ids=["multiply", "slice", "constant"], + ) + def test_divmod(self, func): + series = Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + other = func(series) + results = divmod(series, other) + if isinstance(other, abc.Iterable) and len(series) != len(other): + # if the lengths don't match, this is the test where we use + # `tser[::2]`. Pad every other value in `other_np` with nan. + other_np = [] + for n in other: + other_np.append(n) + other_np.append(np.nan) + else: + other_np = other + other_np = np.asarray(other_np) + with np.errstate(all="ignore"): + expecteds = divmod(series.values, np.asarray(other_np)) + + for result, expected in zip(results, expecteds, strict=True): + # check the values, name, and index separately + tm.assert_almost_equal(np.asarray(result), expected) + + assert result.name == series.name + tm.assert_index_equal(result.index, series.index._with_freq(None)) + + def test_series_divmod_zero(self): + # Check that divmod uses pandas convention for division by zero, + # which does not match numpy. + # pandas convention has + # 1/0 == np.inf + # -1/0 == -np.inf + # 1/-0.0 == -np.inf + # -1/-0.0 == np.inf + tser = Series( + np.arange(1, 11, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + other = tser * 0 + + result = divmod(tser, other) + exp1 = Series([np.inf] * len(tser), index=tser.index, name="ts") + exp2 = Series([np.nan] * len(tser), index=tser.index, name="ts") + tm.assert_series_equal(result[0], exp1) + tm.assert_series_equal(result[1], exp2) + + +class TestUFuncCompat: + # TODO: add more dtypes + @pytest.mark.parametrize("holder", [Index, RangeIndex, Series]) + @pytest.mark.parametrize("dtype", [np.int64, np.uint64, np.float64]) + def test_ufunc_compat(self, holder, dtype): + box = Series if holder is Series else Index + + if holder is RangeIndex: + if dtype != np.int64: + pytest.skip(f"dtype {dtype} not relevant for RangeIndex") + idx = RangeIndex(0, 5, name="foo") + else: + idx = holder(np.arange(5, dtype=dtype), name="foo") + result = np.sin(idx) + expected = box(np.sin(np.arange(5, dtype=dtype)), name="foo") + tm.assert_equal(result, expected) + + # TODO: add more dtypes + @pytest.mark.parametrize("dtype", [np.int64, np.uint64, np.float64]) + def test_ufunc_coercions(self, index_or_series, dtype): + idx = index_or_series([1, 2, 3, 4, 5], dtype=dtype, name="x") + box = index_or_series + + result = np.sqrt(idx) + assert result.dtype == "f8" and isinstance(result, box) + exp = Index(np.sqrt(np.array([1, 2, 3, 4, 5], dtype=np.float64)), name="x") + exp = tm.box_expected(exp, box) + tm.assert_equal(result, exp) + + result = np.divide(idx, 2.0) + assert result.dtype == "f8" and isinstance(result, box) + exp = Index([0.5, 1.0, 1.5, 2.0, 2.5], dtype=np.float64, name="x") + exp = tm.box_expected(exp, box) + tm.assert_equal(result, exp) + + # _evaluate_numeric_binop + result = idx + 2.0 + assert result.dtype == "f8" and isinstance(result, box) + exp = Index([3.0, 4.0, 5.0, 6.0, 7.0], dtype=np.float64, name="x") + exp = tm.box_expected(exp, box) + tm.assert_equal(result, exp) + + result = idx - 2.0 + assert result.dtype == "f8" and isinstance(result, box) + exp = Index([-1.0, 0.0, 1.0, 2.0, 3.0], dtype=np.float64, name="x") + exp = tm.box_expected(exp, box) + tm.assert_equal(result, exp) + + result = idx * 1.0 + assert result.dtype == "f8" and isinstance(result, box) + exp = Index([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float64, name="x") + exp = tm.box_expected(exp, box) + tm.assert_equal(result, exp) + + result = idx / 2.0 + assert result.dtype == "f8" and isinstance(result, box) + exp = Index([0.5, 1.0, 1.5, 2.0, 2.5], dtype=np.float64, name="x") + exp = tm.box_expected(exp, box) + tm.assert_equal(result, exp) + + # TODO: add more dtypes + @pytest.mark.parametrize("holder", [Index, Series]) + @pytest.mark.parametrize("dtype", [np.int64, np.uint64, np.float64]) + def test_ufunc_multiple_return_values(self, holder, dtype): + obj = holder([1, 2, 3], dtype=dtype, name="x") + box = Series if holder is Series else Index + + result = np.modf(obj) + assert isinstance(result, tuple) + exp1 = Index([0.0, 0.0, 0.0], dtype=np.float64, name="x") + exp2 = Index([1.0, 2.0, 3.0], dtype=np.float64, name="x") + tm.assert_equal(result[0], tm.box_expected(exp1, box)) + tm.assert_equal(result[1], tm.box_expected(exp2, box)) + + def test_ufunc_at(self): + s = Series([0, 1, 2], index=[1, 2, 3], name="x") + np.add.at(s, [0, 2], 10) + expected = Series([10, 1, 12], index=[1, 2, 3], name="x") + tm.assert_series_equal(s, expected) + + +class TestObjectDtypeEquivalence: + # Tests that arithmetic operations match operations executed elementwise + + @pytest.mark.parametrize("dtype", [None, object]) + def test_numarr_with_dtype_add_nan(self, dtype, box_with_array): + box = box_with_array + ser = Series([1, 2, 3], dtype=dtype) + expected = Series([np.nan, np.nan, np.nan], dtype=dtype) + + ser = tm.box_expected(ser, box) + expected = tm.box_expected(expected, box) + + result = np.nan + ser + tm.assert_equal(result, expected) + + result = ser + np.nan + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("dtype", [None, object]) + def test_numarr_with_dtype_add_int(self, dtype, box_with_array): + box = box_with_array + ser = Series([1, 2, 3], dtype=dtype) + expected = Series([2, 3, 4], dtype=dtype) + + ser = tm.box_expected(ser, box) + expected = tm.box_expected(expected, box) + + result = 1 + ser + tm.assert_equal(result, expected) + + result = ser + 1 + tm.assert_equal(result, expected) + + # TODO: moved from tests.series.test_operators; needs cleanup + @pytest.mark.parametrize( + "op", + [operator.add, operator.sub, operator.mul, operator.truediv, operator.floordiv], + ) + def test_operators_reverse_object(self, op): + # GH#56 + arr = Series( + np.random.default_rng(2).standard_normal(10), + index=np.arange(10), + dtype=object, + ) + + result = op(1.0, arr) + expected = op(1.0, arr.astype(float)) + tm.assert_series_equal(result.astype(float), expected) + + +class TestNumericArithmeticUnsorted: + # Tests in this class have been moved from type-specific test modules + # but not yet sorted, parametrized, and de-duplicated + @pytest.mark.parametrize( + "op", + [ + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.truediv, + ], + ) + @pytest.mark.parametrize( + "idx1", + [ + RangeIndex(0, 10, 1), + RangeIndex(0, 20, 2), + RangeIndex(-10, 10, 2), + RangeIndex(5, -5, -1), + ], + ) + @pytest.mark.parametrize( + "idx2", + [ + RangeIndex(0, 10, 1), + RangeIndex(0, 20, 2), + RangeIndex(-10, 10, 2), + RangeIndex(5, -5, -1), + ], + ) + def test_binops_index(self, op, idx1, idx2): + idx1 = idx1._rename("foo") + idx2 = idx2._rename("bar") + result = op(idx1, idx2) + expected = op(Index(idx1.to_numpy()), Index(idx2.to_numpy())) + tm.assert_index_equal(result, expected, exact="equiv") + + @pytest.mark.parametrize( + "op", + [ + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.truediv, + ], + ) + @pytest.mark.parametrize( + "idx", + [ + RangeIndex(0, 10, 1), + RangeIndex(0, 20, 2), + RangeIndex(-10, 10, 2), + RangeIndex(5, -5, -1), + ], + ) + @pytest.mark.parametrize("scalar", [-1, 1, 2]) + def test_binops_index_scalar(self, op, idx, scalar): + result = op(idx, scalar) + expected = op(Index(idx.to_numpy()), scalar) + tm.assert_index_equal(result, expected, exact="equiv") + + @pytest.mark.parametrize("idx1", [RangeIndex(0, 10, 1), RangeIndex(0, 20, 2)]) + @pytest.mark.parametrize("idx2", [RangeIndex(0, 10, 1), RangeIndex(0, 20, 2)]) + def test_binops_index_pow(self, idx1, idx2): + # numpy does not allow powers of negative integers so test separately + # https://github.com/numpy/numpy/pull/8127 + idx1 = idx1._rename("foo") + idx2 = idx2._rename("bar") + result = pow(idx1, idx2) + expected = pow(Index(idx1.to_numpy()), Index(idx2.to_numpy())) + tm.assert_index_equal(result, expected, exact="equiv") + + @pytest.mark.parametrize("idx", [RangeIndex(0, 10, 1), RangeIndex(0, 20, 2)]) + @pytest.mark.parametrize("scalar", [1, 2]) + def test_binops_index_scalar_pow(self, idx, scalar): + # numpy does not allow powers of negative integers so test separately + # https://github.com/numpy/numpy/pull/8127 + result = pow(idx, scalar) + expected = pow(Index(idx.to_numpy()), scalar) + tm.assert_index_equal(result, expected, exact="equiv") + + # TODO: divmod? + @pytest.mark.parametrize( + "op", + [ + operator.add, + operator.sub, + operator.mul, + operator.floordiv, + operator.truediv, + operator.pow, + operator.mod, + ], + ) + def test_arithmetic_with_frame_or_series(self, op): + # check that we return NotImplemented when operating with Series + # or DataFrame + index = RangeIndex(5) + other = Series(np.random.default_rng(2).standard_normal(5)) + + expected = op(Series(index), other) + result = op(index, other) + tm.assert_series_equal(result, expected) + + other = pd.DataFrame(np.random.default_rng(2).standard_normal((2, 5))) + expected = op(pd.DataFrame([index, index]), other) + result = op(index, other) + tm.assert_frame_equal(result, expected) + + def test_numeric_compat2(self): + # validate that we are handling the RangeIndex overrides to numeric ops + # and returning RangeIndex where possible + + idx = RangeIndex(0, 10, 2) + + result = idx * 2 + expected = RangeIndex(0, 20, 4) + tm.assert_index_equal(result, expected, exact=True) + + result = idx + 2 + expected = RangeIndex(2, 12, 2) + tm.assert_index_equal(result, expected, exact=True) + + result = idx - 2 + expected = RangeIndex(-2, 8, 2) + tm.assert_index_equal(result, expected, exact=True) + + result = idx / 2 + expected = RangeIndex(0, 5, 1).astype("float64") + tm.assert_index_equal(result, expected, exact=True) + + result = idx / 4 + expected = RangeIndex(0, 10, 2) / 4 + tm.assert_index_equal(result, expected, exact=True) + + result = idx // 1 + expected = idx + tm.assert_index_equal(result, expected, exact=True) + + # __mul__ + result = idx * idx + expected = Index(idx.values * idx.values) + tm.assert_index_equal(result, expected, exact=True) + + # __pow__ + idx = RangeIndex(0, 1000, 2) + result = idx**2 + expected = Index(idx._values) ** 2 + tm.assert_index_equal(Index(result.values), expected, exact=True) + + @pytest.mark.parametrize( + "idx, div, expected", + [ + # TODO: add more dtypes + (RangeIndex(0, 1000, 2), 2, RangeIndex(0, 500, 1)), + (RangeIndex(-99, -201, -3), -3, RangeIndex(33, 67, 1)), + ( + RangeIndex(0, 1000, 1), + 2, + Index(RangeIndex(0, 1000, 1)._values) // 2, + ), + ( + RangeIndex(0, 100, 1), + 2.0, + Index(RangeIndex(0, 100, 1)._values) // 2.0, + ), + (RangeIndex(0), 50, RangeIndex(0)), + (RangeIndex(2, 4, 2), 3, RangeIndex(0, 1, 1)), + (RangeIndex(-5, -10, -6), 4, RangeIndex(-2, -1, 1)), + (RangeIndex(-100, -200, 3), 2, RangeIndex(0)), + ], + ) + def test_numeric_compat2_floordiv(self, idx, div, expected): + # __floordiv__ + tm.assert_index_equal(idx // div, expected, exact=True) + + @pytest.mark.parametrize("dtype", [np.int64, np.float64]) + @pytest.mark.parametrize("delta", [1, 0, -1]) + def test_addsub_arithmetic(self, dtype, delta): + # GH#8142 + delta = dtype(delta) + index = Index([10, 11, 12], dtype=dtype) + result = index + delta + expected = Index(index.values + delta, dtype=dtype) + tm.assert_index_equal(result, expected) + + # this subtraction used to fail + result = index - delta + expected = Index(index.values - delta, dtype=dtype) + tm.assert_index_equal(result, expected) + + tm.assert_index_equal(index + index, 2 * index) + tm.assert_index_equal(index - index, 0 * index) + assert not (index - index).empty + + def test_pow_nan_with_zero(self, box_with_array): + left = Index([np.nan, np.nan, np.nan]) + right = Index([0, 0, 0]) + expected = Index([1.0, 1.0, 1.0]) + + left = tm.box_expected(left, box_with_array) + right = tm.box_expected(right, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = left**right + tm.assert_equal(result, expected) + + +def test_fill_value_inf_masking(): + # GH #27464 make sure we mask 0/1 with Inf and not NaN + df = pd.DataFrame({"A": [0, 1, 2], "B": [1.1, None, 1.1]}) + + other = pd.DataFrame({"A": [1.1, 1.2, 1.3]}, index=[0, 2, 3]) + + result = df.rfloordiv(other, fill_value=1) + + expected = pd.DataFrame( + {"A": [np.inf, 1.0, 0.0, 1.0], "B": [0.0, np.nan, 0.0, np.nan]} + ) + tm.assert_frame_equal(result, expected, check_index_type=False) + + +def test_dataframe_div_silenced(): + # GH#26793 + pdf1 = pd.DataFrame( + { + "A": np.arange(10), + "B": [np.nan, 1, 2, 3, 4] * 2, + "C": [np.nan] * 10, + "D": np.arange(10), + }, + index=list("abcdefghij"), + columns=list("ABCD"), + ) + pdf2 = pd.DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + index=list("abcdefghjk"), + columns=list("ABCX"), + ) + with tm.assert_produces_warning(None): + pdf1.div(pdf2, fill_value=0) + + +@pytest.mark.parametrize( + "data, expected_data", + [([0, 1, 2], [0, 2, 4])], +) +@pytest.mark.parametrize("box_pandas_1d_array", [Index, Series, tm.to_array]) +@pytest.mark.parametrize("box_1d_array", [Index, Series, tm.to_array, np.array, list]) +def test_integer_array_add_list_like( + box_pandas_1d_array, box_1d_array, data, expected_data +): + # GH22606 Verify operators with IntegerArray and list-likes + arr = array(data, dtype="Int64") + container = box_pandas_1d_array(arr) + left = container + box_1d_array(data) + right = box_1d_array(data) + container + + if Series in [box_1d_array, box_pandas_1d_array]: + cls = Series + elif Index in [box_1d_array, box_pandas_1d_array]: + cls = Index + else: + cls = array + + expected = cls(expected_data, dtype="Int64") + + tm.assert_equal(left, expected) + tm.assert_equal(right, expected) + + +def test_sub_multiindex_swapped_levels(): + # GH 9952 + df = pd.DataFrame( + {"a": np.random.default_rng(2).standard_normal(6)}, + index=pd.MultiIndex.from_product( + [["a", "b"], [0, 1, 2]], names=["levA", "levB"] + ), + ) + df2 = df.copy() + df2.index = df2.index.swaplevel(0, 1) + result = df - df2 + expected = pd.DataFrame([0.0] * 6, columns=["a"], index=df.index) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("power", [1, 2, 5]) +@pytest.mark.parametrize("string_size", [0, 1, 2, 5]) +def test_empty_str_comparison(power, string_size): + # GH 37348 + a = np.array(range(10**power)) + right = pd.DataFrame(a, dtype=np.int64) + left = " " * string_size + + result = right == left + expected = pd.DataFrame(np.zeros(right.shape, dtype=bool)) + tm.assert_frame_equal(result, expected) + + +def test_series_add_sub_with_UInt64(): + # GH 22023 + series1 = Series([1, 2, 3]) + series2 = Series([2, 1, 3], dtype="UInt64") + + result = series1 + series2 + expected = Series([3, 3, 6], dtype="Float64") + tm.assert_series_equal(result, expected) + + result = series1 - series2 + expected = Series([-1, 1, 0], dtype="Float64") + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/arithmetic/test_object.py b/pandas/tests/arithmetic/test_object.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0f78d3aa01af5353bc9385848fb49a7784020d --- /dev/null +++ b/pandas/tests/arithmetic/test_object.py @@ -0,0 +1,410 @@ +# Arithmetic tests for DataFrame/Series/Index/Array classes that should +# behave identically. +# Specifically for object dtype +import datetime +from decimal import Decimal +import operator + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + Series, + Timestamp, + option_context, +) +import pandas._testing as tm +from pandas.core import ops + +# ------------------------------------------------------------------ +# Comparisons + + +class TestObjectComparisons: + def test_comparison_object_numeric_nas(self, comparison_op): + ser = Series(np.random.default_rng(2).standard_normal(10), dtype=object) + shifted = ser.shift(2) + + func = comparison_op + + result = func(ser, shifted) + expected = func(ser.astype(float), shifted.astype(float)) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))] + ) + def test_object_comparisons(self, infer_string): + with option_context("future.infer_string", infer_string): + ser = Series(["a", "b", np.nan, "c", "a"]) + + result = ser == "a" + expected = Series([True, False, False, False, True]) + tm.assert_series_equal(result, expected) + + result = ser < "a" + expected = Series([False, False, False, False, False]) + tm.assert_series_equal(result, expected) + + result = ser != "a" + expected = -(ser == "a") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("dtype", [None, object]) + def test_more_na_comparisons(self, dtype): + left = Series(["a", np.nan, "c"], dtype=dtype) + right = Series(["a", np.nan, "d"], dtype=dtype) + + result = left == right + expected = Series([True, False, False]) + tm.assert_series_equal(result, expected) + + result = left != right + expected = Series([False, True, True]) + tm.assert_series_equal(result, expected) + + result = left == np.nan + expected = Series([False, False, False]) + tm.assert_series_equal(result, expected) + + result = left != np.nan + expected = Series([True, True, True]) + tm.assert_series_equal(result, expected) + + +# ------------------------------------------------------------------ +# Arithmetic + + +class TestArithmetic: + def test_add_period_to_array_of_offset(self): + # GH#50162 + per = pd.Period("2012-1-1", freq="D") + pi = pd.period_range("2012-1-1", periods=10, freq="D") + idx = per - pi + + expected = pd.Index([x + per for x in idx], dtype=object) + result = idx + per + tm.assert_index_equal(result, expected) + + result = per + idx + tm.assert_index_equal(result, expected) + + # TODO: parametrize + def test_pow_ops_object(self): + # GH#22922 + # pow is weird with masking & 1, so testing here + a = Series([1, np.nan, 1, np.nan], dtype=object) + b = Series([1, np.nan, np.nan, 1], dtype=object) + result = a**b + expected = Series(a.values**b.values, dtype=object) + tm.assert_series_equal(result, expected) + + result = b**a + expected = Series(b.values**a.values, dtype=object) + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("op", [operator.add, ops.radd]) + @pytest.mark.parametrize("other", ["category", "Int64"]) + def test_add_extension_scalar(self, other, box_with_array, op): + # GH#22378 + # Check that scalars satisfying is_extension_array_dtype(obj) + # do not incorrectly try to dispatch to an ExtensionArray operation + + arr = Series(["a", "b", "c"]) + expected = Series([op(x, other) for x in arr]) + + arr = tm.box_expected(arr, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = op(arr, other) + tm.assert_equal(result, expected) + + def test_objarr_add_str(self, box_with_array): + ser = Series(["x", np.nan, "x"]) + expected = Series(["xa", np.nan, "xa"]) + + ser = tm.box_expected(ser, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = ser + "a" + tm.assert_equal(result, expected) + + def test_objarr_radd_str(self, box_with_array): + ser = Series(["x", np.nan, "x"]) + expected = Series(["ax", np.nan, "ax"]) + + ser = tm.box_expected(ser, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = "a" + ser + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "data", + [ + [1, 2, 3], + [1.1, 2.2, 3.3], + [Timestamp("2011-01-01"), Timestamp("2011-01-02"), pd.NaT], + ["x", "y", 1], + ], + ) + @pytest.mark.parametrize("dtype", [None, object]) + def test_objarr_radd_str_invalid(self, dtype, data, box_with_array): + ser = Series(data, dtype=dtype) + + ser = tm.box_expected(ser, box_with_array) + msg = "|".join( + [ + "can only concatenate str", + "did not contain a loop with signature matching types", + "unsupported operand type", + "must be str", + ] + ) + with pytest.raises(TypeError, match=msg): + "foo_" + ser + + @pytest.mark.parametrize("op", [operator.add, ops.radd, operator.sub, ops.rsub]) + def test_objarr_add_invalid(self, op, box_with_array): + # invalid ops + box = box_with_array + + obj_ser = Series(list("abc"), dtype=object, name="objects") + + obj_ser = tm.box_expected(obj_ser, box) + msg = "|".join( + [ + "can only concatenate str", + "unsupported operand type", + "must be str", + "has no kernel", + ] + ) + with pytest.raises(Exception, match=msg): + op(obj_ser, 1) + with pytest.raises(Exception, match=msg): + op(obj_ser, np.array(1, dtype=np.int64)) + + # TODO: Moved from tests.series.test_operators; needs cleanup + def test_operators_na_handling(self): + ser = Series(["foo", "bar", "baz", np.nan]) + result = "prefix_" + ser + expected = Series(["prefix_foo", "prefix_bar", "prefix_baz", np.nan]) + tm.assert_series_equal(result, expected) + + result = ser + "_suffix" + expected = Series(["foo_suffix", "bar_suffix", "baz_suffix", np.nan]) + tm.assert_series_equal(result, expected) + + # TODO: parametrize over box + @pytest.mark.parametrize("dtype", [None, object]) + def test_series_with_dtype_radd_timedelta(self, dtype): + # note this test is _not_ aimed at timedelta64-dtyped Series + # as of 2.0 we retain object dtype when ser.dtype == object + ser = Series( + [pd.Timedelta("1 days"), pd.Timedelta("2 days"), pd.Timedelta("3 days")], + dtype=dtype, + ) + expected = Series( + [pd.Timedelta("4 days"), pd.Timedelta("5 days"), pd.Timedelta("6 days")], + dtype=dtype, + ) + + result = pd.Timedelta("3 days") + ser + tm.assert_series_equal(result, expected) + + result = ser + pd.Timedelta("3 days") + tm.assert_series_equal(result, expected) + + # TODO: cleanup & parametrize over box + def test_mixed_timezone_series_ops_object(self): + # GH#13043 + ser = Series( + [ + Timestamp("2015-01-01", tz="US/Eastern"), + Timestamp("2015-01-01", tz="Asia/Tokyo"), + ], + name="xxx", + ) + assert ser.dtype == object + + exp = Series( + [ + Timestamp("2015-01-02", tz="US/Eastern"), + Timestamp("2015-01-02", tz="Asia/Tokyo"), + ], + name="xxx", + ) + tm.assert_series_equal(ser + pd.Timedelta("1 days"), exp) + tm.assert_series_equal(pd.Timedelta("1 days") + ser, exp) + + # object series & object series + ser2 = Series( + [ + Timestamp("2015-01-03", tz="US/Eastern"), + Timestamp("2015-01-05", tz="Asia/Tokyo"), + ], + name="xxx", + ) + assert ser2.dtype == object + exp = Series( + [pd.Timedelta("2 days"), pd.Timedelta("4 days")], name="xxx", dtype=object + ) + tm.assert_series_equal(ser2 - ser, exp) + tm.assert_series_equal(ser - ser2, -exp) + + ser = Series( + [pd.Timedelta("01:00:00"), pd.Timedelta("02:00:00")], + name="xxx", + dtype=object, + ) + assert ser.dtype == object + + exp = Series( + [pd.Timedelta("01:30:00"), pd.Timedelta("02:30:00")], + name="xxx", + dtype=object, + ) + tm.assert_series_equal(ser + pd.Timedelta("00:30:00"), exp) + tm.assert_series_equal(pd.Timedelta("00:30:00") + ser, exp) + + # TODO: cleanup & parametrize over box + def test_iadd_preserves_name(self): + # GH#17067, GH#19723 __iadd__ and __isub__ should preserve index name + ser = Series([1, 2, 3]) + ser.index.name = "foo" + + ser.index += 1 + assert ser.index.name == "foo" + + ser.index -= 1 + assert ser.index.name == "foo" + + def test_add_string(self): + # from bug report + index = pd.Index(["a", "b", "c"]) + index2 = index + "foo" + + assert "a" not in index2 + assert "afoo" in index2 + + def test_iadd_string(self): + index = pd.Index(["a", "b", "c"]) + # doesn't fail test unless there is a check before `+=` + assert "a" in index + + index += "_x" + assert "a_x" in index + + def test_add(self): + index = pd.Index([str(i) for i in range(10)]) + expected = pd.Index(index.values * 2) + tm.assert_index_equal(index + index, expected) + tm.assert_index_equal(index + index.tolist(), expected) + tm.assert_index_equal(index.tolist() + index, expected) + + # test add and radd + index = pd.Index(list("abc")) + expected = pd.Index(["a1", "b1", "c1"]) + tm.assert_index_equal(index + "1", expected) + expected = pd.Index(["1a", "1b", "1c"]) + tm.assert_index_equal("1" + index, expected) + + def test_sub_fail(self): + index = pd.Index([str(i) for i in range(10)]) + + msg = "unsupported operand type|Cannot broadcast|sub' not supported" + with pytest.raises(TypeError, match=msg): + index - "a" + with pytest.raises(TypeError, match=msg): + index - index + with pytest.raises(TypeError, match=msg): + index - index.tolist() + with pytest.raises(TypeError, match=msg): + index.tolist() - index + + def test_sub_object(self): + # GH#19369 + index = pd.Index([Decimal(1), Decimal(2)]) + expected = pd.Index([Decimal(0), Decimal(1)]) + + result = index - Decimal(1) + tm.assert_index_equal(result, expected) + + result = index - pd.Index([Decimal(1), Decimal(1)]) + tm.assert_index_equal(result, expected) + + msg = "unsupported operand type" + with pytest.raises(TypeError, match=msg): + index - "foo" + + with pytest.raises(TypeError, match=msg): + index - np.array([2, "foo"], dtype=object) + + def test_rsub_object(self, fixed_now_ts): + # GH#19369 + index = pd.Index([Decimal(1), Decimal(2)]) + expected = pd.Index([Decimal(1), Decimal(0)]) + + result = Decimal(2) - index + tm.assert_index_equal(result, expected) + + result = np.array([Decimal(2), Decimal(2)]) - index + tm.assert_index_equal(result, expected) + + msg = "unsupported operand type" + with pytest.raises(TypeError, match=msg): + "foo" - index + + with pytest.raises(TypeError, match=msg): + np.array([True, fixed_now_ts]) - index + + +class MyIndex(pd.Index): + # Simple index subclass that tracks ops calls. + + _calls: int + + @classmethod + def _simple_new(cls, values, name=None, dtype=None): + result = object.__new__(cls) + result._data = values + result._name = name + result._calls = 0 + result._reset_identity() + + return result + + def __add__(self, other): + self._calls += 1 + return self._simple_new(self._data) + + def __radd__(self, other): + return self.__add__(other) + + +@pytest.mark.parametrize( + "other", + [ + [datetime.timedelta(1), datetime.timedelta(2)], + [datetime.datetime(2000, 1, 1), datetime.datetime(2000, 1, 2)], + [pd.Period("2000"), pd.Period("2001")], + ["a", "b"], + ], + ids=["timedelta", "datetime", "period", "object"], +) +def test_index_ops_defer_to_unknown_subclasses(other): + # https://github.com/pandas-dev/pandas/issues/31109 + values = np.array( + [datetime.date(2000, 1, 1), datetime.date(2000, 1, 2)], dtype=object + ) + a = MyIndex._simple_new(values) + other = pd.Index(other) + result = other + a + assert isinstance(result, MyIndex) + assert a._calls == 1 diff --git a/pandas/tests/arithmetic/test_period.py b/pandas/tests/arithmetic/test_period.py new file mode 100644 index 0000000000000000000000000000000000000000..24733f3b3e5634e96742e8035d00bcb3edf7ccd7 --- /dev/null +++ b/pandas/tests/arithmetic/test_period.py @@ -0,0 +1,1679 @@ +# Arithmetic tests for DataFrame/Series/Index/Array classes that should +# behave identically. +# Specifically for Period dtype +import operator + +import numpy as np +import pytest + +from pandas._libs.tslibs import ( + IncompatibleFrequency, + Period, + Timestamp, + to_offset, +) + +import pandas as pd +from pandas import ( + PeriodIndex, + Series, + Timedelta, + TimedeltaIndex, + period_range, +) +import pandas._testing as tm +from pandas.core import ops +from pandas.core.arrays import TimedeltaArray +from pandas.tests.arithmetic.common import ( + assert_invalid_addsub_type, + assert_invalid_comparison, + get_upcast_box, +) + +_common_mismatch = [ + pd.offsets.YearBegin(2), + pd.offsets.MonthBegin(1), + pd.offsets.Minute(), +] + + +@pytest.fixture( + params=[ + Timedelta(minutes=30).to_pytimedelta(), + np.timedelta64(30, "s"), + Timedelta(seconds=30), + *_common_mismatch, + ] +) +def not_hourly(request): + """ + Several timedelta-like and DateOffset instances that are _not_ + compatible with Hourly frequencies. + """ + return request.param + + +@pytest.fixture( + params=[ + np.timedelta64(365, "D"), + Timedelta(days=365).to_pytimedelta(), + Timedelta(days=365), + *_common_mismatch, + ] +) +def mismatched_freq(request): + """ + Several timedelta-like and DateOffset instances that are _not_ + compatible with Monthly or Annual frequencies. + """ + return request.param + + +# ------------------------------------------------------------------ +# Comparisons + + +class TestPeriodArrayLikeComparisons: + # Comparison tests for PeriodDtype vectors fully parametrized over + # DataFrame/Series/PeriodIndex/PeriodArray. Ideally all comparison + # tests will eventually end up here. + + @pytest.mark.parametrize("other", ["2017", Period("2017", freq="D")]) + def test_eq_scalar(self, other, box_with_array): + idx = PeriodIndex(["2017", "2017", "2018"], freq="D") + idx = tm.box_expected(idx, box_with_array) + xbox = get_upcast_box(idx, other, True) + + expected = np.array([True, True, False]) + expected = tm.box_expected(expected, xbox) + + result = idx == other + + tm.assert_equal(result, expected) + + def test_compare_zerodim(self, box_with_array): + # GH#26689 make sure we unbox zero-dimensional arrays + + pi = period_range("2000", periods=4) + other = np.array(pi.to_numpy()[0]) + + pi = tm.box_expected(pi, box_with_array) + xbox = get_upcast_box(pi, other, True) + + result = pi <= other + expected = np.array([True, False, False, False]) + expected = tm.box_expected(expected, xbox) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "scalar", + [ + "foo", + Timestamp("2021-01-01"), + Timedelta(days=4), + 9, + 9.5, + 2000, # specifically don't consider 2000 to match Period("2000", "D") + False, + None, + ], + ) + def test_compare_invalid_scalar(self, box_with_array, scalar): + # GH#28980 + # comparison with scalar that cannot be interpreted as a Period + pi = period_range("2000", periods=4) + parr = tm.box_expected(pi, box_with_array) + assert_invalid_comparison(parr, scalar, box_with_array) + + @pytest.mark.parametrize( + "other", + [ + pd.date_range("2000", periods=4).array, + pd.timedelta_range("1D", periods=4).array, + np.arange(4), + np.arange(4).astype(np.float64), + list(range(4)), + # match Period semantics by not treating integers as Periods + [2000, 2001, 2002, 2003], + np.arange(2000, 2004), + np.arange(2000, 2004).astype(object), + pd.Index([2000, 2001, 2002, 2003]), + ], + ) + def test_compare_invalid_listlike(self, box_with_array, other): + pi = period_range("2000", periods=4) + parr = tm.box_expected(pi, box_with_array) + assert_invalid_comparison(parr, other, box_with_array) + + @pytest.mark.parametrize("other_box", [list, np.array, lambda x: x.astype(object)]) + def test_compare_object_dtype(self, box_with_array, other_box): + pi = period_range("2000", periods=5) + parr = tm.box_expected(pi, box_with_array) + + other = other_box(pi) + xbox = get_upcast_box(parr, other, True) + + expected = np.array([True, True, True, True, True]) + expected = tm.box_expected(expected, xbox) + + result = parr == other + tm.assert_equal(result, expected) + result = parr <= other + tm.assert_equal(result, expected) + result = parr >= other + tm.assert_equal(result, expected) + + result = parr != other + tm.assert_equal(result, ~expected) + result = parr < other + tm.assert_equal(result, ~expected) + result = parr > other + tm.assert_equal(result, ~expected) + + other = other_box(pi[::-1]) + + expected = np.array([False, False, True, False, False]) + expected = tm.box_expected(expected, xbox) + result = parr == other + tm.assert_equal(result, expected) + + expected = np.array([True, True, True, False, False]) + expected = tm.box_expected(expected, xbox) + result = parr <= other + tm.assert_equal(result, expected) + + expected = np.array([False, False, True, True, True]) + expected = tm.box_expected(expected, xbox) + result = parr >= other + tm.assert_equal(result, expected) + + expected = np.array([True, True, False, True, True]) + expected = tm.box_expected(expected, xbox) + result = parr != other + tm.assert_equal(result, expected) + + expected = np.array([True, True, False, False, False]) + expected = tm.box_expected(expected, xbox) + result = parr < other + tm.assert_equal(result, expected) + + expected = np.array([False, False, False, True, True]) + expected = tm.box_expected(expected, xbox) + result = parr > other + tm.assert_equal(result, expected) + + +class TestPeriodIndexComparisons: + # TODO: parameterize over boxes + + def test_pi_cmp_period(self): + idx = period_range("2007-01", periods=20, freq="M") + per = idx[10] + + result = idx < per + exp = idx.values < idx.values[10] + tm.assert_numpy_array_equal(result, exp) + + # Tests Period.__richcmp__ against ndarray[object, ndim=2] + result = idx.values.reshape(10, 2) < per + tm.assert_numpy_array_equal(result, exp.reshape(10, 2)) + + # Tests Period.__richcmp__ against ndarray[object, ndim=0] + result = idx < np.array(per) + tm.assert_numpy_array_equal(result, exp) + + # TODO: moved from test_datetime64; de-duplicate with version below + def test_parr_cmp_period_scalar2(self, box_with_array): + pi = period_range("2000-01-01", periods=10, freq="D") + + val = pi[3] + expected = [x > val for x in pi] + + ser = tm.box_expected(pi, box_with_array) + xbox = get_upcast_box(ser, val, True) + + expected = tm.box_expected(expected, xbox) + result = ser > val + tm.assert_equal(result, expected) + + val = pi[5] + result = ser > val + expected = [x > val for x in pi] + expected = tm.box_expected(expected, xbox) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("freq", ["M", "2M", "3M"]) + def test_parr_cmp_period_scalar(self, freq, box_with_array): + # GH#13200 + base = PeriodIndex(["2011-01", "2011-02", "2011-03", "2011-04"], freq=freq) + base = tm.box_expected(base, box_with_array) + per = Period("2011-02", freq=freq) + xbox = get_upcast_box(base, per, True) + + exp = np.array([False, True, False, False]) + exp = tm.box_expected(exp, xbox) + tm.assert_equal(base == per, exp) + tm.assert_equal(per == base, exp) + + exp = np.array([True, False, True, True]) + exp = tm.box_expected(exp, xbox) + tm.assert_equal(base != per, exp) + tm.assert_equal(per != base, exp) + + exp = np.array([False, False, True, True]) + exp = tm.box_expected(exp, xbox) + tm.assert_equal(base > per, exp) + tm.assert_equal(per < base, exp) + + exp = np.array([True, False, False, False]) + exp = tm.box_expected(exp, xbox) + tm.assert_equal(base < per, exp) + tm.assert_equal(per > base, exp) + + exp = np.array([False, True, True, True]) + exp = tm.box_expected(exp, xbox) + tm.assert_equal(base >= per, exp) + tm.assert_equal(per <= base, exp) + + exp = np.array([True, True, False, False]) + exp = tm.box_expected(exp, xbox) + tm.assert_equal(base <= per, exp) + tm.assert_equal(per >= base, exp) + + @pytest.mark.parametrize("freq", ["M", "2M", "3M"]) + def test_parr_cmp_pi(self, freq, box_with_array): + # GH#13200 + base = PeriodIndex(["2011-01", "2011-02", "2011-03", "2011-04"], freq=freq) + base = tm.box_expected(base, box_with_array) + + # TODO: could also box idx? + idx = PeriodIndex(["2011-02", "2011-01", "2011-03", "2011-05"], freq=freq) + + xbox = get_upcast_box(base, idx, True) + + exp = np.array([False, False, True, False]) + exp = tm.box_expected(exp, xbox) + tm.assert_equal(base == idx, exp) + + exp = np.array([True, True, False, True]) + exp = tm.box_expected(exp, xbox) + tm.assert_equal(base != idx, exp) + + exp = np.array([False, True, False, False]) + exp = tm.box_expected(exp, xbox) + tm.assert_equal(base > idx, exp) + + exp = np.array([True, False, False, True]) + exp = tm.box_expected(exp, xbox) + tm.assert_equal(base < idx, exp) + + exp = np.array([False, True, True, False]) + exp = tm.box_expected(exp, xbox) + tm.assert_equal(base >= idx, exp) + + exp = np.array([True, False, True, True]) + exp = tm.box_expected(exp, xbox) + tm.assert_equal(base <= idx, exp) + + @pytest.mark.parametrize("freq", ["M", "2M", "3M"]) + def test_parr_cmp_pi_mismatched_freq(self, freq, box_with_array): + # GH#13200 + # different base freq + base = PeriodIndex(["2011-01", "2011-02", "2011-03", "2011-04"], freq=freq) + base = tm.box_expected(base, box_with_array) + + msg = rf"Invalid comparison between dtype=period\[{freq}\] and Period" + with pytest.raises(TypeError, match=msg): + base <= Period("2011", freq="Y") + + with pytest.raises(TypeError, match=msg): + Period("2011", freq="Y") >= base + + # TODO: Could parametrize over boxes for idx? + idx = PeriodIndex(["2011", "2012", "2013", "2014"], freq="Y") + rev_msg = r"Invalid comparison between dtype=period\[Y-DEC\] and PeriodArray" + idx_msg = rev_msg if box_with_array in [tm.to_array, pd.array] else msg + with pytest.raises(TypeError, match=idx_msg): + base <= idx + + # Different frequency + msg = rf"Invalid comparison between dtype=period\[{freq}\] and Period" + with pytest.raises(TypeError, match=msg): + base <= Period("2011", freq="4M") + + with pytest.raises(TypeError, match=msg): + Period("2011", freq="4M") >= base + + idx = PeriodIndex(["2011", "2012", "2013", "2014"], freq="4M") + rev_msg = r"Invalid comparison between dtype=period\[4M\] and PeriodArray" + idx_msg = rev_msg if box_with_array in [tm.to_array, pd.array] else msg + with pytest.raises(TypeError, match=idx_msg): + base <= idx + + @pytest.mark.parametrize("freq", ["M", "2M", "3M"]) + def test_pi_cmp_nat(self, freq): + idx1 = PeriodIndex(["2011-01", "2011-02", "NaT", "2011-05"], freq=freq) + per = idx1[1] + + result = idx1 > per + exp = np.array([False, False, False, True]) + tm.assert_numpy_array_equal(result, exp) + result = per < idx1 + tm.assert_numpy_array_equal(result, exp) + + result = idx1 == pd.NaT + exp = np.array([False, False, False, False]) + tm.assert_numpy_array_equal(result, exp) + result = pd.NaT == idx1 + tm.assert_numpy_array_equal(result, exp) + + result = idx1 != pd.NaT + exp = np.array([True, True, True, True]) + tm.assert_numpy_array_equal(result, exp) + result = pd.NaT != idx1 + tm.assert_numpy_array_equal(result, exp) + + idx2 = PeriodIndex(["2011-02", "2011-01", "2011-04", "NaT"], freq=freq) + result = idx1 < idx2 + exp = np.array([True, False, False, False]) + tm.assert_numpy_array_equal(result, exp) + + result = idx1 == idx2 + exp = np.array([False, False, False, False]) + tm.assert_numpy_array_equal(result, exp) + + result = idx1 != idx2 + exp = np.array([True, True, True, True]) + tm.assert_numpy_array_equal(result, exp) + + result = idx1 == idx1 + exp = np.array([True, True, False, True]) + tm.assert_numpy_array_equal(result, exp) + + result = idx1 != idx1 + exp = np.array([False, False, True, False]) + tm.assert_numpy_array_equal(result, exp) + + @pytest.mark.parametrize("freq", ["M", "2M", "3M"]) + def test_pi_cmp_nat_mismatched_freq_raises(self, freq): + idx1 = PeriodIndex(["2011-01", "2011-02", "NaT", "2011-05"], freq=freq) + + diff = PeriodIndex(["2011-02", "2011-01", "2011-04", "NaT"], freq="4M") + msg = rf"Invalid comparison between dtype=period\[{freq}\] and PeriodArray" + with pytest.raises(TypeError, match=msg): + idx1 > diff + + result = idx1 == diff + expected = np.array([False, False, False, False], dtype=bool) + tm.assert_numpy_array_equal(result, expected) + + # TODO: De-duplicate with test_pi_cmp_nat + @pytest.mark.parametrize("dtype", [object, None]) + def test_comp_nat(self, dtype): + left = PeriodIndex([Period("2011-01-01"), pd.NaT, Period("2011-01-03")]) + right = PeriodIndex([pd.NaT, pd.NaT, Period("2011-01-03")]) + + if dtype is not None: + left = left.astype(dtype) + right = right.astype(dtype) + + result = left == right + expected = np.array([False, False, True]) + tm.assert_numpy_array_equal(result, expected) + + result = left != right + expected = np.array([True, True, False]) + tm.assert_numpy_array_equal(result, expected) + + expected = np.array([False, False, False]) + tm.assert_numpy_array_equal(left == pd.NaT, expected) + tm.assert_numpy_array_equal(pd.NaT == right, expected) + + expected = np.array([True, True, True]) + tm.assert_numpy_array_equal(left != pd.NaT, expected) + tm.assert_numpy_array_equal(pd.NaT != left, expected) + + expected = np.array([False, False, False]) + tm.assert_numpy_array_equal(left < pd.NaT, expected) + tm.assert_numpy_array_equal(pd.NaT > left, expected) + + +class TestPeriodSeriesComparisons: + def test_cmp_series_period_series_mixed_freq(self): + # GH#13200 + base = Series( + [ + Period("2011", freq="Y"), + Period("2011-02", freq="M"), + Period("2013", freq="Y"), + Period("2011-04", freq="M"), + ] + ) + + ser = Series( + [ + Period("2012", freq="Y"), + Period("2011-01", freq="M"), + Period("2013", freq="Y"), + Period("2011-05", freq="M"), + ] + ) + + exp = Series([False, False, True, False]) + tm.assert_series_equal(base == ser, exp) + + exp = Series([True, True, False, True]) + tm.assert_series_equal(base != ser, exp) + + exp = Series([False, True, False, False]) + tm.assert_series_equal(base > ser, exp) + + exp = Series([True, False, False, True]) + tm.assert_series_equal(base < ser, exp) + + exp = Series([False, True, True, False]) + tm.assert_series_equal(base >= ser, exp) + + exp = Series([True, False, True, True]) + tm.assert_series_equal(base <= ser, exp) + + +class TestPeriodIndexSeriesComparisonConsistency: + """Test PeriodIndex and Period Series Ops consistency""" + + # TODO: needs parametrization+de-duplication + + def _check(self, values, func, expected): + # Test PeriodIndex and Period Series Ops consistency + + idx = PeriodIndex(values) + result = func(idx) + + # check that we don't pass an unwanted type to tm.assert_equal + assert isinstance(expected, (pd.Index, np.ndarray)) + tm.assert_equal(result, expected) + + s = Series(values) + result = func(s) + + exp = Series(expected, name=values.name) + tm.assert_series_equal(result, exp) + + def test_pi_comp_period(self): + idx = PeriodIndex( + ["2011-01", "2011-02", "2011-03", "2011-04"], freq="M", name="idx" + ) + per = idx[2] + + f = lambda x: x == per + exp = np.array([False, False, True, False], dtype=np.bool_) + self._check(idx, f, exp) + f = lambda x: per == x + self._check(idx, f, exp) + + f = lambda x: x != per + exp = np.array([True, True, False, True], dtype=np.bool_) + self._check(idx, f, exp) + f = lambda x: per != x + self._check(idx, f, exp) + + f = lambda x: per >= x + exp = np.array([True, True, True, False], dtype=np.bool_) + self._check(idx, f, exp) + + f = lambda x: x > per + exp = np.array([False, False, False, True], dtype=np.bool_) + self._check(idx, f, exp) + + f = lambda x: per >= x + exp = np.array([True, True, True, False], dtype=np.bool_) + self._check(idx, f, exp) + + def test_pi_comp_period_nat(self): + idx = PeriodIndex( + ["2011-01", "NaT", "2011-03", "2011-04"], freq="M", name="idx" + ) + per = idx[2] + + f = lambda x: x == per + exp = np.array([False, False, True, False], dtype=np.bool_) + self._check(idx, f, exp) + f = lambda x: per == x + self._check(idx, f, exp) + + f = lambda x: x == pd.NaT + exp = np.array([False, False, False, False], dtype=np.bool_) + self._check(idx, f, exp) + f = lambda x: pd.NaT == x + self._check(idx, f, exp) + + f = lambda x: x != per + exp = np.array([True, True, False, True], dtype=np.bool_) + self._check(idx, f, exp) + f = lambda x: per != x + self._check(idx, f, exp) + + f = lambda x: x != pd.NaT + exp = np.array([True, True, True, True], dtype=np.bool_) + self._check(idx, f, exp) + f = lambda x: pd.NaT != x + self._check(idx, f, exp) + + f = lambda x: per >= x + exp = np.array([True, False, True, False], dtype=np.bool_) + self._check(idx, f, exp) + + f = lambda x: x < per + exp = np.array([True, False, False, False], dtype=np.bool_) + self._check(idx, f, exp) + + f = lambda x: x > pd.NaT + exp = np.array([False, False, False, False], dtype=np.bool_) + self._check(idx, f, exp) + + f = lambda x: pd.NaT >= x + exp = np.array([False, False, False, False], dtype=np.bool_) + self._check(idx, f, exp) + + +# ------------------------------------------------------------------ +# Arithmetic + + +class TestPeriodFrameArithmetic: + def test_ops_frame_period(self): + # GH#13043 + df = pd.DataFrame( + { + "A": [Period("2015-01", freq="M"), Period("2015-02", freq="M")], + "B": [Period("2014-01", freq="M"), Period("2014-02", freq="M")], + } + ) + assert df["A"].dtype == "Period[M]" + assert df["B"].dtype == "Period[M]" + + p = Period("2015-03", freq="M") + off = p.freq + # dtype will be object because of original dtype + exp = pd.DataFrame( + { + "A": np.array([2 * off, 1 * off], dtype=object), + "B": np.array([14 * off, 13 * off], dtype=object), + } + ) + tm.assert_frame_equal(p - df, exp) + tm.assert_frame_equal(df - p, -1 * exp) + + df2 = pd.DataFrame( + { + "A": [Period("2015-05", freq="M"), Period("2015-06", freq="M")], + "B": [Period("2015-05", freq="M"), Period("2015-06", freq="M")], + } + ) + assert df2["A"].dtype == "Period[M]" + assert df2["B"].dtype == "Period[M]" + + exp = pd.DataFrame( + { + "A": np.array([4 * off, 4 * off], dtype=object), + "B": np.array([16 * off, 16 * off], dtype=object), + } + ) + tm.assert_frame_equal(df2 - df, exp) + tm.assert_frame_equal(df - df2, -1 * exp) + + +class TestPeriodIndexArithmetic: + # --------------------------------------------------------------- + # __add__/__sub__ with PeriodIndex + # PeriodIndex + other is defined for integers and timedelta-like others + # PeriodIndex - other is defined for integers, timedelta-like others, + # and PeriodIndex (with matching freq) + + def test_parr_add_iadd_parr_raises(self, box_with_array): + rng = period_range("1/1/2000", freq="D", periods=5) + other = period_range("1/6/2000", freq="D", periods=5) + # TODO: parametrize over boxes for other? + + rng = tm.box_expected(rng, box_with_array) + # An earlier implementation of PeriodIndex addition performed + # a set operation (union). This has since been changed to + # raise a TypeError. See GH#14164 and GH#13077 for historical + # reference. + msg = r"unsupported operand type\(s\) for \+: .* and .*" + with pytest.raises(TypeError, match=msg): + rng + other + + with pytest.raises(TypeError, match=msg): + rng += other + + def test_pi_sub_isub_pi(self): + # GH#20049 + # For historical reference see GH#14164, GH#13077. + # PeriodIndex subtraction originally performed set difference, + # then changed to raise TypeError before being implemented in GH#20049 + rng = period_range("1/1/2000", freq="D", periods=5) + other = period_range("1/6/2000", freq="D", periods=5) + + off = rng.freq + expected = pd.Index([-5 * off] * 5) + result = rng - other + tm.assert_index_equal(result, expected) + + rng -= other + tm.assert_index_equal(rng, expected) + + def test_pi_sub_pi_with_nat(self): + rng = period_range("1/1/2000", freq="D", periods=5) + other = rng[1:].insert(0, pd.NaT) + assert other[1:].equals(rng[1:]) + + result = rng - other + off = rng.freq + expected = pd.Index([pd.NaT, 0 * off, 0 * off, 0 * off, 0 * off]) + tm.assert_index_equal(result, expected) + + def test_parr_sub_pi_mismatched_freq(self, box_with_array, box_with_array2): + rng = period_range("1/1/2000", freq="D", periods=5) + other = period_range("1/6/2000", freq="h", periods=5) + + rng = tm.box_expected(rng, box_with_array) + other = tm.box_expected(other, box_with_array2) + msg = r"Input has different freq=[hD] from PeriodArray\(freq=[Dh]\)" + with pytest.raises(IncompatibleFrequency, match=msg): + rng - other + + @pytest.mark.parametrize("n", [1, 2, 3, 4]) + def test_sub_n_gt_1_ticks(self, tick_classes, n): + # GH 23878 + p1_d = "19910905" + p2_d = "19920406" + p1 = PeriodIndex([p1_d], freq=tick_classes(n)) + p2 = PeriodIndex([p2_d], freq=tick_classes(n)) + + expected = PeriodIndex([p2_d], freq=p2.freq.base) - PeriodIndex( + [p1_d], freq=p1.freq.base + ) + + tm.assert_index_equal((p2 - p1), expected) + + @pytest.mark.parametrize("n", [1, 2, 3, 4]) + @pytest.mark.parametrize( + "offset, kwd_name", + [ + (pd.offsets.YearEnd, "month"), + (pd.offsets.QuarterEnd, "startingMonth"), + (pd.offsets.MonthEnd, None), + (pd.offsets.Week, "weekday"), + ], + ) + def test_sub_n_gt_1_offsets(self, offset, kwd_name, n): + # GH 23878 + kwds = {kwd_name: 3} if kwd_name is not None else {} + p1_d = "19910905" + p2_d = "19920406" + freq = offset(n, normalize=False, **kwds) + p1 = PeriodIndex([p1_d], freq=freq) + p2 = PeriodIndex([p2_d], freq=freq) + + result = p2 - p1 + expected = PeriodIndex([p2_d], freq=freq.base) - PeriodIndex( + [p1_d], freq=freq.base + ) + + tm.assert_index_equal(result, expected) + + # ------------------------------------------------------------- + # Invalid Operations + + @pytest.mark.parametrize( + "other", + [ + # datetime scalars + Timestamp("2016-01-01"), + Timestamp("2016-01-01").to_pydatetime(), + Timestamp("2016-01-01").to_datetime64(), + # datetime-like arrays + pd.date_range("2016-01-01", periods=3, freq="h"), + pd.date_range("2016-01-01", periods=3, tz="Europe/Brussels"), + pd.date_range("2016-01-01", periods=3, freq="s")._data, + pd.date_range("2016-01-01", periods=3, tz="Asia/Tokyo")._data, + # Miscellaneous invalid types + 3.14, + np.array([2.0, 3.0, 4.0]), + ], + ) + def test_parr_add_sub_invalid(self, other, box_with_array): + # GH#23215 + rng = period_range("1/1/2000", freq="D", periods=3) + rng = tm.box_expected(rng, box_with_array) + + msg = "|".join( + [ + r"(:?cannot add PeriodArray and .*)", + r"(:?cannot subtract .* from (:?a\s)?.*)", + r"(:?unsupported operand type\(s\) for \+: .* and .*)", + r"unsupported operand type\(s\) for [+-]: .* and .*", + ] + ) + assert_invalid_addsub_type(rng, other, msg) + with pytest.raises(TypeError, match=msg): + rng + other + with pytest.raises(TypeError, match=msg): + other + rng + with pytest.raises(TypeError, match=msg): + rng - other + with pytest.raises(TypeError, match=msg): + other - rng + + # ----------------------------------------------------------------- + # __add__/__sub__ with ndarray[datetime64] and ndarray[timedelta64] + + def test_pi_add_sub_td64_array_non_tick_raises(self): + rng = period_range("1/1/2000", freq="Q", periods=3) + tdi = TimedeltaIndex(["-1 Day", "-1 Day", "-1 Day"]) + tdarr = tdi.values + + msg = r"Cannot add or subtract timedelta64\[ns\] dtype from period\[Q-DEC\]" + with pytest.raises(TypeError, match=msg): + rng + tdarr + with pytest.raises(TypeError, match=msg): + tdarr + rng + + with pytest.raises(TypeError, match=msg): + rng - tdarr + msg = r"cannot subtract PeriodArray from TimedeltaArray" + with pytest.raises(TypeError, match=msg): + tdarr - rng + + def test_pi_add_sub_td64_array_tick(self): + # PeriodIndex + Timedelta-like is allowed only with + # tick-like frequencies + rng = period_range("1/1/2000", freq="90D", periods=3) + tdi = TimedeltaIndex(["-1 Day", "-1 Day", "-1 Day"]) + tdarr = tdi.values + + expected = period_range("12/31/1999", freq="90D", periods=3) + result = rng + tdi + tm.assert_index_equal(result, expected) + result = rng + tdarr + tm.assert_index_equal(result, expected) + result = tdi + rng + tm.assert_index_equal(result, expected) + result = tdarr + rng + tm.assert_index_equal(result, expected) + + expected = period_range("1/2/2000", freq="90D", periods=3) + + result = rng - tdi + tm.assert_index_equal(result, expected) + result = rng - tdarr + tm.assert_index_equal(result, expected) + + msg = r"cannot subtract .* from .*" + with pytest.raises(TypeError, match=msg): + tdarr - rng + + with pytest.raises(TypeError, match=msg): + tdi - rng + + @pytest.mark.parametrize("pi_freq", ["D", "W", "Q", "h"]) + @pytest.mark.parametrize("tdi_freq", [None, "h"]) + def test_parr_sub_td64array(self, box_with_array, tdi_freq, pi_freq): + box = box_with_array + xbox = box if box not in [pd.array, tm.to_array] else pd.Index + + tdi = TimedeltaIndex(["1 hours", "2 hours"], freq=tdi_freq) + dti = Timestamp("2018-03-07 17:16:40") + tdi + pi = dti.to_period(pi_freq) + + # TODO: parametrize over box for pi? + td64obj = tm.box_expected(tdi, box) + + if pi_freq == "h": + result = pi - td64obj + expected = (pi.to_timestamp("s") - tdi).to_period(pi_freq) + expected = tm.box_expected(expected, xbox) + tm.assert_equal(result, expected) + + # Subtract from scalar + result = pi[0] - td64obj + expected = (pi[0].to_timestamp("s") - tdi).to_period(pi_freq) + expected = tm.box_expected(expected, box) + tm.assert_equal(result, expected) + + elif pi_freq == "D": + # Tick, but non-compatible + msg = ( + "Cannot add/subtract timedelta-like from PeriodArray that is " + "not an integer multiple of the PeriodArray's freq." + ) + with pytest.raises(IncompatibleFrequency, match=msg): + pi - td64obj + + with pytest.raises(IncompatibleFrequency, match=msg): + pi[0] - td64obj + + else: + # With non-Tick freq, we could not add timedelta64 array regardless + # of what its resolution is + msg = "Cannot add or subtract timedelta64" + with pytest.raises(TypeError, match=msg): + pi - td64obj + with pytest.raises(TypeError, match=msg): + pi[0] - td64obj + + # ----------------------------------------------------------------- + # operations with array/Index of DateOffset objects + + @pytest.mark.parametrize("box", [np.array, pd.Index]) + def test_pi_add_offset_array(self, performance_warning, box): + # GH#18849 + pi = PeriodIndex([Period("2015Q1"), Period("2016Q2")]) + offs = box( + [ + pd.offsets.QuarterEnd(n=1, startingMonth=12), + pd.offsets.QuarterEnd(n=-2, startingMonth=12), + ] + ) + expected = PeriodIndex([Period("2015Q2"), Period("2015Q4")]).astype(object) + + with tm.assert_produces_warning(performance_warning): + res = pi + offs + tm.assert_index_equal(res, expected) + + with tm.assert_produces_warning(performance_warning): + res2 = offs + pi + tm.assert_index_equal(res2, expected) + + unanchored = np.array([pd.offsets.Hour(n=1), pd.offsets.Minute(n=-2)]) + # addition/subtraction ops with incompatible offsets should issue + # a PerformanceWarning and _then_ raise a TypeError. + msg = r"Input cannot be converted to Period\(freq=Q-DEC\)" + with pytest.raises(IncompatibleFrequency, match=msg): + with tm.assert_produces_warning(performance_warning): + pi + unanchored + with pytest.raises(IncompatibleFrequency, match=msg): + with tm.assert_produces_warning(performance_warning): + unanchored + pi + + @pytest.mark.parametrize("box", [np.array, pd.Index]) + def test_pi_sub_offset_array(self, performance_warning, box): + # GH#18824 + pi = PeriodIndex([Period("2015Q1"), Period("2016Q2")]) + other = box( + [ + pd.offsets.QuarterEnd(n=1, startingMonth=12), + pd.offsets.QuarterEnd(n=-2, startingMonth=12), + ] + ) + + expected = PeriodIndex([pi[n] - other[n] for n in range(len(pi))]) + expected = expected.astype(object) + + with tm.assert_produces_warning(performance_warning): + res = pi - other + tm.assert_index_equal(res, expected) + + anchored = box([pd.offsets.MonthEnd(), pd.offsets.Day(n=2)]) + + # addition/subtraction ops with anchored offsets should issue + # a PerformanceWarning and _then_ raise a TypeError. + msg = r"Input has different freq=-1M from Period\(freq=Q-DEC\)" + with pytest.raises(IncompatibleFrequency, match=msg): + with tm.assert_produces_warning(performance_warning): + pi - anchored + with pytest.raises(IncompatibleFrequency, match=msg): + with tm.assert_produces_warning(performance_warning): + anchored - pi + + def test_pi_add_iadd_int(self, one): + # Variants of `one` for #19012 + rng = period_range("2000-01-01 09:00", freq="h", periods=10) + result = rng + one + expected = period_range("2000-01-01 10:00", freq="h", periods=10) + tm.assert_index_equal(result, expected) + rng += one + tm.assert_index_equal(rng, expected) + + def test_pi_sub_isub_int(self, one): + """ + PeriodIndex.__sub__ and __isub__ with several representations of + the integer 1, e.g. int, np.int64, np.uint8, ... + """ + rng = period_range("2000-01-01 09:00", freq="h", periods=10) + result = rng - one + expected = period_range("2000-01-01 08:00", freq="h", periods=10) + tm.assert_index_equal(result, expected) + rng -= one + tm.assert_index_equal(rng, expected) + + @pytest.mark.parametrize("five", [5, np.array(5, dtype=np.int64)]) + def test_pi_sub_intlike(self, five): + rng = period_range("2007-01", periods=50) + + result = rng - five + exp = rng + (-five) + tm.assert_index_equal(result, exp) + + def test_pi_add_sub_int_array_freqn_gt1(self): + # GH#47209 test adding array of ints when freq.n > 1 matches + # scalar behavior + pi = period_range("2016-01-01", periods=10, freq="2D") + arr = np.arange(10) + result = pi + arr + expected = pd.Index([x + y for x, y in zip(pi, arr, strict=True)]) + tm.assert_index_equal(result, expected) + + result = pi - arr + expected = pd.Index([x - y for x, y in zip(pi, arr, strict=True)]) + tm.assert_index_equal(result, expected) + + def test_pi_sub_isub_offset(self): + # offset + # DateOffset + rng = period_range("2014", "2024", freq="Y") + result = rng - pd.offsets.YearEnd(5) + expected = period_range("2009", "2019", freq="Y") + tm.assert_index_equal(result, expected) + rng -= pd.offsets.YearEnd(5) + tm.assert_index_equal(rng, expected) + + rng = period_range("2014-01", "2016-12", freq="M") + result = rng - pd.offsets.MonthEnd(5) + expected = period_range("2013-08", "2016-07", freq="M") + tm.assert_index_equal(result, expected) + + rng -= pd.offsets.MonthEnd(5) + tm.assert_index_equal(rng, expected) + + @pytest.mark.parametrize("transpose", [True, False]) + def test_pi_add_offset_n_gt1(self, box_with_array, transpose): + # GH#23215 + # add offset to PeriodIndex with freq.n > 1 + + per = Period("2016-01", freq="2M") + pi = PeriodIndex([per]) + + expected = PeriodIndex(["2016-03"], freq="2M") + + pi = tm.box_expected(pi, box_with_array, transpose=transpose) + expected = tm.box_expected(expected, box_with_array, transpose=transpose) + + result = pi + per.freq + tm.assert_equal(result, expected) + + result = per.freq + pi + tm.assert_equal(result, expected) + + def test_pi_add_offset_n_gt1_not_divisible(self, box_with_array): + # GH#23215 + # PeriodIndex with freq.n > 1 add offset with offset.n % freq.n != 0 + pi = PeriodIndex(["2016-01"], freq="2M") + expected = PeriodIndex(["2016-04"], freq="2M") + + pi = tm.box_expected(pi, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = pi + to_offset("3ME") + tm.assert_equal(result, expected) + + result = to_offset("3ME") + pi + tm.assert_equal(result, expected) + + # --------------------------------------------------------------- + # __add__/__sub__ with integer arrays + + @pytest.mark.parametrize("int_holder", [np.array, pd.Index]) + @pytest.mark.parametrize("op", [operator.add, ops.radd]) + def test_pi_add_intarray(self, int_holder, op): + # GH#19959 + pi = PeriodIndex([Period("2015Q1"), Period("NaT")]) + other = int_holder([4, -1]) + + result = op(pi, other) + expected = PeriodIndex([Period("2016Q1"), Period("NaT")]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("int_holder", [np.array, pd.Index]) + def test_pi_sub_intarray(self, int_holder): + # GH#19959 + pi = PeriodIndex([Period("2015Q1"), Period("NaT")]) + other = int_holder([4, -1]) + + result = pi - other + expected = PeriodIndex([Period("2014Q1"), Period("NaT")]) + tm.assert_index_equal(result, expected) + + msg = r"bad operand type for unary -: 'PeriodArray'" + with pytest.raises(TypeError, match=msg): + other - pi + + # --------------------------------------------------------------- + # Timedelta-like (timedelta, timedelta64, Timedelta, Tick) + # TODO: Some of these are misnomers because of non-Tick DateOffsets + + def test_parr_add_timedeltalike_minute_gt1(self, three_days, box_with_array): + # GH#23031 adding a time-delta-like offset to a PeriodArray that has + # minute frequency with n != 1. A more general case is tested below + # in test_pi_add_timedeltalike_tick_gt1, but here we write out the + # expected result more explicitly. + other = three_days + rng = period_range("2014-05-01", periods=3, freq="2D") + rng = tm.box_expected(rng, box_with_array) + + expected = PeriodIndex(["2014-05-04", "2014-05-06", "2014-05-08"], freq="2D") + expected = tm.box_expected(expected, box_with_array) + + result = rng + other + tm.assert_equal(result, expected) + + result = other + rng + tm.assert_equal(result, expected) + + # subtraction + expected = PeriodIndex(["2014-04-28", "2014-04-30", "2014-05-02"], freq="2D") + expected = tm.box_expected(expected, box_with_array) + result = rng - other + tm.assert_equal(result, expected) + + msg = "|".join( + [ + r"bad operand type for unary -: 'PeriodArray'", + r"cannot subtract PeriodArray from timedelta64\[[hD]\]", + ] + ) + with pytest.raises(TypeError, match=msg): + other - rng + + @pytest.mark.parametrize("freqstr", ["5ns", "5us", "5ms", "5s", "5min", "5h", "5D"]) + def test_parr_add_timedeltalike_tick_gt1(self, three_days, freqstr, box_with_array): + # GH#23031 adding a time-delta-like offset to a PeriodArray that has + # tick-like frequency with n != 1 + other = three_days + rng = period_range("2014-05-01", periods=6, freq=freqstr) + first = rng[0] + rng = tm.box_expected(rng, box_with_array) + + expected = period_range(first + other, periods=6, freq=freqstr) + expected = tm.box_expected(expected, box_with_array) + + result = rng + other + tm.assert_equal(result, expected) + + result = other + rng + tm.assert_equal(result, expected) + + # subtraction + expected = period_range(first - other, periods=6, freq=freqstr) + expected = tm.box_expected(expected, box_with_array) + result = rng - other + tm.assert_equal(result, expected) + msg = "|".join( + [ + r"bad operand type for unary -: 'PeriodArray'", + r"cannot subtract PeriodArray from timedelta64\[[hD]\]", + ] + ) + with pytest.raises(TypeError, match=msg): + other - rng + + def test_pi_add_iadd_timedeltalike_daily(self, three_days): + # Tick + other = three_days + rng = period_range("2014-05-01", "2014-05-15", freq="D") + expected = period_range("2014-05-04", "2014-05-18", freq="D") + + result = rng + other + tm.assert_index_equal(result, expected) + + rng += other + tm.assert_index_equal(rng, expected) + + def test_pi_sub_isub_timedeltalike_daily(self, three_days): + # Tick-like 3 Days + other = three_days + rng = period_range("2014-05-01", "2014-05-15", freq="D") + expected = period_range("2014-04-28", "2014-05-12", freq="D") + + result = rng - other + tm.assert_index_equal(result, expected) + + rng -= other + tm.assert_index_equal(rng, expected) + + def test_parr_add_sub_timedeltalike_freq_mismatch_daily( + self, not_daily, box_with_array + ): + other = not_daily + rng = period_range("2014-05-01", "2014-05-15", freq="D") + rng = tm.box_expected(rng, box_with_array) + + msg = "|".join( + [ + # non-timedelta-like DateOffset + "Input has different freq(=.+)? from Period.*?\\(freq=D\\)", + # timedelta/td64/Timedelta but not a multiple of 24H + "Cannot add/subtract timedelta-like from PeriodArray that is " + "not an integer multiple of the PeriodArray's freq.", + ] + ) + with pytest.raises(IncompatibleFrequency, match=msg): + rng + other + with pytest.raises(IncompatibleFrequency, match=msg): + rng += other + with pytest.raises(IncompatibleFrequency, match=msg): + rng - other + with pytest.raises(IncompatibleFrequency, match=msg): + rng -= other + + def test_pi_add_iadd_timedeltalike_hourly(self, two_hours): + other = two_hours + rng = period_range("2014-01-01 10:00", "2014-01-05 10:00", freq="h") + expected = period_range("2014-01-01 12:00", "2014-01-05 12:00", freq="h") + + result = rng + other + tm.assert_index_equal(result, expected) + + rng += other + tm.assert_index_equal(rng, expected) + + def test_parr_add_timedeltalike_mismatched_freq_hourly( + self, not_hourly, box_with_array + ): + other = not_hourly + rng = period_range("2014-01-01 10:00", "2014-01-05 10:00", freq="h") + rng = tm.box_expected(rng, box_with_array) + msg = "|".join( + [ + # non-timedelta-like DateOffset + "Input has different freq(=.+)? from Period.*?\\(freq=h\\)", + # timedelta/td64/Timedelta but not a multiple of 24H + "Cannot add/subtract timedelta-like from PeriodArray that is " + "not an integer multiple of the PeriodArray's freq.", + ] + ) + + with pytest.raises(IncompatibleFrequency, match=msg): + rng + other + + with pytest.raises(IncompatibleFrequency, match=msg): + rng += other + + def test_pi_sub_isub_timedeltalike_hourly(self, two_hours): + other = two_hours + rng = period_range("2014-01-01 10:00", "2014-01-05 10:00", freq="h") + expected = period_range("2014-01-01 08:00", "2014-01-05 08:00", freq="h") + + result = rng - other + tm.assert_index_equal(result, expected) + + rng -= other + tm.assert_index_equal(rng, expected) + + def test_add_iadd_timedeltalike_annual(self): + # offset + # DateOffset + rng = period_range("2014", "2024", freq="Y") + result = rng + pd.offsets.YearEnd(5) + expected = period_range("2019", "2029", freq="Y") + tm.assert_index_equal(result, expected) + rng += pd.offsets.YearEnd(5) + tm.assert_index_equal(rng, expected) + + def test_pi_add_sub_timedeltalike_freq_mismatch_annual(self, mismatched_freq): + other = mismatched_freq + rng = period_range("2014", "2024", freq="Y") + msg = "Input has different freq(=.+)? from Period.*?\\(freq=Y-DEC\\)" + with pytest.raises(IncompatibleFrequency, match=msg): + rng + other + with pytest.raises(IncompatibleFrequency, match=msg): + rng += other + with pytest.raises(IncompatibleFrequency, match=msg): + rng - other + with pytest.raises(IncompatibleFrequency, match=msg): + rng -= other + + def test_pi_add_iadd_timedeltalike_M(self): + rng = period_range("2014-01", "2016-12", freq="M") + expected = period_range("2014-06", "2017-05", freq="M") + + result = rng + pd.offsets.MonthEnd(5) + tm.assert_index_equal(result, expected) + + rng += pd.offsets.MonthEnd(5) + tm.assert_index_equal(rng, expected) + + def test_pi_add_sub_timedeltalike_freq_mismatch_monthly(self, mismatched_freq): + other = mismatched_freq + rng = period_range("2014-01", "2016-12", freq="M") + msg = "Input has different freq(=.+)? from Period.*?\\(freq=M\\)" + with pytest.raises(IncompatibleFrequency, match=msg): + rng + other + with pytest.raises(IncompatibleFrequency, match=msg): + rng += other + with pytest.raises(IncompatibleFrequency, match=msg): + rng - other + with pytest.raises(IncompatibleFrequency, match=msg): + rng -= other + + @pytest.mark.parametrize("transpose", [True, False]) + def test_parr_add_sub_td64_nat(self, box_with_array, transpose): + # GH#23320 special handling for timedelta64("NaT") + pi = period_range("1994-04-01", periods=9, freq="19D") + other = np.timedelta64("NaT") + expected = PeriodIndex(["NaT"] * 9, freq="19D") + + obj = tm.box_expected(pi, box_with_array, transpose=transpose) + expected = tm.box_expected(expected, box_with_array, transpose=transpose) + + result = obj + other + tm.assert_equal(result, expected) + result = other + obj + tm.assert_equal(result, expected) + result = obj - other + tm.assert_equal(result, expected) + msg = r"cannot subtract .* from .*" + with pytest.raises(TypeError, match=msg): + other - obj + + @pytest.mark.parametrize( + "other", + [ + np.array(["NaT"] * 9, dtype="m8[ns]"), + TimedeltaArray._from_sequence(["NaT"] * 9, dtype="m8[ns]"), + ], + ) + def test_parr_add_sub_tdt64_nat_array(self, box_with_array, other): + pi = period_range("1994-04-01", periods=9, freq="19D") + expected = PeriodIndex(["NaT"] * 9, freq="19D") + + obj = tm.box_expected(pi, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = obj + other + tm.assert_equal(result, expected) + result = other + obj + tm.assert_equal(result, expected) + result = obj - other + tm.assert_equal(result, expected) + msg = r"cannot subtract .* from .*" + with pytest.raises(TypeError, match=msg): + other - obj + + # some but not *all* NaT + other = other.copy() + other[0] = np.timedelta64(0, "ns") + expected = PeriodIndex([pi[0]] + ["NaT"] * 8, freq="19D") + expected = tm.box_expected(expected, box_with_array) + + result = obj + other + tm.assert_equal(result, expected) + result = other + obj + tm.assert_equal(result, expected) + result = obj - other + tm.assert_equal(result, expected) + with pytest.raises(TypeError, match=msg): + other - obj + + # --------------------------------------------------------------- + # Unsorted + + def test_parr_add_sub_index(self): + # Check that PeriodArray defers to Index on arithmetic ops + pi = period_range("2000-12-31", periods=3) + parr = pi.array + + result = parr - pi + expected = pi - pi + tm.assert_index_equal(result, expected) + + def test_parr_add_sub_object_array(self, performance_warning): + pi = period_range("2000-12-31", periods=3, freq="D") + parr = pi.array + + other = np.array([Timedelta(days=1), pd.offsets.Day(2), 3]) + + with tm.assert_produces_warning(performance_warning): + result = parr + other + + expected = PeriodIndex( + ["2001-01-01", "2001-01-03", "2001-01-05"], freq="D" + )._data.astype(object) + tm.assert_equal(result, expected) + + with tm.assert_produces_warning(performance_warning): + result = parr - other + + expected = PeriodIndex(["2000-12-30"] * 3, freq="D")._data.astype(object) + tm.assert_equal(result, expected) + + def test_period_add_timestamp_raises(self, box_with_array): + # GH#17983 + ts = Timestamp("2017") + per = Period("2017", freq="M") + + arr = pd.Index([per], dtype="Period[M]") + arr = tm.box_expected(arr, box_with_array) + + msg = "cannot add PeriodArray and Timestamp" + with pytest.raises(TypeError, match=msg): + arr + ts + with pytest.raises(TypeError, match=msg): + ts + arr + + msg = "cannot add PeriodArray and DatetimeArray" + with pytest.raises(TypeError, match=msg): + arr + Series([ts]) + with pytest.raises(TypeError, match=msg): + Series([ts]) + arr + with pytest.raises(TypeError, match=msg): + arr + pd.Index([ts]) + with pytest.raises(TypeError, match=msg): + pd.Index([ts]) + arr + + if box_with_array is pd.DataFrame: + msg = "cannot add PeriodArray and DatetimeArray" + else: + msg = r"unsupported operand type\(s\) for \+: 'Period' and 'DatetimeArray" + with pytest.raises(TypeError, match=msg): + arr + pd.DataFrame([ts]) + if box_with_array is pd.DataFrame: + msg = "cannot add PeriodArray and DatetimeArray" + else: + msg = r"unsupported operand type\(s\) for \+: 'DatetimeArray' and 'Period'" + with pytest.raises(TypeError, match=msg): + pd.DataFrame([ts]) + arr + + +class TestPeriodSeriesArithmetic: + def test_parr_add_timedeltalike_scalar(self, three_days, box_with_array): + # GH#13043 + ser = Series( + [Period("2015-01-01", freq="D"), Period("2015-01-02", freq="D")], + name="xxx", + ) + assert ser.dtype == "Period[D]" + + expected = Series( + [Period("2015-01-04", freq="D"), Period("2015-01-05", freq="D")], + name="xxx", + ) + + obj = tm.box_expected(ser, box_with_array) + if box_with_array is pd.DataFrame: + assert (obj.dtypes == "Period[D]").all() + + expected = tm.box_expected(expected, box_with_array) + + result = obj + three_days + tm.assert_equal(result, expected) + + result = three_days + obj + tm.assert_equal(result, expected) + + def test_ops_series_period(self): + # GH#13043 + ser = Series( + [Period("2015-01-01", freq="D"), Period("2015-01-02", freq="D")], + name="xxx", + ) + assert ser.dtype == "Period[D]" + + per = Period("2015-01-10", freq="D") + off = per.freq + # dtype will be object because of original dtype + expected = Series([9 * off, 8 * off], name="xxx", dtype=object) + tm.assert_series_equal(per - ser, expected) + tm.assert_series_equal(ser - per, -1 * expected) + + s2 = Series( + [Period("2015-01-05", freq="D"), Period("2015-01-04", freq="D")], + name="xxx", + ) + assert s2.dtype == "Period[D]" + + expected = Series([4 * off, 2 * off], name="xxx", dtype=object) + tm.assert_series_equal(s2 - ser, expected) + tm.assert_series_equal(ser - s2, -1 * expected) + + +class TestPeriodIndexSeriesMethods: + """Test PeriodIndex and Period Series Ops consistency""" + + def _check(self, values, func, expected): + idx = PeriodIndex(values) + result = func(idx) + tm.assert_equal(result, expected) + + ser = Series(values) + result = func(ser) + + exp = Series(expected, name=values.name) + tm.assert_series_equal(result, exp) + + def test_pi_ops(self): + idx = PeriodIndex( + ["2011-01", "2011-02", "2011-03", "2011-04"], freq="M", name="idx" + ) + + expected = PeriodIndex( + ["2011-03", "2011-04", "2011-05", "2011-06"], freq="M", name="idx" + ) + + self._check(idx, lambda x: x + 2, expected) + self._check(idx, lambda x: 2 + x, expected) + + self._check(idx + 2, lambda x: x - 2, idx) + + result = idx - Period("2011-01", freq="M") + off = idx.freq + exp = pd.Index([0 * off, 1 * off, 2 * off, 3 * off], name="idx") + tm.assert_index_equal(result, exp) + + result = Period("2011-01", freq="M") - idx + exp = pd.Index([0 * off, -1 * off, -2 * off, -3 * off], name="idx") + tm.assert_index_equal(result, exp) + + @pytest.mark.parametrize("ng", ["str", 1.5]) + @pytest.mark.parametrize( + "func", + [ + lambda obj, ng: obj + ng, + lambda obj, ng: ng + obj, + lambda obj, ng: obj - ng, + lambda obj, ng: ng - obj, + lambda obj, ng: np.add(obj, ng), + lambda obj, ng: np.add(ng, obj), + lambda obj, ng: np.subtract(obj, ng), + lambda obj, ng: np.subtract(ng, obj), + ], + ) + def test_parr_ops_errors(self, ng, func, box_with_array): + idx = PeriodIndex( + ["2011-01", "2011-02", "2011-03", "2011-04"], freq="M", name="idx" + ) + obj = tm.box_expected(idx, box_with_array) + msg = "|".join( + [ + r"unsupported operand type\(s\)", + "can only concatenate", + r"must be str", + "object to str implicitly", + ] + ) + + with pytest.raises(TypeError, match=msg): + func(obj, ng) + + def test_pi_ops_nat(self): + idx = PeriodIndex( + ["2011-01", "2011-02", "NaT", "2011-04"], freq="M", name="idx" + ) + expected = PeriodIndex( + ["2011-03", "2011-04", "NaT", "2011-06"], freq="M", name="idx" + ) + + self._check(idx, lambda x: x + 2, expected) + self._check(idx, lambda x: 2 + x, expected) + self._check(idx, lambda x: np.add(x, 2), expected) + + self._check(idx + 2, lambda x: x - 2, idx) + self._check(idx + 2, lambda x: np.subtract(x, 2), idx) + + # freq with mult + idx = PeriodIndex( + ["2011-01", "2011-02", "NaT", "2011-04"], freq="2M", name="idx" + ) + expected = PeriodIndex( + ["2011-07", "2011-08", "NaT", "2011-10"], freq="2M", name="idx" + ) + + self._check(idx, lambda x: x + 3, expected) + self._check(idx, lambda x: 3 + x, expected) + self._check(idx, lambda x: np.add(x, 3), expected) + + self._check(idx + 3, lambda x: x - 3, idx) + self._check(idx + 3, lambda x: np.subtract(x, 3), idx) + + def test_pi_ops_array_int(self): + idx = PeriodIndex( + ["2011-01", "2011-02", "NaT", "2011-04"], freq="M", name="idx" + ) + f = lambda x: x + np.array([1, 2, 3, 4]) + exp = PeriodIndex( + ["2011-02", "2011-04", "NaT", "2011-08"], freq="M", name="idx" + ) + self._check(idx, f, exp) + + f = lambda x: np.add(x, np.array([4, -1, 1, 2])) + exp = PeriodIndex( + ["2011-05", "2011-01", "NaT", "2011-06"], freq="M", name="idx" + ) + self._check(idx, f, exp) + + f = lambda x: x - np.array([1, 2, 3, 4]) + exp = PeriodIndex( + ["2010-12", "2010-12", "NaT", "2010-12"], freq="M", name="idx" + ) + self._check(idx, f, exp) + + f = lambda x: np.subtract(x, np.array([3, 2, 3, -2])) + exp = PeriodIndex( + ["2010-10", "2010-12", "NaT", "2011-06"], freq="M", name="idx" + ) + self._check(idx, f, exp) + + def test_pi_ops_offset(self): + idx = PeriodIndex( + ["2011-01-01", "2011-02-01", "2011-03-01", "2011-04-01"], + freq="D", + name="idx", + ) + f = lambda x: x + pd.offsets.Day() + exp = PeriodIndex( + ["2011-01-02", "2011-02-02", "2011-03-02", "2011-04-02"], + freq="D", + name="idx", + ) + self._check(idx, f, exp) + + f = lambda x: x + pd.offsets.Day(2) + exp = PeriodIndex( + ["2011-01-03", "2011-02-03", "2011-03-03", "2011-04-03"], + freq="D", + name="idx", + ) + self._check(idx, f, exp) + + f = lambda x: x - pd.offsets.Day(2) + exp = PeriodIndex( + ["2010-12-30", "2011-01-30", "2011-02-27", "2011-03-30"], + freq="D", + name="idx", + ) + self._check(idx, f, exp) + + def test_pi_offset_errors(self): + idx = PeriodIndex( + ["2011-01-01", "2011-02-01", "2011-03-01", "2011-04-01"], + freq="D", + name="idx", + ) + ser = Series(idx) + + msg = ( + "Cannot add/subtract timedelta-like from PeriodArray that is not " + "an integer multiple of the PeriodArray's freq" + ) + for obj in [idx, ser]: + with pytest.raises(IncompatibleFrequency, match=msg): + obj + pd.offsets.Hour(2) + + with pytest.raises(IncompatibleFrequency, match=msg): + pd.offsets.Hour(2) + obj + + with pytest.raises(IncompatibleFrequency, match=msg): + obj - pd.offsets.Hour(2) + + def test_pi_sub_period(self): + # GH#13071 + idx = PeriodIndex( + ["2011-01", "2011-02", "2011-03", "2011-04"], freq="M", name="idx" + ) + + result = idx - Period("2012-01", freq="M") + off = idx.freq + exp = pd.Index([-12 * off, -11 * off, -10 * off, -9 * off], name="idx") + tm.assert_index_equal(result, exp) + + result = np.subtract(idx, Period("2012-01", freq="M")) + tm.assert_index_equal(result, exp) + + result = Period("2012-01", freq="M") - idx + exp = pd.Index([12 * off, 11 * off, 10 * off, 9 * off], name="idx") + tm.assert_index_equal(result, exp) + + result = np.subtract(Period("2012-01", freq="M"), idx) + tm.assert_index_equal(result, exp) + + exp = TimedeltaIndex( + [np.nan, np.nan, np.nan, np.nan], name="idx", dtype="m8[ns]" + ) + result = idx - Period("NaT", freq="M") + tm.assert_index_equal(result, exp) + assert result.freq == exp.freq + + result = Period("NaT", freq="M") - idx + tm.assert_index_equal(result, exp) + assert result.freq == exp.freq + + def test_pi_sub_pdnat(self): + # GH#13071, GH#19389 + idx = PeriodIndex( + ["2011-01", "2011-02", "NaT", "2011-04"], freq="M", name="idx" + ) + exp = TimedeltaIndex([pd.NaT] * 4, name="idx", dtype="m8[ns]") + tm.assert_index_equal(pd.NaT - idx, exp) + tm.assert_index_equal(idx - pd.NaT, exp) + + def test_pi_sub_period_nat(self): + # GH#13071 + idx = PeriodIndex( + ["2011-01", "NaT", "2011-03", "2011-04"], freq="M", name="idx" + ) + + result = idx - Period("2012-01", freq="M") + off = idx.freq + exp = pd.Index([-12 * off, pd.NaT, -10 * off, -9 * off], name="idx") + tm.assert_index_equal(result, exp) + + result = Period("2012-01", freq="M") - idx + exp = pd.Index([12 * off, pd.NaT, 10 * off, 9 * off], name="idx") + tm.assert_index_equal(result, exp) + + exp = TimedeltaIndex( + [np.nan, np.nan, np.nan, np.nan], name="idx", dtype="m8[ns]" + ) + tm.assert_index_equal(idx - Period("NaT", freq="M"), exp) + tm.assert_index_equal(Period("NaT", freq="M") - idx, exp) diff --git a/pandas/tests/arithmetic/test_string.py b/pandas/tests/arithmetic/test_string.py new file mode 100644 index 0000000000000000000000000000000000000000..46a3d1e8386eb29153700937bb82e43ddf23883d --- /dev/null +++ b/pandas/tests/arithmetic/test_string.py @@ -0,0 +1,472 @@ +import operator +from pathlib import Path + +import numpy as np +import pytest + +from pandas.compat import HAS_PYARROW +from pandas.errors import Pandas4Warning +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + NA, + ArrowDtype, + Series, + StringDtype, +) +import pandas._testing as tm +from pandas.core.construction import extract_array + + +def string_dtype_highest_priority(dtype1, dtype2): + if HAS_PYARROW: + DTYPE_HIERARCHY = [ + StringDtype("python", na_value=np.nan), + StringDtype("pyarrow", na_value=np.nan), + StringDtype("python", na_value=NA), + StringDtype("pyarrow", na_value=NA), + ] + else: + DTYPE_HIERARCHY = [ + StringDtype("python", na_value=np.nan), + StringDtype("python", na_value=NA), + ] + + h1 = DTYPE_HIERARCHY.index(dtype1) + h2 = DTYPE_HIERARCHY.index(dtype2) + return DTYPE_HIERARCHY[max(h1, h2)] + + +def test_eq_all_na(): + pytest.importorskip("pyarrow") + a = pd.array([NA, NA], dtype=StringDtype("pyarrow")) + result = a == a + expected = pd.array([NA, NA], dtype="boolean[pyarrow]") + tm.assert_extension_array_equal(result, expected) + + +def test_reversed_logical_ops(any_string_dtype): + # GH#60234 + dtype = any_string_dtype + warn = None if dtype == object else Pandas4Warning + left = Series([True, False, False, True]) + right = Series(["", "", "b", "c"], dtype=dtype) + + msg = "operations between boolean dtype and" + with tm.assert_produces_warning(warn, match=msg): + result = left | right + expected = left | right.astype(bool) + tm.assert_series_equal(result, expected) + + with tm.assert_produces_warning(warn, match=msg): + result = left & right + expected = left & right.astype(bool) + tm.assert_series_equal(result, expected) + + with tm.assert_produces_warning(warn, match=msg): + result = left ^ right + expected = left ^ right.astype(bool) + tm.assert_series_equal(result, expected) + + +def test_pathlib_path_division(any_string_dtype, request): + # GH#61940 + if any_string_dtype == object: + mark = pytest.mark.xfail( + reason="with NA present we go through _masked_arith_op which " + "raises TypeError bc Path is not recognized by lib.is_scalar." + ) + request.applymarker(mark) + + item = Path("/Users/Irv/") + ser = Series(["A", "B", NA], dtype=any_string_dtype) + + result = item / ser + expected = Series([item / "A", item / "B", ser.dtype.na_value], dtype=object) + tm.assert_series_equal(result, expected) + + result = ser / item + expected = Series(["A" / item, "B" / item, ser.dtype.na_value], dtype=object) + tm.assert_series_equal(result, expected) + + +def test_mixed_object_comparison(any_string_dtype): + # GH#60228 + dtype = any_string_dtype + ser = Series(["a", "b"], dtype=dtype) + + mixed = Series([1, "b"], dtype=object) + + result = ser == mixed + expected = Series([False, True], dtype=bool) + if dtype == object: + pass + elif dtype.storage == "python" and dtype.na_value is NA: + expected = expected.astype("boolean") + elif dtype.storage == "pyarrow" and dtype.na_value is NA: + expected = expected.astype("bool[pyarrow]") + + tm.assert_series_equal(result, expected) + + +def test_pyarrow_numpy_string_invalid(): + # GH#56008 + pa = pytest.importorskip("pyarrow") + ser = Series([False, True]) + ser2 = Series(["a", "b"], dtype=StringDtype(na_value=np.nan)) + result = ser == ser2 + expected_eq = Series(False, index=ser.index) + tm.assert_series_equal(result, expected_eq) + + result = ser != ser2 + expected_ne = Series(True, index=ser.index) + tm.assert_series_equal(result, expected_ne) + + with pytest.raises(TypeError, match="Invalid comparison"): + ser > ser2 + + # GH#59505 + ser3 = ser2.astype("string[pyarrow]") + result3_eq = ser3 == ser + tm.assert_series_equal(result3_eq, expected_eq.astype("bool[pyarrow]")) + result3_ne = ser3 != ser + tm.assert_series_equal(result3_ne, expected_ne.astype("bool[pyarrow]")) + + with pytest.raises(TypeError, match="Invalid comparison"): + ser > ser3 + + ser4 = ser2.astype(ArrowDtype(pa.string())) + result4_eq = ser4 == ser + tm.assert_series_equal(result4_eq, expected_eq.astype("bool[pyarrow]")) + result4_ne = ser4 != ser + tm.assert_series_equal(result4_ne, expected_ne.astype("bool[pyarrow]")) + + with pytest.raises(TypeError, match="Invalid comparison"): + ser > ser4 + + +def test_mul_bool_invalid(any_string_dtype): + # GH#62595 + dtype = any_string_dtype + ser = Series(["a", "b", "c"], dtype=dtype) + + if dtype == object: + pytest.skip("This is not expect to raise") + elif dtype.storage == "python": + msg = "Cannot multiply StringArray by bools. Explicitly cast to integers" + else: + msg = "Can only string multiply by an integer" + + with pytest.raises(TypeError, match=msg): + False * ser + with pytest.raises(TypeError, match=msg): + ser * True + with pytest.raises(TypeError, match=msg): + ser * np.array([True, False, True], dtype=bool) + with pytest.raises(TypeError, match=msg): + np.array([True, False, True], dtype=bool) * ser + + +def test_add(any_string_dtype, request): + dtype = any_string_dtype + if dtype == object: + mark = pytest.mark.xfail( + reason="Need to update expected for numpy object dtype" + ) + request.applymarker(mark) + + a = Series(["a", "b", "c", None, None], dtype=dtype) + b = Series(["x", "y", None, "z", None], dtype=dtype) + + result = a + b + expected = Series(["ax", "by", None, None, None], dtype=dtype) + tm.assert_series_equal(result, expected) + + result = a.add(b) + tm.assert_series_equal(result, expected) + + result = a.radd(b) + expected = Series(["xa", "yb", None, None, None], dtype=dtype) + tm.assert_series_equal(result, expected) + + result = a.add(b, fill_value="-") + expected = Series(["ax", "by", "c-", "-z", None], dtype=dtype) + tm.assert_series_equal(result, expected) + + +def test_add_2d(any_string_dtype, request): + dtype = any_string_dtype + + if dtype == object or dtype.storage == "pyarrow": + reason = "Failed: DID NOT RAISE " + mark = pytest.mark.xfail(raises=None, reason=reason) + request.applymarker(mark) + + a = pd.array(["a", "b", "c"], dtype=dtype) + b = np.array([["a", "b", "c"]], dtype=object) + with pytest.raises(ValueError, match="3 != 1"): + a + b + + s = Series(a) + with pytest.raises(ValueError, match="3 != 1"): + s + b + + +def test_add_sequence(any_string_dtype, request, using_infer_string): + dtype = any_string_dtype + if ( + dtype != object + and dtype.storage == "python" + and dtype.na_value is np.nan + and HAS_PYARROW + and using_infer_string + ): + mark = pytest.mark.xfail( + reason="As of GH#62522, the list gets wrapped with sanitize_array, " + "which casts to a higher-priority StringArray, so we get " + "NotImplemented." + ) + request.applymarker(mark) + if dtype == np.dtype(object) and using_infer_string: + mark = pytest.mark.xfail(reason="Cannot broadcast list") + request.applymarker(mark) + + a = pd.array(["a", "b", None, None], dtype=dtype) + other = ["x", None, "y", None] + + result = a + other + expected = pd.array(["ax", None, None, None], dtype=dtype) + tm.assert_extension_array_equal(result, expected) + + result = other + a + expected = pd.array(["xa", None, None, None], dtype=dtype) + tm.assert_extension_array_equal(result, expected) + + +def test_mul(any_string_dtype): + dtype = any_string_dtype + a = pd.array(["a", "b", None], dtype=dtype) + result = a * 2 + expected = pd.array(["aa", "bb", None], dtype=dtype) + tm.assert_extension_array_equal(result, expected) + + result = 2 * a + tm.assert_extension_array_equal(result, expected) + + +def test_add_strings(any_string_dtype, request): + dtype = any_string_dtype + if dtype != np.dtype(object): + mark = pytest.mark.xfail(reason="GH-28527") + request.applymarker(mark) + arr = pd.array(["a", "b", "c", "d"], dtype=dtype) + df = pd.DataFrame([["t", "y", "v", "w"]], dtype=object) + assert arr.__add__(df) is NotImplemented + + result = arr + df + expected = pd.DataFrame([["at", "by", "cv", "dw"]]).astype(dtype) + tm.assert_frame_equal(result, expected) + + result = df + arr + expected = pd.DataFrame([["ta", "yb", "vc", "wd"]]).astype(dtype) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.xfail(reason="GH-28527") +def test_add_frame(dtype): + arr = pd.array(["a", "b", np.nan, np.nan], dtype=dtype) + df = pd.DataFrame([["x", np.nan, "y", np.nan]]) + + assert arr.__add__(df) is NotImplemented + + result = arr + df + expected = pd.DataFrame([["ax", np.nan, np.nan, np.nan]]).astype(dtype) + tm.assert_frame_equal(result, expected) + + result = df + arr + expected = pd.DataFrame([["xa", np.nan, np.nan, np.nan]]).astype(dtype) + tm.assert_frame_equal(result, expected) + + +def test_comparison_methods_scalar(comparison_op, any_string_dtype): + dtype = any_string_dtype + op_name = f"__{comparison_op.__name__}__" + a = pd.array(["a", None, "c"], dtype=dtype) + other = "a" + result = getattr(a, op_name)(other) + if dtype == object or dtype.na_value is np.nan: + expected = np.array([getattr(item, op_name)(other) for item in a]) + if comparison_op == operator.ne: + expected[1] = True + else: + expected[1] = False + result = extract_array(result, extract_numpy=True) + tm.assert_numpy_array_equal(result, expected.astype(np.bool_)) + else: + expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean" + expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object) + expected = pd.array(expected, dtype=expected_dtype) + tm.assert_extension_array_equal(result, expected) + + +def test_comparison_methods_scalar_pd_na(comparison_op, any_string_dtype): + dtype = any_string_dtype + op_name = f"__{comparison_op.__name__}__" + a = pd.array(["a", None, "c"], dtype=dtype) + result = getattr(a, op_name)(NA) + + if dtype == np.dtype(object) or dtype.na_value is np.nan: + if operator.ne == comparison_op: + expected = np.array([True, True, True]) + else: + expected = np.array([False, False, False]) + result = extract_array(result, extract_numpy=True) + tm.assert_numpy_array_equal(result, expected) + else: + expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean" + expected = pd.array([None, None, None], dtype=expected_dtype) + tm.assert_extension_array_equal(result, expected) + tm.assert_extension_array_equal(result, expected) + + +def test_comparison_methods_scalar_not_string(comparison_op, any_string_dtype): + op_name = f"__{comparison_op.__name__}__" + dtype = any_string_dtype + + a = pd.array(["a", None, "c"], dtype=dtype) + other = 42 + + if op_name not in ["__eq__", "__ne__"]: + with pytest.raises(TypeError, match="Invalid comparison|not supported between"): + getattr(a, op_name)(other) + + return + + result = getattr(a, op_name)(other) + result = extract_array(result, extract_numpy=True) + + if dtype == np.dtype(object) or dtype.na_value is np.nan: + expected_data = { + "__eq__": [False, False, False], + "__ne__": [True, True, True], + }[op_name] + expected = np.array(expected_data) + tm.assert_numpy_array_equal(result, expected) + else: + expected_data = {"__eq__": [False, None, False], "__ne__": [True, None, True]}[ + op_name + ] + expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean" + expected = pd.array(expected_data, dtype=expected_dtype) + tm.assert_extension_array_equal(result, expected) + + +def test_comparison_methods_array(comparison_op, any_string_dtype, any_string_dtype2): + op_name = f"__{comparison_op.__name__}__" + dtype = any_string_dtype + dtype2 = any_string_dtype2 + + a = pd.array(["a", None, "c"], dtype=dtype) + other = pd.array([None, None, "c"], dtype=dtype2) + result = comparison_op(a, other) + result = extract_array(result, extract_numpy=True) + + # ensure operation is commutative + result2 = comparison_op(other, a) + result2 = extract_array(result2, extract_numpy=True) + tm.assert_equal(result, result2) + + if (dtype == object or dtype.na_value is np.nan) and ( + dtype2 == object or dtype2.na_value is np.nan + ): + if operator.ne == comparison_op: + expected = np.array([True, True, False]) + else: + expected = np.array([False, False, False]) + expected[-1] = getattr(other[-1], op_name)(a[-1]) + result = extract_array(result, extract_numpy=True) + tm.assert_numpy_array_equal(result, expected) + + else: + if dtype == object: + max_dtype = dtype2 + elif dtype2 == object: + max_dtype = dtype + else: + max_dtype = string_dtype_highest_priority(dtype, dtype2) + if max_dtype.storage == "python": + expected_dtype = "boolean" + else: + expected_dtype = "bool[pyarrow]" + + expected = np.full(len(a), fill_value=None, dtype="object") + expected[-1] = getattr(other[-1], op_name)(a[-1]) + expected = pd.array(expected, dtype=expected_dtype) + tm.assert_equal(result, expected) + + +@td.skip_if_no("pyarrow") +def test_comparison_methods_array_arrow_extension(comparison_op, any_string_dtype): + # Test pd.ArrowDtype(pa.string()) against other string arrays + import pyarrow as pa + + dtype2 = any_string_dtype + + op_name = f"__{comparison_op.__name__}__" + dtype = ArrowDtype(pa.string()) + a = pd.array(["a", None, "c"], dtype=dtype) + other = pd.array([None, None, "c"], dtype=dtype2) + result = comparison_op(a, other) + + # ensure operation is commutative + result2 = comparison_op(other, a) + tm.assert_equal(result, result2) + + expected = pd.array([None, None, True], dtype="bool[pyarrow]") + expected[-1] = getattr(other[-1], op_name)(a[-1]) + tm.assert_extension_array_equal(result, expected) + + +@pytest.mark.parametrize("box", [pd.array, pd.Index, Series]) +def test_comparison_methods_list(comparison_op, any_string_dtype, box, request): + dtype = any_string_dtype + + if box is pd.array and dtype != object and dtype.na_value is np.nan: + mark = pytest.mark.xfail( + reason="After wrapping list, op returns NotImplemented, see GH#62522" + ) + request.applymarker(mark) + + op_name = f"__{comparison_op.__name__}__" + + a = box(pd.array(["a", None, "c"], dtype=dtype)) + item = "c" + other = [None, None, "c"] + result = comparison_op(a, other) + + # ensure operation is commutative + result2 = comparison_op(other, a) + tm.assert_equal(result, result2) + + if dtype == np.dtype(object) or dtype.na_value is np.nan: + if operator.ne == comparison_op: + expected = np.array([True, True, False]) + else: + expected = np.array([False, False, False]) + expected[-1] = getattr(item, op_name)(item) + if box is not pd.Index: + # if GH#62766 is addressed this check can be removed + expected = box(expected, dtype=expected.dtype) + tm.assert_equal(result, expected) + + else: + expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean" + expected = np.full(len(a), fill_value=None, dtype="object") + expected[-1] = getattr(item, op_name)(item) + expected = pd.array(expected, dtype=expected_dtype) + expected = extract_array(expected, extract_numpy=True) + if box is not pd.Index: + # if GH#62766 is addressed this check can be removed + expected = tm.box_expected(expected, box) + tm.assert_equal(result, expected) diff --git a/pandas/tests/arithmetic/test_timedelta64.py b/pandas/tests/arithmetic/test_timedelta64.py new file mode 100644 index 0000000000000000000000000000000000000000..89a9148bed5575f24b99c0446f85fe4826062518 --- /dev/null +++ b/pandas/tests/arithmetic/test_timedelta64.py @@ -0,0 +1,2331 @@ +# Arithmetic tests for DataFrame/Series/Index/Array classes that should +# behave identically. +from datetime import ( + datetime, + timedelta, +) + +import numpy as np +import pytest + +from pandas._libs.tslibs import timezones +from pandas.compat import WASM +from pandas.errors import OutOfBoundsDatetime +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + DatetimeIndex, + Index, + NaT, + Series, + Timedelta, + TimedeltaIndex, + Timestamp, + offsets, + timedelta_range, +) +import pandas._testing as tm +from pandas.core.arrays import NumpyExtensionArray +from pandas.tests.arithmetic.common import ( + assert_invalid_addsub_type, + assert_invalid_comparison, + get_upcast_box, +) + + +def assert_dtype(obj, expected_dtype): + """ + Helper to check the dtype for a Series, Index, or single-column DataFrame. + """ + dtype = tm.get_dtype(obj) + + assert dtype == expected_dtype + + +def get_expected_name(box, names): + if box is DataFrame: + # Since we are operating with a DataFrame and a non-DataFrame, + # the non-DataFrame is cast to Series and its name ignored. + exname = names[0] + elif box in [tm.to_array, pd.array]: + exname = names[1] + else: + exname = names[2] + return exname + + +# ------------------------------------------------------------------ +# Timedelta64[ns] dtype Comparisons + + +class TestTimedelta64ArrayLikeComparisons: + # Comparison tests for timedelta64[ns] vectors fully parametrized over + # DataFrame/Series/TimedeltaIndex/TimedeltaArray. Ideally all comparison + # tests will eventually end up here. + + def test_compare_timedelta64_zerodim(self, box_with_array): + # GH#26689 should unbox when comparing with zerodim array + box = box_with_array + xbox = box_with_array if box_with_array not in [Index, pd.array] else np.ndarray + + tdi = timedelta_range("2h", periods=4) + other = np.array(tdi.to_numpy()[0]) + + tdi = tm.box_expected(tdi, box) + res = tdi <= other + expected = np.array([True, False, False, False]) + expected = tm.box_expected(expected, xbox) + tm.assert_equal(res, expected) + + @pytest.mark.parametrize( + "td_scalar", + [ + timedelta(days=1), + Timedelta(days=1), + Timedelta(days=1).to_timedelta64(), + offsets.Hour(24), + ], + ) + def test_compare_timedeltalike_scalar(self, box_with_array, td_scalar): + # regression test for GH#5963 + box = box_with_array + xbox = box if box not in [Index, pd.array] else np.ndarray + + ser = Series([timedelta(days=1), timedelta(days=2)]) + ser = tm.box_expected(ser, box) + actual = ser > td_scalar + expected = Series([False, True]) + expected = tm.box_expected(expected, xbox) + tm.assert_equal(actual, expected) + + @pytest.mark.parametrize( + "invalid", + [ + 345600000000000, + "a", + Timestamp("2021-01-01"), + Timestamp("2021-01-01").now("UTC"), + Timestamp("2021-01-01").now().to_datetime64(), + Timestamp("2021-01-01").now().to_pydatetime(), + Timestamp("2021-01-01").date(), + np.array(4), # zero-dim mismatched dtype + ], + ) + def test_td64_comparisons_invalid(self, box_with_array, invalid): + # GH#13624 for str + box = box_with_array + + rng = timedelta_range("1 days", periods=10) + obj = tm.box_expected(rng, box) + + assert_invalid_comparison(obj, invalid, box) + + @pytest.mark.parametrize( + "other", + [ + list(range(10)), + np.arange(10), + np.arange(10).astype(np.float32), + np.arange(10).astype(object), + pd.date_range("1970-01-01", periods=10, tz="UTC").array, + np.array(pd.date_range("1970-01-01", periods=10)), + list(pd.date_range("1970-01-01", periods=10)), + pd.date_range("1970-01-01", periods=10).astype(object), + pd.period_range("1971-01-01", freq="D", periods=10).array, + pd.period_range("1971-01-01", freq="D", periods=10).astype(object), + ], + ) + def test_td64arr_cmp_arraylike_invalid(self, other, box_with_array): + # We don't parametrize this over box_with_array because listlike + # other plays poorly with assert_invalid_comparison reversed checks + + rng = timedelta_range("1 days", periods=10)._data + rng = tm.box_expected(rng, box_with_array) + assert_invalid_comparison(rng, other, box_with_array) + + def test_td64arr_cmp_mixed_invalid(self): + rng = timedelta_range("1 days", periods=5)._data + other = np.array([0, 1, 2, rng[3], Timestamp("2021-01-01")]) + + result = rng == other + expected = np.array([False, False, False, True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = rng != other + tm.assert_numpy_array_equal(result, ~expected) + + msg = "Invalid comparison between|Cannot compare type|not supported between" + with pytest.raises(TypeError, match=msg): + rng < other + with pytest.raises(TypeError, match=msg): + rng > other + with pytest.raises(TypeError, match=msg): + rng <= other + with pytest.raises(TypeError, match=msg): + rng >= other + + +class TestTimedelta64ArrayComparisons: + # TODO: All of these need to be parametrized over box + + @pytest.mark.parametrize("dtype", [None, object]) + def test_comp_nat(self, dtype): + left = TimedeltaIndex([Timedelta("1 days"), NaT, Timedelta("3 days")]) + right = TimedeltaIndex([NaT, NaT, Timedelta("3 days")]) + + lhs, rhs = left, right + if dtype is object: + lhs, rhs = left.astype(object), right.astype(object) + + result = rhs == lhs + expected = np.array([False, False, True]) + tm.assert_numpy_array_equal(result, expected) + + result = rhs != lhs + expected = np.array([True, True, False]) + tm.assert_numpy_array_equal(result, expected) + + expected = np.array([False, False, False]) + tm.assert_numpy_array_equal(lhs == NaT, expected) + tm.assert_numpy_array_equal(NaT == rhs, expected) + + expected = np.array([True, True, True]) + tm.assert_numpy_array_equal(lhs != NaT, expected) + tm.assert_numpy_array_equal(NaT != lhs, expected) + + expected = np.array([False, False, False]) + tm.assert_numpy_array_equal(lhs < NaT, expected) + tm.assert_numpy_array_equal(NaT > lhs, expected) + + @pytest.mark.parametrize( + "idx2", + [ + TimedeltaIndex( + ["2 day", "2 day", NaT, NaT, "1 day 00:00:02", "5 days 00:00:03"] + ), + np.array( + [ + np.timedelta64(2, "D"), + np.timedelta64(2, "D"), + np.timedelta64("nat"), + np.timedelta64("nat"), + np.timedelta64(1, "D") + np.timedelta64(2, "s"), + np.timedelta64(5, "D") + np.timedelta64(3, "s"), + ] + ), + ], + ) + def test_comparisons_nat(self, idx2): + idx1 = TimedeltaIndex( + [ + "1 day", + NaT, + "1 day 00:00:01", + NaT, + "1 day 00:00:01", + "5 day 00:00:03", + ] + ) + # Check pd.NaT is handles as the same as np.nan + result = idx1 < idx2 + expected = np.array([True, False, False, False, True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = idx2 > idx1 + expected = np.array([True, False, False, False, True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = idx1 <= idx2 + expected = np.array([True, False, False, False, True, True]) + tm.assert_numpy_array_equal(result, expected) + + result = idx2 >= idx1 + expected = np.array([True, False, False, False, True, True]) + tm.assert_numpy_array_equal(result, expected) + + result = idx1 == idx2 + expected = np.array([False, False, False, False, False, True]) + tm.assert_numpy_array_equal(result, expected) + + result = idx1 != idx2 + expected = np.array([True, True, True, True, True, False]) + tm.assert_numpy_array_equal(result, expected) + + # TODO: better name + def test_comparisons_coverage(self): + rng = timedelta_range("1 days", periods=10) + + result = rng < rng[3] + expected = np.array([True, True, True] + [False] * 7) + tm.assert_numpy_array_equal(result, expected) + + result = rng == list(rng) + exp = rng == rng + tm.assert_numpy_array_equal(result, exp) + + +# ------------------------------------------------------------------ +# Timedelta64[ns] dtype Arithmetic Operations + + +class TestTimedelta64ArithmeticUnsorted: + # Tests moved from type-specific test files but not + # yet sorted/parametrized/de-duplicated + + def test_td64_op_with_list(self, box_with_array): + # GH#62353 + box = box_with_array + + left = TimedeltaIndex(["2D", "4D"]) + left = tm.box_expected(left, box) + + right = [Timestamp("2016-01-01"), Timestamp("2016-02-01")] + + result = left + right + expected = DatetimeIndex(["2016-01-03", "2016-02-05"], dtype="M8[us]") + expected = tm.box_expected(expected, box) + tm.assert_equal(result, expected) + + result2 = right + left + tm.assert_equal(result2, expected) + + def test_ufunc_coercions(self): + # normal ops are also tested in tseries/test_timedeltas.py + idx = TimedeltaIndex(["2h", "4h", "6h", "8h", "10h"], freq="2h", name="x") + + for result in [idx * 2, np.multiply(idx, 2)]: + assert isinstance(result, TimedeltaIndex) + exp = TimedeltaIndex(["4h", "8h", "12h", "16h", "20h"], freq="4h", name="x") + tm.assert_index_equal(result, exp) + assert result.freq == "4h" + + for result in [idx / 2, np.divide(idx, 2)]: + assert isinstance(result, TimedeltaIndex) + exp = TimedeltaIndex(["1h", "2h", "3h", "4h", "5h"], freq="h", name="x") + tm.assert_index_equal(result, exp) + assert result.freq == "h" + + for result in [-idx, np.negative(idx)]: + assert isinstance(result, TimedeltaIndex) + exp = TimedeltaIndex( + ["-2h", "-4h", "-6h", "-8h", "-10h"], freq="-2h", name="x" + ) + tm.assert_index_equal(result, exp) + assert result.freq == "-2h" + + idx = TimedeltaIndex(["-2h", "-1h", "0h", "1h", "2h"], freq="h", name="x") + for result in [abs(idx), np.absolute(idx)]: + assert isinstance(result, TimedeltaIndex) + exp = TimedeltaIndex(["2h", "1h", "0h", "1h", "2h"], freq=None, name="x") + tm.assert_index_equal(result, exp) + assert result.freq is None + + def test_subtraction_ops(self): + # with datetimes/timedelta and tdi/dti + tdi = TimedeltaIndex(["1 days", NaT, "2 days"], name="foo") + dti = pd.date_range("20130101", periods=3, name="bar") + td = Timedelta("1 days") + dt = Timestamp("20130101") + + msg = "cannot subtract a datelike from a TimedeltaArray" + with pytest.raises(TypeError, match=msg): + tdi - dt + with pytest.raises(TypeError, match=msg): + tdi - dti + + msg = r"unsupported operand type\(s\) for -" + with pytest.raises(TypeError, match=msg): + td - dt + + msg = "cannot subtract DatetimeArray from Timedelta" + with pytest.raises(TypeError, match=msg): + td - dti + + result = dt - dti + expected = TimedeltaIndex(["0 days", "-1 days", "-2 days"], name="bar") + tm.assert_index_equal(result, expected) + + result = dti - dt + expected = TimedeltaIndex(["0 days", "1 days", "2 days"], name="bar") + tm.assert_index_equal(result, expected) + + result = tdi - td + expected = TimedeltaIndex(["0 days", NaT, "1 days"], name="foo") + tm.assert_index_equal(result, expected) + + result = td - tdi + expected = TimedeltaIndex(["0 days", NaT, "-1 days"], name="foo") + tm.assert_index_equal(result, expected) + + result = dti - td + expected = DatetimeIndex( + ["20121231", "20130101", "20130102"], dtype="M8[us]", freq="D", name="bar" + ) + tm.assert_index_equal(result, expected) + + result = dt - tdi + expected = DatetimeIndex( + ["20121231", NaT, "20121230"], dtype="M8[us]", name="foo" + ) + tm.assert_index_equal(result, expected) + + def test_subtraction_ops_with_tz(self, box_with_array): + # check that dt/dti subtraction ops with tz are validated + dti = pd.date_range("20130101", periods=3) + dti = tm.box_expected(dti, box_with_array) + ts = Timestamp("20130101") + dt = ts.to_pydatetime() + dti_tz = pd.date_range("20130101", periods=3).tz_localize("US/Eastern") + dti_tz = tm.box_expected(dti_tz, box_with_array) + ts_tz = Timestamp("20130101").tz_localize("US/Eastern") + ts_tz2 = Timestamp("20130101").tz_localize("CET") + dt_tz = ts_tz.to_pydatetime() + td = Timedelta("1 days") + + def _check(result, expected): + assert result == expected + assert isinstance(result, Timedelta) + + # scalars + result = ts - ts + expected = Timedelta("0 days") + _check(result, expected) + + result = dt_tz - ts_tz + expected = Timedelta("0 days") + _check(result, expected) + + result = ts_tz - dt_tz + expected = Timedelta("0 days") + _check(result, expected) + + # tz mismatches + msg = "Cannot subtract tz-naive and tz-aware datetime-like objects." + with pytest.raises(TypeError, match=msg): + dt_tz - ts + msg = "can't subtract offset-naive and offset-aware datetimes" + with pytest.raises(TypeError, match=msg): + dt_tz - dt + msg = "can't subtract offset-naive and offset-aware datetimes" + with pytest.raises(TypeError, match=msg): + dt - dt_tz + msg = "Cannot subtract tz-naive and tz-aware datetime-like objects." + with pytest.raises(TypeError, match=msg): + ts - dt_tz + with pytest.raises(TypeError, match=msg): + ts_tz2 - ts + with pytest.raises(TypeError, match=msg): + ts_tz2 - dt + + msg = "Cannot subtract tz-naive and tz-aware" + # with dti + with pytest.raises(TypeError, match=msg): + dti - ts_tz + with pytest.raises(TypeError, match=msg): + dti_tz - ts + + result = dti_tz - dt_tz + expected = TimedeltaIndex(["0 days", "1 days", "2 days"]) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(result, expected) + + result = dt_tz - dti_tz + expected = TimedeltaIndex(["0 days", "-1 days", "-2 days"]) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(result, expected) + + result = dti_tz - ts_tz + expected = TimedeltaIndex(["0 days", "1 days", "2 days"]) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(result, expected) + + result = ts_tz - dti_tz + expected = TimedeltaIndex(["0 days", "-1 days", "-2 days"]) + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(result, expected) + + result = td - td + expected = Timedelta("0 days") + _check(result, expected) + + result = dti_tz - td + expected = DatetimeIndex(["20121231", "20130101", "20130102"], tz="US/Eastern") + expected = tm.box_expected(expected, box_with_array) + tm.assert_equal(result, expected) + + def test_dti_tdi_numeric_ops(self): + # These are normally union/diff set-like ops + tdi = TimedeltaIndex(["1 days", NaT, "2 days"], name="foo") + dti = pd.date_range("20130101", periods=3, name="bar") + + result = tdi - tdi + expected = TimedeltaIndex(["0 days", NaT, "0 days"], name="foo") + tm.assert_index_equal(result, expected) + + result = tdi + tdi + expected = TimedeltaIndex(["2 days", NaT, "4 days"], name="foo") + tm.assert_index_equal(result, expected) + + result = dti - tdi # name will be reset + expected = DatetimeIndex(["20121231", NaT, "20130101"], dtype="M8[us]") + tm.assert_index_equal(result, expected) + + def test_addition_ops(self): + # with datetimes/timedelta and tdi/dti + tdi = TimedeltaIndex(["1 days", NaT, "2 days"], name="foo") + dti = pd.date_range("20130101", periods=3, name="bar") + td = Timedelta("1 days") + dt = Timestamp("20130101") + + result = tdi + dt + expected = DatetimeIndex( + ["20130102", NaT, "20130103"], dtype="M8[us]", name="foo" + ) + tm.assert_index_equal(result, expected) + + result = dt + tdi + expected = DatetimeIndex( + ["20130102", NaT, "20130103"], dtype="M8[us]", name="foo" + ) + tm.assert_index_equal(result, expected) + + result = td + tdi + expected = TimedeltaIndex(["2 days", NaT, "3 days"], name="foo") + tm.assert_index_equal(result, expected) + + result = tdi + td + expected = TimedeltaIndex(["2 days", NaT, "3 days"], name="foo") + tm.assert_index_equal(result, expected) + + # unequal length + msg = "cannot add indices of unequal length" + with pytest.raises(ValueError, match=msg): + tdi + dti[0:1] + with pytest.raises(ValueError, match=msg): + tdi[0:1] + dti + + # random indexes + msg = "Addition/subtraction of integers and integer-arrays" + with pytest.raises(TypeError, match=msg): + tdi + Index([1, 2, 3], dtype=np.int64) + + # this is a union! + # FIXME: don't leave commented-out + # pytest.raises(TypeError, lambda : Index([1,2,3]) + tdi) + + result = tdi + dti # name will be reset + expected = DatetimeIndex(["20130102", NaT, "20130105"], dtype="M8[us]") + tm.assert_index_equal(result, expected) + + result = dti + tdi # name will be reset + expected = DatetimeIndex(["20130102", NaT, "20130105"], dtype="M8[us]") + tm.assert_index_equal(result, expected) + + result = dt + td + expected = Timestamp("20130102") + assert result == expected + + result = td + dt + expected = Timestamp("20130102") + assert result == expected + + # TODO: Needs more informative name, probably split up into + # more targeted tests + @pytest.mark.parametrize("freq", ["D", "B"]) + def test_timedelta(self, freq): + index = pd.date_range("1/1/2000", periods=50, freq=freq, unit="ns") + + shifted = index + timedelta(1) + back = shifted + timedelta(-1) + back = back._with_freq("infer") + tm.assert_index_equal(index, back) + + if freq == "D": + expected = pd.tseries.offsets.Day(1) + assert index.freq == expected + assert shifted.freq == expected + assert back.freq == expected + else: # freq == 'B' + assert index.freq == pd.tseries.offsets.BusinessDay(1) + assert shifted.freq is None + assert back.freq == pd.tseries.offsets.BusinessDay(1) + + result = index - timedelta(1) + expected = index + timedelta(-1) + tm.assert_index_equal(result, expected) + + def test_timedelta_tick_arithmetic(self): + # GH#4134, buggy with timedeltas + rng = pd.date_range("2013", "2014") + s = Series(rng) + result1 = rng - offsets.Hour(1) + result2 = DatetimeIndex(s - np.timedelta64(100000000)) + result3 = rng - np.timedelta64(100000000) + result4 = DatetimeIndex(s - offsets.Hour(1)) + + assert result1.freq == rng.freq + result1 = result1._with_freq(None) + tm.assert_index_equal(result1, result4) + + assert result3.freq == rng.freq + result3 = result3._with_freq(None) + tm.assert_index_equal(result2, result3) + + def test_tda_add_sub_index(self): + # Check that TimedeltaArray defers to Index on arithmetic ops + tdi = TimedeltaIndex(["1 days", NaT, "2 days"]) + tda = tdi.array + + dti = pd.date_range("1999-12-31", periods=3, freq="D") + + result = tda + dti + expected = tdi + dti + tm.assert_index_equal(result, expected) + + result = tda + tdi + expected = tdi + tdi + tm.assert_index_equal(result, expected) + + result = tda - tdi + expected = tdi - tdi + tm.assert_index_equal(result, expected) + + def test_tda_add_dt64_object_array( + self, performance_warning, box_with_array, tz_naive_fixture + ): + # Result should be cast back to DatetimeArray + box = box_with_array + + dti = pd.date_range("2016-01-01", periods=3, tz=tz_naive_fixture) + dti = dti._with_freq(None) + tdi = dti - dti + + obj = tm.box_expected(tdi, box) + other = tm.box_expected(dti, box) + + with tm.assert_produces_warning(performance_warning): + result = obj + other.astype(object) + tm.assert_equal(result, other.astype(object)) + + # ------------------------------------------------------------- + # Binary operations TimedeltaIndex and timedelta-like + + def test_tdi_iadd_timedeltalike(self, two_hours, box_with_array): + # only test adding/sub offsets as + is now numeric + rng = timedelta_range("1 days", "10 days") + expected = timedelta_range("1 days 02:00:00", "10 days 02:00:00", freq="D") + if ( + isinstance(two_hours, Timedelta) + and two_hours.unit == "ns" + and box_with_array is not pd.array + ): + # The EA op has to be _actually_ inplace so does not cast to a + # new dtype. For the others, the op can assign a new array + # and get the dtype that normally results from `rng + two_hours` + expected = expected.as_unit("ns") + + rng = tm.box_expected(rng, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + orig_rng = rng + rng += two_hours + tm.assert_equal(rng, expected) + if box_with_array is not Index: + # Check that operation is actually inplace + tm.assert_equal(orig_rng, expected) + + def test_tdi_isub_timedeltalike(self, two_hours, box_with_array): + # only test adding/sub offsets as - is now numeric + rng = timedelta_range("1 days", "10 days") + expected = timedelta_range("0 days 22:00:00", "9 days 22:00:00") + if ( + isinstance(two_hours, Timedelta) + and two_hours.unit == "ns" + and box_with_array is not pd.array + ): + # The EA op has to be _actually_ inplace so does not cast to a + # new dtype. For the others, the op can assign a new array + # and get the dtype that normally results from `rng - two_hours` + expected = expected.as_unit("ns") + + rng = tm.box_expected(rng, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + orig_rng = rng + rng -= two_hours + tm.assert_equal(rng, expected) + if box_with_array is not Index: + # Check that operation is actually inplace + tm.assert_equal(orig_rng, expected) + + # ------------------------------------------------------------- + + def test_tdi_ops_attributes(self): + rng = timedelta_range("2 days", periods=5, freq="2D", name="x") + + result = rng + 1 * rng.freq + exp = timedelta_range("4 days", periods=5, freq="2D", name="x") + tm.assert_index_equal(result, exp) + assert result.freq == "2D" + + result = rng - 2 * rng.freq + exp = timedelta_range("-2 days", periods=5, freq="2D", name="x") + tm.assert_index_equal(result, exp) + assert result.freq == "2D" + + result = rng * 2 + exp = timedelta_range("4 days", periods=5, freq="4D", name="x") + tm.assert_index_equal(result, exp) + assert result.freq == "4D" + + result = rng / 2 + exp = timedelta_range("1 days", periods=5, freq="D", name="x") + tm.assert_index_equal(result, exp) + assert result.freq == "D" + + result = -rng + exp = timedelta_range("-2 days", periods=5, freq="-2D", name="x") + tm.assert_index_equal(result, exp) + assert result.freq == "-2D" + + rng = timedelta_range("-2 days", periods=5, freq="D", name="x") + + result = abs(rng) + exp = TimedeltaIndex( + ["2 days", "1 days", "0 days", "1 days", "2 days"], name="x" + ) + tm.assert_index_equal(result, exp) + assert result.freq is None + + +class TestAddSubNaTMasking: + # TODO: parametrize over boxes + + @pytest.mark.parametrize("str_ts", ["1950-01-01", "1980-01-01"]) + def test_tdarr_add_timestamp_nat_masking(self, box_with_array, str_ts): + # GH#17991 checking for overflow-masking with NaT + tdinat = pd.to_timedelta(["24658 days 11:15:00", "NaT"]) + tdobj = tm.box_expected(tdinat, box_with_array) + + ts = Timestamp(str_ts) + ts_variants = [ + ts, + ts.to_pydatetime(), + ts.to_datetime64().astype("datetime64[ns]"), + ts.to_datetime64().astype("datetime64[D]"), + ] + + for variant in ts_variants: + res = tdobj + variant + if box_with_array is DataFrame: + assert res.iloc[1, 1] is NaT + else: + assert res[1] is NaT + + def test_tdi_add_overflow(self): + # See GH#14068 + # preliminary test scalar analogue of vectorized tests below + # TODO: Make raised error message more informative and test + ts = Timestamp("2000").as_unit("ns") + with pytest.raises(OutOfBoundsDatetime, match="10155196800000000000"): + pd.to_timedelta(106580, "D") + ts + with pytest.raises(OutOfBoundsDatetime, match="10155196800000000000"): + ts + pd.to_timedelta(106580, "D") + + _NaT = NaT._value + 1 + td = pd.to_timedelta([106580], "D").as_unit("ns") + msg = "Overflow in int64 addition" + with pytest.raises(OverflowError, match=msg): + td + Timestamp("2000") + with pytest.raises(OverflowError, match=msg): + Timestamp("2000") + td + with pytest.raises(OverflowError, match=msg): + pd.to_timedelta([_NaT]) - Timedelta("1 days") + with pytest.raises(OverflowError, match=msg): + pd.to_timedelta(["5 days", _NaT]) - Timedelta("1 days") + with pytest.raises(OverflowError, match=msg): + ( + pd.to_timedelta([_NaT, "5 days", "1 hours"]) + - pd.to_timedelta(["7 seconds", _NaT, "4 hours"]) + ) + + # These should not overflow! + exp = TimedeltaIndex([NaT], dtype="m8[us]") + result = pd.to_timedelta([NaT]) - Timedelta("1 days") + tm.assert_index_equal(result, exp) + + exp = TimedeltaIndex(["4 days", NaT]) + result = pd.to_timedelta(["5 days", NaT]) - Timedelta("1 days") + tm.assert_index_equal(result, exp) + + exp = TimedeltaIndex([NaT, NaT, "5 hours"]) + result = pd.to_timedelta([NaT, "5 days", "1 hours"]) + pd.to_timedelta( + ["7 seconds", NaT, "4 hours"] + ) + tm.assert_index_equal(result, exp) + + +class TestTimedeltaArraylikeAddSubOps: + # Tests for timedelta64[ns] __add__, __sub__, __radd__, __rsub__ + + def test_sub_nat_retain_unit(self): + ser = pd.to_timedelta(Series(["00:00:01"])).astype("m8[s]") + + result = ser - NaT + expected = Series([NaT], dtype="m8[s]") + tm.assert_series_equal(result, expected) + + # TODO: moved from tests.indexes.timedeltas.test_arithmetic; needs + # parametrization+de-duplication + def test_timedelta_ops_with_missing_values(self): + # setup + s1 = pd.to_timedelta(Series(["00:00:01"])) + s2 = pd.to_timedelta(Series(["00:00:02"])) + + sn = pd.to_timedelta(Series([NaT], dtype="m8[us]")) + + df1 = DataFrame(["00:00:01"]).apply(pd.to_timedelta) + df2 = DataFrame(["00:00:02"]).apply(pd.to_timedelta) + + dfn = DataFrame([NaT._value]).apply(pd.to_timedelta).astype("m8[us]") + + scalar1 = pd.to_timedelta("00:00:01") + scalar2 = pd.to_timedelta("00:00:02") + timedelta_NaT = pd.to_timedelta("NaT") + + actual = scalar1 + scalar1 + assert actual == scalar2 + actual = scalar2 - scalar1 + assert actual == scalar1 + + actual = s1 + s1 + tm.assert_series_equal(actual, s2) + actual = s2 - s1 + tm.assert_series_equal(actual, s1) + + actual = s1 + scalar1 + tm.assert_series_equal(actual, s2) + actual = scalar1 + s1 + tm.assert_series_equal(actual, s2) + actual = s2 - scalar1 + tm.assert_series_equal(actual, s1) + actual = -scalar1 + s2 + tm.assert_series_equal(actual, s1) + + actual = s1 + timedelta_NaT + tm.assert_series_equal(actual, sn) + actual = timedelta_NaT + s1 + tm.assert_series_equal(actual, sn) + actual = s1 - timedelta_NaT + tm.assert_series_equal(actual, sn) + actual = -timedelta_NaT + s1 + tm.assert_series_equal(actual, sn) + + msg = "unsupported operand type" + with pytest.raises(TypeError, match=msg): + s1 + np.nan + with pytest.raises(TypeError, match=msg): + np.nan + s1 + with pytest.raises(TypeError, match=msg): + s1 - np.nan + with pytest.raises(TypeError, match=msg): + -np.nan + s1 + + actual = s1 + NaT + tm.assert_series_equal(actual, sn) + actual = s2 - NaT + tm.assert_series_equal(actual, sn) + + actual = s1 + df1 + tm.assert_frame_equal(actual, df2) + actual = s2 - df1 + tm.assert_frame_equal(actual, df1) + actual = df1 + s1 + tm.assert_frame_equal(actual, df2) + actual = df2 - s1 + tm.assert_frame_equal(actual, df1) + + actual = df1 + df1 + tm.assert_frame_equal(actual, df2) + actual = df2 - df1 + tm.assert_frame_equal(actual, df1) + + actual = df1 + scalar1 + tm.assert_frame_equal(actual, df2) + actual = df2 - scalar1 + tm.assert_frame_equal(actual, df1) + + actual = df1 + timedelta_NaT + tm.assert_frame_equal(actual, dfn) + actual = df1 - timedelta_NaT + tm.assert_frame_equal(actual, dfn) + + msg = "cannot subtract a datelike from|unsupported operand type" + with pytest.raises(TypeError, match=msg): + df1 + np.nan + with pytest.raises(TypeError, match=msg): + df1 - np.nan + + actual = df1 + NaT # NaT is datetime, not timedelta + tm.assert_frame_equal(actual, dfn) + actual = df1 - NaT + tm.assert_frame_equal(actual, dfn) + + # TODO: moved from tests.series.test_operators, needs splitting, cleanup, + # de-duplication, box-parametrization... + def test_operators_timedelta64(self): + # series ops + v1 = pd.date_range("2012-1-1", periods=3, freq="D", unit="ns") + v2 = pd.date_range("2012-1-2", periods=3, freq="D", unit="ns") + rs = Series(v2) - Series(v1) + xp = Series(1e9 * 3600 * 24, rs.index).astype("int64").astype("timedelta64[ns]") + tm.assert_series_equal(rs, xp) + assert rs.dtype == "timedelta64[ns]" + + df = DataFrame({"A": v1}) + td = Series([timedelta(days=i) for i in range(3)], dtype="m8[ns]") + assert td.dtype == "timedelta64[ns]" + + # series on the rhs + result = df["A"] - df["A"].shift() + assert result.dtype == "timedelta64[ns]" + + result = df["A"] + td + assert result.dtype == "M8[ns]" + + # scalar Timestamp on rhs + maxa = df["A"].max() + assert isinstance(maxa, Timestamp) + + resultb = df["A"] - df["A"].max() + assert resultb.dtype == "timedelta64[ns]" + + # timestamp on lhs + result = resultb + df["A"] + values = [Timestamp("20111230"), Timestamp("20120101"), Timestamp("20120103")] + expected = Series(values, dtype="M8[ns]", name="A") + tm.assert_series_equal(result, expected) + + # datetimes on rhs + result = df["A"] - datetime(2001, 1, 1) + expected = Series( + [timedelta(days=4017 + i) for i in range(3)], name="A", dtype="m8[ns]" + ) + tm.assert_series_equal(result, expected) + assert result.dtype == "m8[ns]" + + d = datetime(2001, 1, 1, 3, 4) + resulta = df["A"] - d + assert resulta.dtype == "m8[ns]" + + # roundtrip + resultb = resulta + d + tm.assert_series_equal(df["A"], resultb) + + # timedeltas on rhs + td = timedelta(days=1) + resulta = df["A"] + td + resultb = resulta - td + tm.assert_series_equal(resultb, df["A"]) + assert resultb.dtype == "M8[ns]" + + # roundtrip + td = timedelta(minutes=5, seconds=3) + resulta = df["A"] + td + resultb = resulta - td + tm.assert_series_equal(df["A"], resultb) + assert resultb.dtype == "M8[ns]" + + # inplace + value = rs[2] + np.timedelta64(timedelta(minutes=5, seconds=1)) + rs[2] += np.timedelta64(timedelta(minutes=5, seconds=1)) + assert rs[2] == value + + def test_timedelta64_ops_nat(self): + # GH 11349 + timedelta_series = Series([NaT, Timedelta("1s")]) + nat_series_dtype_timedelta = Series([NaT, NaT], dtype="timedelta64[us]") + single_nat_dtype_timedelta = Series([NaT], dtype="timedelta64[us]") + + # subtraction + tm.assert_series_equal(timedelta_series - NaT, nat_series_dtype_timedelta) + tm.assert_series_equal(-NaT + timedelta_series, nat_series_dtype_timedelta) + + tm.assert_series_equal( + timedelta_series - single_nat_dtype_timedelta, nat_series_dtype_timedelta + ) + tm.assert_series_equal( + -single_nat_dtype_timedelta + timedelta_series, nat_series_dtype_timedelta + ) + + # addition + tm.assert_series_equal( + nat_series_dtype_timedelta + NaT, nat_series_dtype_timedelta + ) + tm.assert_series_equal( + NaT + nat_series_dtype_timedelta, nat_series_dtype_timedelta + ) + + tm.assert_series_equal( + nat_series_dtype_timedelta + single_nat_dtype_timedelta, + nat_series_dtype_timedelta, + ) + tm.assert_series_equal( + single_nat_dtype_timedelta + nat_series_dtype_timedelta, + nat_series_dtype_timedelta, + ) + + tm.assert_series_equal(timedelta_series + NaT, nat_series_dtype_timedelta) + tm.assert_series_equal(NaT + timedelta_series, nat_series_dtype_timedelta) + + tm.assert_series_equal( + timedelta_series + single_nat_dtype_timedelta, nat_series_dtype_timedelta + ) + tm.assert_series_equal( + single_nat_dtype_timedelta + timedelta_series, nat_series_dtype_timedelta + ) + + tm.assert_series_equal( + nat_series_dtype_timedelta + NaT, nat_series_dtype_timedelta + ) + tm.assert_series_equal( + NaT + nat_series_dtype_timedelta, nat_series_dtype_timedelta + ) + + tm.assert_series_equal( + nat_series_dtype_timedelta + single_nat_dtype_timedelta, + nat_series_dtype_timedelta, + ) + tm.assert_series_equal( + single_nat_dtype_timedelta + nat_series_dtype_timedelta, + nat_series_dtype_timedelta, + ) + + # multiplication + tm.assert_series_equal( + nat_series_dtype_timedelta * 1.0, nat_series_dtype_timedelta + ) + tm.assert_series_equal( + 1.0 * nat_series_dtype_timedelta, nat_series_dtype_timedelta + ) + + tm.assert_series_equal(timedelta_series * 1, timedelta_series) + tm.assert_series_equal(1 * timedelta_series, timedelta_series) + + tm.assert_series_equal(timedelta_series * 1.5, Series([NaT, Timedelta("1.5s")])) + tm.assert_series_equal(1.5 * timedelta_series, Series([NaT, Timedelta("1.5s")])) + + tm.assert_series_equal(timedelta_series * np.nan, nat_series_dtype_timedelta) + tm.assert_series_equal(np.nan * timedelta_series, nat_series_dtype_timedelta) + + # division + tm.assert_series_equal(timedelta_series / 2, Series([NaT, Timedelta("0.5s")])) + tm.assert_series_equal(timedelta_series / 2.0, Series([NaT, Timedelta("0.5s")])) + tm.assert_series_equal(timedelta_series / np.nan, nat_series_dtype_timedelta) + + # ------------------------------------------------------------- + # Binary operations td64 arraylike and datetime-like + + @pytest.mark.parametrize("cls", [Timestamp, datetime, np.datetime64]) + def test_td64arr_add_sub_datetimelike_scalar( + self, cls, box_with_array, tz_naive_fixture + ): + # GH#11925, GH#29558, GH#23215 + tz = tz_naive_fixture + + dt_scalar = Timestamp("2012-01-01", tz=tz) + if cls is datetime: + ts = dt_scalar.to_pydatetime() + elif cls is np.datetime64: + if tz_naive_fixture is not None: + pytest.skip(f"{cls} doesn't support {tz_naive_fixture}") + ts = dt_scalar.to_datetime64() + else: + ts = dt_scalar + + tdi = timedelta_range("1 day", periods=3) + expected = pd.date_range("2012-01-02", periods=3, tz=tz) + if tz is not None and not timezones.is_utc(expected.tz): + # Day is no longer preserved by timedelta add/sub in pandas3 because + # it represents Calendar-Day instead of 24h + expected = expected._with_freq(None) + + tdarr = tm.box_expected(tdi, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + tm.assert_equal(ts + tdarr, expected) + tm.assert_equal(tdarr + ts, expected) + + expected2 = pd.date_range("2011-12-31", periods=3, freq="-1D", tz=tz) + if tz is not None and not timezones.is_utc(expected2.tz): + # Day is no longer preserved by timedelta add/sub in pandas3 because + # it represents Calendar-Day instead of 24h + expected2 = expected2._with_freq(None) + expected2 = tm.box_expected(expected2, box_with_array) + + tm.assert_equal(ts - tdarr, expected2) + tm.assert_equal(ts + (-tdarr), expected2) + + msg = "cannot subtract a datelike" + with pytest.raises(TypeError, match=msg): + tdarr - ts + + def test_td64arr_add_datetime64_nat(self, box_with_array): + # GH#23215 + other = np.datetime64("NaT") + + tdi = timedelta_range("1 day", periods=3) + expected = DatetimeIndex(["NaT", "NaT", "NaT"], dtype="M8[us]") + + tdser = tm.box_expected(tdi, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + tm.assert_equal(tdser + other, expected) + tm.assert_equal(other + tdser, expected) + + def test_td64arr_sub_dt64_array(self, box_with_array): + dti = pd.date_range("2016-01-01", periods=3) + tdi = TimedeltaIndex(["-1 Day"] * 3) + dtarr = dti.values + expected = DatetimeIndex(dtarr) - tdi + + tdi = tm.box_expected(tdi, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + msg = "cannot subtract a datelike from" + with pytest.raises(TypeError, match=msg): + tdi - dtarr + + # TimedeltaIndex.__rsub__ + result = dtarr - tdi + tm.assert_equal(result, expected) + + def test_td64arr_add_dt64_array(self, box_with_array): + dti = pd.date_range("2016-01-01", periods=3) + tdi = TimedeltaIndex(["-1 Day"] * 3) + dtarr = dti.values + expected = DatetimeIndex(dtarr) + tdi + + tdi = tm.box_expected(tdi, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = tdi + dtarr + tm.assert_equal(result, expected) + result = dtarr + tdi + tm.assert_equal(result, expected) + + # ------------------------------------------------------------------ + # Invalid __add__/__sub__ operations + + @pytest.mark.parametrize("pi_freq", ["D", "W", "Q", "h"]) + @pytest.mark.parametrize("tdi_freq", [None, "h"]) + def test_td64arr_sub_periodlike( + self, box_with_array, box_with_array2, tdi_freq, pi_freq + ): + # GH#20049 subtracting PeriodIndex should raise TypeError + tdi = TimedeltaIndex(["1 hours", "2 hours"], freq=tdi_freq) + dti = Timestamp("2018-03-07 17:16:40") + tdi + pi = dti.to_period(pi_freq) + per = pi[0] + + tdi = tm.box_expected(tdi, box_with_array) + pi = tm.box_expected(pi, box_with_array2) + msg = "|".join( + [ + "cannot subtract", + "unsupported operand type", + r"bad operand type for unary \-: 'PeriodArray'", + r"Input has different freq=-1h from PeriodArray\(.*\)", + "Cannot add/subtract timedelta-like from PeriodArray", + ] + ) + with pytest.raises(TypeError, match=msg): + tdi - pi + + # GH#13078 subtraction of Period scalar not supported + with pytest.raises(TypeError, match=msg): + tdi - per + + @pytest.mark.parametrize( + "other", + [ + # GH#12624 for str case + "a", + # GH#19123 + 1, + 1.5, + np.array(2), + ], + ) + def test_td64arr_addsub_numeric_scalar_invalid(self, box_with_array, other): + # vector-like others are tested in test_td64arr_add_sub_numeric_arr_invalid + tdser = Series(["59 Days", "59 Days", "NaT"], dtype="m8[ns]") + tdarr = tm.box_expected(tdser, box_with_array) + + assert_invalid_addsub_type(tdarr, other) + + @pytest.mark.parametrize( + "vec", + [ + np.array([1, 2, 3]), + Index([1, 2, 3]), + Series([1, 2, 3]), + DataFrame([[1, 2, 3]]), + ], + ids=lambda x: type(x).__name__, + ) + def test_td64arr_addsub_numeric_arr_invalid( + self, box_with_array, vec, any_real_numpy_dtype + ): + tdser = Series(["59 Days", "59 Days", "NaT"], dtype="m8[ns]") + tdarr = tm.box_expected(tdser, box_with_array) + + vector = vec.astype(any_real_numpy_dtype) + assert_invalid_addsub_type(tdarr, vector) + + def test_td64arr_add_sub_int(self, box_with_array, one): + # Variants of `one` for #19012, deprecated GH#22535 + rng = timedelta_range("1 days 09:00:00", freq="h", periods=10) + tdarr = tm.box_expected(rng, box_with_array) + + msg = "Addition/subtraction of integers" + assert_invalid_addsub_type(tdarr, one, msg) + + # TODO: get inplace ops into assert_invalid_addsub_type + with pytest.raises(TypeError, match=msg): + tdarr += one + with pytest.raises(TypeError, match=msg): + tdarr -= one + + def test_td64arr_add_sub_integer_array(self, box_with_array): + # GH#19959, deprecated GH#22535 + # GH#22696 for DataFrame case, check that we don't dispatch to numpy + # implementation, which treats int64 as m8[ns] + box = box_with_array + xbox = np.ndarray if box is pd.array else box + + rng = timedelta_range("1 days 09:00:00", freq="h", periods=3) + tdarr = tm.box_expected(rng, box) + other = tm.box_expected([4, 3, 2], xbox) + + msg = "Addition/subtraction of integers and integer-arrays" + assert_invalid_addsub_type(tdarr, other, msg) + + def test_td64arr_addsub_integer_array_no_freq(self, box_with_array): + # GH#19959 + box = box_with_array + xbox = np.ndarray if box is pd.array else box + + tdi = TimedeltaIndex(["1 Day", "NaT", "3 Hours"]) + tdarr = tm.box_expected(tdi, box) + other = tm.box_expected([14, -1, 16], xbox) + + msg = "Addition/subtraction of integers" + assert_invalid_addsub_type(tdarr, other, msg) + + # ------------------------------------------------------------------ + # Operations with timedelta-like others + + def test_td64arr_add_sub_td64_array(self, box_with_array): + box = box_with_array + dti = pd.date_range("2016-01-01", periods=3) + tdi = dti - dti.shift(1) + tdarr = tdi.values + + expected = 2 * tdi + tdi = tm.box_expected(tdi, box) + expected = tm.box_expected(expected, box) + + result = tdi + tdarr + tm.assert_equal(result, expected) + result = tdarr + tdi + tm.assert_equal(result, expected) + + expected_sub = 0 * tdi + result = tdi - tdarr + tm.assert_equal(result, expected_sub) + result = tdarr - tdi + tm.assert_equal(result, expected_sub) + + def test_td64arr_add_sub_tdi(self, box_with_array, names): + # GH#17250 make sure result dtype is correct + # GH#19043 make sure names are propagated correctly + box = box_with_array + exname = get_expected_name(box, names) + + tdi = TimedeltaIndex(["0 days", "1 day"], name=names[1]) + tdi = np.array(tdi) if box in [tm.to_array, pd.array] else tdi + ser = Series( + [Timedelta(hours=3), Timedelta(hours=4)], name=names[0], dtype="m8[ns]" + ) + expected = Series( + [Timedelta(hours=3), Timedelta(days=1, hours=4)], + name=exname, + dtype="m8[ns]", + ) + + ser = tm.box_expected(ser, box) + expected = tm.box_expected(expected, box) + + result = tdi + ser + tm.assert_equal(result, expected) + assert_dtype(result, "timedelta64[ns]") + + result = ser + tdi + tm.assert_equal(result, expected) + assert_dtype(result, "timedelta64[ns]") + + expected = Series( + [Timedelta(hours=-3), Timedelta(days=1, hours=-4)], + name=exname, + dtype="m8[ns]", + ) + expected = tm.box_expected(expected, box) + + result = tdi - ser + tm.assert_equal(result, expected) + assert_dtype(result, "timedelta64[ns]") + + result = ser - tdi + tm.assert_equal(result, -expected) + assert_dtype(result, "timedelta64[ns]") + + @pytest.mark.parametrize("tdnat", [np.timedelta64("NaT"), NaT]) + def test_td64arr_add_sub_td64_nat(self, box_with_array, tdnat): + # GH#18808, GH#23320 special handling for timedelta64("NaT") + box = box_with_array + tdi = TimedeltaIndex([NaT, Timedelta("1s")]) + expected = TimedeltaIndex(["NaT"] * 2).as_unit("us") + + obj = tm.box_expected(tdi, box) + expected = tm.box_expected(expected, box) + + result = obj + tdnat + tm.assert_equal(result, expected) + result = tdnat + obj + tm.assert_equal(result, expected) + result = obj - tdnat + tm.assert_equal(result, expected) + result = tdnat - obj + tm.assert_equal(result, expected) + + def test_td64arr_add_timedeltalike(self, two_hours, box_with_array): + # only test adding/sub offsets as + is now numeric + # GH#10699 for Tick cases + box = box_with_array + rng = timedelta_range("1 days", "10 days") + expected = timedelta_range("1 days 02:00:00", "10 days 02:00:00", freq="D") + if isinstance(two_hours, Timedelta) and two_hours.unit == "ns": + expected = expected.as_unit("ns") + + rng = tm.box_expected(rng, box) + expected = tm.box_expected(expected, box) + + result = rng + two_hours + tm.assert_equal(result, expected) + + result = two_hours + rng + tm.assert_equal(result, expected) + + def test_td64arr_sub_timedeltalike(self, two_hours, box_with_array): + # only test adding/sub offsets as - is now numeric + # GH#10699 for Tick cases + box = box_with_array + rng = timedelta_range("1 days", "10 days") + expected = timedelta_range("0 days 22:00:00", "9 days 22:00:00") + if isinstance(two_hours, Timedelta) and two_hours.unit == "ns": + expected = expected.as_unit("ns") + + rng = tm.box_expected(rng, box) + expected = tm.box_expected(expected, box) + + result = rng - two_hours + tm.assert_equal(result, expected) + + result = two_hours - rng + tm.assert_equal(result, -expected) + + # ------------------------------------------------------------------ + # __add__/__sub__ with DateOffsets and arrays of DateOffsets + + def test_td64arr_add_sub_offset_index( + self, performance_warning, names, box_with_array + ): + # GH#18849, GH#19744 + box = box_with_array + exname = get_expected_name(box, names) + + tdi = TimedeltaIndex(["1 days 00:00:00", "3 days 04:00:00"], name=names[0]) + other = Index([offsets.Hour(n=1), offsets.Minute(n=-2)], name=names[1]) + other = np.array(other) if box in [tm.to_array, pd.array] else other + + expected = TimedeltaIndex( + [tdi[n] + other[n] for n in range(len(tdi))], freq="infer", name=exname + ) + expected_sub = TimedeltaIndex( + [tdi[n] - other[n] for n in range(len(tdi))], freq="infer", name=exname + ) + + tdi = tm.box_expected(tdi, box) + expected = tm.box_expected(expected, box).astype(object) + expected_sub = tm.box_expected(expected_sub, box).astype(object) + + with tm.assert_produces_warning(performance_warning): + res = tdi + other + tm.assert_equal(res, expected) + + with tm.assert_produces_warning(performance_warning): + res2 = other + tdi + tm.assert_equal(res2, expected) + + with tm.assert_produces_warning(performance_warning): + res_sub = tdi - other + tm.assert_equal(res_sub, expected_sub) + + def test_td64arr_add_sub_offset_array(self, performance_warning, box_with_array): + # GH#18849, GH#18824 + box = box_with_array + tdi = TimedeltaIndex(["1 days 00:00:00", "3 days 04:00:00"]) + other = np.array([offsets.Hour(n=1), offsets.Minute(n=-2)]) + + expected = TimedeltaIndex( + [tdi[n] + other[n] for n in range(len(tdi))], freq="infer" + ) + expected_sub = TimedeltaIndex( + [tdi[n] - other[n] for n in range(len(tdi))], freq="infer" + ) + + tdi = tm.box_expected(tdi, box) + expected = tm.box_expected(expected, box).astype(object) + + with tm.assert_produces_warning(performance_warning): + res = tdi + other + tm.assert_equal(res, expected) + + with tm.assert_produces_warning(performance_warning): + res2 = other + tdi + tm.assert_equal(res2, expected) + + expected_sub = tm.box_expected(expected_sub, box_with_array).astype(object) + with tm.assert_produces_warning(performance_warning): + res_sub = tdi - other + tm.assert_equal(res_sub, expected_sub) + + def test_td64arr_with_offset_series( + self, performance_warning, names, box_with_array + ): + # GH#18849 + box = box_with_array + box2 = Series if box in [Index, tm.to_array, pd.array] else box + exname = get_expected_name(box, names) + + tdi = TimedeltaIndex(["1 days 00:00:00", "3 days 04:00:00"], name=names[0]) + other = Series([offsets.Hour(n=1), offsets.Minute(n=-2)], name=names[1]) + + expected_add = Series( + [tdi[n] + other[n] for n in range(len(tdi))], name=exname, dtype=object + ) + obj = tm.box_expected(tdi, box) + expected_add = tm.box_expected(expected_add, box2).astype(object) + + with tm.assert_produces_warning(performance_warning): + res = obj + other + tm.assert_equal(res, expected_add) + + with tm.assert_produces_warning(performance_warning): + res2 = other + obj + tm.assert_equal(res2, expected_add) + + expected_sub = Series( + [tdi[n] - other[n] for n in range(len(tdi))], name=exname, dtype=object + ) + expected_sub = tm.box_expected(expected_sub, box2).astype(object) + + with tm.assert_produces_warning(performance_warning): + res3 = obj - other + tm.assert_equal(res3, expected_sub) + + @pytest.mark.parametrize("obox", [np.array, Index, Series]) + def test_td64arr_addsub_anchored_offset_arraylike( + self, performance_warning, obox, box_with_array + ): + # GH#18824 + tdi = TimedeltaIndex(["1 days 00:00:00", "3 days 04:00:00"]) + tdi = tm.box_expected(tdi, box_with_array) + + anchored = obox([offsets.MonthEnd(), offsets.Day(n=2)]) + + # addition/subtraction ops with anchored offsets should issue + # a PerformanceWarning and _then_ raise a TypeError. + msg = "has incorrect type|cannot add the type MonthEnd" + with pytest.raises(TypeError, match=msg): + with tm.assert_produces_warning(performance_warning): + tdi + anchored + with pytest.raises(TypeError, match=msg): + with tm.assert_produces_warning(performance_warning): + anchored + tdi + with pytest.raises(TypeError, match=msg): + with tm.assert_produces_warning(performance_warning): + tdi - anchored + with pytest.raises(TypeError, match=msg): + with tm.assert_produces_warning(performance_warning): + anchored - tdi + + # ------------------------------------------------------------------ + # Unsorted + + def test_td64arr_add_sub_object_array(self, performance_warning, box_with_array): + box = box_with_array + xbox = np.ndarray if box is pd.array else box + + tdi = timedelta_range("1 day", periods=3, freq="D") + tdarr = tm.box_expected(tdi, box) + + other = np.array([Timedelta(days=1), offsets.Day(2), Timestamp("2000-01-04")]) + + with tm.assert_produces_warning(performance_warning): + result = tdarr + other + + expected = Index( + [Timedelta(days=2), Timedelta(days=4), Timestamp("2000-01-07")] + ) + expected = tm.box_expected(expected, xbox).astype(object) + tm.assert_equal(result, expected) + + msg = "unsupported operand type|cannot subtract a datelike" + with pytest.raises(TypeError, match=msg): + with tm.assert_produces_warning(performance_warning): + tdarr - other + + with tm.assert_produces_warning(performance_warning): + result = other - tdarr + + expected = Index([Timedelta(0), Timedelta(0), Timestamp("2000-01-01")]) + expected = tm.box_expected(expected, xbox).astype(object) + tm.assert_equal(result, expected) + + +class TestTimedeltaArraylikeMulDivOps: + # Tests for timedelta64[ns] + # __mul__, __rmul__, __div__, __rdiv__, __floordiv__, __rfloordiv__ + + # ------------------------------------------------------------------ + # Multiplication + # organized with scalar others first, then array-like + + def test_td64arr_mul_int(self, box_with_array): + idx = TimedeltaIndex(np.arange(5, dtype="int64")) + idx = tm.box_expected(idx, box_with_array) + + result = idx * 1 + tm.assert_equal(result, idx) + + result = 1 * idx + tm.assert_equal(result, idx) + + def test_td64arr_mul_tdlike_scalar_raises(self, two_hours, box_with_array): + rng = timedelta_range("1 days", "10 days", name="foo") + rng = tm.box_expected(rng, box_with_array) + msg = "|".join( + [ + "argument must be an integer", + "cannot use operands with types dtype", + "Cannot multiply with", + r"unsupported operand type\(s\) for \*", + ] + ) + with pytest.raises(TypeError, match=msg): + rng * two_hours + + def test_tdi_mul_int_array_zerodim(self, box_with_array): + rng5 = np.arange(5, dtype="int64") + idx = TimedeltaIndex(rng5) + expected = TimedeltaIndex(rng5 * 5) + + idx = tm.box_expected(idx, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = idx * np.array(5, dtype="int64") + tm.assert_equal(result, expected) + + def test_tdi_mul_int_array(self, box_with_array): + rng5 = np.arange(5, dtype="int64") + idx = TimedeltaIndex(rng5) + expected = TimedeltaIndex(rng5**2) + + idx = tm.box_expected(idx, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = idx * rng5 + tm.assert_equal(result, expected) + + def test_tdi_mul_int_series(self, box_with_array): + box = box_with_array + xbox = Series if box in [Index, tm.to_array, pd.array] else box + + idx = TimedeltaIndex(np.arange(5, dtype="int64")) + expected = TimedeltaIndex(np.arange(5, dtype="int64") ** 2) + + idx = tm.box_expected(idx, box) + expected = tm.box_expected(expected, xbox) + + result = idx * Series(np.arange(5, dtype="int64")) + tm.assert_equal(result, expected) + + def test_tdi_mul_float_series(self, box_with_array): + box = box_with_array + xbox = Series if box in [Index, tm.to_array, pd.array] else box + + idx = TimedeltaIndex(np.arange(5, dtype="int64")) + idx = tm.box_expected(idx, box) + + rng5f = np.arange(5, dtype="float64") + expected = TimedeltaIndex(rng5f * (rng5f + 1.0)) + expected = tm.box_expected(expected, xbox) + + result = idx * Series(rng5f + 1.0) + tm.assert_equal(result, expected) + + # TODO: Put Series/DataFrame in others? + @pytest.mark.parametrize( + "other", + [ + np.arange(1, 11), + Index(np.arange(1, 11), np.int64), + Index(range(1, 11), np.uint64), + Index(range(1, 11), np.float64), + pd.RangeIndex(1, 11), + ], + ids=lambda x: type(x).__name__, + ) + def test_tdi_rmul_arraylike(self, other, box_with_array): + box = box_with_array + + tdi = TimedeltaIndex(["1 Day"] * 10) + expected = timedelta_range("1 days", "10 days")._with_freq(None) + + tdi = tm.box_expected(tdi, box) + xbox = get_upcast_box(tdi, other) + + expected = tm.box_expected(expected, xbox) + + result = other * tdi + tm.assert_equal(result, expected) + commute = tdi * other + tm.assert_equal(commute, expected) + + def test_td64arr_mul_bool_scalar_raises(self, box_with_array): + # GH#58054 + ser = Series(np.arange(5) * timedelta(hours=1), dtype="m8[ns]") + obj = tm.box_expected(ser, box_with_array) + + msg = r"Cannot multiply 'timedelta64\[ns\]' by bool" + with pytest.raises(TypeError, match=msg): + True * obj + with pytest.raises(TypeError, match=msg): + obj * True + with pytest.raises(TypeError, match=msg): + np.True_ * obj + with pytest.raises(TypeError, match=msg): + obj * np.True_ + + @pytest.mark.parametrize( + "dtype", + [ + bool, + "boolean", + pytest.param("bool[pyarrow]", marks=td.skip_if_no("pyarrow")), + ], + ) + def test_td64arr_mul_bool_raises(self, dtype, box_with_array): + # GH#58054 + ser = Series(np.arange(5) * timedelta(hours=1), dtype="m8[ns]") + obj = tm.box_expected(ser, box_with_array) + + other = Series(np.arange(5) < 0.5, dtype=dtype) + other = tm.box_expected(other, box_with_array) + + msg = r"Cannot multiply 'timedelta64\[ns\]' by bool" + with pytest.raises(TypeError, match=msg): + obj * other + + msg2 = msg.replace("rmul", "mul") + if dtype == "bool[pyarrow]": + # We go through ArrowEA.__mul__ which gives a different message + msg2 = ( + r"operation 'mul' not supported for dtype 'bool\[pyarrow\]' " + r"with dtype 'timedelta64\[ns\]'" + ) + with pytest.raises(TypeError, match=msg2): + other * obj + + @pytest.mark.parametrize( + "dtype", + [ + "Int64", + "Float64", + pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow")), + ], + ) + def test_td64arr_mul_masked(self, dtype, box_with_array): + ser = Series(np.arange(5) * timedelta(hours=1), dtype="m8[ns]") + obj = tm.box_expected(ser, box_with_array) + + other = Series(np.arange(5), dtype=dtype) + other = tm.box_expected(other, box_with_array) + + expected = Series([Timedelta(hours=n**2) for n in range(5)], dtype="m8[ns]") + expected = tm.box_expected(expected, box_with_array) + if dtype == "int64[pyarrow]": + expected = expected.astype("duration[ns][pyarrow]") + + result = obj * other + tm.assert_equal(result, expected) + result = other * obj + tm.assert_equal(result, expected) + + # ------------------------------------------------------------------ + # __div__, __rdiv__ + + def test_td64arr_div_nat_invalid(self, box_with_array): + # don't allow division by NaT (maybe could in the future) + rng = timedelta_range("1 days", "10 days", name="foo") + rng = tm.box_expected(rng, box_with_array) + + with pytest.raises(TypeError, match="unsupported operand type"): + rng / NaT + with pytest.raises(TypeError, match="Cannot divide NaTType by"): + NaT / rng + + dt64nat = np.datetime64("NaT", "ns") + msg = "|".join( + [ + # 'divide' on npdev as of 2021-12-18 + "ufunc '(true_divide|divide)' cannot use operands", + "cannot perform __r?truediv__", + "Cannot divide datetime64 by TimedeltaArray", + ] + ) + with pytest.raises(TypeError, match=msg): + rng / dt64nat + with pytest.raises(TypeError, match=msg): + dt64nat / rng + + def test_td64arr_div_td64nat(self, box_with_array): + # GH#23829 + box = box_with_array + xbox = np.ndarray if box is pd.array else box + + rng = timedelta_range("1 days", "10 days") + rng = tm.box_expected(rng, box) + + other = np.timedelta64("NaT") + + expected = np.array([np.nan] * 10) + expected = tm.box_expected(expected, xbox) + + result = rng / other + tm.assert_equal(result, expected) + + result = other / rng + tm.assert_equal(result, expected) + + def test_td64arr_div_int(self, box_with_array): + idx = TimedeltaIndex(np.arange(5, dtype="int64")) + idx = tm.box_expected(idx, box_with_array) + + result = idx / 1 + tm.assert_equal(result, idx) + + with pytest.raises(TypeError, match="Cannot divide"): + # GH#23829 + 1 / idx + + def test_td64arr_div_tdlike_scalar(self, two_hours, box_with_array): + # GH#20088, GH#22163 ensure DataFrame returns correct dtype + box = box_with_array + xbox = np.ndarray if box is pd.array else box + + rng = timedelta_range("1 days", "10 days", name="foo") + expected = Index((np.arange(10) + 1) * 12, dtype=np.float64, name="foo") + + rng = tm.box_expected(rng, box) + expected = tm.box_expected(expected, xbox) + + result = rng / two_hours + tm.assert_equal(result, expected) + + result = two_hours / rng + expected = 1 / expected + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("m", [1, 3, 10]) + @pytest.mark.parametrize("unit", ["D", "h", "m", "s", "ms", "us", "ns"]) + def test_td64arr_div_td64_scalar(self, m, unit, box_with_array): + box = box_with_array + xbox = np.ndarray if box is pd.array else box + + ser = Series([Timedelta(days=59)] * 3) + ser[2] = np.nan + flat = ser + ser = tm.box_expected(ser, box) + + # op + expected = Series([x / np.timedelta64(m, unit) for x in flat]) + expected = tm.box_expected(expected, xbox) + result = ser / np.timedelta64(m, unit) + tm.assert_equal(result, expected) + + # reverse op + expected = Series([Timedelta(np.timedelta64(m, unit)) / x for x in flat]) + expected = tm.box_expected(expected, xbox) + result = np.timedelta64(m, unit) / ser + tm.assert_equal(result, expected) + + def test_td64arr_div_tdlike_scalar_with_nat(self, two_hours, box_with_array): + box = box_with_array + xbox = np.ndarray if box is pd.array else box + + rng = TimedeltaIndex(["1 days", NaT, "2 days"], name="foo") + expected = Index([12, np.nan, 24], dtype=np.float64, name="foo") + + rng = tm.box_expected(rng, box) + expected = tm.box_expected(expected, xbox) + + result = rng / two_hours + tm.assert_equal(result, expected) + + result = two_hours / rng + expected = 1 / expected + tm.assert_equal(result, expected) + + def test_td64arr_div_td64_ndarray(self, box_with_array): + # GH#22631 + box = box_with_array + xbox = np.ndarray if box is pd.array else box + + rng = TimedeltaIndex(["1 days", NaT, "2 days"]) + expected = Index([12, np.nan, 24], dtype=np.float64) + + rng = tm.box_expected(rng, box) + expected = tm.box_expected(expected, xbox) + + other = np.array([2, 4, 2], dtype="m8[h]") + result = rng / other + tm.assert_equal(result, expected) + + result = rng / tm.box_expected(other, box) + tm.assert_equal(result, expected) + + result = rng / other.astype(object) + tm.assert_equal(result, expected.astype(object)) + + result = rng / list(other) + tm.assert_equal(result, expected) + + # reversed op + expected = 1 / expected + result = other / rng + tm.assert_equal(result, expected) + + result = tm.box_expected(other, box) / rng + tm.assert_equal(result, expected) + + result = other.astype(object) / rng + tm.assert_equal(result, expected) + + result = list(other) / rng + tm.assert_equal(result, expected) + + def test_tdarr_div_length_mismatch(self, box_with_array): + rng = TimedeltaIndex(["1 days", NaT, "2 days"]) + mismatched = [1, 2, 3, 4] + + rng = tm.box_expected(rng, box_with_array) + msg = "Cannot divide vectors|Unable to coerce to Series" + for obj in [mismatched, mismatched[:2]]: + # one shorter, one longer + for other in [obj, np.array(obj), Index(obj)]: + with pytest.raises(ValueError, match=msg): + rng / other + with pytest.raises(ValueError, match=msg): + other / rng + + def test_td64_div_object_mixed_result(self, box_with_array): + # Case where we having a NaT in the result inseat of timedelta64("NaT") + # is misleading + orig = timedelta_range("1 Day", periods=3).insert(1, NaT) + tdi = tm.box_expected(orig, box_with_array, transpose=False) + + other = np.array([orig[0], 1.5, 2.0, orig[2]], dtype=object) + other = tm.box_expected(other, box_with_array, transpose=False) + + res = tdi / other + + expected = Index([1.0, np.timedelta64("NaT", "us"), orig[0], 1.5], dtype=object) + expected = tm.box_expected(expected, box_with_array, transpose=False) + if isinstance(expected, NumpyExtensionArray): + expected = expected.to_numpy() + tm.assert_equal(res, expected) + if box_with_array is DataFrame: + # We have an np.timedelta64(NaT), not pd.NaT + assert isinstance(res.iloc[1, 0], np.timedelta64) + + res = tdi // other + + expected = Index([1, np.timedelta64("NaT", "us"), orig[0], 1], dtype=object) + expected = tm.box_expected(expected, box_with_array, transpose=False) + if isinstance(expected, NumpyExtensionArray): + expected = expected.to_numpy() + tm.assert_equal(res, expected) + if box_with_array is DataFrame: + # We have an np.timedelta64(NaT), not pd.NaT + assert isinstance(res.iloc[1, 0], np.timedelta64) + + # ------------------------------------------------------------------ + # __floordiv__, __rfloordiv__ + + @pytest.mark.skipif(WASM, reason="no fp exception support in wasm") + def test_td64arr_floordiv_td64arr_with_nat(self, box_with_array): + # GH#35529 + box = box_with_array + xbox = np.ndarray if box is pd.array else box + + left = Series([1000, 222330, 30], dtype="timedelta64[ns]") + right = Series([1000, 222330, None], dtype="timedelta64[ns]") + + left = tm.box_expected(left, box) + right = tm.box_expected(right, box) + + expected = np.array([1.0, 1.0, np.nan], dtype=np.float64) + expected = tm.box_expected(expected, xbox) + + with tm.maybe_produces_warning( + RuntimeWarning, box is pd.array, check_stacklevel=False + ): + result = left // right + + tm.assert_equal(result, expected) + + # case that goes through __rfloordiv__ with arraylike + with tm.maybe_produces_warning( + RuntimeWarning, box is pd.array, check_stacklevel=False + ): + result = np.asarray(left) // right + tm.assert_equal(result, expected) + + @pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") + def test_td64arr_floordiv_tdscalar(self, box_with_array, scalar_td): + # GH#18831, GH#19125 + box = box_with_array + xbox = np.ndarray if box is pd.array else box + td = Timedelta("5m3s") # i.e. (scalar_td - 1sec) / 2 + + td1 = Series([td, td, NaT], dtype="m8[ns]") + td1 = tm.box_expected(td1, box, transpose=False) + + expected = Series([0, 0, np.nan]) + expected = tm.box_expected(expected, xbox, transpose=False) + + result = td1 // scalar_td + tm.assert_equal(result, expected) + + # Reversed op + expected = Series([2, 2, np.nan]) + expected = tm.box_expected(expected, xbox, transpose=False) + + result = scalar_td // td1 + tm.assert_equal(result, expected) + + # same thing buts let's be explicit about calling __rfloordiv__ + result = td1.__rfloordiv__(scalar_td) + tm.assert_equal(result, expected) + + def test_td64arr_floordiv_int(self, box_with_array): + idx = TimedeltaIndex(np.arange(5, dtype="int64")) + idx = tm.box_expected(idx, box_with_array) + result = idx // 1 + tm.assert_equal(result, idx) + + pattern = "floor_divide cannot use operands|Cannot divide int by Timedelta*" + with pytest.raises(TypeError, match=pattern): + 1 // idx + + # ------------------------------------------------------------------ + # mod, divmod + # TODO: operations with timedelta-like arrays, numeric arrays, + # reversed ops + + def test_td64arr_mod_tdscalar( + self, performance_warning, box_with_array, three_days + ): + tdi = timedelta_range("1 Day", "9 days") + tdarr = tm.box_expected(tdi, box_with_array) + + expected = TimedeltaIndex(["1 Day", "2 Days", "0 Days"] * 3) + expected = tm.box_expected(expected, box_with_array) + + if isinstance(three_days, offsets.Day): + msg = "unsupported operand type" + with pytest.raises(TypeError, match=msg): + tdarr % three_days + with pytest.raises(TypeError, match=msg): + divmod(tdarr, three_days) + with pytest.raises(TypeError, match=msg): + tdarr // three_days + return + + result = tdarr % three_days + tm.assert_equal(result, expected) + + if box_with_array is DataFrame and isinstance(three_days, pd.DateOffset): + # TODO: making expected be object here a result of DataFrame.__divmod__ + # being defined in a naive way that does not dispatch to the underlying + # array's __divmod__ + expected = expected.astype(object) + else: + performance_warning = False + + with tm.assert_produces_warning(performance_warning): + result = divmod(tdarr, three_days) + + tm.assert_equal(result[1], expected) + tm.assert_equal(result[0], tdarr // three_days) + + def test_td64arr_mod_int(self, box_with_array): + tdi = timedelta_range("1 ns", "10 ns", periods=10) + tdarr = tm.box_expected(tdi, box_with_array) + + expected = TimedeltaIndex(["1 ns", "0 ns"] * 5) + expected = tm.box_expected(expected, box_with_array) + + result = tdarr % 2 + tm.assert_equal(result, expected) + + msg = "Cannot divide int by" + with pytest.raises(TypeError, match=msg): + 2 % tdarr + + result = divmod(tdarr, 2) + tm.assert_equal(result[1], expected) + tm.assert_equal(result[0], tdarr // 2) + + def test_td64arr_rmod_tdscalar(self, box_with_array, three_days): + tdi = timedelta_range("1 Day", "9 days") + tdarr = tm.box_expected(tdi, box_with_array) + + expected = ["0 Days", "1 Day", "0 Days"] + ["3 Days"] * 6 + expected = TimedeltaIndex(expected) + expected = tm.box_expected(expected, box_with_array) + + if isinstance(three_days, offsets.Day): + msg = "Cannot divide Day by TimedeltaArray" + with pytest.raises(TypeError, match=msg): + three_days % tdarr + return + + result = three_days % tdarr + tm.assert_equal(result, expected) + + result = divmod(three_days, tdarr) + tm.assert_equal(result[1], expected) + tm.assert_equal(result[0], three_days // tdarr) + + # ------------------------------------------------------------------ + # Operations with invalid others + + def test_td64arr_mul_tdscalar_invalid(self, box_with_array, scalar_td): + td1 = Series([timedelta(minutes=5, seconds=3)] * 3) + td1.iloc[2] = np.nan + + td1 = tm.box_expected(td1, box_with_array) + + # check that we are getting a TypeError + # with 'operate' (from core/ops.py) for the ops that are not + # defined + pattern = "operate|unsupported|cannot|not supported" + with pytest.raises(TypeError, match=pattern): + td1 * scalar_td + with pytest.raises(TypeError, match=pattern): + scalar_td * td1 + + def test_td64arr_mul_too_short_raises(self, box_with_array): + idx = TimedeltaIndex(np.arange(5, dtype="int64")) + idx = tm.box_expected(idx, box_with_array) + msg = "|".join( + [ + "cannot use operands with types dtype", + "Cannot multiply with unequal lengths", + "Unable to coerce to Series", + ] + ) + with pytest.raises(TypeError, match=msg): + # length check before dtype check + idx * idx[:3] + with pytest.raises(ValueError, match=msg): + idx * np.array([1, 2]) + + def test_td64arr_mul_td64arr_raises(self, box_with_array): + idx = TimedeltaIndex(np.arange(5, dtype="int64")) + idx = tm.box_expected(idx, box_with_array) + msg = "cannot use operands with types dtype" + with pytest.raises(TypeError, match=msg): + idx * idx + + # ------------------------------------------------------------------ + # Operations with numeric others + + def test_td64arr_mul_numeric_scalar(self, box_with_array, one): + # GH#4521 + # divide/multiply by integers + tdser = Series(["59 Days", "59 Days", "NaT"], dtype="m8[ns]") + expected = Series(["-59 Days", "-59 Days", "NaT"], dtype="timedelta64[ns]") + + tdser = tm.box_expected(tdser, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = tdser * (-one) + tm.assert_equal(result, expected) + result = (-one) * tdser + tm.assert_equal(result, expected) + + expected = Series(["118 Days", "118 Days", "NaT"], dtype="timedelta64[ns]") + expected = tm.box_expected(expected, box_with_array) + + result = tdser * (2 * one) + tm.assert_equal(result, expected) + result = (2 * one) * tdser + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("two", [2, 2.0, np.array(2), np.array(2.0)]) + def test_td64arr_div_numeric_scalar(self, box_with_array, two): + # GH#4521 + # divide/multiply by integers + tdser = Series(["59 Days", "59 Days", "NaT"], dtype="m8[ns]") + expected = Series(["29.5D", "29.5D", "NaT"], dtype="timedelta64[ns]") + + tdser = tm.box_expected(tdser, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = tdser / two + tm.assert_equal(result, expected) + + with pytest.raises(TypeError, match="Cannot divide"): + two / tdser + + @pytest.mark.parametrize("two", [2, 2.0, np.array(2), np.array(2.0)]) + def test_td64arr_floordiv_numeric_scalar(self, box_with_array, two): + tdser = Series(["59 Days", "59 Days", "NaT"], dtype="m8[ns]") + expected = Series(["29.5D", "29.5D", "NaT"], dtype="timedelta64[ns]") + + tdser = tm.box_expected(tdser, box_with_array) + expected = tm.box_expected(expected, box_with_array) + + result = tdser // two + tm.assert_equal(result, expected) + + with pytest.raises(TypeError, match="Cannot divide"): + two // tdser + + @pytest.mark.parametrize( + "klass", + [np.array, Index, Series], + ids=lambda x: x.__name__, + ) + def test_td64arr_rmul_numeric_array( + self, + box_with_array, + klass, + any_real_numpy_dtype, + ): + # GH#4521 + # divide/multiply by integers + + vector = klass([20, 30, 40]) + tdser = Series(["59 Days", "59 Days", "NaT"], dtype="m8[ns]") + vector = vector.astype(any_real_numpy_dtype) + + expected = Series(["1180 Days", "1770 Days", "NaT"], dtype="timedelta64[ns]") + + tdser = tm.box_expected(tdser, box_with_array) + xbox = get_upcast_box(tdser, vector) + + expected = tm.box_expected(expected, xbox) + + result = tdser * vector + tm.assert_equal(result, expected) + + result = vector * tdser + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "klass", + [np.array, Index, Series], + ids=lambda x: x.__name__, + ) + def test_td64arr_div_numeric_array( + self, box_with_array, klass, any_real_numpy_dtype + ): + # GH#4521 + # divide/multiply by integers + + vector = klass([20, 30, 40]) + tdser = Series(["59 Days", "59 Days", "NaT"], dtype="m8[ns]") + vector = vector.astype(any_real_numpy_dtype) + + expected = Series(["2.95D", "1D 23h 12m", "NaT"], dtype="timedelta64[ns]") + + tdser = tm.box_expected(tdser, box_with_array) + xbox = get_upcast_box(tdser, vector) + expected = tm.box_expected(expected, xbox) + + result = tdser / vector + tm.assert_equal(result, expected) + + pattern = "|".join( + [ + "true_divide'? cannot use operands", + "cannot perform __div__", + "cannot perform __truediv__", + "unsupported operand", + "Cannot divide", + "ufunc 'divide' cannot use operands with types", + ] + ) + with pytest.raises(TypeError, match=pattern): + vector / tdser + + result = tdser / vector.astype(object) + if box_with_array is DataFrame: + expected = [tdser.iloc[0, n] / vector[n] for n in range(len(vector))] + expected = tm.box_expected(expected, xbox).astype(object) + expected[2] = expected[2].fillna(np.timedelta64("NaT", "ns")) + else: + expected = [tdser[n] / vector[n] for n in range(len(tdser))] + expected = [ + x if x is not NaT else np.timedelta64("NaT", "ns") for x in expected + ] + if xbox is tm.to_array: + expected = tm.to_array(expected).astype(object) + else: + expected = xbox(expected, dtype=object) + + tm.assert_equal(result, expected) + + with pytest.raises(TypeError, match=pattern): + vector.astype(object) / tdser + + def test_td64arr_mul_int_series(self, box_with_array, names): + # GH#19042 test for correct name attachment + box = box_with_array + exname = get_expected_name(box, names) + + tdi = TimedeltaIndex( + ["0days", "1day", "2days", "3days", "4days"], name=names[0] + ) + # TODO: Should we be parametrizing over types for `ser` too? + ser = Series([0, 1, 2, 3, 4], dtype=np.int64, name=names[1]) + + expected = Series( + ["0days", "1day", "4days", "9days", "16days"], + dtype="timedelta64[us]", + name=exname, + ) + + tdi = tm.box_expected(tdi, box) + xbox = get_upcast_box(tdi, ser) + + expected = tm.box_expected(expected, xbox) + + result = ser * tdi + tm.assert_equal(result, expected) + + result = tdi * ser + tm.assert_equal(result, expected) + + # TODO: Should we be parametrizing over types for `ser` too? + def test_float_series_rdiv_td64arr(self, box_with_array, names): + # GH#19042 test for correct name attachment + box = box_with_array + tdi = TimedeltaIndex( + ["0days", "1day", "2days", "3days", "4days"], name=names[0] + ) + ser = Series([1.5, 3, 4.5, 6, 7.5], dtype=np.float64, name=names[1]) + + xname = names[2] if box not in [tm.to_array, pd.array] else names[1] + expected = Series( + [tdi[n] / ser[n] for n in range(len(ser))], + dtype="timedelta64[us]", + name=xname, + ) + + tdi = tm.box_expected(tdi, box) + xbox = get_upcast_box(tdi, ser) + expected = tm.box_expected(expected, xbox) + + result = ser.__rtruediv__(tdi) + if box is DataFrame: + assert result is NotImplemented + else: + tm.assert_equal(result, expected) + + def test_td64arr_all_nat_div_object_dtype_numeric(self, box_with_array): + # GH#39750 make sure we infer the result as td64 + tdi = TimedeltaIndex([NaT, NaT], dtype="m8[ns]") + + left = tm.box_expected(tdi, box_with_array) + right = np.array([2, 2.0], dtype=object) + + tdnat = np.timedelta64("NaT", "ns") + expected = Index([tdnat] * 2, dtype=object) + if box_with_array is not Index: + expected = tm.box_expected(expected, box_with_array).astype(object) + if box_with_array in [Series, DataFrame]: + expected = expected.fillna(tdnat) # GH#18463 + + result = left / right + tm.assert_equal(result, expected) + + result = left // right + tm.assert_equal(result, expected) + + +class TestTimedelta64ArrayLikeArithmetic: + # Arithmetic tests for timedelta64[ns] vectors fully parametrized over + # DataFrame/Series/TimedeltaIndex/TimedeltaArray. Ideally all arithmetic + # tests will eventually end up here. + + def test_td64arr_pow_invalid(self, scalar_td, box_with_array): + td1 = Series([timedelta(minutes=5, seconds=3)] * 3) + td1.iloc[2] = np.nan + + td1 = tm.box_expected(td1, box_with_array) + + # check that we are getting a TypeError + # with 'operate' (from core/ops.py) for the ops that are not + # defined + pattern = "operate|unsupported|cannot|not supported" + with pytest.raises(TypeError, match=pattern): + scalar_td**td1 + + with pytest.raises(TypeError, match=pattern): + td1**scalar_td + + +def test_add_timestamp_to_timedelta(): + # GH: 35897 + timestamp = Timestamp("2021-01-01") + result = timestamp + timedelta_range("0s", "1s", periods=31) + expected = DatetimeIndex( + [ + timestamp + + ( + pd.to_timedelta("0.033333s") * i + + pd.to_timedelta("0.000001s") * divmod(i, 3)[0] + ) + for i in range(31) + ] + ) + tm.assert_index_equal(result, expected) diff --git a/pandas/tests/arrays/__init__.py b/pandas/tests/arrays/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/arrays/masked_shared.py b/pandas/tests/arrays/masked_shared.py new file mode 100644 index 0000000000000000000000000000000000000000..545b14af2c98bcdfeea2969d859ca097e7e0db8b --- /dev/null +++ b/pandas/tests/arrays/masked_shared.py @@ -0,0 +1,155 @@ +""" +Tests shared by MaskedArray subclasses. +""" + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.tests.extension.base import BaseOpsUtil + + +class ComparisonOps(BaseOpsUtil): + def _compare_other(self, data, op, other): + # array + result = pd.Series(op(data, other)) + expected = pd.Series(op(data._data, other), dtype="boolean") + + # fill the nan locations + expected[data._mask] = pd.NA + + tm.assert_series_equal(result, expected) + + # series + ser = pd.Series(data) + result = op(ser, other) + + # Set nullable dtype here to avoid upcasting when setting to pd.NA below + expected = op(pd.Series(data._data), other).astype("boolean") + + # fill the nan locations + expected[data._mask] = pd.NA + + tm.assert_series_equal(result, expected) + + # subclass will override to parametrize 'other' + def test_scalar(self, other, comparison_op, dtype): + op = comparison_op + left = pd.array([1, 0, None], dtype=dtype) + + result = op(left, other) + + if other is pd.NA: + expected = pd.array([None, None, None], dtype="boolean") + else: + values = op(left._data, other) + expected = pd.arrays.BooleanArray(values, left._mask, copy=True) + tm.assert_extension_array_equal(result, expected) + + # ensure we haven't mutated anything inplace + result[0] = pd.NA + tm.assert_extension_array_equal(left, pd.array([1, 0, None], dtype=dtype)) + + +class NumericOps: + # Shared by IntegerArray and FloatingArray, not BooleanArray + + def test_searchsorted_nan(self, dtype): + # The base class casts to object dtype, for which searchsorted returns + # 0 from the left and 10 from the right. + arr = pd.array(range(10), dtype=dtype) + + assert arr.searchsorted(np.nan, side="left") == 10 + assert arr.searchsorted(np.nan, side="right") == 10 + + def test_no_shared_mask(self, data): + result = data + 1 + assert not tm.shares_memory(result, data) + + def test_array(self, comparison_op, dtype): + op = comparison_op + + left = pd.array([0, 1, 2, None, None, None], dtype=dtype) + right = pd.array([0, 1, None, 0, 1, None], dtype=dtype) + + result = op(left, right) + values = op(left._data, right._data) + mask = left._mask | right._mask + + expected = pd.arrays.BooleanArray(values, mask) + tm.assert_extension_array_equal(result, expected) + + # ensure we haven't mutated anything inplace + result[0] = pd.NA + tm.assert_extension_array_equal( + left, pd.array([0, 1, 2, None, None, None], dtype=dtype) + ) + tm.assert_extension_array_equal( + right, pd.array([0, 1, None, 0, 1, None], dtype=dtype) + ) + + def test_compare_with_booleanarray(self, comparison_op, dtype): + op = comparison_op + + left = pd.array([True, False, None] * 3, dtype="boolean") + right = pd.array([0] * 3 + [1] * 3 + [None] * 3, dtype=dtype) + other = pd.array([False] * 3 + [True] * 3 + [None] * 3, dtype="boolean") + + expected = op(left, other) + result = op(left, right) + tm.assert_extension_array_equal(result, expected) + + # reversed op + expected = op(other, left) + result = op(right, left) + tm.assert_extension_array_equal(result, expected) + + def test_compare_to_string(self, dtype): + # GH#28930 + ser = pd.Series([1, None], dtype=dtype) + result = ser == "a" + expected = pd.Series([False, pd.NA], dtype="boolean") + + tm.assert_series_equal(result, expected) + + def test_ufunc_with_out(self, dtype): + arr = pd.array([1, 2, 3], dtype=dtype) + arr2 = pd.array([1, 2, pd.NA], dtype=dtype) + + mask = arr == arr + mask2 = arr2 == arr2 + + result = np.zeros(3, dtype=bool) + result |= mask + # If MaskedArray.__array_ufunc__ handled "out" appropriately, + # `result` should still be an ndarray. + assert isinstance(result, np.ndarray) + assert result.all() + + # result |= mask worked because mask could be cast losslessly to + # boolean ndarray. mask2 can't, so this raises + result = np.zeros(3, dtype=bool) + msg = "Specify an appropriate 'na_value' for this dtype" + with pytest.raises(ValueError, match=msg): + result |= mask2 + + # addition + res = np.add(arr, arr2) + expected = pd.array([2, 4, pd.NA], dtype=dtype) + tm.assert_extension_array_equal(res, expected) + + # when passing out=arr, we will modify 'arr' inplace. + res = np.add(arr, arr2, out=arr) + assert res is arr + tm.assert_extension_array_equal(res, expected) + tm.assert_extension_array_equal(arr, expected) + + def test_mul_td64_array(self, dtype): + # GH#45622 + arr = pd.array([1, 2, pd.NA], dtype=dtype) + other = np.arange(3, dtype=np.int64).view("m8[ns]") + + result = arr * other + expected = pd.array([pd.Timedelta(0), pd.Timedelta(2), pd.NaT]) + tm.assert_extension_array_equal(result, expected) diff --git a/pandas/tests/arrays/test_array.py b/pandas/tests/arrays/test_array.py new file mode 100644 index 0000000000000000000000000000000000000000..a02926dd5e158cd914a3eff0bc061a01cabea323 --- /dev/null +++ b/pandas/tests/arrays/test_array.py @@ -0,0 +1,539 @@ +import datetime +import decimal +import zoneinfo + +import numpy as np +import pytest + +from pandas._config import using_string_dtype + +import pandas as pd +import pandas._testing as tm +from pandas.api.extensions import register_extension_dtype +from pandas.arrays import ( + BooleanArray, + DatetimeArray, + FloatingArray, + IntegerArray, + IntervalArray, + SparseArray, + TimedeltaArray, +) +from pandas.core.arrays import ( + NumpyExtensionArray, + period_array, +) +from pandas.tests.extension.decimal import ( + DecimalArray, + DecimalDtype, + to_decimal, +) + + +@pytest.mark.parametrize("dtype_unit", ["M8[h]", "M8[m]", "m8[h]"]) +def test_dt64_array(dtype_unit): + # GH#53817 + dtype_var = np.dtype(dtype_unit) + msg = ( + r"datetime64 and timedelta64 dtype resolutions other than " + r"'s', 'ms', 'us', and 'ns' are no longer supported." + ) + with pytest.raises(ValueError, match=msg): + pd.array([], dtype=dtype_var) + + +@pytest.mark.parametrize( + "data, dtype, expected", + [ + # Basic NumPy defaults. + ([], None, FloatingArray._from_sequence([], dtype="Float64")), + ([1, 2], None, IntegerArray._from_sequence([1, 2], dtype="Int64")), + ([1, 2], object, NumpyExtensionArray(np.array([1, 2], dtype=object))), + ( + [1, 2], + np.dtype("float32"), + NumpyExtensionArray(np.array([1.0, 2.0], dtype=np.dtype("float32"))), + ), + ( + np.array([], dtype=object), + None, + NumpyExtensionArray(np.array([], dtype=object)), + ), + ( + np.array([1, 2], dtype="int64"), + None, + IntegerArray._from_sequence([1, 2], dtype="Int64"), + ), + ( + np.array([1.0, 2.0], dtype="float64"), + None, + FloatingArray._from_sequence([1.0, 2.0], dtype="Float64"), + ), + # String alias passes through to NumPy + ([1, 2], "float32", NumpyExtensionArray(np.array([1, 2], dtype="float32"))), + ([1, 2], "int64", NumpyExtensionArray(np.array([1, 2], dtype=np.int64))), + # GH#44715 FloatingArray does not support float16, so fall + # back to NumpyExtensionArray + ( + np.array([1, 2], dtype=np.float16), + None, + NumpyExtensionArray(np.array([1, 2], dtype=np.float16)), + ), + # idempotency with e.g. pd.array(pd.array([1, 2], dtype="int64")) + ( + NumpyExtensionArray(np.array([1, 2], dtype=np.int32)), + None, + NumpyExtensionArray(np.array([1, 2], dtype=np.int32)), + ), + # Period alias + ( + [pd.Period("2000", "D"), pd.Period("2001", "D")], + "Period[D]", + period_array(["2000", "2001"], freq="D"), + ), + # Period dtype + ( + [pd.Period("2000", "D")], + pd.PeriodDtype("D"), + period_array(["2000"], freq="D"), + ), + # Datetime (naive) + ( + [1, 2], + np.dtype("datetime64[ns]"), + DatetimeArray._from_sequence( + np.array([1, 2], dtype="M8[ns]"), dtype="M8[ns]" + ), + ), + ( + [1, 2], + np.dtype("datetime64[s]"), + DatetimeArray._from_sequence( + np.array([1, 2], dtype="M8[s]"), dtype="M8[s]" + ), + ), + ( + np.array([1, 2], dtype="datetime64[ns]"), + None, + DatetimeArray._from_sequence( + np.array([1, 2], dtype="M8[ns]"), dtype="M8[ns]" + ), + ), + ( + pd.DatetimeIndex(["2000", "2001"]), + np.dtype("datetime64[ns]"), + DatetimeArray._from_sequence(["2000", "2001"], dtype="M8[ns]"), + ), + ( + pd.DatetimeIndex(["2000", "2001"]), + None, + DatetimeArray._from_sequence(["2000", "2001"], dtype="M8[us]"), + ), + ( + ["2000", "2001"], + np.dtype("datetime64[ns]"), + DatetimeArray._from_sequence(["2000", "2001"], dtype="M8[ns]"), + ), + ( + [pd.NaT, pd.NaT], + None, + DatetimeArray._from_sequence([pd.NaT, pd.NaT], dtype="M8[s]"), + ), + # Datetime (tz-aware) + ( + ["2000", "2001"], + pd.DatetimeTZDtype(tz="CET"), + DatetimeArray._from_sequence( + ["2000", "2001"], dtype=pd.DatetimeTZDtype(tz="CET") + ), + ), + # Timedelta + ( + ["1h", "2h"], + np.dtype("timedelta64[ns]"), + TimedeltaArray._from_sequence(["1h", "2h"], dtype="m8[ns]"), + ), + ( + pd.TimedeltaIndex(["1h", "2h"]), + np.dtype("timedelta64[ns]"), + TimedeltaArray._from_sequence(["1h", "2h"], dtype="m8[ns]"), + ), + ( + np.array([1, 2], dtype="m8[s]"), + np.dtype("timedelta64[s]"), + TimedeltaArray._from_sequence( + np.array([1, 2], dtype="m8[s]"), dtype="m8[s]" + ), + ), + ( + pd.TimedeltaIndex(["1h", "2h"]), + None, + TimedeltaArray._from_sequence(["1h", "2h"], dtype="m8[us]"), + ), + ( + # preserve non-nano, i.e. don't cast to NumpyExtensionArray + TimedeltaArray._simple_new( + np.arange(5, dtype=np.int64).view("m8[s]"), dtype=np.dtype("m8[s]") + ), + None, + TimedeltaArray._simple_new( + np.arange(5, dtype=np.int64).view("m8[s]"), dtype=np.dtype("m8[s]") + ), + ), + ( + # preserve non-nano, i.e. don't cast to NumpyExtensionArray + TimedeltaArray._simple_new( + np.arange(5, dtype=np.int64).view("m8[s]"), dtype=np.dtype("m8[s]") + ), + np.dtype("m8[s]"), + TimedeltaArray._simple_new( + np.arange(5, dtype=np.int64).view("m8[s]"), dtype=np.dtype("m8[s]") + ), + ), + # Category + (["a", "b"], "category", pd.Categorical(["a", "b"])), + ( + ["a", "b"], + pd.CategoricalDtype(None, ordered=True), + pd.Categorical(["a", "b"], ordered=True), + ), + # Interval + ( + [pd.Interval(1, 2), pd.Interval(3, 4)], + "interval", + IntervalArray.from_tuples([(1, 2), (3, 4)]), + ), + # Sparse + ([0, 1], "Sparse[int64]", SparseArray([0, 1], dtype="int64")), + # IntegerNA + ([1, None], "Int16", pd.array([1, None], dtype="Int16")), + ( + pd.Series([1, 2]), + None, + NumpyExtensionArray(np.array([1, 2], dtype=np.int64)), + ), + # String + ( + ["a", None], + "string", + pd.StringDtype() + .construct_array_type() + ._from_sequence(["a", None], dtype=pd.StringDtype()), + ), + ( + ["a", None], + "str", + pd.StringDtype(na_value=np.nan) + .construct_array_type() + ._from_sequence(["a", None], dtype=pd.StringDtype(na_value=np.nan)) + if using_string_dtype() + else NumpyExtensionArray(np.array(["a", "None"])), + ), + ( + ["a", None], + pd.StringDtype(), + pd.StringDtype() + .construct_array_type() + ._from_sequence(["a", None], dtype=pd.StringDtype()), + ), + ( + ["a", None], + pd.StringDtype(na_value=np.nan), + pd.StringDtype(na_value=np.nan) + .construct_array_type() + ._from_sequence(["a", None], dtype=pd.StringDtype(na_value=np.nan)), + ), + ( + # numpy array with string dtype + np.array(["a", "b"], dtype=str), + pd.StringDtype(), + pd.StringDtype() + .construct_array_type() + ._from_sequence(["a", "b"], dtype=pd.StringDtype()), + ), + ( + # numpy array with string dtype + np.array(["a", "b"], dtype=str), + pd.StringDtype(na_value=np.nan), + pd.StringDtype(na_value=np.nan) + .construct_array_type() + ._from_sequence(["a", "b"], dtype=pd.StringDtype(na_value=np.nan)), + ), + # Boolean + ( + [True, None], + "boolean", + BooleanArray._from_sequence([True, None], dtype="boolean"), + ), + ( + [True, None], + pd.BooleanDtype(), + BooleanArray._from_sequence([True, None], dtype="boolean"), + ), + # Index + (pd.Index([1, 2]), None, NumpyExtensionArray(np.array([1, 2], dtype=np.int64))), + # Series[EA] returns the EA + ( + pd.Series(pd.Categorical(["a", "b"], categories=["a", "b", "c"])), + None, + pd.Categorical(["a", "b"], categories=["a", "b", "c"]), + ), + # "3rd party" EAs work + ([decimal.Decimal(0), decimal.Decimal(1)], "decimal", to_decimal([0, 1])), + # pass an ExtensionArray, but a different dtype + ( + period_array(["2000", "2001"], freq="D"), + "category", + pd.Categorical([pd.Period("2000", "D"), pd.Period("2001", "D")]), + ), + # Complex + ( + np.array([complex(1), complex(2)], dtype=np.complex128), + None, + NumpyExtensionArray( + np.array([complex(1), complex(2)], dtype=np.complex128) + ), + ), + ], +) +def test_array(data, dtype, expected): + result = pd.array(data, dtype=dtype) + tm.assert_equal(result, expected) + + +def test_array_copy(): + a = np.array([1, 2]) + # default is to copy + b = pd.array(a, dtype=a.dtype) + assert not tm.shares_memory(a, b) + + # copy=True + b = pd.array(a, dtype=a.dtype, copy=True) + assert not tm.shares_memory(a, b) + + # copy=False + b = pd.array(a, dtype=a.dtype, copy=False) + assert tm.shares_memory(a, b) + + +@pytest.mark.parametrize( + "data, expected", + [ + # period + ( + [pd.Period("2000", "D"), pd.Period("2001", "D")], + period_array(["2000", "2001"], freq="D"), + ), + # interval + ([pd.Interval(0, 1), pd.Interval(1, 2)], IntervalArray.from_breaks([0, 1, 2])), + # datetime + ( + [pd.Timestamp("2000").as_unit("s"), pd.Timestamp("2001").as_unit("s")], + DatetimeArray._from_sequence(["2000", "2001"], dtype="M8[s]"), + ), + ( + [datetime.datetime(2000, 1, 1), datetime.datetime(2001, 1, 1)], + DatetimeArray._from_sequence(["2000", "2001"], dtype="M8[us]"), + ), + ( + np.array([1, 2], dtype="M8[ns]"), + DatetimeArray._from_sequence(np.array([1, 2], dtype="M8[ns]")), + ), + ( + np.array([1, 2], dtype="M8[us]"), + DatetimeArray._simple_new( + np.array([1, 2], dtype="M8[us]"), dtype=np.dtype("M8[us]") + ), + ), + # datetimetz + ( + [ + pd.Timestamp("2000", tz="CET").as_unit("s"), + pd.Timestamp("2001", tz="CET").as_unit("s"), + ], + DatetimeArray._from_sequence( + ["2000", "2001"], dtype=pd.DatetimeTZDtype(tz="CET", unit="s") + ), + ), + ( + [ + datetime.datetime( + 2000, 1, 1, tzinfo=zoneinfo.ZoneInfo("Europe/Berlin") + ), + datetime.datetime( + 2001, 1, 1, tzinfo=zoneinfo.ZoneInfo("Europe/Berlin") + ), + ], + DatetimeArray._from_sequence( + ["2000", "2001"], + dtype=pd.DatetimeTZDtype( + tz=zoneinfo.ZoneInfo("Europe/Berlin"), unit="us" + ), + ), + ), + # timedelta + ( + [pd.Timedelta("1h"), pd.Timedelta("2h")], + TimedeltaArray._from_sequence(["1h", "2h"], dtype="m8[us]"), + ), + ( + np.array([1, 2], dtype="m8[ns]"), + TimedeltaArray._from_sequence( + np.array([1, 2], dtype="m8[ns]"), dtype=np.dtype("m8[ns]") + ), + ), + ( + np.array([1, 2], dtype="m8[us]"), + TimedeltaArray._from_sequence( + np.array([1, 2], dtype="m8[us]"), dtype=np.dtype("m8[us]") + ), + ), + # integer + ([1, 2], IntegerArray._from_sequence([1, 2], dtype="Int64")), + ([1, None], IntegerArray._from_sequence([1, None], dtype="Int64")), + ([1, pd.NA], IntegerArray._from_sequence([1, pd.NA], dtype="Int64")), + ([1, np.nan], IntegerArray._from_sequence([1, pd.NA], dtype="Int64")), + # float + ([0.1, 0.2], FloatingArray._from_sequence([0.1, 0.2], dtype="Float64")), + ([0.1, None], FloatingArray._from_sequence([0.1, pd.NA], dtype="Float64")), + ([0.1, np.nan], FloatingArray._from_sequence([0.1, pd.NA], dtype="Float64")), + ([0.1, pd.NA], FloatingArray._from_sequence([0.1, pd.NA], dtype="Float64")), + # integer-like float + ([1.0, 2.0], FloatingArray._from_sequence([1.0, 2.0], dtype="Float64")), + ([1.0, None], FloatingArray._from_sequence([1.0, pd.NA], dtype="Float64")), + ([1.0, np.nan], FloatingArray._from_sequence([1.0, pd.NA], dtype="Float64")), + ([1.0, pd.NA], FloatingArray._from_sequence([1.0, pd.NA], dtype="Float64")), + # mixed-integer-float + ([1, 2.0], FloatingArray._from_sequence([1.0, 2.0], dtype="Float64")), + ( + [1, np.nan, 2.0], + FloatingArray._from_sequence([1.0, None, 2.0], dtype="Float64"), + ), + # string + ( + ["a", "b"], + pd.StringDtype() + .construct_array_type() + ._from_sequence(["a", "b"], dtype=pd.StringDtype()), + ), + ( + ["a", None], + pd.StringDtype() + .construct_array_type() + ._from_sequence(["a", None], dtype=pd.StringDtype()), + ), + ( + # numpy array with string dtype + np.array(["a", "b"], dtype=str), + pd.StringDtype() + .construct_array_type() + ._from_sequence(["a", "b"], dtype=pd.StringDtype()), + ), + # Boolean + ([True, False], BooleanArray._from_sequence([True, False], dtype="boolean")), + ([True, None], BooleanArray._from_sequence([True, None], dtype="boolean")), + ], +) +def test_array_inference(data, expected): + result = pd.array(data) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "data", + [ + # mix of frequencies + [pd.Period("2000", "D"), pd.Period("2001", "Y")], + # mix of closed + [pd.Interval(0, 1, closed="left"), pd.Interval(1, 2, closed="right")], + # Mix of timezones + [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2000", tz="UTC")], + # Mix of tz-aware and tz-naive + [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2000")], + np.array([pd.Timestamp("2000"), pd.Timestamp("2000", tz="CET")]), + ], +) +def test_array_inference_fails(data): + result = pd.array(data) + expected = NumpyExtensionArray(np.array(data, dtype=object)) + tm.assert_extension_array_equal(result, expected) + + +@pytest.mark.parametrize("data", [np.array(0)]) +def test_nd_raises(data): + with pytest.raises(ValueError, match="NumpyExtensionArray must be 1-dimensional"): + pd.array(data, dtype="int64") + + +def test_scalar_raises(): + with pytest.raises(ValueError, match="Cannot pass scalar '1'"): + pd.array(1) + + +def test_dataframe_raises(): + # GH#51167 don't accidentally cast to StringArray by doing inference on columns + df = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + msg = "Cannot pass DataFrame to 'pandas.array'" + with pytest.raises(TypeError, match=msg): + pd.array(df) + + +def test_bounds_check(): + # GH21796 + with pytest.raises( + TypeError, match=r"cannot safely cast non-equivalent int(32|64) to uint16" + ): + pd.array([-1, 2, 3], dtype="UInt16") + + +# --------------------------------------------------------------------------- +# A couple dummy classes to ensure that Series and Indexes are unboxed before +# getting to the EA classes. + + +@register_extension_dtype +class DecimalDtype2(DecimalDtype): + name = "decimal2" + + def construct_array_type(self): + """ + Return the array type associated with this dtype. + + Returns + ------- + type + """ + return DecimalArray2 + + +class DecimalArray2(DecimalArray): + @classmethod + def _from_sequence(cls, scalars, *, dtype=None, copy=False): + if isinstance(scalars, (pd.Series, pd.Index)): + raise TypeError("scalars should not be of type pd.Series or pd.Index") + + return super()._from_sequence(scalars, dtype=dtype, copy=copy) + + +def test_array_unboxes(index_or_series): + box = index_or_series + + data = box([decimal.Decimal("1"), decimal.Decimal("2")]) + dtype = DecimalDtype2() + # make sure it works + with pytest.raises( + TypeError, match="scalars should not be of type pd.Series or pd.Index" + ): + DecimalArray2._from_sequence(data, dtype=dtype) + + result = pd.array(data, dtype="decimal2") + expected = DecimalArray2._from_sequence(data.values, dtype=dtype) + tm.assert_equal(result, expected) + + +def test_array_to_numpy_na(): + # GH#40638 + arr = pd.array([pd.NA, 1], dtype="string[python]") + result = arr.to_numpy(na_value=True, dtype=bool) + expected = np.array([True, True]) + tm.assert_numpy_array_equal(result, expected) diff --git a/pandas/tests/arrays/test_datetimelike.py b/pandas/tests/arrays/test_datetimelike.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6b3491a74693312e8c6c6e6f92faee127bc357 --- /dev/null +++ b/pandas/tests/arrays/test_datetimelike.py @@ -0,0 +1,1390 @@ +from __future__ import annotations + +import re +import warnings + +import numpy as np +import pytest + +from pandas._libs import ( + NaT, + Timestamp, +) +from pandas._libs.tslibs import to_offset +from pandas.compat.numpy import np_version_gt2 + +from pandas.core.dtypes.dtypes import PeriodDtype + +import pandas as pd +from pandas import ( + DatetimeIndex, + Period, + PeriodIndex, + TimedeltaIndex, +) +import pandas._testing as tm +from pandas.core.arrays import ( + DatetimeArray, + NumpyExtensionArray, + PeriodArray, + TimedeltaArray, +) + + +# TODO: more freq variants +@pytest.fixture(params=["D", "B", "W", "ME", "QE", "YE"]) +def freqstr(request): + """Fixture returning parametrized frequency in string format.""" + return request.param + + +@pytest.fixture +def period_index(freqstr): + """ + A fixture to provide PeriodIndex objects with different frequencies. + + Most PeriodArray behavior is already tested in PeriodIndex tests, + so here we just test that the PeriodArray behavior matches + the PeriodIndex behavior. + """ + # TODO: non-monotone indexes; NaTs, different start dates + with warnings.catch_warnings(): + # suppress deprecation of Period[B] + warnings.filterwarnings( + "ignore", message="Period with BDay freq", category=FutureWarning + ) + freqstr = PeriodDtype(to_offset(freqstr))._freqstr + pi = pd.period_range(start=Timestamp("2000-01-01"), periods=100, freq=freqstr) + return pi + + +@pytest.fixture +def datetime_index(freqstr): + """ + A fixture to provide DatetimeIndex objects with different frequencies. + + Most DatetimeArray behavior is already tested in DatetimeIndex tests, + so here we just test that the DatetimeArray behavior matches + the DatetimeIndex behavior. + """ + # TODO: non-monotone indexes; NaTs, different start dates, timezones + dti = pd.date_range( + start=Timestamp("2000-01-01"), periods=100, freq=freqstr, unit="ns" + ) + return dti + + +@pytest.fixture +def timedelta_index(): + """ + A fixture to provide TimedeltaIndex objects with different frequencies. + Most TimedeltaArray behavior is already tested in TimedeltaIndex tests, + so here we just test that the TimedeltaArray behavior matches + the TimedeltaIndex behavior. + """ + # TODO: flesh this out + return TimedeltaIndex(["1 Day", "3 Hours", "NaT"]) + + +class SharedTests: + index_cls: type[DatetimeIndex | PeriodIndex | TimedeltaIndex] + + @pytest.fixture + def arr1d(self): + """Fixture returning DatetimeArray with daily frequency.""" + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + if self.array_cls is PeriodArray: + arr = self.array_cls(data, freq="D") + else: + arr = self.index_cls(data, freq="D")._data + return arr + + def test_compare_len1_raises(self, arr1d): + # make sure we raise when comparing with different lengths, specific + # to the case where one has length-1, which numpy would broadcast + arr = arr1d + idx = self.index_cls(arr) + + with pytest.raises(ValueError, match="Lengths must match"): + arr == arr[:1] + + # test the index classes while we're at it, GH#23078 + with pytest.raises(ValueError, match="Lengths must match"): + idx <= idx[[0]] + + @pytest.mark.parametrize( + "result", + [ + pd.date_range("2020", periods=3), + pd.date_range("2020", periods=3, tz="UTC"), + pd.timedelta_range("0 days", periods=3), + pd.period_range("2020Q1", periods=3, freq="Q"), + ], + ) + def test_compare_with_Categorical(self, result): + expected = pd.Categorical(result) + assert all(result == expected) + assert not any(result != expected) + + @pytest.mark.parametrize("reverse", [True, False]) + @pytest.mark.parametrize("as_index", [True, False]) + def test_compare_categorical_dtype(self, arr1d, as_index, reverse, ordered): + other = pd.Categorical(arr1d, ordered=ordered) + if as_index: + other = pd.CategoricalIndex(other) + + left, right = arr1d, other + if reverse: + left, right = right, left + + ones = np.ones(arr1d.shape, dtype=bool) + zeros = ~ones + + result = left == right + tm.assert_numpy_array_equal(result, ones) + + result = left != right + tm.assert_numpy_array_equal(result, zeros) + + if not reverse and not as_index: + # Otherwise Categorical raises TypeError bc it is not ordered + # TODO: we should probably get the same behavior regardless? + result = left < right + tm.assert_numpy_array_equal(result, zeros) + + result = left <= right + tm.assert_numpy_array_equal(result, ones) + + result = left > right + tm.assert_numpy_array_equal(result, zeros) + + result = left >= right + tm.assert_numpy_array_equal(result, ones) + + def test_take(self): + data = np.arange(100, dtype="i8") * 24 * 3600 * 10**9 + np.random.default_rng(2).shuffle(data) + + if self.array_cls is PeriodArray: + arr = PeriodArray(data, dtype="period[D]") + else: + arr = self.index_cls(data)._data + idx = self.index_cls._simple_new(arr) + + takers = [1, 4, 94] + result = arr.take(takers) + expected = idx.take(takers) + + tm.assert_index_equal(self.index_cls(result), expected) + + takers = np.array([1, 4, 94]) + result = arr.take(takers) + expected = idx.take(takers) + + tm.assert_index_equal(self.index_cls(result), expected) + + @pytest.mark.parametrize("fill_value", [2, 2.0, Timestamp(2021, 1, 1, 12).time]) + def test_take_fill_raises(self, fill_value, arr1d): + msg = f"value should be a '{arr1d._scalar_type.__name__}' or 'NaT'. Got" + with pytest.raises(TypeError, match=msg): + arr1d.take([0, 1], allow_fill=True, fill_value=fill_value) + + def test_take_fill(self, arr1d): + arr = arr1d + + result = arr.take([-1, 1], allow_fill=True, fill_value=None) + assert result[0] is NaT + + result = arr.take([-1, 1], allow_fill=True, fill_value=np.nan) + assert result[0] is NaT + + result = arr.take([-1, 1], allow_fill=True, fill_value=NaT) + assert result[0] is NaT + + @pytest.mark.filterwarnings( + "ignore:Period with BDay freq is deprecated:FutureWarning" + ) + def test_take_fill_str(self, arr1d): + # Cast str fill_value matching other fill_value-taking methods + result = arr1d.take([-1, 1], allow_fill=True, fill_value=str(arr1d[-1])) + expected = arr1d[[-1, 1]] + tm.assert_equal(result, expected) + + msg = f"value should be a '{arr1d._scalar_type.__name__}' or 'NaT'. Got" + with pytest.raises(TypeError, match=msg): + arr1d.take([-1, 1], allow_fill=True, fill_value="foo") + + def test_concat_same_type(self, arr1d): + arr = arr1d + idx = self.index_cls(arr) + idx = idx.insert(0, NaT) + arr = arr1d + + result = arr._concat_same_type([arr[:-1], arr[1:], arr]) + arr2 = arr.astype(object) + expected = self.index_cls(np.concatenate([arr2[:-1], arr2[1:], arr2])) + + tm.assert_index_equal(self.index_cls(result), expected) + + def test_unbox_scalar(self, arr1d): + result = arr1d._unbox_scalar(arr1d[0]) + expected = arr1d._ndarray.dtype.type + assert isinstance(result, expected) + + result = arr1d._unbox_scalar(NaT) + assert isinstance(result, expected) + + msg = f"'value' should be a {self.scalar_type.__name__}." + with pytest.raises(ValueError, match=msg): + arr1d._unbox_scalar("foo") + + def test_check_compatible_with(self, arr1d): + arr1d._check_compatible_with(arr1d[0]) + arr1d._check_compatible_with(arr1d[:1]) + arr1d._check_compatible_with(NaT) + + def test_scalar_from_string(self, arr1d): + result = arr1d._scalar_from_string(str(arr1d[0])) + assert result == arr1d[0] + + def test_reduce_invalid(self, arr1d): + msg = "does not support operation 'not a method'" + with pytest.raises(TypeError, match=msg): + arr1d._reduce("not a method") + + @pytest.mark.parametrize("method", ["pad", "backfill"]) + def test_fillna_method_doesnt_change_orig(self, method): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + if self.array_cls is PeriodArray: + arr = self.array_cls(data, dtype="period[D]") + else: + dtype = "M8[ns]" if self.array_cls is DatetimeArray else "m8[ns]" + arr = self.array_cls._from_sequence(data, dtype=np.dtype(dtype)) + arr[4] = NaT + + fill_value = arr[3] if method == "pad" else arr[5] + + result = arr._pad_or_backfill(method=method) + assert result[4] == fill_value + + # check that the original was not changed + assert arr[4] is NaT + + def test_searchsorted(self): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + if self.array_cls is PeriodArray: + arr = self.array_cls(data, dtype="period[D]") + else: + dtype = "M8[ns]" if self.array_cls is DatetimeArray else "m8[ns]" + arr = self.array_cls._from_sequence(data, dtype=np.dtype(dtype)) + + # scalar + result = arr.searchsorted(arr[1]) + assert result == 1 + + result = arr.searchsorted(arr[2], side="right") + assert result == 3 + + # own-type + result = arr.searchsorted(arr[1:3]) + expected = np.array([1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + result = arr.searchsorted(arr[1:3], side="right") + expected = np.array([2, 3], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + # GH#29884 match numpy convention on whether NaT goes + # at the end or the beginning + result = arr.searchsorted(NaT) + assert result == 10 + + @pytest.mark.parametrize("box", [None, "index", "series"]) + def test_searchsorted_castable_strings( + self, arr1d, box, string_storage, using_infer_string + ): + arr = arr1d + if box is None: + pass + elif box == "index": + # Test the equivalent Index.searchsorted method while we're here + arr = self.index_cls(arr) + else: + # Test the equivalent Series.searchsorted method while we're here + arr = pd.Series(arr) + + # scalar + result = arr.searchsorted(str(arr[1])) + assert result == 1 + + result = arr.searchsorted(str(arr[2]), side="right") + assert result == 3 + + result = arr.searchsorted([str(x) for x in arr[1:3]]) + expected = np.array([1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + with pytest.raises( + TypeError, + match=re.escape( + f"value should be a '{arr1d._scalar_type.__name__}', 'NaT', " + "or array of those. Got 'str' instead." + ), + ): + arr.searchsorted("foo") + + msg = re.escape( + f"value should be a '{arr1d._scalar_type.__name__}', 'NaT', " + "or array of those. Got str array instead." + ) + if not using_infer_string: + msg = msg.replace("str", "string") + with pd.option_context("string_storage", string_storage): + with pytest.raises( + TypeError, + match=msg, + ): + arr.searchsorted([str(arr[1]), "baz"]) + + def test_getitem_near_implementation_bounds(self): + # We only check tz-naive for DTA bc the bounds are slightly different + # for other tzs + i8vals = np.asarray([NaT._value + n for n in range(1, 5)], dtype="i8") + if self.array_cls is PeriodArray: + arr = self.array_cls(i8vals, dtype="period[ns]") + else: + arr = self.index_cls(i8vals, freq="ns")._data + arr[0] # should not raise OutOfBoundsDatetime + + index = pd.Index(arr) + index[0] # should not raise OutOfBoundsDatetime + + ser = pd.Series(arr) + ser[0] # should not raise OutOfBoundsDatetime + + def test_getitem_2d(self, arr1d): + # 2d slicing on a 1D array + expected = type(arr1d)._simple_new( + arr1d._ndarray[:, np.newaxis], dtype=arr1d.dtype + ) + result = arr1d[:, np.newaxis] + tm.assert_equal(result, expected) + + # Lookup on a 2D array + arr2d = expected + expected = type(arr2d)._simple_new(arr2d._ndarray[:3, 0], dtype=arr2d.dtype) + result = arr2d[:3, 0] + tm.assert_equal(result, expected) + + # Scalar lookup + result = arr2d[-1, 0] + expected = arr1d[-1] + assert result == expected + + def test_iter_2d(self, arr1d): + data2d = arr1d._ndarray[:3, np.newaxis] + arr2d = type(arr1d)._simple_new(data2d, dtype=arr1d.dtype) + result = list(arr2d) + assert len(result) == 3 + for x in result: + assert isinstance(x, type(arr1d)) + assert x.ndim == 1 + assert x.dtype == arr1d.dtype + + def test_repr_2d(self, arr1d): + data2d = arr1d._ndarray[:3, np.newaxis] + arr2d = type(arr1d)._simple_new(data2d, dtype=arr1d.dtype) + + result = repr(arr2d) + + if isinstance(arr2d, TimedeltaArray): + expected = ( + f"<{type(arr2d).__name__}>\n" + "[\n" + f"['{arr1d[0]._repr_base()}'],\n" + f"['{arr1d[1]._repr_base()}'],\n" + f"['{arr1d[2]._repr_base()}']\n" + "]\n" + f"Shape: (3, 1), dtype: {arr1d.dtype}" + ) + else: + expected = ( + f"<{type(arr2d).__name__}>\n" + "[\n" + f"['{arr1d[0]}'],\n" + f"['{arr1d[1]}'],\n" + f"['{arr1d[2]}']\n" + "]\n" + f"Shape: (3, 1), dtype: {arr1d.dtype}" + ) + + assert result == expected + + def test_setitem(self): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + if self.array_cls is PeriodArray: + arr = self.array_cls(data, dtype="period[D]") + else: + arr = self.index_cls(data, freq="D")._data + + arr[0] = arr[1] + expected = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + expected[0] = expected[1] + + tm.assert_numpy_array_equal(arr.asi8, expected) + + arr[:2] = arr[-2:] + expected[:2] = expected[-2:] + tm.assert_numpy_array_equal(arr.asi8, expected) + + def test_setitem_list_of_nats(self, arr1d): + # GH#63420 + arr1d[:] = [NaT] * len(arr1d) + assert arr1d.isna().all() + + @pytest.mark.parametrize( + "box", + [ + pd.Index, + pd.Series, + np.array, + list, + NumpyExtensionArray, + ], + ) + def test_setitem_object_dtype(self, box, arr1d): + expected = arr1d.copy()[::-1] + if expected.dtype.kind in ["m", "M"]: + expected = expected._with_freq(None) + + vals = expected + if box is list: + vals = list(vals) + elif box is np.array: + # if we do np.array(x).astype(object) then dt64 and td64 cast to ints + vals = np.array(vals.astype(object)) + elif box is NumpyExtensionArray: + vals = box(np.asarray(vals, dtype=object)) + else: + vals = box(vals).astype(object) + + arr1d[:] = vals + + tm.assert_equal(arr1d, expected) + + def test_setitem_strs(self, arr1d): + # Check that we parse strs in both scalar and listlike + + # Setting list-like of strs + expected = arr1d.copy() + expected[[0, 1]] = arr1d[-2:] + + result = arr1d.copy() + result[:2] = [str(x) for x in arr1d[-2:]] + tm.assert_equal(result, expected) + + # Same thing but now for just a scalar str + expected = arr1d.copy() + expected[0] = arr1d[-1] + + result = arr1d.copy() + result[0] = str(arr1d[-1]) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("as_index", [True, False]) + def test_setitem_categorical(self, arr1d, as_index): + expected = arr1d.copy()[::-1] + if not isinstance(expected, PeriodArray): + expected = expected._with_freq(None) + + cat = pd.Categorical(arr1d) + if as_index: + cat = pd.CategoricalIndex(cat) + + arr1d[:] = cat[::-1] + + tm.assert_equal(arr1d, expected) + + def test_setitem_raises(self, arr1d): + arr = arr1d[:10] + val = arr[0] + + with pytest.raises(IndexError, match="index 12 is out of bounds"): + arr[12] = val + + with pytest.raises(TypeError, match="value should be a.* 'object'"): + arr[0] = object() + + msg = "cannot set using a list-like indexer with a different length" + with pytest.raises(ValueError, match=msg): + # GH#36339 + arr[[]] = [arr[1]] + + msg = "cannot set using a slice indexer with a different length than" + with pytest.raises(ValueError, match=msg): + # GH#36339 + arr[1:1] = arr[:3] + + @pytest.mark.parametrize("box", [list, np.array, pd.Index, pd.Series]) + def test_setitem_numeric_raises(self, arr1d, box): + # We dont case e.g. int64 to our own dtype for setitem + + msg = ( + f"value should be a '{arr1d._scalar_type.__name__}', " + "'NaT', or array of those. Got" + ) + with pytest.raises(TypeError, match=msg): + arr1d[:2] = box([0, 1]) + + with pytest.raises(TypeError, match=msg): + arr1d[:2] = box([0.0, 1.0]) + + def test_inplace_arithmetic(self): + # GH#24115 check that iadd and isub are actually in-place + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + if self.array_cls is PeriodArray: + arr = self.array_cls(data, dtype="period[D]") + else: + arr = self.index_cls(data, freq="D")._data + + expected = arr + pd.Timedelta(days=1) + arr += pd.Timedelta(days=1) + tm.assert_equal(arr, expected) + + expected = arr - pd.Timedelta(days=1) + arr -= pd.Timedelta(days=1) + tm.assert_equal(arr, expected) + + def test_shift_fill_int_deprecated(self, arr1d): + # GH#31971, enforced in 2.0 + with pytest.raises(TypeError, match="value should be a"): + arr1d.shift(1, fill_value=1) + + def test_median(self, arr1d): + arr = arr1d + if len(arr) % 2 == 0: + # make it easier to define `expected` + arr = arr[:-1] + + expected = arr[len(arr) // 2] + + result = arr.median() + assert type(result) is type(expected) + assert result == expected + + arr[len(arr) // 2] = NaT + if not isinstance(expected, Period): + expected = arr[len(arr) // 2 - 1 : len(arr) // 2 + 2].mean() + + assert arr.median(skipna=False) is NaT + + result = arr.median() + assert type(result) is type(expected) + assert result == expected + + assert arr[:0].median() is NaT + assert arr[:0].median(skipna=False) is NaT + + # 2d Case + arr2 = arr.reshape(-1, 1) + + result = arr2.median(axis=None) + assert type(result) is type(expected) + assert result == expected + + assert arr2.median(axis=None, skipna=False) is NaT + + result = arr2.median(axis=0) + expected2 = type(arr)._from_sequence([expected], dtype=arr.dtype) + tm.assert_equal(result, expected2) + + result = arr2.median(axis=0, skipna=False) + expected2 = type(arr)._from_sequence([NaT], dtype=arr.dtype) + tm.assert_equal(result, expected2) + + result = arr2.median(axis=1) + tm.assert_equal(result, arr) + + result = arr2.median(axis=1, skipna=False) + tm.assert_equal(result, arr) + + def test_from_integer_array(self): + arr = np.array([1, 2, 3], dtype=np.int64) + data = pd.array(arr, dtype="Int64") + if self.array_cls is PeriodArray: + expected = self.array_cls(arr, dtype=self.example_dtype) + result = self.array_cls(data, dtype=self.example_dtype) + else: + expected = self.array_cls._from_sequence(arr, dtype=self.example_dtype) + result = self.array_cls._from_sequence(data, dtype=self.example_dtype) + + tm.assert_extension_array_equal(result, expected) + + +class TestDatetimeArray(SharedTests): + index_cls = DatetimeIndex + array_cls = DatetimeArray + scalar_type = Timestamp + example_dtype = "M8[ns]" + + @pytest.fixture + def arr1d(self, tz_naive_fixture, freqstr): + """ + Fixture returning DatetimeArray with parametrized frequency and + timezones + """ + tz = tz_naive_fixture + dti = pd.date_range( + "2016-01-01 01:01:00", periods=5, freq=freqstr, tz=tz, unit="ns" + ) + dta = dti._data + return dta + + def test_round(self, arr1d): + # GH#24064 + dti = self.index_cls(arr1d) + + result = dti.round(freq="2min") + expected = dti - pd.Timedelta(minutes=1) + expected = expected._with_freq(None) + tm.assert_index_equal(result, expected) + + dta = dti._data + result = dta.round(freq="2min") + expected = expected._data._with_freq(None) + tm.assert_datetime_array_equal(result, expected) + + def test_array_interface(self, datetime_index): + arr = datetime_index._data + copy_false = None if np_version_gt2 else False + + # default asarray gives the same underlying data (for tz naive) + result = np.asarray(arr) + expected = arr._ndarray + assert result is expected + tm.assert_numpy_array_equal(result, expected) + result = np.array(arr, copy=copy_false) + assert result is expected + tm.assert_numpy_array_equal(result, expected) + + # specifying M8[ns] gives the same result as default + result = np.asarray(arr, dtype="datetime64[ns]") + expected = arr._ndarray + assert result is expected + tm.assert_numpy_array_equal(result, expected) + result = np.array(arr, dtype="datetime64[ns]", copy=copy_false) + assert result is expected + tm.assert_numpy_array_equal(result, expected) + result = np.array(arr, dtype="datetime64[ns]") + if not np_version_gt2: + # TODO: GH 57739 + assert result is not expected + tm.assert_numpy_array_equal(result, expected) + + # to object dtype + result = np.asarray(arr, dtype=object) + expected = np.array(list(arr), dtype=object) + tm.assert_numpy_array_equal(result, expected) + + # to other dtype always copies + result = np.asarray(arr, dtype="int64") + assert result is not arr.asi8 + assert not np.may_share_memory(arr, result) + expected = arr.asi8.copy() + tm.assert_numpy_array_equal(result, expected) + + # other dtypes handled by numpy + for dtype in ["float64", str]: + result = np.asarray(arr, dtype=dtype) + expected = np.asarray(arr).astype(dtype) + tm.assert_numpy_array_equal(result, expected) + + def test_array_object_dtype(self, arr1d): + # GH#23524 + arr = arr1d + dti = self.index_cls(arr1d) + + expected = np.array(list(dti)) + + result = np.array(arr, dtype=object) + tm.assert_numpy_array_equal(result, expected) + + # also test the DatetimeIndex method while we're at it + result = np.array(dti, dtype=object) + tm.assert_numpy_array_equal(result, expected) + + def test_array_tz(self, arr1d): + # GH#23524 + arr = arr1d + dti = self.index_cls(arr1d, copy=False) + copy_false = None if np_version_gt2 else False + + expected = dti.asi8.view("M8[ns]") + result = np.array(arr, dtype="M8[ns]") + tm.assert_numpy_array_equal(result, expected) + + result = np.array(arr, dtype="datetime64[ns]") + tm.assert_numpy_array_equal(result, expected) + + # check that we are not making copies when setting copy=copy_false + result = np.array(arr, dtype="M8[ns]", copy=copy_false) + assert result.base is expected.base + assert result.base is not None + result = np.array(arr, dtype="datetime64[ns]", copy=copy_false) + assert result.base is expected.base + assert result.base is not None + + def test_array_i8_dtype(self, arr1d): + arr = arr1d + dti = self.index_cls(arr1d) + copy_false = None if np_version_gt2 else False + + expected = dti.asi8 + result = np.array(arr, dtype="i8") + tm.assert_numpy_array_equal(result, expected) + + result = np.array(arr, dtype=np.int64) + tm.assert_numpy_array_equal(result, expected) + + # check that we are still making copies when setting copy=copy_false + result = np.array(arr, dtype="i8", copy=copy_false) + assert result.base is not expected.base + assert result.base is None + + def test_from_array_keeps_base(self): + # Ensure that DatetimeArray._ndarray.base isn't lost. + arr = np.array(["2000-01-01", "2000-01-02"], dtype="M8[ns]") + dta = DatetimeArray._from_sequence(arr, dtype=arr.dtype) + + assert dta._ndarray is arr + dta = DatetimeArray._from_sequence(arr[:0], dtype=arr.dtype) + assert dta._ndarray.base is arr + + def test_from_dti(self, arr1d): + arr = arr1d + dti = self.index_cls(arr1d) + assert list(dti) == list(arr) + + # Check that Index.__new__ knows what to do with DatetimeArray + dti2 = pd.Index(arr) + assert isinstance(dti2, DatetimeIndex) + assert list(dti2) == list(arr) + + def test_astype_object(self, arr1d): + arr = arr1d + dti = self.index_cls(arr1d) + + asobj = arr.astype("O") + assert isinstance(asobj, np.ndarray) + assert asobj.dtype == "O" + assert list(asobj) == list(dti) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_to_period(self, datetime_index, freqstr): + dti = datetime_index + arr = dti._data + + freqstr = PeriodDtype(to_offset(freqstr))._freqstr + expected = dti.to_period(freq=freqstr) + result = arr.to_period(freq=freqstr) + assert isinstance(result, PeriodArray) + + tm.assert_equal(result, expected._data) + + def test_to_period_2d(self, arr1d): + arr2d = arr1d.reshape(1, -1) + + warn = None if arr1d.tz is None else UserWarning + with tm.assert_produces_warning(warn, match="will drop timezone information"): + result = arr2d.to_period("D") + expected = arr1d.to_period("D").reshape(1, -1) + tm.assert_period_array_equal(result, expected) + + @pytest.mark.parametrize("propname", DatetimeArray._bool_ops) + def test_bool_properties(self, arr1d, propname): + # in this case _bool_ops is just `is_leap_year` + dti = self.index_cls(arr1d) + arr = arr1d + assert dti.freq == arr.freq + + result = getattr(arr, propname) + expected = np.array(getattr(dti, propname), dtype=result.dtype) + + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("propname", DatetimeArray._field_ops) + def test_int_properties(self, arr1d, propname): + dti = self.index_cls(arr1d) + arr = arr1d + + result = getattr(arr, propname) + expected = np.array(getattr(dti, propname), dtype=result.dtype) + + tm.assert_numpy_array_equal(result, expected) + + def test_take_fill_valid(self, arr1d, fixed_now_ts): + arr = arr1d + dti = self.index_cls(arr1d) + + now = fixed_now_ts.tz_localize(dti.tz) + result = arr.take([-1, 1], allow_fill=True, fill_value=now) + assert result[0] == now + + msg = f"value should be a '{arr1d._scalar_type.__name__}' or 'NaT'. Got" + with pytest.raises(TypeError, match=msg): + # fill_value Timedelta invalid + arr.take([-1, 1], allow_fill=True, fill_value=now - now) + + with pytest.raises(TypeError, match=msg): + # fill_value Period invalid + arr.take([-1, 1], allow_fill=True, fill_value=Period("2014Q1")) + + tz = None if dti.tz is not None else "US/Eastern" + now = fixed_now_ts.tz_localize(tz) + msg = "Cannot compare tz-naive and tz-aware datetime-like objects" + with pytest.raises(TypeError, match=msg): + # Timestamp with mismatched tz-awareness + arr.take([-1, 1], allow_fill=True, fill_value=now) + + value = NaT._value + msg = f"value should be a '{arr1d._scalar_type.__name__}' or 'NaT'. Got" + with pytest.raises(TypeError, match=msg): + # require NaT, not iNaT, as it could be confused with an integer + arr.take([-1, 1], allow_fill=True, fill_value=value) + + value = np.timedelta64("NaT", "ns") + with pytest.raises(TypeError, match=msg): + # require appropriate-dtype if we have an NA value + arr.take([-1, 1], allow_fill=True, fill_value=value) + + if arr.tz is not None: + # GH#37356 + # Assuming here that arr1d fixture does not include Australia/Melbourne + value = fixed_now_ts.tz_localize("Australia/Melbourne") + result = arr.take([-1, 1], allow_fill=True, fill_value=value) + + expected = arr.take( + [-1, 1], + allow_fill=True, + fill_value=value.tz_convert(arr.dtype.tz), + ) + tm.assert_equal(result, expected) + + def test_concat_same_type_invalid(self, arr1d): + # different timezones + arr = arr1d + + if arr.tz is None: + other = arr.tz_localize("UTC") + else: + other = arr.tz_localize(None) + + with pytest.raises(ValueError, match="to_concat must have the same"): + arr._concat_same_type([arr, other]) + + def test_concat_same_type_different_freq(self, unit): + # we *can* concatenate DTI with different freqs. + a = pd.date_range("2000", periods=2, freq="D", tz="US/Central", unit=unit)._data + b = pd.date_range("2000", periods=2, freq="h", tz="US/Central", unit=unit)._data + result = DatetimeArray._concat_same_type([a, b]) + expected = ( + pd.to_datetime( + [ + "2000-01-01 00:00:00", + "2000-01-02 00:00:00", + "2000-01-01 00:00:00", + "2000-01-01 01:00:00", + ] + ) + .tz_localize("US/Central") + .as_unit(unit) + ._data + ) + + tm.assert_datetime_array_equal(result, expected) + + def test_strftime(self, arr1d, using_infer_string): + arr = arr1d + + result = arr.strftime("%Y %b") + expected = np.array([ts.strftime("%Y %b") for ts in arr], dtype=object) + if using_infer_string: + expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan)) + tm.assert_equal(result, expected) + + def test_strftime_nat(self, using_infer_string): + # GH 29578 + arr = DatetimeIndex(["2019-01-01", NaT])._data + + result = arr.strftime("%Y-%m-%d") + expected = np.array(["2019-01-01", np.nan], dtype=object) + if using_infer_string: + expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan)) + tm.assert_equal(result, expected) + + +class TestTimedeltaArray(SharedTests): + index_cls = TimedeltaIndex + array_cls = TimedeltaArray + scalar_type = pd.Timedelta + example_dtype = "m8[ns]" + + def test_from_tdi(self): + tdi = TimedeltaIndex(["1 Day", "3 Hours"]) + arr = tdi._data + assert list(arr) == list(tdi) + + # Check that Index.__new__ knows what to do with TimedeltaArray + tdi2 = pd.Index(arr) + assert isinstance(tdi2, TimedeltaIndex) + assert list(tdi2) == list(arr) + + def test_astype_object(self): + tdi = TimedeltaIndex(["1 Day", "3 Hours"]) + arr = tdi._data + asobj = arr.astype("O") + assert isinstance(asobj, np.ndarray) + assert asobj.dtype == "O" + assert list(asobj) == list(tdi) + + def test_to_pytimedelta(self, timedelta_index): + tdi = timedelta_index + arr = tdi._data + + expected = tdi.to_pytimedelta() + result = arr.to_pytimedelta() + + tm.assert_numpy_array_equal(result, expected) + + def test_total_seconds(self, timedelta_index): + tdi = timedelta_index + arr = tdi._data + + expected = tdi.total_seconds() + result = arr.total_seconds() + + tm.assert_numpy_array_equal(result, expected.values) + + @pytest.mark.parametrize("propname", TimedeltaArray._field_ops) + def test_int_properties(self, timedelta_index, propname): + tdi = timedelta_index + arr = tdi._data + + result = getattr(arr, propname) + expected = np.array(getattr(tdi, propname), dtype=result.dtype) + + tm.assert_numpy_array_equal(result, expected) + + def test_array_interface(self, timedelta_index): + arr = timedelta_index._data + copy_false = None if np_version_gt2 else False + + # default asarray gives the same underlying data + result = np.asarray(arr) + expected = arr._ndarray + assert result is expected + tm.assert_numpy_array_equal(result, expected) + result = np.array(arr, copy=copy_false) + assert result is expected + tm.assert_numpy_array_equal(result, expected) + + # specifying m8[us] gives the same result as default + result = np.asarray(arr, dtype="timedelta64[us]") + expected = arr._ndarray + assert result is expected + tm.assert_numpy_array_equal(result, expected) + result = np.array(arr, dtype="timedelta64[us]", copy=copy_false) + assert result is expected + tm.assert_numpy_array_equal(result, expected) + result = np.array(arr, dtype="timedelta64[us]") + if not np_version_gt2: + # TODO: GH 57739 + assert result is not expected + tm.assert_numpy_array_equal(result, expected) + + # to object dtype + result = np.asarray(arr, dtype=object) + expected = np.array(list(arr), dtype=object) + tm.assert_numpy_array_equal(result, expected) + + # to other dtype always copies + result = np.asarray(arr, dtype="int64") + assert result is not arr.asi8 + assert not np.may_share_memory(arr, result) + expected = arr.asi8.copy() + tm.assert_numpy_array_equal(result, expected) + + # other dtypes handled by numpy + for dtype in ["float64", str]: + result = np.asarray(arr, dtype=dtype) + expected = np.asarray(arr).astype(dtype) + tm.assert_numpy_array_equal(result, expected) + + def test_take_fill_valid(self, timedelta_index, fixed_now_ts): + tdi = timedelta_index + arr = tdi._data + + td1 = pd.Timedelta(days=1) + result = arr.take([-1, 1], allow_fill=True, fill_value=td1) + assert result[0] == td1 + + value = fixed_now_ts + msg = f"value should be a '{arr._scalar_type.__name__}' or 'NaT'. Got" + with pytest.raises(TypeError, match=msg): + # fill_value Timestamp invalid + arr.take([0, 1], allow_fill=True, fill_value=value) + + value = fixed_now_ts.to_period("D") + with pytest.raises(TypeError, match=msg): + # fill_value Period invalid + arr.take([0, 1], allow_fill=True, fill_value=value) + + value = np.datetime64("NaT", "ns") + with pytest.raises(TypeError, match=msg): + # require appropriate-dtype if we have an NA value + arr.take([-1, 1], allow_fill=True, fill_value=value) + + +@pytest.mark.filterwarnings(r"ignore:Period with BDay freq is deprecated:FutureWarning") +@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") +class TestPeriodArray(SharedTests): + index_cls = PeriodIndex + array_cls = PeriodArray + scalar_type = Period + example_dtype = PeriodIndex([], freq="W").dtype + + @pytest.fixture + def arr1d(self, period_index): + """ + Fixture returning DatetimeArray from parametrized PeriodIndex objects + """ + return period_index._data + + def test_from_pi(self, arr1d): + pi = self.index_cls(arr1d) + arr = arr1d + assert list(arr) == list(pi) + + # Check that Index.__new__ knows what to do with PeriodArray + pi2 = pd.Index(arr) + assert isinstance(pi2, PeriodIndex) + assert list(pi2) == list(arr) + + def test_astype_object(self, arr1d): + pi = self.index_cls(arr1d) + arr = arr1d + asobj = arr.astype("O") + assert isinstance(asobj, np.ndarray) + assert asobj.dtype == "O" + assert list(asobj) == list(pi) + + def test_take_fill_valid(self, arr1d): + arr = arr1d + + value = NaT._value + msg = f"value should be a '{arr1d._scalar_type.__name__}' or 'NaT'. Got" + with pytest.raises(TypeError, match=msg): + # require NaT, not iNaT, as it could be confused with an integer + arr.take([-1, 1], allow_fill=True, fill_value=value) + + value = np.timedelta64("NaT", "ns") + with pytest.raises(TypeError, match=msg): + # require appropriate-dtype if we have an NA value + arr.take([-1, 1], allow_fill=True, fill_value=value) + + @pytest.mark.parametrize("how", ["S", "E"]) + def test_to_timestamp(self, how, arr1d): + pi = self.index_cls(arr1d) + arr = arr1d + + expected = DatetimeIndex(pi.to_timestamp(how=how))._data + result = arr.to_timestamp(how=how) + assert isinstance(result, DatetimeArray) + + tm.assert_equal(result, expected) + + def test_to_timestamp_roundtrip_bday(self): + # Case where infer_freq inside would choose "D" instead of "B" + dta = pd.date_range("2021-10-18", periods=3, freq="B", unit="ns")._data + parr = dta.to_period() + result = parr.to_timestamp() + assert result.freq == "B" + tm.assert_extension_array_equal(result, dta.as_unit("us")) + + dta2 = dta[::2] + parr2 = dta2.to_period() + result2 = parr2.to_timestamp() + assert result2.freq == "2B" + tm.assert_extension_array_equal(result2, dta2.as_unit("us")) + + parr3 = dta.to_period("2B") + result3 = parr3.to_timestamp() + assert result3.freq == "B" + tm.assert_extension_array_equal(result3, dta.as_unit("us")) + + def test_to_timestamp_out_of_bounds(self): + # GH#19643 previously overflowed silently + pi = pd.period_range("1500", freq="Y", periods=3) + pi.to_timestamp() + dta = pi._data.to_timestamp() + assert dta[0] == Timestamp(1500, 1, 1) + + @pytest.mark.parametrize("propname", PeriodArray._bool_ops) + def test_bool_properties(self, arr1d, propname): + # in this case _bool_ops is just `is_leap_year` + pi = self.index_cls(arr1d) + arr = arr1d + + result = getattr(arr, propname) + expected = np.array(getattr(pi, propname)) + + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("propname", PeriodArray._field_ops) + def test_int_properties(self, arr1d, propname): + pi = self.index_cls(arr1d) + arr = arr1d + + result = getattr(arr, propname) + expected = np.array(getattr(pi, propname)) + + tm.assert_numpy_array_equal(result, expected) + + def test_array_interface(self, arr1d): + arr = arr1d + + # default asarray gives objects + result = np.asarray(arr) + expected = np.array(list(arr), dtype=object) + tm.assert_numpy_array_equal(result, expected) + + # to object dtype (same as default) + result = np.asarray(arr, dtype=object) + tm.assert_numpy_array_equal(result, expected) + + # to int64 gives the underlying representation + result = np.asarray(arr, dtype="int64") + tm.assert_numpy_array_equal(result, arr.asi8) + + result2 = np.asarray(arr, dtype="int64") + assert np.may_share_memory(result, result2) + + result_copy1 = np.array(arr, dtype="int64", copy=True) + result_copy2 = np.array(arr, dtype="int64", copy=True) + assert not np.may_share_memory(result_copy1, result_copy2) + + # to other dtypes + msg = r"float\(\) argument must be a string or a( real)? number, not 'Period'" + with pytest.raises(TypeError, match=msg): + np.asarray(arr, dtype="float64") + + result = np.asarray(arr, dtype="S20") + expected = np.asarray(arr).astype("S20") + tm.assert_numpy_array_equal(result, expected) + + def test_strftime(self, arr1d, using_infer_string): + arr = arr1d + + result = arr.strftime("%Y") + expected = np.array([per.strftime("%Y") for per in arr], dtype=object) + if using_infer_string: + expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan)) + tm.assert_equal(result, expected) + + def test_strftime_nat(self, using_infer_string): + # GH 29578 + arr = PeriodArray(PeriodIndex(["2019-01-01", NaT], dtype="period[D]")) + + result = arr.strftime("%Y-%m-%d") + expected = np.array(["2019-01-01", np.nan], dtype=object) + if using_infer_string: + expected = pd.array(expected, dtype=pd.StringDtype(na_value=np.nan)) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "arr,casting_nats", + [ + ( + TimedeltaIndex(["1 Day", "3 Hours", "NaT"])._data, + (NaT, np.timedelta64("NaT", "ns")), + ), + ( + pd.date_range("2000-01-01", periods=3, freq="D")._data, + (NaT, np.datetime64("NaT", "ns")), + ), + (pd.period_range("2000-01-01", periods=3, freq="D")._data, (NaT,)), + ], + ids=lambda x: type(x).__name__, +) +def test_casting_nat_setitem_array(arr, casting_nats): + expected = type(arr)._from_sequence([NaT, arr[1], arr[2]], dtype=arr.dtype) + + for nat in casting_nats: + arr = arr.copy() + arr[0] = nat + tm.assert_equal(arr, expected) + + +@pytest.mark.parametrize( + "arr,non_casting_nats", + [ + ( + TimedeltaIndex(["1 Day", "3 Hours", "NaT"])._data, + (np.datetime64("NaT", "ns"), NaT._value), + ), + ( + pd.date_range("2000-01-01", periods=3, freq="D")._data, + (np.timedelta64("NaT", "ns"), NaT._value), + ), + ( + pd.period_range("2000-01-01", periods=3, freq="D")._data, + (np.datetime64("NaT", "ns"), np.timedelta64("NaT", "ns"), NaT._value), + ), + ], + ids=lambda x: type(x).__name__, +) +def test_invalid_nat_setitem_array(arr, non_casting_nats): + msg = ( + "value should be a '(Timestamp|Timedelta|Period)', 'NaT', or array of those. " + "Got '(timedelta64|datetime64|int)' instead." + ) + + for nat in non_casting_nats: + with pytest.raises(TypeError, match=msg): + arr[0] = nat + + +@pytest.mark.parametrize( + "arr", + [ + pd.date_range("2000", periods=4)._values, + pd.timedelta_range("2000", periods=4)._values, + ], +) +def test_to_numpy_extra(arr): + arr[0] = NaT + original = arr.copy() + + result = arr.to_numpy() + assert np.isnan(result[0]) + + result = arr.to_numpy(dtype="int64") + assert result[0] == -9223372036854775808 + + result = arr.to_numpy(dtype="int64", na_value=0) + assert result[0] == 0 + + result = arr.to_numpy(na_value=arr[1].to_numpy()) + assert result[0] == result[1] + + result = arr.to_numpy(na_value=arr[1].to_numpy(copy=False)) + assert result[0] == result[1] + + tm.assert_equal(arr, original) + + +@pytest.mark.parametrize( + "arr", + [ + pd.date_range("2000", periods=4)._values, + pd.timedelta_range("2000", periods=4)._values, + ], +) +def test_to_numpy_extra_readonly(arr): + arr[0] = NaT + original = arr.copy() + arr._readonly = True + + result = arr.to_numpy(dtype=object) + assert result.flags.writeable + + # numpy does not do zero-copy conversion from M8 to i8 + result = arr.to_numpy(dtype="int64") + assert result.flags.writeable + + tm.assert_equal(arr, original) + + +@pytest.mark.parametrize("as_index", [True, False]) +@pytest.mark.parametrize( + "values", + [ + pd.to_datetime(["2020-01-01", "2020-02-01"]), + pd.to_timedelta([1, 2], unit="D"), + PeriodIndex(["2020-01-01", "2020-02-01"], freq="D"), + ], +) +@pytest.mark.parametrize( + "klass", + [ + list, + np.array, + pd.array, + pd.Series, + pd.Index, + pd.Categorical, + pd.CategoricalIndex, + ], +) +def test_searchsorted_datetimelike_with_listlike(values, klass, as_index): + # https://github.com/pandas-dev/pandas/issues/32762 + if not as_index: + values = values._data + + result = values.searchsorted(klass(values)) + expected = np.array([0, 1], dtype=result.dtype) + + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize( + "values", + [ + pd.to_datetime(["2020-01-01", "2020-02-01"]), + pd.to_timedelta([1, 2], unit="D"), + PeriodIndex(["2020-01-01", "2020-02-01"], freq="D"), + ], +) +@pytest.mark.parametrize( + "arg", [[1, 2], ["a", "b"], [Timestamp("2020-01-01", tz="Europe/London")] * 2] +) +def test_searchsorted_datetimelike_with_listlike_invalid_dtype(values, arg): + # https://github.com/pandas-dev/pandas/issues/32762 + msg = "[Unexpected type|Cannot compare]" + with pytest.raises(TypeError, match=msg): + values.searchsorted(arg) + + +@pytest.mark.parametrize("klass", [list, tuple, np.array, pd.Series]) +def test_period_index_construction_from_strings(klass): + # https://github.com/pandas-dev/pandas/issues/26109 + strings = ["2020Q1", "2020Q2"] * 2 + data = klass(strings) + result = PeriodIndex(data, freq="Q") + expected = PeriodIndex([Period(s) for s in strings]) + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"]) +def test_from_pandas_array(dtype): + # GH#24615 + data = np.array([1, 2, 3], dtype=dtype) + arr = NumpyExtensionArray(data) + + cls = {"M8[ns]": DatetimeArray, "m8[ns]": TimedeltaArray}[dtype] + + result = cls._from_sequence(arr, dtype=dtype) + expected = cls._from_sequence(data, dtype=dtype) + tm.assert_extension_array_equal(result, expected) + + func = {"M8[ns]": pd.to_datetime, "m8[ns]": pd.to_timedelta}[dtype] + result = func(arr).array + expected = func(data).array + tm.assert_equal(result, expected) + + # Let's check the Indexes while we're here + idx_cls = {"M8[ns]": DatetimeIndex, "m8[ns]": TimedeltaIndex}[dtype] + result = idx_cls(arr) + expected = idx_cls(data) + tm.assert_index_equal(result, expected) diff --git a/pandas/tests/arrays/test_datetimes.py b/pandas/tests/arrays/test_datetimes.py new file mode 100644 index 0000000000000000000000000000000000000000..7f1be30e5c0221d73abc148e15922db225d8dfed --- /dev/null +++ b/pandas/tests/arrays/test_datetimes.py @@ -0,0 +1,848 @@ +""" +Tests for DatetimeArray +""" + +from __future__ import annotations + +from datetime import timedelta +import operator + +import numpy as np +import pytest + +from pandas._libs.tslibs import tz_compare +from pandas.errors import Pandas4Warning + +from pandas.core.dtypes.dtypes import DatetimeTZDtype + +import pandas as pd +import pandas._testing as tm +from pandas.core.arrays import ( + DatetimeArray, + TimedeltaArray, +) + + +class TestNonNano: + @pytest.fixture(params=["s", "ms", "us"]) + def unit(self, request): + """Fixture returning parametrized time units""" + return request.param + + @pytest.fixture + def dtype(self, unit, tz_naive_fixture): + tz = tz_naive_fixture + if tz is None: + return np.dtype(f"datetime64[{unit}]") + else: + return DatetimeTZDtype(unit=unit, tz=tz) + + @pytest.fixture + def dta_dti(self, unit, dtype): + tz = getattr(dtype, "tz", None) + + dti = pd.date_range("2016-01-01", periods=55, freq="D", tz=tz, unit="ns") + if tz is None: + arr = np.asarray(dti).astype(f"M8[{unit}]") + else: + arr = np.asarray(dti.tz_convert("UTC").tz_localize(None)).astype( + f"M8[{unit}]" + ) + + dta = DatetimeArray._simple_new(arr, dtype=dtype) + return dta, dti + + @pytest.fixture + def dta(self, dta_dti): + dta, dti = dta_dti + return dta + + def test_non_nano(self, unit, dtype): + arr = np.arange(5, dtype=np.int64).view(f"M8[{unit}]") + dta = DatetimeArray._simple_new(arr, dtype=dtype) + + assert dta.dtype == dtype + assert dta[0].unit == unit + assert tz_compare(dta.tz, dta[0].tz) + assert (dta[0] == dta[:1]).all() + + @pytest.mark.parametrize( + "field", DatetimeArray._field_ops + DatetimeArray._bool_ops + ) + def test_fields(self, unit, field, dtype, dta_dti): + dta, dti = dta_dti + + assert (dti == dta).all() + + res = getattr(dta, field) + expected = getattr(dti._data, field) + tm.assert_numpy_array_equal(res, expected) + + def test_normalize(self, unit): + dti = pd.date_range("2016-01-01 06:00:00", periods=55, freq="D") + arr = np.asarray(dti).astype(f"M8[{unit}]") + + dta = DatetimeArray._simple_new(arr, dtype=arr.dtype) + + assert not dta.is_normalized + + # TODO: simplify once we can just .astype to other unit + exp = np.asarray(dti.normalize()).astype(f"M8[{unit}]") + expected = DatetimeArray._simple_new(exp, dtype=exp.dtype) + + res = dta.normalize() + tm.assert_extension_array_equal(res, expected) + + def test_normalize_overflow_raises(self): + # GH#60583 + ts = pd.Timestamp.min + dta = DatetimeArray._from_sequence([ts], dtype="M8[ns]") + + msg = "Cannot normalize Timestamp without integer overflow" + with pytest.raises(ValueError, match=msg): + dta.normalize() + + def test_simple_new_requires_match(self, unit): + arr = np.arange(5, dtype=np.int64).view(f"M8[{unit}]") + dtype = DatetimeTZDtype(unit, "UTC") + + dta = DatetimeArray._simple_new(arr, dtype=dtype) + assert dta.dtype == dtype + + wrong = DatetimeTZDtype("ns", "UTC") + with pytest.raises(AssertionError, match="^$"): + DatetimeArray._simple_new(arr, dtype=wrong) + + def test_std_non_nano(self, unit): + dti = pd.date_range("2016-01-01", periods=55, freq="D", unit="ns") + arr = np.asarray(dti).astype(f"M8[{unit}]") + + dta = DatetimeArray._simple_new(arr, dtype=arr.dtype) + + # we should match the nano-reso std, but floored to our reso. + res = dta.std() + assert res._creso == dta._creso + assert res == dti.std().floor(unit) + + @pytest.mark.filterwarnings("ignore:Converting to PeriodArray.*:UserWarning") + def test_to_period(self, dta_dti): + dta, dti = dta_dti + result = dta.to_period("D") + expected = dti._data.to_period("D") + + tm.assert_extension_array_equal(result, expected) + + def test_iter(self, dta): + res = next(iter(dta)) + expected = dta[0] + + assert type(res) is pd.Timestamp + assert res._value == expected._value + assert res._creso == expected._creso + assert res == expected + + def test_astype_object(self, dta): + result = dta.astype(object) + assert all(x._creso == dta._creso for x in result) + assert all(x == y for x, y in zip(result, dta, strict=True)) + + def test_to_pydatetime(self, dta_dti): + dta, dti = dta_dti + + result = dta.to_pydatetime() + expected = dti.to_pydatetime() + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("meth", ["time", "timetz", "date"]) + def test_time_date(self, dta_dti, meth): + dta, dti = dta_dti + + result = getattr(dta, meth) + expected = getattr(dti, meth) + tm.assert_numpy_array_equal(result, expected) + + def test_format_native_types(self, unit, dtype, dta_dti): + # In this case we should get the same formatted values with our nano + # version dti._data as we do with the non-nano dta + dta, dti = dta_dti + + res = dta._format_native_types() + exp = dti._data._format_native_types() + tm.assert_numpy_array_equal(res, exp) + + def test_repr(self, dta_dti, unit): + dta, dti = dta_dti + + assert repr(dta) == repr(dti._data).replace("[ns", f"[{unit}") + + # TODO: tests with td64 + def test_compare_mismatched_resolutions(self, comparison_op): + # comparison that numpy gets wrong bc of silent overflows + op = comparison_op + + iinfo = np.iinfo(np.int64) + vals = np.array([iinfo.min, iinfo.min + 1, iinfo.max], dtype=np.int64) + + # Construct so that arr2[1] < arr[1] < arr[2] < arr2[2] + arr = np.array(vals).view("M8[ns]") + arr2 = arr.view("M8[s]") + + left = DatetimeArray._simple_new(arr, dtype=arr.dtype) + right = DatetimeArray._simple_new(arr2, dtype=arr2.dtype) + + if comparison_op is operator.eq: + expected = np.array([False, False, False]) + elif comparison_op is operator.ne: + expected = np.array([True, True, True]) + elif comparison_op in [operator.lt, operator.le]: + expected = np.array([False, False, True]) + else: + expected = np.array([False, True, False]) + + result = op(left, right) + tm.assert_numpy_array_equal(result, expected) + + result = op(left[1], right) + tm.assert_numpy_array_equal(result, expected) + + if op not in [operator.eq, operator.ne]: + # check that numpy still gets this wrong; if it is fixed we may be + # able to remove compare_mismatched_resolutions + np_res = op(left._ndarray, right._ndarray) + tm.assert_numpy_array_equal(np_res[1:], ~expected[1:]) + + def test_add_mismatched_reso_doesnt_downcast(self): + # https://github.com/pandas-dev/pandas/pull/48748#issuecomment-1260181008 + td = pd.Timedelta(microseconds=1) + dti = pd.date_range("2016-01-01", periods=3) - td + dta = dti._data.as_unit("us") + + res = dta + td.as_unit("us") + # even though the result is an even number of days + # (so we _could_ downcast to unit="s"), we do not. + assert res.unit == "us" + + @pytest.mark.parametrize( + "scalar", + [ + timedelta(hours=2), + pd.Timedelta(hours=2), + np.timedelta64(2, "h"), + np.timedelta64(2 * 3600 * 1000, "ms"), + pd.offsets.Minute(120), + pd.offsets.Hour(2), + ], + ) + def test_add_timedeltalike_scalar_mismatched_reso(self, dta_dti, scalar): + dta, dti = dta_dti + + td = pd.Timedelta(scalar) + exp_unit = tm.get_finest_unit(dta.unit, td.unit) + + expected = (dti + td)._data.as_unit(exp_unit) + result = dta + scalar + tm.assert_extension_array_equal(result, expected) + + result = scalar + dta + tm.assert_extension_array_equal(result, expected) + + expected = (dti - td)._data.as_unit(exp_unit) + result = dta - scalar + tm.assert_extension_array_equal(result, expected) + + def test_sub_datetimelike_scalar_mismatch(self): + dti = pd.date_range("2016-01-01", periods=3) + dta = dti._data.as_unit("us") + + ts = dta[0].as_unit("s") + + result = dta - ts + expected = (dti - dti[0])._data.as_unit("us") + assert result.dtype == "m8[us]" + tm.assert_extension_array_equal(result, expected) + + def test_sub_datetime64_reso_mismatch(self): + dti = pd.date_range("2016-01-01", periods=3) + left = dti._data.as_unit("s") + right = left.as_unit("ms") + + result = left - right + exp_values = np.array([0, 0, 0], dtype="m8[ms]") + expected = TimedeltaArray._simple_new( + exp_values, + dtype=exp_values.dtype, + ) + tm.assert_extension_array_equal(result, expected) + result2 = right - left + tm.assert_extension_array_equal(result2, expected) + + +class TestDatetimeArrayComparisons: + # TODO: merge this into tests/arithmetic/test_datetime64 once it is + # sufficiently robust + + def test_cmp_dt64_arraylike_tznaive(self, comparison_op): + # arbitrary tz-naive DatetimeIndex + op = comparison_op + + dti = pd.date_range("2016-01-1", freq="MS", periods=9, tz=None) + arr = dti._data + assert arr.freq == dti.freq + assert arr.tz == dti.tz + + right = dti + + expected = np.ones(len(arr), dtype=bool) + if comparison_op.__name__ in ["ne", "gt", "lt"]: + # for these the comparisons should be all-False + expected = ~expected + + result = op(arr, arr) + tm.assert_numpy_array_equal(result, expected) + for other in [ + right, + np.array(right), + list(right), + tuple(right), + right.astype(object), + ]: + result = op(arr, other) + tm.assert_numpy_array_equal(result, expected) + + result = op(other, arr) + tm.assert_numpy_array_equal(result, expected) + + +class TestDatetimeArray: + def test_astype_ns_to_ms_near_bounds(self): + # GH#55979 + ts = pd.Timestamp("1677-09-21 00:12:43.145225") + target = ts.as_unit("ms") + + dta = DatetimeArray._from_sequence([ts], dtype="M8[ns]") + assert (dta.view("i8") == ts.as_unit("ns").value).all() + + result = dta.astype("M8[ms]") + assert result[0] == target + + expected = DatetimeArray._from_sequence([ts], dtype="M8[ms]") + assert (expected.view("i8") == target._value).all() + + tm.assert_datetime_array_equal(result, expected) + + def test_astype_non_nano_tznaive(self): + dti = pd.date_range("2016-01-01", periods=3) + + res = dti.astype("M8[s]") + assert res.dtype == "M8[s]" + + dta = dti._data + res = dta.astype("M8[s]") + assert res.dtype == "M8[s]" + assert isinstance(res, pd.core.arrays.DatetimeArray) # used to be ndarray + + def test_astype_non_nano_tzaware(self): + dti = pd.date_range("2016-01-01", periods=3, tz="UTC") + + res = dti.astype("M8[s, US/Pacific]") + assert res.dtype == "M8[s, US/Pacific]" + + dta = dti._data + res = dta.astype("M8[s, US/Pacific]") + assert res.dtype == "M8[s, US/Pacific]" + + # from non-nano to non-nano, preserving reso + res2 = res.astype("M8[s, UTC]") + assert res2.dtype == "M8[s, UTC]" + assert not tm.shares_memory(res2, res) + + res3 = res.astype("M8[s, UTC]", copy=False) + assert res2.dtype == "M8[s, UTC]" + assert tm.shares_memory(res3, res) + + def test_astype_to_same(self): + arr = DatetimeArray._from_sequence( + ["2000"], dtype=DatetimeTZDtype(tz="US/Central") + ) + result = arr.astype(DatetimeTZDtype(tz="US/Central"), copy=False) + assert result is arr + + @pytest.mark.parametrize("dtype", ["datetime64[ns]", "datetime64[ns, UTC]"]) + @pytest.mark.parametrize( + "other", ["datetime64[ns]", "datetime64[ns, UTC]", "datetime64[ns, CET]"] + ) + def test_astype_copies(self, dtype, other): + # https://github.com/pandas-dev/pandas/pull/32490 + ser = pd.Series([1, 2], dtype=dtype) + orig = ser.copy() + + err = False + if (dtype == "datetime64[ns]") ^ (other == "datetime64[ns]"): + # deprecated in favor of tz_localize + err = True + + if err: + if dtype == "datetime64[ns]": + msg = "Use obj.tz_localize instead or series.dt.tz_localize instead" + else: + msg = "from timezone-aware dtype to timezone-naive dtype" + with pytest.raises(TypeError, match=msg): + ser.astype(other) + else: + t = ser.astype(other) + t[:] = pd.NaT + tm.assert_series_equal(ser, orig) + + @pytest.mark.parametrize("dtype", [int, np.int32, np.int64, "uint32", "uint64"]) + def test_astype_int(self, dtype): + arr = DatetimeArray._from_sequence( + [pd.Timestamp("2000"), pd.Timestamp("2001")], dtype="M8[ns]" + ) + + if np.dtype(dtype) != np.int64: + with pytest.raises(TypeError, match=r"Do obj.astype\('int64'\)"): + arr.astype(dtype) + return + + result = arr.astype(dtype) + expected = arr._ndarray.view("i8") + tm.assert_numpy_array_equal(result, expected) + + def test_astype_to_sparse_dt64(self): + # GH#50082 + dti = pd.date_range("2016-01-01", periods=4) + dta = dti._data + result = dta.astype("Sparse[datetime64[ns]]") + + assert result.dtype == "Sparse[datetime64[ns]]" + assert (result == dta).all() + + def test_tz_setter_raises(self): + arr = DatetimeArray._from_sequence( + ["2000"], dtype=DatetimeTZDtype(tz="US/Central") + ) + with pytest.raises(AttributeError, match="tz_localize"): + arr.tz = "UTC" + + def test_setitem_str_impute_tz(self, tz_naive_fixture): + # Like for getitem, if we are passed a naive-like string, we impute + # our own timezone. + tz = tz_naive_fixture + + data = np.array([1, 2, 3], dtype="M8[ns]") + dtype = data.dtype if tz is None else DatetimeTZDtype(tz=tz) + arr = DatetimeArray._from_sequence(data, dtype=dtype) + expected = arr.copy() + + ts = pd.Timestamp("2020-09-08 16:50").tz_localize(tz) + setter = str(ts.tz_localize(None)) + + # Setting a scalar tznaive string + expected[0] = ts + arr[0] = setter + tm.assert_equal(arr, expected) + + # Setting a listlike of tznaive strings + expected[1] = ts + arr[:2] = [setter, setter] + tm.assert_equal(arr, expected) + + def test_setitem_different_tz_raises(self): + # pre-2.0 we required exact tz match, in 2.0 we require only + # tzawareness-match + data = np.array([1, 2, 3], dtype="M8[ns]") + arr = DatetimeArray._from_sequence( + data, copy=False, dtype=DatetimeTZDtype(tz="US/Central") + ) + with pytest.raises(TypeError, match="Cannot compare tz-naive and tz-aware"): + arr[0] = pd.Timestamp("2000") + + ts = pd.Timestamp("2000", tz="US/Eastern") + arr[0] = ts + assert arr[0] == ts.tz_convert("US/Central") + + def test_setitem_clears_freq(self): + a = pd.date_range("2000", periods=2, freq="D", tz="US/Central")._data + a[0] = pd.Timestamp("2000", tz="US/Central") + assert a.freq is None + + @pytest.mark.parametrize( + "obj", + [ + pd.Timestamp("2021-01-01"), + pd.Timestamp("2021-01-01").to_datetime64(), + pd.Timestamp("2021-01-01").to_pydatetime(), + ], + ) + def test_setitem_objects(self, obj): + # make sure we accept datetime64 and datetime in addition to Timestamp + dti = pd.date_range("2000", periods=2, freq="D") + arr = dti._data + + arr[0] = obj + assert arr[0] == obj + + def test_repeat_preserves_tz(self): + dti = pd.date_range("2000", periods=2, freq="D", tz="US/Central") + arr = dti._data + + repeated = arr.repeat([1, 1]) + + # preserves tz and values, but not freq + expected = DatetimeArray._from_sequence(arr.asi8, dtype=arr.dtype) + tm.assert_equal(repeated, expected) + + def test_value_counts_preserves_tz(self): + dti = pd.date_range("2000", periods=2, freq="D", tz="US/Central") + arr = dti._data.repeat([4, 3]) + + result = arr.value_counts() + + # Note: not tm.assert_index_equal, since `freq`s do not match + assert result.index.equals(dti) + + arr[-2] = pd.NaT + result = arr.value_counts(dropna=False) + expected = pd.Series([4, 2, 1], index=[dti[0], dti[1], pd.NaT], name="count") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("method", ["pad", "backfill"]) + def test_fillna_preserves_tz(self, method): + dti = pd.date_range( + "2000-01-01", periods=5, freq="D", tz="US/Central", unit="ns" + ) + arr = DatetimeArray._from_sequence(dti, dtype=dti.dtype, copy=True) + arr[2] = pd.NaT + + fill_val = dti[1] if method == "pad" else dti[3] + expected = DatetimeArray._from_sequence( + [dti[0], dti[1], fill_val, dti[3], dti[4]], + dtype=DatetimeTZDtype(tz="US/Central"), + ) + + result = arr._pad_or_backfill(method=method) + tm.assert_extension_array_equal(result, expected) + + # assert that arr and dti were not modified in-place + assert arr[2] is pd.NaT + assert dti[2] == pd.Timestamp("2000-01-03", tz="US/Central") + + def test_fillna_2d(self): + dti = pd.date_range("2016-01-01", periods=6, tz="US/Pacific") + dta = dti._data.reshape(3, 2).copy() + dta[0, 1] = pd.NaT + dta[1, 0] = pd.NaT + + res1 = dta._pad_or_backfill(method="pad") + expected1 = dta.copy() + expected1[1, 0] = dta[0, 0] + tm.assert_extension_array_equal(res1, expected1) + + res2 = dta._pad_or_backfill(method="backfill") + expected2 = dta.copy() + expected2 = dta.copy() + expected2[1, 0] = dta[2, 0] + expected2[0, 1] = dta[1, 1] + tm.assert_extension_array_equal(res2, expected2) + + # with different ordering for underlying ndarray; behavior should + # be unchanged + dta2 = dta._from_backing_data(dta._ndarray.copy(order="F")) + assert dta2._ndarray.flags["F_CONTIGUOUS"] + assert not dta2._ndarray.flags["C_CONTIGUOUS"] + tm.assert_extension_array_equal(dta, dta2) + + res3 = dta2._pad_or_backfill(method="pad") + tm.assert_extension_array_equal(res3, expected1) + + res4 = dta2._pad_or_backfill(method="backfill") + tm.assert_extension_array_equal(res4, expected2) + + # test the DataFrame method while we're here + df = pd.DataFrame(dta) + res = df.ffill() + expected = pd.DataFrame(expected1) + tm.assert_frame_equal(res, expected) + + res = df.bfill() + expected = pd.DataFrame(expected2) + tm.assert_frame_equal(res, expected) + + def test_array_interface_tz(self): + tz = "US/Central" + data = pd.date_range("2017", periods=2, tz=tz, unit="ns")._data + result = np.asarray(data) + + expected = np.array( + [ + pd.Timestamp("2017-01-01T00:00:00", tz=tz), + pd.Timestamp("2017-01-02T00:00:00", tz=tz), + ], + dtype=object, + ) + tm.assert_numpy_array_equal(result, expected) + + result = np.asarray(data, dtype=object) + tm.assert_numpy_array_equal(result, expected) + + result = np.asarray(data, dtype="M8[ns]") + + expected = np.array( + ["2017-01-01T06:00:00", "2017-01-02T06:00:00"], dtype="M8[ns]" + ) + tm.assert_numpy_array_equal(result, expected) + + def test_array_interface(self): + data = pd.date_range("2017", periods=2, unit="ns")._data + expected = np.array( + ["2017-01-01T00:00:00", "2017-01-02T00:00:00"], dtype="datetime64[ns]" + ) + + result = np.asarray(data) + tm.assert_numpy_array_equal(result, expected) + + result = np.asarray(data, dtype=object) + expected = np.array( + [pd.Timestamp("2017-01-01T00:00:00"), pd.Timestamp("2017-01-02T00:00:00")], + dtype=object, + ) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("index", [True, False]) + def test_searchsorted_different_tz(self, index): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + arr = pd.DatetimeIndex(data, freq="D")._data.tz_localize("Asia/Tokyo") + if index: + arr = pd.Index(arr) + + expected = arr.searchsorted(arr[2]) + result = arr.searchsorted(arr[2].tz_convert("UTC")) + assert result == expected + + expected = arr.searchsorted(arr[2:6]) + result = arr.searchsorted(arr[2:6].tz_convert("UTC")) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("index", [True, False]) + def test_searchsorted_tzawareness_compat(self, index): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + arr = pd.DatetimeIndex(data, freq="D")._data + if index: + arr = pd.Index(arr) + + mismatch = arr.tz_localize("Asia/Tokyo") + + msg = "Cannot compare tz-naive and tz-aware datetime-like objects" + with pytest.raises(TypeError, match=msg): + arr.searchsorted(mismatch[0]) + with pytest.raises(TypeError, match=msg): + arr.searchsorted(mismatch) + + with pytest.raises(TypeError, match=msg): + mismatch.searchsorted(arr[0]) + with pytest.raises(TypeError, match=msg): + mismatch.searchsorted(arr) + + @pytest.mark.parametrize( + "other", + [ + 1, + np.int64(1), + 1.0, + np.timedelta64("NaT"), + pd.Timedelta(days=2), + "invalid", + np.arange(10, dtype="i8") * 24 * 3600 * 10**9, + np.arange(10).view("timedelta64[ns]") * 24 * 3600 * 10**9, + pd.Timestamp("2021-01-01").to_period("D"), + ], + ) + @pytest.mark.parametrize("index", [True, False]) + def test_searchsorted_invalid_types(self, other, index): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + arr = pd.DatetimeIndex(data, freq="D")._data + if index: + arr = pd.Index(arr) + + msg = "|".join( + [ + "searchsorted requires compatible dtype or scalar", + "value should be a 'Timestamp', 'NaT', or array of those. Got", + ] + ) + with pytest.raises(TypeError, match=msg): + arr.searchsorted(other) + + def test_shift_fill_value(self): + dti = pd.date_range("2016-01-01", periods=3) + + dta = dti._data + expected = DatetimeArray._from_sequence( + np.roll(dta._ndarray, 1), dtype=dti.dtype + ) + + fv = dta[-1] + for fill_value in [fv, fv.to_pydatetime(), fv.to_datetime64()]: + result = dta.shift(1, fill_value=fill_value) + tm.assert_datetime_array_equal(result, expected) + + dta = dta.tz_localize("UTC") + expected = expected.tz_localize("UTC") + fv = dta[-1] + for fill_value in [fv, fv.to_pydatetime()]: + result = dta.shift(1, fill_value=fill_value) + tm.assert_datetime_array_equal(result, expected) + + def test_shift_value_tzawareness_mismatch(self): + dti = pd.date_range("2016-01-01", periods=3) + + dta = dti._data + + fv = dta[-1].tz_localize("UTC") + for invalid in [fv, fv.to_pydatetime()]: + with pytest.raises(TypeError, match="Cannot compare"): + dta.shift(1, fill_value=invalid) + + dta = dta.tz_localize("UTC") + fv = dta[-1].tz_localize(None) + for invalid in [fv, fv.to_pydatetime(), fv.to_datetime64()]: + with pytest.raises(TypeError, match="Cannot compare"): + dta.shift(1, fill_value=invalid) + + def test_shift_requires_tzmatch(self): + # pre-2.0 we required exact tz match, in 2.0 we require just + # matching tzawareness + dti = pd.date_range("2016-01-01", periods=3, tz="UTC") + dta = dti._data + + fill_value = pd.Timestamp("2020-10-18 18:44", tz="US/Pacific") + + result = dta.shift(1, fill_value=fill_value) + expected = dta.shift(1, fill_value=fill_value.tz_convert("UTC")) + tm.assert_equal(result, expected) + + def test_tz_localize_t2d(self): + dti = pd.date_range("1994-05-12", periods=12, tz="US/Pacific") + dta = dti._data.reshape(3, 4) + result = dta.tz_localize(None) + + expected = dta.ravel().tz_localize(None).reshape(dta.shape) + tm.assert_datetime_array_equal(result, expected) + + roundtrip = expected.tz_localize("US/Pacific") + tm.assert_datetime_array_equal(roundtrip, dta) + + @pytest.mark.parametrize( + "tz", ["US/Eastern", "dateutil/US/Eastern", "pytz/US/Eastern"] + ) + def test_iter_zoneinfo_fold(self, tz): + # GH#49684 + if tz.startswith("pytz/"): + pytz = pytest.importorskip("pytz") + tz = pytz.timezone(tz.removeprefix("pytz/")) + utc_vals = np.array( + [1320552000, 1320555600, 1320559200, 1320562800], dtype=np.int64 + ) + utc_vals *= 1_000_000_000 + + dta = ( + DatetimeArray._from_sequence(utc_vals, dtype=np.dtype("M8[ns]")) + .tz_localize("UTC") + .tz_convert(tz) + ) + + left = dta[2] + right = list(dta)[2] + assert str(left) == str(right) + # previously there was a bug where with non-pytz right would be + # Timestamp('2011-11-06 01:00:00-0400', tz='US/Eastern') + # while left would be + # Timestamp('2011-11-06 01:00:00-0500', tz='US/Eastern') + # The .value's would match (so they would compare as equal), + # but the folds would not + assert left.utcoffset() == right.utcoffset() + + # The same bug in ints_to_pydatetime affected .astype, so we test + # that here. + right2 = dta.astype(object)[2] + assert str(left) == str(right2) + assert left.utcoffset() == right2.utcoffset() + + @pytest.mark.parametrize( + "freq", + ["2M", "2SM", "2sm", "2Q", "2Q-SEP", "1Y", "2Y-MAR", "2m", "2q-sep", "2y"], + ) + def test_date_range_frequency_M_Q_Y_raises(self, freq): + msg = f"Invalid frequency: {freq}" + + with pytest.raises(ValueError, match=msg): + pd.date_range("1/1/2000", periods=4, freq=freq) + + @pytest.mark.parametrize("freq_depr", ["2MIN", "2nS", "2Us"]) + def test_date_range_uppercase_frequency_deprecated(self, freq_depr): + # GH#9586, GH#54939 + depr_msg = ( + f"'{freq_depr[1:]}' is deprecated and will be removed in a " + f"future version, please use '{freq_depr.lower()[1:]}' instead." + ) + + expected = pd.date_range("1/1/2000", periods=4, freq=freq_depr.lower()) + with tm.assert_produces_warning(Pandas4Warning, match=depr_msg): + result = pd.date_range("1/1/2000", periods=4, freq=freq_depr) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "freq", + [ + "2ye-mar", + "2ys", + "2qe", + "2qs-feb", + "2bqs", + "2sms", + "2bms", + "2cbme", + "2me", + ], + ) + def test_date_range_lowercase_frequency_raises(self, freq): + msg = f"Invalid frequency: {freq}" + + with pytest.raises(ValueError, match=msg): + pd.date_range("1/1/2000", periods=4, freq=freq) + + def test_date_range_lowercase_frequency_deprecated(self): + # GH#9586, GH#54939 + depr_msg = "'w' is deprecated and will be removed in a future version" + + expected = pd.date_range("1/1/2000", periods=4, freq="2W") + with tm.assert_produces_warning(Pandas4Warning, match=depr_msg): + result = pd.date_range("1/1/2000", periods=4, freq="2w") + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("freq", ["1A", "2A-MAR", "2a-mar"]) + def test_date_range_frequency_A_raises(self, freq): + msg = f"Invalid frequency: {freq}" + + with pytest.raises(ValueError, match=msg): + pd.date_range("1/1/2000", periods=4, freq=freq) + + @pytest.mark.parametrize("freq", ["2H", "2CBH", "2S"]) + def test_date_range_uppercase_frequency_raises(self, freq): + msg = f"Invalid frequency: {freq}" + + with pytest.raises(ValueError, match=msg): + pd.date_range("1/1/2000", periods=4, freq=freq) + + +def test_factorize_sort_without_freq(): + dta = DatetimeArray._from_sequence([0, 2, 1], dtype="M8[ns]") + + msg = r"call pd.factorize\(obj, sort=True\) instead" + with pytest.raises(NotImplementedError, match=msg): + dta.factorize(sort=True) + + # Do TimedeltaArray while we're here + tda = dta - dta[0] + with pytest.raises(NotImplementedError, match=msg): + tda.factorize(sort=True) diff --git a/pandas/tests/arrays/test_ndarray_backed.py b/pandas/tests/arrays/test_ndarray_backed.py new file mode 100644 index 0000000000000000000000000000000000000000..2af59a03a5b3e774c1c0692399c285f0ec26a1dc --- /dev/null +++ b/pandas/tests/arrays/test_ndarray_backed.py @@ -0,0 +1,76 @@ +""" +Tests for subclasses of NDArrayBackedExtensionArray +""" + +import numpy as np + +from pandas import ( + CategoricalIndex, + date_range, +) +from pandas.core.arrays import ( + Categorical, + DatetimeArray, + NumpyExtensionArray, + TimedeltaArray, +) + + +class TestEmpty: + def test_empty_categorical(self): + ci = CategoricalIndex(["a", "b", "c"], ordered=True) + dtype = ci.dtype + + # case with int8 codes + shape = (4,) + result = Categorical._empty(shape, dtype=dtype) + assert isinstance(result, Categorical) + assert result.shape == shape + assert result._ndarray.dtype == np.int8 + + # case where repr would segfault if we didn't override base implementation + result = Categorical._empty((4096,), dtype=dtype) + assert isinstance(result, Categorical) + assert result.shape == (4096,) + assert result._ndarray.dtype == np.int8 + repr(result) + + # case with int16 codes + ci = CategoricalIndex(list(range(512)) * 4, ordered=False) + dtype = ci.dtype + result = Categorical._empty(shape, dtype=dtype) + assert isinstance(result, Categorical) + assert result.shape == shape + assert result._ndarray.dtype == np.int16 + + def test_empty_dt64tz(self): + dti = date_range("2016-01-01", periods=2, tz="Asia/Tokyo") + dtype = dti.dtype + + shape = (0,) + result = DatetimeArray._empty(shape, dtype=dtype) + assert result.dtype == dtype + assert isinstance(result, DatetimeArray) + assert result.shape == shape + + def test_empty_dt64(self): + shape = (3, 9) + result = DatetimeArray._empty(shape, dtype="datetime64[ns]") + assert isinstance(result, DatetimeArray) + assert result.shape == shape + + def test_empty_td64(self): + shape = (3, 9) + result = TimedeltaArray._empty(shape, dtype="m8[ns]") + assert isinstance(result, TimedeltaArray) + assert result.shape == shape + + def test_empty_pandas_array(self): + arr = NumpyExtensionArray(np.array([1, 2])) + dtype = arr.dtype + + shape = (3, 9) + result = NumpyExtensionArray._empty(shape, dtype=dtype) + assert isinstance(result, NumpyExtensionArray) + assert result.dtype == dtype + assert result.shape == shape diff --git a/pandas/tests/arrays/test_period.py b/pandas/tests/arrays/test_period.py new file mode 100644 index 0000000000000000000000000000000000000000..48453ba19e9a1f6971a2e56872ec42f1856d1dd0 --- /dev/null +++ b/pandas/tests/arrays/test_period.py @@ -0,0 +1,184 @@ +import numpy as np +import pytest + +from pandas._libs.tslibs import iNaT +from pandas._libs.tslibs.period import IncompatibleFrequency + +from pandas.core.dtypes.base import _registry as registry +from pandas.core.dtypes.dtypes import PeriodDtype + +import pandas as pd +import pandas._testing as tm +from pandas.core.arrays import PeriodArray + +# ---------------------------------------------------------------------------- +# Dtype + + +def test_registered(): + assert PeriodDtype in registry.dtypes + result = registry.find("Period[D]") + expected = PeriodDtype("D") + assert result == expected + + +# ---------------------------------------------------------------------------- +# period_array + + +def test_asi8(): + result = PeriodArray._from_sequence(["2000", "2001", None], dtype="period[D]").asi8 + expected = np.array([10957, 11323, iNaT]) + tm.assert_numpy_array_equal(result, expected) + + +def test_take_raises(): + arr = PeriodArray._from_sequence(["2000", "2001"], dtype="period[D]") + with pytest.raises(IncompatibleFrequency, match="freq"): + arr.take([0, -1], allow_fill=True, fill_value=pd.Period("2000", freq="W")) + + msg = "value should be a 'Period' or 'NaT'. Got 'str' instead" + with pytest.raises(TypeError, match=msg): + arr.take([0, -1], allow_fill=True, fill_value="foo") + + +def test_fillna_raises(): + arr = PeriodArray._from_sequence(["2000", "2001", "2002"], dtype="period[D]") + with pytest.raises(ValueError, match="Length"): + arr.fillna(arr[:2]) + + +def test_fillna_copies(): + arr = PeriodArray._from_sequence(["2000", "2001", "2002"], dtype="period[D]") + result = arr.fillna(pd.Period("2000", "D")) + assert result is not arr + + +# ---------------------------------------------------------------------------- +# setitem + + +@pytest.mark.parametrize( + "key, value, expected", + [ + ([0], pd.Period("2000", "D"), [10957, 1, 2]), + ([0], None, [iNaT, 1, 2]), + ([0], np.nan, [iNaT, 1, 2]), + ([0, 1, 2], pd.Period("2000", "D"), [10957] * 3), + ( + [0, 1, 2], + [pd.Period("2000", "D"), pd.Period("2001", "D"), pd.Period("2002", "D")], + [10957, 11323, 11688], + ), + ], +) +def test_setitem(key, value, expected): + arr = PeriodArray(np.arange(3), dtype="period[D]") + expected = PeriodArray(expected, dtype="period[D]") + arr[key] = value + tm.assert_period_array_equal(arr, expected) + + +def test_setitem_raises_incompatible_freq(): + arr = PeriodArray(np.arange(3), dtype="period[D]") + with pytest.raises(IncompatibleFrequency, match="freq"): + arr[0] = pd.Period("2000", freq="Y") + + other = PeriodArray._from_sequence(["2000", "2001"], dtype="period[Y]") + with pytest.raises(IncompatibleFrequency, match="freq"): + arr[[0, 1]] = other + + +def test_setitem_raises_length(): + arr = PeriodArray(np.arange(3), dtype="period[D]") + with pytest.raises(ValueError, match="length"): + arr[[0, 1]] = [pd.Period("2000", freq="D")] + + +def test_setitem_raises_type(): + arr = PeriodArray(np.arange(3), dtype="period[D]") + with pytest.raises(TypeError, match="int"): + arr[0] = 1 + + +# ---------------------------------------------------------------------------- +# Ops + + +def test_sub_period(): + arr = PeriodArray._from_sequence(["2000", "2001"], dtype="period[D]") + other = pd.Period("2000", freq="M") + with pytest.raises(IncompatibleFrequency, match="freq"): + arr - other + + +def test_sub_period_overflow(): + # GH#47538 + dti = pd.date_range("1677-09-22", periods=2, freq="D") + pi = dti.to_period("ns") + + per = pd.Period._from_ordinal(10**14, pi.freq) + + with pytest.raises(OverflowError, match="Overflow in int64 addition"): + pi - per + + with pytest.raises(OverflowError, match="Overflow in int64 addition"): + per - pi + + +# ---------------------------------------------------------------------------- +# Methods + + +@pytest.mark.parametrize( + "other", + [ + pd.Period("2000", freq="h"), + PeriodArray._from_sequence(["2000", "2001", "2000"], dtype="period[h]"), + ], +) +def test_where_different_freq_raises(other): + # GH#45768 The PeriodArray method raises, the Series method coerces + ser = pd.Series( + PeriodArray._from_sequence(["2000", "2001", "2002"], dtype="period[D]") + ) + cond = np.array([True, False, True]) + + with pytest.raises(IncompatibleFrequency, match="freq"): + ser.array._where(cond, other) + + res = ser.where(cond, other) + expected = ser.astype(object).where(cond, other) + tm.assert_series_equal(res, expected) + + +# ---------------------------------------------------------------------------- +# Printing + + +def test_repr_small(): + arr = PeriodArray._from_sequence(["2000", "2001"], dtype="period[D]") + result = str(arr) + expected = ( + "\n['2000-01-01', '2001-01-01']\nLength: 2, dtype: period[D]" + ) + assert result == expected + + +def test_repr_large(): + arr = PeriodArray._from_sequence(["2000", "2001"] * 500, dtype="period[D]") + result = str(arr) + expected = ( + "\n" + "['2000-01-01', '2001-01-01', '2000-01-01', '2001-01-01', " + "'2000-01-01',\n" + " '2001-01-01', '2000-01-01', '2001-01-01', '2000-01-01', " + "'2001-01-01',\n" + " ...\n" + " '2000-01-01', '2001-01-01', '2000-01-01', '2001-01-01', " + "'2000-01-01',\n" + " '2001-01-01', '2000-01-01', '2001-01-01', '2000-01-01', " + "'2001-01-01']\n" + "Length: 1000, dtype: period[D]" + ) + assert result == expected diff --git a/pandas/tests/arrays/test_timedeltas.py b/pandas/tests/arrays/test_timedeltas.py new file mode 100644 index 0000000000000000000000000000000000000000..fb7c7afdc6ff984c175c7a000e92b84b607c3b70 --- /dev/null +++ b/pandas/tests/arrays/test_timedeltas.py @@ -0,0 +1,312 @@ +from datetime import timedelta + +import numpy as np +import pytest + +import pandas as pd +from pandas import Timedelta +import pandas._testing as tm +from pandas.core.arrays import ( + DatetimeArray, + TimedeltaArray, +) + + +class TestNonNano: + @pytest.fixture(params=["s", "ms", "us"]) + def unit(self, request): + return request.param + + @pytest.fixture + def tda(self, unit): + arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]") + return TimedeltaArray._simple_new(arr, dtype=arr.dtype) + + def test_non_nano(self, unit): + arr = np.arange(5, dtype=np.int64).view(f"m8[{unit}]") + tda = TimedeltaArray._simple_new(arr, dtype=arr.dtype) + + assert tda.dtype == arr.dtype + assert tda[0].unit == unit + + def test_as_unit_raises(self, tda): + # GH#50616 + with pytest.raises(ValueError, match="Supported units"): + tda.as_unit("D") + + tdi = pd.Index(tda) + with pytest.raises(ValueError, match="Supported units"): + tdi.as_unit("D") + + @pytest.mark.parametrize("field", TimedeltaArray._field_ops) + def test_fields(self, tda, field): + as_nano = tda._ndarray.astype("m8[ns]") + tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype) + + result = getattr(tda, field) + expected = getattr(tda_nano, field) + tm.assert_numpy_array_equal(result, expected) + + def test_to_pytimedelta(self, tda): + as_nano = tda._ndarray.astype("m8[ns]") + tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype) + + result = tda.to_pytimedelta() + expected = tda_nano.to_pytimedelta() + tm.assert_numpy_array_equal(result, expected) + + def test_total_seconds(self, unit, tda): + as_nano = tda._ndarray.astype("m8[ns]") + tda_nano = TimedeltaArray._simple_new(as_nano, dtype=as_nano.dtype) + + result = tda.total_seconds() + expected = tda_nano.total_seconds() + tm.assert_numpy_array_equal(result, expected) + + def test_timedelta_array_total_seconds(self): + # GH34290 + expected = Timedelta("2 min").total_seconds() + + result = pd.array([Timedelta("2 min")]).total_seconds()[0] + assert result == expected + + def test_total_seconds_nanoseconds(self): + # issue #48521 + start_time = pd.Series(["2145-11-02 06:00:00"]).astype("datetime64[ns]") + end_time = pd.Series(["2145-11-02 07:06:00"]).astype("datetime64[ns]") + expected = (end_time - start_time).values / np.timedelta64(1, "s") + result = (end_time - start_time).dt.total_seconds().values + assert result == expected + + @pytest.mark.parametrize( + "nat", [np.datetime64("NaT", "ns"), np.datetime64("NaT", "us")] + ) + def test_add_nat_datetimelike_scalar(self, nat, tda): + result = tda + nat + assert isinstance(result, DatetimeArray) + assert result._creso == tda._creso + assert result.isna().all() + + result = nat + tda + assert isinstance(result, DatetimeArray) + assert result._creso == tda._creso + assert result.isna().all() + + def test_add_pdnat(self, tda): + result = tda + pd.NaT + assert isinstance(result, TimedeltaArray) + assert result._creso == tda._creso + assert result.isna().all() + + result = pd.NaT + tda + assert isinstance(result, TimedeltaArray) + assert result._creso == tda._creso + assert result.isna().all() + + # TODO: 2022-07-11 this is the only test that gets to DTA.tz_convert + # or tz_localize with non-nano; implement tests specific to that. + def test_add_datetimelike_scalar(self, tda, tz_naive_fixture): + ts = pd.Timestamp("2016-01-01", tz=tz_naive_fixture).as_unit("ns") + + expected = tda.as_unit("ns") + ts + res = tda + ts + tm.assert_extension_array_equal(res, expected) + res = ts + tda + tm.assert_extension_array_equal(res, expected) + + ts += Timedelta(1) # case where we can't cast losslessly + + exp_values = tda._ndarray + ts.asm8 + expected = ( + DatetimeArray._simple_new(exp_values, dtype=exp_values.dtype) + .tz_localize("UTC") + .tz_convert(ts.tz) + ) + + result = tda + ts + tm.assert_extension_array_equal(result, expected) + + result = ts + tda + tm.assert_extension_array_equal(result, expected) + + def test_mul_scalar(self, tda): + other = 2 + result = tda * other + expected = TimedeltaArray._simple_new(tda._ndarray * other, dtype=tda.dtype) + tm.assert_extension_array_equal(result, expected) + assert result._creso == tda._creso + + def test_mul_listlike(self, tda): + other = np.arange(len(tda)) + result = tda * other + expected = TimedeltaArray._simple_new(tda._ndarray * other, dtype=tda.dtype) + tm.assert_extension_array_equal(result, expected) + assert result._creso == tda._creso + + def test_mul_listlike_object(self, tda): + other = np.arange(len(tda)) + result = tda * other.astype(object) + expected = TimedeltaArray._simple_new(tda._ndarray * other, dtype=tda.dtype) + tm.assert_extension_array_equal(result, expected) + assert result._creso == tda._creso + + def test_div_numeric_scalar(self, tda): + other = 2 + result = tda / other + expected = TimedeltaArray._simple_new(tda._ndarray / other, dtype=tda.dtype) + tm.assert_extension_array_equal(result, expected) + assert result._creso == tda._creso + + def test_div_td_scalar(self, tda): + other = timedelta(seconds=1) + result = tda / other + expected = tda._ndarray / np.timedelta64(1, "s") + tm.assert_numpy_array_equal(result, expected) + + def test_div_numeric_array(self, tda): + other = np.arange(len(tda)) + result = tda / other + expected = TimedeltaArray._simple_new(tda._ndarray / other, dtype=tda.dtype) + tm.assert_extension_array_equal(result, expected) + assert result._creso == tda._creso + + def test_div_td_array(self, tda): + other = tda._ndarray + tda._ndarray[-1] + result = tda / other + expected = tda._ndarray / other + tm.assert_numpy_array_equal(result, expected) + + def test_add_timedeltaarraylike(self, tda): + tda_nano = tda.astype("m8[ns]") + + expected = tda_nano * 2 + res = tda_nano + tda + tm.assert_extension_array_equal(res, expected) + res = tda + tda_nano + tm.assert_extension_array_equal(res, expected) + + expected = tda_nano * 0 + res = tda - tda_nano + tm.assert_extension_array_equal(res, expected) + + res = tda_nano - tda + tm.assert_extension_array_equal(res, expected) + + +class TestTimedeltaArray: + def test_astype_int(self, any_int_numpy_dtype): + arr = TimedeltaArray._from_sequence( + [Timedelta("1h"), Timedelta("2h")], dtype="m8[ns]" + ) + + if np.dtype(any_int_numpy_dtype) != np.int64: + with pytest.raises(TypeError, match=r"Do obj.astype\('int64'\)"): + arr.astype(any_int_numpy_dtype) + return + + result = arr.astype(any_int_numpy_dtype) + expected = arr._ndarray.view("i8") + tm.assert_numpy_array_equal(result, expected) + + def test_setitem_clears_freq(self): + a = pd.timedelta_range("1h", periods=2, freq="h")._data + a[0] = Timedelta("1h") + assert a.freq is None + + @pytest.mark.parametrize( + "obj", + [ + Timedelta(seconds=1), + Timedelta(seconds=1).to_timedelta64(), + Timedelta(seconds=1).to_pytimedelta(), + ], + ) + def test_setitem_objects(self, obj): + # make sure we accept timedelta64 and timedelta in addition to Timedelta + tdi = pd.timedelta_range("2 Days", periods=4, freq="h") + arr = tdi._data + + arr[0] = obj + assert arr[0] == Timedelta(seconds=1) + + @pytest.mark.parametrize( + "other", + [ + 1, + np.int64(1), + 1.0, + np.datetime64("NaT"), + pd.Timestamp("2021-01-01"), + "invalid", + np.arange(10, dtype="i8") * 24 * 3600 * 10**9, + (np.arange(10) * 24 * 3600 * 10**9).view("datetime64[ns]"), + pd.Timestamp("2021-01-01").to_period("D"), + ], + ) + @pytest.mark.parametrize("index", [True, False]) + def test_searchsorted_invalid_types(self, other, index): + data = np.arange(10, dtype="i8") * 24 * 3600 * 10**9 + arr = pd.TimedeltaIndex(data, freq="D")._data + if index: + arr = pd.Index(arr) + + msg = "|".join( + [ + "searchsorted requires compatible dtype or scalar", + "value should be a 'Timedelta', 'NaT', or array of those. Got", + ] + ) + with pytest.raises(TypeError, match=msg): + arr.searchsorted(other) + + +class TestUnaryOps: + def test_abs(self): + vals = np.array([-3600 * 10**9, "NaT", 7200 * 10**9], dtype="m8[ns]") + arr = TimedeltaArray._from_sequence(vals, dtype=vals.dtype) + + evals = np.array([3600 * 10**9, "NaT", 7200 * 10**9], dtype="m8[ns]") + expected = TimedeltaArray._from_sequence(evals, dtype=evals.dtype) + + result = abs(arr) + tm.assert_timedelta_array_equal(result, expected) + + result2 = np.abs(arr) + tm.assert_timedelta_array_equal(result2, expected) + + def test_pos(self): + vals = np.array([-3600 * 10**9, "NaT", 7200 * 10**9], dtype="m8[ns]") + arr = TimedeltaArray._from_sequence(vals, dtype=vals.dtype) + + result = +arr + tm.assert_timedelta_array_equal(result, arr) + assert not tm.shares_memory(result, arr) + + result2 = np.positive(arr) + tm.assert_timedelta_array_equal(result2, arr) + assert not tm.shares_memory(result2, arr) + + def test_neg(self): + vals = np.array([-3600 * 10**9, "NaT", 7200 * 10**9], dtype="m8[ns]") + arr = TimedeltaArray._from_sequence(vals, dtype=vals.dtype) + + evals = np.array([3600 * 10**9, "NaT", -7200 * 10**9], dtype="m8[ns]") + expected = TimedeltaArray._from_sequence(evals) + + result = -arr + tm.assert_timedelta_array_equal(result, expected) + + result2 = np.negative(arr) + tm.assert_timedelta_array_equal(result2, expected) + + def test_neg_freq(self): + tdi = pd.timedelta_range("2 Days", periods=4, freq="h") + arr = tdi._data + + expected = -tdi._data + + result = -arr + tm.assert_timedelta_array_equal(result, expected) + + result2 = np.negative(arr) + tm.assert_timedelta_array_equal(result2, expected) diff --git a/pandas/tests/computation/__init__.py b/pandas/tests/computation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/computation/test_compat.py b/pandas/tests/computation/test_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..856a5b3a22a95d35cc577050f52d762b065e3ddf --- /dev/null +++ b/pandas/tests/computation/test_compat.py @@ -0,0 +1,32 @@ +import pytest + +from pandas.compat._optional import VERSIONS + +import pandas as pd +from pandas.core.computation import expr +from pandas.core.computation.engines import ENGINES +from pandas.util.version import Version + + +def test_compat(): + # test we have compat with our version of numexpr + + from pandas.core.computation.check import NUMEXPR_INSTALLED + + ne = pytest.importorskip("numexpr") + + ver = ne.__version__ + if Version(ver) < Version(VERSIONS["numexpr"]): + assert not NUMEXPR_INSTALLED + else: + assert NUMEXPR_INSTALLED + + +@pytest.mark.parametrize("engine", ENGINES) +@pytest.mark.parametrize("parser", expr.PARSERS) +def test_invalid_numexpr_version(engine, parser): + if engine == "numexpr": + pytest.importorskip("numexpr") + a, b = 1, 2 # noqa: F841 + res = pd.eval("a + b", engine=engine, parser=parser) + assert res == 3 diff --git a/pandas/tests/computation/test_eval.py b/pandas/tests/computation/test_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..e14b997310baa321f85090e6c4dfc9068867e891 --- /dev/null +++ b/pandas/tests/computation/test_eval.py @@ -0,0 +1,2044 @@ +from __future__ import annotations + +from functools import reduce +from itertools import product +import operator + +import numpy as np +import pytest + +from pandas.compat import ( + PY312, + PY314, +) +from pandas.compat._optional import import_optional_dependency +from pandas.errors import ( + NumExprClobberingError, + PerformanceWarning, + UndefinedVariableError, +) +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import ( + is_bool, + is_float, + is_list_like, + is_scalar, +) + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + date_range, + period_range, + timedelta_range, +) +import pandas._testing as tm +from pandas.core.computation import ( + expr, + pytables, +) +from pandas.core.computation.engines import ENGINES +from pandas.core.computation.expr import ( + BaseExprVisitor, + PandasExprVisitor, + PythonExprVisitor, +) +from pandas.core.computation.expressions import ( + NUMEXPR_INSTALLED, + USE_NUMEXPR, +) +from pandas.core.computation.ops import ( + ARITH_OPS_SYMS, + _binary_math_ops, + _binary_ops_dict, + _unary_math_ops, +) +from pandas.core.computation.scope import DEFAULT_GLOBALS +from pandas.util.version import Version + +numexpr = import_optional_dependency("numexpr", errors="ignore") + + +@pytest.fixture( + params=( + pytest.param( + engine, + marks=[ + pytest.mark.skipif( + engine == "numexpr" and not USE_NUMEXPR, + reason=f"numexpr enabled->{USE_NUMEXPR}, " + f"installed->{NUMEXPR_INSTALLED}", + ), + td.skip_if_no("numexpr"), + ], + ) + for engine in ENGINES + ) +) +def engine(request): + return request.param + + +@pytest.fixture(params=expr.PARSERS) +def parser(request): + return request.param + + +def _eval_single_bin(lhs, cmp1, rhs, engine): + c = _binary_ops_dict[cmp1] + if ENGINES[engine].has_neg_frac: + try: + return c(lhs, rhs) + except ValueError as e: + if str(e).startswith( + "negative number cannot be raised to a fractional power" + ): + return np.nan + raise + return c(lhs, rhs) + + +# TODO: using range(5) here is a kludge +@pytest.fixture( + params=list(range(5)), + ids=["DataFrame", "Series", "SeriesNaN", "DataFrameNaN", "float"], +) +def lhs(request): + rng = np.random.default_rng(2) + if request.param == 0: + return DataFrame(rng.standard_normal((10, 5))) + elif request.param == 1: + return Series(rng.standard_normal(5)) + elif request.param == 2: + return Series([1, 2, np.nan, np.nan, 5]) + elif request.param == 3: + nan_df1 = DataFrame(rng.standard_normal((10, 5))) + nan_df1[nan_df1 > 0.5] = np.nan + return nan_df1 + elif request.param == 4: + return rng.standard_normal() + else: + raise ValueError(f"{request.param}") + + +rhs = lhs +midhs = lhs + + +@pytest.fixture +def idx_func_dict(): + return { + "i": lambda n: Index(np.arange(n), dtype=np.int64), + "f": lambda n: Index(np.arange(n), dtype=np.float64), + "s": lambda n: Index([f"{i}_{chr(i)}" for i in range(97, 97 + n)]), + "dt": lambda n: date_range("2020-01-01", periods=n), + "td": lambda n: timedelta_range("1 day", periods=n), + "p": lambda n: period_range("2020-01-01", periods=n, freq="D"), + } + + +class TestEval: + @pytest.mark.parametrize( + "cmp1", + ["!=", "==", "<=", ">=", "<", ">"], + ids=["ne", "eq", "le", "ge", "lt", "gt"], + ) + @pytest.mark.parametrize("cmp2", [">", "<"], ids=["gt", "lt"]) + @pytest.mark.parametrize("binop", expr.BOOL_OPS_SYMS) + def test_complex_cmp_ops(self, cmp1, cmp2, binop, lhs, rhs, engine, parser): + if parser == "python" and binop in ["and", "or"]: + msg = "'BoolOp' nodes are not implemented" + ex = f"(lhs {cmp1} rhs) {binop} (lhs {cmp2} rhs)" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(ex, engine=engine, parser=parser) + return + + lhs_new = _eval_single_bin(lhs, cmp1, rhs, engine) + rhs_new = _eval_single_bin(lhs, cmp2, rhs, engine) + expected = _eval_single_bin(lhs_new, binop, rhs_new, engine) + + ex = f"(lhs {cmp1} rhs) {binop} (lhs {cmp2} rhs)" + result = pd.eval(ex, engine=engine, parser=parser) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("cmp_op", expr.CMP_OPS_SYMS) + def test_simple_cmp_ops(self, cmp_op, lhs, rhs, engine, parser): + lhs = lhs < 0 + rhs = rhs < 0 + + if parser == "python" and cmp_op in ["in", "not in"]: + msg = "'(In|NotIn)' nodes are not implemented" + ex = f"lhs {cmp_op} rhs" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(ex, engine=engine, parser=parser) + return + + ex = f"lhs {cmp_op} rhs" + msg = "|".join( + [ + r"only list-like( or dict-like)? objects are allowed to be " + r"passed to (DataFrame\.)?isin\(\), you passed a " + r"(`|')bool(`|')", + "argument of type 'bool' is not .*", + ] + ) + if cmp_op in ("in", "not in") and not is_list_like(rhs): + with pytest.raises(TypeError, match=msg): + pd.eval( + ex, + engine=engine, + parser=parser, + local_dict={"lhs": lhs, "rhs": rhs}, + ) + else: + expected = _eval_single_bin(lhs, cmp_op, rhs, engine) + result = pd.eval(ex, engine=engine, parser=parser) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("op", expr.CMP_OPS_SYMS) + def test_compound_invert_op(self, op, lhs, rhs, request, engine, parser): + if parser == "python" and op in ["in", "not in"]: + msg = "'(In|NotIn)' nodes are not implemented" + ex = f"~(lhs {op} rhs)" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(ex, engine=engine, parser=parser) + return + + if ( + is_float(lhs) + and not is_float(rhs) + and op in ["in", "not in"] + and engine == "python" + and parser == "pandas" + ): + mark = pytest.mark.xfail( + reason="Looks like expected is negative, unclear whether " + "expected is incorrect or result is incorrect" + ) + request.applymarker(mark) + skip_these = ["in", "not in"] + ex = f"~(lhs {op} rhs)" + + msg = "|".join( + [ + r"only list-like( or dict-like)? objects are allowed to be " + r"passed to (DataFrame\.)?isin\(\), you passed a " + r"(`|')float(`|')", + "argument of type 'float' is not .*", + ] + ) + if is_scalar(rhs) and op in skip_these: + with pytest.raises(TypeError, match=msg): + pd.eval( + ex, + engine=engine, + parser=parser, + local_dict={"lhs": lhs, "rhs": rhs}, + ) + else: + # compound + if is_scalar(lhs) and is_scalar(rhs): + lhs, rhs = (np.array([x]) for x in (lhs, rhs)) + expected = _eval_single_bin(lhs, op, rhs, engine) + if is_scalar(expected): + expected = not expected + else: + expected = ~expected + result = pd.eval(ex, engine=engine, parser=parser) + tm.assert_almost_equal(expected, result) + + @pytest.mark.parametrize("cmp1", ["<", ">"]) + @pytest.mark.parametrize("cmp2", ["<", ">"]) + def test_chained_cmp_op(self, cmp1, cmp2, lhs, midhs, rhs, engine, parser): + mid = midhs + if parser == "python": + ex1 = f"lhs {cmp1} mid {cmp2} rhs" + msg = "'BoolOp' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(ex1, engine=engine, parser=parser) + return + + lhs_new = _eval_single_bin(lhs, cmp1, mid, engine) + rhs_new = _eval_single_bin(mid, cmp2, rhs, engine) + + if lhs_new is not None and rhs_new is not None: + ex1 = f"lhs {cmp1} mid {cmp2} rhs" + ex2 = f"lhs {cmp1} mid and mid {cmp2} rhs" + ex3 = f"(lhs {cmp1} mid) & (mid {cmp2} rhs)" + expected = _eval_single_bin(lhs_new, "&", rhs_new, engine) + + for ex in (ex1, ex2, ex3): + result = pd.eval(ex, engine=engine, parser=parser) + + tm.assert_almost_equal(result, expected) + + @pytest.mark.parametrize( + "arith1", sorted(set(ARITH_OPS_SYMS).difference({"**", "//", "%"})) + ) + def test_binary_arith_ops(self, arith1, lhs, rhs, engine, parser): + ex = f"lhs {arith1} rhs" + result = pd.eval(ex, engine=engine, parser=parser) + expected = _eval_single_bin(lhs, arith1, rhs, engine) + + tm.assert_almost_equal(result, expected) + ex = f"lhs {arith1} rhs {arith1} rhs" + result = pd.eval(ex, engine=engine, parser=parser) + nlhs = _eval_single_bin(lhs, arith1, rhs, engine) + try: + nlhs, ghs = nlhs.align(rhs) + except (ValueError, TypeError, AttributeError): + # ValueError: series frame or frame series align + # TypeError, AttributeError: series or frame with scalar align + return + else: + if engine == "numexpr": + import numexpr as ne + + # direct numpy comparison + expected = ne.evaluate(f"nlhs {arith1} ghs") + # Update assert statement due to unreliable numerical + # precision component (GH37328) + # TODO: update testing code so that assert_almost_equal statement + # can be replaced again by the assert_numpy_array_equal statement + tm.assert_almost_equal(result.values, expected) + else: + expected = eval(f"nlhs {arith1} ghs") + tm.assert_almost_equal(result, expected) + + # modulus, pow, and floor division require special casing + + def test_modulus(self, lhs, rhs, engine, parser): + ex = r"lhs % rhs" + result = pd.eval(ex, engine=engine, parser=parser) + expected = lhs % rhs + tm.assert_almost_equal(result, expected) + + if engine == "numexpr": + import numexpr as ne + + expected = ne.evaluate(r"expected % rhs") + if isinstance(result, (DataFrame, Series)): + tm.assert_almost_equal(result.values, expected) + else: + tm.assert_almost_equal(result, expected.item()) + else: + expected = _eval_single_bin(expected, "%", rhs, engine) + tm.assert_almost_equal(result, expected) + + def test_floor_division(self, lhs, rhs, engine, parser): + ex = "lhs // rhs" + + if engine == "python" or ( + engine == "numexpr" and Version(numexpr.__version__) >= Version("2.13.0") + ): + res = pd.eval(ex, engine=engine, parser=parser) + expected = lhs // rhs + tm.assert_equal(res, expected) + else: + msg = ( + r"unsupported operand type\(s\) for //: 'VariableNode' and " + "'VariableNode'" + ) + with pytest.raises(TypeError, match=msg): + pd.eval( + ex, + local_dict={"lhs": lhs, "rhs": rhs}, + engine=engine, + parser=parser, + ) + + @td.skip_if_windows + def test_pow(self, lhs, rhs, engine, parser): + # odd failure on win32 platform, so skip + ex = "lhs ** rhs" + expected = _eval_single_bin(lhs, "**", rhs, engine) + result = pd.eval(ex, engine=engine, parser=parser) + + if ( + is_scalar(lhs) + and is_scalar(rhs) + and isinstance(expected, (complex, np.complexfloating)) + and np.isnan(result) + ): + msg = "(DataFrame.columns|numpy array) are different" + with pytest.raises(AssertionError, match=msg): + tm.assert_numpy_array_equal(result, expected) + else: + tm.assert_almost_equal(result, expected) + + ex = "(lhs ** rhs) ** rhs" + result = pd.eval(ex, engine=engine, parser=parser) + + middle = _eval_single_bin(lhs, "**", rhs, engine) + expected = _eval_single_bin(middle, "**", rhs, engine) + tm.assert_almost_equal(result, expected) + + def test_check_single_invert_op(self, lhs, engine, parser): + # simple + try: + elb = lhs.astype(bool) + except AttributeError: + elb = np.array([bool(lhs)]) + expected = ~elb + result = pd.eval("~elb", engine=engine, parser=parser) + tm.assert_almost_equal(expected, result) + + def test_frame_invert(self, engine, parser): + expr = "~lhs" + + # ~ ## + # frame + # float always raises + lhs = DataFrame(np.random.default_rng(2).standard_normal((5, 2))) + if engine == "numexpr": + msg = "couldn't find matching opcode for 'invert_dd'" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + else: + msg = "ufunc 'invert' not supported for the input types" + with pytest.raises(TypeError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + + # int raises on numexpr + lhs = DataFrame(np.random.default_rng(2).integers(5, size=(5, 2))) + if engine == "numexpr" and Version(numexpr.__version__) < Version("2.13.0"): + msg = "couldn't find matching opcode for 'invert" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + else: + expect = ~lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_frame_equal(expect, result) + + # bool always works + lhs = DataFrame(np.random.default_rng(2).standard_normal((5, 2)) > 0.5) + expect = ~lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_frame_equal(expect, result) + + # object raises + lhs = DataFrame( + {"b": ["a", 1, 2.0], "c": np.random.default_rng(2).standard_normal(3) > 0.5} + ) + if engine == "numexpr": + with pytest.raises(ValueError, match="unknown type object"): + pd.eval(expr, engine=engine, parser=parser) + else: + msg = "bad operand type for unary ~: 'str'" + with pytest.raises(TypeError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + + def test_series_invert(self, engine, parser): + # ~ #### + expr = "~lhs" + + # series + # float raises + lhs = Series(np.random.default_rng(2).standard_normal(5)) + if engine == "numexpr": + msg = "couldn't find matching opcode for 'invert_dd'" + with pytest.raises(NotImplementedError, match=msg): + result = pd.eval(expr, engine=engine, parser=parser) + else: + msg = "ufunc 'invert' not supported for the input types" + with pytest.raises(TypeError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + + # int raises on numexpr + lhs = Series(np.random.default_rng(2).integers(5, size=5)) + if engine == "numexpr" and Version(numexpr.__version__) < Version("2.13.0"): + msg = "couldn't find matching opcode for 'invert" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + else: + expect = ~lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_series_equal(expect, result) + + # bool + lhs = Series(np.random.default_rng(2).standard_normal(5) > 0.5) + expect = ~lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_series_equal(expect, result) + + # float + # int + # bool + + # object + lhs = Series(["a", 1, 2.0]) + if engine == "numexpr": + with pytest.raises(ValueError, match="unknown type object"): + pd.eval(expr, engine=engine, parser=parser) + else: + msg = "bad operand type for unary ~: 'str'" + with pytest.raises(TypeError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + + def test_frame_negate(self, engine, parser): + expr = "-lhs" + + # float + lhs = DataFrame(np.random.default_rng(2).standard_normal((5, 2))) + expect = -lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_frame_equal(expect, result) + + # int + lhs = DataFrame(np.random.default_rng(2).integers(5, size=(5, 2))) + expect = -lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_frame_equal(expect, result) + + # bool doesn't work with numexpr but works elsewhere + lhs = DataFrame(np.random.default_rng(2).standard_normal((5, 2)) > 0.5) + if engine == "numexpr": + msg = "couldn't find matching opcode for 'neg_bb'" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + else: + expect = -lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_frame_equal(expect, result) + + def test_series_negate(self, engine, parser): + expr = "-lhs" + + # float + lhs = Series(np.random.default_rng(2).standard_normal(5)) + expect = -lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_series_equal(expect, result) + + # int + lhs = Series(np.random.default_rng(2).integers(5, size=5)) + expect = -lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_series_equal(expect, result) + + # bool doesn't work with numexpr but works elsewhere + lhs = Series(np.random.default_rng(2).standard_normal(5) > 0.5) + if engine == "numexpr": + msg = "couldn't find matching opcode for 'neg_bb'" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(expr, engine=engine, parser=parser) + else: + expect = -lhs + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_series_equal(expect, result) + + @pytest.mark.parametrize( + "lhs", + [ + # Float + np.random.default_rng(2).standard_normal((5, 2)), + # Int + np.random.default_rng(2).integers(5, size=(5, 2)), + # bool doesn't work with numexpr but works elsewhere + np.array([True, False, True, False, True], dtype=np.bool_), + ], + ) + def test_frame_pos(self, lhs, engine, parser): + lhs = DataFrame(lhs) + expr = "+lhs" + expect = lhs + + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_frame_equal(expect, result) + + @pytest.mark.parametrize( + "lhs", + [ + # Float + np.random.default_rng(2).standard_normal(5), + # Int + np.random.default_rng(2).integers(5, size=5), + # bool doesn't work with numexpr but works elsewhere + np.array([True, False, True, False, True], dtype=np.bool_), + ], + ) + def test_series_pos(self, lhs, engine, parser): + lhs = Series(lhs) + expr = "+lhs" + expect = lhs + + result = pd.eval(expr, engine=engine, parser=parser) + tm.assert_series_equal(expect, result) + + def test_scalar_unary(self, engine, parser): + msg = "bad operand type for unary ~: 'float'" + warn = None + if (PY314 and engine == "numexpr" and parser == "pandas") or ( + PY312 and not (engine == "numexpr" and parser == "pandas") + ): + warn = DeprecationWarning + with pytest.raises(TypeError, match=msg): + pd.eval("~1.0", engine=engine, parser=parser) + + assert pd.eval("-1.0", parser=parser, engine=engine) == -1.0 + assert pd.eval("+1.0", parser=parser, engine=engine) == +1.0 + assert pd.eval("~1", parser=parser, engine=engine) == ~1 + assert pd.eval("-1", parser=parser, engine=engine) == -1 + assert pd.eval("+1", parser=parser, engine=engine) == +1 + with tm.assert_produces_warning( + warn, match="Bitwise inversion", check_stacklevel=False + ): + assert pd.eval("~True", parser=parser, engine=engine) == ~True + with tm.assert_produces_warning( + warn, match="Bitwise inversion", check_stacklevel=False + ): + assert pd.eval("~False", parser=parser, engine=engine) == ~False + assert pd.eval("-True", parser=parser, engine=engine) == -True + assert pd.eval("-False", parser=parser, engine=engine) == -False + assert pd.eval("+True", parser=parser, engine=engine) == +True + assert pd.eval("+False", parser=parser, engine=engine) == +False + + def test_unary_in_array(self): + # GH 11235 + # TODO: 2022-01-29: result return list with numexpr 2.7.3 in CI + # but cannot reproduce locally + result = np.array( + pd.eval("[-True, True, +True, -False, False, +False, -37, 37, ~37, +37]"), + dtype=np.object_, + ) + expected = np.array( + [ + -True, + True, + +True, + -False, + False, + +False, + -37, + 37, + ~37, + +37, + ], + dtype=np.object_, + ) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("expr", ["x < -0.1", "-5 > x"]) + def test_float_comparison_bin_op(self, float_numpy_dtype, expr): + # GH 16363 + df = DataFrame({"x": np.array([0], dtype=float_numpy_dtype)}) + res = df.eval(expr) + assert res.values == np.array([False]) + + def test_unary_in_function(self): + # GH 46471 + df = DataFrame({"x": [0, 1, np.nan]}) + + result = df.eval("x.fillna(-1)") + expected = df.x.fillna(-1) + # column name becomes None if using numexpr + # only check names when the engine is not numexpr + tm.assert_series_equal(result, expected, check_names=not USE_NUMEXPR) + + result = df.eval("x.shift(1, fill_value=-1)") + expected = df.x.shift(1, fill_value=-1) + tm.assert_series_equal(result, expected, check_names=not USE_NUMEXPR) + + @pytest.mark.parametrize( + "ex", + ( + "1 or 2", + "1 and 2", + "a and b", + "a or b", + "1 or 2 and (3 + 2) > 3", + "2 * x > 2 or 1 and 2", + "2 * df > 3 and 1 or a", + ), + ) + def test_disallow_scalar_bool_ops(self, ex, engine, parser): + x, a, b = np.random.default_rng(2).standard_normal(3), 1, 2 # noqa: F841 + df = DataFrame(np.random.default_rng(2).standard_normal((3, 2))) # noqa: F841 + + msg = "cannot evaluate scalar only bool ops|'BoolOp' nodes are not" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(ex, engine=engine, parser=parser) + + def test_identical(self, engine, parser): + # see gh-10546 + x = 1 + result = pd.eval("x", engine=engine, parser=parser) + assert result == 1 + assert is_scalar(result) + + x = 1.5 + result = pd.eval("x", engine=engine, parser=parser) + assert result == 1.5 + assert is_scalar(result) + + x = False + result = pd.eval("x", engine=engine, parser=parser) + assert not result + assert is_bool(result) + assert is_scalar(result) + + x = np.array([1]) + result = pd.eval("x", engine=engine, parser=parser) + tm.assert_numpy_array_equal(result, np.array([1])) + assert result.shape == (1,) + + x = np.array([1.5]) + result = pd.eval("x", engine=engine, parser=parser) + tm.assert_numpy_array_equal(result, np.array([1.5])) + assert result.shape == (1,) + + x = np.array([False]) # noqa: F841 + result = pd.eval("x", engine=engine, parser=parser) + tm.assert_numpy_array_equal(result, np.array([False])) + assert result.shape == (1,) + + def test_line_continuation(self, engine, parser): + # GH 11149 + exp = """1 + 2 * \ + 5 - 1 + 2 """ + result = pd.eval(exp, engine=engine, parser=parser) + assert result == 12 + + def test_float_truncation(self, engine, parser): + # GH 14241 + exp = "1000000000.006" + result = pd.eval(exp, engine=engine, parser=parser) + expected = np.float64(exp) + assert result == expected + + df = DataFrame({"A": [1000000000.0009, 1000000000.0011, 1000000000.0015]}) + cutoff = 1000000000.0006 + result = df.query(f"A < {cutoff:.4f}") + assert result.empty + + cutoff = 1000000000.0010 + result = df.query(f"A > {cutoff:.4f}") + expected = df.loc[[1, 2], :] + tm.assert_frame_equal(expected, result) + + exact = 1000000000.0011 + result = df.query(f"A == {exact:.4f}") + expected = df.loc[[1], :] + tm.assert_frame_equal(expected, result) + + def test_disallow_python_keywords(self): + # GH 18221 + df = DataFrame([[0, 0, 0]], columns=["foo", "bar", "class"]) + msg = "Python keyword not valid identifier in numexpr query" + with pytest.raises(SyntaxError, match=msg): + df.query("class == 0") + + df = DataFrame() + df.index.name = "lambda" + with pytest.raises(SyntaxError, match=msg): + df.query("lambda == 0") + + def test_true_false_logic(self): + # GH 25823 + # This behavior is deprecated in Python 3.12 + with tm.maybe_produces_warning( + DeprecationWarning, PY312, check_stacklevel=False + ): + assert pd.eval("not True") == -2 + assert pd.eval("not False") == -1 + assert pd.eval("True and not True") == 0 + + def test_and_logic_string_match(self): + # GH 25823 + event = Series({"a": "hello"}) + assert pd.eval(f"{event.str.match('hello').a}") + assert pd.eval(f"{event.str.match('hello').a and event.str.match('hello').a}") + + def test_eval_keep_name(self, engine, parser): + df = Series([2, 15, 28], name="a").to_frame() + res = df.eval("a + a", engine=engine, parser=parser) + expected = Series([4, 30, 56], name="a") + tm.assert_series_equal(expected, res) + + def test_eval_unmatching_names(self, engine, parser): + variable_name = Series([42], name="series_name") + res = pd.eval("variable_name + 0", engine=engine, parser=parser) + tm.assert_series_equal(variable_name, res) + + +# ------------------------------------- +# gh-12388: Typecasting rules consistency with python + + +class TestTypeCasting: + @pytest.mark.parametrize("op", ["+", "-", "*", "**", "/"]) + # maybe someday... numexpr has too many upcasting rules now + # chain(*(np.core.sctypes[x] for x in ['uint', 'int', 'float'])) + @pytest.mark.parametrize("left_right", [("df", "3"), ("3", "df")]) + def test_binop_typecasting( + self, engine, parser, op, complex_or_float_dtype, left_right, request + ): + # GH#21374 + dtype = complex_or_float_dtype + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3)), dtype=dtype) + left, right = left_right + s = f"{left} {op} {right}" + res = pd.eval(s, engine=engine, parser=parser) + if dtype == "complex64" and engine == "numexpr": + mark = pytest.mark.xfail( + reason="numexpr issue with complex that are upcast " + "to complex 128 " + "https://github.com/pydata/numexpr/issues/492" + ) + request.applymarker(mark) + assert df.values.dtype == dtype + assert res.values.dtype == dtype + tm.assert_frame_equal(res, eval(s), check_exact=False) + + +# ------------------------------------- +# Basic and complex alignment + + +def should_warn(*args): + not_mono = not any(map(operator.attrgetter("is_monotonic_increasing"), args)) + only_one_dt = reduce( + operator.xor, (issubclass(x.dtype.type, np.datetime64) for x in args) + ) + return not_mono and only_one_dt + + +class TestAlignment: + index_types = ["i", "s", "dt"] + lhs_index_types = [*index_types, "s"] # 'p' + + def test_align_nested_unary_op(self, engine, parser): + s = "df * ~2" + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + res = pd.eval(s, engine=engine, parser=parser) + tm.assert_frame_equal(res, df * ~2) + + @pytest.mark.filterwarnings("always::RuntimeWarning") + @pytest.mark.parametrize("lr_idx_type", lhs_index_types) + @pytest.mark.parametrize("rr_idx_type", index_types) + @pytest.mark.parametrize("c_idx_type", index_types) + def test_basic_frame_alignment( + self, engine, parser, lr_idx_type, rr_idx_type, c_idx_type, idx_func_dict + ): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 10)), + index=idx_func_dict[lr_idx_type](10), + columns=idx_func_dict[c_idx_type](10), + ) + df2 = DataFrame( + np.random.default_rng(2).standard_normal((20, 10)), + index=idx_func_dict[rr_idx_type](20), + columns=idx_func_dict[c_idx_type](10), + ) + # only warns if not monotonic and not sortable + if should_warn(df.index, df2.index): + with tm.assert_produces_warning(RuntimeWarning): + res = pd.eval("df + df2", engine=engine, parser=parser) + else: + res = pd.eval("df + df2", engine=engine, parser=parser) + tm.assert_frame_equal(res, df + df2) + + @pytest.mark.parametrize("r_idx_type", lhs_index_types) + @pytest.mark.parametrize("c_idx_type", lhs_index_types) + def test_frame_comparison( + self, engine, parser, r_idx_type, c_idx_type, idx_func_dict + ): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 10)), + index=idx_func_dict[r_idx_type](10), + columns=idx_func_dict[c_idx_type](10), + ) + res = pd.eval("df < 2", engine=engine, parser=parser) + tm.assert_frame_equal(res, df < 2) + + df3 = DataFrame( + np.random.default_rng(2).standard_normal(df.shape), + index=df.index, + columns=df.columns, + ) + res = pd.eval("df < df3", engine=engine, parser=parser) + tm.assert_frame_equal(res, df < df3) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize("r1", lhs_index_types) + @pytest.mark.parametrize("c1", index_types) + @pytest.mark.parametrize("r2", index_types) + @pytest.mark.parametrize("c2", index_types) + def test_medium_complex_frame_alignment( + self, engine, parser, r1, c1, r2, c2, idx_func_dict + ): + df = DataFrame( + np.random.default_rng(2).standard_normal((3, 2)), + index=idx_func_dict[r1](3), + columns=idx_func_dict[c1](2), + ) + df2 = DataFrame( + np.random.default_rng(2).standard_normal((4, 2)), + index=idx_func_dict[r2](4), + columns=idx_func_dict[c2](2), + ) + df3 = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), + index=idx_func_dict[r2](5), + columns=idx_func_dict[c2](2), + ) + if should_warn(df.index, df2.index, df3.index): + with tm.assert_produces_warning(RuntimeWarning): + res = pd.eval("df + df2 + df3", engine=engine, parser=parser) + else: + res = pd.eval("df + df2 + df3", engine=engine, parser=parser) + tm.assert_frame_equal(res, df + df2 + df3) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize("index_name", ["index", "columns"]) + @pytest.mark.parametrize("c_idx_type", index_types) + @pytest.mark.parametrize("r_idx_type", lhs_index_types) + def test_basic_frame_series_alignment( + self, engine, parser, index_name, r_idx_type, c_idx_type, idx_func_dict + ): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 10)), + index=idx_func_dict[r_idx_type](10), + columns=idx_func_dict[c_idx_type](10), + ) + index = getattr(df, index_name) + s = Series(np.random.default_rng(2).standard_normal(5), index[:5]) + + if should_warn(df.index, s.index): + with tm.assert_produces_warning(RuntimeWarning): + res = pd.eval("df + s", engine=engine, parser=parser) + else: + res = pd.eval("df + s", engine=engine, parser=parser) + + if r_idx_type == "dt" or c_idx_type == "dt": + expected = df.add(s) if engine == "numexpr" else df + s + else: + expected = df + s + tm.assert_frame_equal(res, expected) + + @pytest.mark.parametrize("index_name", ["index", "columns"]) + @pytest.mark.parametrize( + "r_idx_type, c_idx_type", + [*list(product(["i", "s"], ["i", "s"])), ("dt", "dt")], + ) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_basic_series_frame_alignment( + self, request, engine, parser, index_name, r_idx_type, c_idx_type, idx_func_dict + ): + if ( + engine == "numexpr" + and parser in ("pandas", "python") + and index_name == "index" + and r_idx_type == "i" + and c_idx_type == "s" + ): + reason = ( + f"Flaky column ordering when engine={engine}, " + f"parser={parser}, index_name={index_name}, " + f"r_idx_type={r_idx_type}, c_idx_type={c_idx_type}" + ) + request.applymarker(pytest.mark.xfail(reason=reason, strict=False)) + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 7)), + index=idx_func_dict[r_idx_type](10), + columns=idx_func_dict[c_idx_type](7), + ) + index = getattr(df, index_name) + s = Series(np.random.default_rng(2).standard_normal(5), index[:5]) + if should_warn(s.index, df.index): + with tm.assert_produces_warning(RuntimeWarning): + res = pd.eval("s + df", engine=engine, parser=parser) + else: + res = pd.eval("s + df", engine=engine, parser=parser) + + if r_idx_type == "dt" or c_idx_type == "dt": + expected = df.add(s) if engine == "numexpr" else s + df + else: + expected = s + df + tm.assert_frame_equal(res, expected) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize("c_idx_type", index_types) + @pytest.mark.parametrize("r_idx_type", lhs_index_types) + @pytest.mark.parametrize("index_name", ["index", "columns"]) + @pytest.mark.parametrize("op", ["+", "*"]) + def test_series_frame_commutativity( + self, engine, parser, index_name, op, r_idx_type, c_idx_type, idx_func_dict + ): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 10)), + index=idx_func_dict[r_idx_type](10), + columns=idx_func_dict[c_idx_type](10), + ) + index = getattr(df, index_name) + s = Series(np.random.default_rng(2).standard_normal(5), index[:5]) + + lhs = f"s {op} df" + rhs = f"df {op} s" + if should_warn(df.index, s.index): + with tm.assert_produces_warning(RuntimeWarning): + a = pd.eval(lhs, engine=engine, parser=parser) + with tm.assert_produces_warning(RuntimeWarning): + b = pd.eval(rhs, engine=engine, parser=parser) + else: + a = pd.eval(lhs, engine=engine, parser=parser) + b = pd.eval(rhs, engine=engine, parser=parser) + + if r_idx_type != "dt" and c_idx_type != "dt": + if engine == "numexpr": + tm.assert_frame_equal(a, b) + + @pytest.mark.filterwarnings("always::RuntimeWarning") + @pytest.mark.parametrize("r1", lhs_index_types) + @pytest.mark.parametrize("c1", index_types) + @pytest.mark.parametrize("r2", index_types) + @pytest.mark.parametrize("c2", index_types) + def test_complex_series_frame_alignment( + self, engine, parser, r1, c1, r2, c2, idx_func_dict + ): + n = 3 + m1 = 5 + m2 = 2 * m1 + df = DataFrame( + np.random.default_rng(2).standard_normal((m1, n)), + index=idx_func_dict[r1](m1), + columns=idx_func_dict[c1](n), + ) + df2 = DataFrame( + np.random.default_rng(2).standard_normal((m2, n)), + index=idx_func_dict[r2](m2), + columns=idx_func_dict[c2](n), + ) + index = df2.columns + ser = Series(np.random.default_rng(2).standard_normal(n), index[:n]) + + if r2 == "dt" or c2 == "dt": + if engine == "numexpr": + expected2 = df2.add(ser) + else: + expected2 = df2 + ser + else: + expected2 = df2 + ser + + if r1 == "dt" or c1 == "dt": + if engine == "numexpr": + expected = expected2.add(df) + else: + expected = expected2 + df + else: + expected = expected2 + df + + if should_warn(df2.index, ser.index, df.index): + with tm.assert_produces_warning(RuntimeWarning): + res = pd.eval("df2 + ser + df", engine=engine, parser=parser) + else: + res = pd.eval("df2 + ser + df", engine=engine, parser=parser) + assert res.shape == expected.shape + tm.assert_frame_equal(res, expected) + + def test_performance_warning_for_poor_alignment( + self, performance_warning, engine, parser + ): + df = DataFrame(np.random.default_rng(2).standard_normal((1000, 10))) + s = Series(np.random.default_rng(2).standard_normal(10000)) + if engine == "numexpr" and performance_warning: + seen = PerformanceWarning + else: + seen = False + + msg = "Alignment difference on axis 1 is larger than an order of magnitude" + with tm.assert_produces_warning(seen, match=msg): + pd.eval("df + s", engine=engine, parser=parser) + + s = Series(np.random.default_rng(2).standard_normal(1000)) + with tm.assert_produces_warning(False): + pd.eval("df + s", engine=engine, parser=parser) + + df = DataFrame(np.random.default_rng(2).standard_normal((10, 10000))) + s = Series(np.random.default_rng(2).standard_normal(10000)) + with tm.assert_produces_warning(False): + pd.eval("df + s", engine=engine, parser=parser) + + df = DataFrame(np.random.default_rng(2).standard_normal((10, 10))) + s = Series(np.random.default_rng(2).standard_normal(10000)) + + is_python_engine = engine == "python" + + if not is_python_engine and performance_warning: + wrn = PerformanceWarning + else: + wrn = False + + with tm.assert_produces_warning(wrn, match=msg) as w: + pd.eval("df + s", engine=engine, parser=parser) + + if not is_python_engine and performance_warning: + assert len(w) == 1 + msg = str(w[0].message) + logged = np.log10(s.size - df.shape[1]) + expected = ( + f"Alignment difference on axis 1 is larger " + f"than an order of magnitude on term 'df', " + f"by more than {logged:.4g}; performance may suffer." + ) + assert msg == expected + + +# ------------------------------------ +# Slightly more complex ops + + +class TestOperations: + def eval(self, *args, **kwargs): + kwargs["level"] = kwargs.pop("level", 0) + 1 + return pd.eval(*args, **kwargs) + + def test_simple_arith_ops(self, engine, parser): + exclude_arith = [] + if parser == "python": + exclude_arith = ["in", "not in"] + + arith_ops = [ + op + for op in expr.ARITH_OPS_SYMS + expr.CMP_OPS_SYMS + if op not in exclude_arith + ] + + ops = (op for op in arith_ops if op != "//") + + for op in ops: + ex = f"1 {op} 1" + ex2 = f"x {op} 1" + ex3 = f"1 {op} (x + 1)" + + if op in ("in", "not in"): + msg = "argument of type 'int' is not .*" + with pytest.raises(TypeError, match=msg): + pd.eval(ex, engine=engine, parser=parser) + else: + expec = _eval_single_bin(1, op, 1, engine) + x = self.eval(ex, engine=engine, parser=parser) + assert x == expec + + expec = _eval_single_bin(x, op, 1, engine) + y = self.eval(ex2, local_dict={"x": x}, engine=engine, parser=parser) + assert y == expec + + expec = _eval_single_bin(1, op, x + 1, engine) + y = self.eval(ex3, local_dict={"x": x}, engine=engine, parser=parser) + assert y == expec + + @pytest.mark.parametrize("rhs", [True, False]) + @pytest.mark.parametrize("lhs", [True, False]) + @pytest.mark.parametrize("op", expr.BOOL_OPS_SYMS) + def test_simple_bool_ops(self, rhs, lhs, op): + ex = f"{lhs} {op} {rhs}" + + if parser == "python" and op in ["and", "or"]: + msg = "'BoolOp' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + self.eval(ex) + return + + res = self.eval(ex) + exp = eval(ex) + assert res == exp + + @pytest.mark.parametrize("rhs", [True, False]) + @pytest.mark.parametrize("lhs", [True, False]) + @pytest.mark.parametrize("op", expr.BOOL_OPS_SYMS) + def test_bool_ops_with_constants(self, rhs, lhs, op): + ex = f"{lhs} {op} {rhs}" + + if parser == "python" and op in ["and", "or"]: + msg = "'BoolOp' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + self.eval(ex) + return + + res = self.eval(ex) + exp = eval(ex) + assert res == exp + + def test_4d_ndarray_fails(self): + x = np.random.default_rng(2).standard_normal((3, 4, 5, 6)) + y = Series(np.random.default_rng(2).standard_normal(10)) + msg = "N-dimensional objects, where N > 2, are not supported with eval" + with pytest.raises(NotImplementedError, match=msg): + self.eval("x + y", local_dict={"x": x, "y": y}) + + def test_constant(self): + x = self.eval("1") + assert x == 1 + + def test_single_variable(self): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2))) + df2 = self.eval("df", local_dict={"df": df}) + tm.assert_frame_equal(df, df2) + + def test_failing_subscript_with_name_error(self): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) # noqa: F841 + with pytest.raises(NameError, match="name 'x' is not defined"): + self.eval("df[x > 2] > 2") + + def test_lhs_expression_subscript(self): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + result = self.eval("(df + 1)[df > 2]", local_dict={"df": df}) + expected = (df + 1)[df > 2] + tm.assert_frame_equal(result, expected) + + def test_attr_expression(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), columns=list("abc") + ) + expr1 = "df.a < df.b" + expec1 = df.a < df.b + expr2 = "df.a + df.b + df.c" + expec2 = df.a + df.b + df.c + expr3 = "df.a + df.b + df.c[df.b < 0]" + expec3 = df.a + df.b + df.c[df.b < 0] + exprs = expr1, expr2, expr3 + expecs = expec1, expec2, expec3 + for e, expec in zip(exprs, expecs, strict=True): + tm.assert_series_equal(expec, self.eval(e, local_dict={"df": df})) + + def test_assignment_fails(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), columns=list("abc") + ) + df2 = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + expr1 = "df = df2" + msg = "cannot assign without a target object" + with pytest.raises(ValueError, match=msg): + self.eval(expr1, local_dict={"df": df, "df2": df2}) + + def test_assignment_column_multiple_raise(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + # multiple assignees + with pytest.raises(SyntaxError, match="invalid syntax"): + df.eval("d c = a + b") + + def test_assignment_column_invalid_assign(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + # invalid assignees + msg = "left hand side of an assignment must be a single name" + with pytest.raises(SyntaxError, match=msg): + df.eval("d,c = a + b") + + def test_assignment_column_invalid_assign_function_call(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + msg = "cannot assign to function call" + with pytest.raises(SyntaxError, match=msg): + df.eval('Timestamp("20131001") = a + b') + + def test_assignment_single_assign_existing(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + # single assignment - existing variable + expected = df.copy() + expected["a"] = expected["a"] + expected["b"] + df.eval("a = a + b", inplace=True) + tm.assert_frame_equal(df, expected) + + def test_assignment_single_assign_new(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + # single assignment - new variable + expected = df.copy() + expected["c"] = expected["a"] + expected["b"] + df.eval("c = a + b", inplace=True) + tm.assert_frame_equal(df, expected) + + def test_assignment_single_assign_local_overlap(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + df = df.copy() + a = 1 # noqa: F841 + df.eval("a = 1 + b", inplace=True) + + expected = df.copy() + expected["a"] = 1 + expected["b"] + tm.assert_frame_equal(df, expected) + + def test_assignment_single_assign_name(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + + a = 1 # noqa: F841 + old_a = df.a.copy() + df.eval("a = a + b", inplace=True) + result = old_a + df.b + tm.assert_series_equal(result, df.a, check_names=False) + assert result.name is None + + def test_assignment_multiple_raises(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + # multiple assignment + df.eval("c = a + b", inplace=True) + msg = "can only assign a single expression" + with pytest.raises(SyntaxError, match=msg): + df.eval("c = a = b") + + def test_assignment_explicit(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + # explicit targets + self.eval("c = df.a + df.b", local_dict={"df": df}, target=df, inplace=True) + expected = df.copy() + expected["c"] = expected["a"] + expected["b"] + tm.assert_frame_equal(df, expected) + + def test_column_in(self, engine): + # GH 11235 + df = DataFrame({"a": [11], "b": [-32]}) + result = df.eval("a in [11, -32]", engine=engine) + expected = Series([True], name="a") + tm.assert_series_equal(result, expected) + + @pytest.mark.xfail(reason="Unknown: Omitted test_ in name prior.") + def test_assignment_not_inplace(self): + # see gh-9297 + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=list("ab") + ) + + actual = df.eval("c = a + b", inplace=False) + assert actual is not None + + expected = df.copy() + expected["c"] = expected["a"] + expected["b"] + tm.assert_frame_equal(df, expected) + + def test_multi_line_expression(self): + # GH 11149 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + expected = df.copy() + + expected["c"] = expected["a"] + expected["b"] + expected["d"] = expected["c"] + expected["b"] + answer = df.eval( + """ + c = a + b + d = c + b""", + inplace=True, + ) + tm.assert_frame_equal(expected, df) + assert answer is None + + expected["a"] = expected["a"] - 1 + expected["e"] = expected["a"] + 2 + answer = df.eval( + """ + a = a - 1 + e = a + 2""", + inplace=True, + ) + tm.assert_frame_equal(expected, df) + assert answer is None + + # multi-line not valid if not all assignments + msg = "Multi-line expressions are only valid if all expressions contain" + with pytest.raises(ValueError, match=msg): + df.eval( + """ + a = b + 2 + b - 2""", + inplace=False, + ) + + def test_multi_line_expression_not_inplace(self): + # GH 11149 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + expected = df.copy() + + expected["c"] = expected["a"] + expected["b"] + expected["d"] = expected["c"] + expected["b"] + df = df.eval( + """ + c = a + b + d = c + b""", + inplace=False, + ) + tm.assert_frame_equal(expected, df) + + expected["a"] = expected["a"] - 1 + expected["e"] = expected["a"] + 2 + df = df.eval( + """ + a = a - 1 + e = a + 2""", + inplace=False, + ) + tm.assert_frame_equal(expected, df) + + def test_multi_line_expression_local_variable(self): + # GH 15342 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + expected = df.copy() + + local_var = 7 + expected["c"] = expected["a"] * local_var + expected["d"] = expected["c"] + local_var + answer = df.eval( + """ + c = a * @local_var + d = c + @local_var + """, + inplace=True, + ) + tm.assert_frame_equal(expected, df) + assert answer is None + + def test_multi_line_expression_callable_local_variable(self): + # 26426 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + def local_func(a, b): + return b + + expected = df.copy() + expected["c"] = expected["a"] * local_func(1, 7) + expected["d"] = expected["c"] + local_func(1, 7) + answer = df.eval( + """ + c = a * @local_func(1, 7) + d = c + @local_func(1, 7) + """, + inplace=True, + ) + tm.assert_frame_equal(expected, df) + assert answer is None + + def test_multi_line_expression_callable_local_variable_with_kwargs(self): + # 26426 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + def local_func(a, b): + return b + + expected = df.copy() + expected["c"] = expected["a"] * local_func(b=7, a=1) + expected["d"] = expected["c"] + local_func(b=7, a=1) + answer = df.eval( + """ + c = a * @local_func(b=7, a=1) + d = c + @local_func(b=7, a=1) + """, + inplace=True, + ) + tm.assert_frame_equal(expected, df) + assert answer is None + + def test_assignment_in_query(self): + # GH 8664 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df_orig = df.copy() + msg = "cannot assign without a target object" + with pytest.raises(ValueError, match=msg): + df.query("a = 1") + tm.assert_frame_equal(df, df_orig) + + def test_query_inplace(self): + # see gh-11149 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + expected = df.copy() + expected = expected[expected["a"] == 2] + df.query("a == 2", inplace=True) + tm.assert_frame_equal(expected, df) + + df = {} + expected = {"a": 3} + + self.eval("a = 1 + 2", target=df, inplace=True) + tm.assert_dict_equal(df, expected) + + @pytest.mark.parametrize("invalid_target", [1, "cat", [1, 2], np.array([]), (1, 3)]) + def test_cannot_item_assign(self, invalid_target): + msg = "Cannot assign expression output to target" + expression = "a = 1 + 2" + + with pytest.raises(ValueError, match=msg): + self.eval(expression, target=invalid_target, inplace=True) + + if hasattr(invalid_target, "copy"): + with pytest.raises(ValueError, match=msg): + self.eval(expression, target=invalid_target, inplace=False) + + @pytest.mark.parametrize("invalid_target", [1, "cat", (1, 3)]) + def test_cannot_copy_item(self, invalid_target): + msg = "Cannot return a copy of the target" + expression = "a = 1 + 2" + + with pytest.raises(ValueError, match=msg): + self.eval(expression, target=invalid_target, inplace=False) + + @pytest.mark.parametrize("target", [1, "cat", [1, 2], np.array([]), (1, 3), {1: 2}]) + def test_inplace_no_assignment(self, target): + expression = "1 + 2" + + assert self.eval(expression, target=target, inplace=False) == 3 + + msg = "Cannot operate inplace if there is no assignment" + with pytest.raises(ValueError, match=msg): + self.eval(expression, target=target, inplace=True) + + def test_basic_period_index_boolean_expression(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((2, 2)), + columns=period_range("2020-01-01", freq="D", periods=2), + ) + e = df < 2 + r = self.eval("df < 2", local_dict={"df": df}) + x = df < 2 + + tm.assert_frame_equal(r, e) + tm.assert_frame_equal(x, e) + + def test_basic_period_index_subscript_expression(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((2, 2)), + columns=period_range("2020-01-01", freq="D", periods=2), + ) + r = self.eval("df[df < 2 + 3]", local_dict={"df": df}) + e = df[df < 2 + 3] + tm.assert_frame_equal(r, e) + + def test_nested_period_index_subscript_expression(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((2, 2)), + columns=period_range("2020-01-01", freq="D", periods=2), + ) + r = self.eval("df[df[df < 2] < 2] + df * 2", local_dict={"df": df}) + e = df[df[df < 2] < 2] + df * 2 + tm.assert_frame_equal(r, e) + + def test_date_boolean(self, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + df["dates1"] = date_range("1/1/2012", periods=5) + res = self.eval( + "df.dates1 < 20130101", + local_dict={"df": df}, + engine=engine, + parser=parser, + ) + expec = df.dates1 < "20130101" + tm.assert_series_equal(res, expec) + + def test_simple_in_ops(self, engine, parser): + if parser != "python": + res = pd.eval("1 in [1, 2]", engine=engine, parser=parser) + assert res + + res = pd.eval("2 in (1, 2)", engine=engine, parser=parser) + assert res + + res = pd.eval("3 in (1, 2)", engine=engine, parser=parser) + assert not res + + res = pd.eval("3 not in (1, 2)", engine=engine, parser=parser) + assert res + + res = pd.eval("[3] not in (1, 2)", engine=engine, parser=parser) + assert res + + res = pd.eval("[3] in ([3], 2)", engine=engine, parser=parser) + assert res + + res = pd.eval("[[3]] in [[[3]], 2]", engine=engine, parser=parser) + assert res + + res = pd.eval("(3,) in [(3,), 2]", engine=engine, parser=parser) + assert res + + res = pd.eval("(3,) not in [(3,), 2]", engine=engine, parser=parser) + assert not res + + res = pd.eval("[(3,)] in [[(3,)], 2]", engine=engine, parser=parser) + assert res + else: + msg = "'In' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + pd.eval("1 in [1, 2]", engine=engine, parser=parser) + with pytest.raises(NotImplementedError, match=msg): + pd.eval("2 in (1, 2)", engine=engine, parser=parser) + with pytest.raises(NotImplementedError, match=msg): + pd.eval("3 in (1, 2)", engine=engine, parser=parser) + with pytest.raises(NotImplementedError, match=msg): + pd.eval("[(3,)] in (1, 2, [(3,)])", engine=engine, parser=parser) + msg = "'NotIn' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + pd.eval("3 not in (1, 2)", engine=engine, parser=parser) + with pytest.raises(NotImplementedError, match=msg): + pd.eval("[3] not in (1, 2, [[3]])", engine=engine, parser=parser) + + def test_check_many_exprs(self, engine, parser): + a = 1 # noqa: F841 + expr = " * ".join("a" * 33) + expected = 1 + res = pd.eval(expr, engine=engine, parser=parser) + assert res == expected + + @pytest.mark.parametrize( + "expr", + [ + "df > 2 and df > 3", + "df > 2 or df > 3", + "not df > 2", + ], + ) + def test_fails_and_or_not(self, expr, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + if parser == "python": + msg = "'BoolOp' nodes are not implemented" + if "not" in expr: + msg = "'Not' nodes are not implemented" + + with pytest.raises(NotImplementedError, match=msg): + pd.eval( + expr, + local_dict={"df": df}, + parser=parser, + engine=engine, + ) + else: + # smoke-test, should not raise + pd.eval( + expr, + local_dict={"df": df}, + parser=parser, + engine=engine, + ) + + @pytest.mark.parametrize("char", ["|", "&"]) + def test_fails_ampersand_pipe(self, char, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) # noqa: F841 + ex = f"(df + 2)[df > 1] > 0 {char} (df > 0)" + if parser == "python": + msg = "cannot evaluate scalar only bool ops" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(ex, parser=parser, engine=engine) + else: + # smoke-test, should not raise + pd.eval(ex, parser=parser, engine=engine) + + +class TestMath: + def eval(self, *args, **kwargs): + kwargs["level"] = kwargs.pop("level", 0) + 1 + return pd.eval(*args, **kwargs) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize("fn", _unary_math_ops) + def test_unary_functions(self, fn, engine, parser): + df = DataFrame({"a": np.random.default_rng(2).standard_normal(10)}) + a = df.a + + expr = f"{fn}(a)" + got = self.eval(expr, engine=engine, parser=parser) + with np.errstate(all="ignore"): + expect = getattr(np, fn)(a) + tm.assert_series_equal(got, expect) + + @pytest.mark.parametrize("fn", _binary_math_ops) + def test_binary_functions(self, fn, engine, parser): + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(10), + "b": np.random.default_rng(2).standard_normal(10), + } + ) + a = df.a + b = df.b + + expr = f"{fn}(a, b)" + got = self.eval(expr, engine=engine, parser=parser) + with np.errstate(all="ignore"): + expect = getattr(np, fn)(a, b) + tm.assert_almost_equal(got, expect) + + def test_df_use_case(self, engine, parser): + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(10), + "b": np.random.default_rng(2).standard_normal(10), + } + ) + df.eval( + "e = arctan2(sin(a), b)", + engine=engine, + parser=parser, + inplace=True, + ) + got = df.e + expect = np.arctan2(np.sin(df.a), df.b).rename("e") + tm.assert_series_equal(got, expect) + + def test_df_arithmetic_subexpression(self, engine, parser): + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(10), + "b": np.random.default_rng(2).standard_normal(10), + } + ) + df.eval("e = sin(a + b)", engine=engine, parser=parser, inplace=True) + got = df.e + expect = np.sin(df.a + df.b).rename("e") + tm.assert_series_equal(got, expect) + + @pytest.mark.parametrize( + "dtype, expect_dtype", + [ + (np.int32, np.float64), + (np.int64, np.float64), + (np.float32, np.float32), + (np.float64, np.float64), + pytest.param(np.complex128, np.complex128, marks=td.skip_if_windows), + ], + ) + def test_result_types(self, dtype, expect_dtype, engine, parser): + # xref https://github.com/pandas-dev/pandas/issues/12293 + # this fails on Windows, apparently a floating point precision issue + + # Did not test complex64 because DataFrame is converting it to + # complex128. Due to https://github.com/pandas-dev/pandas/issues/10952 + df = DataFrame( + {"a": np.random.default_rng(2).standard_normal(10).astype(dtype)} + ) + assert df.a.dtype == dtype + df.eval("b = sin(a)", engine=engine, parser=parser, inplace=True) + got = df.b + expect = np.sin(df.a).rename("b") + assert expect.dtype == got.dtype + assert expect_dtype == got.dtype + tm.assert_series_equal(got, expect) + + def test_undefined_func(self, engine, parser): + df = DataFrame({"a": np.random.default_rng(2).standard_normal(10)}) + msg = '"mysin" is not a supported function' + + with pytest.raises(ValueError, match=msg): + df.eval("mysin(a)", engine=engine, parser=parser) + + def test_keyword_arg(self, engine, parser): + df = DataFrame({"a": np.random.default_rng(2).standard_normal(10)}) + msg = 'Function "sin" does not support keyword arguments' + + with pytest.raises(TypeError, match=msg): + df.eval("sin(x=a)", engine=engine, parser=parser) + + +_var_s = np.random.default_rng(2).standard_normal(10) + + +class TestScope: + def test_global_scope(self, engine, parser): + e = "_var_s * 2" + tm.assert_numpy_array_equal( + _var_s * 2, pd.eval(e, engine=engine, parser=parser) + ) + + def test_no_new_locals(self, engine, parser): + x = 1 + lcls = locals().copy() + pd.eval("x + 1", local_dict=lcls, engine=engine, parser=parser) + lcls2 = locals().copy() + lcls2.pop("lcls") + assert lcls == lcls2 + + def test_no_new_globals(self, engine, parser): + x = 1 # noqa: F841 + gbls = globals().copy() + pd.eval("x + 1", engine=engine, parser=parser) + gbls2 = globals().copy() + assert gbls == gbls2 + + def test_empty_locals(self, engine, parser): + # GH 47084 + x = 1 # noqa: F841 + msg = "name 'x' is not defined" + with pytest.raises(UndefinedVariableError, match=msg): + pd.eval("x + 1", engine=engine, parser=parser, local_dict={}) + + def test_empty_globals(self, engine, parser): + # GH 47084 + msg = "name '_var_s' is not defined" + e = "_var_s * 2" + with pytest.raises(UndefinedVariableError, match=msg): + pd.eval(e, engine=engine, parser=parser, global_dict={}) + + +@td.skip_if_no("numexpr") +def test_invalid_engine(): + msg = "Invalid engine 'asdf' passed" + with pytest.raises(KeyError, match=msg): + pd.eval("x + y", local_dict={"x": 1, "y": 2}, engine="asdf") + + +@td.skip_if_no("numexpr") +@pytest.mark.parametrize( + ("use_numexpr", "expected"), + ( + (True, "numexpr"), + (False, "python"), + ), +) +def test_numexpr_option_respected(use_numexpr, expected): + # GH 32556 + from pandas.core.computation.eval import _check_engine + + with pd.option_context("compute.use_numexpr", use_numexpr): + result = _check_engine(None) + assert result == expected + + +@td.skip_if_no("numexpr") +def test_numexpr_option_incompatible_op(): + # GH 32556 + with pd.option_context("compute.use_numexpr", False): + df = DataFrame( + {"A": [True, False, True, False, None, None], "B": [1, 2, 3, 4, 5, 6]} + ) + result = df.query("A.isnull()") + expected = DataFrame({"A": [None, None], "B": [5, 6]}, index=range(4, 6)) + tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numexpr") +def test_invalid_parser(): + msg = "Invalid parser 'asdf' passed" + with pytest.raises(KeyError, match=msg): + pd.eval("x + y", local_dict={"x": 1, "y": 2}, parser="asdf") + + +_parsers: dict[str, type[BaseExprVisitor]] = { + "python": PythonExprVisitor, + "pytables": pytables.PyTablesExprVisitor, + "pandas": PandasExprVisitor, +} + + +@pytest.mark.parametrize("engine", ENGINES) +@pytest.mark.parametrize("parser", _parsers) +def test_disallowed_nodes(engine, parser): + VisitorClass = _parsers[parser] + inst = VisitorClass("x + 1", engine, parser) + + for ops in VisitorClass.unsupported_nodes: + msg = "nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + getattr(inst, ops)() + + +def test_syntax_error_exprs(engine, parser): + e = "s +" + with pytest.raises(SyntaxError, match="invalid syntax"): + pd.eval(e, engine=engine, parser=parser) + + +def test_name_error_exprs(engine, parser): + e = "s + t" + msg = "name 's' is not defined" + with pytest.raises(NameError, match=msg): + pd.eval(e, engine=engine, parser=parser) + + +@pytest.mark.parametrize("express", ["a + @b", "@a + b", "@a + @b"]) +def test_invalid_local_variable_reference(engine, parser, express): + a, b = 1, 2 # noqa: F841 + + if parser != "pandas": + with pytest.raises(SyntaxError, match="The '@' prefix is only"): + pd.eval(express, engine=engine, parser=parser) + else: + with pytest.raises(SyntaxError, match="The '@' prefix is not"): + pd.eval(express, engine=engine, parser=parser) + + +def test_numexpr_builtin_raises(engine, parser): + sin, dotted_line = 1, 2 + if engine == "numexpr": + msg = "Variables in expression .+" + with pytest.raises(NumExprClobberingError, match=msg): + pd.eval("sin + dotted_line", engine=engine, parser=parser) + else: + res = pd.eval("sin + dotted_line", engine=engine, parser=parser) + assert res == sin + dotted_line + + +def test_bad_resolver_raises(engine, parser): + cannot_resolve = 42, 3.0 + with pytest.raises(TypeError, match="Resolver of type .+"): + pd.eval("1 + 2", resolvers=cannot_resolve, engine=engine, parser=parser) + + +def test_empty_string_raises(engine, parser): + # GH 13139 + with pytest.raises(ValueError, match="expr cannot be an empty string"): + pd.eval("", engine=engine, parser=parser) + + +def test_more_than_one_expression_raises(engine, parser): + with pytest.raises(SyntaxError, match="only a single expression is allowed"): + pd.eval("1 + 1; 2 + 2", engine=engine, parser=parser) + + +@pytest.mark.parametrize("cmp", ("and", "or")) +@pytest.mark.parametrize("lhs", (int, float)) +@pytest.mark.parametrize("rhs", (int, float)) +def test_bool_ops_fails_on_scalars(lhs, cmp, rhs, engine, parser): + gen = { + int: lambda: np.random.default_rng(2).integers(10), + float: np.random.default_rng(2).standard_normal, + } + + mid = gen[lhs]() # noqa: F841 + lhs = gen[lhs]() + rhs = gen[rhs]() + + ex1 = f"lhs {cmp} mid {cmp} rhs" + ex2 = f"lhs {cmp} mid and mid {cmp} rhs" + ex3 = f"(lhs {cmp} mid) & (mid {cmp} rhs)" + for ex in (ex1, ex2, ex3): + msg = "cannot evaluate scalar only bool ops|'BoolOp' nodes are not" + with pytest.raises(NotImplementedError, match=msg): + pd.eval(ex, engine=engine, parser=parser) + + +@pytest.mark.parametrize( + "other", + [ + "'x'", + "...", + ], +) +def test_equals_various(other): + df = DataFrame({"A": ["a", "b", "c"]}, dtype=object) + result = df.eval(f"A == {other}") + expected = Series([False, False, False], name="A") + tm.assert_series_equal(result, expected) + + +def test_inf(engine, parser): + s = "inf + 1" + expected = np.inf + result = pd.eval(s, engine=engine, parser=parser) + assert result == expected + + +@pytest.mark.parametrize("column", ["Temp(°C)", "Capacitance(μF)"]) +def test_query_token(engine, column): + # See: https://github.com/pandas-dev/pandas/pull/42826 + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=[column, "b"] + ) + expected = df[df[column] > 5] + query_string = f"`{column}` > 5" + result = df.query(query_string, engine=engine) + tm.assert_frame_equal(result, expected) + + +def test_negate_lt_eq_le(engine, parser): + df = DataFrame([[0, 10], [1, 20]], columns=["cat", "count"]) + expected = df[~(df.cat > 0)] + + result = df.query("~(cat > 0)", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + if parser == "python": + msg = "'Not' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + df.query("not (cat > 0)", engine=engine, parser=parser) + else: + result = df.query("not (cat > 0)", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "column", + DEFAULT_GLOBALS.keys(), +) +def test_eval_no_support_column_name(request, column): + # GH 44603 + if column in ["True", "False", "inf", "Inf"]: + request.applymarker( + pytest.mark.xfail( + raises=KeyError, + reason=f"GH 47859 DataFrame eval not supported with {column}", + ) + ) + + df = DataFrame( + np.random.default_rng(2).integers(0, 100, size=(10, 2)), + columns=[column, "col1"], + ) + expected = df[df[column] > 6] + result = df.query(f"{column}>6") + + tm.assert_frame_equal(result, expected) + + +def test_set_inplace(): + # https://github.com/pandas-dev/pandas/issues/47449 + # Ensure we don't only update the DataFrame inplace, but also the actual + # column values, such that references to this column also get updated + df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result_view = df[:] + ser = df["A"] + df.eval("A = B + C", inplace=True) + expected = DataFrame({"A": [11, 13, 15], "B": [4, 5, 6], "C": [7, 8, 9]}) + tm.assert_frame_equal(df, expected) + expected = Series([1, 2, 3], name="A") + tm.assert_series_equal(ser, expected) + tm.assert_series_equal(result_view["A"], expected) + + +@pytest.mark.parametrize("value", [1, "True", [1, 2, 3], 5.0]) +def test_validate_bool_args(value): + msg = 'For argument "inplace" expected type bool, received type' + with pytest.raises(ValueError, match=msg): + pd.eval("2+2", inplace=value) + + +@td.skip_if_no("numexpr") +def test_eval_float_div_numexpr(): + # GH 59736 + result = pd.eval("1 / 2", engine="numexpr") + expected = 0.5 + assert result == expected + + +def test_method_calls_on_binop(): + # GH 61175 + x = Series([1, 2, 3, 5]) + y = Series([2, 3, 4]) + + # Method call on binary operation result + result = pd.eval("(x + y).dropna()") + expected = (x + y).dropna() + tm.assert_series_equal(result, expected) + + # Test with other binary operations + result = pd.eval("(x * y).dropna()") + expected = (x * y).dropna() + tm.assert_series_equal(result, expected) + + # Test with method chaining + result = pd.eval("(x + y).dropna().reset_index(drop=True)") + expected = (x + y).dropna().reset_index(drop=True) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/config/__init__.py b/pandas/tests/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/config/test_config.py b/pandas/tests/config/test_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e2f526f0ca3249503b05b84ba08804236ba41a67 --- /dev/null +++ b/pandas/tests/config/test_config.py @@ -0,0 +1,499 @@ +import pytest + +from pandas._config import config as cf +from pandas._config.config import OptionError + +from pandas.errors import Pandas4Warning + +import pandas as pd +import pandas._testing as tm + + +class TestConfig: + @pytest.fixture(autouse=True) + def clean_config(self, monkeypatch): + with monkeypatch.context() as m: + m.setattr(cf, "_global_config", {}) + m.setattr(cf, "options", cf.DictWrapper(cf._global_config)) + m.setattr(cf, "_deprecated_options", {}) + m.setattr(cf, "_registered_options", {}) + + # Our test fixture in conftest.py sets "chained_assignment" + # to "raise" only after all test methods have been setup. + # However, after this setup, there is no longer any + # "chained_assignment" option, so re-register it. + cf.register_option("chained_assignment", "raise") + yield + + def test_api(self): + # the pandas object exposes the user API + assert hasattr(pd, "get_option") + assert hasattr(pd, "set_option") + assert hasattr(pd, "reset_option") + assert hasattr(pd, "describe_option") + + def test_is_one_of_factory(self): + v = cf.is_one_of_factory([None, 12]) + + v(12) + v(None) + msg = r"Value must be one of None\|12" + with pytest.raises(ValueError, match=msg): + v(1.1) + + def test_register_option(self): + cf.register_option("a", 1, "doc") + + # can't register an already registered option + msg = "Option 'a' has already been registered" + with pytest.raises(OptionError, match=msg): + cf.register_option("a", 1, "doc") + + # can't register an already registered option + msg = "Path prefix to option 'a' is already an option" + with pytest.raises(OptionError, match=msg): + cf.register_option("a.b.c.d1", 1, "doc") + with pytest.raises(OptionError, match=msg): + cf.register_option("a.b.c.d2", 1, "doc") + + # no python keywords + msg = "for is a python keyword" + with pytest.raises(ValueError, match=msg): + cf.register_option("for", 0) + with pytest.raises(ValueError, match=msg): + cf.register_option("a.for.b", 0) + # must be valid identifier (ensure attribute access works) + msg = "oh my goddess! is not a valid identifier" + with pytest.raises(ValueError, match=msg): + cf.register_option("Oh my Goddess!", 0) + + # we can register options several levels deep + # without predefining the intermediate steps + # and we can define differently named options + # in the same namespace + cf.register_option("k.b.c.d1", 1, "doc") + cf.register_option("k.b.c.d2", 1, "doc") + + def test_describe_option(self): + cf.register_option("a", 1, "doc") + cf.register_option("b", 1, "doc2") + cf.deprecate_option("b", FutureWarning) + + cf.register_option("c.d.e1", 1, "doc3") + cf.register_option("c.d.e2", 1, "doc4") + cf.register_option("f", 1) + cf.register_option("g.h", 1) + cf.register_option("k", 2) + cf.deprecate_option("g.h", FutureWarning, rkey="k") + cf.register_option("l", "foo") + + # non-existent keys raise KeyError + msg = r"No such keys\(s\)" + with pytest.raises(OptionError, match=msg): + cf.describe_option("no.such.key") + + # we can get the description for any key we registered + assert "doc" in cf.describe_option("a", _print_desc=False) + assert "doc2" in cf.describe_option("b", _print_desc=False) + assert "precated" in cf.describe_option("b", _print_desc=False) + assert "doc3" in cf.describe_option("c.d.e1", _print_desc=False) + assert "doc4" in cf.describe_option("c.d.e2", _print_desc=False) + + # if no doc is specified we get a default message + # saying "description not available" + assert "available" in cf.describe_option("f", _print_desc=False) + assert "available" in cf.describe_option("g.h", _print_desc=False) + assert "precated" in cf.describe_option("g.h", _print_desc=False) + assert "k" in cf.describe_option("g.h", _print_desc=False) + + # default is reported + assert "foo" in cf.describe_option("l", _print_desc=False) + # current value is reported + assert "bar" not in cf.describe_option("l", _print_desc=False) + cf.set_option("l", "bar") + assert "bar" in cf.describe_option("l", _print_desc=False) + + @pytest.mark.parametrize("category", [DeprecationWarning, FutureWarning]) + def test_case_insensitive(self, category): + cf.register_option("KanBAN", 1, "doc") + + assert "doc" in cf.describe_option("kanbaN", _print_desc=False) + assert cf.get_option("kanBaN") == 1 + cf.set_option("KanBan", 2) + assert cf.get_option("kAnBaN") == 2 + + # gets of non-existent keys fail + msg = r"No such keys\(s\): 'no_such_option'" + with pytest.raises(OptionError, match=msg): + cf.get_option("no_such_option") + + cf.deprecate_option("KanBan", category) + msg = "'kanban' is deprecated, please refrain from using it." + with tm.assert_produces_warning(category, match=msg): + cf.get_option("kAnBaN") + + def test_get_option(self): + cf.register_option("a", 1, "doc") + cf.register_option("b.c", "hullo", "doc2") + cf.register_option("b.b", None, "doc2") + + # gets of existing keys succeed + assert cf.get_option("a") == 1 + assert cf.get_option("b.c") == "hullo" + assert cf.get_option("b.b") is None + + # gets of non-existent keys fail + msg = r"No such keys\(s\): 'no_such_option'" + with pytest.raises(OptionError, match=msg): + cf.get_option("no_such_option") + + def test_set_option(self): + cf.register_option("a", 1, "doc") + cf.register_option("b.c", "hullo", "doc2") + cf.register_option("b.b", None, "doc2") + + assert cf.get_option("a") == 1 + assert cf.get_option("b.c") == "hullo" + assert cf.get_option("b.b") is None + + cf.set_option("a", 2) + cf.set_option("b.c", "wurld") + cf.set_option("b.b", 1.1) + + assert cf.get_option("a") == 2 + assert cf.get_option("b.c") == "wurld" + assert cf.get_option("b.b") == 1.1 + + msg = r"No such keys\(s\): 'no.such.key'" + with pytest.raises(OptionError, match=msg): + cf.set_option("no.such.key", None) + + def test_set_option_empty_args(self): + msg = "Must provide an even number of non-keyword arguments" + with pytest.raises(ValueError, match=msg): + cf.set_option() + + def test_set_option_uneven_args(self): + msg = "Must provide an even number of non-keyword arguments" + with pytest.raises(ValueError, match=msg): + cf.set_option("a.b", 2, "b.c") + + def test_set_option_invalid_single_argument_type(self): + msg = "Must provide an even number of non-keyword arguments" + with pytest.raises(ValueError, match=msg): + cf.set_option(2) + + def test_set_option_multiple(self): + cf.register_option("a", 1, "doc") + cf.register_option("b.c", "hullo", "doc2") + cf.register_option("b.b", None, "doc2") + + assert cf.get_option("a") == 1 + assert cf.get_option("b.c") == "hullo" + assert cf.get_option("b.b") is None + + cf.set_option("a", "2", "b.c", None, "b.b", 10.0) + + assert cf.get_option("a") == "2" + assert cf.get_option("b.c") is None + assert cf.get_option("b.b") == 10.0 + + def test_set_option_dict(self): + # GH 61093 + + cf.register_option("a", 1, "doc") + cf.register_option("b.c", "hullo", "doc2") + cf.register_option("b.b", None, "doc2") + + assert cf.get_option("a") == 1 + assert cf.get_option("b.c") == "hullo" + assert cf.get_option("b.b") is None + + options_dict = {"a": "2", "b.c": None, "b.b": 10.0} + cf.set_option(options_dict) + + assert cf.get_option("a") == "2" + assert cf.get_option("b.c") is None + assert cf.get_option("b.b") == 10.0 + + def test_validation(self): + cf.register_option("a", 1, "doc", validator=cf.is_int) + cf.register_option("d", 1, "doc", validator=cf.is_nonnegative_int) + cf.register_option("b.c", "hullo", "doc2", validator=cf.is_text) + + msg = "Value must have type ''" + with pytest.raises(ValueError, match=msg): + cf.register_option("a.b.c.d2", "NO", "doc", validator=cf.is_int) + + cf.set_option("a", 2) # int is_int + cf.set_option("b.c", "wurld") # str is_str + cf.set_option("d", 2) + cf.set_option("d", None) # non-negative int can be None + + # None not is_int + with pytest.raises(ValueError, match=msg): + cf.set_option("a", None) + with pytest.raises(ValueError, match=msg): + cf.set_option("a", "ab") + + msg = "Value must be a nonnegative integer or None" + with pytest.raises(ValueError, match=msg): + cf.register_option("a.b.c.d3", "NO", "doc", validator=cf.is_nonnegative_int) + with pytest.raises(ValueError, match=msg): + cf.register_option("a.b.c.d3", -2, "doc", validator=cf.is_nonnegative_int) + + msg = r"Value must be an instance of \|" + with pytest.raises(ValueError, match=msg): + cf.set_option("b.c", 1) + + validator = cf.is_one_of_factory([None, cf.is_callable]) + cf.register_option("b", lambda: None, "doc", validator=validator) + cf.set_option("b", "%.1f".format) # Formatter is callable + cf.set_option("b", None) # Formatter is none (default) + with pytest.raises(ValueError, match="Value must be a callable"): + cf.set_option("b", "%.1f") + + def test_reset_option(self): + cf.register_option("a", 1, "doc", validator=cf.is_int) + cf.register_option("b.c", "hullo", "doc2", validator=cf.is_str) + assert cf.get_option("a") == 1 + assert cf.get_option("b.c") == "hullo" + + cf.set_option("a", 2) + cf.set_option("b.c", "wurld") + assert cf.get_option("a") == 2 + assert cf.get_option("b.c") == "wurld" + + cf.reset_option("a") + assert cf.get_option("a") == 1 + assert cf.get_option("b.c") == "wurld" + cf.reset_option("b.c") + assert cf.get_option("a") == 1 + assert cf.get_option("b.c") == "hullo" + + def test_reset_option_all(self): + cf.register_option("a", 1, "doc", validator=cf.is_int) + cf.register_option("b.c", "hullo", "doc2", validator=cf.is_str) + assert cf.get_option("a") == 1 + assert cf.get_option("b.c") == "hullo" + + cf.set_option("a", 2) + cf.set_option("b.c", "wurld") + assert cf.get_option("a") == 2 + assert cf.get_option("b.c") == "wurld" + + cf.reset_option("all") + assert cf.get_option("a") == 1 + assert cf.get_option("b.c") == "hullo" + + def test_deprecate_option(self): + # we can deprecate non-existent options + cf.deprecate_option("foo", FutureWarning) + + with tm.assert_produces_warning(FutureWarning, match="deprecated"): + with pytest.raises(KeyError, match="No such keys.s.: 'foo'"): + cf.get_option("foo") + + cf.register_option("a", 1, "doc", validator=cf.is_int) + cf.register_option("b.c", "hullo", "doc2") + cf.register_option("foo", "hullo", "doc2") + + cf.deprecate_option("a", FutureWarning, removal_ver="nifty_ver") + with tm.assert_produces_warning(FutureWarning, match="eprecated.*nifty_ver"): + cf.get_option("a") + + msg = "Option 'a' has already been defined as deprecated" + with pytest.raises(OptionError, match=msg): + cf.deprecate_option("a", FutureWarning) + + cf.deprecate_option("b.c", FutureWarning, "zounds!") + with tm.assert_produces_warning(FutureWarning, match="zounds!"): + cf.get_option("b.c") + + # test rerouting keys + cf.register_option("d.a", "foo", "doc2") + cf.register_option("d.dep", "bar", "doc2") + assert cf.get_option("d.a") == "foo" + assert cf.get_option("d.dep") == "bar" + + cf.deprecate_option("d.dep", FutureWarning, rkey="d.a") # reroute d.dep to d.a + with tm.assert_produces_warning(FutureWarning, match="eprecated"): + assert cf.get_option("d.dep") == "foo" + + with tm.assert_produces_warning(FutureWarning, match="eprecated"): + cf.set_option("d.dep", "baz") # should overwrite "d.a" + + with tm.assert_produces_warning(FutureWarning, match="eprecated"): + assert cf.get_option("d.dep") == "baz" + + def test_config_prefix(self): + with cf.config_prefix("base"): + cf.register_option("a", 1, "doc1") + cf.register_option("b", 2, "doc2") + assert cf.get_option("a") == 1 + assert cf.get_option("b") == 2 + + cf.set_option("a", 3) + cf.set_option("b", 4) + assert cf.get_option("a") == 3 + assert cf.get_option("b") == 4 + + assert cf.get_option("base.a") == 3 + assert cf.get_option("base.b") == 4 + assert "doc1" in cf.describe_option("base.a", _print_desc=False) + assert "doc2" in cf.describe_option("base.b", _print_desc=False) + + cf.reset_option("base.a") + cf.reset_option("base.b") + + with cf.config_prefix("base"): + assert cf.get_option("a") == 1 + assert cf.get_option("b") == 2 + + def test_callback(self): + k = [None] + v = [None] + + def callback(key): + k.append(key) + v.append(cf.get_option(key)) + + cf.register_option("d.a", "foo", cb=callback) + cf.register_option("d.b", "foo", cb=callback) + + del k[-1], v[-1] + cf.set_option("d.a", "fooz") + assert k[-1] == "d.a" + assert v[-1] == "fooz" + + del k[-1], v[-1] + cf.set_option("d.b", "boo") + assert k[-1] == "d.b" + assert v[-1] == "boo" + + del k[-1], v[-1] + cf.reset_option("d.b") + assert k[-1] == "d.b" + + def test_set_ContextManager(self): + def eq(val): + assert cf.get_option("a") == val + + cf.register_option("a", 0) + eq(0) + with cf.option_context("a", 15): + eq(15) + with cf.option_context("a", 25): + eq(25) + eq(15) + eq(0) + + cf.set_option("a", 17) + eq(17) + + # Test that option_context can be used as a decorator too (#34253). + @cf.option_context("a", 123) + def f(): + eq(123) + + f() + + def test_set_ContextManager_dict(self): + def eq(val): + assert cf.get_option("a") == val + assert cf.get_option("b.c") == val + + cf.register_option("a", 0) + cf.register_option("b.c", 0) + + eq(0) + with cf.option_context({"a": 15, "b.c": 15}): + eq(15) + with cf.option_context({"a": 25, "b.c": 25}): + eq(25) + eq(15) + eq(0) + + cf.set_option("a", 17) + cf.set_option("b.c", 17) + eq(17) + + # Test that option_context can be used as a decorator too + @cf.option_context({"a": 123, "b.c": 123}) + def f(): + eq(123) + + f() + + def test_attribute_access(self): + holder = [] + + def f3(key): + holder.append(True) + + cf.register_option("a", 0) + cf.register_option("c", 0, cb=f3) + options = cf.options + + assert options.a == 0 + with cf.option_context("a", 15): + assert options.a == 15 + + options.a = 500 + assert cf.get_option("a") == 500 + + cf.reset_option("a") + assert options.a == cf.get_option("a") + + msg = "You can only set the value of existing options" + with pytest.raises(OptionError, match=msg): + options.b = 1 + with pytest.raises(OptionError, match=msg): + options.display = 1 + + # make sure callback kicks when using this form of setting + options.c = 1 + assert len(holder) == 1 + + def test_option_context_scope(self): + # Ensure that creating a context does not affect the existing + # environment as it is supposed to be used with the `with` statement. + # See https://github.com/pandas-dev/pandas/issues/8514 + + original_value = 60 + context_value = 10 + option_name = "a" + + cf.register_option(option_name, original_value) + + # Ensure creating contexts didn't affect the current context. + ctx = cf.option_context(option_name, context_value) + assert cf.get_option(option_name) == original_value + + # Ensure the correct value is available inside the context. + with ctx: + assert cf.get_option(option_name) == context_value + + # Ensure the current context is reset + assert cf.get_option(option_name) == original_value + + def test_dictwrapper_getattr(self): + options = cf.options + # GH 19789 + with pytest.raises(OptionError, match="No such option"): + options.bananas + assert not hasattr(options, "bananas") + + +def test_no_silent_downcasting_deprecated(): + # GH#59502 + with tm.assert_produces_warning(Pandas4Warning, match="is deprecated"): + cf.get_option("future.no_silent_downcasting") + with tm.assert_produces_warning(Pandas4Warning, match="is deprecated"): + cf.set_option("future.no_silent_downcasting", True) + + +def test_option_context_invalid_option(): + with pytest.raises(OptionError, match="No such keys"): + with cf.option_context("invalid", True): + pass diff --git a/pandas/tests/config/test_localization.py b/pandas/tests/config/test_localization.py new file mode 100644 index 0000000000000000000000000000000000000000..b9a0a44bf8c89f10537bc2b7c64d3bb418a9a4a7 --- /dev/null +++ b/pandas/tests/config/test_localization.py @@ -0,0 +1,155 @@ +import codecs +import locale +import os + +import pytest + +from pandas._config.localization import ( + can_set_locale, + get_locales, + set_locale, +) + +from pandas.compat import ISMUSL + +import pandas as pd + +_all_locales = get_locales() + +# Don't run any of these tests if we have no locales. +pytestmark = pytest.mark.skipif(not _all_locales, reason="Need locales") + +_skip_if_only_one_locale = pytest.mark.skipif( + len(_all_locales) <= 1, reason="Need multiple locales for meaningful test" +) + + +def _get_current_locale(lc_var: int = locale.LC_ALL) -> str: + # getlocale is not always compliant with setlocale, use setlocale. GH#46595 + return locale.setlocale(lc_var) + + +@pytest.mark.parametrize("lc_var", (locale.LC_ALL, locale.LC_CTYPE, locale.LC_TIME)) +def test_can_set_current_locale(lc_var): + # Can set the current locale + before_locale = _get_current_locale(lc_var) + assert can_set_locale(before_locale, lc_var=lc_var) + after_locale = _get_current_locale(lc_var) + assert before_locale == after_locale + + +@pytest.mark.parametrize("lc_var", (locale.LC_ALL, locale.LC_CTYPE, locale.LC_TIME)) +def test_can_set_locale_valid_set(lc_var): + # Can set the default locale. + before_locale = _get_current_locale(lc_var) + assert can_set_locale("", lc_var=lc_var) + after_locale = _get_current_locale(lc_var) + assert before_locale == after_locale + + +@pytest.mark.parametrize( + "lc_var", + ( + locale.LC_ALL, + locale.LC_CTYPE, + pytest.param( + locale.LC_TIME, + marks=pytest.mark.skipif( + ISMUSL, reason="MUSL allows setting invalid LC_TIME." + ), + ), + ), +) +def test_can_set_locale_invalid_set(lc_var): + # Cannot set an invalid locale. + before_locale = _get_current_locale(lc_var) + assert not can_set_locale("non-existent_locale", lc_var=lc_var) + after_locale = _get_current_locale(lc_var) + assert before_locale == after_locale + + +@pytest.mark.parametrize( + "lang,enc", + [ + ("it_CH", "UTF-8"), + ("en_US", "ascii"), + ("zh_CN", "GB2312"), + ("it_IT", "ISO-8859-1"), + ], +) +@pytest.mark.parametrize("lc_var", (locale.LC_ALL, locale.LC_CTYPE, locale.LC_TIME)) +def test_can_set_locale_no_leak(lang, enc, lc_var): + # Test that can_set_locale does not leak even when returning False. See GH#46595 + before_locale = _get_current_locale(lc_var) + can_set_locale((lang, enc), locale.LC_ALL) + after_locale = _get_current_locale(lc_var) + assert before_locale == after_locale + + +def test_can_set_locale_invalid_get(monkeypatch): + # see GH#22129 + # In some cases, an invalid locale can be set, + # but a subsequent getlocale() raises a ValueError. + + def mock_get_locale(): + raise ValueError + + with monkeypatch.context() as m: + m.setattr(locale, "getlocale", mock_get_locale) + assert not can_set_locale("") + + +def test_get_locales_at_least_one(): + # see GH#9744 + assert len(_all_locales) > 0 + + +@_skip_if_only_one_locale +def test_get_locales_prefix(): + first_locale = _all_locales[0] + assert len(get_locales(prefix=first_locale[:2])) > 0 + + +@_skip_if_only_one_locale +@pytest.mark.parametrize( + "lang,enc", + [ + ("it_CH", "UTF-8"), + ("en_US", "ascii"), + ("zh_CN", "GB2312"), + ("it_IT", "ISO-8859-1"), + ], +) +def test_set_locale(lang, enc): + before_locale = _get_current_locale() + + enc = codecs.lookup(enc).name + new_locale = lang, enc + + if not can_set_locale(new_locale): + msg = "unsupported locale setting" + + with pytest.raises(locale.Error, match=msg): + with set_locale(new_locale): + pass + else: + with set_locale(new_locale) as normalized_locale: + new_lang, new_enc = normalized_locale.split(".") + new_enc = codecs.lookup(enc).name + + normalized_locale = new_lang, new_enc + assert normalized_locale == new_locale + + # Once we exit the "with" statement, locale should be back to what it was. + after_locale = _get_current_locale() + assert before_locale == after_locale + + +def test_encoding_detected(): + system_locale = os.environ.get("LC_ALL") + system_encoding = system_locale.split(".")[-1] if system_locale else "utf-8" + + assert ( + codecs.lookup(pd.options.display.encoding).name + == codecs.lookup(system_encoding).name + ) diff --git a/pandas/tests/construction/__init__.py b/pandas/tests/construction/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/construction/test_extract_array.py b/pandas/tests/construction/test_extract_array.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd3eda8c995ce022e9d46b907323e79bcd679f8 --- /dev/null +++ b/pandas/tests/construction/test_extract_array.py @@ -0,0 +1,18 @@ +from pandas import Index +import pandas._testing as tm +from pandas.core.construction import extract_array + + +def test_extract_array_rangeindex(): + ri = Index(range(5)) + + expected = ri._values + res = extract_array(ri, extract_numpy=True, extract_range=True) + tm.assert_numpy_array_equal(res, expected) + res = extract_array(ri, extract_numpy=False, extract_range=True) + tm.assert_numpy_array_equal(res, expected) + + res = extract_array(ri, extract_numpy=True, extract_range=False) + tm.assert_index_equal(res, ri) + res = extract_array(ri, extract_numpy=False, extract_range=False) + tm.assert_index_equal(res, ri) diff --git a/pandas/tests/copy_view/__init__.py b/pandas/tests/copy_view/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/copy_view/test_array.py b/pandas/tests/copy_view/test_array.py new file mode 100644 index 0000000000000000000000000000000000000000..22976f307cae6d8da9852c12aad9f85b14c5dd64 --- /dev/null +++ b/pandas/tests/copy_view/test_array.py @@ -0,0 +1,229 @@ +import numpy as np +import pytest + +from pandas.compat.numpy import np_version_gt2 + +from pandas import ( + DataFrame, + Series, + date_range, +) +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array + +# ----------------------------------------------------------------------------- +# Copy/view behaviour for accessing underlying array of Series/DataFrame + + +@pytest.mark.parametrize( + "method", + [ + lambda ser: ser.values, + lambda ser: np.asarray(ser.array), + lambda ser: np.asarray(ser), + lambda ser: np.array(ser, copy=False), + ], + ids=["values", "array", "np.asarray", "np.array"], +) +def test_series_values(request, method): + ser = Series([1, 2, 3], name="name") + ser_orig = ser.copy() + + arr = method(ser) + + if request.node.callspec.id == "array": + # https://github.com/pandas-dev/pandas/issues/63099 + # .array for now does not return a read-only view + assert arr.flags.writeable is True + # updating the array updates the series + arr[0] = 0 + assert ser.iloc[0] == 0 + return + + # .values still gives a view but is read-only + assert np.shares_memory(arr, get_array(ser, "name")) + assert arr.flags.writeable is False + + # mutating series through arr therefore doesn't work + with pytest.raises(ValueError, match="read-only"): + arr[0] = 0 + tm.assert_series_equal(ser, ser_orig) + + # mutating the series itself still works + ser.iloc[0] = 0 + assert ser.values[0] == 0 + + +@pytest.mark.parametrize( + "method", + [ + lambda df: df.values, + lambda df: np.asarray(df), + lambda ser: np.array(ser, copy=False), + ], + ids=["values", "asarray", "array"], +) +def test_dataframe_values(method): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df_orig = df.copy() + + arr = method(df) + + # .values still gives a view but is read-only + assert np.shares_memory(arr, get_array(df, "a")) + assert arr.flags.writeable is False + + # mutating series through arr therefore doesn't work + with pytest.raises(ValueError, match="read-only"): + arr[0, 0] = 0 + tm.assert_frame_equal(df, df_orig) + + # mutating the series itself still works + df.iloc[0, 0] = 0 + assert df.values[0, 0] == 0 + + +def test_series_to_numpy(): + ser = Series([1, 2, 3], name="name") + ser_orig = ser.copy() + + # default: copy=False, no dtype or NAs + arr = ser.to_numpy() + # to_numpy still gives a view but is read-only + assert np.shares_memory(arr, get_array(ser, "name")) + assert arr.flags.writeable is False + + # mutating series through arr therefore doesn't work + with pytest.raises(ValueError, match="read-only"): + arr[0] = 0 + tm.assert_series_equal(ser, ser_orig) + + # mutating the series itself still works + ser.iloc[0] = 0 + assert ser.values[0] == 0 + + # specify copy=True gives a writeable array + ser = Series([1, 2, 3], name="name") + arr = ser.to_numpy(copy=True) + assert not np.shares_memory(arr, get_array(ser, "name")) + assert arr.flags.writeable is True + + # specifying a dtype that already causes a copy also gives a writeable array + ser = Series([1, 2, 3], name="name") + arr = ser.to_numpy(dtype="float64") + assert not np.shares_memory(arr, get_array(ser, "name")) + assert arr.flags.writeable is True + + +@pytest.mark.parametrize( + "method", + [ + lambda ser: np.asarray(ser.values), + lambda ser: np.asarray(ser.array), + lambda ser: np.asarray(ser), + lambda ser: np.asarray(ser, dtype="int64"), + lambda ser: np.array(ser, copy=False), + ], + ids=["values", "array", "np.asarray", "np.asarray-dtype", "np.array"], +) +def test_series_values_ea_dtypes(request, method): + ser = Series([1, 2, 3], dtype="Int64") + ser_orig = ser.copy() + + arr = method(ser) + + if request.node.callspec.id in ("values", "array"): + # https://github.com/pandas-dev/pandas/issues/63099 + # .array/values for now does not return a read-only view + assert arr.flags.writeable is True + # updating the array updates the series + arr[0] = 0 + assert ser.iloc[0] == 0 + return + + # conversion to ndarray gives a view but is read-only + assert np.shares_memory(arr, get_array(ser)) + assert arr.flags.writeable is False + + # mutating series through arr therefore doesn't work + with pytest.raises(ValueError, match="read-only"): + arr[0] = 0 + tm.assert_series_equal(ser, ser_orig) + + # mutating the series itself still works + ser.iloc[0] = 0 + assert ser.values[0] == 0 + + +@pytest.mark.parametrize( + "method", + [ + lambda df: df.values, + lambda df: np.asarray(df), + lambda df: np.asarray(df, dtype="int64"), + lambda df: np.array(df, copy=False), + ], + ids=["values", "np.asarray", "np.asarray-dtype", "np.array"], +) +def test_dataframe_array_ea_dtypes(method): + df = DataFrame({"a": [1, 2, 3]}, dtype="Int64") + arr = method(df) + + assert np.shares_memory(arr, get_array(df, "a")) + assert arr.flags.writeable is False + + +def test_dataframe_array_string_dtype(): + df = DataFrame({"a": ["a", "b"]}, dtype="string[python]") + arr = np.asarray(df) + assert np.shares_memory(arr, get_array(df, "a")) + assert arr.flags.writeable is False + + +def test_series_array_string_dtype(any_string_dtype): + ser = Series(["a", "b"], dtype=any_string_dtype) + arr = np.asarray(ser) + if any_string_dtype == "string" and any_string_dtype.storage == "pyarrow": + # for pyarrow strings, the numpy arrays is not a view, so also does + # not need to be read-only (https://github.com/pandas-dev/pandas/pull/64035) + assert not np.shares_memory(arr, get_array(ser)) + assert arr.flags.writeable is True + else: + assert np.shares_memory(arr, get_array(ser)) + assert arr.flags.writeable is False + + +def test_dataframe_multiple_numpy_dtypes(): + df = DataFrame({"a": [1, 2, 3], "b": 1.5}) + arr = np.asarray(df) + assert not np.shares_memory(arr, get_array(df, "a")) + assert arr.flags.writeable is True + + if np_version_gt2: + # copy=False semantics are only supported in NumPy>=2. + + with pytest.raises(ValueError, match="Unable to avoid copy while creating"): + arr = np.array(df, copy=False) + + arr = np.array(df, copy=True) + assert arr.flags.writeable is True + + +def test_dataframe_single_block_copy_true(): + # the copy=False/None cases are tested above in test_dataframe_values + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + arr = np.array(df, copy=True) + assert not np.shares_memory(arr, get_array(df, "a")) + assert arr.flags.writeable is True + + +def test_values_is_ea(): + df = DataFrame({"a": date_range("2012-01-01", periods=3)}) + arr = np.asarray(df) + assert arr.flags.writeable is False + + +def test_empty_dataframe(): + df = DataFrame() + arr = np.asarray(df) + assert arr.flags.writeable is True diff --git a/pandas/tests/copy_view/test_astype.py b/pandas/tests/copy_view/test_astype.py new file mode 100644 index 0000000000000000000000000000000000000000..c436391739ab282ffd612186a1406422e3b0774a --- /dev/null +++ b/pandas/tests/copy_view/test_astype.py @@ -0,0 +1,230 @@ +import pickle + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, + Timestamp, + date_range, +) +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array + + +def test_astype_single_dtype(): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": 1.5}) + df_orig = df.copy() + df2 = df.astype("float64") + + assert np.shares_memory(get_array(df2, "c"), get_array(df, "c")) + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + # mutating df2 triggers a copy-on-write for that column/block + df2.iloc[0, 2] = 5.5 + assert not np.shares_memory(get_array(df2, "c"), get_array(df, "c")) + tm.assert_frame_equal(df, df_orig) + + # mutating parent also doesn't update result + df2 = df.astype("float64") + df.iloc[0, 2] = 5.5 + tm.assert_frame_equal(df2, df_orig.astype("float64")) + + +@pytest.mark.parametrize("dtype", ["int64", "Int64"]) +@pytest.mark.parametrize("new_dtype", ["int64", "Int64", "int64[pyarrow]"]) +def test_astype_avoids_copy(dtype, new_dtype): + if new_dtype == "int64[pyarrow]": + pytest.importorskip("pyarrow") + df = DataFrame({"a": [1, 2, 3]}, dtype=dtype) + df_orig = df.copy() + df2 = df.astype(new_dtype) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + # mutating df2 triggers a copy-on-write for that column/block + df2.iloc[0, 0] = 10 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + # mutating parent also doesn't update result + df2 = df.astype(new_dtype) + df.iloc[0, 0] = 100 + tm.assert_frame_equal(df2, df_orig.astype(new_dtype)) + + +@pytest.mark.parametrize("dtype", ["float64", "int32", "Int32", "int32[pyarrow]"]) +def test_astype_different_target_dtype(dtype): + if dtype == "int32[pyarrow]": + pytest.importorskip("pyarrow") + df = DataFrame({"a": [1, 2, 3]}) + df_orig = df.copy() + df2 = df.astype(dtype) + + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert df2._mgr._has_no_reference(0) + + df2.iloc[0, 0] = 5 + tm.assert_frame_equal(df, df_orig) + + # mutating parent also doesn't update result + df2 = df.astype(dtype) + df.iloc[0, 0] = 100 + tm.assert_frame_equal(df2, df_orig.astype(dtype)) + + +def test_astype_numpy_to_ea(): + ser = Series([1, 2, 3]) + result = ser.astype("Int64") + assert np.shares_memory(get_array(ser), get_array(result)) + + +@pytest.mark.parametrize( + "dtype, new_dtype", [("object", "string[python]"), ("string[python]", "object")] +) +def test_astype_string_and_object(dtype, new_dtype): + df = DataFrame({"a": ["a", "b", "c"]}, dtype=dtype) + df_orig = df.copy() + df2 = df.astype(new_dtype) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + df2.iloc[0, 0] = "x" + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "dtype, new_dtype", [("object", "string[python]"), ("string[python]", "object")] +) +def test_astype_string_and_object_update_original(dtype, new_dtype): + df = DataFrame({"a": ["a", "b", "c"]}, dtype=dtype) + df2 = df.astype(new_dtype) + df_orig = df2.copy() + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + df.iloc[0, 0] = "x" + tm.assert_frame_equal(df2, df_orig) + + +def test_astype_str_copy_on_pickle_roundrip(): + # TODO(infer_string) this test can be removed after 3.0 (once str is the default) + # https://github.com/pandas-dev/pandas/issues/54654 + # ensure_string_array may alter array inplace + base = Series(np.array([(1, 2), None, 1], dtype="object")) + base_copy = pickle.loads(pickle.dumps(base)) + base_copy.astype(str) + tm.assert_series_equal(base, base_copy) + + +def test_astype_string_copy_on_pickle_roundrip(any_string_dtype): + # https://github.com/pandas-dev/pandas/issues/54654 + # ensure_string_array may alter array inplace + base = Series(np.array([(1, 2), None, 1], dtype="object")) + base_copy = pickle.loads(pickle.dumps(base)) + base_copy.astype(any_string_dtype) + tm.assert_series_equal(base, base_copy) + + +def test_astype_string_read_only_on_pickle_roundrip(any_string_dtype): + # https://github.com/pandas-dev/pandas/issues/54654 + # ensure_string_array may alter read-only array inplace + base = Series(np.array([(1, 2), None, 1], dtype="object")) + base_copy = pickle.loads(pickle.dumps(base)) + base_copy._values.flags.writeable = False + base_copy.astype(any_string_dtype) + tm.assert_series_equal(base, base_copy) + + +def test_astype_dict_dtypes(): + df = DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": Series([1.5, 1.5, 1.5], dtype="float64")} + ) + df_orig = df.copy() + df2 = df.astype({"a": "float64", "c": "float64"}) + + assert np.shares_memory(get_array(df2, "c"), get_array(df, "c")) + assert np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + # mutating df2 triggers a copy-on-write for that column/block + df2.iloc[0, 2] = 5.5 + assert not np.shares_memory(get_array(df2, "c"), get_array(df, "c")) + + df2.iloc[0, 1] = 10 + assert not np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + tm.assert_frame_equal(df, df_orig) + + +def test_astype_different_datetime_resos(): + df = DataFrame({"a": date_range("2019-12-31", periods=2, freq="D")}) + result = df.astype("datetime64[ms]") + + assert not np.shares_memory(get_array(df, "a"), get_array(result, "a")) + assert result._mgr._has_no_reference(0) + + +def test_astype_different_timezones(): + df = DataFrame( + {"a": date_range("2019-12-31", periods=5, freq="D", tz="US/Pacific", unit="ns")} + ) + result = df.astype("datetime64[ns, Europe/Berlin]") + assert not result._mgr._has_no_reference(0) + assert np.shares_memory(get_array(df, "a"), get_array(result, "a")) + + +def test_astype_different_timezones_different_reso(): + df = DataFrame( + {"a": date_range("2019-12-31", periods=5, freq="D", tz="US/Pacific", unit="ns")} + ) + result = df.astype("datetime64[ms, Europe/Berlin]") + assert result._mgr._has_no_reference(0) + assert not np.shares_memory(get_array(df, "a"), get_array(result, "a")) + + +def test_astype_arrow_timestamp(): + pytest.importorskip("pyarrow") + df = DataFrame( + { + "a": [ + Timestamp("2020-01-01 01:01:01.000001"), + Timestamp("2020-01-01 01:01:01.000001"), + ] + }, + dtype="M8[ns]", + ) + result = df.astype("timestamp[ns][pyarrow]") + assert not result._mgr._has_no_reference(0) + assert np.shares_memory(get_array(df, "a"), get_array(result, "a")._pa_array) + + +def test_convert_dtypes_infer_objects(): + ser = Series(["a", "b", "c"]) + ser_orig = ser.copy() + result = ser.convert_dtypes( + convert_integer=False, + convert_boolean=False, + convert_floating=False, + convert_string=False, + ) + + assert tm.shares_memory(get_array(ser), get_array(result)) + result.iloc[0] = "x" + tm.assert_series_equal(ser, ser_orig) + + +def test_convert_dtypes(using_infer_string): + df = DataFrame({"a": ["a", "b"], "b": [1, 2], "c": [1.5, 2.5], "d": [True, False]}) + df_orig = df.copy() + df2 = df.convert_dtypes() + + if using_infer_string: + # String column is already Arrow-backed, so memory is shared + assert tm.shares_memory(get_array(df2, "a"), get_array(df, "a")) + else: + # String column converts from object to Arrow, no memory sharing + assert not tm.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert tm.shares_memory(get_array(df2, "d"), get_array(df, "d")) + assert tm.shares_memory(get_array(df2, "b"), get_array(df, "b")) + assert tm.shares_memory(get_array(df2, "c"), get_array(df, "c")) + df2.iloc[0, 0] = "x" + df2.iloc[0, 1] = 10 + tm.assert_frame_equal(df, df_orig) diff --git a/pandas/tests/copy_view/test_chained_assignment_deprecation.py b/pandas/tests/copy_view/test_chained_assignment_deprecation.py new file mode 100644 index 0000000000000000000000000000000000000000..d8a75fcd380c446a6822f4c551af7e152053a56a --- /dev/null +++ b/pandas/tests/copy_view/test_chained_assignment_deprecation.py @@ -0,0 +1,104 @@ +import numpy as np +import pytest + +from pandas.compat import CHAINED_WARNING_DISABLED +from pandas.errors import ChainedAssignmentError + +from pandas import DataFrame +import pandas._testing as tm + + +@pytest.mark.parametrize( + "indexer", [0, [0, 1], slice(0, 2), np.array([True, False, True])] +) +def test_series_setitem(indexer): + # ensure we only get a single warning for those typical cases of chained + # assignment + df = DataFrame({"a": [1, 2, 3], "b": 1}) + + # using custom check instead of tm.assert_produces_warning because that doesn't + # fail if multiple warnings are raised + if CHAINED_WARNING_DISABLED: + return + with pytest.warns() as record: # noqa: TID251 + df["a"][indexer] = 0 + assert len(record) == 1 + assert record[0].category == ChainedAssignmentError + + +@pytest.mark.parametrize( + "indexer", ["a", ["a", "b"], slice(0, 2), np.array([True, False, True])] +) +def test_frame_setitem(indexer): + df = DataFrame({"a": [1, 2, 3, 4, 5], "b": 1}) + + with tm.raises_chained_assignment_error(): + df[0:3][indexer] = 10 + + +@pytest.mark.parametrize( + "indexer", [0, [0, 1], slice(0, 2), np.array([True, False, True])] +) +def test_series_iloc_setitem(indexer): + df = DataFrame({"a": [1, 2, 3], "b": 1}) + + with tm.raises_chained_assignment_error(): + df["a"].iloc[indexer] = 0 + + +@pytest.mark.parametrize( + "indexer", [0, [0, 1], slice(0, 2), np.array([True, False, True])] +) +def test_frame_iloc_setitem(indexer): + df = DataFrame({"a": [1, 2, 3, 4, 5], "b": 1}) + + with tm.raises_chained_assignment_error(): + df[0:3].iloc[indexer] = 10 + + +@pytest.mark.parametrize( + "indexer", [0, [0, 1], slice(0, 2), np.array([True, False, True])] +) +def test_series_loc_setitem(indexer): + df = DataFrame({"a": [1, 2, 3], "b": 1}) + + with tm.raises_chained_assignment_error(): + df["a"].loc[indexer] = 0 + + +@pytest.mark.parametrize( + "indexer", [0, [0, 1], (0, "a"), slice(0, 2), np.array([True, False, True])] +) +def test_frame_loc_setitem(indexer): + df = DataFrame({"a": [1, 2, 3, 4, 5], "b": 1}) + + with tm.raises_chained_assignment_error(): + df[0:3].loc[indexer] = 10 + + +def test_series_at_setitem(): + df = DataFrame({"a": [1, 2, 3], "b": 1}) + + with tm.raises_chained_assignment_error(): + df["a"].at[0] = 0 + + +def test_frame_at_setitem(): + df = DataFrame({"a": [1, 2, 3, 4, 5], "b": 1}) + + with tm.raises_chained_assignment_error(): + df[0:3].at[0, "a"] = 10 + + +def test_series_iat_setitem(): + df = DataFrame({"a": [1, 2, 3], "b": 1}) + + with tm.raises_chained_assignment_error(): + df["a"].iat[0] = 0 + + +def test_frame_iat_setitem(): + df = DataFrame({"a": [1, 2, 3, 4, 5], "b": 1}) + + with tm.raises_chained_assignment_error(): + df[0:3].iat[0, 0] = 10 diff --git a/pandas/tests/copy_view/test_clip.py b/pandas/tests/copy_view/test_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..56df33db6d416e6ae2307139b531f48a012f8d4c --- /dev/null +++ b/pandas/tests/copy_view/test_clip.py @@ -0,0 +1,72 @@ +import numpy as np + +from pandas import DataFrame +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array + + +def test_clip_inplace_reference(): + df = DataFrame({"a": [1.5, 2, 3]}) + df_copy = df.copy() + arr_a = get_array(df, "a") + view = df[:] + df.clip(lower=2, inplace=True) + + assert not np.shares_memory(get_array(df, "a"), arr_a) + assert df._mgr._has_no_reference(0) + assert view._mgr._has_no_reference(0) + tm.assert_frame_equal(df_copy, view) + + +def test_clip_inplace_reference_no_op(): + df = DataFrame({"a": [1.5, 2, 3]}) + df_copy = df.copy() + arr_a = get_array(df, "a") + view = df[:] + df.clip(lower=0, inplace=True) + + assert np.shares_memory(get_array(df, "a"), arr_a) + + assert not df._mgr._has_no_reference(0) + assert not view._mgr._has_no_reference(0) + tm.assert_frame_equal(df_copy, view) + + +def test_clip_inplace(): + df = DataFrame({"a": [1.5, 2, 3]}) + arr_a = get_array(df, "a") + df.clip(lower=2, inplace=True) + + assert np.shares_memory(get_array(df, "a"), arr_a) + assert df._mgr._has_no_reference(0) + + +def test_clip(): + df = DataFrame({"a": [1.5, 2, 3]}) + df_orig = df.copy() + df2 = df.clip(lower=2) + + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + assert df._mgr._has_no_reference(0) + tm.assert_frame_equal(df_orig, df) + + +def test_clip_no_op(): + df = DataFrame({"a": [1.5, 2, 3]}) + df2 = df.clip(lower=0) + + assert not df._mgr._has_no_reference(0) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + +def test_clip_chained_inplace(): + df = DataFrame({"a": [1, 4, 2], "b": 1}) + df_orig = df.copy() + with tm.raises_chained_assignment_error(): + df["a"].clip(1, 2, inplace=True) + tm.assert_frame_equal(df, df_orig) + + with tm.raises_chained_assignment_error(): + df[["a"]].clip(1, 2, inplace=True) + tm.assert_frame_equal(df, df_orig) diff --git a/pandas/tests/copy_view/test_constructors.py b/pandas/tests/copy_view/test_constructors.py new file mode 100644 index 0000000000000000000000000000000000000000..7204aea950314f6b6e08f64291b42491c1fd415d --- /dev/null +++ b/pandas/tests/copy_view/test_constructors.py @@ -0,0 +1,382 @@ +import numpy as np +import pytest + +from pandas._config import using_string_dtype + +import pandas as pd +from pandas import ( + DataFrame, + DatetimeIndex, + Index, + Period, + PeriodIndex, + Series, + Timedelta, + TimedeltaIndex, + Timestamp, +) +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array + +# ----------------------------------------------------------------------------- +# Copy/view behaviour for Series / DataFrame constructors + + +@pytest.mark.parametrize("dtype", [None, "int64"]) +def test_series_from_series(dtype): + # Case: constructing a Series from another Series object follows CoW rules: + # a new object is returned and thus mutations are not propagated + ser = Series([1, 2, 3], name="name") + + # default is copy=False -> new Series is a shallow copy / view of original + result = Series(ser, dtype=dtype) + + # the shallow copy still shares memory + assert np.shares_memory(get_array(ser), get_array(result)) + + assert result._mgr.blocks[0].refs.has_reference() + + # mutating new series copy doesn't mutate original + result.iloc[0] = 0 + assert ser.iloc[0] == 1 + # mutating triggered a copy-on-write -> no longer shares memory + assert not np.shares_memory(get_array(ser), get_array(result)) + + # the same when modifying the parent + result = Series(ser, dtype=dtype) + + # mutating original doesn't mutate new series + ser.iloc[0] = 0 + assert result.iloc[0] == 1 + + # forcing copy=False still gives a CoW shallow copy + result = Series(ser, dtype=dtype, copy=False) + assert np.shares_memory(get_array(ser), get_array(result)) + assert result._mgr.blocks[0].refs.has_reference() + + # forcing copy=True still results in an actual hard copy up front + result = Series(ser, dtype=dtype, copy=True) + assert not np.shares_memory(get_array(ser), get_array(result)) + assert ser._mgr._has_no_reference(0) + + +def test_series_from_series_with_reindex(): + # Case: constructing a Series from another Series with specifying an index + # that potentially requires a reindex of the values + ser = Series([1, 2, 3], name="name") + + # passing an index that doesn't actually require a reindex of the values + # -> still getting a CoW shallow copy + for index in [ + ser.index, + ser.index.copy(), + list(ser.index), + ser.index.rename("idx"), + ]: + result = Series(ser, index=index) + assert np.shares_memory(ser.values, result.values) + result.iloc[0] = 0 + assert ser.iloc[0] == 1 + + # forcing copy=True still results in an actual hard copy up front + result = Series(ser, index=index, copy=True) + assert not np.shares_memory(ser.values, result.values) + assert not result._mgr.blocks[0].refs.has_reference() + + # ensure that if an actual reindex is needed, we don't have any refs + # (mutating the result wouldn't trigger CoW) + result = Series(ser, index=[0, 1, 2, 3]) + assert not np.shares_memory(ser.values, result.values) + assert not result._mgr.blocks[0].refs.has_reference() + + +@pytest.mark.parametrize("dtype", [None, "int64"]) +@pytest.mark.parametrize("idx", [None, pd.RangeIndex(start=0, stop=3, step=1)]) +@pytest.mark.parametrize( + "arr", [np.array([1, 2, 3], dtype="int64"), pd.array([1, 2, 3], dtype="Int64")] +) +def test_series_from_array(idx, dtype, arr): + ser = Series(arr, dtype=dtype, index=idx) + ser_orig = ser.copy() + data = getattr(arr, "_data", arr) + assert not np.shares_memory(get_array(ser), data) + + arr[0] = 100 + tm.assert_series_equal(ser, ser_orig) + + # if the user explicitly passes copy=False, we get an actual view + # not protected by CoW + ser = Series(arr, dtype=dtype, index=idx, copy=False) + assert np.shares_memory(get_array(ser), data) + arr[0] = 50 + assert ser.iloc[0] == 50 + + +@pytest.mark.parametrize("copy", [True, False, None]) +def test_series_from_array_different_dtype(copy): + arr = np.array([1, 2, 3], dtype="int64") + ser = Series(arr, dtype="int32", copy=copy) + assert not np.shares_memory(get_array(ser), arr) + + +@pytest.mark.parametrize( + "idx", + [ + Index([1, 2]), + DatetimeIndex([Timestamp("2019-12-31"), Timestamp("2020-12-31")]), + PeriodIndex([Period("2019-12-31"), Period("2020-12-31")]), + TimedeltaIndex([Timedelta("1 days"), Timedelta("2 days")]), + ], +) +def test_series_from_index(idx): + ser = Series(idx) + expected = idx.copy(deep=True) + assert np.shares_memory(get_array(ser), get_array(idx)) + assert not ser._mgr._has_no_reference(0) + ser.iloc[0] = ser.iloc[1] + tm.assert_index_equal(idx, expected) + + # forcing copy=False still gives a CoW shallow copy + ser = Series(idx, copy=False) + assert np.shares_memory(get_array(ser), get_array(idx)) + assert not ser._mgr._has_no_reference(0) + ser.iloc[0] = ser.iloc[1] + tm.assert_index_equal(idx, expected) + + # forcing copy=True still results in a copy + ser = Series(idx, copy=True) + assert not np.shares_memory(get_array(ser), get_array(idx)) + assert ser._mgr._has_no_reference(0) + + +@pytest.mark.parametrize("copy", [True, False, None]) +def test_series_from_index_different_dtypes(copy): + idx = Index([1, 2, 3], dtype="int64", copy=copy) + ser = Series(idx, dtype="int32") + assert not np.shares_memory(get_array(ser), get_array(idx)) + assert ser._mgr._has_no_reference(0) + + +def test_series_from_block_manager_different_dtype(): + ser = Series([1, 2, 3], dtype="int64") + msg = "Passing a SingleBlockManager to Series" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + ser2 = Series(ser._mgr, dtype="int32") + assert not np.shares_memory(get_array(ser), get_array(ser2)) + assert ser2._mgr._has_no_reference(0) + + +@pytest.mark.parametrize("use_mgr", [True, False]) +@pytest.mark.parametrize("columns", [None, ["a"]]) +def test_dataframe_constructor_mgr_or_df(columns, use_mgr): + df = DataFrame({"a": [1, 2, 3]}) + df_orig = df.copy() + + if use_mgr: + data = df._mgr + warn = DeprecationWarning + else: + data = df + warn = None + msg = "Passing a BlockManager to DataFrame" + with tm.assert_produces_warning(warn, match=msg, check_stacklevel=False): + new_df = DataFrame(data) + + assert np.shares_memory(get_array(df, "a"), get_array(new_df, "a")) + new_df.iloc[0] = 100 + + assert not np.shares_memory(get_array(df, "a"), get_array(new_df, "a")) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("dtype", [None, "int64", "Int64"]) +@pytest.mark.parametrize("index", [None, [0, 1, 2]]) +@pytest.mark.parametrize("columns", [None, ["a", "b"], ["a", "b", "c"]]) +def test_dataframe_from_dict_of_series(columns, index, dtype): + # Case: constructing a DataFrame from Series objects with copy=False + # has to do a lazy following CoW rules + # (the default for DataFrame(dict) is still to copy to ensure consolidation) + s1 = Series([1, 2, 3]) + s2 = Series([4, 5, 6]) + s1_orig = s1.copy() + expected = DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6]}, index=index, columns=columns, dtype=dtype + ) + + result = DataFrame( + {"a": s1, "b": s2}, index=index, columns=columns, dtype=dtype, copy=False + ) + + # the shallow copy still shares memory + assert np.shares_memory(get_array(result, "a"), get_array(s1)) + + # mutating the new dataframe doesn't mutate original + result.iloc[0, 0] = 10 + assert not np.shares_memory(get_array(result, "a"), get_array(s1)) + tm.assert_series_equal(s1, s1_orig) + + # the same when modifying the parent series + s1 = Series([1, 2, 3]) + s2 = Series([4, 5, 6]) + result = DataFrame( + {"a": s1, "b": s2}, index=index, columns=columns, dtype=dtype, copy=False + ) + s1.iloc[0] = 10 + assert not np.shares_memory(get_array(result, "a"), get_array(s1)) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", [None, "int64"]) +def test_dataframe_from_dict_of_series_with_reindex(dtype): + # Case: constructing a DataFrame from Series objects with copy=False + # and passing an index that requires an actual (no-view) reindex -> need + # to ensure the result doesn't have refs set up to unnecessarily trigger + # a copy on write + s1 = Series([1, 2, 3]) + s2 = Series([4, 5, 6]) + df = DataFrame({"a": s1, "b": s2}, index=[1, 2, 3], dtype=dtype, copy=False) + + # df should own its memory, so mutating shouldn't trigger a copy + arr_before = get_array(df, "a") + assert not np.shares_memory(arr_before, get_array(s1)) + df.iloc[0, 0] = 100 + arr_after = get_array(df, "a") + assert np.shares_memory(arr_before, arr_after) + + +@pytest.mark.parametrize( + "data, dtype", + [ + ([1, 2], "int64"), + # 1D-only EA + ([1, 2], "Int64"), + pytest.param( + ["a", "b"], + "str", + marks=pytest.mark.xfail( + reason="TODO bug with infer_string=False and specifying dtype='str'" + ) + if not using_string_dtype() + else [], + ), + (["a", "b"], object), + # 2D EA + ( + [Timestamp("2020", tz="UTC"), Timestamp("2021", tz="UTC")], + "datetime64[ns, UTC]", + ), + ], + ids=["int", "int-ea", "str", "object", "datetime64tz"], +) +def test_dataframe_from_series_or_index(data, dtype, index_or_series): + obj = index_or_series(data, dtype=dtype) + obj_orig = obj.copy(deep=True) # deep=True needed for Index + + # default is copy=False -> DataFrame holds a shallow copy of original Index/Series + df = DataFrame(obj) + assert tm.shares_memory(get_array(obj), get_array(df, 0)) + assert not df._mgr._has_no_reference(0) + + df.iloc[0, 0] = data[-1] + tm.assert_equal(obj, obj_orig) + + # with passing the (identical) dtype -> same + df = DataFrame(obj, dtype=dtype) + assert tm.shares_memory(get_array(obj), get_array(df, 0)) + assert not df._mgr._has_no_reference(0) + + df.iloc[0, 0] = data[-1] + tm.assert_equal(obj, obj_orig) + + # forcing copy=True still results in an actual hard copy up front + df = DataFrame(obj, copy=True) + if not (obj.dtype == "str" and obj.dtype.storage == "pyarrow"): + # ArrowExtensionArray deep copy still points to the same underlying data + assert not tm.shares_memory(get_array(obj), get_array(df, 0)) + assert df._mgr._has_no_reference(0) + + df.iloc[0, 0] = data[-1] + tm.assert_equal(obj, obj_orig) + + +def test_dataframe_from_series_or_index_different_dtype(index_or_series): + obj = index_or_series([1, 2], dtype="int64") + df = DataFrame(obj, dtype="int32") + assert not np.shares_memory(get_array(obj), get_array(df, 0)) + assert df._mgr._has_no_reference(0) + + +def test_dataframe_from_series_dont_infer_datetime(): + ser = Series([Timestamp("2019-12-31"), Timestamp("2020-12-31")], dtype=object) + df = DataFrame(ser) + assert df.dtypes.iloc[0] == np.dtype(object) + assert np.shares_memory(get_array(ser), get_array(df, 0)) + assert not df._mgr._has_no_reference(0) + + +@pytest.mark.parametrize("index", [None, [0, 1, 2]]) +def test_dataframe_from_dict_of_series_with_dtype(index): + # Variant of above, but now passing a dtype that causes a copy + # -> need to ensure the result doesn't have refs set up to unnecessarily + # trigger a copy on write + s1 = Series([1.0, 2.0, 3.0]) + s2 = Series([4, 5, 6]) + df = DataFrame({"a": s1, "b": s2}, index=index, dtype="int64", copy=False) + + # df should own its memory, so mutating shouldn't trigger a copy + arr_before = get_array(df, "a") + assert not np.shares_memory(arr_before, get_array(s1)) + df.iloc[0, 0] = 100 + arr_after = get_array(df, "a") + assert np.shares_memory(arr_before, arr_after) + + +@pytest.mark.parametrize("copy", [False, None, True]) +def test_dataframe_from_numpy_array(copy): + arr = np.array([[1, 2], [3, 4]]) + df = DataFrame(arr, copy=copy) + + if copy is not False or copy is True: + assert not np.shares_memory(get_array(df, 0), arr) + else: + assert np.shares_memory(get_array(df, 0), arr) + + +@pytest.mark.parametrize( + "data, dtype", + [ + # 1D-only EA + ([1, 2], "Int64"), + # 2D EA + ( + [Timestamp("2020", tz="UTC"), Timestamp("2021", tz="UTC")], + "datetime64[ns, UTC]", + ), + ], + ids=["int-ea", "datetime64tz"], +) +@pytest.mark.parametrize("copy", [False, None, True]) +def test_dataframe_from_extension_array(copy, data, dtype): + arr = pd.array(data, dtype=dtype) + df = DataFrame(arr, copy=copy) + + if arr.dtype == "Int64": + # to ensure tm.shares_memory works correctly + # TODO fix in tm.shares_memory or get_array? + arr = arr._data + + if copy is None or copy is True: + assert not tm.shares_memory(get_array(df, 0), arr) + else: + assert tm.shares_memory(get_array(df, 0), arr) + + +def test_frame_from_dict_of_index(): + idx = Index([1, 2, 3]) + expected = idx.copy(deep=True) + df = DataFrame({"a": idx}, copy=False) + assert np.shares_memory(get_array(df, "a"), idx._values) + assert not df._mgr._has_no_reference(0) + + df.iloc[0, 0] = 100 + tm.assert_index_equal(idx, expected) diff --git a/pandas/tests/copy_view/test_copy_deprecation.py b/pandas/tests/copy_view/test_copy_deprecation.py new file mode 100644 index 0000000000000000000000000000000000000000..acc87787dbe0a3b678bbdb347f775b59dab90d8b --- /dev/null +++ b/pandas/tests/copy_view/test_copy_deprecation.py @@ -0,0 +1,100 @@ +import pytest + +from pandas.errors import Pandas4Warning + +import pandas as pd +from pandas import ( + concat, + merge, +) +import pandas._testing as tm + + +@pytest.mark.parametrize( + "meth, kwargs", + [ + ("truncate", {}), + ("tz_convert", {"tz": "UTC"}), + ("tz_localize", {"tz": "UTC"}), + ("infer_objects", {}), + ("astype", {"dtype": "float64"}), + ("reindex", {"index": [2, 0, 1]}), + ("transpose", {}), + ("set_axis", {"labels": [1, 2, 3]}), + ("rename", {"index": {1: 2}}), + ("set_flags", {}), + ("to_period", {}), + ("to_timestamp", {}), + ("swaplevel", {"i": 0, "j": 1}), + ], +) +def test_copy_deprecation(meth, kwargs): + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": 1}) + + if meth in ("tz_convert", "tz_localize", "to_period"): + tz = None if meth in ("tz_localize", "to_period") else "US/Eastern" + df.index = pd.date_range("2020-01-01", freq="D", periods=len(df), tz=tz) + elif meth == "to_timestamp": + df.index = pd.period_range("2020-01-01", freq="D", periods=len(df)) + elif meth == "swaplevel": + df = df.set_index(["b", "c"]) + + if meth != "swaplevel": + with tm.assert_produces_warning(Pandas4Warning, match="copy"): + getattr(df, meth)(copy=False, **kwargs) + + if meth != "transpose": + with tm.assert_produces_warning(Pandas4Warning, match="copy"): + getattr(df.a, meth)(copy=False, **kwargs) + + +def test_copy_deprecation_reindex_like_align(): + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + # Somehow the stack level check is incorrect here + with tm.assert_produces_warning( + Pandas4Warning, match="copy", check_stacklevel=False + ): + df.reindex_like(df, copy=False) + + with tm.assert_produces_warning( + Pandas4Warning, match="copy", check_stacklevel=False + ): + df.a.reindex_like(df.a, copy=False) + + with tm.assert_produces_warning( + Pandas4Warning, match="copy", check_stacklevel=False + ): + df.align(df, copy=False) + + with tm.assert_produces_warning( + Pandas4Warning, match="copy", check_stacklevel=False + ): + df.a.align(df.a, copy=False) + + +def test_copy_deprecation_merge_concat(): + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + with tm.assert_produces_warning( + Pandas4Warning, match="copy", check_stacklevel=False + ): + df.merge(df, copy=False) + + with tm.assert_produces_warning( + Pandas4Warning, match="copy", check_stacklevel=False + ): + merge(df, df, copy=False) + + with tm.assert_produces_warning( + Pandas4Warning, match="copy", check_stacklevel=False + ): + concat([df, df], copy=False) + + +@pytest.mark.parametrize("value", [False, True, "warn"]) +def test_copy_on_write_deprecation_option(value): + msg = "Copy-on-Write can no longer be disabled" + # stacklevel points to contextlib due to use of context manager. + with tm.assert_produces_warning(Pandas4Warning, match=msg, check_stacklevel=False): + with pd.option_context("mode.copy_on_write", value): + pass diff --git a/pandas/tests/copy_view/test_core_functionalities.py b/pandas/tests/copy_view/test_core_functionalities.py new file mode 100644 index 0000000000000000000000000000000000000000..ad16bafdf0ee431b6af53835fc8c8ccbee7cac97 --- /dev/null +++ b/pandas/tests/copy_view/test_core_functionalities.py @@ -0,0 +1,93 @@ +import numpy as np +import pytest + +from pandas import DataFrame +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array + + +def test_assigning_to_same_variable_removes_references(): + df = DataFrame({"a": [1, 2, 3]}) + df = df.reset_index() + assert df._mgr._has_no_reference(1) + arr = get_array(df, "a") + df.iloc[0, 1] = 100 # Write into a + + assert np.shares_memory(arr, get_array(df, "a")) + + +def test_setitem_dont_track_unnecessary_references(): + df = DataFrame({"a": [1, 2, 3], "b": 1, "c": 1}) + + df["b"] = 100 + arr = get_array(df, "a") + # We split the block in setitem, if we are not careful the new blocks will + # reference each other triggering a copy + df.iloc[0, 0] = 100 + assert np.shares_memory(arr, get_array(df, "a")) + + +def test_setitem_with_view_copies(): + df = DataFrame({"a": [1, 2, 3], "b": 1, "c": 1}) + view = df[:] + expected = df.copy() + + df["b"] = 100 + arr = get_array(df, "a") + df.iloc[0, 0] = 100 # Check that we correctly track reference + assert not np.shares_memory(arr, get_array(df, "a")) + tm.assert_frame_equal(view, expected) + + +def test_setitem_with_view_invalidated_does_not_copy(request): + df = DataFrame({"a": [1, 2, 3], "b": 1, "c": 1}) + view = df[:] + + df["b"] = 100 + arr = get_array(df, "a") + view = None # noqa: F841 + # TODO(CoW) block gets split because of `df["b"] = 100` + # which introduces additional refs, even when those of `view` go out of scopes + df.iloc[0, 0] = 100 + # Setitem split the block. Since the old block shared data with view + # all the new blocks are referencing view and each other. When view + # goes out of scope, they don't share data with any other block, + # so we should not trigger a copy + mark = pytest.mark.xfail(reason="blk.delete does not track references correctly") + request.applymarker(mark) + assert np.shares_memory(arr, get_array(df, "a")) + + +def test_out_of_scope(): + def func(): + df = DataFrame({"a": [1, 2], "b": 1.5, "c": 1}) + # create some subset + result = df[["a", "b"]] + return result + + result = func() + assert not result._mgr.blocks[0].refs.has_reference() + assert not result._mgr.blocks[1].refs.has_reference() + + +def test_delete(): + df = DataFrame( + np.random.default_rng(2).standard_normal((4, 3)), columns=["a", "b", "c"] + ) + del df["b"] + assert not df._mgr.blocks[0].refs.has_reference() + assert not df._mgr.blocks[1].refs.has_reference() + + df = df[["a"]] + assert not df._mgr.blocks[0].refs.has_reference() + + +def test_delete_reference(): + df = DataFrame( + np.random.default_rng(2).standard_normal((4, 3)), columns=["a", "b", "c"] + ) + x = df[:] + del df["b"] + assert df._mgr.blocks[0].refs.has_reference() + assert df._mgr.blocks[1].refs.has_reference() + assert x._mgr.blocks[0].refs.has_reference() diff --git a/pandas/tests/copy_view/test_functions.py b/pandas/tests/copy_view/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..7e355ce1b5ed59cda09f0546ef4a76f40eb2f7d5 --- /dev/null +++ b/pandas/tests/copy_view/test_functions.py @@ -0,0 +1,332 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + Series, + concat, + merge, +) +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array + + +def test_concat_frames(): + df = DataFrame({"b": ["a"] * 3}, dtype=object) + df2 = DataFrame({"a": ["a"] * 3}, dtype=object) + df_orig = df.copy() + result = concat([df, df2], axis=1) + + assert np.shares_memory(get_array(result, "b"), get_array(df, "b")) + assert np.shares_memory(get_array(result, "a"), get_array(df2, "a")) + + result.iloc[0, 0] = "d" + assert not np.shares_memory(get_array(result, "b"), get_array(df, "b")) + assert np.shares_memory(get_array(result, "a"), get_array(df2, "a")) + + result.iloc[0, 1] = "d" + assert not np.shares_memory(get_array(result, "a"), get_array(df2, "a")) + tm.assert_frame_equal(df, df_orig) + + +def test_concat_frames_updating_input(): + df = DataFrame({"b": ["a"] * 3}, dtype=object) + df2 = DataFrame({"a": ["a"] * 3}, dtype=object) + result = concat([df, df2], axis=1) + + assert np.shares_memory(get_array(result, "b"), get_array(df, "b")) + assert np.shares_memory(get_array(result, "a"), get_array(df2, "a")) + + expected = result.copy() + df.iloc[0, 0] = "d" + assert not np.shares_memory(get_array(result, "b"), get_array(df, "b")) + assert np.shares_memory(get_array(result, "a"), get_array(df2, "a")) + + df2.iloc[0, 0] = "d" + assert not np.shares_memory(get_array(result, "a"), get_array(df2, "a")) + tm.assert_frame_equal(result, expected) + + +def test_concat_series(): + ser = Series([1, 2], name="a") + ser2 = Series([3, 4], name="b") + ser_orig = ser.copy() + ser2_orig = ser2.copy() + result = concat([ser, ser2], axis=1) + + assert np.shares_memory(get_array(result, "a"), ser.values) + assert np.shares_memory(get_array(result, "b"), ser2.values) + + result.iloc[0, 0] = 100 + assert not np.shares_memory(get_array(result, "a"), ser.values) + assert np.shares_memory(get_array(result, "b"), ser2.values) + + result.iloc[0, 1] = 1000 + assert not np.shares_memory(get_array(result, "b"), ser2.values) + tm.assert_series_equal(ser, ser_orig) + tm.assert_series_equal(ser2, ser2_orig) + + +def test_concat_frames_chained(): + df1 = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3]}) + df2 = DataFrame({"c": [4, 5, 6]}) + df3 = DataFrame({"d": [4, 5, 6]}) + result = concat([concat([df1, df2], axis=1), df3], axis=1) + expected = result.copy() + + assert np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "c"), get_array(df2, "c")) + assert np.shares_memory(get_array(result, "d"), get_array(df3, "d")) + + df1.iloc[0, 0] = 100 + assert not np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + + tm.assert_frame_equal(result, expected) + + +def test_concat_series_chained(): + ser1 = Series([1, 2, 3], name="a") + ser2 = Series([4, 5, 6], name="c") + ser3 = Series([4, 5, 6], name="d") + result = concat([concat([ser1, ser2], axis=1), ser3], axis=1) + expected = result.copy() + + assert np.shares_memory(get_array(result, "a"), get_array(ser1, "a")) + assert np.shares_memory(get_array(result, "c"), get_array(ser2, "c")) + assert np.shares_memory(get_array(result, "d"), get_array(ser3, "d")) + + ser1.iloc[0] = 100 + assert not np.shares_memory(get_array(result, "a"), get_array(ser1, "a")) + + tm.assert_frame_equal(result, expected) + + +def test_concat_series_updating_input(): + ser = Series([1, 2], name="a") + ser2 = Series([3, 4], name="b") + expected = DataFrame({"a": [1, 2], "b": [3, 4]}) + result = concat([ser, ser2], axis=1) + + assert np.shares_memory(get_array(result, "a"), get_array(ser, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(ser2, "b")) + + ser.iloc[0] = 100 + assert not np.shares_memory(get_array(result, "a"), get_array(ser, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(ser2, "b")) + tm.assert_frame_equal(result, expected) + + ser2.iloc[0] = 1000 + assert not np.shares_memory(get_array(result, "b"), get_array(ser2, "b")) + tm.assert_frame_equal(result, expected) + + +def test_concat_mixed_series_frame(): + df = DataFrame({"a": [1, 2, 3], "c": 1}) + ser = Series([4, 5, 6], name="d") + result = concat([df, ser], axis=1) + expected = result.copy() + + assert np.shares_memory(get_array(result, "a"), get_array(df, "a")) + assert np.shares_memory(get_array(result, "c"), get_array(df, "c")) + assert np.shares_memory(get_array(result, "d"), get_array(ser, "d")) + + ser.iloc[0] = 100 + assert not np.shares_memory(get_array(result, "d"), get_array(ser, "d")) + + df.iloc[0, 0] = 100 + assert not np.shares_memory(get_array(result, "a"), get_array(df, "a")) + tm.assert_frame_equal(result, expected) + + +def test_concat_copy_keyword(): + df = DataFrame({"a": [1, 2]}) + df2 = DataFrame({"b": [1.5, 2.5]}) + + result = concat([df, df2], axis=1) + + assert np.shares_memory(get_array(df, "a"), get_array(result, "a")) + assert np.shares_memory(get_array(df2, "b"), get_array(result, "b")) + + +@pytest.mark.parametrize( + "func", + [ + lambda df1, df2, **kwargs: df1.merge(df2, **kwargs), + lambda df1, df2, **kwargs: merge(df1, df2, **kwargs), + ], +) +def test_merge_on_key(func): + df1 = DataFrame({"key": Series(["a", "b", "c"], dtype=object), "a": [1, 2, 3]}) + df2 = DataFrame({"key": Series(["a", "b", "c"], dtype=object), "b": [4, 5, 6]}) + df1_orig = df1.copy() + df2_orig = df2.copy() + + result = func(df1, df2, on="key") + + assert np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + assert np.shares_memory(get_array(result, "key"), get_array(df1, "key")) + assert not np.shares_memory(get_array(result, "key"), get_array(df2, "key")) + + result.iloc[0, 1] = 0 + assert not np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + + result.iloc[0, 2] = 0 + assert not np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + tm.assert_frame_equal(df1, df1_orig) + tm.assert_frame_equal(df2, df2_orig) + + +def test_merge_on_index(): + df1 = DataFrame({"a": [1, 2, 3]}) + df2 = DataFrame({"b": [4, 5, 6]}) + df1_orig = df1.copy() + df2_orig = df2.copy() + + result = merge(df1, df2, left_index=True, right_index=True) + + assert np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + + result.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + + result.iloc[0, 1] = 0 + assert not np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + tm.assert_frame_equal(df1, df1_orig) + tm.assert_frame_equal(df2, df2_orig) + + +@pytest.mark.parametrize( + "func, how", + [ + (lambda df1, df2, **kwargs: merge(df2, df1, on="key", **kwargs), "right"), + (lambda df1, df2, **kwargs: merge(df1, df2, on="key", **kwargs), "left"), + ], +) +def test_merge_on_key_enlarging_one(func, how): + df1 = DataFrame({"key": Series(["a", "b", "c"], dtype=object), "a": [1, 2, 3]}) + df2 = DataFrame({"key": Series(["a", "b"], dtype=object), "b": [4, 5]}) + df1_orig = df1.copy() + df2_orig = df2.copy() + + result = func(df1, df2, how=how) + + assert np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert not np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + assert df2._mgr._has_no_reference(1) + assert df2._mgr._has_no_reference(0) + assert np.shares_memory(get_array(result, "key"), get_array(df1, "key")) is ( + how == "left" + ) + assert not np.shares_memory(get_array(result, "key"), get_array(df2, "key")) + + if how == "left": + result.iloc[0, 1] = 0 + else: + result.iloc[0, 2] = 0 + assert not np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + tm.assert_frame_equal(df1, df1_orig) + tm.assert_frame_equal(df2, df2_orig) + + +def test_merge_copy_keyword(): + df = DataFrame({"a": [1, 2]}) + df2 = DataFrame({"b": [3, 4.5]}) + + result = df.merge(df2, left_index=True, right_index=True) + + assert np.shares_memory(get_array(df, "a"), get_array(result, "a")) + assert np.shares_memory(get_array(df2, "b"), get_array(result, "b")) + + +def test_merge_upcasting_no_copy(): + left = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + left_copy = left.copy() + right = DataFrame({"a": [1, 2, 3], "c": [7, 8, 9]}, dtype=object) + result = merge(left, right, on="a") + assert np.shares_memory(get_array(result, "b"), get_array(left, "b")) + assert not np.shares_memory(get_array(result, "a"), get_array(left, "a")) + tm.assert_frame_equal(left, left_copy) + + result = merge(right, left, on="a") + assert np.shares_memory(get_array(result, "b"), get_array(left, "b")) + assert not np.shares_memory(get_array(result, "a"), get_array(left, "a")) + tm.assert_frame_equal(left, left_copy) + + +def test_merge_indicator_no_deep_copy(): + left = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + right = DataFrame({"a": [1, 2, 3], "c": [7, 8, 9]}) + result = merge(left, right, on="a", indicator=True) + assert np.shares_memory(get_array(result, "b"), get_array(left, "b")) + assert np.shares_memory(get_array(result, "c"), get_array(right, "c")) + + +@pytest.mark.parametrize("dtype", [object, "str"]) +def test_join_on_key(dtype): + df_index = Index(["a", "b", "c"], name="key", dtype=dtype) + + df1 = DataFrame({"a": [1, 2, 3]}, index=df_index.copy(deep=True)) + df2 = DataFrame({"b": [4, 5, 6]}, index=df_index.copy(deep=True)) + + df1_orig = df1.copy() + df2_orig = df2.copy() + + result = df1.join(df2, on="key") + + assert np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + assert tm.shares_memory(get_array(result.index), get_array(df1.index)) + assert not np.shares_memory(get_array(result.index), get_array(df2.index)) + + result.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + + result.iloc[0, 1] = 0 + assert not np.shares_memory(get_array(result, "b"), get_array(df2, "b")) + + tm.assert_frame_equal(df1, df1_orig) + tm.assert_frame_equal(df2, df2_orig) + + +def test_join_multiple_dataframes_on_key(): + df_index = Index(["a", "b", "c"], name="key", dtype=object) + + df1 = DataFrame({"a": [1, 2, 3]}, index=df_index.copy(deep=True)) + dfs_list = [ + DataFrame({"b": [4, 5, 6]}, index=df_index.copy(deep=True)), + DataFrame({"c": [7, 8, 9]}, index=df_index.copy(deep=True)), + ] + + df1_orig = df1.copy() + dfs_list_orig = [df.copy() for df in dfs_list] + + result = df1.join(dfs_list) + + assert np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(dfs_list[0], "b")) + assert np.shares_memory(get_array(result, "c"), get_array(dfs_list[1], "c")) + assert np.shares_memory(get_array(result.index), get_array(df1.index)) + assert not np.shares_memory(get_array(result.index), get_array(dfs_list[0].index)) + assert not np.shares_memory(get_array(result.index), get_array(dfs_list[1].index)) + + result.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(result, "a"), get_array(df1, "a")) + assert np.shares_memory(get_array(result, "b"), get_array(dfs_list[0], "b")) + assert np.shares_memory(get_array(result, "c"), get_array(dfs_list[1], "c")) + + result.iloc[0, 1] = 0 + assert not np.shares_memory(get_array(result, "b"), get_array(dfs_list[0], "b")) + assert np.shares_memory(get_array(result, "c"), get_array(dfs_list[1], "c")) + + result.iloc[0, 2] = 0 + assert not np.shares_memory(get_array(result, "c"), get_array(dfs_list[1], "c")) + + tm.assert_frame_equal(df1, df1_orig) + for df, df_orig in zip(dfs_list, dfs_list_orig, strict=True): + tm.assert_frame_equal(df, df_orig) diff --git a/pandas/tests/copy_view/test_indexing.py b/pandas/tests/copy_view/test_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..3e32b88849e836502e303e73296dc1f80ae253b9 --- /dev/null +++ b/pandas/tests/copy_view/test_indexing.py @@ -0,0 +1,902 @@ +import numpy as np +import pytest + +from pandas.core.dtypes.common import is_float_dtype + +import pandas as pd +from pandas import ( + DataFrame, + Series, +) +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array + + +@pytest.fixture(params=["numpy", "nullable"]) +def backend(request): + if request.param == "numpy": + + def make_dataframe(*args, **kwargs): + return DataFrame(*args, **kwargs) + + def make_series(*args, **kwargs): + return Series(*args, **kwargs) + + elif request.param == "nullable": + + def make_dataframe(*args, **kwargs): + df = DataFrame(*args, **kwargs) + df_nullable = df.convert_dtypes() + # convert_dtypes will try to cast float to int if there is no loss in + # precision -> undo that change + for col in df.columns: + if is_float_dtype(df[col].dtype) and not is_float_dtype( + df_nullable[col].dtype + ): + df_nullable[col] = df_nullable[col].astype("Float64") + # copy final result to ensure we start with a fully self-owning DataFrame + return df_nullable.copy() + + def make_series(*args, **kwargs): + ser = Series(*args, **kwargs) + return ser.convert_dtypes().copy() + + return request.param, make_dataframe, make_series + + +# ----------------------------------------------------------------------------- +# Indexing operations taking subset + modifying the subset/parent + + +def test_subset_column_selection(backend): + # Case: taking a subset of the columns of a DataFrame + # + afterwards modifying the subset + _, DataFrame, _ = backend + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + + subset = df[["a", "c"]] + + assert subset.index is not df.index + + # the subset shares memory ... + assert np.shares_memory(get_array(subset, "a"), get_array(df, "a")) + # ... but uses CoW when being modified + subset.iloc[0, 0] = 0 + + assert not np.shares_memory(get_array(subset, "a"), get_array(df, "a")) + + expected = DataFrame({"a": [0, 2, 3], "c": [0.1, 0.2, 0.3]}) + tm.assert_frame_equal(subset, expected) + tm.assert_frame_equal(df, df_orig) + + +def test_subset_column_selection_modify_parent(backend): + # Case: taking a subset of the columns of a DataFrame + # + afterwards modifying the parent + _, DataFrame, _ = backend + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + + subset = df[["a", "c"]] + + # the subset shares memory ... + assert np.shares_memory(get_array(subset, "a"), get_array(df, "a")) + # ... but parent uses CoW parent when it is modified + df.iloc[0, 0] = 0 + + assert not np.shares_memory(get_array(subset, "a"), get_array(df, "a")) + # different column/block still shares memory + assert np.shares_memory(get_array(subset, "c"), get_array(df, "c")) + + expected = DataFrame({"a": [1, 2, 3], "c": [0.1, 0.2, 0.3]}) + tm.assert_frame_equal(subset, expected) + + +def test_subset_row_slice(backend): + # Case: taking a subset of the rows of a DataFrame using a slice + # + afterwards modifying the subset + _, DataFrame, _ = backend + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + + subset = df[1:3] + subset._mgr._verify_integrity() + + assert subset.columns is not df.columns + assert np.shares_memory(get_array(subset, "a"), get_array(df, "a")) + + subset.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(subset, "a"), get_array(df, "a")) + + subset._mgr._verify_integrity() + + expected = DataFrame({"a": [0, 3], "b": [5, 6], "c": [0.2, 0.3]}, index=range(1, 3)) + tm.assert_frame_equal(subset, expected) + # original parent dataframe is not modified (CoW) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "dtype", ["int64", "float64"], ids=["single-block", "mixed-block"] +) +def test_subset_column_slice(backend, dtype): + # Case: taking a subset of the columns of a DataFrame using a slice + # + afterwards modifying the subset + dtype_backend, DataFrame, _ = backend + df = DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": np.array([7, 8, 9], dtype=dtype)} + ) + df_orig = df.copy() + + subset = df.iloc[:, 1:] + subset._mgr._verify_integrity() + + assert subset.index is not df.index + assert np.shares_memory(get_array(subset, "b"), get_array(df, "b")) + + subset.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(subset, "b"), get_array(df, "b")) + + expected = DataFrame({"b": [0, 5, 6], "c": np.array([7, 8, 9], dtype=dtype)}) + tm.assert_frame_equal(subset, expected) + # original parent dataframe is not modified (also not for BlockManager case, + # except for single block) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "dtype", ["int64", "float64"], ids=["single-block", "mixed-block"] +) +@pytest.mark.parametrize( + "row_indexer", + [slice(1, 2), np.array([False, True, True]), np.array([1, 2])], + ids=["slice", "mask", "array"], +) +@pytest.mark.parametrize( + "column_indexer", + [slice("b", "c"), np.array([False, True, True]), ["b", "c"]], + ids=["slice", "mask", "array"], +) +def test_subset_loc_rows_columns( + backend, + dtype, + row_indexer, + column_indexer, +): + # Case: taking a subset of the rows+columns of a DataFrame using .loc + # + afterwards modifying the subset + # Generic test for several combinations of row/column indexers, not all + # of those could actually return a view / need CoW (so this test is not + # checking memory sharing, only ensuring subsequent mutation doesn't + # affect the parent dataframe) + dtype_backend, DataFrame, _ = backend + df = DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": np.array([7, 8, 9], dtype=dtype)} + ) + df_orig = df.copy() + + subset = df.loc[row_indexer, column_indexer] + + assert subset.index is not df.index + assert subset.columns is not df.columns + + # modifying the subset never modifies the parent + subset.iloc[0, 0] = 0 + + expected = DataFrame( + {"b": [0, 6], "c": np.array([8, 9], dtype=dtype)}, index=range(1, 3) + ) + tm.assert_frame_equal(subset, expected) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "dtype", ["int64", "float64"], ids=["single-block", "mixed-block"] +) +@pytest.mark.parametrize( + "row_indexer", + [slice(1, 3), np.array([False, True, True]), np.array([1, 2])], + ids=["slice", "mask", "array"], +) +@pytest.mark.parametrize( + "column_indexer", + [slice(1, 3), np.array([False, True, True]), [1, 2]], + ids=["slice", "mask", "array"], +) +def test_subset_iloc_rows_columns( + backend, + dtype, + row_indexer, + column_indexer, +): + # Case: taking a subset of the rows+columns of a DataFrame using .iloc + # + afterwards modifying the subset + # Generic test for several combinations of row/column indexers, not all + # of those could actually return a view / need CoW (so this test is not + # checking memory sharing, only ensuring subsequent mutation doesn't + # affect the parent dataframe) + dtype_backend, DataFrame, _ = backend + df = DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": np.array([7, 8, 9], dtype=dtype)} + ) + df_orig = df.copy() + + subset = df.iloc[row_indexer, column_indexer] + + assert subset.index is not df.index + assert subset.columns is not df.columns + + # modifying the subset never modifies the parent + subset.iloc[0, 0] = 0 + + expected = DataFrame( + {"b": [0, 6], "c": np.array([8, 9], dtype=dtype)}, index=range(1, 3) + ) + tm.assert_frame_equal(subset, expected) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "indexer", + [slice(0, 2), np.array([True, True, False]), np.array([0, 1])], + ids=["slice", "mask", "array"], +) +def test_subset_set_with_row_indexer(backend, indexer_si, indexer): + # Case: setting values with a row indexer on a viewing subset + # subset[indexer] = value and subset.iloc[indexer] = value + _, DataFrame, _ = backend + df = DataFrame({"a": [1, 2, 3, 4], "b": [4, 5, 6, 7], "c": [0.1, 0.2, 0.3, 0.4]}) + df_orig = df.copy() + subset = df[1:4] + + if ( + indexer_si is tm.setitem + and isinstance(indexer, np.ndarray) + and indexer.dtype == "int" + ): + pytest.skip("setitem with labels selects on columns") + + indexer_si(subset)[indexer] = 0 + + expected = DataFrame( + {"a": [0, 0, 4], "b": [0, 0, 7], "c": [0.0, 0.0, 0.4]}, index=range(1, 4) + ) + tm.assert_frame_equal(subset, expected) + # original parent dataframe is not modified (CoW) + tm.assert_frame_equal(df, df_orig) + + +def test_subset_set_with_mask(backend): + # Case: setting values with a mask on a viewing subset: subset[mask] = value + _, DataFrame, _ = backend + df = DataFrame({"a": [1, 2, 3, 4], "b": [4, 5, 6, 7], "c": [0.1, 0.2, 0.3, 0.4]}) + df_orig = df.copy() + subset = df[1:4] + + mask = subset > 3 + + subset[mask] = 0 + + expected = DataFrame( + {"a": [2, 3, 0], "b": [0, 0, 0], "c": [0.20, 0.3, 0.4]}, index=range(1, 4) + ) + tm.assert_frame_equal(subset, expected) + tm.assert_frame_equal(df, df_orig) + + +def test_subset_set_column(backend): + # Case: setting a single column on a viewing subset -> subset[col] = value + dtype_backend, DataFrame, _ = backend + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + subset = df[1:3] + + if dtype_backend == "numpy": + arr = np.array([10, 11], dtype="int64") + else: + arr = pd.array([10, 11], dtype="Int64") + + subset["a"] = arr + subset._mgr._verify_integrity() + expected = DataFrame( + {"a": [10, 11], "b": [5, 6], "c": [0.2, 0.3]}, index=range(1, 3) + ) + tm.assert_frame_equal(subset, expected) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "dtype", ["int64", "float64"], ids=["single-block", "mixed-block"] +) +def test_subset_set_column_with_loc(backend, dtype): + # Case: setting a single column with loc on a viewing subset + # -> subset.loc[:, col] = value + _, DataFrame, _ = backend + df = DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": np.array([7, 8, 9], dtype=dtype)} + ) + df_orig = df.copy() + subset = df[1:3] + + subset.loc[:, "a"] = np.array([10, 11], dtype="int64") + + subset._mgr._verify_integrity() + expected = DataFrame( + {"a": [10, 11], "b": [5, 6], "c": np.array([8, 9], dtype=dtype)}, + index=range(1, 3), + ) + tm.assert_frame_equal(subset, expected) + # original parent dataframe is not modified (CoW) + tm.assert_frame_equal(df, df_orig) + + +def test_subset_set_column_with_loc2(backend): + # Case: setting a single column with loc on a viewing subset + # -> subset.loc[:, col] = value + # separate test for case of DataFrame of a single column -> takes a separate + # code path + _, DataFrame, _ = backend + df = DataFrame({"a": [1, 2, 3]}) + df_orig = df.copy() + subset = df[1:3] + + subset.loc[:, "a"] = 0 + + subset._mgr._verify_integrity() + expected = DataFrame({"a": [0, 0]}, index=range(1, 3)) + tm.assert_frame_equal(subset, expected) + # original parent dataframe is not modified (CoW) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "dtype", ["int64", "float64"], ids=["single-block", "mixed-block"] +) +def test_subset_set_columns(backend, dtype): + # Case: setting multiple columns on a viewing subset + # -> subset[[col1, col2]] = value + dtype_backend, DataFrame, _ = backend + df = DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": np.array([7, 8, 9], dtype=dtype)} + ) + df_orig = df.copy() + subset = df[1:3] + + subset[["a", "c"]] = 0 + + subset._mgr._verify_integrity() + # first and third column should certainly have no references anymore + assert all(subset._mgr._has_no_reference(i) for i in [0, 2]) + expected = DataFrame({"a": [0, 0], "b": [5, 6], "c": [0, 0]}, index=range(1, 3)) + if dtype_backend == "nullable": + # there is not yet a global option, so overriding a column by setting a scalar + # defaults to numpy dtype even if original column was nullable + expected["a"] = expected["a"].astype("int64") + expected["c"] = expected["c"].astype("int64") + + tm.assert_frame_equal(subset, expected) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "indexer", + [slice("a", "b"), np.array([True, True, False]), ["a", "b"]], + ids=["slice", "mask", "array"], +) +def test_subset_set_with_column_indexer(backend, indexer): + # Case: setting multiple columns with a column indexer on a viewing subset + # -> subset.loc[:, [col1, col2]] = value + _, DataFrame, _ = backend + df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3], "c": [4, 5, 6]}) + df_orig = df.copy() + subset = df[1:3] + + subset.loc[:, indexer] = 0 + + subset._mgr._verify_integrity() + expected = DataFrame({"a": [0, 0], "b": [0.0, 0.0], "c": [5, 6]}, index=range(1, 3)) + tm.assert_frame_equal(subset, expected) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "method", + [ + lambda df: df[["a", "b"]][0:2], + lambda df: df[0:2][["a", "b"]], + lambda df: df[["a", "b"]].iloc[0:2], + lambda df: df[["a", "b"]].loc[0:1], + lambda df: df[0:2].iloc[:, 0:2], + lambda df: df[0:2].loc[:, "a":"b"], # type: ignore[misc] + ], + ids=[ + "row-getitem-slice", + "column-getitem", + "row-iloc-slice", + "row-loc-slice", + "column-iloc-slice", + "column-loc-slice", + ], +) +@pytest.mark.parametrize( + "dtype", ["int64", "float64"], ids=["single-block", "mixed-block"] +) +def test_subset_chained_getitem( + request, + backend, + method, + dtype, +): + # Case: creating a subset using multiple, chained getitem calls using views + # still needs to guarantee proper CoW behaviour + _, DataFrame, _ = backend + df = DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": np.array([7, 8, 9], dtype=dtype)} + ) + df_orig = df.copy() + + # modify subset -> don't modify parent + subset = method(df) + + subset.iloc[0, 0] = 0 + tm.assert_frame_equal(df, df_orig) + + # modify parent -> don't modify subset + subset = method(df) + df.iloc[0, 0] = 0 + expected = DataFrame({"a": [1, 2], "b": [4, 5]}) + tm.assert_frame_equal(subset, expected) + + +@pytest.mark.parametrize( + "dtype", ["int64", "float64"], ids=["single-block", "mixed-block"] +) +def test_subset_chained_getitem_column(backend, dtype): + # Case: creating a subset using multiple, chained getitem calls using views + # still needs to guarantee proper CoW behaviour + dtype_backend, DataFrame, Series = backend + df = DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": np.array([7, 8, 9], dtype=dtype)} + ) + df_orig = df.copy() + + # modify subset -> don't modify parent + subset = df[:]["a"][0:2] + subset.iloc[0] = 0 + tm.assert_frame_equal(df, df_orig) + + # modify parent -> don't modify subset + subset = df[:]["a"][0:2] + df.iloc[0, 0] = 0 + expected = Series([1, 2], name="a") + tm.assert_series_equal(subset, expected) + + +@pytest.mark.parametrize( + "method", + [ + lambda s: s["a":"c"]["a":"b"], # type: ignore[misc] + lambda s: s.iloc[0:3].iloc[0:2], + lambda s: s.loc["a":"c"].loc["a":"b"], # type: ignore[misc] + lambda s: s.loc["a":"c"] # type: ignore[misc] + .iloc[0:3] + .iloc[0:2] + .loc["a":"b"] # type: ignore[misc] + .iloc[0:1], + ], + ids=["getitem", "iloc", "loc", "long-chain"], +) +def test_subset_chained_getitem_series(backend, method): + # Case: creating a subset using multiple, chained getitem calls using views + # still needs to guarantee proper CoW behaviour + _, _, Series = backend + s = Series([1, 2, 3], index=["a", "b", "c"]) + s_orig = s.copy() + + # modify subset -> don't modify parent + subset = method(s) + subset.iloc[0] = 0 + tm.assert_series_equal(s, s_orig) + + # modify parent -> don't modify subset + subset = s.iloc[0:3].iloc[0:2] + s.iloc[0] = 0 + expected = Series([1, 2], index=["a", "b"]) + tm.assert_series_equal(subset, expected) + + +def test_subset_chained_single_block_row(): + # not parametrizing this for dtype backend, since this explicitly tests single block + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + df_orig = df.copy() + + # modify subset -> don't modify parent + subset = df[:].iloc[0].iloc[0:2] + subset.iloc[0] = 0 + tm.assert_frame_equal(df, df_orig) + + # modify parent -> don't modify subset + subset = df[:].iloc[0].iloc[0:2] + df.iloc[0, 0] = 0 + expected = Series([1, 4], index=["a", "b"], name=0) + tm.assert_series_equal(subset, expected) + + +@pytest.mark.parametrize( + "method", + [ + lambda df: df[:], + lambda df: df.loc[:, :], + lambda df: df.loc[:], + lambda df: df.iloc[:, :], + lambda df: df.iloc[:], + ], + ids=["getitem", "loc", "loc-rows", "iloc", "iloc-rows"], +) +def test_null_slice(backend, method): + # Case: also all variants of indexing with a null slice (:) should return + # new objects to ensure we correctly use CoW for the results + dtype_backend, DataFrame, _ = backend + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + df_orig = df.copy() + + df2 = method(df) + + # we always return new objects (shallow copy), regardless of CoW or not + assert df2 is not df + assert df2.index is not df.index + assert df2.columns is not df.columns + + # and those trigger CoW when mutated + df2.iloc[0, 0] = 0 + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "method", + [ + lambda s: s[:], + lambda s: s.loc[:], + lambda s: s.iloc[:], + ], + ids=["getitem", "loc", "iloc"], +) +def test_null_slice_series(backend, method): + _, _, Series = backend + s = Series([1, 2, 3], index=["a", "b", "c"]) + s_orig = s.copy() + + s2 = method(s) + + # we always return new objects, regardless of CoW or not + assert s2 is not s + assert s2.index is not s.index + + # and those trigger CoW when mutated + s2.iloc[0] = 0 + tm.assert_series_equal(s, s_orig) + + +# TODO add more tests modifying the parent + + +# ----------------------------------------------------------------------------- +# Series -- Indexing operations taking subset + modifying the subset/parent + + +def test_series_getitem_slice(backend): + # Case: taking a slice of a Series + afterwards modifying the subset + _, _, Series = backend + s = Series([1, 2, 3], index=["a", "b", "c"]) + s_orig = s.copy() + + subset = s[:] + assert np.shares_memory(get_array(subset), get_array(s)) + assert subset.index is not s.index + + subset.iloc[0] = 0 + + assert not np.shares_memory(get_array(subset), get_array(s)) + + expected = Series([0, 2, 3], index=["a", "b", "c"]) + tm.assert_series_equal(subset, expected) + + # original parent series is not modified (CoW) + tm.assert_series_equal(s, s_orig) + + +def test_series_getitem_ellipsis(): + # Case: taking a view of a Series using Ellipsis + afterwards modifying the subset + s = Series([1, 2, 3]) + s_orig = s.copy() + + subset = s[...] + assert np.shares_memory(get_array(subset), get_array(s)) + assert subset.index is not s.index + + subset.iloc[0] = 0 + + assert not np.shares_memory(get_array(subset), get_array(s)) + + expected = Series([0, 2, 3]) + tm.assert_series_equal(subset, expected) + + # original parent series is not modified (CoW) + tm.assert_series_equal(s, s_orig) + + +@pytest.mark.parametrize( + "indexer", + [slice(0, 2), np.array([True, True, False]), np.array([0, 1])], + ids=["slice", "mask", "array"], +) +def test_series_subset_set_with_indexer(backend, indexer_si, indexer): + # Case: setting values in a viewing Series with an indexer + _, _, Series = backend + s = Series([1, 2, 3], index=["a", "b", "c"]) + s_orig = s.copy() + subset = s[:] + + if ( + indexer_si is tm.setitem + and isinstance(indexer, np.ndarray) + and indexer.dtype.kind == "i" + ): + # In 3.0 we treat integers as always-labels + with pytest.raises(KeyError): + indexer_si(subset)[indexer] = 0 + return + + indexer_si(subset)[indexer] = 0 + expected = Series([0, 0, 3], index=["a", "b", "c"]) + tm.assert_series_equal(subset, expected) + + tm.assert_series_equal(s, s_orig) + + +# ----------------------------------------------------------------------------- +# del operator + + +def test_del_frame(backend): + # Case: deleting a column with `del` on a viewing child dataframe should + # not modify parent + update the references + dtype_backend, DataFrame, _ = backend + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + df2 = df[:] + + assert np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + + del df2["b"] + + assert np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + tm.assert_frame_equal(df, df_orig) + tm.assert_frame_equal(df2, df_orig[["a", "c"]]) + df2._mgr._verify_integrity() + + df.loc[0, "b"] = 200 + assert np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + df_orig = df.copy() + + df2.loc[0, "a"] = 100 + # modifying child after deleting a column still doesn't update parent + tm.assert_frame_equal(df, df_orig) + + +def test_del_series(backend): + _, _, Series = backend + s = Series([1, 2, 3], index=["a", "b", "c"]) + s_orig = s.copy() + s2 = s[:] + + assert np.shares_memory(get_array(s), get_array(s2)) + + del s2["a"] + + assert not np.shares_memory(get_array(s), get_array(s2)) + tm.assert_series_equal(s, s_orig) + tm.assert_series_equal(s2, s_orig[["b", "c"]]) + + # modifying s2 doesn't need copy on write (due to `del`, s2 is backed by new array) + values = s2.values + s2.loc["b"] = 100 + assert values[0] == 100 + + +# ----------------------------------------------------------------------------- +# Accessing column as Series + + +def test_column_as_series(backend): + # Case: selecting a single column now also uses Copy-on-Write + dtype_backend, DataFrame, Series = backend + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + + s = df["a"] + + assert s.index is not df.index + assert np.shares_memory(get_array(s, "a"), get_array(df, "a")) + + s[0] = 0 + + expected = Series([0, 2, 3], name="a") + tm.assert_series_equal(s, expected) + # assert not np.shares_memory(s.values, get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + # ensure cached series on getitem is not the changed series + tm.assert_series_equal(df["a"], df_orig["a"]) + + +def test_column_as_series_set_with_upcast(backend): + # Case: selecting a single column now also uses Copy-on-Write -> when + # setting a value causes an upcast, we don't need to update the parent + # DataFrame through the cache mechanism + dtype_backend, DataFrame, Series = backend + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + + s = df["a"] + if dtype_backend == "nullable": + with pytest.raises(TypeError, match="Invalid value"): + s[0] = "foo" + expected = Series([1, 2, 3], name="a") + tm.assert_series_equal(s, expected) + tm.assert_frame_equal(df, df_orig) + # ensure cached series on getitem is not the changed series + tm.assert_series_equal(df["a"], df_orig["a"]) + else: + with pytest.raises(TypeError, match="Invalid value"): + s[0] = "foo" + + +@pytest.mark.parametrize( + "method", + [ + lambda df: df["a"], + lambda df: df.loc[:, "a"], + lambda df: df.iloc[:, 0], + ], + ids=["getitem", "loc", "iloc"], +) +def test_column_as_series_no_item_cache(request, backend, method): + # Case: selecting a single column (which now also uses Copy-on-Write to protect + # the view) should always give a new object (i.e. not make use of a cache) + dtype_backend, DataFrame, _ = backend + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + + s1 = method(df) + s2 = method(df) + + assert s1 is not s2 + assert s1.index is not df.index + assert s1.index is not s2.index + + s1.iloc[0] = 0 + + tm.assert_series_equal(s2, df_orig["a"]) + tm.assert_frame_equal(df, df_orig) + + +# TODO add tests for other indexing methods on the Series + + +def test_dataframe_add_column_from_series(backend): + # Case: adding a new column to a DataFrame from an existing column/series + # -> delays copy under CoW + _, DataFrame, Series = backend + df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3]}) + + s = Series([10, 11, 12]) + df["new"] = s + assert np.shares_memory(get_array(df, "new"), get_array(s)) + + # editing series -> doesn't modify column in frame + s[0] = 0 + expected = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3], "new": [10, 11, 12]}) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.parametrize("val", [100, "a"]) +@pytest.mark.parametrize( + "indexer_func, indexer", + [ + (tm.loc, (0, "a")), + (tm.iloc, (0, 0)), + (tm.loc, ([0], "a")), + (tm.iloc, ([0], 0)), + (tm.loc, (slice(None), "a")), + (tm.iloc, (slice(None), 0)), + ], +) +@pytest.mark.parametrize( + "col", [[0.1, 0.2, 0.3], [7, 8, 9]], ids=["mixed-block", "single-block"] +) +def test_set_value_copy_only_necessary_column(indexer_func, indexer, val, col): + # When setting inplace, only copy column that is modified instead of the whole + # block (by splitting the block) + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": col}) + df_orig = df.copy() + view = df[:] + + if val == "a": + with pytest.raises(TypeError, match="Invalid value"): + indexer_func(df)[indexer] = val + else: + indexer_func(df)[indexer] = val + + assert np.shares_memory(get_array(df, "b"), get_array(view, "b")) + assert not np.shares_memory(get_array(df, "a"), get_array(view, "a")) + tm.assert_frame_equal(view, df_orig) + + +def test_series_midx_slice(): + ser = Series([1, 2, 3], index=pd.MultiIndex.from_arrays([[1, 1, 2], [3, 4, 5]])) + ser_orig = ser.copy() + result = ser[1] + assert np.shares_memory(get_array(ser), get_array(result)) + result.iloc[0] = 100 + tm.assert_series_equal(ser, ser_orig) + + +def test_getitem_midx_slice(): + df = DataFrame({("a", "x"): [1, 2], ("a", "y"): 1, ("b", "x"): 2}) + df_orig = df.copy() + new_df = df[("a",)] + + assert not new_df._mgr._has_no_reference(0) + + assert np.shares_memory(get_array(df, ("a", "x")), get_array(new_df, "x")) + new_df.iloc[0, 0] = 100 + tm.assert_frame_equal(df_orig, df) + + +def test_series_midx_tuples_slice(): + ser = Series( + [1, 2, 3], + index=pd.MultiIndex.from_tuples([((1, 2), 3), ((1, 2), 4), ((2, 3), 4)]), + ) + result = ser[(1, 2)] + assert np.shares_memory(get_array(ser), get_array(result)) + result.iloc[0] = 100 + expected = Series( + [1, 2, 3], + index=pd.MultiIndex.from_tuples([((1, 2), 3), ((1, 2), 4), ((2, 3), 4)]), + ) + tm.assert_series_equal(ser, expected) + + +def test_midx_read_only_bool_indexer(): + # GH#56635 + def mklbl(prefix, n): + return [f"{prefix}{i}" for i in range(n)] + + idx = pd.MultiIndex.from_product( + [mklbl("A", 4), mklbl("B", 2), mklbl("C", 4), mklbl("D", 2)] + ) + cols = pd.MultiIndex.from_tuples( + [("a", "foo"), ("a", "bar"), ("b", "foo"), ("b", "bah")], names=["lvl0", "lvl1"] + ) + df = DataFrame(1, index=idx, columns=cols).sort_index().sort_index(axis=1) + + mask = df[("a", "foo")] == 1 + expected_mask = mask.copy() + result = df.loc[pd.IndexSlice[mask, :, ["C1", "C3"]], :] + expected = df.loc[pd.IndexSlice[:, :, ["C1", "C3"]], :] + tm.assert_frame_equal(result, expected) + tm.assert_series_equal(mask, expected_mask) + + +def test_loc_enlarging_with_dataframe(): + df = DataFrame({"a": [1, 2, 3]}) + rhs = DataFrame({"b": [1, 2, 3], "c": [4, 5, 6]}) + rhs_orig = rhs.copy() + df.loc[:, ["b", "c"]] = rhs + assert np.shares_memory(get_array(df, "b"), get_array(rhs, "b")) + assert np.shares_memory(get_array(df, "c"), get_array(rhs, "c")) + assert not df._mgr._has_no_reference(1) + + df.iloc[0, 1] = 100 + tm.assert_frame_equal(rhs, rhs_orig) diff --git a/pandas/tests/copy_view/test_internals.py b/pandas/tests/copy_view/test_internals.py new file mode 100644 index 0000000000000000000000000000000000000000..b7baf01ecc36e5a3cfed4ba443d244ac669344fd --- /dev/null +++ b/pandas/tests/copy_view/test_internals.py @@ -0,0 +1,112 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, +) +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array + + +def test_consolidate(): + # create unconsolidated DataFrame + df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3]}) + df["c"] = [4, 5, 6] + + # take a viewing subset + subset = df[:] + + # each block of subset references a block of df + assert all(blk.refs.has_reference() for blk in subset._mgr.blocks) + + # consolidate the two int64 blocks + subset._consolidate_inplace() + + # the float64 block still references the parent one because it still a view + assert subset._mgr.blocks[0].refs.has_reference() + # equivalent of assert np.shares_memory(df["b"].values, subset["b"].values) + # but avoids caching df["b"] + assert np.shares_memory(get_array(df, "b"), get_array(subset, "b")) + + # the new consolidated int64 block does not reference another + assert not subset._mgr.blocks[1].refs.has_reference() + + # the parent dataframe now also only is linked for the float column + assert not df._mgr.blocks[0].refs.has_reference() + assert df._mgr.blocks[1].refs.has_reference() + assert not df._mgr.blocks[2].refs.has_reference() + + # and modifying subset still doesn't modify parent + subset.iloc[0, 1] = 0.0 + assert not df._mgr.blocks[1].refs.has_reference() + assert df.loc[0, "b"] == 0.1 + + +@pytest.mark.parametrize("dtype", [np.intp, np.int8]) +@pytest.mark.parametrize( + "locs, arr", + [ + ([0], np.array([-1, -2, -3])), + ([1], np.array([-1, -2, -3])), + ([5], np.array([-1, -2, -3])), + ([0, 1], np.array([[-1, -2, -3], [-4, -5, -6]]).T), + ([0, 2], np.array([[-1, -2, -3], [-4, -5, -6]]).T), + ([0, 1, 2], np.array([[-1, -2, -3], [-4, -5, -6], [-4, -5, -6]]).T), + ([1, 2], np.array([[-1, -2, -3], [-4, -5, -6]]).T), + ([1, 3], np.array([[-1, -2, -3], [-4, -5, -6]]).T), + ], +) +def test_iset_splits_blocks_inplace(locs, arr, dtype): + # Nothing currently calls iset with + # more than 1 loc with inplace=True (only happens with inplace=False) + # but ensure that it works + df = DataFrame( + { + "a": [1, 2, 3], + "b": [4, 5, 6], + "c": [7, 8, 9], + "d": [10, 11, 12], + "e": [13, 14, 15], + "f": Series(["a", "b", "c"], dtype=object), + }, + ) + arr = arr.astype(dtype) + df_orig = df.copy() + df2 = df.copy(deep=False) # Trigger a CoW (if enabled, otherwise makes copy) + df2._mgr.iset(locs, arr, inplace=True) + + tm.assert_frame_equal(df, df_orig) + for i, col in enumerate(df.columns): + if i not in locs: + assert np.shares_memory(get_array(df, col), get_array(df2, col)) + + +def test_exponential_backoff(): + # GH#55518 + df = DataFrame({"a": [1, 2, 3]}) + for i in range(490): + df.copy(deep=False) + + assert len(df._mgr.blocks[0].refs.referenced_blocks) == 491 + + df = DataFrame({"a": [1, 2, 3]}) + dfs = [df.copy(deep=False) for i in range(510)] + + for i in range(20): + df.copy(deep=False) + assert len(df._mgr.blocks[0].refs.referenced_blocks) == 531 + assert df._mgr.blocks[0].refs.clear_counter == 1000 + + for i in range(500): + df.copy(deep=False) + + # Don't reduce since we still have over 500 objects alive + assert df._mgr.blocks[0].refs.clear_counter == 1000 + + dfs = dfs[:300] + for i in range(500): + df.copy(deep=False) + + # Reduce since there are less than 500 objects alive + assert df._mgr.blocks[0].refs.clear_counter == 500 diff --git a/pandas/tests/copy_view/test_interp_fillna.py b/pandas/tests/copy_view/test_interp_fillna.py new file mode 100644 index 0000000000000000000000000000000000000000..d5880e99df5d7122c1bb396189ae735dfe2e77ba --- /dev/null +++ b/pandas/tests/copy_view/test_interp_fillna.py @@ -0,0 +1,307 @@ +import numpy as np +import pytest + +from pandas import ( + NA, + DataFrame, + Interval, + NaT, + Series, + Timestamp, + interval_range, +) +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array + + +@pytest.mark.parametrize("method", ["pad", "nearest", "linear"]) +def test_interpolate_no_op(method): + df = DataFrame({"a": [1, 2]}) + df_orig = df.copy() + + if method == "pad": + msg = f"Can not interpolate with method={method}" + with pytest.raises(ValueError, match=msg): + df.interpolate(method=method) + else: + result = df.interpolate(method=method) + assert np.shares_memory(get_array(result, "a"), get_array(df, "a")) + assert result.index is not df.index + assert result.columns is not df.columns + + result.iloc[0, 0] = 100 + + assert not np.shares_memory(get_array(result, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("func", ["ffill", "bfill"]) +def test_interp_fill_functions(func): + # Check that these takes the same code paths as interpolate + df = DataFrame({"a": [1, 2]}) + df_orig = df.copy() + + result = getattr(df, func)() + + assert np.shares_memory(get_array(result, "a"), get_array(df, "a")) + assert result.index is not df.index + assert result.columns is not df.columns + + result.iloc[0, 0] = 100 + assert not np.shares_memory(get_array(result, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("func", ["ffill", "bfill"]) +@pytest.mark.parametrize( + "vals", [[1, np.nan, 2], [Timestamp("2019-12-31"), NaT, Timestamp("2020-12-31")]] +) +def test_interpolate_triggers_copy(vals, func): + df = DataFrame({"a": vals}) + result = getattr(df, func)() + + assert not np.shares_memory(get_array(result, "a"), get_array(df, "a")) + # Check that we don't have references when triggering a copy + assert result._mgr._has_no_reference(0) + + +@pytest.mark.parametrize( + "vals", [[1, np.nan, 2], [Timestamp("2019-12-31"), NaT, Timestamp("2020-12-31")]] +) +def test_interpolate_inplace_no_reference_no_copy(vals): + df = DataFrame({"a": vals}) + arr = get_array(df, "a") + df.interpolate(method="linear", inplace=True) + + assert np.shares_memory(arr, get_array(df, "a")) + # Check that we don't have references when triggering a copy + assert df._mgr._has_no_reference(0) + + +@pytest.mark.parametrize( + "vals", [[1, np.nan, 2], [Timestamp("2019-12-31"), NaT, Timestamp("2020-12-31")]] +) +def test_interpolate_inplace_with_refs(vals): + df = DataFrame({"a": [1, np.nan, 2]}) + df_orig = df.copy() + arr = get_array(df, "a") + view = df[:] + df.interpolate(method="linear", inplace=True) + # Check that copy was triggered in interpolate and that we don't + # have any references left + assert not np.shares_memory(arr, get_array(df, "a")) + tm.assert_frame_equal(df_orig, view) + assert df._mgr._has_no_reference(0) + assert view._mgr._has_no_reference(0) + + +@pytest.mark.parametrize("func", ["ffill", "bfill"]) +@pytest.mark.parametrize("dtype", ["float64", "Float64"]) +def test_interp_fill_functions_inplace(func, dtype): + # Check that these takes the same code paths as interpolate + df = DataFrame({"a": [1, np.nan, 2]}, dtype=dtype) + df_orig = df.copy() + arr = get_array(df, "a") + view = df[:] + + getattr(df, func)(inplace=True) + + # Check that copy was triggered in interpolate and that we don't + # have any references left + assert not np.shares_memory(arr, get_array(df, "a")) + tm.assert_frame_equal(df_orig, view) + assert df._mgr._has_no_reference(0) + assert view._mgr._has_no_reference(0) + + +def test_interpolate_cannot_with_object_dtype(): + df = DataFrame({"a": ["a", np.nan, "c"], "b": 1}) + df["a"] = df["a"].astype(object) + + msg = "DataFrame cannot interpolate with object dtype" + with pytest.raises(TypeError, match=msg): + df.interpolate() + + +def test_interpolate_object_convert_no_op(): + df = DataFrame({"a": ["a", "b", "c"], "b": 1}) + df["a"] = df["a"].astype(object) + arr_a = get_array(df, "a") + + # Now CoW makes a copy, it should not! + assert df._mgr._has_no_reference(0) + assert np.shares_memory(arr_a, get_array(df, "a")) + + +def test_interpolate_object_convert_copies(): + df = DataFrame({"a": [1, np.nan, 2.5], "b": 1}) + arr_a = get_array(df, "a") + msg = "Can not interpolate with method=pad" + with pytest.raises(ValueError, match=msg): + df.interpolate(method="pad", inplace=True) + + assert df._mgr._has_no_reference(0) + assert np.shares_memory(arr_a, get_array(df, "a")) + + +def test_interpolate_downcast_reference_triggers_copy(): + df = DataFrame({"a": [1, np.nan, 2.5], "b": 1}) + df_orig = df.copy() + arr_a = get_array(df, "a") + view = df[:] + + msg = "Can not interpolate with method=pad" + with pytest.raises(ValueError, match=msg): + df.interpolate(method="pad", inplace=True) + assert df._mgr._has_no_reference(0) + assert not np.shares_memory(arr_a, get_array(df, "a")) + + tm.assert_frame_equal(df_orig, view) + + +def test_fillna(): + df = DataFrame({"a": [1.5, np.nan], "b": 1}) + df_orig = df.copy() + + df2 = df.fillna(5.5) + assert np.shares_memory(get_array(df, "b"), get_array(df2, "b")) + assert df2.index is not df.index + assert df2.columns is not df.columns + + df2.iloc[0, 1] = 100 + tm.assert_frame_equal(df_orig, df) + + +def test_fillna_dict(): + df = DataFrame({"a": [1.5, np.nan], "b": 1}) + df_orig = df.copy() + + df2 = df.fillna({"a": 100.5}) + assert np.shares_memory(get_array(df, "b"), get_array(df2, "b")) + assert not np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + + df2.iloc[0, 1] = 100 + tm.assert_frame_equal(df_orig, df) + + +def test_fillna_inplace(): + df = DataFrame({"a": [1.5, np.nan], "b": 1}) + arr_a = get_array(df, "a") + arr_b = get_array(df, "b") + + df.fillna(5.5, inplace=True) + assert np.shares_memory(get_array(df, "a"), arr_a) + assert np.shares_memory(get_array(df, "b"), arr_b) + assert df._mgr._has_no_reference(0) + assert df._mgr._has_no_reference(1) + + +def test_fillna_inplace_reference(): + df = DataFrame({"a": [1.5, np.nan], "b": 1}) + df_orig = df.copy() + arr_a = get_array(df, "a") + arr_b = get_array(df, "b") + view = df[:] + + df.fillna(5.5, inplace=True) + assert not np.shares_memory(get_array(df, "a"), arr_a) + assert np.shares_memory(get_array(df, "b"), arr_b) + assert view._mgr._has_no_reference(0) + assert df._mgr._has_no_reference(0) + tm.assert_frame_equal(view, df_orig) + expected = DataFrame({"a": [1.5, 5.5], "b": 1}) + tm.assert_frame_equal(df, expected) + + +def test_fillna_interval_inplace_reference(): + # Set dtype explicitly to avoid implicit cast when setting nan + ser = Series( + interval_range(start=0, end=5), name="a", dtype="interval[float64, right]" + ) + ser.iloc[1] = np.nan + + ser_orig = ser.copy() + view = ser[:] + ser.fillna(value=Interval(left=0, right=5), inplace=True) + + assert not np.shares_memory( + get_array(ser, "a").left.values, get_array(view, "a").left.values + ) + tm.assert_series_equal(view, ser_orig) + + +def test_fillna_series_empty_arg(): + ser = Series([1, np.nan, 2]) + ser_orig = ser.copy() + result = ser.fillna({}) + assert np.shares_memory(get_array(ser), get_array(result)) + + ser.iloc[0] = 100.5 + tm.assert_series_equal(ser_orig, result) + + +def test_fillna_series_empty_arg_inplace(): + ser = Series([1, np.nan, 2]) + arr = get_array(ser) + ser.fillna({}, inplace=True) + + assert np.shares_memory(get_array(ser), arr) + assert ser._mgr._has_no_reference(0) + + +def test_fillna_ea_noop_shares_memory(any_numeric_ea_and_arrow_dtype): + df = DataFrame({"a": [1, NA, 3], "b": 1}, dtype=any_numeric_ea_and_arrow_dtype) + df_orig = df.copy() + df2 = df.fillna(100) + + assert not np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + + assert np.shares_memory(get_array(df, "b"), get_array(df2, "b")) + assert not df2._mgr._has_no_reference(1) + tm.assert_frame_equal(df_orig, df) + + df2.iloc[0, 1] = 100 + assert not np.shares_memory(get_array(df, "b"), get_array(df2, "b")) + assert df2._mgr._has_no_reference(1) + assert df._mgr._has_no_reference(1) + tm.assert_frame_equal(df_orig, df) + + +def test_fillna_inplace_ea_noop_shares_memory(any_numeric_ea_and_arrow_dtype): + df = DataFrame({"a": [1, NA, 3], "b": 1}, dtype=any_numeric_ea_and_arrow_dtype) + df_orig = df.copy() + view = df[:] + df.fillna(100, inplace=True) + assert not np.shares_memory(get_array(df, "a"), get_array(view, "a")) + + assert np.shares_memory(get_array(df, "b"), get_array(view, "b")) + assert not df._mgr._has_no_reference(1) + assert not view._mgr._has_no_reference(1) + + df.iloc[0, 1] = 100 + tm.assert_frame_equal(df_orig, view) + + +def test_fillna_chained_assignment(): + df = DataFrame({"a": [1, np.nan, 2], "b": 1}) + df_orig = df.copy() + with tm.raises_chained_assignment_error(): + df["a"].fillna(100, inplace=True) + tm.assert_frame_equal(df, df_orig) + + with tm.raises_chained_assignment_error(): + df[["a"]].fillna(100, inplace=True) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("func", ["interpolate", "ffill", "bfill"]) +def test_interpolate_chained_assignment(func): + df = DataFrame({"a": [1, np.nan, 2], "b": 1}) + df_orig = df.copy() + with tm.raises_chained_assignment_error(): + getattr(df["a"], func)(inplace=True) + tm.assert_frame_equal(df, df_orig) + + with tm.raises_chained_assignment_error(): + getattr(df[["a"]], func)(inplace=True) + tm.assert_frame_equal(df, df_orig) diff --git a/pandas/tests/copy_view/test_methods.py b/pandas/tests/copy_view/test_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..bb19132d16a072304faa16599711a7d4be4122b8 --- /dev/null +++ b/pandas/tests/copy_view/test_methods.py @@ -0,0 +1,1601 @@ +import numpy as np +import pytest + +from pandas.compat import HAS_PYARROW + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Period, + Series, + Timestamp, + date_range, + period_range, +) +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array +from pandas.util.version import Version + + +def test_copy(): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_copy = df.copy() + + # the deep copy by defaults takes a shallow copy of the Index + assert df_copy.index is not df.index + assert df_copy.columns is not df.columns + assert df_copy.index.is_(df.index) + assert df_copy.columns.is_(df.columns) + + # the deep copy doesn't share memory + assert not np.shares_memory(get_array(df_copy, "a"), get_array(df, "a")) + assert not df_copy._mgr.blocks[0].refs.has_reference() + assert not df_copy._mgr.blocks[1].refs.has_reference() + + assert df_copy.index is not df.index + assert df_copy.columns is not df.columns + + # mutating copy doesn't mutate original + df_copy.iloc[0, 0] = 0 + assert df.iloc[0, 0] == 1 + + +def test_copy_shallow(): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_copy = df.copy(deep=False) + + # the shallow copy also makes a shallow copy of the index + assert df_copy.index is not df.index + assert df_copy.columns is not df.columns + assert df_copy.index.is_(df.index) + assert df_copy.columns.is_(df.columns) + + # the shallow copy still shares memory + assert np.shares_memory(get_array(df_copy, "a"), get_array(df, "a")) + assert df_copy._mgr.blocks[0].refs.has_reference() + assert df_copy._mgr.blocks[1].refs.has_reference() + + # mutating shallow copy doesn't mutate original + df_copy.iloc[0, 0] = 0 + assert df.iloc[0, 0] == 1 + # mutating triggered a copy-on-write -> no longer shares memory + assert not np.shares_memory(get_array(df_copy, "a"), get_array(df, "a")) + # but still shares memory for the other columns/blocks + assert np.shares_memory(get_array(df_copy, "c"), get_array(df, "c")) + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.parametrize("copy", [True, None, False]) +@pytest.mark.parametrize( + "method", + [ + lambda df, copy: df.rename(columns=str.lower, copy=copy), + lambda df, copy: df.reindex(columns=["a", "c"], copy=copy), + lambda df, copy: df.reindex_like(df, copy=copy), + lambda df, copy: df.align(df, copy=copy)[0], + lambda df, copy: df.set_axis(["a", "b", "c"], axis="index", copy=copy), + lambda df, copy: df.rename_axis(index="test", copy=copy), + lambda df, copy: df.rename_axis(columns="test", copy=copy), + lambda df, copy: df.astype({"b": "int64"}, copy=copy), + # lambda df, copy: df.swaplevel(0, 0, copy=copy), + lambda df, copy: df.truncate(0, 5, copy=copy), + lambda df, copy: df.infer_objects(copy=copy), + lambda df, copy: df.to_timestamp(copy=copy), + lambda df, copy: df.to_period(freq="D", copy=copy), + lambda df, copy: df.tz_localize("US/Central", copy=copy), + lambda df, copy: df.tz_convert("US/Central", copy=copy), + lambda df, copy: df.set_flags(allows_duplicate_labels=False, copy=copy), + ], + ids=[ + "rename", + "reindex", + "reindex_like", + "align", + "set_axis", + "rename_axis0", + "rename_axis1", + "astype", + # "swaplevel", # only series + "truncate", + "infer_objects", + "to_timestamp", + "to_period", + "tz_localize", + "tz_convert", + "set_flags", + ], +) +def test_methods_copy_keyword(request, method, copy): + index = None + if "to_timestamp" in request.node.callspec.id: + index = period_range("2012-01-01", freq="D", periods=3) + elif "to_period" in request.node.callspec.id: + index = date_range("2012-01-01", freq="D", periods=3) + elif "tz_localize" in request.node.callspec.id: + index = date_range("2012-01-01", freq="D", periods=3) + elif "tz_convert" in request.node.callspec.id: + index = date_range("2012-01-01", freq="D", periods=3, tz="Europe/Brussels") + + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}, index=index) + df2 = method(df, copy=copy) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +@pytest.mark.parametrize("copy", [True, None, False]) +@pytest.mark.parametrize( + "method", + [ + lambda ser, copy: ser.rename(index={0: 100}, copy=copy), + lambda ser, copy: ser.rename(None, copy=copy), + lambda ser, copy: ser.reindex(index=ser.index, copy=copy), + lambda ser, copy: ser.reindex_like(ser, copy=copy), + lambda ser, copy: ser.align(ser, copy=copy)[0], + lambda ser, copy: ser.set_axis(["a", "b", "c"], axis="index", copy=copy), + lambda ser, copy: ser.rename_axis(index="test", copy=copy), + lambda ser, copy: ser.astype("int64", copy=copy), + lambda ser, copy: ser.swaplevel(0, 1, copy=copy), + lambda ser, copy: ser.truncate(0, 5, copy=copy), + lambda ser, copy: ser.infer_objects(copy=copy), + lambda ser, copy: ser.to_timestamp(copy=copy), + lambda ser, copy: ser.to_period(freq="D", copy=copy), + lambda ser, copy: ser.tz_localize("US/Central", copy=copy), + lambda ser, copy: ser.tz_convert("US/Central", copy=copy), + lambda ser, copy: ser.set_flags(allows_duplicate_labels=False, copy=copy), + ], + ids=[ + "rename (dict)", + "rename", + "reindex", + "reindex_like", + "align", + "set_axis", + "rename_axis0", + "astype", + "swaplevel", + "truncate", + "infer_objects", + "to_timestamp", + "to_period", + "tz_localize", + "tz_convert", + "set_flags", + ], +) +def test_methods_series_copy_keyword(request, method, copy): + index = None + if "to_timestamp" in request.node.callspec.id: + index = period_range("2012-01-01", freq="D", periods=3) + elif "to_period" in request.node.callspec.id: + index = date_range("2012-01-01", freq="D", periods=3) + elif "tz_localize" in request.node.callspec.id: + index = date_range("2012-01-01", freq="D", periods=3) + elif "tz_convert" in request.node.callspec.id: + index = date_range("2012-01-01", freq="D", periods=3, tz="Europe/Brussels") + elif "swaplevel" in request.node.callspec.id: + index = MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]]) + + ser = Series([1, 2, 3], index=index) + ser2 = method(ser, copy=copy) + assert np.shares_memory(get_array(ser2), get_array(ser)) + + +# ----------------------------------------------------------------------------- +# DataFrame methods returning new DataFrame using shallow copy + + +def test_reset_index(): + # Case: resetting the index (i.e. adding a new column) + mutating the + # resulting dataframe + df = DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}, index=[10, 11, 12] + ) + df_orig = df.copy() + df2 = df.reset_index() + df2._mgr._verify_integrity() + + # still shares memory (df2 is a shallow copy) + assert np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + assert np.shares_memory(get_array(df2, "c"), get_array(df, "c")) + # mutating df2 triggers a copy-on-write for that column / block + df2.iloc[0, 2] = 0 + assert not np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + assert np.shares_memory(get_array(df2, "c"), get_array(df, "c")) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("index", [pd.RangeIndex(0, 2), Index([1, 2])]) +def test_reset_index_series_drop(index): + ser = Series([1, 2], index=index) + ser_orig = ser.copy() + ser2 = ser.reset_index(drop=True) + assert np.shares_memory(get_array(ser), get_array(ser2)) + assert not ser._mgr._has_no_reference(0) + + ser2.iloc[0] = 100 + tm.assert_series_equal(ser, ser_orig) + + +def test_groupby_column_index_in_references(): + df = DataFrame( + {"A": ["a", "b", "c", "d"], "B": [1, 2, 3, 4], "C": ["a", "a", "b", "b"]} + ) + df = df.set_index("A") + key = df["C"] + result = df.groupby(key, observed=True).sum() + expected = df.groupby("C", observed=True).sum() + tm.assert_frame_equal(result, expected) + + +def test_groupby_modify_series(): + # https://github.com/pandas-dev/pandas/issues/63219 + # Modifying a Series after using it to groupby should not impact + # the groupby operation. + ser = Series([1, 2, 1]) + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + gb = df.groupby(ser) + ser.iloc[0] = 100 + result = gb.sum() + expected = DataFrame({"a": [4, 2], "b": [10, 5]}, index=[1, 2]) + tm.assert_frame_equal(result, expected) + + +def test_rename_columns(): + # Case: renaming columns returns a new dataframe + # + afterwards modifying the result + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + df2 = df.rename(columns=str.upper) + + assert np.shares_memory(get_array(df2, "A"), get_array(df, "a")) + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "A"), get_array(df, "a")) + assert np.shares_memory(get_array(df2, "C"), get_array(df, "c")) + expected = DataFrame({"A": [0, 2, 3], "B": [4, 5, 6], "C": [0.1, 0.2, 0.3]}) + tm.assert_frame_equal(df2, expected) + tm.assert_frame_equal(df, df_orig) + + +def test_rename_columns_modify_parent(): + # Case: renaming columns returns a new dataframe + # + afterwards modifying the original (parent) dataframe + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df2 = df.rename(columns=str.upper) + df2_orig = df2.copy() + + assert np.shares_memory(get_array(df2, "A"), get_array(df, "a")) + df.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "A"), get_array(df, "a")) + assert np.shares_memory(get_array(df2, "C"), get_array(df, "c")) + expected = DataFrame({"a": [0, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + tm.assert_frame_equal(df, expected) + tm.assert_frame_equal(df2, df2_orig) + + +def test_pipe(): + df = DataFrame({"a": [1, 2, 3], "b": 1.5}) + df_orig = df.copy() + + def testfunc(df): + return df + + df2 = df.pipe(testfunc) + + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + # mutating df2 triggers a copy-on-write for that column + df2.iloc[0, 0] = 0 + tm.assert_frame_equal(df, df_orig) + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + + +def test_pipe_modify_df(): + df = DataFrame({"a": [1, 2, 3], "b": 1.5}) + df_orig = df.copy() + + def testfunc(df): + df.iloc[0, 0] = 100 + return df + + df2 = df.pipe(testfunc) + + assert np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + + tm.assert_frame_equal(df, df_orig) + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + + +def test_reindex_columns(): + # Case: reindexing the column returns a new dataframe + # + afterwards modifying the result + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + df2 = df.reindex(columns=["a", "c"]) + + # still shares memory (df2 is a shallow copy) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + # mutating df2 triggers a copy-on-write for that column + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert np.shares_memory(get_array(df2, "c"), get_array(df, "c")) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "index", + [ + lambda idx: idx, + lambda idx: idx.view(), + lambda idx: idx.copy(), + lambda idx: list(idx), + ], + ids=["identical", "view", "copy", "values"], +) +def test_reindex_rows(index): + # Case: reindexing the rows with an index that matches the current index + # can use a shallow copy + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + df2 = df.reindex(index=index(df.index)) + + # still shares memory (df2 is a shallow copy) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + # mutating df2 triggers a copy-on-write for that column + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert np.shares_memory(get_array(df2, "c"), get_array(df, "c")) + tm.assert_frame_equal(df, df_orig) + + +def test_drop_on_column(): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + df2 = df.drop(columns="a") + df2._mgr._verify_integrity() + + assert np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + assert np.shares_memory(get_array(df2, "c"), get_array(df, "c")) + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + assert np.shares_memory(get_array(df2, "c"), get_array(df, "c")) + tm.assert_frame_equal(df, df_orig) + + +def test_select_dtypes(): + # Case: selecting columns using `select_dtypes()` returns a new dataframe + # + afterwards modifying the result + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + df2 = df.select_dtypes("int64") + df2._mgr._verify_integrity() + + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + # mutating df2 triggers a copy-on-write for that column/block + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "filter_kwargs", [{"items": ["a"]}, {"like": "a"}, {"regex": "a"}] +) +def test_filter(filter_kwargs): + # Case: selecting columns using `filter()` returns a new dataframe + # + afterwards modifying the result + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + df2 = df.filter(**filter_kwargs) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + # mutating df2 triggers a copy-on-write for that column/block + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +def test_shift_no_op(): + df = DataFrame( + [[1, 2], [3, 4], [5, 6]], + index=date_range("2020-01-01", "2020-01-03"), + columns=["a", "b"], + ) + df_orig = df.copy() + df2 = df.shift(periods=0) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert df2.index is not df.index + assert df2.columns is not df.columns + + df.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + assert np.shares_memory(get_array(df, "b"), get_array(df2, "b")) + tm.assert_frame_equal(df2, df_orig) + + +def test_shift_index(): + df = DataFrame( + [[1, 2], [3, 4], [5, 6]], + index=date_range("2020-01-01", "2020-01-03"), + columns=["a", "b"], + ) + df2 = df.shift(periods=1, axis=0) + + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert df2.index is not df.index + assert df2.columns is not df.columns + + +def test_shift_rows_freq(): + df = DataFrame( + [[1, 2], [3, 4], [5, 6]], + index=date_range("2020-01-01", "2020-01-03"), + columns=["a", "b"], + ) + df_orig = df.copy() + df_orig.index = date_range("2020-01-02", "2020-01-04") + df2 = df.shift(periods=1, freq="1D") + + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + df.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + tm.assert_frame_equal(df2, df_orig) + + +def test_shift_columns(): + df = DataFrame( + [[1, 2], [3, 4], [5, 6]], columns=date_range("2020-01-01", "2020-01-02") + ) + df2 = df.shift(periods=1, axis=1) + + assert np.shares_memory(get_array(df2, "2020-01-02"), get_array(df, "2020-01-01")) + df.iloc[0, 0] = 0 + assert not np.shares_memory( + get_array(df2, "2020-01-02"), get_array(df, "2020-01-01") + ) + expected = DataFrame( + [[np.nan, 1], [np.nan, 3], [np.nan, 5]], + columns=date_range("2020-01-01", "2020-01-02"), + ) + tm.assert_frame_equal(df2, expected) + + +def test_pop(): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + view_original = df[:] + result = df.pop("a") + + assert np.shares_memory(result.values, get_array(view_original, "a")) + assert np.shares_memory(get_array(df, "b"), get_array(view_original, "b")) + + result.iloc[0] = 0 + assert not np.shares_memory(result.values, get_array(view_original, "a")) + df.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df, "b"), get_array(view_original, "b")) + tm.assert_frame_equal(view_original, df_orig) + + +@pytest.mark.parametrize( + "func", + [ + lambda x, y: x.align(y), + lambda x, y: x.align(y.a, axis=0), + lambda x, y: x.align(y.a.iloc[slice(0, 1)], axis=1), + ], +) +def test_align_frame(func): + df = DataFrame({"a": [1, 2, 3], "b": "a"}) + df_orig = df.copy() + df_changed = df[["b", "a"]].copy() + df2, _ = func(df, df_changed) + + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +def test_align_series(): + ser = Series([1, 2]) + ser_orig = ser.copy() + ser_other = ser.copy() + ser2, ser_other_result = ser.align(ser_other) + + assert np.shares_memory(ser2.values, ser.values) + assert np.shares_memory(ser_other_result.values, ser_other.values) + ser2.iloc[0] = 0 + ser_other_result.iloc[0] = 0 + assert not np.shares_memory(ser2.values, ser.values) + assert not np.shares_memory(ser_other_result.values, ser_other.values) + tm.assert_series_equal(ser, ser_orig) + tm.assert_series_equal(ser_other, ser_orig) + + +def test_align_copy_false(): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df_orig = df.copy() + df2, df3 = df.align(df) + + assert np.shares_memory(get_array(df, "b"), get_array(df2, "b")) + assert np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + + df2.loc[0, "a"] = 0 + tm.assert_frame_equal(df, df_orig) # Original is unchanged + + df3.loc[0, "a"] = 0 + tm.assert_frame_equal(df, df_orig) # Original is unchanged + + +def test_align_with_series_copy_false(): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + ser = Series([1, 2, 3], name="x") + ser_orig = ser.copy() + df_orig = df.copy() + df2, ser2 = df.align(ser, axis=0) + + assert np.shares_memory(get_array(df, "b"), get_array(df2, "b")) + assert np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + assert np.shares_memory(get_array(ser, "x"), get_array(ser2, "x")) + + df2.loc[0, "a"] = 0 + tm.assert_frame_equal(df, df_orig) # Original is unchanged + + ser2.loc[0] = 0 + tm.assert_series_equal(ser, ser_orig) # Original is unchanged + + +def test_to_frame(): + # Case: converting a Series to a DataFrame with to_frame + ser = Series([1, 2, 3]) + ser_orig = ser.copy() + + df = ser[:].to_frame() + + # currently this always returns a "view" + assert np.shares_memory(ser.values, get_array(df, 0)) + + df.iloc[0, 0] = 0 + + # mutating df triggers a copy-on-write for that column + assert not np.shares_memory(ser.values, get_array(df, 0)) + tm.assert_series_equal(ser, ser_orig) + + # modify original series -> don't modify dataframe + df = ser[:].to_frame() + ser.iloc[0] = 0 + + tm.assert_frame_equal(df, ser_orig.to_frame()) + + df = ser.to_frame() + assert df.index is not ser.index + + +@pytest.mark.parametrize( + "method, idx", + [ + (lambda df: df.copy(deep=False).copy(deep=False), 0), + (lambda df: df.reset_index().reset_index(), 2), + (lambda df: df.rename(columns=str.upper).rename(columns=str.lower), 0), + (lambda df: df.copy(deep=False).select_dtypes(include="number"), 0), + ], + ids=["shallow-copy", "reset_index", "rename", "select_dtypes"], +) +def test_chained_methods(method, idx): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + + # modify df2 -> don't modify df + df2 = method(df) + df2.iloc[0, idx] = 0 + tm.assert_frame_equal(df, df_orig) + + # modify df -> don't modify df2 + df2 = method(df) + df.iloc[0, 0] = 0 + tm.assert_frame_equal(df2.iloc[:, idx:], df_orig) + + +@pytest.mark.parametrize("obj", [Series([1, 2], name="a"), DataFrame({"a": [1, 2]})]) +def test_to_timestamp(obj): + obj.index = Index([Period("2012-1-1", freq="D"), Period("2012-1-2", freq="D")]) + + obj_orig = obj.copy() + obj2 = obj.to_timestamp() + + assert np.shares_memory(get_array(obj2, "a"), get_array(obj, "a")) + + # mutating obj2 triggers a copy-on-write for that column / block + obj2.iloc[0] = 0 + assert not np.shares_memory(get_array(obj2, "a"), get_array(obj, "a")) + tm.assert_equal(obj, obj_orig) + + +@pytest.mark.parametrize("obj", [Series([1, 2], name="a"), DataFrame({"a": [1, 2]})]) +def test_to_period(obj): + obj.index = Index([Timestamp("2019-12-31"), Timestamp("2020-12-31")]) + + obj_orig = obj.copy() + obj2 = obj.to_period(freq="Y") + + assert np.shares_memory(get_array(obj2, "a"), get_array(obj, "a")) + + # mutating obj2 triggers a copy-on-write for that column / block + obj2.iloc[0] = 0 + assert not np.shares_memory(get_array(obj2, "a"), get_array(obj, "a")) + tm.assert_equal(obj, obj_orig) + + +def test_set_index(): + # GH 49473 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + df2 = df.set_index("a") + + assert np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + + # mutating df2 triggers a copy-on-write for that column / block + df2.iloc[0, 1] = 0 + assert not np.shares_memory(get_array(df2, "c"), get_array(df, "c")) + tm.assert_frame_equal(df, df_orig) + + +def test_set_index_mutating_parent_does_not_mutate_index(): + df = DataFrame({"a": [1, 2, 3], "b": 1}) + result = df.set_index("a") + expected = result.copy() + + df.iloc[0, 0] = 100 + tm.assert_frame_equal(result, expected) + + +def test_add_prefix(): + # GH 49473 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + df2 = df.add_prefix("CoW_") + + assert np.shares_memory(get_array(df2, "CoW_a"), get_array(df, "a")) + df2.iloc[0, 0] = 0 + + assert not np.shares_memory(get_array(df2, "CoW_a"), get_array(df, "a")) + + assert np.shares_memory(get_array(df2, "CoW_c"), get_array(df, "c")) + expected = DataFrame( + {"CoW_a": [0, 2, 3], "CoW_b": [4, 5, 6], "CoW_c": [0.1, 0.2, 0.3]} + ) + tm.assert_frame_equal(df2, expected) + tm.assert_frame_equal(df, df_orig) + + +def test_add_suffix(): + # GH 49473 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + df2 = df.add_suffix("_CoW") + assert np.shares_memory(get_array(df2, "a_CoW"), get_array(df, "a")) + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "a_CoW"), get_array(df, "a")) + assert np.shares_memory(get_array(df2, "c_CoW"), get_array(df, "c")) + expected = DataFrame( + {"a_CoW": [0, 2, 3], "b_CoW": [4, 5, 6], "c_CoW": [0.1, 0.2, 0.3]} + ) + tm.assert_frame_equal(df2, expected) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("axis, val", [(0, 5.5), (1, np.nan)]) +def test_dropna(axis, val): + df = DataFrame({"a": [1, 2, 3], "b": [4, val, 6], "c": "d"}) + df_orig = df.copy() + df2 = df.dropna(axis=axis) + + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("val", [5, 5.5]) +def test_dropna_series(val): + ser = Series([1, val, 4]) + ser_orig = ser.copy() + ser2 = ser.dropna() + assert np.shares_memory(ser2.values, ser.values) + + ser2.iloc[0] = 0 + assert not np.shares_memory(ser2.values, ser.values) + tm.assert_series_equal(ser, ser_orig) + + +@pytest.mark.parametrize( + "method", + [ + lambda df: df.head(), + lambda df: df.head(2), + lambda df: df.tail(), + lambda df: df.tail(3), + ], +) +def test_head_tail(method): + df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + df2 = method(df) + df2._mgr._verify_integrity() + + # We are explicitly deviating for CoW here to make an eager copy (avoids + # tracking references for very cheap ops) + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert not np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + + # modify df2 to trigger CoW for that block + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +def test_infer_objects(using_infer_string): + df = DataFrame( + {"a": [1, 2], "b": Series(["x", "y"], dtype=object), "c": 1, "d": "x"} + ) + df_orig = df.copy() + df2 = df.infer_objects() + + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + if using_infer_string and HAS_PYARROW: + assert not tm.shares_memory(get_array(df2, "b"), get_array(df, "b")) + else: + assert np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + + df2.iloc[0, 0] = 0 + df2.iloc[0, 1] = "d" + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert not np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + tm.assert_frame_equal(df, df_orig) + + +def test_infer_objects_no_reference(using_infer_string): + df = DataFrame( + { + "a": [1, 2], + "b": Series(["x", "y"], dtype=object), + "c": 1, + "d": Series( + [Timestamp("2019-12-31"), Timestamp("2020-12-31")], dtype="object" + ), + "e": Series(["z", "w"], dtype=object), + } + ) + df = df.infer_objects() + + arr_a = get_array(df, "a") + arr_b = get_array(df, "b") + arr_d = get_array(df, "d") + + df.iloc[0, 0] = 0 + df.iloc[0, 1] = "d" + df.iloc[0, 3] = Timestamp("2018-12-31") + assert np.shares_memory(arr_a, get_array(df, "a")) + if using_infer_string and HAS_PYARROW: + # note that the underlying memory of arr_b has been copied anyway + # because of the assignment, but the EA is updated inplace so still + # appears the share memory + assert tm.shares_memory(arr_b, get_array(df, "b")) + else: + # TODO(CoW): Block splitting causes references here + assert not np.shares_memory(arr_b, get_array(df, "b")) + assert np.shares_memory(arr_d, get_array(df, "d")) + + +def test_infer_objects_reference(): + df = DataFrame( + { + "a": [1, 2], + "b": Series(["x", "y"], dtype=object), + "c": 1, + "d": Series( + [Timestamp("2019-12-31"), Timestamp("2020-12-31")], dtype="object" + ), + } + ) + view = df[:] # noqa: F841 + df = df.infer_objects() + + arr_a = get_array(df, "a") + arr_b = get_array(df, "b") + arr_d = get_array(df, "d") + + df.iloc[0, 0] = 0 + df.iloc[0, 1] = "d" + df.iloc[0, 3] = Timestamp("2018-12-31") + assert not np.shares_memory(arr_a, get_array(df, "a")) + assert not np.shares_memory(arr_b, get_array(df, "b")) + assert np.shares_memory(arr_d, get_array(df, "d")) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"before": "a", "after": "b", "axis": 1}, + {"before": 0, "after": 1, "axis": 0}, + ], +) +def test_truncate(kwargs): + df = DataFrame({"a": [1, 2, 3], "b": 1, "c": 2}) + df_orig = df.copy() + df2 = df.truncate(**kwargs) + df2._mgr._verify_integrity() + + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("method", ["assign", "drop_duplicates"]) +def test_assign_drop_duplicates(method): + df = DataFrame({"a": [1, 2, 3]}) + df_orig = df.copy() + df2 = getattr(df, method)() + df2._mgr._verify_integrity() + + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("obj", [Series([1, 2]), DataFrame({"a": [1, 2]})]) +def test_take(obj): + # Check that no copy is made when we take all rows in original order + obj_orig = obj.copy() + obj2 = obj.take([0, 1]) + assert np.shares_memory(obj2.values, obj.values) + + obj2.iloc[0] = 0 + assert not np.shares_memory(obj2.values, obj.values) + tm.assert_equal(obj, obj_orig) + + +@pytest.mark.parametrize("obj", [Series([1, 2]), DataFrame({"a": [1, 2]})]) +def test_between_time(obj): + obj.index = date_range("2018-04-09", periods=2, freq="1D20min") + obj_orig = obj.copy() + obj2 = obj.between_time("0:00", "1:00") + assert np.shares_memory(obj2.values, obj.values) + + obj2.iloc[0] = 0 + assert not np.shares_memory(obj2.values, obj.values) + tm.assert_equal(obj, obj_orig) + + +def test_reindex_like(): + df = DataFrame({"a": [1, 2], "b": "a"}) + other = DataFrame({"b": "a", "a": [1, 2]}) + + df_orig = df.copy() + df2 = df.reindex_like(other) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + df2.iloc[0, 1] = 0 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +def test_sort_index(): + # GH 49473 + ser = Series([1, 2, 3]) + ser_orig = ser.copy() + ser2 = ser.sort_index() + assert np.shares_memory(ser.values, ser2.values) + + # mutating ser triggers a copy-on-write for the column / block + ser2.iloc[0] = 0 + assert not np.shares_memory(ser2.values, ser.values) + tm.assert_series_equal(ser, ser_orig) + + +@pytest.mark.parametrize( + "obj, kwargs", + [(Series([1, 2, 3], name="a"), {}), (DataFrame({"a": [1, 2, 3]}), {"by": "a"})], +) +def test_sort_values(obj, kwargs): + obj_orig = obj.copy() + obj2 = obj.sort_values(**kwargs) + assert np.shares_memory(get_array(obj2, "a"), get_array(obj, "a")) + + # mutating df triggers a copy-on-write for the column / block + obj2.iloc[0] = 0 + assert not np.shares_memory(get_array(obj2, "a"), get_array(obj, "a")) + tm.assert_equal(obj, obj_orig) + + +@pytest.mark.parametrize( + "obj, kwargs", + [(Series([1, 2, 3], name="a"), {}), (DataFrame({"a": [1, 2, 3]}), {"by": "a"})], +) +def test_sort_values_inplace(obj, kwargs): + obj_orig = obj.copy() + view = obj[:] + obj.sort_values(inplace=True, **kwargs) + + assert np.shares_memory(get_array(obj, "a"), get_array(view, "a")) + + # mutating obj triggers a copy-on-write for the column / block + obj.iloc[0] = 0 + assert not np.shares_memory(get_array(obj, "a"), get_array(view, "a")) + tm.assert_equal(view, obj_orig) + + +@pytest.mark.parametrize("decimals", [-1, 0, 1]) +def test_round(decimals): + df = DataFrame({"a": [1, 2], "b": "c"}) + df_orig = df.copy() + df2 = df.round(decimals=decimals) + + assert tm.shares_memory(get_array(df2, "b"), get_array(df, "b")) + # TODO: Make inplace by using out parameter of ndarray.round? + if decimals >= 0 and Version(np.__version__) < Version("2.4.0.dev0"): + # Ensure lazy copy if no-op + # TODO: Cannot rely on Numpy returning view after version 2.3 + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + else: + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert df2.index is not df.index + assert df2.columns is not df.columns + + df2.iloc[0, 1] = "d" + df2.iloc[0, 0] = 4 + assert not np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +def test_reorder_levels(): + index = MultiIndex.from_tuples( + [(1, 1), (1, 2), (2, 1), (2, 2)], names=["one", "two"] + ) + df = DataFrame({"a": [1, 2, 3, 4]}, index=index) + df_orig = df.copy() + df2 = df.reorder_levels(order=["two", "one"]) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +def test_series_reorder_levels(): + index = MultiIndex.from_tuples( + [(1, 1), (1, 2), (2, 1), (2, 2)], names=["one", "two"] + ) + ser = Series([1, 2, 3, 4], index=index) + ser_orig = ser.copy() + ser2 = ser.reorder_levels(order=["two", "one"]) + assert np.shares_memory(ser2.values, ser.values) + + ser2.iloc[0] = 0 + assert not np.shares_memory(ser2.values, ser.values) + tm.assert_series_equal(ser, ser_orig) + + +@pytest.mark.parametrize("obj", [Series([1, 2, 3]), DataFrame({"a": [1, 2, 3]})]) +def test_swaplevel(obj): + index = MultiIndex.from_tuples([(1, 1), (1, 2), (2, 1)], names=["one", "two"]) + obj.index = index + obj_orig = obj.copy() + obj2 = obj.swaplevel() + assert np.shares_memory(obj2.values, obj.values) + + obj2.iloc[0] = 0 + assert not np.shares_memory(obj2.values, obj.values) + tm.assert_equal(obj, obj_orig) + + +def test_frame_set_axis(): + # GH 49473 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + df2 = df.set_axis(["a", "b", "c"], axis="index") + + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + # mutating df2 triggers a copy-on-write for that column / block + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +def test_series_set_axis(): + # GH 49473 + ser = Series([1, 2, 3]) + ser_orig = ser.copy() + ser2 = ser.set_axis(["a", "b", "c"], axis="index") + assert np.shares_memory(ser, ser2) + + # mutating ser triggers a copy-on-write for the column / block + ser2.iloc[0] = 0 + assert not np.shares_memory(ser2, ser) + tm.assert_series_equal(ser, ser_orig) + + +def test_set_flags(): + ser = Series([1, 2, 3]) + ser_orig = ser.copy() + ser2 = ser.set_flags(allows_duplicate_labels=False) + + assert np.shares_memory(ser, ser2) + + # mutating ser triggers a copy-on-write for the column / block + ser2.iloc[0] = 0 + assert not np.shares_memory(ser2, ser) + tm.assert_series_equal(ser, ser_orig) + + +@pytest.mark.parametrize("kwargs", [{"mapper": "test"}, {"index": "test"}]) +def test_rename_axis(kwargs): + df = DataFrame({"a": [1, 2, 3, 4]}, index=Index([1, 2, 3, 4], name="a")) + df_orig = df.copy() + df2 = df.rename_axis(**kwargs) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + df2.iloc[0, 0] = 0 + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize( + "func, tz", [("tz_convert", "Europe/Berlin"), ("tz_localize", None)] +) +def test_tz_convert_localize(func, tz): + # GH 49473 + ser = Series( + [1, 2], index=date_range(start="2014-08-01 09:00", freq="h", periods=2, tz=tz) + ) + ser_orig = ser.copy() + ser2 = getattr(ser, func)("US/Central") + assert np.shares_memory(ser.values, ser2.values) + + # mutating ser triggers a copy-on-write for the column / block + ser2.iloc[0] = 0 + assert not np.shares_memory(ser2.values, ser.values) + tm.assert_series_equal(ser, ser_orig) + + +def test_droplevel(): + # GH 49473 + index = MultiIndex.from_tuples([(1, 1), (1, 2), (2, 1)], names=["one", "two"]) + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}, index=index) + df_orig = df.copy() + df2 = df.droplevel(0) + + assert np.shares_memory(get_array(df2, "c"), get_array(df, "c")) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + # mutating df2 triggers a copy-on-write for that column / block + df2.iloc[0, 0] = 0 + + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert np.shares_memory(get_array(df2, "b"), get_array(df, "b")) + + tm.assert_frame_equal(df, df_orig) + + +def test_squeeze(): + df = DataFrame({"a": [1, 2, 3]}) + df_orig = df.copy() + series = df.squeeze() + + # Should share memory regardless of CoW since squeeze is just an iloc + assert np.shares_memory(series.values, get_array(df, "a")) + + # mutating squeezed df triggers a copy-on-write for that column/block + series.iloc[0] = 0 + assert not np.shares_memory(series.values, get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +def test_items(): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + df_orig = df.copy() + + # Test this twice, since the second time, the item cache will be + # triggered, and we want to make sure it still works then. + for i in range(2): + for name, ser in df.items(): + assert np.shares_memory(get_array(ser, name), get_array(df, name)) + + # mutating df triggers a copy-on-write for that column / block + ser.iloc[0] = 0 + + assert not np.shares_memory(get_array(ser, name), get_array(df, name)) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("dtype", ["int64", "Int64"]) +def test_putmask(dtype): + df = DataFrame({"a": [1, 2], "b": 1, "c": 2}, dtype=dtype) + view = df[:] + df_orig = df.copy() + df[df == df] = 5 + + assert not np.shares_memory(get_array(view, "a"), get_array(df, "a")) + tm.assert_frame_equal(view, df_orig) + + +@pytest.mark.parametrize("dtype", ["int64", "Int64"]) +def test_putmask_no_reference(dtype): + df = DataFrame({"a": [1, 2], "b": 1, "c": 2}, dtype=dtype) + arr_a = get_array(df, "a") + df[df == df] = 5 + assert np.shares_memory(arr_a, get_array(df, "a")) + + +@pytest.mark.parametrize("dtype", ["float64", "Float64"]) +def test_putmask_aligns_rhs_no_reference(dtype): + df = DataFrame({"a": [1.5, 2], "b": 1.5}, dtype=dtype) + arr_a = get_array(df, "a") + df[df == df] = DataFrame({"a": [5.5, 5]}) + assert np.shares_memory(arr_a, get_array(df, "a")) + + +@pytest.mark.parametrize("val, exp, raises", [(5.5, True, True), (5, False, False)]) +def test_putmask_dont_copy_some_blocks(val, exp, raises: bool): + df = DataFrame({"a": [1, 2], "b": 1, "c": 1.5}) + view = df[:] + df_orig = df.copy() + indexer = DataFrame( + [[True, False, False], [True, False, False]], columns=list("abc") + ) + if raises: + with pytest.raises(TypeError, match="Invalid value"): + df[indexer] = val + else: + df[indexer] = val + assert not np.shares_memory(get_array(view, "a"), get_array(df, "a")) + # TODO(CoW): Could split blocks to avoid copying the whole block + assert np.shares_memory(get_array(view, "b"), get_array(df, "b")) is exp + assert np.shares_memory(get_array(view, "c"), get_array(df, "c")) + assert df._mgr._has_no_reference(1) is not exp + assert not df._mgr._has_no_reference(2) + tm.assert_frame_equal(view, df_orig) + + +@pytest.mark.parametrize("dtype", ["int64", "Int64"]) +@pytest.mark.parametrize( + "func", + [ + lambda ser: ser.where(ser > 0, 10), + lambda ser: ser.mask(ser <= 0, 10), + ], +) +def test_where_mask_noop(dtype, func): + ser = Series([1, 2, 3], dtype=dtype) + ser_orig = ser.copy() + + result = func(ser) + assert np.shares_memory(get_array(ser), get_array(result)) + assert result.index is not ser.index + + result.iloc[0] = 10 + assert not np.shares_memory(get_array(ser), get_array(result)) + tm.assert_series_equal(ser, ser_orig) + + +@pytest.mark.parametrize("dtype", ["int64", "Int64"]) +@pytest.mark.parametrize( + "func", + [ + lambda ser: ser.where(ser < 0, 10), + lambda ser: ser.mask(ser >= 0, 10), + ], +) +def test_where_mask(dtype, func): + ser = Series([1, 2, 3], dtype=dtype) + ser_orig = ser.copy() + + result = func(ser) + + assert not np.shares_memory(get_array(ser), get_array(result)) + assert result.index is not ser.index + tm.assert_series_equal(ser, ser_orig) + + +@pytest.mark.parametrize("dtype, val", [("int64", 10.5), ("Int64", 10)]) +@pytest.mark.parametrize( + "func", + [ + lambda df, val: df.where(df < 0, val), + lambda df, val: df.mask(df >= 0, val), + ], +) +def test_where_mask_noop_on_single_column(dtype, val, func): + df = DataFrame({"a": [1, 2, 3], "b": [-4, -5, -6]}, dtype=dtype) + df_orig = df.copy() + + result = func(df, val) + assert np.shares_memory(get_array(df, "b"), get_array(result, "b")) + assert not np.shares_memory(get_array(df, "a"), get_array(result, "a")) + + result.iloc[0, 1] = 10 + assert not np.shares_memory(get_array(df, "b"), get_array(result, "b")) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("func", ["mask", "where"]) +def test_chained_where_mask(func): + df = DataFrame({"a": [1, 4, 2], "b": 1}) + df_orig = df.copy() + with tm.raises_chained_assignment_error(): + getattr(df["a"], func)(df["a"] > 2, 5, inplace=True) + tm.assert_frame_equal(df, df_orig) + + with tm.raises_chained_assignment_error(): + getattr(df[["a"]], func)(df["a"] > 2, 5, inplace=True) + tm.assert_frame_equal(df, df_orig) + + +def test_asfreq_noop(): + df = DataFrame( + {"a": [0.0, None, 2.0, 3.0]}, + index=date_range("1/1/2000", periods=4, freq="min"), + ) + df_orig = df.copy() + df2 = df.asfreq(freq="min") + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + # mutating df2 triggers a copy-on-write for that column / block + df2.iloc[0, 0] = 0 + + assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +def test_iterrows(): + df = DataFrame({"a": 0, "b": 1}, index=[1, 2, 3]) + df_orig = df.copy() + + for _, sub in df.iterrows(): + sub.iloc[0] = 100 + tm.assert_frame_equal(df, df_orig) + + +def test_interpolate_creates_copy(): + # GH#51126 + df = DataFrame({"a": [1.5, np.nan, 3]}) + view = df[:] + expected = df.copy() + + df.ffill(inplace=True) + df.iloc[0, 0] = 100.5 + tm.assert_frame_equal(view, expected) + + +def test_isetitem(): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + df_orig = df.copy() + df2 = df.copy(deep=False) # Trigger a CoW + df2.isetitem(1, np.array([-1, -2, -3])) # This is inplace + assert np.shares_memory(get_array(df, "c"), get_array(df2, "c")) + assert np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + + df2.loc[0, "a"] = 0 + tm.assert_frame_equal(df, df_orig) # Original is unchanged + assert np.shares_memory(get_array(df, "c"), get_array(df2, "c")) + + +@pytest.mark.parametrize( + "dtype", ["int64", "float64"], ids=["single-block", "mixed-block"] +) +def test_isetitem_series(dtype): + df = DataFrame({"a": [1, 2, 3], "b": np.array([4, 5, 6], dtype=dtype)}) + ser = Series([7, 8, 9]) + ser_orig = ser.copy() + df.isetitem(0, ser) + + assert np.shares_memory(get_array(df, "a"), get_array(ser)) + assert not df._mgr._has_no_reference(0) + + # mutating dataframe doesn't update series + df.loc[0, "a"] = 0 + tm.assert_series_equal(ser, ser_orig) + + # mutating series doesn't update dataframe + df = DataFrame({"a": [1, 2, 3], "b": np.array([4, 5, 6], dtype=dtype)}) + ser = Series([7, 8, 9]) + df.isetitem(0, ser) + + ser.loc[0] = 0 + expected = DataFrame({"a": [7, 8, 9], "b": np.array([4, 5, 6], dtype=dtype)}) + tm.assert_frame_equal(df, expected) + + +def test_isetitem_frame(): + df = DataFrame({"a": [1, 2, 3], "b": 1, "c": 2}) + rhs = DataFrame({"a": [4, 5, 6], "b": 2}) + df.isetitem([0, 1], rhs) + assert np.shares_memory(get_array(df, "a"), get_array(rhs, "a")) + assert np.shares_memory(get_array(df, "b"), get_array(rhs, "b")) + assert not df._mgr._has_no_reference(0) + expected = df.copy() + rhs.iloc[0, 0] = 100 + rhs.iloc[0, 1] = 100 + tm.assert_frame_equal(df, expected) + + +@pytest.mark.parametrize("key", ["a", ["a"]]) +def test_get(key): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df_orig = df.copy() + + result = df.get(key) + + assert np.shares_memory(get_array(result, "a"), get_array(df, "a")) + result.iloc[0] = 0 + assert not np.shares_memory(get_array(result, "a"), get_array(df, "a")) + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("axis, key", [(0, 0), (1, "a")]) +@pytest.mark.parametrize( + "dtype", ["int64", "float64"], ids=["single-block", "mixed-block"] +) +def test_xs(axis, key, dtype): + single_block = dtype == "int64" + df = DataFrame( + {"a": [1, 2, 3], "b": [4, 5, 6], "c": np.array([7, 8, 9], dtype=dtype)} + ) + df_orig = df.copy() + + result = df.xs(key, axis=axis) + + if axis == 1 or single_block: + assert np.shares_memory(get_array(df, "a"), get_array(result)) + else: + assert result._mgr._has_no_reference(0) + if axis == 0: + assert result.index is not df.columns + else: + assert result.index is not df.index + + result.iloc[0] = 0 + tm.assert_frame_equal(df, df_orig) + + +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.parametrize("key, level", [("l1", 0), (2, 1)]) +def test_xs_multiindex(key, level, axis): + arr = np.arange(18).reshape(6, 3) + index = MultiIndex.from_product([["l1", "l2"], [1, 2, 3]], names=["lev1", "lev2"]) + df = DataFrame(arr, index=index, columns=list("abc")) + if axis == 1: + df = df.transpose().copy() + df_orig = df.copy() + + result = df.xs(key, level=level, axis=axis) + + if level == 0: + assert np.shares_memory( + get_array(df, df.columns[0]), get_array(result, result.columns[0]) + ) + assert result.index is not df.index + assert result.columns is not df.columns + + result.iloc[0, 0] = 0 + tm.assert_frame_equal(df, df_orig) + + +def test_update_frame(): + df1 = DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 5.0, 6.0]}) + df2 = DataFrame({"b": [100.0]}, index=[1]) + df1_orig = df1.copy() + view = df1[:] + df1.update(df2) + + expected = DataFrame({"a": [1.0, 2.0, 3.0], "b": [4.0, 100.0, 6.0]}) + tm.assert_frame_equal(df1, expected) + # df1 is updated, but its view not + tm.assert_frame_equal(view, df1_orig) + assert np.shares_memory(get_array(df1, "a"), get_array(view, "a")) + assert not np.shares_memory(get_array(df1, "b"), get_array(view, "b")) + + +def test_update_series(): + ser1 = Series([1.0, 2.0, 3.0]) + ser2 = Series([100.0], index=[1]) + ser1_orig = ser1.copy() + view = ser1[:] + + ser1.update(ser2) + + expected = Series([1.0, 100.0, 3.0]) + tm.assert_series_equal(ser1, expected) + # ser1 is updated, but its view not + tm.assert_series_equal(view, ser1_orig) + + +def test_update_chained_assignment(): + df = DataFrame({"a": [1, 2, 3]}) + ser2 = Series([100.0], index=[1]) + df_orig = df.copy() + with tm.raises_chained_assignment_error(): + df["a"].update(ser2) + tm.assert_frame_equal(df, df_orig) + + with tm.raises_chained_assignment_error(): + df[["a"]].update(ser2.to_frame()) + tm.assert_frame_equal(df, df_orig) + + +def test_inplace_arithmetic_series(): + ser = Series([1, 2, 3]) + ser_orig = ser.copy() + data = get_array(ser) + ser *= 2 + # https://github.com/pandas-dev/pandas/pull/55745 + # changed to NOT update inplace because there is no benefit (actual + # operation already done non-inplace). This was only for the optics + # of updating the backing array inplace, but we no longer want to make + # that guarantee + assert not np.shares_memory(get_array(ser), data) + tm.assert_numpy_array_equal(data, get_array(ser_orig)) + + +def test_inplace_arithmetic_series_with_reference(): + ser = Series([1, 2, 3]) + ser_orig = ser.copy() + view = ser[:] + ser *= 2 + assert not np.shares_memory(get_array(ser), get_array(view)) + tm.assert_series_equal(ser_orig, view) + + +def test_transpose(): + df = DataFrame({"a": [1, 2, 3], "b": 1}) + df_orig = df.copy() + result = df.transpose() + assert np.shares_memory(get_array(df, "a"), get_array(result, 0)) + + result.iloc[0, 0] = 100 + tm.assert_frame_equal(df, df_orig) + + +def test_transpose_different_dtypes(): + df = DataFrame({"a": [1, 2, 3], "b": 1.5}) + df_orig = df.copy() + result = df.T + + assert not np.shares_memory(get_array(df, "a"), get_array(result, 0)) + result.iloc[0, 0] = 100 + tm.assert_frame_equal(df, df_orig) + + +def test_transpose_ea_single_column(): + df = DataFrame({"a": [1, 2, 3]}, dtype="Int64") + result = df.T + + assert not np.shares_memory(get_array(df, "a"), get_array(result, 0)) + + +def test_transform_frame(): + df = DataFrame({"a": [1, 2, 3], "b": 1}) + df_orig = df.copy() + + def func(ser): + ser.iloc[0] = 100 + return ser + + df.transform(func) + tm.assert_frame_equal(df, df_orig) + + +def test_transform_series(): + ser = Series([1, 2, 3]) + ser_orig = ser.copy() + + def func(ser): + ser.iloc[0] = 100 + return ser + + ser.transform(func) + tm.assert_series_equal(ser, ser_orig) + + +def test_count_read_only_array(): + df = DataFrame({"a": [1, 2], "b": 3}) + result = df.count() + result.iloc[0] = 100 + expected = Series([100, 2], index=["a", "b"]) + tm.assert_series_equal(result, expected) + + +def test_insert_series(): + df = DataFrame({"a": [1, 2, 3]}) + ser = Series([1, 2, 3]) + ser_orig = ser.copy() + df.insert(loc=1, value=ser, column="b") + assert np.shares_memory(get_array(ser), get_array(df, "b")) + assert not df._mgr._has_no_reference(1) + + df.iloc[0, 1] = 100 + tm.assert_series_equal(ser, ser_orig) + + +def test_eval(): + df = DataFrame({"a": [1, 2, 3], "b": 1}) + df_orig = df.copy() + + result = df.eval("c = a+b") + assert np.shares_memory(get_array(df, "a"), get_array(result, "a")) + + result.iloc[0, 0] = 100 + tm.assert_frame_equal(df, df_orig) + + +def test_eval_inplace(): + df = DataFrame({"a": [1, 2, 3], "b": 1}) + df_orig = df.copy() + df_view = df[:] + + df.eval("c = a+b", inplace=True) + assert np.shares_memory(get_array(df, "a"), get_array(df_view, "a")) + + df.iloc[0, 0] = 100 + tm.assert_frame_equal(df_view, df_orig) + + +def test_apply_modify_row(): + # Case: applying a function on each row as a Series object, where the + # function mutates the row object (which needs to trigger CoW if row is a view) + df = DataFrame({"A": [1, 2], "B": [3, 4]}) + df_orig = df.copy() + + def transform(row): + row["B"] = 100 + return row + + df.apply(transform, axis=1) + + tm.assert_frame_equal(df, df_orig) + + # row Series is a copy + df = DataFrame({"A": [1, 2], "B": ["b", "c"]}) + df_orig = df.copy() + + with tm.assert_produces_warning(None): + df.apply(transform, axis=1) + + tm.assert_frame_equal(df, df_orig) + + +def test_reduce(): + df = DataFrame({"a": [1, 2, 3], "b": 1.5}) + + result = df.sum() + assert result.index is not df.columns + + result = df.groupby([0, 0, 1]).sum() + assert result.columns is not df.columns + + result = df.quantile(0.5) + assert result.index is not df.columns + result = df.quantile([0.25, 0.5, 0.75]) + assert result.columns is not df.columns + + +def test_diff(): + df = DataFrame({"a": [1, 2, 3], "b": 1.5}) + + result = df.diff() + assert result.index is not df.index + assert result.columns is not df.columns + + ser = Series([1, 2, 3]) + result = ser.diff() + assert result.index is not ser.index diff --git a/pandas/tests/copy_view/test_replace.py b/pandas/tests/copy_view/test_replace.py new file mode 100644 index 0000000000000000000000000000000000000000..d4838a5e68ab8328d6263289bc73309424edf458 --- /dev/null +++ b/pandas/tests/copy_view/test_replace.py @@ -0,0 +1,356 @@ +import numpy as np +import pytest + +from pandas import ( + Categorical, + DataFrame, +) +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array + + +@pytest.mark.parametrize( + "replace_kwargs", + [ + {"to_replace": {"a": 1, "b": 4}, "value": -1}, + # Test CoW splits blocks to avoid copying unchanged columns + {"to_replace": {"a": 1}, "value": -1}, + {"to_replace": {"b": 4}, "value": -1}, + {"to_replace": {"b": {4: 1}}}, + # TODO: Add these in a further optimization + # We would need to see which columns got replaced in the mask + # which could be expensive + # {"to_replace": {"b": 1}}, + # 1 + ], +) +def test_replace(replace_kwargs): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}) + df_orig = df.copy() + + df_replaced = df.replace(**replace_kwargs) + + if (df_replaced["b"] == df["b"]).all(): + assert np.shares_memory(get_array(df_replaced, "b"), get_array(df, "b")) + assert tm.shares_memory(get_array(df_replaced, "c"), get_array(df, "c")) + + # mutating squeezed df triggers a copy-on-write for that column/block + df_replaced.loc[0, "c"] = -1 + assert not np.shares_memory(get_array(df_replaced, "c"), get_array(df, "c")) + + if "a" in replace_kwargs["to_replace"]: + arr = get_array(df_replaced, "a") + df_replaced.loc[0, "a"] = 100 + assert np.shares_memory(get_array(df_replaced, "a"), arr) + tm.assert_frame_equal(df, df_orig) + + +def test_replace_regex_inplace_refs(): + df = DataFrame({"a": ["aaa", "bbb"]}) + df_orig = df.copy() + view = df[:] + arr = get_array(df, "a") + df.replace(to_replace=r"^a.*$", value="new", inplace=True, regex=True) + assert not np.shares_memory(arr, get_array(df, "a")) + assert df._mgr._has_no_reference(0) + tm.assert_frame_equal(view, df_orig) + + +def test_replace_regex_inplace(): + df = DataFrame({"a": ["aaa", "bbb"]}) + arr = get_array(df, "a") + df.replace(to_replace=r"^a.*$", value="new", inplace=True, regex=True) + assert df._mgr._has_no_reference(0) + assert tm.shares_memory(arr, get_array(df, "a")) + + df_orig = df.copy() + df2 = df.replace(to_replace=r"^b.*$", value="new", regex=True) + tm.assert_frame_equal(df_orig, df) + assert not tm.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + +def test_replace_regex_inplace_no_op(): + df = DataFrame({"a": [1, 2]}) + arr = get_array(df, "a") + df.replace(to_replace=r"^a.$", value="new", inplace=True, regex=True) + assert df._mgr._has_no_reference(0) + assert np.shares_memory(arr, get_array(df, "a")) + + df_orig = df.copy() + df2 = df.replace(to_replace=r"^x.$", value="new", regex=True) + tm.assert_frame_equal(df_orig, df) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + + +def test_replace_mask_all_false_second_block(): + df = DataFrame({"a": [1.5, 2, 3], "b": 100.5, "c": 1, "d": 2}) + df_orig = df.copy() + + df2 = df.replace(to_replace=1.5, value=55.5) + + # TODO: Block splitting would allow us to avoid copying b + assert np.shares_memory(get_array(df, "c"), get_array(df2, "c")) + assert not np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + + df2.loc[0, "c"] = 1 + tm.assert_frame_equal(df, df_orig) # Original is unchanged + + assert not np.shares_memory(get_array(df, "c"), get_array(df2, "c")) + assert np.shares_memory(get_array(df, "d"), get_array(df2, "d")) + + +def test_replace_coerce_single_column(): + df = DataFrame({"a": [1.5, 2, 3], "b": 100.5}) + df_orig = df.copy() + + df2 = df.replace(to_replace=1.5, value="a") + assert np.shares_memory(get_array(df, "b"), get_array(df2, "b")) + assert not np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + + df2.loc[0, "b"] = 0.5 + tm.assert_frame_equal(df, df_orig) # Original is unchanged + assert not np.shares_memory(get_array(df, "b"), get_array(df2, "b")) + + +def test_replace_to_replace_wrong_dtype(): + df = DataFrame({"a": [1.5, 2, 3], "b": 100.5}) + df_orig = df.copy() + + df2 = df.replace(to_replace="xxx", value=1.5) + + assert np.shares_memory(get_array(df, "b"), get_array(df2, "b")) + assert np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + + df2.loc[0, "b"] = 0.5 + tm.assert_frame_equal(df, df_orig) # Original is unchanged + assert not np.shares_memory(get_array(df, "b"), get_array(df2, "b")) + + +def test_replace_list_categorical(): + df = DataFrame({"a": ["a", "b", "c"]}, dtype="category") + arr = get_array(df, "a") + + df.replace(["c"], value="a", inplace=True) + assert np.shares_memory(arr.codes, get_array(df, "a").codes) + assert df._mgr._has_no_reference(0) + + df_orig = df.copy() + df.replace(["b"], value="a") + df2 = df.apply(lambda x: x.cat.rename_categories({"b": "d"})) + assert not np.shares_memory(arr.codes, get_array(df2, "a").codes) + + tm.assert_frame_equal(df, df_orig) + + +def test_replace_list_inplace_refs_categorical(): + df = DataFrame({"a": ["a", "b", "c"]}, dtype="category") + view = df[:] + df_orig = df.copy() + df.replace(["c"], value="a", inplace=True) + tm.assert_frame_equal(df_orig, view) + + +@pytest.mark.parametrize("to_replace", [1.5, [1.5], []]) +def test_replace_inplace(to_replace): + df = DataFrame({"a": [1.5, 2, 3]}) + arr_a = get_array(df, "a") + df.replace(to_replace=1.5, value=15.5, inplace=True) + + assert np.shares_memory(get_array(df, "a"), arr_a) + assert df._mgr._has_no_reference(0) + + +@pytest.mark.parametrize("to_replace", [1.5, [1.5]]) +def test_replace_inplace_reference(to_replace): + df = DataFrame({"a": [1.5, 2, 3]}) + arr_a = get_array(df, "a") + view = df[:] + df.replace(to_replace=to_replace, value=15.5, inplace=True) + + assert not np.shares_memory(get_array(df, "a"), arr_a) + assert df._mgr._has_no_reference(0) + assert view._mgr._has_no_reference(0) + + +@pytest.mark.parametrize("to_replace", ["a", 100.5]) +def test_replace_inplace_reference_no_op(to_replace): + df = DataFrame({"a": [1.5, 2, 3]}) + arr_a = get_array(df, "a") + view = df[:] + df.replace(to_replace=to_replace, value=15.5, inplace=True) + + assert np.shares_memory(get_array(df, "a"), arr_a) + assert not df._mgr._has_no_reference(0) + assert not view._mgr._has_no_reference(0) + + +@pytest.mark.parametrize("to_replace", [1, [1]]) +def test_replace_categorical_inplace_reference(to_replace): + df = DataFrame({"a": Categorical([1, 2, 3])}) + df_orig = df.copy() + arr_a = get_array(df, "a") + view = df[:] + df.replace(to_replace=to_replace, value=1, inplace=True) + assert not np.shares_memory(get_array(df, "a").codes, arr_a.codes) + assert df._mgr._has_no_reference(0) + assert view._mgr._has_no_reference(0) + tm.assert_frame_equal(view, df_orig) + + +def test_replace_categorical_inplace(): + df = DataFrame({"a": Categorical([1, 2, 3])}) + arr_a = get_array(df, "a") + df.replace(to_replace=1, value=1, inplace=True) + + assert np.shares_memory(get_array(df, "a").codes, arr_a.codes) + assert df._mgr._has_no_reference(0) + + expected = DataFrame({"a": Categorical([1, 2, 3])}) + tm.assert_frame_equal(df, expected) + + +def test_replace_categorical(): + df = DataFrame({"a": Categorical([1, 2, 3])}) + df_orig = df.copy() + df2 = df.replace(to_replace=1, value=1) + + assert df._mgr._has_no_reference(0) + assert df2._mgr._has_no_reference(0) + assert not np.shares_memory(get_array(df, "a").codes, get_array(df2, "a").codes) + tm.assert_frame_equal(df, df_orig) + + arr_a = get_array(df2, "a").codes + df2.iloc[0, 0] = 2.0 + assert np.shares_memory(get_array(df2, "a").codes, arr_a) + + +@pytest.mark.parametrize("method", ["where", "mask"]) +def test_masking_inplace(method): + df = DataFrame({"a": [1.5, 2, 3]}) + df_orig = df.copy() + arr_a = get_array(df, "a") + view = df[:] + + method = getattr(df, method) + method(df["a"] > 1.6, -1, inplace=True) + + assert not np.shares_memory(get_array(df, "a"), arr_a) + assert df._mgr._has_no_reference(0) + assert view._mgr._has_no_reference(0) + tm.assert_frame_equal(view, df_orig) + + +def test_replace_empty_list(): + df = DataFrame({"a": [1, 2]}) + + df2 = df.replace([], []) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + assert not df._mgr._has_no_reference(0) + arr_a = get_array(df, "a") + df.replace([], []) + assert np.shares_memory(get_array(df, "a"), arr_a) + assert not df._mgr._has_no_reference(0) + assert not df2._mgr._has_no_reference(0) + + +@pytest.mark.parametrize("value", ["d", None]) +def test_replace_object_list_inplace(value): + df = DataFrame({"a": ["a", "b", "c"]}, dtype=object) + arr = get_array(df, "a") + df.replace(["c"], value, inplace=True) + assert np.shares_memory(arr, get_array(df, "a")) + assert df._mgr._has_no_reference(0) + + +def test_replace_list_multiple_elements_inplace(): + df = DataFrame({"a": [1, 2, 3]}) + arr = get_array(df, "a") + df.replace([1, 2], 4, inplace=True) + assert np.shares_memory(arr, get_array(df, "a")) + assert df._mgr._has_no_reference(0) + + +def test_replace_list_none(): + df = DataFrame({"a": ["a", "b", "c"]}) + + df_orig = df.copy() + df2 = df.replace(["b"], value=None) + tm.assert_frame_equal(df, df_orig) + + assert not np.shares_memory(get_array(df, "a"), get_array(df2, "a")) + + # replace multiple values that don't actually replace anything with None + # https://github.com/pandas-dev/pandas/issues/59770 + df3 = df.replace(["d", "e", "f"], value=None) + tm.assert_frame_equal(df3, df_orig) + assert tm.shares_memory(get_array(df, "a"), get_array(df3, "a")) + + +def test_replace_list_none_inplace_refs(): + df = DataFrame({"a": ["a", "b", "c"]}) + arr = get_array(df, "a") + df_orig = df.copy() + view = df[:] + df.replace(["a"], value=None, inplace=True) + assert df._mgr._has_no_reference(0) + assert not np.shares_memory(arr, get_array(df, "a")) + tm.assert_frame_equal(df_orig, view) + + +def test_replace_columnwise_no_op_inplace(): + df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + view = df[:] + df_orig = df.copy() + df.replace({"a": 10}, 100, inplace=True) + assert np.shares_memory(get_array(view, "a"), get_array(df, "a")) + df.iloc[0, 0] = 100 + tm.assert_frame_equal(view, df_orig) + + +def test_replace_columnwise_no_op(): + df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + df_orig = df.copy() + df2 = df.replace({"a": 10}, 100) + assert np.shares_memory(get_array(df2, "a"), get_array(df, "a")) + df2.iloc[0, 0] = 100 + tm.assert_frame_equal(df, df_orig) + + +def test_replace_chained_assignment(): + df = DataFrame({"a": [1, np.nan, 2], "b": 1}) + df_orig = df.copy() + with tm.raises_chained_assignment_error(): + df["a"].replace(1, 100, inplace=True) + tm.assert_frame_equal(df, df_orig) + + with tm.raises_chained_assignment_error(): + df[["a"]].replace(1, 100, inplace=True) + tm.assert_frame_equal(df, df_orig) + + +def test_replace_listlike(): + df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + df_orig = df.copy() + + result = df.replace([200, 201], [11, 11]) + assert np.shares_memory(get_array(result, "a"), get_array(df, "a")) + + result.iloc[0, 0] = 100 + tm.assert_frame_equal(df, df) + + result = df.replace([200, 2], [10, 10]) + assert not np.shares_memory(get_array(df, "a"), get_array(result, "a")) + tm.assert_frame_equal(df, df_orig) + + +def test_replace_listlike_inplace(): + df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + arr = get_array(df, "a") + df.replace([200, 2], [10, 11], inplace=True) + assert np.shares_memory(get_array(df, "a"), arr) + + view = df[:] + df_orig = df.copy() + df.replace([200, 3], [10, 11], inplace=True) + assert not np.shares_memory(get_array(df, "a"), arr) + tm.assert_frame_equal(view, df_orig) diff --git a/pandas/tests/copy_view/test_setitem.py b/pandas/tests/copy_view/test_setitem.py new file mode 100644 index 0000000000000000000000000000000000000000..2f28e9826c7a1bb2b5379e005f0bd2fd57ef4067 --- /dev/null +++ b/pandas/tests/copy_view/test_setitem.py @@ -0,0 +1,142 @@ +import numpy as np + +from pandas import ( + DataFrame, + Index, + MultiIndex, + RangeIndex, + Series, +) +import pandas._testing as tm +from pandas.tests.copy_view.util import get_array + +# ----------------------------------------------------------------------------- +# Copy/view behaviour for the values that are set in a DataFrame + + +def test_set_column_with_array(): + # Case: setting an array as a new column (df[col] = arr) copies that data + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + arr = np.array([1, 2, 3], dtype="int64") + + df["c"] = arr + + # the array data is copied + assert not np.shares_memory(get_array(df, "c"), arr) + # and thus modifying the array does not modify the DataFrame + arr[0] = 0 + tm.assert_series_equal(df["c"], Series([1, 2, 3], name="c")) + + +def test_set_column_with_series(): + # Case: setting a series as a new column (df[col] = s) copies that data + # (with delayed copy with CoW) + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + ser = Series([1, 2, 3]) + + df["c"] = ser + + assert np.shares_memory(get_array(df, "c"), get_array(ser)) + + # and modifying the series does not modify the DataFrame + ser.iloc[0] = 0 + assert ser.iloc[0] == 0 + tm.assert_series_equal(df["c"], Series([1, 2, 3], name="c")) + + +def test_set_column_with_index(): + # Case: setting an index as a new column (df[col] = idx) copies that data + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + idx = Index([1, 2, 3]) + + df["c"] = idx + + # the index data is copied + assert not np.shares_memory(get_array(df, "c"), idx.values) + + idx = RangeIndex(1, 4) + arr = idx.values + + df["d"] = idx + + assert not np.shares_memory(get_array(df, "d"), arr) + + +def test_set_columns_with_dataframe(): + # Case: setting a DataFrame as new columns copies that data + # (with delayed copy with CoW) + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df2 = DataFrame({"c": [7, 8, 9], "d": [10, 11, 12]}) + + df[["c", "d"]] = df2 + + assert np.shares_memory(get_array(df, "c"), get_array(df2, "c")) + # and modifying the set DataFrame does not modify the original DataFrame + df2.iloc[0, 0] = 0 + tm.assert_series_equal(df["c"], Series([7, 8, 9], name="c")) + + +def test_setitem_series_no_copy(): + # Case: setting a Series as column into a DataFrame can delay copying that data + df = DataFrame({"a": [1, 2, 3]}) + rhs = Series([4, 5, 6]) + rhs_orig = rhs.copy() + + # adding a new column + df["b"] = rhs + assert np.shares_memory(get_array(rhs), get_array(df, "b")) + + df.iloc[0, 1] = 100 + tm.assert_series_equal(rhs, rhs_orig) + + +def test_setitem_series_no_copy_single_block(): + # Overwriting an existing column that is a single block + df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3]}) + rhs = Series([4, 5, 6]) + rhs_orig = rhs.copy() + + df["a"] = rhs + assert np.shares_memory(get_array(rhs), get_array(df, "a")) + + df.iloc[0, 0] = 100 + tm.assert_series_equal(rhs, rhs_orig) + + +def test_setitem_series_no_copy_split_block(): + # Overwriting an existing column that is part of a larger block + df = DataFrame({"a": [1, 2, 3], "b": 1}) + rhs = Series([4, 5, 6]) + rhs_orig = rhs.copy() + + df["b"] = rhs + assert np.shares_memory(get_array(rhs), get_array(df, "b")) + + df.iloc[0, 1] = 100 + tm.assert_series_equal(rhs, rhs_orig) + + +def test_setitem_series_column_midx_broadcasting(): + # Setting a Series to multiple columns will repeat the data + # (currently copying the data eagerly) + df = DataFrame( + [[1, 2, 3], [3, 4, 5]], + columns=MultiIndex.from_arrays([["a", "a", "b"], [1, 2, 3]]), + ) + rhs = Series([10, 11]) + df["a"] = rhs + assert not np.shares_memory(get_array(rhs), df._get_column_array(0)) + assert df._mgr._has_no_reference(0) + + +def test_set_column_with_inplace_operator(): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + # this should not raise any warning + with tm.assert_produces_warning(None): + df["a"] += 1 + + # when it is not in a chain, then it should produce a warning + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + ser = df["a"] + ser += 1 diff --git a/pandas/tests/copy_view/test_util.py b/pandas/tests/copy_view/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..ff55330d70b28c5459a4c0915dd93c8640a91add --- /dev/null +++ b/pandas/tests/copy_view/test_util.py @@ -0,0 +1,14 @@ +import numpy as np + +from pandas import DataFrame +from pandas.tests.copy_view.util import get_array + + +def test_get_array_numpy(): + df = DataFrame({"a": [1, 2, 3]}) + assert np.shares_memory(get_array(df, "a"), get_array(df, "a")) + + +def test_get_array_masked(): + df = DataFrame({"a": [1, 2, 3]}, dtype="Int64") + assert np.shares_memory(get_array(df, "a"), get_array(df, "a")) diff --git a/pandas/tests/copy_view/util.py b/pandas/tests/copy_view/util.py new file mode 100644 index 0000000000000000000000000000000000000000..969334424936559767b0bca87093acfec52f9763 --- /dev/null +++ b/pandas/tests/copy_view/util.py @@ -0,0 +1,30 @@ +from pandas import ( + Categorical, + Index, + Series, +) +from pandas.core.arrays import BaseMaskedArray + + +def get_array(obj, col=None): + """ + Helper method to get array for a DataFrame column or a Series. + + Equivalent of df[col].values, but without going through normal getitem, + which triggers tracking references / CoW (and we might be testing that + this is done by some other operation). + """ + if isinstance(obj, Index): + arr = obj._values + elif isinstance(obj, Series) and (col is None or obj.name == col): + arr = obj._values + else: + assert col is not None + icol = obj.columns.get_loc(col) + assert isinstance(icol, int) + arr = obj._get_column_array(icol) + if isinstance(arr, BaseMaskedArray): + return arr._data + elif isinstance(arr, Categorical): + return arr + return getattr(arr, "_ndarray", arr) diff --git a/pandas/tests/dtypes/__init__.py b/pandas/tests/dtypes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/dtypes/test_common.py b/pandas/tests/dtypes/test_common.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d8195321140a0bc9e43473b561c5f2d09a1973 --- /dev/null +++ b/pandas/tests/dtypes/test_common.py @@ -0,0 +1,882 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from pandas.compat import HAS_PYARROW +from pandas.errors import Pandas4Warning +import pandas.util._test_decorators as td + +from pandas.core.dtypes.astype import astype_array +import pandas.core.dtypes.common as com +from pandas.core.dtypes.dtypes import ( + CategoricalDtype, + CategoricalDtypeType, + DatetimeTZDtype, + ExtensionDtype, + IntervalDtype, + PeriodDtype, +) +from pandas.core.dtypes.missing import isna + +import pandas as pd +import pandas._testing as tm +from pandas.api.types import pandas_dtype +from pandas.arrays import SparseArray +from pandas.util.version import Version + + +# EA & Actual Dtypes +def to_ea_dtypes(dtypes): + """convert list of string dtypes to EA dtype""" + return [getattr(pd, dt + "Dtype") for dt in dtypes] + + +def to_numpy_dtypes(dtypes): + """convert list of string dtypes to numpy dtype""" + return [getattr(np, dt) for dt in dtypes if isinstance(dt, str)] + + +class TestNumpyEADtype: + # Passing invalid dtype, both as a string or object, must raise TypeError + # Per issue GH15520 + @pytest.mark.parametrize("box", [pd.Timestamp, "pd.Timestamp", list]) + def test_invalid_dtype_error(self, box): + with pytest.raises(TypeError, match="not understood"): + com.pandas_dtype(box) + + @pytest.mark.parametrize( + "dtype", + [ + object, + "float64", + np.object_, + np.dtype("object"), + "O", + np.float64, + float, + np.dtype("float64"), + "object_", + ], + ) + def test_pandas_dtype_valid(self, dtype): + assert com.pandas_dtype(dtype) == dtype + + @pytest.mark.parametrize( + "dtype", ["M8[ns]", "m8[ns]", "object", "float64", "int64"] + ) + def test_numpy_dtype(self, dtype): + assert com.pandas_dtype(dtype) == np.dtype(dtype) + + def test_numpy_string_dtype(self): + # do not parse freq-like string as period dtype + assert com.pandas_dtype("U") == np.dtype("U") + assert com.pandas_dtype("S") == np.dtype("S") + + @pytest.mark.parametrize( + "dtype", + [ + "datetime64[ns, US/Eastern]", + "datetime64[ns, Asia/Tokyo]", + "datetime64[ns, UTC]", + # GH#33885 check that the M8 alias is understood + "M8[ns, US/Eastern]", + "M8[ns, Asia/Tokyo]", + "M8[ns, UTC]", + ], + ) + def test_datetimetz_dtype(self, dtype): + assert com.pandas_dtype(dtype) == DatetimeTZDtype.construct_from_string(dtype) + assert com.pandas_dtype(dtype) == dtype + + def test_categorical_dtype(self): + assert com.pandas_dtype("category") == CategoricalDtype() + + @pytest.mark.parametrize( + "dtype", + [ + "period[D]", + "period[3M]", + "period[us]", + "Period[D]", + "Period[3M]", + "Period[us]", + ], + ) + def test_period_dtype(self, dtype): + assert com.pandas_dtype(dtype) is not PeriodDtype(dtype) + assert com.pandas_dtype(dtype) == PeriodDtype(dtype) + assert com.pandas_dtype(dtype) == dtype + + +dtypes = { + "datetime_tz": com.pandas_dtype("datetime64[ns, US/Eastern]"), + "datetime": com.pandas_dtype("datetime64[ns]"), + "timedelta": com.pandas_dtype("timedelta64[ns]"), + "period": PeriodDtype("D"), + "integer": np.dtype(np.int64), + "float": np.dtype(np.float64), + "object": np.dtype(object), + "category": com.pandas_dtype("category"), + "string": pd.StringDtype("python"), +} + + +@pytest.mark.parametrize("name1,dtype1", list(dtypes.items()), ids=lambda x: str(x)) +@pytest.mark.parametrize("name2,dtype2", list(dtypes.items()), ids=lambda x: str(x)) +def test_dtype_equal(name1, dtype1, name2, dtype2): + # match equal to self, but not equal to other + assert com.is_dtype_equal(dtype1, dtype1) + if name1 != name2: + assert not com.is_dtype_equal(dtype1, dtype2) + + +@pytest.mark.parametrize("name,dtype", list(dtypes.items()), ids=lambda x: str(x)) +def test_pyarrow_string_import_error(name, dtype): + # GH-44276 + assert not com.is_dtype_equal(dtype, "string[pyarrow]") + + +@pytest.mark.parametrize( + "dtype1,dtype2", + [ + (np.int8, np.int64), + (np.int16, np.int64), + (np.int32, np.int64), + (np.float32, np.float64), + (PeriodDtype("D"), PeriodDtype("2D")), # PeriodType + ( + com.pandas_dtype("datetime64[ns, US/Eastern]"), + com.pandas_dtype("datetime64[ns, CET]"), + ), # Datetime + (None, None), # gh-15941: no exception should be raised. + ], +) +def test_dtype_equal_strict(dtype1, dtype2): + assert not com.is_dtype_equal(dtype1, dtype2) + + +def get_is_dtype_funcs(): + """ + Get all functions in pandas.core.dtypes.common that + begin with 'is_' and end with 'dtype' + + """ + fnames = [f for f in dir(com) if (f.startswith("is_") and f.endswith("dtype"))] + fnames.remove("is_string_or_object_np_dtype") # fastpath requires np.dtype obj + return [getattr(com, fname) for fname in fnames] + + +@pytest.mark.filterwarnings( + "ignore:is_categorical_dtype is deprecated:DeprecationWarning" +) +@pytest.mark.parametrize("func", get_is_dtype_funcs(), ids=lambda x: x.__name__) +def test_get_dtype_error_catch(func): + # see gh-15941 + # + # No exception should be raised. + + msg = f"{func.__name__} is deprecated" + warn = None + if ( + func is com.is_int64_dtype + or func is com.is_interval_dtype + or func is com.is_datetime64tz_dtype + or func is com.is_categorical_dtype + or func is com.is_period_dtype + ): + warn = Pandas4Warning + + with tm.assert_produces_warning(warn, match=msg): + assert not func(None) + + +def test_is_object(): + assert com.is_object_dtype(object) + assert com.is_object_dtype(np.array([], dtype=object)) + + assert not com.is_object_dtype(int) + assert not com.is_object_dtype(np.array([], dtype=int)) + assert not com.is_object_dtype([1, 2, 3]) + + +@pytest.mark.parametrize( + "check_scipy", [False, pytest.param(True, marks=td.skip_if_no("scipy"))] +) +def test_is_sparse(check_scipy): + msg = "is_sparse is deprecated" + with tm.assert_produces_warning(Pandas4Warning, match=msg): + assert com.is_sparse(SparseArray([1, 2, 3])) + + assert not com.is_sparse(np.array([1, 2, 3])) + + if check_scipy: + import scipy.sparse + + assert not com.is_sparse(scipy.sparse.bsr_matrix([1, 2, 3])) + + +def test_is_scipy_sparse(): + sp_sparse = pytest.importorskip("scipy.sparse") + + assert com.is_scipy_sparse(sp_sparse.bsr_matrix([1, 2, 3])) + + assert not com.is_scipy_sparse(SparseArray([1, 2, 3])) + + +def test_is_datetime64_dtype(): + assert not com.is_datetime64_dtype(object) + assert not com.is_datetime64_dtype([1, 2, 3]) + assert not com.is_datetime64_dtype(np.array([], dtype=int)) + + assert com.is_datetime64_dtype(np.datetime64) + assert com.is_datetime64_dtype(np.array([], dtype=np.datetime64)) + + +def test_is_datetime64tz_dtype(): + msg = "is_datetime64tz_dtype is deprecated" + with tm.assert_produces_warning(Pandas4Warning, match=msg): + assert not com.is_datetime64tz_dtype(object) + assert not com.is_datetime64tz_dtype([1, 2, 3]) + assert not com.is_datetime64tz_dtype(pd.DatetimeIndex([1, 2, 3])) + assert com.is_datetime64tz_dtype(pd.DatetimeIndex(["2000"], tz="US/Eastern")) + + +def test_custom_ea_kind_M_not_datetime64tz(): + # GH 34986 + class NotTZDtype(ExtensionDtype): + @property + def kind(self) -> str: + return "M" + + not_tz_dtype = NotTZDtype() + msg = "is_datetime64tz_dtype is deprecated" + with tm.assert_produces_warning(Pandas4Warning, match=msg): + assert not com.is_datetime64tz_dtype(not_tz_dtype) + assert not com.needs_i8_conversion(not_tz_dtype) + + +def test_is_timedelta64_dtype(): + assert not com.is_timedelta64_dtype(object) + assert not com.is_timedelta64_dtype(None) + assert not com.is_timedelta64_dtype([1, 2, 3]) + assert not com.is_timedelta64_dtype(np.array([], dtype=np.datetime64)) + assert not com.is_timedelta64_dtype("0 days") + assert not com.is_timedelta64_dtype("0 days 00:00:00") + assert not com.is_timedelta64_dtype(["0 days 00:00:00"]) + assert not com.is_timedelta64_dtype("NO DATE") + + assert com.is_timedelta64_dtype(np.timedelta64) + assert com.is_timedelta64_dtype(pd.Series([], dtype="timedelta64[ns]")) + assert com.is_timedelta64_dtype(pd.to_timedelta(["0 days", "1 days"])) + + +def test_is_period_dtype(): + msg = "is_period_dtype is deprecated" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + assert not com.is_period_dtype(object) + assert not com.is_period_dtype([1, 2, 3]) + assert not com.is_period_dtype(pd.Period("2017-01-01")) + + assert com.is_period_dtype(PeriodDtype(freq="D")) + assert com.is_period_dtype(pd.PeriodIndex([], freq="Y")) + + +def test_is_interval_dtype(): + msg = "is_interval_dtype is deprecated" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + assert not com.is_interval_dtype(object) + assert not com.is_interval_dtype([1, 2, 3]) + + assert com.is_interval_dtype(IntervalDtype()) + + interval = pd.Interval(1, 2, closed="right") + assert not com.is_interval_dtype(interval) + assert com.is_interval_dtype(pd.IntervalIndex([interval])) + + +def test_is_categorical_dtype(): + msg = "is_categorical_dtype is deprecated" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + assert not com.is_categorical_dtype(object) + assert not com.is_categorical_dtype([1, 2, 3]) + + assert com.is_categorical_dtype(CategoricalDtype()) + assert com.is_categorical_dtype(pd.Categorical([1, 2, 3])) + assert com.is_categorical_dtype(pd.CategoricalIndex([1, 2, 3])) + + +@pytest.mark.parametrize( + "dtype, expected", + [ + (int, False), + (pd.Series([1, 2]), False), + (str, True), + (object, True), + (np.array(["a", "b"]), True), + (pd.StringDtype(), True), + (pd.Index([], dtype="O"), True), + ], +) +def test_is_string_dtype(dtype, expected): + # GH#54661 + + result = com.is_string_dtype(dtype) + assert result is expected + + +@pytest.mark.parametrize( + "data", + [[(0, 1), (1, 1)], pd.Categorical([1, 2, 3]), np.array([1, 2], dtype=object)], +) +def test_is_string_dtype_arraylike_with_object_elements_not_strings(data): + # GH 15585 + assert not com.is_string_dtype(pd.Series(data)) + + +def test_is_string_dtype_nullable(nullable_string_dtype): + assert com.is_string_dtype(pd.array(["a", "b"], dtype=nullable_string_dtype)) + + +integer_dtypes: list = [] + + +@pytest.mark.parametrize( + "dtype", + [ + *integer_dtypes, + pd.Series([1, 2]), + *tm.ALL_INT_NUMPY_DTYPES, + *to_numpy_dtypes(tm.ALL_INT_NUMPY_DTYPES), + *tm.ALL_INT_EA_DTYPES, + *to_ea_dtypes(tm.ALL_INT_EA_DTYPES), + ], +) +def test_is_integer_dtype(dtype): + assert com.is_integer_dtype(dtype) + + +@pytest.mark.parametrize( + "dtype", + [ + str, + float, + np.datetime64, + np.timedelta64, + pd.Index([1, 2.0]), + np.array(["a", "b"]), + np.array([], dtype=np.timedelta64), + ], +) +def test_is_not_integer_dtype(dtype): + assert not com.is_integer_dtype(dtype) + + +signed_integer_dtypes: list = [] + + +@pytest.mark.parametrize( + "dtype", + [ + *signed_integer_dtypes, + pd.Series([1, 2]), + *tm.SIGNED_INT_NUMPY_DTYPES, + *to_numpy_dtypes(tm.SIGNED_INT_NUMPY_DTYPES), + *tm.SIGNED_INT_EA_DTYPES, + *to_ea_dtypes(tm.SIGNED_INT_EA_DTYPES), + ], +) +def test_is_signed_integer_dtype(dtype): + assert com.is_integer_dtype(dtype) + + +@pytest.mark.parametrize( + "dtype", + [ + str, + float, + np.datetime64, + np.timedelta64, + pd.Index([1, 2.0]), + np.array(["a", "b"]), + np.array([], dtype=np.timedelta64), + *tm.UNSIGNED_INT_NUMPY_DTYPES, + *to_numpy_dtypes(tm.UNSIGNED_INT_NUMPY_DTYPES), + *tm.UNSIGNED_INT_EA_DTYPES, + *to_ea_dtypes(tm.UNSIGNED_INT_EA_DTYPES), + ], +) +def test_is_not_signed_integer_dtype(dtype): + assert not com.is_signed_integer_dtype(dtype) + + +unsigned_integer_dtypes: list = [] + + +@pytest.mark.parametrize( + "dtype", + [ + *unsigned_integer_dtypes, + pd.Series([1, 2], dtype=np.uint32), + *tm.UNSIGNED_INT_NUMPY_DTYPES, + *to_numpy_dtypes(tm.UNSIGNED_INT_NUMPY_DTYPES), + *tm.UNSIGNED_INT_EA_DTYPES, + *to_ea_dtypes(tm.UNSIGNED_INT_EA_DTYPES), + ], +) +def test_is_unsigned_integer_dtype(dtype): + assert com.is_unsigned_integer_dtype(dtype) + + +@pytest.mark.parametrize( + "dtype", + [ + str, + float, + np.datetime64, + np.timedelta64, + pd.Index([1, 2.0]), + np.array(["a", "b"]), + np.array([], dtype=np.timedelta64), + *tm.SIGNED_INT_NUMPY_DTYPES, + *to_numpy_dtypes(tm.SIGNED_INT_NUMPY_DTYPES), + *tm.SIGNED_INT_EA_DTYPES, + *to_ea_dtypes(tm.SIGNED_INT_EA_DTYPES), + ], +) +def test_is_not_unsigned_integer_dtype(dtype): + assert not com.is_unsigned_integer_dtype(dtype) + + +@pytest.mark.parametrize( + "dtype", [np.int64, np.array([1, 2], dtype=np.int64), "Int64", pd.Int64Dtype] +) +def test_is_int64_dtype(dtype): + msg = "is_int64_dtype is deprecated" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + assert com.is_int64_dtype(dtype) + + +def test_type_comparison_with_numeric_ea_dtype(any_numeric_ea_dtype): + # GH#43038 + assert pandas_dtype(any_numeric_ea_dtype) == any_numeric_ea_dtype + + +def test_type_comparison_with_real_numpy_dtype(any_real_numpy_dtype): + # GH#43038 + assert pandas_dtype(any_real_numpy_dtype) == any_real_numpy_dtype + + +def test_type_comparison_with_signed_int_ea_dtype_and_signed_int_numpy_dtype( + any_signed_int_ea_dtype, any_signed_int_numpy_dtype +): + # GH#43038 + assert not pandas_dtype(any_signed_int_ea_dtype) == any_signed_int_numpy_dtype + + +@pytest.mark.parametrize( + "dtype", + [ + str, + float, + np.int32, + np.uint64, + pd.Index([1, 2.0]), + np.array(["a", "b"]), + np.array([1, 2], dtype=np.uint32), + "int8", + "Int8", + pd.Int8Dtype, + ], +) +def test_is_not_int64_dtype(dtype): + msg = "is_int64_dtype is deprecated" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + assert not com.is_int64_dtype(dtype) + + +def test_is_datetime64_any_dtype(): + assert not com.is_datetime64_any_dtype(int) + assert not com.is_datetime64_any_dtype(str) + assert not com.is_datetime64_any_dtype(np.array([1, 2])) + assert not com.is_datetime64_any_dtype(np.array(["a", "b"])) + + assert com.is_datetime64_any_dtype(np.datetime64) + assert com.is_datetime64_any_dtype(np.array([], dtype=np.datetime64)) + assert com.is_datetime64_any_dtype(DatetimeTZDtype("ns", "US/Eastern")) + assert com.is_datetime64_any_dtype( + pd.DatetimeIndex([1, 2, 3], dtype="datetime64[ns]") + ) + + +def test_is_datetime64_ns_dtype(): + assert not com.is_datetime64_ns_dtype(int) + assert not com.is_datetime64_ns_dtype(str) + assert not com.is_datetime64_ns_dtype(np.datetime64) + assert not com.is_datetime64_ns_dtype(np.array([1, 2])) + assert not com.is_datetime64_ns_dtype(np.array(["a", "b"])) + assert not com.is_datetime64_ns_dtype(np.array([], dtype=np.datetime64)) + + # This datetime array has the wrong unit (ps instead of ns) + assert not com.is_datetime64_ns_dtype(np.array([], dtype="datetime64[ps]")) + + assert com.is_datetime64_ns_dtype(DatetimeTZDtype("ns", "US/Eastern")) + assert com.is_datetime64_ns_dtype( + pd.DatetimeIndex([1, 2, 3], dtype=np.dtype("datetime64[ns]")) + ) + + # non-nano dt64tz + assert not com.is_datetime64_ns_dtype(DatetimeTZDtype("us", "US/Eastern")) + + +def test_is_timedelta64_ns_dtype(): + assert not com.is_timedelta64_ns_dtype(np.dtype("m8[ps]")) + assert not com.is_timedelta64_ns_dtype(np.array([1, 2], dtype=np.timedelta64)) + + assert com.is_timedelta64_ns_dtype(np.dtype("m8[ns]")) + assert com.is_timedelta64_ns_dtype(np.array([1, 2], dtype="m8[ns]")) + + +def test_is_numeric_v_string_like(): + assert not com.is_numeric_v_string_like(np.array([1]), 1) + assert not com.is_numeric_v_string_like(np.array([1]), np.array([2])) + assert not com.is_numeric_v_string_like(np.array(["foo"]), np.array(["foo"])) + + assert com.is_numeric_v_string_like(np.array([1]), "foo") + assert com.is_numeric_v_string_like(np.array([1, 2]), np.array(["foo"])) + assert com.is_numeric_v_string_like(np.array(["foo"]), np.array([1, 2])) + + +def test_needs_i8_conversion(): + assert not com.needs_i8_conversion(str) + assert not com.needs_i8_conversion(np.int64) + assert not com.needs_i8_conversion(pd.Series([1, 2])) + assert not com.needs_i8_conversion(np.array(["a", "b"])) + + assert not com.needs_i8_conversion(np.datetime64) + assert com.needs_i8_conversion(np.dtype(np.datetime64)) + assert not com.needs_i8_conversion(pd.Series([], dtype="timedelta64[ns]")) + assert com.needs_i8_conversion(pd.Series([], dtype="timedelta64[ns]").dtype) + assert not com.needs_i8_conversion(pd.DatetimeIndex(["2000"], tz="US/Eastern")) + assert com.needs_i8_conversion(pd.DatetimeIndex(["2000"], tz="US/Eastern").dtype) + + +def test_is_numeric_dtype(): + assert not com.is_numeric_dtype(str) + assert not com.is_numeric_dtype(np.datetime64) + assert not com.is_numeric_dtype(np.timedelta64) + assert not com.is_numeric_dtype(np.array(["a", "b"])) + assert not com.is_numeric_dtype(np.array([], dtype=np.timedelta64)) + + assert com.is_numeric_dtype(int) + assert com.is_numeric_dtype(float) + assert com.is_numeric_dtype(np.uint64) + assert com.is_numeric_dtype(pd.Series([1, 2])) + assert com.is_numeric_dtype(pd.Index([1, 2.0])) + + class MyNumericDType(ExtensionDtype): + @property + def type(self): + return str + + @property + def name(self): + raise NotImplementedError + + def construct_array_type(self): + raise NotImplementedError + + def _is_numeric(self) -> bool: + return True + + assert com.is_numeric_dtype(MyNumericDType()) + + +def test_is_any_real_numeric_dtype(): + assert not com.is_any_real_numeric_dtype(str) + assert not com.is_any_real_numeric_dtype(bool) + assert not com.is_any_real_numeric_dtype(complex) + assert not com.is_any_real_numeric_dtype(object) + assert not com.is_any_real_numeric_dtype(np.datetime64) + assert not com.is_any_real_numeric_dtype(np.array(["a", "b", complex(1, 2)])) + assert not com.is_any_real_numeric_dtype(pd.DataFrame([complex(1, 2), True])) + + assert com.is_any_real_numeric_dtype(int) + assert com.is_any_real_numeric_dtype(float) + assert com.is_any_real_numeric_dtype(np.array([1, 2.5])) + + +def test_is_float_dtype(): + assert not com.is_float_dtype(str) + assert not com.is_float_dtype(int) + assert not com.is_float_dtype(pd.Series([1, 2])) + assert not com.is_float_dtype(np.array(["a", "b"])) + + assert com.is_float_dtype(float) + assert com.is_float_dtype(pd.Index([1, 2.0])) + + +def test_is_bool_dtype(): + assert not com.is_bool_dtype(int) + assert not com.is_bool_dtype(str) + assert not com.is_bool_dtype(pd.Series([1, 2])) + assert not com.is_bool_dtype(pd.Series(["a", "b"], dtype="category")) + assert not com.is_bool_dtype(np.array(["a", "b"])) + assert not com.is_bool_dtype(pd.Index(["a", "b"])) + assert not com.is_bool_dtype("Int64") + + assert com.is_bool_dtype(bool) + assert com.is_bool_dtype(np.bool_) + assert com.is_bool_dtype(pd.Series([True, False], dtype="category")) + assert com.is_bool_dtype(np.array([True, False])) + assert com.is_bool_dtype(pd.Index([True, False])) + + assert com.is_bool_dtype(pd.BooleanDtype()) + assert com.is_bool_dtype(pd.array([True, False, None], dtype="boolean")) + assert com.is_bool_dtype("boolean") + + +def test_is_bool_dtype_numpy_error(): + # GH39010 + assert not com.is_bool_dtype("0 - Name") + + +@pytest.mark.parametrize( + "check_scipy", [False, pytest.param(True, marks=td.skip_if_no("scipy"))] +) +def test_is_extension_array_dtype(check_scipy): + assert not com.is_extension_array_dtype([1, 2, 3]) + assert not com.is_extension_array_dtype(np.array([1, 2, 3])) + assert not com.is_extension_array_dtype(pd.DatetimeIndex([1, 2, 3])) + + cat = pd.Categorical([1, 2, 3]) + assert com.is_extension_array_dtype(cat) + assert com.is_extension_array_dtype(pd.Series(cat)) + assert com.is_extension_array_dtype(SparseArray([1, 2, 3])) + assert com.is_extension_array_dtype(pd.DatetimeIndex(["2000"], tz="US/Eastern")) + + dtype = DatetimeTZDtype("ns", tz="US/Eastern") + s = pd.Series([], dtype=dtype) + assert com.is_extension_array_dtype(s) + + if check_scipy: + import scipy.sparse + + assert not com.is_extension_array_dtype(scipy.sparse.bsr_matrix([1, 2, 3])) + + +def test_is_complex_dtype(): + assert not com.is_complex_dtype(int) + assert not com.is_complex_dtype(str) + assert not com.is_complex_dtype(pd.Series([1, 2])) + assert not com.is_complex_dtype(np.array(["a", "b"])) + + assert com.is_complex_dtype(np.complex128) + assert com.is_complex_dtype(complex) + assert com.is_complex_dtype(np.array([1 + 1j, 5])) + + +@pytest.mark.parametrize( + "input_param,result", + [ + (int, np.dtype(int)), + ("int32", np.dtype("int32")), + (float, np.dtype(float)), + ("float64", np.dtype("float64")), + (np.dtype("float64"), np.dtype("float64")), + (str, np.dtype(str)), + (pd.Series([1, 2], dtype=np.dtype("int16")), np.dtype("int16")), + (pd.Series(["a", "b"], dtype=object), np.dtype(object)), + (pd.Index([1, 2]), np.dtype("int64")), + (pd.Index(["a", "b"], dtype=object), np.dtype(object)), + ("category", "category"), + (pd.Categorical(["a", "b"]).dtype, CategoricalDtype(["a", "b"])), + (pd.Categorical(["a", "b"]), CategoricalDtype(["a", "b"])), + (pd.CategoricalIndex(["a", "b"]).dtype, CategoricalDtype(["a", "b"])), + (pd.CategoricalIndex(["a", "b"]), CategoricalDtype(["a", "b"])), + (CategoricalDtype(), CategoricalDtype()), + (pd.DatetimeIndex([1, 2]), np.dtype("=M8[ns]")), + (pd.DatetimeIndex([1, 2]).dtype, np.dtype("=M8[ns]")), + (" df.two.sum() + + with tm.assert_produces_warning(None): + # successfully modify column in place + # this should not raise a warning + df.one += 1 + assert df.one.iloc[0] == 2 + + with tm.assert_produces_warning(None): + # successfully add an attribute to a series + # this should not raise a warning + df.two.not_an_index = [1, 2] + + with tm.assert_produces_warning(UserWarning, match="doesn't allow columns"): + # warn when setting column to nonexistent name + df.four = df.two + 2 + assert df.four.sum() > df.two.sum() diff --git a/pandas/tests/dtypes/test_inference.py b/pandas/tests/dtypes/test_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..5cea1276f7f4bd5d5077b29aa698484d324ccc7f --- /dev/null +++ b/pandas/tests/dtypes/test_inference.py @@ -0,0 +1,2155 @@ +""" +These the test the public routines exposed in types/common.py +related to inference and not otherwise tested in types/test_common.py + +""" + +import collections +from collections import namedtuple +from collections.abc import Iterator +from datetime import ( + date, + datetime, + time, + timedelta, + timezone, +) +from decimal import Decimal +from fractions import Fraction +from io import StringIO +import itertools +from numbers import Number +import re +import sys +from typing import ( + Generic, + TypeVar, +) + +import numpy as np +import pytest + +from pandas._libs import ( + lib, + missing as libmissing, + ops as libops, +) +from pandas.compat import PY312 +from pandas.compat.numpy import np_version_gt2 +from pandas.errors import Pandas4Warning + +from pandas.core.dtypes import inference +from pandas.core.dtypes.cast import find_result_type +from pandas.core.dtypes.common import ( + ensure_int32, + is_bool, + is_complex, + is_datetime64_any_dtype, + is_datetime64_dtype, + is_datetime64_ns_dtype, + is_datetime64tz_dtype, + is_float, + is_integer, + is_number, + is_scalar, + is_scipy_sparse, + is_timedelta64_dtype, + is_timedelta64_ns_dtype, +) + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + DateOffset, + DatetimeIndex, + Index, + Interval, + Period, + Series, + Timedelta, + TimedeltaIndex, + Timestamp, +) +import pandas._testing as tm +from pandas.core.arrays import ( + BooleanArray, + FloatingArray, + IntegerArray, +) + + +@pytest.fixture(params=[True, False], ids=str) +def coerce(request): + return request.param + + +class MockNumpyLikeArray: + """ + A class which is numpy-like (e.g. Pint's Quantity) but not actually numpy + + The key is that it is not actually a numpy array so + ``util.is_array(mock_numpy_like_array_instance)`` returns ``False``. Other + important properties are that the class defines a :meth:`__iter__` method + (so that ``isinstance(abc.Iterable)`` returns ``True``) and has a + :meth:`ndim` property, as pandas special-cases 0-dimensional arrays in some + cases. + + We expect pandas to behave with respect to such duck arrays exactly as + with real numpy arrays. In particular, a 0-dimensional duck array is *NOT* + a scalar (`is_scalar(np.array(1)) == False`), but it is not list-like either. + """ + + def __init__(self, values) -> None: + self._values = values + + def __iter__(self) -> Iterator: + iter_values = iter(self._values) + + def it_outer(): + yield from iter_values + + return it_outer() + + def __len__(self) -> int: + return len(self._values) + + def __array__(self, dtype=None, copy=None): + return np.asarray(self._values, dtype=dtype) + + @property + def ndim(self): + return self._values.ndim + + @property + def dtype(self): + return self._values.dtype + + @property + def size(self): + return self._values.size + + @property + def shape(self): + return self._values.shape + + +# collect all objects to be tested for list-like-ness; use tuples of objects, +# whether they are list-like or not (special casing for sets), and their ID +ll_params = [ + ([1], True, "list"), + ([], True, "list-empty"), + ((1,), True, "tuple"), + ((), True, "tuple-empty"), + ({"a": 1}, True, "dict"), + ({}, True, "dict-empty"), + ({"a", 1}, "set", "set"), + (set(), "set", "set-empty"), + (frozenset({"a", 1}), "set", "frozenset"), + (frozenset(), "set", "frozenset-empty"), + (iter([1, 2]), True, "iterator"), + (iter([]), True, "iterator-empty"), + ((x for x in [1, 2]), True, "generator"), + ((_ for _ in []), True, "generator-empty"), + (Series([1]), True, "Series"), + (Series([], dtype=object), True, "Series-empty"), + # Series.str will still raise a TypeError if iterated + (Series(["a"]).str, True, "StringMethods"), + (Series([], dtype="O").str, True, "StringMethods-empty"), + (Index([1]), True, "Index"), + (Index([]), True, "Index-empty"), + (DataFrame([[1]]), True, "DataFrame"), + (DataFrame(), True, "DataFrame-empty"), + (np.ndarray((2,) * 1), True, "ndarray-1d"), + (np.array([]), True, "ndarray-1d-empty"), + (np.ndarray((2,) * 2), True, "ndarray-2d"), + (np.array([[]]), True, "ndarray-2d-empty"), + (np.ndarray((2,) * 3), True, "ndarray-3d"), + (np.array([[[]]]), True, "ndarray-3d-empty"), + (np.ndarray((2,) * 4), True, "ndarray-4d"), + (np.array([[[[]]]]), True, "ndarray-4d-empty"), + (np.array(2), False, "ndarray-0d"), + (MockNumpyLikeArray(np.ndarray((2,) * 1)), True, "duck-ndarray-1d"), + (MockNumpyLikeArray(np.array([])), True, "duck-ndarray-1d-empty"), + (MockNumpyLikeArray(np.ndarray((2,) * 2)), True, "duck-ndarray-2d"), + (MockNumpyLikeArray(np.array([[]])), True, "duck-ndarray-2d-empty"), + (MockNumpyLikeArray(np.ndarray((2,) * 3)), True, "duck-ndarray-3d"), + (MockNumpyLikeArray(np.array([[[]]])), True, "duck-ndarray-3d-empty"), + (MockNumpyLikeArray(np.ndarray((2,) * 4)), True, "duck-ndarray-4d"), + (MockNumpyLikeArray(np.array([[[[]]]])), True, "duck-ndarray-4d-empty"), + (MockNumpyLikeArray(np.array(2)), False, "duck-ndarray-0d"), + (1, False, "int"), + (b"123", False, "bytes"), + (b"", False, "bytes-empty"), + ("123", False, "string"), + ("", False, "string-empty"), + (str, False, "string-type"), + (object(), False, "object"), + (np.nan, False, "NaN"), + (None, False, "None"), +] +objs, expected, ids = zip(*ll_params, strict=True) + + +@pytest.fixture(params=zip(objs, expected, strict=True), ids=ids) +def maybe_list_like(request): + return request.param + + +def test_is_list_like(maybe_list_like): + obj, expected = maybe_list_like + expected = True if expected == "set" else expected + assert inference.is_list_like(obj) == expected + + +def test_is_list_like_disallow_sets(maybe_list_like): + obj, expected = maybe_list_like + expected = False if expected == "set" else expected + assert inference.is_list_like(obj, allow_sets=False) == expected + + +def test_is_list_like_recursion(): + # GH 33721 + # interpreter would crash with SIGABRT + def list_like(): + inference.is_list_like([]) + list_like() + + rec_limit = sys.getrecursionlimit() + try: + # Limit to avoid stack overflow on Windows CI + sys.setrecursionlimit(100) + with tm.external_error_raised(RecursionError): + list_like() + finally: + sys.setrecursionlimit(rec_limit) + + +def test_is_list_like_iter_is_none(): + # GH 43373 + # is_list_like was yielding false positives with __iter__ == None + class NotListLike: + def __getitem__(self, item): + return self + + __iter__ = None + + assert not inference.is_list_like(NotListLike()) + + +def test_is_list_like_generic(): + # GH 49649 + # is_list_like was yielding false positives for Generic classes in python 3.11 + T = TypeVar("T") + + class MyDataFrame(DataFrame, Generic[T]): ... + + tstc = MyDataFrame[int] + tst = MyDataFrame[int]({"x": [1, 2, 3]}) + + assert not inference.is_list_like(tstc) + assert isinstance(tst, DataFrame) + assert inference.is_list_like(tst) + + +def test_is_list_like_native_container_types(): + # GH 61565 + # is_list_like was yielding false positives for native container types + assert not inference.is_list_like(list[int]) + assert not inference.is_list_like(list[str]) + assert not inference.is_list_like(tuple[int]) + assert not inference.is_list_like(tuple[str]) + + +def test_is_sequence(): + is_seq = inference.is_sequence + assert is_seq((1, 2)) + assert is_seq([1, 2]) + assert not is_seq("abcd") + assert not is_seq(np.int64) + + class A: + def __getitem__(self, item): + return 1 + + assert not is_seq(A()) + + +def test_is_array_like(): + assert inference.is_array_like(Series([], dtype=object)) + assert inference.is_array_like(Series([1, 2])) + assert inference.is_array_like(np.array(["a", "b"])) + assert inference.is_array_like(Index(["2016-01-01"])) + assert inference.is_array_like(np.array([2, 3])) + assert inference.is_array_like(MockNumpyLikeArray(np.array([2, 3]))) + + class DtypeList(list): + dtype = "special" + + assert inference.is_array_like(DtypeList()) + + assert not inference.is_array_like([1, 2, 3]) + assert not inference.is_array_like(()) + assert not inference.is_array_like("foo") + assert not inference.is_array_like(123) + + +@pytest.mark.parametrize( + "inner", + [ + [], + [1], + (1,), + (1, 2), + {"a": 1}, + {1, "a"}, + Series([1]), + Series([], dtype=object), + Series(["a"]).str, + (x for x in range(5)), + ], +) +@pytest.mark.parametrize("outer", [list, Series, np.array, tuple]) +def test_is_nested_list_like_passes(inner, outer): + result = outer([inner for _ in range(5)]) + assert inference.is_list_like(result) + + +@pytest.mark.parametrize( + "obj", + [ + "abc", + [], + [1], + (1,), + ["a"], + "a", + {"a"}, + [1, 2, 3], + Series([1]), + DataFrame({"A": [1]}), + ([1, 2] for _ in range(5)), + ], +) +def test_is_nested_list_like_fails(obj): + assert not inference.is_nested_list_like(obj) + + +@pytest.mark.parametrize("ll", [{}, {"A": 1}, Series([1]), collections.defaultdict()]) +def test_is_dict_like_passes(ll): + assert inference.is_dict_like(ll) + + +@pytest.mark.parametrize( + "ll", + [ + "1", + 1, + [1, 2], + (1, 2), + range(2), + Index([1]), + dict, + collections.defaultdict, + Series, + ], +) +def test_is_dict_like_fails(ll): + assert not inference.is_dict_like(ll) + + +@pytest.mark.parametrize("has_keys", [True, False]) +@pytest.mark.parametrize("has_getitem", [True, False]) +@pytest.mark.parametrize("has_contains", [True, False]) +def test_is_dict_like_duck_type(has_keys, has_getitem, has_contains): + class DictLike: + def __init__(self, d) -> None: + self.d = d + + if has_keys: + + def keys(self): + return self.d.keys() + + if has_getitem: + + def __getitem__(self, key): + return self.d.__getitem__(key) + + if has_contains: + + def __contains__(self, key) -> bool: + return self.d.__contains__(key) + + d = DictLike({1: 2}) + result = inference.is_dict_like(d) + expected = has_keys and has_getitem and has_contains + + assert result is expected + + +def test_is_file_like(): + class MockFile: + pass + + is_file = inference.is_file_like + + data = StringIO("data") + assert is_file(data) + + # No read / write attributes + # No iterator attributes + m = MockFile() + assert not is_file(m) + + MockFile.write = lambda self: 0 + + # Write attribute but not an iterator + m = MockFile() + assert not is_file(m) + + # gh-16530: Valid iterator just means we have the + # __iter__ attribute for our purposes. + MockFile.__iter__ = lambda self: self + + # Valid write-only file + m = MockFile() + assert is_file(m) + + del MockFile.write + MockFile.read = lambda self: 0 + + # Valid read-only file + m = MockFile() + assert is_file(m) + + # Iterator but no read / write attributes + data = [1, 2, 3] + assert not is_file(data) + + +test_tuple = collections.namedtuple("test_tuple", ["a", "b", "c"]) + + +@pytest.mark.parametrize("ll", [test_tuple(1, 2, 3)]) +def test_is_names_tuple_passes(ll): + assert inference.is_named_tuple(ll) + + +@pytest.mark.parametrize("ll", [(1, 2, 3), "a", Series({"pi": 3.14})]) +def test_is_names_tuple_fails(ll): + assert not inference.is_named_tuple(ll) + + +def test_is_hashable(): + # all new-style classes are hashable by default + class HashableClass: + pass + + class UnhashableClass1: + __hash__ = None + + class UnhashableClass2: + def __hash__(self): + raise TypeError("Not hashable") + + # Temporary helper for Python 3.11 compatibility. + # This can be removed once support for Python 3.11 is dropped. + class HashableSlice: + def __init__(self, start, stop, step=None): + self.slice = slice(start, stop, step) + + def __eq__(self, other): + return isinstance(other, HashableSlice) and self.slice == other.slice + + def __hash__(self): + return hash((self.slice.start, self.slice.stop, self.slice.step)) + + def __repr__(self): + return ( + f"HashableSlice({self.slice.start}, {self.slice.stop}, " + f"{self.slice.step})" + ) + + hashable = (1, 3.14, np.float64(3.14), "a", (), (1,), HashableClass()) + not_hashable = ([], UnhashableClass1()) + abc_hashable_not_really_hashable = (([],), UnhashableClass2()) + hashable_slice = HashableSlice(1, 2) + tuple_with_slice = (slice(1, 2), 3) + + for i in hashable: + assert inference.is_hashable(i) + assert inference.is_hashable(i, allow_slice=True) + assert inference.is_hashable(i, allow_slice=False) + for i in not_hashable: + assert not inference.is_hashable(i) + assert not inference.is_hashable(i, allow_slice=True) + assert not inference.is_hashable(i, allow_slice=False) + for i in abc_hashable_not_really_hashable: + assert not inference.is_hashable(i) + assert not inference.is_hashable(i, allow_slice=True) + assert not inference.is_hashable(i, allow_slice=False) + + assert inference.is_hashable(hashable_slice) + assert inference.is_hashable(hashable_slice, allow_slice=True) + assert inference.is_hashable(hashable_slice, allow_slice=False) + + if PY312: + for obj in [slice(1, 2), tuple_with_slice]: + assert inference.is_hashable(obj) + assert inference.is_hashable(obj, allow_slice=True) + assert not inference.is_hashable(obj, allow_slice=False) + else: + for obj in [slice(1, 2), tuple_with_slice]: + assert not inference.is_hashable(obj) + assert not inference.is_hashable(obj, allow_slice=True) + assert not inference.is_hashable(obj, allow_slice=False) + + # numpy.array is no longer collections.abc.Hashable as of + # https://github.com/numpy/numpy/pull/5326, just test + # is_hashable() + assert not inference.is_hashable(np.array([])) + + +@pytest.mark.parametrize("ll", [re.compile("ad")]) +def test_is_re_passes(ll): + assert inference.is_re(ll) + + +@pytest.mark.parametrize("ll", ["x", 2, 3, object()]) +def test_is_re_fails(ll): + assert not inference.is_re(ll) + + +@pytest.mark.parametrize( + "ll", [r"a", "x", r"asdf", re.compile("adsf"), r"\u2233\s*", re.compile(r"")] +) +def test_is_recompilable_passes(ll): + assert inference.is_re_compilable(ll) + + +@pytest.mark.parametrize("ll", [1, [], object()]) +def test_is_recompilable_fails(ll): + assert not inference.is_re_compilable(ll) + + +class TestInference: + @pytest.mark.parametrize( + "arr", + [ + np.array(list("abc"), dtype="S1"), + np.array(list("abc"), dtype="S1").astype(object), + [b"a", np.nan, b"c"], + ], + ) + def test_infer_dtype_bytes(self, arr): + result = lib.infer_dtype(arr, skipna=True) + assert result == "bytes" + + @pytest.mark.parametrize( + "value, expected", + [ + (float("inf"), True), + (np.inf, True), + (-np.inf, False), + (1, False), + ("a", False), + ], + ) + def test_isposinf_scalar(self, value, expected): + # GH 11352 + result = libmissing.isposinf_scalar(value) + assert result is expected + + @pytest.mark.parametrize( + "value, expected", + [ + (float("-inf"), True), + (-np.inf, True), + (np.inf, False), + (1, False), + ("a", False), + ], + ) + def test_isneginf_scalar(self, value, expected): + result = libmissing.isneginf_scalar(value) + assert result is expected + + @pytest.mark.parametrize( + "convert_to_masked_nullable, exp", + [ + ( + True, + BooleanArray( + np.array([True, False], dtype="bool"), np.array([False, True]) + ), + ), + (False, np.array([True, np.nan], dtype="object")), + ], + ) + def test_maybe_convert_nullable_boolean(self, convert_to_masked_nullable, exp): + # GH 40687 + arr = np.array([True, np.nan], dtype=object) + result = libops.maybe_convert_bool( + arr, set(), convert_to_masked_nullable=convert_to_masked_nullable + ) + if convert_to_masked_nullable: + tm.assert_extension_array_equal(BooleanArray(*result), exp) + else: + result = result[0] + tm.assert_numpy_array_equal(result, exp) + + @pytest.mark.parametrize("convert_to_masked_nullable", [True, False]) + @pytest.mark.parametrize("coerce_numeric", [True, False]) + @pytest.mark.parametrize( + "infinity", ["inf", "inF", "iNf", "Inf", "iNF", "InF", "INf", "INF"] + ) + @pytest.mark.parametrize("prefix", ["", "-", "+"]) + def test_maybe_convert_numeric_infinities( + self, coerce_numeric, infinity, prefix, convert_to_masked_nullable + ): + # see gh-13274 + result, _ = lib.maybe_convert_numeric( + np.array([prefix + infinity], dtype=object), + na_values={"", "NULL", "nan"}, + coerce_numeric=coerce_numeric, + convert_to_masked_nullable=convert_to_masked_nullable, + ) + expected = np.array([np.inf if prefix in ["", "+"] else -np.inf]) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("convert_to_masked_nullable", [True, False]) + def test_maybe_convert_numeric_infinities_raises(self, convert_to_masked_nullable): + msg = "Unable to parse string" + with pytest.raises(ValueError, match=msg): + lib.maybe_convert_numeric( + np.array(["foo_inf"], dtype=object), + na_values={"", "NULL", "nan"}, + coerce_numeric=False, + convert_to_masked_nullable=convert_to_masked_nullable, + ) + + @pytest.mark.parametrize("convert_to_masked_nullable", [True, False]) + def test_maybe_convert_numeric_post_floatify_nan( + self, coerce, convert_to_masked_nullable + ): + # see gh-13314 + data = np.array(["1.200", "-999.000", "4.500"], dtype=object) + expected = np.array([1.2, np.nan, 4.5], dtype=np.float64) + nan_values = {-999, -999.0} + + out = lib.maybe_convert_numeric( + data, + nan_values, + coerce, + convert_to_masked_nullable=convert_to_masked_nullable, + ) + if convert_to_masked_nullable: + expected = FloatingArray(expected, np.isnan(expected)) + tm.assert_extension_array_equal(expected, FloatingArray(*out)) + else: + out = out[0] + tm.assert_numpy_array_equal(out, expected) + + def test_convert_infs(self): + arr = np.array(["inf", "inf", "inf"], dtype="O") + result, _ = lib.maybe_convert_numeric(arr, set(), False) + assert result.dtype == np.float64 + + arr = np.array(["-inf", "-inf", "-inf"], dtype="O") + result, _ = lib.maybe_convert_numeric(arr, set(), False) + assert result.dtype == np.float64 + + def test_scientific_no_exponent(self): + # See PR 12215 + arr = np.array(["42E", "2E", "99e", "6e"], dtype="O") + result, _ = lib.maybe_convert_numeric(arr, set(), False, True) + assert np.all(np.isnan(result)) + + def test_convert_non_hashable(self): + # GH13324 + # make sure that we are handing non-hashables + arr = np.array([[10.0, 2], 1.0, "apple"], dtype=object) + result, _ = lib.maybe_convert_numeric(arr, set(), False, True) + tm.assert_numpy_array_equal(result, np.array([np.nan, 1.0, np.nan])) + + def test_convert_numeric_uint64(self): + arr = np.array([2**63], dtype=object) + exp = np.array([2**63], dtype=np.uint64) + tm.assert_numpy_array_equal(lib.maybe_convert_numeric(arr, set())[0], exp) + + arr = np.array([str(2**63)], dtype=object) + exp = np.array([2**63], dtype=np.uint64) + tm.assert_numpy_array_equal(lib.maybe_convert_numeric(arr, set())[0], exp) + + arr = np.array([np.uint64(2**63)], dtype=object) + exp = np.array([2**63], dtype=np.uint64) + tm.assert_numpy_array_equal(lib.maybe_convert_numeric(arr, set())[0], exp) + + @pytest.mark.parametrize( + "arr", + [ + np.array([2**63, np.nan], dtype=object), + np.array([str(2**63), np.nan], dtype=object), + np.array([np.nan, 2**63], dtype=object), + np.array([np.nan, str(2**63)], dtype=object), + ], + ) + def test_convert_numeric_uint64_nan(self, coerce, arr): + expected = arr.astype(float) if coerce else arr.copy() + result, _ = lib.maybe_convert_numeric(arr, set(), coerce_numeric=coerce) + tm.assert_almost_equal(result, expected) + + @pytest.mark.parametrize("convert_to_masked_nullable", [True, False]) + def test_convert_numeric_uint64_nan_values( + self, coerce, convert_to_masked_nullable + ): + arr = np.array([2**63, 2**63 + 1], dtype=object) + na_values = {2**63} + + expected = np.array([np.nan, 2**63 + 1], dtype=float) if coerce else arr.copy() + result = lib.maybe_convert_numeric( + arr, + na_values, + coerce_numeric=coerce, + convert_to_masked_nullable=convert_to_masked_nullable, + ) + if convert_to_masked_nullable and coerce: + expected = IntegerArray( + np.array([0, 2**63 + 1], dtype="u8"), + np.array([True, False], dtype="bool"), + ) + result = IntegerArray(*result) + else: + result = result[0] # discard mask + tm.assert_almost_equal(result, expected) + + @pytest.mark.parametrize( + "case", + [ + np.array([2**63, -1], dtype=object), + np.array([str(2**63), -1], dtype=object), + np.array([str(2**63), str(-1)], dtype=object), + np.array([-1, 2**63], dtype=object), + np.array([-1, str(2**63)], dtype=object), + np.array([str(-1), str(2**63)], dtype=object), + ], + ) + @pytest.mark.parametrize("convert_to_masked_nullable", [True, False]) + def test_convert_numeric_int64_uint64( + self, case, coerce, convert_to_masked_nullable + ): + expected = case.astype(float) if coerce else case.copy() + result, _ = lib.maybe_convert_numeric( + case, + set(), + coerce_numeric=coerce, + convert_to_masked_nullable=convert_to_masked_nullable, + ) + + tm.assert_almost_equal(result, expected) + + @pytest.mark.parametrize("convert_to_masked_nullable", [True, False]) + def test_convert_numeric_string_uint64(self, convert_to_masked_nullable): + # GH32394 + result = lib.maybe_convert_numeric( + np.array(["uint64"], dtype=object), + set(), + coerce_numeric=True, + convert_to_masked_nullable=convert_to_masked_nullable, + ) + if convert_to_masked_nullable: + result = FloatingArray(*result) + else: + result = result[0] + assert np.isnan(result) + + @pytest.mark.parametrize("value", [-(2**63) - 1, 2**64]) + def test_convert_int_overflow(self, value): + # see gh-18584 + arr = np.array([value], dtype=object) + result = lib.maybe_convert_objects(arr) + tm.assert_numpy_array_equal(arr, result) + + @pytest.mark.parametrize( + "value, expected_value", + [ + (-(1 << 65), -(1 << 65)), + (1 << 65, 1 << 65), + (str(1 << 65), 1 << 65), + (f"-{1 << 65}", -(1 << 65)), + ], + ) + @pytest.mark.parametrize("coerce_numeric", [False, True]) + def test_convert_numeric_overflow(self, value, expected_value, coerce_numeric): + arr = np.array([value], dtype=object) + expected = np.array([expected_value], dtype=float if coerce_numeric else object) + result, _ = lib.maybe_convert_numeric( + arr, + set(), + coerce_numeric=coerce_numeric, + ) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("val", [None, np.nan, float("nan")]) + @pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"]) + def test_maybe_convert_objects_nat_inference(self, val, dtype): + dtype = np.dtype(dtype) + vals = np.array([pd.NaT, val], dtype=object) + result = lib.maybe_convert_objects( + vals, + convert_non_numeric=True, + dtype_if_all_nat=dtype, + ) + assert result.dtype == dtype + assert np.isnat(result).all() + + result = lib.maybe_convert_objects( + vals[::-1], + convert_non_numeric=True, + dtype_if_all_nat=dtype, + ) + assert result.dtype == dtype + assert np.isnat(result).all() + + @pytest.mark.parametrize( + "value, expected_dtype", + [ + # see gh-4471 + ([2**63], np.uint64), + # NumPy bug: can't compare uint64 to int64, as that + # results in both casting to float64, so we should + # make sure that this function is robust against it + ([np.uint64(2**63)], np.uint64), + ([2, -1], np.int64), + ([2**63, -1], object), + # GH#47294 + ([np.uint8(1)], np.uint8), + ([np.uint16(1)], np.uint16), + ([np.uint32(1)], np.uint32), + ([np.uint64(1)], np.uint64), + ([np.uint8(2), np.uint16(1)], np.uint16), + ([np.uint32(2), np.uint16(1)], np.uint32), + ([np.uint32(2), -1], object), + ([np.uint32(2), 1], np.uint64), + ([np.uint32(2), np.int32(1)], object), + ], + ) + def test_maybe_convert_objects_uint(self, value, expected_dtype): + arr = np.array(value, dtype=object) + exp = np.array(value, dtype=expected_dtype) + tm.assert_numpy_array_equal(lib.maybe_convert_objects(arr), exp) + + def test_maybe_convert_objects_datetime(self): + # GH27438 + arr = np.array( + [np.datetime64("2000-01-01"), np.timedelta64(1, "s")], dtype=object + ) + exp = arr.copy() + out = lib.maybe_convert_objects(arr, convert_non_numeric=True) + tm.assert_numpy_array_equal(out, exp) + + arr = np.array([pd.NaT, np.timedelta64(1, "s")], dtype=object) + exp = np.array([np.timedelta64("NaT"), np.timedelta64(1, "s")], dtype="m8[s]") + out = lib.maybe_convert_objects(arr, convert_non_numeric=True) + tm.assert_numpy_array_equal(out, exp) + + # with convert_non_numeric=True, the nan is a valid NA value for td64 + arr = np.array([np.timedelta64(1, "s"), np.nan], dtype=object) + exp = exp[::-1] + out = lib.maybe_convert_objects(arr, convert_non_numeric=True) + tm.assert_numpy_array_equal(out, exp) + + def test_maybe_convert_objects_dtype_if_all_nat(self): + arr = np.array([pd.NaT, pd.NaT], dtype=object) + out = lib.maybe_convert_objects(arr, convert_non_numeric=True) + # no dtype_if_all_nat passed -> we dont guess + tm.assert_numpy_array_equal(out, arr) + + out = lib.maybe_convert_objects( + arr, + convert_non_numeric=True, + dtype_if_all_nat=np.dtype("timedelta64[ns]"), + ) + exp = np.array(["NaT", "NaT"], dtype="timedelta64[ns]") + tm.assert_numpy_array_equal(out, exp) + + out = lib.maybe_convert_objects( + arr, + convert_non_numeric=True, + dtype_if_all_nat=np.dtype("datetime64[ns]"), + ) + exp = np.array(["NaT", "NaT"], dtype="datetime64[ns]") + tm.assert_numpy_array_equal(out, exp) + + def test_maybe_convert_objects_dtype_if_all_nat_invalid(self): + # we accept datetime64[ns], timedelta64[ns], and EADtype + arr = np.array([pd.NaT, pd.NaT], dtype=object) + + with pytest.raises(ValueError, match="int64"): + lib.maybe_convert_objects( + arr, + convert_non_numeric=True, + dtype_if_all_nat=np.dtype("int64"), + ) + + @pytest.mark.parametrize("dtype", ["datetime64[ns]", "timedelta64[ns]"]) + def test_maybe_convert_objects_datetime_overflow_safe(self, dtype): + stamp = datetime(2363, 10, 4) # Enterprise-D launch date + if dtype == "timedelta64[ns]": + stamp = stamp - datetime(1970, 1, 1) + arr = np.array([stamp], dtype=object) + + out = lib.maybe_convert_objects(arr, convert_non_numeric=True) + # no OutOfBoundsDatetime/OutOfBoundsTimedeltas + if dtype == "datetime64[ns]": + expected = np.array(["2363-10-04"], dtype="M8[us]") + else: + expected = arr.astype("m8[us]") + tm.assert_numpy_array_equal(out, expected) + + def test_maybe_convert_objects_mixed_datetimes(self): + ts = Timestamp("now") + vals = [ts, ts.to_pydatetime(), ts.to_datetime64(), pd.NaT, np.nan, None] + + for data in itertools.permutations(vals): + data = np.array(list(data), dtype=object) + expected = DatetimeIndex(data)._data._ndarray + result = lib.maybe_convert_objects(data, convert_non_numeric=True) + tm.assert_numpy_array_equal(result, expected) + + def test_maybe_convert_objects_timedelta64_nat(self): + obj = np.timedelta64("NaT", "ns") + arr = np.array([obj], dtype=object) + assert arr[0] is obj + + result = lib.maybe_convert_objects(arr, convert_non_numeric=True) + + expected = np.array([obj], dtype="m8[ns]") + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "exp", + [ + IntegerArray(np.array([2, 0], dtype="i8"), np.array([False, True])), + IntegerArray(np.array([2, 0], dtype="int64"), np.array([False, True])), + ], + ) + def test_maybe_convert_objects_nullable_integer(self, exp): + # GH27335 + arr = np.array([2, np.nan], dtype=object) + result = lib.maybe_convert_objects(arr, convert_to_nullable_dtype=True) + + tm.assert_extension_array_equal(result, exp) + + @pytest.mark.parametrize( + "dtype, val", [("int64", 1), ("uint64", np.iinfo(np.int64).max + 1)] + ) + def test_maybe_convert_objects_nullable_none(self, dtype, val): + # GH#50043 + arr = np.array([val, None, 3], dtype="object") + result = lib.maybe_convert_objects(arr, convert_to_nullable_dtype=True) + expected = IntegerArray( + np.array([val, 0, 3], dtype=dtype), np.array([False, True, False]) + ) + tm.assert_extension_array_equal(result, expected) + + @pytest.mark.parametrize( + "convert_to_masked_nullable, exp", + [ + (True, IntegerArray(np.array([2, 0], dtype="i8"), np.array([False, True]))), + (False, np.array([2, np.nan], dtype="float64")), + ], + ) + def test_maybe_convert_numeric_nullable_integer( + self, convert_to_masked_nullable, exp + ): + # GH 40687 + arr = np.array([2, np.nan], dtype=object) + result = lib.maybe_convert_numeric( + arr, set(), convert_to_masked_nullable=convert_to_masked_nullable + ) + if convert_to_masked_nullable: + result = IntegerArray(*result) + tm.assert_extension_array_equal(result, exp) + else: + result = result[0] + tm.assert_numpy_array_equal(result, exp) + + @pytest.mark.parametrize( + "convert_to_masked_nullable, exp", + [ + ( + True, + FloatingArray( + np.array([2.0, 0.0], dtype="float64"), np.array([False, True]) + ), + ), + (False, np.array([2.0, np.nan], dtype="float64")), + ], + ) + def test_maybe_convert_numeric_floating_array( + self, convert_to_masked_nullable, exp + ): + # GH 40687 + arr = np.array([2.0, np.nan], dtype=object) + result = lib.maybe_convert_numeric( + arr, set(), convert_to_masked_nullable=convert_to_masked_nullable + ) + if convert_to_masked_nullable: + tm.assert_extension_array_equal(FloatingArray(*result), exp) + else: + result = result[0] + tm.assert_numpy_array_equal(result, exp) + + def test_maybe_convert_objects_bool_nan(self): + # GH32146 + ind = Index([True, False, np.nan], dtype=object) + exp = np.array([True, False, np.nan], dtype=object) + out = lib.maybe_convert_objects(ind.values, safe=1) + tm.assert_numpy_array_equal(out, exp) + + def test_maybe_convert_objects_nullable_boolean(self): + # GH50047 + arr = np.array([True, False], dtype=object) + exp = BooleanArray._from_sequence([True, False], dtype="boolean") + out = lib.maybe_convert_objects(arr, convert_to_nullable_dtype=True) + tm.assert_extension_array_equal(out, exp) + + arr = np.array([True, False, pd.NaT], dtype=object) + exp = np.array([True, False, pd.NaT], dtype=object) + out = lib.maybe_convert_objects(arr, convert_to_nullable_dtype=True) + tm.assert_numpy_array_equal(out, exp) + + @pytest.mark.parametrize("val", [None, np.nan]) + def test_maybe_convert_objects_nullable_boolean_na(self, val): + # GH50047 + arr = np.array([True, False, val], dtype=object) + exp = BooleanArray( + np.array([True, False, False]), np.array([False, False, True]) + ) + out = lib.maybe_convert_objects(arr, convert_to_nullable_dtype=True) + tm.assert_extension_array_equal(out, exp) + + @pytest.mark.parametrize( + "data0", + [ + True, + 1, + 1.0, + 1.0 + 1.0j, + np.int8(1), + np.int16(1), + np.int32(1), + np.int64(1), + np.float16(1), + np.float32(1), + np.float64(1), + np.complex64(1), + np.complex128(1), + ], + ) + @pytest.mark.parametrize( + "data1", + [ + True, + 1, + 1.0, + 1.0 + 1.0j, + np.int8(1), + np.int16(1), + np.int32(1), + np.int64(1), + np.float16(1), + np.float32(1), + np.float64(1), + np.complex64(1), + np.complex128(1), + ], + ) + def test_maybe_convert_objects_itemsize(self, data0, data1): + # GH 40908 + data = [data0, data1] + arr = np.array(data, dtype="object") + + common_kind = np.result_type(type(data0), type(data1)).kind + kind0 = "python" if not hasattr(data0, "dtype") else data0.dtype.kind + kind1 = "python" if not hasattr(data1, "dtype") else data1.dtype.kind + if kind0 != "python" and kind1 != "python": + kind = common_kind + itemsize = max(data0.dtype.itemsize, data1.dtype.itemsize) + elif is_bool(data0) or is_bool(data1): + kind = "bool" if (is_bool(data0) and is_bool(data1)) else "object" + itemsize = "" + elif is_complex(data0) or is_complex(data1): + kind = common_kind + itemsize = 16 + else: + kind = common_kind + itemsize = 8 + + expected = np.array(data, dtype=f"{kind}{itemsize}") + result = lib.maybe_convert_objects(arr) + tm.assert_numpy_array_equal(result, expected) + + def test_mixed_dtypes_remain_object_array(self): + # GH14956 + arr = np.array([datetime(2015, 1, 1, tzinfo=timezone.utc), 1], dtype=object) + result = lib.maybe_convert_objects(arr, convert_non_numeric=True) + tm.assert_numpy_array_equal(result, arr) + + @pytest.mark.parametrize( + "idx", + [ + pd.IntervalIndex.from_breaks(range(5), closed="both"), + pd.period_range("2016-01-01", periods=3, freq="D"), + ], + ) + def test_maybe_convert_objects_ea(self, idx): + result = lib.maybe_convert_objects( + np.array(idx, dtype=object), + convert_non_numeric=True, + ) + tm.assert_extension_array_equal(result, idx._data) + + +class TestTypeInference: + # Dummy class used for testing with Python objects + class Dummy: + pass + + def test_inferred_dtype_fixture(self, any_skipna_inferred_dtype): + # see pandas/conftest.py + inferred_dtype, values = any_skipna_inferred_dtype + + # make sure the inferred dtype of the fixture is as requested + assert inferred_dtype == lib.infer_dtype(values, skipna=True) + + def test_length_zero(self, skipna): + result = lib.infer_dtype(np.array([], dtype="i4"), skipna=skipna) + assert result == "integer" + + result = lib.infer_dtype([], skipna=skipna) + assert result == "empty" + + # GH 18004 + arr = np.array([np.array([], dtype=object), np.array([], dtype=object)]) + result = lib.infer_dtype(arr, skipna=skipna) + assert result == "empty" + + def test_integers(self): + arr = np.array([1, 2, 3, np.int64(4), np.int32(5)], dtype="O") + result = lib.infer_dtype(arr, skipna=True) + assert result == "integer" + + arr = np.array([1, 2, 3, np.int64(4), np.int32(5), "foo"], dtype="O") + result = lib.infer_dtype(arr, skipna=True) + assert result == "mixed-integer" + + arr = np.array([1, 2, 3, 4, 5], dtype="i4") + result = lib.infer_dtype(arr, skipna=True) + assert result == "integer" + + @pytest.mark.parametrize( + "arr, skipna", + [ + ([1, 2, np.nan, np.nan, 3], False), + ([1, 2, np.nan, np.nan, 3], True), + ([1, 2, 3, np.int64(4), np.int32(5), np.nan], False), + ([1, 2, 3, np.int64(4), np.int32(5), np.nan], True), + ], + ) + def test_integer_na(self, arr, skipna): + # GH 27392 + result = lib.infer_dtype(np.array(arr, dtype="O"), skipna=skipna) + expected = "integer" if skipna else "integer-na" + assert result == expected + + def test_infer_dtype_skipna_default(self): + # infer_dtype `skipna` default deprecated in GH#24050, + # changed to True in GH#29876 + arr = np.array([1, 2, 3, np.nan], dtype=object) + + result = lib.infer_dtype(arr) + assert result == "integer" + + def test_bools(self): + arr = np.array([True, False, True, True, True], dtype="O") + result = lib.infer_dtype(arr, skipna=True) + assert result == "boolean" + + arr = np.array([np.bool_(True), np.bool_(False)], dtype="O") + result = lib.infer_dtype(arr, skipna=True) + assert result == "boolean" + + arr = np.array([True, False, True, "foo"], dtype="O") + result = lib.infer_dtype(arr, skipna=True) + assert result == "mixed" + + arr = np.array([True, False, True], dtype=bool) + result = lib.infer_dtype(arr, skipna=True) + assert result == "boolean" + + arr = np.array([True, np.nan, False], dtype="O") + result = lib.infer_dtype(arr, skipna=True) + assert result == "boolean" + + result = lib.infer_dtype(arr, skipna=False) + assert result == "mixed" + + def test_floats(self): + arr = np.array([1.0, 2.0, 3.0, np.float64(4), np.float32(5)], dtype="O") + result = lib.infer_dtype(arr, skipna=True) + assert result == "floating" + + arr = np.array([1, 2, 3, np.float64(4), np.float32(5), "foo"], dtype="O") + result = lib.infer_dtype(arr, skipna=True) + assert result == "mixed-integer" + + arr = np.array([1, 2, 3, 4, 5], dtype="f4") + result = lib.infer_dtype(arr, skipna=True) + assert result == "floating" + + arr = np.array([1, 2, 3, 4, 5], dtype="f8") + result = lib.infer_dtype(arr, skipna=True) + assert result == "floating" + + def test_decimals(self): + # GH15690 + arr = np.array([Decimal(1), Decimal(2), Decimal(3)]) + result = lib.infer_dtype(arr, skipna=True) + assert result == "decimal" + + arr = np.array([1.0, 2.0, Decimal(3)]) + result = lib.infer_dtype(arr, skipna=True) + assert result == "mixed" + + result = lib.infer_dtype(arr[::-1], skipna=True) + assert result == "mixed" + + arr = np.array([Decimal(1), Decimal("NaN"), Decimal(3)]) + result = lib.infer_dtype(arr, skipna=True) + assert result == "decimal" + + arr = np.array([Decimal(1), np.nan, Decimal(3)], dtype="O") + result = lib.infer_dtype(arr, skipna=True) + assert result == "decimal" + + # complex is compatible with nan, so skipna has no effect + def test_complex(self, skipna): + # gets cast to complex on array construction + arr = np.array([1.0, 2.0, 1 + 1j]) + result = lib.infer_dtype(arr, skipna=skipna) + assert result == "complex" + + arr = np.array([1.0, 2.0, 1 + 1j], dtype="O") + result = lib.infer_dtype(arr, skipna=skipna) + assert result == "mixed" + + result = lib.infer_dtype(arr[::-1], skipna=skipna) + assert result == "mixed" + + # gets cast to complex on array construction + arr = np.array([1, np.nan, 1 + 1j]) + result = lib.infer_dtype(arr, skipna=skipna) + assert result == "complex" + + arr = np.array([1.0, np.nan, 1 + 1j], dtype="O") + result = lib.infer_dtype(arr, skipna=skipna) + assert result == "mixed" + + # complex with nans stays complex + arr = np.array([1 + 1j, np.nan, 3 + 3j], dtype="O") + result = lib.infer_dtype(arr, skipna=skipna) + assert result == "complex" + + # test smaller complex dtype; will pass through _try_infer_map fastpath + arr = np.array([1 + 1j, np.nan, 3 + 3j], dtype=np.complex64) + result = lib.infer_dtype(arr, skipna=skipna) + assert result == "complex" + + def test_string(self): + pass + + def test_unicode(self): + arr = ["a", np.nan, "c"] + result = lib.infer_dtype(arr, skipna=False) + # This currently returns "mixed", but it's not clear that's optimal. + # This could also return "string" or "mixed-string" + assert result == "mixed" + + # even though we use skipna, we are only skipping those NAs that are + # considered matching by is_string_array + arr = ["a", np.nan, "c"] + result = lib.infer_dtype(arr, skipna=True) + assert result == "string" + + arr = ["a", pd.NA, "c"] + result = lib.infer_dtype(arr, skipna=True) + assert result == "string" + + arr = ["a", pd.NaT, "c"] + result = lib.infer_dtype(arr, skipna=True) + assert result == "mixed" + + arr = ["a", "c"] + result = lib.infer_dtype(arr, skipna=False) + assert result == "string" + + @pytest.mark.parametrize( + "dtype, missing, skipna, expected", + [ + (float, np.nan, False, "floating"), + (float, np.nan, True, "floating"), + (object, np.nan, False, "floating"), + (object, np.nan, True, "empty"), + (object, None, False, "mixed"), + (object, None, True, "empty"), + ], + ) + @pytest.mark.parametrize("box", [Series, np.array]) + def test_object_empty(self, box, missing, dtype, skipna, expected): + # GH 23421 + arr = box([missing, missing], dtype=dtype) + + result = lib.infer_dtype(arr, skipna=skipna) + assert result == expected + + def test_datetime(self): + dates = [datetime(2012, 1, x) for x in range(1, 20)] + index = Index(dates) + assert index.inferred_type == "datetime64" + + def test_infer_dtype_datetime64(self): + arr = np.array( + [np.datetime64("2011-01-01"), np.datetime64("2011-01-01")], dtype=object + ) + assert lib.infer_dtype(arr, skipna=True) == "datetime64" + + @pytest.mark.parametrize("na_value", [pd.NaT, np.nan]) + def test_infer_dtype_datetime64_with_na(self, na_value): + # starts with nan + arr = np.array([na_value, np.datetime64("2011-01-02")]) + assert lib.infer_dtype(arr, skipna=True) == "datetime64" + + arr = np.array([na_value, np.datetime64("2011-01-02"), na_value]) + assert lib.infer_dtype(arr, skipna=True) == "datetime64" + + @pytest.mark.parametrize( + "arr", + [ + np.array( + [np.timedelta64("nat"), np.datetime64("2011-01-02")], dtype=object + ), + np.array( + [np.datetime64("2011-01-02"), np.timedelta64("nat")], dtype=object + ), + np.array([np.datetime64("2011-01-01"), Timestamp("2011-01-02")]), + np.array([Timestamp("2011-01-02"), np.datetime64("2011-01-01")]), + np.array([np.nan, Timestamp("2011-01-02"), 1.1]), + np.array([np.nan, "2011-01-01", Timestamp("2011-01-02")], dtype=object), + np.array([np.datetime64("nat"), np.timedelta64(1, "D")], dtype=object), + np.array([np.timedelta64(1, "D"), np.datetime64("nat")], dtype=object), + ], + ) + def test_infer_datetimelike_dtype_mixed(self, arr): + assert lib.infer_dtype(arr, skipna=False) == "mixed" + + def test_infer_dtype_mixed_integer(self): + arr = np.array([np.nan, Timestamp("2011-01-02"), 1]) + assert lib.infer_dtype(arr, skipna=True) == "mixed-integer" + + @pytest.mark.parametrize( + "arr", + [ + [Timestamp("2011-01-01"), Timestamp("2011-01-02")], + [datetime(2011, 1, 1), datetime(2012, 2, 1)], + [datetime(2011, 1, 1), Timestamp("2011-01-02")], + ], + ) + def test_infer_dtype_datetime(self, arr): + assert lib.infer_dtype(np.array(arr), skipna=True) == "datetime" + + @pytest.mark.parametrize("na_value", [pd.NaT, np.nan]) + @pytest.mark.parametrize( + "time_stamp", [Timestamp("2011-01-01"), datetime(2011, 1, 1)] + ) + def test_infer_dtype_datetime_with_na(self, na_value, time_stamp): + # starts with nan + arr = np.array([na_value, time_stamp]) + assert lib.infer_dtype(arr, skipna=True) == "datetime" + + arr = np.array([na_value, time_stamp, na_value]) + assert lib.infer_dtype(arr, skipna=True) == "datetime" + + @pytest.mark.parametrize( + "arr", + [ + np.array([Timedelta("1 days"), Timedelta("2 days")]), + np.array([np.timedelta64(1, "D"), np.timedelta64(2, "D")], dtype=object), + np.array([timedelta(1), timedelta(2)]), + ], + ) + def test_infer_dtype_timedelta(self, arr): + assert lib.infer_dtype(arr, skipna=True) == "timedelta" + + @pytest.mark.parametrize("na_value", [pd.NaT, np.nan]) + @pytest.mark.parametrize( + "delta", [Timedelta("1 days"), np.timedelta64(1, "D"), timedelta(1)] + ) + def test_infer_dtype_timedelta_with_na(self, na_value, delta): + # starts with nan + arr = np.array([na_value, delta]) + assert lib.infer_dtype(arr, skipna=True) == "timedelta" + + arr = np.array([na_value, delta, na_value]) + assert lib.infer_dtype(arr, skipna=True) == "timedelta" + + def test_infer_dtype_period(self): + # GH 13664 + arr = np.array([Period("2011-01", freq="D"), Period("2011-02", freq="D")]) + assert lib.infer_dtype(arr, skipna=True) == "period" + + # non-homogeneous freqs -> mixed + arr = np.array([Period("2011-01", freq="D"), Period("2011-02", freq="M")]) + assert lib.infer_dtype(arr, skipna=True) == "mixed" + + def test_infer_dtype_period_array(self, index_or_series_or_array, skipna): + klass = index_or_series_or_array + # https://github.com/pandas-dev/pandas/issues/23553 + values = klass( + [ + Period("2011-01-01", freq="D"), + Period("2011-01-02", freq="D"), + pd.NaT, + ] + ) + assert lib.infer_dtype(values, skipna=skipna) == "period" + + # periods but mixed freq + values = klass( + [ + Period("2011-01-01", freq="D"), + Period("2011-01-02", freq="M"), + pd.NaT, + ] + ) + # with pd.array this becomes NumpyExtensionArray which ends up + # as "unknown-array" + exp = "unknown-array" if klass is pd.array else "mixed" + assert lib.infer_dtype(values, skipna=skipna) == exp + + def test_infer_dtype_period_mixed(self): + arr = np.array( + [Period("2011-01", freq="M"), np.datetime64("nat")], dtype=object + ) + assert lib.infer_dtype(arr, skipna=False) == "mixed" + + arr = np.array( + [np.datetime64("nat"), Period("2011-01", freq="M")], dtype=object + ) + assert lib.infer_dtype(arr, skipna=False) == "mixed" + + @pytest.mark.parametrize("na_value", [pd.NaT, np.nan]) + def test_infer_dtype_period_with_na(self, na_value): + # starts with nan + arr = np.array([na_value, Period("2011-01", freq="D")]) + assert lib.infer_dtype(arr, skipna=True) == "period" + + arr = np.array([na_value, Period("2011-01", freq="D"), na_value]) + assert lib.infer_dtype(arr, skipna=True) == "period" + + @pytest.mark.parametrize("na_value", [pd.NA, np.nan]) + def test_infer_dtype_numeric_with_na(self, na_value): + # GH61621 + ser = Series([1, 2, na_value], dtype=object) + assert lib.infer_dtype(ser, skipna=True) == "integer" + + ser = Series([1.0, 2.0, na_value], dtype=object) + assert lib.infer_dtype(ser, skipna=True) == "floating" + + # GH#61976 + ser = Series([1 + 1j, na_value], dtype=object) + assert lib.infer_dtype(ser, skipna=True) == "complex" + + def test_infer_dtype_all_nan_nat_like(self): + arr = np.array([np.nan, np.nan]) + assert lib.infer_dtype(arr, skipna=True) == "floating" + + # nan and None mix are result in mixed + arr = np.array([np.nan, np.nan, None]) + assert lib.infer_dtype(arr, skipna=True) == "empty" + assert lib.infer_dtype(arr, skipna=False) == "mixed" + + arr = np.array([None, np.nan, np.nan]) + assert lib.infer_dtype(arr, skipna=True) == "empty" + assert lib.infer_dtype(arr, skipna=False) == "mixed" + + # pd.NaT + arr = np.array([pd.NaT]) + assert lib.infer_dtype(arr, skipna=False) == "datetime" + + arr = np.array([pd.NaT, np.nan]) + assert lib.infer_dtype(arr, skipna=False) == "datetime" + + arr = np.array([np.nan, pd.NaT]) + assert lib.infer_dtype(arr, skipna=False) == "datetime" + + arr = np.array([np.nan, pd.NaT, np.nan]) + assert lib.infer_dtype(arr, skipna=False) == "datetime" + + arr = np.array([None, pd.NaT, None]) + assert lib.infer_dtype(arr, skipna=False) == "datetime" + + # np.datetime64(nat) + arr = np.array([np.datetime64("nat")]) + assert lib.infer_dtype(arr, skipna=False) == "datetime64" + + for n in [np.nan, pd.NaT, None]: + arr = np.array([n, np.datetime64("nat"), n]) + assert lib.infer_dtype(arr, skipna=False) == "datetime64" + + arr = np.array([pd.NaT, n, np.datetime64("nat"), n]) + assert lib.infer_dtype(arr, skipna=False) == "datetime64" + + arr = np.array([np.timedelta64("nat")], dtype=object) + assert lib.infer_dtype(arr, skipna=False) == "timedelta" + + for n in [np.nan, pd.NaT, None]: + arr = np.array([n, np.timedelta64("nat"), n]) + assert lib.infer_dtype(arr, skipna=False) == "timedelta" + + arr = np.array([pd.NaT, n, np.timedelta64("nat"), n]) + assert lib.infer_dtype(arr, skipna=False) == "timedelta" + + # datetime / timedelta mixed + arr = np.array([pd.NaT, np.datetime64("nat"), np.timedelta64("nat"), np.nan]) + assert lib.infer_dtype(arr, skipna=False) == "mixed" + + arr = np.array([np.timedelta64("nat"), np.datetime64("nat")], dtype=object) + assert lib.infer_dtype(arr, skipna=False) == "mixed" + + def test_is_datetimelike_array_all_nan_nat_like(self): + arr = np.array([np.nan, pd.NaT, np.datetime64("nat")]) + assert lib.is_datetime_array(arr) + assert lib.is_datetime64_array(arr) + assert not lib.is_timedelta_or_timedelta64_array(arr) + + arr = np.array([np.nan, pd.NaT, np.timedelta64("nat")]) + assert not lib.is_datetime_array(arr) + assert not lib.is_datetime64_array(arr) + assert lib.is_timedelta_or_timedelta64_array(arr) + + arr = np.array([np.nan, pd.NaT, np.datetime64("nat"), np.timedelta64("nat")]) + assert not lib.is_datetime_array(arr) + assert not lib.is_datetime64_array(arr) + assert not lib.is_timedelta_or_timedelta64_array(arr) + + arr = np.array([np.nan, pd.NaT]) + assert lib.is_datetime_array(arr) + assert lib.is_datetime64_array(arr) + assert lib.is_timedelta_or_timedelta64_array(arr) + + arr = np.array([np.nan, np.nan], dtype=object) + assert not lib.is_datetime_array(arr) + assert not lib.is_datetime64_array(arr) + assert not lib.is_timedelta_or_timedelta64_array(arr) + + assert lib.is_datetime_with_singletz_array( + np.array( + [ + Timestamp("20130101", tz="US/Eastern"), + Timestamp("20130102", tz="US/Eastern"), + ], + dtype=object, + ) + ) + assert not lib.is_datetime_with_singletz_array( + np.array( + [ + Timestamp("20130101", tz="US/Eastern"), + Timestamp("20130102", tz="CET"), + ], + dtype=object, + ) + ) + + @pytest.mark.parametrize( + "func", + [ + "is_datetime_array", + "is_datetime64_array", + "is_bool_array", + "is_timedelta_or_timedelta64_array", + "is_date_array", + "is_time_array", + "is_interval_array", + ], + ) + def test_other_dtypes_for_array(self, func): + func = getattr(lib, func) + arr = np.array(["foo", "bar"]) + assert not func(arr) + assert not func(arr.reshape(2, 1)) + + arr = np.array([1, 2]) + assert not func(arr) + assert not func(arr.reshape(2, 1)) + + def test_date(self): + dates = [date(2012, 1, day) for day in range(1, 20)] + index = Index(dates) + assert index.inferred_type == "date" + + dates = [date(2012, 1, day) for day in range(1, 20)] + [np.nan] + result = lib.infer_dtype(dates, skipna=False) + assert result == "mixed" + + result = lib.infer_dtype(dates, skipna=True) + assert result == "date" + + @pytest.mark.parametrize( + "values", + [ + [date(2020, 1, 1), Timestamp("2020-01-01")], + [Timestamp("2020-01-01"), date(2020, 1, 1)], + [date(2020, 1, 1), pd.NaT], + [pd.NaT, date(2020, 1, 1)], + ], + ) + def test_infer_dtype_date_order_invariant(self, values, skipna): + # https://github.com/pandas-dev/pandas/issues/33741 + result = lib.infer_dtype(values, skipna=skipna) + assert result == "date" + + def test_is_numeric_array(self): + assert lib.is_float_array(np.array([1, 2.0])) + assert lib.is_float_array(np.array([1, 2.0, np.nan])) + assert not lib.is_float_array(np.array([1, 2])) + + assert lib.is_integer_array(np.array([1, 2])) + assert not lib.is_integer_array(np.array([1, 2.0])) + + def test_is_string_array(self): + # We should only be accepting pd.NA, np.nan, + # other floating point nans e.g. float('nan')] + # when skipna is True. + assert lib.is_string_array(np.array(["foo", "bar"])) + assert not lib.is_string_array( + np.array(["foo", "bar", pd.NA], dtype=object), skipna=False + ) + assert lib.is_string_array( + np.array(["foo", "bar", pd.NA], dtype=object), skipna=True + ) + # we allow NaN/None in the StringArray constructor, so its allowed here + assert lib.is_string_array( + np.array(["foo", "bar", None], dtype=object), skipna=True + ) + assert lib.is_string_array( + np.array(["foo", "bar", np.nan], dtype=object), skipna=True + ) + # But not e.g. datetimelike or Decimal NAs + assert not lib.is_string_array( + np.array(["foo", "bar", pd.NaT], dtype=object), skipna=True + ) + assert not lib.is_string_array( + np.array(["foo", "bar", np.datetime64("NaT")], dtype=object), skipna=True + ) + assert not lib.is_string_array( + np.array(["foo", "bar", Decimal("NaN")], dtype=object), skipna=True + ) + + assert not lib.is_string_array( + np.array(["foo", "bar", None], dtype=object), skipna=False + ) + assert not lib.is_string_array( + np.array(["foo", "bar", np.nan], dtype=object), skipna=False + ) + assert not lib.is_string_array(np.array([1, 2])) + + def test_is_interval_array_subclass(self): + # GH#46945 + + class TimestampsInterval(Interval): + def __init__(self, left: str, right: str, closed="both") -> None: + super().__init__(Timestamp(left), Timestamp(right), closed) + + @property + def seconds(self) -> float: + return self.length.seconds + + item = TimestampsInterval("1970-01-01 00:00:00", "1970-01-01 00:00:01") + arr = np.array([item], dtype=object) + assert not lib.is_interval_array(arr) + assert lib.infer_dtype(arr) != "interval" + out = Series([item])[0] + assert isinstance(out, TimestampsInterval) + + @pytest.mark.parametrize( + "func", + [ + "is_bool_array", + "is_date_array", + "is_datetime_array", + "is_datetime64_array", + "is_float_array", + "is_integer_array", + "is_interval_array", + "is_string_array", + "is_time_array", + "is_timedelta_or_timedelta64_array", + ], + ) + def test_is_dtype_array_empty_obj(self, func): + # https://github.com/pandas-dev/pandas/pull/60796 + func = getattr(lib, func) + + arr = np.empty((2, 0), dtype=object) + assert not func(arr) + + arr = np.empty((0, 2), dtype=object) + assert not func(arr) + + def test_to_object_array_tuples(self): + r = (5, 6) + values = [r] + lib.to_object_array_tuples(values) + + # make sure record array works + record = namedtuple("record", "x y") + r = record(5, 6) + values = [r] + lib.to_object_array_tuples(values) + + def test_object(self): + # GH 7431 + # cannot infer more than this as only a single element + arr = np.array([None], dtype="O") + result = lib.infer_dtype(arr, skipna=False) + assert result == "mixed" + result = lib.infer_dtype(arr, skipna=True) + assert result == "empty" + + def test_to_object_array_width(self): + # see gh-13320 + rows = [[1, 2, 3], [4, 5, 6]] + + expected = np.array(rows, dtype=object) + out = lib.to_object_array(rows) + tm.assert_numpy_array_equal(out, expected) + + expected = np.array(rows, dtype=object) + out = lib.to_object_array(rows, min_width=1) + tm.assert_numpy_array_equal(out, expected) + + expected = np.array( + [[1, 2, 3, None, None], [4, 5, 6, None, None]], dtype=object + ) + out = lib.to_object_array(rows, min_width=5) + tm.assert_numpy_array_equal(out, expected) + + def test_categorical(self): + # GH 8974 + arr = Categorical(list("abc")) + result = lib.infer_dtype(arr, skipna=True) + assert result == "categorical" + + result = lib.infer_dtype(Series(arr), skipna=True) + assert result == "categorical" + + arr = Categorical([None, None, None], categories=["cegfab"], ordered=True) + result = lib.infer_dtype(arr, skipna=True) + assert result == "categorical" + + result = lib.infer_dtype(Series(arr), skipna=True) + assert result == "categorical" + + @pytest.mark.parametrize("asobject", [True, False]) + def test_interval(self, asobject): + idx = pd.IntervalIndex.from_breaks(range(5), closed="both") + if asobject: + idx = idx.astype(object) + + inferred = lib.infer_dtype(idx, skipna=False) + assert inferred == "interval" + + inferred = lib.infer_dtype(idx._data, skipna=False) + assert inferred == "interval" + + inferred = lib.infer_dtype(Series(idx, dtype=idx.dtype), skipna=False) + assert inferred == "interval" + + @pytest.mark.parametrize("value", [Timestamp(0), Timedelta(0), 0, 0.0]) + def test_interval_mismatched_closed(self, value): + first = Interval(value, value, closed="left") + second = Interval(value, value, closed="right") + + # if closed match, we should infer "interval" + arr = np.array([first, first], dtype=object) + assert lib.infer_dtype(arr, skipna=False) == "interval" + + # if closed dont match, we should _not_ get "interval" + arr2 = np.array([first, second], dtype=object) + assert lib.infer_dtype(arr2, skipna=False) == "mixed" + + def test_interval_mismatched_subtype(self): + first = Interval(0, 1, closed="left") + second = Interval(Timestamp(0), Timestamp(1), closed="left") + third = Interval(Timedelta(0), Timedelta(1), closed="left") + + arr = np.array([first, second]) + assert lib.infer_dtype(arr, skipna=False) == "mixed" + + arr = np.array([second, third]) + assert lib.infer_dtype(arr, skipna=False) == "mixed" + + arr = np.array([first, third]) + assert lib.infer_dtype(arr, skipna=False) == "mixed" + + # float vs int subdtype are compatible + flt_interval = Interval(1.5, 2.5, closed="left") + arr = np.array([first, flt_interval], dtype=object) + assert lib.infer_dtype(arr, skipna=False) == "interval" + + @pytest.mark.parametrize("data", [["a", "b", "c"], ["a", "b", pd.NA]]) + def test_string_dtype( + self, data, skipna, index_or_series_or_array, nullable_string_dtype + ): + # StringArray + val = index_or_series_or_array(data, dtype=nullable_string_dtype) + inferred = lib.infer_dtype(val, skipna=skipna) + assert inferred == "string" + + @pytest.mark.parametrize("data", [[True, False, True], [True, False, pd.NA]]) + def test_boolean_dtype(self, data, skipna, index_or_series_or_array): + # BooleanArray + val = index_or_series_or_array(data, dtype="boolean") + inferred = lib.infer_dtype(val, skipna=skipna) + assert inferred == "boolean" + + +class TestNumberScalar: + def test_is_number(self): + assert is_number(True) + assert is_number(1) + assert is_number(1.1) + assert is_number(1 + 3j) + assert is_number(np.int64(1)) + assert is_number(np.float64(1.1)) + assert is_number(np.complex128(1 + 3j)) + assert is_number(np.nan) + + assert not is_number(None) + assert not is_number("x") + assert not is_number(datetime(2011, 1, 1)) + assert not is_number(np.datetime64("2011-01-01")) + assert not is_number(Timestamp("2011-01-01")) + assert not is_number(Timestamp("2011-01-01", tz="US/Eastern")) + assert not is_number(timedelta(1000)) + assert not is_number(Timedelta("1 days")) + + # questionable + assert not is_number(np.bool_(False)) + assert is_number(np.timedelta64(1, "D")) + + def test_is_bool(self): + assert is_bool(True) + assert is_bool(False) + assert is_bool(np.bool_(False)) + + assert not is_bool(1) + assert not is_bool(1.1) + assert not is_bool(1 + 3j) + assert not is_bool(np.int64(1)) + assert not is_bool(np.float64(1.1)) + assert not is_bool(np.complex128(1 + 3j)) + assert not is_bool(np.nan) + assert not is_bool(None) + assert not is_bool("x") + assert not is_bool(datetime(2011, 1, 1)) + assert not is_bool(np.datetime64("2011-01-01")) + assert not is_bool(Timestamp("2011-01-01")) + assert not is_bool(Timestamp("2011-01-01", tz="US/Eastern")) + assert not is_bool(timedelta(1000)) + assert not is_bool(np.timedelta64(1, "D")) + assert not is_bool(Timedelta("1 days")) + + def test_is_integer(self): + assert is_integer(1) + assert is_integer(np.int64(1)) + + assert not is_integer(True) + assert not is_integer(1.1) + assert not is_integer(1 + 3j) + assert not is_integer(False) + assert not is_integer(np.bool_(False)) + assert not is_integer(np.float64(1.1)) + assert not is_integer(np.complex128(1 + 3j)) + assert not is_integer(np.nan) + assert not is_integer(None) + assert not is_integer("x") + assert not is_integer(datetime(2011, 1, 1)) + assert not is_integer(np.datetime64("2011-01-01")) + assert not is_integer(Timestamp("2011-01-01")) + assert not is_integer(Timestamp("2011-01-01", tz="US/Eastern")) + assert not is_integer(timedelta(1000)) + assert not is_integer(Timedelta("1 days")) + assert not is_integer(np.timedelta64(1, "D")) + + def test_is_float(self): + assert is_float(1.1) + assert is_float(np.float64(1.1)) + assert is_float(np.nan) + + assert not is_float(True) + assert not is_float(1) + assert not is_float(1 + 3j) + assert not is_float(False) + assert not is_float(np.bool_(False)) + assert not is_float(np.int64(1)) + assert not is_float(np.complex128(1 + 3j)) + assert not is_float(None) + assert not is_float("x") + assert not is_float(datetime(2011, 1, 1)) + assert not is_float(np.datetime64("2011-01-01")) + assert not is_float(Timestamp("2011-01-01")) + assert not is_float(Timestamp("2011-01-01", tz="US/Eastern")) + assert not is_float(timedelta(1000)) + assert not is_float(np.timedelta64(1, "D")) + assert not is_float(Timedelta("1 days")) + + def test_is_datetime_dtypes(self): + ts = pd.date_range("20130101", periods=3, unit="ns") + tsa = pd.date_range("20130101", periods=3, tz="US/Eastern", unit="ns") + + msg = "is_datetime64tz_dtype is deprecated" + + assert is_datetime64_dtype("datetime64") + assert is_datetime64_dtype("datetime64[ns]") + assert is_datetime64_dtype(ts) + assert not is_datetime64_dtype(tsa) + + assert not is_datetime64_ns_dtype("datetime64") + assert is_datetime64_ns_dtype("datetime64[ns]") + assert is_datetime64_ns_dtype(ts) + assert is_datetime64_ns_dtype(tsa) + + assert is_datetime64_any_dtype("datetime64") + assert is_datetime64_any_dtype("datetime64[ns]") + assert is_datetime64_any_dtype(ts) + assert is_datetime64_any_dtype(tsa) + + with tm.assert_produces_warning(Pandas4Warning, match=msg): + assert not is_datetime64tz_dtype("datetime64") + assert not is_datetime64tz_dtype("datetime64[ns]") + assert not is_datetime64tz_dtype(ts) + assert is_datetime64tz_dtype(tsa) + + @pytest.mark.parametrize("tz", ["US/Eastern", "UTC"]) + def test_is_datetime_dtypes_with_tz(self, tz): + dtype = f"datetime64[ns, {tz}]" + assert not is_datetime64_dtype(dtype) + + msg = "is_datetime64tz_dtype is deprecated" + with tm.assert_produces_warning(DeprecationWarning, match=msg): + assert is_datetime64tz_dtype(dtype) + assert is_datetime64_ns_dtype(dtype) + assert is_datetime64_any_dtype(dtype) + + def test_is_timedelta(self): + assert is_timedelta64_dtype("timedelta64") + assert is_timedelta64_dtype("timedelta64[ns]") + assert not is_timedelta64_ns_dtype("timedelta64") + assert is_timedelta64_ns_dtype("timedelta64[ns]") + + tdi = TimedeltaIndex([1e14, 2e14], dtype="timedelta64[ns]") + assert is_timedelta64_dtype(tdi) + assert is_timedelta64_ns_dtype(tdi) + assert is_timedelta64_ns_dtype(tdi.astype("timedelta64[ns]")) + + assert not is_timedelta64_ns_dtype(Index([], dtype=np.float64)) + assert not is_timedelta64_ns_dtype(Index([], dtype=np.int64)) + + +class TestIsScalar: + def test_is_scalar_builtin_scalars(self): + assert is_scalar(None) + assert is_scalar(True) + assert is_scalar(False) + assert is_scalar(Fraction()) + assert is_scalar(0.0) + assert is_scalar(1) + assert is_scalar(complex(2)) + assert is_scalar(float("NaN")) + assert is_scalar(np.nan) + assert is_scalar("foobar") + assert is_scalar(b"foobar") + assert is_scalar(datetime(2014, 1, 1)) + assert is_scalar(date(2014, 1, 1)) + assert is_scalar(time(12, 0)) + assert is_scalar(timedelta(hours=1)) + assert is_scalar(pd.NaT) + assert is_scalar(pd.NA) + + def test_is_scalar_builtin_nonscalars(self): + assert not is_scalar({}) + assert not is_scalar([]) + assert not is_scalar([1]) + assert not is_scalar(()) + assert not is_scalar((1,)) + assert not is_scalar(slice(None)) + assert not is_scalar(Ellipsis) + + def test_is_scalar_numpy_array_scalars(self): + assert is_scalar(np.int64(1)) + assert is_scalar(np.float64(1.0)) + assert is_scalar(np.int32(1)) + assert is_scalar(np.complex64(2)) + assert is_scalar(np.object_("foobar")) + assert is_scalar(np.str_("foobar")) + assert is_scalar(np.bytes_(b"foobar")) + assert is_scalar(np.datetime64("2014-01-01")) + assert is_scalar(np.timedelta64(1, "h")) + + @pytest.mark.parametrize( + "zerodim", + [ + 1, + "foobar", + np.datetime64("2014-01-01"), + np.timedelta64(1, "h"), + np.datetime64("NaT"), + ], + ) + def test_is_scalar_numpy_zerodim_arrays(self, zerodim): + zerodim = np.array(zerodim) + assert not is_scalar(zerodim) + assert is_scalar(lib.item_from_zerodim(zerodim)) + + @pytest.mark.parametrize("arr", [np.array([]), np.array([[]])]) + def test_is_scalar_numpy_arrays(self, arr): + assert not is_scalar(arr) + assert not is_scalar(MockNumpyLikeArray(arr)) + + def test_is_scalar_pandas_scalars(self): + assert is_scalar(Timestamp("2014-01-01")) + assert is_scalar(Timedelta(hours=1)) + assert is_scalar(Period("2014-01-01")) + assert is_scalar(Interval(left=0, right=1)) + assert is_scalar(DateOffset(days=1)) + assert is_scalar(pd.offsets.Minute(3)) + + def test_is_scalar_pandas_containers(self): + assert not is_scalar(Series(dtype=object)) + assert not is_scalar(Series([1])) + assert not is_scalar(DataFrame()) + assert not is_scalar(DataFrame([[1]])) + assert not is_scalar(Index([])) + assert not is_scalar(Index([1])) + assert not is_scalar(Categorical([])) + assert not is_scalar(DatetimeIndex([])._data) + assert not is_scalar(TimedeltaIndex([])._data) + assert not is_scalar(DatetimeIndex([])._data.to_period("D")) + assert not is_scalar(pd.array([1, 2, 3])) + + def test_is_scalar_number(self): + # Number() is not recognized by PyNumber_Check, so by extension + # is not recognized by is_scalar, but instances of non-abstract + # subclasses are. + + class Numeric(Number): + def __init__(self, value) -> None: + self.value = value + + def __int__(self) -> int: + return self.value + + num = Numeric(1) + assert is_scalar(num) + + +@pytest.mark.parametrize("unit", ["ms", "us", "ns"]) +def test_datetimeindex_from_empty_datetime64_array(unit): + idx = DatetimeIndex(np.array([], dtype=f"datetime64[{unit}]")) + assert len(idx) == 0 + + +def test_nan_to_nat_conversions(): + df = DataFrame( + {"A": np.asarray(range(10), dtype="float64"), "B": Timestamp("20010101")} + ) + df.iloc[3:6, :] = np.nan + result = df.loc[4, "B"] + assert result is pd.NaT + + s = df["B"].copy() + s[8:9] = np.nan + assert s[8] is pd.NaT + + +@pytest.mark.filterwarnings("ignore::PendingDeprecationWarning") +@pytest.mark.parametrize("spmatrix", ["bsr", "coo", "csc", "csr", "dia", "dok", "lil"]) +def test_is_scipy_sparse(spmatrix): + sparse = pytest.importorskip("scipy.sparse") + + klass = getattr(sparse, spmatrix + "_matrix") + assert is_scipy_sparse(klass([[0, 1]])) + assert not is_scipy_sparse(np.array([1])) + + +def test_ensure_int32(): + values = np.arange(10, dtype=np.int32) + result = ensure_int32(values) + assert result.dtype == np.int32 + + values = np.arange(10, dtype=np.int64) + result = ensure_int32(values) + assert result.dtype == np.int32 + + +@pytest.mark.parametrize( + "right,result", + [ + (0, np.uint8), + (-1, np.int16), + (300, np.uint16), + # For floats, we just upcast directly to float64 instead of trying to + # find a smaller floating dtype + (300.0, np.uint16), # for integer floats, we convert them to ints + (300.1, np.float64), + (np.int16(300), np.int16 if np_version_gt2 else np.uint16), + ], +) +def test_find_result_type_uint_int(right, result): + left_dtype = np.dtype("uint8") + assert find_result_type(left_dtype, right) == result + + +@pytest.mark.parametrize( + "right,result", + [ + (0, np.int8), + (-1, np.int8), + (300, np.int16), + # For floats, we just upcast directly to float64 instead of trying to + # find a smaller floating dtype + (300.0, np.int16), # for integer floats, we convert them to ints + (300.1, np.float64), + (np.int16(300), np.int16), + ], +) +def test_find_result_type_int_int(right, result): + left_dtype = np.dtype("int8") + assert find_result_type(left_dtype, right) == result + + +@pytest.mark.parametrize( + "right,result", + [ + (300.0, np.float64), + (np.float32(300), np.float32), + ], +) +def test_find_result_type_floats(right, result): + left_dtype = np.dtype("float16") + assert find_result_type(left_dtype, right) == result diff --git a/pandas/tests/dtypes/test_missing.py b/pandas/tests/dtypes/test_missing.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b22ac30d820b4d1c2178bf3a532fafc9f917e3 --- /dev/null +++ b/pandas/tests/dtypes/test_missing.py @@ -0,0 +1,876 @@ +from datetime import datetime +from decimal import Decimal + +import numpy as np +import pytest + +from pandas._libs import missing as libmissing +from pandas._libs.tslibs import iNaT + +from pandas.core.dtypes.common import ( + is_float, + is_scalar, + pandas_dtype, +) +from pandas.core.dtypes.dtypes import ( + CategoricalDtype, + DatetimeTZDtype, + IntervalDtype, + PeriodDtype, +) +from pandas.core.dtypes.missing import ( + array_equivalent, + is_valid_na_for_dtype, + isna, + isnull, + na_value_for_dtype, + notna, + notnull, +) + +import pandas as pd +from pandas import ( + DatetimeIndex, + Index, + NaT, + Series, + TimedeltaIndex, + date_range, + period_range, +) +import pandas._testing as tm + +fix_now = pd.Timestamp("2021-01-01") +fix_utcnow = pd.Timestamp("2021-01-01", tz="UTC") + + +@pytest.mark.parametrize("notna_f", [notna, notnull]) +def test_notna_notnull(notna_f): + assert notna_f(1.0) + assert not notna_f(None) + assert not notna_f(np.nan) + + +@pytest.mark.parametrize("null_func", [notna, notnull, isna, isnull]) +@pytest.mark.parametrize( + "ser", + [ + Series( + [str(i) for i in range(5)], + index=Index([str(i) for i in range(5)], dtype=object), + dtype=object, + ), + Series(range(5), date_range("2020-01-01", periods=5)), + Series(range(5), period_range("2020-01-01", periods=5)), + ], +) +def test_null_check_is_series(null_func, ser): + assert isinstance(null_func(ser), Series) + + +class TestIsNA: + def test_0d_array(self): + assert isna(np.array(np.nan)) + assert not isna(np.array(0.0)) + assert not isna(np.array(0)) + # test object dtype + assert isna(np.array(np.nan, dtype=object)) + assert not isna(np.array(0.0, dtype=object)) + assert not isna(np.array(0, dtype=object)) + + @pytest.mark.parametrize("shape", [(4, 0), (4,)]) + def test_empty_object(self, shape): + arr = np.empty(shape=shape, dtype=object) + result = isna(arr) + expected = np.ones(shape=shape, dtype=bool) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("isna_f", [isna, isnull]) + def test_isna_isnull(self, isna_f): + assert not isna_f(1.0) + assert isna_f(None) + assert isna_f(np.nan) + assert float("nan") + assert not isna_f(np.inf) + assert not isna_f(-np.inf) + + # type + assert not isna_f(type(Series(dtype=object))) + assert not isna_f(type(Series(dtype=np.float64))) + assert not isna_f(type(pd.DataFrame())) + + @pytest.mark.parametrize("isna_f", [isna, isnull]) + @pytest.mark.parametrize( + "data", + [ + np.arange(4, dtype=float), + [0.0, 1.0, 0.0, 1.0], + Series(list("abcd"), dtype=object), + date_range("2020-01-01", periods=4), + ], + ) + @pytest.mark.parametrize( + "index", + [ + date_range("2020-01-01", periods=4), + range(4), + period_range("2020-01-01", periods=4), + ], + ) + def test_isna_isnull_frame(self, isna_f, data, index): + # frame + df = pd.DataFrame(data, index=index) + result = isna_f(df) + expected = df.apply(isna_f) + tm.assert_frame_equal(result, expected) + + def test_isna_lists(self): + result = isna([[False]]) + exp = np.array([[False]]) + tm.assert_numpy_array_equal(result, exp) + + result = isna([[1], [2]]) + exp = np.array([[False], [False]]) + tm.assert_numpy_array_equal(result, exp) + + # list of strings / unicode + result = isna(["foo", "bar"]) + exp = np.array([False, False]) + tm.assert_numpy_array_equal(result, exp) + + result = isna(["foo", "bar"]) + exp = np.array([False, False]) + tm.assert_numpy_array_equal(result, exp) + + # GH20675 + result = isna([np.nan, "world"]) + exp = np.array([True, False]) + tm.assert_numpy_array_equal(result, exp) + + def test_isna_nat(self): + result = isna([NaT]) + exp = np.array([True]) + tm.assert_numpy_array_equal(result, exp) + + result = isna(np.array([NaT], dtype=object)) + exp = np.array([True]) + tm.assert_numpy_array_equal(result, exp) + + def test_isna_numpy_nat(self): + arr = np.array( + [ + NaT, + np.datetime64("NaT"), + np.timedelta64("NaT"), + np.datetime64("NaT", "s"), + ] + ) + result = isna(arr) + expected = np.array([True] * 4) + tm.assert_numpy_array_equal(result, expected) + + def test_isna_datetime(self): + assert not isna(datetime.now()) + assert notna(datetime.now()) + + idx = date_range("1/1/1990", periods=20) + exp = np.ones(len(idx), dtype=bool) + tm.assert_numpy_array_equal(notna(idx), exp) + + idx = np.asarray(idx) + idx[0] = iNaT + idx = DatetimeIndex(idx) + mask = isna(idx) + assert mask[0] + exp = np.array([True] + [False] * (len(idx) - 1), dtype=bool) + tm.assert_numpy_array_equal(mask, exp) + + # GH 9129 + pidx = idx.to_period(freq="M") + mask = isna(pidx) + assert mask[0] + exp = np.array([True] + [False] * (len(idx) - 1), dtype=bool) + tm.assert_numpy_array_equal(mask, exp) + + mask = isna(pidx[1:]) + exp = np.zeros(len(mask), dtype=bool) + tm.assert_numpy_array_equal(mask, exp) + + def test_isna_old_datetimelike(self): + # isna_old should work for dt64tz, td64, and period, not just tznaive + dti = date_range("2016-01-01", periods=3) + dta = dti._data + dta[-1] = NaT + expected = np.array([False, False, True], dtype=bool) + + objs = [dta, dta.tz_localize("US/Eastern"), dta - dta, dta.to_period("D")] + + for obj in objs: + result = isna(obj) + + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "value, expected", + [ + (np.complex128(np.nan), True), + (np.float64(1), False), + (np.array([1, 1 + 0j, np.nan, 3]), np.array([False, False, True, False])), + ( + np.array([1, 1 + 0j, np.nan, 3], dtype=object), + np.array([False, False, True, False]), + ), + ( + np.array([1, 1 + 0j, np.nan, 3]).astype(object), + np.array([False, False, True, False]), + ), + ], + ) + def test_complex(self, value, expected): + result = isna(value) + if is_scalar(result): + assert result is expected + else: + tm.assert_numpy_array_equal(result, expected) + + def test_datetime_other_units(self): + idx = DatetimeIndex(["2011-01-01", "NaT", "2011-01-02"]) + exp = np.array([False, True, False]) + tm.assert_numpy_array_equal(isna(idx), exp) + tm.assert_numpy_array_equal(notna(idx), ~exp) + tm.assert_numpy_array_equal(isna(idx.values), exp) + tm.assert_numpy_array_equal(notna(idx.values), ~exp) + + @pytest.mark.parametrize( + "dtype", + [ + "datetime64[D]", + "datetime64[h]", + "datetime64[m]", + "datetime64[s]", + "datetime64[ms]", + "datetime64[us]", + "datetime64[ns]", + ], + ) + def test_datetime_other_units_astype(self, dtype): + idx = DatetimeIndex(["2011-01-01", "NaT", "2011-01-02"]) + values = idx.values.astype(dtype) + + exp = np.array([False, True, False]) + tm.assert_numpy_array_equal(isna(values), exp) + tm.assert_numpy_array_equal(notna(values), ~exp) + + exp = Series([False, True, False]) + s = Series(values) + tm.assert_series_equal(isna(s), exp) + tm.assert_series_equal(notna(s), ~exp) + s = Series(values, dtype=object) + tm.assert_series_equal(isna(s), exp) + tm.assert_series_equal(notna(s), ~exp) + + def test_timedelta_other_units(self): + idx = TimedeltaIndex(["1 days", "NaT", "2 days"]) + exp = np.array([False, True, False]) + tm.assert_numpy_array_equal(isna(idx), exp) + tm.assert_numpy_array_equal(notna(idx), ~exp) + tm.assert_numpy_array_equal(isna(idx.values), exp) + tm.assert_numpy_array_equal(notna(idx.values), ~exp) + + @pytest.mark.parametrize( + "dtype", + [ + "timedelta64[D]", + "timedelta64[h]", + "timedelta64[m]", + "timedelta64[s]", + "timedelta64[ms]", + "timedelta64[us]", + "timedelta64[ns]", + ], + ) + def test_timedelta_other_units_dtype(self, dtype): + idx = TimedeltaIndex(["1 days", "NaT", "2 days"]) + values = idx.values.astype(dtype) + + exp = np.array([False, True, False]) + tm.assert_numpy_array_equal(isna(values), exp) + tm.assert_numpy_array_equal(notna(values), ~exp) + + exp = Series([False, True, False]) + s = Series(values) + tm.assert_series_equal(isna(s), exp) + tm.assert_series_equal(notna(s), ~exp) + s = Series(values, dtype=object) + tm.assert_series_equal(isna(s), exp) + tm.assert_series_equal(notna(s), ~exp) + + def test_period(self): + idx = pd.PeriodIndex(["2011-01", "NaT", "2012-01"], freq="M") + exp = np.array([False, True, False]) + tm.assert_numpy_array_equal(isna(idx), exp) + tm.assert_numpy_array_equal(notna(idx), ~exp) + + exp = Series([False, True, False]) + s = Series(idx) + tm.assert_series_equal(isna(s), exp) + tm.assert_series_equal(notna(s), ~exp) + s = Series(idx, dtype=object) + tm.assert_series_equal(isna(s), exp) + tm.assert_series_equal(notna(s), ~exp) + + def test_decimal(self): + # scalars GH#23530 + a = Decimal("1.0") + assert isna(a) is False + assert notna(a) is True + + b = Decimal("NaN") + assert isna(b) is True + assert notna(b) is False + + # array + arr = np.array([a, b]) + expected = np.array([False, True]) + result = isna(arr) + tm.assert_numpy_array_equal(result, expected) + + result = notna(arr) + tm.assert_numpy_array_equal(result, ~expected) + + # series + ser = Series(arr) + expected = Series(expected) + result = isna(ser) + tm.assert_series_equal(result, expected) + + result = notna(ser) + tm.assert_series_equal(result, ~expected) + + # index + idx = Index(arr) + expected = np.array([False, True]) + result = isna(idx) + tm.assert_numpy_array_equal(result, expected) + + result = notna(idx) + tm.assert_numpy_array_equal(result, ~expected) + + +@pytest.mark.parametrize("dtype_equal", [True, False]) +def test_array_equivalent(dtype_equal): + assert array_equivalent( + np.array([np.nan, np.nan]), np.array([np.nan, np.nan]), dtype_equal=dtype_equal + ) + assert array_equivalent( + np.array([np.nan, 1, np.nan]), + np.array([np.nan, 1, np.nan]), + dtype_equal=dtype_equal, + ) + assert array_equivalent( + np.array([np.nan, None], dtype="object"), + np.array([np.nan, None], dtype="object"), + dtype_equal=dtype_equal, + ) + # Check the handling of nested arrays in array_equivalent_object + assert array_equivalent( + np.array([np.array([np.nan, None], dtype="object"), None], dtype="object"), + np.array([np.array([np.nan, None], dtype="object"), None], dtype="object"), + dtype_equal=dtype_equal, + ) + assert array_equivalent( + np.array([np.nan, 1 + 1j], dtype="complex"), + np.array([np.nan, 1 + 1j], dtype="complex"), + dtype_equal=dtype_equal, + ) + assert not array_equivalent( + np.array([np.nan, 1 + 1j], dtype="complex"), + np.array([np.nan, 1 + 2j], dtype="complex"), + dtype_equal=dtype_equal, + ) + assert not array_equivalent( + np.array([np.nan, 1, np.nan]), + np.array([np.nan, 2, np.nan]), + dtype_equal=dtype_equal, + ) + assert not array_equivalent( + np.array(["a", "b", "c", "d"]), np.array(["e", "e"]), dtype_equal=dtype_equal + ) + assert array_equivalent( + Index([0, np.nan]), Index([0, np.nan]), dtype_equal=dtype_equal + ) + assert not array_equivalent( + Index([0, np.nan]), Index([1, np.nan]), dtype_equal=dtype_equal + ) + + +@pytest.mark.parametrize("dtype_equal", [True, False]) +def test_array_equivalent_tdi(dtype_equal): + assert array_equivalent( + TimedeltaIndex([0, np.nan]), + TimedeltaIndex([0, np.nan]), + dtype_equal=dtype_equal, + ) + assert not array_equivalent( + TimedeltaIndex([0, np.nan]), + TimedeltaIndex([1, np.nan]), + dtype_equal=dtype_equal, + ) + + +@pytest.mark.parametrize("dtype_equal", [True, False]) +def test_array_equivalent_dti(dtype_equal): + assert array_equivalent( + DatetimeIndex([0, np.nan]), DatetimeIndex([0, np.nan]), dtype_equal=dtype_equal + ) + assert not array_equivalent( + DatetimeIndex([0, np.nan]), DatetimeIndex([1, np.nan]), dtype_equal=dtype_equal + ) + + dti1 = DatetimeIndex([0, np.nan], tz="US/Eastern") + dti2 = DatetimeIndex([0, np.nan], tz="CET") + dti3 = DatetimeIndex([1, np.nan], tz="US/Eastern") + + assert array_equivalent( + dti1, + dti1, + dtype_equal=dtype_equal, + ) + assert not array_equivalent( + dti1, + dti3, + dtype_equal=dtype_equal, + ) + # The rest are not dtype_equal + assert not array_equivalent(DatetimeIndex([0, np.nan]), dti1) + assert array_equivalent( + dti2, + dti1, + ) + + assert not array_equivalent(DatetimeIndex([0, np.nan]), TimedeltaIndex([0, np.nan])) + + +@pytest.mark.parametrize( + "val", [1, 1.1, 1 + 1j, True, "abc", [1, 2], (1, 2), {1, 2}, {"a": 1}, None] +) +def test_array_equivalent_series(val): + arr = np.array([1, 2]) + assert not array_equivalent(Series([arr, arr]), Series([arr, val])) + + +def test_array_equivalent_array_mismatched_shape(): + # to trigger the motivating bug, the first N elements of the arrays need + # to match + first = np.array([1, 2, 3]) + second = np.array([1, 2]) + + left = Series([first, "a"], dtype=object) + right = Series([second, "a"], dtype=object) + assert not array_equivalent(left, right) + + +def test_array_equivalent_array_mismatched_dtype(): + # same shape, different dtype can still be equivalent + first = np.array([1, 2], dtype=np.float64) + second = np.array([1, 2]) + + left = Series([first, "a"], dtype=object) + right = Series([second, "a"], dtype=object) + assert array_equivalent(left, right) + + +def test_array_equivalent_different_dtype_but_equal(): + # Unclear if this is exposed anywhere in the public-facing API + assert array_equivalent(np.array([1, 2]), np.array([1.0, 2.0])) + + +@pytest.mark.parametrize( + "lvalue, rvalue", + [ + # There are 3 variants for each of lvalue and rvalue. We include all + # three for the tz-naive `now` and exclude the datetim64 variant + # for utcnow because it drops tzinfo. + (fix_now, fix_utcnow), + (fix_now.to_datetime64(), fix_utcnow), + (fix_now.to_pydatetime(), fix_utcnow), + (fix_now.to_datetime64(), fix_utcnow.to_pydatetime()), + (fix_now.to_pydatetime(), fix_utcnow.to_pydatetime()), + ], +) +def test_array_equivalent_tzawareness(lvalue, rvalue): + # we shouldn't raise if comparing tzaware and tznaive datetimes + left = np.array([lvalue], dtype=object) + right = np.array([rvalue], dtype=object) + + assert not array_equivalent(left, right, strict_nan=True) + assert not array_equivalent(left, right, strict_nan=False) + + +def test_array_equivalent_compat(): + # see gh-13388 + m = np.array([(1, 2), (3, 4)], dtype=[("a", int), ("b", float)]) + n = np.array([(1, 2), (3, 4)], dtype=[("a", int), ("b", float)]) + assert array_equivalent(m, n, strict_nan=True) + assert array_equivalent(m, n, strict_nan=False) + + m = np.array([(1, 2), (3, 4)], dtype=[("a", int), ("b", float)]) + n = np.array([(1, 2), (4, 3)], dtype=[("a", int), ("b", float)]) + assert not array_equivalent(m, n, strict_nan=True) + assert not array_equivalent(m, n, strict_nan=False) + + m = np.array([(1, 2), (3, 4)], dtype=[("a", int), ("b", float)]) + n = np.array([(1, 2), (3, 4)], dtype=[("b", int), ("a", float)]) + assert not array_equivalent(m, n, strict_nan=True) + assert not array_equivalent(m, n, strict_nan=False) + + +@pytest.mark.parametrize("dtype", ["O", "S", "U"]) +def test_array_equivalent_str(dtype): + assert array_equivalent( + np.array(["A", "B"], dtype=dtype), np.array(["A", "B"], dtype=dtype) + ) + assert not array_equivalent( + np.array(["A", "B"], dtype=dtype), np.array(["A", "X"], dtype=dtype) + ) + + +@pytest.mark.parametrize("strict_nan", [True, False]) +def test_array_equivalent_nested(strict_nan): + # reached in groupby aggregations, make sure we use np.any when checking + # if the comparison is truthy + left = np.array([np.array([50, 70, 90]), np.array([20, 30])], dtype=object) + right = np.array([np.array([50, 70, 90]), np.array([20, 30])], dtype=object) + + assert array_equivalent(left, right, strict_nan=strict_nan) + assert not array_equivalent(left, right[::-1], strict_nan=strict_nan) + + left = np.empty(2, dtype=object) + left[:] = [np.array([50, 70, 90]), np.array([20, 30, 40])] + right = np.empty(2, dtype=object) + right[:] = [np.array([50, 70, 90]), np.array([20, 30, 40])] + assert array_equivalent(left, right, strict_nan=strict_nan) + assert not array_equivalent(left, right[::-1], strict_nan=strict_nan) + + left = np.array([np.array([50, 50, 50]), np.array([40, 40])], dtype=object) + right = np.array([50, 40]) + assert not array_equivalent(left, right, strict_nan=strict_nan) + + +@pytest.mark.filterwarnings("ignore:elementwise comparison failed:DeprecationWarning") +@pytest.mark.parametrize("strict_nan", [True, False]) +def test_array_equivalent_nested2(strict_nan): + # more than one level of nesting + left = np.array( + [ + np.array([np.array([50, 70]), np.array([90])], dtype=object), + np.array([np.array([20, 30])], dtype=object), + ], + dtype=object, + ) + right = np.array( + [ + np.array([np.array([50, 70]), np.array([90])], dtype=object), + np.array([np.array([20, 30])], dtype=object), + ], + dtype=object, + ) + assert array_equivalent(left, right, strict_nan=strict_nan) + assert not array_equivalent(left, right[::-1], strict_nan=strict_nan) + + left = np.array([np.array([np.array([50, 50, 50])], dtype=object)], dtype=object) + right = np.array([50]) + assert not array_equivalent(left, right, strict_nan=strict_nan) + + +@pytest.mark.parametrize("strict_nan", [True, False]) +def test_array_equivalent_nested_list(strict_nan): + left = np.array([[50, 70, 90], [20, 30]], dtype=object) + right = np.array([[50, 70, 90], [20, 30]], dtype=object) + + assert array_equivalent(left, right, strict_nan=strict_nan) + assert not array_equivalent(left, right[::-1], strict_nan=strict_nan) + + left = np.array([[50, 50, 50], [40, 40]], dtype=object) + right = np.array([50, 40]) + assert not array_equivalent(left, right, strict_nan=strict_nan) + + +@pytest.mark.filterwarnings("ignore:elementwise comparison failed:DeprecationWarning") +@pytest.mark.xfail(reason="failing") +@pytest.mark.parametrize("strict_nan", [True, False]) +def test_array_equivalent_nested_mixed_list(strict_nan): + # mixed arrays / lists in left and right + # https://github.com/pandas-dev/pandas/issues/50360 + left = np.array([np.array([1, 2, 3]), np.array([4, 5])], dtype=object) + right = np.array([[1, 2, 3], [4, 5]], dtype=object) + + assert array_equivalent(left, right, strict_nan=strict_nan) + assert not array_equivalent(left, right[::-1], strict_nan=strict_nan) + + # multiple levels of nesting + left = np.array( + [ + np.array([np.array([1, 2, 3]), np.array([4, 5])], dtype=object), + np.array([np.array([6]), np.array([7, 8]), np.array([9])], dtype=object), + ], + dtype=object, + ) + right = np.array([[[1, 2, 3], [4, 5]], [[6], [7, 8], [9]]], dtype=object) + assert array_equivalent(left, right, strict_nan=strict_nan) + assert not array_equivalent(left, right[::-1], strict_nan=strict_nan) + + # same-length lists + subarr = np.empty(2, dtype=object) + subarr[:] = [ + np.array([None, "b"], dtype=object), + np.array(["c", "d"], dtype=object), + ] + left = np.array([subarr, None], dtype=object) + right = np.array([[[None, "b"], ["c", "d"]], None], dtype=object) + assert array_equivalent(left, right, strict_nan=strict_nan) + assert not array_equivalent(left, right[::-1], strict_nan=strict_nan) + + +@pytest.mark.xfail(reason="failing") +@pytest.mark.parametrize("strict_nan", [True, False]) +def test_array_equivalent_nested_dicts(strict_nan): + left = np.array([{"f1": 1, "f2": np.array(["a", "b"], dtype=object)}], dtype=object) + right = np.array( + [{"f1": 1, "f2": np.array(["a", "b"], dtype=object)}], dtype=object + ) + assert array_equivalent(left, right, strict_nan=strict_nan) + assert not array_equivalent(left, right[::-1], strict_nan=strict_nan) + + right2 = np.array([{"f1": 1, "f2": ["a", "b"]}], dtype=object) + assert array_equivalent(left, right2, strict_nan=strict_nan) + assert not array_equivalent(left, right2[::-1], strict_nan=strict_nan) + + +def test_array_equivalent_index_with_tuples(): + # GH#48446 + idx1 = Index(np.array([(pd.NA, 4), (1, 1)], dtype="object")) + idx2 = Index(np.array([(1, 1), (pd.NA, 4)], dtype="object")) + assert not array_equivalent(idx1, idx2) + assert not idx1.equals(idx2) + assert not array_equivalent(idx2, idx1) + assert not idx2.equals(idx1) + + idx1 = Index(np.array([(4, pd.NA), (1, 1)], dtype="object")) + idx2 = Index(np.array([(1, 1), (4, pd.NA)], dtype="object")) + assert not array_equivalent(idx1, idx2) + assert not idx1.equals(idx2) + assert not array_equivalent(idx2, idx1) + assert not idx2.equals(idx1) + + +@pytest.mark.parametrize( + "dtype, na_value", + [ + # Datetime-like + (np.dtype("M8[ns]"), np.datetime64("NaT", "ns")), + (np.dtype("m8[ns]"), np.timedelta64("NaT", "ns")), + (DatetimeTZDtype.construct_from_string("datetime64[ns, US/Eastern]"), NaT), + (PeriodDtype("M"), NaT), + # Integer + ("u1", 0), + ("u2", 0), + ("u4", 0), + ("u8", 0), + ("i1", 0), + ("i2", 0), + ("i4", 0), + ("i8", 0), + # Bool + ("bool", False), + # Float + ("f2", np.nan), + ("f4", np.nan), + ("f8", np.nan), + # Complex + ("c8", np.nan), + ("c16", np.nan), + # Object + ("O", np.nan), + # Interval + (IntervalDtype(), np.nan), + ], +) +def test_na_value_for_dtype(dtype, na_value): + result = na_value_for_dtype(pandas_dtype(dtype)) + # identify check doesn't work for datetime64/timedelta64("NaT") bc they + # are not singletons + assert result is na_value or ( + isna(result) and isna(na_value) and type(result) is type(na_value) + ) + + +class TestNAObj: + def _check_behavior(self, arr, expected): + result = libmissing.isnaobj(arr) + tm.assert_numpy_array_equal(result, expected) + + arr = np.atleast_2d(arr) + expected = np.atleast_2d(expected) + + result = libmissing.isnaobj(arr) + tm.assert_numpy_array_equal(result, expected) + + # Test fortran order + arr = arr.copy(order="F") + result = libmissing.isnaobj(arr) + tm.assert_numpy_array_equal(result, expected) + + def test_basic(self): + arr = np.array([1, None, "foo", -5.1, NaT, np.nan]) + expected = np.array([False, True, False, False, True, True]) + + self._check_behavior(arr, expected) + + def test_non_obj_dtype(self): + arr = np.array([1, 3, np.nan, 5], dtype=float) + expected = np.array([False, False, True, False]) + + self._check_behavior(arr, expected) + + def test_empty_arr(self): + arr = np.array([]) + expected = np.array([], dtype=bool) + + self._check_behavior(arr, expected) + + def test_empty_str_inp(self): + arr = np.array([""]) # empty but not na + expected = np.array([False]) + + self._check_behavior(arr, expected) + + def test_empty_like(self): + # see gh-13717: no segfaults! + arr = np.empty_like([None]) + expected = np.array([True]) + + self._check_behavior(arr, expected) + + +m8_units = ["as", "ps", "ns", "us", "ms", "s", "m", "h", "D", "W", "M", "Y"] + +na_vals = ( + [ + None, + NaT, + float("NaN"), + complex("NaN"), + np.nan, + np.float64("NaN"), + np.float32("NaN"), + np.complex64(np.nan), + np.complex128(np.nan), + np.datetime64("NaT"), + np.timedelta64("NaT"), + ] + + [np.datetime64("NaT", unit) for unit in m8_units] # type: ignore[call-overload] + + [np.timedelta64("NaT", unit) for unit in m8_units] # type: ignore[call-overload] +) + +inf_vals = [ + float("inf"), + float("-inf"), + complex("inf"), + complex("-inf"), + np.inf, + -np.inf, +] + +int_na_vals = [ + # Values that match iNaT, which we treat as null in specific cases + np.int64(NaT._value), + int(NaT._value), +] + +sometimes_na_vals = [Decimal("NaN")] + +never_na_vals = [ + # float/complex values that when viewed as int64 match iNaT + -0.0, + np.float64("-0.0"), + -0j, + np.complex64(-0j), +] + + +class TestLibMissing: + @pytest.mark.parametrize("func", [libmissing.checknull, isna]) + @pytest.mark.parametrize( + "value", + na_vals + sometimes_na_vals, # type: ignore[operator] + ) + def test_checknull_na_vals(self, func, value): + assert func(value) + + @pytest.mark.parametrize("func", [libmissing.checknull, isna]) + @pytest.mark.parametrize("value", inf_vals) + def test_checknull_inf_vals(self, func, value): + assert not func(value) + + @pytest.mark.parametrize("func", [libmissing.checknull, isna]) + @pytest.mark.parametrize("value", int_na_vals) + def test_checknull_intna_vals(self, func, value): + assert not func(value) + + @pytest.mark.parametrize("func", [libmissing.checknull, isna]) + @pytest.mark.parametrize("value", never_na_vals) + def test_checknull_never_na_vals(self, func, value): + assert not func(value) + + @pytest.mark.parametrize( + "value", + na_vals + sometimes_na_vals, # type: ignore[operator] + ) + def test_checknull_old_na_vals(self, value): + assert libmissing.checknull(value) + + @pytest.mark.parametrize("value", int_na_vals) + def test_checknull_old_intna_vals(self, value): + assert not libmissing.checknull(value) + + def test_is_matching_na(self, nulls_fixture, nulls_fixture2): + left = nulls_fixture + right = nulls_fixture2 + + assert libmissing.is_matching_na(left, left) + + if left is right: + assert libmissing.is_matching_na(left, right) + elif is_float(left) and is_float(right): + # np.nan vs float("NaN") we consider as matching + assert libmissing.is_matching_na(left, right) + elif type(left) is type(right): + # e.g. both Decimal("NaN") + assert libmissing.is_matching_na(left, right) + else: + assert not libmissing.is_matching_na(left, right) + + def test_is_matching_na_nan_matches_none(self): + assert not libmissing.is_matching_na(None, np.nan) + assert not libmissing.is_matching_na(np.nan, None) + + assert libmissing.is_matching_na(None, np.nan, nan_matches_none=True) + assert libmissing.is_matching_na(np.nan, None, nan_matches_none=True) + + +class TestIsValidNAForDtype: + def test_is_valid_na_for_dtype_interval(self): + dtype = IntervalDtype("int64", "left") + assert not is_valid_na_for_dtype(NaT, dtype) + + dtype = IntervalDtype("datetime64[ns]", "both") + assert not is_valid_na_for_dtype(NaT, dtype) + + def test_is_valid_na_for_dtype_categorical(self): + dtype = CategoricalDtype(categories=[0, 1, 2]) + assert is_valid_na_for_dtype(np.nan, dtype) + + assert not is_valid_na_for_dtype(NaT, dtype) + assert not is_valid_na_for_dtype(np.datetime64("NaT", "ns"), dtype) + assert not is_valid_na_for_dtype(np.timedelta64("NaT", "ns"), dtype) diff --git a/pandas/tests/extension/__init__.py b/pandas/tests/extension/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/extension/conftest.py b/pandas/tests/extension/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..1376af5e51a6b6255625c257031e74a33d7b7b77 --- /dev/null +++ b/pandas/tests/extension/conftest.py @@ -0,0 +1,214 @@ +import operator + +import pytest + +from pandas import Series + + +@pytest.fixture +def dtype(): + """A fixture providing the ExtensionDtype to validate.""" + raise NotImplementedError + + +@pytest.fixture +def data(): + """ + Length-10 array for this type. + + * data[0] and data[1] should both be non missing + * data[0] and data[1] should not be equal + """ + raise NotImplementedError + + +@pytest.fixture +def data_for_twos(dtype): + """ + Length-10 array in which all the elements are two. + + Call pytest.skip in your fixture if the dtype does not support divmod. + """ + if not (dtype._is_numeric or dtype.kind == "m"): + # Object-dtypes may want to allow this, but for the most part + # only numeric and timedelta-like dtypes will need to implement this. + pytest.skip(f"{dtype} is not a numeric dtype") + + raise NotImplementedError + + +@pytest.fixture +def data_missing(): + """Length-2 array with [NA, Valid]""" + raise NotImplementedError + + +@pytest.fixture(params=["data", "data_missing"]) +def all_data(request, data, data_missing): + """Parametrized fixture giving 'data' and 'data_missing'""" + if request.param == "data": + return data + elif request.param == "data_missing": + return data_missing + + +@pytest.fixture +def data_repeated(data): + """ + Generate many datasets. + + Parameters + ---------- + data : fixture implementing `data` + + Returns + ------- + Callable[[int], Generator]: + A callable that takes a `count` argument and + returns a generator yielding `count` datasets. + """ + + def gen(count): + for _ in range(count): + yield data + + return gen + + +@pytest.fixture +def data_for_sorting(): + """ + Length-3 array with a known sort order. + + This should be three items [B, C, A] with + A < B < C + + For boolean dtypes (for which there are only 2 values available), + set B=C=True + """ + raise NotImplementedError + + +@pytest.fixture +def data_missing_for_sorting(): + """ + Length-3 array with a known sort order. + + This should be three items [B, NA, A] with + A < B and NA missing. + """ + raise NotImplementedError + + +@pytest.fixture +def na_cmp(): + """ + Binary operator for comparing NA values. + + Should return a function of two arguments that returns + True if both arguments are (scalar) NA for your type. + + By default, uses ``operator.is_`` + """ + return operator.is_ + + +@pytest.fixture +def na_value(dtype): + """ + The scalar missing value for this type. Default dtype.na_value. + + TODO: can be removed in 3.x (see https://github.com/pandas-dev/pandas/pull/54930) + """ + return dtype.na_value + + +@pytest.fixture +def data_for_grouping(): + """ + Data for factorization, grouping, and unique tests. + + Expected to be like [B, B, NA, NA, A, A, B, C] + + Where A < B < C and NA is missing. + + If a dtype has _is_boolean = True, i.e. only 2 unique non-NA entries, + then set C=B. + """ + raise NotImplementedError + + +@pytest.fixture(params=[True, False]) +def box_in_series(request): + """Whether to box the data in a Series""" + return request.param + + +@pytest.fixture( + params=[ + lambda x: 1, + lambda x: [1] * len(x), + lambda x: Series([1] * len(x)), + lambda x: x, + ], + ids=["scalar", "list", "series", "object"], +) +def groupby_apply_op(request): + """ + Functions to test groupby.apply(). + """ + return request.param + + +@pytest.fixture(params=[True, False]) +def as_frame(request): + """ + Boolean fixture to support Series and Series.to_frame() comparison testing. + """ + return request.param + + +@pytest.fixture(params=[True, False]) +def as_series(request): + """ + Boolean fixture to support arr and Series(arr) comparison testing. + """ + return request.param + + +@pytest.fixture(params=[True, False]) +def use_numpy(request): + """ + Boolean fixture to support comparison testing of ExtensionDtype array + and numpy array. + """ + return request.param + + +@pytest.fixture(params=["ffill", "bfill"]) +def fillna_method(request): + """ + Parametrized fixture giving method parameters 'ffill' and 'bfill' for + Series. testing. + """ + return request.param + + +@pytest.fixture(params=[True, False]) +def as_array(request): + """ + Boolean fixture to support ExtensionDtype _from_sequence method testing. + """ + return request.param + + +@pytest.fixture +def invalid_scalar(data): + """ + A scalar that *cannot* be held by this ExtensionArray. + + The default should work for most subclasses, but is not guaranteed. + + If the array can hold any item (i.e. object dtype), then use pytest.skip. + """ + return object.__new__(object) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py new file mode 100644 index 0000000000000000000000000000000000000000..f3388d74447391dc3640c35b0cb12da2267c579c --- /dev/null +++ b/pandas/tests/extension/test_arrow.py @@ -0,0 +1,3951 @@ +""" +This file contains a minimal set of tests for compliance with the extension +array interface test suite, and should contain no other tests. +The test suite for the full functionality of the array is located in +`pandas/tests/arrays/`. +The tests in this file are inherited from the BaseExtensionTests, and only +minimal tweaks should be applied to get the tests passing (by overwriting a +parent method). +Additional tests should either be added to one of the BaseExtensionTests +classes (if they are relevant for the extension interface for all dtypes), or +be added to the array-specific tests in `pandas/tests/arrays/`. +""" + +from __future__ import annotations + +from datetime import ( + date, + datetime, + time, + timedelta, +) +from decimal import Decimal +from io import ( + BytesIO, + StringIO, +) +import operator +import pickle +import re +import sys + +import numpy as np +import pytest + +from pandas._libs import lib +from pandas._libs.tslibs import timezones +from pandas.compat import ( + PY312, + is_ci_environment, + is_platform_windows, + pa_version_under14p0, + pa_version_under19p0, + pa_version_under20p0, + pa_version_under21p0, +) +from pandas.compat.pyarrow import pa_version_under22p0 +from pandas.errors import Pandas4Warning + +from pandas.core.dtypes.common import pandas_dtype +from pandas.core.dtypes.dtypes import ( + ArrowDtype, + CategoricalDtypeType, +) + +import pandas as pd +import pandas._testing as tm +from pandas.api.extensions import no_default +from pandas.api.types import ( + is_bool_dtype, + is_datetime64_any_dtype, + is_float_dtype, + is_integer_dtype, + is_numeric_dtype, + is_signed_integer_dtype, + is_string_dtype, + is_unsigned_integer_dtype, +) +from pandas.tests.extension import base + +pa = pytest.importorskip("pyarrow") + +from pandas.core.arrays.arrow.array import ArrowExtensionArray +from pandas.core.arrays.arrow.extension_types import ArrowPeriodType + + +def _require_timezone_database(request): + if is_platform_windows() and is_ci_environment() and pa_version_under22p0: + mark = pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason=( + "TODO: Set ARROW_TIMEZONE_DATABASE environment variable " + "on CI to path to the tzdata for pyarrow." + ), + ) + request.applymarker(mark) + + +@pytest.fixture(params=tm.ALL_PYARROW_DTYPES, ids=str) +def dtype(request): + return ArrowDtype(pyarrow_dtype=request.param) + + +@pytest.fixture +def data(dtype): + pa_dtype = dtype.pyarrow_dtype + if pa.types.is_boolean(pa_dtype): + data = [True, False] * 2 + [None] + [True, False] + [None] + [True, False] + elif pa.types.is_floating(pa_dtype): + data = [1.0, 0.0] * 2 + [None] + [-2.0, -1.0] + [None] + [0.5, 99.5] + elif pa.types.is_signed_integer(pa_dtype): + data = [1, 0] * 2 + [None] + [-2, -1] + [None] + [1, 99] + elif pa.types.is_unsigned_integer(pa_dtype): + data = [1, 0] * 2 + [None] + [2, 1] + [None] + [1, 99] + elif pa.types.is_decimal(pa_dtype): + data = ( + [Decimal("1"), Decimal("0.0")] * 2 + + [None] + + [Decimal("-2.0"), Decimal("-1.0")] + + [None] + + [Decimal("0.5"), Decimal("33.123")] + ) + elif pa.types.is_date(pa_dtype): + data = ( + [date(2022, 1, 1), date(1999, 12, 31)] * 2 + + [None] + + [date(2022, 1, 1), date(2022, 1, 1)] + + [None] + + [date(1999, 12, 31), date(1999, 12, 31)] + ) + elif pa.types.is_timestamp(pa_dtype): + data = ( + [datetime(2020, 1, 1, 1, 1, 1, 1), datetime(1999, 1, 1, 1, 1, 1, 1)] * 2 + + [None] + + [datetime(2020, 1, 1, 1), datetime(1999, 1, 1, 1)] + + [None] + + [datetime(2020, 1, 1), datetime(1999, 1, 1)] + ) + elif pa.types.is_duration(pa_dtype): + data = ( + [timedelta(1), timedelta(1, 1)] * 2 + + [None] + + [timedelta(-1), timedelta(0)] + + [None] + + [timedelta(-10), timedelta(10)] + ) + elif pa.types.is_time(pa_dtype): + data = ( + [time(12, 0), time(0, 12)] * 2 + + [None] + + [time(0, 0), time(1, 1)] + + [None] + + [time(0, 5), time(5, 0)] + ) + elif pa.types.is_string(pa_dtype): + data = ["a", "b"] * 2 + [None] + ["1", "2"] + [None] + ["!", ">"] + elif pa.types.is_binary(pa_dtype): + data = [b"a", b"b"] * 2 + [None] + [b"1", b"2"] + [None] + [b"!", b">"] + else: + raise NotImplementedError + return pd.array(data, dtype=dtype) + + +@pytest.fixture +def data_missing(data): + """Length-2 array with [NA, Valid]""" + return type(data)._from_sequence([None, data[0]], dtype=data.dtype) + + +@pytest.fixture(params=["data", "data_missing"]) +def all_data(request, data, data_missing): + """Parametrized fixture returning 'data' or 'data_missing' integer arrays. + + Used to test dtype conversion with and without missing values. + """ + if request.param == "data": + return data + elif request.param == "data_missing": + return data_missing + + +@pytest.fixture +def data_for_grouping(dtype): + """ + Data for factorization, grouping, and unique tests. + + Expected to be like [B, B, NA, NA, A, A, B, C] + + Where A < B < C and NA is missing + """ + pa_dtype = dtype.pyarrow_dtype + if pa.types.is_boolean(pa_dtype): + A = False + B = True + C = True + elif pa.types.is_floating(pa_dtype): + A = -1.1 + B = 0.0 + C = 1.1 + elif pa.types.is_signed_integer(pa_dtype): + A = -1 + B = 0 + C = 1 + elif pa.types.is_unsigned_integer(pa_dtype): + A = 0 + B = 1 + C = 10 + elif pa.types.is_date(pa_dtype): + A = date(1999, 12, 31) + B = date(2010, 1, 1) + C = date(2022, 1, 1) + elif pa.types.is_timestamp(pa_dtype): + A = datetime(1999, 1, 1, 1, 1, 1, 1) + B = datetime(2020, 1, 1) + C = datetime(2020, 1, 1, 1) + elif pa.types.is_duration(pa_dtype): + A = timedelta(-1) + B = timedelta(0) + C = timedelta(1, 4) + elif pa.types.is_time(pa_dtype): + A = time(0, 0) + B = time(0, 12) + C = time(12, 12) + elif pa.types.is_string(pa_dtype): + A = "a" + B = "b" + C = "c" + elif pa.types.is_binary(pa_dtype): + A = b"a" + B = b"b" + C = b"c" + elif pa.types.is_decimal(pa_dtype): + A = Decimal("-1.1") + B = Decimal("0.0") + C = Decimal("1.1") + else: + raise NotImplementedError + return pd.array([B, B, None, None, A, A, B, C], dtype=dtype) + + +@pytest.fixture +def data_for_sorting(data_for_grouping): + """ + Length-3 array with a known sort order. + + This should be three items [B, C, A] with + A < B < C + """ + return type(data_for_grouping)._from_sequence( + [data_for_grouping[0], data_for_grouping[7], data_for_grouping[4]], + dtype=data_for_grouping.dtype, + ) + + +@pytest.fixture +def data_missing_for_sorting(data_for_grouping): + """ + Length-3 array with a known sort order. + + This should be three items [B, NA, A] with + A < B and NA missing. + """ + return type(data_for_grouping)._from_sequence( + [data_for_grouping[0], data_for_grouping[2], data_for_grouping[4]], + dtype=data_for_grouping.dtype, + ) + + +@pytest.fixture +def data_for_twos(data): + """Length-100 array in which all the elements are two.""" + pa_dtype = data.dtype.pyarrow_dtype + if ( + pa.types.is_integer(pa_dtype) + or pa.types.is_floating(pa_dtype) + or pa.types.is_decimal(pa_dtype) + or pa.types.is_duration(pa_dtype) + ): + return pd.array([2] * 10, dtype=data.dtype) + # tests will be xfailed where 2 is not a valid scalar for pa_dtype + return data + # TODO: skip otherwise? + + +class TestArrowArray(base.ExtensionTests): + def _construct_for_combine_add(self, left, right): + dtype = left.dtype + + # in a couple cases, addition is not dtype-preserving + if dtype == "bool[pyarrow]": + dtype = pandas_dtype("int64[pyarrow]") + elif dtype == "int8[pyarrow]" and isinstance(right, type(left)): + dtype = pandas_dtype("int64[pyarrow]") + + if isinstance(right, type(left)): + return left._from_sequence( + [a + b for (a, b) in zip(list(left), list(right), strict=True)], + dtype=dtype, + ) + else: + return left._from_sequence( + [a + right for a in list(left)], + dtype=dtype, + ) + + def test_compare_scalar(self, data, comparison_op): + ser = pd.Series(data) + self._compare_other(ser, data, comparison_op, data[0]) + + def test_compare_range_len(self, data, comparison_op): + # GH#63429 + ser = pd.Series(data) + range_test = range(len(ser)) + self._compare_other(ser, range_test, comparison_op, range_test) + + @pytest.mark.parametrize("na_action", [None, "ignore"]) + def test_map(self, data_missing, na_action, using_nan_is_na): + if data_missing.dtype.kind in "mM": + result = data_missing.map(lambda x: x, na_action=na_action) + expected = data_missing.to_numpy(dtype=object) + tm.assert_numpy_array_equal(result, expected) + else: + result = data_missing.map(lambda x: x, na_action=na_action) + if data_missing.dtype == "float32[pyarrow]" and using_nan_is_na: + # map roundtrips through objects, which converts to float64 + expected = data_missing.to_numpy(dtype="float64", na_value=np.nan) + else: + expected = data_missing.to_numpy() + tm.assert_numpy_array_equal(result, expected) + + def test_astype_str(self, data, request, using_infer_string): + pa_dtype = data.dtype.pyarrow_dtype + if pa.types.is_binary(pa_dtype): + request.applymarker( + pytest.mark.xfail( + reason=f"For {pa_dtype} .astype(str) decodes.", + ) + ) + elif not using_infer_string and ( + (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None) + or pa.types.is_duration(pa_dtype) + ): + request.applymarker( + pytest.mark.xfail( + reason="pd.Timestamp/pd.Timedelta repr different from numpy repr", + ) + ) + super().test_astype_str(data) + + def test_from_dtype(self, data, request): + pa_dtype = data.dtype.pyarrow_dtype + if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype): + if pa.types.is_string(pa_dtype): + reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')" + else: + reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}" + + request.applymarker( + pytest.mark.xfail( + reason=reason, + ) + ) + super().test_from_dtype(data) + + def test_from_sequence_pa_array(self, data): + # https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784 + # data._pa_array = pa.ChunkedArray + result = type(data)._from_sequence(data._pa_array, dtype=data.dtype) + tm.assert_extension_array_equal(result, data) + assert isinstance(result._pa_array, pa.ChunkedArray) + + result = type(data)._from_sequence( + data._pa_array.combine_chunks(), dtype=data.dtype + ) + tm.assert_extension_array_equal(result, data) + assert isinstance(result._pa_array, pa.ChunkedArray) + + def test_from_sequence_pa_array_notimplemented(self, request): + dtype = ArrowDtype(pa.month_day_nano_interval()) + with pytest.raises(NotImplementedError, match="Converting strings to"): + ArrowExtensionArray._from_sequence_of_strings(["12-1"], dtype=dtype) + + def test_from_sequence_of_strings_pa_array(self, data, request): + pa_dtype = data.dtype.pyarrow_dtype + if pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None: + _require_timezone_database(request) + + pa_array = data._pa_array.cast(pa.string()) + result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype) + tm.assert_extension_array_equal(result, data) + + pa_array = pa_array.combine_chunks() + result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype) + tm.assert_extension_array_equal(result, data) + + def check_accumulate(self, ser, op_name, skipna): + result = getattr(ser, op_name)(skipna=skipna) + + pa_type = ser.dtype.pyarrow_dtype + if pa.types.is_temporal(pa_type): + # Just check that we match the integer behavior. + if pa_type.bit_width == 32: + int_type = "int32[pyarrow]" + else: + int_type = "int64[pyarrow]" + ser = ser.astype(int_type) + result = result.astype(int_type) + + result = result.astype("Float64") + expected = getattr(ser.astype("Float64"), op_name)(skipna=skipna) + tm.assert_series_equal(result, expected, check_dtype=False) + + def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool: + # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no + # attribute "pyarrow_dtype" + pa_type = ser.dtype.pyarrow_dtype # type: ignore[union-attr] + + if pa.types.is_binary(pa_type) or pa.types.is_decimal(pa_type): + if op_name in ["cumsum", "cumprod", "cummax", "cummin"]: + return False + elif pa.types.is_string(pa_type): + if op_name == "cumprod": + return False + elif pa.types.is_boolean(pa_type): + if op_name in ["cumprod", "cummax", "cummin"]: + return False + elif pa.types.is_temporal(pa_type): + if op_name == "cumsum" and not pa.types.is_duration(pa_type): + return False + elif op_name == "cumprod": + return False + return True + + @pytest.mark.parametrize("skipna", [True, False]) + def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request): + pa_type = data.dtype.pyarrow_dtype + op_name = all_numeric_accumulations + + if pa.types.is_string(pa_type) and op_name in ["cumsum", "cummin", "cummax"]: + # https://github.com/pandas-dev/pandas/pull/60633 + # Doesn't fit test structure, tested in series/test_cumulative.py instead. + return + + ser = pd.Series(data) + + if not self._supports_accumulation(ser, op_name): + # The base class test will check that we raise + return super().test_accumulate_series( + data, all_numeric_accumulations, skipna + ) + + if all_numeric_accumulations == "cumsum" and ( + pa.types.is_boolean(pa_type) or pa.types.is_decimal(pa_type) + ): + request.applymarker( + pytest.mark.xfail( + reason=f"{all_numeric_accumulations} not implemented for {pa_type}", + raises=TypeError, + ) + ) + + self.check_accumulate(ser, op_name, skipna) + + def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: + if op_name == "kurt" or (pa_version_under20p0 and op_name == "skew"): + return False + + dtype = ser.dtype + # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has + # no attribute "pyarrow_dtype" + pa_dtype = dtype.pyarrow_dtype # type: ignore[union-attr] + if pa.types.is_temporal(pa_dtype) and op_name in ["sum", "var", "prod", "skew"]: + if pa.types.is_duration(pa_dtype) and op_name in ["sum"]: + # summing timedeltas is one case that *is* well-defined + pass + else: + return False + elif pa.types.is_binary(pa_dtype) and op_name in ["sum", "skew"]: + return False + elif ( + pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) + ) and op_name in ["mean", "median", "prod", "std", "sem", "var", "skew"]: + return False + + if ( + pa.types.is_temporal(pa_dtype) + and not pa.types.is_duration(pa_dtype) + and op_name in ["any", "all"] + ): + # xref GH#34479 we support this in our non-pyarrow datetime64 dtypes, + # but it isn't obvious we _should_. For now, we keep the pyarrow + # behavior which does not support this. + return False + + if pa.types.is_boolean(pa_dtype) and op_name in [ + "median", + "std", + "var", + "skew", + "kurt", + "sem", + ]: + return False + + return True + + def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool): + # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no + # attribute "pyarrow_dtype" + pa_dtype = ser.dtype.pyarrow_dtype # type: ignore[union-attr] + if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype): + alt = ser.astype("Float64") + else: + # TODO: in the opposite case, aren't we testing... nothing? For + # e.g. date/time dtypes trying to calculate 'expected' by converting + # to object will raise for mean, std etc + alt = ser + + # TODO: in the opposite case, aren't we testing... nothing? + if op_name == "count": + result = getattr(ser, op_name)() + expected = getattr(alt, op_name)() + else: + result = getattr(ser, op_name)(skipna=skipna) + expected = getattr(alt, op_name)(skipna=skipna) + tm.assert_almost_equal(result, expected) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_reduce_series_boolean( + self, data, all_boolean_reductions, skipna, na_value, request + ): + pa_dtype = data.dtype.pyarrow_dtype + xfail_mark = pytest.mark.xfail( + raises=TypeError, + reason=( + f"{all_boolean_reductions} is not implemented in " + f"pyarrow={pa.__version__} for {pa_dtype}" + ), + ) + if pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype): + # We *might* want to make this behave like the non-pyarrow cases, + # but have not yet decided. + request.applymarker(xfail_mark) + + return super().test_reduce_series_boolean(data, all_boolean_reductions, skipna) + + def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool): + pa_type = arr._pa_array.type + + if op_name in ["max", "min"]: + cmp_dtype = arr.dtype + elif pa.types.is_temporal(pa_type): + if op_name in ["std", "sem"]: + if pa.types.is_duration(pa_type): + cmp_dtype = arr.dtype + elif pa.types.is_date(pa_type): + cmp_dtype = ArrowDtype(pa.duration("s")) + elif pa.types.is_time(pa_type): + cmp_dtype = ArrowDtype(pa.duration(pa_type.unit)) + else: + cmp_dtype = ArrowDtype(pa.duration(pa_type.unit)) + else: + cmp_dtype = arr.dtype + elif arr.dtype.name == "decimal128(7, 3)[pyarrow]": + if op_name == "sum" and not pa_version_under21p0: + # https://github.com/apache/arrow/pull/44184 + cmp_dtype = ArrowDtype(pa.decimal128(38, 3)) + elif op_name not in ["median", "var", "std", "sem", "skew"]: + cmp_dtype = arr.dtype + else: + cmp_dtype = "float64[pyarrow]" + elif op_name in ["median", "var", "std", "mean", "skew", "sem"]: + cmp_dtype = "float64[pyarrow]" + elif op_name in ["sum", "prod"] and pa.types.is_boolean(pa_type): + cmp_dtype = "uint64[pyarrow]" + elif op_name == "sum" and pa.types.is_string(pa_type): + cmp_dtype = arr.dtype + else: + cmp_dtype = { + "i": "int64[pyarrow]", + "u": "uint64[pyarrow]", + "f": "float64[pyarrow]", + }[arr.dtype.kind] + return cmp_dtype + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize("skipna", [True, False]) + def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request): + if ( + not pa_version_under20p0 + and skipna + and all_numeric_reductions == "skew" + and ( + pa.types.is_integer(data.dtype.pyarrow_dtype) + or pa.types.is_floating(data.dtype.pyarrow_dtype) + ) + ): + request.applymarker( + pytest.mark.xfail( + reason="https://github.com/apache/arrow/issues/45733", + ) + ) + return super().test_reduce_series_numeric(data, all_numeric_reductions, skipna) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_reduce_frame(self, data, all_numeric_reductions, skipna, request): + op_name = all_numeric_reductions + if op_name == "skew" and pa_version_under20p0: + if data.dtype._is_numeric: + mark = pytest.mark.xfail(reason="skew not implemented") + request.applymarker(mark) + return super().test_reduce_frame(data, all_numeric_reductions, skipna) + + @pytest.mark.parametrize("typ", ["int64", "uint64", "float64"]) + def test_median_not_approximate(self, typ): + # GH 52679 + result = pd.Series([1, 2], dtype=f"{typ}[pyarrow]").median() + assert result == 1.5 + + def test_construct_from_string_own_name(self, dtype, request): + pa_dtype = dtype.pyarrow_dtype + if pa.types.is_decimal(pa_dtype): + request.applymarker( + pytest.mark.xfail( + raises=NotImplementedError, + reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", + ) + ) + + if pa.types.is_string(pa_dtype): + # We still support StringDtype('pyarrow') over ArrowDtype(pa.string()) + msg = r"string\[pyarrow\] should be constructed by StringDtype" + with pytest.raises(TypeError, match=msg): + dtype.construct_from_string(dtype.name) + + return + + super().test_construct_from_string_own_name(dtype) + + def test_is_dtype_from_name(self, dtype, request): + pa_dtype = dtype.pyarrow_dtype + if pa.types.is_string(pa_dtype): + # We still support StringDtype('pyarrow') over ArrowDtype(pa.string()) + assert not type(dtype).is_dtype(dtype.name) + else: + if pa.types.is_decimal(pa_dtype): + request.applymarker( + pytest.mark.xfail( + raises=NotImplementedError, + reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}", + ) + ) + super().test_is_dtype_from_name(dtype) + + def test_construct_from_string_another_type_raises(self, dtype): + msg = r"'another_type' must end with '\[pyarrow\]'" + with pytest.raises(TypeError, match=msg): + type(dtype).construct_from_string("another_type") + + def test_get_common_dtype(self, dtype, request): + pa_dtype = dtype.pyarrow_dtype + if ( + pa.types.is_date(pa_dtype) + or pa.types.is_time(pa_dtype) + or (pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None) + or pa.types.is_binary(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ): + request.applymarker( + pytest.mark.xfail( + reason=( + f"{pa_dtype} does not have associated numpy " + f"dtype findable by find_common_type" + ) + ) + ) + super().test_get_common_dtype(dtype) + + def test_is_not_string_type(self, dtype): + pa_dtype = dtype.pyarrow_dtype + if pa.types.is_string(pa_dtype): + assert is_string_dtype(dtype) + else: + super().test_is_not_string_type(dtype) + + @pytest.mark.xfail( + reason="GH 45419: pyarrow.ChunkedArray does not support views.", run=False + ) + def test_view(self, data): + super().test_view(data) + + def test_fillna_no_op_returns_copy(self, data): + data = data[~data.isna()] + + valid = data[0] + result = data.fillna(valid) + assert result is not data + tm.assert_extension_array_equal(result, data) + + def test_fillna_readonly(self, data_missing): + data = data_missing.copy() + data._readonly = True + + # by default fillna(copy=True), then this works fine + result = data.fillna(data_missing[1]) + assert result[0] == data_missing[1] + tm.assert_extension_array_equal(data, data_missing) + + # fillna(copy=False) is generally not honored by Arrow-backed array, + # but always returns new data -> same result as above + result = data.fillna(data_missing[1]) + assert result[0] == data_missing[1] + tm.assert_extension_array_equal(data, data_missing) + + @pytest.mark.xfail( + reason="GH 45419: pyarrow.ChunkedArray does not support views", run=False + ) + def test_transpose(self, data): + super().test_transpose(data) + + @pytest.mark.xfail( + reason="GH 45419: pyarrow.ChunkedArray does not support views", run=False + ) + def test_setitem_preserves_views(self, data): + super().test_setitem_preserves_views(data) + + @pytest.mark.parametrize("dtype_backend", ["pyarrow", no_default]) + @pytest.mark.parametrize("engine", ["c", "python"]) + def test_EA_types(self, engine, data, dtype_backend, request, using_nan_is_na): + pa_dtype = data.dtype.pyarrow_dtype + if pa.types.is_decimal(pa_dtype): + request.applymarker( + pytest.mark.xfail( + raises=NotImplementedError, + reason=f"Parameterized types {pa_dtype} not supported.", + ) + ) + elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"): + request.applymarker( + pytest.mark.xfail( + raises=ValueError, + reason="https://github.com/pandas-dev/pandas/issues/49767", + ) + ) + elif pa.types.is_binary(pa_dtype): + request.applymarker( + pytest.mark.xfail(reason="CSV parsers don't correctly handle binary") + ) + df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))}) + if not using_nan_is_na: + csv_output = df.to_csv(index=False, na_rep="NA") + else: + csv_output = df.to_csv(index=False, na_rep=np.nan) + if pa.types.is_binary(pa_dtype): + csv_output = BytesIO(csv_output) + else: + csv_output = StringIO(csv_output) + result = pd.read_csv( + csv_output, + dtype={"with_dtype": str(data.dtype)}, + engine=engine, + dtype_backend=dtype_backend, + ) + expected = df + tm.assert_frame_equal(result, expected) + + def test_invert(self, data, request): + pa_dtype = data.dtype.pyarrow_dtype + if not ( + pa.types.is_boolean(pa_dtype) + or pa.types.is_integer(pa_dtype) + or pa.types.is_string(pa_dtype) + ): + request.applymarker( + pytest.mark.xfail( + raises=pa.ArrowNotImplementedError, + reason=f"pyarrow.compute.invert does support {pa_dtype}", + ) + ) + if PY312 and pa.types.is_boolean(pa_dtype): + with tm.assert_produces_warning( + DeprecationWarning, match="Bitwise inversion", check_stacklevel=False + ): + super().test_invert(data) + else: + super().test_invert(data) + + @pytest.mark.parametrize("periods", [1, -2]) + def test_diff(self, data, periods, request): + pa_dtype = data.dtype.pyarrow_dtype + if pa.types.is_unsigned_integer(pa_dtype) and periods == 1: + request.applymarker( + pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason=( + f"diff with {pa_dtype} and periods={periods} will overflow" + ), + ) + ) + super().test_diff(data, periods) + + def test_value_counts_returns_pyarrow_int64(self, data): + # GH 51462 + data = data[:10] + result = data.value_counts() + assert result.dtype == ArrowDtype(pa.int64()) + + _combine_le_expected_dtype = "bool[pyarrow]" + + def get_op_from_name(self, op_name): + short_opname = op_name.strip("_") + if short_opname == "rtruediv": + # use the numpy version that won't raise on division by zero + + def rtruediv(x, y): + return np.divide(y, x) + + return rtruediv + elif short_opname == "rfloordiv": + return lambda x, y: np.floor_divide(y, x) + + return tm.get_op_from_name(op_name) + + # TODO: use EA._cast_pointwise_result, same with other test files that + # override this + def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): + # BaseOpsUtil._combine can upcast expected dtype + # (because it generates expected on python scalars) + # while ArrowExtensionArray maintains original type + expected = pointwise_result + + if op_name in ["eq", "ne", "lt", "le", "gt", "ge"]: + return pointwise_result.astype("boolean[pyarrow]") + + original_dtype = tm.get_dtype(expected) + + was_frame = False + if isinstance(expected, pd.DataFrame): + was_frame = True + expected_data = expected.iloc[:, 0] + else: + expected_data = expected + + # the pointwise method will have retained our original dtype, while + # the op(ser, other) version will have cast to 64bit + if type(other) is int and op_name not in ["__floordiv__"]: + if original_dtype.kind == "f": + return expected.astype("float64[pyarrow]") + else: + return expected.astype("int64[pyarrow]") + elif type(other) is float: + return expected.astype("float64[pyarrow]") + + # error: Item "ExtensionDtype" of "dtype[Any] | ExtensionDtype" has + # no attribute "pyarrow_dtype" + orig_pa_type = original_dtype.pyarrow_dtype # type: ignore[union-attr] + if not was_frame and isinstance(other, pd.Series): + # i.e. test_arith_series_with_array + if not ( + pa.types.is_floating(orig_pa_type) + or ( + pa.types.is_integer(orig_pa_type) + and op_name not in ["__truediv__", "__rtruediv__"] + ) + or pa.types.is_duration(orig_pa_type) + or pa.types.is_timestamp(orig_pa_type) + or pa.types.is_date(orig_pa_type) + or pa.types.is_decimal(orig_pa_type) + ): + # base class _combine always returns int64, while + # ArrowExtensionArray does not upcast + return expected + elif not ( + (op_name == "__floordiv__" and pa.types.is_integer(orig_pa_type)) + or pa.types.is_duration(orig_pa_type) + or pa.types.is_timestamp(orig_pa_type) + or pa.types.is_date(orig_pa_type) + or pa.types.is_decimal(orig_pa_type) + ): + # base class _combine always returns int64, while + # ArrowExtensionArray does not upcast + return expected + + pa_expected = pa.array(expected_data._values) + + if pa.types.is_decimal(pa_expected.type) and pa.types.is_decimal(orig_pa_type): + # decimal precision can resize in the result type depending on data + # just compare the float values + alt = getattr(obj, op_name)(other) + alt_dtype = tm.get_dtype(alt) + assert isinstance(alt_dtype, ArrowDtype) + if op_name == "__pow__" and isinstance(other, Decimal): + # TODO: would it make more sense to retain Decimal here? + alt_dtype = ArrowDtype(pa.float64()) + elif ( + op_name == "__pow__" + and isinstance(other, pd.Series) + and other.dtype == original_dtype + ): + # TODO: would it make more sense to retain Decimal here? + alt_dtype = ArrowDtype(pa.float64()) + else: + assert pa.types.is_decimal(alt_dtype.pyarrow_dtype) + return expected.astype(alt_dtype) + + else: + pa_expected = pa_expected.cast(orig_pa_type) + + pd_expected = type(expected_data._values)(pa_expected) + if was_frame: + expected = pd.DataFrame( + pd_expected, index=expected.index, columns=expected.columns + ) + else: + expected = pd.Series(pd_expected) + return expected + + def _is_temporal_supported(self, opname, pa_dtype): + return ( + ( + opname in ("__add__", "__radd__") + or ( + opname + in ("__truediv__", "__rtruediv__", "__floordiv__", "__rfloordiv__") + and not pa_version_under14p0 + ) + ) + and pa.types.is_duration(pa_dtype) + ) or (opname in ("__sub__", "__rsub__") and pa.types.is_temporal(pa_dtype)) + + def _get_expected_exception( + self, op_name: str, obj, other + ) -> type[Exception] | tuple[type[Exception], ...] | None: + if op_name in ("__divmod__", "__rdivmod__"): + return (NotImplementedError, TypeError) + + exc: type[Exception] | tuple[type[Exception], ...] | None + dtype = tm.get_dtype(obj) + # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no + # attribute "pyarrow_dtype" + pa_dtype = dtype.pyarrow_dtype # type: ignore[union-attr] + + arrow_temporal_supported = self._is_temporal_supported(op_name, pa_dtype) + if op_name in { + "__mod__", + "__rmod__", + }: + exc = (NotImplementedError, TypeError) + elif arrow_temporal_supported: + exc = None + elif op_name in ["__add__", "__radd__"] and ( + pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) + ): + exc = None + elif not ( + pa.types.is_floating(pa_dtype) + or pa.types.is_integer(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ): + exc = TypeError + else: + exc = None + return exc + + def _get_arith_xfail_marker(self, opname, pa_dtype): + mark = None + + arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype) + + if opname == "__rpow__" and ( + pa.types.is_floating(pa_dtype) + or pa.types.is_integer(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ): + mark = pytest.mark.xfail( + reason=( + f"GH#29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL " + f"for {pa_dtype}" + ) + ) + elif arrow_temporal_supported and ( + pa.types.is_time(pa_dtype) + or ( + opname + in ("__truediv__", "__rtruediv__", "__floordiv__", "__rfloordiv__") + and pa.types.is_duration(pa_dtype) + ) + ): + mark = pytest.mark.xfail( + raises=TypeError, + reason=( + f"{opname} not supported betweenpd.NA and {pa_dtype} Python scalar" + ), + ) + elif opname == "__rfloordiv__" and ( + pa.types.is_integer(pa_dtype) or pa.types.is_decimal(pa_dtype) + ): + mark = pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason="divide by 0", + ) + elif opname == "__rtruediv__" and pa.types.is_decimal(pa_dtype): + mark = pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason="divide by 0", + ) + + return mark + + def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request): + pa_dtype = data.dtype.pyarrow_dtype + + if all_arithmetic_operators == "__rmod__" and pa.types.is_binary(pa_dtype): + pytest.skip("Skip testing Python string formatting") + + mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) + if mark is not None: + request.applymarker(mark) + + super().test_arith_series_with_scalar(data, all_arithmetic_operators) + + def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request): + pa_dtype = data.dtype.pyarrow_dtype + + if all_arithmetic_operators == "__rmod__" and ( + pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype) + ): + pytest.skip("Skip testing Python string formatting") + + mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) + if mark is not None: + request.applymarker(mark) + + super().test_arith_frame_with_scalar(data, all_arithmetic_operators) + + def test_arith_series_with_array(self, data, all_arithmetic_operators, request): + pa_dtype = data.dtype.pyarrow_dtype + + if all_arithmetic_operators in ( + "__sub__", + "__rsub__", + ) and pa.types.is_unsigned_integer(pa_dtype): + request.applymarker( + pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason=( + f"Implemented pyarrow.compute.subtract_checked " + f"which raises on overflow for {pa_dtype}" + ), + ) + ) + + mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype) + if mark is not None: + request.applymarker(mark) + + op_name = all_arithmetic_operators + ser = pd.Series(data) + # pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray + # since ser.iloc[0] is a python scalar + other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype)) + + self.check_opname(ser, op_name, other) + + def test_add_series_with_extension_array(self, data, request): + pa_dtype = data.dtype.pyarrow_dtype + + if pa_dtype.equals("int8"): + request.applymarker( + pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason=f"raises on overflow for {pa_dtype}", + ) + ) + super().test_add_series_with_extension_array(data) + + def test_invalid_other_comp(self, data, comparison_op): + # GH 48833 + with pytest.raises( + NotImplementedError, match=".* not implemented for " + ): + comparison_op(data, object()) + + @pytest.mark.parametrize("masked_dtype", ["boolean", "Int64", "Float64"]) + def test_comp_masked_numpy(self, masked_dtype, comparison_op): + # GH 52625 + data = [1, 0, None] + ser_masked = pd.Series(data, dtype=masked_dtype) + ser_pa = pd.Series(data, dtype=f"{masked_dtype.lower()}[pyarrow]") + result = comparison_op(ser_pa, ser_masked) + if comparison_op in [operator.lt, operator.gt, operator.ne]: + exp = [False, False, None] + else: + exp = [True, True, None] + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + def test_loc_setitem_with_expansion_preserves_ea_index_dtype(self, data, request): + pa_dtype = data.dtype.pyarrow_dtype + if pa.types.is_date(pa_dtype): + mark = pytest.mark.xfail( + reason="GH#62343 incorrectly casts to timestamp[ms][pyarrow]" + ) + request.applymarker(mark) + super().test_loc_setitem_with_expansion_preserves_ea_index_dtype(data) + + +class TestLogicalOps: + """Various Series and DataFrame logical ops methods.""" + + def test_kleene_or(self): + a = pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]") + b = pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]") + result = a | b + expected = pd.Series( + [True, True, True, True, False, None, True, None, None], + dtype="boolean[pyarrow]", + ) + tm.assert_series_equal(result, expected) + + result = b | a + tm.assert_series_equal(result, expected) + + # ensure we haven't mutated anything inplace + tm.assert_series_equal( + a, + pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]"), + ) + tm.assert_series_equal( + b, pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]") + ) + + @pytest.mark.parametrize( + "other, expected", + [ + (None, [True, None, None]), + (pd.NA, [True, None, None]), + (True, [True, True, True]), + (np.bool_(True), [True, True, True]), + (False, [True, False, None]), + (np.bool_(False), [True, False, None]), + ], + ) + def test_kleene_or_scalar(self, other, expected): + a = pd.Series([True, False, None], dtype="boolean[pyarrow]") + result = a | other + expected = pd.Series(expected, dtype="boolean[pyarrow]") + tm.assert_series_equal(result, expected) + + result = other | a + tm.assert_series_equal(result, expected) + + # ensure we haven't mutated anything inplace + tm.assert_series_equal( + a, pd.Series([True, False, None], dtype="boolean[pyarrow]") + ) + + def test_kleene_and(self): + a = pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]") + b = pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]") + result = a & b + expected = pd.Series( + [True, False, None, False, False, False, None, False, None], + dtype="boolean[pyarrow]", + ) + tm.assert_series_equal(result, expected) + + result = b & a + tm.assert_series_equal(result, expected) + + # ensure we haven't mutated anything inplace + tm.assert_series_equal( + a, + pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]"), + ) + tm.assert_series_equal( + b, pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]") + ) + + @pytest.mark.parametrize( + "other, expected", + [ + (None, [None, False, None]), + (pd.NA, [None, False, None]), + (True, [True, False, None]), + (False, [False, False, False]), + (np.bool_(True), [True, False, None]), + (np.bool_(False), [False, False, False]), + ], + ) + def test_kleene_and_scalar(self, other, expected): + a = pd.Series([True, False, None], dtype="boolean[pyarrow]") + result = a & other + expected = pd.Series(expected, dtype="boolean[pyarrow]") + tm.assert_series_equal(result, expected) + + result = other & a + tm.assert_series_equal(result, expected) + + # ensure we haven't mutated anything inplace + tm.assert_series_equal( + a, pd.Series([True, False, None], dtype="boolean[pyarrow]") + ) + + def test_kleene_xor(self): + a = pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]") + b = pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]") + result = a ^ b + expected = pd.Series( + [False, True, None, True, False, None, None, None, None], + dtype="boolean[pyarrow]", + ) + tm.assert_series_equal(result, expected) + + result = b ^ a + tm.assert_series_equal(result, expected) + + # ensure we haven't mutated anything inplace + tm.assert_series_equal( + a, + pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]"), + ) + tm.assert_series_equal( + b, pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]") + ) + + @pytest.mark.parametrize( + "other, expected", + [ + (None, [None, None, None]), + (pd.NA, [None, None, None]), + (True, [False, True, None]), + (np.bool_(True), [False, True, None]), + (np.bool_(False), [True, False, None]), + ], + ) + def test_kleene_xor_scalar(self, other, expected): + a = pd.Series([True, False, None], dtype="boolean[pyarrow]") + result = a ^ other + expected = pd.Series(expected, dtype="boolean[pyarrow]") + tm.assert_series_equal(result, expected) + + result = other ^ a + tm.assert_series_equal(result, expected) + + # ensure we haven't mutated anything inplace + tm.assert_series_equal( + a, pd.Series([True, False, None], dtype="boolean[pyarrow]") + ) + + @pytest.mark.parametrize( + "op, exp", + [ + ["__and__", True], + ["__or__", True], + ["__xor__", False], + ], + ) + def test_logical_masked_numpy(self, op, exp): + # GH 52625 + data = [True, False, None] + ser_masked = pd.Series(data, dtype="boolean") + ser_pa = pd.Series(data, dtype="boolean[pyarrow]") + result = getattr(ser_pa, op)(ser_masked) + expected = pd.Series([exp, False, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("pa_type", tm.ALL_INT_PYARROW_DTYPES) +def test_bitwise(pa_type): + # GH 54495 + dtype = ArrowDtype(pa_type) + left = pd.Series([1, None, 3, 4], dtype=dtype) + right = pd.Series([None, 3, 5, 4], dtype=dtype) + + result = left | right + expected = pd.Series([None, None, 3 | 5, 4 | 4], dtype=dtype) + tm.assert_series_equal(result, expected) + + result = left & right + expected = pd.Series([None, None, 3 & 5, 4 & 4], dtype=dtype) + tm.assert_series_equal(result, expected) + + result = left ^ right + expected = pd.Series([None, None, 3 ^ 5, 4 ^ 4], dtype=dtype) + tm.assert_series_equal(result, expected) + + result = ~left + expected = ~(left.fillna(0).to_numpy()) + expected = pd.Series(expected, dtype=dtype).mask(left.isnull()) + tm.assert_series_equal(result, expected) + + +def test_arrowdtype_construct_from_string_type_with_unsupported_parameters(): + with pytest.raises(NotImplementedError, match="Passing pyarrow type"): + ArrowDtype.construct_from_string("not_a_real_dype[s, tz=UTC][pyarrow]") + + with pytest.raises(NotImplementedError, match="Passing pyarrow type"): + ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]") + + +def test_arrowdtype_construct_from_string_supports_dt64tz(): + # as of GH#50689, timestamptz is supported + dtype = ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]") + expected = ArrowDtype(pa.timestamp("s", "UTC")) + assert dtype == expected + + +def test_arrowdtype_construct_from_string_type_only_one_pyarrow(): + # GH#51225 + invalid = "int64[pyarrow]foobar[pyarrow]" + msg = ( + r"Passing pyarrow type specific parameters \(\[pyarrow\]\) in the " + r"string is not supported\." + ) + with pytest.raises(NotImplementedError, match=msg): + pd.Series(range(3), dtype=invalid) + + +def test_arrow_string_multiplication(): + # GH 56537 + binary = pd.Series(["abc", "defg"], dtype=ArrowDtype(pa.string())) + repeat = pd.Series([2, -2], dtype="int64[pyarrow]") + result = binary * repeat + expected = pd.Series(["abcabc", ""], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + reflected_result = repeat * binary + tm.assert_series_equal(result, reflected_result) + + +def test_arrow_string_multiplication_scalar_repeat(): + binary = pd.Series(["abc", "defg"], dtype=ArrowDtype(pa.string())) + result = binary * 2 + expected = pd.Series(["abcabc", "defgdefg"], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + reflected_result = 2 * binary + tm.assert_series_equal(reflected_result, expected) + + +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest", "midpoint"] +) +@pytest.mark.parametrize("quantile", [0.5, [0.5, 0.5]]) +def test_quantile(data, interpolation, quantile, request): + pa_dtype = data.dtype.pyarrow_dtype + + data = data.take([0, 0, 0]) + ser = pd.Series(data) + + if ( + pa.types.is_string(pa_dtype) + or pa.types.is_binary(pa_dtype) + or pa.types.is_boolean(pa_dtype) + ): + # For string, bytes, and bool, we don't *expect* to have quantile work + # Note this matches the non-pyarrow behavior + msg = r"Function 'quantile' has no kernel matching input types \(.*\)" + with pytest.raises(pa.ArrowNotImplementedError, match=msg): + ser.quantile(q=quantile, interpolation=interpolation) + return + + if ( + pa.types.is_integer(pa_dtype) + or pa.types.is_floating(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ): + pass + elif pa.types.is_temporal(data._pa_array.type): + pass + else: + request.applymarker( + pytest.mark.xfail( + raises=pa.ArrowNotImplementedError, + reason=f"quantile not supported by pyarrow for {pa_dtype}", + ) + ) + data = data.take([0, 0, 0]) + ser = pd.Series(data) + result = ser.quantile(q=quantile, interpolation=interpolation) + + if pa.types.is_timestamp(pa_dtype) and interpolation not in ["lower", "higher"]: + # rounding error will make the check below fail + # (e.g. '2020-01-01 01:01:01.000001' vs '2020-01-01 01:01:01.000001024'), + # so we'll check for now that we match the numpy analogue + if pa_dtype.tz: + pd_dtype = f"M8[{pa_dtype.unit}, {pa_dtype.tz}]" + else: + pd_dtype = f"M8[{pa_dtype.unit}]" + ser_np = ser.astype(pd_dtype) + + expected = ser_np.quantile(q=quantile, interpolation=interpolation) + if quantile == 0.5: + if pa_dtype.unit == "us": + expected = expected.to_pydatetime(warn=False) + assert result == expected + else: + if pa_dtype.unit == "us": + expected = expected.dt.floor("us") + tm.assert_series_equal(result, expected.astype(data.dtype)) + return + + if quantile == 0.5: + assert result == data[0] + else: + # Just check the values + expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5]) + if ( + pa.types.is_integer(pa_dtype) + or pa.types.is_floating(pa_dtype) + or pa.types.is_decimal(pa_dtype) + ): + expected = expected.astype("float64[pyarrow]") + result = result.astype("float64[pyarrow]") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "take_idx, exp_idx", + [[[0, 0, 2, 2, 4, 4], [4, 0]], [[0, 0, 0, 2, 4, 4], [0]]], + ids=["multi_mode", "single_mode"], +) +def test_mode_dropna_true(data_for_grouping, take_idx, exp_idx): + data = data_for_grouping.take(take_idx) + ser = pd.Series(data) + result = ser.mode(dropna=True) + expected = pd.Series(data_for_grouping.take(exp_idx)) + tm.assert_series_equal(result, expected) + + +def test_mode_dropna_false_mode_na(data): + # GH 50982 + more_nans = pd.Series([None, None, data[0]], dtype=data.dtype) + result = more_nans.mode(dropna=False) + expected = pd.Series([None], dtype=data.dtype) + tm.assert_series_equal(result, expected) + + expected = pd.Series([data[0], None], dtype=data.dtype) + result = expected.mode(dropna=False) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "arrow_dtype, expected_type", + [ + [pa.binary(), bytes], + [pa.binary(16), bytes], + [pa.large_binary(), bytes], + [pa.large_string(), str], + [pa.list_(pa.int64()), list], + [pa.large_list(pa.int64()), list], + [pa.map_(pa.string(), pa.int64()), list], + [pa.struct([("f1", pa.int8()), ("f2", pa.string())]), dict], + [pa.dictionary(pa.int64(), pa.int64()), CategoricalDtypeType], + ], +) +def test_arrow_dtype_type(arrow_dtype, expected_type): + # GH 51845 + # TODO: Redundant with test_getitem_scalar once arrow_dtype exists in data fixture + assert ArrowDtype(arrow_dtype).type == expected_type + + +def test_is_bool_dtype(): + # GH 22667 + data = ArrowExtensionArray(pa.array([True, False, True])) + assert is_bool_dtype(data) + assert pd.core.common.is_bool_indexer(data) + s = pd.Series(range(len(data))) + result = s[data] + expected = s[np.asarray(data)] + tm.assert_series_equal(result, expected) + + +def test_is_numeric_dtype(data): + # GH 50563 + pa_type = data.dtype.pyarrow_dtype + if ( + pa.types.is_floating(pa_type) + or pa.types.is_integer(pa_type) + or pa.types.is_decimal(pa_type) + ): + assert is_numeric_dtype(data) + else: + assert not is_numeric_dtype(data) + + +def test_is_integer_dtype(data): + # GH 50667 + pa_type = data.dtype.pyarrow_dtype + if pa.types.is_integer(pa_type): + assert is_integer_dtype(data) + else: + assert not is_integer_dtype(data) + + +def test_is_signed_integer_dtype(data): + pa_type = data.dtype.pyarrow_dtype + if pa.types.is_signed_integer(pa_type): + assert is_signed_integer_dtype(data) + else: + assert not is_signed_integer_dtype(data) + + +def test_is_unsigned_integer_dtype(data): + pa_type = data.dtype.pyarrow_dtype + if pa.types.is_unsigned_integer(pa_type): + assert is_unsigned_integer_dtype(data) + else: + assert not is_unsigned_integer_dtype(data) + + +def test_is_datetime64_any_dtype(data): + pa_type = data.dtype.pyarrow_dtype + if pa.types.is_timestamp(pa_type) or pa.types.is_date(pa_type): + assert is_datetime64_any_dtype(data) + else: + assert not is_datetime64_any_dtype(data) + + +def test_is_float_dtype(data): + pa_type = data.dtype.pyarrow_dtype + if pa.types.is_floating(pa_type): + assert is_float_dtype(data) + else: + assert not is_float_dtype(data) + + +def test_pickle_roundtrip(data): + # GH 42600 + expected = pd.Series(data) + expected_sliced = expected.head(2) + full_pickled = pickle.dumps(expected) + sliced_pickled = pickle.dumps(expected_sliced) + + assert len(full_pickled) > len(sliced_pickled) + + result = pickle.loads(full_pickled) + tm.assert_series_equal(result, expected) + + result_sliced = pickle.loads(sliced_pickled) + tm.assert_series_equal(result_sliced, expected_sliced) + + +def test_astype_from_non_pyarrow(data): + # GH49795 + np_arr = data.to_numpy() + pd_array = pd.array(np_arr, dtype=np_arr.dtype) + result = pd_array.astype(data.dtype) + assert not isinstance(pd_array.dtype, ArrowDtype) + assert isinstance(result.dtype, ArrowDtype) + tm.assert_extension_array_equal(result, data) + + +def test_astype_float_from_non_pyarrow_str(): + # GH50430 + ser = pd.Series(["1.0"]) + result = ser.astype("float64[pyarrow]") + expected = pd.Series([1.0], dtype="float64[pyarrow]") + tm.assert_series_equal(result, expected) + + +def test_astype_errors_ignore(): + # GH 55399 + expected = pd.DataFrame({"col": [17000000]}, dtype="int32[pyarrow]") + result = expected.astype("float[pyarrow]", errors="ignore") + tm.assert_frame_equal(result, expected) + + +def test_to_numpy_with_defaults(data, using_nan_is_na): + # GH49973 + result = data.to_numpy() + + pa_type = data._pa_array.type + if pa.types.is_duration(pa_type) or pa.types.is_timestamp(pa_type): + pytest.skip("Tested in test_to_numpy_temporal") + elif pa.types.is_date(pa_type): + expected = np.array(list(data)) + else: + expected = np.array(data._pa_array) + + if data._hasna and (not is_numeric_dtype(data.dtype) or not using_nan_is_na): + expected = expected.astype(object) + expected[pd.isna(data)] = pd.NA + + tm.assert_numpy_array_equal(result, expected) + + +def test_to_numpy_int_with_na(using_nan_is_na): + # GH51227: ensure to_numpy does not convert int to float + data = [1, None] + arr = pd.array(data, dtype="int64[pyarrow]") + result = arr.to_numpy() + if not using_nan_is_na: + expected = np.array([1, pd.NA], dtype=object) + else: + expected = np.array([1, np.nan]) + assert isinstance(result[0], float) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("na_val, exp", [(lib.no_default, np.nan), (1, 1)]) +def test_to_numpy_null_array(na_val, exp): + # GH#52443 + arr = pd.array([pd.NA, pd.NA], dtype="null[pyarrow]") + result = arr.to_numpy(dtype="float64", na_value=na_val) + expected = np.array([exp] * 2, dtype="float64") + tm.assert_numpy_array_equal(result, expected) + + +def test_to_numpy_null_array_no_dtype(): + # GH#52443 + arr = pd.array([pd.NA, pd.NA], dtype="null[pyarrow]") + result = arr.to_numpy(dtype=None) + expected = np.array([pd.NA] * 2, dtype="object") + tm.assert_numpy_array_equal(result, expected) + + +def test_to_numpy_without_dtype(): + # GH 54808 + arr = pd.array([True, pd.NA], dtype="boolean[pyarrow]") + result = arr.to_numpy(na_value=False) + expected = np.array([True, False], dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) + + arr = pd.array([1.0, pd.NA], dtype="float32[pyarrow]") + result = arr.to_numpy(na_value=0.0) + expected = np.array([1.0, 0.0], dtype=np.float32) + tm.assert_numpy_array_equal(result, expected) + + +def test_setitem_null_slice(data): + # GH50248 + orig = data.copy() + + result = orig.copy() + result[:] = data[0] + expected = ArrowExtensionArray._from_sequence( + [data[0]] * len(data), + dtype=data.dtype, + ) + tm.assert_extension_array_equal(result, expected) + + result = orig.copy() + result[:] = data[::-1] + expected = data[::-1] + tm.assert_extension_array_equal(result, expected) + + result = orig.copy() + result[:] = data.tolist() + expected = data + tm.assert_extension_array_equal(result, expected) + + +def test_setitem_invalid_dtype(data): + # GH50248 + pa_type = data._pa_array.type + if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type): + fill_value = 123 + err = TypeError + msg = "Invalid value '123' for dtype" + elif ( + pa.types.is_integer(pa_type) + or pa.types.is_floating(pa_type) + or pa.types.is_boolean(pa_type) + ): + fill_value = "foo" + err = pa.ArrowInvalid + msg = "Could not convert" + else: + fill_value = "foo" + err = TypeError + msg = "Invalid value 'foo' for dtype" + with pytest.raises(err, match=msg): + data[:] = fill_value + + +def test_from_arrow_respecting_given_dtype(): + date_array = pa.array( + [pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31")], type=pa.date32() + ) + result = date_array.to_pandas( + types_mapper={pa.date32(): ArrowDtype(pa.date64())}.get + ) + expected = pd.Series( + [pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31")], + dtype=ArrowDtype(pa.date64()), + ) + tm.assert_series_equal(result, expected) + + +def test_from_arrow_respecting_given_dtype_unsafe(): + array = pa.array([1.5, 2.5], type=pa.float64()) + with tm.external_error_raised(pa.ArrowInvalid): + array.to_pandas(types_mapper={pa.float64(): ArrowDtype(pa.int64())}.get) + + +def test_round(): + dtype = "float64[pyarrow]" + + ser = pd.Series([0.0, 1.23, 2.56, pd.NA], dtype=dtype) + result = ser.round(1) + expected = pd.Series([0.0, 1.2, 2.6, pd.NA], dtype=dtype) + tm.assert_series_equal(result, expected) + + ser = pd.Series([123.4, pd.NA, 56.78], dtype=dtype) + result = ser.round(-1) + expected = pd.Series([120.0, pd.NA, 60.0], dtype=dtype) + tm.assert_series_equal(result, expected) + + +def test_searchsorted_with_na_raises(data_for_sorting, as_series): + # GH50447 + b, c, a = data_for_sorting + arr = data_for_sorting.take([2, 0, 1]) # to get [a, b, c] + arr[-1] = pd.NA + + if as_series: + arr = pd.Series(arr) + + msg = ( + "searchsorted requires array to be sorted, " + "which is impossible with NAs present." + ) + with pytest.raises(ValueError, match=msg): + arr.searchsorted(b) + + +def test_sort_values_dictionary(): + df = pd.DataFrame( + { + "a": pd.Series( + ["x", "y"], dtype=ArrowDtype(pa.dictionary(pa.int32(), pa.string())) + ), + "b": [1, 2], + }, + ) + expected = df.copy() + result = df.sort_values(by=["a", "b"]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("pat", ["abc", "a[a-z]{2}"]) +def test_str_count(pat): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.count(pat) + expected = pd.Series([1, None], dtype=ArrowDtype(pa.int32())) + tm.assert_series_equal(result, expected) + + +def test_str_count_flags_unsupported(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match="count not"): + ser.str.count("abc", flags=1) + + +@pytest.mark.parametrize( + "side, str_func", [["left", "rjust"], ["right", "ljust"], ["both", "center"]] +) +def test_str_pad(side, str_func): + ser = pd.Series(["a", None], dtype=ArrowDtype(pa.string())) + result = ser.str.pad(width=3, side=side, fillchar="x") + expected = pd.Series( + [getattr("a", str_func)(3, "x"), None], dtype=ArrowDtype(pa.string()) + ) + tm.assert_series_equal(result, expected) + + +def test_str_pad_invalid_side(): + ser = pd.Series(["a", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(ValueError, match="Invalid side: foo"): + ser.str.pad(3, "foo", "x") + + +@pytest.mark.parametrize( + "pat, case, na, regex, exp", + [ + ["ab", False, None, False, [True, None]], + ["Ab", True, None, False, [False, None]], + ["ab", False, True, False, [True, True]], + ["a[a-z]{1}", False, None, True, [True, None]], + ["A[a-z]{1}", True, None, True, [False, None]], + ], +) +def test_str_contains(pat, case, na, regex, exp): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.contains(pat, case=case, na=na, regex=regex) + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +def test_str_contains_flags_unsupported(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match="contains not"): + ser.str.contains("a", flags=1) + + +def test_str_contains_re2_unicode_escape(): + # GH 63901 + ser = pd.Series(["a", "\u0e01", None], dtype=ArrowDtype(pa.string())) + result = ser.str.contains(r"[\x{0e00}-\x{0e7f}]") + expected = pd.Series([False, True, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "side, pat, na, exp", + [ + ["startswith", "ab", None, [True, None, False]], + ["startswith", "b", False, [False, False, False]], + ["endswith", "b", True, [False, True, False]], + ["endswith", "bc", None, [True, None, False]], + ["startswith", ("a", "e", "g"), None, [True, None, True]], + ["endswith", ("a", "c", "g"), None, [True, None, True]], + ["startswith", (), None, [False, None, False]], + ["endswith", (), None, [False, None, False]], + ], +) +def test_str_start_ends_with(side, pat, na, exp): + ser = pd.Series(["abc", None, "efg"], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, side)(pat, na=na) + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("side", ("startswith", "endswith")) +def test_str_starts_ends_with_all_nulls_empty_tuple(side): + ser = pd.Series([None, None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, side)(()) + + # bool datatype preserved for all nulls. + expected = pd.Series([None, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "arg_name, arg", + [["pat", re.compile("b")], ["repl", str], ["case", False], ["flags", 1]], +) +def test_str_replace_unsupported(arg_name, arg): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + kwargs = {"pat": "b", "repl": "x", "regex": True} + kwargs[arg_name] = arg + with pytest.raises(NotImplementedError, match="replace is not supported"): + ser.str.replace(**kwargs) + + +@pytest.mark.parametrize( + "pat, repl, n, regex, exp", + [ + ["a", "x", -1, False, ["xbxc", None]], + ["a", "x", 1, False, ["xbac", None]], + ["[a-b]", "x", -1, True, ["xxxc", None]], + ], +) +def test_str_replace(pat, repl, n, regex, exp): + ser = pd.Series(["abac", None], dtype=ArrowDtype(pa.string())) + result = ser.str.replace(pat, repl, n=n, regex=regex) + expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +def test_str_replace_re2_unicode_property(): + ser = pd.Series(["Jan", "Feb", None], dtype=ArrowDtype(pa.string())) + result = ser.str.replace(r"\p{Lu}", "U", regex=True) + expected = pd.Series(["Uan", "Ueb", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +def test_str_replace_negative_n(): + # GH 56404 + ser = pd.Series(["abc", "aaaaaa"], dtype=ArrowDtype(pa.string())) + actual = ser.str.replace("a", "", -3, True) + expected = pd.Series(["bc", ""], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(expected, actual) + + # Same bug for pyarrow-backed StringArray GH#59628 + ser2 = ser.astype(pd.StringDtype(storage="pyarrow")) + actual2 = ser2.str.replace("a", "", -3, True) + expected2 = expected.astype(ser2.dtype) + tm.assert_series_equal(expected2, actual2) + + ser3 = ser.astype(pd.StringDtype(storage="pyarrow", na_value=np.nan)) + actual3 = ser3.str.replace("a", "", -3, True) + expected3 = expected.astype(ser3.dtype) + tm.assert_series_equal(expected3, actual3) + + +def test_str_repeat_unsupported(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + with pytest.raises(NotImplementedError, match="repeat is not"): + ser.str.repeat([1, 2]) + + +def test_str_repeat(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.repeat(2) + expected = pd.Series(["abcabc", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "pat, case, na, exp", + [ + ["ab", False, None, [True, None]], + ["Ab", True, None, [False, None]], + ["bc", True, None, [False, None]], + ["ab", False, True, [True, True]], + ["a[a-z]{1}", False, None, [True, None]], + ["A[a-z]{1}", True, None, [False, None]], + ], +) +def test_str_match(pat, case, na, exp): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.match(pat, case=case, na=na) + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "pat, case, na, exp", + # Note: keep cases in sync with + # pandas/tests/strings/test_find_replace.py::test_str_fullmatch_extra_cases + [ + ["abc", False, None, [True, False, False, None]], + ["Abc", True, None, [False, False, False, None]], + ["bc", True, None, [False, False, False, None]], + ["ab", False, None, [False, False, False, None]], + ["a[a-z]{2}", False, None, [True, False, False, None]], + ["A[a-z]{1}", True, None, [False, False, False, None]], + # GH Issue: #56652 + ["abc$", False, None, [True, False, False, None]], + ["abc\\$", False, None, [False, True, False, None]], + ["Abc$", True, None, [False, False, False, None]], + ["Abc\\$", True, None, [False, False, False, None]], + # https://github.com/pandas-dev/pandas/issues/61072 + ["(abc)|(abx)", True, None, [True, False, False, None]], + ["((abc)|(abx))", True, None, [True, False, False, None]], + ], +) +def test_str_fullmatch(pat, case, na, exp): + ser = pd.Series(["abc", "abc$", "$abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.fullmatch(pat, case=case, na=na) + expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "sub, start, end, exp, exp_type", + [ + ["ab", 0, None, [0, None], pa.int32()], + ["bc", 1, 3, [1, None], pa.int64()], + ["ab", 1, 3, [-1, None], pa.int64()], + ["ab", -3, -3, [-1, None], pa.int64()], + ], +) +def test_str_find(sub, start, end, exp, exp_type): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.find(sub, start=start, end=end) + expected = pd.Series(exp, dtype=ArrowDtype(exp_type)) + tm.assert_series_equal(result, expected) + + +def test_str_find_negative_start(): + # GH 56411 + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.find(sub="b", start=-1000, end=3) + expected = pd.Series([1, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + +def test_str_find_no_end(): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.find("ab", start=1) + expected = pd.Series([-1, None], dtype="int64[pyarrow]") + tm.assert_series_equal(result, expected) + + +def test_str_find_negative_start_negative_end(): + # GH 56791 + ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) + result = ser.str.find(sub="d", start=-6, end=-3) + expected = pd.Series([3, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + +def test_str_find_large_start(): + # GH 56791 + ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) + result = ser.str.find(sub="d", start=16) + expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("start", [-15, -3, 0, 1, 15, None]) +@pytest.mark.parametrize("end", [-15, -1, 0, 3, 15, None]) +@pytest.mark.parametrize("sub", ["", "az", "abce", "a", "caa"]) +def test_str_find_e2e(start, end, sub): + s = pd.Series( + ["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""], + dtype=ArrowDtype(pa.string()), + ) + object_series = s.astype(pd.StringDtype(storage="python")) + result = s.str.find(sub, start, end) + expected = object_series.str.find(sub, start, end).astype(result.dtype) + tm.assert_series_equal(result, expected) + + arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow")) + result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype) + tm.assert_series_equal(result2, expected) + + +def test_str_find_negative_start_negative_end_no_match(): + # GH 56791 + ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string())) + result = ser.str.find(sub="d", start=-3, end=-6) + expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "i, exp", + [ + [1, ["b", "e", None]], + [-1, ["c", "e", None]], + [2, ["c", None, None]], + [-3, ["a", None, None]], + [4, [None, None, None]], + ], +) +def test_str_get(i, exp): + ser = pd.Series(["abc", "de", None], dtype=ArrowDtype(pa.string())) + result = ser.str.get(i) + expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.xfail( + reason="TODO: StringMethods._validate should support Arrow list types", + raises=AttributeError, +) +def test_str_join(): + ser = pd.Series(ArrowExtensionArray(pa.array([list("abc"), list("123"), None]))) + result = ser.str.join("=") + expected = pd.Series(["a=b=c", "1=2=3", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +def test_str_join_string_type(): + ser = pd.Series(ArrowExtensionArray(pa.array(["abc", "123", None]))) + result = ser.str.join("=") + expected = pd.Series(["a=b=c", "1=2=3", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "start, stop, step, exp", + [ + [None, 2, None, ["ab", None]], + [None, 2, 1, ["ab", None]], + [1, 3, 1, ["bc", None]], + (None, None, -1, ["dcba", None]), + ], +) +def test_str_slice(start, stop, step, exp): + ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string())) + result = ser.str.slice(start, stop, step) + expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "start, stop, repl, exp", + [ + [1, 2, "x", ["axcd", None]], + [None, 2, "x", ["xcd", None]], + [None, 2, None, ["cd", None]], + ], +) +def test_str_slice_replace(start, stop, repl, exp): + ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string())) + result = ser.str.slice_replace(start, stop, repl) + expected = pd.Series(exp, dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "value, method, exp", + [ + ["a1c", "isalnum", True], + ["!|,", "isalnum", False], + ["aaa", "isalpha", True], + ["!!!", "isalpha", False], + ["٠", "isdecimal", True], # noqa: RUF001 + ["~!", "isdecimal", False], + ["2", "isdigit", True], + ["~", "isdigit", False], + ["aaa", "islower", True], + ["aaA", "islower", False], + ["123", "isnumeric", True], + ["11I", "isnumeric", False], + [" ", "isspace", True], + ["", "isspace", False], + ["The That", "istitle", True], + ["the That", "istitle", False], + ["AAA", "isupper", True], + ["AAc", "isupper", False], + ], +) +def test_str_is_functions(value, method, exp): + ser = pd.Series([value, None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)() + expected = pd.Series([exp, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, exp", + [ + ["capitalize", "Abc def"], + ["title", "Abc Def"], + ["swapcase", "AbC Def"], + ["lower", "abc def"], + ["upper", "ABC DEF"], + ["casefold", "abc def"], + ], +) +def test_str_transform_functions(method, exp): + ser = pd.Series(["aBc dEF", None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)() + expected = pd.Series([exp, None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +def test_str_len(): + ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string())) + result = ser.str.len() + expected = pd.Series([4, None], dtype=ArrowDtype(pa.int32())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, to_strip, val", + [ + ["strip", None, " abc "], + ["strip", "x", "xabcx"], + ["lstrip", None, " abc"], + ["lstrip", "x", "xabc"], + ["rstrip", None, "abc "], + ["rstrip", "x", "abcx"], + ], +) +def test_str_strip(method, to_strip, val): + ser = pd.Series([val, None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)(to_strip=to_strip) + expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("val", ["abc123", "abc"]) +def test_str_removesuffix(val): + ser = pd.Series([val, None], dtype=ArrowDtype(pa.string())) + result = ser.str.removesuffix("123") + expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("val", ["123abc", "abc"]) +def test_str_removeprefix(val): + ser = pd.Series([val, None], dtype=ArrowDtype(pa.string())) + result = ser.str.removeprefix("123") + expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("errors", ["ignore", "strict"]) +@pytest.mark.parametrize( + "encoding, exp", + [ + ("utf8", {"little": b"abc", "big": "abc"}), + ( + "utf32", + { + "little": b"\xff\xfe\x00\x00a\x00\x00\x00b\x00\x00\x00c\x00\x00\x00", + "big": b"\x00\x00\xfe\xff\x00\x00\x00a\x00\x00\x00b\x00\x00\x00c", + }, + ), + ], + ids=["utf8", "utf32"], +) +def test_str_encode(errors, encoding, exp): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.encode(encoding, errors) + expected = pd.Series([exp[sys.byteorder], None], dtype=ArrowDtype(pa.binary())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("flags", [0, 2]) +def test_str_findall(flags): + ser = pd.Series(["abc", "efg", None], dtype=ArrowDtype(pa.string())) + result = ser.str.findall("b", flags=flags) + expected = pd.Series([["b"], [], None], dtype=ArrowDtype(pa.list_(pa.string()))) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["index", "rindex"]) +@pytest.mark.parametrize( + "start, end", + [ + [0, None], + [1, 4], + ], +) +def test_str_r_index(method, start, end): + ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)("c", start, end) + expected = pd.Series([2, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + with pytest.raises(ValueError, match="substring not found"): + getattr(ser.str, method)("foo", start, end) + + +@pytest.mark.parametrize("form", ["NFC", "NFKC"]) +def test_str_normalize(form): + ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string())) + result = ser.str.normalize(form) + expected = ser.copy() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "start, end", + [ + [0, None], + [1, 4], + ], +) +def test_str_rfind(start, end): + ser = pd.Series(["abcba", "foo", None], dtype=ArrowDtype(pa.string())) + result = ser.str.rfind("c", start, end) + expected = pd.Series([2, -1, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + +def test_str_translate(): + ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string())) + result = ser.str.translate({97: "b"}) + expected = pd.Series(["bbcbb", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +def test_str_wrap(): + ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string())) + result = ser.str.wrap(3) + expected = pd.Series(["abc\nba", None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +def test_get_dummies(): + ser = pd.Series(["a|b", None, "a|c"], dtype=ArrowDtype(pa.string())) + result = ser.str.get_dummies() + expected = pd.DataFrame( + [[True, True, False], [False, False, False], [True, False, True]], + dtype=ArrowDtype(pa.bool_()), + columns=["a", "b", "c"], + ) + tm.assert_frame_equal(result, expected) + + +def test_str_partition(): + ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string())) + result = ser.str.partition("b") + expected = pd.DataFrame( + [["a", "b", "cba"], [None, None, None]], + dtype=ArrowDtype(pa.string()), + columns=pd.RangeIndex(3), + ) + tm.assert_frame_equal(result, expected, check_column_type=True) + + result = ser.str.partition("b", expand=False) + expected = pd.Series(ArrowExtensionArray(pa.array([["a", "b", "cba"], None]))) + tm.assert_series_equal(result, expected) + + result = ser.str.rpartition("b") + expected = pd.DataFrame( + [["abc", "b", "a"], [None, None, None]], + dtype=ArrowDtype(pa.string()), + columns=pd.RangeIndex(3), + ) + tm.assert_frame_equal(result, expected, check_column_type=True) + + result = ser.str.rpartition("b", expand=False) + expected = pd.Series(ArrowExtensionArray(pa.array([["abc", "b", "a"], None]))) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["rsplit", "split"]) +def test_str_split_pat_none(method): + # GH 56271 + ser = pd.Series(["a1 cbc\nb", None], dtype=ArrowDtype(pa.string())) + result = getattr(ser.str, method)() + expected = pd.Series(ArrowExtensionArray(pa.array([["a1", "cbc", "b"], None]))) + tm.assert_series_equal(result, expected) + + +def test_str_split(): + # GH 52401 + ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string())) + result = ser.str.split("c") + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("c", n=1) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1", "bcb"], ["a2", "bcb"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("[1-2]", regex=True) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a", "cbcb"], ["a", "cbcb"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.split("[1-2]", regex=True, expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a", "a", None])), + 1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])), + } + ) + tm.assert_frame_equal(result, expected) + + result = ser.str.split("1", expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a", "a2cbcb", None])), + 1: ArrowExtensionArray(pa.array(["cbcb", None, None])), + } + ) + tm.assert_frame_equal(result, expected) + + +def test_str_rsplit(): + # GH 52401 + ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string())) + result = ser.str.rsplit("c") + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.rsplit("c", n=1) + expected = pd.Series( + ArrowExtensionArray(pa.array([["a1cb", "b"], ["a2cb", "b"], None])) + ) + tm.assert_series_equal(result, expected) + + result = ser.str.rsplit("c", n=1, expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a1cb", "a2cb", None])), + 1: ArrowExtensionArray(pa.array(["b", "b", None])), + } + ) + tm.assert_frame_equal(result, expected) + + result = ser.str.rsplit("1", expand=True) + expected = pd.DataFrame( + { + 0: ArrowExtensionArray(pa.array(["a", "a2cbcb", None])), + 1: ArrowExtensionArray(pa.array(["cbcb", None, None])), + } + ) + tm.assert_frame_equal(result, expected) + + +def test_str_extract_non_symbolic(): + ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string())) + with pytest.raises(ValueError, match="pat=.* must contain a symbolic group name."): + ser.str.extract(r"[ab](\d)") + + +@pytest.mark.parametrize("expand", [True, False]) +def test_str_extract(expand): + ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string())) + result = ser.str.extract(r"(?P[ab])(?P\d)", expand=expand) + expected = pd.DataFrame( + { + "letter": ArrowExtensionArray(pa.array(["a", "b", None])), + "digit": ArrowExtensionArray(pa.array(["1", "2", None])), + } + ) + tm.assert_frame_equal(result, expected) + + +def test_str_extract_expand(): + ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string())) + result = ser.str.extract(r"[ab](?P\d)", expand=True) + expected = pd.DataFrame( + { + "digit": ArrowExtensionArray(pa.array(["1", "2", None])), + } + ) + tm.assert_frame_equal(result, expected) + + result = ser.str.extract(r"[ab](?P\d)", expand=False) + expected = pd.Series(ArrowExtensionArray(pa.array(["1", "2", None])), name="digit") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("unit", ["ns", "us", "ms", "s"]) +def test_duration_from_strings_with_nat(unit): + # GH51175 + strings = ["1000", "NaT"] + pa_type = pa.duration(unit) + dtype = ArrowDtype(pa_type) + result = ArrowExtensionArray._from_sequence_of_strings(strings, dtype=dtype) + expected = ArrowExtensionArray(pa.array([1000, None], type=pa_type)) + tm.assert_extension_array_equal(result, expected) + + +def test_unsupported_dt(data): + pa_dtype = data.dtype.pyarrow_dtype + if not pa.types.is_temporal(pa_dtype): + with pytest.raises( + AttributeError, match="Can only use .dt accessor with datetimelike values" + ): + pd.Series(data).dt + + +@pytest.mark.parametrize( + "prop, expected", + [ + ["year", 2023], + ["day", 2], + ["day_of_week", 0], + ["dayofweek", 0], + ["weekday", 0], + ["day_of_year", 2], + ["dayofyear", 2], + ["hour", 3], + ["minute", 4], + ["is_leap_year", False], + ["microsecond", 2000], + ["month", 1], + ["nanosecond", 6], + ["quarter", 1], + ["second", 7], + ["date", date(2023, 1, 2)], + ["time", time(3, 4, 7, 2000)], + ], +) +def test_dt_properties(prop, expected): + ser = pd.Series( + [ + pd.Timestamp( + year=2023, + month=1, + day=2, + hour=3, + minute=4, + second=7, + microsecond=2000, + nanosecond=6, + ), + None, + ], + dtype=ArrowDtype(pa.timestamp("ns")), + ) + result = getattr(ser.dt, prop) + exp_type = None + if isinstance(expected, date): + exp_type = pa.date32() + elif isinstance(expected, time): + exp_type = pa.time64("ns") + expected = pd.Series(ArrowExtensionArray(pa.array([expected, None], type=exp_type))) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("microsecond", [2000, 5, 0]) +def test_dt_microsecond(microsecond): + # GH 59183 + ser = pd.Series( + [ + pd.Timestamp( + year=2024, + month=7, + day=7, + second=5, + microsecond=microsecond, + nanosecond=6, + ), + None, + ], + dtype=ArrowDtype(pa.timestamp("ns")), + ) + result = ser.dt.microsecond + expected = pd.Series([microsecond, None], dtype="int64[pyarrow]") + tm.assert_series_equal(result, expected) + + +def test_dt_is_month_start_end(): + ser = pd.Series( + [ + datetime(year=2023, month=12, day=2, hour=3), + datetime(year=2023, month=1, day=1, hour=3), + datetime(year=2023, month=3, day=31, hour=3), + None, + ], + dtype=ArrowDtype(pa.timestamp("us")), + ) + result = ser.dt.is_month_start + expected = pd.Series([False, True, False, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + result = ser.dt.is_month_end + expected = pd.Series([False, False, True, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +def test_dt_is_year_start_end(): + ser = pd.Series( + [ + datetime(year=2023, month=12, day=31, hour=3), + datetime(year=2023, month=1, day=1, hour=3), + datetime(year=2023, month=3, day=31, hour=3), + None, + ], + dtype=ArrowDtype(pa.timestamp("us")), + ) + result = ser.dt.is_year_start + expected = pd.Series([False, True, False, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + result = ser.dt.is_year_end + expected = pd.Series([True, False, False, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +def test_dt_is_quarter_start_end(): + ser = pd.Series( + [ + datetime(year=2023, month=11, day=30, hour=3), + datetime(year=2023, month=1, day=1, hour=3), + datetime(year=2023, month=3, day=31, hour=3), + None, + ], + dtype=ArrowDtype(pa.timestamp("us")), + ) + result = ser.dt.is_quarter_start + expected = pd.Series([False, True, False, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + result = ser.dt.is_quarter_end + expected = pd.Series([False, False, True, None], dtype=ArrowDtype(pa.bool_())) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["days_in_month", "daysinmonth"]) +def test_dt_days_in_month(method): + ser = pd.Series( + [ + datetime(year=2023, month=3, day=30, hour=3), + datetime(year=2023, month=4, day=1, hour=3), + datetime(year=2023, month=2, day=3, hour=3), + None, + ], + dtype=ArrowDtype(pa.timestamp("us")), + ) + result = getattr(ser.dt, method) + expected = pd.Series([31, 30, 28, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) + + +def test_dt_normalize(): + ser = pd.Series( + [ + datetime(year=2023, month=3, day=30), + datetime(year=2023, month=4, day=1, hour=3), + datetime(year=2023, month=2, day=3, hour=23, minute=59, second=59), + None, + ], + dtype=ArrowDtype(pa.timestamp("us")), + ) + result = ser.dt.normalize() + expected = pd.Series( + [ + datetime(year=2023, month=3, day=30), + datetime(year=2023, month=4, day=1), + datetime(year=2023, month=2, day=3), + None, + ], + dtype=ArrowDtype(pa.timestamp("us")), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("unit", ["us", "ns"]) +def test_dt_time_preserve_unit(unit): + ser = pd.Series( + [datetime(year=2023, month=1, day=2, hour=3), None], + dtype=ArrowDtype(pa.timestamp(unit)), + ) + assert ser.dt.unit == unit + + result = ser.dt.time + expected = pd.Series( + ArrowExtensionArray(pa.array([time(3, 0), None], type=pa.time64(unit))) + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("tz", [None, "UTC", "US/Pacific"]) +def test_dt_tz(tz): + ser = pd.Series( + [datetime(year=2023, month=1, day=2, hour=3), None], + dtype=ArrowDtype(pa.timestamp("ns", tz=tz)), + ) + result = ser.dt.tz + assert result == timezones.maybe_get_tz(tz) + + +def test_dt_isocalendar(): + ser = pd.Series( + [datetime(year=2023, month=1, day=2, hour=3), None], + dtype=ArrowDtype(pa.timestamp("ns")), + ) + result = ser.dt.isocalendar() + expected = pd.DataFrame( + [[2023, 1, 1], [0, 0, 0]], + columns=["year", "week", "day"], + dtype="int64[pyarrow]", + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "method, exp", [["day_name", "Sunday"], ["month_name", "January"]] +) +def test_dt_day_month_name(method, exp, request): + # GH 52388 + _require_timezone_database(request) + + ser = pd.Series([datetime(2023, 1, 1), None], dtype=ArrowDtype(pa.timestamp("ms"))) + result = getattr(ser.dt, method)() + expected = pd.Series([exp, None], dtype=ArrowDtype(pa.string())) + tm.assert_series_equal(result, expected) + + +def test_dt_strftime(request): + _require_timezone_database(request) + + ser = pd.Series( + [datetime(year=2023, month=1, day=2, hour=3), None], + dtype=ArrowDtype(pa.timestamp("ns")), + ) + result = ser.dt.strftime("%Y-%m-%dT%H:%M:%S") + expected = pd.Series( + ["2023-01-02T03:00:00.000000000", None], dtype=ArrowDtype(pa.string()) + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["ceil", "floor", "round"]) +def test_dt_roundlike_tz_options_not_supported(method): + ser = pd.Series( + [datetime(year=2023, month=1, day=2, hour=3), None], + dtype=ArrowDtype(pa.timestamp("ns")), + ) + with pytest.raises(NotImplementedError, match="ambiguous is not supported."): + getattr(ser.dt, method)("1h", ambiguous="NaT") + + with pytest.raises(NotImplementedError, match="nonexistent is not supported."): + getattr(ser.dt, method)("1h", nonexistent="NaT") + + +@pytest.mark.parametrize("method", ["ceil", "floor", "round"]) +def test_dt_roundlike_unsupported_freq(method): + ser = pd.Series( + [datetime(year=2023, month=1, day=2, hour=3), None], + dtype=ArrowDtype(pa.timestamp("ns")), + ) + with pytest.raises(ValueError, match="freq='1B' is not supported"): + getattr(ser.dt, method)("1B") + + with pytest.raises(ValueError, match="Must specify a valid frequency: None"): + getattr(ser.dt, method)(None) + + +@pytest.mark.parametrize("freq", ["D", "h", "min", "s", "ms", "us", "ns"]) +@pytest.mark.parametrize("method", ["ceil", "floor", "round"]) +def test_dt_ceil_year_floor(freq, method): + ser = pd.Series( + [datetime(year=2023, month=1, day=1), None], + ) + pa_dtype = ArrowDtype(pa.timestamp("ns")) + expected = getattr(ser.dt, method)(f"1{freq}").astype(pa_dtype) + result = getattr(ser.astype(pa_dtype).dt, method)(f"1{freq}") + tm.assert_series_equal(result, expected) + + +def test_dt_to_pydatetime(): + # GH 51859 + data = [datetime(2022, 1, 1), datetime(2023, 1, 1)] + ser = pd.Series(data, dtype=ArrowDtype(pa.timestamp("ns"))) + result = ser.dt.to_pydatetime() + expected = pd.Series(data, dtype=object) + tm.assert_series_equal(result, expected) + assert all(type(expected.iloc[i]) is datetime for i in range(len(expected))) + + expected = ser.astype("datetime64[ns]").dt.to_pydatetime() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("date_type", [32, 64]) +def test_dt_to_pydatetime_date_error(date_type): + # GH 52812 + ser = pd.Series( + [date(2022, 12, 31)], + dtype=ArrowDtype(getattr(pa, f"date{date_type}")()), + ) + with pytest.raises(ValueError, match="to_pydatetime cannot be called with"): + ser.dt.to_pydatetime() + + +def test_dt_tz_localize_unsupported_tz_options(): + ser = pd.Series( + [datetime(year=2023, month=1, day=2, hour=3), None], + dtype=ArrowDtype(pa.timestamp("ns")), + ) + with pytest.raises(NotImplementedError, match="ambiguous='NaT' is not supported"): + ser.dt.tz_localize("UTC", ambiguous="NaT") + + with pytest.raises(NotImplementedError, match="nonexistent='NaT' is not supported"): + ser.dt.tz_localize("UTC", nonexistent="NaT") + + +def test_dt_tz_localize_none(request): + _require_timezone_database(request) + + ser = pd.Series( + [datetime(year=2023, month=1, day=2, hour=3), None], + dtype=ArrowDtype(pa.timestamp("ns", tz="US/Pacific")), + ) + result = ser.dt.tz_localize(None) + expected = pd.Series( + [ser[0].tz_localize(None), None], + dtype=ArrowDtype(pa.timestamp("ns")), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("unit", ["us", "ns"]) +def test_dt_tz_localize(unit, request): + _require_timezone_database(request) + + ser = pd.Series( + [datetime(year=2023, month=1, day=2, hour=3), None], + dtype=ArrowDtype(pa.timestamp(unit)), + ) + result = ser.dt.tz_localize("US/Pacific") + exp_data = pa.array( + [datetime(year=2023, month=1, day=2, hour=3), None], type=pa.timestamp(unit) + ) + exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific") + expected = pd.Series(ArrowExtensionArray(exp_data)) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "nonexistent, exp_date", + [ + ["shift_forward", datetime(year=2023, month=3, day=12, hour=3)], + ["shift_backward", pd.Timestamp("2023-03-12 01:59:59.999999999")], + ], +) +def test_dt_tz_localize_nonexistent(nonexistent, exp_date, request): + _require_timezone_database(request) + + ser = pd.Series( + [datetime(year=2023, month=3, day=12, hour=2, minute=30), None], + dtype=ArrowDtype(pa.timestamp("ns")), + ) + result = ser.dt.tz_localize("US/Pacific", nonexistent=nonexistent) + exp_data = pa.array([exp_date, None], type=pa.timestamp("ns")) + exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific") + expected = pd.Series(ArrowExtensionArray(exp_data)) + tm.assert_series_equal(result, expected) + + +def test_dt_tz_convert_not_tz_raises(): + ser = pd.Series( + [datetime(year=2023, month=1, day=2, hour=3), None], + dtype=ArrowDtype(pa.timestamp("ns")), + ) + with pytest.raises(TypeError, match="Cannot convert tz-naive timestamps"): + ser.dt.tz_convert("UTC") + + +def test_dt_tz_convert_none(): + ser = pd.Series( + [datetime(year=2023, month=1, day=2, hour=3), None], + dtype=ArrowDtype(pa.timestamp("ns", "US/Pacific")), + ) + result = ser.dt.tz_convert(None) + expected = pd.Series( + [ser[0].tz_convert(None), None], + dtype=ArrowDtype(pa.timestamp("ns")), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("unit", ["us", "ns"]) +def test_dt_tz_convert(unit): + ser = pd.Series( + [datetime(year=2023, month=1, day=2, hour=3), None], + dtype=ArrowDtype(pa.timestamp(unit, "US/Pacific")), + ) + result = ser.dt.tz_convert("US/Eastern") + expected = pd.Series( + [ser[0].tz_convert("US/Eastern"), None], + dtype=ArrowDtype(pa.timestamp(unit, "US/Eastern")), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["timestamp[ms][pyarrow]", "duration[ms][pyarrow]"]) +def test_as_unit(dtype): + # GH 52284 + ser = pd.Series([1000, None], dtype=dtype) + result = ser.dt.as_unit("ns") + expected = ser.astype(dtype.replace("ms", "ns")) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "from_unit,to_unit", + [ + ("ns", "us"), + ("ns", "ms"), + ("ns", "s"), + ("us", "ms"), + ("us", "s"), + ("ms", "s"), + ("s", "ms"), + ("s", "us"), + ("s", "ns"), + ("ms", "us"), + ("ms", "ns"), + ("us", "ns"), + ], +) +def test_as_unit_duration_truncation(from_unit, to_unit): + # Test that as_unit truncates correctly (matches NumPy behavior) + # Value with sub-unit precision to test truncation + ser_numpy = pd.Series( + pd.to_timedelta([93784567890123, None], unit="ns").as_unit(from_unit) + ) + ser_arrow = ser_numpy.astype(f"duration[{from_unit}][pyarrow]") + + result = ser_arrow.dt.as_unit(to_unit) + expected = ser_numpy.dt.as_unit(to_unit).astype(f"duration[{to_unit}][pyarrow]") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "from_unit,to_unit", + [ + ("ns", "us"), + ("ns", "ms"), + ("ns", "s"), + ("s", "ns"), + ("ms", "ns"), + ("us", "ns"), + ], +) +def test_as_unit_timestamp(from_unit, to_unit): + # Test timestamp as_unit matches NumPy behavior + # Create Arrow series directly to preserve nulls correctly + ser_arrow = pd.Series( + [pd.Timestamp("2024-01-15 12:30:45.123456789"), None], + dtype=f"timestamp[{from_unit}][pyarrow]", + ) + ser_numpy = ser_arrow.astype(f"datetime64[{from_unit}]") + + result = ser_arrow.dt.as_unit(to_unit) + expected_numpy = ser_numpy.dt.as_unit(to_unit) + # Compare values (excluding null handling differences) + tm.assert_almost_equal( + result.dropna().to_numpy(dtype=f"datetime64[{to_unit}]"), + expected_numpy.dropna().to_numpy(), + ) + # Verify nulls are preserved + assert result.isna().sum() == ser_arrow.isna().sum() + + +@pytest.mark.parametrize("to_unit", ["s", "ms", "us", "ns"]) +def test_as_unit_timestamp_with_timezone(to_unit): + # Test that timezone is preserved + ser_numpy = pd.Series( + pd.to_datetime(["2024-01-15 12:30:45.123456789"]) + .tz_localize("US/Eastern") + .as_unit("ns") + ) + ser_arrow = ser_numpy.astype("timestamp[ns, US/Eastern][pyarrow]") + + result = ser_arrow.dt.as_unit(to_unit) + expected = ser_numpy.dt.as_unit(to_unit).astype( + f"timestamp[{to_unit}, US/Eastern][pyarrow]" + ) + tm.assert_series_equal(result, expected) + assert str(result.dtype) == f"timestamp[{to_unit}, tz=US/Eastern][pyarrow]" + + +def test_as_unit_date_raises(): + # as_unit should raise for date types + ser = pd.Series([1, 2], dtype=ArrowDtype(pa.date32())) + with pytest.raises(NotImplementedError, match="as_unit not implemented"): + ser.dt.as_unit("ns") + + +@pytest.mark.parametrize( + "prop, expected", + [ + ["days", 1], + ["seconds", 2], + ["microseconds", 3], + ["nanoseconds", 4], + ], +) +def test_dt_timedelta_properties(prop, expected): + # GH 52284 + ser = pd.Series( + [ + pd.Timedelta( + days=1, + seconds=2, + microseconds=3, + nanoseconds=4, + ), + None, + ], + dtype=ArrowDtype(pa.duration("ns")), + ) + result = getattr(ser.dt, prop) + expected = pd.Series( + ArrowExtensionArray(pa.array([expected, None], type=pa.int32())) + ) + tm.assert_series_equal(result, expected) + + +def test_dt_timedelta_total_seconds(): + # GH 52284 + ser = pd.Series( + [ + pd.Timedelta( + days=1, + seconds=2, + microseconds=3, + nanoseconds=4, + ), + None, + ], + dtype=ArrowDtype(pa.duration("ns")), + ) + result = ser.dt.total_seconds() + expected = pd.Series( + ArrowExtensionArray(pa.array([86402.000003, None], type=pa.float64())) + ) + tm.assert_series_equal(result, expected) + + +def test_dt_to_pytimedelta(): + # GH 52284 + data = [timedelta(1, 2, 3), timedelta(1, 2, 4)] + ser = pd.Series(data, dtype=ArrowDtype(pa.duration("ns"))) + + msg = "The behavior of ArrowTemporalProperties.to_pytimedelta is deprecated" + with tm.assert_produces_warning(Pandas4Warning, match=msg): + result = ser.dt.to_pytimedelta() + expected = np.array(data, dtype=object) + tm.assert_numpy_array_equal(result, expected) + assert all(type(res) is timedelta for res in result) + + msg = "The behavior of TimedeltaProperties.to_pytimedelta is deprecated" + with tm.assert_produces_warning(Pandas4Warning, match=msg): + expected = ser.astype("timedelta64[ns]").dt.to_pytimedelta() + tm.assert_numpy_array_equal(result, expected) + + +def test_dt_components(): + # GH 52284 + ser = pd.Series( + [ + pd.Timedelta( + days=1, + seconds=2, + microseconds=3, + nanoseconds=4, + ), + None, + ], + dtype=ArrowDtype(pa.duration("ns")), + ) + result = ser.dt.components + expected = pd.DataFrame( + [[1, 0, 0, 2, 0, 3, 4], [pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA]], + columns=[ + "days", + "hours", + "minutes", + "seconds", + "milliseconds", + "microseconds", + "nanoseconds", + ], + dtype="int32[pyarrow]", + ) + tm.assert_frame_equal(result, expected) + + +def test_dt_components_large_values(): + ser = pd.Series( + [ + pd.Timedelta("365 days 23:59:59.999000"), + None, + ], + dtype=ArrowDtype(pa.duration("ns")), + ) + result = ser.dt.components + expected = pd.DataFrame( + [ + [365, 23, 59, 59, 999, 0, 0], + [pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA, pd.NA], + ], + columns=[ + "days", + "hours", + "minutes", + "seconds", + "milliseconds", + "microseconds", + "nanoseconds", + ], + dtype="int32[pyarrow]", + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("skipna", [True, False]) +def test_boolean_reduce_series_all_null(all_boolean_reductions, skipna): + # GH51624 + ser = pd.Series([None], dtype="float64[pyarrow]") + result = getattr(ser, all_boolean_reductions)(skipna=skipna) + if skipna: + expected = all_boolean_reductions == "all" + else: + expected = pd.NA + assert result is expected + + +def test_from_sequence_of_strings_boolean(): + true_strings = ["true", "TRUE", "True", "1", "1.0"] + false_strings = ["false", "FALSE", "False", "0", "0.0"] + nulls = [None] + strings = true_strings + false_strings + nulls + bools = ( + [True] * len(true_strings) + [False] * len(false_strings) + [None] * len(nulls) + ) + + dtype = ArrowDtype(pa.bool_()) + result = ArrowExtensionArray._from_sequence_of_strings(strings, dtype=dtype) + expected = pd.array(bools, dtype="boolean[pyarrow]") + tm.assert_extension_array_equal(result, expected) + + strings = ["True", "foo"] + with pytest.raises(pa.ArrowInvalid, match="Failed to parse"): + ArrowExtensionArray._from_sequence_of_strings(strings, dtype=dtype) + + +def test_concat_empty_arrow_backed_series(dtype): + # GH#51734 + ser = pd.Series([], dtype=dtype) + expected = ser.copy() + result = pd.concat([ser[np.array([], dtype=np.bool_)]]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["string", "string[pyarrow]"]) +def test_series_from_string_array(dtype): + arr = pa.array("the quick brown fox".split()) + ser = pd.Series(arr, dtype=dtype) + expected = pd.Series(ArrowExtensionArray(arr), dtype=dtype) + tm.assert_series_equal(ser, expected) + + +# _data was renamed to _pa_data +class OldArrowExtensionArray(ArrowExtensionArray): + def __getstate__(self): + state = super().__getstate__() + state["_data"] = state.pop("_pa_array") + return state + + +def test_pickle_old_arrowextensionarray(): + data = pa.array([1]) + expected = OldArrowExtensionArray(data) + result = pickle.loads(pickle.dumps(expected)) + tm.assert_extension_array_equal(result, expected) + assert result._pa_array == pa.chunked_array(data) + assert not hasattr(result, "_data") + + +def test_setitem_boolean_replace_with_mask_segfault(): + # GH#52059 + N = 145_000 + arr = ArrowExtensionArray(pa.chunked_array([np.ones((N,), dtype=np.bool_)])) + expected = arr.copy() + arr[np.zeros((N,), dtype=np.bool_)] = False + assert arr._pa_array == expected._pa_array + + +@pytest.mark.parametrize( + "data, arrow_dtype", + [ + ([b"a", b"b"], pa.large_binary()), + (["a", "b"], pa.large_string()), + ], +) +def test_conversion_large_dtypes_from_numpy_array(data, arrow_dtype): + dtype = ArrowDtype(arrow_dtype) + result = pd.array(np.array(data), dtype=dtype) + expected = pd.array(data, dtype=dtype) + tm.assert_extension_array_equal(result, expected) + + +def test_concat_null_array(): + df = pd.DataFrame({"a": [None, None]}, dtype=ArrowDtype(pa.null())) + df2 = pd.DataFrame({"a": [0, 1]}, dtype="int64[pyarrow]") + + result = pd.concat([df, df2], ignore_index=True) + expected = pd.DataFrame({"a": [None, None, 0, 1]}, dtype="int64[pyarrow]") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("pa_type", tm.ALL_INT_PYARROW_DTYPES + tm.FLOAT_PYARROW_DTYPES) +def test_describe_numeric_data(pa_type): + # GH 52470 + data = pd.Series([1, 2, 3], dtype=ArrowDtype(pa_type)) + result = data.describe() + expected = pd.Series( + [3, 2, 1, 1, 1.5, 2.0, 2.5, 3], + dtype=ArrowDtype(pa.float64()), + index=["count", "mean", "std", "min", "25%", "50%", "75%", "max"], + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("pa_type", tm.TIMEDELTA_PYARROW_DTYPES) +def test_describe_timedelta_data(pa_type): + # GH53001 + data = pd.Series(range(1, 10), dtype=ArrowDtype(pa_type)) + result = data.describe() + expected = pd.Series( + [9, *pd.to_timedelta([5, 2, 1, 3, 5, 7, 9], unit=pa_type.unit).tolist()], + dtype=object, + index=["count", "mean", "std", "min", "25%", "50%", "75%", "max"], + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("pa_type", tm.DATETIME_PYARROW_DTYPES) +def test_describe_datetime_data(pa_type): + # GH53001 + data = pd.Series(range(1, 10), dtype=ArrowDtype(pa_type)) + result = data.describe() + expected = pd.Series( + [9] + + [ + pd.Timestamp(v, tz=pa_type.tz, unit=pa_type.unit) + for v in [5, 1, 3, 5, 7, 9] + ], + dtype=object, + index=["count", "mean", "min", "25%", "50%", "75%", "max"], + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES +) +def test_quantile_temporal(pa_type): + # GH52678 + data = [1, 2, 3] + ser = pd.Series(data, dtype=ArrowDtype(pa_type)) + result = ser.quantile(0.1) + expected = ser[0] + assert result == expected + + +def test_date32_repr(): + # GH48238 + arrow_dt = pa.array([date.fromisoformat("2020-01-01")], type=pa.date32()) + ser = pd.Series(arrow_dt, dtype=ArrowDtype(arrow_dt.type)) + assert repr(ser) == "0 2020-01-01\ndtype: date32[day][pyarrow]" + + +def test_duration_overflow_from_ndarray_containing_nat(): + # GH52843 + data_ts = pd.to_datetime([1, None]) + data_td = pd.to_timedelta([1, None]) + ser_ts = pd.Series(data_ts, dtype=ArrowDtype(pa.timestamp("ns"))) + ser_td = pd.Series(data_td, dtype=ArrowDtype(pa.duration("ns"))) + result = ser_ts + ser_td + expected = pd.Series([2, None], dtype=ArrowDtype(pa.timestamp("ns"))) + tm.assert_series_equal(result, expected) + + +def test_infer_dtype_pyarrow_dtype(data, request): + res = lib.infer_dtype(data) + assert res != "unknown-array" + + if data._hasna and res in ["datetime64", "timedelta64"]: + mark = pytest.mark.xfail( + reason="in infer_dtype pd.NA is not ignored in these cases " + "even with skipna=True in the list(data) check below" + ) + request.applymarker(mark) + + assert res == lib.infer_dtype(list(data), skipna=True) + + +@pytest.mark.parametrize( + "pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES +) +def test_from_sequence_temporal(pa_type): + # GH 53171 + val = 3 + unit = pa_type.unit + if pa.types.is_duration(pa_type): + seq = [pd.Timedelta(val, unit=unit).as_unit(unit)] + else: + seq = [pd.Timestamp(val, unit=unit, tz=pa_type.tz).as_unit(unit)] + + result = ArrowExtensionArray._from_sequence(seq, dtype=pa_type) + expected = ArrowExtensionArray(pa.array([val], type=pa_type)) + tm.assert_extension_array_equal(result, expected) + + +@pytest.mark.parametrize( + "pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES +) +def test_setitem_temporal(pa_type): + # GH 53171 + unit = pa_type.unit + if pa.types.is_duration(pa_type): + val = pd.Timedelta(1, unit=unit).as_unit(unit) + else: + val = pd.Timestamp(1, unit=unit, tz=pa_type.tz).as_unit(unit) + + arr = ArrowExtensionArray(pa.array([1, 2, 3], type=pa_type)) + + result = arr.copy() + result[:] = val + expected = ArrowExtensionArray(pa.array([1, 1, 1], type=pa_type)) + tm.assert_extension_array_equal(result, expected) + + +@pytest.mark.parametrize( + "pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES +) +def test_arithmetic_temporal(pa_type, request): + # GH 53171 + arr = ArrowExtensionArray(pa.array([1, 2, 3], type=pa_type)) + unit = pa_type.unit + result = arr - pd.Timedelta(1, unit=unit).as_unit(unit) + expected = ArrowExtensionArray(pa.array([0, 1, 2], type=pa_type)) + tm.assert_extension_array_equal(result, expected) + + +@pytest.mark.parametrize( + "pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES +) +def test_comparison_temporal(pa_type): + # GH 53171 + unit = pa_type.unit + if pa.types.is_duration(pa_type): + val = pd.Timedelta(1, unit=unit).as_unit(unit) + else: + val = pd.Timestamp(1, unit=unit, tz=pa_type.tz).as_unit(unit) + + arr = ArrowExtensionArray(pa.array([1, 2, 3], type=pa_type)) + + result = arr > val + expected = ArrowExtensionArray(pa.array([False, True, True], type=pa.bool_())) + tm.assert_extension_array_equal(result, expected) + + +@pytest.mark.parametrize( + "pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES +) +def test_getitem_temporal(pa_type): + # GH 53326 + arr = ArrowExtensionArray(pa.array([1, 2, 3], type=pa_type)) + result = arr[1] + if pa.types.is_duration(pa_type): + expected = pd.Timedelta(2, unit=pa_type.unit).as_unit(pa_type.unit) + assert isinstance(result, pd.Timedelta) + else: + expected = pd.Timestamp(2, unit=pa_type.unit, tz=pa_type.tz).as_unit( + pa_type.unit + ) + assert isinstance(result, pd.Timestamp) + assert result.unit == expected.unit + assert result == expected + + +@pytest.mark.parametrize( + "pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES +) +def test_iter_temporal(pa_type): + # GH 53326 + arr = ArrowExtensionArray(pa.array([1, None], type=pa_type)) + result = list(arr) + if pa.types.is_duration(pa_type): + expected = [ + pd.Timedelta(1, unit=pa_type.unit).as_unit(pa_type.unit), + pd.NA, + ] + assert isinstance(result[0], pd.Timedelta) + else: + expected = [ + pd.Timestamp(1, unit=pa_type.unit, tz=pa_type.tz).as_unit(pa_type.unit), + pd.NA, + ] + assert isinstance(result[0], pd.Timestamp) + assert result[0].unit == expected[0].unit + assert result == expected + + +def test_groupby_series_size_returns_pa_int(data): + # GH 54132 + ser = pd.Series(data[:3], index=["a", "a", "b"]) + result = ser.groupby(level=0).size() + expected = pd.Series([2, 1], dtype="int64[pyarrow]", index=["a", "b"]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "pa_type", tm.DATETIME_PYARROW_DTYPES + tm.TIMEDELTA_PYARROW_DTYPES, ids=repr +) +@pytest.mark.parametrize("dtype", [None, object]) +def test_to_numpy_temporal(pa_type, dtype): + # GH 53326 + # GH 55997: Return datetime64/timedelta64 types with NaT if possible + arr = ArrowExtensionArray(pa.array([1, None], type=pa_type)) + result = arr.to_numpy(dtype=dtype) + if pa.types.is_duration(pa_type): + value = pd.Timedelta(1, unit=pa_type.unit).as_unit(pa_type.unit) + else: + value = pd.Timestamp(1, unit=pa_type.unit, tz=pa_type.tz).as_unit(pa_type.unit) + + if dtype == object or (pa.types.is_timestamp(pa_type) and pa_type.tz is not None): + if dtype == object: + na = pd.NA + else: + na = pd.NaT + expected = np.array([value, na], dtype=object) + assert result[0].unit == value.unit + else: + na = pa_type.to_pandas_dtype().type("nat", pa_type.unit) + value = value.to_numpy() + expected = np.array([value, na]) + assert np.datetime_data(result[0])[0] == pa_type.unit + tm.assert_numpy_array_equal(result, expected) + + +def test_groupby_count_return_arrow_dtype(data_missing): + df = pd.DataFrame({"A": [1, 1], "B": data_missing, "C": data_missing}) + result = df.groupby("A").count() + expected = pd.DataFrame( + [[1, 1]], + index=pd.Index([1], name="A"), + columns=["B", "C"], + dtype="int64[pyarrow]", + ) + tm.assert_frame_equal(result, expected) + + +def test_fixed_size_list(): + # GH#55000 + ser = pd.Series( + [[1, 2], [3, 4]], dtype=ArrowDtype(pa.list_(pa.int64(), list_size=2)) + ) + result = ser.dtype.type + assert result == list + + +def test_arrowextensiondtype_dataframe_repr(): + # GH 54062 + df = pd.DataFrame( + pd.period_range("2012", periods=3), + columns=["col"], + dtype=ArrowDtype(ArrowPeriodType("D")), + ) + result = repr(df) + # TODO: repr value may not be expected; address how + # pyarrow.ExtensionType values are displayed + expected = " col\n0 15340\n1 15341\n2 15342" + assert result == expected + + +def test_pow_missing_operand(): + # GH 55512 + k = pd.Series([2, None], dtype="int64[pyarrow]") + result = k.pow(None, fill_value=3) + expected = pd.Series([8, None], dtype="int64[pyarrow]") + tm.assert_series_equal(result, expected) + + +def test_decimal_parse_raises(): + # GH 56984 + ser = pd.Series(["1.2345"], dtype=ArrowDtype(pa.string())) + with pytest.raises( + pa.lib.ArrowInvalid, match="Rescaling Decimal(128)? value would cause data loss" + ): + ser.astype(ArrowDtype(pa.decimal128(1, 0))) + + +def test_decimal_parse_succeeds(): + # GH 56984 + ser = pd.Series(["1.2345"], dtype=ArrowDtype(pa.string())) + dtype = ArrowDtype(pa.decimal128(5, 4)) + result = ser.astype(dtype) + expected = pd.Series([Decimal("1.2345")], dtype=dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("pa_type", tm.TIMEDELTA_PYARROW_DTYPES) +def test_duration_fillna_numpy(pa_type): + # GH 54707 + ser1 = pd.Series([None, 2], dtype=ArrowDtype(pa_type)) + ser2 = pd.Series(np.array([1, 3], dtype=f"m8[{pa_type.unit}]")) + result = ser1.fillna(ser2) + expected = pd.Series([1, 2], dtype=ArrowDtype(pa_type)) + tm.assert_series_equal(result, expected) + + +def test_comparison_not_propagating_arrow_error(): + # GH#54944 + a = pd.Series([1 << 63], dtype="uint64[pyarrow]") + b = pd.Series([None], dtype="int64[pyarrow]") + with pytest.raises(pa.lib.ArrowInvalid, match="Integer value"): + a < b + + +def test_factorize_chunked_dictionary(): + # GH 54844 + pa_array = pa.chunked_array( + [pa.array(["a"]).dictionary_encode(), pa.array(["b"]).dictionary_encode()] + ) + ser = pd.Series(ArrowExtensionArray(pa_array)) + res_indices, res_uniques = ser.factorize() + exp_indices = np.array([0, 1], dtype=np.intp) + exp_uniques = pd.Index(ArrowExtensionArray(pa_array.combine_chunks())) + tm.assert_numpy_array_equal(res_indices, exp_indices) + tm.assert_index_equal(res_uniques, exp_uniques) + + +def test_factorize_dictionary_with_na(): + # GH#60567 + arr = pd.array( + ["a1", pd.NA], dtype=ArrowDtype(pa.dictionary(pa.int32(), pa.utf8())) + ) + indices, uniques = arr.factorize(use_na_sentinel=False) + expected_indices = np.array([0, 1], dtype=np.intp) + expected_uniques = pd.array(["a1", None], dtype=ArrowDtype(pa.string())) + tm.assert_numpy_array_equal(indices, expected_indices) + tm.assert_extension_array_equal(uniques, expected_uniques) + + +def test_dictionary_astype_categorical(): + # GH#56672 + arrs = [ + pa.array(np.array(["a", "x", "c", "a"])).dictionary_encode(), + pa.array(np.array(["a", "d", "c"])).dictionary_encode(), + ] + ser = pd.Series(ArrowExtensionArray(pa.chunked_array(arrs))) + result = ser.astype("category") + categories = pd.Index(["a", "x", "c", "d"], dtype=ArrowDtype(pa.string())) + expected = pd.Series( + ["a", "x", "c", "a", "a", "d", "c"], + dtype=pd.CategoricalDtype(categories=categories), + ) + tm.assert_series_equal(result, expected) + + +def test_arrow_floordiv(): + # GH 55561 + a = pd.Series([-7], dtype="int64[pyarrow]") + b = pd.Series([4], dtype="int64[pyarrow]") + expected = pd.Series([-2], dtype="int64[pyarrow]") + result = a // b + tm.assert_series_equal(result, expected) + + +def test_arrow_floordiv_large_values(): + # GH 56645 + a = pd.Series([1425801600000000000], dtype="int64[pyarrow]") + expected = pd.Series([1425801600000], dtype="int64[pyarrow]") + result = a // 1_000_000 + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["int64[pyarrow]", "uint64[pyarrow]"]) +def test_arrow_floordiv_large_integral_result(dtype): + # GH 56676 + a = pd.Series([18014398509481983], dtype=dtype) + result = a // 1 + tm.assert_series_equal(result, a) + + +@pytest.mark.parametrize("pa_type", tm.SIGNED_INT_PYARROW_DTYPES) +def test_arrow_floordiv_larger_divisor(pa_type): + # GH 56676 + dtype = ArrowDtype(pa_type) + a = pd.Series([-23], dtype=dtype) + result = a // 24 + expected = pd.Series([-1], dtype=dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("pa_type", tm.SIGNED_INT_PYARROW_DTYPES) +def test_arrow_floordiv_integral_invalid(pa_type): + # GH 56676 + min_value = np.iinfo(pa_type.to_pandas_dtype()).min + a = pd.Series([min_value], dtype=ArrowDtype(pa_type)) + with pytest.raises(pa.lib.ArrowInvalid, match="overflow|not in range"): + a // -1 + with pytest.raises(pa.lib.ArrowInvalid, match="divide by zero"): + a // 0 + + +@pytest.mark.parametrize("dtype", tm.FLOAT_PYARROW_DTYPES_STR_REPR) +def test_arrow_floordiv_floating_0_divisor(dtype): + # GH 56676 + a = pd.Series([2], dtype=dtype) + result = a // 0 + expected = pd.Series([float("inf")], dtype=dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["float64", "datetime64[ns]", "timedelta64[ns]"]) +def test_astype_int_with_null_to_numpy_dtype(dtype): + # GH 57093 + ser = pd.Series([1, None], dtype="int64[pyarrow]") + result = ser.astype(dtype) + expected = pd.Series([1, None], dtype=dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("pa_type", tm.ALL_INT_PYARROW_DTYPES) +def test_arrow_integral_floordiv_large_values(pa_type): + # GH 56676 + max_value = np.iinfo(pa_type.to_pandas_dtype()).max + dtype = ArrowDtype(pa_type) + a = pd.Series([max_value], dtype=dtype) + b = pd.Series([1], dtype=dtype) + result = a // b + tm.assert_series_equal(result, a) + + +@pytest.mark.parametrize("dtype", ["int64[pyarrow]", "uint64[pyarrow]"]) +def test_arrow_true_division_large_divisor(dtype): + # GH 56706 + a = pd.Series([0], dtype=dtype) + b = pd.Series([18014398509481983], dtype=dtype) + expected = pd.Series([0], dtype="float64[pyarrow]") + result = a / b + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["int64[pyarrow]", "uint64[pyarrow]"]) +def test_arrow_floor_division_large_divisor(dtype): + # GH 56706 + a = pd.Series([0], dtype=dtype) + b = pd.Series([18014398509481983], dtype=dtype) + expected = pd.Series([0], dtype=dtype) + result = a // b + tm.assert_series_equal(result, expected) + + +def test_string_to_datetime_parsing_cast(): + # GH 56266 + string_dates = ["2020-01-01 04:30:00", "2020-01-02 00:00:00", "2020-01-03 00:00:00"] + result = pd.Series(string_dates, dtype="timestamp[s][pyarrow]") + + pd_res = pd.to_datetime(string_dates).as_unit("s") + expected = pd.Series(ArrowExtensionArray(pa.array(pd_res, from_pandas=True))) + tm.assert_series_equal(result, expected) + + +def test_interpolate_not_numeric(data): + if not data.dtype._is_numeric: + ser = pd.Series(data) + msg = re.escape(f"Cannot interpolate with {ser.dtype} dtype") + with pytest.raises(TypeError, match=msg): + pd.Series(data).interpolate() + + +@pytest.mark.parametrize("dtype", ["int64[pyarrow]", "float64[pyarrow]"]) +def test_interpolate_linear(dtype): + ser = pd.Series([None, 1, 2, None, 4, None], dtype=dtype) + result = ser.interpolate() + expected = pd.Series([None, 1, 2, 3, 4, None], dtype=dtype) + tm.assert_series_equal(result, expected) + + +def test_string_to_time_parsing_cast(): + # GH 56463 + string_times = ["11:41:43.076160"] + result = pd.Series(string_times, dtype="time64[us][pyarrow]") + expected = pd.Series( + ArrowExtensionArray(pa.array([time(11, 41, 43, 76160)], from_pandas=True)) + ) + tm.assert_series_equal(result, expected) + + +def test_to_numpy_float(): + # GH#56267 + ser = pd.Series([32, 40, None], dtype="float[pyarrow]") + result = ser.astype("float64") + expected = pd.Series([32, 40, np.nan], dtype="float64") + tm.assert_series_equal(result, expected) + + +def test_to_numpy_timestamp_to_int(): + # GH 55997 + ser = pd.Series(["2020-01-01 04:30:00"], dtype="timestamp[ns][pyarrow]") + result = ser.to_numpy(dtype=np.int64) + expected = np.array([1577853000000000000]) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("arrow_type", [pa.large_string(), pa.string()]) +def test_cast_dictionary_different_value_dtype(arrow_type): + df = pd.DataFrame({"a": ["x", "y"]}, dtype="string[pyarrow]") + data_type = ArrowDtype(pa.dictionary(pa.int32(), arrow_type)) + result = df.astype({"a": data_type}) + assert result.dtypes.iloc[0] == data_type + + +def test_map_numeric_na_action(using_nan_is_na): + ser = pd.Series([32, 40, None], dtype="int64[pyarrow]") + result = ser.map(lambda x: 42, na_action="ignore") + if not using_nan_is_na: + expected = pd.Series([42.0, 42.0, pd.NA], dtype="object") + else: + expected = pd.Series([42.0, 42.0, np.nan], dtype="float64") + tm.assert_series_equal(result, expected) + + +def test_categorical_from_arrow_dictionary(): + # GH 60563 + df = pd.DataFrame( + {"A": ["a1", "a2"]}, dtype=ArrowDtype(pa.dictionary(pa.int32(), pa.utf8())) + ) + result = df.value_counts(dropna=False) + expected = pd.Series( + [1, 1], + index=pd.MultiIndex.from_arrays( + [pd.Index(["a1", "a2"], dtype=ArrowDtype(pa.string()), name="A")] + ), + name="count", + dtype="int64", + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.skipif( + pa_version_under19p0, reason="pa.json_ was introduced in pyarrow v19.0" +) +def test_arrow_json_type(): + # GH 60958 + dtype = ArrowDtype(pa.json_(pa.string())) + result = dtype.type + assert result == str + + +def test_timestamp_dtype_disallows_decimal(): + # GH#61773 constructing with pyarrow timestamp dtype should disallow + # Decimal NaN, just like pd.to_datetime + vals = [pd.Timestamp("2016-01-02 03:04:05"), Decimal("NaN")] + + msg = " is not convertible to datetime" + with pytest.raises(TypeError, match=msg): + # Check that the non-pyarrow version raises as expected + pd.to_datetime(vals) + + with pytest.raises(TypeError, match=msg): + pd.array(vals, dtype=ArrowDtype(pa.timestamp("us"))) + + +def test_timestamp_dtype_matches_to_datetime(): + # GH#61775 + dtype1 = "datetime64[ns, US/Eastern]" + dtype2 = "timestamp[ns, US/Eastern][pyarrow]" + + ts = pd.Timestamp("2025-07-03 18:10") + + result = pd.Series([ts], dtype=dtype2) + expected = pd.Series([ts], dtype=dtype1).convert_dtypes(dtype_backend="pyarrow") + + tm.assert_series_equal(result, expected) + + +def test_timestamp_vs_dt64_comparison(): + # GH#60937 + left = pd.Series(["2016-01-01"], dtype="timestamp[ns][pyarrow]") + right = left.astype("datetime64[ns]") + + result = left == right + expected = pd.Series([True], dtype="bool[pyarrow]") + tm.assert_series_equal(result, expected) + + result = right == left + tm.assert_series_equal(result, expected) + + +# TODO: reuse assert_invalid_comparison? +def test_date_vs_timestamp_scalar_comparison(): + # GH#62157 match non-pyarrow behavior + ser = pd.Series(["2016-01-01"], dtype="date32[pyarrow]") + ser2 = ser.astype("timestamp[ns][pyarrow]") + + ts = ser2[0] + dt = ser[0] + + # date dtype don't match a Timestamp object + assert not (ser == ts).any() + assert not (ts == ser).any() + + # timestamp dtype doesn't match date object + assert not (ser2 == dt).any() + assert not (dt == ser2).any() + + +# TODO: reuse assert_invalid_comparison? +def test_date_vs_timestamp_array_comparison(): + # GH#62157 match non-pyarrow behavior + # GH# + ser = pd.Series(["2016-01-01"], dtype="date32[pyarrow]") + ser2 = ser.astype("timestamp[ns][pyarrow]") + ser3 = ser.astype("datetime64[ns]") + + assert not (ser == ser2).any() + assert not (ser2 == ser).any() + assert (ser != ser2).all() + assert (ser2 != ser).all() + + assert not (ser == ser3).any() + assert not (ser3 == ser).any() + assert (ser != ser3).all() + assert (ser3 != ser).all() + + +def test_ops_with_nan_is_na(using_nan_is_na): + # GH#61732 + ser = pd.Series([-1, 0, 1], dtype="int64[pyarrow]") + + result = ser - np.nan + if using_nan_is_na: + assert result.isna().all() + else: + assert not result.isna().any() + + result = ser * np.nan + if using_nan_is_na: + assert result.isna().all() + else: + assert not result.isna().any() + + result = ser / 0 + if using_nan_is_na: + assert result.isna()[1] + else: + assert not result.isna()[1] + + +def test_setitem_float_nan_is_na(using_nan_is_na): + # GH#61732 + ser = pd.Series([-1, 0, 1], dtype="int64[pyarrow]") + + if using_nan_is_na: + ser[1] = np.nan + assert ser.isna()[1] + else: + msg = "Could not convert nan with type float: tried to convert to int64" + with pytest.raises(pa.lib.ArrowInvalid, match=msg): + ser[1] = np.nan + + ser = pd.Series([-1, np.nan, 1], dtype="float64[pyarrow]") + if using_nan_is_na: + assert ser.isna()[1] + assert ser[1] is pd.NA + + ser[1] = np.nan + assert ser[1] is pd.NA + + else: + assert not ser.isna()[1] + assert isinstance(ser[1], float) + assert np.isnan(ser[1]) + + ser[2] = np.nan + assert isinstance(ser[2], float) + assert np.isnan(ser[2]) + + +def test_pow_with_all_na_float(): + # GH#62520 + + s = pd.Series([None, None], dtype="float64[pyarrow]") + result = s.pow(2) + expected = pd.Series([pd.NA, pd.NA], dtype="float64[pyarrow]") + tm.assert_series_equal(result, expected) + + +def test_mul_numpy_nullable_with_pyarrow_float(): + # GH#58602 + left = pd.Series(range(5), dtype="Float64") + right = pd.Series(range(5), dtype="float64[pyarrow]") + + expected = pd.Series([0, 1, 4, 9, 16], dtype="float64[pyarrow]") + + result = left * right + tm.assert_series_equal(result, expected) + + result2 = right * left + tm.assert_series_equal(result2, expected) + + # while we're here, let's check __eq__ + result3 = left == right + expected3 = pd.Series([True] * 5, dtype="bool[pyarrow]") + tm.assert_series_equal(result3, expected3) + + result4 = right == left + tm.assert_series_equal(result4, expected3) + + +@pytest.mark.parametrize( + "type_name, expected_size", + [ + # Integer types + ("int8", 1), + ("int16", 2), + ("int32", 4), + ("int64", 8), + ("uint8", 1), + ("uint16", 2), + ("uint32", 4), + ("uint64", 8), + # Floating point types + ("float16", 2), + ("float32", 4), + ("float64", 8), + # Boolean + ("bool_", 1), + # Date and timestamp types + ("date32", 4), + ("date64", 8), + ("timestamp", 8), + # Time types + ("time32", 4), + ("time64", 8), + # Decimal types + ("decimal128", 16), + ("decimal256", 32), + ], +) +def test_arrow_dtype_itemsize_fixed_width(type_name, expected_size): + # GH 57948 + + parametric_type_map = { + "timestamp": pa.timestamp("ns"), + "time32": pa.time32("s"), + "time64": pa.time64("ns"), + "decimal128": pa.decimal128(38, 10), + "decimal256": pa.decimal256(76, 10), + } + + if type_name in parametric_type_map: + arrow_type = parametric_type_map.get(type_name) + else: + arrow_type = getattr(pa, type_name)() + dtype = ArrowDtype(arrow_type) + + if type_name == "bool_": + expected_size = dtype.numpy_dtype.itemsize + + assert dtype.itemsize == expected_size, ( + f"{type_name} expected {expected_size}, got {dtype.itemsize} " + f"(bit_width={getattr(dtype.pyarrow_dtype, 'bit_width', 'N/A')})" + ) + + +@pytest.mark.parametrize("type_name", ["string", "binary", "large_string"]) +def test_arrow_dtype_itemsize_variable_width(type_name): + # GH 57948 + + arrow_type = getattr(pa, type_name)() + dtype = ArrowDtype(arrow_type) + + assert dtype.itemsize == dtype.numpy_dtype.itemsize + + +def test_cast_pontwise_result_decimal_nan(): + # GH#62522 we don't want to get back null[pyarrow] here + ser = pd.Series([], dtype="float64[pyarrow]") + arr = ser.array + item = Decimal("NaN") + + result = arr._cast_pointwise_result([item]) + + pa_type = result.dtype.pyarrow_dtype + assert pa.types.is_decimal(pa_type) + + +def test_ufunc_retains_missing(): + # GH#62800 + ser = pd.Series([0.1, pd.NA], dtype="float64[pyarrow]") + + result = np.sin(ser) + + expected = pd.Series([np.sin(0.1), pd.NA], dtype="float64[pyarrow]") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["sum", "min", "max", "mean", "median"]) +def test_duration_reduction_consistency(unit, method): + # GH#63170 + dtype = f"duration[{unit}][pyarrow]" + ser = pd.Series([timedelta(seconds=1), timedelta(seconds=2)], dtype=dtype) + result = getattr(ser, method)() + assert isinstance(result, pd.Timedelta), ( + f"{method} for {unit} returned {type(result)}" + ) + assert result.unit == unit + + +@pytest.mark.parametrize("method", ["min", "max", "median"]) +def test_timestamp_reduction_consistency(unit, method): + # GH#63170 + dtype = f"timestamp[{unit}][pyarrow]" + ser = pd.Series([datetime(2024, 1, 1), datetime(2024, 1, 3)], dtype=dtype) + result = getattr(ser, method)() + assert isinstance(result, pd.Timestamp), ( + f"{method} for {unit} returned {type(result)}" + ) + assert result.unit == unit diff --git a/pandas/tests/extension/test_categorical.py b/pandas/tests/extension/test_categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..eb671e74f4b25924121e83cd687530136d1fb54a --- /dev/null +++ b/pandas/tests/extension/test_categorical.py @@ -0,0 +1,192 @@ +""" +This file contains a minimal set of tests for compliance with the extension +array interface test suite, and should contain no other tests. +The test suite for the full functionality of the array is located in +`pandas/tests/arrays/`. + +The tests in this file are inherited from the BaseExtensionTests, and only +minimal tweaks should be applied to get the tests passing (by overwriting a +parent method). + +Additional tests should either be added to one of the BaseExtensionTests +classes (if they are relevant for the extension interface for all dtypes), or +be added to the array-specific tests in `pandas/tests/arrays/`. + +""" + +import string + +import numpy as np +import pytest + +from pandas._config import using_string_dtype + +import pandas as pd +from pandas import Categorical +import pandas._testing as tm +from pandas.api.types import CategoricalDtype +from pandas.tests.extension import base + + +def make_data(n: int): + while True: + values = np.random.default_rng(2).choice(list(string.ascii_letters), size=n) + # ensure we meet the requirements + # 1. first two not null + # 2. first and second are different + if values[0] != values[1]: + break + return values + + +@pytest.fixture +def dtype(): + return CategoricalDtype() + + +@pytest.fixture +def data(): + """Length-100 array for this type. + + * data[0] and data[1] should both be non missing + * data[0] and data[1] should not be equal + """ + return Categorical(make_data(10)) + + +@pytest.fixture +def data_missing(): + """Length 2 array with [NA, Valid]""" + return Categorical([np.nan, "A"]) + + +@pytest.fixture +def data_for_sorting(): + return Categorical(["A", "B", "C"], categories=["C", "A", "B"], ordered=True) + + +@pytest.fixture +def data_missing_for_sorting(): + return Categorical(["A", None, "B"], categories=["B", "A"], ordered=True) + + +@pytest.fixture +def data_for_grouping(): + return Categorical(["a", "a", None, None, "b", "b", "a", "c"]) + + +class TestCategorical(base.ExtensionTests): + def test_contains(self, data, data_missing): + # GH-37867 + # na value handling in Categorical.__contains__ is deprecated. + # See base.BaseInterFaceTests.test_contains for more details. + + na_value = data.dtype.na_value + # ensure data without missing values + data = data[~data.isna()] + + # first elements are non-missing + assert data[0] in data + assert data_missing[0] in data_missing + + # check the presence of na_value + assert na_value in data_missing + assert na_value not in data + + # Categoricals can contain other nan-likes than na_value + for na_value_obj in tm.NULL_OBJECTS: + if na_value_obj is na_value: + continue + assert na_value_obj not in data + # this section suffers from super method + if not using_string_dtype(): + assert na_value_obj in data_missing + + def test_empty(self, dtype): + cls = dtype.construct_array_type() + result = cls._empty((4,), dtype=dtype) + + assert isinstance(result, cls) + # the dtype we passed is not initialized, so will not match the + # dtype on our result. + assert result.dtype == CategoricalDtype([]) + + @pytest.mark.skip(reason="Backwards compatibility") + def test_getitem_scalar(self, data): + # CategoricalDtype.type isn't "correct" since it should + # be a parent of the elements (object). But don't want + # to break things by changing. + super().test_getitem_scalar(data) + + def test_combine_add(self, data_repeated): + # GH 20825 + # When adding categoricals in combine, result is a string + orig_data1, orig_data2 = data_repeated(2) + s1 = pd.Series(orig_data1) + s2 = pd.Series(orig_data2) + result = s1.combine(s2, lambda x1, x2: x1 + x2) + expected = pd.Series( + [a + b for (a, b) in zip(list(orig_data1), list(orig_data2), strict=True)] + ) + tm.assert_series_equal(result, expected) + + val = s1.iloc[0] + result = s1.combine(val, lambda x1, x2: x1 + x2) + expected = pd.Series([a + val for a in list(orig_data1)]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("na_action", [None, "ignore"]) + def test_map(self, data, na_action): + result = data.map(lambda x: x, na_action=na_action) + tm.assert_extension_array_equal(result, data) + + def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request): + # frame & scalar + op_name = all_arithmetic_operators + if op_name == "__rmod__": + request.applymarker( + pytest.mark.xfail( + reason="rmod never called when string is first argument" + ) + ) + super().test_arith_frame_with_scalar(data, op_name) + + def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request): + op_name = all_arithmetic_operators + if op_name == "__rmod__": + request.applymarker( + pytest.mark.xfail( + reason="rmod never called when string is first argument" + ) + ) + super().test_arith_series_with_scalar(data, op_name) + + def _compare_other(self, ser: pd.Series, data, op, other): + op_name = f"__{op.__name__}__" + if op_name not in ["__eq__", "__ne__"]: + msg = "Unordered Categoricals can only compare equality or not" + with pytest.raises(TypeError, match=msg): + op(data, other) + else: + return super()._compare_other(ser, data, op, other) + + @pytest.mark.xfail(reason="Categorical overrides __repr__") + @pytest.mark.parametrize("size", ["big", "small"]) + def test_array_repr(self, data, size): + super().test_array_repr(data, size) + + @pytest.mark.xfail(reason="TBD") + @pytest.mark.parametrize("as_index", [True, False]) + def test_groupby_extension_agg(self, as_index, data_for_grouping): + super().test_groupby_extension_agg(as_index, data_for_grouping) + + +class Test2DCompat(base.NDArrayBacked2DTests): + def test_repr_2d(self, data): + # Categorical __repr__ doesn't include "Categorical", so we need + # to special-case + res = repr(data.reshape(1, -1)) + assert res.count("\nCategories") == 1 + + res = repr(data.reshape(-1, 1)) + assert res.count("\nCategories") == 1 diff --git a/pandas/tests/extension/test_common.py b/pandas/tests/extension/test_common.py new file mode 100644 index 0000000000000000000000000000000000000000..aa5c91f478c45c363c34a7de1083154baa8f752b --- /dev/null +++ b/pandas/tests/extension/test_common.py @@ -0,0 +1,110 @@ +import numpy as np +import pytest + +from pandas.core.dtypes import dtypes +from pandas.core.dtypes.common import is_extension_array_dtype + +import pandas as pd +import pandas._testing as tm +from pandas.core.arrays import ExtensionArray + + +class DummyDtype(dtypes.ExtensionDtype): + pass + + +class DummyArray(ExtensionArray): + def __init__(self, data) -> None: + self.data = data + + def __array__(self, dtype=None, copy=None): + return self.data + + @property + def dtype(self): + return DummyDtype() + + def astype(self, dtype, copy=True): + # we don't support anything but a single dtype + if isinstance(dtype, DummyDtype): + if copy: + return type(self)(self.data) + return self + elif not copy: + return np.asarray(self, dtype=dtype) + else: + return np.array(self, dtype=dtype, copy=copy) + + +class TestExtensionArrayDtype: + @pytest.mark.parametrize( + "values", + [ + pd.Categorical([]), + pd.Categorical([]).dtype, + pd.Series(pd.Categorical([])), + DummyDtype(), + DummyArray(np.array([1, 2])), + ], + ) + def test_is_extension_array_dtype(self, values): + assert is_extension_array_dtype(values) + + @pytest.mark.parametrize("values", [np.array([]), pd.Series(np.array([]))]) + def test_is_not_extension_array_dtype(self, values): + assert not is_extension_array_dtype(values) + + +def test_astype(): + arr = DummyArray(np.array([1, 2, 3])) + expected = np.array([1, 2, 3], dtype=object) + + result = arr.astype(object) + tm.assert_numpy_array_equal(result, expected) + + result = arr.astype("object") + tm.assert_numpy_array_equal(result, expected) + + +def test_astype_no_copy(): + arr = DummyArray(np.array([1, 2, 3], dtype=np.int64)) + result = arr.astype(arr.dtype, copy=False) + + assert arr is result + + result = arr.astype(arr.dtype) + assert arr is not result + + +@pytest.mark.parametrize("dtype", [dtypes.CategoricalDtype(), dtypes.IntervalDtype()]) +def test_is_extension_array_dtype(dtype): + assert isinstance(dtype, dtypes.ExtensionDtype) + assert is_extension_array_dtype(dtype) + + +class CapturingStringArray(pd.arrays.StringArray): + """Extend StringArray to capture arguments to __getitem__""" + + def __getitem__(self, item): + self.last_item_arg = item + return super().__getitem__(item) + + +def test_ellipsis_index(): + # GH#42430 1D slices over extension types turn into N-dimensional slices + # over ExtensionArrays + dtype = pd.StringDtype() + df = pd.DataFrame( + { + "col1": CapturingStringArray( + np.array(["hello", "world"], dtype=object), dtype=dtype + ) + } + ) + _ = df.iloc[:1] + + # String comparison because there's no native way to compare slices. + # Before the fix for GH#42430, last_item_arg would get set to the 2D slice + # (Ellipsis, slice(None, 1, None)) + out = df["col1"]._values.last_item_arg + assert str(out) == "slice(None, 1, None)" diff --git a/pandas/tests/extension/test_datetime.py b/pandas/tests/extension/test_datetime.py new file mode 100644 index 0000000000000000000000000000000000000000..ab9eff220914da0b9d33ff81bafc8e644d2290b8 --- /dev/null +++ b/pandas/tests/extension/test_datetime.py @@ -0,0 +1,148 @@ +""" +This file contains a minimal set of tests for compliance with the extension +array interface test suite, and should contain no other tests. +The test suite for the full functionality of the array is located in +`pandas/tests/arrays/`. + +The tests in this file are inherited from the BaseExtensionTests, and only +minimal tweaks should be applied to get the tests passing (by overwriting a +parent method). + +Additional tests should either be added to one of the BaseExtensionTests +classes (if they are relevant for the extension interface for all dtypes), or +be added to the array-specific tests in `pandas/tests/arrays/`. + +""" + +import numpy as np +import pytest + +from pandas.core.dtypes.dtypes import DatetimeTZDtype + +import pandas as pd +import pandas._testing as tm +from pandas.core.arrays import DatetimeArray +from pandas.tests.extension import base + + +@pytest.fixture +def dtype(): + return DatetimeTZDtype(unit="ns", tz="US/Central") + + +@pytest.fixture +def data(dtype): + data = DatetimeArray._from_sequence( + pd.date_range("2000", periods=10, tz=dtype.tz), dtype=dtype + ) + return data + + +@pytest.fixture +def data_missing(dtype): + return DatetimeArray._from_sequence( + np.array(["NaT", "2000-01-01"], dtype="datetime64[ns]"), dtype=dtype + ) + + +@pytest.fixture +def data_for_sorting(dtype): + a = pd.Timestamp("2000-01-01") + b = pd.Timestamp("2000-01-02") + c = pd.Timestamp("2000-01-03") + return DatetimeArray._from_sequence( + np.array([b, c, a], dtype="datetime64[ns]"), dtype=dtype + ) + + +@pytest.fixture +def data_missing_for_sorting(dtype): + a = pd.Timestamp("2000-01-01") + b = pd.Timestamp("2000-01-02") + return DatetimeArray._from_sequence( + np.array([b, "NaT", a], dtype="datetime64[ns]"), dtype=dtype + ) + + +@pytest.fixture +def data_for_grouping(dtype): + """ + Expected to be like [B, B, NA, NA, A, A, B, C] + + Where A < B < C and NA is missing + """ + a = pd.Timestamp("2000-01-01") + b = pd.Timestamp("2000-01-02") + c = pd.Timestamp("2000-01-03") + na = "NaT" + return DatetimeArray._from_sequence( + np.array([b, b, na, na, a, a, b, c], dtype="datetime64[ns]"), dtype=dtype + ) + + +@pytest.fixture +def na_cmp(): + def cmp(a, b): + return a is pd.NaT and a is b + + return cmp + + +# ---------------------------------------------------------------------------- +class TestDatetimeArray(base.ExtensionTests): + def _get_expected_exception(self, op_name, obj, other): + if op_name in ["__sub__", "__rsub__"]: + return None + return super()._get_expected_exception(op_name, obj, other) + + def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool): + if op_name == "std": + return "timedelta64[ns]" + return arr.dtype + + def _supports_accumulation(self, ser, op_name: str) -> bool: + return op_name in ["cummin", "cummax"] + + def _supports_reduction(self, obj, op_name: str) -> bool: + return op_name in ["min", "max", "median", "mean", "std", "any", "all"] + + @pytest.mark.parametrize("skipna", [True, False]) + def test_reduce_series_boolean(self, data, all_boolean_reductions, skipna): + meth = all_boolean_reductions + msg = f"datetime64 type does not support operation '{meth}'" + with pytest.raises(TypeError, match=msg): + super().test_reduce_series_boolean(data, all_boolean_reductions, skipna) + + def test_series_constructor(self, data): + # Series construction drops any .freq attr + data = data._with_freq(None) + super().test_series_constructor(data) + + @pytest.mark.parametrize("na_action", [None, "ignore"]) + def test_map(self, data, na_action): + result = data.map(lambda x: x, na_action=na_action) + tm.assert_extension_array_equal(result, data) + + def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool): + if op_name in ["median", "mean", "std"]: + alt = ser.astype("int64") + + res_op = getattr(ser, op_name) + exp_op = getattr(alt, op_name) + result = res_op(skipna=skipna) + expected = exp_op(skipna=skipna) + if op_name in ["mean", "median"]: + # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" + # has no attribute "tz" + tz = ser.dtype.tz # type: ignore[union-attr] + expected = pd.Timestamp(expected, tz=tz) + else: + expected = pd.Timedelta(expected) + tm.assert_almost_equal(result, expected) + + else: + return super().check_reduce(ser, op_name, skipna) + + +class Test2DCompat(base.NDArrayBacked2DTests): + pass diff --git a/pandas/tests/extension/test_extension.py b/pandas/tests/extension/test_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..456f4863b1c313c624ac42905a3b4117f6ba556d --- /dev/null +++ b/pandas/tests/extension/test_extension.py @@ -0,0 +1,27 @@ +""" +Tests for behavior if an author does *not* implement EA methods. +""" + +import numpy as np +import pytest + +from pandas.core.arrays import ExtensionArray + + +class MyEA(ExtensionArray): + def __init__(self, values) -> None: + self._values = values + + +@pytest.fixture +def data(): + arr = np.arange(10) + return MyEA(arr) + + +class TestExtensionArray: + def test_errors(self, data, all_arithmetic_operators): + # invalid ops + op_name = all_arithmetic_operators + with pytest.raises(AttributeError): + getattr(data, op_name) diff --git a/pandas/tests/extension/test_interval.py b/pandas/tests/extension/test_interval.py new file mode 100644 index 0000000000000000000000000000000000000000..47bc26ba4a7666b67fe774b4287a6e9660cd82c8 --- /dev/null +++ b/pandas/tests/extension/test_interval.py @@ -0,0 +1,147 @@ +""" +This file contains a minimal set of tests for compliance with the extension +array interface test suite, and should contain no other tests. +The test suite for the full functionality of the array is located in +`pandas/tests/arrays/`. + +The tests in this file are inherited from the BaseExtensionTests, and only +minimal tweaks should be applied to get the tests passing (by overwriting a +parent method). + +Additional tests should either be added to one of the BaseExtensionTests +classes (if they are relevant for the extension interface for all dtypes), or +be added to the array-specific tests in `pandas/tests/arrays/`. + +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from pandas.core.dtypes.dtypes import IntervalDtype + +from pandas import Interval +from pandas.core.arrays import IntervalArray +from pandas.tests.extension import base + +if TYPE_CHECKING: + import pandas as pd + + +def make_data(n: int): + left_array = np.random.default_rng(2).uniform(size=n).cumsum() + right_array = left_array + np.random.default_rng(2).uniform(size=n) + return [ + Interval(left, right) + for left, right in zip(left_array, right_array, strict=True) + ] + + +@pytest.fixture +def dtype(): + return IntervalDtype() + + +@pytest.fixture +def data(): + """Length-10 IntervalArray for semantics test.""" + return IntervalArray(make_data(10)) + + +@pytest.fixture +def data_missing(): + """Length 2 array with [NA, Valid]""" + return IntervalArray.from_tuples([None, (0, 1)]) + + +@pytest.fixture +def data_for_twos(): + pytest.skip("Interval is not a numeric dtype") + + +@pytest.fixture +def data_for_sorting(): + return IntervalArray.from_tuples([(1, 2), (2, 3), (0, 1)]) + + +@pytest.fixture +def data_missing_for_sorting(): + return IntervalArray.from_tuples([(1, 2), None, (0, 1)]) + + +@pytest.fixture +def data_for_grouping(): + a = (0, 1) + b = (1, 2) + c = (2, 3) + return IntervalArray.from_tuples([b, b, None, None, a, a, b, c]) + + +class TestIntervalArray(base.ExtensionTests): + divmod_exc = TypeError + + def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: + return op_name in ["min", "max"] + + def test_fillna_limit_frame(self, data_missing): + # GH#58001 + with pytest.raises(ValueError, match="limit must be None"): + super().test_fillna_limit_frame(data_missing) + + def test_fillna_limit_series(self, data_missing): + # GH#58001 + with pytest.raises(ValueError, match="limit must be None"): + super().test_fillna_limit_frame(data_missing) + + @pytest.mark.xfail( + reason="Raises with incorrect message bc it disallows *all* listlikes " + "instead of just wrong-length listlikes" + ) + def test_fillna_length_mismatch(self, data_missing): + super().test_fillna_length_mismatch(data_missing) + + @pytest.mark.xfail(reason="copy=False is not Implemented") + def test_fillna_readonly(self, data_missing): + super().test_fillna_readonly(data_missing) + + @pytest.mark.filterwarnings( + "ignore:invalid value encountered in cast:RuntimeWarning" + ) + def test_hash_pandas_object(self, data): + super().test_hash_pandas_object(data) + + @pytest.mark.filterwarnings( + "ignore:invalid value encountered in cast:RuntimeWarning" + ) + def test_hash_pandas_object_works(self, data, as_frame): + super().test_hash_pandas_object_works(data, as_frame) + + @pytest.mark.filterwarnings( + "ignore:invalid value encountered in cast:RuntimeWarning" + ) + @pytest.mark.parametrize("engine", ["c", "python"]) + def test_EA_types(self, engine, data, request): + super().test_EA_types(engine, data, request) + + @pytest.mark.filterwarnings( + "ignore:invalid value encountered in cast:RuntimeWarning" + ) + def test_astype_str(self, data): + super().test_astype_str(data) + + @pytest.mark.xfail( + reason="Test is invalid for IntervalDtype, needs to be adapted for " + "this dtype with an index with index._index_as_unique." + ) + def test_loc_setitem_with_expansion_preserves_ea_index_dtype(self, data): + super().test_loc_setitem_with_expansion_preserves_ea_index_dtype(data) + + +# TODO: either belongs in tests.arrays.interval or move into base tests. +def test_fillna_non_scalar_raises(data_missing): + msg = "can only insert Interval objects and NA into an IntervalArray" + with pytest.raises(TypeError, match=msg): + data_missing.fillna([1, 1]) diff --git a/pandas/tests/extension/test_masked.py b/pandas/tests/extension/test_masked.py new file mode 100644 index 0000000000000000000000000000000000000000..fadc51ea714c0413133c1ced5b6666e52fd1cf51 --- /dev/null +++ b/pandas/tests/extension/test_masked.py @@ -0,0 +1,375 @@ +""" +This file contains a minimal set of tests for compliance with the extension +array interface test suite, and should contain no other tests. +The test suite for the full functionality of the array is located in +`pandas/tests/arrays/`. + +The tests in this file are inherited from the BaseExtensionTests, and only +minimal tweaks should be applied to get the tests passing (by overwriting a +parent method). + +Additional tests should either be added to one of the BaseExtensionTests +classes (if they are relevant for the extension interface for all dtypes), or +be added to the array-specific tests in `pandas/tests/arrays/`. + +""" + +import numpy as np +import pytest + +from pandas.compat import ( + IS64, + is_platform_windows, +) +from pandas.compat.numpy import np_version_gt2 + +from pandas.core.dtypes.common import ( + is_float_dtype, + is_signed_integer_dtype, + is_unsigned_integer_dtype, +) + +import pandas as pd +import pandas._testing as tm +from pandas.core.arrays.boolean import BooleanDtype +from pandas.core.arrays.floating import ( + Float32Dtype, + Float64Dtype, +) +from pandas.core.arrays.integer import ( + Int8Dtype, + Int16Dtype, + Int32Dtype, + Int64Dtype, + UInt8Dtype, + UInt16Dtype, + UInt32Dtype, + UInt64Dtype, +) +from pandas.tests.extension import base + +is_windows_or_32bit = (is_platform_windows() and not np_version_gt2) or not IS64 + +pytestmark = [ + pytest.mark.filterwarnings( + "ignore:invalid value encountered in divide:RuntimeWarning" + ), + pytest.mark.filterwarnings("ignore:Mean of empty slice:RuntimeWarning"), + # overflow only relevant for Floating dtype cases cases + pytest.mark.filterwarnings("ignore:overflow encountered in reduce:RuntimeWarning"), +] + + +def make_data(): + return [1, 2, 3, 4, pd.NA, 10, 11, pd.NA, 99, 100] + + +def make_float_data(): + return [0.1, 0.2, 0.3, 0.4, pd.NA, 1.0, 1.1, pd.NA, 9.9, 10.0] + + +def make_bool_data(): + return [True, False] * 2 + [np.nan] + [True, False] + [np.nan] + [True, False] + + +@pytest.fixture( + params=[ + Int8Dtype, + Int16Dtype, + Int32Dtype, + Int64Dtype, + UInt8Dtype, + UInt16Dtype, + UInt32Dtype, + UInt64Dtype, + Float32Dtype, + Float64Dtype, + BooleanDtype, + ] +) +def dtype(request): + return request.param() + + +@pytest.fixture +def data(dtype): + if dtype.kind == "f": + data = make_float_data() + elif dtype.kind == "b": + data = make_bool_data() + else: + data = make_data() + return pd.array(data, dtype=dtype) + + +@pytest.fixture +def data_for_twos(dtype): + if dtype.kind == "b": + return pd.array(np.ones(10), dtype=dtype) + return pd.array(np.ones(10) * 2, dtype=dtype) + + +@pytest.fixture +def data_missing(dtype): + if dtype.kind == "f": + return pd.array([pd.NA, 0.1], dtype=dtype) + elif dtype.kind == "b": + return pd.array([np.nan, True], dtype=dtype) + return pd.array([pd.NA, 1], dtype=dtype) + + +@pytest.fixture +def data_for_sorting(dtype): + if dtype.kind == "f": + return pd.array([0.1, 0.2, 0.0], dtype=dtype) + elif dtype.kind == "b": + return pd.array([True, True, False], dtype=dtype) + return pd.array([1, 2, 0], dtype=dtype) + + +@pytest.fixture +def data_missing_for_sorting(dtype): + if dtype.kind == "f": + return pd.array([0.1, pd.NA, 0.0], dtype=dtype) + elif dtype.kind == "b": + return pd.array([True, np.nan, False], dtype=dtype) + return pd.array([1, pd.NA, 0], dtype=dtype) + + +@pytest.fixture +def na_cmp(): + # we are pd.NA + return lambda x, y: x is pd.NA and y is pd.NA + + +@pytest.fixture +def data_for_grouping(dtype): + if dtype.kind == "f": + b = 0.1 + a = 0.0 + c = 0.2 + elif dtype.kind == "b": + b = True + a = False + c = b + else: + b = 1 + a = 0 + c = 2 + + na = pd.NA + return pd.array([b, b, na, na, a, a, b, c], dtype=dtype) + + +class TestMaskedArrays(base.ExtensionTests): + _combine_le_expected_dtype = "boolean" + + @pytest.fixture(autouse=True) + def skip_if_doesnt_support_2d(self, dtype, request): + # Override the fixture so that we run these tests. + assert not dtype._supports_2d + # If dtype._supports_2d is ever changed to True, then this fixture + # override becomes unnecessary. + + @pytest.mark.parametrize("na_action", [None, "ignore"]) + def test_map(self, data_missing, na_action, using_nan_is_na): + result = data_missing.map(lambda x: x, na_action=na_action) + if data_missing.dtype == Float32Dtype() and using_nan_is_na: + # map roundtrips through objects, which converts to float64 + expected = data_missing.to_numpy(dtype="float64", na_value=np.nan) + else: + expected = data_missing.to_numpy() + tm.assert_numpy_array_equal(result, expected) + + def test_map_na_action_ignore(self, data_missing_for_sorting, using_nan_is_na): + zero = data_missing_for_sorting[2] + result = data_missing_for_sorting.map(lambda x: zero, na_action="ignore") + if data_missing_for_sorting.dtype.kind == "b": + expected = np.array([False, pd.NA, False], dtype=object) + elif not using_nan_is_na: + # TODO: would we prefer to get NaN in this case to get a non-object? + expected = np.array([zero, pd.NA, zero], dtype=object) + else: + expected = np.array([zero, np.nan, zero]) + tm.assert_numpy_array_equal(result, expected) + + def _get_expected_exception(self, op_name, obj, other): + try: + dtype = tm.get_dtype(obj) + except AttributeError: + # passed arguments reversed + dtype = tm.get_dtype(other) + + if dtype.kind == "b": + if op_name.strip("_").lstrip("r") in ["pow", "truediv", "floordiv"]: + # match behavior with non-masked bool dtype + return NotImplementedError + elif op_name in ["__sub__", "__rsub__"]: + # exception message would include "numpy boolean subtract"" + return TypeError + return None + return None + + def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): + sdtype = tm.get_dtype(obj) + expected = pointwise_result + + if sdtype.kind == "b": + if op_name in ( + "__mod__", + "__rmod__", + ): + # combine keeps boolean type + expected = expected.astype("Int8") + + return expected + + def test_divmod_series_array(self, data, data_for_twos, request): + if data.dtype.kind == "b": + mark = pytest.mark.xfail( + reason="Inconsistency between floordiv and divmod; we raise for " + "floordiv but not for divmod. This matches what we do for " + "non-masked bool dtype." + ) + request.applymarker(mark) + super().test_divmod_series_array(data, data_for_twos) + + def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: + if op_name in ["any", "all"] and ser.dtype.kind != "b": + pytest.skip(reason="Tested in tests/reductions/test_reductions.py") + return True + + def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool): + # overwrite to ensure pd.NA is tested instead of np.nan + # https://github.com/pandas-dev/pandas/issues/30958 + + cmp_dtype = "int64" + if ser.dtype.kind == "f": + # Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]" has + # no attribute "numpy_dtype" + cmp_dtype = ser.dtype.numpy_dtype # type: ignore[union-attr] + elif ser.dtype.kind == "b": + if op_name in ["min", "max"]: + cmp_dtype = "bool" + + # TODO: prod with integer dtypes does *not* match the result we would + # get if we used object for cmp_dtype. In that cae the object result + # is a large integer while the non-object case overflows and returns 0 + alt = ser.dropna().astype(cmp_dtype) + if op_name == "count": + result = getattr(ser, op_name)() + expected = getattr(alt, op_name)() + else: + result = getattr(ser, op_name)(skipna=skipna) + expected = getattr(alt, op_name)(skipna=skipna) + if not skipna and ser.isna().any() and op_name not in ["any", "all"]: + expected = pd.NA + tm.assert_almost_equal(result, expected) + + def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool): + if is_float_dtype(arr.dtype): + cmp_dtype = arr.dtype.name + elif op_name in ["mean", "median", "var", "std", "skew", "kurt", "sem"]: + cmp_dtype = "Float64" + elif op_name in ["max", "min"]: + cmp_dtype = arr.dtype.name + elif arr.dtype in ["Int64", "UInt64"]: + cmp_dtype = arr.dtype.name + elif is_signed_integer_dtype(arr.dtype): + # TODO: Why does Window Numpy 2.0 dtype depend on skipna? + cmp_dtype = ( + "Int32" + if (is_platform_windows() and (not np_version_gt2 or not skipna)) + or not IS64 + else "Int64" + ) + elif is_unsigned_integer_dtype(arr.dtype): + cmp_dtype = ( + "UInt32" + if (is_platform_windows() and (not np_version_gt2 or not skipna)) + or not IS64 + else "UInt64" + ) + elif arr.dtype.kind == "b": + if op_name in ["min", "max"]: + cmp_dtype = "boolean" + elif op_name in ["sum", "prod"]: + cmp_dtype = ( + "Int32" + if (is_platform_windows() and (not np_version_gt2 or not skipna)) + or not IS64 + else "Int64" + ) + else: + raise TypeError("not supposed to reach this") + else: + raise TypeError("not supposed to reach this") + return cmp_dtype + + def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool: + return True + + def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool): + # overwrite to ensure pd.NA is tested instead of np.nan + # https://github.com/pandas-dev/pandas/issues/30958 + length = 64 + if is_windows_or_32bit: + # Item "ExtensionDtype" of "Union[dtype[Any], ExtensionDtype]" has + # no attribute "itemsize" + if not ser.dtype.itemsize == 8: # type: ignore[union-attr] + length = 32 + + if ser.dtype.name.startswith("U"): + expected_dtype = f"UInt{length}" + elif ser.dtype.name.startswith("I"): + expected_dtype = f"Int{length}" + elif ser.dtype.name.startswith("F"): + # Incompatible types in assignment (expression has type + # "Union[dtype[Any], ExtensionDtype]", variable has type "str") + expected_dtype = ser.dtype # type: ignore[assignment] + elif ser.dtype.kind == "b": + if op_name in ("cummin", "cummax"): + expected_dtype = "boolean" + else: + expected_dtype = f"Int{length}" + + if expected_dtype == "Float32" and op_name == "cumprod" and skipna: + # TODO: xfail? + pytest.skip( + f"Float32 precision lead to large differences with op {op_name} " + f"and skipna={skipna}" + ) + + if op_name == "cumsum": + pass + elif op_name in ["cummax", "cummin"]: + expected_dtype = ser.dtype # type: ignore[assignment] + elif op_name == "cumprod": + ser = ser[:12] + else: + raise NotImplementedError(f"{op_name} not supported") + + result = getattr(ser, op_name)(skipna=skipna) + expected = pd.Series( + pd.array( + getattr(ser.astype("float64"), op_name)(skipna=skipna), + dtype="Float64", + ) + ) + expected[np.isnan(expected)] = pd.NA + expected = expected.astype(expected_dtype) + tm.assert_series_equal(result, expected) + + def test_loc_setitem_with_expansion_preserves_ea_index_dtype(self, data, request): + super().test_loc_setitem_with_expansion_preserves_ea_index_dtype(data) + + +@pytest.mark.parametrize( + "arr", [pd.array([True, False]), pd.array([1, 2]), pd.array([1.0, 2.0])] +) +def test_cast_pointwise_result_all_na_respects_original_dtype(arr): + # GH#62344 + values = [pd.NA, pd.NA] + result = arr._cast_pointwise_result(values) + assert result.dtype == arr.dtype + assert all(x is pd.NA for x in result) diff --git a/pandas/tests/extension/test_numpy.py b/pandas/tests/extension/test_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..c3f619e4263df86162ef6d9dee9698aa7492d83d --- /dev/null +++ b/pandas/tests/extension/test_numpy.py @@ -0,0 +1,439 @@ +""" +This file contains a minimal set of tests for compliance with the extension +array interface test suite, and should contain no other tests. +The test suite for the full functionality of the array is located in +`pandas/tests/arrays/`. + +The tests in this file are inherited from the BaseExtensionTests, and only +minimal tweaks should be applied to get the tests passing (by overwriting a +parent method). + +Additional tests should either be added to one of the BaseExtensionTests +classes (if they are relevant for the extension interface for all dtypes), or +be added to the array-specific tests in `pandas/tests/arrays/`. + +Note: we do not bother with base.BaseIndexTests because NumpyExtensionArray +will never be held in an Index. +""" + +import numpy as np +import pytest + +from pandas.core.dtypes.dtypes import NumpyEADtype + +import pandas as pd +import pandas._testing as tm +from pandas.api.types import is_object_dtype +from pandas.core.arrays.numpy_ import NumpyExtensionArray +from pandas.tests.extension import base + +orig_assert_attr_equal = tm.assert_attr_equal + + +def _assert_attr_equal(attr: str, left, right, obj: str = "Attributes"): + """ + patch tm.assert_attr_equal so NumpyEADtype("object") is closed enough to + np.dtype("object") + """ + if attr == "dtype": + lattr = getattr(left, "dtype", None) + rattr = getattr(right, "dtype", None) + if isinstance(lattr, NumpyEADtype) and not isinstance(rattr, NumpyEADtype): + left = left.astype(lattr.numpy_dtype) + elif isinstance(rattr, NumpyEADtype) and not isinstance(lattr, NumpyEADtype): + right = right.astype(rattr.numpy_dtype) + + orig_assert_attr_equal(attr, left, right, obj) + + +@pytest.fixture(params=["float", "object"]) +def dtype(request): + return NumpyEADtype(np.dtype(request.param)) + + +@pytest.fixture +def allow_in_pandas(monkeypatch): + """ + A monkeypatch to tells pandas to let us in. + + By default, passing a NumpyExtensionArray to an index / series / frame + constructor will unbox that NumpyExtensionArray to an ndarray, and treat + it as a non-EA column. We don't want people using EAs without + reason. + + The mechanism for this is a check against ABCNumpyExtensionArray + in each constructor. + + But, for testing, we need to allow them in pandas. So we patch + the _typ of NumpyExtensionArray, so that we evade the ABCNumpyExtensionArray + check. + """ + with monkeypatch.context() as m: + m.setattr(NumpyExtensionArray, "_typ", "extension") + m.setattr(tm.asserters, "assert_attr_equal", _assert_attr_equal) + yield + + +@pytest.fixture +def data(allow_in_pandas, dtype): + if dtype.numpy_dtype == "object": + arr = pd.Series([(i,) for i in range(10)])._values + else: + arr = np.arange(1, 11, dtype=dtype._dtype) + return NumpyExtensionArray(arr) + + +@pytest.fixture +def data_missing(allow_in_pandas, dtype): + if dtype.numpy_dtype == "object": + return NumpyExtensionArray(np.array([np.nan, (1,)], dtype=object)) + return NumpyExtensionArray(np.array([np.nan, 1.0])) + + +@pytest.fixture +def na_cmp(): + def cmp(a, b): + return np.isnan(a) and np.isnan(b) + + return cmp + + +@pytest.fixture +def data_for_sorting(allow_in_pandas, dtype): + """Length-3 array with a known sort order. + + This should be three items [B, C, A] with + A < B < C + """ + if dtype.numpy_dtype == "object": + # Use an empty tuple for first element, then remove, + # to disable np.array's shape inference. + return NumpyExtensionArray(np.array([(), (2,), (3,), (1,)], dtype=object)[1:]) + return NumpyExtensionArray(np.array([1, 2, 0])) + + +@pytest.fixture +def data_missing_for_sorting(allow_in_pandas, dtype): + """Length-3 array with a known sort order. + + This should be three items [B, NA, A] with + A < B and NA missing. + """ + if dtype.numpy_dtype == "object": + return NumpyExtensionArray(np.array([(1,), np.nan, (0,)], dtype=object)) + return NumpyExtensionArray(np.array([1, np.nan, 0])) + + +@pytest.fixture +def data_for_grouping(allow_in_pandas, dtype): + """Data for factorization, grouping, and unique tests. + + Expected to be like [B, B, NA, NA, A, A, B, C] + + Where A < B < C and NA is missing + """ + if dtype.numpy_dtype == "object": + a, b, c = (1,), (2,), (3,) + else: + a, b, c = np.arange(3) + return NumpyExtensionArray( + np.array([b, b, np.nan, np.nan, a, a, b, c], dtype=dtype.numpy_dtype) + ) + + +@pytest.fixture +def data_for_twos(dtype): + if dtype.kind == "O": + pytest.skip(f"{dtype} is not a numeric dtype") + arr = np.ones(10) * 2 + return NumpyExtensionArray._from_sequence(arr, dtype=dtype) + + +@pytest.fixture +def skip_numpy_object(dtype, request): + """ + Tests for NumpyExtensionArray with nested data. Users typically won't create + these objects via `pd.array`, but they can show up through `.array` + on a Series with nested data. Many of the base tests fail, as they aren't + appropriate for nested data. + + This fixture allows these tests to be skipped when used as a usefixtures + marker to either an individual test or a test class. + """ + if dtype == "object": + mark = pytest.mark.xfail(reason="Fails for object dtype") + request.applymarker(mark) + + +skip_nested = pytest.mark.usefixtures("skip_numpy_object") + + +class TestNumpyExtensionArray(base.ExtensionTests): + @pytest.mark.skip(reason="We don't register our dtype") + # We don't want to register. This test should probably be split in two. + def test_from_dtype(self, data): + pass + + @skip_nested + def test_series_constructor_scalar_with_index(self, data, dtype): + # ValueError: Length of passed values is 1, index implies 3. + super().test_series_constructor_scalar_with_index(data, dtype) + + def test_check_dtype(self, data, request, using_infer_string): + if data.dtype.numpy_dtype == "object": + request.applymarker( + pytest.mark.xfail( + reason=f"NumpyExtensionArray expectedly clashes with a " + f"NumPy name: {data.dtype.numpy_dtype}" + ) + ) + super().test_check_dtype(data) + + def test_is_not_object_type(self, dtype, request): + if dtype.numpy_dtype == "object": + # Different from BaseDtypeTests.test_is_not_object_type + # because NumpyEADtype(object) is an object type + assert is_object_dtype(dtype) + else: + super().test_is_not_object_type(dtype) + + @skip_nested + def test_getitem_scalar(self, data): + # AssertionError + super().test_getitem_scalar(data) + + @skip_nested + def test_shift_fill_value(self, data): + # np.array shape inference. Shift implementation fails. + super().test_shift_fill_value(data) + + @skip_nested + def test_fillna_limit_frame(self, data_missing): + # GH#58001 + # The "scalar" for this array isn't a scalar. + super().test_fillna_limit_frame(data_missing) + + @skip_nested + def test_fillna_limit_series(self, data_missing): + # GH#58001 + # The "scalar" for this array isn't a scalar. + super().test_fillna_limit_series(data_missing) + + @skip_nested + def test_fillna_copy_frame(self, data_missing): + # The "scalar" for this array isn't a scalar. + super().test_fillna_copy_frame(data_missing) + + @skip_nested + def test_fillna_copy_series(self, data_missing): + # The "scalar" for this array isn't a scalar. + super().test_fillna_copy_series(data_missing) + + @skip_nested + def test_searchsorted(self, data_for_sorting, as_series): + # TODO: NumpyExtensionArray.searchsorted calls ndarray.searchsorted which + # isn't quite what we want in nested data cases. Instead we need to + # adapt something like libindex._bin_search. + super().test_searchsorted(data_for_sorting, as_series) + + @pytest.mark.xfail(reason="NumpyExtensionArray.diff may fail on dtype") + def test_diff(self, data, periods): + return super().test_diff(data, periods) + + def test_insert(self, data, request): + if data.dtype.numpy_dtype == object: + mark = pytest.mark.xfail(reason="Dimension mismatch in np.concatenate") + request.applymarker(mark) + + super().test_insert(data) + + @skip_nested + def test_insert_invalid(self, data, invalid_scalar): + # NumpyExtensionArray[object] can hold anything, so skip + super().test_insert_invalid(data, invalid_scalar) + + divmod_exc = None + series_scalar_exc = None + frame_scalar_exc = None + series_array_exc = None + + def test_divmod(self, data): + divmod_exc = None + if data.dtype.kind == "O": + divmod_exc = TypeError + self.divmod_exc = divmod_exc + super().test_divmod(data) + + def test_divmod_series_array(self, data): + ser = pd.Series(data) + exc = None + if data.dtype.kind == "O": + exc = TypeError + self.divmod_exc = exc + self._check_divmod_op(ser, divmod, data) + + def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request): + opname = all_arithmetic_operators + series_scalar_exc = None + if data.dtype.numpy_dtype == object: + if opname in ["__mul__", "__rmul__"]: + mark = pytest.mark.xfail( + reason="the Series.combine step raises but not the Series method." + ) + request.node.add_marker(mark) + series_scalar_exc = TypeError + self.series_scalar_exc = series_scalar_exc + super().test_arith_series_with_scalar(data, all_arithmetic_operators) + + def test_arith_series_with_array(self, data, all_arithmetic_operators): + opname = all_arithmetic_operators + series_array_exc = None + if data.dtype.numpy_dtype == object and opname not in ["__add__", "__radd__"]: + series_array_exc = TypeError + self.series_array_exc = series_array_exc + super().test_arith_series_with_array(data, all_arithmetic_operators) + + def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request): + opname = all_arithmetic_operators + frame_scalar_exc = None + if data.dtype.numpy_dtype == object: + if opname in ["__mul__", "__rmul__"]: + mark = pytest.mark.xfail( + reason="the Series.combine step raises but not the Series method." + ) + request.node.add_marker(mark) + frame_scalar_exc = TypeError + self.frame_scalar_exc = frame_scalar_exc + super().test_arith_frame_with_scalar(data, all_arithmetic_operators) + + def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: + if ser.dtype.kind == "O": + return op_name in ["sum", "min", "max", "any", "all"] + return True + + def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool): + res_op = getattr(ser, op_name) + # avoid coercing int -> float. Just cast to the actual numpy type. + # error: Item "ExtensionDtype" of "dtype[Any] | ExtensionDtype" has + # no attribute "numpy_dtype" + cmp_dtype = ser.dtype.numpy_dtype # type: ignore[union-attr] + alt = ser.astype(cmp_dtype) + exp_op = getattr(alt, op_name) + if op_name == "count": + result = res_op() + expected = exp_op() + else: + result = res_op(skipna=skipna) + expected = exp_op(skipna=skipna) + tm.assert_almost_equal(result, expected) + + @pytest.mark.skip("TODO: tests not written yet") + @pytest.mark.parametrize("skipna", [True, False]) + def test_reduce_frame(self, data, all_numeric_reductions, skipna): + pass + + @skip_nested + def test_fillna_series(self, data_missing): + # Non-scalar "scalar" values. + super().test_fillna_series(data_missing) + + @skip_nested + def test_fillna_frame(self, data_missing): + # Non-scalar "scalar" values. + super().test_fillna_frame(data_missing) + + @skip_nested + def test_fillna_readonly(self, data_missing): + # Non-scalar "scalar" values. + super().test_fillna_readonly(data_missing) + + @skip_nested + def test_setitem_invalid(self, data, invalid_scalar): + # object dtype can hold anything, so doesn't raise + super().test_setitem_invalid(data, invalid_scalar) + + @skip_nested + def test_setitem_sequence_broadcasts(self, data, box_in_series): + # ValueError: cannot set using a list-like indexer with a different + # length than the value + super().test_setitem_sequence_broadcasts(data, box_in_series) + + @skip_nested + @pytest.mark.parametrize("setter", ["loc", None]) + def test_setitem_mask_broadcast(self, data, setter): + # ValueError: cannot set using a list-like indexer with a different + # length than the value + super().test_setitem_mask_broadcast(data, setter) + + @skip_nested + def test_setitem_scalar_key_sequence_raise(self, data): + # Failed: DID NOT RAISE + super().test_setitem_scalar_key_sequence_raise(data) + + # TODO: there is some issue with NumpyExtensionArray, therefore, + # skip the setitem test for now, and fix it later (GH 31446) + + @skip_nested + @pytest.mark.parametrize( + "mask", + [ + np.array([True, True, True, False, False]), + pd.array([True, True, True, False, False], dtype="boolean"), + ], + ids=["numpy-array", "boolean-array"], + ) + def test_setitem_mask(self, data, mask, box_in_series): + super().test_setitem_mask(data, mask, box_in_series) + + @skip_nested + @pytest.mark.parametrize( + "idx", + [[0, 1, 2], pd.array([0, 1, 2], dtype="Int64"), np.array([0, 1, 2])], + ids=["list", "integer-array", "numpy-array"], + ) + def test_setitem_integer_array(self, data, idx, box_in_series): + super().test_setitem_integer_array(data, idx, box_in_series) + + @skip_nested + def test_setitem_slice(self, data, box_in_series): + super().test_setitem_slice(data, box_in_series) + + @skip_nested + def test_setitem_loc_iloc_slice(self, data): + super().test_setitem_loc_iloc_slice(data) + + def test_setitem_with_expansion_dataframe_column(self, data, full_indexer): + # https://github.com/pandas-dev/pandas/issues/32395 + df = expected = pd.DataFrame({"data": pd.Series(data)}) + result = pd.DataFrame(index=df.index) + + # because result has object dtype, the attempt to do setting inplace + # is successful, and object dtype is retained + key = full_indexer(df) + result.loc[key, "data"] = df["data"] + + # base class method has expected = df; NumpyExtensionArray behaves oddly because + # we patch _typ for these tests. + if data.dtype.numpy_dtype != object: + if not isinstance(key, slice) or key != slice(None): + expected = pd.DataFrame({"data": data.to_numpy()}) + tm.assert_frame_equal(result, expected, check_column_type=False) + + @pytest.mark.xfail(reason="NumpyEADtype is unpacked") + def test_index_from_listlike_with_dtype(self, data): + super().test_index_from_listlike_with_dtype(data) + + @skip_nested + @pytest.mark.parametrize("engine", ["c", "python"]) + def test_EA_types(self, engine, data, request): + super().test_EA_types(engine, data, request) + + def test_loc_setitem_with_expansion_preserves_ea_index_dtype(self, data, request): + if isinstance(data[-1], tuple): + mark = pytest.mark.xfail(reason="Unpacks tuple") + request.applymarker(mark) + super().test_loc_setitem_with_expansion_preserves_ea_index_dtype(data) + + +class Test2DCompat(base.NDArrayBacked2DTests): + pass diff --git a/pandas/tests/extension/test_period.py b/pandas/tests/extension/test_period.py new file mode 100644 index 0000000000000000000000000000000000000000..a3be4e2b4420a569b2b5249c60a570897c847474 --- /dev/null +++ b/pandas/tests/extension/test_period.py @@ -0,0 +1,116 @@ +""" +This file contains a minimal set of tests for compliance with the extension +array interface test suite, and should contain no other tests. +The test suite for the full functionality of the array is located in +`pandas/tests/arrays/`. + +The tests in this file are inherited from the BaseExtensionTests, and only +minimal tweaks should be applied to get the tests passing (by overwriting a +parent method). + +Additional tests should either be added to one of the BaseExtensionTests +classes (if they are relevant for the extension interface for all dtypes), or +be added to the array-specific tests in `pandas/tests/arrays/`. + +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from pandas._libs import ( + Period, + iNaT, +) + +from pandas.core.dtypes.dtypes import PeriodDtype + +import pandas._testing as tm +from pandas.core.arrays import PeriodArray +from pandas.tests.extension import base + +if TYPE_CHECKING: + import pandas as pd + + +@pytest.fixture(params=["D", "2D"]) +def dtype(request): + return PeriodDtype(freq=request.param) + + +@pytest.fixture +def data(dtype): + return PeriodArray(np.arange(1970, 1980), dtype=dtype) + + +@pytest.fixture +def data_for_sorting(dtype): + return PeriodArray([2018, 2019, 2017], dtype=dtype) + + +@pytest.fixture +def data_missing(dtype): + return PeriodArray([iNaT, 2017], dtype=dtype) + + +@pytest.fixture +def data_missing_for_sorting(dtype): + return PeriodArray([2018, iNaT, 2017], dtype=dtype) + + +@pytest.fixture +def data_for_grouping(dtype): + B = 2018 + NA = iNaT + A = 2017 + C = 2019 + return PeriodArray([B, B, NA, NA, A, A, B, C], dtype=dtype) + + +class TestPeriodArray(base.ExtensionTests): + def _get_expected_exception(self, op_name, obj, other): + if op_name in ("__sub__", "__rsub__"): + return None + return super()._get_expected_exception(op_name, obj, other) + + def _supports_accumulation(self, ser, op_name: str) -> bool: + return op_name in ["cummin", "cummax"] + + def _supports_reduction(self, obj, op_name: str) -> bool: + return op_name in ["min", "max", "median"] + + def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool): + if op_name == "median": + res_op = getattr(ser, op_name) + + alt = ser.astype("int64") + + exp_op = getattr(alt, op_name) + result = res_op(skipna=skipna) + expected = exp_op(skipna=skipna) + # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no + # attribute "freq" + freq = ser.dtype.freq # type: ignore[union-attr] + expected = Period._from_ordinal(int(expected), freq=freq) + tm.assert_almost_equal(result, expected) + + else: + return super().check_reduce(ser, op_name, skipna) + + @pytest.mark.parametrize("periods", [1, -2]) + # NOTE: RuntimeWarning on Windows(non-ARM) platforms (in CI) + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_diff(self, request, data, periods): + super().test_diff(data, periods) + + @pytest.mark.parametrize("na_action", [None, "ignore"]) + def test_map(self, data, na_action): + result = data.map(lambda x: x, na_action=na_action) + tm.assert_extension_array_equal(result, data) + + +class Test2DCompat(base.NDArrayBacked2DTests): + pass diff --git a/pandas/tests/extension/test_sparse.py b/pandas/tests/extension/test_sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..59c752cd24163b1a698a9296974e1f595a359fbf --- /dev/null +++ b/pandas/tests/extension/test_sparse.py @@ -0,0 +1,517 @@ +""" +This file contains a minimal set of tests for compliance with the extension +array interface test suite, and should contain no other tests. +The test suite for the full functionality of the array is located in +`pandas/tests/arrays/`. + +The tests in this file are inherited from the BaseExtensionTests, and only +minimal tweaks should be applied to get the tests passing (by overwriting a +parent method). + +Additional tests should either be added to one of the BaseExtensionTests +classes (if they are relevant for the extension interface for all dtypes), or +be added to the array-specific tests in `pandas/tests/arrays/`. + +""" + +import numpy as np +import pytest + +import pandas as pd +from pandas import SparseDtype +import pandas._testing as tm +from pandas.arrays import SparseArray +from pandas.tests.extension import base + + +def make_data(fill_value, n: int): + rng = np.random.default_rng(2) + if np.isnan(fill_value): + data = rng.uniform(size=n) + else: + data = rng.integers(1, 100, size=n, dtype=int) + if data[0] == data[1]: + data[0] += 1 + + data[2::3] = fill_value + return data + + +@pytest.fixture +def dtype(): + return SparseDtype() + + +@pytest.fixture(params=[0, np.nan]) +def data(request): + """Length-10 SparseArray for semantics test.""" + res = SparseArray(make_data(request.param, 10), fill_value=request.param) + return res + + +@pytest.fixture +def data_for_twos(): + return SparseArray(np.ones(10) * 2) + + +@pytest.fixture(params=[0, np.nan]) +def data_missing(request): + """Length 2 array with [NA, Valid]""" + return SparseArray([np.nan, 1], fill_value=request.param) + + +@pytest.fixture(params=[0, np.nan]) +def data_repeated(request): + """Return different versions of data for count times""" + + def gen(count): + for _ in range(count): + yield SparseArray(make_data(request.param, 10), fill_value=request.param) + + return gen + + +@pytest.fixture(params=[0, np.nan]) +def data_for_sorting(request): + return SparseArray([2, 3, 1], fill_value=request.param) + + +@pytest.fixture(params=[0, np.nan]) +def data_missing_for_sorting(request): + return SparseArray([2, np.nan, 1], fill_value=request.param) + + +@pytest.fixture +def na_cmp(): + return lambda left, right: pd.isna(left) and pd.isna(right) + + +@pytest.fixture(params=[0, np.nan]) +def data_for_grouping(request): + return SparseArray([1, 1, np.nan, np.nan, 2, 2, 1, 3], fill_value=request.param) + + +@pytest.fixture(params=[0, np.nan]) +def data_for_compare(request): + return SparseArray([0, 0, np.nan, -2, -1, 4, 2, 3, 0, 0], fill_value=request.param) + + +class TestSparseArray(base.ExtensionTests): + def _supports_reduction(self, obj, op_name: str) -> bool: + return True + + @pytest.mark.parametrize("skipna", [True, False]) + def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, request): + if all_numeric_reductions in [ + "prod", + "median", + "var", + "std", + "sem", + "skew", + "kurt", + ]: + mark = pytest.mark.xfail( + reason="This should be viable but is not implemented" + ) + request.node.add_marker(mark) + elif ( + all_numeric_reductions in ["sum", "max", "min", "mean"] + and data.dtype.kind == "f" + and not skipna + ): + mark = pytest.mark.xfail(reason="getting a non-nan float") + request.node.add_marker(mark) + + super().test_reduce_series_numeric(data, all_numeric_reductions, skipna) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_reduce_frame(self, data, all_numeric_reductions, skipna, request): + if all_numeric_reductions in [ + "prod", + "median", + "var", + "std", + "sem", + "skew", + "kurt", + ]: + mark = pytest.mark.xfail( + reason="This should be viable but is not implemented" + ) + request.node.add_marker(mark) + elif ( + all_numeric_reductions in ["sum", "max", "min", "mean"] + and data.dtype.kind == "f" + and not skipna + ): + mark = pytest.mark.xfail(reason="ExtensionArray NA mask are different") + request.node.add_marker(mark) + + super().test_reduce_frame(data, all_numeric_reductions, skipna) + + def _check_unsupported(self, data): + if data.dtype == SparseDtype(int, 0): + pytest.skip("Can't store nan in int array.") + + def test_concat_mixed_dtypes(self, data): + # https://github.com/pandas-dev/pandas/issues/20762 + # This should be the same, aside from concat([sparse, float]) + df1 = pd.DataFrame({"A": data[:3]}) + df2 = pd.DataFrame({"A": [1, 2, 3]}) + df3 = pd.DataFrame({"A": ["a", "b", "c"]}).astype("category") + dfs = [df1, df2, df3] + + # dataframes + result = pd.concat(dfs) + expected = pd.concat( + [x.apply(lambda s: np.asarray(s).astype(object)) for x in dfs] + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize( + "columns", + [ + ["A", "B"], + pd.MultiIndex.from_tuples( + [("A", "a"), ("A", "b")], names=["outer", "inner"] + ), + ], + ) + @pytest.mark.parametrize("future_stack", [True, False]) + def test_stack(self, data, columns, future_stack): + super().test_stack(data, columns, future_stack) + + def test_concat_columns(self, data, na_value): + self._check_unsupported(data) + super().test_concat_columns(data, na_value) + + def test_concat_extension_arrays_copy_false(self, data, na_value): + self._check_unsupported(data) + super().test_concat_extension_arrays_copy_false(data, na_value) + + def test_align(self, data, na_value): + self._check_unsupported(data) + super().test_align(data, na_value) + + def test_align_frame(self, data, na_value): + self._check_unsupported(data) + super().test_align_frame(data, na_value) + + def test_align_series_frame(self, data, na_value): + self._check_unsupported(data) + super().test_align_series_frame(data, na_value) + + def test_merge(self, data, na_value): + self._check_unsupported(data) + super().test_merge(data, na_value) + + def test_get(self, data): + ser = pd.Series(data, index=[2 * i for i in range(len(data))]) + if np.isnan(ser.values.fill_value): + assert np.isnan(ser.get(4)) and np.isnan(ser.iloc[2]) + else: + assert ser.get(4) == ser.iloc[2] + assert ser.get(2) == ser.iloc[1] + + def test_array_item_with_index(self, data, request): + # TODO https://github.com/pandas-dev/pandas/pull/64183 + request.node.add_marker(pytest.mark.xfail(reason="SparseArray getitem buggy")) + super().test_array_item_with_index(data) + + def test_reindex(self, data, na_value): + self._check_unsupported(data) + super().test_reindex(data, na_value) + + def test_isna(self, data_missing): + sarr = SparseArray(data_missing) + expected_dtype = SparseDtype(bool, pd.isna(data_missing.dtype.fill_value)) + expected = SparseArray([True, False], dtype=expected_dtype) + result = sarr.isna() + tm.assert_sp_array_equal(result, expected) + + # test isna for arr without na + sarr = sarr.fillna(0) + expected_dtype = SparseDtype(bool, pd.isna(data_missing.dtype.fill_value)) + expected = SparseArray([False, False], fill_value=False, dtype=expected_dtype) + tm.assert_equal(sarr.isna(), expected) + + def test_fillna_no_op_returns_copy(self, data, request): + super().test_fillna_no_op_returns_copy(data) + + def test_fillna_readonly(self, data_missing): + # copy keyword is ignored by SparseArray.fillna + # -> copy=True vs False doesn't make a difference + data = data_missing.copy() + data._readonly = True + + result = data.fillna(data_missing[1]) + assert result[0] == data_missing[1] + tm.assert_extension_array_equal(data, data_missing) + + # fillna(copy=False) is ignored -> so same result as above + result = data.fillna(data_missing[1], copy=False) + assert result[0] == data_missing[1] + tm.assert_extension_array_equal(data, data_missing) + + @pytest.mark.xfail(reason="Unsupported") + def test_fillna_series(self, data_missing): + # this one looks doable. + # TODO: this fails bc we do not pass through data_missing. If we did, + # the 0-fill case would xpass + super().test_fillna_series() + + def test_fillna_frame(self, data_missing): + # Have to override to specify that fill_value will change. + fill_value = data_missing[1] + + result = pd.DataFrame({"A": data_missing, "B": [1, 2]}).fillna(fill_value) + + if pd.isna(data_missing.fill_value): + dtype = SparseDtype(data_missing.dtype, fill_value) + else: + dtype = data_missing.dtype + + expected = pd.DataFrame( + { + "A": data_missing._from_sequence([fill_value, fill_value], dtype=dtype), + "B": [1, 2], + } + ) + + tm.assert_frame_equal(result, expected) + + def test_fillna_limit_frame(self, data_missing): + # GH#58001 + with pytest.raises(ValueError, match="limit must be None"): + super().test_fillna_limit_frame(data_missing) + + def test_fillna_limit_series(self, data_missing): + # GH#58001 + with pytest.raises(ValueError, match="limit must be None"): + super().test_fillna_limit_frame(data_missing) + + _combine_le_expected_dtype = "Sparse[bool]" + + def test_fillna_copy_frame(self, data_missing): + arr = data_missing.take([1, 1]) + df = pd.DataFrame({"A": arr}, copy=False) + + filled_val = df.iloc[0, 0] + result = df.fillna(filled_val) + + if hasattr(df._mgr, "blocks"): + assert df.values.base is result.values.base + assert df.A._values.to_dense() is arr.to_dense() + + def test_fillna_copy_series(self, data_missing): + arr = data_missing.take([1, 1]) + ser = pd.Series(arr, copy=False) + + filled_val = ser[0] + result = ser.fillna(filled_val) + + assert ser._values is result._values + assert ser._values.to_dense() is arr.to_dense() + + @pytest.mark.xfail(reason="Not Applicable") + def test_fillna_length_mismatch(self, data_missing): + super().test_fillna_length_mismatch(data_missing) + + def test_where_series(self, data, na_value): + assert data[0] != data[1] + cls = type(data) + a, b = data[:2] + + ser = pd.Series(cls._from_sequence([a, a, b, b], dtype=data.dtype)) + + cond = np.array([True, True, False, False]) + result = ser.where(cond) + + new_dtype = SparseDtype("float", 0.0) + expected = pd.Series( + cls._from_sequence([a, a, na_value, na_value], dtype=new_dtype) + ) + tm.assert_series_equal(result, expected) + + other = cls._from_sequence([a, b, a, b], dtype=data.dtype) + cond = np.array([True, False, True, True]) + result = ser.where(cond, other) + expected = pd.Series(cls._from_sequence([a, b, b, b], dtype=data.dtype)) + tm.assert_series_equal(result, expected) + + def test_searchsorted(self, performance_warning, data_for_sorting, as_series): + with tm.assert_produces_warning(performance_warning, check_stacklevel=False): + super().test_searchsorted(data_for_sorting, as_series) + + def test_shift_0_periods(self, data): + # GH#33856 shifting with periods=0 should return a copy, not same obj + result = data.shift(0) + + data._sparse_values[0] = data._sparse_values[1] + assert result._sparse_values[0] != result._sparse_values[1] + + @pytest.mark.parametrize("method", ["argmax", "argmin"]) + def test_argmin_argmax_all_na(self, method, data, na_value): + # overriding because Sparse[int64, 0] cannot handle na_value + self._check_unsupported(data) + super().test_argmin_argmax_all_na(method, data, na_value) + + @pytest.mark.fails_arm_wheels + @pytest.mark.parametrize("box", [pd.array, pd.Series, pd.DataFrame]) + def test_equals(self, data, na_value, as_series, box): + self._check_unsupported(data) + super().test_equals(data, na_value, as_series, box) + + @pytest.mark.fails_arm_wheels + def test_equals_same_data_different_object(self, data): + super().test_equals_same_data_different_object(data) + + @pytest.mark.parametrize( + "func, na_action, expected", + [ + (lambda x: x, None, SparseArray([1.0, np.nan])), + (lambda x: x, "ignore", SparseArray([1.0, np.nan])), + (str, None, SparseArray(["1.0", "nan"], fill_value="nan")), + (str, "ignore", SparseArray(["1.0", np.nan])), + ], + ) + def test_map(self, func, na_action, expected): + # GH52096 + data = SparseArray([1, np.nan]) + result = data.map(func, na_action=na_action) + tm.assert_extension_array_equal(result, expected) + + @pytest.mark.parametrize("na_action", [None, "ignore"]) + def test_map_raises(self, data, na_action): + # GH52096 + msg = "fill value in the sparse values not supported" + with pytest.raises(ValueError, match=msg): + data.map(lambda x: np.nan, na_action=na_action) + + @pytest.mark.xfail(raises=TypeError, reason="no sparse StringDtype") + def test_astype_string(self, data, nullable_string_dtype): + # TODO: this fails bc we do not pass through nullable_string_dtype; + # If we did, the 0-cases would xpass + super().test_astype_string(data) + + series_scalar_exc = None + frame_scalar_exc = None + divmod_exc = None + series_array_exc = None + + def _skip_if_different_combine(self, data): + if data.fill_value == 0: + # arith ops call on dtype.fill_value so that the sparsity + # is maintained. Combine can't be called on a dtype in + # general, so we can't make the expected. This is tested elsewhere + pytest.skip("Incorrected expected from Series.combine and tested elsewhere") + + def test_arith_series_with_scalar(self, data, all_arithmetic_operators): + self._skip_if_different_combine(data) + super().test_arith_series_with_scalar(data, all_arithmetic_operators) + + def test_arith_series_with_array(self, data, all_arithmetic_operators): + self._skip_if_different_combine(data) + super().test_arith_series_with_array(data, all_arithmetic_operators) + + def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request): + if data.dtype.fill_value != 0: + pass + elif all_arithmetic_operators.strip("_") not in [ + "mul", + "rmul", + "floordiv", + "rfloordiv", + "truediv", + "rtruediv", + "pow", + "mod", + "rmod", + ]: + mark = pytest.mark.xfail(reason="result dtype.fill_value mismatch") + request.applymarker(mark) + super().test_arith_frame_with_scalar(data, all_arithmetic_operators) + + def _compare_other( + self, ser: pd.Series, data_for_compare: SparseArray, comparison_op, other + ): + op = comparison_op + + result = op(data_for_compare, other) + if isinstance(other, pd.Series): + assert isinstance(result, pd.Series) + assert isinstance(result.dtype, SparseDtype) + else: + assert isinstance(result, SparseArray) + assert result.dtype.subtype == np.bool_ + + if isinstance(other, pd.Series): + fill_value = op(data_for_compare.fill_value, other._values.fill_value) + expected = SparseArray( + op(data_for_compare.to_dense(), np.asarray(other)), + fill_value=fill_value, + dtype=np.bool_, + ) + + else: + fill_value = np.all( + op(np.asarray(data_for_compare.fill_value), np.asarray(other)) + ) + + expected = SparseArray( + op(data_for_compare.to_dense(), np.asarray(other)), + fill_value=fill_value, + dtype=np.bool_, + ) + if isinstance(other, pd.Series): + # error: Incompatible types in assignment + expected = pd.Series(expected) # type: ignore[assignment] + tm.assert_equal(result, expected) + + def test_scalar(self, data_for_compare: SparseArray, comparison_op): + ser = pd.Series(data_for_compare) + self._compare_other(ser, data_for_compare, comparison_op, 0) + self._compare_other(ser, data_for_compare, comparison_op, 1) + self._compare_other(ser, data_for_compare, comparison_op, -1) + self._compare_other(ser, data_for_compare, comparison_op, np.nan) + + def test_array(self, data_for_compare: SparseArray, comparison_op, request): + if data_for_compare.dtype.fill_value == 0 and comparison_op.__name__ in [ + "eq", + "ge", + "le", + ]: + mark = pytest.mark.xfail(reason="Wrong fill_value") + request.applymarker(mark) + + arr = np.linspace(-4, 5, 10) + ser = pd.Series(data_for_compare) + self._compare_other(ser, data_for_compare, comparison_op, arr) + + def test_sparse_array(self, data_for_compare: SparseArray, comparison_op, request): + if data_for_compare.dtype.fill_value == 0 and comparison_op.__name__ != "gt": + mark = pytest.mark.xfail(reason="Wrong fill_value") + request.applymarker(mark) + + ser = pd.Series(data_for_compare) + arr = data_for_compare + 1 + self._compare_other(ser, data_for_compare, comparison_op, arr) + arr = data_for_compare * 2 + self._compare_other(ser, data_for_compare, comparison_op, arr) + + @pytest.mark.xfail(reason="Different repr") + def test_array_repr(self, data, size): + super().test_array_repr(data, size) + + @pytest.mark.xfail(reason="result does not match expected") + @pytest.mark.parametrize("as_index", [True, False]) + def test_groupby_extension_agg(self, as_index, data_for_grouping): + super().test_groupby_extension_agg(as_index, data_for_grouping) + + +def test_array_type_with_arg(dtype): + assert dtype.construct_array_type() is SparseArray diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py new file mode 100644 index 0000000000000000000000000000000000000000..07c957beef6522267e393bbfc27e02fcfb487d3c --- /dev/null +++ b/pandas/tests/extension/test_string.py @@ -0,0 +1,308 @@ +""" +This file contains a minimal set of tests for compliance with the extension +array interface test suite, and should contain no other tests. +The test suite for the full functionality of the array is located in +`pandas/tests/arrays/`. + +The tests in this file are inherited from the BaseExtensionTests, and only +minimal tweaks should be applied to get the tests passing (by overwriting a +parent method). + +Additional tests should either be added to one of the BaseExtensionTests +classes (if they are relevant for the extension interface for all dtypes), or +be added to the array-specific tests in `pandas/tests/arrays/`. + +""" + +from __future__ import annotations + +import string +from typing import cast + +import numpy as np +import pytest + +from pandas.compat import HAS_PYARROW + +from pandas.core.dtypes.base import StorageExtensionDtype + +import pandas as pd +import pandas._testing as tm +from pandas.api.types import is_string_dtype +from pandas.core.arrays import ArrowStringArray +from pandas.core.arrays.string_ import StringDtype +from pandas.tests.arithmetic.test_string import string_dtype_highest_priority +from pandas.tests.extension import base + + +def maybe_split_array(arr, chunked): + if not chunked: + return arr + elif arr.dtype.storage != "pyarrow": + return arr + + pa = pytest.importorskip("pyarrow") + + arrow_array = arr._pa_array + split = len(arrow_array) // 2 + arrow_array = pa.chunked_array( + [*arrow_array[:split].chunks, *arrow_array[split:].chunks] + ) + assert arrow_array.num_chunks == 2 + return arr._from_pyarrow_array(arrow_array) + + +@pytest.fixture(params=[True, False]) +def chunked(request): + return request.param + + +@pytest.fixture +def dtype(string_dtype_arguments): + storage, na_value = string_dtype_arguments + return StringDtype(storage=storage, na_value=na_value) + + +@pytest.fixture +def data(dtype, chunked): + strings = np.random.default_rng(2).choice(list(string.ascii_letters), size=10) + while strings[0] == strings[1]: + strings = np.random.default_rng(2).choice(list(string.ascii_letters), size=10) + + arr = dtype.construct_array_type()._from_sequence(strings, dtype=dtype) + return maybe_split_array(arr, chunked) + + +@pytest.fixture +def data_missing(dtype, chunked): + """Length 2 array with [NA, Valid]""" + arr = dtype.construct_array_type()._from_sequence([pd.NA, "A"], dtype=dtype) + return maybe_split_array(arr, chunked) + + +@pytest.fixture +def data_for_sorting(dtype, chunked): + arr = dtype.construct_array_type()._from_sequence(["B", "C", "A"], dtype=dtype) + return maybe_split_array(arr, chunked) + + +@pytest.fixture +def data_missing_for_sorting(dtype, chunked): + arr = dtype.construct_array_type()._from_sequence(["B", pd.NA, "A"], dtype=dtype) + return maybe_split_array(arr, chunked) + + +@pytest.fixture +def data_for_grouping(dtype, chunked): + arr = dtype.construct_array_type()._from_sequence( + ["B", "B", pd.NA, pd.NA, "A", "A", "B", "C"], dtype=dtype + ) + return maybe_split_array(arr, chunked) + + +class TestStringArray(base.ExtensionTests): + def test_combine_le(self, data_repeated): + dtype = next(iter(data_repeated(2))).dtype + if dtype.storage == "pyarrow" and dtype.na_value is pd.NA: + self._combine_le_expected_dtype = "bool[pyarrow]" + else: + self._combine_le_expected_dtype = "bool" + return super().test_combine_le(data_repeated) + + def test_eq_with_str(self, dtype): + super().test_eq_with_str(dtype) + + if dtype.na_value is pd.NA: + # only the NA-variant supports parametrized string alias + assert dtype == f"string[{dtype.storage}]" + elif dtype.storage == "pyarrow": + assert dtype == "str" + + def test_is_not_string_type(self, dtype): + # Different from BaseDtypeTests.test_is_not_string_type + # because StringDtype is a string type + assert is_string_dtype(dtype) + + def test_is_dtype_from_name(self, dtype, using_infer_string): + if dtype.na_value is np.nan and not using_infer_string: + result = type(dtype).is_dtype(dtype.name) + assert result is False + else: + super().test_is_dtype_from_name(dtype) + + def test_construct_from_string_own_name(self, dtype, using_infer_string): + if dtype.na_value is np.nan and not using_infer_string: + with pytest.raises(TypeError, match="Cannot construct a 'StringDtype'"): + dtype.construct_from_string(dtype.name) + else: + super().test_construct_from_string_own_name(dtype) + + def test_view(self, data): + if data.dtype.storage == "pyarrow": + pytest.skip(reason="2D support not implemented for ArrowStringArray") + super().test_view(data) + + def test_from_dtype(self, data): + # base test uses string representation of dtype + pass + + def test_transpose(self, data): + if data.dtype.storage == "pyarrow": + pytest.skip(reason="2D support not implemented for ArrowStringArray") + super().test_transpose(data) + + def test_setitem_preserves_views(self, data): + if data.dtype.storage == "pyarrow": + pytest.skip(reason="2D support not implemented for ArrowStringArray") + super().test_setitem_preserves_views(data) + + def test_dropna_array(self, data_missing): + result = data_missing.dropna() + expected = data_missing[[1]] + tm.assert_extension_array_equal(result, expected) + + def test_fillna_no_op_returns_copy(self, data): + data = data[~data.isna()] + + valid = data[0] + result = data.fillna(valid) + assert result is not data + tm.assert_extension_array_equal(result, data) + + def test_fillna_readonly(self, data_missing): + data = data_missing.copy() + data._readonly = True + + # by default fillna(copy=True), then this works fine + result = data.fillna(data_missing[1]) + assert result[0] == data_missing[1] + tm.assert_extension_array_equal(data, data_missing) + + # fillna(copy=False) is generally not honored by Arrow-backed array, + # but always returns new data -> same result as above + if data.dtype.storage == "pyarrow": + result = data.fillna(data_missing[1]) + assert result[0] == data_missing[1] + else: + with pytest.raises(ValueError, match="Cannot modify read-only array"): + data.fillna(data_missing[1], copy=False) + tm.assert_extension_array_equal(data, data_missing) + + def _get_expected_exception( + self, op_name: str, obj, other + ) -> type[Exception] | tuple[type[Exception], ...] | None: + if op_name in [ + "__mod__", + "__rmod__", + "__divmod__", + "__rdivmod__", + "__pow__", + "__rpow__", + ]: + return TypeError + elif op_name in ["__mul__", "__rmul__"]: + # Can only multiply strings by integers + return TypeError + elif op_name in [ + "__truediv__", + "__rtruediv__", + "__floordiv__", + "__rfloordiv__", + "__sub__", + "__rsub__", + ]: + return TypeError + + return None + + def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool: + return op_name in ["min", "max", "sum"] or ( + ser.dtype.na_value is np.nan # type: ignore[union-attr] + and op_name in ("any", "all") + ) + + def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool: + assert isinstance(ser.dtype, StorageExtensionDtype) + return op_name in ["cummin", "cummax", "cumsum"] + + def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): + dtype = cast(StringDtype, tm.get_dtype(obj)) + if op_name in ["__add__", "__radd__"]: + cast_to = dtype + dtype_other = tm.get_dtype(other) if not isinstance(other, str) else None + if isinstance(dtype_other, StringDtype): + cast_to = string_dtype_highest_priority(dtype, dtype_other) + elif dtype.na_value is np.nan: + cast_to = np.bool_ # type: ignore[assignment] + elif dtype.storage == "pyarrow": + cast_to = "bool[pyarrow]" # type: ignore[assignment] + else: + cast_to = "boolean" # type: ignore[assignment] + return pointwise_result.astype(cast_to) + + def test_compare_scalar(self, data, comparison_op): + ser = pd.Series(data) + self._compare_other(ser, data, comparison_op, "abc") + + def test_groupby_extension_apply(self, data_for_grouping, groupby_apply_op): + super().test_groupby_extension_apply(data_for_grouping, groupby_apply_op) + + def test_combine_add(self, data_repeated, using_infer_string, request): + dtype = next(data_repeated(1)).dtype + if not using_infer_string and dtype.storage == "python": + mark = pytest.mark.xfail( + reason="The pointwise operation result will be inferred to " + "string[nan, pyarrow], which does not match the input dtype" + ) + request.applymarker(mark) + super().test_combine_add(data_repeated) + + def test_arith_series_with_array( + self, data, all_arithmetic_operators, using_infer_string, request + ): + dtype = data.dtype + if ( + using_infer_string + and all_arithmetic_operators == "__radd__" + and dtype.na_value is pd.NA + and (HAS_PYARROW or dtype.storage == "pyarrow") + ): + # TODO(infer_string) + mark = pytest.mark.xfail( + reason="The pointwise operation result will be inferred to " + "string[nan, pyarrow], which does not match the input dtype" + ) + request.applymarker(mark) + super().test_arith_series_with_array(data, all_arithmetic_operators) + + def test_loc_setitem_with_expansion_preserves_ea_index_dtype( + self, data, request, using_infer_string + ): + if not using_infer_string and data.dtype.storage == "python": + mark = pytest.mark.xfail(reason="Casts to object") + request.applymarker(mark) + super().test_loc_setitem_with_expansion_preserves_ea_index_dtype(data) + + +class Test2DCompat(base.Dim2CompatTests): + @pytest.fixture(autouse=True) + def arrow_not_supported(self, data): + if isinstance(data, ArrowStringArray): + pytest.skip(reason="2D support not implemented for ArrowStringArray") + + +def test_searchsorted_with_na_raises(data_for_sorting, as_series): + # GH50447 + b, c, a = data_for_sorting + arr = data_for_sorting.take([2, 0, 1]) # to get [a, b, c] + arr[-1] = pd.NA + + if as_series: + arr = pd.Series(arr) + + msg = ( + "searchsorted requires array to be sorted, " + "which is impossible with NAs present." + ) + with pytest.raises(ValueError, match=msg): + arr.searchsorted(b) diff --git a/pandas/tests/frame/__init__.py b/pandas/tests/frame/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/frame/common.py b/pandas/tests/frame/common.py new file mode 100644 index 0000000000000000000000000000000000000000..fc41d7907a240f0dd9dc19e0ae1296bee86be421 --- /dev/null +++ b/pandas/tests/frame/common.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pandas import ( + DataFrame, + concat, +) + +if TYPE_CHECKING: + from pandas._typing import AxisInt + + +def _check_mixed_float(df, dtype=None): + # float16 are most likely to be upcasted to float32 + dtypes = {"A": "float32", "B": "float32", "C": "float16", "D": "float64"} + if isinstance(dtype, str): + dtypes = {k: dtype for k, v in dtypes.items()} + elif isinstance(dtype, dict): + dtypes.update(dtype) + if dtypes.get("A"): + assert df.dtypes["A"] == dtypes["A"] + if dtypes.get("B"): + assert df.dtypes["B"] == dtypes["B"] + if dtypes.get("C"): + assert df.dtypes["C"] == dtypes["C"] + if dtypes.get("D"): + assert df.dtypes["D"] == dtypes["D"] + + +def _check_mixed_int(df, dtype=None): + dtypes = {"A": "int32", "B": "uint64", "C": "uint8", "D": "int64"} + if isinstance(dtype, str): + dtypes = {k: dtype for k, v in dtypes.items()} + elif isinstance(dtype, dict): + dtypes.update(dtype) + if dtypes.get("A"): + assert df.dtypes["A"] == dtypes["A"] + if dtypes.get("B"): + assert df.dtypes["B"] == dtypes["B"] + if dtypes.get("C"): + assert df.dtypes["C"] == dtypes["C"] + if dtypes.get("D"): + assert df.dtypes["D"] == dtypes["D"] + + +def zip_frames(frames: list[DataFrame], axis: AxisInt = 1) -> DataFrame: + """ + take a list of frames, zip them together under the + assumption that these all have the first frames' index/columns. + + Returns + ------- + new_frame : DataFrame + """ + if axis == 1: + columns = frames[0].columns + zipped = [f.loc[:, c] for c in columns for f in frames] + return concat(zipped, axis=1) + else: + index = frames[0].index + zipped = [f.loc[i, :] for i in index for f in frames] + return DataFrame(zipped) diff --git a/pandas/tests/frame/conftest.py b/pandas/tests/frame/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..73b8f08957687a8b4af0d582a93968c1514c96c3 --- /dev/null +++ b/pandas/tests/frame/conftest.py @@ -0,0 +1,100 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + NaT, + date_range, +) + + +@pytest.fixture +def datetime_frame() -> DataFrame: + """ + Fixture for DataFrame of floats with DatetimeIndex + + Columns are ['A', 'B', 'C', 'D'] + """ + return DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD")), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + + +@pytest.fixture +def float_string_frame(): + """ + Fixture for DataFrame of floats and strings with index of unique strings + + Columns are ['A', 'B', 'C', 'D', 'foo']. + """ + df = DataFrame( + np.random.default_rng(2).standard_normal((30, 4)), + index=Index([f"foo_{i}" for i in range(30)], dtype=object), + columns=Index(list("ABCD")), + ) + df["foo"] = "bar" + return df + + +@pytest.fixture +def mixed_float_frame(): + """ + Fixture for DataFrame of different float types with index of unique strings + + Columns are ['A', 'B', 'C', 'D']. + """ + df = DataFrame( + { + col: np.random.default_rng(2).random(30, dtype=dtype) + for col, dtype in zip( + list("ABCD"), ["float32", "float32", "float32", "float64"] + ) + }, + index=Index([f"foo_{i}" for i in range(30)], dtype=object), + ) + # not supported by numpy random + df["C"] = df["C"].astype("float16") + return df + + +@pytest.fixture +def mixed_int_frame(): + """ + Fixture for DataFrame of different int types with index of unique strings + + Columns are ['A', 'B', 'C', 'D']. + """ + return DataFrame( + { + col: np.ones(30, dtype=dtype) + for col, dtype in zip(list("ABCD"), ["int32", "uint64", "uint8", "int64"]) + }, + index=Index([f"foo_{i}" for i in range(30)], dtype=object), + ) + + +@pytest.fixture +def timezone_frame(): + """ + Fixture for DataFrame of date_range Series with different time zones + + Columns are ['A', 'B', 'C']; some entries are missing + + A B C + 0 2013-01-01 2013-01-01 00:00:00-05:00 2013-01-01 00:00:00+01:00 + 1 2013-01-02 NaT NaT + 2 2013-01-03 2013-01-03 00:00:00-05:00 2013-01-03 00:00:00+01:00 + """ + df = DataFrame( + { + "A": date_range("20130101", periods=3, unit="ns"), + "B": date_range("20130101", periods=3, tz="US/Eastern", unit="ns"), + "C": date_range("20130101", periods=3, tz="CET", unit="ns"), + } + ) + df.iloc[1, 1] = NaT + df.iloc[1, 2] = NaT + return df diff --git a/pandas/tests/frame/test_alter_axes.py b/pandas/tests/frame/test_alter_axes.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c16b94fcf8b1ee918dce8a9084e56d00225e7b --- /dev/null +++ b/pandas/tests/frame/test_alter_axes.py @@ -0,0 +1,31 @@ +from datetime import ( + datetime, + timezone, +) + +from pandas import DataFrame +import pandas._testing as tm + + +class TestDataFrameAlterAxes: + # Tests for setting index/columns attributes directly (i.e. __setattr__) + + def test_set_axis_setattr_index(self): + # GH 6785 + # set the index manually + + df = DataFrame([{"ts": datetime(2014, 4, 1, tzinfo=timezone.utc), "foo": 1}]) + expected = df.set_index("ts") + df.index = df["ts"] + df.pop("ts") + tm.assert_frame_equal(df, expected) + + # Renaming + + def test_assign_columns(self, float_frame): + float_frame["hi"] = "there" + + df = float_frame.copy() + df.columns = ["foo", "bar", "baz", "quux", "foo2"] + tm.assert_series_equal(float_frame["C"], df["baz"], check_names=False) + tm.assert_series_equal(float_frame["hi"], df["foo2"], check_names=False) diff --git a/pandas/tests/frame/test_api.py b/pandas/tests/frame/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..f54e7605528254fa18c9267057ee4eff0e2977c8 --- /dev/null +++ b/pandas/tests/frame/test_api.py @@ -0,0 +1,408 @@ +from copy import deepcopy +import inspect +import pydoc + +import numpy as np +import pytest + +from pandas._config import using_string_dtype +from pandas._config.config import option_context + +from pandas.compat import HAS_PYARROW + +import pandas as pd +from pandas import ( + DataFrame, + Series, + date_range, + timedelta_range, +) +import pandas._testing as tm + + +class TestDataFrameMisc: + def test_getitem_pop_assign_name(self, float_frame): + s = float_frame["A"] + assert s.name == "A" + + s = float_frame.pop("A") + assert s.name == "A" + + s = float_frame.loc[:, "B"] + assert s.name == "B" + + s2 = s.loc[:] + assert s2.name == "B" + + def test_get_axis(self, float_frame): + f = float_frame + assert f._get_axis_number(0) == 0 + assert f._get_axis_number(1) == 1 + assert f._get_axis_number("index") == 0 + assert f._get_axis_number("rows") == 0 + assert f._get_axis_number("columns") == 1 + + assert f._get_axis_name(0) == "index" + assert f._get_axis_name(1) == "columns" + assert f._get_axis_name("index") == "index" + assert f._get_axis_name("rows") == "index" + assert f._get_axis_name("columns") == "columns" + + assert f._get_axis(0) is f.index + assert f._get_axis(1) is f.columns + + with pytest.raises(ValueError, match="No axis named"): + f._get_axis_number(2) + + with pytest.raises(ValueError, match="No axis.*foo"): + f._get_axis_name("foo") + + with pytest.raises(ValueError, match="No axis.*None"): + f._get_axis_name(None) + + with pytest.raises(ValueError, match="No axis named"): + f._get_axis_number(None) + + def test_column_contains_raises(self, float_frame): + with pytest.raises(TypeError, match="unhashable type: 'Index'"): + float_frame.columns in float_frame + + def test_tab_completion(self): + # DataFrame whose columns are identifiers shall have them in __dir__. + df = DataFrame([list("abcd"), list("efgh")], columns=list("ABCD")) + for key in list("ABCD"): + assert key in dir(df) + assert isinstance(df.__getitem__("A"), Series) + + # DataFrame whose first-level columns are identifiers shall have + # them in __dir__. + df = DataFrame( + [list("abcd"), list("efgh")], + columns=pd.MultiIndex.from_tuples(list(zip("ABCD", "EFGH"))), + ) + for key in list("ABCD"): + assert key in dir(df) + for key in list("EFGH"): + assert key not in dir(df) + assert isinstance(df.__getitem__("A"), DataFrame) + + def test_display_max_dir_items(self): + # display.max_dir_items increases the number of columns that are in __dir__. + columns = ["a" + str(i) for i in range(420)] + values = [range(420), range(420)] + df = DataFrame(values, columns=columns) + + # The default value for display.max_dir_items is 100 + assert "a99" in dir(df) + assert "a100" not in dir(df) + + with option_context("display.max_dir_items", 300): + df = DataFrame(values, columns=columns) + assert "a299" in dir(df) + assert "a300" not in dir(df) + + with option_context("display.max_dir_items", None): + df = DataFrame(values, columns=columns) + assert "a419" in dir(df) + + def test_not_hashable(self): + empty_frame = DataFrame() + + df = DataFrame([1]) + msg = "unhashable type: 'DataFrame'" + with pytest.raises(TypeError, match=msg): + hash(df) + with pytest.raises(TypeError, match=msg): + hash(empty_frame) + + @pytest.mark.xfail( + using_string_dtype() and HAS_PYARROW, reason="surrogates not allowed" + ) + def test_column_name_contains_unicode_surrogate(self): + # GH 25509 + colname = "\ud83d" + df = DataFrame({colname: []}) + # this should not crash + assert colname not in dir(df) + assert df.columns[0] == colname + + def test_new_empty_index(self): + df1 = DataFrame(np.random.default_rng(2).standard_normal((0, 3))) + df2 = DataFrame(np.random.default_rng(2).standard_normal((0, 3))) + df1.index.name = "foo" + assert df2.index.name is None + + def test_get_agg_axis(self, float_frame): + cols = float_frame._get_agg_axis(0) + assert cols is float_frame.columns + + idx = float_frame._get_agg_axis(1) + assert idx is float_frame.index + + msg = r"Axis must be 0 or 1 \(got 2\)" + with pytest.raises(ValueError, match=msg): + float_frame._get_agg_axis(2) + + def test_empty(self, float_frame, float_string_frame): + empty_frame = DataFrame() + assert empty_frame.empty + + assert not float_frame.empty + assert not float_string_frame.empty + + # corner case + df = DataFrame({"A": [1.0, 2.0, 3.0], "B": ["a", "b", "c"]}, index=np.arange(3)) + del df["A"] + assert not df.empty + + def test_len(self, float_frame): + assert len(float_frame) == len(float_frame.index) + + # single block corner case + arr = float_frame[["A", "B"]].values + expected = float_frame.reindex(columns=["A", "B"]).values + tm.assert_almost_equal(arr, expected) + + def test_axis_aliases(self, float_frame): + f = float_frame + + # reg name + expected = f.sum(axis=0) + result = f.sum(axis="index") + tm.assert_series_equal(result, expected) + + expected = f.sum(axis=1) + result = f.sum(axis="columns") + tm.assert_series_equal(result, expected) + + def test_class_axis(self): + # GH 18147 + # no exception and no empty docstring + assert pydoc.getdoc(DataFrame.index) + assert pydoc.getdoc(DataFrame.columns) + + def test_series_put_names(self, float_string_frame): + series = float_string_frame._series + for k, v in series.items(): + assert v.name == k + + def test_empty_nonzero(self): + df = DataFrame([1, 2, 3]) + assert not df.empty + df = DataFrame(index=[1], columns=[1]) + assert not df.empty + df = DataFrame(index=["a", "b"], columns=["c", "d"]).dropna() + assert df.empty + assert df.T.empty + + @pytest.mark.parametrize( + "df", + [ + DataFrame(), + DataFrame(index=[1]), + DataFrame(columns=[1]), + DataFrame({1: []}), + ], + ) + def test_empty_like(self, df): + assert df.empty + assert df.T.empty + + def test_with_datetimelikes(self): + df = DataFrame( + { + "A": date_range("20130101", periods=10), + "B": timedelta_range("1 day", periods=10), + } + ) + t = df.T + + result = t.dtypes.value_counts() + expected = Series({np.dtype("object"): 10}, name="count") + tm.assert_series_equal(result, expected) + + def test_deepcopy(self, float_frame): + cp = deepcopy(float_frame) + cp.loc[0, "A"] = 10 + assert not float_frame.equals(cp) + + def test_inplace_return_self(self): + # GH 1893 + + data = DataFrame( + {"a": ["foo", "bar", "baz", "qux"], "b": [0, 0, 1, 1], "c": [1, 2, 3, 4]} + ) + + def _check_none(base, f): + result = f(base) + assert result is None + + def _check_return(base, f): + result = f(base) + assert result is base + + # -----DataFrame----- + + # set_index + f = lambda x: x.set_index("a", inplace=True) + _check_none(data.copy(), f) + + # reset_index + f = lambda x: x.reset_index(inplace=True) + _check_none(data.set_index("a"), f) + + # drop_duplicates + f = lambda x: x.drop_duplicates(inplace=True) + _check_none(data.copy(), f) + + # sort + f = lambda x: x.sort_values("b", inplace=True) + _check_none(data.copy(), f) + + # sort_index + f = lambda x: x.sort_index(inplace=True) + _check_none(data.copy(), f) + + # fillna + f = lambda x: x.fillna(0, inplace=True) + _check_return(data.copy(), f) + + # replace + f = lambda x: x.replace(1, 0, inplace=True) + _check_return(data.copy(), f) + + # rename + f = lambda x: x.rename({1: "foo"}, inplace=True) + _check_none(data.copy(), f) + + # -----Series----- + d = data.copy()["c"] + + # reset_index + f = lambda x: x.reset_index(inplace=True, drop=True) + _check_none(data.set_index("a")["c"], f) + + # fillna + f = lambda x: x.fillna(0, inplace=True) + _check_return(d.copy(), f) + + # replace + f = lambda x: x.replace(1, 0, inplace=True) + _check_return(d.copy(), f) + + # rename + f = lambda x: x.rename({1: "foo"}, inplace=True) + _check_none(d.copy(), f) + + def test_tab_complete_warning(self, ip, frame_or_series): + # GH 16409 + pytest.importorskip("IPython", minversion="6.0.0") + from IPython.core.completer import provisionalcompleter + + if frame_or_series is DataFrame: + code = "from pandas import DataFrame; obj = DataFrame()" + else: + code = "from pandas import Series; obj = Series(dtype=object)" + + ip.run_cell(code) + # GH 31324 newer jedi version raises Deprecation warning; + # appears resolved 2021-02-02 + with tm.assert_produces_warning(None, raise_on_extra_warnings=False): + with provisionalcompleter("ignore"): + list(ip.Completer.completions("obj.", 1)) + + def test_attrs(self): + df = DataFrame({"A": [2, 3]}) + assert df.attrs == {} + df.attrs["version"] = 1 + + result = df.rename(columns=str) + assert result.attrs == {"version": 1} + + def test_attrs_is_deepcopy(self): + df = DataFrame({"A": [2, 3]}) + assert df.attrs == {} + df.attrs["tags"] = {"spam", "ham"} + + result = df.rename(columns=str) + assert result.attrs == df.attrs + assert result.attrs["tags"] is not df.attrs["tags"] + + def test_attrs_concat(self): + # concat propagates attrs if all input attrs are equal + df1 = DataFrame({"A": [2, 3]}) + df1.attrs = {"a": 1, "b": 2} + df2 = DataFrame({"A": [4, 5]}) + df2.attrs = df1.attrs.copy() + df3 = DataFrame({"A": [6, 7]}) + df3.attrs = df1.attrs.copy() + assert pd.concat([df1, df2, df3]).attrs == df1.attrs + # concat does not propagate attrs if input attrs are different + df2.attrs = {"c": 3} + assert pd.concat([df1, df2, df3]).attrs == {} + + def test_attrs_merge(self): + # merge propagates attrs if all input attrs are equal + df1 = DataFrame({"key": ["a", "b"], "val1": [1, 2]}) + df1.attrs = {"a": 1, "b": 2} + df2 = DataFrame({"key": ["a", "b"], "val2": [3, 4]}) + df2.attrs = df1.attrs.copy() + assert pd.merge(df1, df2).attrs == df1.attrs + # merge does not propagate attrs if input attrs are different + df2.attrs = {"c": 3} + assert pd.merge(df1, df2).attrs == {} + + @pytest.mark.parametrize("allows_duplicate_labels", [True, False, None]) + def test_set_flags( + self, + allows_duplicate_labels, + frame_or_series, + ): + obj = DataFrame({"A": [1, 2]}) + key = (0, 0) + if frame_or_series is Series: + obj = obj["A"] + key = 0 + + result = obj.set_flags(allows_duplicate_labels=allows_duplicate_labels) + + if allows_duplicate_labels is None: + # We don't update when it's not provided + assert result.flags.allows_duplicate_labels is True + else: + assert result.flags.allows_duplicate_labels is allows_duplicate_labels + + # We made a copy + assert obj is not result + + # We didn't mutate obj + assert obj.flags.allows_duplicate_labels is True + + # But we didn't copy data + if frame_or_series is Series: + assert np.may_share_memory(obj.values, result.values) + else: + assert np.may_share_memory(obj["A"].values, result["A"].values) + + result.iloc[key] = 0 + assert obj.iloc[key] == 1 + + # Now we do copy. + result = obj.set_flags(allows_duplicate_labels=allows_duplicate_labels) + result.iloc[key] = 10 + assert obj.iloc[key] == 1 + + def test_constructor_expanddim(self): + # GH#33628 accessing _constructor_expanddim should not raise NotImplementedError + # GH38782 pandas has no container higher than DataFrame (two-dim), so + # DataFrame._constructor_expand_dim, doesn't make sense, so is removed. + df = DataFrame() + + msg = "'DataFrame' object has no attribute '_constructor_expanddim'" + with pytest.raises(AttributeError, match=msg): + df._constructor_expanddim(np.arange(27).reshape(3, 3, 3)) + + def test_inspect_getmembers(self): + # GH38740 + df = DataFrame() + inspect.getmembers(df) diff --git a/pandas/tests/frame/test_arithmetic.py b/pandas/tests/frame/test_arithmetic.py new file mode 100644 index 0000000000000000000000000000000000000000..388c28f4015f4e53c6ac42cea8945be60fc414fb --- /dev/null +++ b/pandas/tests/frame/test_arithmetic.py @@ -0,0 +1,2203 @@ +from collections import deque +from datetime import ( + datetime, + timezone, +) +from enum import Enum +import functools +import operator +import re + +import numpy as np +import pytest + +from pandas.compat._optional import import_optional_dependency + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, +) +import pandas._testing as tm +from pandas.core.computation import expressions as expr +from pandas.tests.frame.common import ( + _check_mixed_float, + _check_mixed_int, +) +from pandas.util.version import Version + + +@pytest.fixture +def simple_frame(): + """ + Fixture for simple 3x3 DataFrame + + Columns are ['one', 'two', 'three'], index is ['a', 'b', 'c']. + + one two three + a 1.0 2.0 3.0 + b 4.0 5.0 6.0 + c 7.0 8.0 9.0 + """ + arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + + return DataFrame(arr, columns=["one", "two", "three"], index=["a", "b", "c"]) + + +@pytest.fixture(autouse=True, params=[0, 100], ids=["numexpr", "python"]) +def switch_numexpr_min_elements(request, monkeypatch): + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", request.param) + yield request.param + + +class DummyElement: + def __init__(self, value, dtype) -> None: + self.value = value + self.dtype = np.dtype(dtype) + + def __array__(self, dtype=None, copy=None): + return np.array(self.value, dtype=self.dtype) + + def __str__(self) -> str: + return f"DummyElement({self.value}, {self.dtype})" + + def __repr__(self) -> str: + return str(self) + + def astype(self, dtype, copy=False): + self.dtype = dtype + return self + + def view(self, dtype): + return type(self)(self.value.view(dtype), dtype) + + def any(self, axis=None): + return bool(self.value) + + +# ------------------------------------------------------------------- +# Comparisons + + +class TestFrameComparisons: + # Specifically _not_ flex-comparisons + + def test_comparison_with_categorical_dtype(self): + # GH#12564 + + df = DataFrame({"A": ["foo", "bar", "baz"]}) + exp = DataFrame({"A": [True, False, False]}) + + res = df == "foo" + tm.assert_frame_equal(res, exp) + + # casting to categorical shouldn't affect the result + df["A"] = df["A"].astype("category") + + res = df == "foo" + tm.assert_frame_equal(res, exp) + + def test_frame_in_list(self): + # GH#12689 this should raise at the DataFrame level, not blocks + df = DataFrame( + np.random.default_rng(2).standard_normal((6, 4)), columns=list("ABCD") + ) + msg = "The truth value of a DataFrame is ambiguous" + with pytest.raises(ValueError, match=msg): + df in [None] + + @pytest.mark.parametrize( + "arg, arg2", + [ + [ + { + "a": np.random.default_rng(2).integers(10, size=10), + "b": pd.date_range("20010101", periods=10, unit="ns"), + }, + { + "a": np.random.default_rng(2).integers(10, size=10), + "b": np.random.default_rng(2).integers(10, size=10), + }, + ], + [ + { + "a": np.random.default_rng(2).integers(10, size=10), + "b": np.random.default_rng(2).integers(10, size=10), + }, + { + "a": np.random.default_rng(2).integers(10, size=10), + "b": pd.date_range("20010101", periods=10, unit="ns"), + }, + ], + [ + { + "a": pd.date_range("20010101", periods=10, unit="ns"), + "b": pd.date_range("20010101", periods=10, unit="ns"), + }, + { + "a": np.random.default_rng(2).integers(10, size=10), + "b": np.random.default_rng(2).integers(10, size=10), + }, + ], + [ + { + "a": np.random.default_rng(2).integers(10, size=10), + "b": pd.date_range("20010101", periods=10, unit="ns"), + }, + { + "a": pd.date_range("20010101", periods=10, unit="ns"), + "b": pd.date_range("20010101", periods=10, unit="ns"), + }, + ], + ], + ) + def test_comparison_invalid(self, arg, arg2): + # GH4968 + # invalid date/int comparisons + x = DataFrame(arg) + y = DataFrame(arg2) + # we expect the result to match Series comparisons for + # == and !=, inequalities should raise + result = x == y + expected = DataFrame( + {col: x[col] == y[col] for col in x.columns}, + index=x.index, + columns=x.columns, + ) + tm.assert_frame_equal(result, expected) + + result = x != y + expected = DataFrame( + {col: x[col] != y[col] for col in x.columns}, + index=x.index, + columns=x.columns, + ) + tm.assert_frame_equal(result, expected) + + msgs = [ + r"Invalid comparison between dtype=datetime64\[ns\] and ndarray", + "invalid type promotion", + ( + # npdev 1.20.0 + r"The DTypes and " + r" do not have a common DType." + ), + ] + msg = "|".join(msgs) + with pytest.raises(TypeError, match=msg): + x >= y + with pytest.raises(TypeError, match=msg): + x > y + with pytest.raises(TypeError, match=msg): + x < y + with pytest.raises(TypeError, match=msg): + x <= y + + @pytest.mark.parametrize( + "left, right", + [ + ("gt", "lt"), + ("lt", "gt"), + ("ge", "le"), + ("le", "ge"), + ("eq", "eq"), + ("ne", "ne"), + ], + ) + def test_timestamp_compare(self, left, right): + # make sure we can compare Timestamps on the right AND left hand side + # GH#4982 + df = DataFrame( + { + "dates1": pd.date_range("20010101", periods=10), + "dates2": pd.date_range("20010102", periods=10), + "intcol": np.random.default_rng(2).integers(1000000000, size=10), + "floatcol": np.random.default_rng(2).standard_normal(10), + "stringcol": [chr(100 + i) for i in range(10)], + } + ) + df.loc[np.random.default_rng(2).random(len(df)) > 0.5, "dates2"] = pd.NaT + left_f = getattr(operator, left) + right_f = getattr(operator, right) + + # no nats + if left in ["eq", "ne"]: + expected = left_f(df, pd.Timestamp("20010109")) + result = right_f(pd.Timestamp("20010109"), df) + tm.assert_frame_equal(result, expected) + else: + msg = ( + "'(<|>)=?' not supported between " + "instances of 'numpy.ndarray' and 'Timestamp'" + ) + with pytest.raises(TypeError, match=msg): + left_f(df, pd.Timestamp("20010109")) + with pytest.raises(TypeError, match=msg): + right_f(pd.Timestamp("20010109"), df) + # nats + if left in ["eq", "ne"]: + expected = left_f(df, pd.Timestamp("nat")) + result = right_f(pd.Timestamp("nat"), df) + tm.assert_frame_equal(result, expected) + else: + msg = ( + "'(<|>)=?' not supported between " + "instances of 'numpy.ndarray' and 'NaTType'" + ) + with pytest.raises(TypeError, match=msg): + left_f(df, pd.Timestamp("nat")) + with pytest.raises(TypeError, match=msg): + right_f(pd.Timestamp("nat"), df) + + def test_mixed_comparison(self): + # GH#13128, GH#22163 != datetime64 vs non-dt64 should be False, + # not raise TypeError + # (this appears to be fixed before GH#22163, not sure when) + df = DataFrame([["1989-08-01", 1], ["1989-08-01", 2]]) + other = DataFrame([["a", "b"], ["c", "d"]]) + + result = df == other + assert not result.any().any() + + result = df != other + assert result.all().all() + + def test_df_boolean_comparison_error(self): + # GH#4576, GH#22880 + # comparing DataFrame against list/tuple with len(obj) matching + # len(df.columns) is supported as of GH#22800 + df = DataFrame(np.arange(6).reshape((3, 2))) + + expected = DataFrame([[False, False], [True, False], [False, False]]) + + result = df == (2, 2) + tm.assert_frame_equal(result, expected) + + result = df == [2, 2] + tm.assert_frame_equal(result, expected) + + def test_df_float_none_comparison(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((8, 3)), + index=range(8), + columns=["A", "B", "C"], + ) + + result = df.__eq__(None) + assert not result.any().any() + + def test_df_string_comparison(self): + df = DataFrame([{"a": 1, "b": "foo"}, {"a": 2, "b": "bar"}]) + mask_a = df.a > 1 + tm.assert_frame_equal(df[mask_a], df.loc[1:1, :]) + tm.assert_frame_equal(df[-mask_a], df.loc[0:0, :]) + + mask_b = df.b == "foo" + tm.assert_frame_equal(df[mask_b], df.loc[0:0, :]) + tm.assert_frame_equal(df[-mask_b], df.loc[1:1, :]) + + +class TestFrameFlexComparisons: + # TODO: test_bool_flex_frame needs a better name + def test_bool_flex_frame(self, comparison_op): + data = np.random.default_rng(2).standard_normal((5, 3)) + other_data = np.random.default_rng(2).standard_normal((5, 3)) + df = DataFrame(data) + other = DataFrame(other_data) + ndim_5 = np.ones((*df.shape, 1, 3)) + + # DataFrame + assert df.eq(df).values.all() + assert not df.ne(df).values.any() + f = getattr(df, comparison_op.__name__) + o = comparison_op + # No NAs + tm.assert_frame_equal(f(other), o(df, other)) + # Unaligned + part_o = other.loc[3:, 1:].copy() + rs = f(part_o) + xp = o(df, part_o.reindex(index=df.index, columns=df.columns)) + tm.assert_frame_equal(rs, xp) + # ndarray + tm.assert_frame_equal(f(other.values), o(df, other.values)) + # scalar + tm.assert_frame_equal(f(0), o(df, 0)) + # NAs + msg = "Unable to coerce to Series/DataFrame" + tm.assert_frame_equal(f(np.nan), o(df, np.nan)) + with pytest.raises(ValueError, match=msg): + f(ndim_5) + + @pytest.mark.parametrize("box", [np.array, Series]) + def test_bool_flex_series(self, box): + # Series + # list/tuple + data = np.random.default_rng(2).standard_normal((5, 3)) + df = DataFrame(data) + idx_ser = box(np.random.default_rng(2).standard_normal(5)) + col_ser = box(np.random.default_rng(2).standard_normal(3)) + + idx_eq = df.eq(idx_ser, axis=0) + col_eq = df.eq(col_ser) + idx_ne = df.ne(idx_ser, axis=0) + col_ne = df.ne(col_ser) + tm.assert_frame_equal(col_eq, df == Series(col_ser)) + tm.assert_frame_equal(col_eq, -col_ne) + tm.assert_frame_equal(idx_eq, -idx_ne) + tm.assert_frame_equal(idx_eq, df.T.eq(idx_ser).T) + tm.assert_frame_equal(col_eq, df.eq(list(col_ser))) + tm.assert_frame_equal(idx_eq, df.eq(Series(idx_ser), axis=0)) + tm.assert_frame_equal(idx_eq, df.eq(list(idx_ser), axis=0)) + + idx_gt = df.gt(idx_ser, axis=0) + col_gt = df.gt(col_ser) + idx_le = df.le(idx_ser, axis=0) + col_le = df.le(col_ser) + + tm.assert_frame_equal(col_gt, df > Series(col_ser)) + tm.assert_frame_equal(col_gt, -col_le) + tm.assert_frame_equal(idx_gt, -idx_le) + tm.assert_frame_equal(idx_gt, df.T.gt(idx_ser).T) + + idx_ge = df.ge(idx_ser, axis=0) + col_ge = df.ge(col_ser) + idx_lt = df.lt(idx_ser, axis=0) + col_lt = df.lt(col_ser) + tm.assert_frame_equal(col_ge, df >= Series(col_ser)) + tm.assert_frame_equal(col_ge, -col_lt) + tm.assert_frame_equal(idx_ge, -idx_lt) + tm.assert_frame_equal(idx_ge, df.T.ge(idx_ser).T) + + idx_ser = Series(np.random.default_rng(2).standard_normal(5)) + col_ser = Series(np.random.default_rng(2).standard_normal(3)) + + def test_bool_flex_frame_na(self): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + # NA + df.loc[0, 0] = np.nan + rs = df.eq(df) + assert not rs.loc[0, 0] + rs = df.ne(df) + assert rs.loc[0, 0] + rs = df.gt(df) + assert not rs.loc[0, 0] + rs = df.lt(df) + assert not rs.loc[0, 0] + rs = df.ge(df) + assert not rs.loc[0, 0] + rs = df.le(df) + assert not rs.loc[0, 0] + + def test_bool_flex_frame_complex_dtype(self): + # complex + arr = np.array([np.nan, 1, 6, np.nan]) + arr2 = np.array([2j, np.nan, 7, None]) + df = DataFrame({"a": arr}) + df2 = DataFrame({"a": arr2}) + + msg = "|".join( + [ + "'>' not supported between instances of '.*' and 'complex'", + r"unorderable types: .*complex\(\)", # PY35 + ] + ) + with pytest.raises(TypeError, match=msg): + # inequalities are not well-defined for complex numbers + df.gt(df2) + with pytest.raises(TypeError, match=msg): + # regression test that we get the same behavior for Series + df["a"].gt(df2["a"]) + with pytest.raises(TypeError, match=msg): + # Check that we match numpy behavior here + df.values > df2.values + + rs = df.ne(df2) + assert rs.values.all() + + arr3 = np.array([2j, np.nan, None]) + df3 = DataFrame({"a": arr3}) + + with pytest.raises(TypeError, match=msg): + # inequalities are not well-defined for complex numbers + df3.gt(2j) + with pytest.raises(TypeError, match=msg): + # regression test that we get the same behavior for Series + df3["a"].gt(2j) + with pytest.raises(TypeError, match=msg): + # Check that we match numpy behavior here + df3.values > 2j + + def test_bool_flex_frame_object_dtype(self): + # corner, dtype=object + df1 = DataFrame({"col": ["foo", np.nan, "bar"]}, dtype=object) + df2 = DataFrame({"col": ["foo", datetime.now(), "bar"]}, dtype=object) + result = df1.ne(df2) + exp = DataFrame({"col": [False, True, False]}) + tm.assert_frame_equal(result, exp) + + def test_flex_comparison_nat(self): + # GH 15697, GH 22163 df.eq(pd.NaT) should behave like df == pd.NaT, + # and _definitely_ not be NaN + df = DataFrame([pd.NaT]) + + result = df == pd.NaT + # result.iloc[0, 0] is an np.bool_ object + assert result.iloc[0, 0].item() is False + + result = df.eq(pd.NaT) + assert result.iloc[0, 0].item() is False + + result = df != pd.NaT + assert result.iloc[0, 0].item() is True + + result = df.ne(pd.NaT) + assert result.iloc[0, 0].item() is True + + def test_df_flex_cmp_constant_return_types(self, comparison_op): + # GH 15077, non-empty DataFrame + df = DataFrame({"x": [1, 2, 3], "y": [1.0, 2.0, 3.0]}) + const = 2 + + result = getattr(df, comparison_op.__name__)(const).dtypes.value_counts() + tm.assert_series_equal( + result, Series([2], index=[np.dtype(bool)], name="count") + ) + + def test_df_flex_cmp_constant_return_types_empty(self, comparison_op): + # GH 15077 empty DataFrame + df = DataFrame({"x": [1, 2, 3], "y": [1.0, 2.0, 3.0]}) + const = 2 + + empty = df.iloc[:0] + result = getattr(empty, comparison_op.__name__)(const).dtypes.value_counts() + tm.assert_series_equal( + result, Series([2], index=[np.dtype(bool)], name="count") + ) + + def test_df_flex_cmp_ea_dtype_with_ndarray_series(self): + ii = pd.IntervalIndex.from_breaks([1, 2, 3]) + df = DataFrame({"A": ii, "B": ii}) + + ser = Series([0, 0]) + res = df.eq(ser, axis=0) + + expected = DataFrame({"A": [False, False], "B": [False, False]}) + tm.assert_frame_equal(res, expected) + + ser2 = Series([1, 2], index=["A", "B"]) + res2 = df.eq(ser2, axis=1) + tm.assert_frame_equal(res2, expected) + + +# ------------------------------------------------------------------- +# Arithmetic + + +class TestFrameFlexArithmetic: + def test_floordiv_axis0(self): + # make sure we df.floordiv(ser, axis=0) matches column-wise result + arr = np.arange(3) + ser = Series(arr) + df = DataFrame({"A": ser, "B": ser}) + + result = df.floordiv(ser, axis=0) + + expected = DataFrame({col: df[col] // ser for col in df.columns}) + + tm.assert_frame_equal(result, expected) + + result2 = df.floordiv(ser.values, axis=0) + tm.assert_frame_equal(result2, expected) + + def test_df_add_td64_columnwise(self): + # GH 22534 Check that column-wise addition broadcasts correctly + dti = pd.date_range("2016-01-01", periods=10) + tdi = pd.timedelta_range("1", periods=10) + tser = Series(tdi) + df = DataFrame({0: dti, 1: tdi}) + + result = df.add(tser, axis=0) + expected = DataFrame({0: dti + tdi, 1: tdi + tdi}) + tm.assert_frame_equal(result, expected) + + def test_df_add_flex_filled_mixed_dtypes(self): + # GH 19611 + dti = pd.date_range("2016-01-01", periods=3) + ser = Series(["1 Day", "NaT", "2 Days"], dtype="timedelta64[ns]") + df = DataFrame({"A": dti, "B": ser}) + other = DataFrame({"A": ser, "B": ser}) + fill = pd.Timedelta(days=1).to_timedelta64() + result = df.add(other, fill_value=fill) + + expected = DataFrame( + { + "A": Series( + ["2016-01-02", "2016-01-03", "2016-01-05"], dtype="datetime64[ns]" + ), + "B": ser * 2, + } + ) + tm.assert_frame_equal(result, expected) + + def test_arith_flex_frame( + self, all_arithmetic_operators, float_frame, mixed_float_frame + ): + # one instance of parametrized fixture + op = all_arithmetic_operators + + def f(x, y): + # r-versions not in operator-stdlib; get op without "r" and invert + if op.startswith("__r"): + return getattr(operator, op.replace("__r", "__"))(y, x) + return getattr(operator, op)(x, y) + + result = getattr(float_frame, op)(2 * float_frame) + expected = f(float_frame, 2 * float_frame) + tm.assert_frame_equal(result, expected) + + # vs mix float + result = getattr(mixed_float_frame, op)(2 * mixed_float_frame) + expected = f(mixed_float_frame, 2 * mixed_float_frame) + tm.assert_frame_equal(result, expected) + _check_mixed_float(result, dtype={"C": None}) + + @pytest.mark.parametrize("op", ["__add__", "__sub__", "__mul__"]) + def test_arith_flex_frame_mixed( + self, + op, + int_frame, + mixed_int_frame, + mixed_float_frame, + switch_numexpr_min_elements, + ): + f = getattr(operator, op) + + # vs mix int + result = getattr(mixed_int_frame, op)(2 + mixed_int_frame) + expected = f(mixed_int_frame, 2 + mixed_int_frame) + + # no overflow in the uint + dtype = None + if op in ["__sub__"]: + dtype = {"B": "uint64", "C": None} + elif op in ["__add__", "__mul__"]: + dtype = {"C": None} + if expr.USE_NUMEXPR and switch_numexpr_min_elements == 0: + # when using numexpr, the casting rules are slightly different: + # in the `2 + mixed_int_frame` operation, int32 column becomes + # and int64 column (not preserving dtype in operation with Python + # scalar), and then the int32/int64 combo results in int64 result + dtype["A"] = (2 + mixed_int_frame)["A"].dtype + tm.assert_frame_equal(result, expected) + _check_mixed_int(result, dtype=dtype) + + # vs mix float + result = getattr(mixed_float_frame, op)(2 * mixed_float_frame) + expected = f(mixed_float_frame, 2 * mixed_float_frame) + tm.assert_frame_equal(result, expected) + _check_mixed_float(result, dtype={"C": None}) + + # vs plain int + result = getattr(int_frame, op)(2 * int_frame) + expected = f(int_frame, 2 * int_frame) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dim", range(3, 6)) + def test_arith_flex_frame_raise(self, all_arithmetic_operators, float_frame, dim): + # one instance of parametrized fixture + op = all_arithmetic_operators + + # Check that arrays with dim >= 3 raise + arr = np.ones((1,) * dim) + msg = "Unable to coerce to Series/DataFrame" + with pytest.raises(ValueError, match=msg): + getattr(float_frame, op)(arr) + + def test_arith_flex_frame_corner(self, float_frame): + const_add = float_frame.add(1) + tm.assert_frame_equal(const_add, float_frame + 1) + + # corner cases + result = float_frame.add(float_frame[:0]) + expected = float_frame.sort_index() * np.nan + tm.assert_frame_equal(result, expected) + + result = float_frame[:0].add(float_frame) + expected = float_frame.sort_index() * np.nan + tm.assert_frame_equal(result, expected) + + with pytest.raises(NotImplementedError, match="fill_value"): + float_frame.add(float_frame.iloc[0], fill_value=3) + + with pytest.raises(NotImplementedError, match="fill_value"): + float_frame.add(float_frame.iloc[0], axis="index", fill_value=3) + + @pytest.mark.parametrize("op", ["add", "sub", "mul", "mod"]) + def test_arith_flex_series_ops(self, simple_frame, op): + # after arithmetic refactor, add truediv here + df = simple_frame + + row = df.xs("a") + col = df["two"] + f = getattr(df, op) + op = getattr(operator, op) + tm.assert_frame_equal(f(row), op(df, row)) + tm.assert_frame_equal(f(col, axis=0), op(df.T, col).T) + + def test_arith_flex_series(self, simple_frame): + df = simple_frame + + row = df.xs("a") + col = df["two"] + # special case for some reason + tm.assert_frame_equal(df.add(row, axis=None), df + row) + + # cases which will be refactored after big arithmetic refactor + tm.assert_frame_equal(df.div(row), df / row) + tm.assert_frame_equal(df.div(col, axis=0), (df.T / col).T) + + def test_arith_flex_series_broadcasting(self, any_real_numpy_dtype): + # broadcasting issue in GH 7325 + df = DataFrame(np.arange(3 * 2).reshape((3, 2)), dtype=any_real_numpy_dtype) + expected = DataFrame([[np.nan, np.inf], [1.0, 1.5], [1.0, 1.25]]) + if any_real_numpy_dtype == "float32": + expected = expected.astype(any_real_numpy_dtype) + result = df.div(df[0], axis="index") + tm.assert_frame_equal(result, expected) + + def test_arith_flex_zero_len_raises(self): + # GH 19522 passing fill_value to frame flex arith methods should + # raise even in the zero-length special cases + ser_len0 = Series([], dtype=object) + df_len0 = DataFrame(columns=["A", "B"]) + df = DataFrame([[1, 2], [3, 4]], columns=["A", "B"]) + + with pytest.raises(NotImplementedError, match="fill_value"): + df.add(ser_len0, fill_value="E") + + with pytest.raises(NotImplementedError, match="fill_value"): + df_len0.sub(df["A"], axis=None, fill_value=3) + + def test_flex_add_scalar_fill_value(self): + # GH#12723 + dat = np.array([0, 1, np.nan, 3, 4, 5], dtype="float") + df = DataFrame({"foo": dat}, index=range(6)) + + exp = df.fillna(0).add(2) + res = df.add(2, fill_value=0) + tm.assert_frame_equal(res, exp) + + def test_sub_alignment_with_duplicate_index(self): + # GH#5185 dup aligning operations should work + df1 = DataFrame([1, 2, 3, 4, 5], index=[1, 2, 1, 2, 3]) + df2 = DataFrame([1, 2, 3], index=[1, 2, 3]) + expected = DataFrame([0, 2, 0, 2, 2], index=[1, 1, 2, 2, 3]) + result = df1.sub(df2) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("op", ["__add__", "__mul__", "__sub__", "__truediv__"]) + def test_arithmetic_with_duplicate_columns(self, op): + # operations + df = DataFrame({"A": np.arange(10), "B": np.random.default_rng(2).random(10)}) + expected = getattr(df, op)(df) + expected.columns = ["A", "A"] + df.columns = ["A", "A"] + result = getattr(df, op)(df) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("level", [0, None]) + def test_broadcast_multiindex(self, level): + # GH34388 + df1 = DataFrame({"A": [0, 1, 2], "B": [1, 2, 3]}) + df1.columns = df1.columns.set_names("L1") + + df2 = DataFrame({("A", "C"): [0, 0, 0], ("A", "D"): [0, 0, 0]}) + df2.columns = df2.columns.set_names(["L1", "L2"]) + + result = df1.add(df2, level=level) + expected = DataFrame({("A", "C"): [0, 1, 2], ("A", "D"): [0, 1, 2]}) + expected.columns = expected.columns.set_names(["L1", "L2"]) + + tm.assert_frame_equal(result, expected) + + def test_frame_multiindex_operations(self): + # GH 43321 + df = DataFrame( + {2010: [1, 2, 3], 2020: [3, 4, 5]}, + index=MultiIndex.from_product( + [["a"], ["b"], [0, 1, 2]], names=["scen", "mod", "id"] + ), + ) + + series = Series( + [0.4], + index=MultiIndex.from_product([["b"], ["a"]], names=["mod", "scen"]), + ) + + expected = DataFrame( + {2010: [1.4, 2.4, 3.4], 2020: [3.4, 4.4, 5.4]}, + index=MultiIndex.from_product( + [["a"], ["b"], [0, 1, 2]], names=["scen", "mod", "id"] + ), + ) + result = df.add(series, axis=0) + + tm.assert_frame_equal(result, expected) + + def test_frame_multiindex_operations_series_index_to_frame_index(self): + # GH 43321 + df = DataFrame( + {2010: [1], 2020: [3]}, + index=MultiIndex.from_product([["a"], ["b"]], names=["scen", "mod"]), + ) + + series = Series( + [10.0, 20.0, 30.0], + index=MultiIndex.from_product( + [["a"], ["b"], [0, 1, 2]], names=["scen", "mod", "id"] + ), + ) + + expected = DataFrame( + {2010: [11.0, 21, 31.0], 2020: [13.0, 23.0, 33.0]}, + index=MultiIndex.from_product( + [["a"], ["b"], [0, 1, 2]], names=["scen", "mod", "id"] + ), + ) + result = df.add(series, axis=0) + + tm.assert_frame_equal(result, expected) + + def test_frame_multiindex_operations_no_align(self): + df = DataFrame( + {2010: [1, 2, 3], 2020: [3, 4, 5]}, + index=MultiIndex.from_product( + [["a"], ["b"], [0, 1, 2]], names=["scen", "mod", "id"] + ), + ) + + series = Series( + [0.4], + index=MultiIndex.from_product([["c"], ["a"]], names=["mod", "scen"]), + ) + + expected = DataFrame( + {2010: np.nan, 2020: np.nan}, + index=MultiIndex.from_tuples( + [ + ("a", "b", 0), + ("a", "b", 1), + ("a", "b", 2), + ("a", "c", np.nan), + ], + names=["scen", "mod", "id"], + ), + ) + result = df.add(series, axis=0) + + tm.assert_frame_equal(result, expected) + + def test_frame_multiindex_operations_part_align(self): + df = DataFrame( + {2010: [1, 2, 3], 2020: [3, 4, 5]}, + index=MultiIndex.from_tuples( + [ + ("a", "b", 0), + ("a", "b", 1), + ("a", "c", 2), + ], + names=["scen", "mod", "id"], + ), + ) + + series = Series( + [0.4], + index=MultiIndex.from_product([["b"], ["a"]], names=["mod", "scen"]), + ) + + expected = DataFrame( + {2010: [1.4, 2.4, np.nan], 2020: [3.4, 4.4, np.nan]}, + index=MultiIndex.from_tuples( + [ + ("a", "b", 0), + ("a", "b", 1), + ("a", "c", 2), + ], + names=["scen", "mod", "id"], + ), + ) + result = df.add(series, axis=0) + + tm.assert_frame_equal(result, expected) + + def test_frame_multiindex_operations_part_align_axis1(self): + # GH#61009 Test DataFrame-Series arithmetic operation + # with partly aligned MultiIndex and axis = 1 + df = DataFrame( + [[1, 2, 3], [3, 4, 5]], + index=[2010, 2020], + columns=MultiIndex.from_tuples( + [ + ("a", "b", 0), + ("a", "b", 1), + ("a", "c", 2), + ], + names=["scen", "mod", "id"], + ), + ) + + series = Series( + [0.4], + index=MultiIndex.from_product([["b"], ["a"]], names=["mod", "scen"]), + ) + + expected = DataFrame( + [[1.4, 2.4, np.nan], [3.4, 4.4, np.nan]], + index=[2010, 2020], + columns=MultiIndex.from_tuples( + [ + ("a", "b", 0), + ("a", "b", 1), + ("a", "c", 2), + ], + names=["scen", "mod", "id"], + ), + ) + result = df.add(series, axis=1) + + tm.assert_frame_equal(result, expected) + + +class TestFrameArithmetic: + def test_td64_op_nat_casting(self): + # Make sure we don't accidentally treat timedelta64(NaT) as datetime64 + # when calling dispatch_to_series in DataFrame arithmetic + ser = Series(["NaT", "NaT"], dtype="timedelta64[ns]") + df = DataFrame([[1, 2], [3, 4]]) + + result = df * ser + expected = DataFrame({0: ser, 1: ser}) + tm.assert_frame_equal(result, expected) + + def test_df_add_2d_array_rowlike_broadcasts(self): + # GH#23000 + arr = np.arange(6).reshape(3, 2) + df = DataFrame(arr, columns=[True, False], index=["A", "B", "C"]) + + rowlike = arr[[1], :] # shape --> (1, ncols) + assert rowlike.shape == (1, df.shape[1]) + + expected = DataFrame( + [[2, 4], [4, 6], [6, 8]], + columns=df.columns, + index=df.index, + # specify dtype explicitly to avoid failing + # on 32bit builds + dtype=arr.dtype, + ) + result = df + rowlike + tm.assert_frame_equal(result, expected) + result = rowlike + df + tm.assert_frame_equal(result, expected) + + def test_df_add_2d_array_collike_broadcasts(self): + # GH#23000 + arr = np.arange(6).reshape(3, 2) + df = DataFrame(arr, columns=[True, False], index=["A", "B", "C"]) + + collike = arr[:, [1]] # shape --> (nrows, 1) + assert collike.shape == (df.shape[0], 1) + + expected = DataFrame( + [[1, 2], [5, 6], [9, 10]], + columns=df.columns, + index=df.index, + # specify dtype explicitly to avoid failing + # on 32bit builds + dtype=arr.dtype, + ) + result = df + collike + tm.assert_frame_equal(result, expected) + result = collike + df + tm.assert_frame_equal(result, expected) + + def test_df_arith_2d_array_rowlike_broadcasts( + self, request, all_arithmetic_operators + ): + # GH#23000 + opname = all_arithmetic_operators + arr = np.arange(6).reshape(3, 2) + df = DataFrame(arr, columns=[True, False], index=["A", "B", "C"]) + + rowlike = arr[[1], :] # shape --> (1, ncols) + assert rowlike.shape == (1, df.shape[1]) + + exvals = [ + getattr(df.loc["A"], opname)(rowlike.squeeze()), + getattr(df.loc["B"], opname)(rowlike.squeeze()), + getattr(df.loc["C"], opname)(rowlike.squeeze()), + ] + + expected = DataFrame(exvals, columns=df.columns, index=df.index) + + result = getattr(df, opname)(rowlike) + tm.assert_frame_equal(result, expected) + + def test_df_arith_2d_array_collike_broadcasts( + self, request, all_arithmetic_operators + ): + # GH#23000 + opname = all_arithmetic_operators + arr = np.arange(6).reshape(3, 2) + df = DataFrame(arr, columns=[True, False], index=["A", "B", "C"]) + + collike = arr[:, [1]] # shape --> (nrows, 1) + assert collike.shape == (df.shape[0], 1) + + exvals = { + True: getattr(df[True], opname)(collike.squeeze()), + False: getattr(df[False], opname)(collike.squeeze()), + } + + dtype = None + if opname in ["__rmod__", "__rfloordiv__"]: + # Series ops may return mixed int/float dtypes in cases where + # DataFrame op will return all-float. So we upcast `expected` + dtype = np.common_type(*(x.values for x in exvals.values())) + + expected = DataFrame(exvals, columns=df.columns, index=df.index, dtype=dtype) + + result = getattr(df, opname)(collike) + tm.assert_frame_equal(result, expected) + + def test_df_bool_mul_int(self): + # GH 22047, GH 22163 multiplication by 1 should result in int dtype, + # not object dtype + df = DataFrame([[False, True], [False, False]]) + result = df * 1 + + # On appveyor this comes back as np.int32 instead of np.int64, + # so we check dtype.kind instead of just dtype + kinds = result.dtypes.apply(lambda x: x.kind) + assert (kinds == "i").all() + + result = 1 * df + kinds = result.dtypes.apply(lambda x: x.kind) + assert (kinds == "i").all() + + def test_arith_mixed(self): + left = DataFrame({"A": ["a", "b", "c"], "B": [1, 2, 3]}) + + result = left + left + expected = DataFrame({"A": ["aa", "bb", "cc"], "B": [2, 4, 6]}) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("col", ["A", "B"]) + def test_arith_getitem_commute(self, all_arithmetic_functions, col): + df = DataFrame({"A": [1.1, 3.3], "B": [2.5, -3.9]}) + result = all_arithmetic_functions(df, 1)[col] + expected = all_arithmetic_functions(df[col], 1) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "values", [[1, 2], (1, 2), np.array([1, 2]), range(1, 3), deque([1, 2])] + ) + def test_arith_alignment_non_pandas_object(self, values): + # GH#17901 + df = DataFrame({"A": [1, 1], "B": [1, 1]}) + expected = DataFrame({"A": [2, 2], "B": [3, 3]}) + result = df + values + tm.assert_frame_equal(result, expected) + + def test_arith_non_pandas_object(self): + df = DataFrame( + np.arange(1, 10, dtype="f8").reshape(3, 3), + columns=["one", "two", "three"], + index=["a", "b", "c"], + ) + + val1 = df.xs("a").values + added = DataFrame(df.values + val1, index=df.index, columns=df.columns) + tm.assert_frame_equal(df + val1, added) + + added = DataFrame((df.values.T + val1).T, index=df.index, columns=df.columns) + tm.assert_frame_equal(df.add(val1, axis=0), added) + + val2 = list(df["two"]) + + added = DataFrame(df.values + val2, index=df.index, columns=df.columns) + tm.assert_frame_equal(df + val2, added) + + added = DataFrame((df.values.T + val2).T, index=df.index, columns=df.columns) + tm.assert_frame_equal(df.add(val2, axis="index"), added) + + val3 = np.random.default_rng(2).random(df.shape) + added = DataFrame(df.values + val3, index=df.index, columns=df.columns) + tm.assert_frame_equal(df.add(val3), added) + + def test_operations_with_interval_categories_index(self, all_arithmetic_operators): + # GH#27415 + op = all_arithmetic_operators + ind = pd.CategoricalIndex(pd.interval_range(start=0.0, end=2.0)) + data = [1, 2] + df = DataFrame([data], columns=ind) + num = 10 + result = getattr(df, op)(num) + expected = DataFrame([[getattr(n, op)(num) for n in data]], columns=ind) + tm.assert_frame_equal(result, expected) + + def test_frame_with_frame_reindex(self): + # GH#31623 + df = DataFrame( + { + "foo": [pd.Timestamp("2019"), pd.Timestamp("2020")], + "bar": [pd.Timestamp("2018"), pd.Timestamp("2021")], + }, + columns=["foo", "bar"], + dtype="M8[ns]", + ) + df2 = df[["foo"]] + + result = df - df2 + + expected = DataFrame( + {"foo": [pd.Timedelta(0), pd.Timedelta(0)], "bar": [np.nan, np.nan]}, + columns=["bar", "foo"], + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "value, dtype", + [ + (1, "i8"), + (1.0, "f8"), + (2**63, "f8"), + (1j, "complex128"), + (2**63, "complex128"), + (True, "bool"), + (np.timedelta64(20, "ns"), "]=?' not supported between instances of 'str' and 'int'", + "Invalid comparison between dtype=str and int", + ] + ) + with pytest.raises(TypeError, match=msg): + f(df, 0) + + def test_comparison_protected_from_errstate(self): + missing_df = DataFrame( + np.ones((10, 4), dtype=np.float64), + columns=Index(list("ABCD"), dtype=object), + ) + missing_df.loc[missing_df.index[0], "A"] = np.nan + with np.errstate(invalid="ignore"): + expected = missing_df.values < 0 + with np.errstate(invalid="raise"): + result = (missing_df < 0).values + tm.assert_numpy_array_equal(result, expected) + + def test_boolean_comparison(self): + # GH 4576 + # boolean comparisons with a tuple/list give unexpected results + df = DataFrame(np.arange(6).reshape((3, 2))) + b = np.array([2, 2]) + b_r = np.atleast_2d([2, 2]) + b_c = b_r.T + lst = [2, 2, 2] + tup = tuple(lst) + + # gt + expected = DataFrame([[False, False], [False, True], [True, True]]) + result = df > b + tm.assert_frame_equal(result, expected) + + result = df.values > b + tm.assert_numpy_array_equal(result, expected.values) + + msg1d = "Unable to coerce to Series, length must be 2: given 3" + msg2d = "Unable to coerce to DataFrame, shape must be" + msg2db = "operands could not be broadcast together with shapes" + with pytest.raises(ValueError, match=msg1d): + # wrong shape + df > lst + + with pytest.raises(ValueError, match=msg1d): + # wrong shape + df > tup + + # broadcasts like ndarray (GH#23000) + result = df > b_r + tm.assert_frame_equal(result, expected) + + result = df.values > b_r + tm.assert_numpy_array_equal(result, expected.values) + + with pytest.raises(ValueError, match=msg2d): + df > b_c + + with pytest.raises(ValueError, match=msg2db): + df.values > b_c + + # == + expected = DataFrame([[False, False], [True, False], [False, False]]) + result = df == b + tm.assert_frame_equal(result, expected) + + with pytest.raises(ValueError, match=msg1d): + df == lst + + with pytest.raises(ValueError, match=msg1d): + df == tup + + # broadcasts like ndarray (GH#23000) + result = df == b_r + tm.assert_frame_equal(result, expected) + + result = df.values == b_r + tm.assert_numpy_array_equal(result, expected.values) + + with pytest.raises(ValueError, match=msg2d): + df == b_c + + assert df.values.shape != b_c.shape + + # with alignment + df = DataFrame( + np.arange(6).reshape((3, 2)), columns=list("AB"), index=list("abc") + ) + expected.index = df.index + expected.columns = df.columns + + with pytest.raises(ValueError, match=msg1d): + df == lst + + with pytest.raises(ValueError, match=msg1d): + df == tup + + def test_inplace_ops_alignment(self): + # inplace ops / ops alignment + # GH 8511 + + columns = list("abcdefg") + X_orig = DataFrame( + np.arange(10 * len(columns)).reshape(-1, len(columns)), + columns=columns, + index=range(10), + ) + Z = 100 * X_orig.iloc[:, 1:-1].copy() + block1 = list("bedcf") + subs = list("bcdef") + + # add + X = X_orig.copy() + result1 = (X[block1] + Z).reindex(columns=subs) + + X[block1] += Z + result2 = X.reindex(columns=subs) + + X = X_orig.copy() + result3 = (X[block1] + Z[block1]).reindex(columns=subs) + + X[block1] += Z[block1] + result4 = X.reindex(columns=subs) + + tm.assert_frame_equal(result1, result2) + tm.assert_frame_equal(result1, result3) + tm.assert_frame_equal(result1, result4) + + # sub + X = X_orig.copy() + result1 = (X[block1] - Z).reindex(columns=subs) + + X[block1] -= Z + result2 = X.reindex(columns=subs) + + X = X_orig.copy() + result3 = (X[block1] - Z[block1]).reindex(columns=subs) + + X[block1] -= Z[block1] + result4 = X.reindex(columns=subs) + + tm.assert_frame_equal(result1, result2) + tm.assert_frame_equal(result1, result3) + tm.assert_frame_equal(result1, result4) + + def test_inplace_ops_identity(self): + # GH 5104 + # make sure that we are actually changing the object + s_orig = Series([1, 2, 3]) + df_orig = DataFrame( + np.random.default_rng(2).integers(0, 5, size=10).reshape(-1, 5) + ) + + # no dtype change + s = s_orig.copy() + s2 = s + s += 1 + tm.assert_series_equal(s, s2) + tm.assert_series_equal(s_orig + 1, s) + assert s is s2 + assert s._mgr is s2._mgr + + df = df_orig.copy() + df2 = df + df += 1 + tm.assert_frame_equal(df, df2) + tm.assert_frame_equal(df_orig + 1, df) + assert df is df2 + assert df._mgr is df2._mgr + + # dtype change + s = s_orig.copy() + s2 = s + s += 1.5 + tm.assert_series_equal(s, s2) + tm.assert_series_equal(s_orig + 1.5, s) + + df = df_orig.copy() + df2 = df + df += 1.5 + tm.assert_frame_equal(df, df2) + tm.assert_frame_equal(df_orig + 1.5, df) + assert df is df2 + assert df._mgr is df2._mgr + + # mixed dtype + arr = np.random.default_rng(2).integers(0, 10, size=5) + df_orig = DataFrame({"A": arr.copy(), "B": "foo"}) + df = df_orig.copy() + df2 = df + df["A"] += 1 + expected = DataFrame({"A": arr.copy() + 1, "B": "foo"}) + tm.assert_frame_equal(df, expected) + tm.assert_frame_equal(df2, expected) + assert df._mgr is df2._mgr + + df = df_orig.copy() + df2 = df + df["A"] += 1.5 + expected = DataFrame({"A": arr.copy() + 1.5, "B": "foo"}) + tm.assert_frame_equal(df, expected) + tm.assert_frame_equal(df2, expected) + assert df._mgr is df2._mgr + + @pytest.mark.parametrize( + "op", + [ + "add", + "and", + pytest.param( + "div", + marks=pytest.mark.xfail( + raises=AttributeError, reason="__idiv__ not implemented" + ), + ), + "floordiv", + "mod", + "mul", + "or", + "pow", + "sub", + "truediv", + "xor", + ], + ) + def test_inplace_ops_identity2(self, op): + df = DataFrame({"a": [1.0, 2.0, 3.0], "b": [1, 2, 3]}) + + operand = 2 + if op in ("and", "or", "xor"): + # cannot use floats for boolean ops + df["a"] = [True, False, True] + + df_copy = df.copy() + iop = f"__i{op}__" + op = f"__{op}__" + + # no id change and value is correct + getattr(df, iop)(operand) + expected = getattr(df_copy, op)(operand) + tm.assert_frame_equal(df, expected) + expected = id(df) + assert id(df) == expected + + @pytest.mark.parametrize( + "val", + [ + [1, 2, 3], + (1, 2, 3), + np.array([1, 2, 3], dtype=np.int64), + range(1, 4), + ], + ) + def test_alignment_non_pandas(self, val): + index = ["A", "B", "C"] + columns = ["X", "Y", "Z"] + df = DataFrame( + np.random.default_rng(2).standard_normal((3, 3)), + index=index, + columns=columns, + ) + + align = DataFrame._align_for_op + + expected = DataFrame({"X": val, "Y": val, "Z": val}, index=df.index) + tm.assert_frame_equal(align(df, val, axis=0)[1], expected) + + expected = DataFrame( + {"X": [1, 1, 1], "Y": [2, 2, 2], "Z": [3, 3, 3]}, index=df.index + ) + tm.assert_frame_equal(align(df, val, axis=1)[1], expected) + + @pytest.mark.parametrize("val", [[1, 2], (1, 2), np.array([1, 2]), range(1, 3)]) + def test_alignment_non_pandas_length_mismatch(self, val): + index = ["A", "B", "C"] + columns = ["X", "Y", "Z"] + df = DataFrame( + np.random.default_rng(2).standard_normal((3, 3)), + index=index, + columns=columns, + ) + + align = DataFrame._align_for_op + # length mismatch + msg = "Unable to coerce to Series, length must be 3: given 2" + with pytest.raises(ValueError, match=msg): + align(df, val, axis=0) + + with pytest.raises(ValueError, match=msg): + align(df, val, axis=1) + + def test_alignment_non_pandas_index_columns(self): + index = ["A", "B", "C"] + columns = ["X", "Y", "Z"] + df = DataFrame( + np.random.default_rng(2).standard_normal((3, 3)), + index=index, + columns=columns, + ) + + align = DataFrame._align_for_op + val = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + tm.assert_frame_equal( + align(df, val, axis=0)[1], + DataFrame(val, index=df.index, columns=df.columns), + ) + tm.assert_frame_equal( + align(df, val, axis=1)[1], + DataFrame(val, index=df.index, columns=df.columns), + ) + + # shape mismatch + msg = "Unable to coerce to DataFrame, shape must be" + val = np.array([[1, 2, 3], [4, 5, 6]]) + with pytest.raises(ValueError, match=msg): + align(df, val, axis=0) + + with pytest.raises(ValueError, match=msg): + align(df, val, axis=1) + + val = np.zeros((3, 3, 3)) + msg = re.escape( + "Unable to coerce to Series/DataFrame, dimension must be <= 2: (3, 3, 3)" + ) + with pytest.raises(ValueError, match=msg): + align(df, val, axis=0) + with pytest.raises(ValueError, match=msg): + align(df, val, axis=1) + + def test_no_warning(self, all_arithmetic_operators): + df = DataFrame({"A": [0.0, 0.0], "B": [0.0, None]}) + b = df["B"] + with tm.assert_produces_warning(None): + getattr(df, all_arithmetic_operators)(b) + + def test_dunder_methods_binary(self, all_arithmetic_operators): + # GH#??? frame.__foo__ should only accept one argument + df = DataFrame({"A": [0.0, 0.0], "B": [0.0, None]}) + b = df["B"] + with pytest.raises(TypeError, match="takes 2 positional arguments"): + getattr(df, all_arithmetic_operators)(b, 0) + + def test_align_int_fill_bug(self): + # GH#910 + X = np.arange(10 * 10, dtype="float64").reshape(10, 10) + Y = np.ones((10, 1), dtype=int) + + df1 = DataFrame(X) + df1["0.X"] = Y.squeeze() + + df2 = df1.astype(float) + + result = df1 - df1.mean() + expected = df2 - df2.mean() + tm.assert_frame_equal(result, expected) + + +def test_pow_with_realignment(): + # GH#32685 pow has special semantics for operating with null values + left = DataFrame({"A": [0, 1, 2]}) + right = DataFrame(index=[0, 1, 2]) + + result = left**right + expected = DataFrame({"A": [np.nan, 1.0, np.nan]}) + tm.assert_frame_equal(result, expected) + + +def test_dataframe_series_extension_dtypes(): + # https://github.com/pandas-dev/pandas/issues/34311 + df = DataFrame( + np.random.default_rng(2).integers(0, 100, (10, 3)), columns=["a", "b", "c"] + ) + ser = Series([1, 2, 3], index=["a", "b", "c"]) + + expected = df.to_numpy("int64") + ser.to_numpy("int64").reshape(-1, 3) + expected = DataFrame(expected, columns=df.columns, dtype="Int64") + + df_ea = df.astype("Int64") + result = df_ea + ser + tm.assert_frame_equal(result, expected) + result = df_ea + ser.astype("Int64") + tm.assert_frame_equal(result, expected) + + +def test_dataframe_blockwise_slicelike(): + # GH#34367 + arr = np.random.default_rng(2).integers(0, 1000, (100, 10)) + df1 = DataFrame(arr) + # Explicit cast to float to avoid implicit cast when setting nan + df2 = df1.copy().astype({1: "float", 3: "float", 7: "float"}) + df2.iloc[0, [1, 3, 7]] = np.nan + + # Explicit cast to float to avoid implicit cast when setting nan + df3 = df1.copy().astype({5: "float"}) + df3.iloc[0, [5]] = np.nan + + # Explicit cast to float to avoid implicit cast when setting nan + df4 = df1.copy().astype({2: "float", 3: "float", 4: "float"}) + df4.iloc[0, np.arange(2, 5)] = np.nan + # Explicit cast to float to avoid implicit cast when setting nan + df5 = df1.copy().astype({4: "float", 5: "float", 6: "float"}) + df5.iloc[0, np.arange(4, 7)] = np.nan + + for left, right in [(df1, df2), (df2, df3), (df4, df5)]: + res = left + right + + expected = DataFrame({i: left[i] + right[i] for i in left.columns}) + tm.assert_frame_equal(res, expected) + + +@pytest.mark.parametrize( + "df, col_dtype", + [ + (DataFrame([[1.0, 2.0], [4.0, 5.0]], columns=list("ab")), "float64"), + ( + DataFrame([[1.0, "b"], [4.0, "b"]], columns=list("ab")).astype( + {"b": object} + ), + "object", + ), + ], +) +def test_dataframe_operation_with_non_numeric_types(df, col_dtype): + # GH #22663 + expected = DataFrame([[0.0, np.nan], [3.0, np.nan]], columns=list("ab")) + expected = expected.astype({"b": col_dtype}) + result = df + Series([-1.0], index=list("a")) + tm.assert_frame_equal(result, expected) + + +def test_arith_reindex_with_duplicates(): + # https://github.com/pandas-dev/pandas/issues/35194 + df1 = DataFrame(data=[[0]], columns=["second"]) + df2 = DataFrame(data=[[0, 0, 0]], columns=["first", "second", "second"]) + result = df1 + df2 + expected = DataFrame([[np.nan, 0, 0]], columns=["first", "second", "second"]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("to_add", [[Series([1, 1])], [Series([1, 1]), Series([1, 1])]]) +def test_arith_list_of_arraylike_raise(to_add): + # GH 36702. Raise when trying to add list of array-like to DataFrame + df = DataFrame({"x": [1, 2], "y": [1, 2]}) + + msg = f"Unable to coerce list of {type(to_add[0])} to Series/DataFrame" + with pytest.raises(ValueError, match=msg): + df + to_add + with pytest.raises(ValueError, match=msg): + to_add + df + + +def test_inplace_arithmetic_series_update(): + # https://github.com/pandas-dev/pandas/issues/36373 + df = DataFrame({"A": [1, 2, 3]}) + df_orig = df.copy() + series = df["A"] + vals = series._values + + series += 1 + assert series._values is not vals + tm.assert_frame_equal(df, df_orig) + + +def test_arithmetic_multiindex_align(): + """ + Regression test for: https://github.com/pandas-dev/pandas/issues/33765 + """ + df1 = DataFrame( + [[1]], + index=["a"], + columns=MultiIndex.from_product([[0], [1]], names=["a", "b"]), + ) + df2 = DataFrame([[1]], index=["a"], columns=Index([0], name="a")) + expected = DataFrame( + [[0]], + index=["a"], + columns=MultiIndex.from_product([[0], [1]], names=["a", "b"]), + ) + result = df1 - df2 + tm.assert_frame_equal(result, expected) + + +def test_arithmetic_multiindex_column_align(): + # GH#60498 + df1 = DataFrame( + data=100, + columns=MultiIndex.from_product( + [["1A", "1B"], ["2A", "2B"]], names=["Lev1", "Lev2"] + ), + index=["C1", "C2"], + ) + df2 = DataFrame( + data=np.array([[0.1, 0.25], [0.2, 0.45]]), + columns=MultiIndex.from_product([["1A", "1B"]], names=["Lev1"]), + index=["C1", "C2"], + ) + expected = DataFrame( + data=np.array([[10.0, 10.0, 25.0, 25.0], [20.0, 20.0, 45.0, 45.0]]), + columns=MultiIndex.from_product( + [["1A", "1B"], ["2A", "2B"]], names=["Lev1", "Lev2"] + ), + index=["C1", "C2"], + ) + result = df1 * df2 + tm.assert_frame_equal(result, expected) + + +def test_arithmetic_multiindex_column_align_with_fillvalue(): + # GH#60903 + df1 = DataFrame( + data=[[1.0, 2.0]], + columns=MultiIndex.from_tuples([("A", "one"), ("A", "two")]), + ) + df2 = DataFrame( + data=[[3.0, 4.0]], + columns=MultiIndex.from_tuples([("B", "one"), ("B", "two")]), + ) + expected = DataFrame( + data=[[1.0, 2.0, 3.0, 4.0]], + columns=MultiIndex.from_tuples( + [("A", "one"), ("A", "two"), ("B", "one"), ("B", "two")] + ), + ) + result = df1.add(df2, fill_value=0) + tm.assert_frame_equal(result, expected) + + +def test_bool_frame_mult_float(): + # GH 18549 + df = DataFrame(True, list("ab"), list("cd")) + result = df * 1.0 + expected = DataFrame(np.ones((2, 2)), list("ab"), list("cd")) + tm.assert_frame_equal(result, expected) + + +def test_frame_sub_nullable_int(any_int_ea_dtype): + # GH 32822 + series1 = Series([1, 2, None], dtype=any_int_ea_dtype) + series2 = Series([1, 2, 3], dtype=any_int_ea_dtype) + expected = DataFrame([0, 0, None], dtype=any_int_ea_dtype) + result = series1.to_frame() - series2.to_frame() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:Passing a BlockManager|Passing a SingleBlockManager:DeprecationWarning" +) +def test_frame_op_subclass_nonclass_constructor(): + # GH#43201 subclass._constructor is a function, not the subclass itself + + class SubclassedSeries(Series): + @property + def _constructor(self): + return SubclassedSeries + + @property + def _constructor_expanddim(self): + return SubclassedDataFrame + + class SubclassedDataFrame(DataFrame): + _metadata = ["my_extra_data"] + + def __init__(self, my_extra_data, *args, **kwargs) -> None: + self.my_extra_data = my_extra_data + super().__init__(*args, **kwargs) + + @property + def _constructor(self): + return functools.partial(type(self), self.my_extra_data) + + @property + def _constructor_sliced(self): + return SubclassedSeries + + sdf = SubclassedDataFrame("some_data", {"A": [1, 2, 3], "B": [4, 5, 6]}) + result = sdf * 2 + expected = SubclassedDataFrame("some_data", {"A": [2, 4, 6], "B": [8, 10, 12]}) + tm.assert_frame_equal(result, expected) + + result = sdf + sdf + tm.assert_frame_equal(result, expected) + + +def test_enum_column_equality(): + Cols = Enum("Cols", "col1 col2") + + q1 = DataFrame({Cols.col1: [1, 2, 3]}) + q2 = DataFrame({Cols.col1: [1, 2, 3]}) + + result = q1[Cols.col1] == q2[Cols.col1] + expected = Series([True, True, True], name=Cols.col1) + + tm.assert_series_equal(result, expected) + + +def test_mixed_col_index_dtype(string_dtype_no_object): + # GH 47382 + df1 = DataFrame(columns=list("abc"), data=1.0, index=[0]) + df2 = DataFrame(columns=list("abc"), data=0.0, index=[0]) + df1.columns = df2.columns.astype(string_dtype_no_object) + result = df1 + df2 + expected = DataFrame(columns=list("abc"), data=1.0, index=[0]) + + expected.columns = expected.columns.astype(string_dtype_no_object) + + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/frame/test_arrow_interface.py b/pandas/tests/frame/test_arrow_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..fcebabb434683027d40e1bfca2febe6e733f18f3 --- /dev/null +++ b/pandas/tests/frame/test_arrow_interface.py @@ -0,0 +1,94 @@ +import ctypes + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd +import pandas._testing as tm + +pa = pytest.importorskip("pyarrow") + + +@td.skip_if_no("pyarrow", min_version="14.0") +def test_dataframe_arrow_interface(using_infer_string): + df = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + + capsule = df.__arrow_c_stream__() + assert ( + ctypes.pythonapi.PyCapsule_IsValid( + ctypes.py_object(capsule), b"arrow_array_stream" + ) + == 1 + ) + + table = pa.table(df) + string_type = pa.large_string() if using_infer_string else pa.string() + expected = pa.table({"a": [1, 2, 3], "b": pa.array(["a", "b", "c"], string_type)}) + assert table.equals(expected) + + schema = pa.schema([("a", pa.int8()), ("b", pa.string())]) + table = pa.table(df, schema=schema) + expected = expected.cast(schema) + assert table.equals(expected) + + +@td.skip_if_no("pyarrow", min_version="15.0") +def test_dataframe_to_arrow(using_infer_string): + df = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + + table = pa.RecordBatchReader.from_stream(df).read_all() + string_type = pa.large_string() if using_infer_string else pa.string() + expected = pa.table({"a": [1, 2, 3], "b": pa.array(["a", "b", "c"], string_type)}) + assert table.equals(expected) + + schema = pa.schema([("a", pa.int8()), ("b", pa.string())]) + table = pa.RecordBatchReader.from_stream(df, schema=schema).read_all() + expected = expected.cast(schema) + assert table.equals(expected) + + +class ArrowArrayWrapper: + def __init__(self, batch): + self.array = batch + + def __arrow_c_array__(self, requested_schema=None): + return self.array.__arrow_c_array__(requested_schema) + + +class ArrowStreamWrapper: + def __init__(self, table): + self.stream = table + + def __arrow_c_stream__(self, requested_schema=None): + return self.stream.__arrow_c_stream__(requested_schema) + + +@td.skip_if_no("pyarrow", min_version="14.0") +def test_dataframe_from_arrow(using_infer_string): + # objects with __arrow_c_stream__ + table = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + + result = pd.DataFrame.from_arrow(table) + expected = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + if not using_infer_string: + expected["b"] = expected["b"].astype(pd.StringDtype(na_value=np.nan)) + tm.assert_frame_equal(result, expected) + + # not only pyarrow object are supported + result = pd.DataFrame.from_arrow(ArrowStreamWrapper(table)) + tm.assert_frame_equal(result, expected) + + # objects with __arrow_c_array__ + batch = pa.record_batch([[1, 2, 3], ["a", "b", "c"]], names=["a", "b"]) + + result = pd.DataFrame.from_arrow(table) + tm.assert_frame_equal(result, expected) + + result = pd.DataFrame.from_arrow(ArrowArrayWrapper(batch)) + tm.assert_frame_equal(result, expected) + + # only accept actual Arrow objects + with pytest.raises(TypeError, match="Expected an Arrow-compatible tabular object"): + pd.DataFrame.from_arrow({"a": [1, 2, 3], "b": ["a", "b", "c"]}) diff --git a/pandas/tests/frame/test_block_internals.py b/pandas/tests/frame/test_block_internals.py new file mode 100644 index 0000000000000000000000000000000000000000..ac7438ecf492d950e660196f3bb43152addcde52 --- /dev/null +++ b/pandas/tests/frame/test_block_internals.py @@ -0,0 +1,455 @@ +from datetime import ( + datetime, + timedelta, +) +from io import StringIO +import itertools +from textwrap import dedent + +import numpy as np +import pytest + +from pandas.errors import Pandas4Warning +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + Series, + Timestamp, + date_range, + option_context, +) +import pandas._testing as tm +from pandas.core.internals.blocks import NumpyBlock + +# Segregated collection of methods that require the BlockManager internal data +# structure + + +class TestDataFrameBlockInternals: + def test_setitem_invalidates_datetime_index_freq(self): + # GH#24096 altering a datetime64tz column inplace invalidates the + # `freq` attribute on the underlying DatetimeIndex + + dti = date_range("20130101", periods=3, tz="US/Eastern") + ts = dti[1] + + df = DataFrame({"B": dti}) + assert df["B"]._values.freq is None + + df.iloc[1, 0] = pd.NaT + assert df["B"]._values.freq is None + + # check that the DatetimeIndex was not altered in place + assert dti.freq == "D" + assert dti[1] == ts + + def test_cast_internals(self, float_frame): + msg = "Passing a BlockManager to DataFrame" + with tm.assert_produces_warning( + Pandas4Warning, match=msg, check_stacklevel=False + ): + casted = DataFrame(float_frame._mgr, dtype=int) + expected = DataFrame(float_frame._series, dtype=int) + tm.assert_frame_equal(casted, expected) + + with tm.assert_produces_warning( + Pandas4Warning, match=msg, check_stacklevel=False + ): + casted = DataFrame(float_frame._mgr, dtype=np.int32) + expected = DataFrame(float_frame._series, dtype=np.int32) + tm.assert_frame_equal(casted, expected) + + def test_consolidate(self, float_frame): + float_frame["E"] = 7.0 + consolidated = float_frame._consolidate() + assert len(consolidated._mgr.blocks) == 1 + + # Ensure copy, do I want this? + recons = consolidated._consolidate() + assert recons is not consolidated + tm.assert_frame_equal(recons, consolidated) + + float_frame["F"] = 8.0 + assert len(float_frame._mgr.blocks) == 3 + + return_value = float_frame._consolidate_inplace() + assert return_value is None + assert len(float_frame._mgr.blocks) == 1 + + def test_consolidate_inplace(self, float_frame): + # triggers in-place consolidation + for letter in range(ord("A"), ord("Z")): + float_frame[chr(letter)] = chr(letter) + + def test_modify_values(self, float_frame): + with pytest.raises(ValueError, match="read-only"): + float_frame.values[5] = 5 + assert (float_frame.values[5] != 5).all() + + def test_boolean_set_uncons(self, float_frame): + float_frame["E"] = 7.0 + + expected = float_frame.values.copy() + expected[expected > 1] = 2 + + float_frame[float_frame > 1] = 2 + tm.assert_almost_equal(expected, float_frame.values) + + def test_constructor_with_convert(self): + # this is actually mostly a test of lib.maybe_convert_objects + # #2845 + df = DataFrame({"A": [2**63 - 1]}) + result = df["A"] + expected = Series(np.asarray([2**63 - 1], np.int64), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [2**63]}) + result = df["A"] + expected = Series(np.asarray([2**63], np.uint64), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [datetime(2005, 1, 1), True]}) + result = df["A"] + expected = Series( + np.asarray([datetime(2005, 1, 1), True], np.object_), name="A" + ) + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [None, 1]}) + result = df["A"] + expected = Series(np.asarray([np.nan, 1], np.float64), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1.0, 2]}) + result = df["A"] + expected = Series(np.asarray([1.0, 2], np.float64), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1.0 + 2.0j, 3]}) + result = df["A"] + expected = Series(np.asarray([1.0 + 2.0j, 3], np.complex128), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1.0 + 2.0j, 3.0]}) + result = df["A"] + expected = Series(np.asarray([1.0 + 2.0j, 3.0], np.complex128), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1.0 + 2.0j, True]}) + result = df["A"] + expected = Series(np.asarray([1.0 + 2.0j, True], np.object_), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1.0, None]}) + result = df["A"] + expected = Series(np.asarray([1.0, np.nan], np.float64), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [1.0 + 2.0j, None]}) + result = df["A"] + expected = Series(np.asarray([1.0 + 2.0j, np.nan], np.complex128), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [2.0, 1, True, None]}) + result = df["A"] + expected = Series(np.asarray([2.0, 1, True, None], np.object_), name="A") + tm.assert_series_equal(result, expected) + + df = DataFrame({"A": [2.0, 1, datetime(2006, 1, 1), None]}) + result = df["A"] + expected = Series( + np.asarray([2.0, 1, datetime(2006, 1, 1), None], np.object_), name="A" + ) + tm.assert_series_equal(result, expected) + + def test_construction_with_mixed(self, float_string_frame, using_infer_string): + # mixed-type frames + float_string_frame["datetime"] = datetime.now() + float_string_frame["timedelta"] = timedelta(days=1, seconds=1) + assert float_string_frame["datetime"].dtype == "M8[us]" + assert float_string_frame["timedelta"].dtype == "m8[us]" + result = float_string_frame.dtypes + expected = Series( + [np.dtype("float64")] * 4 + + [ + np.dtype("object") + if not using_infer_string + else pd.StringDtype(na_value=np.nan), + np.dtype("datetime64[us]"), + np.dtype("timedelta64[us]"), + ], + index=[*list("ABCD"), "foo", "datetime", "timedelta"], + ) + tm.assert_series_equal(result, expected) + + def test_construction_with_conversions(self): + # convert from a numpy array of non-ns timedelta64; as of 2.0 this does + # *not* convert + arr = np.array([1, 2, 3], dtype="timedelta64[s]") + df = DataFrame({"A": arr}) + expected = DataFrame( + {"A": pd.timedelta_range("00:00:01", periods=3, freq="s")}, index=range(3) + ) + tm.assert_numpy_array_equal(df["A"].to_numpy(), arr) + + expected = DataFrame( + { + "dt1": Timestamp("20130101").as_unit("s"), + "dt2": date_range("20130101", periods=3).astype("M8[s]"), + # 'dt3' : date_range('20130101 00:00:01',periods=3,freq='s'), + # FIXME: don't leave commented-out + }, + index=range(3), + ) + assert expected.dtypes["dt1"] == "M8[s]" + assert expected.dtypes["dt2"] == "M8[s]" + + dt1 = np.datetime64("2013-01-01") + dt2 = np.array( + ["2013-01-01", "2013-01-02", "2013-01-03"], dtype="datetime64[D]" + ) + df = DataFrame({"dt1": dt1, "dt2": dt2}) + + # df['dt3'] = np.array(['2013-01-01 00:00:01','2013-01-01 + # 00:00:02','2013-01-01 00:00:03'],dtype='datetime64[s]') + # FIXME: don't leave commented-out + + tm.assert_frame_equal(df, expected) + + def test_constructor_compound_dtypes(self): + # GH 5191 + # compound dtypes should raise not-implementederror + + def f(dtype): + data = list(itertools.repeat((datetime(2001, 1, 1), "aa", 20), 9)) + return DataFrame(data=data, columns=["A", "B", "C"], dtype=dtype) + + msg = "compound dtypes are not implemented in the DataFrame constructor" + with pytest.raises(NotImplementedError, match=msg): + f([("A", "datetime64[h]"), ("B", "str"), ("C", "int32")]) + + # pre-2.0 these used to work (though results may be unexpected) + with pytest.raises(TypeError, match="argument must be"): + f("int64") + with pytest.raises(TypeError, match="argument must be"): + f("float64") + + # 10822 + msg = "^Unknown datetime string format, unable to parse: aa$" + with pytest.raises(ValueError, match=msg): + f("M8[ns]") + + def test_pickle_float_string_frame(self, float_string_frame, temp_file): + unpickled = tm.round_trip_pickle(float_string_frame, temp_file) + tm.assert_frame_equal(float_string_frame, unpickled) + + # buglet + float_string_frame._mgr.ndim + + def test_pickle_empty(self, temp_file): + empty_frame = DataFrame() + unpickled = tm.round_trip_pickle(empty_frame, temp_file) + repr(unpickled) + + def test_pickle_empty_tz_frame(self, timezone_frame, temp_file): + unpickled = tm.round_trip_pickle(timezone_frame, temp_file) + tm.assert_frame_equal(timezone_frame, unpickled) + + def test_consolidate_datetime64(self): + # numpy vstack bug + + df = DataFrame( + { + "starting": pd.to_datetime( + [ + "2012-06-21 00:00", + "2012-06-23 07:00", + "2012-06-23 16:30", + "2012-06-25 08:00", + "2012-06-26 12:00", + ] + ), + "ending": pd.to_datetime( + [ + "2012-06-23 07:00", + "2012-06-23 16:30", + "2012-06-25 08:00", + "2012-06-26 12:00", + "2012-06-27 08:00", + ] + ), + "measure": [77, 65, 77, 0, 77], + } + ) + + ser_starting = df.starting + ser_starting.index = ser_starting.values + ser_starting = ser_starting.tz_localize("US/Eastern") + ser_starting = ser_starting.tz_convert("UTC") + ser_starting.index.name = "starting" + + ser_ending = df.ending + ser_ending.index = ser_ending.values + ser_ending = ser_ending.tz_localize("US/Eastern") + ser_ending = ser_ending.tz_convert("UTC") + ser_ending.index.name = "ending" + + df.starting = ser_starting.index + df.ending = ser_ending.index + + tm.assert_index_equal(pd.DatetimeIndex(df.starting), ser_starting.index) + tm.assert_index_equal(pd.DatetimeIndex(df.ending), ser_ending.index) + + def test_is_mixed_type(self, float_frame, float_string_frame): + assert not float_frame._is_mixed_type + assert float_string_frame._is_mixed_type + + def test_stale_cached_series_bug_473(self): + # this is chained, but ok + with option_context("chained_assignment", None): + Y = DataFrame( + np.random.default_rng(2).random((4, 4)), + index=("a", "b", "c", "d"), + columns=("e", "f", "g", "h"), + ) + repr(Y) + Y["e"] = Y["e"].astype("object") + with tm.raises_chained_assignment_error(): + Y["g"]["c"] = np.nan + repr(Y) + Y.sum() + Y["g"].sum() + assert not pd.isna(Y["g"]["c"]) + + def test_strange_column_corruption_issue(self, performance_warning): + # TODO(wesm): Unclear how exactly this is related to internal matters + df = DataFrame(index=[0, 1]) + df[0] = np.nan + wasCol = {} + + with tm.assert_produces_warning( + performance_warning, raise_on_extra_warnings=False + ): + for i, dt in enumerate(df.index): + for col in range(100, 200): + if col not in wasCol: + wasCol[col] = 1 + df[col] = np.nan + df.loc[dt, col] = i + + myid = 100 + + first = len(df.loc[pd.isna(df[myid]), [myid]]) + second = len(df.loc[pd.isna(df[myid]), [myid]]) + assert first == second == 0 + + def test_constructor_no_pandas_array(self): + # Ensure that NumpyExtensionArray isn't allowed inside Series + # See https://github.com/pandas-dev/pandas/issues/23995 for more. + arr = Series([1, 2, 3]).array + result = DataFrame({"A": arr}) + expected = DataFrame({"A": [1, 2, 3]}) + tm.assert_frame_equal(result, expected) + assert isinstance(result._mgr.blocks[0], NumpyBlock) + assert result._mgr.blocks[0].is_numeric + + def test_add_column_with_pandas_array(self): + # GH 26390 + df = DataFrame({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "d"]}) + df["c"] = pd.arrays.NumpyExtensionArray(np.array([1, 2, None, 3], dtype=object)) + df2 = DataFrame( + { + "a": [1, 2, 3, 4], + "b": ["a", "b", "c", "d"], + "c": pd.arrays.NumpyExtensionArray( + np.array([1, 2, None, 3], dtype=object) + ), + } + ) + assert type(df["c"]._mgr.blocks[0]) == NumpyBlock + assert df["c"]._mgr.blocks[0].is_object + assert type(df2["c"]._mgr.blocks[0]) == NumpyBlock + assert df2["c"]._mgr.blocks[0].is_object + tm.assert_frame_equal(df, df2) + + +def test_update_inplace_sets_valid_block_values(): + # https://github.com/pandas-dev/pandas/issues/33457 + df = DataFrame({"a": Series([1, 2, None], dtype="category")}) + + # inplace update of a single column + with tm.raises_chained_assignment_error(): + df["a"].fillna(1, inplace=True) + + # check we haven't put a Series into any block.values + assert isinstance(df._mgr.blocks[0].values, Categorical) + + +def get_longley_data(): + # From statsmodels.datasets.longley + # This specific dataset seems to trigger races in Pandas 3.0.0 more readily + # than data frames used elsewhere in the tests + longley_csv = StringIO( + dedent( + """"Obs","GNPDEFL","GNP","UNEMP","ARMED","POP","YEAR" + 1,83,234289,2356,1590,107608,1947 + 2,88.5,259426,2325,1456,108632,1948 + 3,88.2,258054,3682,1616,109773,1949 + 4,89.5,284599,3351,1650,110929,1950 + 5,96.2,328975,2099,3099,112075,1951 + 6,98.1,346999,1932,3594,113270,1952 + 7,99,365385,1870,3547,115094,1953 + 8,100,363112,3578,3350,116219,1954 + 9,101.2,397469,2904,3048,117388,1955 + 10,104.6,419180,2822,2857,118734,1956 + 11,108.4,442769,2936,2798,120445,1957 + 12,110.8,444546,4681,2637,121950,1958 + 13,112.6,482704,3813,2552,123366,1959 + 14,114.2,502601,3931,2514,125368,1960 + 15,115.7,518173,4806,2572,127852,1961 + 16,116.9,554894,4007,2827,130081,1962 + """ + ) + ) + + return pd.read_csv(longley_csv).iloc[:, [1, 2, 3, 4, 5, 6]].astype(float) + + +# See gh-63685, comparisons and copying led to races in statsmodels tests +# +# This test spawns a thread pool, so it shouldn't run under xdist. +# It generates warnings, so it needs warnings to be thread-safe as well +@td.skip_if_thread_unsafe_warnings +@pytest.mark.single_cpu +def test_multithreaded_reading(): + def numpy_assert(data, b): + b.wait() + tm.assert_almost_equal((data + 1) - 1, data.copy()) + + tm.run_multithreaded( + numpy_assert, max_workers=8, arguments=(get_longley_data(),), pass_barrier=True + ) + + def safe_is_const(s): + try: + return np.ptp(s) == 0.0 and np.any(s != 0.0) + except Exception: + return False + + def concat(data, b): + b.wait() + x = data.copy() + nobs = len(x) + trendarr = np.fliplr(np.vander(np.arange(1, nobs + 1, dtype=np.float64), 1)) + x.apply(safe_is_const, 0) + trendarr = DataFrame(trendarr, index=x.index, columns=["const"]) + x = [trendarr, x] + x = pd.concat(x[::1], axis=1) + tm.assert_frame_equal(x, x) + + tm.run_multithreaded( + concat, max_workers=8, arguments=(get_longley_data(),), pass_barrier=True + ) diff --git a/pandas/tests/frame/test_constructors.py b/pandas/tests/frame/test_constructors.py new file mode 100644 index 0000000000000000000000000000000000000000..2368f75ec06cd76658d6a18776bbda3355ab301d --- /dev/null +++ b/pandas/tests/frame/test_constructors.py @@ -0,0 +1,3376 @@ +import array +from collections import ( + OrderedDict, + abc, + defaultdict, + namedtuple, +) +from collections.abc import Iterator +from dataclasses import make_dataclass +from datetime import ( + date, + datetime, + timedelta, +) +import functools +import re +import zoneinfo + +import numpy as np +from numpy import ma +from numpy.ma import mrecords +import pytest + +from pandas._libs import lib +from pandas.compat.numpy import np_version_gt2 +from pandas.errors import IntCastingNaNError + +from pandas.core.dtypes.common import is_integer_dtype +from pandas.core.dtypes.dtypes import ( + DatetimeTZDtype, + IntervalDtype, + NumpyEADtype, + PeriodDtype, +) + +import pandas as pd +from pandas import ( + Categorical, + CategoricalIndex, + DataFrame, + DatetimeIndex, + Index, + Interval, + MultiIndex, + Period, + RangeIndex, + Series, + Timedelta, + Timestamp, + cut, + date_range, + isna, +) +import pandas._testing as tm +from pandas.arrays import ( + DatetimeArray, + IntervalArray, + PeriodArray, + SparseArray, + TimedeltaArray, +) + +MIXED_FLOAT_DTYPES = ["float16", "float32", "float64"] +MIXED_INT_DTYPES = [ + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", +] + + +class TestDataFrameConstructors: + def test_constructor_from_ndarray_with_str_dtype(self): + # If we don't ravel/reshape around ensure_str_array, we end up + # with an array of strings each of which is e.g. "[0 1 2]" + arr = np.arange(12).reshape(4, 3) + df = DataFrame(arr, dtype=str) + expected = DataFrame(arr.astype(str), dtype="str") + tm.assert_frame_equal(df, expected) + + def test_constructor_from_2d_datetimearray(self): + dti = date_range("2016-01-01", periods=6, tz="US/Pacific") + dta = dti._data.reshape(3, 2) + + df = DataFrame(dta) + expected = DataFrame({0: dta[:, 0], 1: dta[:, 1]}) + tm.assert_frame_equal(df, expected) + # GH#44724 big performance hit if we de-consolidate + assert len(df._mgr.blocks) == 1 + + def test_constructor_dict_with_tzaware_scalar(self): + # GH#42505 + dt = Timestamp("2019-11-03 01:00:00-0700").tz_convert("America/Los_Angeles") + dt = dt.as_unit("ns") + + df = DataFrame({"dt": dt}, index=[0]) + expected = DataFrame({"dt": [dt]}) + tm.assert_frame_equal(df, expected, check_index_type=False) + + # Non-homogeneous + df = DataFrame({"dt": dt, "value": [1]}) + expected = DataFrame({"dt": [dt], "value": [1]}) + tm.assert_frame_equal(df, expected) + + def test_construct_ndarray_with_nas_and_int_dtype(self): + # GH#26919 match Series by not casting np.nan to meaningless int + arr = np.array([[1, np.nan], [2, 3]]) + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + with pytest.raises(IntCastingNaNError, match=msg): + DataFrame(arr, dtype="i8") + + # check this matches Series behavior + with pytest.raises(IntCastingNaNError, match=msg): + Series(arr[0], dtype="i8", name=0) + + def test_construct_from_list_of_datetimes(self): + df = DataFrame([datetime.now(), datetime.now()]) + assert df[0].dtype == np.dtype("M8[us]") + + def test_constructor_from_tzaware_datetimeindex(self): + # don't cast a DatetimeIndex WITH a tz, leave as object + # GH#6032 + naive = DatetimeIndex(["2013-1-1 13:00", "2013-1-2 14:00"], name="B") + idx = naive.tz_localize("US/Pacific") + + expected = Series(np.array(idx.tolist(), dtype="object"), name="B") + assert expected.dtype == idx.dtype + + # convert index to series + result = Series(idx) + tm.assert_series_equal(result, expected) + + def test_columns_with_leading_underscore_work_with_to_dict(self): + col_underscore = "_b" + df = DataFrame({"a": [1, 2], col_underscore: [3, 4]}) + d = df.to_dict(orient="records") + + ref_d = [{"a": 1, col_underscore: 3}, {"a": 2, col_underscore: 4}] + + assert ref_d == d + + def test_columns_with_leading_number_and_underscore_work_with_to_dict(self): + col_with_num = "1_b" + df = DataFrame({"a": [1, 2], col_with_num: [3, 4]}) + d = df.to_dict(orient="records") + + ref_d = [{"a": 1, col_with_num: 3}, {"a": 2, col_with_num: 4}] + + assert ref_d == d + + def test_array_of_dt64_nat_with_td64dtype_raises(self, frame_or_series): + # GH#39462 + nat = np.datetime64("NaT", "ns") + arr = np.array([nat], dtype=object) + if frame_or_series is DataFrame: + arr = arr.reshape(1, 1) + + msg = "Invalid type for timedelta scalar: " + with pytest.raises(TypeError, match=msg): + frame_or_series(arr, dtype="m8[ns]") + + @pytest.mark.parametrize("kind", ["m", "M"]) + def test_datetimelike_values_with_object_dtype(self, kind, frame_or_series): + # with dtype=object, we should cast dt64 values to Timestamps, not pydatetimes + if kind == "M": + dtype = "M8[ns]" + scalar_type = Timestamp + else: + dtype = "m8[ns]" + scalar_type = Timedelta + + arr = np.arange(6, dtype="i8").view(dtype).reshape(3, 2) + if frame_or_series is Series: + arr = arr[:, 0] + + obj = frame_or_series(arr, dtype=object) + assert obj._mgr.blocks[0].values.dtype == object + assert isinstance(obj._mgr.blocks[0].values.ravel()[0], scalar_type) + + # go through a different path in internals.construction + obj = frame_or_series(frame_or_series(arr), dtype=object) + assert obj._mgr.blocks[0].values.dtype == object + assert isinstance(obj._mgr.blocks[0].values.ravel()[0], scalar_type) + + obj = frame_or_series(frame_or_series(arr), dtype=NumpyEADtype(object)) + assert obj._mgr.blocks[0].values.dtype == object + assert isinstance(obj._mgr.blocks[0].values.ravel()[0], scalar_type) + + if frame_or_series is DataFrame: + # other paths through internals.construction + sers = [Series(x) for x in arr] + obj = frame_or_series(sers, dtype=object) + assert obj._mgr.blocks[0].values.dtype == object + assert isinstance(obj._mgr.blocks[0].values.ravel()[0], scalar_type) + + def test_series_with_name_not_matching_column(self): + # GH#9232 + x = Series(range(5), name=1) + y = Series(range(5), name=0) + + result = DataFrame(x, columns=[0]) + expected = DataFrame([], columns=[0]) + tm.assert_frame_equal(result, expected) + + result = DataFrame(y, columns=[1]) + expected = DataFrame([], columns=[1]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "constructor", + [ + lambda: DataFrame(), + lambda: DataFrame(None), + lambda: DataFrame(()), + lambda: DataFrame([]), + lambda: DataFrame(_ for _ in []), + lambda: DataFrame(range(0)), + lambda: DataFrame(data=None), + lambda: DataFrame(data=()), + lambda: DataFrame(data=[]), + lambda: DataFrame(data=(_ for _ in [])), + lambda: DataFrame(data=range(0)), + ], + ) + def test_empty_constructor(self, constructor): + expected = DataFrame() + result = constructor() + assert len(result.index) == 0 + assert len(result.columns) == 0 + tm.assert_frame_equal(result, expected) + + def test_empty_constructor_object_index(self): + expected = DataFrame(index=RangeIndex(0), columns=RangeIndex(0)) + result = DataFrame({}) + assert len(result.index) == 0 + assert len(result.columns) == 0 + tm.assert_frame_equal(result, expected, check_index_type=True) + + @pytest.mark.parametrize( + "emptylike,expected_index,expected_columns", + [ + ([[]], RangeIndex(1), RangeIndex(0)), + ([[], []], RangeIndex(2), RangeIndex(0)), + ([(_ for _ in [])], RangeIndex(1), RangeIndex(0)), + ], + ) + def test_emptylike_constructor(self, emptylike, expected_index, expected_columns): + expected = DataFrame(index=expected_index, columns=expected_columns) + result = DataFrame(emptylike) + tm.assert_frame_equal(result, expected) + + def test_constructor_mixed(self, float_string_frame, using_infer_string): + dtype = "str" if using_infer_string else np.object_ + assert float_string_frame["foo"].dtype == dtype + + def test_constructor_cast_failure(self): + # as of 2.0, we raise if we can't respect "dtype", previously we + # silently ignored + msg = "could not convert string to float" + with pytest.raises(ValueError, match=msg): + DataFrame({"a": ["a", "b", "c"]}, dtype=np.float64) + + # GH 3010, constructing with odd arrays + df = DataFrame(np.ones((4, 2))) + + # this is ok + df["foo"] = np.ones((4, 2)).tolist() + + # this is not ok + msg = "Expected a 1D array, got an array with shape \\(4, 2\\)" + with pytest.raises(ValueError, match=msg): + df["test"] = np.ones((4, 2)) + + # this is ok + df["foo2"] = np.ones((4, 2)).tolist() + + def test_constructor_dtype_copy(self): + orig_df = DataFrame({"col1": [1.0], "col2": [2.0], "col3": [3.0]}) + + new_df = DataFrame(orig_df, dtype=float, copy=True) + + new_df["col1"] = 200.0 + assert orig_df["col1"][0] == 1.0 + + def test_constructor_dtype_nocast_view_dataframe(self): + df = DataFrame([[1, 2]]) + should_be_view = DataFrame(df, dtype=df[0].dtype) + should_be_view.iloc[0, 0] = 99 + assert df.values[0, 0] == 1 + + def test_constructor_dtype_nocast_view_2d_array(self): + df = DataFrame([[1, 2], [3, 4]], dtype="int64") + df2 = DataFrame(df.values, dtype=df[0].dtype) + assert df2._mgr.blocks[0].values.flags.c_contiguous + + def test_1d_object_array_does_not_copy(self, using_infer_string): + # https://github.com/pandas-dev/pandas/issues/39272 + arr = np.array(["a", "b"], dtype="object") + df = DataFrame(arr, copy=False) + if using_infer_string: + if df[0].dtype.storage == "pyarrow": + # object dtype strings are converted to arrow memory, + # no numpy arrays to compare + pass + else: + assert np.shares_memory(df[0].to_numpy(), arr) + else: + assert np.shares_memory(df.values, arr) + + df = DataFrame(arr, dtype=object, copy=False) + assert np.shares_memory(df.values, arr) + + def test_2d_object_array_does_not_copy(self, using_infer_string): + # https://github.com/pandas-dev/pandas/issues/39272 + arr = np.array([["a", "b"], ["c", "d"]], dtype="object") + df = DataFrame(arr, copy=False) + if using_infer_string: + if df[0].dtype.storage == "pyarrow": + # object dtype strings are converted to arrow memory, + # no numpy arrays to compare + pass + else: + assert np.shares_memory(df[0].to_numpy(), arr) + else: + assert np.shares_memory(df.values, arr) + + df = DataFrame(arr, dtype=object, copy=False) + assert np.shares_memory(df.values, arr) + + def test_constructor_dtype_list_data(self): + df = DataFrame([[1, "2"], [None, "a"]], dtype=object) + assert df.loc[1, 0] is None + assert df.loc[0, 1] == "2" + + def test_constructor_list_of_2d_raises(self): + # https://github.com/pandas-dev/pandas/issues/32289 + a = DataFrame() + b = np.empty((0, 0)) + with pytest.raises(ValueError, match=r"shape=\(1, 0, 0\)"): + DataFrame([a]) + + with pytest.raises(ValueError, match=r"shape=\(1, 0, 0\)"): + DataFrame([b]) + + a = DataFrame({"A": [1, 2]}) + with pytest.raises(ValueError, match=r"shape=\(2, 2, 1\)"): + DataFrame([a, a]) + + @pytest.mark.parametrize( + "typ, ad", + [ + # mixed floating and integer coexist in the same frame + ["float", {}], + # add lots of types + ["float", {"A": 1, "B": "foo", "C": "bar"}], + # GH 622 + ["int", {}], + ], + ) + def test_constructor_mixed_dtypes(self, typ, ad): + if typ == "int": + dtypes = MIXED_INT_DTYPES + arrays = [ + np.array(np.random.default_rng(2).random(10), dtype=d) for d in dtypes + ] + elif typ == "float": + dtypes = MIXED_FLOAT_DTYPES + arrays = [ + np.array(np.random.default_rng(2).integers(10, size=10), dtype=d) + for d in dtypes + ] + + for d, a in zip(dtypes, arrays): + assert a.dtype == d + ad.update(dict(zip(dtypes, arrays))) + df = DataFrame(ad) + + dtypes = MIXED_FLOAT_DTYPES + MIXED_INT_DTYPES + for d in dtypes: + if d in df: + assert df.dtypes[d] == d + + def test_constructor_complex_dtypes(self): + # GH10952 + a = np.random.default_rng(2).random(10).astype(np.complex64) + b = np.random.default_rng(2).random(10).astype(np.complex128) + + df = DataFrame({"a": a, "b": b}) + assert a.dtype == df.a.dtype + assert b.dtype == df.b.dtype + + def test_constructor_dtype_str_na_values(self, string_dtype): + # https://github.com/pandas-dev/pandas/issues/21083 + df = DataFrame({"A": ["x", None]}, dtype=string_dtype) + result = df.isna() + expected = DataFrame({"A": [False, True]}) + tm.assert_frame_equal(result, expected) + assert df.iloc[1, 0] is None + + df = DataFrame({"A": ["x", np.nan]}, dtype=string_dtype) + assert np.isnan(df.iloc[1, 0]) + + def test_constructor_rec(self, float_frame): + rec = float_frame.to_records(index=False) + rec.dtype.names = list(rec.dtype.names)[::-1] + + index = float_frame.index + + df = DataFrame(rec) + tm.assert_index_equal(df.columns, Index(rec.dtype.names)) + + df2 = DataFrame(rec, index=index) + tm.assert_index_equal(df2.columns, Index(rec.dtype.names)) + tm.assert_index_equal(df2.index, index) + + # case with columns != the ones we would infer from the data + rng = np.arange(len(rec))[::-1] + df3 = DataFrame(rec, index=rng, columns=["C", "B"]) + expected = DataFrame(rec, index=rng).reindex(columns=["C", "B"]) + tm.assert_frame_equal(df3, expected) + + def test_constructor_bool(self): + df = DataFrame({0: np.ones(10, dtype=bool), 1: np.zeros(10, dtype=bool)}) + assert df.values.dtype == np.bool_ + + def test_constructor_overflow_int64(self): + # see gh-14881 + values = np.array([2**64 - i for i in range(1, 10)], dtype=np.uint64) + + result = DataFrame({"a": values}) + assert result["a"].dtype == np.uint64 + + # see gh-2355 + data_scores = [ + (6311132704823138710, 273), + (2685045978526272070, 23), + (8921811264899370420, 45), + (17019687244989530680, 270), + (9930107427299601010, 273), + ] + dtype = [("uid", "u8"), ("score", "u8")] + data = np.zeros((len(data_scores),), dtype=dtype) + data[:] = data_scores + df_crawls = DataFrame(data) + assert df_crawls["uid"].dtype == np.uint64 + + @pytest.mark.parametrize( + "values", + [ + np.array([2**64], dtype=object), + np.array([2**65]), + [2**64 + 1], + np.array([-(2**63) - 4], dtype=object), + np.array([-(2**64) - 1]), + [-(2**65) - 2], + ], + ) + def test_constructor_int_overflow(self, values): + # see gh-18584 + value = values[0] + result = DataFrame(values) + + assert result[0].dtype == object + assert result[0][0] == value + + @pytest.mark.parametrize( + "values", + [ + np.array([1], dtype=np.uint16), + np.array([1], dtype=np.uint32), + np.array([1], dtype=np.uint64), + [np.uint16(1)], + [np.uint32(1)], + [np.uint64(1)], + ], + ) + def test_constructor_numpy_uints(self, values): + # GH#47294 + value = values[0] + result = DataFrame(values) + + assert result[0].dtype == value.dtype + assert result[0][0] == value + + def test_constructor_ordereddict(self): + nitems = 100 + nums = list(range(nitems)) + np.random.default_rng(2).shuffle(nums) + expected = [f"A{i:d}" for i in nums] + df = DataFrame(OrderedDict(zip(expected, [[0]] * nitems))) + assert expected == list(df.columns) + + def test_constructor_dict(self): + datetime_series = Series( + np.arange(30, dtype=np.float64), index=date_range("2020-01-01", periods=30) + ) + # test expects index shifted by 5 + datetime_series_short = datetime_series[5:] + + frame = DataFrame({"col1": datetime_series, "col2": datetime_series_short}) + + # col2 is padded with NaN + assert len(datetime_series) == 30 + assert len(datetime_series_short) == 25 + + tm.assert_series_equal(frame["col1"], datetime_series.rename("col1")) + + exp = Series( + np.concatenate([[np.nan] * 5, datetime_series_short.values]), + index=datetime_series.index, + name="col2", + ) + tm.assert_series_equal(exp, frame["col2"]) + + frame = DataFrame( + {"col1": datetime_series, "col2": datetime_series_short}, + columns=["col2", "col3", "col4"], + ) + + assert len(frame) == len(datetime_series_short) + assert "col1" not in frame + assert isna(frame["col3"]).all() + + # Corner cases + assert len(DataFrame()) == 0 + + # mix dict and array, wrong size - no spec for which error should raise + # first + msg = "Mixing dicts with non-Series may lead to ambiguous ordering." + with pytest.raises(ValueError, match=msg): + DataFrame({"A": {"a": "a", "b": "b"}, "B": ["a", "b", "c"]}) + + def test_constructor_dict_length1(self): + # Length-one dict micro-optimization + frame = DataFrame({"A": {"1": 1, "2": 2}}) + tm.assert_index_equal(frame.index, Index(["1", "2"])) + + def test_constructor_dict_with_index(self): + # empty dict plus index + idx = Index([0, 1, 2]) + frame = DataFrame({}, index=idx) + assert frame.index is idx + + def test_constructor_dict_with_index_and_columns(self): + # empty dict with index and columns + idx = Index([0, 1, 2]) + frame = DataFrame({}, index=idx, columns=idx) + assert frame.index is idx + assert frame.columns is idx + assert len(frame._series) == 3 + + def test_constructor_dict_of_empty_lists(self): + # with dict of empty list and Series + frame = DataFrame({"A": [], "B": []}, columns=["A", "B"]) + tm.assert_index_equal(frame.index, RangeIndex(0), exact=True) + + def test_constructor_dict_with_none(self): + # GH 14381 + # Dict with None value + frame_none = DataFrame({"a": None}, index=[0]) + frame_none_list = DataFrame({"a": [None]}, index=[0]) + assert frame_none._get_value(0, "a") is None + assert frame_none_list._get_value(0, "a") is None + tm.assert_frame_equal(frame_none, frame_none_list) + + def test_constructor_dict_errors(self): + # GH10856 + # dict with scalar values should raise error, even if columns passed + msg = "If using all scalar values, you must pass an index" + with pytest.raises(ValueError, match=msg): + DataFrame({"a": 0.7}) + + with pytest.raises(ValueError, match=msg): + DataFrame({"a": 0.7}, columns=["a"]) + + @pytest.mark.parametrize("scalar", [2, np.nan, None, "D"]) + def test_constructor_invalid_items_unused(self, scalar): + # No error if invalid (scalar) value is in fact not used: + result = DataFrame({"a": scalar}, columns=["b"]) + expected = DataFrame(columns=["b"]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("value", [4, np.nan, None, float("nan")]) + def test_constructor_dict_nan_key(self, value): + # GH 18455 + cols = [1, value, 3] + idx = ["a", value] + values = [[0, 3], [1, 4], [2, 5]] + data = {cols[c]: Series(values[c], index=idx) for c in range(3)} + result = DataFrame(data).sort_values(1).sort_values("a", axis=1) + expected = DataFrame( + np.arange(6, dtype="int64").reshape(2, 3), index=idx, columns=cols + ) + tm.assert_frame_equal(result, expected) + + result = DataFrame(data, index=idx).sort_values("a", axis=1) + tm.assert_frame_equal(result, expected) + + result = DataFrame(data, index=idx, columns=cols) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("value", [np.nan, None, float("nan")]) + def test_constructor_dict_nan_tuple_key(self, value): + # GH 18455 + cols = Index([(11, 21), (value, 22), (13, value)]) + idx = Index([("a", value), (value, 2)]) + values = [[0, 3], [1, 4], [2, 5]] + data = {cols[c]: Series(values[c], index=idx) for c in range(3)} + result = DataFrame(data).sort_values((11, 21)).sort_values(("a", value), axis=1) + expected = DataFrame( + np.arange(6, dtype="int64").reshape(2, 3), index=idx, columns=cols + ) + tm.assert_frame_equal(result, expected) + + result = DataFrame(data, index=idx).sort_values(("a", value), axis=1) + tm.assert_frame_equal(result, expected) + + result = DataFrame(data, index=idx, columns=cols) + tm.assert_frame_equal(result, expected) + + def test_constructor_dict_order_insertion(self): + datetime_series = Series( + np.arange(10, dtype=np.float64), index=date_range("2020-01-01", periods=10) + ) + datetime_series_short = datetime_series[:5] + + # GH19018 + # initialization ordering: by insertion order if python>= 3.6 + d = {"b": datetime_series_short, "a": datetime_series} + frame = DataFrame(data=d) + expected = DataFrame(data=d, columns=list("ba")) + tm.assert_frame_equal(frame, expected) + + def test_constructor_dict_nan_key_and_columns(self): + # GH 16894 + result = DataFrame({np.nan: [1, 2], 2: [2, 3]}, columns=[np.nan, 2]) + expected = DataFrame([[1, 2], [2, 3]], columns=[np.nan, 2]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("missing_value", [None, np.nan, pd.NA]) + def test_constructor_list_of_dict_with_str_na_key( + self, missing_value, using_infer_string + ): + # https://github.com/pandas-dev/pandas/issues/63889 + # preserve values when None key is converted to NaN column name + dict_data = [ + {"colA": 1, missing_value: 2}, + {"colA": 3, missing_value: 4}, + ] + result = DataFrame(dict_data) + expected = DataFrame( + [[1, 2], [3, 4]], + columns=["colA", np.nan if using_infer_string else missing_value], + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("missing_value", [None, np.nan, pd.NA]) + def test_constructor_dict_of_dict_with_str_na_key( + self, missing_value, using_infer_string + ): + # https://github.com/pandas-dev/pandas/issues/63889 + dict_data = {"col": {"row1": 1, missing_value: 2, "row3": 3}} + result = DataFrame(dict_data) + expected = DataFrame( + {"col": [1, 2, 3]}, + index=Index( + ["row1", np.nan if using_infer_string else missing_value, "row3"] + ), + ) + tm.assert_frame_equal(result, expected) + + def test_constructor_multi_index(self): + # GH 4078 + # construction error with mi and all-nan frame + tuples = [(2, 3), (3, 3), (3, 3)] + mi = MultiIndex.from_tuples(tuples) + df = DataFrame(index=mi, columns=mi) + assert isna(df).values.ravel().all() + + tuples = [(3, 3), (2, 3), (3, 3)] + mi = MultiIndex.from_tuples(tuples) + df = DataFrame(index=mi, columns=mi) + assert isna(df).values.ravel().all() + + def test_constructor_2d_index(self): + # GH 25416 + # handling of 2d index in construction + df = DataFrame([[1]], columns=[[1]], index=[1, 2]) + expected = DataFrame( + [1, 1], + index=Index([1, 2], dtype="int64"), + columns=MultiIndex(levels=[[1]], codes=[[0]]), + ) + tm.assert_frame_equal(df, expected) + + df = DataFrame([[1]], columns=[[1]], index=[[1, 2]]) + expected = DataFrame( + [1, 1], + index=MultiIndex(levels=[[1, 2]], codes=[[0, 1]]), + columns=MultiIndex(levels=[[1]], codes=[[0]]), + ) + tm.assert_frame_equal(df, expected) + + def test_constructor_error_msgs(self): + msg = "Empty data passed with indices specified." + # passing an empty array with columns specified. + with pytest.raises(ValueError, match=msg): + DataFrame(np.empty(0), index=[1]) + + msg = "Mixing dicts with non-Series may lead to ambiguous ordering." + # mix dict and array, wrong size + with pytest.raises(ValueError, match=msg): + DataFrame({"A": {"a": "a", "b": "b"}, "B": ["a", "b", "c"]}) + + # wrong size ndarray, GH 3105 + msg = r"Shape of passed values is \(4, 3\), indices imply \(3, 3\)" + with pytest.raises(ValueError, match=msg): + DataFrame( + np.arange(12).reshape((4, 3)), + columns=["foo", "bar", "baz"], + index=date_range("2000-01-01", periods=3), + ) + + arr = np.array([[4, 5, 6]]) + msg = r"Shape of passed values is \(1, 3\), indices imply \(1, 4\)" + with pytest.raises(ValueError, match=msg): + DataFrame(index=[0], columns=range(4), data=arr) + + arr = np.array([4, 5, 6]) + msg = r"Shape of passed values is \(3, 1\), indices imply \(1, 4\)" + with pytest.raises(ValueError, match=msg): + DataFrame(index=[0], columns=range(4), data=arr) + + # higher dim raise exception + with pytest.raises(ValueError, match="Must pass 2-d input"): + DataFrame(np.zeros((3, 3, 3)), columns=["A", "B", "C"], index=[1]) + + # wrong size axis labels + msg = r"Shape of passed values is \(2, 3\), indices imply \(1, 3\)" + with pytest.raises(ValueError, match=msg): + DataFrame( + np.random.default_rng(2).random((2, 3)), + columns=["A", "B", "C"], + index=[1], + ) + + msg = r"Shape of passed values is \(2, 3\), indices imply \(2, 2\)" + with pytest.raises(ValueError, match=msg): + DataFrame( + np.random.default_rng(2).random((2, 3)), + columns=["A", "B"], + index=[1, 2], + ) + + # gh-26429 + msg = "2 columns passed, passed data had 10 columns" + with pytest.raises(ValueError, match=msg): + DataFrame((range(10), range(10, 20)), columns=("ones", "twos")) + + msg = "If using all scalar values, you must pass an index" + with pytest.raises(ValueError, match=msg): + DataFrame({"a": False, "b": True}) + + def test_constructor_subclass_dict(self, dict_subclass): + # Test for passing dict subclass to constructor + data = { + "col1": dict_subclass((x, 10.0 * x) for x in range(10)), + "col2": dict_subclass((x, 20.0 * x) for x in range(10)), + } + df = DataFrame(data) + refdf = DataFrame({col: dict(val.items()) for col, val in data.items()}) + tm.assert_frame_equal(refdf, df) + + data = dict_subclass(data.items()) + df = DataFrame(data) + tm.assert_frame_equal(refdf, df) + + def test_constructor_defaultdict(self, float_frame): + # try with defaultdict + data = {} + float_frame.loc[: float_frame.index[10], "B"] = np.nan + + for k, v in float_frame.items(): + dct = defaultdict(dict) + dct.update(v.to_dict()) + data[k] = dct + frame = DataFrame(data) + expected = frame.reindex(index=float_frame.index) + tm.assert_frame_equal(float_frame, expected) + + def test_constructor_dict_block(self): + expected = np.array([[4.0, 3.0, 2.0, 1.0]]) + df = DataFrame( + {"d": [4.0], "c": [3.0], "b": [2.0], "a": [1.0]}, + columns=["d", "c", "b", "a"], + ) + tm.assert_numpy_array_equal(df.values, expected) + + def test_constructor_dict_cast(self, using_infer_string): + # cast float tests + test_data = {"A": {"1": 1, "2": 2}, "B": {"1": "1", "2": "2", "3": "3"}} + frame = DataFrame(test_data, dtype=float) + assert len(frame) == 3 + assert frame["B"].dtype == np.float64 + assert frame["A"].dtype == np.float64 + + frame = DataFrame(test_data) + assert len(frame) == 3 + assert frame["B"].dtype == np.object_ if not using_infer_string else "str" + assert frame["A"].dtype == np.float64 + + def test_constructor_dict_cast2(self): + # can't cast to float + test_data = { + "A": dict(zip(range(20), [f"word_{i}" for i in range(20)])), + "B": dict(zip(range(15), np.random.default_rng(2).standard_normal(15))), + } + with pytest.raises(ValueError, match="could not convert string"): + DataFrame(test_data, dtype=float) + + def test_constructor_dict_dont_upcast(self): + d = {"Col1": {"Row1": "A String", "Row2": np.nan}} + df = DataFrame(d) + assert isinstance(df["Col1"]["Row2"], float) + + def test_constructor_dict_dont_upcast2(self): + dm = DataFrame([[1, 2], ["a", "b"]], index=[1, 2], columns=[1, 2]) + assert isinstance(dm[1][1], int) + + def test_constructor_dict_of_tuples(self): + # GH #1491 + data = {"a": (1, 2, 3), "b": (4, 5, 6)} + + result = DataFrame(data) + expected = DataFrame({k: list(v) for k, v in data.items()}) + tm.assert_frame_equal(result, expected, check_dtype=False) + + def test_constructor_dict_of_ranges(self): + # GH 26356 + data = {"a": range(3), "b": range(3, 6)} + + result = DataFrame(data) + expected = DataFrame({"a": [0, 1, 2], "b": [3, 4, 5]}) + tm.assert_frame_equal(result, expected) + + def test_constructor_dict_of_iterators(self): + # GH 26349 + data = {"a": iter(range(3)), "b": reversed(range(3))} + + result = DataFrame(data) + expected = DataFrame({"a": [0, 1, 2], "b": [2, 1, 0]}) + tm.assert_frame_equal(result, expected) + + def test_constructor_dict_of_generators(self): + # GH 26349 + data = {"a": (i for i in (range(3))), "b": (i for i in reversed(range(3)))} + result = DataFrame(data) + expected = DataFrame({"a": [0, 1, 2], "b": [2, 1, 0]}) + tm.assert_frame_equal(result, expected) + + def test_constructor_dict_multiindex(self): + d = { + ("a", "a"): {("i", "i"): 0, ("i", "j"): 1, ("j", "i"): 2}, + ("b", "a"): {("i", "i"): 6, ("i", "j"): 5, ("j", "i"): 4}, + ("b", "c"): {("i", "i"): 7, ("i", "j"): 8, ("j", "i"): 9}, + } + _d = sorted(d.items()) + df = DataFrame(d) + expected = DataFrame( + [x[1] for x in _d], index=MultiIndex.from_tuples([x[0] for x in _d]) + ).T + expected.index = MultiIndex.from_tuples(expected.index) + tm.assert_frame_equal( + df, + expected, + ) + + d["z"] = {"y": 123.0, ("i", "i"): 111, ("i", "j"): 111, ("j", "i"): 111} + _d.insert(0, ("z", d["z"])) + expected = DataFrame( + [x[1] for x in _d], index=Index([x[0] for x in _d], tupleize_cols=False) + ).T + expected.index = Index(expected.index, tupleize_cols=False) + df = DataFrame(d) + df = df.reindex(columns=expected.columns, index=expected.index) + tm.assert_frame_equal(df, expected) + + def test_constructor_dict_datetime64_index(self): + # GH 10160 + dates_as_str = ["1984-02-19", "1988-11-06", "1989-12-03", "1990-03-15"] + + def create_data(constructor): + return {i: {constructor(s): 2 * i} for i, s in enumerate(dates_as_str)} + + data_datetime64 = create_data(np.datetime64) + data_datetime = create_data(lambda x: datetime.strptime(x, "%Y-%m-%d")) + data_Timestamp = create_data(Timestamp) + + expected = DataFrame( + [ + [0, None, None, None], + [None, 2, None, None], + [None, None, 4, None], + [None, None, None, 6], + ], + index=[Timestamp(dt) for dt in dates_as_str], + ) + + result_datetime64 = DataFrame(data_datetime64) + assert result_datetime64.index.unit == "s" + result_datetime64.index = result_datetime64.index.as_unit("us") + result_datetime = DataFrame(data_datetime) + assert result_datetime.index.unit == "us" + result_Timestamp = DataFrame(data_Timestamp) + tm.assert_frame_equal(result_datetime64, expected) + tm.assert_frame_equal(result_datetime, expected) + tm.assert_frame_equal(result_Timestamp, expected) + + @pytest.mark.parametrize( + "klass,exp_dtype", + [ + (lambda x: np.timedelta64(x, "D"), "m8[s]"), + (lambda x: timedelta(days=x), "m8[us]"), + (lambda x: Timedelta(x, "D"), "m8[s]"), + (lambda x: Timedelta(x, "D").as_unit("ms"), "m8[ms]"), + ], + ) + def test_constructor_dict_timedelta64_index(self, klass, exp_dtype): + # GH 10160 + td_as_int = [1, 2, 3, 4] + + data = {i: {klass(s): 2 * i} for i, s in enumerate(td_as_int)} + + expected = DataFrame( + [ + {0: 0, 1: None, 2: None, 3: None}, + {0: None, 1: 2, 2: None, 3: None}, + {0: None, 1: None, 2: 4, 3: None}, + {0: None, 1: None, 2: None, 3: 6}, + ], + index=[Timedelta(td, "D") for td in td_as_int], + ) + expected.index = expected.index.astype(exp_dtype) + + result = DataFrame(data) + + tm.assert_frame_equal(result, expected) + + def test_constructor_period_dict(self): + # PeriodIndex + a = pd.PeriodIndex(["2012-01", "NaT", "2012-04"], freq="M") + b = pd.PeriodIndex(["2012-02-01", "2012-03-01", "NaT"], freq="D") + df = DataFrame({"a": a, "b": b}) + assert df["a"].dtype == a.dtype + assert df["b"].dtype == b.dtype + + # list of periods + df = DataFrame({"a": a.astype(object).tolist(), "b": b.astype(object).tolist()}) + assert df["a"].dtype == a.dtype + assert df["b"].dtype == b.dtype + + def test_constructor_dict_extension_scalar(self, ea_scalar_and_dtype): + ea_scalar, ea_dtype = ea_scalar_and_dtype + df = DataFrame({"a": ea_scalar}, index=[0]) + assert df["a"].dtype == ea_dtype + + expected = DataFrame(index=[0], columns=["a"], data=ea_scalar) + + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( + "data,dtype", + [ + (Period("2020-01"), PeriodDtype("M")), + (Interval(left=0, right=5), IntervalDtype("int64", "right")), + ( + Timestamp("2011-01-01", tz="US/Eastern").as_unit("s"), + DatetimeTZDtype(unit="s", tz="US/Eastern"), + ), + ], + ) + def test_constructor_extension_scalar_data(self, data, dtype): + # GH 34832 + df = DataFrame(index=range(2), columns=["a", "b"], data=data) + + assert df["a"].dtype == dtype + assert df["b"].dtype == dtype + + arr = pd.array([data] * 2, dtype=dtype) + expected = DataFrame({"a": arr, "b": arr}) + + tm.assert_frame_equal(df, expected) + + def test_nested_dict_frame_constructor(self): + rng = pd.period_range("1/1/2000", periods=5) + df = DataFrame(np.random.default_rng(2).standard_normal((10, 5)), columns=rng) + + data = {} + for col in df.columns: + for row in df.index: + data.setdefault(col, {})[row] = df._get_value(row, col) + + result = DataFrame(data, columns=rng) + tm.assert_frame_equal(result, df) + + data = {} + for col in df.columns: + for row in df.index: + data.setdefault(row, {})[col] = df._get_value(row, col) + + result = DataFrame(data, index=rng).T + tm.assert_frame_equal(result, df) + + def _check_basic_constructor(self, empty): + # mat: 2d matrix with shape (3, 2) to input. empty - makes sized + # objects + mat = empty((2, 3), dtype=float) + # 2-D input + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2]) + + assert len(frame.index) == 2 + assert len(frame.columns) == 3 + + # 1-D input + frame = DataFrame(empty((3,)), columns=["A"], index=[1, 2, 3]) + assert len(frame.index) == 3 + assert len(frame.columns) == 1 + + if empty is not np.ones: + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + with pytest.raises(IntCastingNaNError, match=msg): + DataFrame(mat, columns=["A", "B", "C"], index=[1, 2], dtype=np.int64) + return + else: + frame = DataFrame( + mat, columns=["A", "B", "C"], index=[1, 2], dtype=np.int64 + ) + assert frame.values.dtype == np.int64 + + # wrong size axis labels + msg = r"Shape of passed values is \(2, 3\), indices imply \(1, 3\)" + with pytest.raises(ValueError, match=msg): + DataFrame(mat, columns=["A", "B", "C"], index=[1]) + msg = r"Shape of passed values is \(2, 3\), indices imply \(2, 2\)" + with pytest.raises(ValueError, match=msg): + DataFrame(mat, columns=["A", "B"], index=[1, 2]) + + # higher dim raise exception + with pytest.raises(ValueError, match="Must pass 2-d input"): + DataFrame(empty((3, 3, 3)), columns=["A", "B", "C"], index=[1]) + + # automatic labeling + frame = DataFrame(mat) + tm.assert_index_equal(frame.index, Index(range(2)), exact=True) + tm.assert_index_equal(frame.columns, Index(range(3)), exact=True) + + frame = DataFrame(mat, index=[1, 2]) + tm.assert_index_equal(frame.columns, Index(range(3)), exact=True) + + frame = DataFrame(mat, columns=["A", "B", "C"]) + tm.assert_index_equal(frame.index, Index(range(2)), exact=True) + + # 0-length axis + frame = DataFrame(empty((0, 3))) + assert len(frame.index) == 0 + + frame = DataFrame(empty((3, 0))) + assert len(frame.columns) == 0 + + def test_constructor_ndarray(self): + self._check_basic_constructor(np.ones) + + frame = DataFrame(["foo", "bar"], index=[0, 1], columns=["A"]) + assert len(frame) == 2 + + def test_constructor_maskedarray(self): + self._check_basic_constructor(ma.masked_all) + + # Check non-masked values + mat = ma.masked_all((2, 3), dtype=float) + mat[0, 0] = 1.0 + mat[1, 2] = 2.0 + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2]) + assert 1.0 == frame["A"][1] + assert 2.0 == frame["C"][2] + + # what is this even checking?? + mat = ma.masked_all((2, 3), dtype=float) + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2]) + assert np.all(~np.asarray(frame == frame)) + + @pytest.mark.filterwarnings( + "ignore:elementwise comparison failed:DeprecationWarning" + ) + def test_constructor_maskedarray_nonfloat(self): + # masked int promoted to float + mat = ma.masked_all((2, 3), dtype=int) + # 2-D input + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2]) + + assert len(frame.index) == 2 + assert len(frame.columns) == 3 + assert np.all(~np.asarray(frame == frame)) + + # cast type + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2], dtype=np.float64) + assert frame.values.dtype == np.float64 + + # Check non-masked values + mat2 = ma.copy(mat) + mat2[0, 0] = 1 + mat2[1, 2] = 2 + frame = DataFrame(mat2, columns=["A", "B", "C"], index=[1, 2]) + assert 1 == frame["A"][1] + assert 2 == frame["C"][2] + + # masked np.datetime64 stays (use NaT as null) + mat = ma.masked_all((2, 3), dtype="M8[ns]") + # 2-D input + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2]) + + assert len(frame.index) == 2 + assert len(frame.columns) == 3 + assert isna(frame).values.all() + + # cast type + msg = r"datetime64\[ns\] values and dtype=int64 is not supported" + with pytest.raises(TypeError, match=msg): + DataFrame(mat, columns=["A", "B", "C"], index=[1, 2], dtype=np.int64) + + # Check non-masked values + mat2 = ma.copy(mat) + mat2[0, 0] = 1 + mat2[1, 2] = 2 + frame = DataFrame(mat2, columns=["A", "B", "C"], index=[1, 2]) + assert 1 == frame["A"].astype("i8")[1] + assert 2 == frame["C"].astype("i8")[2] + + # masked bool promoted to object + mat = ma.masked_all((2, 3), dtype=bool) + # 2-D input + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2]) + + assert len(frame.index) == 2 + assert len(frame.columns) == 3 + assert np.all(~np.asarray(frame == frame)) + + # cast type + frame = DataFrame(mat, columns=["A", "B", "C"], index=[1, 2], dtype=object) + assert frame.values.dtype == object + + # Check non-masked values + mat2 = ma.copy(mat) + mat2[0, 0] = True + mat2[1, 2] = False + frame = DataFrame(mat2, columns=["A", "B", "C"], index=[1, 2]) + assert frame["A"][1] is True + assert frame["C"][2] is False + + def test_constructor_maskedarray_hardened(self): + # Check numpy masked arrays with hard masks -- from GH24574 + mat_hard = ma.masked_all((2, 2), dtype=float).harden_mask() + result = DataFrame(mat_hard, columns=["A", "B"], index=[1, 2]) + expected = DataFrame( + {"A": [np.nan, np.nan], "B": [np.nan, np.nan]}, + columns=["A", "B"], + index=[1, 2], + dtype=float, + ) + tm.assert_frame_equal(result, expected) + # Check case where mask is hard but no data are masked + mat_hard = ma.ones((2, 2), dtype=float).harden_mask() + result = DataFrame(mat_hard, columns=["A", "B"], index=[1, 2]) + expected = DataFrame( + {"A": [1.0, 1.0], "B": [1.0, 1.0]}, + columns=["A", "B"], + index=[1, 2], + dtype=float, + ) + tm.assert_frame_equal(result, expected) + + def test_constructor_maskedrecarray_dtype(self): + # Ensure constructor honors dtype + data = np.ma.array( + np.ma.zeros(5, dtype=[("date", " None: + self._lst = lst + + def __getitem__(self, n): + return self._lst.__getitem__(n) + + def __len__(self) -> int: + return self._lst.__len__() + + lst_containers = [DummyContainer([1, "a"]), DummyContainer([2, "b"])] + columns = ["num", "str"] + result = DataFrame(lst_containers, columns=columns) + expected = DataFrame([[1, "a"], [2, "b"]], columns=columns) + tm.assert_frame_equal(result, expected, check_dtype=False) + + def test_constructor_stdlib_array(self): + # GH 4297 + # support Array + result = DataFrame({"A": array.array("i", range(10))}) + expected = DataFrame({"A": list(range(10))}) + tm.assert_frame_equal(result, expected, check_dtype=False) + + expected = DataFrame([list(range(10)), list(range(10))]) + result = DataFrame([array.array("i", range(10)), array.array("i", range(10))]) + tm.assert_frame_equal(result, expected, check_dtype=False) + + def test_constructor_range(self): + # GH26342 + result = DataFrame(range(10)) + expected = DataFrame(list(range(10))) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_ranges(self): + result = DataFrame([range(10), range(10)]) + expected = DataFrame([list(range(10)), list(range(10))]) + tm.assert_frame_equal(result, expected) + + def test_constructor_iterable(self): + # GH 21987 + class Iter: + def __iter__(self) -> Iterator: + for i in range(10): + yield [1, 2, 3] + + expected = DataFrame([[1, 2, 3]] * 10) + result = DataFrame(Iter()) + tm.assert_frame_equal(result, expected) + + def test_constructor_iterator(self): + result = DataFrame(iter(range(10))) + expected = DataFrame(list(range(10))) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_iterators(self): + result = DataFrame([iter(range(10)), iter(range(10))]) + expected = DataFrame([list(range(10)), list(range(10))]) + tm.assert_frame_equal(result, expected) + + def test_constructor_generator(self): + # related #2305 + + gen1 = (i for i in range(10)) + gen2 = (i for i in range(10)) + + expected = DataFrame([list(range(10)), list(range(10))]) + result = DataFrame([gen1, gen2]) + tm.assert_frame_equal(result, expected) + + gen = ([i, "a"] for i in range(10)) + result = DataFrame(gen) + expected = DataFrame({0: range(10), 1: "a"}) + tm.assert_frame_equal(result, expected, check_dtype=False) + + def test_constructor_list_of_dicts(self): + result = DataFrame([{}]) + expected = DataFrame(index=RangeIndex(1), columns=[]) + tm.assert_frame_equal(result, expected) + + def test_constructor_ordered_dict_nested_preserve_order(self): + # see gh-18166 + nested1 = OrderedDict([("b", 1), ("a", 2)]) + nested2 = OrderedDict([("b", 2), ("a", 5)]) + data = OrderedDict([("col2", nested1), ("col1", nested2)]) + result = DataFrame(data) + data = {"col2": [1, 2], "col1": [2, 5]} + expected = DataFrame(data=data, index=["b", "a"]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dict_type", [dict, OrderedDict]) + def test_constructor_ordered_dict_preserve_order(self, dict_type): + # see gh-13304 + expected = DataFrame([[2, 1]], columns=["b", "a"]) + + data = dict_type() + data["b"] = [2] + data["a"] = [1] + + result = DataFrame(data) + tm.assert_frame_equal(result, expected) + + data = dict_type() + data["b"] = 2 + data["a"] = 1 + + result = DataFrame([data]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dict_type", [dict, OrderedDict]) + def test_constructor_ordered_dict_conflicting_orders(self, dict_type): + # the first dict element sets the ordering for the DataFrame, + # even if there are conflicting orders from subsequent ones + row_one = dict_type() + row_one["b"] = 2 + row_one["a"] = 1 + + row_two = dict_type() + row_two["a"] = 1 + row_two["b"] = 2 + + row_three = {"b": 2, "a": 1} + + expected = DataFrame([[2, 1], [2, 1]], columns=["b", "a"]) + result = DataFrame([row_one, row_two]) + tm.assert_frame_equal(result, expected) + + expected = DataFrame([[2, 1], [2, 1], [2, 1]], columns=["b", "a"]) + result = DataFrame([row_one, row_two, row_three]) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_series_aligned_index(self): + series = [Series(i, index=["b", "a", "c"], name=str(i)) for i in range(3)] + result = DataFrame(series) + expected = DataFrame( + {"b": [0, 1, 2], "a": [0, 1, 2], "c": [0, 1, 2]}, + columns=["b", "a", "c"], + index=["0", "1", "2"], + ) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_derived_dicts(self): + class CustomDict(dict): + pass + + d = {"a": 1.5, "b": 3} + + data_custom = [CustomDict(d)] + data = [d] + + result_custom = DataFrame(data_custom) + result = DataFrame(data) + tm.assert_frame_equal(result, result_custom) + + def test_constructor_ragged(self): + data = { + "A": np.random.default_rng(2).standard_normal(10), + "B": np.random.default_rng(2).standard_normal(8), + } + with pytest.raises(ValueError, match="All arrays must be of the same length"): + DataFrame(data) + + def test_constructor_scalar(self): + idx = Index(range(3)) + df = DataFrame({"a": 0}, index=idx) + expected = DataFrame({"a": [0, 0, 0]}, index=idx) + tm.assert_frame_equal(df, expected, check_dtype=False) + + def test_constructor_Series_copy_bug(self, float_frame): + df = DataFrame(float_frame["A"], index=float_frame.index, columns=["A"]) + df.copy() + + def test_constructor_mixed_dict_and_Series(self): + data = {} + data["A"] = {"foo": 1, "bar": 2, "baz": 3} + data["B"] = Series([4, 3, 2, 1], index=["bar", "qux", "baz", "foo"]) + + result = DataFrame(data) + assert result.index.is_monotonic_increasing + + # ordering ambiguous, raise exception + with pytest.raises(ValueError, match="ambiguous ordering"): + DataFrame({"A": ["a", "b"], "B": {"a": "a", "b": "b"}}) + + # this is OK though + result = DataFrame({"A": ["a", "b"], "B": Series(["a", "b"], index=["a", "b"])}) + expected = DataFrame({"A": ["a", "b"], "B": ["a", "b"]}, index=["a", "b"]) + tm.assert_frame_equal(result, expected) + + def test_constructor_mixed_type_rows(self): + # Issue 25075 + data = [[1, 2], (3, 4)] + result = DataFrame(data) + expected = DataFrame([[1, 2], [3, 4]]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "tuples,lists", + [ + ((), []), + (((),), [[]]), + (((), ()), [(), ()]), + (((), ()), [[], []]), + (([], []), [[], []]), + (([1], [2]), [[1], [2]]), # GH 32776 + (([1, 2, 3], [4, 5, 6]), [[1, 2, 3], [4, 5, 6]]), + ], + ) + def test_constructor_tuple(self, tuples, lists): + # GH 25691 + result = DataFrame(tuples) + expected = DataFrame(lists) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_tuples(self): + result = DataFrame({"A": [(1, 2), (3, 4)]}) + expected = DataFrame({"A": Series([(1, 2), (3, 4)])}) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_namedtuples(self): + # GH11181 + named_tuple = namedtuple("Pandas", list("ab")) + tuples = [named_tuple(1, 3), named_tuple(2, 4)] + expected = DataFrame({"a": [1, 2], "b": [3, 4]}) + result = DataFrame(tuples) + tm.assert_frame_equal(result, expected) + + # with columns + expected = DataFrame({"y": [1, 2], "z": [3, 4]}) + result = DataFrame(tuples, columns=["y", "z"]) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_dataclasses(self): + # GH21910 + Point = make_dataclass("Point", [("x", int), ("y", int)]) + + data = [Point(0, 3), Point(1, 3)] + expected = DataFrame({"x": [0, 1], "y": [3, 3]}) + result = DataFrame(data) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_dataclasses_with_varying_types(self): + # GH21910 + # varying types + Point = make_dataclass("Point", [("x", int), ("y", int)]) + HLine = make_dataclass("HLine", [("x0", int), ("x1", int), ("y", int)]) + + data = [Point(0, 3), HLine(1, 3, 3)] + + expected = DataFrame( + {"x": [0, np.nan], "y": [3, 3], "x0": [np.nan, 1], "x1": [np.nan, 3]} + ) + result = DataFrame(data) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_of_dataclasses_error_thrown(self): + # GH21910 + Point = make_dataclass("Point", [("x", int), ("y", int)]) + + # expect TypeError + msg = "asdict() should be called on dataclass instances" + with pytest.raises(TypeError, match=re.escape(msg)): + DataFrame([Point(0, 0), {"x": 1, "y": 0}]) + + def test_constructor_list_of_dict_order(self): + # GH10056 + data = [ + {"First": 1, "Second": 4, "Third": 7, "Fourth": 10}, + {"Second": 5, "First": 2, "Fourth": 11, "Third": 8}, + {"Second": 6, "First": 3, "Fourth": 12, "Third": 9, "YYY": 14, "XXX": 13}, + ] + expected = DataFrame( + { + "First": [1, 2, 3], + "Second": [4, 5, 6], + "Third": [7, 8, 9], + "Fourth": [10, 11, 12], + "YYY": [None, None, 14], + "XXX": [None, None, 13], + } + ) + result = DataFrame(data) + tm.assert_frame_equal(result, expected) + + def test_constructor_Series_named(self): + a = Series([1, 2, 3], index=["a", "b", "c"], name="x") + df = DataFrame(a) + assert df.columns[0] == "x" + tm.assert_index_equal(df.index, a.index) + + # ndarray like + arr = np.random.default_rng(2).standard_normal(10) + s = Series(arr, name="x") + df = DataFrame(s) + expected = DataFrame({"x": s}) + tm.assert_frame_equal(df, expected) + + s = Series(arr, index=range(3, 13)) + df = DataFrame(s) + expected = DataFrame({0: s}) + tm.assert_frame_equal(df, expected, check_column_type=False) + + msg = r"Shape of passed values is \(10, 1\), indices imply \(10, 2\)" + with pytest.raises(ValueError, match=msg): + DataFrame(s, columns=[1, 2]) + + # #2234 + a = Series([], name="x", dtype=object) + df = DataFrame(a) + assert df.columns[0] == "x" + + # series with name and w/o + s1 = Series(arr, name="x") + df = DataFrame([s1, arr]).T + expected = DataFrame({"x": s1, "Unnamed 0": arr}, columns=["x", "Unnamed 0"]) + tm.assert_frame_equal(df, expected) + + # this is a bit non-intuitive here; the series collapse down to arrays + df = DataFrame([arr, s1]).T + expected = DataFrame({1: s1, 0: arr}, columns=range(2)) + tm.assert_frame_equal(df, expected) + + def test_constructor_Series_named_and_columns(self): + # GH 9232 validation + + s0 = Series(range(5), name=0) + s1 = Series(range(5), name=1) + + # matching name and column gives standard frame + tm.assert_frame_equal(DataFrame(s0, columns=[0]), s0.to_frame()) + tm.assert_frame_equal(DataFrame(s1, columns=[1]), s1.to_frame()) + + # non-matching produces empty frame + assert DataFrame(s0, columns=[1]).empty + assert DataFrame(s1, columns=[0]).empty + + def test_constructor_Series_differently_indexed(self): + # name + s1 = Series([1, 2, 3], index=["a", "b", "c"], name="x") + + # no name + s2 = Series([1, 2, 3], index=["a", "b", "c"]) + + other_index = Index(["a", "b"]) + + df1 = DataFrame(s1, index=other_index) + exp1 = DataFrame(s1.reindex(other_index)) + assert df1.columns[0] == "x" + tm.assert_frame_equal(df1, exp1) + + df2 = DataFrame(s2, index=other_index) + exp2 = DataFrame(s2.reindex(other_index)) + assert df2.columns[0] == 0 + tm.assert_index_equal(df2.index, other_index) + tm.assert_frame_equal(df2, exp2) + + @pytest.mark.parametrize( + "name_in1,name_in2,name_in3,name_out", + [ + ("idx", "idx", "idx", "idx"), + ("idx", "idx", None, None), + ("idx", None, None, None), + ("idx1", "idx2", None, None), + ("idx1", "idx1", "idx2", None), + ("idx1", "idx2", "idx3", None), + (None, None, None, None), + ], + ) + def test_constructor_index_names(self, name_in1, name_in2, name_in3, name_out): + # GH13475 + indices = [ + Index(["a", "b", "c"], name=name_in1), + Index(["b", "c", "d"], name=name_in2), + Index(["c", "d", "e"], name=name_in3), + ] + series = { + c: Series([0, 1, 2], index=i) for i, c in zip(indices, ["x", "y", "z"]) + } + result = DataFrame(series) + + exp_ind = Index(["a", "b", "c", "d", "e"], name=name_out) + expected = DataFrame( + { + "x": [0, 1, 2, np.nan, np.nan], + "y": [np.nan, 0, 1, 2, np.nan], + "z": [np.nan, np.nan, 0, 1, 2], + }, + index=exp_ind, + ) + + tm.assert_frame_equal(result, expected) + + def test_constructor_manager_resize(self, float_frame): + index = list(float_frame.index[:5]) + columns = list(float_frame.columns[:3]) + + msg = "Passing a BlockManager to DataFrame" + with tm.assert_produces_warning( + DeprecationWarning, match=msg, check_stacklevel=False + ): + result = DataFrame(float_frame._mgr, index=index, columns=columns) + tm.assert_index_equal(result.index, Index(index)) + tm.assert_index_equal(result.columns, Index(columns)) + + def test_constructor_mix_series_nonseries(self, float_frame): + df = DataFrame( + {"A": float_frame["A"], "B": list(float_frame["B"])}, columns=["A", "B"] + ) + tm.assert_frame_equal(df, float_frame.loc[:, ["A", "B"]]) + + msg = "does not match index length" + with pytest.raises(ValueError, match=msg): + DataFrame({"A": float_frame["A"], "B": list(float_frame["B"])[:-2]}) + + def test_constructor_miscast_na_int_dtype(self): + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + + with pytest.raises(IntCastingNaNError, match=msg): + DataFrame([[np.nan, 1], [1, 0]], dtype=np.int64) + + def test_constructor_column_duplicates(self): + # it works! #2079 + df = DataFrame([[8, 5]], columns=["a", "a"]) + edf = DataFrame([[8, 5]]) + edf.columns = ["a", "a"] + + tm.assert_frame_equal(df, edf) + + idf = DataFrame.from_records([(8, 5)], columns=["a", "a"]) + + tm.assert_frame_equal(idf, edf) + + def test_constructor_empty_with_string_dtype(self, using_infer_string): + # GH 9428 + expected = DataFrame(index=[0, 1], columns=[0, 1], dtype=object) + expected_str = DataFrame( + index=[0, 1], columns=[0, 1], dtype=pd.StringDtype(na_value=np.nan) + ) + + df = DataFrame(index=[0, 1], columns=[0, 1], dtype=str) + if using_infer_string: + tm.assert_frame_equal(df, expected_str) + else: + tm.assert_frame_equal(df, expected) + df = DataFrame(index=[0, 1], columns=[0, 1], dtype=np.str_) + tm.assert_frame_equal(df, expected) + df = DataFrame(index=[0, 1], columns=[0, 1], dtype="U5") + tm.assert_frame_equal(df, expected) + + def test_constructor_empty_with_string_extension(self, nullable_string_dtype): + # GH 34915 + expected = DataFrame(columns=["c1"], dtype=nullable_string_dtype) + df = DataFrame(columns=["c1"], dtype=nullable_string_dtype) + tm.assert_frame_equal(df, expected) + + def test_constructor_single_value(self): + # expecting single value upcasting here + df = DataFrame(0.0, index=[1, 2, 3], columns=["a", "b", "c"]) + tm.assert_frame_equal( + df, DataFrame(np.zeros(df.shape).astype("float64"), df.index, df.columns) + ) + + df = DataFrame(0, index=[1, 2, 3], columns=["a", "b", "c"]) + tm.assert_frame_equal( + df, DataFrame(np.zeros(df.shape).astype("int64"), df.index, df.columns) + ) + + df = DataFrame("a", index=[1, 2], columns=["a", "c"]) + tm.assert_frame_equal( + df, + DataFrame( + np.array([["a", "a"], ["a", "a"]], dtype=object), + index=[1, 2], + columns=["a", "c"], + ), + ) + + msg = "DataFrame constructor not properly called!" + with pytest.raises(ValueError, match=msg): + DataFrame("a", [1, 2]) + with pytest.raises(ValueError, match=msg): + DataFrame("a", columns=["a", "c"]) + + msg = "incompatible data and dtype" + with pytest.raises(TypeError, match=msg): + DataFrame("a", [1, 2], ["a", "c"], float) + + def test_constructor_with_datetimes(self, using_infer_string): + intname = np.dtype(int).name + floatname = np.dtype(np.float64).name + objectname = np.dtype(np.object_).name + + # single item + df = DataFrame( + { + "A": 1, + "B": "foo", + "C": "bar", + "D": Timestamp("20010101").as_unit("s"), + "E": datetime(2001, 1, 2, 0, 0), + }, + index=np.arange(10), + ) + result = df.dtypes + expected = Series( + [np.dtype("int64")] + + [ + np.dtype(objectname) + if not using_infer_string + else pd.StringDtype(na_value=np.nan) + ] + * 2 + + [np.dtype("M8[s]"), np.dtype("M8[us]")], + index=list("ABCDE"), + ) + tm.assert_series_equal(result, expected) + + # check with ndarray construction ndim==0 (e.g. we are passing an ndim 0 + # ndarray with a dtype specified) + df = DataFrame( + { + "a": 1.0, + "b": 2, + "c": "foo", + floatname: np.array(1.0, dtype=floatname), + intname: np.array(1, dtype=intname), + }, + index=np.arange(10), + ) + result = df.dtypes + expected = Series( + [ + np.dtype("float64"), + np.dtype("int64"), + np.dtype("object") + if not using_infer_string + else pd.StringDtype(na_value=np.nan), + np.dtype("float64"), + np.dtype(intname), + ], + index=["a", "b", "c", floatname, intname], + ) + tm.assert_series_equal(result, expected) + + # check with ndarray construction ndim>0 + df = DataFrame( + { + "a": 1.0, + "b": 2, + "c": "foo", + floatname: np.array([1.0] * 10, dtype=floatname), + intname: np.array([1] * 10, dtype=intname), + }, + index=np.arange(10), + ) + result = df.dtypes + expected = Series( + [ + np.dtype("float64"), + np.dtype("int64"), + np.dtype("object") + if not using_infer_string + else pd.StringDtype(na_value=np.nan), + np.dtype("float64"), + np.dtype(intname), + ], + index=["a", "b", "c", floatname, intname], + ) + tm.assert_series_equal(result, expected) + + def test_constructor_with_datetimes1(self): + # GH 2809 + ind = date_range(start="2000-01-01", freq="D", periods=10) + datetimes = [ts.to_pydatetime() for ts in ind] + datetime_s = Series(datetimes) + assert datetime_s.dtype == "M8[us]" + + def test_constructor_with_datetimes2(self): + # GH 2810 + ind = date_range(start="2000-01-01", freq="D", periods=10) + datetimes = [ts.to_pydatetime() for ts in ind] + dates = [ts.date() for ts in ind] + df = DataFrame(datetimes, columns=["datetimes"]) + df["dates"] = dates + result = df.dtypes + expected = Series( + [np.dtype("datetime64[us]"), np.dtype("object")], + index=["datetimes", "dates"], + ) + tm.assert_series_equal(result, expected) + + def test_constructor_with_datetimes3(self): + # GH 7594 + # don't coerce tz-aware + dt = datetime(2012, 1, 1, tzinfo=zoneinfo.ZoneInfo("US/Eastern")) + + df = DataFrame({"End Date": dt}, index=[0]) + assert df.iat[0, 0] == dt + tm.assert_series_equal( + df.dtypes, Series({"End Date": "datetime64[us, US/Eastern]"}, dtype=object) + ) + + df = DataFrame([{"End Date": dt}]) + assert df.iat[0, 0] == dt + tm.assert_series_equal( + df.dtypes, Series({"End Date": "datetime64[us, US/Eastern]"}, dtype=object) + ) + + def test_constructor_with_datetimes4(self): + # tz-aware (UTC and other tz's) + # GH 8411 + dr = date_range("20130101", periods=3) + df = DataFrame({"value": dr}) + assert df.iat[0, 0].tz is None + dr = date_range("20130101", periods=3, tz="UTC") + df = DataFrame({"value": dr}) + assert str(df.iat[0, 0].tz) == "UTC" + dr = date_range("20130101", periods=3, tz="US/Eastern") + df = DataFrame({"value": dr}) + assert str(df.iat[0, 0].tz) == "US/Eastern" + + def test_constructor_with_datetimes5(self): + # GH 7822 + # preserver an index with a tz on dict construction + i = date_range("1/1/2011", periods=5, freq="10s", tz="US/Eastern") + + expected = DataFrame({"a": i.to_series().reset_index(drop=True)}) + df = DataFrame() + df["a"] = i + tm.assert_frame_equal(df, expected) + + df = DataFrame({"a": i}) + tm.assert_frame_equal(df, expected) + + def test_constructor_with_datetimes6(self): + # multiples + i = date_range("1/1/2011", periods=5, freq="10s", tz="US/Eastern") + i_no_tz = date_range("1/1/2011", periods=5, freq="10s") + df = DataFrame({"a": i, "b": i_no_tz}) + expected = DataFrame({"a": i.to_series().reset_index(drop=True), "b": i_no_tz}) + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( + "arr", + [ + np.array([None, None, None, None, datetime.now(), None]), + np.array([None, None, datetime.now(), None]), + [[np.datetime64("NaT")], [None]], + [[np.datetime64("NaT")], [pd.NaT]], + [[None], [np.datetime64("NaT")]], + [[None], [pd.NaT]], + [[pd.NaT], [np.datetime64("NaT")]], + [[pd.NaT], [None]], + ], + ) + def test_constructor_datetimes_with_nulls(self, arr): + # gh-15869, GH#11220 + result = DataFrame(arr).dtypes + unit = "ns" + if isinstance(arr, np.ndarray): + # inferred from a pydatetime object + unit = "us" + elif not any(isinstance(x, np.datetime64) for y in arr for x in y): + # TODO: this condition is not clear about why we have different behavior + unit = "s" + expected = Series([np.dtype(f"datetime64[{unit}]")]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("order", ["K", "A", "C", "F"]) + @pytest.mark.parametrize( + "unit", + ["M", "D", "h", "m", "s", "ms", "us", "ns"], + ) + def test_constructor_datetimes_non_ns(self, order, unit): + dtype = f"datetime64[{unit}]" + na = np.array( + [ + ["2015-01-01", "2015-01-02", "2015-01-03"], + ["2017-01-01", "2017-01-02", "2017-02-03"], + ], + dtype=dtype, + order=order, + ) + df = DataFrame(na) + expected = DataFrame(na.astype("M8[ns]")) + if unit in ["M", "D", "h", "m"]: + with pytest.raises(TypeError, match="Cannot cast"): + expected.astype(dtype) + + # instead the constructor casts to the closest supported reso, i.e. "s" + expected = expected.astype("datetime64[s]") + else: + expected = expected.astype(dtype=dtype) + + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize("order", ["K", "A", "C", "F"]) + @pytest.mark.parametrize( + "unit", + [ + "D", + "h", + "m", + "s", + "ms", + "us", + "ns", + ], + ) + def test_constructor_timedelta_non_ns(self, order, unit): + dtype = f"timedelta64[{unit}]" + na = np.array( + [ + [np.timedelta64(1, "D"), np.timedelta64(2, "D")], + [np.timedelta64(4, "D"), np.timedelta64(5, "D")], + ], + dtype=dtype, + order=order, + ) + df = DataFrame(na) + if unit in ["D", "h", "m"]: + # we get the nearest supported unit, i.e. "s" + exp_unit = "s" + else: + exp_unit = unit + exp_dtype = np.dtype(f"m8[{exp_unit}]") + expected = DataFrame( + [ + [Timedelta(1, "D"), Timedelta(2, "D")], + [Timedelta(4, "D"), Timedelta(5, "D")], + ], + dtype=exp_dtype, + ) + # TODO(2.0): ideally we should get the same 'expected' without passing + # dtype=exp_dtype. + tm.assert_frame_equal(df, expected) + + def test_constructor_for_list_with_dtypes(self, using_infer_string): + # test list of lists/ndarrays + df = DataFrame([np.arange(5) for x in range(5)]) + result = df.dtypes + expected = Series([np.dtype("int")] * 5) + tm.assert_series_equal(result, expected) + + df = DataFrame([np.array(np.arange(5), dtype="int32") for x in range(5)]) + result = df.dtypes + expected = Series([np.dtype("int32")] * 5) + tm.assert_series_equal(result, expected) + + # overflow issue? (we always expected int64 upcasting here) + df = DataFrame({"a": [2**31, 2**31 + 1]}) + assert df.dtypes.iloc[0] == np.dtype("int64") + + # GH #2751 (construction with no index specified), make sure we cast to + # platform values + df = DataFrame([1, 2]) + assert df.dtypes.iloc[0] == np.dtype("int64") + + df = DataFrame([1.0, 2.0]) + assert df.dtypes.iloc[0] == np.dtype("float64") + + df = DataFrame({"a": [1, 2]}) + assert df.dtypes.iloc[0] == np.dtype("int64") + + df = DataFrame({"a": [1.0, 2.0]}) + assert df.dtypes.iloc[0] == np.dtype("float64") + + df = DataFrame({"a": 1}, index=range(3)) + assert df.dtypes.iloc[0] == np.dtype("int64") + + df = DataFrame({"a": 1.0}, index=range(3)) + assert df.dtypes.iloc[0] == np.dtype("float64") + + # with object list + df = DataFrame( + { + "a": [1, 2, 4, 7], + "b": [1.2, 2.3, 5.1, 6.3], + "c": list("abcd"), + "d": [datetime(2000, 1, 1) for i in range(4)], + "e": [1.0, 2, 4.0, 7], + } + ) + result = df.dtypes + expected = Series( + [ + np.dtype("int64"), + np.dtype("float64"), + np.dtype("object") + if not using_infer_string + else pd.StringDtype(na_value=np.nan), + np.dtype("datetime64[us]"), + np.dtype("float64"), + ], + index=list("abcde"), + ) + tm.assert_series_equal(result, expected) + + def test_constructor_frame_copy(self, float_frame): + cop = DataFrame(float_frame, copy=True) + cop["A"] = 5 + assert (cop["A"] == 5).all() + assert not (float_frame["A"] == 5).all() + + def test_constructor_frame_shallow_copy(self, float_frame): + # constructing a DataFrame from DataFrame with copy=False should still + # give a "shallow" copy (share data, not attributes) + # https://github.com/pandas-dev/pandas/issues/49523 + orig = float_frame.copy() + cop = DataFrame(float_frame) + assert cop._mgr is not float_frame._mgr + # Overwriting index of copy doesn't change original + cop.index = np.arange(len(cop)) + tm.assert_frame_equal(float_frame, orig) + + def test_constructor_ndarray_copy(self, float_frame): + arr = float_frame.values.copy() + df = DataFrame(arr) + + arr[5] = 5 + assert not (df.values[5] == 5).all() + df = DataFrame(arr, copy=True) + arr[6] = 6 + assert not (df.values[6] == 6).all() + + def test_constructor_series_copy(self, float_frame): + series = float_frame._series + + df = DataFrame({"A": series["A"]}, copy=True) + # TODO can be replaced with `df.loc[:, "A"] = 5` after deprecation about + # inplace mutation is enforced + df.loc[df.index[0] : df.index[-1], "A"] = 5 + + assert not (series["A"] == 5).all() + + @pytest.mark.parametrize( + "df", + [ + DataFrame([[1, 2, 3], [4, 5, 6]], index=[1, np.nan]), + DataFrame([[1, 2, 3], [4, 5, 6]], columns=[1.1, 2.2, np.nan]), + DataFrame([[0, 1, 2, 3], [4, 5, 6, 7]], columns=[np.nan, 1.1, 2.2, np.nan]), + DataFrame( + [[0.0, 1, 2, 3.0], [4, 5, 6, 7]], columns=[np.nan, 1.1, 2.2, np.nan] + ), + DataFrame([[0.0, 1, 2, 3.0], [4, 5, 6, 7]], columns=[np.nan, 1, 2, 2]), + ], + ) + def test_constructor_with_nas(self, df): + # GH 5016 + # na's in indices + # GH 21428 (non-unique columns) + + for i in range(len(df.columns)): + df.iloc[:, i] + + indexer = np.arange(len(df.columns))[isna(df.columns)] + + # No NaN found -> error + if len(indexer) == 0: + with pytest.raises(KeyError, match="^nan$"): + df.loc[:, np.nan] + # single nan should result in Series + elif len(indexer) == 1: + tm.assert_series_equal(df.iloc[:, indexer[0]], df.loc[:, np.nan]) + # multiple nans should result in DataFrame + else: + tm.assert_frame_equal(df.iloc[:, indexer], df.loc[:, np.nan]) + + def test_constructor_lists_to_object_dtype(self): + # from #1074 + d = DataFrame({"a": [np.nan, False]}) + assert d["a"].dtype == np.object_ + assert not d["a"][1] + + def test_constructor_ndarray_categorical_dtype(self): + cat = Categorical(["A", "B", "C"]) + arr = np.array(cat).reshape(-1, 1) + arr = np.broadcast_to(arr, (3, 4)) + + result = DataFrame(arr, dtype=cat.dtype) + + expected = DataFrame({0: cat, 1: cat, 2: cat, 3: cat}) + tm.assert_frame_equal(result, expected) + + def test_constructor_categorical(self): + # GH8626 + + # dict creation + df = DataFrame({"A": list("abc")}, dtype="category") + expected = Series(list("abc"), dtype="category", name="A") + tm.assert_series_equal(df["A"], expected) + + # to_frame + s = Series(list("abc"), dtype="category") + result = s.to_frame() + expected = Series(list("abc"), dtype="category", name=0) + tm.assert_series_equal(result[0], expected) + result = s.to_frame(name="foo") + expected = Series(list("abc"), dtype="category", name="foo") + tm.assert_series_equal(result["foo"], expected) + + # list-like creation + df = DataFrame(list("abc"), dtype="category") + expected = Series(list("abc"), dtype="category", name=0) + tm.assert_series_equal(df[0], expected) + + def test_construct_from_1item_list_of_categorical(self): + # pre-2.0 this behaved as DataFrame({0: cat}), in 2.0 we remove + # Categorical special case + # ndim != 1 + cat = Categorical(list("abc")) + df = DataFrame([cat]) + expected = DataFrame([cat.astype(object)]) + tm.assert_frame_equal(df, expected) + + def test_construct_from_list_of_categoricals(self): + # pre-2.0 this behaved as DataFrame({0: cat}), in 2.0 we remove + # Categorical special case + + df = DataFrame([Categorical(list("abc")), Categorical(list("abd"))]) + expected = DataFrame([["a", "b", "c"], ["a", "b", "d"]]) + tm.assert_frame_equal(df, expected) + + def test_from_nested_listlike_mixed_types(self): + # pre-2.0 this behaved as DataFrame({0: cat}), in 2.0 we remove + # Categorical special case + # mixed + df = DataFrame([Categorical(list("abc")), list("def")]) + expected = DataFrame([["a", "b", "c"], ["d", "e", "f"]]) + tm.assert_frame_equal(df, expected) + + def test_construct_from_listlikes_mismatched_lengths(self): + df = DataFrame([Categorical(list("abc")), Categorical(list("abdefg"))]) + expected = DataFrame([list("abc"), list("abdefg")]) + tm.assert_frame_equal(df, expected) + + def test_constructor_categorical_series(self): + items = [1, 2, 3, 1] + exp = Series(items).astype("category") + res = Series(items, dtype="category") + tm.assert_series_equal(res, exp) + + items = ["a", "b", "c", "a"] + exp = Series(items).astype("category") + res = Series(items, dtype="category") + tm.assert_series_equal(res, exp) + + # insert into frame with different index + # GH 8076 + index = date_range("20000101", periods=3) + expected = Series( + Categorical(values=[np.nan, np.nan, np.nan], categories=["a", "b", "c"]) + ) + expected.index = index + + expected = DataFrame({"x": expected}) + df = DataFrame({"x": Series(["a", "b", "c"], dtype="category")}, index=index) + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize( + "dtype", + tm.ALL_NUMERIC_DTYPES + + tm.DATETIME64_DTYPES + + tm.TIMEDELTA64_DTYPES + + tm.BOOL_DTYPES, + ) + def test_check_dtype_empty_numeric_column(self, dtype): + # GH24386: Ensure dtypes are set correctly for an empty DataFrame. + # Empty DataFrame is generated via dictionary data with non-overlapping columns. + data = DataFrame({"a": [1, 2]}, columns=["b"], dtype=dtype) + + assert data.b.dtype == dtype + + @pytest.mark.parametrize( + "dtype", tm.STRING_DTYPES + tm.BYTES_DTYPES + tm.OBJECT_DTYPES + ) + def test_check_dtype_empty_string_column(self, dtype): + # GH24386: Ensure dtypes are set correctly for an empty DataFrame. + # Empty DataFrame is generated via dictionary data with non-overlapping columns. + data = DataFrame({"a": [1, 2]}, columns=["b"], dtype=dtype) + assert data.b.dtype.name == "object" + + def test_to_frame_with_falsey_names(self): + # GH 16114 + result = Series(name=0, dtype=object).to_frame().dtypes + expected = Series({0: object}) + tm.assert_series_equal(result, expected) + + result = DataFrame(Series(name=0, dtype=object)).dtypes + tm.assert_series_equal(result, expected) + + @pytest.mark.arm_slow + @pytest.mark.parametrize("dtype", [None, "uint8", "category"]) + def test_constructor_range_dtype(self, dtype): + expected = DataFrame({"A": [0, 1, 2, 3, 4]}, dtype=dtype or "int64") + + # GH 26342 + result = DataFrame(range(5), columns=["A"], dtype=dtype) + tm.assert_frame_equal(result, expected) + + # GH 16804 + result = DataFrame({"A": range(5)}, dtype=dtype) + tm.assert_frame_equal(result, expected) + + def test_frame_from_list_subclass(self): + # GH21226 + class List(list): + pass + + expected = DataFrame([[1, 2, 3], [4, 5, 6]]) + result = DataFrame(List([List([1, 2, 3]), List([4, 5, 6])])) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "extension_arr", + [ + Categorical(list("aabbc")), + SparseArray([1, np.nan, np.nan, np.nan]), + IntervalArray([Interval(0, 1), Interval(1, 5)]), + PeriodArray(pd.period_range(start="1/1/2017", end="1/1/2018", freq="M")), + ], + ) + def test_constructor_with_extension_array(self, extension_arr): + # GH11363 + expected = DataFrame(Series(extension_arr)) + result = DataFrame(extension_arr) + tm.assert_frame_equal(result, expected) + + def test_datetime_date_tuple_columns_from_dict(self): + # GH 10863 + v = date.today() + tup = v, v + result = DataFrame({tup: Series(range(3), index=range(3))}, columns=[tup]) + expected = DataFrame([0, 1, 2], columns=Index(Series([tup]))) + tm.assert_frame_equal(result, expected) + + def test_construct_with_two_categoricalindex_series(self): + # GH 14600 + s1 = Series([39, 6, 4], index=CategoricalIndex(["female", "male", "unknown"])) + s2 = Series( + [2, 152, 2, 242, 150], + index=CategoricalIndex(["f", "female", "m", "male", "unknown"]), + ) + result = DataFrame([s1, s2]) + expected = DataFrame( + np.array([[39, 6, 4, np.nan, np.nan], [152.0, 242.0, 150.0, 2.0, 2.0]]), + columns=["female", "male", "unknown", "f", "m"], + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:invalid value encountered in cast:RuntimeWarning" + ) + def test_constructor_series_nonexact_categoricalindex(self): + # GH 42424 + ser = Series(range(100)) + ser1 = cut(ser, 10).value_counts().head(5) + ser2 = cut(ser, 10).value_counts().tail(5) + result = DataFrame({"1": ser1, "2": ser2}) + index = CategoricalIndex( + [ + Interval(-0.099, 9.9, closed="right"), + Interval(9.9, 19.8, closed="right"), + Interval(19.8, 29.7, closed="right"), + Interval(29.7, 39.6, closed="right"), + Interval(39.6, 49.5, closed="right"), + Interval(49.5, 59.4, closed="right"), + Interval(59.4, 69.3, closed="right"), + Interval(69.3, 79.2, closed="right"), + Interval(79.2, 89.1, closed="right"), + Interval(89.1, 99, closed="right"), + ], + ordered=True, + ) + expected = DataFrame( + {"1": [10] * 5 + [np.nan] * 5, "2": [np.nan] * 5 + [10] * 5}, index=index + ) + tm.assert_frame_equal(expected, result) + + def test_from_M8_structured(self): + dates = [(datetime(2012, 9, 9, 0, 0), datetime(2012, 9, 8, 15, 10))] + arr = np.array(dates, dtype=[("Date", "M8[us]"), ("Forecasting", "M8[us]")]) + df = DataFrame(arr) + + assert df["Date"][0] == dates[0][0] + assert df["Forecasting"][0] == dates[0][1] + + s = Series(arr["Date"]) + assert isinstance(s[0], Timestamp) + assert s[0] == dates[0][0] + + def test_from_datetime_subclass(self): + # GH21142 Verify whether Datetime subclasses are also of dtype datetime + class DatetimeSubclass(datetime): + pass + + data = DataFrame({"datetime": [DatetimeSubclass(2020, 1, 1, 1, 1)]}) + assert data.datetime.dtype == "datetime64[us]" + + def test_with_mismatched_index_length_raises(self): + # GH#33437 + dti = date_range("2016-01-01", periods=3, tz="US/Pacific") + msg = "Shape of passed values|Passed arrays should have the same length" + with pytest.raises(ValueError, match=msg): + DataFrame(dti, index=range(4)) + + def test_frame_ctor_datetime64_column(self): + rng = date_range("1/1/2000 00:00:00", "1/1/2000 1:59:50", freq="10s") + dates = np.asarray(rng) + + df = DataFrame( + {"A": np.random.default_rng(2).standard_normal(len(rng)), "B": dates} + ) + assert np.issubdtype(df["B"].dtype, np.dtype("M8[ns]")) + + def test_dataframe_constructor_infer_multiindex(self): + index_lists = [["a", "a", "b", "b"], ["x", "y", "x", "y"]] + + multi = DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), + index=[np.array(x) for x in index_lists], + ) + assert isinstance(multi.index, MultiIndex) + assert not isinstance(multi.columns, MultiIndex) + + multi = DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), columns=index_lists + ) + assert isinstance(multi.columns, MultiIndex) + + @pytest.mark.parametrize( + "input_vals", + [ + ([1, 2]), + (["1", "2"]), + (list(date_range("1/1/2011", periods=2, freq="h"))), + (list(date_range("1/1/2011", periods=2, freq="h", tz="US/Eastern"))), + ([Interval(left=0, right=5)]), + ], + ) + def test_constructor_list_str(self, input_vals, string_dtype): + # GH#16605 + # Ensure that data elements are converted to strings when + # dtype is str, 'str', or 'U' + + result = DataFrame({"A": input_vals}, dtype=string_dtype) + expected = DataFrame({"A": input_vals}).astype({"A": string_dtype}) + tm.assert_frame_equal(result, expected) + + def test_constructor_list_str_na(self, string_dtype): + result = DataFrame({"A": [1.0, 2.0, None]}, dtype=string_dtype) + expected = DataFrame({"A": ["1.0", "2.0", None]}, dtype=object) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("copy", [False, True]) + def test_dict_nocopy( + self, + copy, + any_numeric_ea_dtype, + any_numpy_dtype, + ): + a = np.array([1, 2], dtype=any_numpy_dtype) + b = np.array([3, 4], dtype=any_numpy_dtype) + if b.dtype.kind in ["S", "U"]: + # These get cast, making the checks below more cumbersome + pytest.skip(f"{b.dtype} get cast, making the checks below more cumbersome") + + c = pd.array([1, 2], dtype=any_numeric_ea_dtype) + c_orig = c.copy() + df = DataFrame({"a": a, "b": b, "c": c}, copy=copy) + + def get_base(obj): + if isinstance(obj, np.ndarray): + return obj.base + elif isinstance(obj.dtype, np.dtype): + # i.e. DatetimeArray, TimedeltaArray + return obj._ndarray.base + else: + raise TypeError + + def check_views(c_only: bool = False): + # Check that the underlying data behind df["c"] is still `c` + # after setting with iloc. Since we don't know which entry in + # df._mgr.blocks corresponds to df["c"], we just check that exactly + # one of these arrays is `c`. GH#38939 + assert sum(x.values is c for x in df._mgr.blocks) == 1 + if c_only: + # If we ever stop consolidating in setitem_with_indexer, + # this will become unnecessary. + return + + assert ( + sum( + get_base(x.values) is a + for x in df._mgr.blocks + if isinstance(x.values.dtype, np.dtype) + ) + == 1 + ) + assert ( + sum( + get_base(x.values) is b + for x in df._mgr.blocks + if isinstance(x.values.dtype, np.dtype) + ) + == 1 + ) + + if not copy: + # constructor preserves views + check_views() + + # TODO: most of the rest of this test belongs in indexing tests + should_raise = not lib.is_np_dtype(df.dtypes.iloc[0], "fciuO") + if should_raise: + with pytest.raises(TypeError, match="Invalid value"): + df.iloc[0, 0] = 0 + df.iloc[0, 1] = 0 + return + else: + df.iloc[0, 0] = 0 + df.iloc[0, 1] = 0 + if not copy: + check_views(True) + + # FIXME(GH#35417): until GH#35417, iloc.setitem into EA values does not preserve + # view, so we have to check in the other direction + df.iloc[:, 2] = pd.array([45, 46], dtype=c.dtype) + assert df.dtypes.iloc[2] == c.dtype + if copy: + if a.dtype.kind == "M": + assert a[0] == a.dtype.type(1, "ns") + assert b[0] == b.dtype.type(3, "ns") + else: + assert a[0] == a.dtype.type(1) + assert b[0] == b.dtype.type(3) + # FIXME(GH#35417): enable after GH#35417 + assert c[0] == c_orig[0] # i.e. df.iloc[0, 2]=45 did *not* update c + + def test_construct_from_dict_ea_series(self): + # GH#53744 - default of copy=True should also apply for Series with + # extension dtype + ser = Series([1, 2, 3], dtype="Int64") + df = DataFrame({"a": ser}) + assert not np.shares_memory(ser.values._data, df["a"].values._data) + + def test_from_series_with_name_with_columns(self): + # GH 7893 + result = DataFrame(Series(1, name="foo"), columns=["bar"]) + expected = DataFrame(columns=["bar"]) + tm.assert_frame_equal(result, expected) + + def test_nested_list_columns(self): + # GH 14467 + result = DataFrame( + [[1, 2, 3], [4, 5, 6]], columns=[["A", "A", "A"], ["a", "b", "c"]] + ) + expected = DataFrame( + [[1, 2, 3], [4, 5, 6]], + columns=MultiIndex.from_tuples([("A", "a"), ("A", "b"), ("A", "c")]), + ) + tm.assert_frame_equal(result, expected) + + def test_from_2d_object_array_of_periods_or_intervals(self): + # Period analogue to GH#26825 + pi = pd.period_range("2016-04-05", periods=3) + data = pi._data.astype(object).reshape(1, -1) + df = DataFrame(data) + assert df.shape == (1, 3) + assert (df.dtypes == pi.dtype).all() + assert (df == pi).all().all() + + ii = pd.IntervalIndex.from_breaks([3, 4, 5, 6]) + data2 = ii._data.astype(object).reshape(1, -1) + df2 = DataFrame(data2) + assert df2.shape == (1, 3) + assert (df2.dtypes == ii.dtype).all() + assert (df2 == ii).all().all() + + # mixed + data3 = np.r_[data, data2, data, data2].T + df3 = DataFrame(data3) + expected = DataFrame({0: pi, 1: ii, 2: pi, 3: ii}) + tm.assert_frame_equal(df3, expected) + + @pytest.mark.parametrize( + "col_a, col_b", + [ + ([[1], [2]], np.array([[1], [2]])), + (np.array([[1], [2]]), [[1], [2]]), + (np.array([[1], [2]]), np.array([[1], [2]])), + ], + ) + def test_error_from_2darray(self, col_a, col_b): + msg = "Per-column arrays must each be 1-dimensional" + with pytest.raises(ValueError, match=msg): + DataFrame({"a": col_a, "b": col_b}) + + def test_from_dict_with_missing_copy_false(self): + # GH#45369 filled columns should not be views of one another + df = DataFrame(index=[1, 2, 3], columns=["a", "b", "c"], copy=False) + assert not np.shares_memory(df["a"]._values, df["b"]._values) + + df.iloc[0, 0] = 0 + expected = DataFrame( + { + "a": [0, np.nan, np.nan], + "b": [np.nan, np.nan, np.nan], + "c": [np.nan, np.nan, np.nan], + }, + index=[1, 2, 3], + dtype=object, + ) + tm.assert_frame_equal(df, expected) + + def test_construction_empty_array_multi_column_raises(self): + # GH#46822 + msg = r"Shape of passed values is \(0, 1\), indices imply \(0, 2\)" + with pytest.raises(ValueError, match=msg): + DataFrame(data=np.array([]), columns=["a", "b"]) + + def test_construct_with_strings_and_none(self): + # GH#32218 + df = DataFrame(["1", "2", None], columns=["a"], dtype="str") + expected = DataFrame({"a": ["1", "2", None]}, dtype="str") + tm.assert_frame_equal(df, expected) + + def test_frame_string_inference(self): + # GH#54430 + dtype = pd.StringDtype(na_value=np.nan) + expected = DataFrame( + {"a": ["a", "b"]}, dtype=dtype, columns=Index(["a"], dtype=dtype) + ) + with pd.option_context("future.infer_string", True): + df = DataFrame({"a": ["a", "b"]}) + tm.assert_frame_equal(df, expected) + + expected = DataFrame( + {"a": ["a", "b"]}, + dtype=dtype, + columns=Index(["a"], dtype=dtype), + index=Index(["x", "y"], dtype=dtype), + ) + with pd.option_context("future.infer_string", True): + df = DataFrame({"a": ["a", "b"]}, index=["x", "y"]) + tm.assert_frame_equal(df, expected) + + expected = DataFrame( + {"a": ["a", 1]}, dtype="object", columns=Index(["a"], dtype=dtype) + ) + with pd.option_context("future.infer_string", True): + df = DataFrame({"a": ["a", 1]}) + tm.assert_frame_equal(df, expected) + + expected = DataFrame( + {"a": ["a", "b"]}, dtype="object", columns=Index(["a"], dtype=dtype) + ) + with pd.option_context("future.infer_string", True): + df = DataFrame({"a": ["a", "b"]}, dtype="object") + tm.assert_frame_equal(df, expected) + + def test_frame_string_inference_array_string_dtype(self): + # GH#54496 + dtype = pd.StringDtype(na_value=np.nan) + expected = DataFrame( + {"a": ["a", "b"]}, dtype=dtype, columns=Index(["a"], dtype=dtype) + ) + with pd.option_context("future.infer_string", True): + df = DataFrame({"a": np.array(["a", "b"])}) + tm.assert_frame_equal(df, expected) + + expected = DataFrame({0: ["a", "b"], 1: ["c", "d"]}, dtype=dtype) + with pd.option_context("future.infer_string", True): + df = DataFrame(np.array([["a", "c"], ["b", "d"]])) + tm.assert_frame_equal(df, expected) + + expected = DataFrame( + {"a": ["a", "b"], "b": ["c", "d"]}, + dtype=dtype, + columns=Index(["a", "b"], dtype=dtype), + ) + with pd.option_context("future.infer_string", True): + df = DataFrame(np.array([["a", "c"], ["b", "d"]]), columns=["a", "b"]) + tm.assert_frame_equal(df, expected) + + def test_frame_string_inference_block_dim(self): + # GH#55363 + with pd.option_context("future.infer_string", True): + df = DataFrame(np.array([["hello", "goodbye"], ["hello", "Hello"]])) + assert df._mgr.blocks[0].ndim == 2 + + @pytest.mark.parametrize("klass", [Series, Index]) + def test_inference_on_pandas_objects(self, klass): + # GH#56012 + obj = klass([Timestamp("2019-12-31")], dtype=object) + result = DataFrame(obj, columns=["a"]) + assert result.dtypes.iloc[0] == np.object_ + + result = DataFrame({"a": obj}) + assert result.dtypes.iloc[0] == np.object_ + + def test_dict_keys_returns_rangeindex(self): + result = DataFrame({0: [1], 1: [2]}).columns + expected = RangeIndex(2) + tm.assert_index_equal(result, expected, exact=True) + + @pytest.mark.parametrize( + "cons", [Series, Index, DatetimeIndex, DataFrame, pd.array, pd.to_datetime] + ) + def test_construction_datetime_resolution_inference(self, cons): + ts = Timestamp(2999, 1, 1) + ts2 = ts.tz_localize("US/Pacific") + + obj = cons([ts]) + res_dtype = tm.get_dtype(obj) + assert res_dtype == "M8[us]", res_dtype + + obj2 = cons([ts2]) + res_dtype2 = tm.get_dtype(obj2) + assert res_dtype2 == "M8[us, US/Pacific]", res_dtype2 + + def test_construction_nan_value_timedelta64_dtype(self): + # GH#60064 + result = DataFrame([None, 1], dtype="timedelta64[ns]") + expected = DataFrame( + ["NaT", "0 days 00:00:00.000000001"], dtype="timedelta64[ns]" + ) + tm.assert_frame_equal(result, expected) + + def test_dataframe_from_array_like_with_name_attribute(self): + # GH#61443 + class DummyArray(np.ndarray): + def __new__(cls, input_array): + obj = np.asarray(input_array).view(cls) + obj.name = "foo" + return obj + + dummy = DummyArray(np.eye(3)) + df = DataFrame(dummy) + expected = DataFrame(np.eye(3)) + tm.assert_frame_equal(df, expected) + + +class TestDataFrameConstructorIndexInference: + def test_frame_from_dict_of_series_overlapping_monthly_period_indexes(self): + rng1 = pd.period_range("1/1/1999", "1/1/2012", freq="M") + s1 = Series(np.random.default_rng(2).standard_normal(len(rng1)), rng1) + + rng2 = pd.period_range("1/1/1980", "12/1/2001", freq="M") + s2 = Series(np.random.default_rng(2).standard_normal(len(rng2)), rng2) + df = DataFrame({"s1": s1, "s2": s2}) + + exp = pd.period_range("1/1/1980", "1/1/2012", freq="M") + tm.assert_index_equal(df.index, exp) + + def test_frame_from_dict_with_mixed_tzaware_indexes(self): + # GH#44091 + dti = date_range("2016-01-01", periods=3) + + ser1 = Series(range(3), index=dti) + ser2 = Series(range(3), index=dti.tz_localize("UTC")) + ser3 = Series(range(3), index=dti.tz_localize("US/Central")) + ser4 = Series(range(3)) + + # no tz-naive, but we do have mixed tzs and a non-DTI + df1 = DataFrame({"A": ser2, "B": ser3, "C": ser4}) + exp_index = Index( + list(ser2.index) + list(ser3.index) + list(ser4.index), dtype=object + ) + tm.assert_index_equal(df1.index, exp_index) + + df2 = DataFrame({"A": ser2, "C": ser4, "B": ser3}) + exp_index3 = Index( + list(ser2.index) + list(ser4.index) + list(ser3.index), dtype=object + ) + tm.assert_index_equal(df2.index, exp_index3) + + df3 = DataFrame({"B": ser3, "A": ser2, "C": ser4}) + exp_index3 = Index( + list(ser3.index) + list(ser2.index) + list(ser4.index), dtype=object + ) + tm.assert_index_equal(df3.index, exp_index3) + + df4 = DataFrame({"C": ser4, "B": ser3, "A": ser2}) + exp_index4 = Index( + list(ser4.index) + list(ser3.index) + list(ser2.index), dtype=object + ) + tm.assert_index_equal(df4.index, exp_index4) + + # TODO: not clear if these raising is desired (no extant tests), + # but this is de facto behavior 2021-12-22 + msg = "Cannot join tz-naive with tz-aware DatetimeIndex" + with pytest.raises(TypeError, match=msg): + DataFrame({"A": ser2, "B": ser3, "C": ser4, "D": ser1}) + with pytest.raises(TypeError, match=msg): + DataFrame({"A": ser2, "B": ser3, "D": ser1}) + with pytest.raises(TypeError, match=msg): + DataFrame({"D": ser1, "A": ser2, "B": ser3}) + + @pytest.mark.parametrize( + "key_val, col_vals, col_type", + [ + ["3", ["3", "4"], "utf8"], + [3, [3, 4], "int8"], + ], + ) + def test_dict_data_arrow_column_expansion(self, key_val, col_vals, col_type): + # GH 53617 + pa = pytest.importorskip("pyarrow") + cols = pd.arrays.ArrowExtensionArray( + pa.array(col_vals, type=pa.dictionary(pa.int8(), getattr(pa, col_type)())) + ) + result = DataFrame({key_val: [1, 2]}, columns=cols) + expected = DataFrame([[1, np.nan], [2, np.nan]], columns=cols) + expected.isetitem(1, expected.iloc[:, 1].astype(object)) + tm.assert_frame_equal(result, expected) + + +class TestDataFrameConstructorWithDtypeCoercion: + def test_floating_values_integer_dtype(self): + # GH#40110 make DataFrame behavior with arraylike floating data and + # inty dtype match Series behavior + + arr = np.random.default_rng(2).standard_normal((10, 5)) + + # GH#49599 in 2.0 we raise instead of either + # a) silently ignoring dtype and returningfloat (the old Series behavior) or + # b) rounding (the old DataFrame behavior) + msg = "Trying to coerce float values to integers" + with pytest.raises(ValueError, match=msg): + DataFrame(arr, dtype="i8") + + df = DataFrame(arr.round(), dtype="i8") + assert (df.dtypes == "i8").all() + + # with NaNs, we go through a different path with a different warning + arr[0, 0] = np.nan + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + with pytest.raises(IntCastingNaNError, match=msg): + DataFrame(arr, dtype="i8") + with pytest.raises(IntCastingNaNError, match=msg): + Series(arr[0], dtype="i8") + # The future (raising) behavior matches what we would get via astype: + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + with pytest.raises(IntCastingNaNError, match=msg): + DataFrame(arr).astype("i8") + with pytest.raises(IntCastingNaNError, match=msg): + Series(arr[0]).astype("i8") + + +class TestDataFrameConstructorWithDatetimeTZ: + @pytest.mark.parametrize("tz", ["US/Eastern", "dateutil/US/Eastern"]) + def test_construction_preserves_tzaware_dtypes(self, tz): + # after GH#7822 + # these retain the timezones on dict construction + dr = date_range("2011/1/1", "2012/1/1", freq="W-FRI", unit="ns") + dr_tz = dr.tz_localize(tz) + df = DataFrame({"A": "foo", "B": dr_tz}, index=dr) + tz_expected = DatetimeTZDtype("ns", dr_tz.tzinfo) + assert df["B"].dtype == tz_expected + + # GH#2810 (with timezones) + datetimes_naive = [ts.to_pydatetime() for ts in dr] + datetimes_with_tz = [ts.to_pydatetime() for ts in dr_tz] + df = DataFrame({"dr": dr}) + df["dr_tz"] = dr_tz + df["datetimes_naive"] = datetimes_naive + df["datetimes_with_tz"] = datetimes_with_tz + result = df.dtypes + expected = Series( + [ + np.dtype("datetime64[ns]"), + DatetimeTZDtype(tz=tz), + np.dtype("datetime64[us]"), + DatetimeTZDtype(tz=tz, unit="us"), + ], + index=["dr", "dr_tz", "datetimes_naive", "datetimes_with_tz"], + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("pydt", [True, False]) + def test_constructor_data_aware_dtype_naive(self, tz_aware_fixture, pydt): + # GH#25843, GH#41555, GH#33401 + tz = tz_aware_fixture + ts = Timestamp("2019", tz=tz) + if pydt: + ts = ts.to_pydatetime() + + msg = ( + "Cannot convert timezone-aware data to timezone-naive dtype. " + r"Use pd.Series\(values\).dt.tz_localize\(None\) instead." + ) + with pytest.raises(ValueError, match=msg): + DataFrame({0: [ts]}, dtype="datetime64[ns]") + + msg2 = "Cannot unbox tzaware Timestamp to tznaive dtype" + with pytest.raises(TypeError, match=msg2): + DataFrame({0: ts}, index=[0], dtype="datetime64[ns]") + + with pytest.raises(ValueError, match=msg): + DataFrame([ts], dtype="datetime64[ns]") + + with pytest.raises(ValueError, match=msg): + DataFrame(np.array([ts], dtype=object), dtype="datetime64[ns]") + + with pytest.raises(TypeError, match=msg2): + DataFrame(ts, index=[0], columns=[0], dtype="datetime64[ns]") + + with pytest.raises(ValueError, match=msg): + DataFrame([Series([ts])], dtype="datetime64[ns]") + + with pytest.raises(ValueError, match=msg): + DataFrame([[ts]], columns=[0], dtype="datetime64[ns]") + + def test_from_dict(self): + # 8260 + # support datetime64 with tz + + idx = Index(date_range("20130101", periods=3, tz="US/Eastern"), name="foo") + dr = date_range("20130110", periods=3) + + # construction + df = DataFrame({"A": idx, "B": dr}) + assert df["A"].dtype, "M8[ns, US/Eastern" + assert df["A"].name == "A" + tm.assert_series_equal(df["A"], Series(idx, name="A")) + tm.assert_series_equal(df["B"], Series(dr, name="B")) + + def test_from_index(self): + # from index + idx2 = date_range("20130101", periods=3, tz="US/Eastern", name="foo") + df2 = DataFrame(idx2) + tm.assert_series_equal(df2["foo"], Series(idx2, name="foo")) + df2 = DataFrame(Series(idx2)) + tm.assert_series_equal(df2["foo"], Series(idx2, name="foo")) + + idx2 = date_range("20130101", periods=3, tz="US/Eastern") + df2 = DataFrame(idx2) + tm.assert_series_equal(df2[0], Series(idx2, name=0)) + df2 = DataFrame(Series(idx2)) + tm.assert_series_equal(df2[0], Series(idx2, name=0)) + + def test_frame_dict_constructor_datetime64_1680(self): + dr = date_range("1/1/2012", periods=10) + s = Series(dr, index=dr) + + # it works! + DataFrame({"a": "foo", "b": s}, index=dr) + DataFrame({"a": "foo", "b": s.values}, index=dr) + + def test_frame_datetime64_mixed_index_ctor_1681(self): + dr = date_range("2011/1/1", "2012/1/1", freq="W-FRI") + ts = Series(dr) + + # it works! + d = DataFrame({"A": "foo", "B": ts}, index=dr) + assert d["B"].isna().all() + + def test_frame_timeseries_column(self): + # GH19157 + dr = date_range( + start="20130101T10:00:00", periods=3, freq="min", tz="US/Eastern", unit="ns" + ) + result = DataFrame(dr, columns=["timestamps"]) + expected = DataFrame( + { + "timestamps": [ + Timestamp("20130101T10:00:00", tz="US/Eastern"), + Timestamp("20130101T10:01:00", tz="US/Eastern"), + Timestamp("20130101T10:02:00", tz="US/Eastern"), + ] + }, + dtype="M8[ns, US/Eastern]", + ) + tm.assert_frame_equal(result, expected) + + def test_nested_dict_construction(self): + # GH22227 + columns = ["Nevada", "Ohio"] + pop = { + "Nevada": {2001: 2.4, 2002: 2.9}, + "Ohio": {2000: 1.5, 2001: 1.7, 2002: 3.6}, + } + result = DataFrame(pop, index=[2001, 2002, 2003], columns=columns) + expected = DataFrame( + [(2.4, 1.7), (2.9, 3.6), (np.nan, np.nan)], + columns=columns, + index=Index([2001, 2002, 2003]), + ) + tm.assert_frame_equal(result, expected) + + def test_from_tzaware_object_array(self): + # GH#26825 2D object array of tzaware timestamps should not raise + dti = date_range("2016-04-05 04:30", periods=3, tz="UTC") + data = dti._data.astype(object).reshape(1, -1) + df = DataFrame(data) + assert df.shape == (1, 3) + assert (df.dtypes == dti.dtype).all() + assert (df == dti).all().all() + + def test_from_tzaware_mixed_object_array(self): + # GH#26825 + arr = np.array( + [ + [ + Timestamp("2013-01-01 00:00:00"), + Timestamp("2013-01-02 00:00:00"), + Timestamp("2013-01-03 00:00:00"), + ], + [ + Timestamp("2013-01-01 00:00:00-0500", tz="US/Eastern"), + pd.NaT, + Timestamp("2013-01-03 00:00:00-0500", tz="US/Eastern"), + ], + [ + Timestamp("2013-01-01 00:00:00+0100", tz="CET"), + pd.NaT, + Timestamp("2013-01-03 00:00:00+0100", tz="CET"), + ], + ], + dtype=object, + ).T + res = DataFrame(arr, columns=["A", "B", "C"]) + + expected_dtypes = [ + "datetime64[us]", + "datetime64[us, US/Eastern]", + "datetime64[us, CET]", + ] + assert (res.dtypes == expected_dtypes).all() + + def test_from_2d_ndarray_with_dtype(self): + # GH#12513 + array_dim2 = np.arange(10).reshape((5, 2)) + df = DataFrame(array_dim2, dtype="datetime64[ns, UTC]") + + expected = DataFrame(array_dim2).astype("datetime64[ns, UTC]") + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize("typ", [set, frozenset]) + def test_construction_from_set_raises(self, typ): + # https://github.com/pandas-dev/pandas/issues/32582 + values = typ({1, 2, 3}) + msg = f"'{typ.__name__}' type is unordered" + with pytest.raises(TypeError, match=msg): + DataFrame({"a": values}) + + with pytest.raises(TypeError, match=msg): + Series(values) + + def test_construction_from_ndarray_datetimelike(self): + # ensure the underlying arrays are properly wrapped as EA when + # constructed from 2D ndarray + arr = np.arange(0, 12, dtype="datetime64[ns]").reshape(4, 3) + df = DataFrame(arr) + assert all(isinstance(block.values, DatetimeArray) for block in df._mgr.blocks) + + def test_construction_from_ndarray_with_eadtype_mismatched_columns(self): + arr = np.random.default_rng(2).standard_normal((10, 2)) + dtype = pd.array([2.0]).dtype + msg = r"len\(arrays\) must match len\(columns\)" + with pytest.raises(ValueError, match=msg): + DataFrame(arr, columns=["foo"], dtype=dtype) + + arr2 = pd.array([2.0, 3.0, 4.0]) + with pytest.raises(ValueError, match=msg): + DataFrame(arr2, columns=["foo", "bar"]) + + def test_columns_indexes_raise_on_sets(self): + # GH 47215 + data = [[1, 2, 3], [4, 5, 6]] + with pytest.raises(ValueError, match="index cannot be a set"): + DataFrame(data, index={"a", "b"}) + with pytest.raises(ValueError, match="columns cannot be a set"): + DataFrame(data, columns={"a", "b", "c"}) + + def test_from_dict_with_columns_na_scalar(self): + result = DataFrame({"a": pd.NaT}, columns=["a"], index=range(2)) + expected = DataFrame({"a": Series([pd.NaT, pd.NaT])}) + tm.assert_frame_equal(result, expected) + + # TODO: make this not cast to object in pandas 3.0 + @pytest.mark.skipif( + not np_version_gt2, reason="StringDType only available in numpy 2 and above" + ) + @pytest.mark.parametrize( + "data", + [ + {"a": ["a", "b", "c"], "b": [1.0, 2.0, 3.0], "c": ["d", "e", "f"]}, + ], + ) + def test_np_string_array_object_cast(self, data): + from numpy.dtypes import StringDType + + data["a"] = np.array(data["a"], dtype=StringDType()) + res = DataFrame(data) + assert res["a"].dtype == np.object_ + assert (res["a"] == data["a"]).all() + + +def get1(obj): # TODO: make a helper in tm? + if isinstance(obj, Series): + return obj.iloc[0] + else: + return obj.iloc[0, 0] + + +class TestFromScalar: + @pytest.fixture(params=[list, dict, None]) + def box(self, request): + return request.param + + @pytest.fixture + def constructor(self, frame_or_series, box): + extra = {"index": range(2)} + if frame_or_series is DataFrame: + extra["columns"] = ["A"] + + if box is None: + return functools.partial(frame_or_series, **extra) + + elif box is dict: + if frame_or_series is Series: + return lambda x, **kwargs: frame_or_series( + {0: x, 1: x}, **extra, **kwargs + ) + else: + return lambda x, **kwargs: frame_or_series({"A": x}, **extra, **kwargs) + elif frame_or_series is Series: + return lambda x, **kwargs: frame_or_series([x, x], **extra, **kwargs) + else: + return lambda x, **kwargs: frame_or_series({"A": [x, x]}, **extra, **kwargs) + + @pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"]) + def test_from_nat_scalar(self, dtype, constructor): + obj = constructor(pd.NaT, dtype=dtype) + assert np.all(obj.dtypes == dtype) + assert np.all(obj.isna()) + + def test_from_timedelta_scalar_preserves_nanos(self, constructor): + td = Timedelta(1) + + obj = constructor(td, dtype="m8[ns]") + assert get1(obj) == td + + def test_from_timestamp_scalar_preserves_nanos(self, constructor, fixed_now_ts): + ts = fixed_now_ts + Timedelta(1) + + obj = constructor(ts, dtype="M8[ns]") + assert get1(obj) == ts + + def test_from_timedelta64_scalar_object(self, constructor): + td = Timedelta(1) + td64 = td.to_timedelta64() + + obj = constructor(td64, dtype=object) + assert isinstance(get1(obj), np.timedelta64) + + @pytest.mark.parametrize("cls", [np.datetime64, np.timedelta64]) + def test_from_scalar_datetimelike_mismatched(self, constructor, cls): + scalar = cls("NaT", "ns") + dtype = {np.datetime64: "m8[ns]", np.timedelta64: "M8[ns]"}[cls] + + if cls is np.datetime64: + msg1 = "Invalid type for timedelta scalar: " + else: + msg1 = " is not convertible to datetime" + msg = "|".join(["Cannot cast", msg1]) + + with pytest.raises(TypeError, match=msg): + constructor(scalar, dtype=dtype) + + scalar = cls(4, "ns") + with pytest.raises(TypeError, match=msg): + constructor(scalar, dtype=dtype) + + @pytest.mark.parametrize("cls", [datetime, np.datetime64]) + def test_from_out_of_bounds_ns_datetime( + self, constructor, cls, request, box, frame_or_series + ): + # scalar that won't fit in nanosecond dt64, but will fit in microsecond + scalar = datetime(9999, 1, 1) + exp_dtype = "M8[us]" # pydatetime objects default to this reso + + if cls is np.datetime64: + scalar = np.datetime64(scalar, "D") + exp_dtype = "M8[s]" # closest reso to input + result = constructor(scalar) + + item = get1(result) + dtype = tm.get_dtype(result) + + assert type(item) is Timestamp + assert item.asm8.dtype == exp_dtype + assert dtype == exp_dtype + + def test_out_of_s_bounds_datetime64(self, constructor): + scalar = np.datetime64(np.iinfo(np.int64).max, "D") + result = constructor(scalar) + item = get1(result) + assert type(item) is np.datetime64 + dtype = tm.get_dtype(result) + assert dtype == object + + @pytest.mark.parametrize("cls", [timedelta, np.timedelta64]) + def test_from_out_of_bounds_ns_timedelta( + self, constructor, cls, box, frame_or_series + ): + scalar = datetime(9999, 1, 1) - datetime(1970, 1, 1) + exp_dtype = "m8[us]" # smallest reso that fits + if cls is np.timedelta64: + scalar = np.timedelta64(scalar, "D") + exp_dtype = "m8[s]" # closest reso to input + result = constructor(scalar) + + item = get1(result) + dtype = tm.get_dtype(result) + + assert type(item) is Timedelta + assert item.asm8.dtype == exp_dtype + assert dtype == exp_dtype + + @pytest.mark.parametrize("cls", [np.datetime64, np.timedelta64]) + def test_out_of_s_bounds_timedelta64(self, constructor, cls): + scalar = cls(np.iinfo(np.int64).max, "D") + result = constructor(scalar) + item = get1(result) + assert type(item) is cls + dtype = tm.get_dtype(result) + assert dtype == object + + def test_tzaware_data_tznaive_dtype(self, constructor, box, frame_or_series): + tz = "US/Eastern" + ts = Timestamp("2019", tz=tz) + + if box is None or (frame_or_series is DataFrame and box is dict): + msg = "Cannot unbox tzaware Timestamp to tznaive dtype" + err = TypeError + else: + msg = ( + "Cannot convert timezone-aware data to timezone-naive dtype. " + r"Use pd.Series\(values\).dt.tz_localize\(None\) instead." + ) + err = ValueError + + with pytest.raises(err, match=msg): + constructor(ts, dtype="M8[ns]") + + +# TODO: better location for this test? +class TestAllowNonNano: + # Until 2.0, we do not preserve non-nano dt64/td64 when passed as ndarray, + # but do preserve it when passed as DTA/TDA + + @pytest.fixture(params=[True, False]) + def as_td(self, request): + return request.param + + @pytest.fixture + def arr(self, as_td): + values = np.arange(5).astype(np.int64).view("M8[s]") + if as_td: + values = values - values[0] + return TimedeltaArray._simple_new(values, dtype=values.dtype) + else: + return DatetimeArray._simple_new(values, dtype=values.dtype) + + def test_index_allow_non_nano(self, arr): + idx = Index(arr) + assert idx.dtype == arr.dtype + + def test_dti_tdi_allow_non_nano(self, arr, as_td): + if as_td: + idx = pd.TimedeltaIndex(arr) + else: + idx = DatetimeIndex(arr) + assert idx.dtype == arr.dtype + + def test_series_allow_non_nano(self, arr): + ser = Series(arr) + assert ser.dtype == arr.dtype + + def test_frame_allow_non_nano(self, arr): + df = DataFrame(arr) + assert df.dtypes[0] == arr.dtype + + def test_frame_from_dict_allow_non_nano(self, arr): + df = DataFrame({0: arr}) + assert df.dtypes[0] == arr.dtype diff --git a/pandas/tests/frame/test_cumulative.py b/pandas/tests/frame/test_cumulative.py new file mode 100644 index 0000000000000000000000000000000000000000..ab217e1b1332a67d3d17c086232dcf18e91e2a6f --- /dev/null +++ b/pandas/tests/frame/test_cumulative.py @@ -0,0 +1,107 @@ +""" +Tests for DataFrame cumulative operations + +See also +-------- +tests.series.test_cumulative +""" + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, + Timestamp, +) +import pandas._testing as tm + + +class TestDataFrameCumulativeOps: + # --------------------------------------------------------------------- + # Cumulative Operations - cumsum, cummax, ... + + def test_cumulative_ops_smoke(self): + # it works + df = DataFrame({"A": np.arange(20)}, index=np.arange(20)) + df.cummax() + df.cummin() + df.cumsum() + + dm = DataFrame(np.arange(20).reshape(4, 5), index=range(4), columns=range(5)) + # TODO(wesm): do something with this? + dm.cumsum() + + def test_cumprod_smoke(self, datetime_frame): + datetime_frame.iloc[5:10, 0] = np.nan + datetime_frame.iloc[10:15, 1] = np.nan + datetime_frame.iloc[15:, 2] = np.nan + + # ints + df = datetime_frame.fillna(0).astype(int) + df.cumprod(0) + df.cumprod(1) + + # ints32 + df = datetime_frame.fillna(0).astype(np.int32) + df.cumprod(0) + df.cumprod(1) + + def test_cumulative_ops_match_series_apply( + self, datetime_frame, all_numeric_accumulations + ): + datetime_frame.iloc[5:10, 0] = np.nan + datetime_frame.iloc[10:15, 1] = np.nan + datetime_frame.iloc[15:, 2] = np.nan + + # axis = 0 + result = getattr(datetime_frame, all_numeric_accumulations)() + expected = datetime_frame.apply(getattr(Series, all_numeric_accumulations)) + tm.assert_frame_equal(result, expected) + + # axis = 1 + result = getattr(datetime_frame, all_numeric_accumulations)(axis=1) + expected = datetime_frame.apply( + getattr(Series, all_numeric_accumulations), axis=1 + ) + tm.assert_frame_equal(result, expected) + + # fix issue TODO: GH ref? + assert np.shape(result) == np.shape(datetime_frame) + + def test_cumsum_preserve_dtypes(self): + # GH#19296 dont incorrectly upcast to object + df = DataFrame({"A": [1, 2, 3], "B": [1, 2, 3.0], "C": [True, False, False]}) + + result = df.cumsum() + + expected = DataFrame( + { + "A": Series([1, 3, 6], dtype=np.int64), + "B": Series([1, 3, 6], dtype=np.float64), + "C": df["C"].cumsum(), + } + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("method", ["cumsum", "cumprod", "cummin", "cummax"]) + @pytest.mark.parametrize("axis", [0, 1]) + def test_numeric_only_flag(self, method, axis): + df = DataFrame( + { + "int": [1, 2, 3], + "bool": [True, False, False], + "string": ["a", "b", "c"], + "float": [1.0, 3.5, 4.0], + "datetime": [ + Timestamp(2018, 1, 1), + Timestamp(2019, 1, 1), + Timestamp(2020, 1, 1), + ], + } + ) + df_numeric_only = df.drop(["string", "datetime"], axis=1) + + result = getattr(df, method)(axis=axis, numeric_only=True) + expected = getattr(df_numeric_only, method)(axis) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/frame/test_iteration.py b/pandas/tests/frame/test_iteration.py new file mode 100644 index 0000000000000000000000000000000000000000..a1c23ff05f3e19aca490444216ec295453483e80 --- /dev/null +++ b/pandas/tests/frame/test_iteration.py @@ -0,0 +1,160 @@ +import datetime + +import numpy as np +import pytest + +from pandas.compat import ( + IS64, + is_platform_windows, +) + +from pandas import ( + Categorical, + DataFrame, + Series, + date_range, +) +import pandas._testing as tm + + +class TestIteration: + def test_keys(self, float_frame): + assert float_frame.keys() is float_frame.columns + + def test_iteritems(self): + df = DataFrame([[1, 2, 3], [4, 5, 6]], columns=["a", "a", "b"]) + for k, v in df.items(): + assert isinstance(v, DataFrame._constructor_sliced) + + def test_items(self): + # GH#17213, GH#13918 + cols = ["a", "b", "c"] + df = DataFrame([[1, 2, 3], [4, 5, 6]], columns=cols) + for c, (k, v) in zip(cols, df.items()): + assert c == k + assert isinstance(v, Series) + assert (df[k] == v).all() + + def test_items_names(self, float_string_frame): + for k, v in float_string_frame.items(): + assert v.name == k + + def test_iter(self, float_frame): + assert list(float_frame) == list(float_frame.columns) + + def test_iterrows(self, float_frame, float_string_frame): + for k, v in float_frame.iterrows(): + exp = float_frame.loc[k] + tm.assert_series_equal(v, exp) + + for k, v in float_string_frame.iterrows(): + exp = float_string_frame.loc[k] + tm.assert_series_equal(v, exp) + + def test_iterrows_iso8601(self): + # GH#19671 + s = DataFrame( + { + "non_iso8601": ["M1701", "M1802", "M1903", "M2004"], + "iso8601": date_range("2000-01-01", periods=4, freq="ME"), + } + ) + for k, v in s.iterrows(): + exp = s.loc[k] + tm.assert_series_equal(v, exp) + + def test_iterrows_corner(self): + # GH#12222 + df = DataFrame( + { + "a": [datetime.datetime(2015, 1, 1)], + "b": [None], + "c": [None], + "d": [""], + "e": [[]], + "f": [set()], + "g": [{}], + } + ) + expected = Series( + [datetime.datetime(2015, 1, 1), None, None, "", [], set(), {}], + index=list("abcdefg"), + name=0, + dtype="object", + ) + _, result = next(df.iterrows()) + tm.assert_series_equal(result, expected) + + def test_itertuples(self, float_frame): + for i, tup in enumerate(float_frame.itertuples()): + ser = DataFrame._constructor_sliced(tup[1:]) + ser.name = tup[0] + expected = float_frame.iloc[i, :].reset_index(drop=True) + tm.assert_series_equal(ser, expected) + + def test_itertuples_index_false(self): + df = DataFrame( + {"floats": np.random.default_rng(2).standard_normal(5), "ints": range(5)}, + columns=["floats", "ints"], + ) + + for tup in df.itertuples(index=False): + assert isinstance(tup[1], int) + + def test_itertuples_duplicate_cols(self): + df = DataFrame(data={"a": [1, 2, 3], "b": [4, 5, 6]}) + dfaa = df[["a", "a"]] + + assert list(dfaa.itertuples()) == [(0, 1, 1), (1, 2, 2), (2, 3, 3)] + + # repr with int on 32-bit/windows + if not (is_platform_windows() or not IS64): + assert ( + repr(list(df.itertuples(name=None))) + == "[(0, 1, 4), (1, 2, 5), (2, 3, 6)]" + ) + + def test_itertuples_tuple_name(self): + df = DataFrame(data={"a": [1, 2, 3], "b": [4, 5, 6]}) + tup = next(df.itertuples(name="TestName")) + assert tup._fields == ("Index", "a", "b") + assert (tup.Index, tup.a, tup.b) == tup + assert type(tup).__name__ == "TestName" + + def test_itertuples_disallowed_col_labels(self): + df = DataFrame(data={"def": [1, 2, 3], "return": [4, 5, 6]}) + tup2 = next(df.itertuples(name="TestName")) + assert tup2 == (0, 1, 4) + assert tup2._fields == ("Index", "_1", "_2") + + @pytest.mark.parametrize("limit", [254, 255, 1024]) + @pytest.mark.parametrize("index", [True, False]) + def test_itertuples_py2_3_field_limit_namedtuple(self, limit, index): + # GH#28282 + df = DataFrame([{f"foo_{i}": f"bar_{i}" for i in range(limit)}]) + result = next(df.itertuples(index=index)) + assert isinstance(result, tuple) + assert hasattr(result, "_fields") + + def test_sequence_like_with_categorical(self): + # GH#7839 + # make sure can iterate + df = DataFrame( + {"id": [1, 2, 3, 4, 5, 6], "raw_grade": ["a", "b", "b", "a", "a", "e"]} + ) + df["grade"] = Categorical(df["raw_grade"]) + + # basic sequencing testing + result = list(df.grade.values) + expected = np.array(df.grade.values).tolist() + tm.assert_almost_equal(result, expected) + + # iteration + for t in df.itertuples(index=False): + str(t) + + for row, s in df.iterrows(): + str(s) + + for c, col in df.items(): + str(col) diff --git a/pandas/tests/frame/test_logical_ops.py b/pandas/tests/frame/test_logical_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..251a7407edcdc16877f74ab8024a1c4dc64a730f --- /dev/null +++ b/pandas/tests/frame/test_logical_ops.py @@ -0,0 +1,211 @@ +import operator +import re + +import numpy as np +import pytest + +from pandas import ( + CategoricalIndex, + DataFrame, + Interval, + Series, + isnull, +) +import pandas._testing as tm + + +class TestDataFrameLogicalOperators: + # &, |, ^ + + @pytest.mark.parametrize( + "left, right, op, expected", + [ + ( + [True, False, np.nan], + [True, False, True], + operator.and_, + [True, False, False], + ), + ( + [True, False, True], + [True, False, np.nan], + operator.and_, + [True, False, False], + ), + ( + [True, False, np.nan], + [True, False, True], + operator.or_, + [True, False, False], + ), + ( + [True, False, True], + [True, False, np.nan], + operator.or_, + [True, False, True], + ), + ], + ) + def test_logical_operators_nans(self, left, right, op, expected, frame_or_series): + # GH#13896 + result = op(frame_or_series(left), frame_or_series(right)) + expected = frame_or_series(expected) + + tm.assert_equal(result, expected) + + def test_logical_ops_empty_frame(self): + # GH#5808 + # empty frames, non-mixed dtype + df = DataFrame(index=[1]) + + result = df & df + tm.assert_frame_equal(result, df) + + result = df | df + tm.assert_frame_equal(result, df) + + df2 = DataFrame(index=[1, 2]) + result = df & df2 + tm.assert_frame_equal(result, df2) + + dfa = DataFrame(index=[1], columns=["A"]) + + result = dfa & dfa + expected = DataFrame(False, index=[1], columns=["A"]) + tm.assert_frame_equal(result, expected) + + def test_logical_ops_bool_frame(self): + # GH#5808 + df1a_bool = DataFrame(True, index=[1], columns=["A"]) + + result = df1a_bool & df1a_bool + tm.assert_frame_equal(result, df1a_bool) + + result = df1a_bool | df1a_bool + tm.assert_frame_equal(result, df1a_bool) + + def test_logical_ops_int_frame(self): + # GH#5808 + df1a_int = DataFrame(1, index=[1], columns=["A"]) + df1a_bool = DataFrame(True, index=[1], columns=["A"]) + + result = df1a_int | df1a_bool + tm.assert_frame_equal(result, df1a_bool) + + # Check that this matches Series behavior + res_ser = df1a_int["A"] | df1a_bool["A"] + tm.assert_series_equal(res_ser, df1a_bool["A"]) + + def test_logical_ops_invalid(self, using_infer_string): + # GH#5808 + + df1 = DataFrame(1.0, index=[1], columns=["A"]) + df2 = DataFrame(True, index=[1], columns=["A"]) + msg = re.escape("unsupported operand type(s) for |: 'float' and 'bool'") + with pytest.raises(TypeError, match=msg): + df1 | df2 + + df1 = DataFrame("foo", index=[1], columns=["A"]) + df2 = DataFrame(True, index=[1], columns=["A"]) + if using_infer_string and df1["A"].dtype.storage == "pyarrow": + msg = "operation 'or_' not supported for dtype 'str'" + else: + msg = re.escape("unsupported operand type(s) for |: 'str' and 'bool'") + with pytest.raises(TypeError, match=msg): + df1 | df2 + + def test_logical_operators(self): + def _check_bin_op(op): + result = op(df1, df2) + expected = DataFrame( + op(df1.values, df2.values), index=df1.index, columns=df1.columns + ) + assert result.values.dtype == np.bool_ + tm.assert_frame_equal(result, expected) + + def _check_unary_op(op): + result = op(df1) + expected = DataFrame(op(df1.values), index=df1.index, columns=df1.columns) + assert result.values.dtype == np.bool_ + tm.assert_frame_equal(result, expected) + + df1 = { + "a": {"a": True, "b": False, "c": False, "d": True, "e": True}, + "b": {"a": False, "b": True, "c": False, "d": False, "e": False}, + "c": {"a": False, "b": False, "c": True, "d": False, "e": False}, + "d": {"a": True, "b": False, "c": False, "d": True, "e": True}, + "e": {"a": True, "b": False, "c": False, "d": True, "e": True}, + } + + df2 = { + "a": {"a": True, "b": False, "c": True, "d": False, "e": False}, + "b": {"a": False, "b": True, "c": False, "d": False, "e": False}, + "c": {"a": True, "b": False, "c": True, "d": False, "e": False}, + "d": {"a": False, "b": False, "c": False, "d": True, "e": False}, + "e": {"a": False, "b": False, "c": False, "d": False, "e": True}, + } + + df1 = DataFrame(df1) + df2 = DataFrame(df2) + + _check_bin_op(operator.and_) + _check_bin_op(operator.or_) + _check_bin_op(operator.xor) + + _check_unary_op(operator.inv) # TODO: belongs elsewhere + + def test_logical_with_nas(self): + d = DataFrame({"a": [np.nan, False], "b": [True, True]}) + + # GH4947 + # bool comparisons should return bool + result = d["a"] | d["b"] + expected = Series([False, True]) + tm.assert_series_equal(result, expected) + + # GH4604, automatic casting here + result = d["a"].fillna(False) | d["b"] + expected = Series([True, True]) + tm.assert_series_equal(result, expected) + result = d["a"].fillna(False) | d["b"] + expected = Series([True, True]) + tm.assert_series_equal(result, expected) + + def test_logical_ops_categorical_columns(self): + # GH#38367 + intervals = [Interval(1, 2), Interval(3, 4)] + data = DataFrame( + [[1, np.nan], [2, np.nan]], + columns=CategoricalIndex( + intervals, categories=[*intervals, Interval(5, 6)] + ), + ) + mask = DataFrame( + [[False, False], [False, False]], columns=data.columns, dtype=bool + ) + result = mask | isnull(data) + expected = DataFrame( + [[False, True], [False, True]], + columns=CategoricalIndex( + intervals, categories=[*intervals, Interval(5, 6)] + ), + ) + tm.assert_frame_equal(result, expected) + + def test_int_dtype_different_index_not_bool(self): + # GH 52500 + df1 = DataFrame([1, 2, 3], index=[10, 11, 23], columns=["a"]) + df2 = DataFrame([10, 20, 30], index=[11, 10, 23], columns=["a"]) + result = np.bitwise_xor(df1, df2) + expected = DataFrame([21, 8, 29], index=[10, 11, 23], columns=["a"]) + tm.assert_frame_equal(result, expected) + + result = df1 ^ df2 + tm.assert_frame_equal(result, expected) + + def test_different_dtypes_different_index_raises(self): + # GH 52538 + df1 = DataFrame([1, 2], index=["a", "b"]) + df2 = DataFrame([3, 4], index=["b", "c"]) + with pytest.raises(TypeError, match="unsupported operand type"): + df1 & df2 diff --git a/pandas/tests/frame/test_nonunique_indexes.py b/pandas/tests/frame/test_nonunique_indexes.py new file mode 100644 index 0000000000000000000000000000000000000000..1e9aa2325e880d1f6ef651d24f31b87c35bba5f9 --- /dev/null +++ b/pandas/tests/frame/test_nonunique_indexes.py @@ -0,0 +1,336 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Series, + date_range, +) +import pandas._testing as tm + + +class TestDataFrameNonuniqueIndexes: + def test_setattr_columns_vs_construct_with_columns(self): + # assignment + # GH 3687 + arr = np.random.default_rng(2).standard_normal((3, 2)) + idx = list(range(2)) + df = DataFrame(arr, columns=["A", "A"]) + df.columns = idx + expected = DataFrame(arr, columns=idx) + tm.assert_frame_equal(df, expected) + + def test_setattr_columns_vs_construct_with_columns_datetimeindx(self): + idx = date_range("20130101", periods=4, freq="QE-NOV") + df = DataFrame( + [[1, 1, 1, 5], [1, 1, 2, 5], [2, 1, 3, 5]], columns=["a", "a", "a", "a"] + ) + df.columns = idx + expected = DataFrame([[1, 1, 1, 5], [1, 1, 2, 5], [2, 1, 3, 5]], columns=idx) + tm.assert_frame_equal(df, expected) + + def test_insert_with_duplicate_columns(self): + # insert + df = DataFrame( + [[1, 1, 1, 5], [1, 1, 2, 5], [2, 1, 3, 5]], + columns=["foo", "bar", "foo", "hello"], + ) + df["string"] = "bah" + expected = DataFrame( + [[1, 1, 1, 5, "bah"], [1, 1, 2, 5, "bah"], [2, 1, 3, 5, "bah"]], + columns=["foo", "bar", "foo", "hello", "string"], + ) + tm.assert_frame_equal(df, expected) + with pytest.raises(ValueError, match="Length of value"): + df.insert(0, "AnotherColumn", range(len(df.index) - 1)) + + # insert same dtype + df["foo2"] = 3 + expected = DataFrame( + [[1, 1, 1, 5, "bah", 3], [1, 1, 2, 5, "bah", 3], [2, 1, 3, 5, "bah", 3]], + columns=["foo", "bar", "foo", "hello", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + # set (non-dup) + df["foo2"] = 4 + expected = DataFrame( + [[1, 1, 1, 5, "bah", 4], [1, 1, 2, 5, "bah", 4], [2, 1, 3, 5, "bah", 4]], + columns=["foo", "bar", "foo", "hello", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + df["foo2"] = 3 + + # delete (non dup) + del df["bar"] + expected = DataFrame( + [[1, 1, 5, "bah", 3], [1, 2, 5, "bah", 3], [2, 3, 5, "bah", 3]], + columns=["foo", "foo", "hello", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + # try to delete again (its not consolidated) + del df["hello"] + expected = DataFrame( + [[1, 1, "bah", 3], [1, 2, "bah", 3], [2, 3, "bah", 3]], + columns=["foo", "foo", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + # consolidate + df = df._consolidate() + expected = DataFrame( + [[1, 1, "bah", 3], [1, 2, "bah", 3], [2, 3, "bah", 3]], + columns=["foo", "foo", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + # insert + df.insert(2, "new_col", 5.0) + expected = DataFrame( + [[1, 1, 5.0, "bah", 3], [1, 2, 5.0, "bah", 3], [2, 3, 5.0, "bah", 3]], + columns=["foo", "foo", "new_col", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + # insert a dup + with pytest.raises(ValueError, match="cannot insert"): + df.insert(2, "new_col", 4.0) + + df.insert(2, "new_col", 4.0, allow_duplicates=True) + expected = DataFrame( + [ + [1, 1, 4.0, 5.0, "bah", 3], + [1, 2, 4.0, 5.0, "bah", 3], + [2, 3, 4.0, 5.0, "bah", 3], + ], + columns=["foo", "foo", "new_col", "new_col", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + # delete (dup) + del df["foo"] + expected = DataFrame( + [[4.0, 5.0, "bah", 3], [4.0, 5.0, "bah", 3], [4.0, 5.0, "bah", 3]], + columns=["new_col", "new_col", "string", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + def test_dup_across_dtypes(self): + # dup across dtypes + df = DataFrame( + [[1, 1, 1.0, 5], [1, 1, 2.0, 5], [2, 1, 3.0, 5]], + columns=["foo", "bar", "foo", "hello"], + ) + + df["foo2"] = 7.0 + expected = DataFrame( + [[1, 1, 1.0, 5, 7.0], [1, 1, 2.0, 5, 7.0], [2, 1, 3.0, 5, 7.0]], + columns=["foo", "bar", "foo", "hello", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + result = df["foo"] + expected = DataFrame([[1, 1.0], [1, 2.0], [2, 3.0]], columns=["foo", "foo"]) + tm.assert_frame_equal(result, expected) + + # multiple replacements + df["foo"] = "string" + expected = DataFrame( + [ + ["string", 1, "string", 5, 7.0], + ["string", 1, "string", 5, 7.0], + ["string", 1, "string", 5, 7.0], + ], + columns=["foo", "bar", "foo", "hello", "foo2"], + ) + tm.assert_frame_equal(df, expected) + + del df["foo"] + expected = DataFrame( + [[1, 5, 7.0], [1, 5, 7.0], [1, 5, 7.0]], columns=["bar", "hello", "foo2"] + ) + tm.assert_frame_equal(df, expected) + + def test_column_dups_indexes(self): + # check column dups with index equal and not equal to df's index + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), + index=["a", "b", "c", "d", "e"], + columns=["A", "B", "A"], + ) + for index in [df.index, pd.Index(list("edcba"))]: + this_df = df.copy() + expected_ser = Series(index.values, index=this_df.index) + expected_df = DataFrame( + {"A": expected_ser, "B": this_df["B"]}, + columns=["A", "B", "A"], + ) + this_df["A"] = index + tm.assert_frame_equal(this_df, expected_df) + + def test_changing_dtypes_with_duplicate_columns(self): + # multiple assignments that change dtypes + # the location indexer is a slice + # GH 6120 + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=["that", "that"] + ) + expected = DataFrame(1.0, index=range(5), columns=["that", "that"]) + + df["that"] = 1.0 + tm.assert_frame_equal(df, expected) + + df = DataFrame( + np.random.default_rng(2).random((5, 2)), columns=["that", "that"] + ) + expected = DataFrame(1, index=range(5), columns=["that", "that"]) + + df["that"] = 1 + tm.assert_frame_equal(df, expected) + + def test_dup_columns_comparisons(self): + # equality + df1 = DataFrame([[1, 2], [2, np.nan], [3, 4], [4, 4]], columns=["A", "B"]) + df2 = DataFrame([[0, 1], [2, 4], [2, np.nan], [4, 5]], columns=["A", "A"]) + + # not-comparing like-labelled + msg = ( + r"Can only compare identically-labeled \(both index and columns\) " + "DataFrame objects" + ) + with pytest.raises(ValueError, match=msg): + df1 == df2 + + df1r = df1.reindex_like(df2) + result = df1r == df2 + expected = DataFrame( + [[False, True], [True, False], [False, False], [True, False]], + columns=["A", "A"], + ) + tm.assert_frame_equal(result, expected) + + def test_mixed_column_selection(self): + # mixed column selection + # GH 5639 + dfbool = DataFrame( + { + "one": Series([True, True, False], index=["a", "b", "c"]), + "two": Series([False, False, True, False], index=["a", "b", "c", "d"]), + "three": Series([False, True, True, True], index=["a", "b", "c", "d"]), + } + ) + expected = pd.concat([dfbool["one"], dfbool["three"], dfbool["one"]], axis=1) + result = dfbool[["one", "three", "one"]] + tm.assert_frame_equal(result, expected) + + def test_multi_axis_dups(self): + # multi-axis dups + # GH 6121 + df = DataFrame( + np.arange(25.0).reshape(5, 5), + index=["a", "b", "c", "d", "e"], + columns=["A", "B", "C", "D", "E"], + ) + z = df[["A", "C", "A"]].copy() + expected = z.loc[["a", "c", "a"]] + + df = DataFrame( + np.arange(25.0).reshape(5, 5), + index=["a", "b", "c", "d", "e"], + columns=["A", "B", "C", "D", "E"], + ) + z = df[["A", "C", "A"]] + result = z.loc[["a", "c", "a"]] + tm.assert_frame_equal(result, expected) + + def test_columns_with_dups(self): + # GH 3468 related + + # basic + df = DataFrame([[1, 2]], columns=["a", "a"]) + df.columns = ["a", "a.1"] + expected = DataFrame([[1, 2]], columns=["a", "a.1"]) + tm.assert_frame_equal(df, expected) + + df = DataFrame([[1, 2, 3]], columns=["b", "a", "a"]) + df.columns = ["b", "a", "a.1"] + expected = DataFrame([[1, 2, 3]], columns=["b", "a", "a.1"]) + tm.assert_frame_equal(df, expected) + + def test_columns_with_dup_index(self): + # with a dup index + df = DataFrame([[1, 2]], columns=["a", "a"]) + df.columns = ["b", "b"] + expected = DataFrame([[1, 2]], columns=["b", "b"]) + tm.assert_frame_equal(df, expected) + + def test_multi_dtype(self): + # multi-dtype + df = DataFrame( + [[1, 2, 1.0, 2.0, 3.0, "foo", "bar"]], + columns=["a", "a", "b", "b", "d", "c", "c"], + ) + df.columns = list("ABCDEFG") + expected = DataFrame( + [[1, 2, 1.0, 2.0, 3.0, "foo", "bar"]], columns=list("ABCDEFG") + ) + tm.assert_frame_equal(df, expected) + + def test_multi_dtype2(self): + df = DataFrame([[1, 2, "foo", "bar"]], columns=["a", "a", "a", "a"]) + df.columns = ["a", "a.1", "a.2", "a.3"] + expected = DataFrame([[1, 2, "foo", "bar"]], columns=["a", "a.1", "a.2", "a.3"]) + tm.assert_frame_equal(df, expected) + + def test_dups_across_blocks(self): + # dups across blocks + df_float = DataFrame( + np.random.default_rng(2).standard_normal((10, 3)), dtype="float64" + ) + df_int = DataFrame( + np.random.default_rng(2).standard_normal((10, 3)).astype("int64") + ) + df_bool = DataFrame(True, index=df_float.index, columns=df_float.columns) + df_object = DataFrame("foo", index=df_float.index, columns=df_float.columns) + df_dt = DataFrame( + pd.Timestamp("20010101"), index=df_float.index, columns=df_float.columns + ) + df = pd.concat([df_float, df_int, df_bool, df_object, df_dt], axis=1) + + assert len(df._mgr.blknos) == len(df.columns) + assert len(df._mgr.blklocs) == len(df.columns) + + # testing iloc + for i in range(len(df.columns)): + df.iloc[:, i] + + def test_dup_columns_across_dtype(self): + # dup columns across dtype GH 2079/2194 + vals = [[1, -1, 2.0], [2, -2, 3.0]] + rs = DataFrame(vals, columns=["A", "A", "B"]) + xp = DataFrame(vals) + xp.columns = ["A", "A", "B"] + tm.assert_frame_equal(rs, xp) + + def test_set_value_by_index(self): + # See gh-12344 + warn = None + msg = "will attempt to set the values inplace" + + df = DataFrame(np.arange(9).reshape(3, 3).T) + df.columns = list("AAA") + expected = df.iloc[:, 2].copy() + + with tm.assert_produces_warning(warn, match=msg): + df.iloc[:, 0] = 3 + tm.assert_series_equal(df.iloc[:, 2], expected) + + df = DataFrame(np.arange(9).reshape(3, 3).T) + df.columns = [2, float(2), str(2)] + expected = df.iloc[:, 1].copy() + + with tm.assert_produces_warning(warn, match=msg): + df.iloc[:, 0] = 3 + tm.assert_series_equal(df.iloc[:, 1], expected) diff --git a/pandas/tests/frame/test_npfuncs.py b/pandas/tests/frame/test_npfuncs.py new file mode 100644 index 0000000000000000000000000000000000000000..e9a241202d15696b0e91b6ad96546fa967471b29 --- /dev/null +++ b/pandas/tests/frame/test_npfuncs.py @@ -0,0 +1,84 @@ +""" +Tests for np.foo applied to DataFrame, not necessarily ufuncs. +""" + +import numpy as np + +from pandas import ( + Categorical, + DataFrame, +) +import pandas._testing as tm + + +class TestAsArray: + def test_asarray_homogeneous(self): + df = DataFrame({"A": Categorical([1, 2]), "B": Categorical([1, 2])}) + result = np.asarray(df) + # may change from object in the future + expected = np.array([[1, 1], [2, 2]], dtype="object") + tm.assert_numpy_array_equal(result, expected) + + def test_np_sqrt(self, float_frame): + with np.errstate(all="ignore"): + result = np.sqrt(float_frame) + assert isinstance(result, type(float_frame)) + assert result.index.is_(float_frame.index) + assert result.columns.is_(float_frame.columns) + + tm.assert_frame_equal(result, float_frame.apply(np.sqrt)) + + def test_sum_axis_behavior(self): + # GH#52042 df.sum(axis=None) now reduces over both axes, which gets + # called when we do np.sum(df) + + arr = np.random.default_rng(2).standard_normal((4, 3)) + df = DataFrame(arr) + + res = np.sum(df) + expected = df.to_numpy().sum(axis=None) + assert res == expected + + def test_np_ravel(self): + # GH26247 + arr = np.array( + [ + [0.11197053, 0.44361564, -0.92589452], + [0.05883648, -0.00948922, -0.26469934], + ] + ) + + result = np.ravel([DataFrame(batch.reshape(1, 3)) for batch in arr]) + expected = np.array( + [ + 0.11197053, + 0.44361564, + -0.92589452, + 0.05883648, + -0.00948922, + -0.26469934, + ] + ) + tm.assert_numpy_array_equal(result, expected) + + result = np.ravel(DataFrame(arr[0].reshape(1, 3), columns=["x1", "x2", "x3"])) + expected = np.array([0.11197053, 0.44361564, -0.92589452]) + tm.assert_numpy_array_equal(result, expected) + + result = np.ravel( + [ + DataFrame(batch.reshape(1, 3), columns=["x1", "x2", "x3"]) + for batch in arr + ] + ) + expected = np.array( + [ + 0.11197053, + 0.44361564, + -0.92589452, + 0.05883648, + -0.00948922, + -0.26469934, + ] + ) + tm.assert_numpy_array_equal(result, expected) diff --git a/pandas/tests/frame/test_query_eval.py b/pandas/tests/frame/test_query_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..8ccc3af674c09ccf82a75fbf3bf3e41f89fd0dea --- /dev/null +++ b/pandas/tests/frame/test_query_eval.py @@ -0,0 +1,1609 @@ +import operator +from tokenize import TokenError + +import numpy as np +import pytest + +from pandas.errors import ( + NumExprClobberingError, + UndefinedVariableError, +) +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + date_range, +) +import pandas._testing as tm +from pandas.core.computation.check import NUMEXPR_INSTALLED + + +@pytest.fixture(params=["python", "pandas"], ids=lambda x: x) +def parser(request): + return request.param + + +@pytest.fixture( + params=["python", pytest.param("numexpr", marks=td.skip_if_no("numexpr"))], + ids=lambda x: x, +) +def engine(request): + return request.param + + +def skip_if_no_pandas_parser(parser): + if parser != "pandas": + pytest.skip(f"cannot evaluate with parser={parser}") + + +class TestCompat: + @pytest.fixture + def df(self): + return DataFrame({"A": [1, 2, 3]}) + + @pytest.fixture + def expected1(self, df): + return df[df.A > 0] + + @pytest.fixture + def expected2(self, df): + return df.A + 1 + + def test_query_default(self, df, expected1, expected2): + # GH 12749 + # this should always work, whether NUMEXPR_INSTALLED or not + result = df.query("A>0") + tm.assert_frame_equal(result, expected1) + result = df.eval("A+1") + tm.assert_series_equal(result, expected2) + + def test_query_None(self, df, expected1, expected2): + result = df.query("A>0", engine=None) + tm.assert_frame_equal(result, expected1) + result = df.eval("A+1", engine=None) + tm.assert_series_equal(result, expected2) + + def test_query_python(self, df, expected1, expected2): + result = df.query("A>0", engine="python") + tm.assert_frame_equal(result, expected1) + result = df.eval("A+1", engine="python") + tm.assert_series_equal(result, expected2) + + def test_query_numexpr(self, df, expected1, expected2): + if NUMEXPR_INSTALLED: + result = df.query("A>0", engine="numexpr") + tm.assert_frame_equal(result, expected1) + result = df.eval("A+1", engine="numexpr") + tm.assert_series_equal(result, expected2) + else: + msg = ( + r"'numexpr' is not installed or an unsupported version. " + r"Cannot use engine='numexpr' for query/eval if 'numexpr' is " + r"not installed" + ) + with pytest.raises(ImportError, match=msg): + df.query("A>0", engine="numexpr") + with pytest.raises(ImportError, match=msg): + df.eval("A+1", engine="numexpr") + + +class TestDataFrameEval: + # smaller hits python, larger hits numexpr + @pytest.mark.parametrize("n", [4, 4000]) + @pytest.mark.parametrize( + "op_str,op,rop", + [ + ("+", "__add__", "__radd__"), + ("-", "__sub__", "__rsub__"), + ("*", "__mul__", "__rmul__"), + ("/", "__truediv__", "__rtruediv__"), + ], + ) + def test_ops(self, op_str, op, rop, n): + # tst ops and reversed ops in evaluation + # GH7198 + + df = DataFrame(1, index=range(n), columns=list("abcd")) + df.iloc[0] = 2 + m = df.mean() + + base = DataFrame( # noqa: F841 + np.tile(m.values, n).reshape(n, -1), columns=list("abcd") + ) + + expected = eval(f"base {op_str} df") + + # ops as strings + result = eval(f"m {op_str} df") + tm.assert_frame_equal(result, expected) + + # these are commutative + if op in ["+", "*"]: + result = getattr(df, op)(m) + tm.assert_frame_equal(result, expected) + + # these are not + elif op in ["-", "/"]: + result = getattr(df, rop)(m) + tm.assert_frame_equal(result, expected) + + def test_dataframe_sub_numexpr_path(self): + # GH7192: Note we need a large number of rows to ensure this + # goes through the numexpr path + df = DataFrame({"A": np.random.default_rng(2).standard_normal(25000)}) + df.iloc[0:5] = np.nan + expected = 1 - np.isnan(df.iloc[0:25]) + result = (1 - np.isnan(df)).iloc[0:25] + tm.assert_frame_equal(result, expected) + + def test_query_non_str(self): + # GH 11485 + df = DataFrame({"A": [1, 2, 3], "B": ["a", "b", "b"]}) + + msg = "expr must be a string to be evaluated" + with pytest.raises(ValueError, match=msg): + df.query(lambda x: x.B == "b") + + with pytest.raises(ValueError, match=msg): + df.query(111) + + def test_query_empty_string(self): + # GH 13139 + df = DataFrame({"A": [1, 2, 3]}) + + msg = "expr cannot be an empty string" + with pytest.raises(ValueError, match=msg): + df.query("") + + def test_query_duplicate_column_name(self, engine, parser): + df = DataFrame({"A": range(3), "B": range(3), "C": range(3)}).rename( + columns={"B": "A"} + ) + + res = df.query("C == 1", engine=engine, parser=parser) + + expect = DataFrame([[1, 1, 1]], columns=["A", "A", "C"], index=[1]) + + tm.assert_frame_equal(res, expect) + + def test_eval_resolvers_as_list(self): + # GH 14095 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 2)), columns=list("ab") + ) + dict1 = {"a": 1} + dict2 = {"b": 2} + assert df.eval("a + b", resolvers=[dict1, dict2]) == dict1["a"] + dict2["b"] + assert pd.eval("a + b", resolvers=[dict1, dict2]) == dict1["a"] + dict2["b"] + + def test_eval_resolvers_combined(self): + # GH 34966 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 2)), columns=list("ab") + ) + dict1 = {"c": 2} + + # Both input and default index/column resolvers should be usable + result = df.eval("a + b * c", resolvers=[dict1]) + + expected = df["a"] + df["b"] * dict1["c"] + tm.assert_series_equal(result, expected) + + def test_eval_object_dtype_binop(self): + # GH#24883 + df = DataFrame({"a1": ["Y", "N"]}) + res = df.eval("c = ((a1 == 'Y') & True)") + expected = DataFrame({"a1": ["Y", "N"], "c": [True, False]}) + tm.assert_frame_equal(res, expected) + + def test_using_numpy(self, engine, parser): + # GH 58041 + skip_if_no_pandas_parser(parser) + df = Series([0.2, 1.5, 2.8], name="a").to_frame() + res = df.eval("@np.floor(a)", engine=engine, parser=parser) + expected = np.floor(df["a"]) + tm.assert_series_equal(expected, res) + + def test_eval_simple(self, engine, parser): + df = Series([0.2, 1.5, 2.8], name="a").to_frame() + res = df.eval("a", engine=engine, parser=parser) + expected = df["a"] + tm.assert_series_equal(expected, res) + + def test_extension_array_eval(self, engine, parser, request): + # GH#58748 + if engine == "numexpr": + mark = pytest.mark.xfail( + reason="numexpr does not support extension array dtypes" + ) + request.applymarker(mark) + df = DataFrame({"a": pd.array([1, 2, 3]), "b": pd.array([4, 5, 6])}) + result = df.eval("a / b", engine=engine, parser=parser) + expected = Series(pd.array([0.25, 0.40, 0.50])) + tm.assert_series_equal(result, expected) + + def test_complex_eval(self, engine, parser): + # GH#21374 + df = DataFrame({"a": [1 + 2j], "b": [1 + 1j]}) + result = df.eval("a/b", engine=engine, parser=parser) + expected = Series([1.5 + 0.5j]) + tm.assert_series_equal(result, expected) + + +class TestDataFrameQueryWithMultiIndex: + def test_query_with_named_multiindex(self, parser, engine): + skip_if_no_pandas_parser(parser) + a = np.random.default_rng(2).choice(["red", "green"], size=10) + b = np.random.default_rng(2).choice(["eggs", "ham"], size=10) + index = MultiIndex.from_arrays([a, b], names=["color", "food"]) + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2)), index=index) + ind = Series( + df.index.get_level_values("color").values, index=index, name="color" + ) + + # equality + res1 = df.query('color == "red"', parser=parser, engine=engine) + res2 = df.query('"red" == color', parser=parser, engine=engine) + exp = df[ind == "red"] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # inequality + res1 = df.query('color != "red"', parser=parser, engine=engine) + res2 = df.query('"red" != color', parser=parser, engine=engine) + exp = df[ind != "red"] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # list equality (really just set membership) + res1 = df.query('color == ["red"]', parser=parser, engine=engine) + res2 = df.query('["red"] == color', parser=parser, engine=engine) + exp = df[ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + res1 = df.query('color != ["red"]', parser=parser, engine=engine) + res2 = df.query('["red"] != color', parser=parser, engine=engine) + exp = df[~ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # in/not in ops + res1 = df.query('["red"] in color', parser=parser, engine=engine) + res2 = df.query('"red" in color', parser=parser, engine=engine) + exp = df[ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + res1 = df.query('["red"] not in color', parser=parser, engine=engine) + res2 = df.query('"red" not in color', parser=parser, engine=engine) + exp = df[~ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + def test_query_with_unnamed_multiindex(self, parser, engine): + skip_if_no_pandas_parser(parser) + a = np.random.default_rng(2).choice(["red", "green"], size=10) + b = np.random.default_rng(2).choice(["eggs", "ham"], size=10) + index = MultiIndex.from_arrays([a, b]) + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2)), index=index) + ind = Series(df.index.get_level_values(0).values, index=index) + + res1 = df.query('ilevel_0 == "red"', parser=parser, engine=engine) + res2 = df.query('"red" == ilevel_0', parser=parser, engine=engine) + exp = df[ind == "red"] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # inequality + res1 = df.query('ilevel_0 != "red"', parser=parser, engine=engine) + res2 = df.query('"red" != ilevel_0', parser=parser, engine=engine) + exp = df[ind != "red"] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # list equality (really just set membership) + res1 = df.query('ilevel_0 == ["red"]', parser=parser, engine=engine) + res2 = df.query('["red"] == ilevel_0', parser=parser, engine=engine) + exp = df[ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + res1 = df.query('ilevel_0 != ["red"]', parser=parser, engine=engine) + res2 = df.query('["red"] != ilevel_0', parser=parser, engine=engine) + exp = df[~ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # in/not in ops + res1 = df.query('["red"] in ilevel_0', parser=parser, engine=engine) + res2 = df.query('"red" in ilevel_0', parser=parser, engine=engine) + exp = df[ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + res1 = df.query('["red"] not in ilevel_0', parser=parser, engine=engine) + res2 = df.query('"red" not in ilevel_0', parser=parser, engine=engine) + exp = df[~ind.isin(["red"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # ## LEVEL 1 + ind = Series(df.index.get_level_values(1).values, index=index) + res1 = df.query('ilevel_1 == "eggs"', parser=parser, engine=engine) + res2 = df.query('"eggs" == ilevel_1', parser=parser, engine=engine) + exp = df[ind == "eggs"] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # inequality + res1 = df.query('ilevel_1 != "eggs"', parser=parser, engine=engine) + res2 = df.query('"eggs" != ilevel_1', parser=parser, engine=engine) + exp = df[ind != "eggs"] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # list equality (really just set membership) + res1 = df.query('ilevel_1 == ["eggs"]', parser=parser, engine=engine) + res2 = df.query('["eggs"] == ilevel_1', parser=parser, engine=engine) + exp = df[ind.isin(["eggs"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + res1 = df.query('ilevel_1 != ["eggs"]', parser=parser, engine=engine) + res2 = df.query('["eggs"] != ilevel_1', parser=parser, engine=engine) + exp = df[~ind.isin(["eggs"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + # in/not in ops + res1 = df.query('["eggs"] in ilevel_1', parser=parser, engine=engine) + res2 = df.query('"eggs" in ilevel_1', parser=parser, engine=engine) + exp = df[ind.isin(["eggs"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + res1 = df.query('["eggs"] not in ilevel_1', parser=parser, engine=engine) + res2 = df.query('"eggs" not in ilevel_1', parser=parser, engine=engine) + exp = df[~ind.isin(["eggs"])] + tm.assert_frame_equal(res1, exp) + tm.assert_frame_equal(res2, exp) + + def test_query_with_partially_named_multiindex(self, parser, engine): + skip_if_no_pandas_parser(parser) + a = np.random.default_rng(2).choice(["red", "green"], size=10) + b = np.arange(10) + index = MultiIndex.from_arrays([a, b]) + index.names = [None, "rating"] + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2)), index=index) + res = df.query("rating == 1", parser=parser, engine=engine) + ind = Series( + df.index.get_level_values("rating").values, index=index, name="rating" + ) + exp = df[ind == 1] + tm.assert_frame_equal(res, exp) + + res = df.query("rating != 1", parser=parser, engine=engine) + ind = Series( + df.index.get_level_values("rating").values, index=index, name="rating" + ) + exp = df[ind != 1] + tm.assert_frame_equal(res, exp) + + res = df.query('ilevel_0 == "red"', parser=parser, engine=engine) + ind = Series(df.index.get_level_values(0).values, index=index) + exp = df[ind == "red"] + tm.assert_frame_equal(res, exp) + + res = df.query('ilevel_0 != "red"', parser=parser, engine=engine) + ind = Series(df.index.get_level_values(0).values, index=index) + exp = df[ind != "red"] + tm.assert_frame_equal(res, exp) + + def test_query_multiindex_get_index_resolvers(self): + df = DataFrame( + np.ones((10, 3)), + index=MultiIndex.from_arrays( + [range(10) for _ in range(2)], names=["spam", "eggs"] + ), + ) + resolvers = df._get_index_resolvers() + + def to_series(mi, level): + level_values = mi.get_level_values(level) + s = level_values.to_series() + s.index = mi + return s + + col_series = df.columns.to_series() + expected = { + "index": df.index, + "columns": col_series, + "spam": to_series(df.index, "spam"), + "eggs": to_series(df.index, "eggs"), + "clevel_0": col_series, + } + for k, v in resolvers.items(): + if isinstance(v, Index): + assert v.is_(expected[k]) + elif isinstance(v, Series): + tm.assert_series_equal(v, expected[k]) + else: + raise AssertionError("object must be a Series or Index") + + +@td.skip_if_no("numexpr") +class TestDataFrameQueryNumExprPandas: + @pytest.fixture + def engine(self): + return "numexpr" + + @pytest.fixture + def parser(self): + return "pandas" + + def test_date_query_with_attribute_access(self, engine, parser): + skip_if_no_pandas_parser(parser) + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + df["dates1"] = date_range("1/1/2012", periods=5) + df["dates2"] = date_range("1/1/2013", periods=5) + df["dates3"] = date_range("1/1/2014", periods=5) + res = df.query( + "@df.dates1 < 20130101 < @df.dates3", engine=engine, parser=parser + ) + expec = df[(df.dates1 < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_query_no_attribute_access(self, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + df["dates1"] = date_range("1/1/2012", periods=5) + df["dates2"] = date_range("1/1/2013", periods=5) + df["dates3"] = date_range("1/1/2014", periods=5) + res = df.query("dates1 < 20130101 < dates3", engine=engine, parser=parser) + expec = df[(df.dates1 < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_query_with_NaT(self, engine, parser): + n = 10 + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates2"] = date_range("1/1/2013", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + df.loc[np.random.default_rng(2).random(n) > 0.5, "dates1"] = pd.NaT + df.loc[np.random.default_rng(2).random(n) > 0.5, "dates3"] = pd.NaT + res = df.query("dates1 < 20130101 < dates3", engine=engine, parser=parser) + expec = df[(df.dates1 < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_index_query(self, engine, parser): + n = 10 + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + return_value = df.set_index("dates1", inplace=True, drop=True) + assert return_value is None + res = df.query("index < 20130101 < dates3", engine=engine, parser=parser) + expec = df[(df.index < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_index_query_with_NaT(self, engine, parser): + n = 10 + # Cast to object to avoid implicit cast when setting entry to pd.NaT below + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))).astype( + {0: object} + ) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + df.iloc[0, 0] = pd.NaT + return_value = df.set_index("dates1", inplace=True, drop=True) + assert return_value is None + res = df.query("index < 20130101 < dates3", engine=engine, parser=parser) + expec = df[(df.index < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_index_query_with_NaT_duplicates(self, engine, parser): + n = 10 + d = {} + d["dates1"] = date_range("1/1/2012", periods=n) + d["dates3"] = date_range("1/1/2014", periods=n) + df = DataFrame(d) + df.loc[np.random.default_rng(2).random(n) > 0.5, "dates1"] = pd.NaT + return_value = df.set_index("dates1", inplace=True, drop=True) + assert return_value is None + res = df.query("dates1 < 20130101 < dates3", engine=engine, parser=parser) + expec = df[(df.index.to_series() < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_query_with_non_date(self, engine, parser): + n = 10 + df = DataFrame( + { + "dates": date_range("1/1/2012", periods=n, unit="ns"), + "nondate": np.arange(n), + } + ) + + result = df.query("dates == nondate", parser=parser, engine=engine) + assert len(result) == 0 + + result = df.query("dates != nondate", parser=parser, engine=engine) + tm.assert_frame_equal(result, df) + + msg = r"Invalid comparison between dtype=datetime64\[ns\] and ndarray" + for op in ["<", ">", "<=", ">="]: + with pytest.raises(TypeError, match=msg): + df.query(f"dates {op} nondate", parser=parser, engine=engine) + + def test_query_syntax_error(self, engine, parser): + df = DataFrame({"i": range(10), "+": range(3, 13), "r": range(4, 14)}) + msg = "invalid syntax" + with pytest.raises(SyntaxError, match=msg): + df.query("i - +", engine=engine, parser=parser) + + def test_query_scope(self, engine, parser): + skip_if_no_pandas_parser(parser) + + df = DataFrame( + np.random.default_rng(2).standard_normal((20, 2)), columns=list("ab") + ) + + a, b = 1, 2 # noqa: F841 + res = df.query("a > b", engine=engine, parser=parser) + expected = df[df.a > df.b] + tm.assert_frame_equal(res, expected) + + res = df.query("@a > b", engine=engine, parser=parser) + expected = df[a > df.b] + tm.assert_frame_equal(res, expected) + + # no local variable c + with pytest.raises( + UndefinedVariableError, match="local variable 'c' is not defined" + ): + df.query("@a > b > @c", engine=engine, parser=parser) + + # no column named 'c' + with pytest.raises(UndefinedVariableError, match="name 'c' is not defined"): + df.query("@a > b > c", engine=engine, parser=parser) + + def test_query_doesnt_pickup_local(self, engine, parser): + n = m = 10 + df = DataFrame( + np.random.default_rng(2).integers(m, size=(n, 3)), columns=list("abc") + ) + + # we don't pick up the local 'sin' + with pytest.raises(UndefinedVariableError, match="name 'sin' is not defined"): + df.query("sin > 5", engine=engine, parser=parser) + + def test_query_builtin(self, engine, parser): + n = m = 10 + df = DataFrame( + np.random.default_rng(2).integers(m, size=(n, 3)), columns=list("abc") + ) + + df.index.name = "sin" + msg = "Variables in expression.+" + with pytest.raises(NumExprClobberingError, match=msg): + df.query("sin > 5", engine=engine, parser=parser) + + def test_query(self, engine, parser): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 3)), columns=["a", "b", "c"] + ) + + tm.assert_frame_equal( + df.query("a < b", engine=engine, parser=parser), df[df.a < df.b] + ) + tm.assert_frame_equal( + df.query("a + b > b * c", engine=engine, parser=parser), + df[df.a + df.b > df.b * df.c], + ) + + def test_query_index_with_name(self, engine, parser): + df = DataFrame( + np.random.default_rng(2).integers(10, size=(10, 3)), + index=Index(range(10), name="blob"), + columns=["a", "b", "c"], + ) + res = df.query("(blob < 5) & (a < b)", engine=engine, parser=parser) + expec = df[(df.index < 5) & (df.a < df.b)] + tm.assert_frame_equal(res, expec) + + res = df.query("blob < b", engine=engine, parser=parser) + expec = df[df.index < df.b] + + tm.assert_frame_equal(res, expec) + + def test_query_index_without_name(self, engine, parser): + df = DataFrame( + np.random.default_rng(2).integers(10, size=(10, 3)), + index=range(10), + columns=["a", "b", "c"], + ) + + # "index" should refer to the index + res = df.query("index < b", engine=engine, parser=parser) + expec = df[df.index < df.b] + tm.assert_frame_equal(res, expec) + + # test against a scalar + res = df.query("index < 5", engine=engine, parser=parser) + expec = df[df.index < 5] + tm.assert_frame_equal(res, expec) + + def test_nested_scope(self, engine, parser): + skip_if_no_pandas_parser(parser) + + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + df2 = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + expected = df[(df > 0) & (df2 > 0)] + + result = df.query("(@df > 0) & (@df2 > 0)", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + result = pd.eval("df[df > 0 and df2 > 0]", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + result = pd.eval( + "df[df > 0 and df2 > 0 and df[df > 0] > 0]", engine=engine, parser=parser + ) + expected = df[(df > 0) & (df2 > 0) & (df[df > 0] > 0)] + tm.assert_frame_equal(result, expected) + + result = pd.eval("df[(df>0) & (df2>0)]", engine=engine, parser=parser) + expected = df.query("(@df>0) & (@df2>0)", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + def test_nested_raises_on_local_self_reference(self, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + + # can't reference ourself b/c we're a local so @ is necessary + with pytest.raises(UndefinedVariableError, match="name 'df' is not defined"): + df.query("df > 0", engine=engine, parser=parser) + + def test_local_syntax(self, engine, parser): + skip_if_no_pandas_parser(parser) + + df = DataFrame( + np.random.default_rng(2).standard_normal((100, 10)), + columns=list("abcdefghij"), + ) + b = 1 + expect = df[df.a < b] + result = df.query("a < @b", engine=engine, parser=parser) + tm.assert_frame_equal(result, expect) + + expect = df[df.a < df.b] + result = df.query("a < b", engine=engine, parser=parser) + tm.assert_frame_equal(result, expect) + + def test_chained_cmp_and_in(self, engine, parser): + skip_if_no_pandas_parser(parser) + cols = list("abc") + df = DataFrame( + np.random.default_rng(2).standard_normal((100, len(cols))), columns=cols + ) + res = df.query( + "a < b < c and a not in b not in c", engine=engine, parser=parser + ) + ind = (df.a < df.b) & (df.b < df.c) & ~df.b.isin(df.a) & ~df.c.isin(df.b) + expec = df[ind] + tm.assert_frame_equal(res, expec) + + def test_local_variable_with_in(self, engine, parser): + skip_if_no_pandas_parser(parser) + a = Series(np.random.default_rng(2).integers(3, size=15), name="a") + b = Series(np.random.default_rng(2).integers(10, size=15), name="b") + df = DataFrame({"a": a, "b": b}) + + expected = df.loc[(df.b - 1).isin(a)] + result = df.query("b - 1 in a", engine=engine, parser=parser) + tm.assert_frame_equal(expected, result) + + b = Series(np.random.default_rng(2).integers(10, size=15), name="b") + expected = df.loc[(b - 1).isin(a)] + result = df.query("@b - 1 in a", engine=engine, parser=parser) + tm.assert_frame_equal(expected, result) + + def test_at_inside_string(self, engine, parser): + skip_if_no_pandas_parser(parser) + c = 1 # noqa: F841 + df = DataFrame({"a": ["a", "a", "b", "b", "@c", "@c"]}) + result = df.query('a == "@c"', engine=engine, parser=parser) + expected = df[df.a == "@c"] + tm.assert_frame_equal(result, expected) + + def test_query_undefined_local(self): + engine, parser = self.engine, self.parser + skip_if_no_pandas_parser(parser) + + df = DataFrame(np.random.default_rng(2).random((10, 2)), columns=list("ab")) + with pytest.raises( + UndefinedVariableError, match="local variable 'c' is not defined" + ): + df.query("a == @c", engine=engine, parser=parser) + + def test_index_resolvers_come_after_columns_with_the_same_name( + self, engine, parser + ): + n = 1 # noqa: F841 + a = np.r_[20:101:20] + + df = DataFrame( + {"index": a, "b": np.random.default_rng(2).standard_normal(a.size)} + ) + df.index.name = "index" + result = df.query("index > 5", engine=engine, parser=parser) + expected = df[df["index"] > 5] + tm.assert_frame_equal(result, expected) + + df = DataFrame( + {"index": a, "b": np.random.default_rng(2).standard_normal(a.size)} + ) + result = df.query("ilevel_0 > 5", engine=engine, parser=parser) + expected = df.loc[df.index[df.index > 5]] + tm.assert_frame_equal(result, expected) + + df = DataFrame({"a": a, "b": np.random.default_rng(2).standard_normal(a.size)}) + df.index.name = "a" + result = df.query("a > 5", engine=engine, parser=parser) + expected = df[df.a > 5] + tm.assert_frame_equal(result, expected) + + result = df.query("index > 5", engine=engine, parser=parser) + expected = df.loc[df.index[df.index > 5]] + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("op, f", [["==", operator.eq], ["!=", operator.ne]]) + def test_inf(self, op, f, engine, parser): + n = 10 + df = DataFrame( + { + "a": np.random.default_rng(2).random(n), + "b": np.random.default_rng(2).random(n), + } + ) + df.loc[::2, 0] = np.inf + q = f"a {op} inf" + expected = df[f(df.a, np.inf)] + result = df.query(q, engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + def test_check_tz_aware_index_query(self, tz_aware_fixture): + # https://github.com/pandas-dev/pandas/issues/29463 + tz = tz_aware_fixture + df_index = date_range( + start="2019-01-01", freq="1D", periods=10, tz=tz, name="time" + ) + expected = DataFrame(index=df_index) + df = DataFrame(index=df_index) + result = df.query('"2018-01-03 00:00:00+00" < time') + tm.assert_frame_equal(result, expected) + + expected = DataFrame(df_index) + result = df.reset_index().query('"2018-01-03 00:00:00+00" < time') + tm.assert_frame_equal(result, expected) + + def test_method_calls_in_query(self, engine, parser): + # https://github.com/pandas-dev/pandas/issues/22435 + n = 10 + df = DataFrame( + { + "a": 2 * np.random.default_rng(2).random(n), + "b": np.random.default_rng(2).random(n), + } + ) + expected = df[df["a"].astype("int") == 0] + result = df.query("a.astype('int') == 0", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + df = DataFrame( + { + "a": np.where( + np.random.default_rng(2).random(n) < 0.5, + np.nan, + np.random.default_rng(2).standard_normal(n), + ), + "b": np.random.default_rng(2).standard_normal(n), + } + ) + expected = df[df["a"].notnull()] + result = df.query("a.notnull()", engine=engine, parser=parser) + tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numexpr") +class TestDataFrameQueryNumExprPython(TestDataFrameQueryNumExprPandas): + @pytest.fixture + def engine(self): + return "numexpr" + + @pytest.fixture + def parser(self): + return "python" + + def test_date_query_no_attribute_access(self, engine, parser): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + df["dates1"] = date_range("1/1/2012", periods=5) + df["dates2"] = date_range("1/1/2013", periods=5) + df["dates3"] = date_range("1/1/2014", periods=5) + res = df.query( + "(dates1 < 20130101) & (20130101 < dates3)", engine=engine, parser=parser + ) + expec = df[(df.dates1 < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_query_with_NaT(self, engine, parser): + n = 10 + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates2"] = date_range("1/1/2013", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + df.loc[np.random.default_rng(2).random(n) > 0.5, "dates1"] = pd.NaT + df.loc[np.random.default_rng(2).random(n) > 0.5, "dates3"] = pd.NaT + res = df.query( + "(dates1 < 20130101) & (20130101 < dates3)", engine=engine, parser=parser + ) + expec = df[(df.dates1 < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_index_query(self, engine, parser): + n = 10 + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + return_value = df.set_index("dates1", inplace=True, drop=True) + assert return_value is None + res = df.query( + "(index < 20130101) & (20130101 < dates3)", engine=engine, parser=parser + ) + expec = df[(df.index < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_index_query_with_NaT(self, engine, parser): + n = 10 + # Cast to object to avoid implicit cast when setting entry to pd.NaT below + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))).astype( + {0: object} + ) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + df.iloc[0, 0] = pd.NaT + return_value = df.set_index("dates1", inplace=True, drop=True) + assert return_value is None + res = df.query( + "(index < 20130101) & (20130101 < dates3)", engine=engine, parser=parser + ) + expec = df[(df.index < "20130101") & ("20130101" < df.dates3)] + tm.assert_frame_equal(res, expec) + + def test_date_index_query_with_NaT_duplicates(self, engine, parser): + n = 10 + df = DataFrame(np.random.default_rng(2).standard_normal((n, 3))) + df["dates1"] = date_range("1/1/2012", periods=n) + df["dates3"] = date_range("1/1/2014", periods=n) + df.loc[np.random.default_rng(2).random(n) > 0.5, "dates1"] = pd.NaT + return_value = df.set_index("dates1", inplace=True, drop=True) + assert return_value is None + msg = r"'BoolOp' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + df.query("index < 20130101 < dates3", engine=engine, parser=parser) + + def test_nested_scope(self, engine, parser): + # smoke test + x = 1 # noqa: F841 + result = pd.eval("x + 1", engine=engine, parser=parser) + assert result == 2 + + df = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + df2 = DataFrame(np.random.default_rng(2).standard_normal((5, 3))) + + # don't have the pandas parser + msg = r"The '@' prefix is only supported by the pandas parser" + with pytest.raises(SyntaxError, match=msg): + df.query("(@df>0) & (@df2>0)", engine=engine, parser=parser) + + with pytest.raises(UndefinedVariableError, match="name 'df' is not defined"): + df.query("(df>0) & (df2>0)", engine=engine, parser=parser) + + expected = df[(df > 0) & (df2 > 0)] + result = pd.eval("df[(df > 0) & (df2 > 0)]", engine=engine, parser=parser) + tm.assert_frame_equal(expected, result) + + expected = df[(df > 0) & (df2 > 0) & (df[df > 0] > 0)] + result = pd.eval( + "df[(df > 0) & (df2 > 0) & (df[df > 0] > 0)]", engine=engine, parser=parser + ) + tm.assert_frame_equal(expected, result) + + def test_query_numexpr_with_min_and_max_columns(self): + df = DataFrame({"min": [1, 2, 3], "max": [4, 5, 6]}) + regex_to_match = ( + r"Variables in expression \"\(min\) == \(1\)\" " + r"overlap with builtins: \('min'\)" + ) + with pytest.raises(NumExprClobberingError, match=regex_to_match): + df.query("min == 1") + + regex_to_match = ( + r"Variables in expression \"\(max\) == \(1\)\" " + r"overlap with builtins: \('max'\)" + ) + with pytest.raises(NumExprClobberingError, match=regex_to_match): + df.query("max == 1") + + +class TestDataFrameQueryPythonPandas(TestDataFrameQueryNumExprPandas): + @pytest.fixture + def engine(self): + return "python" + + @pytest.fixture + def parser(self): + return "pandas" + + def test_query_builtin(self, engine, parser): + n = m = 10 + df = DataFrame( + np.random.default_rng(2).integers(m, size=(n, 3)), columns=list("abc") + ) + + df.index.name = "sin" + expected = df[df.index > 5] + result = df.query("sin > 5", engine=engine, parser=parser) + tm.assert_frame_equal(expected, result) + + +class TestDataFrameQueryPythonPython(TestDataFrameQueryNumExprPython): + @pytest.fixture + def engine(self): + return "python" + + @pytest.fixture + def parser(self): + return "python" + + def test_query_builtin(self, engine, parser): + n = m = 10 + df = DataFrame( + np.random.default_rng(2).integers(m, size=(n, 3)), columns=list("abc") + ) + + df.index.name = "sin" + expected = df[df.index > 5] + result = df.query("sin > 5", engine=engine, parser=parser) + tm.assert_frame_equal(expected, result) + + +class TestDataFrameQueryStrings: + def test_str_query_method(self, parser, engine): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 1)), columns=["b"]) + df["strings"] = Series(list("aabbccddee")) + expect = df[df.strings == "a"] + + if parser != "pandas": + col = "strings" + lst = '"a"' + + lhs = [col] * 2 + [lst] * 2 + rhs = lhs[::-1] + + eq, ne = "==", "!=" + ops = 2 * ([eq, ne]) + msg = r"'(Not)?In' nodes are not implemented" + + for lh, op_, rh in zip(lhs, ops, rhs): + ex = f"{lh} {op_} {rh}" + with pytest.raises(NotImplementedError, match=msg): + df.query( + ex, + engine=engine, + parser=parser, + local_dict={"strings": df.strings}, + ) + else: + res = df.query('"a" == strings', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + + res = df.query('strings == "a"', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + tm.assert_frame_equal(res, df[df.strings.isin(["a"])]) + + expect = df[df.strings != "a"] + res = df.query('strings != "a"', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + + res = df.query('"a" != strings', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + tm.assert_frame_equal(res, df[~df.strings.isin(["a"])]) + + def test_str_list_query_method(self, parser, engine): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 1)), columns=["b"]) + df["strings"] = Series(list("aabbccddee")) + expect = df[df.strings.isin(["a", "b"])] + + if parser != "pandas": + col = "strings" + lst = '["a", "b"]' + + lhs = [col] * 2 + [lst] * 2 + rhs = lhs[::-1] + + eq, ne = "==", "!=" + ops = 2 * ([eq, ne]) + msg = r"'(Not)?In' nodes are not implemented" + + for lh, ops_, rh in zip(lhs, ops, rhs): + ex = f"{lh} {ops_} {rh}" + with pytest.raises(NotImplementedError, match=msg): + df.query(ex, engine=engine, parser=parser) + else: + res = df.query('strings == ["a", "b"]', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + + res = df.query('["a", "b"] == strings', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + + expect = df[~df.strings.isin(["a", "b"])] + + res = df.query('strings != ["a", "b"]', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + + res = df.query('["a", "b"] != strings', engine=engine, parser=parser) + tm.assert_frame_equal(res, expect) + + def test_query_with_string_columns(self, parser, engine): + df = DataFrame( + { + "a": list("aaaabbbbcccc"), + "b": list("aabbccddeeff"), + "c": np.random.default_rng(2).integers(5, size=12), + "d": np.random.default_rng(2).integers(9, size=12), + } + ) + if parser == "pandas": + res = df.query("a in b", parser=parser, engine=engine) + expec = df[df.a.isin(df.b)] + tm.assert_frame_equal(res, expec) + + res = df.query("a in b and c < d", parser=parser, engine=engine) + expec = df[df.a.isin(df.b) & (df.c < df.d)] + tm.assert_frame_equal(res, expec) + else: + msg = r"'(Not)?In' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + df.query("a in b", parser=parser, engine=engine) + + msg = r"'BoolOp' nodes are not implemented" + with pytest.raises(NotImplementedError, match=msg): + df.query("a in b and c < d", parser=parser, engine=engine) + + def test_object_array_eq_ne(self, parser, engine): + df = DataFrame( + { + "a": list("aaaabbbbcccc"), + "b": list("aabbccddeeff"), + "c": np.random.default_rng(2).integers(5, size=12), + "d": np.random.default_rng(2).integers(9, size=12), + } + ) + res = df.query("a == b", parser=parser, engine=engine) + exp = df[df.a == df.b] + tm.assert_frame_equal(res, exp) + + res = df.query("a != b", parser=parser, engine=engine) + exp = df[df.a != df.b] + tm.assert_frame_equal(res, exp) + + def test_query_with_nested_strings(self, parser, engine): + skip_if_no_pandas_parser(parser) + events = [ + f"page {n} {act}" for n in range(1, 4) for act in ["load", "exit"] + ] * 2 + stamps1 = date_range("2014-01-01 0:00:01", freq="30s", periods=6) + stamps2 = date_range("2014-02-01 1:00:01", freq="30s", periods=6) + df = DataFrame( + { + "id": np.arange(1, 7).repeat(2), + "event": events, + "timestamp": stamps1.append(stamps2), + } + ) + + expected = df[df.event == '"page 1 load"'] + res = df.query("""'"page 1 load"' in event""", parser=parser, engine=engine) + tm.assert_frame_equal(expected, res) + + def test_query_with_nested_special_character(self, parser, engine): + skip_if_no_pandas_parser(parser) + df = DataFrame({"a": ["a", "b", "test & test"], "b": [1, 2, 3]}) + res = df.query('a == "test & test"', parser=parser, engine=engine) + expec = df[df.a == "test & test"] + tm.assert_frame_equal(res, expec) + + @pytest.mark.parametrize( + "op, func", + [ + ["<", operator.lt], + [">", operator.gt], + ["<=", operator.le], + [">=", operator.ge], + ], + ) + def test_query_lex_compare_strings(self, parser, engine, op, func): + a = Series(np.random.default_rng(2).choice(list("abcde"), 20)) + b = Series(np.arange(a.size)) + df = DataFrame({"X": a, "Y": b}) + + res = df.query(f'X {op} "d"', engine=engine, parser=parser) + expected = df[func(df.X, "d")] + tm.assert_frame_equal(res, expected) + + def test_query_single_element_booleans(self, parser, engine): + columns = "bid", "bidsize", "ask", "asksize" + data = np.random.default_rng(2).integers(2, size=(1, len(columns))).astype(bool) + df = DataFrame(data, columns=columns) + res = df.query("bid & ask", engine=engine, parser=parser) + expected = df[df.bid & df.ask] + tm.assert_frame_equal(res, expected) + + def test_query_string_scalar_variable(self, parser, engine): + skip_if_no_pandas_parser(parser) + df = DataFrame( + { + "Symbol": ["BUD US", "BUD US", "IBM US", "IBM US"], + "Price": [109.70, 109.72, 183.30, 183.35], + } + ) + e = df[df.Symbol == "BUD US"] + symb = "BUD US" # noqa: F841 + r = df.query("Symbol == @symb", parser=parser, engine=engine) + tm.assert_frame_equal(e, r) + + @pytest.mark.parametrize( + "in_list", + [ + [None, "asdf", "ghjk"], + ["asdf", None, "ghjk"], + ["asdf", "ghjk", None], + [None, None, "asdf"], + ["asdf", None, None], + [None, None, None], + ], + ) + def test_query_string_null_elements(self, in_list): + # GITHUB ISSUE #31516 + parser = "pandas" + engine = "python" + expected = {i: value for i, value in enumerate(in_list) if value == "asdf"} + + df_expected = DataFrame({"a": expected}, dtype="string") + df_expected.index = df_expected.index.astype("int64") + df = DataFrame({"a": in_list}, dtype="string") + df.index = Index(list(df.index), dtype=df.index.dtype) + res1 = df.query("a == 'asdf'", parser=parser, engine=engine) + res2 = df[df["a"] == "asdf"] + res3 = df.query("a <= 'asdf'", parser=parser, engine=engine) + tm.assert_frame_equal(res1, df_expected) + tm.assert_frame_equal(res1, res2) + tm.assert_frame_equal(res1, res3) + tm.assert_frame_equal(res2, res3) + + +class TestDataFrameEvalWithFrame: + @pytest.fixture + def frame(self): + return DataFrame( + np.random.default_rng(2).standard_normal((10, 3)), columns=list("abc") + ) + + def test_simple_expr(self, frame, parser, engine): + res = frame.eval("a + b", engine=engine, parser=parser) + expect = frame.a + frame.b + tm.assert_series_equal(res, expect) + + def test_bool_arith_expr(self, frame, parser, engine): + res = frame.eval("a[a < 1] + b", engine=engine, parser=parser) + expect = frame.a[frame.a < 1] + frame.b + tm.assert_series_equal(res, expect) + + @pytest.mark.parametrize("op", ["+", "-", "*", "/"]) + def test_invalid_type_for_operator_raises(self, parser, engine, op): + df = DataFrame({"a": [1, 2], "b": ["c", "d"]}) + msg = r"unsupported operand type\(s\) for .+: '.+' and '.+'|Cannot" + + with pytest.raises(TypeError, match=msg): + df.eval(f"a {op} b", engine=engine, parser=parser) + + +class TestDataFrameQueryBacktickQuoting: + @pytest.fixture + def df(self): + """ + Yields a dataframe with strings that may or may not need escaping + by backticks. The last two columns cannot be escaped by backticks + and should raise a ValueError. + """ + return DataFrame( + { + "A": [1, 2, 3], + "B B": [3, 2, 1], + "C C": [4, 5, 6], + "C C": [7, 4, 3], + "C_C": [8, 9, 10], + "D_D D": [11, 1, 101], + "E.E": [6, 3, 5], + "F-F": [8, 1, 10], + "1e1": [2, 4, 8], + "def": [10, 11, 2], + "A (x)": [4, 1, 3], + "B(x)": [1, 1, 5], + "B (x)": [2, 7, 4], + " &^ :!€$?(} > <++*'' ": [2, 5, 6], + "": [10, 11, 1], + " A": [4, 7, 9], + " ": [1, 2, 1], + "it's": [6, 3, 1], + "that's": [9, 1, 8], + "☺": [8, 7, 6], + "xy (z)": [1, 2, 3], # noqa: RUF001 + "xy (z\\uff09": [4, 5, 6], # noqa: RUF001 + "foo#bar": [2, 4, 5], + 1: [5, 7, 9], + } + ) + + def test_single_backtick_variable_query(self, df): + res = df.query("1 < `B B`") + expect = df[1 < df["B B"]] + tm.assert_frame_equal(res, expect) + + def test_two_backtick_variables_query(self, df): + res = df.query("1 < `B B` and 4 < `C C`") + expect = df[(1 < df["B B"]) & (4 < df["C C"])] + tm.assert_frame_equal(res, expect) + + def test_single_backtick_variable_expr(self, df): + res = df.eval("A + `B B`") + expect = df["A"] + df["B B"] + tm.assert_series_equal(res, expect) + + def test_two_backtick_variables_expr(self, df): + res = df.eval("`B B` + `C C`") + expect = df["B B"] + df["C C"] + tm.assert_series_equal(res, expect) + + def test_already_underscore_variable(self, df): + res = df.eval("`C_C` + A") + expect = df["C_C"] + df["A"] + tm.assert_series_equal(res, expect) + + def test_same_name_but_underscores(self, df): + res = df.eval("C_C + `C C`") + expect = df["C_C"] + df["C C"] + tm.assert_series_equal(res, expect) + + def test_mixed_underscores_and_spaces(self, df): + res = df.eval("A + `D_D D`") + expect = df["A"] + df["D_D D"] + tm.assert_series_equal(res, expect) + + def test_backtick_quote_name_with_no_spaces(self, df): + res = df.eval("A + `C_C`") + expect = df["A"] + df["C_C"] + tm.assert_series_equal(res, expect) + + def test_special_characters(self, df): + res = df.eval("`E.E` + `F-F` - A") + expect = df["E.E"] + df["F-F"] - df["A"] + tm.assert_series_equal(res, expect) + + def test_start_with_digit(self, df): + res = df.eval("A + `1e1`") + expect = df["A"] + df["1e1"] + tm.assert_series_equal(res, expect) + + def test_keyword(self, df): + res = df.eval("A + `def`") + expect = df["A"] + df["def"] + tm.assert_series_equal(res, expect) + + def test_unneeded_quoting(self, df): + res = df.query("`A` > 2") + expect = df[df["A"] > 2] + tm.assert_frame_equal(res, expect) + + def test_parenthesis(self, df): + res = df.query("`A (x)` > 2") + expect = df[df["A (x)"] > 2] + tm.assert_frame_equal(res, expect) + + def test_empty_string(self, df): + res = df.query("`` > 5") + expect = df[df[""] > 5] + tm.assert_frame_equal(res, expect) + + def test_multiple_spaces(self, df): + res = df.query("`C C` > 5") + expect = df[df["C C"] > 5] + tm.assert_frame_equal(res, expect) + + def test_start_with_spaces(self, df): + res = df.eval("` A` + ` `") + expect = df[" A"] + df[" "] + tm.assert_series_equal(res, expect) + + def test_ints(self, df): + res = df.query("`1` == 7") + expect = df[df[1] == 7] + tm.assert_frame_equal(res, expect) + + def test_lots_of_operators_string(self, df): + res = df.query("` &^ :!€$?(} > <++*'' ` > 4") + expect = df[df[" &^ :!€$?(} > <++*'' "] > 4] + tm.assert_frame_equal(res, expect) + + def test_missing_attribute(self, df): + message = "module 'pandas' has no attribute 'thing'" + with pytest.raises(AttributeError, match=message): + df.eval("@pd.thing") + + def test_quote(self, df): + res = df.query("`it's` > `that's`") + expect = df[df["it's"] > df["that's"]] + tm.assert_frame_equal(res, expect) + + def test_character_outside_range_smiley(self, df): + res = df.query("`☺` > 4") + expect = df[df["☺"] > 4] + tm.assert_frame_equal(res, expect) + + def test_character_outside_range_2_byte_parens(self, df): + # GH 49633 + res = df.query("`xy (z)` == 2") # noqa: RUF001 + expect = df[df["xy (z)"] == 2] # noqa: RUF001 + tm.assert_frame_equal(res, expect) + + def test_character_outside_range_and_actual_backslash(self, df): + # GH 49633 + res = df.query("`xy (z\\uff09` == 2") # noqa: RUF001 + expect = df[df["xy \uff08z\\uff09"] == 2] + tm.assert_frame_equal(res, expect) + + def test_hashtag(self, df): + res = df.query("`foo#bar` > 4") + expect = df[df["foo#bar"] > 4] + tm.assert_frame_equal(res, expect) + + def test_expr_with_column_name_with_hashtag_character(self): + # GH 59285 + df = DataFrame((1, 2, 3), columns=["a#"]) + result = df.query("`a#` < 2") + expected = df[df["a#"] < 2] + tm.assert_frame_equal(result, expected) + + def test_expr_with_comment(self): + # GH 59285 + df = DataFrame((1, 2, 3), columns=["a#"]) + result = df.query("`a#` < 2 # This is a comment") + expected = df[df["a#"] < 2] + tm.assert_frame_equal(result, expected) + + def test_expr_with_column_name_with_backtick_and_hash(self): + # GH 59285 + df = DataFrame((1, 2, 3), columns=["a`#b"]) + result = df.query("`a``#b` < 2") + expected = df[df["a`#b"] < 2] + tm.assert_frame_equal(result, expected) + + def test_expr_with_column_name_with_backtick(self): + # GH 59285 + df = DataFrame({"a`b": (1, 2, 3), "ab": (4, 5, 6)}) + result = df.query("`a``b` < 2") + # Note: Formatting checks may wrongly consider the above ``inline code``. + expected = df[df["a`b"] < 2] + tm.assert_frame_equal(result, expected) + + def test_expr_with_string_with_backticks(self): + # GH 59285 + df = DataFrame(("`", "`````", "``````````"), columns=["#backticks"]) + result = df.query("'```' < `#backticks`") + expected = df["```" < df["#backticks"]] + tm.assert_frame_equal(result, expected) + + def test_expr_with_string_with_backticked_substring_same_as_column_name(self): + # GH 59285 + df = DataFrame(("`", "`````", "``````````"), columns=["#backticks"]) + result = df.query("'`#backticks`' < `#backticks`") + expected = df["`#backticks`" < df["#backticks"]] + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "col1,col2,expr", + [ + ("it's", "that's", "`it's` < `that's`"), + ('it"s', 'that"s', '`it"s` < `that"s`'), + ("it's", 'that\'s "nice"', "`it's` < `that's \"nice\"`"), + ("it's", "that's #cool", "`it's` < `that's #cool` # This is a comment"), + ], + ) + def test_expr_with_column_names_with_special_characters(self, col1, col2, expr): + # GH 59285 + df = DataFrame( + [ + {col1: 1, col2: 2}, + {col1: 3, col2: 4}, + {col1: -1, col2: -2}, + {col1: -3, col2: -4}, + ] + ) + result = df.query(expr) + expected = df[df[col1] < df[col2]] + tm.assert_frame_equal(result, expected) + + def test_expr_with_no_backticks(self): + # GH 59285 + df = DataFrame(("aaa", "vvv", "zzz"), columns=["column_name"]) + result = df.query("'value' < column_name") + expected = df["value" < df["column_name"]] + tm.assert_frame_equal(result, expected) + + def test_expr_with_no_quotes_and_backtick_is_unmatched(self): + # GH 59285 + df = DataFrame((1, 5, 10), columns=["column-name"]) + with pytest.raises((SyntaxError, TokenError), match="invalid syntax"): + df.query("5 < `column-name") + + def test_expr_with_no_quotes_and_backtick_is_matched(self): + # GH 59285 + df = DataFrame((1, 5, 10), columns=["column-name"]) + result = df.query("5 < `column-name`") + expected = df[5 < df["column-name"]] + tm.assert_frame_equal(result, expected) + + def test_expr_with_backtick_opened_before_quote_and_backtick_is_unmatched(self): + # GH 59285 + df = DataFrame((1, 5, 10), columns=["It's"]) + with pytest.raises( + (SyntaxError, TokenError), match="unterminated string literal" + ): + df.query("5 < `It's") + + def test_expr_with_backtick_opened_before_quote_and_backtick_is_matched(self): + # GH 59285 + df = DataFrame((1, 5, 10), columns=["It's"]) + result = df.query("5 < `It's`") + expected = df[5 < df["It's"]] + tm.assert_frame_equal(result, expected) + + def test_expr_with_quote_opened_before_backtick_and_quote_is_unmatched(self): + # GH 59285 + df = DataFrame(("aaa", "vvv", "zzz"), columns=["column-name"]) + with pytest.raises( + (SyntaxError, TokenError), match="unterminated string literal" + ): + df.query("`column-name` < 'It`s that\\'s \"quote\" #hash") + + def test_expr_with_quote_opened_before_backtick_and_quote_is_matched_at_end(self): + # GH 59285 + df = DataFrame(("aaa", "vvv", "zzz"), columns=["column-name"]) + result = df.query("`column-name` < 'It`s that\\'s \"quote\" #hash'") + expected = df[df["column-name"] < 'It`s that\'s "quote" #hash'] + tm.assert_frame_equal(result, expected) + + def test_expr_with_quote_opened_before_backtick_and_quote_is_matched_in_mid(self): + # GH 59285 + df = DataFrame(("aaa", "vvv", "zzz"), columns=["column-name"]) + result = df.query("'It`s that\\'s \"quote\" #hash' < `column-name`") + expected = df['It`s that\'s "quote" #hash' < df["column-name"]] + tm.assert_frame_equal(result, expected) + + def test_call_non_named_expression(self, df): + """ + Only attributes and variables ('named functions') can be called. + .__call__() is not an allowed attribute because that would allow + calling anything. + https://github.com/pandas-dev/pandas/pull/32460 + """ + + def func(*_): + return 1 + + funcs = [func] # noqa: F841 + + df.eval("@func()") + + with pytest.raises(TypeError, match="Only named functions are supported"): + df.eval("@funcs[0]()") + + with pytest.raises(TypeError, match="Only named functions are supported"): + df.eval("@funcs[0].__call__()") + + def test_ea_dtypes(self, any_numeric_ea_and_arrow_dtype): + # GH#29618 + df = DataFrame( + [[1, 2], [3, 4]], columns=["a", "b"], dtype=any_numeric_ea_and_arrow_dtype + ) + warning = RuntimeWarning if NUMEXPR_INSTALLED else None + with tm.assert_produces_warning(warning): + result = df.eval("c = b - a") + expected = DataFrame( + [[1, 2, 1], [3, 4, 1]], + columns=["a", "b", "c"], + dtype=any_numeric_ea_and_arrow_dtype, + ) + tm.assert_frame_equal(result, expected) + + def test_ea_dtypes_and_scalar(self): + # GH#29618 + df = DataFrame([[1, 2], [3, 4]], columns=["a", "b"], dtype="Float64") + warning = RuntimeWarning if NUMEXPR_INSTALLED else None + with tm.assert_produces_warning(warning): + result = df.eval("c = b - 1") + expected = DataFrame( + [[1, 2, 1], [3, 4, 3]], columns=["a", "b", "c"], dtype="Float64" + ) + tm.assert_frame_equal(result, expected) + + def test_ea_dtypes_and_scalar_operation(self, any_numeric_ea_and_arrow_dtype): + # GH#29618 + df = DataFrame( + [[1, 2], [3, 4]], columns=["a", "b"], dtype=any_numeric_ea_and_arrow_dtype + ) + result = df.eval("c = 2 - 1") + expected = DataFrame( + { + "a": Series([1, 3], dtype=any_numeric_ea_and_arrow_dtype), + "b": Series([2, 4], dtype=any_numeric_ea_and_arrow_dtype), + "c": Series([1, 1], dtype=result["c"].dtype), + } + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["int64", "Int64", "int64[pyarrow]"]) + def test_query_ea_dtypes(self, dtype): + if dtype == "int64[pyarrow]": + pytest.importorskip("pyarrow") + # GH#50261 + df = DataFrame({"a": [1, 2]}, dtype=dtype) + ref = {2} # noqa: F841 + warning = RuntimeWarning if dtype == "Int64" and NUMEXPR_INSTALLED else None + with tm.assert_produces_warning(warning): + result = df.query("a in @ref") + expected = DataFrame({"a": [2]}, index=range(1, 2), dtype=dtype) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("engine", ["python", "numexpr"]) + @pytest.mark.parametrize("dtype", ["int64", "Int64", "int64[pyarrow]"]) + def test_query_ea_equality_comparison(self, dtype, engine): + # GH#50261 + warning = RuntimeWarning if engine == "numexpr" else None + if engine == "numexpr" and not NUMEXPR_INSTALLED: + pytest.skip("numexpr not installed") + if dtype == "int64[pyarrow]": + pytest.importorskip("pyarrow") + df = DataFrame( + {"A": Series([1, 1, 2], dtype="Int64"), "B": Series([1, 2, 2], dtype=dtype)} + ) + with tm.assert_produces_warning(warning): + result = df.query("A == B", engine=engine) + expected = DataFrame( + { + "A": Series([1, 2], dtype="Int64", index=range(0, 4, 2)), + "B": Series([1, 2], dtype=dtype, index=range(0, 4, 2)), + } + ) + tm.assert_frame_equal(result, expected) + + def test_all_nat_in_object(self): + # GH#57068 + now = pd.Timestamp.now("UTC") # noqa: F841 + df = DataFrame({"a": pd.to_datetime([None, None], utc=True)}, dtype=object) + result = df.query("a > @now") + expected = DataFrame({"a": []}, dtype=object) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/frame/test_reductions.py b/pandas/tests/frame/test_reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..6c702525156d7e5ba6b4fea1d69565bf7baf719e --- /dev/null +++ b/pandas/tests/frame/test_reductions.py @@ -0,0 +1,2234 @@ +from datetime import timedelta +from decimal import Decimal +import re + +from dateutil.tz import tzlocal +import numpy as np +import pytest + +from pandas.compat import ( + IS64, + is_platform_windows, +) +from pandas.compat.numpy import np_version_gt2 +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + Categorical, + CategoricalDtype, + DataFrame, + DatetimeIndex, + Index, + PeriodIndex, + RangeIndex, + Series, + Timestamp, + date_range, + isna, + notna, + to_datetime, + to_timedelta, +) +import pandas._testing as tm +from pandas.core import ( + algorithms, + nanops, +) + +is_windows_np2_or_is32 = (is_platform_windows() and not np_version_gt2) or not IS64 +is_windows_or_is32 = is_platform_windows() or not IS64 + + +def make_skipna_wrapper(alternative, skipna_alternative=None): + """ + Create a function for calling on an array. + + Parameters + ---------- + alternative : function + The function to be called on the array with no NaNs. + Only used when 'skipna_alternative' is None. + skipna_alternative : function + The function to be called on the original array + + Returns + ------- + function + """ + if skipna_alternative: + + def skipna_wrapper(x): + return skipna_alternative(x.values) + + else: + + def skipna_wrapper(x): + nona = x.dropna() + if len(nona) == 0: + return np.nan + return alternative(nona) + + return skipna_wrapper + + +def assert_stat_op_calc( + opname, + alternative, + frame, + has_skipna=True, + check_dtype=True, + check_dates=False, + rtol=1e-5, + atol=1e-8, + skipna_alternative=None, +): + """ + Check that operator opname works as advertised on frame + + Parameters + ---------- + opname : str + Name of the operator to test on frame + alternative : function + Function that opname is tested against; i.e. "frame.opname()" should + equal "alternative(frame)". + frame : DataFrame + The object that the tests are executed on + has_skipna : bool, default True + Whether the method "opname" has the kwarg "skip_na" + check_dtype : bool, default True + Whether the dtypes of the result of "frame.opname()" and + "alternative(frame)" should be checked. + check_dates : bool, default false + Whether opname should be tested on a Datetime Series + rtol : float, default 1e-5 + Relative tolerance. + atol : float, default 1e-8 + Absolute tolerance. + skipna_alternative : function, default None + NaN-safe version of alternative + """ + f = getattr(frame, opname) + + if check_dates: + df = DataFrame({"b": date_range("1/1/2001", periods=2)}) + with tm.assert_produces_warning(None): + result = getattr(df, opname)() + assert isinstance(result, Series) + + df["a"] = range(len(df)) + with tm.assert_produces_warning(None): + result = getattr(df, opname)() + assert isinstance(result, Series) + assert len(result) + + if has_skipna: + + def wrapper(x): + return alternative(x.values) + + skipna_wrapper = make_skipna_wrapper(alternative, skipna_alternative) + result0 = f(axis=0, skipna=False) + result1 = f(axis=1, skipna=False) + tm.assert_series_equal( + result0, frame.apply(wrapper), check_dtype=check_dtype, rtol=rtol, atol=atol + ) + tm.assert_series_equal( + result1, + frame.apply(wrapper, axis=1), + rtol=rtol, + atol=atol, + ) + else: + skipna_wrapper = alternative + + result0 = f(axis=0) + result1 = f(axis=1) + tm.assert_series_equal( + result0, + frame.apply(skipna_wrapper), + check_dtype=check_dtype, + rtol=rtol, + atol=atol, + ) + + if opname in ["sum", "prod"]: + expected = frame.apply(skipna_wrapper, axis=1) + tm.assert_series_equal( + result1, expected, check_dtype=False, rtol=rtol, atol=atol + ) + + # check dtypes + if check_dtype: + lcd_dtype = frame.values.dtype + assert lcd_dtype == result0.dtype + assert lcd_dtype == result1.dtype + + # bad axis + with pytest.raises(ValueError, match="No axis named 2"): + f(axis=2) + + # all NA case + if has_skipna: + all_na = frame * np.nan + r0 = getattr(all_na, opname)(axis=0) + r1 = getattr(all_na, opname)(axis=1) + if opname in ["sum", "prod"]: + unit = 1 if opname == "prod" else 0 # result for empty sum/prod + expected = Series(unit, index=r0.index, dtype=r0.dtype) + tm.assert_series_equal(r0, expected) + expected = Series(unit, index=r1.index, dtype=r1.dtype) + tm.assert_series_equal(r1, expected) + + +@pytest.fixture +def bool_frame_with_na(): + """ + Fixture for DataFrame of booleans with index of unique strings + + Columns are ['A', 'B', 'C', 'D']; some entries are missing + """ + df = DataFrame( + np.concatenate( + [np.ones((15, 4), dtype=bool), np.zeros((15, 4), dtype=bool)], axis=0 + ), + index=Index([f"foo_{i}" for i in range(30)], dtype=object), + columns=Index(list("ABCD"), dtype=object), + dtype=object, + ) + # set some NAs + df.iloc[5:10] = np.nan + df.iloc[15:20, -2:] = np.nan + return df + + +@pytest.fixture +def float_frame_with_na(): + """ + Fixture for DataFrame of floats with index of unique strings + + Columns are ['A', 'B', 'C', 'D']; some entries are missing + """ + df = DataFrame( + np.random.default_rng(2).standard_normal((30, 4)), + index=Index([f"foo_{i}" for i in range(30)], dtype=object), + columns=Index(list("ABCD"), dtype=object), + ) + # set some NAs + df.iloc[5:10] = np.nan + df.iloc[15:20, -2:] = np.nan + return df + + +class TestDataFrameAnalytics: + # --------------------------------------------------------------------- + # Reductions + @pytest.mark.parametrize("axis", [0, 1]) + @pytest.mark.parametrize( + "opname", + [ + "count", + "sum", + "mean", + "product", + "median", + "min", + "max", + "nunique", + "var", + "std", + "sem", + pytest.param("skew", marks=td.skip_if_no("scipy")), + pytest.param("kurt", marks=td.skip_if_no("scipy")), + ], + ) + def test_stat_op_api_float_string_frame(self, float_string_frame, axis, opname): + if (opname in ("sum", "min", "max") and axis == 0) or opname in ( + "count", + "nunique", + ): + getattr(float_string_frame, opname)(axis=axis) + else: + if opname in ["var", "std", "sem", "skew", "kurt"]: + msg = "could not convert string to float: 'bar'" + elif opname == "product": + if axis == 1: + msg = "can't multiply sequence by non-int of type 'float'" + else: + msg = "can't multiply sequence by non-int of type 'str'" + elif opname == "sum": + msg = r"unsupported operand type\(s\) for \+: 'float' and 'str'" + elif opname == "mean": + if axis == 0: + # different message on different builds + msg = "|".join( + [ + r"Could not convert \['.*'\] to numeric", + "Could not convert string '(bar){30}' to numeric", + ] + ) + else: + msg = r"unsupported operand type\(s\) for \+: 'float' and 'str'" + elif opname in ["min", "max"]: + msg = "'[><]=' not supported between instances of 'float' and 'str'" + elif opname == "median": + msg = re.compile( + r"Cannot convert \[.*\] to numeric|does not support|Cannot perform", + flags=re.S, + ) + if not isinstance(msg, re.Pattern): + msg = msg + "|does not support|Cannot perform reduction" + with pytest.raises(TypeError, match=msg): + getattr(float_string_frame, opname)(axis=axis) + if opname != "nunique": + getattr(float_string_frame, opname)(axis=axis, numeric_only=True) + + @pytest.mark.parametrize("axis", [0, 1]) + @pytest.mark.parametrize( + "opname", + [ + "count", + "sum", + "mean", + "product", + "median", + "min", + "max", + "var", + "std", + "sem", + pytest.param("skew", marks=td.skip_if_no("scipy")), + pytest.param("kurt", marks=td.skip_if_no("scipy")), + ], + ) + def test_stat_op_api_float_frame(self, float_frame, axis, opname): + getattr(float_frame, opname)(axis=axis, numeric_only=False) + + def test_stat_op_calc(self, float_frame_with_na, mixed_float_frame): + def count(s): + return notna(s).sum() + + def nunique(s): + return len(algorithms.unique1d(s.dropna())) + + def var(x): + return np.var(x, ddof=1) + + def std(x): + return np.std(x, ddof=1) + + def sem(x): + return np.std(x, ddof=1) / np.sqrt(len(x)) + + assert_stat_op_calc( + "nunique", + nunique, + float_frame_with_na, + has_skipna=False, + check_dtype=False, + check_dates=True, + ) + + # GH#32571: rol needed for flaky CI builds + # mixed types (with upcasting happening) + assert_stat_op_calc( + "sum", + np.sum, + mixed_float_frame.astype("float32"), + check_dtype=False, + rtol=1e-3, + ) + + assert_stat_op_calc( + "sum", np.sum, float_frame_with_na, skipna_alternative=np.nansum + ) + assert_stat_op_calc("mean", np.mean, float_frame_with_na, check_dates=True) + assert_stat_op_calc( + "product", np.prod, float_frame_with_na, skipna_alternative=np.nanprod + ) + + assert_stat_op_calc("var", var, float_frame_with_na) + assert_stat_op_calc("std", std, float_frame_with_na) + assert_stat_op_calc("sem", sem, float_frame_with_na) + + assert_stat_op_calc( + "count", + count, + float_frame_with_na, + has_skipna=False, + check_dtype=False, + check_dates=True, + ) + + def test_stat_op_calc_skew_kurtosis(self, float_frame_with_na): + sp_stats = pytest.importorskip("scipy.stats") + + def skewness(x): + if len(x) < 3: + return np.nan + return sp_stats.skew(x, bias=False) + + def kurt(x): + if len(x) < 4: + return np.nan + return sp_stats.kurtosis(x, bias=False) + + assert_stat_op_calc("skew", skewness, float_frame_with_na) + assert_stat_op_calc("kurt", kurt, float_frame_with_na) + + def test_median(self, float_frame_with_na, int_frame): + def wrapper(x): + if isna(x).any(): + return np.nan + return np.median(x) + + assert_stat_op_calc("median", wrapper, float_frame_with_na, check_dates=True) + assert_stat_op_calc( + "median", wrapper, int_frame, check_dtype=False, check_dates=True + ) + + @pytest.mark.parametrize( + "method", ["sum", "mean", "prod", "var", "std", "skew", "min", "max"] + ) + @pytest.mark.parametrize( + "df", + [ + DataFrame( + { + "a": [ + -0.00049987540199591344, + -0.0016467257772919831, + 0.00067695870775883013, + ], + "b": [-0, -0, 0.0], + "c": [ + 0.00031111847529610595, + 0.0014902627951905339, + -0.00094099200035979691, + ], + }, + index=["foo", "bar", "baz"], + dtype="O", + ), + DataFrame({0: [np.nan, 2], 1: [np.nan, 3], 2: [np.nan, 4]}, dtype=object), + ], + ) + @pytest.mark.filterwarnings("ignore:Mismatched null-like values:FutureWarning") + def test_stat_operators_attempt_obj_array(self, method, df, axis): + # GH#676 + assert df.values.dtype == np.object_ + result = getattr(df, method)(axis=axis) + expected = getattr(df.astype("f8"), method)(axis=axis).astype(object) + if axis in [1, "columns"] and method in ["min", "max"]: + expected[expected.isna()] = None + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("op", ["mean", "std", "var", "skew", "kurt", "sem"]) + def test_mixed_ops(self, op): + # GH#16116 + df = DataFrame( + { + "int": [1, 2, 3, 4], + "float": [1.0, 2.0, 3.0, 4.0], + "str": ["a", "b", "c", "d"], + } + ) + msg = "|".join( + [ + "Could not convert", + "could not convert", + "can't multiply sequence by non-int", + "does not support", + "Cannot perform", + ] + ) + with pytest.raises(TypeError, match=msg): + getattr(df, op)() + + with pd.option_context("use_bottleneck", False): + with pytest.raises(TypeError, match=msg): + getattr(df, op)() + + def test_reduce_mixed_frame(self): + # GH 6806 + df = DataFrame( + { + "bool_data": [True, True, False, False, False], + "int_data": [10, 20, 30, 40, 50], + "string_data": ["a", "b", "c", "d", "e"], + } + ) + df.reindex(columns=["bool_data", "int_data", "string_data"]) + test = df.sum(axis=0) + tm.assert_numpy_array_equal( + test.values, np.array([2, 150, "abcde"], dtype=object) + ) + alt = df.T.sum(axis=1) + tm.assert_series_equal(test, alt) + + def test_nunique(self): + df = DataFrame({"A": [1, 1, 1], "B": [1, 2, 3], "C": [1, np.nan, 3]}) + tm.assert_series_equal(df.nunique(), Series({"A": 1, "B": 3, "C": 2})) + tm.assert_series_equal( + df.nunique(dropna=False), Series({"A": 1, "B": 3, "C": 3}) + ) + tm.assert_series_equal(df.nunique(axis=1), Series([1, 2, 2])) + tm.assert_series_equal(df.nunique(axis=1, dropna=False), Series([1, 3, 2])) + + @pytest.mark.parametrize("tz", [None, "UTC"]) + def test_mean_mixed_datetime_numeric(self, tz): + # https://github.com/pandas-dev/pandas/issues/24752 + df = DataFrame({"A": [1, 1], "B": [Timestamp("2000", tz=tz)] * 2}) + result = df.mean() + expected = Series([1.0, Timestamp("2000", tz=tz)], index=["A", "B"]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("tz", [None, "UTC"]) + def test_mean_includes_datetimes(self, tz): + # https://github.com/pandas-dev/pandas/issues/24752 + # Behavior in 0.24.0rc1 was buggy. + # As of 2.0 with numeric_only=None we do *not* drop datetime columns + df = DataFrame({"A": [Timestamp("2000", tz=tz)] * 2}) + result = df.mean() + + expected = Series([Timestamp("2000", tz=tz)], index=["A"]) + tm.assert_series_equal(result, expected) + + def test_mean_mixed_string_decimal(self): + # GH 11670 + # possible bug when calculating mean of DataFrame? + + d = [ + {"A": 2, "B": None, "C": Decimal("628.00")}, + {"A": 1, "B": None, "C": Decimal("383.00")}, + {"A": 3, "B": None, "C": Decimal("651.00")}, + {"A": 2, "B": None, "C": Decimal("575.00")}, + {"A": 4, "B": None, "C": Decimal("1114.00")}, + {"A": 1, "B": "TEST", "C": Decimal("241.00")}, + {"A": 2, "B": None, "C": Decimal("572.00")}, + {"A": 4, "B": None, "C": Decimal("609.00")}, + {"A": 3, "B": None, "C": Decimal("820.00")}, + {"A": 5, "B": None, "C": Decimal("1223.00")}, + ] + + df = DataFrame(d) + + with pytest.raises( + TypeError, match="unsupported operand type|does not support|Cannot perform" + ): + df.mean() + result = df[["A", "C"]].mean() + expected = Series([2.7, 681.6], index=["A", "C"], dtype=object) + tm.assert_series_equal(result, expected) + + def test_var_std(self, datetime_frame): + result = datetime_frame.std(ddof=4) + expected = datetime_frame.apply(lambda x: x.std(ddof=4)) + tm.assert_almost_equal(result, expected) + + result = datetime_frame.var(ddof=4) + expected = datetime_frame.apply(lambda x: x.var(ddof=4)) + tm.assert_almost_equal(result, expected) + + arr = np.repeat(np.random.default_rng(2).random((1, 1000)), 1000, 0) + result = nanops.nanvar(arr, axis=0) + assert not (result < 0).any() + + with pd.option_context("use_bottleneck", False): + result = nanops.nanvar(arr, axis=0) + assert not (result < 0).any() + + @pytest.mark.parametrize("meth", ["sem", "var", "std"]) + def test_numeric_only_flag(self, meth): + # GH 9201 + df1 = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), + columns=["foo", "bar", "baz"], + ) + # Cast to object to avoid implicit cast when setting entry to "100" below + df1 = df1.astype({"foo": object}) + # set one entry to a number in str format + df1.loc[0, "foo"] = "100" + + df2 = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), + columns=["foo", "bar", "baz"], + ) + # Cast to object to avoid implicit cast when setting entry to "a" below + df2 = df2.astype({"foo": object}) + # set one entry to a non-number str + df2.loc[0, "foo"] = "a" + + result = getattr(df1, meth)(axis=1, numeric_only=True) + expected = getattr(df1[["bar", "baz"]], meth)(axis=1) + tm.assert_series_equal(expected, result) + + result = getattr(df2, meth)(axis=1, numeric_only=True) + expected = getattr(df2[["bar", "baz"]], meth)(axis=1) + tm.assert_series_equal(expected, result) + + # df1 has all numbers, df2 has a letter inside + msg = r"unsupported operand type\(s\) for -: 'float' and 'str'" + with pytest.raises(TypeError, match=msg): + getattr(df1, meth)(axis=1, numeric_only=False) + msg = "could not convert string to float: 'a'" + with pytest.raises(TypeError, match=msg): + getattr(df2, meth)(axis=1, numeric_only=False) + + def test_sem(self, datetime_frame): + result = datetime_frame.sem(ddof=4) + expected = datetime_frame.apply(lambda x: x.std(ddof=4) / np.sqrt(len(x))) + tm.assert_almost_equal(result, expected) + + arr = np.repeat(np.random.default_rng(2).random((1, 1000)), 1000, 0) + result = nanops.nansem(arr, axis=0) + assert not (result < 0).any() + + with pd.option_context("use_bottleneck", False): + result = nanops.nansem(arr, axis=0) + assert not (result < 0).any() + + @pytest.mark.parametrize( + "dropna, expected", + [ + ( + True, + { + "A": [12], + "B": [10.0], + "C": [1.0], + "D": ["a"], + "E": Categorical(["a"], categories=["a"]), + "F": DatetimeIndex(["2000-01-02"], dtype="M8[ns]"), + "G": to_timedelta(["1 days"]), + }, + ), + ( + False, + { + "A": [12], + "B": [10.0], + "C": [np.nan], + "D": Series([np.nan], dtype="str"), + "E": Categorical([np.nan], categories=["a"]), + "F": DatetimeIndex([pd.NaT], dtype="M8[ns]"), + "G": to_timedelta([pd.NaT]).as_unit("us"), + }, + ), + ( + True, + { + "H": [8, 9, np.nan, np.nan], + "I": [8, 9, np.nan, np.nan], + "J": [1, np.nan, np.nan, np.nan], + "K": Categorical(["a", np.nan, np.nan, np.nan], categories=["a"]), + "L": DatetimeIndex( + ["2000-01-02", "NaT", "NaT", "NaT"], dtype="M8[ns]" + ), + "M": to_timedelta(["1 days", "nan", "nan", "nan"]), + "N": [0, 1, 2, 3], + }, + ), + ( + False, + { + "H": [8, 9, np.nan, np.nan], + "I": [8, 9, np.nan, np.nan], + "J": [1, np.nan, np.nan, np.nan], + "K": Categorical([np.nan, "a", np.nan, np.nan], categories=["a"]), + "L": DatetimeIndex( + ["NaT", "2000-01-02", "NaT", "NaT"], dtype="M8[ns]" + ), + "M": to_timedelta(["nan", "1 days", "nan", "nan"]), + "N": [0, 1, 2, 3], + }, + ), + ], + ) + def test_mode_dropna(self, dropna, expected): + df = DataFrame( + { + "A": [12, 12, 19, 11], + "B": [10, 10, np.nan, 3], + "C": [1, np.nan, np.nan, np.nan], + "D": Series([np.nan, np.nan, "a", np.nan], dtype="str"), + "E": Categorical([np.nan, np.nan, "a", np.nan]), + "F": DatetimeIndex(["NaT", "2000-01-02", "NaT", "NaT"], dtype="M8[ns]"), + "G": to_timedelta(["1 days", "nan", "nan", "nan"]), + "H": [8, 8, 9, 9], + "I": [9, 9, 8, 8], + "J": [1, 1, np.nan, np.nan], + "K": Categorical(["a", np.nan, "a", np.nan]), + "L": DatetimeIndex( + ["2000-01-02", "2000-01-02", "NaT", "NaT"], dtype="M8[ns]" + ), + "M": to_timedelta(["1 days", "nan", "1 days", "nan"]), + "N": np.arange(4, dtype="int64"), + } + ) + + result = df[sorted(expected.keys())].mode(dropna=dropna) + expected = DataFrame(expected) + tm.assert_frame_equal(result, expected) + + def test_mode_sort_with_na(self, using_infer_string): + df = DataFrame({"A": [np.nan, np.nan, "a", "a"]}) + expected = DataFrame({"A": ["a", np.nan]}) + result = df.mode(dropna=False) + tm.assert_frame_equal(result, expected) + + def test_mode_empty_df(self): + df = DataFrame([], columns=["a", "b"]) + expected = df.copy() + result = df.mode() + tm.assert_frame_equal(result, expected) + + def test_operators_timedelta64(self): + df = DataFrame( + { + "A": date_range("2012-1-1", periods=3, freq="D", unit="ns"), + "B": date_range("2012-1-2", periods=3, freq="D", unit="ns"), + "C": Timestamp("20120101") - timedelta(minutes=5, seconds=5), + } + ) + + diffs = DataFrame({"A": df["A"] - df["C"], "B": df["A"] - df["B"]}) + + # min + result = diffs.min() + assert result.iloc[0] == diffs.loc[0, "A"] + assert result.iloc[1] == diffs.loc[0, "B"] + + result = diffs.min(axis=1) + assert (result == diffs.loc[0, "B"]).all() + + # max + result = diffs.max() + assert result.iloc[0] == diffs.loc[2, "A"] + assert result.iloc[1] == diffs.loc[2, "B"] + + result = diffs.max(axis=1) + assert (result == diffs["A"]).all() + + # abs + result = diffs.abs() + result2 = abs(diffs) + expected = DataFrame({"A": df["A"] - df["C"], "B": df["B"] - df["A"]}) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(result2, expected) + + # mixed frame + mixed = diffs.copy() + mixed["C"] = "foo" + mixed["D"] = 1 + mixed["E"] = 1.0 + mixed["F"] = Timestamp("20130101") + + # results in an object array + result = mixed.min() + expected = Series( + [ + pd.Timedelta(timedelta(seconds=5 * 60 + 5)), + pd.Timedelta(timedelta(days=-1)), + "foo", + 1, + 1.0, + Timestamp("20130101"), + ], + index=mixed.columns, + ) + tm.assert_series_equal(result, expected) + + # excludes non-numeric + result = mixed.min(axis=1, numeric_only=True) + expected = Series([1, 1, 1.0]) + tm.assert_series_equal(result, expected) + + # works when only those columns are selected + result = mixed[["A", "B"]].min(axis=1) + expected = Series([timedelta(days=-1)] * 3, dtype="m8[ns]") + tm.assert_series_equal(result, expected) + + result = mixed[["A", "B"]].min() + expected = Series( + [timedelta(seconds=5 * 60 + 5), timedelta(days=-1)], + index=["A", "B"], + dtype="m8[ns]", + ) + tm.assert_series_equal(result, expected) + + # GH 3106 + df = DataFrame( + { + "time": date_range("20130102", periods=5, unit="ns"), + "time2": date_range("20130105", periods=5, unit="ns"), + } + ) + df["off1"] = df["time2"] - df["time"] + assert df["off1"].dtype == "timedelta64[ns]" + + df["off2"] = df["time"] - df["time2"] + df._consolidate_inplace() + assert df["off1"].dtype == "timedelta64[ns]" + assert df["off2"].dtype == "timedelta64[ns]" + + def test_std_timedelta64_skipna_false(self): + # GH#37392 + tdi = pd.timedelta_range("1 Day", periods=10) + df = DataFrame({"A": tdi, "B": tdi}, copy=True) + df.iloc[-2, -1] = pd.NaT + + result = df.std(skipna=False) + expected = Series( + [df["A"].std(), pd.NaT], index=["A", "B"], dtype="timedelta64[us]" + ) + tm.assert_series_equal(result, expected) + + result = df.std(axis=1, skipna=False) + expected = Series( + [pd.Timedelta(0)] * 8 + [pd.NaT, pd.Timedelta(0)], dtype="m8[us]" + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "values", [["2022-01-01", "2022-01-02", pd.NaT, "2022-01-03"], 4 * [pd.NaT]] + ) + def test_std_datetime64_with_nat(self, values, skipna, request, unit): + # GH#51335 + dti = to_datetime(values).as_unit(unit) + df = DataFrame({"a": dti}) + result = df.std(skipna=skipna) + if not skipna or all(value is pd.NaT for value in values): + expected = Series({"a": pd.NaT}, dtype=f"timedelta64[{unit}]") + else: + expected = Series({"a": "1 days"}, dtype=f"timedelta64[{unit}]") + tm.assert_series_equal(result, expected) + + def test_sum_corner(self): + empty_frame = DataFrame() + + axis0 = empty_frame.sum(axis=0) + axis1 = empty_frame.sum(axis=1) + assert isinstance(axis0, Series) + assert isinstance(axis1, Series) + assert len(axis0) == 0 + assert len(axis1) == 0 + + @pytest.mark.parametrize( + "index", + [ + RangeIndex(0), + DatetimeIndex([]), + Index([], dtype=np.int64), + Index([], dtype=np.float64), + DatetimeIndex([], freq="ME"), + PeriodIndex([], freq="D"), + ], + ) + def test_axis_1_empty(self, all_reductions, index): + df = DataFrame(columns=["a"], index=index) + result = getattr(df, all_reductions)(axis=1) + if all_reductions in ("any", "all"): + expected_dtype = "bool" + elif all_reductions == "count": + expected_dtype = "int64" + else: + expected_dtype = "object" + expected = Series([], index=index, dtype=expected_dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("min_count", [0, 1]) + def test_axis_1_sum_na(self, string_dtype_no_object, skipna, min_count): + # https://github.com/pandas-dev/pandas/issues/60229 + dtype = string_dtype_no_object + df = DataFrame({"a": [pd.NA]}, dtype=dtype) + result = df.sum(axis=1, skipna=skipna, min_count=min_count) + value = "" if skipna and min_count == 0 else pd.NA + expected = Series([value], dtype=dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("method, unit", [("sum", 0), ("prod", 1)]) + @pytest.mark.parametrize("numeric_only", [None, True, False]) + def test_sum_prod_nanops(self, method, unit, numeric_only): + idx = ["a", "b", "c"] + df = DataFrame({"a": [unit, unit], "b": [unit, np.nan], "c": [np.nan, np.nan]}) + # The default + result = getattr(df, method)(numeric_only=numeric_only) + expected = Series([unit, unit, unit], index=idx, dtype="float64") + tm.assert_series_equal(result, expected) + + # min_count=1 + result = getattr(df, method)(numeric_only=numeric_only, min_count=1) + expected = Series([unit, unit, np.nan], index=idx) + tm.assert_series_equal(result, expected) + + # min_count=0 + result = getattr(df, method)(numeric_only=numeric_only, min_count=0) + expected = Series([unit, unit, unit], index=idx, dtype="float64") + tm.assert_series_equal(result, expected) + + result = getattr(df.iloc[1:], method)(numeric_only=numeric_only, min_count=1) + expected = Series([unit, np.nan, np.nan], index=idx) + tm.assert_series_equal(result, expected) + + # min_count > 1 + df = DataFrame({"A": [unit] * 10, "B": [unit] * 5 + [np.nan] * 5}) + result = getattr(df, method)(numeric_only=numeric_only, min_count=5) + expected = Series(result, index=["A", "B"]) + tm.assert_series_equal(result, expected) + + result = getattr(df, method)(numeric_only=numeric_only, min_count=6) + expected = Series(result, index=["A", "B"]) + tm.assert_series_equal(result, expected) + + def test_sum_nanops_timedelta(self): + # prod isn't defined on timedeltas + idx = ["a", "b", "c"] + df = DataFrame({"a": [0, 0], "b": [0, np.nan], "c": [np.nan, np.nan]}) + + df2 = df.apply(to_timedelta) + + # 0 by default + result = df2.sum() + expected = Series([0, 0, 0], dtype="m8[ns]", index=idx) + tm.assert_series_equal(result, expected) + + # min_count=0 + result = df2.sum(min_count=0) + tm.assert_series_equal(result, expected) + + # min_count=1 + result = df2.sum(min_count=1) + expected = Series([0, 0, np.nan], dtype="m8[ns]", index=idx) + tm.assert_series_equal(result, expected) + + def test_sum_nanops_min_count(self): + # https://github.com/pandas-dev/pandas/issues/39738 + df = DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) + result = df.sum(min_count=10) + expected = Series([np.nan, np.nan], index=["x", "y"]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("float_type", ["float16", "float32", "float64"]) + @pytest.mark.parametrize( + "kwargs, expected_result", + [ + ({"axis": 1, "min_count": 2}, [3.2, 5.3, np.nan]), + ({"axis": 1, "min_count": 3}, [np.nan, np.nan, np.nan]), + ({"axis": 1, "skipna": False}, [3.2, 5.3, np.nan]), + ], + ) + def test_sum_nanops_dtype_min_count(self, float_type, kwargs, expected_result): + # GH#46947 + df = DataFrame({"a": [1.0, 2.3, 4.4], "b": [2.2, 3, np.nan]}, dtype=float_type) + result = df.sum(**kwargs) + expected = Series(expected_result).astype(float_type) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("float_type", ["float16", "float32", "float64"]) + @pytest.mark.parametrize( + "kwargs, expected_result", + [ + ({"axis": 1, "min_count": 2}, [2.0, 4.0, np.nan]), + ({"axis": 1, "min_count": 3}, [np.nan, np.nan, np.nan]), + ({"axis": 1, "skipna": False}, [2.0, 4.0, np.nan]), + ], + ) + def test_prod_nanops_dtype_min_count(self, float_type, kwargs, expected_result): + # GH#46947 + df = DataFrame( + {"a": [1.0, 2.0, 4.4], "b": [2.0, 2.0, np.nan]}, dtype=float_type + ) + result = df.prod(**kwargs) + expected = Series(expected_result).astype(float_type) + tm.assert_series_equal(result, expected) + + def test_sum_object(self, float_frame): + values = float_frame.values.astype(int) + frame = DataFrame(values, index=float_frame.index, columns=float_frame.columns) + deltas = frame * timedelta(1) + deltas.sum() + + def test_sum_bool(self, float_frame): + # ensure this works, bug report + bools = np.isnan(float_frame) + bools.sum(axis=1) + bools.sum(axis=0) + + def test_sum_mixed_datetime(self): + # GH#30886 + df = DataFrame({"A": date_range("2000", periods=4), "B": [1, 2, 3, 4]}).reindex( + [2, 3, 4] + ) + with pytest.raises(TypeError, match="does not support operation 'sum'"): + df.sum() + + def test_mean_corner(self, float_frame, float_string_frame): + # unit test when have object data + msg = "Could not convert|does not support|Cannot perform" + with pytest.raises(TypeError, match=msg): + float_string_frame.mean(axis=0) + + # xs sum mixed type, just want to know it works... + with pytest.raises(TypeError, match="unsupported operand type"): + float_string_frame.mean(axis=1) + + # take mean of boolean column + float_frame["bool"] = float_frame["A"] > 0 + means = float_frame.mean(axis=0) + assert means["bool"] == float_frame["bool"].values.mean() + + def test_mean_datetimelike(self): + # GH#24757 check that datetimelike are excluded by default, handled + # correctly with numeric_only=True + # As of 2.0, datetimelike are *not* excluded with numeric_only=None + + df = DataFrame( + { + "A": np.arange(3), + "B": date_range("2016-01-01", periods=3), + "C": pd.timedelta_range("1D", periods=3), + "D": pd.period_range("2016", periods=3, freq="Y"), + } + ) + result = df.mean(numeric_only=True) + expected = Series({"A": 1.0}) + tm.assert_series_equal(result, expected) + + with pytest.raises(TypeError, match="mean is not implemented for PeriodArray"): + df.mean() + + def test_mean_datetimelike_numeric_only_false(self): + df = DataFrame( + { + "A": np.arange(3), + "B": date_range("2016-01-01", periods=3), + "C": pd.timedelta_range("1D", periods=3), + } + ) + + # datetime(tz) and timedelta work + result = df.mean(numeric_only=False) + expected = Series({"A": 1, "B": df.loc[1, "B"], "C": df.loc[1, "C"]}) + tm.assert_series_equal(result, expected) + + # mean of period is not allowed + df["D"] = pd.period_range("2016", periods=3, freq="Y") + + with pytest.raises(TypeError, match="mean is not implemented for Period"): + df.mean(numeric_only=False) + + def test_mean_extensionarray_numeric_only_true(self): + # https://github.com/pandas-dev/pandas/issues/33256 + arr = np.random.default_rng(2).integers(1000, size=(10, 5)) + df = DataFrame(arr, dtype="Int64") + result = df.mean(numeric_only=True) + expected = DataFrame(arr).mean().astype("Float64") + tm.assert_series_equal(result, expected) + + def test_stats_mixed_type(self, float_string_frame): + with pytest.raises(TypeError, match="could not convert"): + float_string_frame.std(axis=1) + with pytest.raises(TypeError, match="could not convert"): + float_string_frame.var(axis=1) + with pytest.raises(TypeError, match="unsupported operand type"): + float_string_frame.mean(axis=1) + with pytest.raises(TypeError, match="could not convert"): + float_string_frame.skew(axis=1) + + def test_sum_bools(self): + df = DataFrame(index=range(1), columns=range(10)) + bools = isna(df) + assert bools.sum(axis=1)[0] == 10 + + @pytest.mark.parametrize( + "input_data, expected_data", + [ + ({"a": ["483", "3"], "b": ["94", "759"]}, ["48394", "3759"]), + ( + {"a": ["483.948", "3.0"], "b": ["94.2", "759.93"]}, + ["483.94894.2", "3.0759.93"], + ), + ({"a": ["483", "3.0"], "b": ["94.2", "79"]}, ["48394.2", "3.079"]), + ], + ) + def test_sum_string_dtype_coercion(self, input_data, expected_data): + # GH#22642 + # Check that summing numeric strings results in concatenation + # and not conversion to dtype int64 or float64 + df = DataFrame(input_data) + expected = Series(expected_data) + result = df.sum(axis=1) + tm.assert_series_equal(result, expected) + + # ---------------------------------------------------------------------- + # Index of max / min + + @pytest.mark.parametrize("axis", [0, 1]) + def test_idxmin(self, float_frame, int_frame, skipna, axis): + frame = float_frame + frame.iloc[5:10] = np.nan + frame.iloc[15:20, -2:] = np.nan + for df in [frame, int_frame]: + if (not skipna or axis == 1) and df is not int_frame: + if skipna: + msg = "Encountered all NA values" + else: + msg = "Encountered an NA value" + with pytest.raises(ValueError, match=msg): + df.idxmin(axis=axis, skipna=skipna) + with pytest.raises(ValueError, match=msg): + df.idxmin(axis=axis, skipna=skipna) + else: + result = df.idxmin(axis=axis, skipna=skipna) + expected = df.apply(Series.idxmin, axis=axis, skipna=skipna) + expected = expected.astype(df.index.dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("axis", [0, 1]) + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_idxmin_empty(self, index, skipna, axis): + # GH53265 + if axis == 0: + frame = DataFrame(index=index) + else: + frame = DataFrame(columns=index) + + result = frame.idxmin(axis=axis, skipna=skipna) + expected = Series(dtype=index.dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("numeric_only", [True, False]) + def test_idxmin_numeric_only(self, numeric_only): + df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1], "c": list("xyx")}) + result = df.idxmin(numeric_only=numeric_only) + if numeric_only: + expected = Series([2, 1], index=["a", "b"]) + else: + expected = Series([2, 1, 0], index=["a", "b", "c"]) + tm.assert_series_equal(result, expected) + + def test_idxmin_axis_2(self, float_frame): + frame = float_frame + msg = "No axis named 2 for object type DataFrame" + with pytest.raises(ValueError, match=msg): + frame.idxmin(axis=2) + + @pytest.mark.parametrize("axis", [0, 1]) + def test_idxmax(self, float_frame, int_frame, skipna, axis): + frame = float_frame + frame.iloc[5:10] = np.nan + frame.iloc[15:20, -2:] = np.nan + for df in [frame, int_frame]: + if (skipna is False or axis == 1) and df is frame: + if skipna: + msg = "Encountered all NA values" + else: + msg = "Encountered an NA value" + with pytest.raises(ValueError, match=msg): + df.idxmax(axis=axis, skipna=skipna) + return + + result = df.idxmax(axis=axis, skipna=skipna) + expected = df.apply(Series.idxmax, axis=axis, skipna=skipna) + expected = expected.astype(df.index.dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("axis", [0, 1]) + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_idxmax_empty(self, index, skipna, axis): + # GH53265 + if axis == 0: + frame = DataFrame(index=index) + else: + frame = DataFrame(columns=index) + + result = frame.idxmax(axis=axis, skipna=skipna) + expected = Series(dtype=index.dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("numeric_only", [True, False]) + def test_idxmax_numeric_only(self, numeric_only): + df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1], "c": list("xyx")}) + result = df.idxmax(numeric_only=numeric_only) + if numeric_only: + expected = Series([1, 0], index=["a", "b"]) + else: + expected = Series([1, 0, 1], index=["a", "b", "c"]) + tm.assert_series_equal(result, expected) + + def test_idxmax_arrow_types(self): + # GH#55368 + pytest.importorskip("pyarrow") + + df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1]}, dtype="int64[pyarrow]") + result = df.idxmax() + expected = Series([1, 0], index=["a", "b"]) + tm.assert_series_equal(result, expected) + + result = df.idxmin() + expected = Series([2, 1], index=["a", "b"]) + tm.assert_series_equal(result, expected) + + df = DataFrame({"a": ["b", "c", "a"]}, dtype="string[pyarrow]") + result = df.idxmax(numeric_only=False) + expected = Series([1], index=["a"]) + tm.assert_series_equal(result, expected) + + result = df.idxmin(numeric_only=False) + expected = Series([2], index=["a"]) + tm.assert_series_equal(result, expected) + + def test_idxmax_axis_2(self, float_frame): + frame = float_frame + msg = "No axis named 2 for object type DataFrame" + with pytest.raises(ValueError, match=msg): + frame.idxmax(axis=2) + + def test_idxmax_mixed_dtype(self): + # don't cast to object, which would raise in nanops + dti = date_range("2016-01-01", periods=3) + df = DataFrame({1: [0, 2, 1], 2: range(3)[::-1], 3: dti}) + + result = df.idxmax() + expected = Series([1, 0, 2], index=range(1, 4)) + tm.assert_series_equal(result, expected) + + result = df.idxmin() + expected = Series([0, 2, 0], index=range(1, 4)) + tm.assert_series_equal(result, expected) + + # with NaTs + df.loc[0, 3] = pd.NaT + result = df.idxmax() + expected = Series([1, 0, 2], index=range(1, 4)) + tm.assert_series_equal(result, expected) + + result = df.idxmin() + expected = Series([0, 2, 1], index=range(1, 4)) + tm.assert_series_equal(result, expected) + + # with multi-column dt64 block + df[4] = dti[::-1] + df._consolidate_inplace() + + result = df.idxmax() + expected = Series([1, 0, 2, 0], index=range(1, 5)) + tm.assert_series_equal(result, expected) + + result = df.idxmin() + expected = Series([0, 2, 1, 2], index=range(1, 5)) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "op, expected_value", + [("idxmax", [0, 4]), ("idxmin", [0, 5])], + ) + def test_idxmax_idxmin_convert_dtypes(self, op, expected_value): + # GH 40346 + df = DataFrame( + { + "ID": [100, 100, 100, 200, 200, 200], + "value": [0, 0, 0, 1, 2, 0], + }, + dtype="Int64", + ) + df = df.groupby("ID") + + result = getattr(df, op)() + expected = DataFrame( + {"value": expected_value}, + index=Index([100, 200], name="ID", dtype="Int64"), + ) + tm.assert_frame_equal(result, expected) + + def test_idxmax_dt64_multicolumn_axis1(self): + dti = date_range("2016-01-01", periods=3) + df = DataFrame({3: dti, 4: dti[::-1]}, copy=True) + df.iloc[0, 0] = pd.NaT + + df._consolidate_inplace() + + result = df.idxmax(axis=1) + expected = Series([4, 3, 3]) + tm.assert_series_equal(result, expected) + + result = df.idxmin(axis=1) + expected = Series([4, 3, 4]) + tm.assert_series_equal(result, expected) + + # ---------------------------------------------------------------------- + # Logical reductions + + @pytest.mark.parametrize("axis", [0, 1]) + @pytest.mark.parametrize("bool_only", [False, True]) + def test_any_all_mixed_float( + self, all_boolean_reductions, axis, bool_only, float_string_frame + ): + # make sure op works on mixed-type frame + mixed = float_string_frame + mixed["_bool_"] = np.random.default_rng(2).standard_normal(len(mixed)) > 0.5 + + getattr(mixed, all_boolean_reductions)(axis=axis, bool_only=bool_only) + + @pytest.mark.parametrize("axis", [0, 1]) + def test_any_all_bool_with_na( + self, all_boolean_reductions, axis, bool_frame_with_na + ): + getattr(bool_frame_with_na, all_boolean_reductions)(axis=axis, bool_only=False) + + def test_any_all_bool_frame(self, all_boolean_reductions, bool_frame_with_na): + # GH#12863: numpy gives back non-boolean data for object type + # so fill NaNs to compare with pandas behavior + frame = bool_frame_with_na.fillna(True) + alternative = getattr(np, all_boolean_reductions) + f = getattr(frame, all_boolean_reductions) + + def skipna_wrapper(x): + nona = x.dropna().values + return alternative(nona) + + def wrapper(x): + return alternative(x.values) + + result0 = f(axis=0, skipna=False) + result1 = f(axis=1, skipna=False) + + tm.assert_series_equal(result0, frame.apply(wrapper)) + tm.assert_series_equal(result1, frame.apply(wrapper, axis=1)) + + result0 = f(axis=0) + result1 = f(axis=1) + + tm.assert_series_equal(result0, frame.apply(skipna_wrapper)) + tm.assert_series_equal( + result1, frame.apply(skipna_wrapper, axis=1), check_dtype=False + ) + + # bad axis + with pytest.raises(ValueError, match="No axis named 2"): + f(axis=2) + + # all NA case + all_na = frame * np.nan + r0 = getattr(all_na, all_boolean_reductions)(axis=0) + r1 = getattr(all_na, all_boolean_reductions)(axis=1) + if all_boolean_reductions == "any": + assert not r0.any() + assert not r1.any() + else: + assert r0.all() + assert r1.all() + + def test_any_all_extra(self, using_python_scalars): + df = DataFrame( + { + "A": [True, False, False], + "B": [True, True, False], + "C": [True, True, True], + }, + index=["a", "b", "c"], + ) + result = df[["A", "B"]].any(axis=1) + expected = Series([True, True, False], index=["a", "b", "c"]) + tm.assert_series_equal(result, expected) + + result = df[["A", "B"]].any(axis=1, bool_only=True) + tm.assert_series_equal(result, expected) + + result = df.all(axis=1) + expected = Series([True, False, False], index=["a", "b", "c"]) + tm.assert_series_equal(result, expected) + + result = df.all(axis=1, bool_only=True) + tm.assert_series_equal(result, expected) + + # Axis is None + result = df.all(axis=None) + if not using_python_scalars: + result = result.item() + assert result is False + + result = df.any(axis=None) + if not using_python_scalars: + result = result.item() + assert result is True + + result = df[["C"]].all(axis=None) + if not using_python_scalars: + result = result.item() + assert result is True + + @pytest.mark.parametrize("axis", [0, 1]) + def test_any_all_object_dtype(self, axis, all_boolean_reductions, skipna): + # GH#35450 + df = DataFrame( + data=[ + [1, np.nan, np.nan, True], + [np.nan, 2, np.nan, True], + [np.nan, np.nan, np.nan, True], + [np.nan, np.nan, "5", np.nan], + ] + ) + result = getattr(df, all_boolean_reductions)(axis=axis, skipna=skipna) + expected = Series([True, True, True, True]) + tm.assert_series_equal(result, expected) + + def test_any_datetime(self): + # GH 23070 + float_data = [1, np.nan, 3, np.nan] + datetime_data = [ + Timestamp("1960-02-15"), + Timestamp("1960-02-16"), + pd.NaT, + pd.NaT, + ] + df = DataFrame({"A": float_data, "B": datetime_data}) + + msg = "datetime64 type does not support operation 'any'" + with pytest.raises(TypeError, match=msg): + df.any(axis=1) + + def test_any_all_bool_only(self): + # GH 25101 + df = DataFrame( + {"col1": [1, 2, 3], "col2": [4, 5, 6], "col3": [None, None, None]}, + columns=Index(["col1", "col2", "col3"], dtype=object), + ) + + result = df.all(bool_only=True) + expected = Series(dtype=np.bool_, index=[]) + tm.assert_series_equal(result, expected) + + df = DataFrame( + { + "col1": [1, 2, 3], + "col2": [4, 5, 6], + "col3": [None, None, None], + "col4": [False, False, True], + } + ) + + result = df.all(bool_only=True) + expected = Series({"col4": False}) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "func, data, expected", + [ + (np.any, {}, False), + (np.all, {}, True), + (np.any, {"A": []}, False), + (np.all, {"A": []}, True), + (np.any, {"A": [False, False]}, False), + (np.all, {"A": [False, False]}, False), + (np.any, {"A": [True, False]}, True), + (np.all, {"A": [True, False]}, False), + (np.any, {"A": [True, True]}, True), + (np.all, {"A": [True, True]}, True), + (np.any, {"A": [False], "B": [False]}, False), + (np.all, {"A": [False], "B": [False]}, False), + (np.any, {"A": [False, False], "B": [False, True]}, True), + (np.all, {"A": [False, False], "B": [False, True]}, False), + # other types + (np.all, {"A": Series([0.0, 1.0], dtype="float")}, False), + (np.any, {"A": Series([0.0, 1.0], dtype="float")}, True), + (np.all, {"A": Series([0, 1], dtype=int)}, False), + (np.any, {"A": Series([0, 1], dtype=int)}, True), + pytest.param(np.all, {"A": Series([0, 1], dtype="M8[ns]")}, False), + pytest.param(np.all, {"A": Series([0, 1], dtype="M8[ns, UTC]")}, False), + pytest.param(np.any, {"A": Series([0, 1], dtype="M8[ns]")}, True), + pytest.param(np.any, {"A": Series([0, 1], dtype="M8[ns, UTC]")}, True), + pytest.param(np.all, {"A": Series([1, 2], dtype="M8[ns]")}, True), + pytest.param(np.all, {"A": Series([1, 2], dtype="M8[ns, UTC]")}, True), + pytest.param(np.any, {"A": Series([1, 2], dtype="M8[ns]")}, True), + pytest.param(np.any, {"A": Series([1, 2], dtype="M8[ns, UTC]")}, True), + pytest.param(np.all, {"A": Series([0, 1], dtype="m8[ns]")}, False), + pytest.param(np.any, {"A": Series([0, 1], dtype="m8[ns]")}, True), + pytest.param(np.all, {"A": Series([1, 2], dtype="m8[ns]")}, True), + pytest.param(np.any, {"A": Series([1, 2], dtype="m8[ns]")}, True), + # np.all on Categorical raises, so the reduction drops the + # column, so all is being done on an empty Series, so is True + (np.all, {"A": Series([0, 1], dtype="category")}, True), + (np.any, {"A": Series([0, 1], dtype="category")}, False), + (np.all, {"A": Series([1, 2], dtype="category")}, True), + (np.any, {"A": Series([1, 2], dtype="category")}, False), + # Mix GH#21484 + pytest.param( + np.all, + { + "A": Series([10, 20], dtype="M8[ns]"), + "B": Series([10, 20], dtype="m8[ns]"), + }, + True, + ), + ], + ) + def test_any_all_np_func(self, func, data, expected, using_python_scalars): + # GH 19976 + data = DataFrame(data) + + if any(isinstance(x, CategoricalDtype) for x in data.dtypes): + with pytest.raises( + TypeError, match=".* dtype category does not support operation" + ): + func(data) + + # method version + with pytest.raises( + TypeError, match=".* dtype category does not support operation" + ): + getattr(DataFrame(data), func.__name__)(axis=None) + if data.dtypes.apply(lambda x: x.kind == "M").any(): + # GH#34479 + msg = "datetime64 type does not support operation '(any|all)'" + with pytest.raises(TypeError, match=msg): + func(data) + + # method version + with pytest.raises(TypeError, match=msg): + getattr(DataFrame(data), func.__name__)(axis=None) + + elif data.dtypes.apply(lambda x: x != "category").any(): + result = func(data) + if using_python_scalars: + assert result is expected + else: + assert isinstance(result, np.bool_) + assert result.item() is expected + + # method version + result = getattr(DataFrame(data), func.__name__)(axis=None) + if using_python_scalars: + assert result is expected + else: + assert isinstance(result, np.bool_) + assert result.item() is expected + + def test_any_all_object(self, using_python_scalars): + # GH 19976 + result = np.all(DataFrame(columns=["a", "b"])) + if not using_python_scalars: + result = result.item() + assert result is True + + result = np.any(DataFrame(columns=["a", "b"])) + if not using_python_scalars: + result = result.item() + assert result is False + + def test_any_all_object_bool_only(self): + df = DataFrame({"A": ["foo", 2], "B": [True, False]}).astype(object) + df._consolidate_inplace() + df["C"] = Series([True, True]) + + # Categorical of bools is _not_ considered booly + df["D"] = df["C"].astype("category") + + # The underlying bug is in DataFrame._get_bool_data, so we check + # that while we're here + res = df._get_bool_data() + expected = df[["C"]] + tm.assert_frame_equal(res, expected) + + res = df.all(bool_only=True, axis=0) + expected = Series([True], index=["C"]) + tm.assert_series_equal(res, expected) + + # operating on a subset of columns should not produce a _larger_ Series + res = df[["B", "C"]].all(bool_only=True, axis=0) + tm.assert_series_equal(res, expected) + + assert df.all(bool_only=True, axis=None) + + res = df.any(bool_only=True, axis=0) + expected = Series([True], index=["C"]) + tm.assert_series_equal(res, expected) + + # operating on a subset of columns should not produce a _larger_ Series + res = df[["C"]].any(bool_only=True, axis=0) + tm.assert_series_equal(res, expected) + + assert df.any(bool_only=True, axis=None) + + # --------------------------------------------------------------------- + # Unsorted + + def test_series_broadcasting(self): + # smoke test for numpy warnings + # GH 16378, GH 16306 + df = DataFrame([1.0, 1.0, 1.0]) + df_nan = DataFrame({"A": [np.nan, 2.0, np.nan]}) + s = Series([1, 1, 1]) + s_nan = Series([np.nan, np.nan, 1]) + + with tm.assert_produces_warning(None): + df_nan.clip(lower=s, axis=0) + for op in ["lt", "le", "gt", "ge", "eq", "ne"]: + getattr(df, op)(s_nan, axis=0) + + +class TestDataFrameReductions: + def test_min_max_dt64_with_NaT(self): + # Both NaT and Timestamp are in DataFrame. + df = DataFrame({"foo": [pd.NaT, pd.NaT, Timestamp("2012-05-01")]}) + + res = df.min() + exp = Series([Timestamp("2012-05-01")], index=["foo"]) + tm.assert_series_equal(res, exp) + + res = df.max() + exp = Series([Timestamp("2012-05-01")], index=["foo"]) + tm.assert_series_equal(res, exp) + + # GH12941, only NaTs are in DataFrame. + df = DataFrame({"foo": [pd.NaT, pd.NaT]}) + + res = df.min() + exp = Series([pd.NaT], index=["foo"]) + tm.assert_series_equal(res, exp) + + res = df.max() + exp = Series([pd.NaT], index=["foo"]) + tm.assert_series_equal(res, exp) + + def test_min_max_dt64_with_NaT_precision(self): + # GH#60646 Make sure the reduction doesn't cast input timestamps to + # float and lose precision. + df = DataFrame( + {"foo": [pd.NaT, pd.NaT, Timestamp("2012-05-01 09:20:00.123456789")]}, + dtype="datetime64[ns]", + ) + + res = df.min(axis=1) + exp = df.foo.rename(None) + tm.assert_series_equal(res, exp) + + res = df.max(axis=1) + exp = df.foo.rename(None) + tm.assert_series_equal(res, exp) + + def test_min_max_td64_with_NaT_precision(self): + # GH#60646 Make sure the reduction doesn't cast input timedeltas to + # float and lose precision. + df = DataFrame( + { + "foo": [ + pd.NaT, + pd.NaT, + to_timedelta("10000 days 06:05:01.123456789"), + ], + }, + dtype="timedelta64[ns]", + ) + + res = df.min(axis=1) + exp = df.foo.rename(None) + tm.assert_series_equal(res, exp) + + res = df.max(axis=1) + exp = df.foo.rename(None) + tm.assert_series_equal(res, exp) + + def test_min_max_dt64_with_NaT_skipna_false(self, request, tz_naive_fixture): + # GH#36907 + tz = tz_naive_fixture + if isinstance(tz, tzlocal) and is_platform_windows(): + pytest.skip( + "GH#37659 OSError raised within tzlocal bc Windows " + "chokes in times before 1970-01-01" + ) + + df = DataFrame( + { + "a": [ + Timestamp("2020-01-01 08:00:00", tz=tz), + Timestamp("1920-02-01 09:00:00", tz=tz), + ], + "b": [Timestamp("2020-02-01 08:00:00", tz=tz), pd.NaT], + } + ) + res = df.min(axis=1, skipna=False) + expected = Series([df.loc[0, "a"], pd.NaT]) + assert expected.dtype == df["a"].dtype + + tm.assert_series_equal(res, expected) + + res = df.max(axis=1, skipna=False) + expected = Series([df.loc[0, "b"], pd.NaT]) + assert expected.dtype == df["a"].dtype + + tm.assert_series_equal(res, expected) + + def test_min_max_dt64_api_consistency_with_NaT(self): + # Calling the following sum functions returned an error for dataframes but + # returned NaT for series. These tests check that the API is consistent in + # min/max calls on empty Series/DataFrames. See GH:33704 for more + # information + df = DataFrame({"x": to_datetime([])}) + expected_dt_series = Series(to_datetime([])) + # check axis 0 + assert (df.min(axis=0).x is pd.NaT) == (expected_dt_series.min() is pd.NaT) + assert (df.max(axis=0).x is pd.NaT) == (expected_dt_series.max() is pd.NaT) + + # check axis 1 + tm.assert_series_equal(df.min(axis=1), expected_dt_series) + tm.assert_series_equal(df.max(axis=1), expected_dt_series) + + def test_min_max_dt64_api_consistency_empty_df(self): + # check DataFrame/Series api consistency when calling min/max on an empty + # DataFrame/Series. + df = DataFrame({"x": []}) + expected_float_series = Series([], dtype=float) + # check axis 0 + assert np.isnan(df.min(axis=0).x) == np.isnan(expected_float_series.min()) + assert np.isnan(df.max(axis=0).x) == np.isnan(expected_float_series.max()) + # check axis 1 + tm.assert_series_equal(df.min(axis=1), expected_float_series) + tm.assert_series_equal(df.min(axis=1), expected_float_series) + + @pytest.mark.parametrize( + "initial", + ["2018-10-08 13:36:45+00:00", "2018-10-08 13:36:45+03:00"], # Non-UTC timezone + ) + @pytest.mark.parametrize("method", ["min", "max"]) + def test_preserve_timezone(self, initial: str, method): + # GH 28552 + initial_dt = to_datetime(initial) + expected = Series([initial_dt]) + df = DataFrame([expected]) + result = getattr(df, method)(axis=1) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("method", ["min", "max"]) + def test_minmax_tzaware_skipna_axis_1(self, method, skipna): + # GH#51242 + val = to_datetime("1900-01-01", utc=True) + df = DataFrame( + {"a": Series([pd.NaT, pd.NaT, val]), "b": Series([pd.NaT, val, val])} + ) + op = getattr(df, method) + result = op(axis=1, skipna=skipna) + if skipna: + expected = Series([pd.NaT, val, val]) + else: + expected = Series([pd.NaT, pd.NaT, val]) + tm.assert_series_equal(result, expected) + + def test_frame_any_with_timedelta(self): + # GH#17667 + df = DataFrame( + { + "a": Series([0, 0]), + "t": Series([to_timedelta(0, "s"), to_timedelta(1, "ms")]), + } + ) + + result = df.any(axis=0) + expected = Series(data=[False, True], index=["a", "t"]) + tm.assert_series_equal(result, expected) + + result = df.any(axis=1) + expected = Series(data=[False, True]) + tm.assert_series_equal(result, expected) + + def test_reductions_skipna_none_raises( + self, request, frame_or_series, all_reductions + ): + if all_reductions == "count": + request.applymarker( + pytest.mark.xfail(reason="Count does not accept skipna") + ) + obj = frame_or_series([1, 2, 3]) + msg = 'For argument "skipna" expected type bool, received type NoneType.' + with pytest.raises(ValueError, match=msg): + getattr(obj, all_reductions)(skipna=None) + + def test_reduction_timestamp_smallest_unit(self): + # GH#52524 + df = DataFrame( + { + "a": Series([Timestamp("2019-12-31")], dtype="datetime64[s]"), + "b": Series( + [Timestamp("2019-12-31 00:00:00.123")], dtype="datetime64[ms]" + ), + } + ) + result = df.max() + expected = Series( + [Timestamp("2019-12-31"), Timestamp("2019-12-31 00:00:00.123")], + dtype="datetime64[ms]", + index=["a", "b"], + ) + tm.assert_series_equal(result, expected) + + def test_reduction_timedelta_smallest_unit(self): + # GH#52524 + df = DataFrame( + { + "a": Series([pd.Timedelta("1 days")], dtype="timedelta64[s]"), + "b": Series([pd.Timedelta("1 days")], dtype="timedelta64[ms]"), + } + ) + result = df.max() + expected = Series( + [pd.Timedelta("1 days"), pd.Timedelta("1 days")], + dtype="timedelta64[ms]", + index=["a", "b"], + ) + tm.assert_series_equal(result, expected) + + +class TestNuisanceColumns: + def test_any_all_categorical_dtype_nuisance_column(self, all_boolean_reductions): + # GH#36076 DataFrame should match Series behavior + ser = Series([0, 1], dtype="category", name="A") + df = ser.to_frame() + + # Double-check the Series behavior is to raise + with pytest.raises(TypeError, match="does not support operation"): + getattr(ser, all_boolean_reductions)() + + with pytest.raises(TypeError, match="does not support operation"): + getattr(np, all_boolean_reductions)(ser) + + with pytest.raises(TypeError, match="does not support operation"): + getattr(df, all_boolean_reductions)(bool_only=False) + + with pytest.raises(TypeError, match="does not support operation"): + getattr(df, all_boolean_reductions)(bool_only=None) + + with pytest.raises(TypeError, match="does not support operation"): + getattr(np, all_boolean_reductions)(df, axis=0) + + def test_median_categorical_dtype_nuisance_column(self): + # GH#21020 DataFrame.median should match Series.median + df = DataFrame({"A": Categorical([1, 2, 2, 2, 3])}) + ser = df["A"] + + # Double-check the Series behavior is to raise + with pytest.raises(TypeError, match="does not support operation"): + ser.median() + + with pytest.raises(TypeError, match="does not support operation"): + df.median(numeric_only=False) + + with pytest.raises(TypeError, match="does not support operation"): + df.median() + + # same thing, but with an additional non-categorical column + df["B"] = df["A"].astype(int) + + with pytest.raises(TypeError, match="does not support operation"): + df.median(numeric_only=False) + + with pytest.raises(TypeError, match="does not support operation"): + df.median() + + # TODO: np.median(df, axis=0) gives np.array([2.0, 2.0]) instead + # of expected.values + + @pytest.mark.parametrize("method", ["min", "max"]) + def test_min_max_categorical_dtype_non_ordered_nuisance_column(self, method): + # GH#28949 DataFrame.min should behave like Series.min + cat = Categorical(["a", "b", "c", "b"], ordered=False) + ser = Series(cat) + df = ser.to_frame("A") + + # Double-check the Series behavior + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(ser, method)() + + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(np, method)(ser) + + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(df, method)(numeric_only=False) + + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(df, method)() + + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(np, method)(df, axis=0) + + # same thing, but with an additional non-categorical column + df["B"] = df["A"].astype(object) + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(df, method)() + + with pytest.raises(TypeError, match="is not ordered for operation"): + getattr(np, method)(df, axis=0) + + +class TestEmptyDataFrameReductions: + @pytest.mark.parametrize( + "opname, dtype, exp_value, exp_dtype", + [ + ("sum", np.int8, 0, np.int64), + ("prod", np.int8, 1, np.int_), + ("sum", np.int64, 0, np.int64), + ("prod", np.int64, 1, np.int64), + ("sum", np.uint8, 0, np.uint64), + ("prod", np.uint8, 1, np.uint), + ("sum", np.uint64, 0, np.uint64), + ("prod", np.uint64, 1, np.uint64), + ("sum", np.float32, 0, np.float32), + ("prod", np.float32, 1, np.float32), + ("sum", np.float64, 0, np.float64), + ], + ) + def test_df_empty_min_count_0(self, opname, dtype, exp_value, exp_dtype): + df = DataFrame({0: [], 1: []}, dtype=dtype) + result = getattr(df, opname)(min_count=0) + + expected = Series([exp_value, exp_value], dtype=exp_dtype, index=range(2)) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "opname, dtype, exp_dtype", + [ + ("sum", np.int8, np.float64), + ("prod", np.int8, np.float64), + ("sum", np.int64, np.float64), + ("prod", np.int64, np.float64), + ("sum", np.uint8, np.float64), + ("prod", np.uint8, np.float64), + ("sum", np.uint64, np.float64), + ("prod", np.uint64, np.float64), + ("sum", np.float32, np.float32), + ("prod", np.float32, np.float32), + ("sum", np.float64, np.float64), + ], + ) + def test_df_empty_min_count_1(self, opname, dtype, exp_dtype): + df = DataFrame({0: [], 1: []}, dtype=dtype) + result = getattr(df, opname)(min_count=1) + + expected = Series([np.nan, np.nan], dtype=exp_dtype, index=Index([0, 1])) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "opname, dtype, exp_value, exp_dtype", + [ + ("sum", "Int8", 0, ("Int32" if is_windows_np2_or_is32 else "Int64")), + ("prod", "Int8", 1, ("Int32" if is_windows_np2_or_is32 else "Int64")), + ("sum", "Int64", 0, "Int64"), + ("prod", "Int64", 1, "Int64"), + ("sum", "UInt8", 0, ("UInt32" if is_windows_np2_or_is32 else "UInt64")), + ("prod", "UInt8", 1, ("UInt32" if is_windows_np2_or_is32 else "UInt64")), + ("sum", "UInt64", 0, "UInt64"), + ("prod", "UInt64", 1, "UInt64"), + ("sum", "Float32", 0, "Float32"), + ("prod", "Float32", 1, "Float32"), + ("sum", "Float64", 0, "Float64"), + ], + ) + def test_df_empty_nullable_min_count_0(self, opname, dtype, exp_value, exp_dtype): + df = DataFrame({0: [], 1: []}, dtype=dtype) + result = getattr(df, opname)(min_count=0) + + expected = Series([exp_value, exp_value], dtype=exp_dtype, index=Index([0, 1])) + tm.assert_series_equal(result, expected) + + # TODO: why does min_count=1 impact the resulting Windows dtype + # differently than min_count=0? + @pytest.mark.parametrize( + "opname, dtype, exp_dtype", + [ + ("sum", "Int8", ("Int32" if is_windows_or_is32 else "Int64")), + ("prod", "Int8", ("Int32" if is_windows_or_is32 else "Int64")), + ("sum", "Int64", "Int64"), + ("prod", "Int64", "Int64"), + ("sum", "UInt8", ("UInt32" if is_windows_or_is32 else "UInt64")), + ("prod", "UInt8", ("UInt32" if is_windows_or_is32 else "UInt64")), + ("sum", "UInt64", "UInt64"), + ("prod", "UInt64", "UInt64"), + ("sum", "Float32", "Float32"), + ("prod", "Float32", "Float32"), + ("sum", "Float64", "Float64"), + ], + ) + def test_df_empty_nullable_min_count_1(self, opname, dtype, exp_dtype): + df = DataFrame({0: [], 1: []}, dtype=dtype) + result = getattr(df, opname)(min_count=1) + + expected = Series([pd.NA, pd.NA], dtype=exp_dtype, index=Index([0, 1])) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "data", + [ + {"a": [0, 1, 2], "b": [pd.NaT, pd.NaT, pd.NaT]}, + {"a": [0, 1, 2], "b": [Timestamp("1990-01-01"), pd.NaT, pd.NaT]}, + { + "a": [0, 1, 2], + "b": [ + Timestamp("1990-01-01"), + Timestamp("1991-01-01"), + Timestamp("1992-01-01"), + ], + }, + { + "a": [0, 1, 2], + "b": [pd.Timedelta("1 days"), pd.Timedelta("2 days"), pd.NaT], + }, + { + "a": [0, 1, 2], + "b": [ + pd.Timedelta("1 days"), + pd.Timedelta("2 days"), + pd.Timedelta("3 days"), + ], + }, + ], + ) + def test_df_cov_pd_nat(self, data): + # GH #53115 + df = DataFrame(data) + with pytest.raises(TypeError, match="not supported for cov"): + df.cov() + + +def test_sum_timedelta64_skipna_false(): + # GH#17235 + arr = np.arange(8).astype(np.int64).view("m8[s]").reshape(4, 2) + arr[-1, -1] = "Nat" + + df = DataFrame(arr) + assert (df.dtypes == arr.dtype).all() + + result = df.sum(skipna=False) + expected = Series([pd.Timedelta(seconds=12), pd.NaT], dtype="m8[s]") + tm.assert_series_equal(result, expected) + + result = df.sum(axis=0, skipna=False) + tm.assert_series_equal(result, expected) + + result = df.sum(axis=1, skipna=False) + expected = Series( + [ + pd.Timedelta(seconds=1), + pd.Timedelta(seconds=5), + pd.Timedelta(seconds=9), + pd.NaT, + ], + dtype="m8[s]", + ) + tm.assert_series_equal(result, expected) + + +def test_mixed_frame_with_integer_sum(): + # https://github.com/pandas-dev/pandas/issues/34520 + df = DataFrame([["a", 1]], columns=list("ab")) + df = df.astype({"b": "Int64"}) + result = df.sum() + expected = Series(["a", 1], index=["a", "b"]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("numeric_only", [True, False, None]) +@pytest.mark.parametrize("method", ["min", "max"]) +def test_minmax_extensionarray(method, numeric_only): + # https://github.com/pandas-dev/pandas/issues/32651 + int64_info = np.iinfo("int64") + ser = Series([int64_info.max, None, int64_info.min], dtype=pd.Int64Dtype()) + df = DataFrame({"Int64": ser}) + result = getattr(df, method)(numeric_only=numeric_only) + expected = Series( + [getattr(int64_info, method)], + dtype="Int64", + index=Index(["Int64"]), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("ts_value", [Timestamp("2000-01-01"), pd.NaT]) +def test_frame_mixed_numeric_object_with_timestamp(ts_value): + # GH 13912 + df = DataFrame({"a": [1], "b": [1.1], "c": ["foo"], "d": [ts_value]}) + with pytest.raises(TypeError, match="does not support operation|Cannot perform"): + df.sum() + + +def test_prod_sum_min_count_mixed_object(): + # https://github.com/pandas-dev/pandas/issues/41074 + df = DataFrame([1, "a", True]) + + result = df.prod(axis=0, min_count=1, numeric_only=False) + expected = Series(["a"], dtype=object) + tm.assert_series_equal(result, expected) + + msg = re.escape("unsupported operand type(s) for +: 'int' and 'str'") + with pytest.raises(TypeError, match=msg): + df.sum(axis=0, min_count=1, numeric_only=False) + + +@pytest.mark.parametrize("method", ["min", "max", "mean", "median", "skew", "kurt"]) +@pytest.mark.parametrize("numeric_only", [True, False]) +@pytest.mark.parametrize("dtype", ["float64", "Float64"]) +def test_reduction_axis_none_returns_scalar(method, numeric_only, dtype): + # GH#21597 As of 2.0, axis=None reduces over all axes. + + df = DataFrame(np.random.default_rng(2).standard_normal((4, 4)), dtype=dtype) + + result = getattr(df, method)(axis=None, numeric_only=numeric_only) + np_arr = df.to_numpy(dtype=np.float64) + if method in {"skew", "kurt"}: + comp_mod = pytest.importorskip("scipy.stats") + if method == "kurt": + method = "kurtosis" + expected = getattr(comp_mod, method)(np_arr, bias=False, axis=None) + tm.assert_almost_equal(result, expected) + else: + expected = getattr(np, method)(np_arr, axis=None) + assert result == expected + + +@pytest.mark.parametrize( + "kernel", + [ + "corr", + "corrwith", + "cov", + "idxmax", + "idxmin", + "kurt", + "max", + "mean", + "median", + "min", + "prod", + "quantile", + "sem", + "skew", + "std", + "sum", + "var", + ], +) +def test_fails_on_non_numeric(kernel): + # GH#46852 + df = DataFrame({"a": [1, 2, 3], "b": object}) + args = (df,) if kernel == "corrwith" else () + msg = "|".join( + [ + "not allowed for this dtype", + "argument must be a string or a number", + "not supported between instances of", + "unsupported operand type", + "argument must be a string or a real number", + ] + ) + if kernel == "median": + # slightly different message on different builds + msg1 = ( + r"Cannot convert \[\[ " + r"\]\] to numeric" + ) + msg2 = ( + r"Cannot convert \[ " + r"\] to numeric" + ) + msg = "|".join([msg1, msg2]) + with pytest.raises(TypeError, match=msg): + getattr(df, kernel)(*args) + + +@pytest.mark.parametrize( + "method", + [ + "all", + "any", + "count", + "idxmax", + "idxmin", + "kurt", + "kurtosis", + "max", + "mean", + "median", + "min", + "nunique", + "prod", + "product", + "sem", + "skew", + "std", + "sum", + "var", + ], +) +@pytest.mark.parametrize("min_count", [0, 2]) +def test_numeric_ea_axis_1( + method, skipna, min_count, any_numeric_ea_dtype, using_nan_is_na +): + # GH 54341 + df = DataFrame( + { + "a": Series([0, 1, 2, 3], dtype=any_numeric_ea_dtype), + "b": Series([0, 1, pd.NA, 3], dtype=any_numeric_ea_dtype), + }, + ) + expected_df = DataFrame( + { + "a": [0.0, 1.0, 2.0, 3.0], + "b": [0.0, 1.0, np.nan, 3.0], + }, + ) + if method in ("count", "nunique"): + expected_dtype = "int64" + elif method in ("all", "any"): + expected_dtype = "boolean" + elif method in ( + "kurt", + "kurtosis", + "mean", + "median", + "sem", + "skew", + "std", + "var", + ) and not any_numeric_ea_dtype.startswith("Float"): + expected_dtype = "Float64" + else: + expected_dtype = any_numeric_ea_dtype + + kwargs = {} + if method not in ("count", "nunique", "quantile"): + kwargs["skipna"] = skipna + if method in ("prod", "product", "sum"): + kwargs["min_count"] = min_count + + if not skipna and method in ("idxmax", "idxmin"): + with pytest.raises(ValueError, match="encountered an NA value"): + getattr(df, method)(axis=1, **kwargs) + with pytest.raises(ValueError, match="Encountered an NA value"): + getattr(expected_df, method)(axis=1, **kwargs) + return + result = getattr(df, method)(axis=1, **kwargs) + expected = getattr(expected_df, method)(axis=1, **kwargs) + if method not in ("idxmax", "idxmin"): + if using_nan_is_na: + expected = expected.astype(expected_dtype) + else: + mask = np.isnan(expected) + expected[mask] = 0 + expected = expected.astype(expected_dtype) + expected[mask] = pd.NA + tm.assert_series_equal(result, expected) + + +def test_mean_nullable_int_axis_1(): + # GH##36585 + df = DataFrame( + {"a": [1, 2, 3, 4], "b": Series([1, 2, 4, None], dtype=pd.Int64Dtype())} + ) + + result = df.mean(axis=1, skipna=True) + expected = Series([1.0, 2.0, 3.5, 4.0], dtype="Float64") + tm.assert_series_equal(result, expected) + + result = df.mean(axis=1, skipna=False) + expected = Series([1.0, 2.0, 3.5, pd.NA], dtype="Float64") + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/frame/test_repr.py b/pandas/tests/frame/test_repr.py new file mode 100644 index 0000000000000000000000000000000000000000..73628424725e57fd5c639f0238d229a289d6ea8e --- /dev/null +++ b/pandas/tests/frame/test_repr.py @@ -0,0 +1,498 @@ +from datetime import ( + datetime, + timedelta, +) +from io import StringIO + +import numpy as np +import pytest + +from pandas import ( + NA, + Categorical, + CategoricalIndex, + DataFrame, + IntervalIndex, + MultiIndex, + NaT, + PeriodIndex, + Series, + Timestamp, + date_range, + option_context, + period_range, +) +import pandas._testing as tm + + +class TestDataFrameRepr: + def test_repr_should_return_str(self): + # https://docs.python.org/3/reference/datamodel.html#object.__repr__ + # "...The return value must be a string object." + + # (str on py2.x, str (unicode) on py3) + + data = [8, 5, 3, 5] + index1 = ["\u03c3", "\u03c4", "\u03c5", "\u03c6"] + cols = ["\u03c8"] + df = DataFrame(data, columns=cols, index=index1) + assert type(df.__repr__()) is str + + ser = df[cols[0]] + assert type(ser.__repr__()) is str + + def test_repr_bytes_61_lines(self): + # GH#12857 + lets = list("ACDEFGHIJKLMNOP") + words = np.random.default_rng(2).choice(lets, (1000, 50)) + df = DataFrame(words).astype("U1") + assert (df.dtypes == object).all() + + # smoke tests; at one point this raised with 61 but not 60 + repr(df) + repr(df.iloc[:60, :]) + repr(df.iloc[:61, :]) + + def test_repr_unicode_level_names(self, frame_or_series): + index = MultiIndex.from_tuples([(0, 0), (1, 1)], names=["\u0394", "i1"]) + + obj = DataFrame(np.random.default_rng(2).standard_normal((2, 4)), index=index) + obj = tm.get_obj(obj, frame_or_series) + repr(obj) + + def test_assign_index_sequences(self): + # GH#2200 + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}).set_index( + ["a", "b"] + ) + index = list(df.index) + index[0] = ("faz", "boo") + df.index = index + repr(df) + + # this travels an improper code path + index[0] = ["faz", "boo"] + df.index = index + repr(df) + + def test_repr_with_mi_nat(self): + df = DataFrame({"X": [1, 2]}, index=[[NaT, Timestamp("20130101")], ["a", "b"]]) + result = repr(df) + expected = " X\nNaT a 1\n2013-01-01 b 2" + assert result == expected + + def test_repr_with_different_nulls(self): + # GH45263 + df = DataFrame([1, 2, 3, 4], [True, None, np.nan, NaT]) + result = repr(df) + expected = """ 0 +True 1 +None 2 +NaN 3 +NaT 4""" + assert result == expected + + def test_repr_with_different_nulls_cols(self): + # GH45263 + d = {np.nan: [1, 2], None: [3, 4], NaT: [6, 7], True: [8, 9]} + df = DataFrame(data=d) + result = repr(df) + expected = """ NaN None NaT True +0 1 3 6 8 +1 2 4 7 9""" + assert result == expected + + def test_multiindex_na_repr(self): + # only an issue with long columns + df3 = DataFrame( + { + "A" * 30: {("A", "A0006000", "nuit"): "A0006000"}, + "B" * 30: {("A", "A0006000", "nuit"): np.nan}, + "C" * 30: {("A", "A0006000", "nuit"): np.nan}, + "D" * 30: {("A", "A0006000", "nuit"): np.nan}, + "E" * 30: {("A", "A0006000", "nuit"): "A"}, + "F" * 30: {("A", "A0006000", "nuit"): np.nan}, + } + ) + + idf = df3.set_index(["A" * 30, "C" * 30]) + repr(idf) + + def test_repr_name_coincide(self): + index = MultiIndex.from_tuples( + [("a", 0, "foo"), ("b", 1, "bar")], names=["a", "b", "c"] + ) + + df = DataFrame({"value": [0, 1]}, index=index) + + lines = repr(df).split("\n") + assert lines[2].startswith("a 0 foo") + + def test_repr_to_string( + self, + multiindex_year_month_day_dataframe_random_data, + multiindex_dataframe_random_data, + ): + ymd = multiindex_year_month_day_dataframe_random_data + frame = multiindex_dataframe_random_data + + repr(frame) + repr(ymd) + repr(frame.T) + repr(ymd.T) + + buf = StringIO() + frame.to_string(buf=buf) + ymd.to_string(buf=buf) + frame.T.to_string(buf=buf) + ymd.T.to_string(buf=buf) + + def test_repr_empty(self): + # empty + repr(DataFrame()) + + # empty with index + frame = DataFrame(index=np.arange(1000)) + repr(frame) + + def test_repr_mixed(self, float_string_frame): + # mixed + repr(float_string_frame) + + @pytest.mark.slow + def test_repr_mixed_big(self): + # big mixed + biggie = DataFrame( + { + "A": np.random.default_rng(2).standard_normal(200), + "B": [str(i) for i in range(200)], + }, + index=range(200), + ) + biggie.loc[:20, "A"] = np.nan + biggie.loc[:20, "B"] = np.nan + + repr(biggie) + + def test_repr(self): + # columns but no index + no_index = DataFrame(columns=[0, 1, 3]) + repr(no_index) + + df = DataFrame(["a\n\r\tb"], columns=["a\n\r\td"], index=["a\n\r\tf"]) + assert "\t" not in repr(df) + assert "\r" not in repr(df) + assert "a\n" not in repr(df) + + def test_repr_dimensions(self): + df = DataFrame([[1, 2], [3, 4]]) + with option_context("display.show_dimensions", True): + assert "2 rows x 2 columns" in repr(df) + + with option_context("display.show_dimensions", False): + assert "2 rows x 2 columns" not in repr(df) + + with option_context("display.show_dimensions", "truncate"): + assert "2 rows x 2 columns" not in repr(df) + + @pytest.mark.slow + def test_repr_big(self): + # big one + biggie = DataFrame(np.zeros((200, 4)), columns=range(4), index=range(200)) + repr(biggie) + + def test_repr_unsortable(self): + # columns are not sortable + + unsortable = DataFrame( + { + "foo": [1] * 50, + datetime.today(): [1] * 50, + "bar": ["bar"] * 50, + datetime.today() + timedelta(1): ["bar"] * 50, + }, + index=np.arange(50), + ) + repr(unsortable) + + def test_repr_float_frame_options(self, float_frame): + repr(float_frame) + + with option_context("display.precision", 3): + repr(float_frame) + + with option_context("display.max_rows", 10, "display.max_columns", 2): + repr(float_frame) + + with option_context("display.max_rows", 1000, "display.max_columns", 1000): + repr(float_frame) + + def test_repr_unicode(self): + uval = "\u03c3\u03c3\u03c3\u03c3" + + df = DataFrame({"A": [uval, uval]}) + + result = repr(df) + ex_top = " A" + assert result.split("\n")[0].rstrip() == ex_top + + df = DataFrame({"A": [uval, uval]}) + result = repr(df) + assert result.split("\n")[0].rstrip() == ex_top + + def test_unicode_string_with_unicode(self): + df = DataFrame({"A": ["\u05d0"]}) + str(df) + + def test_repr_unicode_columns(self): + df = DataFrame({"\u05d0": [1, 2, 3], "\u05d1": [4, 5, 6], "c": [7, 8, 9]}) + repr(df.columns) # should not raise UnicodeDecodeError + + def test_str_to_bytes_raises(self): + # GH 26447 + df = DataFrame({"A": ["abc"]}) + msg = "^'str' object cannot be interpreted as an integer$" + with pytest.raises(TypeError, match=msg): + bytes(df) + + def test_very_wide_repr(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 20)), + columns=np.array(["a" * 10] * 20, dtype=object), + ) + repr(df) + + def test_repr_column_name_unicode_truncation_bug(self): + # #1906 + df = DataFrame( + { + "Id": [7117434], + "StringCol": ( + "Is it possible to modify drop plot code" + "so that the output graph is displayed " + "in iphone simulator, Is it possible to " + "modify drop plot code so that the " + "output graph is \xe2\x80\xa8displayed " + "in iphone simulator.Now we are adding " + "the CSV file externally. I want to Call " + "the File through the code.." + ), + } + ) + + with option_context("display.max_columns", 20): + assert "StringCol" in repr(df) + + def test_latex_repr(self): + pytest.importorskip("jinja2") + expected = r"""\begin{tabular}{llll} +\toprule + & 0 & 1 & 2 \\ +\midrule +0 & $\alpha$ & b & c \\ +1 & 1 & 2 & 3 \\ +\bottomrule +\end{tabular} +""" + with option_context( + "styler.format.escape", None, "styler.render.repr", "latex" + ): + df = DataFrame([[r"$\alpha$", "b", "c"], [1, 2, 3]]) + result = df._repr_latex_() + assert result == expected + + # GH 12182 + assert df._repr_latex_() is None + + def test_repr_with_datetimeindex(self): + df = DataFrame({"A": [1, 2, 3]}, index=date_range("2000", periods=3)) + result = repr(df) + expected = " A\n2000-01-01 1\n2000-01-02 2\n2000-01-03 3" + assert result == expected + + def test_repr_with_intervalindex(self): + # https://github.com/pandas-dev/pandas/pull/24134/files + df = DataFrame( + {"A": [1, 2, 3, 4]}, index=IntervalIndex.from_breaks([0, 1, 2, 3, 4]) + ) + result = repr(df) + expected = " A\n(0, 1] 1\n(1, 2] 2\n(2, 3] 3\n(3, 4] 4" + assert result == expected + + def test_repr_with_categorical_index(self): + df = DataFrame({"A": [1, 2, 3]}, index=CategoricalIndex(["a", "b", "c"])) + result = repr(df) + expected = " A\na 1\nb 2\nc 3" + assert result == expected + + def test_repr_categorical_dates_periods(self): + # normal DataFrame + dt = date_range("2011-01-01 09:00", freq="h", periods=5, tz="US/Eastern") + p = period_range("2011-01", freq="M", periods=5) + df = DataFrame({"dt": dt, "p": p}) + exp = """ dt p +0 2011-01-01 09:00:00-05:00 2011-01 +1 2011-01-01 10:00:00-05:00 2011-02 +2 2011-01-01 11:00:00-05:00 2011-03 +3 2011-01-01 12:00:00-05:00 2011-04 +4 2011-01-01 13:00:00-05:00 2011-05""" + + assert repr(df) == exp + + df2 = DataFrame({"dt": Categorical(dt), "p": Categorical(p)}) + assert repr(df2) == exp + + @pytest.mark.parametrize("arg", [np.datetime64, np.timedelta64]) + @pytest.mark.parametrize( + "box, expected", + [[Series, "0 NaT\ndtype: object"], [DataFrame, " 0\n0 NaT"]], + ) + def test_repr_np_nat_with_object(self, arg, box, expected): + # GH 25445 + result = repr(box([arg("NaT")], dtype=object)) + assert result == expected + + def test_frame_datetime64_pre1900_repr(self): + df = DataFrame({"year": date_range("1/1/1700", periods=50, freq="YE-DEC")}) + # it works! + repr(df) + + def test_frame_to_string_with_periodindex(self): + index = PeriodIndex(["2011-1", "2011-2", "2011-3"], freq="M") + frame = DataFrame(np.random.default_rng(2).standard_normal((3, 4)), index=index) + + # it works! + frame.to_string() + + def test_to_string_ea_na_in_multiindex(self): + # GH#47986 + df = DataFrame( + {"a": [1, 2]}, + index=MultiIndex.from_arrays([Series([NA, 1], dtype="Int64")]), + ) + + result = df.to_string() + expected = """ a + 1 +1 2""" + assert result == expected + + def test_datetime64tz_slice_non_truncate(self): + # GH 30263 + df = DataFrame({"x": date_range("2019", periods=10, tz="UTC")}) + expected = repr(df) + df = df.iloc[:, :5] + result = repr(df) + assert result == expected + + def test_to_records_no_typeerror_in_repr(self): + # GH 48526 + df = DataFrame([["a", "b"], ["c", "d"], ["e", "f"]], columns=["left", "right"]) + df["record"] = df[["left", "right"]].to_records() + expected = """ left right record +0 a b [0, a, b] +1 c d [1, c, d] +2 e f [2, e, f]""" + result = repr(df) + assert result == expected + + def test_to_records_with_na_record_value(self): + # GH 48526 + df = DataFrame( + [["a", np.nan], ["c", "d"], ["e", "f"]], columns=["left", "right"] + ) + df["record"] = df[["left", "right"]].to_records() + expected = """ left right record +0 a NaN [0, a, nan] +1 c d [1, c, d] +2 e f [2, e, f]""" + result = repr(df) + assert result == expected + + def test_to_records_with_na_record(self): + # GH 48526 + df = DataFrame( + [["a", "b"], [np.nan, np.nan], ["e", "f"]], columns=[np.nan, "right"] + ) + df["record"] = df[[np.nan, "right"]].to_records() + expected = """ NaN right record +0 a b [0, a, b] +1 NaN NaN [1, nan, nan] +2 e f [2, e, f]""" + result = repr(df) + assert result == expected + + def test_to_records_with_inf_record(self): + # GH 48526 + expected = """ NaN inf record +0 inf b [0, inf, b] +1 NaN NaN [1, nan, nan] +2 e f [2, e, f]""" + df = DataFrame( + [[np.inf, "b"], [np.nan, np.nan], ["e", "f"]], + columns=[np.nan, np.inf], + ) + df["record"] = df[[np.nan, np.inf]].to_records() + result = repr(df) + assert result == expected + + def test_masked_ea_with_formatter(self): + # GH#39336 + df = DataFrame( + { + "a": Series([0.123456789, 1.123456789], dtype="Float64"), + "b": Series([1, 2], dtype="Int64"), + } + ) + result = df.to_string(formatters=["{:.2f}".format, "{:.2f}".format]) + expected = """ a b +0 0.12 1.00 +1 1.12 2.00""" + assert result == expected + + def test_repr_ea_columns(self, any_string_dtype): + # GH#54797 + pytest.importorskip("pyarrow") + df = DataFrame({"long_column_name": [1, 2, 3], "col2": [4, 5, 6]}) + df.columns = df.columns.astype(any_string_dtype) + expected = """ long_column_name col2 +0 1 4 +1 2 5 +2 3 6""" + assert repr(df) == expected + + +@pytest.mark.parametrize( + "data,output", + [ + ([2, complex("nan"), 1], [" 2.0+0.0j", " NaN+0.0j", " 1.0+0.0j"]), + ([2, complex("nan"), -1], [" 2.0+0.0j", " NaN+0.0j", "-1.0+0.0j"]), + ([-2, complex("nan"), -1], ["-2.0+0.0j", " NaN+0.0j", "-1.0+0.0j"]), + ([-1.23j, complex("nan"), -1], ["-0.00-1.23j", " NaN+0.00j", "-1.00+0.00j"]), + ([1.23j, complex("nan"), 1.23], [" 0.00+1.23j", " NaN+0.00j", " 1.23+0.00j"]), + ( + [-1.23j, complex(np.nan, np.nan), 1], + ["-0.00-1.23j", " NaN+ NaNj", " 1.00+0.00j"], + ), + ( + [-1.23j, complex(1.2, np.nan), 1], + ["-0.00-1.23j", " 1.20+ NaNj", " 1.00+0.00j"], + ), + ( + [-1.23j, complex(np.nan, -1.2), 1], + ["-0.00-1.23j", " NaN-1.20j", " 1.00+0.00j"], + ), + ], +) +@pytest.mark.parametrize("as_frame", [True, False]) +def test_repr_with_complex_nans(data, output, as_frame): + # GH#53762, GH#53841 + obj = Series(np.array(data)) + if as_frame: + obj = obj.to_frame(name="val") + reprs = [f"{i} {val}" for i, val in enumerate(output)] + expected = f"{'val': >{len(reprs[0])}}\n" + "\n".join(reprs) + else: + reprs = [f"{i} {val}" for i, val in enumerate(output)] + expected = "\n".join(reprs) + "\ndtype: complex128" + assert str(obj) == expected, f"\n{obj!s}\n\n{expected}" diff --git a/pandas/tests/frame/test_stack_unstack.py b/pandas/tests/frame/test_stack_unstack.py new file mode 100644 index 0000000000000000000000000000000000000000..a6587ff486d8a4eb016fad9fd5cf583441080829 --- /dev/null +++ b/pandas/tests/frame/test_stack_unstack.py @@ -0,0 +1,2781 @@ +from datetime import datetime +import itertools +import re + +import numpy as np +import pytest + +from pandas._libs import lib +from pandas.errors import Pandas4Warning + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Period, + Series, + Timedelta, + date_range, +) +import pandas._testing as tm +from pandas.core.reshape import reshape as reshape_lib + + +@pytest.fixture(params=[True, False]) +def future_stack(request): + return request.param + + +class TestDataFrameReshape: + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_unstack(self, float_frame, future_stack): + df = float_frame.copy() + df[:] = np.arange(np.prod(df.shape)).reshape(df.shape) + + stacked = df.stack(future_stack=future_stack) + stacked_df = DataFrame({"foo": stacked, "bar": stacked}) + + unstacked = stacked.unstack() + unstacked_df = stacked_df.unstack() + + tm.assert_frame_equal(unstacked, df) + tm.assert_frame_equal(unstacked_df["bar"], df) + + unstacked_cols = stacked.unstack(0) + unstacked_cols_df = stacked_df.unstack(0) + tm.assert_frame_equal(unstacked_cols.T, df) + tm.assert_frame_equal(unstacked_cols_df["bar"].T, df) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_mixed_level(self, future_stack): + # GH 18310 + levels = [range(3), [3, "a", "b"], [1, 2]] + + # flat columns: + df = DataFrame(1, index=levels[0], columns=levels[1]) + result = df.stack(future_stack=future_stack) + expected = Series(1, index=MultiIndex.from_product(levels[:2])) + tm.assert_series_equal(result, expected) + + # MultiIndex columns: + df = DataFrame(1, index=levels[0], columns=MultiIndex.from_product(levels[1:])) + result = df.stack(1, future_stack=future_stack) + expected = DataFrame( + 1, index=MultiIndex.from_product([levels[0], levels[2]]), columns=levels[1] + ) + tm.assert_frame_equal(result, expected) + + # as above, but used labels in level are actually of homogeneous type + result = df[["a", "b"]].stack(1, future_stack=future_stack) + expected = expected[["a", "b"]] + tm.assert_frame_equal(result, expected) + + def test_unstack_not_consolidated(self): + # Gh#34708 + df = DataFrame({"x": [1, 2, np.nan], "y": [3.0, 4, np.nan]}) + df2 = df[["x"]] + df2["y"] = df["y"] + assert len(df2._mgr.blocks) == 2 + + res = df2.unstack() + expected = df.unstack() + tm.assert_series_equal(res, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_unstack_fill(self, future_stack): + # GH #9746: fill_value keyword argument for Series + # and DataFrame unstack + + # From a series + data = Series([1, 2, 4, 5], dtype=np.int16) + data.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + result = data.unstack(fill_value=-1) + expected = DataFrame( + {"a": [1, -1, 5], "b": [2, 4, -1]}, index=["x", "y", "z"], dtype=np.int16 + ) + tm.assert_frame_equal(result, expected) + + msg = ( + "Using a fill_value that cannot be held in the existing dtype is deprecated" + ) + with tm.assert_produces_warning(Pandas4Warning, match=msg): + # From a series with incorrect data type for fill_value + result = data.unstack(fill_value=0.5) + expected = DataFrame( + {"a": [1, 0.5, 5], "b": [2, 4, 0.5]}, index=["x", "y", "z"], dtype=float + ) + tm.assert_frame_equal(result, expected) + + # GH #13971: fill_value when unstacking multiple levels: + df = DataFrame( + {"x": ["a", "a", "b"], "y": ["j", "k", "j"], "z": [0, 1, 2], "w": [0, 1, 2]} + ).set_index(["x", "y", "z"]) + unstacked = df.unstack(["x", "y"], fill_value=0) + key = ("w", "b", "j") + expected = unstacked[key] + result = Series([0, 0, 2], index=unstacked.index, name=key) + tm.assert_series_equal(result, expected) + + stacked = unstacked.stack(["x", "y"], future_stack=future_stack) + stacked.index = stacked.index.reorder_levels(df.index.names) + # Workaround for GH #17886 (unnecessarily casts to float): + stacked = stacked.astype(np.int64) + result = stacked.loc[df.index] + tm.assert_frame_equal(result, df) + + # From a series + s = df["w"] + result = s.unstack(["x", "y"], fill_value=0) + expected = unstacked["w"] + tm.assert_frame_equal(result, expected) + + def test_unstack_fill_frame(self): + # From a dataframe + rows = [[1, 2], [3, 4], [5, 6], [7, 8]] + df = DataFrame(rows, columns=list("AB"), dtype=np.int32) + df.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + result = df.unstack(fill_value=-1) + + rows = [[1, 3, 2, 4], [-1, 5, -1, 6], [7, -1, 8, -1]] + expected = DataFrame(rows, index=list("xyz"), dtype=np.int32) + expected.columns = MultiIndex.from_tuples( + [("A", "a"), ("A", "b"), ("B", "a"), ("B", "b")] + ) + tm.assert_frame_equal(result, expected) + + # From a mixed type dataframe + df["A"] = df["A"].astype(np.int16) + df["B"] = df["B"].astype(np.float64) + + result = df.unstack(fill_value=-1) + expected["A"] = expected["A"].astype(np.int16) + expected["B"] = expected["B"].astype(np.float64) + tm.assert_frame_equal(result, expected) + + msg = ( + "Using a fill_value that cannot be held in the existing dtype is deprecated" + ) + with tm.assert_produces_warning(Pandas4Warning, match=msg): + # From a dataframe with incorrect data type for fill_value + result = df.unstack(fill_value=0.5) + + rows = [[1, 3, 2, 4], [0.5, 5, 0.5, 6], [7, 0.5, 8, 0.5]] + expected = DataFrame(rows, index=list("xyz"), dtype=float) + expected.columns = MultiIndex.from_tuples( + [("A", "a"), ("A", "b"), ("B", "a"), ("B", "b")] + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_fill_frame_datetime(self): + # Test unstacking with date times + dv = date_range("2012-01-01", periods=4).values + data = Series(dv) + data.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + result = data.unstack() + expected = DataFrame( + {"a": [dv[0], pd.NaT, dv[3]], "b": [dv[1], dv[2], pd.NaT]}, + index=["x", "y", "z"], + ) + tm.assert_frame_equal(result, expected) + + result = data.unstack(fill_value=dv[0]) + expected = DataFrame( + {"a": [dv[0], dv[0], dv[3]], "b": [dv[1], dv[2], dv[0]]}, + index=["x", "y", "z"], + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_fill_frame_timedelta(self): + # Test unstacking with time deltas + td = [Timedelta(days=i) for i in range(4)] + data = Series(td) + data.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + result = data.unstack() + expected = DataFrame( + {"a": [td[0], pd.NaT, td[3]], "b": [td[1], td[2], pd.NaT]}, + index=["x", "y", "z"], + ) + tm.assert_frame_equal(result, expected) + + result = data.unstack(fill_value=td[1]) + expected = DataFrame( + {"a": [td[0], td[1], td[3]], "b": [td[1], td[2], td[1]]}, + index=["x", "y", "z"], + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_fill_frame_period(self): + # Test unstacking with period + periods = [ + Period("2012-01"), + Period("2012-02"), + Period("2012-03"), + Period("2012-04"), + ] + data = Series(periods) + data.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + result = data.unstack() + expected = DataFrame( + {"a": [periods[0], None, periods[3]], "b": [periods[1], periods[2], None]}, + index=["x", "y", "z"], + ) + tm.assert_frame_equal(result, expected) + + result = data.unstack(fill_value=periods[1]) + expected = DataFrame( + { + "a": [periods[0], periods[1], periods[3]], + "b": [periods[1], periods[2], periods[1]], + }, + index=["x", "y", "z"], + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_fill_frame_categorical(self): + # Test unstacking with categorical + data = Series(["a", "b", "c", "a"], dtype="category") + data.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + # By default missing values will be NaN + result = data.unstack() + expected = DataFrame( + { + "a": pd.Categorical(["a", None, "a"], categories=list("abc")), + "b": pd.Categorical(["b", "c", None], categories=list("abc")), + }, + index=list("xyz"), + ) + tm.assert_frame_equal(result, expected) + + # Fill with non-category results in a ValueError + msg = r"Cannot setitem on a Categorical with a new category \(d\)" + with pytest.raises(TypeError, match=msg): + data.unstack(fill_value="d") + + # Fill with category value replaces missing values as expected + result = data.unstack(fill_value="c") + expected = DataFrame( + { + "a": pd.Categorical(list("aca"), categories=list("abc")), + "b": pd.Categorical(list("bcc"), categories=list("abc")), + }, + index=list("xyz"), + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_tuplename_in_multiindex(self): + # GH 19966 + idx = MultiIndex.from_product( + [["a", "b", "c"], [1, 2, 3]], names=[("A", "a"), ("B", "b")] + ) + df = DataFrame({"d": [1] * 9, "e": [2] * 9}, index=idx) + result = df.unstack(("A", "a")) + + expected = DataFrame( + [[1, 1, 1, 2, 2, 2], [1, 1, 1, 2, 2, 2], [1, 1, 1, 2, 2, 2]], + columns=MultiIndex.from_tuples( + [ + ("d", "a"), + ("d", "b"), + ("d", "c"), + ("e", "a"), + ("e", "b"), + ("e", "c"), + ], + names=[None, ("A", "a")], + ), + index=Index([1, 2, 3], name=("B", "b")), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "unstack_idx, expected_values, expected_index, expected_columns", + [ + ( + ("A", "a"), + [[1, 1, 2, 2], [1, 1, 2, 2], [1, 1, 2, 2], [1, 1, 2, 2]], + MultiIndex.from_tuples( + [(1, 3), (1, 4), (2, 3), (2, 4)], names=["B", "C"] + ), + MultiIndex.from_tuples( + [("d", "a"), ("d", "b"), ("e", "a"), ("e", "b")], + names=[None, ("A", "a")], + ), + ), + ( + (("A", "a"), "B"), + [[1, 1, 1, 1, 2, 2, 2, 2], [1, 1, 1, 1, 2, 2, 2, 2]], + Index([3, 4], name="C"), + MultiIndex.from_tuples( + [ + ("d", "a", 1), + ("d", "a", 2), + ("d", "b", 1), + ("d", "b", 2), + ("e", "a", 1), + ("e", "a", 2), + ("e", "b", 1), + ("e", "b", 2), + ], + names=[None, ("A", "a"), "B"], + ), + ), + ], + ) + def test_unstack_mixed_type_name_in_multiindex( + self, unstack_idx, expected_values, expected_index, expected_columns + ): + # GH 19966 + idx = MultiIndex.from_product( + [["a", "b"], [1, 2], [3, 4]], names=[("A", "a"), "B", "C"] + ) + df = DataFrame({"d": [1] * 8, "e": [2] * 8}, index=idx) + result = df.unstack(unstack_idx) + + expected = DataFrame( + expected_values, columns=expected_columns, index=expected_index + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_preserve_dtypes(self): + # Checks fix for #11847 + df = DataFrame( + { + "state": ["IL", "MI", "NC"], + "index": ["a", "b", "c"], + "some_categories": Series(["a", "b", "c"]).astype("category"), + "A": np.random.default_rng(2).random(3), + "B": 1, + "C": "foo", + "D": pd.Timestamp("20010102"), + "E": Series([1.0, 50.0, 100.0]).astype("float32"), + "F": Series([3.0, 4.0, 5.0]).astype("float64"), + "G": False, + "H": Series([1, 200, 923442]).astype("int8"), + } + ) + + def unstack_and_compare(df, column_name): + unstacked1 = df.unstack([column_name]) + unstacked2 = df.unstack(column_name) + tm.assert_frame_equal(unstacked1, unstacked2) + + df1 = df.set_index(["state", "index"]) + unstack_and_compare(df1, "index") + + df1 = df.set_index(["state", "some_categories"]) + unstack_and_compare(df1, "some_categories") + + df1 = df.set_index(["F", "C"]) + unstack_and_compare(df1, "F") + + df1 = df.set_index(["G", "B", "state"]) + unstack_and_compare(df1, "B") + + df1 = df.set_index(["E", "A"]) + unstack_and_compare(df1, "E") + + df1 = df.set_index(["state", "index"]) + s = df1["A"] + unstack_and_compare(s, "index") + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_ints(self, future_stack): + columns = MultiIndex.from_tuples(list(itertools.product(range(3), repeat=3))) + df = DataFrame( + np.random.default_rng(2).standard_normal((30, 27)), columns=columns + ) + + tm.assert_frame_equal( + df.stack(level=[1, 2], future_stack=future_stack), + df.stack(level=1, future_stack=future_stack).stack( + level=1, future_stack=future_stack + ), + ) + tm.assert_frame_equal( + df.stack(level=[-2, -1], future_stack=future_stack), + df.stack(level=1, future_stack=future_stack).stack( + level=1, future_stack=future_stack + ), + ) + + df_named = df.copy() + return_value = df_named.columns.set_names(range(3), inplace=True) + assert return_value is None + + tm.assert_frame_equal( + df_named.stack(level=[1, 2], future_stack=future_stack), + df_named.stack(level=1, future_stack=future_stack).stack( + level=1, future_stack=future_stack + ), + ) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_mixed_levels(self, future_stack): + columns = MultiIndex.from_tuples( + [ + ("A", "cat", "long"), + ("B", "cat", "long"), + ("A", "dog", "short"), + ("B", "dog", "short"), + ], + names=["exp", "animal", "hair_length"], + ) + df = DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), columns=columns + ) + + animal_hair_stacked = df.stack( + level=["animal", "hair_length"], future_stack=future_stack + ) + exp_hair_stacked = df.stack( + level=["exp", "hair_length"], future_stack=future_stack + ) + + # GH #8584: Need to check that stacking works when a number + # is passed that is both a level name and in the range of + # the level numbers + df2 = df.copy() + df2.columns.names = ["exp", "animal", 1] + tm.assert_frame_equal( + df2.stack(level=["animal", 1], future_stack=future_stack), + animal_hair_stacked, + check_names=False, + ) + tm.assert_frame_equal( + df2.stack(level=["exp", 1], future_stack=future_stack), + exp_hair_stacked, + check_names=False, + ) + + # When mixed types are passed and the ints are not level + # names, raise + msg = ( + "level should contain all level names or all level numbers, not " + "a mixture of the two" + ) + with pytest.raises(ValueError, match=msg): + df2.stack(level=["animal", 0], future_stack=future_stack) + + # GH #8584: Having 0 in the level names could raise a + # strange error about lexsort depth + df3 = df.copy() + df3.columns.names = ["exp", "animal", 0] + tm.assert_frame_equal( + df3.stack(level=["animal", 0], future_stack=future_stack), + animal_hair_stacked, + check_names=False, + ) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_int_level_names(self, future_stack): + columns = MultiIndex.from_tuples( + [ + ("A", "cat", "long"), + ("B", "cat", "long"), + ("A", "dog", "short"), + ("B", "dog", "short"), + ], + names=["exp", "animal", "hair_length"], + ) + df = DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), columns=columns + ) + + exp_animal_stacked = df.stack( + level=["exp", "animal"], future_stack=future_stack + ) + animal_hair_stacked = df.stack( + level=["animal", "hair_length"], future_stack=future_stack + ) + exp_hair_stacked = df.stack( + level=["exp", "hair_length"], future_stack=future_stack + ) + + df2 = df.copy() + df2.columns.names = [0, 1, 2] + tm.assert_frame_equal( + df2.stack(level=[1, 2], future_stack=future_stack), + animal_hair_stacked, + check_names=False, + ) + tm.assert_frame_equal( + df2.stack(level=[0, 1], future_stack=future_stack), + exp_animal_stacked, + check_names=False, + ) + tm.assert_frame_equal( + df2.stack(level=[0, 2], future_stack=future_stack), + exp_hair_stacked, + check_names=False, + ) + + # Out-of-order int column names + df3 = df.copy() + df3.columns.names = [2, 0, 1] + tm.assert_frame_equal( + df3.stack(level=[0, 1], future_stack=future_stack), + animal_hair_stacked, + check_names=False, + ) + tm.assert_frame_equal( + df3.stack(level=[2, 0], future_stack=future_stack), + exp_animal_stacked, + check_names=False, + ) + tm.assert_frame_equal( + df3.stack(level=[2, 1], future_stack=future_stack), + exp_hair_stacked, + check_names=False, + ) + + def test_unstack_bool(self): + df = DataFrame( + [False, False], + index=MultiIndex.from_arrays([["a", "b"], ["c", "l"]]), + columns=["col"], + ) + rs = df.unstack() + xp = DataFrame( + np.array([[False, np.nan], [np.nan, False]], dtype=object), + index=["a", "b"], + columns=MultiIndex.from_arrays([["col", "col"], ["c", "l"]]), + ) + tm.assert_frame_equal(rs, xp) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_unstack_level_binding(self, future_stack): + # GH9856 + mi = MultiIndex( + levels=[["foo", "bar"], ["one", "two"], ["a", "b"]], + codes=[[0, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 0]], + names=["first", "second", "third"], + ) + s = Series(0, index=mi) + result = s.unstack([1, 2]).stack(0, future_stack=future_stack) + + expected_mi = MultiIndex( + levels=[["foo", "bar"], ["one", "two"]], + codes=[[0, 0, 1, 1], [0, 1, 0, 1]], + names=["first", "second"], + ) + + expected = DataFrame( + np.array( + [[0, np.nan], [np.nan, 0], [0, np.nan], [np.nan, 0]], dtype=np.float64 + ), + index=expected_mi, + columns=Index(["b", "a"], name="third"), + ) + + tm.assert_frame_equal(result, expected) + + def test_unstack_to_series(self, float_frame): + # check reversibility + data = float_frame.unstack() + + assert isinstance(data, Series) + undo = data.unstack().T + tm.assert_frame_equal(undo, float_frame) + + # check NA handling + data = DataFrame({"x": [1, 2, np.nan], "y": [3.0, 4, np.nan]}) + data.index = Index(["a", "b", "c"]) + result = data.unstack() + + midx = MultiIndex( + levels=[["x", "y"], ["a", "b", "c"]], + codes=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + ) + expected = Series([1, 2, np.nan, 3, 4, np.nan], index=midx) + + tm.assert_series_equal(result, expected) + + # check composability of unstack + old_data = data.copy() + for _ in range(4): + data = data.unstack() + tm.assert_frame_equal(old_data, data) + + def test_unstack_dtypes(self, using_infer_string): + # GH 2929 + rows = [[1, 1, 3, 4], [1, 2, 3, 4], [2, 1, 3, 4], [2, 2, 3, 4]] + + df = DataFrame(rows, columns=list("ABCD")) + result = df.dtypes + expected = Series([np.dtype("int64")] * 4, index=list("ABCD")) + tm.assert_series_equal(result, expected) + + # single dtype + df2 = df.set_index(["A", "B"]) + df3 = df2.unstack("B") + result = df3.dtypes + expected = Series( + [np.dtype("int64")] * 4, + index=MultiIndex.from_arrays( + [["C", "C", "D", "D"], [1, 2, 1, 2]], names=(None, "B") + ), + ) + tm.assert_series_equal(result, expected) + + # mixed + df2 = df.set_index(["A", "B"]) + df2["C"] = 3.0 + df3 = df2.unstack("B") + result = df3.dtypes + expected = Series( + [np.dtype("float64")] * 2 + [np.dtype("int64")] * 2, + index=MultiIndex.from_arrays( + [["C", "C", "D", "D"], [1, 2, 1, 2]], names=(None, "B") + ), + ) + tm.assert_series_equal(result, expected) + df2["D"] = "foo" + df3 = df2.unstack("B") + result = df3.dtypes + dtype = ( + pd.StringDtype(na_value=np.nan) + if using_infer_string + else np.dtype("object") + ) + expected = Series( + [np.dtype("float64")] * 2 + [dtype] * 2, + index=MultiIndex.from_arrays( + [["C", "C", "D", "D"], [1, 2, 1, 2]], names=(None, "B") + ), + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "c, d", + ( + (np.zeros(5), np.zeros(5)), + (np.arange(5, dtype="f8"), np.arange(5, 10, dtype="f8")), + ), + ) + def test_unstack_dtypes_mixed_date(self, c, d): + # GH7405 + df = DataFrame( + { + "A": ["a"] * 5, + "C": c, + "D": d, + "B": date_range("2012-01-01", periods=5), + } + ) + + right = df.iloc[:3].copy(deep=True) + + df = df.set_index(["A", "B"]) + df["D"] = df["D"].astype("int64") + + left = df.iloc[:3].unstack(0) + right = right.set_index(["A", "B"]).unstack(0) + right[("D", "a")] = right[("D", "a")].astype("int64") + + assert left.shape == (3, 2) + tm.assert_frame_equal(left, right) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_unstack_non_unique_index_names(self, future_stack): + idx = MultiIndex.from_tuples([("a", "b"), ("c", "d")], names=["c1", "c1"]) + df = DataFrame([1, 2], index=idx) + msg = "The name c1 occurs multiple times, use a level number" + with pytest.raises(ValueError, match=msg): + df.unstack("c1") + + with pytest.raises(ValueError, match=msg): + df.T.stack("c1", future_stack=future_stack) + + def test_unstack_unused_levels(self): + # GH 17845: unused codes in index make unstack() cast int to float + idx = MultiIndex.from_product([["a"], ["A", "B", "C", "D"]])[:-1] + df = DataFrame([[1, 0]] * 3, index=idx) + + result = df.unstack() + exp_col = MultiIndex.from_product([range(2), ["A", "B", "C"]]) + expected = DataFrame([[1, 1, 1, 0, 0, 0]], index=["a"], columns=exp_col) + tm.assert_frame_equal(result, expected) + assert (result.columns.levels[1] == idx.levels[1]).all() + + # Unused items on both levels + levels = [range(3), range(4)] + codes = [[0, 0, 1, 1], [0, 2, 0, 2]] + idx = MultiIndex(levels, codes) + block = np.arange(4).reshape(2, 2) + df = DataFrame(np.concatenate([block, block + 4]), index=idx) + result = df.unstack() + expected = DataFrame( + np.concatenate([block * 2, block * 2 + 1], axis=1), columns=idx + ) + tm.assert_frame_equal(result, expected) + assert (result.columns.levels[1] == idx.levels[1]).all() + + @pytest.mark.parametrize( + "level, idces, col_level, idx_level", + ( + (0, [13, 16, 6, 9, 2, 5, 8, 11], [np.nan, "a", 2], [np.nan, 5, 1]), + (1, [8, 11, 1, 4, 12, 15, 13, 16], [np.nan, 5, 1], [np.nan, "a", 2]), + ), + ) + def test_unstack_unused_levels_mixed_with_nan( + self, level, idces, col_level, idx_level + ): + # With mixed dtype and NaN + levels = [["a", 2, "c"], [1, 3, 5, 7]] + codes = [[0, -1, 1, 1], [0, 2, -1, 2]] + idx = MultiIndex(levels, codes) + data = np.arange(8) + df = DataFrame(data.reshape(4, 2), index=idx) + + result = df.unstack(level=level) + exp_data = np.zeros(18) * np.nan + exp_data[idces] = data + cols = MultiIndex.from_product([range(2), col_level]) + expected = DataFrame(exp_data.reshape(3, 6), index=idx_level, columns=cols) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("cols", [["A", "C"], slice(None)]) + def test_unstack_unused_level(self, cols): + # GH 18562 : unused codes on the unstacked level + df = DataFrame([[2010, "a", "I"], [2011, "b", "II"]], columns=["A", "B", "C"]) + + ind = df.set_index(["A", "B", "C"], drop=False) + selection = ind.loc[(slice(None), slice(None), "I"), cols] + result = selection.unstack() + + expected = ind.iloc[[0]][cols] + expected.columns = MultiIndex.from_product( + [expected.columns, ["I"]], names=[None, "C"] + ) + expected.index = expected.index.droplevel("C") + tm.assert_frame_equal(result, expected) + + def test_unstack_long_index(self): + # PH 32624: Error when using a lot of indices to unstack. + # The error occurred only, if a lot of indices are used. + df = DataFrame( + [[1]], + columns=MultiIndex.from_tuples([[0]], names=["c1"]), + index=MultiIndex.from_tuples( + [[0, 0, 1, 0, 0, 0, 1]], + names=["i1", "i2", "i3", "i4", "i5", "i6", "i7"], + ), + ) + result = df.unstack(["i2", "i3", "i4", "i5", "i6", "i7"]) + expected = DataFrame( + [[1]], + columns=MultiIndex.from_tuples( + [[0, 0, 1, 0, 0, 0, 1]], + names=["c1", "i2", "i3", "i4", "i5", "i6", "i7"], + ), + index=Index([0], name="i1"), + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_multi_level_cols(self): + # PH 24729: Unstack a df with multi level columns + df = DataFrame( + [[0.0, 0.0], [0.0, 0.0]], + columns=MultiIndex.from_tuples( + [["B", "C"], ["B", "D"]], names=["c1", "c2"] + ), + index=MultiIndex.from_tuples( + [[10, 20, 30], [10, 20, 40]], names=["i1", "i2", "i3"] + ), + ) + assert df.unstack(["i2", "i1"]).columns.names[-2:] == ["i2", "i1"] + + def test_unstack_multi_level_rows_and_cols(self): + # PH 28306: Unstack df with multi level cols and rows + df = DataFrame( + [[1, 2], [3, 4], [-1, -2], [-3, -4]], + columns=MultiIndex.from_tuples([["a", "b", "c"], ["d", "e", "f"]]), + index=MultiIndex.from_tuples( + [ + ["m1", "P3", 222], + ["m1", "A5", 111], + ["m2", "P3", 222], + ["m2", "A5", 111], + ], + names=["i1", "i2", "i3"], + ), + ) + result = df.unstack(["i3", "i2"]) + expected = df.unstack(["i3"]).unstack(["i2"]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("idx", [("jim", "joe"), ("joe", "jim")]) + @pytest.mark.parametrize("lev", list(range(2))) + def test_unstack_nan_index1(self, idx, lev): + # GH7466 + def cast(val): + val_str = "" if val != val else val + return f"{val_str:1}" + + df = DataFrame( + { + "jim": ["a", "b", np.nan, "d"], + "joe": ["w", "x", "y", "z"], + "jolie": ["a.w", "b.x", " .y", "d.z"], + } + ) + + left = df.set_index(["jim", "joe"]).unstack()["jolie"] + right = df.set_index(["joe", "jim"]).unstack()["jolie"].T + tm.assert_frame_equal(left, right) + + mi = df.set_index(list(idx)) + udf = mi.unstack(level=lev) + assert udf.notna().values.sum() == len(df) + mk_list = lambda a: list(a) if isinstance(a, tuple) else [a] + rows, cols = udf["jolie"].notna().values.nonzero() + for i, j in zip(rows, cols): + left = sorted(udf["jolie"].iloc[i, j].split(".")) + right = mk_list(udf["jolie"].index[i]) + mk_list(udf["jolie"].columns[j]) + right = sorted(map(cast, right)) + assert left == right + + @pytest.mark.parametrize("idx", itertools.permutations(["1st", "2nd", "3rd"])) + @pytest.mark.parametrize("lev", list(range(3))) + @pytest.mark.parametrize("col", ["4th", "5th"]) + def test_unstack_nan_index_repeats(self, idx, lev, col): + def cast(val): + val_str = "" if val != val else val + return f"{val_str:1}" + + df = DataFrame( + { + "1st": ["d"] * 3 + + [np.nan] * 5 + + ["a"] * 2 + + ["c"] * 3 + + ["e"] * 2 + + ["b"] * 5, + "2nd": ["y"] * 2 + + ["w"] * 3 + + [np.nan] * 3 + + ["z"] * 4 + + [np.nan] * 3 + + ["x"] * 3 + + [np.nan] * 2, + "3rd": [ + 67, + 39, + 53, + 72, + 57, + 80, + 31, + 18, + 11, + 30, + 59, + 50, + 62, + 59, + 76, + 52, + 14, + 53, + 60, + 51, + ], + } + ) + + df["4th"], df["5th"] = ( + df.apply(lambda r: ".".join(map(cast, r)), axis=1), + df.apply(lambda r: ".".join(map(cast, r.iloc[::-1])), axis=1), + ) + + mi = df.set_index(list(idx)) + udf = mi.unstack(level=lev) + assert udf.notna().values.sum() == 2 * len(df) + mk_list = lambda a: list(a) if isinstance(a, tuple) else [a] + rows, cols = udf[col].notna().values.nonzero() + for i, j in zip(rows, cols): + left = sorted(udf[col].iloc[i, j].split(".")) + right = mk_list(udf[col].index[i]) + mk_list(udf[col].columns[j]) + right = sorted(map(cast, right)) + assert left == right + + def test_unstack_nan_index2(self): + # GH7403 + df = DataFrame({"A": list("aaaabbbb"), "B": range(8), "C": range(8)}) + # Explicit cast to avoid implicit cast when setting to np.nan + df = df.astype({"B": "float"}) + df.iloc[3, 1] = np.nan + left = df.set_index(["A", "B"]).unstack(0) + + vals = [ + [3, 0, 1, 2, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, 4, 5, 6, 7], + ] + vals = list(map(list, zip(*vals))) + idx = Index([np.nan, 0, 1, 2, 4, 5, 6, 7], name="B") + cols = MultiIndex( + levels=[["C"], ["a", "b"]], codes=[[0, 0], [0, 1]], names=[None, "A"] + ) + + right = DataFrame(vals, columns=cols, index=idx) + tm.assert_frame_equal(left, right) + + df = DataFrame({"A": list("aaaabbbb"), "B": list(range(4)) * 2, "C": range(8)}) + # Explicit cast to avoid implicit cast when setting to np.nan + df = df.astype({"B": "float"}) + df.iloc[2, 1] = np.nan + left = df.set_index(["A", "B"]).unstack(0) + + vals = [[2, np.nan], [0, 4], [1, 5], [np.nan, 6], [3, 7]] + cols = MultiIndex( + levels=[["C"], ["a", "b"]], codes=[[0, 0], [0, 1]], names=[None, "A"] + ) + idx = Index([np.nan, 0, 1, 2, 3], name="B") + right = DataFrame(vals, columns=cols, index=idx) + tm.assert_frame_equal(left, right) + + df = DataFrame({"A": list("aaaabbbb"), "B": list(range(4)) * 2, "C": range(8)}) + # Explicit cast to avoid implicit cast when setting to np.nan + df = df.astype({"B": "float"}) + df.iloc[3, 1] = np.nan + left = df.set_index(["A", "B"]).unstack(0) + + vals = [[3, np.nan], [0, 4], [1, 5], [2, 6], [np.nan, 7]] + cols = MultiIndex( + levels=[["C"], ["a", "b"]], codes=[[0, 0], [0, 1]], names=[None, "A"] + ) + idx = Index([np.nan, 0, 1, 2, 3], name="B") + right = DataFrame(vals, columns=cols, index=idx) + tm.assert_frame_equal(left, right) + + def test_unstack_nan_index3(self): + # GH7401 + df = DataFrame( + { + "A": list("aaaaabbbbb"), + "B": (date_range("2012-01-01", periods=5).tolist() * 2), + "C": np.arange(10), + } + ) + + df.iloc[3, 1] = np.nan + left = df.set_index(["A", "B"]).unstack() + + vals = np.array([[3, 0, 1, 2, np.nan, 4], [np.nan, 5, 6, 7, 8, 9]]) + idx = Index(["a", "b"], name="A") + cols = MultiIndex( + levels=[["C"], date_range("2012-01-01", periods=5)], + codes=[[0, 0, 0, 0, 0, 0], [-1, 0, 1, 2, 3, 4]], + names=[None, "B"], + ) + + right = DataFrame(vals, columns=cols, index=idx) + tm.assert_frame_equal(left, right) + + def test_unstack_nan_index4(self): + # GH4862 + vals = [ + ["Hg", np.nan, np.nan, 680585148], + ["U", 0.0, np.nan, 680585148], + ["Pb", 7.07e-06, np.nan, 680585148], + ["Sn", 2.3614e-05, 0.0133, 680607017], + ["Ag", 0.0, 0.0133, 680607017], + ["Hg", -0.00015, 0.0133, 680607017], + ] + df = DataFrame( + vals, + columns=["agent", "change", "dosage", "s_id"], + index=[17263, 17264, 17265, 17266, 17267, 17268], + ) + + left = df.copy().set_index(["s_id", "dosage", "agent"]).unstack() + + vals = [ + [np.nan, np.nan, 7.07e-06, np.nan, 0.0], + [0.0, -0.00015, np.nan, 2.3614e-05, np.nan], + ] + + idx = MultiIndex( + levels=[[680585148, 680607017], [0.0133]], + codes=[[0, 1], [-1, 0]], + names=["s_id", "dosage"], + ) + + cols = MultiIndex( + levels=[["change"], ["Ag", "Hg", "Pb", "Sn", "U"]], + codes=[[0, 0, 0, 0, 0], [0, 1, 2, 3, 4]], + names=[None, "agent"], + ) + + right = DataFrame(vals, columns=cols, index=idx) + tm.assert_frame_equal(left, right) + + left = df.loc[17264:].copy().set_index(["s_id", "dosage", "agent"]) + tm.assert_frame_equal(left.unstack(), right) + + def test_unstack_nan_index5(self): + # GH9497 - multiple unstack with nulls + df = DataFrame( + { + "1st": [1, 2, 1, 2, 1, 2], + "2nd": date_range("2014-02-01", periods=6, freq="D"), + "jim": 100 + np.arange(6), + "joe": (np.random.default_rng(2).standard_normal(6) * 10).round(2), + } + ) + + df["3rd"] = df["2nd"] - pd.Timestamp("2014-02-02") + df.loc[1, "2nd"] = df.loc[3, "2nd"] = np.nan + df.loc[1, "3rd"] = df.loc[4, "3rd"] = np.nan + + left = df.set_index(["1st", "2nd", "3rd"]).unstack(["2nd", "3rd"]) + assert left.notna().values.sum() == 2 * len(df) + + for col in ["jim", "joe"]: + for _, r in df.iterrows(): + key = r["1st"], (col, r["2nd"], r["3rd"]) + assert r[col] == left.loc[key] + + def test_stack_datetime_column_multiIndex(self, future_stack): + # GH 8039 + t = datetime(2014, 1, 1) + df = DataFrame([1, 2, 3, 4], columns=MultiIndex.from_tuples([(t, "A", "B")])) + warn = None if future_stack else Pandas4Warning + msg = "The previous implementation of stack is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result = df.stack(future_stack=future_stack) + + eidx = MultiIndex.from_product([range(4), ("B",)]) + ecols = MultiIndex.from_tuples([(t, "A")]) + expected = DataFrame([1, 2, 3, 4], index=eidx, columns=ecols) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize( + "multiindex_columns", + [ + [0, 1, 2, 3, 4], + [0, 1, 2, 3], + [0, 1, 2, 4], + [0, 1, 2], + [1, 2, 3], + [2, 3, 4], + [0, 1], + [0, 2], + [0, 3], + [0], + [2], + [4], + [4, 3, 2, 1, 0], + [3, 2, 1, 0], + [4, 2, 1, 0], + [2, 1, 0], + [3, 2, 1], + [4, 3, 2], + [1, 0], + [2, 0], + [3, 0], + ], + ) + @pytest.mark.parametrize("level", (-1, 0, 1, [0, 1], [1, 0])) + def test_stack_partial_multiIndex(self, multiindex_columns, level, future_stack): + # GH 8844 + dropna = False if not future_stack else lib.no_default + full_multiindex = MultiIndex.from_tuples( + [("B", "x"), ("B", "z"), ("A", "y"), ("C", "x"), ("C", "u")], + names=["Upper", "Lower"], + ) + multiindex = full_multiindex[multiindex_columns] + df = DataFrame( + np.arange(3 * len(multiindex)).reshape(3, len(multiindex)), + columns=multiindex, + ) + result = df.stack(level=level, dropna=dropna, future_stack=future_stack) + + if isinstance(level, int) and not future_stack: + # Stacking a single level should not make any all-NaN rows, + # so df.stack(level=level, dropna=False) should be the same + # as df.stack(level=level, dropna=True). + expected = df.stack(level=level, dropna=True, future_stack=future_stack) + if isinstance(expected, Series): + tm.assert_series_equal(result, expected) + else: + tm.assert_frame_equal(result, expected) + + df.columns = MultiIndex.from_tuples( + df.columns.to_numpy(), names=df.columns.names + ) + expected = df.stack(level=level, dropna=dropna, future_stack=future_stack) + if isinstance(expected, Series): + tm.assert_series_equal(result, expected) + else: + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_full_multiIndex(self, future_stack): + # GH 8844 + full_multiindex = MultiIndex.from_tuples( + [("B", "x"), ("B", "z"), ("A", "y"), ("C", "x"), ("C", "u")], + names=["Upper", "Lower"], + ) + df = DataFrame(np.arange(6).reshape(2, 3), columns=full_multiindex[[0, 1, 3]]) + dropna = False if not future_stack else lib.no_default + result = df.stack(dropna=dropna, future_stack=future_stack) + expected = DataFrame( + [[0, 2], [1, np.nan], [3, 5], [4, np.nan]], + index=MultiIndex( + levels=[range(2), ["u", "x", "y", "z"]], + codes=[[0, 0, 1, 1], [1, 3, 1, 3]], + names=[None, "Lower"], + ), + columns=Index(["B", "C"], name="Upper"), + ) + expected["B"] = expected["B"].astype(df.dtypes.iloc[0]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize("ordered", [False, True]) + def test_stack_preserve_categorical_dtype(self, ordered, future_stack): + # GH13854 + cidx = pd.CategoricalIndex(list("yxz"), categories=list("xyz"), ordered=ordered) + df = DataFrame([[10, 11, 12]], columns=cidx) + result = df.stack(future_stack=future_stack) + + # `MultiIndex.from_product` preserves categorical dtype - + # it's tested elsewhere. + midx = MultiIndex.from_product([df.index, cidx]) + expected = Series([10, 11, 12], index=midx) + + tm.assert_series_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize("ordered", [False, True]) + @pytest.mark.parametrize( + "labels,data", + [ + (list("xyz"), [10, 11, 12, 13, 14, 15]), + (list("zyx"), [14, 15, 12, 13, 10, 11]), + ], + ) + def test_stack_multi_preserve_categorical_dtype( + self, ordered, labels, data, future_stack + ): + # GH-36991 + cidx = pd.CategoricalIndex(labels, categories=sorted(labels), ordered=ordered) + cidx2 = pd.CategoricalIndex(["u", "v"], ordered=ordered) + midx = MultiIndex.from_product([cidx, cidx2]) + df = DataFrame([sorted(data)], columns=midx) + result = df.stack([0, 1], future_stack=future_stack) + + labels = labels if future_stack else sorted(labels) + s_cidx = pd.CategoricalIndex(labels, ordered=ordered) + expected_data = sorted(data) if future_stack else data + expected = Series( + expected_data, index=MultiIndex.from_product([range(1), s_cidx, cidx2]) + ) + + tm.assert_series_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_preserve_categorical_dtype_values(self, future_stack): + # GH-23077 + cat = pd.Categorical(["a", "a", "b", "c"]) + df = DataFrame({"A": cat, "B": cat}) + result = df.stack(future_stack=future_stack) + index = MultiIndex.from_product([range(4), ["A", "B"]]) + expected = Series( + pd.Categorical(["a", "a", "a", "a", "b", "b", "c", "c"]), index=index + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize( + "index", + [ + [0, 0, 1, 1], + [0, 0, 2, 3], + [0, 1, 2, 3], + ], + ) + def test_stack_multi_columns_non_unique_index(self, index, future_stack): + # GH-28301 + columns = MultiIndex.from_product([[1, 2], ["a", "b"]]) + df = DataFrame(index=index, columns=columns).fillna(1) + stacked = df.stack(future_stack=future_stack) + new_index = MultiIndex.from_tuples(stacked.index.to_numpy()) + expected = DataFrame( + stacked.to_numpy(), index=new_index, columns=stacked.columns + ) + tm.assert_frame_equal(stacked, expected) + stacked_codes = np.asarray(stacked.index.codes) + expected_codes = np.asarray(new_index.codes) + tm.assert_numpy_array_equal(stacked_codes, expected_codes) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize( + "vals1, vals2, dtype1, dtype2, expected_dtype", + [ + ([1, 2], [3.0, 4.0], "Int64", "Float64", "Float64"), + ([1, 2], ["foo", "bar"], "Int64", "string", "object"), + ], + ) + def test_stack_multi_columns_mixed_extension_types( + self, vals1, vals2, dtype1, dtype2, expected_dtype, future_stack + ): + # GH45740 + df = DataFrame( + { + ("A", 1): Series(vals1, dtype=dtype1), + ("A", 2): Series(vals2, dtype=dtype2), + } + ) + result = df.stack(future_stack=future_stack) + expected = ( + df.astype(object).stack(future_stack=future_stack).astype(expected_dtype) + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("level", [0, 1]) + def test_unstack_mixed_extension_types(self, level): + index = MultiIndex.from_tuples([("A", 0), ("A", 1), ("B", 1)], names=["a", "b"]) + df = DataFrame( + { + "A": pd.array([0, 1, None], dtype="Int64"), + "B": pd.Categorical(["a", "a", "b"]), + }, + index=index, + ) + + result = df.unstack(level=level) + expected = df.astype(object).unstack(level=level) + if level == 0: + expected[("A", "B")] = expected[("A", "B")].fillna(pd.NA) + else: + expected[("A", 0)] = expected[("A", 0)].fillna(pd.NA) + + expected_dtypes = Series( + [df.A.dtype] * 2 + [df.B.dtype] * 2, index=result.columns + ) + tm.assert_series_equal(result.dtypes, expected_dtypes) + tm.assert_frame_equal(result.astype(object), expected) + + @pytest.mark.parametrize("level", [0, "baz"]) + def test_unstack_swaplevel_sortlevel(self, level): + # GH 20994 + mi = MultiIndex.from_product([range(1), ["d", "c"]], names=["bar", "baz"]) + df = DataFrame([[0, 2], [1, 3]], index=mi, columns=["B", "A"]) + df.columns.name = "foo" + + expected = DataFrame( + [[3, 1, 2, 0]], + columns=MultiIndex.from_tuples( + [("c", "A"), ("c", "B"), ("d", "A"), ("d", "B")], names=["baz", "foo"] + ), + ) + expected.index.name = "bar" + + result = df.unstack().swaplevel(axis=1).sort_index(axis=1, level=level) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["float64", "Float64"]) +def test_unstack_sort_false(frame_or_series, dtype): + # GH 15105 + index = MultiIndex.from_tuples( + [("two", "z", "b"), ("two", "y", "a"), ("one", "z", "b"), ("one", "y", "a")] + ) + obj = frame_or_series(np.arange(1.0, 5.0), index=index, dtype=dtype) + + result = obj.unstack(level=0, sort=False) + + if frame_or_series is DataFrame: + expected_columns = MultiIndex.from_tuples([(0, "two"), (0, "one")]) + else: + expected_columns = ["two", "one"] + expected = DataFrame( + [[1.0, 3.0], [2.0, 4.0]], + index=MultiIndex.from_tuples([("z", "b"), ("y", "a")]), + columns=expected_columns, + dtype=dtype, + ) + tm.assert_frame_equal(result, expected) + + result = obj.unstack(level=-1, sort=False) + + if frame_or_series is DataFrame: + expected_columns = MultiIndex( + levels=[range(1), ["b", "a"]], codes=[[0, 0], [0, 1]] + ) + else: + expected_columns = ["b", "a"] + + item = pd.NA if dtype == "Float64" else np.nan + expected = DataFrame( + [[1.0, item], [item, 2.0], [3.0, item], [item, 4.0]], + columns=expected_columns, + index=MultiIndex.from_tuples( + [("two", "z"), ("two", "y"), ("one", "z"), ("one", "y")] + ), + dtype=dtype, + ) + tm.assert_frame_equal(result, expected) + + result = obj.unstack(level=[1, 2], sort=False) + + if frame_or_series is DataFrame: + expected_columns = MultiIndex( + levels=[range(1), ["z", "y"], ["b", "a"]], codes=[[0, 0], [0, 1], [0, 1]] + ) + else: + expected_columns = MultiIndex.from_tuples([("z", "b"), ("y", "a")]) + expected = DataFrame( + [[1.0, 2.0], [3.0, 4.0]], + index=["two", "one"], + columns=expected_columns, + dtype=dtype, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "levels2, expected_columns", + [ + ( + [None, 1, 2, 3], + [("value", np.nan), ("value", 1), ("value", 2), ("value", 3)], + ), + ( + [1, None, 2, 3], + [("value", 1), ("value", np.nan), ("value", 2), ("value", 3)], + ), + ( + [1, 2, None, 3], + [("value", 1), ("value", 2), ("value", np.nan), ("value", 3)], + ), + ( + [1, 2, 3, None], + [("value", 1), ("value", 2), ("value", 3), ("value", np.nan)], + ), + ], + ids=["nan=first", "nan=second", "nan=third", "nan=last"], +) +def test_unstack_sort_false_nan(levels2, expected_columns): + # GH#61221 + levels1 = ["b", "a"] + index = MultiIndex.from_product([levels1, levels2], names=["level1", "level2"]) + df = DataFrame({"value": [0, 1, 2, 3, 4, 5, 6, 7]}, index=index) + result = df.unstack(level="level2", sort=False) + expected_data = [[0, 4], [1, 5], [2, 6], [3, 7]] + expected = DataFrame( + dict(zip(expected_columns, expected_data)), + index=Index(["b", "a"], name="level1"), + columns=MultiIndex.from_tuples(expected_columns, names=[None, "level2"]), + ) + tm.assert_frame_equal(result, expected) + + +def test_unstack_fill_frame_object(): + # GH12815 Test unstacking with object. + data = Series(["a", "b", "c", "a"], dtype="object") + data.index = MultiIndex.from_tuples( + [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] + ) + + # By default missing values will be NaN + result = data.unstack() + expected = DataFrame( + {"a": ["a", np.nan, "a"], "b": ["b", "c", np.nan]}, + index=list("xyz"), + dtype=object, + ) + tm.assert_frame_equal(result, expected) + + # Fill with any value replaces missing values as expected + result = data.unstack(fill_value="d") + expected = DataFrame( + {"a": ["a", "d", "a"], "b": ["b", "c", "d"]}, index=list("xyz"), dtype=object + ) + tm.assert_frame_equal(result, expected) + + +def test_unstack_timezone_aware_values(): + # GH 18338 + df = DataFrame( + { + "timestamp": [pd.Timestamp("2017-08-27 01:00:00.709949+0000", tz="UTC")], + "a": ["a"], + "b": ["b"], + "c": ["c"], + }, + columns=["timestamp", "a", "b", "c"], + ) + result = df.set_index(["a", "b"]).unstack() + expected = DataFrame( + [[pd.Timestamp("2017-08-27 01:00:00.709949+0000", tz="UTC"), "c"]], + index=Index(["a"], name="a"), + columns=MultiIndex( + levels=[["timestamp", "c"], ["b"]], + codes=[[0, 1], [0, 0]], + names=[None, "b"], + ), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +def test_stack_timezone_aware_values(future_stack): + # GH 19420 + ts = date_range(freq="D", start="20180101", end="20180103", tz="America/New_York") + df = DataFrame({"A": ts}, index=["a", "b", "c"]) + result = df.stack(future_stack=future_stack) + expected = Series( + ts, + index=MultiIndex(levels=[["a", "b", "c"], ["A"]], codes=[[0, 1, 2], [0, 0, 0]]), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +@pytest.mark.parametrize("dropna", [True, False, lib.no_default]) +def test_stack_empty_frame(dropna, future_stack): + # GH 36113 + levels = [pd.RangeIndex(0), pd.RangeIndex(0)] + expected = Series(dtype=np.float64, index=MultiIndex(levels=levels, codes=[[], []])) + if future_stack and dropna is not lib.no_default: + with pytest.raises(ValueError, match="dropna must be unspecified"): + DataFrame(dtype=np.float64).stack(dropna=dropna, future_stack=future_stack) + else: + result = DataFrame(dtype=np.float64).stack( + dropna=dropna, future_stack=future_stack + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +@pytest.mark.parametrize("dropna", [True, False, lib.no_default]) +def test_stack_empty_level(dropna, future_stack, int_frame): + # GH 60740 + if future_stack and dropna is not lib.no_default: + with pytest.raises(ValueError, match="dropna must be unspecified"): + DataFrame(dtype=np.int64).stack(dropna=dropna, future_stack=future_stack) + else: + expected = int_frame + result = int_frame.copy().stack( + level=[], dropna=dropna, future_stack=future_stack + ) + tm.assert_frame_equal(result, expected) + + expected = DataFrame() + result = DataFrame().stack(level=[], dropna=dropna, future_stack=future_stack) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +@pytest.mark.parametrize("dropna", [True, False, lib.no_default]) +@pytest.mark.parametrize("fill_value", [None, 0]) +def test_stack_unstack_empty_frame(dropna, fill_value, future_stack): + # GH 36113 + if future_stack and dropna is not lib.no_default: + with pytest.raises(ValueError, match="dropna must be unspecified"): + DataFrame(dtype=np.int64).stack( + dropna=dropna, future_stack=future_stack + ).unstack(fill_value=fill_value) + else: + result = ( + DataFrame(dtype=np.int64) + .stack(dropna=dropna, future_stack=future_stack) + .unstack(fill_value=fill_value) + ) + expected = DataFrame(dtype=np.int64) + tm.assert_frame_equal(result, expected) + + +def test_unstack_single_index_series(): + # GH 36113 + msg = r"index must be a MultiIndex to unstack.*" + with pytest.raises(ValueError, match=msg): + Series(dtype=np.int64).unstack() + + +def test_unstacking_multi_index_df(): + # see gh-30740 + df = DataFrame( + { + "name": ["Alice", "Bob"], + "score": [9.5, 8], + "employed": [False, True], + "kids": [0, 0], + "gender": ["female", "male"], + } + ) + df = df.set_index(["name", "employed", "kids", "gender"]) + df = df.unstack(["gender"], fill_value=0) + expected = df.unstack("employed", fill_value=0).unstack("kids", fill_value=0) + result = df.unstack(["employed", "kids"], fill_value=0) + expected = DataFrame( + [[9.5, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 8.0]], + index=Index(["Alice", "Bob"], name="name"), + columns=MultiIndex.from_tuples( + [ + ("score", "female", False, 0), + ("score", "female", True, 0), + ("score", "male", False, 0), + ("score", "male", True, 0), + ], + names=[None, "gender", "employed", "kids"], + ), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +def test_stack_positional_level_duplicate_column_names(future_stack): + # https://github.com/pandas-dev/pandas/issues/36353 + columns = MultiIndex.from_product([("x", "y"), ("y", "z")], names=["a", "a"]) + df = DataFrame([[1, 1, 1, 1]], columns=columns) + result = df.stack(0, future_stack=future_stack) + + new_columns = Index(["y", "z"], name="a") + new_index = MultiIndex( + levels=[range(1), ["x", "y"]], codes=[[0, 0], [0, 1]], names=[None, "a"] + ) + expected = DataFrame([[1, 1], [1, 1]], index=new_index, columns=new_columns) + + tm.assert_frame_equal(result, expected) + + +def test_unstack_non_slice_like_blocks(): + # Case where the mgr_locs of a DataFrame's underlying blocks are not slice-like + + mi = MultiIndex.from_product([range(5), ["A", "B", "C"]]) + df = DataFrame( + { + 0: np.random.default_rng(2).standard_normal(15), + 1: np.random.default_rng(2).standard_normal(15).astype(np.int64), + 2: np.random.default_rng(2).standard_normal(15), + 3: np.random.default_rng(2).standard_normal(15), + }, + index=mi, + ) + assert any(not x.mgr_locs.is_slice_like for x in df._mgr.blocks) + + res = df.unstack() + + expected = pd.concat([df[n].unstack() for n in range(4)], keys=range(4), axis=1) + tm.assert_frame_equal(res, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +def test_stack_sort_false(future_stack): + # GH 15105 + data = [[1, 2, 3.0, 4.0], [2, 3, 4.0, 5.0], [3, 4, np.nan, np.nan]] + df = DataFrame( + data, + columns=MultiIndex( + levels=[["B", "A"], ["x", "y"]], codes=[[0, 0, 1, 1], [0, 1, 0, 1]] + ), + ) + kwargs = {} if future_stack else {"sort": False} + result = df.stack(level=0, future_stack=future_stack, **kwargs) + if future_stack: + expected = DataFrame( + { + "x": [1.0, 3.0, 2.0, 4.0, 3.0, np.nan], + "y": [2.0, 4.0, 3.0, 5.0, 4.0, np.nan], + }, + index=MultiIndex.from_arrays( + [[0, 0, 1, 1, 2, 2], ["B", "A", "B", "A", "B", "A"]] + ), + ) + else: + expected = DataFrame( + {"x": [1.0, 3.0, 2.0, 4.0, 3.0], "y": [2.0, 4.0, 3.0, 5.0, 4.0]}, + index=MultiIndex.from_arrays([[0, 0, 1, 1, 2], ["B", "A", "B", "A", "B"]]), + ) + tm.assert_frame_equal(result, expected) + + # Codes sorted in this call + df = DataFrame( + data, + columns=MultiIndex.from_arrays([["B", "B", "A", "A"], ["x", "y", "x", "y"]]), + ) + kwargs = {} if future_stack else {"sort": False} + result = df.stack(level=0, future_stack=future_stack, **kwargs) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +def test_stack_sort_false_multi_level(future_stack): + # GH 15105 + idx = MultiIndex.from_tuples([("weight", "kg"), ("height", "m")]) + df = DataFrame([[1.0, 2.0], [3.0, 4.0]], index=["cat", "dog"], columns=idx) + kwargs = {} if future_stack else {"sort": False} + result = df.stack([0, 1], future_stack=future_stack, **kwargs) + expected_index = MultiIndex.from_tuples( + [ + ("cat", "weight", "kg"), + ("cat", "height", "m"), + ("dog", "weight", "kg"), + ("dog", "height", "m"), + ] + ) + expected = Series([1.0, 2.0, 3.0, 4.0], index=expected_index) + tm.assert_series_equal(result, expected) + + +class TestStackUnstackMultiLevel: + def test_unstack(self, multiindex_year_month_day_dataframe_random_data): + # just check that it works for now + ymd = multiindex_year_month_day_dataframe_random_data + + unstacked = ymd.unstack() + unstacked.unstack() + + # test that ints work + ymd.astype(int).unstack() + + # test that int32 work + ymd.astype(np.int32).unstack() + + @pytest.mark.parametrize( + "result_rows,result_columns,index_product,expected_row", + [ + ( + [[1, 1, None, None, 30.0, None], [2, 2, None, None, 30.0, None]], + ["ix1", "ix2", "col1", "col2", "col3", "col4"], + 2, + [None, None, 30.0, None], + ), + ( + [[1, 1, None, None, 30.0], [2, 2, None, None, 30.0]], + ["ix1", "ix2", "col1", "col2", "col3"], + 2, + [None, None, 30.0], + ), + ( + [[1, 1, None, None, 30.0], [2, None, None, None, 30.0]], + ["ix1", "ix2", "col1", "col2", "col3"], + None, + [None, None, 30.0], + ), + ], + ) + def test_unstack_partial( + self, result_rows, result_columns, index_product, expected_row + ): + # check for regressions on this issue: + # https://github.com/pandas-dev/pandas/issues/19351 + # make sure DataFrame.unstack() works when its run on a subset of the DataFrame + # and the Index levels contain values that are not present in the subset + result = DataFrame(result_rows, columns=result_columns).set_index( + ["ix1", "ix2"] + ) + result = result.iloc[1:2].unstack("ix2") + expected = DataFrame( + [expected_row], + columns=MultiIndex.from_product( + [result_columns[2:], [index_product]], names=[None, "ix2"] + ), + index=Index([2], name="ix1"), + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_multiple_no_empty_columns(self): + index = MultiIndex.from_tuples( + [(0, "foo", 0), (0, "bar", 0), (1, "baz", 1), (1, "qux", 1)] + ) + + s = Series(np.random.default_rng(2).standard_normal(4), index=index) + + unstacked = s.unstack([1, 2]) + expected = unstacked.dropna(axis=1, how="all") + tm.assert_frame_equal(unstacked, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack(self, multiindex_year_month_day_dataframe_random_data, future_stack): + ymd = multiindex_year_month_day_dataframe_random_data + + # regular roundtrip + unstacked = ymd.unstack() + restacked = unstacked.stack(future_stack=future_stack) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + restacked = restacked.dropna(how="all") + tm.assert_frame_equal(restacked, ymd) + + unlexsorted = ymd.sort_index(level=2) + + unstacked = unlexsorted.unstack(2) + restacked = unstacked.stack(future_stack=future_stack) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + restacked = restacked.dropna(how="all") + tm.assert_frame_equal(restacked.sort_index(level=0), ymd) + + unlexsorted = unlexsorted[::-1] + unstacked = unlexsorted.unstack(1) + restacked = unstacked.stack(future_stack=future_stack).swaplevel(1, 2) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + restacked = restacked.dropna(how="all") + tm.assert_frame_equal(restacked.sort_index(level=0), ymd) + + unlexsorted = unlexsorted.swaplevel(0, 1) + unstacked = unlexsorted.unstack(0).swaplevel(0, 1, axis=1) + restacked = unstacked.stack(0, future_stack=future_stack).swaplevel(1, 2) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + restacked = restacked.dropna(how="all") + tm.assert_frame_equal(restacked.sort_index(level=0), ymd) + + # columns unsorted + unstacked = ymd.unstack() + restacked = unstacked.stack(future_stack=future_stack) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + restacked = restacked.dropna(how="all") + tm.assert_frame_equal(restacked, ymd) + + # more than 2 levels in the columns + unstacked = ymd.unstack(1).unstack(1) + + result = unstacked.stack(1, future_stack=future_stack) + expected = ymd.unstack() + tm.assert_frame_equal(result, expected) + + result = unstacked.stack(2, future_stack=future_stack) + expected = ymd.unstack(1) + tm.assert_frame_equal(result, expected) + + result = unstacked.stack(0, future_stack=future_stack) + expected = ymd.stack(future_stack=future_stack).unstack(1).unstack(1) + tm.assert_frame_equal(result, expected) + + # not all levels present in each echelon + unstacked = ymd.unstack(2).loc[:, ::3] + stacked = unstacked.stack(future_stack=future_stack).stack( + future_stack=future_stack + ) + ymd_stacked = ymd.stack(future_stack=future_stack) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + stacked = stacked.dropna(how="all") + ymd_stacked = ymd_stacked.dropna(how="all") + tm.assert_series_equal(stacked, ymd_stacked.reindex(stacked.index)) + + # stack with negative number + result = ymd.unstack(0).stack(-2, future_stack=future_stack) + expected = ymd.unstack(0).stack(0, future_stack=future_stack) + tm.assert_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize( + "idx, exp_idx", + [ + [ + list("abab"), + MultiIndex( + levels=[["a", "b"], ["1st", "2nd"]], + codes=[np.tile(np.arange(2).repeat(3), 2), np.tile([0, 1, 0], 4)], + ), + ], + [ + MultiIndex.from_tuples((("a", 2), ("b", 1), ("a", 1), ("b", 2))), + MultiIndex( + levels=[["a", "b"], [1, 2], ["1st", "2nd"]], + codes=[ + np.tile(np.arange(2).repeat(3), 2), + np.repeat([1, 0, 1], [3, 6, 3]), + np.tile([0, 1, 0], 4), + ], + ), + ], + ], + ) + def test_stack_duplicate_index(self, idx, exp_idx, future_stack): + # GH10417 + df = DataFrame( + np.arange(12).reshape(4, 3), + index=idx, + columns=["1st", "2nd", "1st"], + ) + if future_stack: + msg = "Columns with duplicate values are not supported in stack" + with pytest.raises(ValueError, match=msg): + df.stack(future_stack=future_stack) + else: + result = df.stack(future_stack=future_stack) + expected = Series(np.arange(12), index=exp_idx) + tm.assert_series_equal(result, expected) + assert result.index.is_unique is False + li, ri = result.index, expected.index + tm.assert_index_equal(li, ri) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_unstack_odd_failure(self, future_stack): + mi = MultiIndex.from_arrays( + [ + ["Fri"] * 4 + ["Sat"] * 2 + ["Sun"] * 2 + ["Thu"] * 3, + ["Dinner"] * 2 + ["Lunch"] * 2 + ["Dinner"] * 5 + ["Lunch"] * 2, + ["No", "Yes"] * 4 + ["No", "No", "Yes"], + ], + names=["day", "time", "smoker"], + ) + df = DataFrame( + { + "sum": np.arange(11, dtype="float64"), + "len": np.arange(11, dtype="float64"), + }, + index=mi, + ) + # it works, #2100 + result = df.unstack(2) + + recons = result.stack(future_stack=future_stack) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + recons = recons.dropna(how="all") + tm.assert_frame_equal(recons, df) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_mixed_dtype(self, multiindex_dataframe_random_data, future_stack): + frame = multiindex_dataframe_random_data + + df = frame.T + df["foo", "four"] = "foo" + df = df.sort_index(level=1, axis=1) + + stacked = df.stack(future_stack=future_stack) + result = df["foo"].stack(future_stack=future_stack).sort_index() + tm.assert_series_equal(stacked["foo"], result, check_names=False) + assert result.name is None + assert stacked["bar"].dtype == np.float64 + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_unstack_bug(self, future_stack): + df = DataFrame( + { + "state": ["naive", "naive", "naive", "active", "active", "active"], + "exp": ["a", "b", "b", "b", "a", "a"], + "barcode": [1, 2, 3, 4, 1, 3], + "v": ["hi", "hi", "bye", "bye", "bye", "peace"], + "extra": np.arange(6.0), + } + ) + + result = df.groupby(["state", "exp", "barcode", "v"]).apply(len) + unstacked = result.unstack() + restacked = unstacked.stack(future_stack=future_stack) + tm.assert_series_equal(restacked, result.reindex(restacked.index).astype(float)) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_unstack_preserve_names( + self, multiindex_dataframe_random_data, future_stack + ): + frame = multiindex_dataframe_random_data + + unstacked = frame.unstack() + assert unstacked.index.name == "first" + assert unstacked.columns.names == ["exp", "second"] + + restacked = unstacked.stack(future_stack=future_stack) + assert restacked.index.names == frame.index.names + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize("method", ["stack", "unstack"]) + def test_stack_unstack_wrong_level_name( + self, method, multiindex_dataframe_random_data, future_stack + ): + # GH 18303 - wrong level name should raise + frame = multiindex_dataframe_random_data + + # A DataFrame with flat axes: + df = frame.loc["foo"] + + kwargs = {"future_stack": future_stack} if method == "stack" else {} + with pytest.raises(KeyError, match="does not match index name"): + getattr(df, method)("mistake", **kwargs) + + if method == "unstack": + # Same on a Series: + s = df.iloc[:, 0] + with pytest.raises(KeyError, match="does not match index name"): + getattr(s, method)("mistake", **kwargs) + + def test_unstack_level_name(self, multiindex_dataframe_random_data): + frame = multiindex_dataframe_random_data + + result = frame.unstack("second") + expected = frame.unstack(level=1) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_level_name(self, multiindex_dataframe_random_data, future_stack): + frame = multiindex_dataframe_random_data + + unstacked = frame.unstack("second") + result = unstacked.stack("exp", future_stack=future_stack) + expected = frame.unstack().stack(0, future_stack=future_stack) + tm.assert_frame_equal(result, expected) + + result = frame.stack("exp", future_stack=future_stack) + expected = frame.stack(future_stack=future_stack) + tm.assert_series_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_unstack_multiple( + self, multiindex_year_month_day_dataframe_random_data, future_stack + ): + ymd = multiindex_year_month_day_dataframe_random_data + + unstacked = ymd.unstack(["year", "month"]) + expected = ymd.unstack("year").unstack("month") + tm.assert_frame_equal(unstacked, expected) + assert unstacked.columns.names == expected.columns.names + + # series + s = ymd["A"] + s_unstacked = s.unstack(["year", "month"]) + tm.assert_frame_equal(s_unstacked, expected["A"]) + + restacked = unstacked.stack(["year", "month"], future_stack=future_stack) + if future_stack: + # NA values in unstacked persist to restacked in version 3 + restacked = restacked.dropna(how="all") + restacked = restacked.swaplevel(0, 1).swaplevel(1, 2) + restacked = restacked.sort_index(level=0) + + tm.assert_frame_equal(restacked, ymd) + assert restacked.index.names == ymd.index.names + + # GH #451 + unstacked = ymd.unstack([1, 2]) + expected = ymd.unstack(1).unstack(1).dropna(axis=1, how="all") + tm.assert_frame_equal(unstacked, expected) + + unstacked = ymd.unstack([2, 1]) + expected = ymd.unstack(2).unstack(1).dropna(axis=1, how="all") + tm.assert_frame_equal(unstacked, expected.loc[:, unstacked.columns]) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_names_and_numbers( + self, multiindex_year_month_day_dataframe_random_data, future_stack + ): + ymd = multiindex_year_month_day_dataframe_random_data + + unstacked = ymd.unstack(["year", "month"]) + + # Can't use mixture of names and numbers to stack + with pytest.raises(ValueError, match="level should contain"): + unstacked.stack([0, "month"], future_stack=future_stack) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_multiple_out_of_bounds( + self, multiindex_year_month_day_dataframe_random_data, future_stack + ): + # nlevels == 3 + ymd = multiindex_year_month_day_dataframe_random_data + + unstacked = ymd.unstack(["year", "month"]) + + with pytest.raises(IndexError, match="Too many levels"): + unstacked.stack([2, 3], future_stack=future_stack) + with pytest.raises(IndexError, match="not a valid level number"): + unstacked.stack([-4, -3], future_stack=future_stack) + + def test_unstack_period_series(self): + # GH4342 + idx1 = pd.PeriodIndex( + ["2013-01", "2013-01", "2013-02", "2013-02", "2013-03", "2013-03"], + freq="M", + name="period", + ) + idx2 = Index(["A", "B"] * 3, name="str") + value = [1, 2, 3, 4, 5, 6] + + idx = MultiIndex.from_arrays([idx1, idx2]) + s = Series(value, index=idx) + + result1 = s.unstack() + result2 = s.unstack(level=1) + result3 = s.unstack(level=0) + + e_idx = pd.PeriodIndex( + ["2013-01", "2013-02", "2013-03"], freq="M", name="period" + ) + expected = DataFrame( + {"A": [1, 3, 5], "B": [2, 4, 6]}, index=e_idx, columns=["A", "B"] + ) + expected.columns.name = "str" + + tm.assert_frame_equal(result1, expected) + tm.assert_frame_equal(result2, expected) + tm.assert_frame_equal(result3, expected.T) + + idx1 = pd.PeriodIndex( + ["2013-01", "2013-01", "2013-02", "2013-02", "2013-03", "2013-03"], + freq="M", + name="period1", + ) + + idx2 = pd.PeriodIndex( + ["2013-12", "2013-11", "2013-10", "2013-09", "2013-08", "2013-07"], + freq="M", + name="period2", + ) + idx = MultiIndex.from_arrays([idx1, idx2]) + s = Series(value, index=idx) + + result1 = s.unstack() + result2 = s.unstack(level=1) + result3 = s.unstack(level=0) + + e_idx = pd.PeriodIndex( + ["2013-01", "2013-02", "2013-03"], freq="M", name="period1" + ) + e_cols = pd.PeriodIndex( + ["2013-07", "2013-08", "2013-09", "2013-10", "2013-11", "2013-12"], + freq="M", + name="period2", + ) + expected = DataFrame( + [ + [np.nan, np.nan, np.nan, np.nan, 2, 1], + [np.nan, np.nan, 4, 3, np.nan, np.nan], + [6, 5, np.nan, np.nan, np.nan, np.nan], + ], + index=e_idx, + columns=e_cols, + ) + + tm.assert_frame_equal(result1, expected) + tm.assert_frame_equal(result2, expected) + tm.assert_frame_equal(result3, expected.T) + + def test_unstack_period_frame(self): + # GH4342 + idx1 = pd.PeriodIndex( + ["2014-01", "2014-02", "2014-02", "2014-02", "2014-01", "2014-01"], + freq="M", + name="period1", + ) + idx2 = pd.PeriodIndex( + ["2013-12", "2013-12", "2014-02", "2013-10", "2013-10", "2014-02"], + freq="M", + name="period2", + ) + value = {"A": [1, 2, 3, 4, 5, 6], "B": [6, 5, 4, 3, 2, 1]} + idx = MultiIndex.from_arrays([idx1, idx2]) + df = DataFrame(value, index=idx) + + result1 = df.unstack() + result2 = df.unstack(level=1) + result3 = df.unstack(level=0) + + e_1 = pd.PeriodIndex(["2014-01", "2014-02"], freq="M", name="period1") + e_2 = pd.PeriodIndex( + ["2013-10", "2013-12", "2014-02", "2013-10", "2013-12", "2014-02"], + freq="M", + name="period2", + ) + e_cols = MultiIndex.from_arrays(["A A A B B B".split(), e_2]) + expected = DataFrame( + [[5, 1, 6, 2, 6, 1], [4, 2, 3, 3, 5, 4]], index=e_1, columns=e_cols + ) + + tm.assert_frame_equal(result1, expected) + tm.assert_frame_equal(result2, expected) + + e_1 = pd.PeriodIndex( + ["2014-01", "2014-02", "2014-01", "2014-02"], freq="M", name="period1" + ) + e_2 = pd.PeriodIndex( + ["2013-10", "2013-12", "2014-02"], freq="M", name="period2" + ) + e_cols = MultiIndex.from_arrays(["A A B B".split(), e_1]) + expected = DataFrame( + [[5, 4, 2, 3], [1, 2, 6, 5], [6, 3, 1, 4]], index=e_2, columns=e_cols + ) + + tm.assert_frame_equal(result3, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_multiple_bug(self, future_stack, using_infer_string): + # bug when some uniques are not present in the data GH#3170 + id_col = ([1] * 3) + ([2] * 3) + name = (["a"] * 3) + (["b"] * 3) + date = pd.to_datetime(["2013-01-03", "2013-01-04", "2013-01-05"] * 2) + var1 = np.random.default_rng(2).integers(0, 100, 6) + df = DataFrame({"ID": id_col, "NAME": name, "DATE": date, "VAR1": var1}) + + multi = df.set_index(["DATE", "ID"]) + multi.columns.name = "Params" + unst = multi.unstack("ID") + msg = re.escape("agg function failed [how->mean,dtype->") + if using_infer_string: + msg = "dtype 'str' does not support operation 'mean'" + with pytest.raises(TypeError, match=msg): + unst.resample("W-THU").mean() + down = unst.resample("W-THU").mean(numeric_only=True) + rs = down.stack("ID", future_stack=future_stack) + xp = ( + unst.loc[:, ["VAR1"]] + .resample("W-THU") + .mean() + .stack("ID", future_stack=future_stack) + ) + xp.columns.name = "Params" + tm.assert_frame_equal(rs, xp) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_dropna(self, future_stack): + # GH#3997 + df = DataFrame({"A": ["a1", "a2"], "B": ["b1", "b2"], "C": [1, 1]}) + df = df.set_index(["A", "B"]) + + dropna = False if not future_stack else lib.no_default + stacked = df.unstack().stack(dropna=dropna, future_stack=future_stack) + assert len(stacked) > len(stacked.dropna()) + + if future_stack: + with pytest.raises(ValueError, match="dropna must be unspecified"): + df.unstack().stack(dropna=True, future_stack=future_stack) + else: + stacked = df.unstack().stack(dropna=True, future_stack=future_stack) + tm.assert_frame_equal(stacked, stacked.dropna()) + + def test_unstack_multiple_hierarchical(self, future_stack): + df = DataFrame( + index=[ + [0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 0, 0, 1, 1], + [0, 1, 0, 1, 0, 1, 0, 1], + ], + columns=[[0, 0, 1, 1], [0, 1, 0, 1]], + ) + + df.index.names = ["a", "b", "c"] + df.columns.names = ["d", "e"] + + # it works! + df.unstack(["b", "c"]) + + def test_unstack_sparse_keyspace(self): + # memory problems with naive impl GH#2278 + # Generate Long File & Test Pivot + NUM_ROWS = 1000 + + df = DataFrame( + { + "A": np.random.default_rng(2).integers(100, size=NUM_ROWS), + "B": np.random.default_rng(3).integers(300, size=NUM_ROWS), + "C": np.random.default_rng(4).integers(-7, 7, size=NUM_ROWS), + "D": np.random.default_rng(5).integers(-19, 19, size=NUM_ROWS), + "E": np.random.default_rng(6).integers(3000, size=NUM_ROWS), + "F": np.random.default_rng(7).standard_normal(NUM_ROWS), + } + ) + + idf = df.set_index(["A", "B", "C", "D", "E"]) + + # it works! is sufficient + idf.unstack("E") + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_unstack_unobserved_keys(self, future_stack): + # related to GH#2278 refactoring + levels = [[0, 1], [0, 1, 2, 3]] + codes = [[0, 0, 1, 1], [0, 2, 0, 2]] + + index = MultiIndex(levels, codes) + + df = DataFrame(np.random.default_rng(2).standard_normal((4, 2)), index=index) + + result = df.unstack() + assert len(result.columns) == 4 + + recons = result.stack(future_stack=future_stack) + tm.assert_frame_equal(recons, df) + + @pytest.mark.slow + def test_unstack_number_of_levels_larger_than_int32_warns( + self, performance_warning, monkeypatch + ): + # GH#20601 + # GH 26314: Change ValueError to PerformanceWarning + + class MockUnstacker(reshape_lib._Unstacker): + def __init__(self, *args, **kwargs) -> None: + # __init__ will raise the warning + super().__init__(*args, **kwargs) + raise Exception("Don't compute final result.") + + def _make_selectors(self) -> None: + pass + + with monkeypatch.context() as m: + m.setattr(reshape_lib, "_Unstacker", MockUnstacker) + df = DataFrame( + np.zeros((2**16, 2)), + index=[np.arange(2**16), np.arange(2**16)], + ) + msg = "The following operation may generate" + with tm.assert_produces_warning(performance_warning, match=msg): + with pytest.raises(Exception, match="Don't compute final result."): + df.unstack() + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + @pytest.mark.parametrize( + "levels", + itertools.chain.from_iterable( + itertools.product(itertools.permutations([0, 1, 2], width), repeat=2) + for width in [2, 3] + ), + ) + @pytest.mark.parametrize("stack_lev", range(2)) + def test_stack_order_with_unsorted_levels( + self, levels, stack_lev, sort, future_stack + ): + # GH#16323 + # deep check for 1-row case + columns = MultiIndex(levels=levels, codes=[[0, 0, 1, 1], [0, 1, 0, 1]]) + df = DataFrame(columns=columns, data=[range(4)]) + kwargs = {} if future_stack else {"sort": sort} + df_stacked = df.stack(stack_lev, future_stack=future_stack, **kwargs) + for row in df.index: + for col in df.columns: + expected = df.loc[row, col] + result_row = row, col[stack_lev] + result_col = col[1 - stack_lev] + result = df_stacked.loc[result_row, result_col] + assert result == expected + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_order_with_unsorted_levels_multi_row(self, future_stack): + # GH#16323 + + # check multi-row case + mi = MultiIndex( + levels=[["A", "C", "B"], ["B", "A", "C"]], + codes=[np.repeat(range(3), 3), np.tile(range(3), 3)], + ) + df = DataFrame( + columns=mi, index=range(5), data=np.arange(5 * len(mi)).reshape(5, -1) + ) + assert all( + df.loc[row, col] + == df.stack(0, future_stack=future_stack).loc[(row, col[0]), col[1]] + for row in df.index + for col in df.columns + ) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_order_with_unsorted_levels_multi_row_2(self, future_stack): + # GH#53636 + levels = ((0, 1), (1, 0)) + stack_lev = 1 + columns = MultiIndex(levels=levels, codes=[[0, 0, 1, 1], [0, 1, 0, 1]]) + df = DataFrame(columns=columns, data=[range(4)], index=[1, 0, 2, 3]) + kwargs = {} if future_stack else {"sort": True} + result = df.stack(stack_lev, future_stack=future_stack, **kwargs) + expected_index = MultiIndex( + levels=[[0, 1, 2, 3], [0, 1]], + codes=[[1, 1, 0, 0, 2, 2, 3, 3], [1, 0, 1, 0, 1, 0, 1, 0]], + ) + expected = DataFrame( + { + 0: [0, 1, 0, 1, 0, 1, 0, 1], + 1: [2, 3, 2, 3, 2, 3, 2, 3], + }, + index=expected_index, + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_unstack_unordered_multiindex(self, future_stack): + # GH# 18265 + values = np.arange(5) + data = np.vstack( + [ + [f"b{x}" for x in values], # b0, b1, .. + [f"a{x}" for x in values], # a0, a1, .. + ] + ) + df = DataFrame(data.T, columns=["b", "a"]) + df.columns.name = "first" + second_level_dict = {"x": df} + multi_level_df = pd.concat(second_level_dict, axis=1) + multi_level_df.columns.names = ["second", "first"] + df = multi_level_df.reindex(sorted(multi_level_df.columns), axis=1) + result = df.stack(["first", "second"], future_stack=future_stack).unstack( + ["first", "second"] + ) + expected = DataFrame( + [["a0", "b0"], ["a1", "b1"], ["a2", "b2"], ["a3", "b3"], ["a4", "b4"]], + index=range(5), + columns=MultiIndex.from_tuples( + [("a", "x"), ("b", "x")], names=["first", "second"] + ), + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_preserve_types( + self, multiindex_year_month_day_dataframe_random_data, using_infer_string + ): + # GH#403 + ymd = multiindex_year_month_day_dataframe_random_data + ymd["E"] = "foo" + ymd["F"] = 2 + + unstacked = ymd.unstack("month") + assert unstacked["A", 1].dtype == np.float64 + assert ( + unstacked["E", 1].dtype == np.object_ + if not using_infer_string + else "string" + ) + assert unstacked["F", 1].dtype == np.float64 + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_unstack_group_index_overflow(self, future_stack): + codes = np.tile(np.arange(500), 2) + level = np.arange(500) + + index = MultiIndex( + levels=[level] * 8 + [[0, 1]], + codes=[codes] * 8 + [np.arange(2).repeat(500)], + ) + + s = Series(np.arange(1000), index=index) + result = s.unstack() + assert result.shape == (500, 2) + + # test roundtrip + stacked = result.stack(future_stack=future_stack) + tm.assert_series_equal(s, stacked.reindex(s.index)) + + # put it at beginning + index = MultiIndex( + levels=[[0, 1]] + [level] * 8, + codes=[np.arange(2).repeat(500)] + [codes] * 8, + ) + + s = Series(np.arange(1000), index=index) + result = s.unstack(0) + assert result.shape == (500, 2) + + # put it in middle + index = MultiIndex( + levels=[level] * 4 + [[0, 1]] + [level] * 4, + codes=([codes] * 4 + [np.arange(2).repeat(500)] + [codes] * 4), + ) + + s = Series(np.arange(1000), index=index) + result = s.unstack(4) + assert result.shape == (500, 2) + + def test_unstack_with_missing_int_cast_to_float(self): + # https://github.com/pandas-dev/pandas/issues/37115 + df = DataFrame( + { + "a": ["A", "A", "B"], + "b": ["ca", "cb", "cb"], + "v": [10] * 3, + } + ).set_index(["a", "b"]) + + # add another int column to get 2 blocks + df["is_"] = 1 + assert len(df._mgr.blocks) == 2 + + result = df.unstack("b") + result[("is_", "ca")] = result[("is_", "ca")].fillna(0) + + expected = DataFrame( + [[10.0, 10.0, 1.0, 1.0], [np.nan, 10.0, 0.0, 1.0]], + index=Index(["A", "B"], name="a"), + columns=MultiIndex.from_tuples( + [("v", "ca"), ("v", "cb"), ("is_", "ca"), ("is_", "cb")], + names=[None, "b"], + ), + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_with_level_has_nan(self): + # GH 37510 + df1 = DataFrame( + { + "L1": [1, 2, 3, 4], + "L2": [3, 4, 1, 2], + "L3": [1, 1, 1, 1], + "x": [1, 2, 3, 4], + } + ) + df1 = df1.set_index(["L1", "L2", "L3"]) + new_levels = ["n1", "n2", "n3", None] + df1.index = df1.index.set_levels(levels=new_levels, level="L1") + df1.index = df1.index.set_levels(levels=new_levels, level="L2") + + result = df1.unstack("L3")[("x", 1)].sort_index().index + expected = MultiIndex( + levels=[["n1", "n2", "n3", None], ["n1", "n2", "n3", None]], + codes=[[0, 1, 2, 3], [2, 3, 0, 1]], + names=["L1", "L2"], + ) + + tm.assert_index_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_nan_in_multiindex_columns(self, future_stack): + # GH#39481 + df = DataFrame( + np.zeros([1, 5]), + columns=MultiIndex.from_tuples( + [ + (0, None, None), + (0, 2, 0), + (0, 2, 1), + (0, 3, 0), + (0, 3, 1), + ], + ), + ) + result = df.stack(2, future_stack=future_stack) + if future_stack: + index = MultiIndex(levels=[[0], [0.0, 1.0]], codes=[[0, 0, 0], [-1, 0, 1]]) + columns = MultiIndex(levels=[[0], [2, 3]], codes=[[0, 0, 0], [-1, 0, 1]]) + else: + index = Index([(0, None), (0, 0), (0, 1)]) + columns = Index([(0, None), (0, 2), (0, 3)]) + expected = DataFrame( + [[0.0, np.nan, np.nan], [np.nan, 0.0, 0.0], [np.nan, 0.0, 0.0]], + index=index, + columns=columns, + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_multi_level_stack_categorical(self, future_stack): + # GH 15239 + midx = MultiIndex.from_arrays( + [ + ["A"] * 2 + ["B"] * 2, + pd.Categorical(list("abab")), + pd.Categorical(list("ccdd")), + ] + ) + df = DataFrame(np.arange(8).reshape(2, 4), columns=midx) + result = df.stack([1, 2], future_stack=future_stack) + if future_stack: + expected = DataFrame( + [ + [0, np.nan], + [1, np.nan], + [np.nan, 2], + [np.nan, 3], + [4, np.nan], + [5, np.nan], + [np.nan, 6], + [np.nan, 7], + ], + columns=["A", "B"], + index=MultiIndex.from_arrays( + [ + [0] * 4 + [1] * 4, + pd.Categorical(list("abababab")), + pd.Categorical(list("ccddccdd")), + ] + ), + ) + else: + expected = DataFrame( + [ + [0, np.nan], + [np.nan, 2], + [1, np.nan], + [np.nan, 3], + [4, np.nan], + [np.nan, 6], + [5, np.nan], + [np.nan, 7], + ], + columns=["A", "B"], + index=MultiIndex.from_arrays( + [ + [0] * 4 + [1] * 4, + pd.Categorical(list("aabbaabb")), + pd.Categorical(list("cdcdcdcd")), + ] + ), + ) + tm.assert_frame_equal(result, expected, check_index_type=False) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_nan_level(self, future_stack): + # GH 9406 + df_nan = DataFrame( + np.arange(4).reshape(2, 2), + columns=MultiIndex.from_tuples( + [("A", np.nan), ("B", "b")], names=["Upper", "Lower"] + ), + index=Index([0, 1], name="Num"), + dtype=np.float64, + ) + result = df_nan.stack(future_stack=future_stack) + if future_stack: + index = MultiIndex( + levels=[[0, 1], [np.nan, "b"]], + codes=[[0, 0, 1, 1], [0, 1, 0, 1]], + names=["Num", "Lower"], + ) + else: + index = MultiIndex.from_tuples( + [(0, np.nan), (0, "b"), (1, np.nan), (1, "b")], names=["Num", "Lower"] + ) + expected = DataFrame( + [[0.0, np.nan], [np.nan, 1], [2.0, np.nan], [np.nan, 3.0]], + columns=Index(["A", "B"], name="Upper"), + index=index, + ) + tm.assert_frame_equal(result, expected) + + def test_unstack_categorical_columns(self): + # GH 14018 + idx = MultiIndex.from_product([["A"], [0, 1]]) + df = DataFrame({"cat": pd.Categorical(["a", "b"])}, index=idx) + result = df.unstack() + expected = DataFrame( + { + 0: pd.Categorical(["a"], categories=["a", "b"]), + 1: pd.Categorical(["b"], categories=["a", "b"]), + }, + index=["A"], + ) + expected.columns = MultiIndex.from_tuples([("cat", 0), ("cat", 1)]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_unsorted(self, future_stack): + # GH 16925 + PAE = ["ITA", "FRA"] + VAR = ["A1", "A2"] + TYP = ["CRT", "DBT", "NET"] + MI = MultiIndex.from_product([PAE, VAR, TYP], names=["PAE", "VAR", "TYP"]) + + V = list(range(len(MI))) + DF = DataFrame(data=V, index=MI, columns=["VALUE"]) + + DF = DF.unstack(["VAR", "TYP"]) + DF.columns = DF.columns.droplevel(0) + DF.loc[:, ("A0", "NET")] = 9999 + + result = DF.stack(["VAR", "TYP"], future_stack=future_stack).sort_index() + expected = ( + DF.sort_index(axis=1) + .stack(["VAR", "TYP"], future_stack=future_stack) + .sort_index() + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:The previous implementation of stack is deprecated" + ) + def test_stack_nullable_dtype(self, future_stack): + # GH#43561 + columns = MultiIndex.from_product( + [["54511", "54515"], ["r", "t_mean"]], names=["station", "element"] + ) + index = Index([1, 2, 3], name="time") + + arr = np.array([[50, 226, 10, 215], [10, 215, 9, 220], [305, 232, 111, 220]]) + df = DataFrame(arr, columns=columns, index=index, dtype=pd.Int64Dtype()) + + result = df.stack("station", future_stack=future_stack) + + expected = ( + df.astype(np.int64) + .stack("station", future_stack=future_stack) + .astype(pd.Int64Dtype()) + ) + tm.assert_frame_equal(result, expected) + + # non-homogeneous case + df[df.columns[0]] = df[df.columns[0]].astype(pd.Float64Dtype()) + result = df.stack("station", future_stack=future_stack) + + expected = DataFrame( + { + "r": pd.array( + [50.0, 10.0, 10.0, 9.0, 305.0, 111.0], dtype=pd.Float64Dtype() + ), + "t_mean": pd.array( + [226, 215, 215, 220, 232, 220], dtype=pd.Int64Dtype() + ), + }, + index=MultiIndex.from_product([index, columns.levels[0]]), + ) + expected.columns.name = "element" + tm.assert_frame_equal(result, expected) + + def test_unstack_mixed_level_names(self): + # GH#48763 + arrays = [["a", "a"], [1, 2], ["red", "blue"]] + idx = MultiIndex.from_arrays(arrays, names=("x", 0, "y")) + df = DataFrame({"m": [1, 2]}, index=idx) + result = df.unstack("x") + expected = DataFrame( + [[1], [2]], + columns=MultiIndex.from_tuples([("m", "a")], names=[None, "x"]), + index=MultiIndex.from_tuples([(1, "red"), (2, "blue")], names=[0, "y"]), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:The previous implementation of stack is deprecated") +def test_stack_tuple_columns(future_stack): + # GH#54948 - test stack when the input has a non-MultiIndex with tuples + df = DataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], columns=[("a", 1), ("a", 2), ("b", 1)] + ) + result = df.stack(future_stack=future_stack) + expected = Series( + [1, 2, 3, 4, 5, 6, 7, 8, 9], + index=MultiIndex( + levels=[range(3), [("a", 1), ("a", 2), ("b", 1)]], + codes=[[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]], + ), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype, na_value", + [ + ("float64", np.nan), + ("Float64", np.nan), + ("Float64", pd.NA), + ("Int64", pd.NA), + ], +) +@pytest.mark.parametrize("test_multiindex", [True, False]) +def test_stack_preserves_na(dtype, na_value, test_multiindex): + # GH#56573 + if test_multiindex: + index = MultiIndex.from_arrays(2 * [Index([na_value], dtype=dtype)]) + else: + index = Index([na_value], dtype=dtype) + df = DataFrame({"a": [1]}, index=index) + result = df.stack() + + if test_multiindex: + expected_index = MultiIndex.from_arrays( + [ + Index([na_value], dtype=dtype), + Index([na_value], dtype=dtype), + Index(["a"]), + ] + ) + else: + expected_index = MultiIndex.from_arrays( + [ + Index([na_value], dtype=dtype), + Index(["a"]), + ] + ) + expected = Series(1, index=expected_index) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/frame/test_subclass.py b/pandas/tests/frame/test_subclass.py new file mode 100644 index 0000000000000000000000000000000000000000..c1abbeea80ff3093948cb7c031eed304089a377c --- /dev/null +++ b/pandas/tests/frame/test_subclass.py @@ -0,0 +1,817 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, +) +import pandas._testing as tm + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager|Passing a SingleBlockManager:DeprecationWarning" +) + + +class TestDataFrameSubclassing: + def test_no_warning_on_mgr(self): + # GH#57032 + df = tm.SubclassedDataFrame( + {"X": [1, 2, 3], "Y": [1, 2, 3]}, index=["a", "b", "c"] + ) + with tm.assert_produces_warning(None): + # df.isna() goes through _constructor_from_mgr, which we want to + # *not* pass a Manager do __init__ + df.isna() + df["X"].isna() + + def test_frame_subclassing_and_slicing(self): + # Subclass frame and ensure it returns the right class on slicing it + # In reference to PR 9632 + + class CustomSeries(Series): + @property + def _constructor(self): + return CustomSeries + + def custom_series_function(self): + return "OK" + + class CustomDataFrame(DataFrame): + """ + Subclasses pandas DF, fills DF with simulation results, adds some + custom plotting functions. + """ + + def __init__(self, *args, **kw) -> None: + super().__init__(*args, **kw) + + @property + def _constructor(self): + return CustomDataFrame + + _constructor_sliced = CustomSeries + + def custom_frame_function(self): + return "OK" + + data = {"col1": range(10), "col2": range(10)} + cdf = CustomDataFrame(data) + + # Did we get back our own DF class? + assert isinstance(cdf, CustomDataFrame) + + # Do we get back our own Series class after selecting a column? + cdf_series = cdf.col1 + assert isinstance(cdf_series, CustomSeries) + assert cdf_series.custom_series_function() == "OK" + + # Do we get back our own DF class after slicing row-wise? + cdf_rows = cdf[1:5] + assert isinstance(cdf_rows, CustomDataFrame) + assert cdf_rows.custom_frame_function() == "OK" + + # Make sure sliced part of multi-index frame is custom class + mcol = MultiIndex.from_tuples([("A", "A"), ("A", "B")]) + cdf_multi = CustomDataFrame([[0, 1], [2, 3]], columns=mcol) + assert isinstance(cdf_multi["A"], CustomDataFrame) + + mcol = MultiIndex.from_tuples([("A", ""), ("B", "")]) + cdf_multi2 = CustomDataFrame([[0, 1], [2, 3]], columns=mcol) + assert isinstance(cdf_multi2["A"], CustomSeries) + + def test_dataframe_metadata(self, temp_file): + df = tm.SubclassedDataFrame( + {"X": [1, 2, 3], "Y": [1, 2, 3]}, index=["a", "b", "c"] + ) + df.testattr = "XXX" + + assert df.testattr == "XXX" + assert df[["X"]].testattr == "XXX" + assert df.loc[["a", "b"], :].testattr == "XXX" + assert df.iloc[[0, 1], :].testattr == "XXX" + + # see gh-9776 + assert df.iloc[0:1, :].testattr == "XXX" + + # see gh-10553 + unpickled = tm.round_trip_pickle(df, temp_file) + tm.assert_frame_equal(df, unpickled) + assert df._metadata == unpickled._metadata + assert df.testattr == unpickled.testattr + + def test_indexing_sliced(self): + # GH 11559 + df = tm.SubclassedDataFrame( + {"X": [1, 2, 3], "Y": [4, 5, 6], "Z": [7, 8, 9]}, index=["a", "b", "c"] + ) + res = df.loc[:, "X"] + exp = tm.SubclassedSeries([1, 2, 3], index=list("abc"), name="X") + tm.assert_series_equal(res, exp) + assert isinstance(res, tm.SubclassedSeries) + + res = df.iloc[:, 1] + exp = tm.SubclassedSeries([4, 5, 6], index=list("abc"), name="Y") + tm.assert_series_equal(res, exp) + assert isinstance(res, tm.SubclassedSeries) + + res = df.loc[:, "Z"] + exp = tm.SubclassedSeries([7, 8, 9], index=list("abc"), name="Z") + tm.assert_series_equal(res, exp) + assert isinstance(res, tm.SubclassedSeries) + + res = df.loc["a", :] + exp = tm.SubclassedSeries([1, 4, 7], index=list("XYZ"), name="a") + tm.assert_series_equal(res, exp) + assert isinstance(res, tm.SubclassedSeries) + + res = df.iloc[1, :] + exp = tm.SubclassedSeries([2, 5, 8], index=list("XYZ"), name="b") + tm.assert_series_equal(res, exp) + assert isinstance(res, tm.SubclassedSeries) + + res = df.loc["c", :] + exp = tm.SubclassedSeries([3, 6, 9], index=list("XYZ"), name="c") + tm.assert_series_equal(res, exp) + assert isinstance(res, tm.SubclassedSeries) + + def test_subclass_attr_err_propagation(self): + # GH 11808 + class A(DataFrame): + @property + def nonexistence(self): + return self.i_dont_exist + + with pytest.raises(AttributeError, match=".*i_dont_exist.*"): + A().nonexistence + + def test_subclass_align(self): + # GH 12983 + df1 = tm.SubclassedDataFrame( + {"a": [1, 3, 5], "b": [1, 3, 5]}, index=list("ACE") + ) + df2 = tm.SubclassedDataFrame( + {"c": [1, 2, 4], "d": [1, 2, 4]}, index=list("ABD") + ) + + res1, res2 = df1.align(df2, axis=0) + exp1 = tm.SubclassedDataFrame( + {"a": [1, np.nan, 3, np.nan, 5], "b": [1, np.nan, 3, np.nan, 5]}, + index=list("ABCDE"), + ) + exp2 = tm.SubclassedDataFrame( + {"c": [1, 2, np.nan, 4, np.nan], "d": [1, 2, np.nan, 4, np.nan]}, + index=list("ABCDE"), + ) + assert isinstance(res1, tm.SubclassedDataFrame) + tm.assert_frame_equal(res1, exp1) + assert isinstance(res2, tm.SubclassedDataFrame) + tm.assert_frame_equal(res2, exp2) + + res1, res2 = df1.a.align(df2.c) + assert isinstance(res1, tm.SubclassedSeries) + tm.assert_series_equal(res1, exp1.a) + assert isinstance(res2, tm.SubclassedSeries) + tm.assert_series_equal(res2, exp2.c) + + def test_subclass_align_combinations(self): + # GH 12983 + df = tm.SubclassedDataFrame({"a": [1, 3, 5], "b": [1, 3, 5]}, index=list("ACE")) + s = tm.SubclassedSeries([1, 2, 4], index=list("ABD"), name="x") + + # frame + series + res1, res2 = df.align(s, axis=0) + exp1 = tm.SubclassedDataFrame( + {"a": [1, np.nan, 3, np.nan, 5], "b": [1, np.nan, 3, np.nan, 5]}, + index=list("ABCDE"), + ) + # name is lost when + exp2 = tm.SubclassedSeries( + [1, 2, np.nan, 4, np.nan], index=list("ABCDE"), name="x" + ) + + assert isinstance(res1, tm.SubclassedDataFrame) + tm.assert_frame_equal(res1, exp1) + assert isinstance(res2, tm.SubclassedSeries) + tm.assert_series_equal(res2, exp2) + + # series + frame + res1, res2 = s.align(df) + assert isinstance(res1, tm.SubclassedSeries) + tm.assert_series_equal(res1, exp2) + assert isinstance(res2, tm.SubclassedDataFrame) + tm.assert_frame_equal(res2, exp1) + + def test_subclass_iterrows(self): + # GH 13977 + df = tm.SubclassedDataFrame({"a": [1]}) + for i, row in df.iterrows(): + assert isinstance(row, tm.SubclassedSeries) + tm.assert_series_equal(row, df.loc[i]) + + def test_subclass_stack(self): + # GH 15564 + df = tm.SubclassedDataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + index=["a", "b", "c"], + columns=["X", "Y", "Z"], + ) + + res = df.stack() + exp = tm.SubclassedSeries( + [1, 2, 3, 4, 5, 6, 7, 8, 9], index=[list("aaabbbccc"), list("XYZXYZXYZ")] + ) + + tm.assert_series_equal(res, exp) + + def test_subclass_stack_multi(self): + # GH 15564 + df = tm.SubclassedDataFrame( + [[10, 11, 12, 13], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43]], + index=MultiIndex.from_tuples( + list(zip(list("AABB"), list("cdcd"))), names=["aaa", "ccc"] + ), + columns=MultiIndex.from_tuples( + list(zip(list("WWXX"), list("yzyz"))), names=["www", "yyy"] + ), + ) + + exp = tm.SubclassedDataFrame( + [ + [10, 12], + [11, 13], + [20, 22], + [21, 23], + [30, 32], + [31, 33], + [40, 42], + [41, 43], + ], + index=MultiIndex.from_tuples( + list(zip(list("AAAABBBB"), list("ccddccdd"), list("yzyzyzyz"))), + names=["aaa", "ccc", "yyy"], + ), + columns=Index(["W", "X"], name="www"), + ) + + res = df.stack() + tm.assert_frame_equal(res, exp) + + res = df.stack("yyy") + tm.assert_frame_equal(res, exp) + + exp = tm.SubclassedDataFrame( + [ + [10, 11], + [12, 13], + [20, 21], + [22, 23], + [30, 31], + [32, 33], + [40, 41], + [42, 43], + ], + index=MultiIndex.from_tuples( + list(zip(list("AAAABBBB"), list("ccddccdd"), list("WXWXWXWX"))), + names=["aaa", "ccc", "www"], + ), + columns=Index(["y", "z"], name="yyy"), + ) + + res = df.stack("www") + tm.assert_frame_equal(res, exp) + + def test_subclass_stack_multi_mixed(self): + # GH 15564 + df = tm.SubclassedDataFrame( + [ + [10, 11, 12.0, 13.0], + [20, 21, 22.0, 23.0], + [30, 31, 32.0, 33.0], + [40, 41, 42.0, 43.0], + ], + index=MultiIndex.from_tuples( + list(zip(list("AABB"), list("cdcd"))), names=["aaa", "ccc"] + ), + columns=MultiIndex.from_tuples( + list(zip(list("WWXX"), list("yzyz"))), names=["www", "yyy"] + ), + ) + + exp = tm.SubclassedDataFrame( + [ + [10, 12.0], + [11, 13.0], + [20, 22.0], + [21, 23.0], + [30, 32.0], + [31, 33.0], + [40, 42.0], + [41, 43.0], + ], + index=MultiIndex.from_tuples( + list(zip(list("AAAABBBB"), list("ccddccdd"), list("yzyzyzyz"))), + names=["aaa", "ccc", "yyy"], + ), + columns=Index(["W", "X"], name="www"), + ) + + res = df.stack() + tm.assert_frame_equal(res, exp) + + res = df.stack("yyy") + tm.assert_frame_equal(res, exp) + + exp = tm.SubclassedDataFrame( + [ + [10.0, 11.0], + [12.0, 13.0], + [20.0, 21.0], + [22.0, 23.0], + [30.0, 31.0], + [32.0, 33.0], + [40.0, 41.0], + [42.0, 43.0], + ], + index=MultiIndex.from_tuples( + list(zip(list("AAAABBBB"), list("ccddccdd"), list("WXWXWXWX"))), + names=["aaa", "ccc", "www"], + ), + columns=Index(["y", "z"], name="yyy"), + ) + + res = df.stack("www") + tm.assert_frame_equal(res, exp) + + def test_subclass_unstack(self): + # GH 15564 + df = tm.SubclassedDataFrame( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + index=["a", "b", "c"], + columns=["X", "Y", "Z"], + ) + + res = df.unstack() + exp = tm.SubclassedSeries( + [1, 4, 7, 2, 5, 8, 3, 6, 9], index=[list("XXXYYYZZZ"), list("abcabcabc")] + ) + + tm.assert_series_equal(res, exp) + + def test_subclass_unstack_multi(self): + # GH 15564 + df = tm.SubclassedDataFrame( + [[10, 11, 12, 13], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43]], + index=MultiIndex.from_tuples( + list(zip(list("AABB"), list("cdcd"))), names=["aaa", "ccc"] + ), + columns=MultiIndex.from_tuples( + list(zip(list("WWXX"), list("yzyz"))), names=["www", "yyy"] + ), + ) + + exp = tm.SubclassedDataFrame( + [[10, 20, 11, 21, 12, 22, 13, 23], [30, 40, 31, 41, 32, 42, 33, 43]], + index=Index(["A", "B"], name="aaa"), + columns=MultiIndex.from_tuples( + list(zip(list("WWWWXXXX"), list("yyzzyyzz"), list("cdcdcdcd"))), + names=["www", "yyy", "ccc"], + ), + ) + + res = df.unstack() + tm.assert_frame_equal(res, exp) + + res = df.unstack("ccc") + tm.assert_frame_equal(res, exp) + + exp = tm.SubclassedDataFrame( + [[10, 30, 11, 31, 12, 32, 13, 33], [20, 40, 21, 41, 22, 42, 23, 43]], + index=Index(["c", "d"], name="ccc"), + columns=MultiIndex.from_tuples( + list(zip(list("WWWWXXXX"), list("yyzzyyzz"), list("ABABABAB"))), + names=["www", "yyy", "aaa"], + ), + ) + + res = df.unstack("aaa") + tm.assert_frame_equal(res, exp) + + def test_subclass_unstack_multi_mixed(self): + # GH 15564 + df = tm.SubclassedDataFrame( + [ + [10, 11, 12.0, 13.0], + [20, 21, 22.0, 23.0], + [30, 31, 32.0, 33.0], + [40, 41, 42.0, 43.0], + ], + index=MultiIndex.from_tuples( + list(zip(list("AABB"), list("cdcd"))), names=["aaa", "ccc"] + ), + columns=MultiIndex.from_tuples( + list(zip(list("WWXX"), list("yzyz"))), names=["www", "yyy"] + ), + ) + + exp = tm.SubclassedDataFrame( + [ + [10, 20, 11, 21, 12.0, 22.0, 13.0, 23.0], + [30, 40, 31, 41, 32.0, 42.0, 33.0, 43.0], + ], + index=Index(["A", "B"], name="aaa"), + columns=MultiIndex.from_tuples( + list(zip(list("WWWWXXXX"), list("yyzzyyzz"), list("cdcdcdcd"))), + names=["www", "yyy", "ccc"], + ), + ) + + res = df.unstack() + tm.assert_frame_equal(res, exp) + + res = df.unstack("ccc") + tm.assert_frame_equal(res, exp) + + exp = tm.SubclassedDataFrame( + [ + [10, 30, 11, 31, 12.0, 32.0, 13.0, 33.0], + [20, 40, 21, 41, 22.0, 42.0, 23.0, 43.0], + ], + index=Index(["c", "d"], name="ccc"), + columns=MultiIndex.from_tuples( + list(zip(list("WWWWXXXX"), list("yyzzyyzz"), list("ABABABAB"))), + names=["www", "yyy", "aaa"], + ), + ) + + res = df.unstack("aaa") + tm.assert_frame_equal(res, exp) + + def test_subclass_pivot(self): + # GH 15564 + df = tm.SubclassedDataFrame( + { + "index": ["A", "B", "C", "C", "B", "A"], + "columns": ["One", "One", "One", "Two", "Two", "Two"], + "values": [1.0, 2.0, 3.0, 3.0, 2.0, 1.0], + } + ) + + pivoted = df.pivot(index="index", columns="columns", values="values") + + expected = tm.SubclassedDataFrame( + { + "One": {"A": 1.0, "B": 2.0, "C": 3.0}, + "Two": {"A": 1.0, "B": 2.0, "C": 3.0}, + } + ) + + expected.index.name, expected.columns.name = "index", "columns" + + tm.assert_frame_equal(pivoted, expected) + + def test_subclassed_melt(self): + # GH 15564 + cheese = tm.SubclassedDataFrame( + { + "first": ["John", "Mary"], + "last": ["Doe", "Bo"], + "height": [5.5, 6.0], + "weight": [130, 150], + } + ) + + melted = pd.melt(cheese, id_vars=["first", "last"]) + + expected = tm.SubclassedDataFrame( + [ + ["John", "Doe", "height", 5.5], + ["Mary", "Bo", "height", 6.0], + ["John", "Doe", "weight", 130], + ["Mary", "Bo", "weight", 150], + ], + columns=["first", "last", "variable", "value"], + ) + + tm.assert_frame_equal(melted, expected) + + def test_subclassed_wide_to_long(self): + # GH 9762 + + x = np.random.default_rng(2).standard_normal(3) + df = tm.SubclassedDataFrame( + { + "A1970": {0: "a", 1: "b", 2: "c"}, + "A1980": {0: "d", 1: "e", 2: "f"}, + "B1970": {0: 2.5, 1: 1.2, 2: 0.7}, + "B1980": {0: 3.2, 1: 1.3, 2: 0.1}, + "X": dict(zip(range(3), x)), + } + ) + + df["id"] = df.index + exp_data = { + "X": x.tolist() + x.tolist(), + "A": ["a", "b", "c", "d", "e", "f"], + "B": [2.5, 1.2, 0.7, 3.2, 1.3, 0.1], + "year": [1970, 1970, 1970, 1980, 1980, 1980], + "id": [0, 1, 2, 0, 1, 2], + } + expected = tm.SubclassedDataFrame(exp_data) + expected = expected.set_index(["id", "year"])[["X", "A", "B"]] + long_frame = pd.wide_to_long(df, ["A", "B"], i="id", j="year") + + tm.assert_frame_equal(long_frame, expected) + + def test_subclassed_apply(self): + # GH 19822 + + def check_row_subclass(row): + assert isinstance(row, tm.SubclassedSeries) + + def stretch(row): + if row["variable"] == "height": + row["value"] += 0.5 + return row + + df = tm.SubclassedDataFrame( + [ + ["John", "Doe", "height", 5.5], + ["Mary", "Bo", "height", 6.0], + ["John", "Doe", "weight", 130], + ["Mary", "Bo", "weight", 150], + ], + columns=["first", "last", "variable", "value"], + ) + + df.apply(lambda x: check_row_subclass(x)) + df.apply(lambda x: check_row_subclass(x), axis=1) + + expected = tm.SubclassedDataFrame( + [ + ["John", "Doe", "height", 6.0], + ["Mary", "Bo", "height", 6.5], + ["John", "Doe", "weight", 130], + ["Mary", "Bo", "weight", 150], + ], + columns=["first", "last", "variable", "value"], + ) + + result = df.apply(lambda x: stretch(x), axis=1) + assert isinstance(result, tm.SubclassedDataFrame) + tm.assert_frame_equal(result, expected) + + expected = tm.SubclassedDataFrame([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) + + result = df.apply(lambda x: tm.SubclassedSeries([1, 2, 3]), axis=1) + assert isinstance(result, tm.SubclassedDataFrame) + tm.assert_frame_equal(result, expected) + + result = df.apply(lambda x: [1, 2, 3], axis=1, result_type="expand") + assert isinstance(result, tm.SubclassedDataFrame) + tm.assert_frame_equal(result, expected) + + expected = tm.SubclassedSeries([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) + + result = df.apply(lambda x: [1, 2, 3], axis=1) + assert not isinstance(result, tm.SubclassedDataFrame) + tm.assert_series_equal(result, expected) + + def test_subclassed_reductions(self, all_reductions): + # GH 25596 + + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = getattr(df, all_reductions)() + assert isinstance(result, tm.SubclassedSeries) + + def test_subclassed_count(self): + df = tm.SubclassedDataFrame( + { + "Person": ["John", "Myla", "Lewis", "John", "Myla"], + "Age": [24.0, np.nan, 21.0, 33, 26], + "Single": [False, True, True, True, False], + } + ) + result = df.count() + assert isinstance(result, tm.SubclassedSeries) + + df = tm.SubclassedDataFrame({"A": [1, 0, 3], "B": [0, 5, 6], "C": [7, 8, 0]}) + result = df.count() + assert isinstance(result, tm.SubclassedSeries) + + df = tm.SubclassedDataFrame( + [[10, 11, 12, 13], [20, 21, 22, 23], [30, 31, 32, 33], [40, 41, 42, 43]], + index=MultiIndex.from_tuples( + list(zip(list("AABB"), list("cdcd"))), names=["aaa", "ccc"] + ), + columns=MultiIndex.from_tuples( + list(zip(list("WWXX"), list("yzyz"))), names=["www", "yyy"] + ), + ) + result = df.count() + assert isinstance(result, tm.SubclassedSeries) + + df = tm.SubclassedDataFrame() + result = df.count() + assert isinstance(result, tm.SubclassedSeries) + + def test_isin(self): + df = tm.SubclassedDataFrame( + {"num_legs": [2, 4], "num_wings": [2, 0]}, index=["falcon", "dog"] + ) + result = df.isin([0, 2]) + assert isinstance(result, tm.SubclassedDataFrame) + + def test_duplicated(self): + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = df.duplicated() + assert isinstance(result, tm.SubclassedSeries) + + df = tm.SubclassedDataFrame() + result = df.duplicated() + assert isinstance(result, tm.SubclassedSeries) + + @pytest.mark.parametrize("idx_method", ["idxmax", "idxmin"]) + def test_idx(self, idx_method): + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = getattr(df, idx_method)() + assert isinstance(result, tm.SubclassedSeries) + + def test_dot(self): + df = tm.SubclassedDataFrame([[0, 1, -2, -1], [1, 1, 1, 1]]) + s = tm.SubclassedSeries([1, 1, 2, 1]) + result = df.dot(s) + assert isinstance(result, tm.SubclassedSeries) + + df = tm.SubclassedDataFrame([[0, 1, -2, -1], [1, 1, 1, 1]]) + s = tm.SubclassedDataFrame([1, 1, 2, 1]) + result = df.dot(s) + assert isinstance(result, tm.SubclassedDataFrame) + + def test_memory_usage(self): + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = df.memory_usage() + assert isinstance(result, tm.SubclassedSeries) + + result = df.memory_usage(index=False) + assert isinstance(result, tm.SubclassedSeries) + + def test_corrwith(self): + pytest.importorskip("scipy") + index = ["a", "b", "c", "d", "e"] + columns = ["one", "two", "three", "four"] + df1 = tm.SubclassedDataFrame( + np.random.default_rng(2).standard_normal((5, 4)), + index=index, + columns=columns, + ) + df2 = tm.SubclassedDataFrame( + np.random.default_rng(2).standard_normal((4, 4)), + index=index[:4], + columns=columns, + ) + correls = df1.corrwith(df2, axis=1, drop=True, method="kendall") + + assert isinstance(correls, (tm.SubclassedSeries)) + + def test_asof(self): + N = 3 + rng = pd.date_range("1/1/1990", periods=N, freq="53s") + df = tm.SubclassedDataFrame( + { + "A": [np.nan, np.nan, np.nan], + "B": [np.nan, np.nan, np.nan], + "C": [np.nan, np.nan, np.nan], + }, + index=rng, + ) + + result = df.asof(rng[-2:]) + assert isinstance(result, tm.SubclassedDataFrame) + + result = df.asof(rng[-2]) + assert isinstance(result, tm.SubclassedSeries) + + result = df.asof("1989-12-31") + assert isinstance(result, tm.SubclassedSeries) + + def test_idxmin_preserves_subclass(self): + # GH 28330 + + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = df.idxmin() + assert isinstance(result, tm.SubclassedSeries) + + def test_idxmax_preserves_subclass(self): + # GH 28330 + + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = df.idxmax() + assert isinstance(result, tm.SubclassedSeries) + + def test_convert_dtypes_preserves_subclass(self): + # GH 43668 + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + result = df.convert_dtypes() + assert isinstance(result, tm.SubclassedDataFrame) + + def test_convert_dtypes_preserves_subclass_with_constructor(self): + class SubclassedDataFrame(DataFrame): + @property + def _constructor(self): + return SubclassedDataFrame + + df = SubclassedDataFrame({"a": [1, 2, 3]}) + result = df.convert_dtypes() + assert isinstance(result, SubclassedDataFrame) + + def test_astype_preserves_subclass(self): + # GH#40810 + df = tm.SubclassedDataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]}) + + result = df.astype({"A": np.int64, "B": np.int32, "C": np.float64}) + assert isinstance(result, tm.SubclassedDataFrame) + + def test_equals_subclass(self): + # https://github.com/pandas-dev/pandas/pull/34402 + # allow subclass in both directions + df1 = DataFrame({"a": [1, 2, 3]}) + df2 = tm.SubclassedDataFrame({"a": [1, 2, 3]}) + assert df1.equals(df2) + assert df2.equals(df1) + + +class MySubclassWithMetadata(DataFrame): + _metadata = ["my_metadata"] + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + my_metadata = kwargs.pop("my_metadata", None) + if args and isinstance(args[0], MySubclassWithMetadata): + my_metadata = args[0].my_metadata # type: ignore[has-type] + self.my_metadata = my_metadata + + @property + def _constructor(self): + return MySubclassWithMetadata + + +def test_constructor_with_metadata(): + # https://github.com/pandas-dev/pandas/pull/54922 + # https://github.com/pandas-dev/pandas/issues/55120 + df = MySubclassWithMetadata( + np.random.default_rng(2).random((5, 3)), columns=["A", "B", "C"] + ) + subset = df[["A", "B"]] + assert isinstance(subset, MySubclassWithMetadata) + + +def test_constructor_with_metadata_from_records(): + # GH#57008 + df = MySubclassWithMetadata.from_records([{"a": 1, "b": 2}]) + assert df.my_metadata is None + assert type(df) is MySubclassWithMetadata + + +class SimpleDataFrameSubClass(DataFrame): + """A subclass of DataFrame that does not define a constructor.""" + + +class SimpleSeriesSubClass(Series): + """A subclass of Series that does not define a constructor.""" + + +class TestSubclassWithoutConstructor: + def test_copy_df(self): + expected = DataFrame({"a": [1, 2, 3]}) + result = SimpleDataFrameSubClass(expected).copy() + + assert ( + type(result) is DataFrame + ) # assert_frame_equal only checks isinstance(lhs, type(rhs)) + tm.assert_frame_equal(result, expected) + + def test_copy_series(self): + expected = Series([1, 2, 3]) + result = SimpleSeriesSubClass(expected).copy() + + tm.assert_series_equal(result, expected) + + def test_series_to_frame(self): + orig = Series([1, 2, 3]) + expected = orig.to_frame() + result = SimpleSeriesSubClass(orig).to_frame() + + assert ( + type(result) is DataFrame + ) # assert_frame_equal only checks isinstance(lhs, type(rhs)) + tm.assert_frame_equal(result, expected) + + def test_groupby(self): + df = SimpleDataFrameSubClass(DataFrame({"a": [1, 2, 3]})) + + for _, v in df.groupby("a"): + assert type(v) is DataFrame diff --git a/pandas/tests/frame/test_ufunc.py b/pandas/tests/frame/test_ufunc.py new file mode 100644 index 0000000000000000000000000000000000000000..8d5a227652462e17025e1fcfae023bec296ad751 --- /dev/null +++ b/pandas/tests/frame/test_ufunc.py @@ -0,0 +1,312 @@ +from functools import partial +import re + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.api.types import is_extension_array_dtype + +dtypes = [ + "int64", + "Int64", + {"A": "int64", "B": "Int64"}, +] + + +@pytest.mark.parametrize("dtype", dtypes) +def test_unary_unary(dtype): + # unary input, unary output + values = np.array([[-1, -1], [1, 1]], dtype="int64") + df = pd.DataFrame(values, columns=["A", "B"], index=["a", "b"]).astype(dtype=dtype) + result = np.positive(df) + expected = pd.DataFrame( + np.positive(values), index=df.index, columns=df.columns + ).astype(dtype) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", dtypes) +def test_unary_binary(request, dtype): + # unary input, binary output + if is_extension_array_dtype(dtype) or isinstance(dtype, dict): + request.applymarker( + pytest.mark.xfail( + reason="Extension / mixed with multiple outputs not implemented." + ) + ) + + values = np.array([[-1, -1], [1, 1]], dtype="int64") + df = pd.DataFrame(values, columns=["A", "B"], index=["a", "b"]).astype(dtype=dtype) + result_pandas = np.modf(df) + assert isinstance(result_pandas, tuple) + assert len(result_pandas) == 2 + expected_numpy = np.modf(values) + + for result, b in zip(result_pandas, expected_numpy): + expected = pd.DataFrame(b, index=df.index, columns=df.columns) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", dtypes) +def test_binary_input_dispatch_binop(dtype): + # binop ufuncs are dispatched to our dunder methods. + values = np.array([[-1, -1], [1, 1]], dtype="int64") + df = pd.DataFrame(values, columns=["A", "B"], index=["a", "b"]).astype(dtype=dtype) + result = np.add(df, df) + expected = pd.DataFrame( + np.add(values, values), index=df.index, columns=df.columns + ).astype(dtype) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "func,arg,expected", + [ + (np.add, 1, [2, 3, 4, 5]), + ( + partial(np.add, where=[[False, True], [True, False]]), + np.array([[1, 1], [1, 1]]), + [0, 3, 4, 0], + ), + (np.power, np.array([[1, 1], [2, 2]]), [1, 2, 9, 16]), + (np.subtract, 2, [-1, 0, 1, 2]), + ( + partial(np.negative, where=np.array([[False, True], [True, False]])), + None, + [0, -2, -3, 0], + ), + ], +) +def test_ufunc_passes_args(func, arg, expected): + # GH#40662 + arr = np.array([[1, 2], [3, 4]]) + df = pd.DataFrame(arr) + result_inplace = np.zeros_like(arr) + # 1-argument ufunc + if arg is None: + result = func(df, out=result_inplace) + else: + result = func(df, arg, out=result_inplace) + + expected = np.array(expected).reshape(2, 2) + tm.assert_numpy_array_equal(result_inplace, expected) + + expected = pd.DataFrame(expected) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype_a", dtypes) +@pytest.mark.parametrize("dtype_b", dtypes) +def test_binary_input_aligns_columns(request, dtype_a, dtype_b): + if ( + is_extension_array_dtype(dtype_a) + or isinstance(dtype_a, dict) + or is_extension_array_dtype(dtype_b) + or isinstance(dtype_b, dict) + ): + request.applymarker( + pytest.mark.xfail( + reason="Extension / mixed with multiple inputs not implemented." + ) + ) + + df1 = pd.DataFrame({"A": [1, 2], "B": [3, 4]}).astype(dtype_a) + + if isinstance(dtype_a, dict) and isinstance(dtype_b, dict): + dtype_b = dtype_b.copy() + dtype_b["C"] = dtype_b.pop("B") + df2 = pd.DataFrame({"A": [1, 2], "C": [3, 4]}).astype(dtype_b) + # As of 2.0, align first before applying the ufunc + result = np.heaviside(df1, df2) + expected = np.heaviside( + np.array([[1, 3, np.nan], [2, 4, np.nan]]), + np.array([[1, np.nan, 3], [2, np.nan, 4]]), + ) + expected = pd.DataFrame(expected, index=[0, 1], columns=["A", "B", "C"]) + tm.assert_frame_equal(result, expected) + + result = np.heaviside(df1, df2.values) + expected = pd.DataFrame([[1.0, 1.0], [1.0, 1.0]], columns=["A", "B"]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", dtypes) +def test_binary_input_aligns_index(request, dtype): + if is_extension_array_dtype(dtype) or isinstance(dtype, dict): + request.applymarker( + pytest.mark.xfail( + reason="Extension / mixed with multiple inputs not implemented." + ) + ) + df1 = pd.DataFrame({"A": [1, 2], "B": [3, 4]}, index=["a", "b"]).astype(dtype) + df2 = pd.DataFrame({"A": [1, 2], "B": [3, 4]}, index=["a", "c"]).astype(dtype) + result = np.heaviside(df1, df2) + expected = np.heaviside( + np.array([[1, 3], [3, 4], [np.nan, np.nan]]), + np.array([[1, 3], [np.nan, np.nan], [3, 4]]), + ) + # TODO(FloatArray): this will be Float64Dtype. + expected = pd.DataFrame(expected, index=["a", "b", "c"], columns=["A", "B"]) + tm.assert_frame_equal(result, expected) + + result = np.heaviside(df1, df2.values) + expected = pd.DataFrame( + [[1.0, 1.0], [1.0, 1.0]], columns=["A", "B"], index=["a", "b"] + ) + tm.assert_frame_equal(result, expected) + + +def test_binary_frame_series_raises(): + # We don't currently implement + df = pd.DataFrame({"A": [1, 2]}) + with pytest.raises(NotImplementedError, match="logaddexp"): + np.logaddexp(df, df["A"]) + + with pytest.raises(NotImplementedError, match="logaddexp"): + np.logaddexp(df["A"], df) + + +def test_unary_accumulate_axis(): + # https://github.com/pandas-dev/pandas/issues/39259 + df = pd.DataFrame({"a": [1, 3, 2, 4]}) + result = np.maximum.accumulate(df) + expected = pd.DataFrame({"a": [1, 3, 3, 4]}) + tm.assert_frame_equal(result, expected) + + df = pd.DataFrame({"a": [1, 3, 2, 4], "b": [0.1, 4.0, 3.0, 2.0]}) + result = np.maximum.accumulate(df) + # in theory could preserve int dtype for default axis=0 + expected = pd.DataFrame({"a": [1.0, 3.0, 3.0, 4.0], "b": [0.1, 4.0, 4.0, 4.0]}) + tm.assert_frame_equal(result, expected) + + result = np.maximum.accumulate(df, axis=0) + tm.assert_frame_equal(result, expected) + + result = np.maximum.accumulate(df, axis=1) + expected = pd.DataFrame({"a": [1.0, 3.0, 2.0, 4.0], "b": [1.0, 4.0, 3.0, 4.0]}) + tm.assert_frame_equal(result, expected) + + +def test_frame_outer_disallowed(): + df = pd.DataFrame({"A": [1, 2]}) + with pytest.raises(NotImplementedError, match="^$"): + # deprecation enforced in 2.0 + np.subtract.outer(df, df) + + +def test_alignment_deprecation_enforced(): + # Enforced in 2.0 + # https://github.com/pandas-dev/pandas/issues/39184 + df1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df2 = pd.DataFrame({"b": [1, 2, 3], "c": [4, 5, 6]}) + s1 = pd.Series([1, 2], index=["a", "b"]) + s2 = pd.Series([1, 2], index=["b", "c"]) + + # binary dataframe / dataframe + expected = pd.DataFrame({"a": [2, 4, 6], "b": [8, 10, 12]}) + + with tm.assert_produces_warning(None): + # aligned -> no warning! + result = np.add(df1, df1) + tm.assert_frame_equal(result, expected) + + result = np.add(df1, df2.values) + tm.assert_frame_equal(result, expected) + + result = np.add(df1, df2) + expected = pd.DataFrame({"a": [np.nan] * 3, "b": [5, 7, 9], "c": [np.nan] * 3}) + tm.assert_frame_equal(result, expected) + + result = np.add(df1.values, df2) + expected = pd.DataFrame({"b": [2, 4, 6], "c": [8, 10, 12]}) + tm.assert_frame_equal(result, expected) + + # binary dataframe / series + expected = pd.DataFrame({"a": [2, 3, 4], "b": [6, 7, 8]}) + + with tm.assert_produces_warning(None): + # aligned -> no warning! + result = np.add(df1, s1) + tm.assert_frame_equal(result, expected) + + result = np.add(df1, s2.values) + tm.assert_frame_equal(result, expected) + + expected = pd.DataFrame( + {"a": [np.nan] * 3, "b": [5.0, 6.0, 7.0], "c": [np.nan] * 3} + ) + result = np.add(df1, s2) + tm.assert_frame_equal(result, expected) + + msg = "Cannot apply ufunc to mixed DataFrame and Series inputs." + with pytest.raises(NotImplementedError, match=msg): + np.add(s2, df1) + + +@pytest.mark.single_cpu +def test_alignment_deprecation_many_inputs_enforced(): + # Enforced in 2.0 + # https://github.com/pandas-dev/pandas/issues/39184 + # test that the deprecation also works with > 2 inputs -> using a numba + # written ufunc for this because numpy itself doesn't have such ufuncs + numba = pytest.importorskip("numba") + + @numba.vectorize([numba.float64(numba.float64, numba.float64, numba.float64)]) + def my_ufunc(x, y, z): + return x + y + z + + df1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df2 = pd.DataFrame({"b": [1, 2, 3], "c": [4, 5, 6]}) + df3 = pd.DataFrame({"a": [1, 2, 3], "c": [4, 5, 6]}) + + result = my_ufunc(df1, df2, df3) + expected = pd.DataFrame(np.full((3, 3), np.nan), columns=["a", "b", "c"]) + tm.assert_frame_equal(result, expected) + + # all aligned -> no warning + with tm.assert_produces_warning(None): + result = my_ufunc(df1, df1, df1) + expected = pd.DataFrame([[3.0, 12.0], [6.0, 15.0], [9.0, 18.0]], columns=["a", "b"]) + tm.assert_frame_equal(result, expected) + + # mixed frame / arrays + msg = ( + r"operands could not be broadcast together with shapes \(3,3\) \(3,3\) \(3,2\)" + ) + with pytest.raises(ValueError, match=msg): + my_ufunc(df1, df2, df3.values) + + # single frame -> no warning + with tm.assert_produces_warning(None): + result = my_ufunc(df1, df2.values, df3.values) + tm.assert_frame_equal(result, expected) + + # takes indices of first frame + msg = ( + r"operands could not be broadcast together with shapes \(3,2\) \(3,3\) \(3,3\)" + ) + with pytest.raises(ValueError, match=msg): + my_ufunc(df1.values, df2, df3) + + +def test_array_ufuncs_for_many_arguments(): + # GH39853 + def add3(x, y, z): + return x + y + z + + ufunc = np.frompyfunc(add3, 3, 1) + df = pd.DataFrame([[1, 2], [3, 4]]) + + result = ufunc(df, df, 1) + expected = pd.DataFrame([[3, 5], [7, 9]], dtype=object) + tm.assert_frame_equal(result, expected) + + ser = pd.Series([1, 2]) + msg = ( + "Cannot apply ufunc " + "to mixed DataFrame and Series inputs." + ) + with pytest.raises(NotImplementedError, match=re.escape(msg)): + ufunc(df, df, ser) diff --git a/pandas/tests/frame/test_unary.py b/pandas/tests/frame/test_unary.py new file mode 100644 index 0000000000000000000000000000000000000000..034a43ac40bbafee06eb6cc079d7b820ccedb65b --- /dev/null +++ b/pandas/tests/frame/test_unary.py @@ -0,0 +1,180 @@ +from decimal import Decimal + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + + +class TestDataFrameUnaryOperators: + # __pos__, __neg__, __invert__ + + @pytest.mark.parametrize( + "df_data,expected_data", + [ + ([-1, 1], [1, -1]), + ([False, True], [True, False]), + (pd.to_timedelta([-1, 1]), pd.to_timedelta([1, -1])), + ], + ) + def test_neg_numeric(self, df_data, expected_data): + df = pd.DataFrame({"a": df_data}) + expected = pd.DataFrame({"a": expected_data}) + tm.assert_frame_equal(-df, expected) + tm.assert_series_equal(-df["a"], expected["a"]) + + @pytest.mark.parametrize( + "df, expected", + [ + (np.array([1, 2], dtype=object), np.array([-1, -2], dtype=object)), + ([Decimal("1.0"), Decimal("2.0")], [Decimal("-1.0"), Decimal("-2.0")]), + ], + ) + def test_neg_object(self, df, expected): + # GH#21380 + df = pd.DataFrame({"a": df}) + expected = pd.DataFrame({"a": expected}) + tm.assert_frame_equal(-df, expected) + tm.assert_series_equal(-df["a"], expected["a"]) + + @pytest.mark.parametrize( + "df_data", + [ + ["a", "b"], + pd.to_datetime(["2017-01-22", "1970-01-01"]), + ], + ) + def test_neg_raises(self, df_data, using_infer_string): + df = pd.DataFrame({"a": df_data}) + msg = ( + "bad operand type for unary -: 'str'|" + r"bad operand type for unary -: 'DatetimeArray'|" + "unary '-' not supported for dtype" + ) + with pytest.raises(TypeError, match=msg): + (-df) + with pytest.raises(TypeError, match=msg): + (-df["a"]) + + def test_invert(self, float_frame): + df = float_frame + + tm.assert_frame_equal(-(df < 0), ~(df < 0)) + + def test_invert_mixed(self): + shape = (10, 5) + df = pd.concat( + [ + pd.DataFrame(np.zeros(shape, dtype="bool")), + pd.DataFrame(np.zeros(shape, dtype=int)), + ], + axis=1, + ignore_index=True, + ) + result = ~df + expected = pd.concat( + [ + pd.DataFrame(np.ones(shape, dtype="bool")), + pd.DataFrame(-np.ones(shape, dtype=int)), + ], + axis=1, + ignore_index=True, + ) + tm.assert_frame_equal(result, expected) + + def test_invert_empty_not_input(self): + # GH#51032 + df = pd.DataFrame() + result = ~df + tm.assert_frame_equal(df, result) + assert df is not result + + @pytest.mark.parametrize( + "df_data", + [ + [-1, 1], + [False, True], + pd.to_timedelta([-1, 1]), + ], + ) + def test_pos_numeric(self, df_data): + # GH#16073 + df = pd.DataFrame({"a": df_data}) + tm.assert_frame_equal(+df, df) + tm.assert_series_equal(+df["a"], df["a"]) + + @pytest.mark.parametrize( + "df_data", + [ + np.array([-1, 2], dtype=object), + [Decimal("-1.0"), Decimal("2.0")], + ], + ) + def test_pos_object(self, df_data): + # GH#21380 + df = pd.DataFrame({"a": df_data}) + tm.assert_frame_equal(+df, df) + tm.assert_series_equal(+df["a"], df["a"]) + + @pytest.mark.filterwarnings("ignore:Applying:DeprecationWarning") + def test_pos_object_raises(self): + # GH#21380 + df = pd.DataFrame({"a": ["a", "b"]}) + with pytest.raises( + TypeError, match=r"^bad operand type for unary \+: \'str\'$" + ): + tm.assert_frame_equal(+df, df) + + def test_pos_raises(self): + df = pd.DataFrame({"a": pd.to_datetime(["2017-01-22", "1970-01-01"])}) + msg = r"bad operand type for unary \+: 'DatetimeArray'" + with pytest.raises(TypeError, match=msg): + (+df) + with pytest.raises(TypeError, match=msg): + (+df["a"]) + + def test_unary_nullable(self): + df = pd.DataFrame( + { + "a": pd.array([1, -2, 3, pd.NA], dtype="Int64"), + "b": pd.array([4.0, -5.0, 6.0, pd.NA], dtype="Float32"), + "c": pd.array([True, False, False, pd.NA], dtype="boolean"), + # include numpy bool to make sure bool-vs-boolean behavior + # is consistent in non-NA locations + "d": np.array([True, False, False, True]), + } + ) + + result = +df + res_ufunc = np.positive(df) + expected = df + # TODO: assert that we have copies? + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(res_ufunc, expected) + + result = -df + res_ufunc = np.negative(df) + expected = pd.DataFrame( + { + "a": pd.array([-1, 2, -3, pd.NA], dtype="Int64"), + "b": pd.array([-4.0, 5.0, -6.0, pd.NA], dtype="Float32"), + "c": pd.array([False, True, True, pd.NA], dtype="boolean"), + "d": np.array([False, True, True, False]), + } + ) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(res_ufunc, expected) + + result = abs(df) + res_ufunc = np.abs(df) + expected = pd.DataFrame( + { + "a": pd.array([1, 2, 3, pd.NA], dtype="Int64"), + "b": pd.array([4.0, 5.0, 6.0, pd.NA], dtype="Float32"), + "c": pd.array([True, False, False, pd.NA], dtype="boolean"), + "d": np.array([True, False, False, True]), + } + ) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(res_ufunc, expected) diff --git a/pandas/tests/frame/test_validate.py b/pandas/tests/frame/test_validate.py new file mode 100644 index 0000000000000000000000000000000000000000..fdeecba29a6177444df6141487505e24d284c285 --- /dev/null +++ b/pandas/tests/frame/test_validate.py @@ -0,0 +1,37 @@ +import pytest + +from pandas.core.frame import DataFrame + + +class TestDataFrameValidate: + """Tests for error handling related to data types of method arguments.""" + + @pytest.mark.parametrize( + "func", + [ + "query", + "eval", + "set_index", + "reset_index", + "dropna", + "drop_duplicates", + "sort_values", + ], + ) + @pytest.mark.parametrize("inplace", [1, "True", [1, 2, 3], 5.0]) + def test_validate_bool_args(self, func, inplace): + dataframe = DataFrame({"a": [1, 2], "b": [3, 4]}) + msg = 'For argument "inplace" expected type bool' + kwargs = {"inplace": inplace} + + if func == "query": + kwargs["expr"] = "a > b" + elif func == "eval": + kwargs["expr"] = "a + b" + elif func == "set_index": + kwargs["keys"] = ["a"] + elif func == "sort_values": + kwargs["by"] = ["a"] + + with pytest.raises(ValueError, match=msg): + getattr(dataframe, func)(**kwargs) diff --git a/pandas/tests/generic/__init__.py b/pandas/tests/generic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/generic/test_duplicate_labels.py b/pandas/tests/generic/test_duplicate_labels.py new file mode 100644 index 0000000000000000000000000000000000000000..4de8c8df852f475d0bbea99ec1f0b8bbdbae117a --- /dev/null +++ b/pandas/tests/generic/test_duplicate_labels.py @@ -0,0 +1,390 @@ +"""Tests dealing with the NDFrame.allows_duplicates.""" + +import operator + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + +not_implemented = pytest.mark.xfail(reason="Not implemented.") + +# ---------------------------------------------------------------------------- +# Preservation + + +class TestPreserves: + @pytest.mark.parametrize( + "cls, data", + [ + (pd.Series, np.array([])), + (pd.Series, [1, 2]), + (pd.DataFrame, {}), + (pd.DataFrame, {"A": [1, 2]}), + ], + ) + def test_construction_ok(self, cls, data): + result = cls(data) + assert result.flags.allows_duplicate_labels is True + + result = cls(data).set_flags(allows_duplicate_labels=False) + assert result.flags.allows_duplicate_labels is False + + @pytest.mark.parametrize( + "func", + [ + operator.itemgetter(["a"]), + operator.methodcaller("add", 1), + operator.methodcaller("rename", str.upper), + operator.methodcaller("rename", "name"), + operator.methodcaller("abs"), + np.abs, + ], + ) + def test_preserved_series(self, func): + s = pd.Series([0, 1], index=["a", "b"]).set_flags(allows_duplicate_labels=False) + assert func(s).flags.allows_duplicate_labels is False + + @pytest.mark.parametrize("index", [["a", "b", "c"], ["a", "b"]]) + # TODO: frame + @not_implemented + def test_align(self, index): + other = pd.Series(0, index=index) + s = pd.Series([0, 1], index=["a", "b"]).set_flags(allows_duplicate_labels=False) + a, b = s.align(other) + assert a.flags.allows_duplicate_labels is False + assert b.flags.allows_duplicate_labels is False + + def test_preserved_frame(self): + df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}, index=["a", "b"]).set_flags( + allows_duplicate_labels=False + ) + assert df.loc[["a"]].flags.allows_duplicate_labels is False + assert df.loc[:, ["A", "B"]].flags.allows_duplicate_labels is False + + def test_to_frame(self): + ser = pd.Series(dtype=float).set_flags(allows_duplicate_labels=False) + assert ser.to_frame().flags.allows_duplicate_labels is False + + @pytest.mark.parametrize("func", ["add", "sub"]) + @pytest.mark.parametrize("frame", [False, True]) + @pytest.mark.parametrize("other", [1, pd.Series([1, 2], name="A")]) + def test_binops(self, func, other, frame): + df = pd.Series([1, 2], name="A", index=["a", "b"]).set_flags( + allows_duplicate_labels=False + ) + if frame: + df = df.to_frame() + if isinstance(other, pd.Series) and frame: + other = other.to_frame() + func = operator.methodcaller(func, other) + assert df.flags.allows_duplicate_labels is False + assert func(df).flags.allows_duplicate_labels is False + + def test_preserve_getitem(self): + df = pd.DataFrame({"A": [1, 2]}).set_flags(allows_duplicate_labels=False) + assert df[["A"]].flags.allows_duplicate_labels is False + assert df["A"].flags.allows_duplicate_labels is False + assert df.loc[0].flags.allows_duplicate_labels is False + assert df.loc[[0]].flags.allows_duplicate_labels is False + assert df.loc[0, ["A"]].flags.allows_duplicate_labels is False + + @pytest.mark.parametrize( + "objs, kwargs", + [ + # Series + ( + [ + pd.Series(1, index=["a", "b"]), + pd.Series(2, index=["c", "d"]), + ], + {}, + ), + ( + [ + pd.Series(1, index=["a", "b"]), + pd.Series(2, index=["a", "b"]), + ], + {"ignore_index": True}, + ), + ( + [ + pd.Series(1, index=["a", "b"]), + pd.Series(2, index=["a", "b"]), + ], + {"axis": 1}, + ), + # Frame + ( + [ + pd.DataFrame({"A": [1, 2]}, index=["a", "b"]), + pd.DataFrame({"A": [1, 2]}, index=["c", "d"]), + ], + {}, + ), + ( + [ + pd.DataFrame({"A": [1, 2]}, index=["a", "b"]), + pd.DataFrame({"A": [1, 2]}, index=["a", "b"]), + ], + {"ignore_index": True}, + ), + ( + [ + pd.DataFrame({"A": [1, 2]}, index=["a", "b"]), + pd.DataFrame({"B": [1, 2]}, index=["a", "b"]), + ], + {"axis": 1}, + ), + # Series / Frame + ( + [ + pd.DataFrame({"A": [1, 2]}, index=["a", "b"]), + pd.Series([1, 2], index=["a", "b"], name="B"), + ], + {"axis": 1}, + ), + ], + ) + def test_concat(self, objs, kwargs): + objs = [x.set_flags(allows_duplicate_labels=False) for x in objs] + result = pd.concat(objs, **kwargs) + assert result.flags.allows_duplicate_labels is False + + @pytest.mark.parametrize( + "left, right, expected", + [ + # false false false + pytest.param( + pd.DataFrame({"A": [0, 1]}, index=["a", "b"]).set_flags( + allows_duplicate_labels=False + ), + pd.DataFrame({"B": [0, 1]}, index=["a", "d"]).set_flags( + allows_duplicate_labels=False + ), + False, + ), + # false true false + pytest.param( + pd.DataFrame({"A": [0, 1]}, index=["a", "b"]).set_flags( + allows_duplicate_labels=False + ), + pd.DataFrame({"B": [0, 1]}, index=["a", "d"]), + False, + ), + # true true true + ( + pd.DataFrame({"A": [0, 1]}, index=["a", "b"]), + pd.DataFrame({"B": [0, 1]}, index=["a", "d"]), + True, + ), + ], + ) + def test_merge(self, left, right, expected): + result = pd.merge(left, right, left_index=True, right_index=True) + assert result.flags.allows_duplicate_labels is expected + + @not_implemented + def test_groupby(self): + # XXX: This is under tested + # TODO: + # - apply + # - transform + # - Should passing a grouper that disallows duplicates propagate? + df = pd.DataFrame({"A": [1, 2, 3]}).set_flags(allows_duplicate_labels=False) + result = df.groupby([0, 0, 1]).agg("count") + assert result.flags.allows_duplicate_labels is False + + @pytest.mark.parametrize("frame", [True, False]) + @not_implemented + def test_window(self, frame): + df = pd.Series( + 1, + index=pd.date_range("2000", periods=12), + name="A", + allows_duplicate_labels=False, + ) + if frame: + df = df.to_frame() + assert df.rolling(3).mean().flags.allows_duplicate_labels is False + assert df.ewm(3).mean().flags.allows_duplicate_labels is False + assert df.expanding(3).mean().flags.allows_duplicate_labels is False + + +# ---------------------------------------------------------------------------- +# Raises + + +class TestRaises: + @pytest.mark.parametrize( + "cls, axes", + [ + (pd.Series, {"index": ["a", "a"], "dtype": float}), + (pd.DataFrame, {"index": ["a", "a"]}), + (pd.DataFrame, {"index": ["a", "a"], "columns": ["b", "b"]}), + (pd.DataFrame, {"columns": ["b", "b"]}), + ], + ) + def test_set_flags_with_duplicates(self, cls, axes): + result = cls(**axes) + assert result.flags.allows_duplicate_labels is True + + msg = "Index has duplicates." + with pytest.raises(pd.errors.DuplicateLabelError, match=msg): + cls(**axes).set_flags(allows_duplicate_labels=False) + + @pytest.mark.parametrize( + "data", + [ + pd.Series(index=[0, 0], dtype=float), + pd.DataFrame(index=[0, 0]), + pd.DataFrame(columns=[0, 0]), + ], + ) + def test_setting_allows_duplicate_labels_raises(self, data): + msg = "Index has duplicates." + with pytest.raises(pd.errors.DuplicateLabelError, match=msg): + data.flags.allows_duplicate_labels = False + + assert data.flags.allows_duplicate_labels is True + + def test_series_raises(self): + a = pd.Series(0, index=["a", "b"]) + b = pd.Series([0, 1], index=["a", "b"]).set_flags(allows_duplicate_labels=False) + msg = "Index has duplicates." + with pytest.raises(pd.errors.DuplicateLabelError, match=msg): + pd.concat([a, b]) + + @pytest.mark.parametrize( + "getter, target", + [ + (operator.itemgetter(["A", "A"]), None), + # loc + (operator.itemgetter(["a", "a"]), "loc"), + pytest.param(operator.itemgetter(("a", ["A", "A"])), "loc"), + (operator.itemgetter((["a", "a"], "A")), "loc"), + # iloc + (operator.itemgetter([0, 0]), "iloc"), + pytest.param(operator.itemgetter((0, [0, 0])), "iloc"), + pytest.param(operator.itemgetter(([0, 0], 0)), "iloc"), + ], + ) + def test_getitem_raises(self, getter, target): + df = pd.DataFrame({"A": [1, 2], "B": [3, 4]}, index=["a", "b"]).set_flags( + allows_duplicate_labels=False + ) + if target: + # df, df.loc, or df.iloc + target = getattr(df, target) + else: + target = df + + msg = "Index has duplicates." + with pytest.raises(pd.errors.DuplicateLabelError, match=msg): + getter(target) + + def test_concat_raises(self): + objs = [ + pd.Series(1, index=[0, 1], name="a"), + pd.Series(2, index=[0, 1], name="a"), + ] + objs = [x.set_flags(allows_duplicate_labels=False) for x in objs] + msg = "Index has duplicates." + with pytest.raises(pd.errors.DuplicateLabelError, match=msg): + pd.concat(objs, axis=1) + + def test_merge_raises(self): + a = pd.DataFrame({"A": [0, 1, 2]}, index=["a", "b", "c"]).set_flags( + allows_duplicate_labels=False + ) + b = pd.DataFrame({"B": [0, 1, 2]}, index=["a", "b", "b"]) + msg = "Index has duplicates." + with pytest.raises(pd.errors.DuplicateLabelError, match=msg): + pd.merge(a, b, left_index=True, right_index=True) + + +@pytest.mark.parametrize( + "idx", + [ + pd.Index([1, 1]), + pd.Index(["a", "a"]), + pd.Index([1.1, 1.1]), + pd.PeriodIndex([pd.Period("2000", "D")] * 2), + pd.DatetimeIndex([pd.Timestamp("2000")] * 2), + pd.TimedeltaIndex([pd.Timedelta("1D")] * 2), + pd.CategoricalIndex(["a", "a"]), + pd.IntervalIndex([pd.Interval(0, 1)] * 2), + pd.MultiIndex.from_tuples([("a", 1), ("a", 1)]), + ], + ids=lambda x: type(x).__name__, +) +def test_raises_basic(idx): + msg = "Index has duplicates." + with pytest.raises(pd.errors.DuplicateLabelError, match=msg): + pd.Series(1, index=idx).set_flags(allows_duplicate_labels=False) + + with pytest.raises(pd.errors.DuplicateLabelError, match=msg): + pd.DataFrame({"A": [1, 1]}, index=idx).set_flags(allows_duplicate_labels=False) + + with pytest.raises(pd.errors.DuplicateLabelError, match=msg): + pd.DataFrame([[1, 2]], columns=idx).set_flags(allows_duplicate_labels=False) + + +def test_format_duplicate_labels_message(): + idx = pd.Index(["a", "b", "a", "b", "c"]) + result = idx._format_duplicate_message() + expected = pd.DataFrame( + {"positions": [[0, 2], [1, 3]]}, index=pd.Index(["a", "b"], name="label") + ) + tm.assert_frame_equal(result, expected) + + +def test_format_duplicate_labels_message_multi(): + idx = pd.MultiIndex.from_product([["A"], ["a", "b", "a", "b", "c"]]) + result = idx._format_duplicate_message() + expected = pd.DataFrame( + {"positions": [[0, 2], [1, 3]]}, + index=pd.MultiIndex.from_product([["A"], ["a", "b"]]), + ) + tm.assert_frame_equal(result, expected) + + +def test_dataframe_insert_raises(): + df = pd.DataFrame({"A": [1, 2]}).set_flags(allows_duplicate_labels=False) + msg = "Cannot specify" + with pytest.raises(ValueError, match=msg): + df.insert(0, "A", [3, 4], allow_duplicates=True) + + +@pytest.mark.parametrize( + "method, frame_only", + [ + (operator.methodcaller("set_index", "A", inplace=True), True), + (operator.methodcaller("reset_index", inplace=True), True), + (operator.methodcaller("rename", lambda x: x, inplace=True), False), + ], +) +def test_inplace_raises(method, frame_only): + df = pd.DataFrame({"A": [0, 0], "B": [1, 2]}).set_flags( + allows_duplicate_labels=False + ) + s = df["A"] + s.flags.allows_duplicate_labels = False + msg = "Cannot specify" + + with pytest.raises(ValueError, match=msg): + method(df) + if not frame_only: + with pytest.raises(ValueError, match=msg): + method(s) + + +def test_pickle(temp_file): + a = pd.Series([1, 2]).set_flags(allows_duplicate_labels=False) + b = tm.round_trip_pickle(a, temp_file) + tm.assert_series_equal(a, b) + + a = pd.DataFrame({"A": []}).set_flags(allows_duplicate_labels=False) + b = tm.round_trip_pickle(a, temp_file) + tm.assert_frame_equal(a, b) diff --git a/pandas/tests/generic/test_finalize.py b/pandas/tests/generic/test_finalize.py new file mode 100644 index 0000000000000000000000000000000000000000..de972f2f2f9c7a4101d8e00a3db22e17fa464822 --- /dev/null +++ b/pandas/tests/generic/test_finalize.py @@ -0,0 +1,792 @@ +""" +An exhaustive list of pandas methods exercising NDFrame.__finalize__. +""" + +from copy import deepcopy +from datetime import time +import operator +import re + +import numpy as np +import pytest + +from pandas._typing import MergeHow + +import pandas as pd + +# TODO: +# * Binary methods (mul, div, etc.) +# * Binary outputs (align, etc.) +# * top-level methods (concat, merge, get_dummies, etc.) +# * window +# * cumulative reductions + +not_implemented_mark = pytest.mark.xfail(reason="not implemented") + +mi = pd.MultiIndex.from_product([["a", "b"], [0, 1]], names=["A", "B"]) + +frame_data = ({"A": [1]},) +frame_mi_data = ({"A": [1, 2, 3, 4]}, mi) + + +# Tuple of +# - Callable: Constructor (Series, DataFrame) +# - Tuple: Constructor args +# - Callable: pass the constructed value with attrs set to this. + +_all_methods = [ + (pd.Series, ([0],), operator.methodcaller("take", [])), + (pd.Series, ([0],), operator.methodcaller("__getitem__", [True])), + (pd.Series, ([0],), operator.methodcaller("repeat", 2)), + (pd.Series, ([0],), operator.methodcaller("reset_index")), + (pd.Series, ([0],), operator.methodcaller("reset_index", drop=True)), + (pd.Series, ([0],), operator.methodcaller("to_frame")), + (pd.Series, ([0, 0],), operator.methodcaller("drop_duplicates")), + (pd.Series, ([0, 0],), operator.methodcaller("duplicated")), + (pd.Series, ([0, 0],), operator.methodcaller("round")), + (pd.Series, ([0, 0],), operator.methodcaller("rename", lambda x: x + 1)), + (pd.Series, ([0, 0],), operator.methodcaller("rename", "name")), + (pd.Series, ([0, 0],), operator.methodcaller("set_axis", ["a", "b"])), + (pd.Series, ([0, 0],), operator.methodcaller("reindex", [1, 0])), + (pd.Series, ([0, 0],), operator.methodcaller("drop", [0])), + (pd.Series, (pd.array([0, pd.NA]),), operator.methodcaller("fillna", 0)), + (pd.Series, ([0, 0],), operator.methodcaller("replace", {0: 1})), + (pd.Series, ([0, 0],), operator.methodcaller("shift")), + (pd.Series, ([0, 0],), operator.methodcaller("isin", [0, 1])), + (pd.Series, ([0, 0],), operator.methodcaller("between", 0, 2)), + (pd.Series, ([0, 0],), operator.methodcaller("isna")), + (pd.Series, ([0, 0],), operator.methodcaller("isnull")), + (pd.Series, ([0, 0],), operator.methodcaller("notna")), + (pd.Series, ([0, 0],), operator.methodcaller("notnull")), + (pd.Series, ([1],), operator.methodcaller("add", pd.Series([1]))), + # TODO: mul, div, etc. + ( + pd.Series, + ([0], pd.period_range("2000", periods=1)), + operator.methodcaller("to_timestamp"), + ), + ( + pd.Series, + ([0], pd.date_range("2000", periods=1)), + operator.methodcaller("to_period"), + ), + pytest.param( + ( + pd.DataFrame, + frame_data, + operator.methodcaller("dot", pd.DataFrame(index=["A"])), + ), + marks=pytest.mark.xfail(reason="Implement binary finalize"), + ), + (pd.DataFrame, frame_data, operator.methodcaller("transpose")), + (pd.DataFrame, frame_data, operator.methodcaller("__getitem__", "A")), + (pd.DataFrame, frame_data, operator.methodcaller("__getitem__", ["A"])), + (pd.DataFrame, frame_data, operator.methodcaller("__getitem__", np.array([True]))), + (pd.DataFrame, ({("A", "a"): [1]},), operator.methodcaller("__getitem__", ["A"])), + (pd.DataFrame, frame_data, operator.methodcaller("query", "A == 1")), + (pd.DataFrame, frame_data, operator.methodcaller("eval", "A + 1", engine="python")), + (pd.DataFrame, frame_data, operator.methodcaller("select_dtypes", include="int")), + (pd.DataFrame, frame_data, operator.methodcaller("assign", b=1)), + (pd.DataFrame, frame_data, operator.methodcaller("set_axis", ["A"])), + (pd.DataFrame, frame_data, operator.methodcaller("reindex", [0, 1])), + (pd.DataFrame, frame_data, operator.methodcaller("drop", columns=["A"])), + (pd.DataFrame, frame_data, operator.methodcaller("drop", index=[0])), + (pd.DataFrame, frame_data, operator.methodcaller("rename", columns={"A": "a"})), + (pd.DataFrame, frame_data, operator.methodcaller("rename", index=lambda x: x)), + (pd.DataFrame, frame_data, operator.methodcaller("fillna", "A")), + (pd.DataFrame, frame_data, operator.methodcaller("set_index", "A")), + (pd.DataFrame, frame_data, operator.methodcaller("reset_index")), + (pd.DataFrame, frame_data, operator.methodcaller("isna")), + (pd.DataFrame, frame_data, operator.methodcaller("isnull")), + (pd.DataFrame, frame_data, operator.methodcaller("notna")), + (pd.DataFrame, frame_data, operator.methodcaller("notnull")), + (pd.DataFrame, frame_data, operator.methodcaller("dropna")), + (pd.DataFrame, frame_data, operator.methodcaller("drop_duplicates")), + (pd.DataFrame, frame_data, operator.methodcaller("duplicated")), + (pd.DataFrame, frame_data, operator.methodcaller("sort_values", by="A")), + (pd.DataFrame, frame_data, operator.methodcaller("sort_index")), + (pd.DataFrame, frame_data, operator.methodcaller("nlargest", 1, "A")), + (pd.DataFrame, frame_data, operator.methodcaller("nsmallest", 1, "A")), + (pd.DataFrame, frame_mi_data, operator.methodcaller("swaplevel")), + ( + pd.DataFrame, + frame_data, + operator.methodcaller("add", pd.DataFrame(*frame_data)), + ), + # TODO: div, mul, etc. + ( + pd.DataFrame, + frame_data, + operator.methodcaller("combine", pd.DataFrame(*frame_data), operator.add), + ), + ( + pd.DataFrame, + frame_data, + operator.methodcaller("combine_first", pd.DataFrame(*frame_data)), + ), + pytest.param( + ( + pd.DataFrame, + frame_data, + operator.methodcaller("update", pd.DataFrame(*frame_data)), + ), + marks=not_implemented_mark, + ), + (pd.DataFrame, frame_data, operator.methodcaller("pivot", columns="A")), + ( + pd.DataFrame, + ({"A": [1], "B": [1]},), + operator.methodcaller("pivot_table", columns="A"), + ), + ( + pd.DataFrame, + ({"A": [1], "B": [1]},), + operator.methodcaller("pivot_table", columns="A", aggfunc=["mean", "sum"]), + ), + (pd.DataFrame, frame_data, operator.methodcaller("stack")), + (pd.DataFrame, frame_data, operator.methodcaller("explode", "A")), + (pd.DataFrame, frame_mi_data, operator.methodcaller("unstack")), + ( + pd.DataFrame, + ({"A": ["a", "b", "c"], "B": [1, 3, 5], "C": [2, 4, 6]},), + operator.methodcaller("melt", id_vars=["A"], value_vars=["B"]), + ), + (pd.DataFrame, frame_data, operator.methodcaller("map", lambda x: x)), + (pd.DataFrame, frame_data, operator.methodcaller("round", 2)), + (pd.DataFrame, frame_data, operator.methodcaller("corr")), + pytest.param( + (pd.DataFrame, frame_data, operator.methodcaller("cov")), + marks=[ + pytest.mark.filterwarnings("ignore::RuntimeWarning"), + ], + ), + ( + pd.DataFrame, + frame_data, + operator.methodcaller("corrwith", pd.DataFrame(*frame_data)), + ), + (pd.DataFrame, frame_data, operator.methodcaller("count")), + (pd.DataFrame, frame_data, operator.methodcaller("nunique")), + (pd.DataFrame, frame_data, operator.methodcaller("idxmin")), + (pd.DataFrame, frame_data, operator.methodcaller("idxmax")), + (pd.DataFrame, frame_data, operator.methodcaller("mode")), + (pd.Series, [0], operator.methodcaller("mode")), + (pd.DataFrame, frame_data, operator.methodcaller("median")), + ( + pd.DataFrame, + frame_data, + operator.methodcaller("quantile", numeric_only=True), + ), + ( + pd.DataFrame, + frame_data, + operator.methodcaller("quantile", q=[0.25, 0.75], numeric_only=True), + ), + ( + pd.DataFrame, + ({"A": [pd.Timedelta(days=1), pd.Timedelta(days=2)]},), + operator.methodcaller("quantile", numeric_only=False), + ), + ( + pd.DataFrame, + ({"A": [np.datetime64("2022-01-01"), np.datetime64("2022-01-02")]},), + operator.methodcaller("quantile", numeric_only=True), + ), + ( + pd.DataFrame, + ({"A": [1]}, [pd.Period("2000", "D")]), + operator.methodcaller("to_timestamp"), + ), + ( + pd.DataFrame, + ({"A": [1]}, [pd.Timestamp("2000")]), + operator.methodcaller("to_period", freq="D"), + ), + (pd.DataFrame, frame_mi_data, operator.methodcaller("isin", [1])), + (pd.DataFrame, frame_mi_data, operator.methodcaller("isin", pd.Series([1]))), + ( + pd.DataFrame, + frame_mi_data, + operator.methodcaller("isin", pd.DataFrame({"A": [1]})), + ), + (pd.DataFrame, frame_mi_data, operator.methodcaller("droplevel", "A")), + (pd.DataFrame, frame_data, operator.methodcaller("pop", "A")), + # Squeeze on columns, otherwise we'll end up with a scalar + (pd.DataFrame, frame_data, operator.methodcaller("squeeze", axis="columns")), + (pd.Series, ([1, 2],), operator.methodcaller("squeeze")), + (pd.Series, ([1, 2],), operator.methodcaller("rename_axis", index="a")), + (pd.DataFrame, frame_data, operator.methodcaller("rename_axis", columns="a")), + # Unary ops + (pd.DataFrame, frame_data, operator.neg), + (pd.Series, [1], operator.neg), + (pd.DataFrame, frame_data, operator.pos), + (pd.Series, [1], operator.pos), + (pd.DataFrame, frame_data, operator.inv), + (pd.Series, [1], operator.inv), + (pd.DataFrame, frame_data, abs), + (pd.Series, [1], abs), + (pd.DataFrame, frame_data, round), + (pd.Series, [1], round), + (pd.DataFrame, frame_data, operator.methodcaller("take", [0, 0])), + (pd.DataFrame, frame_mi_data, operator.methodcaller("xs", "a")), + (pd.Series, (1, mi), operator.methodcaller("xs", "a")), + (pd.DataFrame, frame_data, operator.methodcaller("get", "A")), + ( + pd.DataFrame, + frame_data, + operator.methodcaller("reindex_like", pd.DataFrame({"A": [1, 2, 3]})), + ), + ( + pd.Series, + frame_data, + operator.methodcaller("reindex_like", pd.Series([0, 1, 2])), + ), + (pd.DataFrame, frame_data, operator.methodcaller("add_prefix", "_")), + (pd.DataFrame, frame_data, operator.methodcaller("add_suffix", "_")), + (pd.Series, (1, ["a", "b"]), operator.methodcaller("add_prefix", "_")), + (pd.Series, (1, ["a", "b"]), operator.methodcaller("add_suffix", "_")), + (pd.Series, ([3, 2],), operator.methodcaller("sort_values")), + (pd.Series, ([1] * 10,), operator.methodcaller("head")), + (pd.DataFrame, ({"A": [1] * 10},), operator.methodcaller("head")), + (pd.Series, ([1] * 10,), operator.methodcaller("tail")), + (pd.DataFrame, ({"A": [1] * 10},), operator.methodcaller("tail")), + (pd.Series, ([1, 2],), operator.methodcaller("sample", n=2, replace=True)), + (pd.DataFrame, (frame_data,), operator.methodcaller("sample", n=2, replace=True)), + (pd.Series, ([1, 2],), operator.methodcaller("astype", float)), + (pd.DataFrame, frame_data, operator.methodcaller("astype", float)), + (pd.Series, ([1, 2],), operator.methodcaller("copy")), + (pd.DataFrame, frame_data, operator.methodcaller("copy")), + (pd.Series, ([1, 2], None, object), operator.methodcaller("infer_objects")), + ( + pd.DataFrame, + ({"A": np.array([1, 2], dtype=object)},), + operator.methodcaller("infer_objects"), + ), + (pd.Series, ([1, 2],), operator.methodcaller("convert_dtypes")), + (pd.DataFrame, frame_data, operator.methodcaller("convert_dtypes")), + (pd.Series, ([1, None, 3],), operator.methodcaller("interpolate")), + (pd.DataFrame, ({"A": [1, None, 3]},), operator.methodcaller("interpolate")), + (pd.Series, ([1, 2],), operator.methodcaller("clip", lower=1)), + (pd.DataFrame, frame_data, operator.methodcaller("clip", lower=1)), + ( + pd.Series, + (1, pd.date_range("2000", periods=4)), + operator.methodcaller("asfreq", "h"), + ), + ( + pd.DataFrame, + ({"A": [1, 1, 1, 1]}, pd.date_range("2000", periods=4)), + operator.methodcaller("asfreq", "h"), + ), + ( + pd.Series, + (1, pd.date_range("2000", periods=4)), + operator.methodcaller("at_time", time(12)), + ), + ( + pd.DataFrame, + ({"A": [1, 1, 1, 1]}, pd.date_range("2000", periods=4)), + operator.methodcaller("at_time", time(12)), + ), + ( + pd.Series, + (1, pd.date_range("2000", periods=4)), + operator.methodcaller("between_time", "12:00", "13:00"), + ), + ( + pd.DataFrame, + ({"A": [1, 1, 1, 1]}, pd.date_range("2000", periods=4)), + operator.methodcaller("between_time", "12:00", "13:00"), + ), + (pd.Series, ([1, 2],), operator.methodcaller("rank")), + (pd.DataFrame, frame_data, operator.methodcaller("rank")), + (pd.Series, ([1, 2],), operator.methodcaller("where", np.array([True, False]))), + (pd.DataFrame, frame_data, operator.methodcaller("where", np.array([[True]]))), + (pd.Series, ([1, 2],), operator.methodcaller("mask", np.array([True, False]))), + (pd.DataFrame, frame_data, operator.methodcaller("mask", np.array([[True]]))), + (pd.Series, ([1, 2],), operator.methodcaller("truncate", before=0)), + (pd.DataFrame, frame_data, operator.methodcaller("truncate", before=0)), + ( + pd.Series, + (1, pd.date_range("2000", periods=4, tz="UTC")), + operator.methodcaller("tz_convert", "CET"), + ), + ( + pd.DataFrame, + ({"A": [1, 1, 1, 1]}, pd.date_range("2000", periods=4, tz="UTC")), + operator.methodcaller("tz_convert", "CET"), + ), + ( + pd.Series, + (1, pd.date_range("2000", periods=4)), + operator.methodcaller("tz_localize", "CET"), + ), + ( + pd.DataFrame, + ({"A": [1, 1, 1, 1]}, pd.date_range("2000", periods=4)), + operator.methodcaller("tz_localize", "CET"), + ), + (pd.Series, ([1, 2],), operator.methodcaller("describe")), + (pd.DataFrame, frame_data, operator.methodcaller("describe")), + (pd.Series, ([1, 2],), operator.methodcaller("pct_change")), + (pd.DataFrame, frame_data, operator.methodcaller("pct_change")), + (pd.Series, ([1],), operator.methodcaller("transform", lambda x: x - x.min())), + ( + pd.DataFrame, + frame_mi_data, + operator.methodcaller("transform", lambda x: x - x.min()), + ), + (pd.Series, ([1],), operator.methodcaller("apply", lambda x: x)), + (pd.DataFrame, frame_mi_data, operator.methodcaller("apply", lambda x: x)), + # Cumulative reductions + (pd.Series, ([1],), operator.methodcaller("cumsum")), + (pd.DataFrame, frame_data, operator.methodcaller("cumsum")), + (pd.Series, ([1],), operator.methodcaller("cummin")), + (pd.DataFrame, frame_data, operator.methodcaller("cummin")), + (pd.Series, ([1],), operator.methodcaller("cummax")), + (pd.DataFrame, frame_data, operator.methodcaller("cummax")), + (pd.Series, ([1],), operator.methodcaller("cumprod")), + (pd.DataFrame, frame_data, operator.methodcaller("cumprod")), + # Reductions + (pd.DataFrame, frame_data, operator.methodcaller("any")), + (pd.DataFrame, frame_data, operator.methodcaller("all")), + (pd.DataFrame, frame_data, operator.methodcaller("min")), + (pd.DataFrame, frame_data, operator.methodcaller("max")), + (pd.DataFrame, frame_data, operator.methodcaller("sum")), + (pd.DataFrame, frame_data, operator.methodcaller("std")), + (pd.DataFrame, frame_data, operator.methodcaller("mean")), + (pd.DataFrame, frame_data, operator.methodcaller("prod")), + (pd.DataFrame, frame_data, operator.methodcaller("sem")), + (pd.DataFrame, frame_data, operator.methodcaller("skew")), + (pd.DataFrame, frame_data, operator.methodcaller("kurt")), +] + + +def idfn(x): + xpr = re.compile(r"'(.*)?'") + m = xpr.search(str(x)) + if m: + return m.group(1) + else: + return str(x) + + +@pytest.mark.parametrize("ndframe_method", _all_methods, ids=lambda x: idfn(x[-1])) +def test_finalize_called(ndframe_method): + cls, init_args, method = ndframe_method + ndframe = cls(*init_args) + + ndframe.attrs = {"a": 1} + result = method(ndframe) + + assert result.attrs == {"a": 1} + + +@not_implemented_mark +def test_finalize_called_eval_numexpr(): + pytest.importorskip("numexpr") + df = pd.DataFrame({"A": [1, 2]}) + df.attrs["A"] = 1 + result = df.eval("A + 1", engine="numexpr") + assert result.attrs == {"A": 1} + + +# ---------------------------------------------------------------------------- +# Binary operations + + +@pytest.mark.parametrize("annotate", ["left", "right", "both"]) +@pytest.mark.parametrize( + "args", + [ + (1, pd.Series([1])), + (1, pd.DataFrame({"A": [1]})), + (pd.Series([1]), 1), + (pd.DataFrame({"A": [1]}), 1), + (pd.Series([1]), pd.Series([1])), + (pd.DataFrame({"A": [1]}), pd.DataFrame({"A": [1]})), + (pd.Series([1]), pd.DataFrame({"A": [1]})), + (pd.DataFrame({"A": [1]}), pd.Series([1])), + ], + ids=lambda x: f"({type(x[0]).__name__},{type(x[1]).__name__})", +) +def test_binops(request, args, annotate, all_binary_operators): + # This generates 624 tests... Is that needed? + left, right = args + if isinstance(left, (pd.DataFrame, pd.Series)): + left.attrs = {} + if isinstance(right, (pd.DataFrame, pd.Series)): + right.attrs = {} + + if annotate == "left" and isinstance(left, int): + pytest.skip("left is an int and doesn't support .attrs") + if annotate == "right" and isinstance(right, int): + pytest.skip("right is an int and doesn't support .attrs") + + if annotate in {"left", "both"} and not isinstance(left, int): + left.attrs = {"a": 1} + if annotate in {"right", "both"} and not isinstance(right, int): + right.attrs = {"a": 1} + + is_cmp = all_binary_operators in [ + operator.eq, + operator.ne, + operator.gt, + operator.ge, + operator.lt, + operator.le, + ] + if is_cmp and isinstance(left, pd.DataFrame) and isinstance(right, pd.Series): + # in 2.0 silent alignment on comparisons was removed xref GH#28759 + left, right = left.align(right, axis=1) + elif is_cmp and isinstance(left, pd.Series) and isinstance(right, pd.DataFrame): + right, left = right.align(left, axis=1) + + result = all_binary_operators(left, right) + assert result.attrs == {"a": 1} + + +@pytest.mark.parametrize("left", [pd.Series, pd.DataFrame]) +@pytest.mark.parametrize("right", [pd.Series, pd.DataFrame]) +def test_attrs_binary_operations(all_binary_operators, left, right): + # GH 51607 + attrs = {"a": 1} + left = left([1]) + left.attrs = attrs + right = right([2]) + assert all_binary_operators(left, right).attrs == attrs + assert all_binary_operators(right, left).attrs == attrs + + +# ---------------------------------------------------------------------------- +# Accessors + + +@pytest.mark.parametrize( + "method", + [ + operator.methodcaller("capitalize"), + operator.methodcaller("casefold"), + operator.methodcaller("cat", ["a"]), + operator.methodcaller("contains", "a"), + operator.methodcaller("count", "a"), + operator.methodcaller("encode", "utf-8"), + operator.methodcaller("endswith", "a"), + operator.methodcaller("extract", r"(\w)(\d)"), + operator.methodcaller("extract", r"(\w)(\d)", expand=False), + operator.methodcaller("find", "a"), + operator.methodcaller("findall", "a"), + operator.methodcaller("get", 0), + operator.methodcaller("index", "a"), + operator.methodcaller("len"), + operator.methodcaller("ljust", 4), + operator.methodcaller("lower"), + operator.methodcaller("lstrip"), + operator.methodcaller("match", r"\w"), + operator.methodcaller("normalize", "NFC"), + operator.methodcaller("pad", 4), + operator.methodcaller("partition", "a"), + operator.methodcaller("repeat", 2), + operator.methodcaller("replace", "a", "b"), + operator.methodcaller("rfind", "a"), + operator.methodcaller("rindex", "a"), + operator.methodcaller("rjust", 4), + operator.methodcaller("rpartition", "a"), + operator.methodcaller("rstrip"), + operator.methodcaller("slice", 4), + operator.methodcaller("slice_replace", 1, repl="a"), + operator.methodcaller("startswith", "a"), + operator.methodcaller("strip"), + operator.methodcaller("swapcase"), + operator.methodcaller("translate", {"a": "b"}), + operator.methodcaller("upper"), + operator.methodcaller("wrap", 4), + operator.methodcaller("zfill", 4), + operator.methodcaller("isalnum"), + operator.methodcaller("isalpha"), + operator.methodcaller("isdigit"), + operator.methodcaller("isspace"), + operator.methodcaller("islower"), + operator.methodcaller("isupper"), + operator.methodcaller("istitle"), + operator.methodcaller("isnumeric"), + operator.methodcaller("isdecimal"), + operator.methodcaller("get_dummies"), + ], + ids=idfn, +) +def test_string_method(method): + s = pd.Series(["a1"]) + s.attrs = {"a": 1} + result = method(s.str) + assert result.attrs == {"a": 1} + + +@pytest.mark.parametrize( + "method", + [ + operator.methodcaller("to_period"), + operator.methodcaller("tz_localize", "CET"), + operator.methodcaller("normalize"), + operator.methodcaller("strftime", "%Y"), + operator.methodcaller("round", "h"), + operator.methodcaller("floor", "h"), + operator.methodcaller("ceil", "h"), + operator.methodcaller("month_name"), + operator.methodcaller("day_name"), + ], + ids=idfn, +) +def test_datetime_method(method): + s = pd.Series(pd.date_range("2000", periods=4)) + s.attrs = {"a": 1} + result = method(s.dt) + assert result.attrs == {"a": 1} + + +@pytest.mark.parametrize( + "attr", + [ + "date", + "time", + "timetz", + "year", + "month", + "day", + "hour", + "minute", + "second", + "microsecond", + "nanosecond", + "dayofweek", + "day_of_week", + "dayofyear", + "day_of_year", + "quarter", + "is_month_start", + "is_month_end", + "is_quarter_start", + "is_quarter_end", + "is_year_start", + "is_year_end", + "is_leap_year", + "daysinmonth", + "days_in_month", + ], +) +def test_datetime_property(attr): + s = pd.Series(pd.date_range("2000", periods=4)) + s.attrs = {"a": 1} + result = getattr(s.dt, attr) + assert result.attrs == {"a": 1} + + +@pytest.mark.parametrize( + "attr", ["days", "seconds", "microseconds", "nanoseconds", "components"] +) +def test_timedelta_property(attr): + s = pd.Series(pd.timedelta_range("2000", periods=4)) + s.attrs = {"a": 1} + result = getattr(s.dt, attr) + assert result.attrs == {"a": 1} + + +@pytest.mark.parametrize("method", [operator.methodcaller("total_seconds")]) +def test_timedelta_methods(method): + s = pd.Series(pd.timedelta_range("2000", periods=4)) + s.attrs = {"a": 1} + result = method(s.dt) + assert result.attrs == {"a": 1} + + +@pytest.mark.parametrize( + "method", + [ + operator.methodcaller("add_categories", ["c"]), + operator.methodcaller("as_ordered"), + operator.methodcaller("as_unordered"), + lambda x: x.codes, + operator.methodcaller("remove_categories", "a"), + operator.methodcaller("remove_unused_categories"), + operator.methodcaller("rename_categories", {"a": "A", "b": "B"}), + operator.methodcaller("reorder_categories", ["b", "a"]), + operator.methodcaller("set_categories", ["A", "B"]), + ], +) +@not_implemented_mark +def test_categorical_accessor(method): + s = pd.Series(["a", "b"], dtype="category") + s.attrs = {"a": 1} + result = method(s.cat) + assert result.attrs == {"a": 1} + + +# ---------------------------------------------------------------------------- +# Groupby + + +@pytest.mark.parametrize( + "obj", [pd.Series([0, 0]), pd.DataFrame({"A": [0, 1], "B": [1, 2]})] +) +@pytest.mark.parametrize( + "method", + [ + operator.methodcaller("sum"), + lambda x: x.apply(lambda y: y), + lambda x: x.agg("sum"), + lambda x: x.agg("mean"), + lambda x: x.agg("median"), + ], +) +def test_groupby_finalize(obj, method): + obj.attrs = {"a": 1} + result = method(obj.groupby([0, 0], group_keys=False)) + assert result.attrs == {"a": 1} + + +@pytest.mark.parametrize( + "obj", [pd.Series([0, 0]), pd.DataFrame({"A": [0, 1], "B": [1, 2]})] +) +@pytest.mark.parametrize( + "method", + [ + lambda x: x.agg(["sum", "count"]), + lambda x: x.agg("std"), + lambda x: x.agg("var"), + lambda x: x.agg("sem"), + lambda x: x.agg("size"), + lambda x: x.agg("ohlc"), + ], +) +@not_implemented_mark +def test_groupby_finalize_not_implemented(obj, method): + obj.attrs = {"a": 1} + result = method(obj.groupby([0, 0])) + assert result.attrs == {"a": 1} + + +def test_finalize_frame_series_name(): + # https://github.com/pandas-dev/pandas/pull/37186/files#r506978889 + # ensure we don't copy the column `name` to the Series. + df = pd.DataFrame({"name": [1, 2]}) + result = pd.Series([1, 2]).__finalize__(df) + assert result.name is None + + +# ---------------------------------------------------------------------------- +# Merge + + +@pytest.mark.parametrize( + ["allow_on_left", "allow_on_right"], + [(False, False), (False, True), (True, False), (True, True)], +) +@pytest.mark.parametrize( + "how", + [ + "left", + "right", + "inner", + "outer", + "left_anti", + "right_anti", + "cross", + ], +) +def test_merge_correctly_sets_duplication_allowance_flag( + how: MergeHow, + allow_on_left: bool, + allow_on_right: bool, +): + left = pd.DataFrame({"test": [1]}).set_flags(allows_duplicate_labels=allow_on_left) + right = pd.DataFrame({"test": [1]}).set_flags( + allows_duplicate_labels=allow_on_right, + ) + + if not how == "cross": + result = left.merge(right, how=how, on="test") + else: + result = left.merge(right, how=how) + + expected_duplication_allowance = allow_on_left and allow_on_right + assert result.flags.allows_duplicate_labels == expected_duplication_allowance + + +@pytest.mark.parametrize( + ["allow_on_left", "allow_on_right"], + [(False, False), (False, True), (True, False), (True, True)], +) +def test_merge_asof_correctly_sets_duplication_allowance_flag( + allow_on_left: bool, + allow_on_right: bool, +): + left = pd.DataFrame({"test": [1]}).set_flags(allows_duplicate_labels=allow_on_left) + right = pd.DataFrame({"test": [1]}).set_flags( + allows_duplicate_labels=allow_on_right, + ) + + result = pd.merge_asof(left, right) + + expected_duplication_allowance = allow_on_left and allow_on_right + assert result.flags.allows_duplicate_labels == expected_duplication_allowance + + +def test_merge_propagates_metadata_from_equal_input_metadata(): + metadata = {"a": [1, 2]} + left = pd.DataFrame({"test": [1]}) + left.attrs = metadata + right = pd.DataFrame({"test": [1]}) + right.attrs = deepcopy(metadata) + + result = left.merge(right, how="inner", on="test") + + assert result.attrs == metadata + + # Verify that merge deep-copies the attr dictionary. + assert result.attrs is not left.attrs + assert result.attrs is not right.attrs + assert result.attrs["a"] is not left.attrs["a"] + assert result.attrs["a"] is not right.attrs["a"] + + +def test_merge_does_not_propagate_metadata_from_unequal_input_metadata(): + left = pd.DataFrame({"test": [1]}) + left.attrs = {"a": 2} + right = pd.DataFrame({"test": [1]}) + right.attrs = {"b": 3} + + result = left.merge(right, how="inner", on="test") + + assert result.attrs == {} + + +@pytest.mark.parametrize( + ["left_has_metadata", "right_has_metadata", "expected"], + [ + (False, True, {}), + (True, False, {}), + (False, False, {}), + ], + ids=["left-empty", "right-empty", "both-empty"], +) +def test_merge_does_not_propagate_metadata_if_one_input_has_no_metadata( + left_has_metadata: bool, + right_has_metadata: bool, + expected: dict, +): + left = pd.DataFrame({"test": [1]}) + right = pd.DataFrame({"test": [1]}) + + if left_has_metadata: + left.attrs = {"a": [1, 2]} + else: + left.attrs = {} + + if right_has_metadata: + right.attrs = {"a": [1, 2]} + else: + right.attrs = {} + + result = left.merge(right, how="inner", on="test") + + assert result.attrs == expected diff --git a/pandas/tests/generic/test_frame.py b/pandas/tests/generic/test_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..c2d24cceeab0c0a241338b9826bd420489e0dbba --- /dev/null +++ b/pandas/tests/generic/test_frame.py @@ -0,0 +1,202 @@ +from copy import deepcopy +from operator import methodcaller +from typing import Literal + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + MultiIndex, + Series, + date_range, +) +import pandas._testing as tm + + +class TestDataFrame: + @pytest.mark.parametrize("func", ["_set_axis_name", "rename_axis"]) + def test_set_axis_name(self, func): + df = DataFrame([[1, 2], [3, 4]]) + + result = methodcaller(func, "foo")(df) + assert df.index.name is None + assert result.index.name == "foo" + + result = methodcaller(func, "cols", axis=1)(df) + assert df.columns.name is None + assert result.columns.name == "cols" + + @pytest.mark.parametrize("func", ["_set_axis_name", "rename_axis"]) + def test_set_axis_name_mi(self, func): + df = DataFrame( + np.empty((3, 3)), + index=MultiIndex.from_tuples([("A", x) for x in list("aBc")]), + columns=MultiIndex.from_tuples([("C", x) for x in list("xyz")]), + ) + + level_names = ["L1", "L2"] + + result = methodcaller(func, level_names)(df) + assert result.index.names == level_names + assert result.columns.names == [None, None] + + result = methodcaller(func, level_names, axis=1)(df) + assert result.columns.names == ["L1", "L2"] + assert result.index.names == [None, None] + + def test_nonzero_single_element(self): + df = DataFrame([[False, False]]) + msg_err = "The truth value of a DataFrame is ambiguous" + with pytest.raises(ValueError, match=msg_err): + bool(df) + + def test_metadata_propagation_indiv_groupby(self): + # groupby + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.random.default_rng(2).standard_normal(8), + } + ) + result = df.groupby("A").sum() + tm.assert_metadata_equivalent(df, result) + + def test_metadata_propagation_indiv_resample(self): + # resample + df = DataFrame( + np.random.default_rng(2).standard_normal((1000, 2)), + index=date_range("20130101", periods=1000, freq="s"), + ) + result = df.resample("1min") + tm.assert_metadata_equivalent(df, result) + + def test_metadata_propagation_indiv(self, monkeypatch): + # merging with override + # GH 6923 + + def finalize( + self: DataFrame, + other: DataFrame, + method: Literal["merge", "concat"] | None = None, + **kwargs, + ): + for name in self._metadata: + if method == "merge": + left, right = other.input_objs + value = getattr(left, name, "") + "|" + getattr(right, name, "") + object.__setattr__(self, name, value) + elif method == "concat": + value = "+".join( + [ + getattr(o, name) + for o in other.input_objs + if getattr(o, name, None) + ] + ) + object.__setattr__(self, name, value) + else: + object.__setattr__(self, name, getattr(other, name, "")) + + return self + + with monkeypatch.context() as m: + m.setattr(DataFrame, "_metadata", ["filename"]) + m.setattr(DataFrame, "__finalize__", finalize) + + df1 = DataFrame( + np.random.default_rng(2).integers(0, 4, (3, 2)), columns=["a", "b"] + ) + df2 = DataFrame( + np.random.default_rng(2).integers(0, 4, (3, 2)), columns=["c", "d"] + ) + DataFrame._metadata = ["filename"] + df1.filename = "fname1.csv" + df2.filename = "fname2.csv" + + result = df1.merge(df2, left_on=["a"], right_on=["c"], how="inner") + assert result.filename == "fname1.csv|fname2.csv" + + # concat + # GH#6927 + df1 = DataFrame( + np.random.default_rng(2).integers(0, 4, (3, 2)), columns=list("ab") + ) + df1.filename = "foo" + + result = pd.concat([df1, df1]) + assert result.filename == "foo+foo" + + def test_set_attribute(self): + # Test for consistent setattr behavior when an attribute and a column + # have the same name (Issue #8994) + df = DataFrame({"x": [1, 2, 3]}) + + df.y = 2 + df["y"] = [2, 4, 6] + df.y = 5 + + assert df.y == 5 + tm.assert_series_equal(df["y"], Series([2, 4, 6], name="y")) + + def test_deepcopy_empty(self): + # This test covers empty frame copying with non-empty column sets + # as reported in issue GH15370 + empty_frame = DataFrame(data=[], index=[], columns=["A"]) + empty_frame_copy = deepcopy(empty_frame) + + tm.assert_frame_equal(empty_frame_copy, empty_frame) + + +# formerly in Generic but only test DataFrame +class TestDataFrame2: + @pytest.mark.parametrize("value", [1, "True", [1, 2, 3], 5.0]) + def test_validate_bool_args(self, value): + df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + msg = 'For argument "inplace" expected type bool, received type' + with pytest.raises(ValueError, match=msg): + df.copy().rename_axis(mapper={"a": "x", "b": "y"}, axis=1, inplace=value) + + with pytest.raises(ValueError, match=msg): + df.copy().drop("a", axis=1, inplace=value) + + with pytest.raises(ValueError, match=msg): + df.copy().fillna(value=0, inplace=value) + + with pytest.raises(ValueError, match=msg): + df.copy().replace(to_replace=1, value=7, inplace=value) + + with pytest.raises(ValueError, match=msg): + df.copy().interpolate(inplace=value) + + with pytest.raises(ValueError, match=msg): + df.copy()._where(cond=df.a > 2, inplace=value) + + with pytest.raises(ValueError, match=msg): + df.copy().mask(cond=df.a > 2, inplace=value) + + def test_unexpected_keyword(self): + # GH8597 + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 2)), columns=["jim", "joe"] + ) + ca = pd.Categorical([0, 0, 2, 2, 3, np.nan]) + ts = df["joe"].copy() + ts[2] = np.nan + + msg = "unexpected keyword" + with pytest.raises(TypeError, match=msg): + df.drop("joe", axis=1, in_place=True) + + with pytest.raises(TypeError, match=msg): + df.reindex([1, 0], inplace=True) + + with pytest.raises(TypeError, match=msg): + ca.fillna(0, inplace=True) + + with pytest.raises(TypeError, match=msg): + ts.fillna(0, in_place=True) diff --git a/pandas/tests/generic/test_generic.py b/pandas/tests/generic/test_generic.py new file mode 100644 index 0000000000000000000000000000000000000000..ee6503b6929b615163e42cef167b34e87539c001 --- /dev/null +++ b/pandas/tests/generic/test_generic.py @@ -0,0 +1,494 @@ +from copy import ( + copy, + deepcopy, +) + +import numpy as np +import pytest + +from pandas.core.dtypes.common import is_scalar + +from pandas import ( + DataFrame, + Index, + Series, + date_range, +) +import pandas._testing as tm + +# ---------------------------------------------------------------------- +# Generic types test cases + + +def construct(box, shape, value=None, dtype=None, **kwargs): + """ + construct an object for the given shape + if value is specified use that if its a scalar + if value is an array, repeat it as needed + """ + if isinstance(shape, int): + shape = tuple([shape] * box._AXIS_LEN) + if value is not None: + if is_scalar(value): + if value == "empty": + arr = None + dtype = np.float64 + + # remove the info axis + kwargs.pop(box._info_axis_name, None) + else: + arr = np.empty(shape, dtype=dtype) + arr.fill(value) + else: + fshape = np.prod(shape) + arr = value.ravel() + new_shape = fshape / arr.shape[0] + if fshape % arr.shape[0] != 0: + raise Exception("invalid value passed in construct") + + arr = np.repeat(arr, new_shape).reshape(shape) + else: + arr = np.random.default_rng(2).standard_normal(shape) + return box(arr, dtype=dtype, **kwargs) + + +class TestGeneric: + @pytest.mark.parametrize( + "func", + [ + str.lower, + {x: x.lower() for x in list("ABCD")}, + Series({x: x.lower() for x in list("ABCD")}), + ], + ) + def test_rename(self, frame_or_series, func): + # single axis + idx = list("ABCD") + + for axis in frame_or_series._AXIS_ORDERS: + kwargs = {axis: idx} + obj = construct(frame_or_series, 4, **kwargs) + + # rename a single axis + result = obj.rename(**{axis: func}) + expected = obj.copy() + setattr(expected, axis, list("abcd")) + tm.assert_equal(result, expected) + + def test_get_numeric_data(self, frame_or_series): + n = 4 + kwargs = { + frame_or_series._get_axis_name(i): list(range(n)) + for i in range(frame_or_series._AXIS_LEN) + } + + # get the numeric data + o = construct(frame_or_series, n, **kwargs) + result = o._get_numeric_data() + tm.assert_equal(result, o) + + # non-inclusion + result = o._get_bool_data() + expected = construct(frame_or_series, n, value="empty", **kwargs) + if isinstance(o, DataFrame): + # preserve columns dtype + expected.columns = o.columns[:0] + tm.assert_equal(result, expected) + + # get the bool data + arr = np.array([True, True, False, True]) + o = construct(frame_or_series, n, value=arr, **kwargs) + result = o._get_numeric_data() + tm.assert_equal(result, o) + + def test_get_bool_data_empty_preserve_index(self): + expected = Series([], dtype="bool") + result = expected._get_bool_data() + tm.assert_series_equal(result, expected, check_index_type=True) + + def test_nonzero(self, frame_or_series): + # GH 4633 + # look at the boolean/nonzero behavior for objects + obj = construct(frame_or_series, shape=4) + msg = f"The truth value of a {frame_or_series.__name__} is ambiguous" + with pytest.raises(ValueError, match=msg): + bool(obj == 0) + with pytest.raises(ValueError, match=msg): + bool(obj == 1) + with pytest.raises(ValueError, match=msg): + bool(obj) + + obj = construct(frame_or_series, shape=4, value=1) + with pytest.raises(ValueError, match=msg): + bool(obj == 0) + with pytest.raises(ValueError, match=msg): + bool(obj == 1) + with pytest.raises(ValueError, match=msg): + bool(obj) + + obj = construct(frame_or_series, shape=4, value=np.nan) + with pytest.raises(ValueError, match=msg): + bool(obj == 0) + with pytest.raises(ValueError, match=msg): + bool(obj == 1) + with pytest.raises(ValueError, match=msg): + bool(obj) + + # empty + obj = construct(frame_or_series, shape=0) + with pytest.raises(ValueError, match=msg): + bool(obj) + + # invalid behaviors + + obj1 = construct(frame_or_series, shape=4, value=1) + obj2 = construct(frame_or_series, shape=4, value=1) + + with pytest.raises(ValueError, match=msg): + if obj1: + pass + + with pytest.raises(ValueError, match=msg): + obj1 and obj2 + with pytest.raises(ValueError, match=msg): + obj1 or obj2 + with pytest.raises(ValueError, match=msg): + not obj1 + + def test_frame_or_series_compound_dtypes(self, frame_or_series): + # see gh-5191 + # Compound dtypes should raise NotImplementedError. + + def f(dtype): + return construct(frame_or_series, shape=3, value=1, dtype=dtype) + + msg = ( + "compound dtypes are not implemented " + f"in the {frame_or_series.__name__} constructor" + ) + + with pytest.raises(NotImplementedError, match=msg): + f([("A", "datetime64[h]"), ("B", "str"), ("C", "int32")]) + + # these work (though results may be unexpected) + f("int64") + f("float64") + f("M8[ns]") + + def test_metadata_propagation(self, frame_or_series): + # check that the metadata matches up on the resulting ops + + o = construct(frame_or_series, shape=3) + o.name = "foo" + o2 = construct(frame_or_series, shape=3) + o2.name = "bar" + + # ---------- + # preserving + # ---------- + + # simple ops with scalars + for op in ["__add__", "__sub__", "__truediv__", "__mul__"]: + result = getattr(o, op)(1) + tm.assert_metadata_equivalent(o, result) + + # ops with like + for op in ["__add__", "__sub__", "__truediv__", "__mul__"]: + result = getattr(o, op)(o) + tm.assert_metadata_equivalent(o, result) + + # simple boolean + for op in ["__eq__", "__le__", "__ge__"]: + v1 = getattr(o, op)(o) + tm.assert_metadata_equivalent(o, v1) + tm.assert_metadata_equivalent(o, v1 & v1) + tm.assert_metadata_equivalent(o, v1 | v1) + + # combine_first + result = o.combine_first(o2) + tm.assert_metadata_equivalent(o, result) + + # --------------------------- + # non-preserving (by default) + # --------------------------- + + # add non-like + result = o + o2 + tm.assert_metadata_equivalent(result) + + # simple boolean + for op in ["__eq__", "__le__", "__ge__"]: + # this is a name matching op + v1 = getattr(o, op)(o) + v2 = getattr(o, op)(o2) + tm.assert_metadata_equivalent(v2) + tm.assert_metadata_equivalent(v1 & v2) + tm.assert_metadata_equivalent(v1 | v2) + + def test_size_compat(self, frame_or_series): + # GH8846 + # size property should be defined + + o = construct(frame_or_series, shape=10) + assert o.size == np.prod(o.shape) + assert o.size == 10 ** len(o.axes) + + def test_split_compat(self, frame_or_series): + # xref GH8846 + o = construct(frame_or_series, shape=10) + assert len(np.array_split(o, 5)) == 5 + assert len(np.array_split(o, 2)) == 2 + + # See gh-12301 + def test_stat_unexpected_keyword(self, frame_or_series): + obj = construct(frame_or_series, 5) + starwars = "Star Wars" + errmsg = "unexpected keyword" + + with pytest.raises(TypeError, match=errmsg): + obj.max(epic=starwars) # stat_function + with pytest.raises(TypeError, match=errmsg): + obj.var(epic=starwars) # stat_function_ddof + with pytest.raises(TypeError, match=errmsg): + obj.sum(epic=starwars) # cum_function + with pytest.raises(TypeError, match=errmsg): + obj.any(epic=starwars) # logical_function + + @pytest.mark.parametrize("func", ["sum", "cumsum", "any", "var"]) + def test_api_compat(self, func, frame_or_series): + # GH 12021 + # compat for __name__, __qualname__ + + obj = construct(frame_or_series, 5) + f = getattr(obj, func) + assert f.__name__ == func + assert f.__qualname__.endswith(func) + + def test_stat_non_defaults_args(self, frame_or_series): + obj = construct(frame_or_series, 5) + out = np.array([0]) + errmsg = "the 'out' parameter is not supported" + + with pytest.raises(ValueError, match=errmsg): + obj.max(out=out) # stat_function + with pytest.raises(ValueError, match=errmsg): + obj.var(out=out) # stat_function_ddof + with pytest.raises(ValueError, match=errmsg): + obj.sum(out=out) # cum_function + with pytest.raises(ValueError, match=errmsg): + obj.any(out=out) # logical_function + + def test_truncate_out_of_bounds(self, frame_or_series): + # GH11382 + + # small + shape = [2000] + ([1] * (frame_or_series._AXIS_LEN - 1)) + small = construct(frame_or_series, shape, dtype="int8", value=1) + tm.assert_equal(small.truncate(), small) + tm.assert_equal(small.truncate(before=0, after=3e3), small) + tm.assert_equal(small.truncate(before=-1, after=2e3), small) + + # big + shape = [2_000_000] + ([1] * (frame_or_series._AXIS_LEN - 1)) + big = construct(frame_or_series, shape, dtype="int8", value=1) + tm.assert_equal(big.truncate(), big) + tm.assert_equal(big.truncate(before=0, after=3e6), big) + tm.assert_equal(big.truncate(before=-1, after=2e6), big) + + @pytest.mark.parametrize( + "func", + [copy, deepcopy, lambda x: x.copy(deep=False), lambda x: x.copy(deep=True)], + ) + @pytest.mark.parametrize("shape", [0, 1, 2]) + def test_copy_and_deepcopy(self, frame_or_series, shape, func): + # GH 15444 + obj = construct(frame_or_series, shape) + obj_copy = func(obj) + assert obj_copy is not obj + tm.assert_equal(obj_copy, obj) + + def test_stdlib_copy_shallow_copies(self, frame_or_series): + obj = frame_or_series(range(3)) + obj_copy = copy(obj) + assert tm.shares_memory(obj, obj_copy) + + +class TestNDFrame: + # tests that don't fit elsewhere + + @pytest.mark.parametrize( + "ser", + [ + Series(range(10), dtype=np.float64), + Series([str(i) for i in range(10)], dtype=object), + ], + ) + def test_squeeze_series_noop(self, ser): + # noop + tm.assert_series_equal(ser.squeeze(), ser) + + def test_squeeze_frame_noop(self): + # noop + df = DataFrame(np.eye(2)) + tm.assert_frame_equal(df.squeeze(), df) + + def test_squeeze_frame_reindex(self): + # squeezing + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ).reindex(columns=["A"]) + tm.assert_series_equal(df.squeeze(), df["A"]) + + def test_squeeze_0_len_dim(self): + # don't fail with 0 length dimensions GH11229 & GH8999 + empty_series = Series([], name="five", dtype=np.float64) + empty_frame = DataFrame([empty_series]) + tm.assert_series_equal(empty_series, empty_series.squeeze()) + tm.assert_series_equal(empty_series, empty_frame.squeeze()) + + def test_squeeze_axis(self): + # axis argument + df = DataFrame( + np.random.default_rng(2).standard_normal((1, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=1, freq="B"), + ).iloc[:, :1] + assert df.shape == (1, 1) + tm.assert_series_equal(df.squeeze(axis=0), df.iloc[0]) + tm.assert_series_equal(df.squeeze(axis="index"), df.iloc[0]) + tm.assert_series_equal(df.squeeze(axis=1), df.iloc[:, 0]) + tm.assert_series_equal(df.squeeze(axis="columns"), df.iloc[:, 0]) + assert df.squeeze() == df.iloc[0, 0] + msg = "No axis named 2 for object type DataFrame" + with pytest.raises(ValueError, match=msg): + df.squeeze(axis=2) + msg = "No axis named x for object type DataFrame" + with pytest.raises(ValueError, match=msg): + df.squeeze(axis="x") + + def test_squeeze_axis_len_3(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((3, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=3, freq="B"), + ) + tm.assert_frame_equal(df.squeeze(axis=0), df) + + def test_numpy_squeeze(self): + s = Series(range(2), dtype=np.float64) + tm.assert_series_equal(np.squeeze(s), s) + + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ).reindex(columns=["A"]) + tm.assert_series_equal(np.squeeze(df), df["A"]) + + @pytest.mark.parametrize( + "ser", + [ + Series(range(10), dtype=np.float64), + Series([str(i) for i in range(10)], dtype=object), + ], + ) + def test_transpose_series(self, ser): + # calls implementation in pandas/core/base.py + tm.assert_series_equal(ser.transpose(), ser) + + def test_transpose_frame(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + tm.assert_frame_equal(df.transpose().transpose(), df) + + def test_numpy_transpose(self, frame_or_series): + obj = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + obj = tm.get_obj(obj, frame_or_series) + + if frame_or_series is Series: + # 1D -> np.transpose is no-op + tm.assert_series_equal(np.transpose(obj), obj) + + # round-trip preserved + tm.assert_equal(np.transpose(np.transpose(obj)), obj) + + msg = "the 'axes' parameter is not supported" + with pytest.raises(ValueError, match=msg): + np.transpose(obj, axes=1) + + @pytest.mark.parametrize( + "ser", + [ + Series(range(10), dtype=np.float64), + Series([str(i) for i in range(10)], dtype=object), + ], + ) + def test_take_series(self, ser): + indices = [1, 5, -2, 6, 3, -1] + out = ser.take(indices) + expected = Series( + data=ser.values.take(indices), + index=ser.index.take(indices), + dtype=ser.dtype, + ) + tm.assert_series_equal(out, expected) + + def test_take_frame(self): + indices = [1, 5, -2, 6, 3, -1] + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + out = df.take(indices) + expected = DataFrame( + data=df.values.take(indices, axis=0), + index=df.index.take(indices), + columns=df.columns, + ) + tm.assert_frame_equal(out, expected) + + def test_take_invalid_kwargs(self, frame_or_series): + indices = [-3, 2, 0, 1] + + obj = DataFrame(range(5)) + obj = tm.get_obj(obj, frame_or_series) + + msg = r"take\(\) got an unexpected keyword argument 'foo'" + with pytest.raises(TypeError, match=msg): + obj.take(indices, foo=2) + + msg = "the 'out' parameter is not supported" + with pytest.raises(ValueError, match=msg): + obj.take(indices, out=indices) + + msg = "the 'mode' parameter is not supported" + with pytest.raises(ValueError, match=msg): + obj.take(indices, mode="clip") + + def test_axis_classmethods(self, frame_or_series): + box = frame_or_series + obj = box(dtype=object) + values = box._AXIS_TO_AXIS_NUMBER.keys() + for v in values: + assert obj._get_axis_number(v) == box._get_axis_number(v) + assert obj._get_axis_name(v) == box._get_axis_name(v) + assert obj._get_block_manager_axis(v) == box._get_block_manager_axis(v) + + def test_flags_identity(self, frame_or_series): + obj = Series([1, 2]) + if frame_or_series is DataFrame: + obj = obj.to_frame() + + assert obj.flags is obj.flags + obj2 = obj.copy() + assert obj2.flags is not obj.flags diff --git a/pandas/tests/generic/test_label_or_level_utils.py b/pandas/tests/generic/test_label_or_level_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..797f3d488ab18285fd2a829075f5bd06985b25a3 --- /dev/null +++ b/pandas/tests/generic/test_label_or_level_utils.py @@ -0,0 +1,329 @@ +import pytest + +from pandas.core.dtypes.missing import array_equivalent + +import pandas as pd + + +# Fixtures +# ======== +@pytest.fixture +def df(): + """DataFrame with columns 'L1', 'L2', and 'L3'""" + return pd.DataFrame({"L1": [1, 2, 3], "L2": [11, 12, 13], "L3": ["A", "B", "C"]}) + + +@pytest.fixture(params=[[], ["L1"], ["L1", "L2"], ["L1", "L2", "L3"]]) +def df_levels(request, df): + """DataFrame with columns or index levels 'L1', 'L2', and 'L3'""" + levels = request.param + + if levels: + df = df.set_index(levels) + + return df + + +@pytest.fixture +def df_ambig(df): + """DataFrame with levels 'L1' and 'L2' and labels 'L1' and 'L3'""" + df = df.set_index(["L1", "L2"]) + + df["L1"] = df["L3"] + + return df + + +# Test is label/level reference +# ============================= +def get_labels_levels(df_levels): + expected_labels = list(df_levels.columns) + expected_levels = [name for name in df_levels.index.names if name is not None] + return expected_labels, expected_levels + + +def assert_label_reference(frame, labels, axis): + for label in labels: + assert frame._is_label_reference(label, axis=axis) + assert not frame._is_level_reference(label, axis=axis) + assert frame._is_label_or_level_reference(label, axis=axis) + + +def assert_level_reference(frame, levels, axis): + for level in levels: + assert frame._is_level_reference(level, axis=axis) + assert not frame._is_label_reference(level, axis=axis) + assert frame._is_label_or_level_reference(level, axis=axis) + + +# DataFrame +# --------- +def test_is_level_or_label_reference_df_simple(df_levels, axis): + axis = df_levels._get_axis_number(axis) + # Compute expected labels and levels + expected_labels, expected_levels = get_labels_levels(df_levels) + + # Transpose frame if axis == 1 + if axis == 1: + df_levels = df_levels.T + + # Perform checks + assert_level_reference(df_levels, expected_levels, axis=axis) + assert_label_reference(df_levels, expected_labels, axis=axis) + + +def test_is_level_reference_df_ambig(df_ambig, axis): + axis = df_ambig._get_axis_number(axis) + + # Transpose frame if axis == 1 + if axis == 1: + df_ambig = df_ambig.T + + # df has both an on-axis level and off-axis label named L1 + # Therefore L1 should reference the label, not the level + assert_label_reference(df_ambig, ["L1"], axis=axis) + + # df has an on-axis level named L2 and it is not ambiguous + # Therefore L2 is a level reference + assert_level_reference(df_ambig, ["L2"], axis=axis) + + # df has a column named L3 and it is not a level reference + assert_label_reference(df_ambig, ["L3"], axis=axis) + + +# Series +# ------ +def test_is_level_reference_series_simple_axis0(df): + # Make series with L1 as index + s = df.set_index("L1").L2 + assert_level_reference(s, ["L1"], axis=0) + assert not s._is_level_reference("L2") + + # Make series with L1 and L2 as index + s = df.set_index(["L1", "L2"]).L3 + assert_level_reference(s, ["L1", "L2"], axis=0) + assert not s._is_level_reference("L3") + + +def test_is_level_reference_series_axis1_error(df): + # Make series with L1 as index + s = df.set_index("L1").L2 + + with pytest.raises(ValueError, match="No axis named 1"): + s._is_level_reference("L1", axis=1) + + +# Test _check_label_or_level_ambiguity_df +# ======================================= + + +# DataFrame +# --------- +def test_check_label_or_level_ambiguity_df(df_ambig, axis): + axis = df_ambig._get_axis_number(axis) + # Transpose frame if axis == 1 + if axis == 1: + df_ambig = df_ambig.T + msg = "'L1' is both a column level and an index label" + + else: + msg = "'L1' is both an index level and a column label" + # df_ambig has both an on-axis level and off-axis label named L1 + # Therefore, L1 is ambiguous. + with pytest.raises(ValueError, match=msg): + df_ambig._check_label_or_level_ambiguity("L1", axis=axis) + + # df_ambig has an on-axis level named L2,, and it is not ambiguous. + df_ambig._check_label_or_level_ambiguity("L2", axis=axis) + + # df_ambig has an off-axis label named L3, and it is not ambiguous + assert not df_ambig._check_label_or_level_ambiguity("L3", axis=axis) + + +# Series +# ------ +def test_check_label_or_level_ambiguity_series(df): + # A series has no columns and therefore references are never ambiguous + + # Make series with L1 as index + s = df.set_index("L1").L2 + s._check_label_or_level_ambiguity("L1", axis=0) + s._check_label_or_level_ambiguity("L2", axis=0) + + # Make series with L1 and L2 as index + s = df.set_index(["L1", "L2"]).L3 + s._check_label_or_level_ambiguity("L1", axis=0) + s._check_label_or_level_ambiguity("L2", axis=0) + s._check_label_or_level_ambiguity("L3", axis=0) + + +def test_check_label_or_level_ambiguity_series_axis1_error(df): + # Make series with L1 as index + s = df.set_index("L1").L2 + + with pytest.raises(ValueError, match="No axis named 1"): + s._check_label_or_level_ambiguity("L1", axis=1) + + +# Test _get_label_or_level_values +# =============================== +def assert_label_values(frame, labels, axis): + axis = frame._get_axis_number(axis) + for label in labels: + if axis == 0: + expected = frame[label]._values + else: + expected = frame.loc[label]._values + + result = frame._get_label_or_level_values(label, axis=axis) + assert array_equivalent(expected, result) + + +def assert_level_values(frame, levels, axis): + axis = frame._get_axis_number(axis) + for level in levels: + if axis == 0: + expected = frame.index.get_level_values(level=level)._values + else: + expected = frame.columns.get_level_values(level=level)._values + + result = frame._get_label_or_level_values(level, axis=axis) + assert array_equivalent(expected, result) + + +# DataFrame +# --------- +def test_get_label_or_level_values_df_simple(df_levels, axis): + # Compute expected labels and levels + expected_labels, expected_levels = get_labels_levels(df_levels) + + axis = df_levels._get_axis_number(axis) + # Transpose frame if axis == 1 + if axis == 1: + df_levels = df_levels.T + + # Perform checks + assert_label_values(df_levels, expected_labels, axis=axis) + assert_level_values(df_levels, expected_levels, axis=axis) + + +def test_get_label_or_level_values_df_ambig(df_ambig, axis): + axis = df_ambig._get_axis_number(axis) + # Transpose frame if axis == 1 + if axis == 1: + df_ambig = df_ambig.T + + # df has an on-axis level named L2, and it is not ambiguous. + assert_level_values(df_ambig, ["L2"], axis=axis) + + # df has an off-axis label named L3, and it is not ambiguous. + assert_label_values(df_ambig, ["L3"], axis=axis) + + +def test_get_label_or_level_values_df_duplabels(df, axis): + df = df.set_index(["L1"]) + df_duplabels = pd.concat([df, df["L2"]], axis=1) + axis = df_duplabels._get_axis_number(axis) + # Transpose frame if axis == 1 + if axis == 1: + df_duplabels = df_duplabels.T + + # df has unambiguous level 'L1' + assert_level_values(df_duplabels, ["L1"], axis=axis) + + # df has unique label 'L3' + assert_label_values(df_duplabels, ["L3"], axis=axis) + + # df has duplicate labels 'L2' + if axis == 0: + expected_msg = "The column label 'L2' is not unique" + else: + expected_msg = "The index label 'L2' is not unique" + + with pytest.raises(ValueError, match=expected_msg): + assert_label_values(df_duplabels, ["L2"], axis=axis) + + +# Series +# ------ +def test_get_label_or_level_values_series_axis0(df): + # Make series with L1 as index + s = df.set_index("L1").L2 + assert_level_values(s, ["L1"], axis=0) + + # Make series with L1 and L2 as index + s = df.set_index(["L1", "L2"]).L3 + assert_level_values(s, ["L1", "L2"], axis=0) + + +def test_get_label_or_level_values_series_axis1_error(df): + # Make series with L1 as index + s = df.set_index("L1").L2 + + with pytest.raises(ValueError, match="No axis named 1"): + s._get_label_or_level_values("L1", axis=1) + + +# Test _drop_labels_or_levels +# =========================== +def assert_labels_dropped(frame, labels, axis): + axis = frame._get_axis_number(axis) + for label in labels: + df_dropped = frame._drop_labels_or_levels(label, axis=axis) + + if axis == 0: + assert label in frame.columns + assert label not in df_dropped.columns + else: + assert label in frame.index + assert label not in df_dropped.index + + +def assert_levels_dropped(frame, levels, axis): + axis = frame._get_axis_number(axis) + for level in levels: + df_dropped = frame._drop_labels_or_levels(level, axis=axis) + + if axis == 0: + assert level in frame.index.names + assert level not in df_dropped.index.names + else: + assert level in frame.columns.names + assert level not in df_dropped.columns.names + + +# DataFrame +# --------- +def test_drop_labels_or_levels_df(df_levels, axis): + # Compute expected labels and levels + expected_labels, expected_levels = get_labels_levels(df_levels) + + axis = df_levels._get_axis_number(axis) + # Transpose frame if axis == 1 + if axis == 1: + df_levels = df_levels.T + + # Perform checks + assert_labels_dropped(df_levels, expected_labels, axis=axis) + assert_levels_dropped(df_levels, expected_levels, axis=axis) + + with pytest.raises(ValueError, match="not valid labels or levels"): + df_levels._drop_labels_or_levels("L4", axis=axis) + + +# Series +# ------ +def test_drop_labels_or_levels_series(df): + # Make series with L1 as index + s = df.set_index("L1").L2 + assert_levels_dropped(s, ["L1"], axis=0) + + with pytest.raises(ValueError, match="not valid labels or levels"): + s._drop_labels_or_levels("L4", axis=0) + + # Make series with L1 and L2 as index + s = df.set_index(["L1", "L2"]).L3 + assert_levels_dropped(s, ["L1", "L2"], axis=0) + + with pytest.raises(ValueError, match="not valid labels or levels"): + s._drop_labels_or_levels("L4", axis=0) diff --git a/pandas/tests/generic/test_series.py b/pandas/tests/generic/test_series.py new file mode 100644 index 0000000000000000000000000000000000000000..2ea4b8be5cf91731107a41b107c28c9e137ab994 --- /dev/null +++ b/pandas/tests/generic/test_series.py @@ -0,0 +1,119 @@ +from operator import methodcaller + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + MultiIndex, + Series, + date_range, +) +import pandas._testing as tm + + +class TestSeries: + @pytest.mark.parametrize("func", ["rename_axis", "_set_axis_name"]) + def test_set_axis_name_mi(self, func): + ser = Series( + [11, 21, 31], + index=MultiIndex.from_tuples( + [("A", x) for x in ["a", "B", "c"]], names=["l1", "l2"] + ), + ) + + result = methodcaller(func, ["L1", "L2"])(ser) + assert ser.index.name is None + assert ser.index.names == ["l1", "l2"] + assert result.index.name is None + assert result.index.names, ["L1", "L2"] + + def test_set_axis_name_raises(self): + ser = Series([1]) + msg = "No axis named 1 for object type Series" + with pytest.raises(ValueError, match=msg): + ser._set_axis_name(name="a", axis=1) + + def test_get_bool_data_preserve_dtype(self): + ser = Series([True, False, True]) + result = ser._get_bool_data() + tm.assert_series_equal(result, ser) + + @pytest.mark.parametrize("data", [np.nan, pd.NaT, True, False]) + def test_nonzero_single_element_raise_1(self, data): + # single item nan to raise + series = Series([data]) + + msg = "The truth value of a Series is ambiguous" + with pytest.raises(ValueError, match=msg): + bool(series) + + @pytest.mark.parametrize("data", [(True, True), (False, False)]) + def test_nonzero_multiple_element_raise(self, data): + # multiple bool are still an error + msg_err = "The truth value of a Series is ambiguous" + series = Series([data]) + with pytest.raises(ValueError, match=msg_err): + bool(series) + + @pytest.mark.parametrize("data", [1, 0, "a", 0.0]) + def test_nonbool_single_element_raise(self, data): + # single non-bool are an error + msg_err1 = "The truth value of a Series is ambiguous" + series = Series([data]) + with pytest.raises(ValueError, match=msg_err1): + bool(series) + + def test_metadata_propagation_indiv_resample(self): + # resample + ts = Series( + np.random.default_rng(2).random(1000), + index=date_range("20130101", periods=1000, freq="s"), + name="foo", + ) + result = ts.resample("1min").mean() + tm.assert_metadata_equivalent(ts, result) + + result = ts.resample("1min").min() + tm.assert_metadata_equivalent(ts, result) + + result = ts.resample("1min").apply(lambda x: x.sum()) + tm.assert_metadata_equivalent(ts, result) + + def test_metadata_propagation_indiv(self, monkeypatch): + # check that the metadata matches up on the resulting ops + + ser = Series(range(3), range(3)) + ser.name = "foo" + ser2 = Series(range(3), range(3)) + ser2.name = "bar" + + result = ser.T + tm.assert_metadata_equivalent(ser, result) + + def finalize(self, other, method=None, **kwargs): + for name in self._metadata: + if method == "concat" and name == "filename": + value = "+".join( + [ + getattr(obj, name) + for obj in other.input_objs + if getattr(obj, name, None) + ] + ) + object.__setattr__(self, name, value) + else: + object.__setattr__(self, name, getattr(other, name, None)) + + return self + + with monkeypatch.context() as m: + m.setattr(Series, "_metadata", ["name", "filename"]) + m.setattr(Series, "__finalize__", finalize) + + ser.filename = "foo" + ser2.filename = "bar" + + result = pd.concat([ser, ser2]) + assert result.filename == "foo+bar" + assert result.name is None diff --git a/pandas/tests/generic/test_to_xarray.py b/pandas/tests/generic/test_to_xarray.py new file mode 100644 index 0000000000000000000000000000000000000000..3aabdb6d7869a3f007fa50ad3adb56544d575077 --- /dev/null +++ b/pandas/tests/generic/test_to_xarray.py @@ -0,0 +1,126 @@ +import numpy as np +import pytest + +from pandas import ( + Categorical, + DataFrame, + MultiIndex, + Series, + date_range, +) +import pandas._testing as tm +from pandas.util.version import Version + +xarray = pytest.importorskip("xarray") + +if xarray is not None and Version(xarray.__version__) < Version("2025.1.0"): + pytestmark = pytest.mark.filterwarnings( + "ignore:Converting non-nanosecond precision:UserWarning" + ) + + +class TestDataFrameToXArray: + @pytest.fixture + def df(self): + return DataFrame( + { + "a": list("abcd"), + "b": list(range(1, 5)), + "c": np.arange(3, 7).astype("u1"), + "d": np.arange(4.0, 8.0, dtype="float64"), + "e": [True, False, True, False], + "f": Categorical(list("abcd")), + "g": date_range("20130101", periods=4), + "h": date_range("20130101", periods=4, tz="US/Eastern"), + } + ) + + def test_to_xarray_index_types(self, index_flat, df, request): + index = index_flat + # MultiIndex is tested in test_to_xarray_with_multiindex + if len(index) == 0: + pytest.skip("Test doesn't make sense for empty index") + if Version(xarray.__version__) < Version("2025.9.0"): + pytest.skip("Xarray bug https://github.com/pydata/xarray/issues/9661") + + df.index = index[:4] + df.index.name = "foo" + df.columns.name = "bar" + result = df.to_xarray() + assert result.sizes["foo"] == 4 + assert len(result.coords) == 1 + assert len(result.data_vars) == 8 + tm.assert_almost_equal(list(result.coords.keys()), ["foo"]) + assert isinstance(result, xarray.Dataset) + + # idempotency + # datetimes w/tz are preserved + # column names are lost + expected = df.copy() + expected.columns.name = None + tm.assert_frame_equal(result.to_dataframe(), expected) + + def test_to_xarray_empty(self, df): + df.index.name = "foo" + result = df[0:0].to_xarray() + assert result.sizes["foo"] == 0 + assert isinstance(result, xarray.Dataset) + + def test_to_xarray_with_multiindex(self, df, using_infer_string): + # MultiIndex + df.index = MultiIndex.from_product([["a"], range(4)], names=["one", "two"]) + result = df.to_xarray() + assert result.sizes["one"] == 1 + assert result.sizes["two"] == 4 + assert len(result.coords) == 2 + assert len(result.data_vars) == 8 + tm.assert_almost_equal(list(result.coords.keys()), ["one", "two"]) + assert isinstance(result, xarray.Dataset) + + result = result.to_dataframe() + expected = df.copy() + expected["f"] = expected["f"].astype( + object if not using_infer_string else "str" + ) + if Version(xarray.__version__) < Version("2025.1.0"): + expected["g"] = expected["g"].astype("M8[ns]") + expected["h"] = expected["h"].astype("M8[ns, US/Eastern]") + expected.columns.name = None + tm.assert_frame_equal(result, expected) + + +class TestSeriesToXArray: + def test_to_xarray_index_types(self, index_flat, request): + # MultiIndex is tested in test_to_xarray_with_multiindex + index = index_flat + + ser = Series(range(len(index)), index=index, dtype="int64") + ser.index.name = "foo" + result = ser.to_xarray() + repr(result) + assert len(result) == len(index) + assert len(result.coords) == 1 + tm.assert_almost_equal(list(result.coords.keys()), ["foo"]) + assert isinstance(result, xarray.DataArray) + + # idempotency + tm.assert_series_equal(result.to_series(), ser) + + def test_to_xarray_empty(self): + ser = Series([], dtype=object) + ser.index.name = "foo" + result = ser.to_xarray() + assert len(result) == 0 + assert len(result.coords) == 1 + tm.assert_almost_equal(list(result.coords.keys()), ["foo"]) + assert isinstance(result, xarray.DataArray) + + def test_to_xarray_with_multiindex(self): + mi = MultiIndex.from_product([["a", "b"], range(3)], names=["one", "two"]) + ser = Series(range(6), dtype="int64", index=mi) + result = ser.to_xarray() + assert len(result) == 2 + tm.assert_almost_equal(list(result.coords.keys()), ["one", "two"]) + assert isinstance(result, xarray.DataArray) + res = result.to_series() + tm.assert_series_equal(res, ser) diff --git a/pandas/tests/groupby/__init__.py b/pandas/tests/groupby/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79046cd7ed415166e0c81ff645174015fc48eaf6 --- /dev/null +++ b/pandas/tests/groupby/__init__.py @@ -0,0 +1,25 @@ +def get_groupby_method_args(name, obj): + """ + Get required arguments for a groupby method. + + When parametrizing a test over groupby methods (e.g. "sum", "mean"), + it is often the case that arguments are required for certain methods. + + Parameters + ---------- + name: str + Name of the method. + obj: Series or DataFrame + pandas object that is being grouped. + + Returns + ------- + A tuple of required arguments for the method. + """ + if name in ("nth", "take"): + return (0,) + if name == "quantile": + return (0.5,) + if name == "corrwith": + return (obj,) + return () diff --git a/pandas/tests/groupby/conftest.py b/pandas/tests/groupby/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..2745f7c2b8d0faf0842d23f1584e9dd079a5ef30 --- /dev/null +++ b/pandas/tests/groupby/conftest.py @@ -0,0 +1,166 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + Series, + date_range, +) +from pandas.core.groupby.base import ( + reduction_kernels, + transformation_kernels, +) + + +@pytest.fixture +def df(): + return DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.random.default_rng(2).standard_normal(8), + } + ) + + +@pytest.fixture +def ts(): + return Series( + np.random.default_rng(2).standard_normal(30), + index=date_range("2000-01-01", periods=30, freq="B"), + ) + + +@pytest.fixture +def tsframe(): + return DataFrame( + np.random.default_rng(2).standard_normal((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=30, freq="B"), + ) + + +@pytest.fixture +def three_group(): + return DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + +@pytest.fixture +def slice_test_df(): + data = [ + [0, "a", "a0_at_0"], + [1, "b", "b0_at_1"], + [2, "a", "a1_at_2"], + [3, "b", "b1_at_3"], + [4, "c", "c0_at_4"], + [5, "a", "a2_at_5"], + [6, "a", "a3_at_6"], + [7, "a", "a4_at_7"], + ] + df = DataFrame(data, columns=["Index", "Group", "Value"]) + return df.set_index("Index") + + +@pytest.fixture +def slice_test_grouped(slice_test_df): + return slice_test_df.groupby("Group", as_index=False) + + +@pytest.fixture(params=sorted(reduction_kernels)) +def reduction_func(request): + """ + yields the string names of all groupby reduction functions, one at a time. + """ + return request.param + + +@pytest.fixture(params=sorted(transformation_kernels)) +def transformation_func(request): + """yields the string names of all groupby transformation functions.""" + return request.param + + +@pytest.fixture(params=sorted(reduction_kernels) + sorted(transformation_kernels)) +def groupby_func(request): + """yields both aggregation and transformation functions.""" + return request.param + + +@pytest.fixture( + params=[ + ("mean", {}), + ("var", {"ddof": 1}), + ("var", {"ddof": 0}), + ("std", {"ddof": 1}), + ("std", {"ddof": 0}), + ("sum", {}), + ("min", {}), + ("max", {}), + ("sum", {"min_count": 2}), + ("min", {"min_count": 2}), + ("max", {"min_count": 2}), + ], + ids=[ + "mean", + "var_1", + "var_0", + "std_1", + "std_0", + "sum", + "min", + "max", + "sum-min_count", + "min-min_count", + "max-min_count", + ], +) +def numba_supported_reductions(request): + """reductions supported with engine='numba'""" + return request.param diff --git a/pandas/tests/groupby/test_all_methods.py b/pandas/tests/groupby/test_all_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..7a012f5da4aa827dbea1350d7a874ca6e0d85233 --- /dev/null +++ b/pandas/tests/groupby/test_all_methods.py @@ -0,0 +1,105 @@ +""" +Tests that apply to all groupby operation methods. + +The only tests that should appear here are those that use the `groupby_func` fixture. +Even if it does use that fixture, prefer a more specific test file if it available +such as: + + - test_categorical + - test_groupby_dropna + - test_groupby_subclass + - test_raises +""" + +import pytest + +from pandas.errors import Pandas4Warning + +import pandas as pd +from pandas import DataFrame +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + + +def test_multiindex_group_all_columns_when_empty(groupby_func): + # GH 32464 + df = DataFrame({"a": [], "b": [], "c": []}).set_index(["a", "b", "c"]) + gb = df.groupby(["a", "b", "c"], group_keys=True) + method = getattr(gb, groupby_func) + args = get_groupby_method_args(groupby_func, df) + if groupby_func == "corrwith": + warn = Pandas4Warning + warn_msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn = None + warn_msg = "" + with tm.assert_produces_warning(warn, match=warn_msg): + result = method(*args).index + expected = df.index + tm.assert_index_equal(result, expected) + + +def test_duplicate_columns(request, groupby_func, as_index): + # GH#50806 + if groupby_func == "corrwith": + msg = "GH#50845 - corrwith fails when there are duplicate columns" + request.applymarker(pytest.mark.xfail(reason=msg)) + df = DataFrame([[1, 3, 6], [1, 4, 7], [2, 5, 8]], columns=list("abb")) + args = get_groupby_method_args(groupby_func, df) + gb = df.groupby("a", as_index=as_index) + result = getattr(gb, groupby_func)(*args) + + expected_df = df.set_axis(["a", "b", "c"], axis=1) + expected_args = get_groupby_method_args(groupby_func, expected_df) + expected_gb = expected_df.groupby("a", as_index=as_index) + expected = getattr(expected_gb, groupby_func)(*expected_args) + if groupby_func not in ("size", "ngroup", "cumcount"): + expected = expected.rename(columns={"c": "b"}) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "idx", + [ + pd.Index(["a", "a"], name="foo"), + pd.MultiIndex.from_tuples((("a", "a"), ("a", "a")), names=["foo", "bar"]), + ], +) +def test_dup_labels_output_shape(groupby_func, idx): + if groupby_func in {"size", "ngroup", "cumcount"}: + pytest.skip(f"Not applicable for {groupby_func}") + + df = DataFrame([[1, 1]], columns=idx) + grp_by = df.groupby([0]) + + args = get_groupby_method_args(groupby_func, df) + if groupby_func == "corrwith": + warn = Pandas4Warning + warn_msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn = None + warn_msg = "" + with tm.assert_produces_warning(warn, match=warn_msg): + result = getattr(grp_by, groupby_func)(*args) + + assert result.shape == (1, 2) + tm.assert_index_equal(result.columns, idx) + + +def test_not_c_contiguous_mask(groupby_func): + # https://github.com/pandas-dev/pandas/issues/61031 + if groupby_func == "corrwith": + # corrwith is deprecated + return + df = DataFrame({"a": [1, 1, 2], "b": [3, 4, 5]}, dtype="Int64") + reversed = DataFrame( + {"a": [2, 1, 1], "b": [5, 4, 3]}, dtype="Int64", index=[2, 1, 0] + )[::-1] + assert not reversed["b"].array._mask.flags["C_CONTIGUOUS"] + args = get_groupby_method_args(groupby_func, df) + + gb_reversed = reversed.groupby("a") + result = getattr(gb_reversed, groupby_func)(*args) + gb = df.groupby("a") + expected = getattr(gb, groupby_func)(*args) + tm.assert_equal(result, expected) diff --git a/pandas/tests/groupby/test_api.py b/pandas/tests/groupby/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..215e627abb018f04bff337056dee9c4a60312d62 --- /dev/null +++ b/pandas/tests/groupby/test_api.py @@ -0,0 +1,274 @@ +""" +Tests of the groupby API, including internal consistency and with other pandas objects. + +Tests in this file should only check the existence, names, and arguments of groupby +methods. It should not test the results of any groupby operation. +""" + +import inspect + +import pytest + +from pandas import ( + DataFrame, + Series, +) +from pandas.core.groupby.base import ( + groupby_other_methods, + reduction_kernels, + transformation_kernels, +) +from pandas.core.groupby.generic import ( + DataFrameGroupBy, + SeriesGroupBy, +) + + +def test_tab_completion(multiindex_dataframe_random_data): + grp = multiindex_dataframe_random_data.groupby(level="second") + results = {v for v in dir(grp) if not v.startswith("_")} + expected = { + "A", + "B", + "C", + "agg", + "aggregate", + "apply", + "boxplot", + "filter", + "first", + "get_group", + "groups", + "hist", + "indices", + "last", + "max", + "mean", + "median", + "min", + "ngroups", + "nth", + "ohlc", + "plot", + "prod", + "size", + "std", + "sum", + "transform", + "var", + "sem", + "count", + "nunique", + "head", + "describe", + "cummax", + "quantile", + "rank", + "cumprod", + "tail", + "resample", + "cummin", + "cumsum", + "cumcount", + "ngroup", + "all", + "shift", + "skew", + "kurt", + "take", + "pct_change", + "any", + "corr", + "corrwith", + "cov", + "ndim", + "diff", + "idxmax", + "idxmin", + "ffill", + "bfill", + "rolling", + "expanding", + "pipe", + "sample", + "ewm", + "value_counts", + } + assert results == expected + + +def test_all_methods_categorized(multiindex_dataframe_random_data): + grp = multiindex_dataframe_random_data.groupby( + multiindex_dataframe_random_data.iloc[:, 0] + ) + names = {_ for _ in dir(grp) if not _.startswith("_")} - set( + multiindex_dataframe_random_data.columns + ) + new_names = set(names) + new_names -= reduction_kernels + new_names -= transformation_kernels + new_names -= groupby_other_methods + + assert not reduction_kernels & transformation_kernels + assert not reduction_kernels & groupby_other_methods + assert not transformation_kernels & groupby_other_methods + + # new public method? + if new_names: + msg = f""" +There are uncategorized methods defined on the Grouper class: +{new_names}. + +Was a new method recently added? + +Every public method On Grouper must appear in exactly one the +following three lists defined in pandas.core.groupby.base: +- `reduction_kernels` +- `transformation_kernels` +- `groupby_other_methods` +see the comments in pandas/core/groupby/base.py for guidance on +how to fix this test. + """ + raise AssertionError(msg) + + # removed a public method? + all_categorized = reduction_kernels | transformation_kernels | groupby_other_methods + if names != all_categorized: + msg = f""" +Some methods which are supposed to be on the Grouper class +are missing: +{all_categorized - names}. + +They're still defined in one of the lists that live in pandas/core/groupby/base.py. +If you removed a method, you should update them +""" + raise AssertionError(msg) + + +def test_frame_consistency(groupby_func): + # GH#48028 + if groupby_func in ("first", "last"): + msg = "first and last don't exist for DataFrame anymore" + pytest.skip(reason=msg) + + if groupby_func in ("cumcount", "ngroup"): + assert not hasattr(DataFrame, groupby_func) + return + + frame_method = getattr(DataFrame, groupby_func) + gb_method = getattr(DataFrameGroupBy, groupby_func) + result = set(inspect.signature(gb_method).parameters) + if groupby_func == "size": + # "size" is a method on GroupBy but property on DataFrame: + expected = {"self"} + else: + expected = set(inspect.signature(frame_method).parameters) + + # Exclude certain arguments from result and expected depending on the operation + # Some of these may be purposeful inconsistencies between the APIs + exclude_expected, exclude_result = set(), set() + if groupby_func in ("any", "all"): + exclude_expected = {"kwargs", "bool_only", "axis"} + elif groupby_func in ("count",): + exclude_expected = {"numeric_only", "axis"} + elif groupby_func in ("nunique",): + exclude_expected = {"axis"} + elif groupby_func in ("max", "min"): + exclude_expected = {"axis", "kwargs"} + exclude_result = {"min_count", "engine", "engine_kwargs"} + elif groupby_func in ("sum", "mean", "std", "var"): + exclude_expected = {"axis", "kwargs"} + exclude_result = {"engine", "engine_kwargs"} + elif groupby_func in ("median", "prod", "sem"): + exclude_expected = {"axis", "kwargs"} + elif groupby_func in ("bfill", "ffill"): + exclude_expected = {"inplace", "axis", "limit_area"} + elif groupby_func in ("cummax", "cummin"): + exclude_expected = {"axis", "skipna", "args"} + elif groupby_func in ("cumprod", "cumsum"): + exclude_expected = {"axis", "skipna"} + elif groupby_func in ("pct_change",): + exclude_expected = {"kwargs"} + elif groupby_func in ("rank",): + exclude_expected = {"numeric_only"} + elif groupby_func in ("quantile",): + exclude_expected = {"method", "axis"} + elif groupby_func in ["corrwith"]: + exclude_expected = {"min_periods"} + if groupby_func not in ["pct_change", "size"]: + exclude_expected |= {"axis"} + + # Ensure excluded arguments are actually in the signatures + assert result & exclude_result == exclude_result + assert expected & exclude_expected == exclude_expected + + result -= exclude_result + expected -= exclude_expected + assert result == expected + + +def test_series_consistency(request, groupby_func): + # GH#48028 + if groupby_func in ("first", "last"): + msg = "first and last don't exist for Series anymore" + pytest.skip(msg) + + if groupby_func in ("cumcount", "corrwith", "ngroup"): + assert not hasattr(Series, groupby_func) + return + + series_method = getattr(Series, groupby_func) + gb_method = getattr(SeriesGroupBy, groupby_func) + result = set(inspect.signature(gb_method).parameters) + if groupby_func == "size": + # "size" is a method on GroupBy but property on Series + expected = {"self"} + else: + expected = set(inspect.signature(series_method).parameters) + + # Exclude certain arguments from result and expected depending on the operation + # Some of these may be purposeful inconsistencies between the APIs + exclude_expected, exclude_result = set(), set() + if groupby_func in ("any", "all"): + exclude_expected = {"kwargs", "bool_only", "axis"} + elif groupby_func in ("max", "min"): + exclude_expected = {"axis", "kwargs"} + exclude_result = {"min_count", "engine", "engine_kwargs"} + elif groupby_func in ("sum", "mean", "std", "var"): + exclude_expected = {"axis", "kwargs"} + exclude_result = {"engine", "engine_kwargs"} + elif groupby_func in ("median", "prod", "sem"): + exclude_expected = {"axis", "kwargs"} + elif groupby_func in ("bfill", "ffill"): + exclude_expected = {"inplace", "axis", "limit_area"} + elif groupby_func in ("cummax", "cummin"): + exclude_expected = {"skipna", "args"} + exclude_result = {"numeric_only"} + elif groupby_func in ("cumprod", "cumsum"): + exclude_expected = {"skipna"} + exclude_result = {"numeric_only"} + elif groupby_func in ("pct_change",): + exclude_expected = {"kwargs"} + elif groupby_func in ("rank",): + exclude_expected = {"numeric_only"} + elif groupby_func in ("idxmin", "idxmax"): + exclude_expected = {"args", "kwargs"} + elif groupby_func in ("quantile",): + exclude_result = {"numeric_only"} + if groupby_func not in [ + "diff", + "pct_change", + "count", + "nunique", + "quantile", + "size", + ]: + exclude_expected |= {"axis"} + + # Ensure excluded arguments are actually in the signatures + assert result & exclude_result == exclude_result + assert expected & exclude_expected == exclude_expected + + result -= exclude_result + expected -= exclude_expected + assert result == expected diff --git a/pandas/tests/groupby/test_apply.py b/pandas/tests/groupby/test_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c9168d6003e4ea964b1b41c7059e4afe72c8cb --- /dev/null +++ b/pandas/tests/groupby/test_apply.py @@ -0,0 +1,1543 @@ +from datetime import ( + date, + datetime, +) + +import numpy as np +import pytest + +from pandas.errors import Pandas4Warning + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + bdate_range, +) +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + + +def test_apply_func_that_appends_group_to_list_without_copy(): + # GH: 17718 + + df = DataFrame(1, index=list(range(10)) * 10, columns=[0]).reset_index() + groups = [] + + def store(group): + groups.append(group) + + df.groupby("index").apply(store) + expected_value = DataFrame({0: [1] * 10}, index=pd.RangeIndex(0, 100, 10)) + expected_value.columns = expected_value.columns.astype(object) + + tm.assert_frame_equal(groups[0], expected_value) + + +def test_apply_index_date(using_infer_string): + # GH 5788 + ts = [ + "2011-05-16 00:00", + "2011-05-16 01:00", + "2011-05-16 02:00", + "2011-05-16 03:00", + "2011-05-17 02:00", + "2011-05-17 03:00", + "2011-05-17 04:00", + "2011-05-17 05:00", + "2011-05-18 02:00", + "2011-05-18 03:00", + "2011-05-18 04:00", + "2011-05-18 05:00", + ] + df = DataFrame( + { + "value": [ + 1.40893, + 1.40760, + 1.40750, + 1.40649, + 1.40893, + 1.40760, + 1.40750, + 1.40649, + 1.40893, + 1.40760, + 1.40750, + 1.40649, + ], + }, + index=Index(pd.to_datetime(ts), name="date_time"), + ) + expected = df.groupby(df.index.date).idxmax() + result = df.groupby(df.index.date).apply(lambda x: x.idxmax()) + tm.assert_frame_equal(result, expected) + + +def test_apply_index_date_object(): + # GH 5789 + # don't auto coerce dates + ts = [ + "2011-05-16 00:00", + "2011-05-16 01:00", + "2011-05-16 02:00", + "2011-05-16 03:00", + "2011-05-17 02:00", + "2011-05-17 03:00", + "2011-05-17 04:00", + "2011-05-17 05:00", + "2011-05-18 02:00", + "2011-05-18 03:00", + "2011-05-18 04:00", + "2011-05-18 05:00", + ] + df = DataFrame([row.split() for row in ts], columns=["date", "time"]) + df["value"] = [ + 1.40893, + 1.40760, + 1.40750, + 1.40649, + 1.40893, + 1.40760, + 1.40750, + 1.40649, + 1.40893, + 1.40760, + 1.40750, + 1.40649, + ] + exp_idx = Index(["2011-05-16", "2011-05-17", "2011-05-18"], name="date") + expected = Series(["00:00", "02:00", "02:00"], index=exp_idx) + result = df.groupby("date").apply(lambda x: x["time"][x["value"].idxmax()]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "df, group_names", + [ + (DataFrame({"a": [1, 1, 1, 2, 3], "b": ["a", "a", "a", "b", "c"]}), [1, 2, 3]), + (DataFrame({"a": [0, 0, 1, 1], "b": [0, 1, 0, 1]}), [0, 1]), + (DataFrame({"a": [1]}), [1]), + (DataFrame({"a": [1, 1, 1, 2, 2, 1, 1, 2], "b": range(8)}), [1, 2]), + (DataFrame({"a": [1, 2, 3, 1, 2, 3], "two": [4, 5, 6, 7, 8, 9]}), [1, 2, 3]), + ( + DataFrame( + { + "a": list("aaabbbcccc"), + "B": [3, 4, 3, 6, 5, 2, 1, 9, 5, 4], + "C": [4, 0, 2, 2, 2, 7, 8, 6, 2, 8], + } + ), + ["a", "b", "c"], + ), + (DataFrame([[1, 2, 3], [2, 2, 3]], columns=["a", "b", "c"]), [1, 2]), + ], + ids=[ + "GH2936", + "GH7739 & GH10519", + "GH10519", + "GH2656", + "GH12155", + "GH20084", + "GH21417", + ], +) +def test_group_apply_once_per_group(df, group_names): + # GH2936, GH7739, GH10519, GH2656, GH12155, GH20084, GH21417 + + # This test should ensure that a function is only evaluated + # once per group. Previously the function has been evaluated twice + # on the first group to check if the Cython index slider is safe to use + # This test ensures that the side effect (append to list) is only triggered + # once per group + + names = [] + # cannot parameterize over the functions since they need external + # `names` to detect side effects + + def f_copy(group): + # this takes the fast apply path + names.append(group.name) + return group.copy() + + def f_nocopy(group): + # this takes the slow apply path + names.append(group.name) + return group + + def f_scalar(group): + # GH7739, GH2656 + names.append(group.name) + return 0 + + def f_none(group): + # GH10519, GH12155, GH21417 + names.append(group.name) + + def f_constant_df(group): + # GH2936, GH20084 + names.append(group.name) + return DataFrame({"a": [1], "b": [1]}) + + for func in [f_copy, f_nocopy, f_scalar, f_none, f_constant_df]: + del names[:] + + df.groupby("a").apply(func) + assert names == group_names + + +def test_group_apply_once_per_group2(capsys): + # GH: 31111 + # groupby-apply need to execute len(set(group_by_columns)) times + + expected = 2 # Number of times `apply` should call a function for the current test + + df = DataFrame( + { + "group_by_column": [0, 0, 0, 0, 1, 1, 1, 1], + "test_column": ["0", "2", "4", "6", "8", "10", "12", "14"], + }, + index=["0", "2", "4", "6", "8", "10", "12", "14"], + ) + + df.groupby("group_by_column", group_keys=False).apply( + lambda df: print("function_called") + ) + + result = capsys.readouterr().out.count("function_called") + # If `groupby` behaves unexpectedly, this test will break + assert result == expected + + +def test_apply_fast_slow_identical(): + # GH 31613 + + df = DataFrame({"A": [0, 0, 1], "b": range(3)}) + + # For simple index structures we check for fast/slow apply using + # an identity check on in/output + def slow(group): + return group + + def fast(group): + return group.copy() + + fast_df = df.groupby("A", group_keys=False).apply(fast) + slow_df = df.groupby("A", group_keys=False).apply(slow) + tm.assert_frame_equal(fast_df, slow_df) + + +def test_apply_fast_slow_identical_index(): + # GH#44803 + df = DataFrame( + { + "name": ["Alice", "Bob", "Carl"], + "age": [20, 21, 20], + } + ).set_index("name") + + grp_by_same_value = df.groupby(["age"], group_keys=False).apply(lambda group: group) + grp_by_copy = df.groupby(["age"], group_keys=False).apply( + lambda group: group.copy() + ) + tm.assert_frame_equal(grp_by_same_value, grp_by_copy) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: x, + lambda x: x[:], + lambda x: x.copy(deep=False), + lambda x: x.copy(deep=True), + ], +) +def test_groupby_apply_identity_maybecopy_index_identical(func): + # GH 14927 + # Whether the function returns a copy of the input data or not should not + # have an impact on the index structure of the result since this is not + # transparent to the user + + df = DataFrame({"g": [1, 2, 2, 2], "a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) + result = df.groupby("g", group_keys=False).apply(func) + tm.assert_frame_equal(result, df[["a", "b"]]) + + +def test_apply_with_mixed_dtype(): + # GH3480, apply with mixed dtype on axis=1 breaks in 0.11 + df = DataFrame( + { + "foo1": np.random.default_rng(2).standard_normal(6), + "foo2": ["one", "two", "two", "three", "one", "two"], + } + ) + result = df.apply(lambda x: x, axis=1) + expected = df + tm.assert_frame_equal(result, expected) + + # GH 3610 incorrect dtype conversion with as_index=False + df = DataFrame({"c1": [1, 2, 6, 6, 8]}) + df["c2"] = df.c1 / 2.0 + result1 = df.groupby("c2").mean().reset_index() + result2 = df.groupby("c2", as_index=False).mean() + tm.assert_frame_equal(result1, result2) + + +def test_groupby_as_index_apply(as_index): + # GH #4648 and #3417 + df = DataFrame( + { + "item_id": ["b", "b", "a", "c", "a", "b"], + "user_id": [1, 2, 1, 1, 3, 1], + "time": range(6), + } + ) + gb = df.groupby("user_id", as_index=as_index) + + expected = DataFrame( + { + "item_id": ["b", "b", "a", "a"], + "user_id": [1, 2, 1, 3], + "time": [0, 1, 2, 4], + }, + index=[0, 1, 2, 4], + ) + result = gb.head(2) + tm.assert_frame_equal(result, expected) + + # apply doesn't maintain the original ordering + # changed in GH5610 as the as_index=False returns a MI here + if as_index: + tp = [(1, 0), (1, 2), (2, 1), (3, 4)] + index = MultiIndex.from_tuples(tp, names=["user_id", None]) + else: + index = Index([0, 2, 1, 4]) + expected = DataFrame( + { + "item_id": list("baba"), + "time": [0, 2, 1, 4], + }, + index=index, + ) + result = gb.apply(lambda x: x.head(2)) + tm.assert_frame_equal(result, expected) + + +def test_groupby_as_index_apply_str(): + ind = Index(list("abcde")) + df = DataFrame([[1, 2], [2, 3], [1, 4], [1, 5], [2, 6]], index=ind) + res = df.groupby(0, as_index=False, group_keys=False).apply(lambda x: x).index + tm.assert_index_equal(res, ind) + + +def test_apply_concat_preserve_names(three_group): + grouped = three_group.groupby(["A", "B"]) + + def desc(group): + result = group.describe() + result.index.name = "stat" + return result + + def desc2(group): + result = group.describe() + result.index.name = "stat" + result = result[: len(group)] + # weirdo + return result + + def desc3(group): + result = group.describe() + + # names are different + result.index.name = f"stat_{len(group):d}" + + result = result[: len(group)] + # weirdo + return result + + result = grouped.apply(desc) + assert result.index.names == ("A", "B", "stat") + + result2 = grouped.apply(desc2) + assert result2.index.names == ("A", "B", "stat") + + result3 = grouped.apply(desc3) + assert result3.index.names == ("A", "B", None) + + +def test_apply_series_to_frame(): + def f(piece): + with np.errstate(invalid="ignore"): + logged = np.log(piece) + return DataFrame( + {"value": piece, "demeaned": piece - piece.mean(), "logged": logged} + ) + + dr = bdate_range("1/1/2000", periods=10) + ts = Series(np.random.default_rng(2).standard_normal(10), index=dr) + + grouped = ts.groupby(lambda x: x.month, group_keys=False) + result = grouped.apply(f) + + assert isinstance(result, DataFrame) + assert not hasattr(result, "name") # GH49907 + tm.assert_index_equal(result.index, ts.index) + + +def test_apply_series_yield_constant(df): + result = df.groupby(["A", "B"])["C"].apply(len) + assert result.index.names[:2] == ("A", "B") + + +def test_apply_frame_yield_constant(df): + # GH13568 + result = df.groupby(["A", "B"]).apply(len) + assert isinstance(result, Series) + assert result.name is None + + result = df.groupby(["A", "B"])[["C", "D"]].apply(len) + assert isinstance(result, Series) + assert result.name is None + + +def test_apply_frame_to_series(df): + grouped = df.groupby(["A", "B"]) + result = grouped.apply(len) + expected = grouped.count()["C"] + tm.assert_index_equal(result.index, expected.index) + tm.assert_numpy_array_equal(result.values, expected.values) + + +def test_apply_frame_not_as_index_column_name(df): + # GH 35964 - path within _wrap_applied_output not hit by a test + grouped = df.groupby(["A", "B"], as_index=False) + result = grouped.apply(len) + expected = grouped.count().rename(columns={"C": np.nan}).drop(columns="D") + # TODO(GH#34306): Use assert_frame_equal when column name is not np.nan + tm.assert_index_equal(result.index, expected.index) + tm.assert_numpy_array_equal(result.values, expected.values) + + +def test_apply_frame_concat_series(): + def trans(group): + return group.groupby("B")["C"].sum().sort_values().iloc[:2] + + def trans2(group): + grouped = group.groupby(df.reindex(group.index)["B"]) + return grouped.sum().sort_values().iloc[:2] + + df = DataFrame( + { + "A": np.random.default_rng(2).integers(0, 5, 1000), + "B": np.random.default_rng(2).integers(0, 5, 1000), + "C": np.random.default_rng(2).standard_normal(1000), + } + ) + + result = df.groupby("A").apply(trans) + exp = df.groupby("A")["C"].apply(trans2) + tm.assert_series_equal(result, exp, check_names=False) + assert result.name == "C" + + +def test_apply_transform(ts): + grouped = ts.groupby(lambda x: x.month, group_keys=False) + result = grouped.apply(lambda x: x * 2) + expected = grouped.transform(lambda x: x * 2) + tm.assert_series_equal(result, expected) + + +def test_apply_multikey_corner(tsframe): + grouped = tsframe.groupby([lambda x: x.year, lambda x: x.month]) + + def f(group): + return group.sort_values("A")[-5:] + + result = grouped.apply(f) + for key, group in grouped: + tm.assert_frame_equal(result.loc[key], f(group)) + + +@pytest.mark.parametrize("group_keys", [True, False]) +def test_apply_chunk_view(group_keys): + # Low level tinkering could be unsafe, make sure not + df = DataFrame({"key": [1, 1, 1, 2, 2, 2, 3, 3, 3], "value": range(9)}) + + result = df.groupby("key", group_keys=group_keys).apply(lambda x: x.iloc[:2]) + expected = df[["value"]].take([0, 1, 3, 4, 6, 7]) + if group_keys: + expected.index = MultiIndex.from_arrays( + [[1, 1, 2, 2, 3, 3], expected.index], names=["key", None] + ) + + tm.assert_frame_equal(result, expected) + + +def test_apply_no_name_column_conflict(): + df = DataFrame( + { + "name": [1, 1, 1, 1, 1, 1, 2, 2, 2, 2], + "name2": [0, 0, 0, 1, 1, 1, 0, 0, 1, 1], + "value": range(9, -1, -1), + } + ) + + # it works! #2605 + grouped = df.groupby(["name", "name2"]) + grouped.apply(lambda x: x.sort_values("value", inplace=True)) + + +def test_apply_typecast_fail(): + df = DataFrame( + { + "d": [1.0, 1.0, 1.0, 2.0, 2.0, 2.0], + "c": np.tile(["a", "b", "c"], 2), + "v": np.arange(1.0, 7.0), + } + ) + + def f(group): + v = group["v"] + group["v2"] = (v - v.min()) / (v.max() - v.min()) + return group + + result = df.groupby("d", group_keys=False).apply(f) + + expected = df[["c", "v"]] + expected["v2"] = np.tile([0.0, 0.5, 1], 2) + + tm.assert_frame_equal(result, expected) + + +def test_apply_multiindex_fail(): + index = MultiIndex.from_arrays([[0, 0, 0, 1, 1, 1], [1, 2, 3, 1, 2, 3]]) + df = DataFrame( + { + "d": [1.0, 1.0, 1.0, 2.0, 2.0, 2.0], + "c": np.tile(["a", "b", "c"], 2), + "v": np.arange(1.0, 7.0), + }, + index=index, + ) + + def f(group): + v = group["v"] + group["v2"] = (v - v.min()) / (v.max() - v.min()) + return group + + result = df.groupby("d", group_keys=False).apply(f) + + expected = df[["c", "v"]] + expected["v2"] = np.tile([0.0, 0.5, 1], 2) + tm.assert_frame_equal(result, expected) + + +def test_apply_corner(tsframe): + result = tsframe.groupby(lambda x: x.year, group_keys=False).apply(lambda x: x * 2) + expected = tsframe * 2 + tm.assert_frame_equal(result, expected) + + +def test_apply_without_copy(): + # GH 5545 + # returning a non-copy in an applied function fails + + data = DataFrame( + { + "id_field": [100, 100, 200, 300], + "category": ["a", "b", "c", "c"], + "value": [1, 2, 3, 4], + } + ) + + def filt1(x): + if x.shape[0] == 1: + return x.copy() + else: + return x[x.category == "c"] + + def filt2(x): + if x.shape[0] == 1: + return x + else: + return x[x.category == "c"] + + expected = data.groupby("id_field").apply(filt1) + result = data.groupby("id_field").apply(filt2) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("test_series", [True, False]) +def test_apply_with_duplicated_non_sorted_axis(test_series): + # GH 30667 + df = DataFrame( + [["x", "p"], ["x", "p"], ["x", "o"]], columns=["X", "Y"], index=[1, 2, 2] + ) + if test_series: + ser = df.set_index("Y")["X"] + result = ser.groupby(level=0, group_keys=False).apply(lambda x: x) + expected = ser + tm.assert_series_equal(result, expected) + else: + result = df.groupby("Y", group_keys=False).apply(lambda x: x) + expected = df[["X"]] + tm.assert_frame_equal(result, expected) + + +def test_apply_reindex_values(): + # GH: 26209 + # reindexing from a single column of a groupby object with duplicate indices caused + # a ValueError (cannot reindex from duplicate axis) in 0.24.2, the problem was + # solved in #30679 + values = [1, 2, 3, 4] + indices = [1, 1, 2, 2] + df = DataFrame({"group": ["Group1", "Group2"] * 2, "value": values}, index=indices) + expected = Series(values, index=indices, name="value") + + def reindex_helper(x): + return x.reindex(np.arange(x.index.min(), x.index.max() + 1)) + + # the following group by raised a ValueError + result = df.groupby("group", group_keys=False).value.apply(reindex_helper) + tm.assert_series_equal(expected, result) + + +def test_apply_corner_cases(): + # #535, can't use sliding iterator + + N = 10 + labels = np.random.default_rng(2).integers(0, 100, size=N) + df = DataFrame( + { + "key": labels, + "value1": np.random.default_rng(2).standard_normal(N), + "value2": ["foo", "bar", "baz", "qux", "a"] * (N // 5), + } + ) + + grouped = df.groupby("key", group_keys=False) + + def f(g): + g["value3"] = g["value1"] * 2 + return g + + result = grouped.apply(f) + assert "value3" in result + + +def test_apply_numeric_coercion_when_datetime(): + # In the past, group-by/apply operations have been over-eager + # in converting dtypes to numeric, in the presence of datetime + # columns. Various GH issues were filed, the reproductions + # for which are here. + + # GH 15670 + df = DataFrame( + {"Number": [1, 2], "Date": ["2017-03-02"] * 2, "Str": ["foo", "inf"]} + ) + expected = df.groupby(["Number"]).apply(lambda x: x.iloc[0]) + df.Date = pd.to_datetime(df.Date) + result = df.groupby(["Number"]).apply(lambda x: x.iloc[0]) + tm.assert_series_equal(result["Str"], expected["Str"]) + + +def test_apply_numeric_coercion_when_datetime_getitem(): + # GH 15421 + df = DataFrame( + {"A": [10, 20, 30], "B": ["foo", "3", "4"], "T": [pd.Timestamp("12:31:22")] * 3} + ) + + def get_B(g): + return g.iloc[0][["B"]] + + result = df.groupby("A").apply(get_B)["B"] + expected = df.B + expected.index = df.A + tm.assert_series_equal(result, expected) + + +def test_apply_numeric_coercion_when_datetime_with_nat(): + # GH 14423 + def predictions(tool): + out = Series(index=["p1", "p2", "useTime"], dtype=object) + if "step1" in list(tool.State): + out["p1"] = str(tool[tool.State == "step1"].Machine.values[0]) + if "step2" in list(tool.State): + out["p2"] = str(tool[tool.State == "step2"].Machine.values[0]) + out["useTime"] = str(tool[tool.State == "step2"].oTime.values[0]) + return out + + df1 = DataFrame( + { + "Key": ["B", "B", "A", "A"], + "State": ["step1", "step2", "step1", "step2"], + "oTime": ["", "2016-09-19 05:24:33", "", "2016-09-19 23:59:04"], + "Machine": ["23", "36L", "36R", "36R"], + } + ) + df2 = df1.copy() + df2.oTime = pd.to_datetime(df2.oTime) + expected = df1.groupby("Key").apply(predictions).p1 + result = df2.groupby("Key").apply(predictions).p1 + tm.assert_series_equal(expected, result) + + +def test_apply_aggregating_timedelta_and_datetime(): + # Regression test for GH 15562 + # The following groupby caused ValueErrors and IndexErrors pre 0.20.0 + + df = DataFrame( + { + "clientid": ["A", "B", "C"], + "datetime": [np.datetime64("2017-02-01 00:00:00")] * 3, + } + ) + df["time_delta_zero"] = df.datetime - df.datetime + result = df.groupby("clientid").apply( + lambda ddf: Series( + {"clientid_age": ddf.time_delta_zero.min(), "date": ddf.datetime.min()} + ) + ) + expected = DataFrame( + { + "clientid": ["A", "B", "C"], + "clientid_age": [np.timedelta64(0, "D")] * 3, + "date": [np.datetime64("2017-02-01 00:00:00")] * 3, + } + ).set_index("clientid") + + tm.assert_frame_equal(result, expected) + + +def test_apply_groupby_datetimeindex(): + # GH 26182 + # groupby apply failed on dataframe with DatetimeIndex + + data = [["A", 10], ["B", 20], ["B", 30], ["C", 40], ["C", 50]] + df = DataFrame( + data, columns=["Name", "Value"], index=pd.date_range("2020-09-01", "2020-09-05") + ) + + result = df.groupby("Name").sum() + + expected = DataFrame({"Name": ["A", "B", "C"], "Value": [10, 50, 90]}) + expected.set_index("Name", inplace=True) + + tm.assert_frame_equal(result, expected) + + +def test_time_field_bug(): + # Test a fix for the following error related to GH issue 11324 When + # non-key fields in a group-by dataframe contained time-based fields + # that were not returned by the apply function, an exception would be + # raised. + + df = DataFrame({"a": 1, "b": [datetime.now() for nn in range(10)]}) + + def func_with_no_date(batch): + return Series({"c": 2}) + + def func_with_date(batch): + return Series({"b": datetime(2015, 1, 1), "c": 2}) + + dfg_no_conversion = df.groupby(by=["a"]).apply(func_with_no_date) + dfg_no_conversion_expected = DataFrame({"c": 2}, index=[1]) + dfg_no_conversion_expected.index.name = "a" + + dfg_conversion = df.groupby(by=["a"]).apply(func_with_date) + dfg_conversion_expected = DataFrame( + {"b": pd.Timestamp(2015, 1, 1), "c": 2}, index=[1] + ) + dfg_conversion_expected.index.name = "a" + + tm.assert_frame_equal(dfg_no_conversion, dfg_no_conversion_expected) + tm.assert_frame_equal(dfg_conversion, dfg_conversion_expected) + + +def test_gb_apply_list_of_unequal_len_arrays(): + # GH1738 + df = DataFrame( + { + "group1": ["a", "a", "a", "b", "b", "b", "a", "a", "a", "b", "b", "b"], + "group2": ["c", "c", "d", "d", "d", "e", "c", "c", "d", "d", "d", "e"], + "weight": [1.1, 2, 3, 4, 5, 6, 2, 4, 6, 8, 1, 2], + "value": [7.1, 8, 9, 10, 11, 12, 8, 7, 6, 5, 4, 3], + } + ) + df = df.set_index(["group1", "group2"]) + df_grouped = df.groupby(level=["group1", "group2"], sort=True) + + def noddy(value, weight): + out = np.array(value * weight).repeat(3) + return out + + # the kernel function returns arrays of unequal length + # pandas sniffs the first one, sees it's an array and not + # a list, and assumed the rest are of equal length + # and so tries a vstack + + # don't die + df_grouped.apply(lambda x: noddy(x.value, x.weight)) + + +def test_groupby_apply_all_none(): + # Tests to make sure no errors if apply function returns all None + # values. Issue 9684. + test_df = DataFrame({"groups": [0, 0, 1, 1], "random_vars": [8, 7, 4, 5]}) + + def test_func(x): + pass + + result = test_df.groupby("groups").apply(test_func) + expected = DataFrame(columns=["random_vars"], dtype="int64") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "in_data, out_idx, out_data", + [ + [ + {"groups": [1, 1, 1, 2], "vars": [0, 1, 2, 3]}, + [[1, 1], [0, 2]], + {"vars": [0, 2]}, + ], + [ + {"groups": [1, 2, 2, 2], "vars": [0, 1, 2, 3]}, + [[2, 2], [1, 3]], + {"vars": [1, 3]}, + ], + ], +) +def test_groupby_apply_none_first(in_data, out_idx, out_data): + # GH 12824. Tests if apply returns None first. + test_df1 = DataFrame(in_data) + + def test_func(x): + if x.shape[0] < 2: + return None + return x.iloc[[0, -1]] + + result1 = test_df1.groupby("groups").apply(test_func) + index1 = MultiIndex.from_arrays(out_idx, names=["groups", None]) + expected1 = DataFrame(out_data, index=index1) + tm.assert_frame_equal(result1, expected1) + + +def test_groupby_apply_return_empty_chunk(): + # GH 22221: apply filter which returns some empty groups + df = DataFrame({"value": [0, 1], "group": ["filled", "empty"]}) + groups = df.groupby("group") + result = groups.apply(lambda group: group[group.value != 1]["value"]) + expected = Series( + [0], + name="value", + index=MultiIndex.from_product( + [["empty", "filled"], [0]], names=["group", None] + ).drop("empty"), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("meth", ["apply", "transform"]) +def test_apply_with_mixed_types(meth): + # gh-20949 + df = DataFrame({"A": "a a b".split(), "B": [1, 2, 3], "C": [4, 6, 5]}) + g = df.groupby("A", group_keys=False) + + result = getattr(g, meth)(lambda x: x / x.sum()) + expected = DataFrame({"B": [1 / 3.0, 2 / 3.0, 1], "C": [0.4, 0.6, 1.0]}) + tm.assert_frame_equal(result, expected) + + +def test_func_returns_object(): + # GH 28652 + df = DataFrame({"a": [1, 2]}, index=Index([1, 2])) + result = df.groupby("a").apply(lambda g: g.index) + expected = Series([Index([1]), Index([2])], index=Index([1, 2], name="a")) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "group_column_dtlike", + [datetime.today(), datetime.today().date(), datetime.today().time()], +) +def test_apply_datetime_issue(group_column_dtlike): + # GH-28247 + # groupby-apply throws an error if one of the columns in the DataFrame + # is a datetime object and the column labels are different from + # standard int values in range(len(num_columns)) + + df = DataFrame({"a": ["foo"], "b": [group_column_dtlike]}) + result = df.groupby("a").apply(lambda x: Series(["spam"], index=[42])) + + expected = DataFrame(["spam"], Index(["foo"], dtype="str", name="a"), columns=[42]) + tm.assert_frame_equal(result, expected) + + +def test_apply_series_return_dataframe_groups(): + # GH 10078 + tdf = DataFrame( + { + "day": { + 0: pd.Timestamp("2015-02-24 00:00:00"), + 1: pd.Timestamp("2015-02-24 00:00:00"), + 2: pd.Timestamp("2015-02-24 00:00:00"), + 3: pd.Timestamp("2015-02-24 00:00:00"), + 4: pd.Timestamp("2015-02-24 00:00:00"), + }, + "userAgent": { + 0: "some UA string", + 1: "some UA string", + 2: "some UA string", + 3: "another UA string", + 4: "some UA string", + }, + "userId": { + 0: "17661101", + 1: "17661101", + 2: "17661101", + 3: "17661101", + 4: "17661101", + }, + } + ) + + def most_common_values(df): + return Series({c: s.value_counts().index[0] for c, s in df.items()}) + + result = tdf.groupby("day").apply(most_common_values)["userId"] + expected = Series( + ["17661101"], index=pd.DatetimeIndex(["2015-02-24"], name="day"), name="userId" + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("category", [False, True]) +def test_apply_multi_level_name(category): + # https://github.com/pandas-dev/pandas/issues/31068 + b = [1, 2] * 5 + if category: + b = pd.Categorical(b, categories=[1, 2, 3]) + expected_index = pd.CategoricalIndex([1, 2, 3], categories=[1, 2, 3], name="B") + expected_values = [20, 25, 0] + else: + expected_index = Index([1, 2], name="B") + expected_values = [20, 25] + expected = DataFrame( + {"C": expected_values, "D": expected_values}, index=expected_index + ) + + df = DataFrame( + {"A": np.arange(10), "B": b, "C": list(range(10)), "D": list(range(10))} + ).set_index(["A", "B"]) + result = df.groupby("B", observed=False).apply(lambda x: x.sum()) + tm.assert_frame_equal(result, expected) + assert df.index.names == ["A", "B"] + + +def test_groupby_apply_datetime_result_dtypes(using_infer_string): + # GH 14849 + data = DataFrame.from_records( + [ + (pd.Timestamp(2016, 1, 1), "red", "dark", 1, "8"), + (pd.Timestamp(2015, 1, 1), "green", "stormy", 2, "9"), + (pd.Timestamp(2014, 1, 1), "blue", "bright", 3, "10"), + (pd.Timestamp(2013, 1, 1), "blue", "calm", 4, "potato"), + ], + columns=["observation", "color", "mood", "intensity", "score"], + ) + result = data.groupby("color").apply(lambda g: g.iloc[0]).dtypes + dtype = pd.StringDtype(na_value=np.nan) if using_infer_string else object + expected = Series( + [np.dtype("datetime64[us]"), dtype, np.int64, dtype], + index=["observation", "mood", "intensity", "score"], + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "index", + [ + pd.CategoricalIndex(list("abc")), + pd.interval_range(0, 3), + pd.period_range("2020", periods=3, freq="D"), + MultiIndex.from_tuples([("a", 0), ("a", 1), ("b", 0)]), + ], +) +def test_apply_index_has_complex_internals(index): + # GH 31248 + df = DataFrame({"group": [1, 1, 2], "value": [0, 1, 0]}, index=index) + result = df.groupby("group", group_keys=False).apply(lambda x: x) + tm.assert_frame_equal(result, df[["value"]]) + + +@pytest.mark.parametrize( + "function, expected_values", + [ + (lambda x: x.index.to_list(), [[0, 1], [2, 3]]), + (lambda x: set(x.index.to_list()), [{0, 1}, {2, 3}]), + (lambda x: tuple(x.index.to_list()), [(0, 1), (2, 3)]), + ( + lambda x: dict(enumerate(x.index.to_list())), + [{0: 0, 1: 1}, {0: 2, 1: 3}], + ), + ( + lambda x: [{n: i} for (n, i) in enumerate(x.index.to_list())], + [[{0: 0}, {1: 1}], [{0: 2}, {1: 3}]], + ), + ], +) +def test_apply_function_returns_non_pandas_non_scalar(function, expected_values): + # GH 31441 + df = DataFrame(["A", "A", "B", "B"], columns=["groups"]) + result = df.groupby("groups").apply(function) + expected = Series(expected_values, index=Index(["A", "B"], name="groups")) + tm.assert_series_equal(result, expected) + + +def test_apply_function_returns_numpy_array(): + # GH 31605 + def fct(group): + return group["B"].values.flatten() + + df = DataFrame({"A": ["a", "a", "b", "none"], "B": [1, 2, 3, np.nan]}) + + result = df.groupby("A").apply(fct) + expected = Series( + [[1.0, 2.0], [3.0], [np.nan]], index=Index(["a", "b", "none"], name="A") + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("function", [lambda gr: gr.index, lambda gr: gr.index + 1 - 1]) +def test_apply_function_index_return(function): + # GH: 22541 + df = DataFrame([1, 2, 2, 2, 1, 2, 3, 1, 3, 1], columns=["id"]) + result = df.groupby("id").apply(function) + expected = Series( + [Index([0, 4, 7, 9]), Index([1, 2, 3, 5]), Index([6, 8])], + index=Index([1, 2, 3], name="id"), + ) + tm.assert_series_equal(result, expected) + + +def test_apply_function_with_indexing_return_column(): + # GH#7002, GH#41480, GH#49256 + df = DataFrame( + { + "foo1": ["one", "two", "two", "three", "one", "two"], + "foo2": [1, 2, 4, 4, 5, 6], + } + ) + result = df.groupby("foo1", as_index=False).apply(lambda x: x.mean()) + expected = DataFrame( + { + "foo1": ["one", "three", "two"], + "foo2": [3.0, 4.0, 4.0], + } + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "udf", + [lambda x: x.copy(), lambda x: x.copy().rename(lambda y: y + 1)], +) +@pytest.mark.parametrize("group_keys", [True, False]) +def test_apply_result_type(group_keys, udf): + # https://github.com/pandas-dev/pandas/issues/34809 + # We'd like to control whether the group keys end up in the index + # regardless of whether the UDF happens to be a transform. + df = DataFrame({"A": ["a", "b"], "B": [1, 2]}) + df_result = df.groupby("A", group_keys=group_keys).apply(udf) + series_result = df.B.groupby(df.A, group_keys=group_keys).apply(udf) + + if group_keys: + assert df_result.index.nlevels == 2 + assert series_result.index.nlevels == 2 + else: + assert df_result.index.nlevels == 1 + assert series_result.index.nlevels == 1 + + +def test_result_order_group_keys_false(): + # GH 34998 + # apply result order should not depend on whether index is the same or just equal + df = DataFrame({"A": [2, 1, 2], "B": [1, 2, 3]}) + result = df.groupby("A", group_keys=False).apply(lambda x: x) + expected = df.groupby("A", group_keys=False).apply(lambda x: x.copy()) + tm.assert_frame_equal(result, expected) + + +def test_apply_with_timezones_aware(): + # GH: 27212 + dates = ["2001-01-01"] * 2 + ["2001-01-02"] * 2 + ["2001-01-03"] * 2 + index_no_tz = pd.DatetimeIndex(dates) + index_tz = pd.DatetimeIndex(dates, tz="UTC") + df1 = DataFrame({"x": list(range(2)) * 3, "y": range(6), "t": index_no_tz}) + df2 = DataFrame({"x": list(range(2)) * 3, "y": range(6), "t": index_tz}) + + result1 = df1.groupby("x", group_keys=False).apply(lambda df: df[["y"]].copy()) + result2 = df2.groupby("x", group_keys=False).apply(lambda df: df[["y"]].copy()) + + tm.assert_frame_equal(result1, result2) + + +def test_apply_is_unchanged_when_other_methods_are_called_first(reduction_func): + # GH #34656 + # GH #34271 + df = DataFrame( + { + "a": [99, 99, 99, 88, 88, 88], + "b": [1, 2, 3, 4, 5, 6], + "c": [10, 20, 30, 40, 50, 60], + } + ) + + expected = DataFrame( + {"b": [15, 6], "c": [150, 60]}, + index=Index([88, 99], name="a"), + ) + + # Check output when no other methods are called before .apply() + grp = df.groupby(by="a") + result = grp.apply(np.sum, axis=0) + tm.assert_frame_equal(result, expected) + + # Check output when another method is called before .apply() + grp = df.groupby(by="a") + args = get_groupby_method_args(reduction_func, df) + if reduction_func == "corrwith": + warn = Pandas4Warning + msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn = None + msg = "" + with tm.assert_produces_warning(warn, match=msg): + _ = getattr(grp, reduction_func)(*args) + result = grp.apply(np.sum, axis=0) + tm.assert_frame_equal(result, expected) + + +def test_apply_with_date_in_multiindex_does_not_convert_to_timestamp(): + # GH 29617 + + df = DataFrame( + { + "A": ["a", "a", "a", "b"], + "B": [ + date(2020, 1, 10), + date(2020, 1, 10), + date(2020, 2, 10), + date(2020, 2, 10), + ], + "C": [1, 2, 3, 4], + }, + index=Index([100, 101, 102, 103], name="idx"), + ) + + grp = df.groupby(["A", "B"]) + result = grp.apply(lambda x: x.head(1)) + + expected = df.iloc[[0, 2, 3]] + expected = expected.reset_index() + expected.index = MultiIndex.from_frame(expected[["A", "B", "idx"]]) + expected = expected.drop(columns=["A", "B", "idx"]) + + tm.assert_frame_equal(result, expected) + for val in result.index.levels[1]: + assert type(val) is date + + +def test_apply_dropna_with_indexed_same(dropna): + # GH 38227 + # GH#43205 + df = DataFrame( + { + "col": [1, 2, 3, 4, 5], + "group": ["a", np.nan, np.nan, "b", "b"], + }, + index=list("xxyxz"), + ) + result = df.groupby("group", dropna=dropna, group_keys=False).apply(lambda x: x) + expected = df.dropna()[["col"]] if dropna else df[["col"]].iloc[[0, 3, 1, 2, 4]] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "as_index, expected", + [ + [ + False, + DataFrame( + [[1, 1, 1], [2, 2, 1]], columns=Index(["a", "b", None], dtype=object) + ), + ], + [ + True, + Series( + [1, 1], index=MultiIndex.from_tuples([(1, 1), (2, 2)], names=["a", "b"]) + ), + ], + ], +) +def test_apply_as_index_constant_lambda(as_index, expected): + # GH 13217 + df = DataFrame({"a": [1, 1, 2, 2], "b": [1, 1, 2, 2], "c": [1, 1, 1, 1]}) + result = df.groupby(["a", "b"], as_index=as_index).apply(lambda x: 1) + tm.assert_equal(result, expected) + + +def test_sort_index_groups(): + # GH 20420 + df = DataFrame( + {"A": [1, 2, 3, 4, 5], "B": [6, 7, 8, 9, 0], "C": [1, 1, 1, 2, 2]}, + index=range(5), + ) + result = df.groupby("C").apply(lambda x: x.A.sort_index()) + expected = Series( + range(1, 6), + index=MultiIndex.from_tuples( + [(1, 0), (1, 1), (1, 2), (2, 3), (2, 4)], names=["C", None] + ), + name="A", + ) + tm.assert_series_equal(result, expected) + + +def test_positional_slice_groups_datetimelike(): + # GH 21651 + expected = DataFrame( + { + "date": pd.date_range("2010-01-01", freq="12h", periods=5), + "vals": range(5), + "let": list("abcde"), + } + ) + result = expected.groupby( + [expected.let, expected.date.dt.date], group_keys=False + ).apply(lambda x: x.iloc[0:]) + tm.assert_frame_equal(result, expected[["date", "vals"]]) + + +def test_groupby_apply_shape_cache_safety(): + # GH#42702 this fails if we cache_readonly Block.shape + df = DataFrame({"A": ["a", "a", "b"], "B": [1, 2, 3], "C": [4, 6, 5]}) + gb = df.groupby("A") + result = gb[["B", "C"]].apply(lambda x: x.astype(float).max() - x.min()) + + expected = DataFrame( + {"B": [1.0, 0.0], "C": [2.0, 0.0]}, index=Index(["a", "b"], name="A") + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_apply_to_series_name(): + # GH52444 + df = DataFrame.from_dict( + { + "a": ["a", "b", "a", "b"], + "b1": ["aa", "ac", "ac", "ad"], + "b2": ["aa", "aa", "aa", "ac"], + } + ) + grp = df.groupby("a")[["b1", "b2"]] + result = grp.apply(lambda x: x.unstack().value_counts()) + + expected_idx = MultiIndex.from_arrays( + arrays=[["a", "a", "b", "b", "b"], ["aa", "ac", "ac", "ad", "aa"]], + names=["a", None], + ) + expected = Series([3, 1, 2, 1, 1], index=expected_idx, name="count") + tm.assert_series_equal(result, expected) + + +def test_apply_na(dropna): + # GH#28984 + df = DataFrame( + {"grp": [1, 1, 2, 2], "y": [1, 0, 2, 5], "z": [1, 2, np.nan, np.nan]} + ) + dfgrp = df.groupby("grp", dropna=dropna) + result = dfgrp.apply(lambda grp_df: grp_df.nlargest(1, "z")) + expected = dfgrp.apply(lambda x: x.sort_values("z", ascending=False).head(1)) + tm.assert_frame_equal(result, expected) + + +def test_apply_empty_string_nan_coerce_bug(): + # GH#24903 + result = ( + DataFrame( + { + "a": [1, 1, 2, 2], + "b": ["", "", "", ""], + "c": pd.to_datetime([1, 2, 3, 4], unit="s"), + } + ) + .groupby(["a", "b"]) + .apply(lambda df: df.iloc[-1]) + ) + expected = DataFrame( + [[pd.to_datetime(2, unit="s")], [pd.to_datetime(4, unit="s")]], + columns=["c"], + index=MultiIndex.from_tuples([(1, ""), (2, "")], names=["a", "b"]), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("index_values", [[1, 2, 3], [1.0, 2.0, 3.0]]) +def test_apply_index_key_error_bug(index_values): + # GH 44310 + result = DataFrame( + { + "a": ["aa", "a2", "a3"], + "b": [1, 2, 3], + }, + index=Index(index_values), + ) + expected = DataFrame( + { + "b_mean": [2.0, 3.0, 1.0], + }, + index=Index(["a2", "a3", "aa"], name="a"), + ) + result = result.groupby("a").apply( + lambda df: Series([df["b"].mean()], index=["b_mean"]) + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "arg,idx", + [ + [ + [ + 1, + 2, + 3, + ], + [ + 0.1, + 0.3, + 0.2, + ], + ], + [ + [ + 1, + 2, + 3, + ], + [ + 0.1, + 0.2, + 0.3, + ], + ], + [ + [ + 1, + 4, + 3, + ], + [ + 0.1, + 0.4, + 0.2, + ], + ], + ], +) +def test_apply_nonmonotonic_float_index(arg, idx): + # GH 34455 + df = DataFrame({"grp": arg, "col": arg}, index=idx) + result = df.groupby("grp", group_keys=False).apply(lambda x: x) + expected = df[["col"]] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("args, kwargs", [([True], {}), ([], {"numeric_only": True})]) +def test_apply_str_with_args(df, args, kwargs): + # GH#46479 + gb = df.groupby("A") + result = gb.apply("sum", *args, **kwargs) + expected = gb.sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("name", ["some_name", None]) +def test_result_name_when_one_group(name): + # GH 46369 + ser = Series([1, 2], name=name) + result = ser.groupby(["a", "a"], group_keys=False).apply(lambda x: x) + expected = Series([1, 2], name=name) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, op", + [ + ("apply", lambda gb: gb.values[-1]), + ("apply", lambda gb: gb["b"].iloc[0]), + ("agg", "skew"), + ("agg", "kurt"), + ("agg", "prod"), + ("agg", "sum"), + ], +) +def test_empty_df(method, op): + # GH 47985 + empty_df = DataFrame({"a": [], "b": []}) + gb = empty_df.groupby("a", group_keys=True) + group = gb.b + + result = getattr(group, method)(op) + expected = Series( + [], name="b", dtype="float64", index=Index([], dtype="float64", name="a") + ) + + tm.assert_series_equal(result, expected) + + +def test_include_groups(): + # GH#7155 + df = DataFrame({"a": [1, 1, 2], "b": [3, 4, 5]}) + gb = df.groupby("a") + with pytest.raises(ValueError, match="include_groups=True is no longer allowed"): + gb.apply(lambda x: x.sum(), include_groups=True) + + +@pytest.mark.parametrize("func, value", [(max, 2), (min, 1), (sum, 3)]) +def test_builtins_apply(func, value): + # GH#8155, GH#53974 + # Builtins act as e.g. sum(group), which sums the column labels of group + df = DataFrame({0: [1, 1, 2], 1: [3, 4, 5], 2: [3, 4, 5]}) + gb = df.groupby(0) + result = gb.apply(func) + + expected = Series([value, value], index=Index([1, 2], name=0)) + tm.assert_series_equal(result, expected) + + +def test_inconsistent_return_type(): + # GH5592 + # inconsistent return type + df = DataFrame( + { + "A": ["Tiger", "Tiger", "Tiger", "Lamb", "Lamb", "Pony", "Pony"], + "B": Series(np.arange(7), dtype="int64"), + "C": pd.date_range("20130101", periods=7), + } + ) + + def f_0(grp): + return grp.iloc[0] + + expected = df.groupby("A").first()[["B"]] + result = df.groupby("A").apply(f_0)[["B"]] + tm.assert_frame_equal(result, expected) + + def f_1(grp): + if grp.name == "Tiger": + return None + return grp.iloc[0] + + result = df.groupby("A").apply(f_1)[["B"]] + e = expected.copy() + e.loc["Tiger"] = np.nan + tm.assert_frame_equal(result, e) + + def f_2(grp): + if grp.name == "Pony": + return None + return grp.iloc[0] + + result = df.groupby("A").apply(f_2)[["B"]] + e = expected.copy() + e.loc["Pony"] = np.nan + tm.assert_frame_equal(result, e) + + # 5592 revisited, with datetimes + def f_3(grp): + if grp.name == "Pony": + return None + return grp.iloc[0] + + result = df.groupby("A").apply(f_3)[["C"]] + e = df.groupby("A").first()[["C"]] + e.loc["Pony"] = pd.NaT + tm.assert_frame_equal(result, e) + + # scalar outputs + def f_4(grp): + if grp.name == "Pony": + return None + return grp.iloc[0].loc["C"] + + result = df.groupby("A").apply(f_4) + e = df.groupby("A").first()["C"].copy() + e.loc["Pony"] = np.nan + e.name = None + tm.assert_series_equal(result, e) + + +def test_nonreducer_nonstransform(): + # GH3380, GH60619 + # Was originally testing mutating in a UDF; now kept as an example + # of using apply with a nonreducer and nontransformer. + df = DataFrame( + { + "cat1": ["a"] * 8 + ["b"] * 6, + "cat2": ["c"] * 2 + + ["d"] * 2 + + ["e"] * 2 + + ["f"] * 2 + + ["c"] * 2 + + ["d"] * 2 + + ["e"] * 2, + "val": np.random.default_rng(2).integers(100, size=14), + } + ) + + def f(x): + x = x.copy() + x["rank"] = x.val.rank(method="min") + return x.groupby("cat2")["rank"].min() + + expected = DataFrame( + { + "cat1": list("aaaabbb"), + "cat2": list("cdefcde"), + "rank": [3.0, 2.0, 5.0, 1.0, 2.0, 4.0, 1.0], + } + ).set_index(["cat1", "cat2"])["rank"] + result = df.groupby("cat1").apply(f) + tm.assert_series_equal(result, expected) + + +def test_groupby_apply_store_copy(): + # GH40673 + rng = np.random.default_rng(seed=42) + + df = DataFrame( + { + "A": rng.normal(10, 12, size=(4,)), + "B": [1, 2, 1, 2], + } + ) + + store = {} + + def addstore(x): + store[len(store)] = x.copy() + + df.groupby("B").apply(addstore) + + expected_out_0 = df.iloc[[0, 2], [0]] + expected_out_1 = df.iloc[[1, 3], [0]] + + tm.assert_frame_equal(store[0], expected_out_0) + tm.assert_frame_equal(store[1], expected_out_1) diff --git a/pandas/tests/groupby/test_bin_groupby.py b/pandas/tests/groupby/test_bin_groupby.py new file mode 100644 index 0000000000000000000000000000000000000000..07d52308e308ad637fad8df5802b8e79985b8127 --- /dev/null +++ b/pandas/tests/groupby/test_bin_groupby.py @@ -0,0 +1,67 @@ +import numpy as np +import pytest + +from pandas._libs import lib + +import pandas as pd +import pandas._testing as tm + + +def assert_block_lengths(x): + assert len(x) == len(x._mgr.blocks[0].mgr_locs) + return 0 + + +def cumsum_max(x): + x.cumsum().max() + return 0 + + +@pytest.mark.parametrize( + "func", + [ + cumsum_max, + assert_block_lengths, + ], +) +def test_mgr_locs_updated(func): + # https://github.com/pandas-dev/pandas/issues/31802 + # Some operations may require creating new blocks, which requires + # valid mgr_locs + df = pd.DataFrame({"A": ["a", "a", "a"], "B": ["a", "b", "b"], "C": [1, 1, 1]}) + result = df.groupby(["A", "B"]).agg(func) + expected = pd.DataFrame( + {"C": [0, 0]}, + index=pd.MultiIndex.from_product([["a"], ["a", "b"]], names=["A", "B"]), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "binner,closed,expected", + [ + ( + [0, 3, 6, 9], + "left", + [2, 5, 6], + ), + ( + [0, 3, 6, 9], + "right", + [3, 6, 6], + ), + ([0, 3, 6], "left", [2, 5]), + ( + [0, 3, 6], + "right", + [3, 6], + ), + ], +) +def test_generate_bins(binner, closed, expected): + values = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64) + result = lib.generate_bins_dt64( + values, np.array(binner, dtype=np.int64), closed=closed + ) + expected = np.array(expected, dtype=np.int64) + tm.assert_numpy_array_equal(result, expected) diff --git a/pandas/tests/groupby/test_categorical.py b/pandas/tests/groupby/test_categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..e39052e64e3072ef31edb29bb14648c427e8247c --- /dev/null +++ b/pandas/tests/groupby/test_categorical.py @@ -0,0 +1,2189 @@ +from datetime import datetime + +import numpy as np +import pytest + +from pandas.errors import Pandas4Warning + +import pandas as pd +from pandas import ( + Categorical, + CategoricalIndex, + DataFrame, + Index, + MultiIndex, + Series, + qcut, +) +import pandas._testing as tm +from pandas.api.typing import SeriesGroupBy +from pandas.tests.groupby import get_groupby_method_args + + +def cartesian_product_for_groupers(result, args, names, fill_value=np.nan): + """Reindex to a cartesian production for the groupers, + preserving the nature (Categorical) of each grouper + """ + + def f(a): + if isinstance(a, (CategoricalIndex, Categorical)): + categories = a.categories + a = Categorical.from_codes( + np.arange(len(categories)), categories=categories, ordered=a.ordered + ) + return a + + index = MultiIndex.from_product(map(f, args), names=names) + if isinstance(fill_value, dict): + # fill_value is a dict mapping column names to fill values + # -> reindex column by column (reindex itself does not support this) + res = {} + for col in result.columns: + res[col] = result[col].reindex(index, fill_value=fill_value[col]) + return DataFrame(res, index=index).sort_index() + + return result.reindex(index, fill_value=fill_value).sort_index() + + +_results_for_groupbys_with_missing_categories = { + # This maps the builtin groupby functions to their expected outputs for + # missing categories when they are called on a categorical grouper with + # observed=False. Some functions are expected to return NaN, some zero. + # These expected values can be used across several tests (i.e. they are + # the same for SeriesGroupBy and DataFrameGroupBy) but they should only be + # hardcoded in one place. + "all": True, + "any": False, + "count": 0, + "corrwith": np.nan, + "first": np.nan, + "idxmax": np.nan, + "idxmin": np.nan, + "last": np.nan, + "max": np.nan, + "mean": np.nan, + "median": np.nan, + "min": np.nan, + "nth": np.nan, + "nunique": 0, + "prod": 1, + "quantile": np.nan, + "sem": np.nan, + "size": 0, + "skew": np.nan, + "kurt": np.nan, + "std": np.nan, + "sum": 0, + "var": np.nan, +} + + +def test_apply_use_categorical_name(df): + cats = qcut(df.C, 4) + + def get_stats(group): + return { + "min": group.min(), + "max": group.max(), + "count": group.count(), + "mean": group.mean(), + } + + result = df.groupby(cats, observed=False).D.apply(get_stats) + assert result.index.names[0] == "C" + + +def test_basic(): + cats = Categorical( + ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + categories=["a", "b", "c", "d"], + ordered=True, + ) + data = DataFrame({"a": [1, 1, 1, 2, 2, 2, 3, 4, 5], "b": cats}) + + exp_index = CategoricalIndex(list("abcd"), name="b", ordered=True) + expected = DataFrame({"a": [1, 2, 4, np.nan]}, index=exp_index) + result = data.groupby("b", observed=False).mean() + tm.assert_frame_equal(result, expected) + + +def test_basic_single_grouper(): + cat1 = Categorical(["a", "a", "b", "b"], categories=["a", "b", "z"], ordered=True) + cat2 = Categorical(["c", "d", "c", "d"], categories=["c", "d", "y"], ordered=True) + df = DataFrame({"A": cat1, "B": cat2, "values": [1, 2, 3, 4]}) + + gb = df.groupby("A", observed=False) + exp_idx = CategoricalIndex(["a", "b", "z"], name="A", ordered=True) + expected = DataFrame({"values": Series([3, 7, 0], index=exp_idx)}) + result = gb.sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + +def test_basic_string(using_infer_string): + # GH 8623 + x = DataFrame( + [[1, "John P. Doe"], [2, "Jane Dove"], [1, "John P. Doe"]], + columns=["person_id", "person_name"], + ) + x["person_name"] = Categorical(x.person_name) + + g = x.groupby(["person_id"], observed=False) + result = g.transform(lambda x: x) + tm.assert_frame_equal(result, x[["person_name"]]) + + result = x.drop_duplicates("person_name") + expected = x.iloc[[0, 1]] + tm.assert_frame_equal(result, expected) + + def f(x): + return x.drop_duplicates("person_name").iloc[0] + + result = g.apply(f) + expected = x[["person_name"]].iloc[[0, 1]] + expected.index = Index([1, 2], name="person_id") + dtype = "str" if using_infer_string else object + expected["person_name"] = expected["person_name"].astype(dtype) + tm.assert_frame_equal(result, expected) + + +def test_basic_monotonic(): + # GH 9921 + df = DataFrame({"a": [5, 15, 25]}) + c = pd.cut(df.a, bins=[0, 10, 20, 30, 40]) + + result = df.a.groupby(c, observed=False).transform(sum) + tm.assert_series_equal(result, df["a"]) + + tm.assert_series_equal( + df.a.groupby(c, observed=False).transform(lambda xs: np.sum(xs)), df["a"] + ) + result = df.groupby(c, observed=False).transform(sum) + expected = df[["a"]] + tm.assert_frame_equal(result, expected) + + gbc = df.groupby(c, observed=False) + result = gbc.transform(lambda xs: np.max(xs, axis=0)) + tm.assert_frame_equal(result, df[["a"]]) + + result2 = gbc.transform(lambda xs: np.max(xs, axis=0)) + result3 = gbc.transform(max) + result4 = gbc.transform(np.maximum.reduce) + result5 = gbc.transform(lambda xs: np.maximum.reduce(xs)) + tm.assert_frame_equal(result2, df[["a"]], check_dtype=False) + tm.assert_frame_equal(result3, df[["a"]], check_dtype=False) + tm.assert_frame_equal(result4, df[["a"]]) + tm.assert_frame_equal(result5, df[["a"]]) + + # Filter + tm.assert_series_equal(df.a.groupby(c, observed=False).filter(np.all), df["a"]) + tm.assert_frame_equal(df.groupby(c, observed=False).filter(np.all), df) + + +def test_basic_non_monotonic(): + df = DataFrame({"a": [5, 15, 25, -5]}) + c = pd.cut(df.a, bins=[-10, 0, 10, 20, 30, 40]) + + result = df.a.groupby(c, observed=False).transform(sum) + tm.assert_series_equal(result, df["a"]) + + tm.assert_series_equal( + df.a.groupby(c, observed=False).transform(lambda xs: np.sum(xs)), df["a"] + ) + result = df.groupby(c, observed=False).transform(sum) + expected = df[["a"]] + tm.assert_frame_equal(result, expected) + + tm.assert_frame_equal( + df.groupby(c, observed=False).transform(lambda xs: np.sum(xs)), df[["a"]] + ) + + +def test_basic_cut_grouping(): + # GH 9603 + df = DataFrame({"a": [1, 0, 0, 0]}) + c = pd.cut(df.a, [0, 1, 2, 3, 4], labels=Categorical(list("abcd"))) + result = df.groupby(c, observed=False).apply(len) + + exp_index = CategoricalIndex(c.values.categories, ordered=c.values.ordered) + expected = Series([1, 0, 0, 0], index=exp_index) + expected.index.name = "a" + tm.assert_series_equal(result, expected) + + +def test_more_basic(): + levels = ["foo", "bar", "baz", "qux"] + codes = np.random.default_rng(2).integers(0, 4, size=10) + + cats = Categorical.from_codes(codes, levels, ordered=True) + + data = DataFrame(np.random.default_rng(2).standard_normal((10, 4))) + + result = data.groupby(cats, observed=False).mean() + + expected = data.groupby(np.asarray(cats), observed=False).mean() + exp_idx = CategoricalIndex(levels, categories=cats.categories, ordered=True) + expected = expected.reindex(exp_idx) + + tm.assert_frame_equal(result, expected) + + grouped = data.groupby(cats, observed=False) + desc_result = grouped.describe() + + idx = cats.codes.argsort() + ord_labels = np.asarray(cats).take(idx) + ord_data = data.take(idx) + + exp_cats = Categorical( + ord_labels, ordered=True, categories=["foo", "bar", "baz", "qux"] + ) + expected = ord_data.groupby(exp_cats, sort=False, observed=False).describe() + tm.assert_frame_equal(desc_result, expected) + + # GH 10460 + expc = Categorical.from_codes(np.arange(4).repeat(8), levels, ordered=True) + exp = CategoricalIndex(expc) + tm.assert_index_equal(desc_result.stack().index.get_level_values(0), exp) + exp = Index(["count", "mean", "std", "min", "25%", "50%", "75%", "max"] * 4) + tm.assert_index_equal(desc_result.stack().index.get_level_values(1), exp) + + +def test_level_get_group(observed): + # GH15155 + df = DataFrame( + data=np.arange(2, 22, 2), + index=MultiIndex( + levels=[CategoricalIndex(["a", "b"]), range(10)], + codes=[[0] * 5 + [1] * 5, range(10)], + names=["Index1", "Index2"], + ), + ) + g = df.groupby(level=["Index1"], observed=observed) + + # expected should equal test.loc[["a"]] + # GH15166 + expected = DataFrame( + data=np.arange(2, 12, 2), + index=MultiIndex( + levels=[CategoricalIndex(["a", "b"]), range(5)], + codes=[[0] * 5, range(5)], + names=["Index1", "Index2"], + ), + ) + result = g.get_group(("a",)) + tm.assert_frame_equal(result, expected) + + +def test_sorting_with_different_categoricals(): + # GH 24271 + df = DataFrame( + { + "group": ["A"] * 6 + ["B"] * 6, + "dose": ["high", "med", "low"] * 4, + "outcomes": np.arange(12.0), + } + ) + + df.dose = Categorical(df.dose, categories=["low", "med", "high"], ordered=True) + + result = df.groupby("group")["dose"].value_counts() + result = result.sort_index(level=0, sort_remaining=True) + index = ["low", "med", "high", "low", "med", "high"] + index = Categorical(index, categories=["low", "med", "high"], ordered=True) + index = [["A", "A", "A", "B", "B", "B"], CategoricalIndex(index)] + index = MultiIndex.from_arrays(index, names=["group", "dose"]) + expected = Series([2] * 6, index=index, name="count") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("ordered", [True, False]) +def test_apply(ordered): + # GH 10138 + + dense = Categorical(list("abc"), ordered=ordered) + + # 'b' is in the categories but not in the list + missing = Categorical(list("aaa"), categories=["a", "b"], ordered=ordered) + values = np.arange(len(dense)) + df = DataFrame({"missing": missing, "dense": dense, "values": values}) + grouped = df.groupby(["missing", "dense"], observed=True) + + # missing category 'b' should still exist in the output index + idx = MultiIndex.from_arrays([missing, dense], names=["missing", "dense"]) + expected = DataFrame([0, 1, 2.0], index=idx, columns=["values"]) + + result = grouped.apply(lambda x: np.mean(x, axis=0)) + tm.assert_frame_equal(result, expected) + + result = grouped.mean() + tm.assert_frame_equal(result, expected) + + result = grouped.agg(np.mean) + tm.assert_frame_equal(result, expected) + + # but for transform we should still get back the original index + idx = MultiIndex.from_arrays([missing, dense], names=["missing", "dense"]) + expected = Series(1, index=idx) + result = grouped.apply(lambda x: 1) + tm.assert_series_equal(result, expected) + + +def test_observed(observed, using_infer_string): + # multiple groupers, don't re-expand the output space + # of the grouper + # gh-14942 (implement) + # gh-10132 (back-compat) + # gh-8138 (back-compat) + # gh-8869 + + cat1 = Categorical(["a", "a", "b", "b"], categories=["a", "b", "z"], ordered=True) + cat2 = Categorical(["c", "d", "c", "d"], categories=["c", "d", "y"], ordered=True) + df = DataFrame({"A": cat1, "B": cat2, "values": [1, 2, 3, 4]}) + df["C"] = ["foo", "bar"] * 2 + + # multiple groupers with a non-cat + gb = df.groupby(["A", "B", "C"], observed=observed) + exp_index = MultiIndex.from_arrays( + [cat1, cat2, ["foo", "bar"] * 2], names=["A", "B", "C"] + ) + expected = DataFrame({"values": Series([1, 2, 3, 4], index=exp_index)}).sort_index() + result = gb.sum() + if not observed: + expected = cartesian_product_for_groupers( + expected, [cat1, cat2, ["foo", "bar"]], list("ABC"), fill_value=0 + ) + + tm.assert_frame_equal(result, expected) + + gb = df.groupby(["A", "B"], observed=observed) + exp_index = MultiIndex.from_arrays([cat1, cat2], names=["A", "B"]) + expected = DataFrame( + {"values": [1, 2, 3, 4], "C": ["foo", "bar", "foo", "bar"]}, index=exp_index + ) + result = gb.sum() + if not observed: + expected = cartesian_product_for_groupers( + expected, + [cat1, cat2], + list("AB"), + fill_value={"values": 0, "C": ""} if using_infer_string else 0, + ) + + tm.assert_frame_equal(result, expected) + + +def test_observed_single_column(observed): + # https://github.com/pandas-dev/pandas/issues/8138 + d = { + "cat": Categorical( + ["a", "b", "a", "b"], categories=["a", "b", "c"], ordered=True + ), + "ints": [1, 1, 2, 2], + "val": [10, 20, 30, 40], + } + df = DataFrame(d) + + groups_single_key = df.groupby("cat", observed=observed) + result = groups_single_key.mean() + + exp_index = CategoricalIndex( + list("ab"), name="cat", categories=list("abc"), ordered=True + ) + expected = DataFrame({"ints": [1.5, 1.5], "val": [20.0, 30]}, index=exp_index) + if not observed: + index = CategoricalIndex( + list("abc"), name="cat", categories=list("abc"), ordered=True + ) + expected = expected.reindex(index) + + tm.assert_frame_equal(result, expected) + + +def test_observed_two_columns(observed): + # https://github.com/pandas-dev/pandas/issues/8138 + d = { + "cat": Categorical( + ["a", "b", "a", "b"], categories=["a", "b", "c"], ordered=True + ), + "ints": [1, 1, 2, 2], + "val": [10, 20, 30, 40], + } + df = DataFrame(d) + groups_double_key = df.groupby(["cat", "ints"], observed=observed) + result = groups_double_key.agg("mean") + expected = DataFrame( + { + "val": [10.0, 30.0, 20.0, 40.0], + "cat": Categorical( + ["a", "a", "b", "b"], categories=["a", "b", "c"], ordered=True + ), + "ints": [1, 2, 1, 2], + } + ).set_index(["cat", "ints"]) + if not observed: + expected = cartesian_product_for_groupers( + expected, [df.cat.values, [1, 2]], ["cat", "ints"] + ) + + tm.assert_frame_equal(result, expected) + + # GH 10132 + for key in [("a", 1), ("b", 2), ("b", 1), ("a", 2)]: + c, i = key + result = groups_double_key.get_group(key) + expected = df[(df.cat == c) & (df.ints == i)] + tm.assert_frame_equal(result, expected) + + +def test_observed_with_as_index(observed): + # gh-8869 + # with as_index + d = { + "foo": [10, 8, 4, 8, 4, 1, 1], + "bar": [10, 20, 30, 40, 50, 60, 70], + "baz": ["d", "c", "e", "a", "a", "d", "c"], + } + df = DataFrame(d) + cat = pd.cut(df["foo"], np.linspace(0, 10, 3)) + df["range"] = cat + groups = df.groupby(["range", "baz"], as_index=False, observed=observed) + result = groups.agg("mean") + + groups2 = df.groupby(["range", "baz"], as_index=True, observed=observed) + expected = groups2.agg("mean").reset_index() + tm.assert_frame_equal(result, expected) + + +def test_observed_codes_remap(observed): + d = {"C1": [3, 3, 4, 5], "C2": [1, 2, 3, 4], "C3": [10, 100, 200, 34]} + df = DataFrame(d) + values = pd.cut(df["C1"], [1, 2, 3, 6]) + values.name = "cat" + groups_double_key = df.groupby([values, "C2"], observed=observed) + + idx = MultiIndex.from_arrays([values, [1, 2, 3, 4]], names=["cat", "C2"]) + expected = DataFrame( + {"C1": [3.0, 3.0, 4.0, 5.0], "C3": [10.0, 100.0, 200.0, 34.0]}, index=idx + ) + if not observed: + expected = cartesian_product_for_groupers( + expected, [values.values, [1, 2, 3, 4]], ["cat", "C2"] + ) + + result = groups_double_key.agg("mean") + tm.assert_frame_equal(result, expected) + + +def test_observed_perf(): + # we create a cartesian product, so this is + # non-performant if we don't use observed values + # gh-14942 + df = DataFrame( + { + "cat": np.random.default_rng(2).integers(0, 255, size=30000), + "int_id": np.random.default_rng(2).integers(0, 255, size=30000), + "other_id": np.random.default_rng(2).integers(0, 10000, size=30000), + "foo": 0, + } + ) + df["cat"] = df.cat.astype(str).astype("category") + + grouped = df.groupby(["cat", "int_id", "other_id"], observed=True) + result = grouped.count() + assert result.index.levels[0].nunique() == df.cat.nunique() + assert result.index.levels[1].nunique() == df.int_id.nunique() + assert result.index.levels[2].nunique() == df.other_id.nunique() + + +def test_observed_groups(observed): + # gh-20583 + # test that we have the appropriate groups + + cat = Categorical(["a", "c", "a"], categories=["a", "b", "c"]) + df = DataFrame({"cat": cat, "vals": [1, 2, 3]}) + g = df.groupby("cat", observed=observed) + + result = g.groups + if observed: + expected = {"a": Index([0, 2], dtype="int64"), "c": Index([1], dtype="int64")} + else: + expected = { + "a": Index([0, 2], dtype="int64"), + "b": Index([], dtype="int64"), + "c": Index([1], dtype="int64"), + } + + tm.assert_dict_equal(result, expected) + + +def test_groups_na_category(dropna, observed): + # https://github.com/pandas-dev/pandas/issues/61356 + df = DataFrame( + {"cat": Categorical(["a", np.nan, "a"], categories=list("adb"))}, + index=list("xyz"), + ) + g = df.groupby("cat", observed=observed, dropna=dropna) + + result = g.groups + expected = {"a": Index(["x", "z"])} + if not dropna: + expected |= {np.nan: Index(["y"])} + if not observed: + expected |= {"b": Index([]), "d": Index([])} + tm.assert_dict_equal(result, expected) + + +@pytest.mark.parametrize( + "keys, expected_values, expected_index_levels", + [ + ("a", [15, 9, 0], CategoricalIndex([1, 2, 3], name="a")), + ( + ["a", "b"], + [7, 8, 0, 0, 0, 9, 0, 0, 0], + [CategoricalIndex([1, 2, 3], name="a"), Index([4, 5, 6])], + ), + ( + ["a", "a2"], + [15, 0, 0, 0, 9, 0, 0, 0, 0], + [ + CategoricalIndex([1, 2, 3], name="a"), + CategoricalIndex([1, 2, 3], name="a"), + ], + ), + ], +) +@pytest.mark.parametrize("test_series", [True, False]) +def test_unobserved_in_index(keys, expected_values, expected_index_levels, test_series): + # GH#49354 - ensure unobserved cats occur when grouping by index levels + df = DataFrame( + { + "a": Categorical([1, 1, 2], categories=[1, 2, 3]), + "a2": Categorical([1, 1, 2], categories=[1, 2, 3]), + "b": [4, 5, 6], + "c": [7, 8, 9], + } + ).set_index(["a", "a2"]) + if "b" not in keys: + # Only keep b when it is used for grouping for consistent columns in the result + df = df.drop(columns="b") + + gb = df.groupby(keys, observed=False) + if test_series: + gb = gb["c"] + result = gb.sum() + + if len(keys) == 1: + index = expected_index_levels + else: + codes = [[0, 0, 0, 1, 1, 1, 2, 2, 2], 3 * [0, 1, 2]] + index = MultiIndex( + expected_index_levels, + codes=codes, + names=keys, + ) + expected = DataFrame({"c": expected_values}, index=index) + if test_series: + expected = expected["c"] + tm.assert_equal(result, expected) + + +def test_observed_groups_with_nan(observed): + # GH 24740 + df = DataFrame( + { + "cat": Categorical(["a", np.nan, "a"], categories=["a", "b", "d"]), + "vals": [1, 2, 3], + } + ) + g = df.groupby("cat", observed=observed) + result = g.groups + if observed: + expected = {"a": Index([0, 2], dtype="int64")} + else: + expected = { + "a": Index([0, 2], dtype="int64"), + "b": Index([], dtype="int64"), + "d": Index([], dtype="int64"), + } + tm.assert_dict_equal(result, expected) + + +def test_observed_nth(): + # GH 26385 + cat = Categorical(["a", np.nan, np.nan], categories=["a", "b", "c"]) + ser = Series([1, 2, 3]) + df = DataFrame({"cat": cat, "ser": ser}) + + result = df.groupby("cat", observed=False)["ser"].nth(0) + expected = df["ser"].iloc[[0]] + tm.assert_series_equal(result, expected) + + +def test_dataframe_categorical_with_nan(observed): + # GH 21151 + s1 = Categorical([np.nan, "a", np.nan, "a"], categories=["a", "b", "c"]) + s2 = Series([1, 2, 3, 4]) + df = DataFrame({"s1": s1, "s2": s2}) + result = df.groupby("s1", observed=observed).first().reset_index() + if observed: + expected = DataFrame( + {"s1": Categorical(["a"], categories=["a", "b", "c"]), "s2": [2]} + ) + else: + expected = DataFrame( + { + "s1": Categorical(["a", "b", "c"], categories=["a", "b", "c"]), + "s2": [2, np.nan, np.nan], + } + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("ordered", [True, False]) +def test_dataframe_categorical_ordered_observed_sort(ordered, observed, sort): + # GH 25871: Fix groupby sorting on ordered Categoricals + # GH 25167: Groupby with observed=True doesn't sort + + # Build a dataframe with cat having one unobserved category ('missing'), + # and a Series with identical values + label = Categorical( + ["d", "a", "b", "a", "d", "b"], + categories=["a", "b", "missing", "d"], + ordered=ordered, + ) + val = Series(["d", "a", "b", "a", "d", "b"]) + df = DataFrame({"label": label, "val": val}) + + # aggregate on the Categorical + result = df.groupby("label", observed=observed, sort=sort)["val"].aggregate("first") + + # If ordering works, we expect index labels equal to aggregation results, + # except for 'observed=False': label 'missing' has aggregation None + label = Series(result.index.array, dtype="object") + aggr = Series(result.array) + if not observed: + aggr[aggr.isna()] = "missing" + if not all(label == aggr): + msg = ( + "Labels and aggregation results not consistently sorted\n" + f"for (ordered={ordered}, observed={observed}, sort={sort})\n" + f"Result:\n{result}" + ) + pytest.fail(msg) + + +def test_datetime(): + # GH9049: ensure backward compatibility + levels = pd.date_range("2014-01-01", periods=4) + codes = np.random.default_rng(2).integers(0, 4, size=10) + + cats = Categorical.from_codes(codes, levels, ordered=True) + + data = DataFrame(np.random.default_rng(2).standard_normal((10, 4))) + result = data.groupby(cats, observed=False).mean() + + expected = data.groupby(np.asarray(cats), observed=False).mean() + expected = expected.reindex(levels) + expected.index = CategoricalIndex( + expected.index, categories=expected.index, ordered=True + ) + + tm.assert_frame_equal(result, expected) + + grouped = data.groupby(cats, observed=False) + desc_result = grouped.describe() + + idx = cats.codes.argsort() + ord_labels = cats.take(idx) + ord_data = data.take(idx) + expected = ord_data.groupby(ord_labels, observed=False).describe() + tm.assert_frame_equal(desc_result, expected) + tm.assert_index_equal(desc_result.index, expected.index) + tm.assert_index_equal( + desc_result.index.get_level_values(0), expected.index.get_level_values(0) + ) + + # GH 10460 + expc = Categorical.from_codes(np.arange(4).repeat(8), levels, ordered=True) + exp = CategoricalIndex(expc) + tm.assert_index_equal((desc_result.stack().index.get_level_values(0)), exp) + exp = Index(["count", "mean", "std", "min", "25%", "50%", "75%", "max"] * 4) + tm.assert_index_equal((desc_result.stack().index.get_level_values(1)), exp) + + +def test_categorical_index(): + s = np.random.default_rng(2) + levels = ["foo", "bar", "baz", "qux"] + codes = s.integers(0, 4, size=20) + cats = Categorical.from_codes(codes, levels, ordered=True) + df = DataFrame(np.repeat(np.arange(20), 4).reshape(-1, 4), columns=list("abcd")) + df["cats"] = cats + + # with a cat index + result = df.set_index("cats").groupby(level=0, observed=False).sum() + expected = df[list("abcd")].groupby(cats.codes, observed=False).sum() + expected.index = CategoricalIndex( + Categorical.from_codes([0, 1, 2, 3], levels, ordered=True), name="cats" + ) + tm.assert_frame_equal(result, expected) + + # with a cat column, should produce a cat index + result = df.groupby("cats", observed=False).sum() + expected = df[list("abcd")].groupby(cats.codes, observed=False).sum() + expected.index = CategoricalIndex( + Categorical.from_codes([0, 1, 2, 3], levels, ordered=True), name="cats" + ) + tm.assert_frame_equal(result, expected) + + +def test_describe_categorical_columns(): + # GH 11558 + cats = CategoricalIndex( + ["qux", "foo", "baz", "bar"], + categories=["foo", "bar", "baz", "qux"], + ordered=True, + ) + df = DataFrame(np.random.default_rng(2).standard_normal((20, 4)), columns=cats) + result = df.groupby([1, 2, 3, 4] * 5).describe() + + tm.assert_index_equal(result.stack().columns, cats) + tm.assert_categorical_equal(result.stack().columns.values, cats.values) + + +def test_unstack_categorical(): + # GH11558 (example is taken from the original issue) + df = DataFrame( + {"a": range(10), "medium": ["A", "B"] * 5, "artist": list("XYXXY") * 2} + ) + df["medium"] = df["medium"].astype("category") + + gcat = df.groupby(["artist", "medium"], observed=False)["a"].count().unstack() + result = gcat.describe() + + exp_columns = CategoricalIndex(["A", "B"], ordered=False, name="medium") + tm.assert_index_equal(result.columns, exp_columns) + tm.assert_categorical_equal(result.columns.values, exp_columns.values) + + result = gcat["A"] + gcat["B"] + expected = Series([6, 4], index=Index(["X", "Y"], name="artist")) + tm.assert_series_equal(result, expected) + + +def test_bins_unequal_len(): + # GH3011 + series = Series([np.nan, np.nan, 1, 1, 2, 2, 3, 3, 4, 4]) + bins = pd.cut(series.dropna().values, 4) + + # len(bins) != len(series) here + with pytest.raises(ValueError, match="Grouper and axis must be same length"): + series.groupby(bins).mean() + + +@pytest.mark.parametrize( + ["series", "data"], + [ + # Group a series with length and index equal to those of the grouper. + (Series(range(4)), {"A": [0, 3], "B": [1, 2]}), + # Group a series with length equal to that of the grouper and index unequal to + # that of the grouper. + (Series(range(4)).rename(lambda idx: idx + 1), {"A": [2], "B": [0, 1]}), + # GH44179: Group a series with length unequal to that of the grouper. + (Series(range(7)), {"A": [0, 3], "B": [1, 2]}), + ], +) +def test_categorical_series(series, data): + # Group the given series by a series with categorical data type such that group A + # takes indices 0 and 3 and group B indices 1 and 2, obtaining the values mapped in + # the given data. + groupby = series.groupby(Series(list("ABBA"), dtype="category"), observed=False) + result = groupby.aggregate(list) + expected = Series(data, index=CategoricalIndex(data.keys())) + tm.assert_series_equal(result, expected) + + +def test_as_index(): + # GH13204 + df = DataFrame( + { + "cat": Categorical([1, 2, 2], [1, 2, 3]), + "A": [10, 11, 11], + "B": [101, 102, 103], + } + ) + result = df.groupby(["cat", "A"], as_index=False, observed=True).sum() + expected = DataFrame( + { + "cat": Categorical([1, 2], categories=df.cat.cat.categories), + "A": [10, 11], + "B": [101, 205], + }, + columns=["cat", "A", "B"], + ) + tm.assert_frame_equal(result, expected) + + # function grouper + f = lambda r: df.loc[r, "A"] + result = df.groupby(["cat", f], as_index=False, observed=True).sum() + expected = DataFrame( + { + "cat": Categorical([1, 2], categories=df.cat.cat.categories), + "level_1": [10, 11], + "A": [10, 22], + "B": [101, 205], + }, + ) + tm.assert_frame_equal(result, expected) + + # another not in-axis grouper (conflicting names in index) + s = Series(["a", "b", "b"], name="cat") + result = df.groupby(["cat", s], as_index=False, observed=True).sum() + expected = DataFrame( + { + "cat": ["a", "b"], + "A": [10, 22], + "B": [101, 205], + }, + ) + tm.assert_frame_equal(result, expected) + + # is original index dropped? + group_columns = ["cat", "A"] + expected = DataFrame( + { + "cat": Categorical([1, 2], categories=df.cat.cat.categories), + "A": [10, 11], + "B": [101, 205], + }, + columns=["cat", "A", "B"], + ) + + for name in [None, "X", "B"]: + df.index = Index(list("abc"), name=name) + result = df.groupby(group_columns, as_index=False, observed=True).sum() + + tm.assert_frame_equal(result, expected) + + +def test_preserve_categories(): + # GH-13179 + categories = list("abc") + + # ordered=True + df = DataFrame({"A": Categorical(list("ba"), categories=categories, ordered=True)}) + sort_index = CategoricalIndex(categories, categories, ordered=True, name="A") + nosort_index = CategoricalIndex(list("bac"), categories, ordered=True, name="A") + tm.assert_index_equal( + df.groupby("A", sort=True, observed=False).first().index, sort_index + ) + # GH#42482 - don't sort result when sort=False, even when ordered=True + tm.assert_index_equal( + df.groupby("A", sort=False, observed=False).first().index, nosort_index + ) + + +def test_preserve_categories_ordered_false(): + # GH-13179 + categories = list("abc") + df = DataFrame({"A": Categorical(list("ba"), categories=categories, ordered=False)}) + sort_index = CategoricalIndex(categories, categories, ordered=False, name="A") + # GH#48749 - don't change order of categories + # GH#42482 - don't sort result when sort=False, even when ordered=True + nosort_index = CategoricalIndex(list("bac"), list("abc"), ordered=False, name="A") + tm.assert_index_equal( + df.groupby("A", sort=True, observed=False).first().index, sort_index + ) + tm.assert_index_equal( + df.groupby("A", sort=False, observed=False).first().index, nosort_index + ) + + +@pytest.mark.parametrize("col", ["C1", "C2"]) +def test_preserve_categorical_dtype(col): + # GH13743, GH13854 + df = DataFrame( + { + "A": [1, 2, 1, 1, 2], + "B": [10, 16, 22, 28, 34], + "C1": Categorical(list("abaab"), categories=list("bac"), ordered=False), + "C2": Categorical(list("abaab"), categories=list("bac"), ordered=True), + } + ) + # single grouper + exp_full = DataFrame( + { + "A": [2.0, 1.0, np.nan], + "B": [25.0, 20.0, np.nan], + "C1": Categorical(list("bac"), categories=list("bac"), ordered=False), + "C2": Categorical(list("bac"), categories=list("bac"), ordered=True), + } + ) + result1 = df.groupby(by=col, as_index=False, observed=False).mean(numeric_only=True) + result2 = ( + df.groupby(by=col, as_index=True, observed=False) + .mean(numeric_only=True) + .reset_index() + ) + expected = exp_full.reindex(columns=result1.columns) + tm.assert_frame_equal(result1, expected) + tm.assert_frame_equal(result2, expected) + + +@pytest.mark.parametrize( + "func, values", + [ + ("first", ["second", "first"]), + ("last", ["fourth", "third"]), + ("min", ["fourth", "first"]), + ("max", ["second", "third"]), + ], +) +def test_preserve_on_ordered_ops(func, values): + # gh-18502 + # preserve the categoricals on ops + c = Categorical(["first", "second", "third", "fourth"], ordered=True) + df = DataFrame({"payload": [-1, -2, -1, -2], "col": c}) + g = df.groupby("payload") + result = getattr(g, func)() + expected = DataFrame( + {"payload": [-2, -1], "col": Series(values, dtype=c.dtype)} + ).set_index("payload") + tm.assert_frame_equal(result, expected) + + # we should also preserve categorical for SeriesGroupBy + sgb = df.groupby("payload")["col"] + result = getattr(sgb, func)() + expected = expected["col"] + tm.assert_series_equal(result, expected) + + +def test_categorical_no_compress(): + data = Series(np.random.default_rng(2).standard_normal(9)) + + codes = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]) + cats = Categorical.from_codes(codes, [0, 1, 2], ordered=True) + + result = data.groupby(cats, observed=False).mean() + exp = data.groupby(codes, observed=False).mean() + + exp.index = CategoricalIndex( + exp.index, categories=cats.categories, ordered=cats.ordered + ) + tm.assert_series_equal(result, exp) + + codes = np.array([0, 0, 0, 1, 1, 1, 3, 3, 3]) + cats = Categorical.from_codes(codes, [0, 1, 2, 3], ordered=True) + + result = data.groupby(cats, observed=False).mean() + exp = data.groupby(codes, observed=False).mean().reindex(cats.categories) + exp.index = CategoricalIndex( + exp.index, categories=cats.categories, ordered=cats.ordered + ) + tm.assert_series_equal(result, exp) + + +def test_categorical_no_compress_string(): + cats = Categorical( + ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + categories=["a", "b", "c", "d"], + ordered=True, + ) + data = DataFrame({"a": [1, 1, 1, 2, 2, 2, 3, 4, 5], "b": cats}) + + result = data.groupby("b", observed=False).mean() + result = result["a"].values + exp = np.array([1, 2, 4, np.nan]) + tm.assert_numpy_array_equal(result, exp) + + +def test_groupby_empty_with_category(): + # GH-9614 + # test fix for when group by on None resulted in + # coercion of dtype categorical -> float + df = DataFrame({"A": [None] * 3, "B": Categorical(["train", "train", "test"])}) + result = df.groupby("A").first()["B"] + expected = Series( + Categorical([], categories=["test", "train"]), + index=Series([], dtype="object", name="A"), + name="B", + ) + tm.assert_series_equal(result, expected) + + +def test_sort(): + # https://stackoverflow.com/questions/23814368/sorting-pandas- + # categorical-labels-after-groupby + # This should result in a properly sorted Series so that the plot + # has a sorted x axis + # self.cat.groupby(['value_group'])['value_group'].count().plot(kind='bar') + + df = DataFrame({"value": np.random.default_rng(2).integers(0, 10000, 10)}) + labels = [f"{i} - {i + 499}" for i in range(0, 10000, 500)] + cat_labels = Categorical(labels, labels) + + df = df.sort_values(by=["value"], ascending=True) + df["value_group"] = pd.cut( + df.value, range(0, 10500, 500), right=False, labels=cat_labels + ) + + res = df.groupby(["value_group"], observed=False)["value_group"].count() + exp = res[sorted(res.index, key=lambda x: float(x.split()[0]))] + exp.index = CategoricalIndex(exp.index, name=exp.index.name) + tm.assert_series_equal(res, exp) + + +@pytest.mark.parametrize("ordered", [True, False]) +def test_sort2(sort, ordered): + # dataframe groupby sort was being ignored # GH 8868 + # GH#48749 - don't change order of categories + # GH#42482 - don't sort result when sort=False, even when ordered=True + df = DataFrame( + [ + ["(7.5, 10]", 10, 10], + ["(7.5, 10]", 8, 20], + ["(2.5, 5]", 5, 30], + ["(5, 7.5]", 6, 40], + ["(2.5, 5]", 4, 50], + ["(0, 2.5]", 1, 60], + ["(5, 7.5]", 7, 70], + ], + columns=["range", "foo", "bar"], + ) + df["range"] = Categorical(df["range"], ordered=ordered) + result = df.groupby("range", sort=sort, observed=False).first() + + if sort: + data_values = [[1, 60], [5, 30], [6, 40], [10, 10]] + index_values = ["(0, 2.5]", "(2.5, 5]", "(5, 7.5]", "(7.5, 10]"] + else: + data_values = [[10, 10], [5, 30], [6, 40], [1, 60]] + index_values = ["(7.5, 10]", "(2.5, 5]", "(5, 7.5]", "(0, 2.5]"] + expected = DataFrame( + data_values, + columns=["foo", "bar"], + index=CategoricalIndex(index_values, name="range", ordered=ordered), + ) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("ordered", [True, False]) +def test_sort_datetimelike(sort, ordered): + # GH10505 + # GH#42482 - don't sort result when sort=False, even when ordered=True + + # use same data as test_groupby_sort_categorical, which category is + # corresponding to datetime.month + df = DataFrame( + { + "dt": [ + datetime(2011, 7, 1), + datetime(2011, 7, 1), + datetime(2011, 2, 1), + datetime(2011, 5, 1), + datetime(2011, 2, 1), + datetime(2011, 1, 1), + datetime(2011, 5, 1), + ], + "foo": [10, 8, 5, 6, 4, 1, 7], + "bar": [10, 20, 30, 40, 50, 60, 70], + }, + columns=["dt", "foo", "bar"], + ) + + # ordered=True + df["dt"] = Categorical(df["dt"], ordered=ordered) + if sort: + data_values = [[1, 60], [5, 30], [6, 40], [10, 10]] + index_values = [ + datetime(2011, 1, 1), + datetime(2011, 2, 1), + datetime(2011, 5, 1), + datetime(2011, 7, 1), + ] + else: + data_values = [[10, 10], [5, 30], [6, 40], [1, 60]] + index_values = [ + datetime(2011, 7, 1), + datetime(2011, 2, 1), + datetime(2011, 5, 1), + datetime(2011, 1, 1), + ] + expected = DataFrame( + data_values, + columns=["foo", "bar"], + index=CategoricalIndex(index_values, name="dt", ordered=ordered), + ) + result = df.groupby("dt", sort=sort, observed=False).first() + tm.assert_frame_equal(result, expected) + + +def test_empty_sum(): + # https://github.com/pandas-dev/pandas/issues/18678 + df = DataFrame( + {"A": Categorical(["a", "a", "b"], categories=["a", "b", "c"]), "B": [1, 2, 1]} + ) + expected_idx = CategoricalIndex(["a", "b", "c"], name="A") + + # 0 by default + result = df.groupby("A", observed=False).B.sum() + expected = Series([3, 1, 0], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + # min_count=0 + result = df.groupby("A", observed=False).B.sum(min_count=0) + expected = Series([3, 1, 0], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + # min_count=1 + result = df.groupby("A", observed=False).B.sum(min_count=1) + expected = Series([3, 1, np.nan], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + # min_count>1 + result = df.groupby("A", observed=False).B.sum(min_count=2) + expected = Series([3, np.nan, np.nan], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + +def test_empty_prod(): + # https://github.com/pandas-dev/pandas/issues/18678 + df = DataFrame( + {"A": Categorical(["a", "a", "b"], categories=["a", "b", "c"]), "B": [1, 2, 1]} + ) + + expected_idx = CategoricalIndex(["a", "b", "c"], name="A") + + # 1 by default + result = df.groupby("A", observed=False).B.prod() + expected = Series([2, 1, 1], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + # min_count=0 + result = df.groupby("A", observed=False).B.prod(min_count=0) + expected = Series([2, 1, 1], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + # min_count=1 + result = df.groupby("A", observed=False).B.prod(min_count=1) + expected = Series([2, 1, np.nan], expected_idx, name="B") + tm.assert_series_equal(result, expected) + + +def test_groupby_multiindex_categorical_datetime(): + # https://github.com/pandas-dev/pandas/issues/21390 + + df = DataFrame( + { + "key1": Categorical(list("abcbabcba")), + "key2": Categorical( + list(pd.date_range("2018-06-01 00", freq="1min", periods=3)) * 3 + ), + "values": np.arange(9), + } + ) + result = df.groupby(["key1", "key2"], observed=False).mean() + + idx = MultiIndex.from_product( + [ + Categorical(["a", "b", "c"]), + Categorical(pd.date_range("2018-06-01 00", freq="1min", periods=3)), + ], + names=["key1", "key2"], + ) + expected = DataFrame({"values": [0, 4, 8, 3, 4, 5, 6, np.nan, 2]}, index=idx) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "as_index, expected", + [ + ( + True, + Series( + index=MultiIndex.from_arrays( + [Series([1, 1, 2], dtype="category"), [1, 2, 2]], names=["a", "b"] + ), + data=[1, 2, 3], + name="x", + ), + ), + ( + False, + DataFrame( + { + "a": Series([1, 1, 2], dtype="category"), + "b": [1, 2, 2], + "x": [1, 2, 3], + } + ), + ), + ], +) +def test_groupby_agg_observed_true_single_column(as_index, expected): + # GH-23970 + df = DataFrame( + {"a": Series([1, 1, 2], dtype="category"), "b": [1, 2, 2], "x": [1, 2, 3]} + ) + + result = df.groupby(["a", "b"], as_index=as_index, observed=True)["x"].sum() + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("fill_value", [None, np.nan, pd.NaT]) +def test_shift(fill_value): + ct = Categorical( + ["a", "b", "c", "d"], categories=["a", "b", "c", "d"], ordered=False + ) + expected = Categorical( + [None, "a", "b", "c"], categories=["a", "b", "c", "d"], ordered=False + ) + res = ct.shift(1, fill_value=fill_value) + tm.assert_equal(res, expected) + + +@pytest.fixture +def df_cat(df): + """ + DataFrame with multiple categorical columns and a column of integers. + Shortened so as not to contain all possible combinations of categories. + Useful for testing `observed` kwarg functionality on GroupBy objects. + + Parameters + ---------- + df: DataFrame + Non-categorical, longer DataFrame from another fixture, used to derive + this one + + Returns + ------- + df_cat: DataFrame + """ + df_cat = df.copy()[:4] # leave out some groups + df_cat["A"] = df_cat["A"].astype("category") + df_cat["B"] = df_cat["B"].astype("category") + df_cat["C"] = Series([1, 2, 3, 4]) + df_cat = df_cat.drop(["D"], axis=1) + return df_cat + + +@pytest.mark.parametrize("operation", ["agg", "apply"]) +def test_seriesgroupby_observed_true(df_cat, operation): + # GH#24880 + # GH#49223 - order of results was wrong when grouping by index levels + lev_a = Index(["bar", "bar", "foo", "foo"], dtype=df_cat["A"].dtype, name="A") + lev_b = Index(["one", "three", "one", "two"], dtype=df_cat["B"].dtype, name="B") + index = MultiIndex.from_arrays([lev_a, lev_b]) + expected = Series(data=[2, 4, 1, 3], index=index, name="C").sort_index() + + grouped = df_cat.groupby(["A", "B"], observed=True)["C"] + result = getattr(grouped, operation)(sum) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("operation", ["agg", "apply"]) +@pytest.mark.parametrize("observed", [False, None]) +def test_seriesgroupby_observed_false_or_none(df_cat, observed, operation): + # GH 24880 + # GH#49223 - order of results was wrong when grouping by index levels + index, _ = MultiIndex.from_product( + [ + CategoricalIndex(["bar", "foo"], ordered=False), + CategoricalIndex(["one", "three", "two"], ordered=False), + ], + names=["A", "B"], + ).sortlevel() + + expected = Series(data=[2, 4, 0, 1, 0, 3], index=index, name="C") + grouped = df_cat.groupby(["A", "B"], observed=observed)["C"] + result = getattr(grouped, operation)(sum) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "observed, index, data", + [ + ( + True, + MultiIndex.from_arrays( + [ + Index(["bar"] * 4 + ["foo"] * 4, dtype="category", name="A"), + Index( + ["one", "one", "three", "three", "one", "one", "two", "two"], + dtype="category", + name="B", + ), + Index(["min", "max"] * 4), + ] + ), + [2, 2, 4, 4, 1, 1, 3, 3], + ), + ( + False, + MultiIndex.from_product( + [ + CategoricalIndex(["bar", "foo"], ordered=False), + CategoricalIndex(["one", "three", "two"], ordered=False), + Index(["min", "max"]), + ], + names=["A", "B", None], + ), + [2, 2, 4, 4, np.nan, np.nan, 1, 1, np.nan, np.nan, 3, 3], + ), + ( + None, + MultiIndex.from_product( + [ + CategoricalIndex(["bar", "foo"], ordered=False), + CategoricalIndex(["one", "three", "two"], ordered=False), + Index(["min", "max"]), + ], + names=["A", "B", None], + ), + [2, 2, 4, 4, np.nan, np.nan, 1, 1, np.nan, np.nan, 3, 3], + ), + ], +) +def test_seriesgroupby_observed_apply_dict(df_cat, observed, index, data): + # GH 24880 + expected = Series(data=data, index=index, name="C") + result = df_cat.groupby(["A", "B"], observed=observed)["C"].apply( + lambda x: {"min": x.min(), "max": x.max()} + ) + tm.assert_series_equal(result, expected) + + +def test_groupby_categorical_series_dataframe_consistent(df_cat): + # GH 20416 + expected = df_cat.groupby(["A", "B"], observed=False)["C"].mean() + result = df_cat.groupby(["A", "B"], observed=False).mean()["C"] + tm.assert_series_equal(result, expected) + + +def test_groupby_cat_preserves_structure(observed, ordered): + # GH 28787 + df = DataFrame( + {"Name": Categorical(["Bob", "Greg"], ordered=ordered), "Item": [1, 2]}, + columns=["Name", "Item"], + ) + expected = df.copy() + + result = ( + df.groupby("Name", observed=observed) + .agg(DataFrame.sum, skipna=True) + .reset_index() + ) + + tm.assert_frame_equal(result, expected) + + +def test_get_nonexistent_category(): + # Accessing a Category that is not in the dataframe + df = DataFrame({"var": ["a", "a", "b", "b"], "val": range(4)}) + with pytest.raises(KeyError, match="'vau'"): + df.groupby("var").apply(lambda rows: DataFrame({"val": [rows.iloc[-1]["vau"]]})) + + +def test_series_groupby_on_2_categoricals_unobserved(reduction_func, observed): + # GH 17605 + if reduction_func == "ngroup": + pytest.skip("ngroup is not truly a reduction") + + df = DataFrame( + { + "cat_1": Categorical(list("AABB"), categories=list("ABCD")), + "cat_2": Categorical(list("AB") * 2, categories=list("ABCD")), + "value": [0.1] * 4, + } + ) + args = get_groupby_method_args(reduction_func, df) + + expected_length = 4 if observed else 16 + + series_groupby = df.groupby(["cat_1", "cat_2"], observed=observed)["value"] + + if reduction_func == "corrwith": + # TODO: implemented SeriesGroupBy.corrwith. See GH 32293 + assert not hasattr(series_groupby, reduction_func) + return + + agg = getattr(series_groupby, reduction_func) + + if not observed and reduction_func in ["idxmin", "idxmax"]: + # idxmin and idxmax are designed to fail on empty inputs + with pytest.raises( + ValueError, match="empty group due to unobserved categories" + ): + agg(*args) + return + + result = agg(*args) + + assert len(result) == expected_length + + +def test_series_groupby_on_2_categoricals_unobserved_zeroes_or_nans( + reduction_func, request +): + # GH 17605 + # Tests whether the unobserved categories in the result contain 0 or NaN + + if reduction_func == "ngroup": + pytest.skip("ngroup is not truly a reduction") + + if reduction_func == "corrwith": # GH 32293 + mark = pytest.mark.xfail( + reason="TODO: implemented SeriesGroupBy.corrwith. See GH 32293" + ) + request.applymarker(mark) + + df = DataFrame( + { + "cat_1": Categorical(list("AABB"), categories=list("ABC")), + "cat_2": Categorical(list("AB") * 2, categories=list("ABC")), + "value": [0.1] * 4, + } + ) + unobserved = [tuple("AC"), tuple("BC"), tuple("CA"), tuple("CB"), tuple("CC")] + args = get_groupby_method_args(reduction_func, df) + + series_groupby = df.groupby(["cat_1", "cat_2"], observed=False)["value"] + agg = getattr(series_groupby, reduction_func) + + if reduction_func in ["idxmin", "idxmax"]: + # idxmin and idxmax are designed to fail on empty inputs + with pytest.raises( + ValueError, match="empty group due to unobserved categories" + ): + agg(*args) + return + + result = agg(*args) + + missing_fillin = _results_for_groupbys_with_missing_categories[reduction_func] + + for idx in unobserved: + val = result.loc[idx] + assert (pd.isna(missing_fillin) and pd.isna(val)) or (val == missing_fillin) + + # If we expect unobserved values to be zero, we also expect the dtype to be int. + # Except for .sum(). If the observed categories sum to dtype=float (i.e. their + # sums have decimals), then the zeros for the missing categories should also be + # floats. + if missing_fillin == 0: + if reduction_func in ["count", "nunique", "size"]: + assert np.issubdtype(result.dtype, np.integer) + else: + assert reduction_func in ["sum", "any"] + + +def test_dataframe_groupby_on_2_categoricals_when_observed_is_true(reduction_func): + # GH 23865 + # GH 27075 + # Ensure that df.groupby, when 'by' is two Categorical variables, + # does not return the categories that are not in df when observed=True + if reduction_func == "ngroup": + pytest.skip("ngroup does not return the Categories on the index") + + df = DataFrame( + { + "cat_1": Categorical(list("AABB"), categories=list("ABC")), + "cat_2": Categorical(list("1111"), categories=list("12")), + "value": [0.1, 0.1, 0.1, 0.1], + } + ) + unobserved_cats = [("A", "2"), ("B", "2"), ("C", "1"), ("C", "2")] + + df_grp = df.groupby(["cat_1", "cat_2"], observed=True) + + args = get_groupby_method_args(reduction_func, df) + if reduction_func == "corrwith": + warn = Pandas4Warning + warn_msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn = None + warn_msg = "" + with tm.assert_produces_warning(warn, match=warn_msg): + res = getattr(df_grp, reduction_func)(*args) + + for cat in unobserved_cats: + assert cat not in res.index + + +@pytest.mark.parametrize("observed", [False, None]) +def test_dataframe_groupby_on_2_categoricals_when_observed_is_false( + reduction_func, observed, using_python_scalars +): + # GH 23865 + # GH 27075 + # Ensure that df.groupby, when 'by' is two Categorical variables, + # returns the categories that are not in df when observed=False/None + + if reduction_func == "ngroup": + pytest.skip("ngroup does not return the Categories on the index") + + df = DataFrame( + { + "cat_1": Categorical(list("AABB"), categories=list("ABC")), + "cat_2": Categorical(list("1111"), categories=list("12")), + "value": [0.1, 0.1, 0.1, 0.1], + } + ) + unobserved_cats = [("A", "2"), ("B", "2"), ("C", "1"), ("C", "2")] + + df_grp = df.groupby(["cat_1", "cat_2"], observed=observed) + + args = get_groupby_method_args(reduction_func, df) + + if not observed and reduction_func in ["idxmin", "idxmax"]: + # idxmin and idxmax are designed to fail on empty inputs + with pytest.raises( + ValueError, match="empty group due to unobserved categories" + ): + getattr(df_grp, reduction_func)(*args) + return + + if reduction_func == "corrwith": + warn = Pandas4Warning + warn_msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn = None + warn_msg = "" + with tm.assert_produces_warning(warn, match=warn_msg): + res = getattr(df_grp, reduction_func)(*args) + + expected = _results_for_groupbys_with_missing_categories[reduction_func] + + if using_python_scalars and reduction_func == "size": + assert (res.loc[unobserved_cats] == expected).all() is True + elif expected is np.nan: + assert res.loc[unobserved_cats].isnull().all().all() + else: + assert (res.loc[unobserved_cats] == expected).all().all() + + +def test_series_groupby_categorical_aggregation_getitem(): + # GH 8870 + d = {"foo": [10, 8, 4, 1], "bar": [10, 20, 30, 40], "baz": ["d", "c", "d", "c"]} + df = DataFrame(d) + cat = pd.cut(df["foo"], np.linspace(0, 20, 5)) + df["range"] = cat + groups = df.groupby(["range", "baz"], as_index=True, sort=True, observed=False) + result = groups["foo"].agg("mean") + expected = groups.agg("mean")["foo"] + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "func, expected_values", + [(Series.nunique, [1, 1, 2]), (Series.count, [1, 2, 2])], +) +def test_groupby_agg_categorical_columns(func, expected_values): + # 31256 + df = DataFrame( + { + "id": [0, 1, 2, 3, 4], + "groups": [0, 1, 1, 2, 2], + "value": Categorical([0, 0, 0, 0, 1]), + } + ).set_index("id") + result = df.groupby("groups").agg(func) + + expected = DataFrame( + {"value": expected_values}, index=Index([0, 1, 2], name="groups") + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_agg_non_numeric(): + df = DataFrame({"A": Categorical(["a", "a", "b"], categories=["a", "b", "c"])}) + expected = DataFrame({"A": [2, 1]}, index=np.array([1, 2])) + + result = df.groupby([1, 2, 1]).agg(Series.nunique) + tm.assert_frame_equal(result, expected) + + result = df.groupby([1, 2, 1]).nunique() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["first", "last"]) +def test_groupby_first_returned_categorical_instead_of_dataframe(func): + # GH 28641: groupby drops index, when grouping over categorical column with + # first/last. Renamed Categorical instead of DataFrame previously. + df = DataFrame({"A": [1997], "B": Series(["b"], dtype="category").cat.as_ordered()}) + df_grouped = df.groupby("A")["B"] + result = getattr(df_grouped, func)() + + # ordered categorical dtype should be preserved + expected = Series( + ["b"], index=Index([1997], name="A"), name="B", dtype=df["B"].dtype + ) + tm.assert_series_equal(result, expected) + + +def test_read_only_category_no_sort(): + # GH33410 + cats = np.array([1, 2]) + cats.flags.writeable = False + df = DataFrame( + {"a": [1, 3, 5, 7], "b": Categorical([1, 1, 2, 2], categories=Index(cats))} + ) + expected = DataFrame(data={"a": [2.0, 6.0]}, index=CategoricalIndex(cats, name="b")) + result = df.groupby("b", sort=False, observed=False).mean() + tm.assert_frame_equal(result, expected) + + +def test_sorted_missing_category_values(): + # GH 28597 + df = DataFrame( + { + "foo": [ + "small", + "large", + "large", + "large", + "medium", + "large", + "large", + "medium", + ], + "bar": ["C", "A", "A", "C", "A", "C", "A", "C"], + } + ) + df["foo"] = ( + df["foo"] + .astype("category") + .cat.set_categories(["tiny", "small", "medium", "large"], ordered=True) + ) + + expected = DataFrame( + { + "tiny": {"A": 0, "C": 0}, + "small": {"A": 0, "C": 1}, + "medium": {"A": 1, "C": 1}, + "large": {"A": 3, "C": 2}, + } + ) + expected = expected.rename_axis("bar", axis="index") + expected.columns = CategoricalIndex( + ["tiny", "small", "medium", "large"], + categories=["tiny", "small", "medium", "large"], + ordered=True, + name="foo", + dtype="category", + ) + + result = df.groupby(["bar", "foo"], observed=False).size().unstack() + + tm.assert_frame_equal(result, expected) + + +def test_agg_cython_category_not_implemented_fallback(): + # https://github.com/pandas-dev/pandas/issues/31450 + df = DataFrame({"col_num": [1, 1, 2, 3]}) + df["col_cat"] = df["col_num"].astype("category") + + result = df.groupby("col_num").col_cat.first() + + # ordered categorical dtype should definitely be preserved; + # this is unordered, so is less-clear case (if anything, it should raise) + expected = Series( + [1, 2, 3], + index=Index([1, 2, 3], name="col_num"), + name="col_cat", + dtype=df["col_cat"].dtype, + ) + tm.assert_series_equal(result, expected) + + result = df.groupby("col_num").agg({"col_cat": "first"}) + expected = expected.to_frame() + tm.assert_frame_equal(result, expected) + + +def test_aggregate_categorical_with_isnan(): + # GH 29837 + df = DataFrame( + { + "A": [1, 1, 1, 1], + "B": [1, 2, 1, 2], + "numerical_col": [0.1, 0.2, np.nan, 0.3], + "object_col": ["foo", "bar", "foo", "fee"], + "categorical_col": ["foo", "bar", "foo", "fee"], + } + ) + + df = df.astype({"categorical_col": "category"}) + + result = df.groupby(["A", "B"]).agg(lambda df: df.isna().sum()) + index = MultiIndex.from_arrays([[1, 1], [1, 2]], names=("A", "B")) + expected = DataFrame( + data={ + "numerical_col": [1, 0], + "object_col": [0, 0], + "categorical_col": [0, 0], + }, + index=index, + ) + tm.assert_frame_equal(result, expected) + + +def test_categorical_transform(): + # GH 29037 + df = DataFrame( + { + "package_id": [1, 1, 1, 2, 2, 3], + "status": [ + "Waiting", + "OnTheWay", + "Delivered", + "Waiting", + "OnTheWay", + "Waiting", + ], + } + ) + + delivery_status_type = pd.CategoricalDtype( + categories=["Waiting", "OnTheWay", "Delivered"], ordered=True + ) + df["status"] = df["status"].astype(delivery_status_type) + df["last_status"] = df.groupby("package_id")["status"].transform(max) + result = df.copy() + + expected = DataFrame( + { + "package_id": [1, 1, 1, 2, 2, 3], + "status": [ + "Waiting", + "OnTheWay", + "Delivered", + "Waiting", + "OnTheWay", + "Waiting", + ], + "last_status": [ + "Waiting", + "Waiting", + "Waiting", + "Waiting", + "Waiting", + "Waiting", + ], + } + ) + + expected["status"] = expected["status"].astype(delivery_status_type) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["first", "last"]) +def test_series_groupby_first_on_categorical_col_grouped_on_2_categoricals( + func: str, observed: bool +): + # GH 34951 + cat = Categorical([0, 0, 1, 1]) + val = [0, 1, 1, 0] + df = DataFrame({"a": cat, "b": cat, "c": val}) + + cat2 = Categorical([0, 1]) + idx = MultiIndex.from_product([cat2, cat2], names=["a", "b"]) + expected_dict = { + "first": Series([0, np.nan, np.nan, 1], idx, name="c"), + "last": Series([1, np.nan, np.nan, 0], idx, name="c"), + } + + expected = expected_dict[func] + if observed: + expected = expected.dropna().astype(np.int64) + + srs_grp = df.groupby(["a", "b"], observed=observed)["c"] + result = getattr(srs_grp, func)() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["first", "last"]) +def test_df_groupby_first_on_categorical_col_grouped_on_2_categoricals( + func: str, observed: bool +): + # GH 34951 + cat = Categorical([0, 0, 1, 1]) + val = [0, 1, 1, 0] + df = DataFrame({"a": cat, "b": cat, "c": val}) + + cat2 = Categorical([0, 1]) + idx = MultiIndex.from_product([cat2, cat2], names=["a", "b"]) + expected_dict = { + "first": Series([0, np.nan, np.nan, 1], idx, name="c"), + "last": Series([1, np.nan, np.nan, 0], idx, name="c"), + } + + expected = expected_dict[func].to_frame() + if observed: + expected = expected.dropna().astype(np.int64) + + df_grp = df.groupby(["a", "b"], observed=observed) + result = getattr(df_grp, func)() + tm.assert_frame_equal(result, expected) + + +def test_groupby_categorical_indices_unused_categories(): + # GH#38642 + df = DataFrame( + { + "key": Categorical(["b", "b", "a"], categories=["a", "b", "c"]), + "col": range(3), + } + ) + grouped = df.groupby("key", sort=False, observed=False) + result = grouped.indices + expected = { + "b": np.array([0, 1], dtype="intp"), + "a": np.array([2], dtype="intp"), + "c": np.array([], dtype="intp"), + } + assert result.keys() == expected.keys() + for key in result.keys(): + tm.assert_numpy_array_equal(result[key], expected[key]) + + +@pytest.mark.parametrize("func", ["first", "last"]) +def test_groupby_last_first_preserve_categoricaldtype(func): + # GH#33090 + df = DataFrame({"a": [1, 2, 3]}) + df["b"] = df["a"].astype("category") + result = getattr(df.groupby("a")["b"], func)() + expected = Series( + Categorical([1, 2, 3]), name="b", index=Index([1, 2, 3], name="a") + ) + tm.assert_series_equal(expected, result) + + +def test_groupby_categorical_observed_nunique(): + # GH#45128 + df = DataFrame({"a": [1, 2], "b": [1, 2], "c": [10, 11]}) + df = df.astype(dtype={"a": "category", "b": "category"}) + result = df.groupby(["a", "b"], observed=True).nunique()["c"] + expected = Series( + [1, 1], + index=MultiIndex.from_arrays( + [CategoricalIndex([1, 2], name="a"), CategoricalIndex([1, 2], name="b")] + ), + name="c", + ) + tm.assert_series_equal(result, expected) + + +def test_groupby_categorical_aggregate_functions(): + # GH#37275 + dtype = pd.CategoricalDtype(categories=["small", "big"], ordered=True) + df = DataFrame( + [[1, "small"], [1, "big"], [2, "small"]], columns=["grp", "description"] + ).astype({"description": dtype}) + + result = df.groupby("grp")["description"].max() + expected = Series( + ["big", "small"], + index=Index([1, 2], name="grp"), + name="description", + dtype=pd.CategoricalDtype(categories=["small", "big"], ordered=True), + ) + + tm.assert_series_equal(result, expected) + + +def test_groupby_categorical_dropna(observed, dropna): + # GH#48645 - dropna should have no impact on the result when there are no NA values + cat = Categorical([1, 2], categories=[1, 2, 3]) + df = DataFrame({"x": Categorical([1, 2], categories=[1, 2, 3]), "y": [3, 4]}) + gb = df.groupby("x", observed=observed, dropna=dropna) + result = gb.sum() + + if observed: + expected = DataFrame({"y": [3, 4]}, index=cat) + else: + index = CategoricalIndex([1, 2, 3], [1, 2, 3]) + expected = DataFrame({"y": [3, 4, 0]}, index=index) + expected.index.name = "x" + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("index_kind", ["range", "single", "multi"]) +@pytest.mark.parametrize("ordered", [True, False]) +def test_category_order_reducer( + request, as_index, sort, observed, reduction_func, index_kind, ordered +): + # GH#48749 + if reduction_func == "corrwith" and not as_index and index_kind != "single": + msg = "GH#49950 - corrwith with as_index=False may not have grouping column" + request.applymarker(pytest.mark.xfail(reason=msg)) + elif index_kind != "range" and not as_index: + pytest.skip(reason="Result doesn't have categories, nothing to test") + df = DataFrame( + { + "a": Categorical([2, 1, 2, 3], categories=[1, 4, 3, 2], ordered=ordered), + "b": range(4), + } + ) + if index_kind == "range": + keys = ["a"] + elif index_kind == "single": + keys = ["a"] + df = df.set_index(keys) + elif index_kind == "multi": + keys = ["a", "a2"] + df["a2"] = df["a"] + df = df.set_index(keys) + args = get_groupby_method_args(reduction_func, df) + gb = df.groupby(keys, as_index=as_index, sort=sort, observed=observed) + + if not observed and reduction_func in ["idxmin", "idxmax"]: + # idxmin and idxmax are designed to fail on empty inputs + with pytest.raises( + ValueError, match="empty group due to unobserved categories" + ): + getattr(gb, reduction_func)(*args) + return + if reduction_func == "corrwith": + warn = Pandas4Warning + warn_msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn = None + warn_msg = "" + with tm.assert_produces_warning(warn, match=warn_msg): + op_result = getattr(gb, reduction_func)(*args) + if as_index: + result = op_result.index.get_level_values("a").categories + else: + result = op_result["a"].cat.categories + expected = Index([1, 4, 3, 2]) + tm.assert_index_equal(result, expected) + + if index_kind == "multi": + result = op_result.index.get_level_values("a2").categories + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("index_kind", ["single", "multi"]) +@pytest.mark.parametrize("ordered", [True, False]) +def test_category_order_transformer( + as_index, sort, observed, transformation_func, index_kind, ordered +): + # GH#48749 + df = DataFrame( + { + "a": Categorical([2, 1, 2, 3], categories=[1, 4, 3, 2], ordered=ordered), + "b": range(4), + } + ) + if index_kind == "single": + keys = ["a"] + df = df.set_index(keys) + elif index_kind == "multi": + keys = ["a", "a2"] + df["a2"] = df["a"] + df = df.set_index(keys) + args = get_groupby_method_args(transformation_func, df) + gb = df.groupby(keys, as_index=as_index, sort=sort, observed=observed) + op_result = getattr(gb, transformation_func)(*args) + result = op_result.index.get_level_values("a").categories + expected = Index([1, 4, 3, 2]) + tm.assert_index_equal(result, expected) + + if index_kind == "multi": + result = op_result.index.get_level_values("a2").categories + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("index_kind", ["range", "single", "multi"]) +@pytest.mark.parametrize("method", ["head", "tail"]) +@pytest.mark.parametrize("ordered", [True, False]) +def test_category_order_head_tail( + as_index, sort, observed, method, index_kind, ordered +): + # GH#48749 + df = DataFrame( + { + "a": Categorical([2, 1, 2, 3], categories=[1, 4, 3, 2], ordered=ordered), + "b": range(4), + } + ) + if index_kind == "range": + keys = ["a"] + elif index_kind == "single": + keys = ["a"] + df = df.set_index(keys) + elif index_kind == "multi": + keys = ["a", "a2"] + df["a2"] = df["a"] + df = df.set_index(keys) + gb = df.groupby(keys, as_index=as_index, sort=sort, observed=observed) + op_result = getattr(gb, method)() + if index_kind == "range": + result = op_result["a"].cat.categories + else: + result = op_result.index.get_level_values("a").categories + expected = Index([1, 4, 3, 2]) + tm.assert_index_equal(result, expected) + + if index_kind == "multi": + result = op_result.index.get_level_values("a2").categories + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("index_kind", ["range", "single", "multi"]) +@pytest.mark.parametrize("method", ["apply", "agg", "transform"]) +@pytest.mark.parametrize("ordered", [True, False]) +def test_category_order_apply(as_index, sort, observed, method, index_kind, ordered): + # GH#48749 + if (method == "transform" and index_kind == "range") or ( + not as_index and index_kind != "range" + ): + pytest.skip("No categories in result, nothing to test") + df = DataFrame( + { + "a": Categorical([2, 1, 2, 3], categories=[1, 4, 3, 2], ordered=ordered), + "b": range(4), + } + ) + if index_kind == "range": + keys = ["a"] + elif index_kind == "single": + keys = ["a"] + df = df.set_index(keys) + elif index_kind == "multi": + keys = ["a", "a2"] + df["a2"] = df["a"] + df = df.set_index(keys) + gb = df.groupby(keys, as_index=as_index, sort=sort, observed=observed) + op_result = getattr(gb, method)(lambda x: x.sum(numeric_only=True)) + if (method == "transform" or not as_index) and index_kind == "range": + result = op_result["a"].cat.categories + else: + result = op_result.index.get_level_values("a").categories + expected = Index([1, 4, 3, 2]) + tm.assert_index_equal(result, expected) + + if index_kind == "multi": + result = op_result.index.get_level_values("a2").categories + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("index_kind", ["range", "single", "multi"]) +def test_many_categories(as_index, sort, index_kind, ordered): + # GH#48749 - Test when the grouper has many categories + if index_kind != "range" and not as_index: + pytest.skip(reason="Result doesn't have categories, nothing to test") + categories = np.arange(9999, -1, -1) + grouper = Categorical([2, 1, 2, 3], categories=categories, ordered=ordered) + df = DataFrame({"a": grouper, "b": range(4)}) + if index_kind == "range": + keys = ["a"] + elif index_kind == "single": + keys = ["a"] + df = df.set_index(keys) + elif index_kind == "multi": + keys = ["a", "a2"] + df["a2"] = df["a"] + df = df.set_index(keys) + gb = df.groupby(keys, as_index=as_index, sort=sort, observed=True) + result = gb.sum() + + # Test is setup so that data and index are the same values + data = [3, 2, 1] if sort else [2, 1, 3] + + index = CategoricalIndex( + data, categories=grouper.categories, ordered=ordered, name="a" + ) + if as_index: + expected = DataFrame({"b": data}) + if index_kind == "multi": + expected.index = MultiIndex.from_frame(DataFrame({"a": index, "a2": index})) + else: + expected.index = index + elif index_kind == "multi": + expected = DataFrame({"a": Series(index), "a2": Series(index), "b": data}) + else: + expected = DataFrame({"a": Series(index), "b": data}) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("test_series", [True, False]) +@pytest.mark.parametrize("keys", [["a1"], ["a1", "a2"]]) +def test_agg_list(request, as_index, observed, reduction_func, test_series, keys): + # GH#52760 + if test_series and reduction_func == "corrwith": + assert not hasattr(SeriesGroupBy, "corrwith") + pytest.skip("corrwith not implemented for SeriesGroupBy") + elif reduction_func == "corrwith": + msg = "GH#32293: attempts to call SeriesGroupBy.corrwith" + request.applymarker(pytest.mark.xfail(reason=msg)) + + df = DataFrame({"a1": [0, 0, 1], "a2": [2, 3, 3], "b": [4, 5, 6]}) + df = df.astype({"a1": "category", "a2": "category"}) + if "a2" not in keys: + df = df.drop(columns="a2") + gb = df.groupby(by=keys, as_index=as_index, observed=observed) + if test_series: + gb = gb["b"] + args = get_groupby_method_args(reduction_func, df) + + if not observed and reduction_func in ["idxmin", "idxmax"] and keys == ["a1", "a2"]: + with pytest.raises( + ValueError, match="empty group due to unobserved categories" + ): + gb.agg([reduction_func], *args) + return + + result = gb.agg([reduction_func], *args) + expected = getattr(gb, reduction_func)(*args) + + if as_index and (test_series or reduction_func == "size"): + expected = expected.to_frame(reduction_func) + if not test_series: + expected.columns = MultiIndex.from_tuples( + [(ind, "") for ind in expected.columns[:-1]] + [("b", reduction_func)] + ) + elif not as_index: + expected.columns = [*keys, reduction_func] + + tm.assert_equal(result, expected) + + +def test_categorical_with_noncategorical_na(observed, sort): + # https://github.com/pandas-dev/pandas/issues/63920 + df = DataFrame( + { + "dates": list("YXXYY"), + "sector": Categorical( + [2, 1, 2, 1, np.nan], categories=[1, 2, 3], ordered=True + ), + "metric": [1, 2, 3, 4, 5], + } + ) + gb = df.groupby(["dates", "sector"], observed=observed, sort=sort) + # Only testing the ids/result_index, okay to just use one kernel + result = gb.sum() + + if sort and observed: + taker = [0, 1, 2, 3] + elif not sort and observed: + taker = [3, 0, 1, 2] + elif sort and not observed: + taker = [0, 1, 4, 2, 3, 5] + elif not sort and not observed: + taker = [3, 0, 1, 2, 5, 4] + expected = ( + DataFrame( + { + "dates": list("XXYYXY"), + "sector": Categorical( + [1, 2, 1, 2, 3, 3], categories=[1, 2, 3], ordered=True + ), + "metric": [2, 3, 4, 1, 0, 0], + } + ) + .set_index(["dates", "sector"]) + .take(taker) + ) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/groupby/test_counting.py b/pandas/tests/groupby/test_counting.py new file mode 100644 index 0000000000000000000000000000000000000000..679f7eb7f7f11d842dc36c9c6cb83d2096fb66b2 --- /dev/null +++ b/pandas/tests/groupby/test_counting.py @@ -0,0 +1,394 @@ +from itertools import product +from string import ascii_lowercase + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + MultiIndex, + Period, + Series, + Timedelta, + Timestamp, + date_range, +) +import pandas._testing as tm + + +class TestCounting: + def test_cumcount(self): + df = DataFrame([["a"], ["a"], ["a"], ["b"], ["a"]], columns=["A"]) + g = df.groupby("A") + sg = g.A + + expected = Series([0, 1, 2, 0, 3]) + + tm.assert_series_equal(expected, g.cumcount()) + tm.assert_series_equal(expected, sg.cumcount()) + + def test_cumcount_empty(self): + ge = DataFrame().groupby(level=0) + se = Series(dtype=object).groupby(level=0) + + # edge case, as this is usually considered float + e = Series(dtype="int64") + + tm.assert_series_equal(e, ge.cumcount()) + tm.assert_series_equal(e, se.cumcount()) + + def test_cumcount_dupe_index(self): + df = DataFrame( + [["a"], ["a"], ["a"], ["b"], ["a"]], columns=["A"], index=[0] * 5 + ) + g = df.groupby("A") + sg = g.A + + expected = Series([0, 1, 2, 0, 3], index=[0] * 5) + + tm.assert_series_equal(expected, g.cumcount()) + tm.assert_series_equal(expected, sg.cumcount()) + + def test_cumcount_mi(self): + mi = MultiIndex.from_tuples([[0, 1], [1, 2], [2, 2], [2, 2], [1, 0]]) + df = DataFrame([["a"], ["a"], ["a"], ["b"], ["a"]], columns=["A"], index=mi) + g = df.groupby("A") + sg = g.A + + expected = Series([0, 1, 2, 0, 3], index=mi) + + tm.assert_series_equal(expected, g.cumcount()) + tm.assert_series_equal(expected, sg.cumcount()) + + def test_cumcount_groupby_not_col(self): + df = DataFrame( + [["a"], ["a"], ["a"], ["b"], ["a"]], columns=["A"], index=[0] * 5 + ) + g = df.groupby([0, 0, 0, 1, 0]) + sg = g.A + + expected = Series([0, 1, 2, 0, 3], index=[0] * 5) + + tm.assert_series_equal(expected, g.cumcount()) + tm.assert_series_equal(expected, sg.cumcount()) + + def test_ngroup(self): + df = DataFrame({"A": list("aaaba")}) + g = df.groupby("A") + sg = g.A + + expected = Series([0, 0, 0, 1, 0]) + + tm.assert_series_equal(expected, g.ngroup()) + tm.assert_series_equal(expected, sg.ngroup()) + + def test_ngroup_distinct(self): + df = DataFrame({"A": list("abcde")}) + g = df.groupby("A") + sg = g.A + + expected = Series(range(5), dtype="int64") + + tm.assert_series_equal(expected, g.ngroup()) + tm.assert_series_equal(expected, sg.ngroup()) + + def test_ngroup_one_group(self): + df = DataFrame({"A": [0] * 5}) + g = df.groupby("A") + sg = g.A + + expected = Series([0] * 5) + + tm.assert_series_equal(expected, g.ngroup()) + tm.assert_series_equal(expected, sg.ngroup()) + + def test_ngroup_empty(self): + ge = DataFrame().groupby(level=0) + se = Series(dtype=object).groupby(level=0) + + # edge case, as this is usually considered float + e = Series(dtype="int64") + + tm.assert_series_equal(e, ge.ngroup()) + tm.assert_series_equal(e, se.ngroup()) + + def test_ngroup_series_matches_frame(self): + df = DataFrame({"A": list("aaaba")}) + s = Series(list("aaaba")) + + tm.assert_series_equal(df.groupby(s).ngroup(), s.groupby(s).ngroup()) + + def test_ngroup_dupe_index(self): + df = DataFrame({"A": list("aaaba")}, index=[0] * 5) + g = df.groupby("A") + sg = g.A + + expected = Series([0, 0, 0, 1, 0], index=[0] * 5) + + tm.assert_series_equal(expected, g.ngroup()) + tm.assert_series_equal(expected, sg.ngroup()) + + def test_ngroup_mi(self): + mi = MultiIndex.from_tuples([[0, 1], [1, 2], [2, 2], [2, 2], [1, 0]]) + df = DataFrame({"A": list("aaaba")}, index=mi) + g = df.groupby("A") + sg = g.A + expected = Series([0, 0, 0, 1, 0], index=mi) + + tm.assert_series_equal(expected, g.ngroup()) + tm.assert_series_equal(expected, sg.ngroup()) + + def test_ngroup_groupby_not_col(self): + df = DataFrame({"A": list("aaaba")}, index=[0] * 5) + g = df.groupby([0, 0, 0, 1, 0]) + sg = g.A + + expected = Series([0, 0, 0, 1, 0], index=[0] * 5) + + tm.assert_series_equal(expected, g.ngroup()) + tm.assert_series_equal(expected, sg.ngroup()) + + def test_ngroup_descending(self): + df = DataFrame(["a", "a", "b", "a", "b"], columns=["A"]) + g = df.groupby(["A"]) + + ascending = Series([0, 0, 1, 0, 1]) + descending = Series([1, 1, 0, 1, 0]) + + tm.assert_series_equal(descending, (g.ngroups - 1) - ascending) + tm.assert_series_equal(ascending, g.ngroup(ascending=True)) + tm.assert_series_equal(descending, g.ngroup(ascending=False)) + + def test_ngroup_matches_cumcount(self): + # verify one manually-worked out case works + df = DataFrame( + [["a", "x"], ["a", "y"], ["b", "x"], ["a", "x"], ["b", "y"]], + columns=["A", "X"], + ) + g = df.groupby(["A", "X"]) + g_ngroup = g.ngroup() + g_cumcount = g.cumcount() + expected_ngroup = Series([0, 1, 2, 0, 3]) + expected_cumcount = Series([0, 0, 0, 1, 0]) + + tm.assert_series_equal(g_ngroup, expected_ngroup) + tm.assert_series_equal(g_cumcount, expected_cumcount) + + def test_ngroup_cumcount_pair(self): + # brute force comparison for all small series + for p in product(range(3), repeat=4): + df = DataFrame({"a": p}) + g = df.groupby(["a"]) + + order = sorted(set(p)) + ngroupd = [order.index(val) for val in p] + cumcounted = [p[:i].count(val) for i, val in enumerate(p)] + + tm.assert_series_equal(g.ngroup(), Series(ngroupd)) + tm.assert_series_equal(g.cumcount(), Series(cumcounted)) + + def test_ngroup_respects_groupby_order(self, sort): + df = DataFrame({"a": np.random.default_rng(2).choice(list("abcdef"), 100)}) + g = df.groupby("a", sort=sort) + df["group_id"] = -1 + df["group_index"] = -1 + + for i, (_, group) in enumerate(g): + df.loc[group.index, "group_id"] = i + for j, ind in enumerate(group.index): + df.loc[ind, "group_index"] = j + + tm.assert_series_equal(Series(df["group_id"].values), g.ngroup()) + tm.assert_series_equal(Series(df["group_index"].values), g.cumcount()) + + @pytest.mark.parametrize( + "datetimelike", + [ + [Timestamp(f"2016-05-{i:02d} 20:09:25+00:00") for i in range(1, 4)], + [Timestamp(f"2016-05-{i:02d} 20:09:25") for i in range(1, 4)], + [Timestamp(f"2016-05-{i:02d} 20:09:25", tz="UTC") for i in range(1, 4)], + [Timedelta(x, unit="h") for x in range(1, 4)], + [Period(freq="2W", year=2017, month=x) for x in range(1, 4)], + ], + ) + def test_count_with_datetimelike(self, datetimelike): + # test for #13393, where DataframeGroupBy.count() fails + # when counting a datetimelike column. + + df = DataFrame({"x": ["a", "a", "b"], "y": datetimelike}) + res = df.groupby("x").count() + expected = DataFrame({"y": [2, 1]}, index=["a", "b"]) + expected.index.name = "x" + tm.assert_frame_equal(expected, res) + + def test_count_with_only_nans_in_first_group(self): + # GH21956 + df = DataFrame({"A": [np.nan, np.nan], "B": ["a", "b"], "C": [1, 2]}) + result = df.groupby(["A", "B"]).C.count() + mi = MultiIndex(levels=[[], ["a", "b"]], codes=[[], []], names=["A", "B"]) + expected = Series([], index=mi, dtype=np.int64, name="C") + tm.assert_series_equal(result, expected, check_index_type=False) + + def test_count_groupby_column_with_nan_in_groupby_column(self): + # https://github.com/pandas-dev/pandas/issues/32841 + df = DataFrame({"A": [1, 1, 1, 1, 1], "B": [5, 4, np.nan, 3, 0]}) + res = df.groupby(["B"]).count() + expected = DataFrame( + index=Index([0.0, 3.0, 4.0, 5.0], name="B"), data={"A": [1, 1, 1, 1]} + ) + tm.assert_frame_equal(expected, res) + + def test_groupby_count_dateparseerror(self): + dr = date_range(start="1/1/2012", freq="5min", periods=10) + + # BAD Example, datetimes first + ser = Series(np.arange(10), index=[dr, np.arange(10)]) + grouped = ser.groupby(lambda x: x[1] % 2 == 0) + result = grouped.count() + + ser = Series(np.arange(10), index=[np.arange(10), dr]) + grouped = ser.groupby(lambda x: x[0] % 2 == 0) + expected = grouped.count() + + tm.assert_series_equal(result, expected) + + +def test_groupby_timedelta_cython_count(): + df = DataFrame( + {"g": list("ab" * 2), "delta": np.arange(4).astype("timedelta64[ns]")} + ) + expected = Series([2, 2], index=Index(["a", "b"], name="g"), name="delta") + result = df.groupby("g").delta.count() + tm.assert_series_equal(expected, result) + + +def test_count(): + n = 1 << 15 + dr = date_range("2015-08-30", periods=n // 10, freq="min") + + df = DataFrame( + { + "1st": np.random.default_rng(2).choice(list(ascii_lowercase), n), + "2nd": np.random.default_rng(2).integers(0, 5, n), + "3rd": np.random.default_rng(2).standard_normal(n).round(3), + "4th": np.random.default_rng(2).integers(-10, 10, n), + "5th": np.random.default_rng(2).choice(dr, n), + "6th": np.random.default_rng(2).standard_normal(n).round(3), + "7th": np.random.default_rng(2).standard_normal(n).round(3), + "8th": np.random.default_rng(2).choice(dr, n) + - np.random.default_rng(2).choice(dr, 1), + "9th": np.random.default_rng(2).choice(list(ascii_lowercase), n), + } + ) + + for col in df.columns.drop(["1st", "2nd", "4th"]): + df.loc[np.random.default_rng(2).choice(n, n // 10), col] = np.nan + + df["9th"] = df["9th"].astype("category") + + for key in ["1st", "2nd", ["1st", "2nd"]]: + left = df.groupby(key).count() + right = df.groupby(key).apply(DataFrame.count) + tm.assert_frame_equal(left, right) + + +def test_count_non_nulls(): + # GH#5610 + # count counts non-nulls + df = DataFrame( + [[1, 2, "foo"], [1, np.nan, "bar"], [3, np.nan, np.nan]], + columns=["A", "B", "C"], + ) + + count_as = df.groupby("A").count() + count_not_as = df.groupby("A", as_index=False).count() + + expected = DataFrame([[1, 2], [0, 0]], columns=["B", "C"], index=[1, 3]) + expected.index.name = "A" + tm.assert_frame_equal(count_not_as, expected.reset_index()) + tm.assert_frame_equal(count_as, expected) + + count_B = df.groupby("A")["B"].count() + tm.assert_series_equal(count_B, expected["B"]) + + +def test_count_object(): + df = DataFrame({"a": ["a"] * 3 + ["b"] * 3, "c": [2] * 3 + [3] * 3}) + result = df.groupby("c").a.count() + expected = Series([3, 3], index=Index([2, 3], name="c"), name="a") + tm.assert_series_equal(result, expected) + + +def test_count_object_nan(): + df = DataFrame({"a": ["a", np.nan, np.nan] + ["b"] * 3, "c": [2] * 3 + [3] * 3}) + result = df.groupby("c").a.count() + expected = Series([1, 3], index=Index([2, 3], name="c"), name="a") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("typ", ["object", "float32"]) +def test_count_cross_type(typ): + # GH8169 + # Set float64 dtype to avoid upcast when setting nan below + vals = np.hstack( + ( + np.random.default_rng(2).integers(0, 5, (10, 2)), + np.random.default_rng(2).integers(0, 2, (10, 2)), + ) + ).astype("float64") + + df = DataFrame(vals, columns=["a", "b", "c", "d"]) + df[df == 2] = np.nan + expected = df.groupby(["c", "d"]).count() + + df["a"] = df["a"].astype(typ) + df["b"] = df["b"].astype(typ) + result = df.groupby(["c", "d"]).count() + tm.assert_frame_equal(result, expected) + + +def test_lower_int_prec_count(): + df = DataFrame( + { + "a": np.array([0, 1, 2, 100], np.int8), + "b": np.array([1, 2, 3, 6], np.uint32), + "c": np.array([4, 5, 6, 8], np.int16), + "grp": list("ab" * 2), + } + ) + result = df.groupby("grp").count() + expected = DataFrame( + {"a": [2, 2], "b": [2, 2], "c": [2, 2]}, index=Index(list("ab"), name="grp") + ) + tm.assert_frame_equal(result, expected) + + +def test_count_uses_size_on_exception(): + class RaisingObjectException(Exception): + pass + + class RaisingObject: + def __init__(self, msg="I will raise inside Cython") -> None: + super().__init__() + self.msg = msg + + def __eq__(self, other): + # gets called in Cython to check that raising calls the method + raise RaisingObjectException(self.msg) + + df = DataFrame({"a": [RaisingObject() for _ in range(4)], "grp": list("ab" * 2)}) + result = df.groupby("grp").count() + expected = DataFrame({"a": [2, 2]}, index=Index(list("ab"), name="grp")) + tm.assert_frame_equal(result, expected) + + +def test_count_arrow_string_array(any_string_dtype): + # GH#54751 + pytest.importorskip("pyarrow") + df = DataFrame( + {"a": [1, 2, 3], "b": Series(["a", "b", "a"], dtype=any_string_dtype)} + ) + result = df.groupby("a").count() + expected = DataFrame({"b": 1}, index=Index([1, 2, 3], name="a")) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/groupby/test_cumulative.py b/pandas/tests/groupby/test_cumulative.py new file mode 100644 index 0000000000000000000000000000000000000000..cca4971e930b416fb5cbbacd704d6f171ebdfeee --- /dev/null +++ b/pandas/tests/groupby/test_cumulative.py @@ -0,0 +1,332 @@ +import numpy as np +import pytest + +from pandas.errors import UnsupportedFunctionCall +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Series, +) +import pandas._testing as tm + + +@pytest.fixture( + params=[np.int32, np.int64, np.float32, np.float64, "Int64", "Float64"], + ids=["np.int32", "np.int64", "np.float32", "np.float64", "Int64", "Float64"], +) +def dtypes_for_minmax(request): + """ + Fixture of dtypes with min and max values used for testing + cummin and cummax + """ + dtype = request.param + + np_type = dtype + if dtype == "Int64": + np_type = np.int64 + elif dtype == "Float64": + np_type = np.float64 + + min_val = ( + np.iinfo(np_type).min + if np.dtype(np_type).kind == "i" + else np.finfo(np_type).min + ) + max_val = ( + np.iinfo(np_type).max + if np.dtype(np_type).kind == "i" + else np.finfo(np_type).max + ) + + return (dtype, min_val, max_val) + + +def test_groupby_cumprod(): + # GH 4095 + df = DataFrame({"key": ["b"] * 10, "value": 2}) + + actual = df.groupby("key")["value"].cumprod() + expected = df.groupby("key", group_keys=False)["value"].apply(lambda x: x.cumprod()) + expected.name = "value" + tm.assert_series_equal(actual, expected) + + df = DataFrame({"key": ["b"] * 100, "value": 2}) + df["value"] = df["value"].astype(float) + actual = df.groupby("key")["value"].cumprod() + expected = df.groupby("key", group_keys=False)["value"].apply(lambda x: x.cumprod()) + expected.name = "value" + tm.assert_series_equal(actual, expected) + + +def test_groupby_cumprod_overflow(): + # GH#37493 if we overflow we return garbage consistent with numpy + df = DataFrame({"key": ["b"] * 4, "value": 100_000}) + actual = df.groupby("key")["value"].cumprod() + expected = Series( + [100_000, 10_000_000_000, 1_000_000_000_000_000, 7766279631452241920], + name="value", + ) + tm.assert_series_equal(actual, expected) + + numpy_result = df.groupby("key", group_keys=False)["value"].apply( + lambda x: x.cumprod() + ) + numpy_result.name = "value" + tm.assert_series_equal(actual, numpy_result) + + +def test_groupby_cumprod_nan_influences_other_columns(): + # GH#48064 + df = DataFrame( + { + "a": 1, + "b": [1, np.nan, 2], + "c": [1, 2, 3.0], + } + ) + result = df.groupby("a").cumprod(numeric_only=True, skipna=False) + expected = DataFrame({"b": [1, np.nan, np.nan], "c": [1, 2, 6.0]}) + tm.assert_frame_equal(result, expected) + + +def test_cummin(dtypes_for_minmax): + dtype = dtypes_for_minmax[0] + + # GH 15048 + base_df = DataFrame({"A": [1, 1, 1, 1, 2, 2, 2, 2], "B": [3, 4, 3, 2, 2, 3, 2, 1]}) + expected_mins = [3, 3, 3, 2, 2, 2, 2, 1] + + df = base_df.astype(dtype) + expected = DataFrame({"B": expected_mins}).astype(dtype) + result = df.groupby("A").cummin() + tm.assert_frame_equal(result, expected) + result = df.groupby("A", group_keys=False).B.apply(lambda x: x.cummin()).to_frame() + tm.assert_frame_equal(result, expected) + + +def test_cummin_min_value_for_dtype(dtypes_for_minmax): + dtype = dtypes_for_minmax[0] + min_val = dtypes_for_minmax[1] + + # GH 15048 + base_df = DataFrame({"A": [1, 1, 1, 1, 2, 2, 2, 2], "B": [3, 4, 3, 2, 2, 3, 2, 1]}) + expected_mins = [3, 3, 3, 2, 2, 2, 2, 1] + expected = DataFrame({"B": expected_mins}).astype(dtype) + df = base_df.astype(dtype) + df.loc[[2, 6], "B"] = min_val + df.loc[[1, 5], "B"] = min_val + 1 + expected.loc[[2, 3, 6, 7], "B"] = min_val + expected.loc[[1, 5], "B"] = min_val + 1 # should not be rounded to min_val + result = df.groupby("A").cummin() + tm.assert_frame_equal(result, expected, check_exact=True) + expected = ( + df.groupby("A", group_keys=False).B.apply(lambda x: x.cummin()).to_frame() + ) + tm.assert_frame_equal(result, expected, check_exact=True) + + +def test_cummin_nan_in_some_values(dtypes_for_minmax): + # Explicit cast to float to avoid implicit cast when setting nan + base_df = DataFrame({"A": [1, 1, 1, 1, 2, 2, 2, 2], "B": [3, 4, 3, 2, 2, 3, 2, 1]}) + base_df = base_df.astype({"B": "float"}) + base_df.loc[[0, 2, 4, 6], "B"] = np.nan + expected = DataFrame({"B": [np.nan, 4, np.nan, 2, np.nan, 3, np.nan, 1]}) + result = base_df.groupby("A").cummin() + tm.assert_frame_equal(result, expected) + expected = ( + base_df.groupby("A", group_keys=False).B.apply(lambda x: x.cummin()).to_frame() + ) + tm.assert_frame_equal(result, expected) + + +def test_cummin_datetime(): + # GH 15561 + df = DataFrame({"a": [1], "b": pd.to_datetime(["2001"])}) + expected = Series(pd.to_datetime("2001"), index=[0], name="b") + + result = df.groupby("a")["b"].cummin() + tm.assert_series_equal(expected, result) + + +def test_cummin_getattr_series(): + # GH 15635 + df = DataFrame({"a": [1, 2, 1], "b": [1, 2, 2]}) + result = df.groupby("a").b.cummin() + expected = Series([1, 2, 1], name="b") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["cummin", "cummax"]) +@pytest.mark.parametrize("dtype", ["UInt64", "Int64", "Float64", "float", "boolean"]) +def test_cummin_max_all_nan_column(method, dtype): + item = np.nan if dtype == "float" else pd.NA + base_df = DataFrame({"A": [1, 1, 1, 1, 2, 2, 2, 2], "B": [item] * 8}) + base_df["B"] = base_df["B"].astype(dtype) + grouped = base_df.groupby("A") + + expected = DataFrame({"B": [item] * 8}, dtype=dtype) + result = getattr(grouped, method)() + tm.assert_frame_equal(expected, result) + + result = getattr(grouped["B"], method)().to_frame() + tm.assert_frame_equal(expected, result) + + +def test_cummax(dtypes_for_minmax): + dtype = dtypes_for_minmax[0] + + # GH 15048 + base_df = DataFrame({"A": [1, 1, 1, 1, 2, 2, 2, 2], "B": [3, 4, 3, 2, 2, 3, 2, 1]}) + expected_maxs = [3, 4, 4, 4, 2, 3, 3, 3] + + df = base_df.astype(dtype) + + expected = DataFrame({"B": expected_maxs}).astype(dtype) + result = df.groupby("A").cummax() + tm.assert_frame_equal(result, expected) + result = df.groupby("A", group_keys=False).B.apply(lambda x: x.cummax()).to_frame() + tm.assert_frame_equal(result, expected) + + +def test_cummax_min_value_for_dtype(dtypes_for_minmax): + dtype = dtypes_for_minmax[0] + max_val = dtypes_for_minmax[2] + + # GH 15048 + base_df = DataFrame({"A": [1, 1, 1, 1, 2, 2, 2, 2], "B": [3, 4, 3, 2, 2, 3, 2, 1]}) + expected_maxs = [3, 4, 4, 4, 2, 3, 3, 3] + + df = base_df.astype(dtype) + df.loc[[2, 6], "B"] = max_val + expected = DataFrame({"B": expected_maxs}).astype(dtype) + expected.loc[[2, 3, 6, 7], "B"] = max_val + result = df.groupby("A").cummax() + tm.assert_frame_equal(result, expected) + expected = ( + df.groupby("A", group_keys=False).B.apply(lambda x: x.cummax()).to_frame() + ) + tm.assert_frame_equal(result, expected) + + +def test_cummax_nan_in_some_values(dtypes_for_minmax): + # Test nan in some values + # Explicit cast to float to avoid implicit cast when setting nan + base_df = DataFrame({"A": [1, 1, 1, 1, 2, 2, 2, 2], "B": [3, 4, 3, 2, 2, 3, 2, 1]}) + base_df = base_df.astype({"B": "float"}) + base_df.loc[[0, 2, 4, 6], "B"] = np.nan + expected = DataFrame({"B": [np.nan, 4, np.nan, 4, np.nan, 3, np.nan, 3]}) + result = base_df.groupby("A").cummax() + tm.assert_frame_equal(result, expected) + expected = ( + base_df.groupby("A", group_keys=False).B.apply(lambda x: x.cummax()).to_frame() + ) + tm.assert_frame_equal(result, expected) + + +def test_cummax_datetime(): + # GH 15561 + df = DataFrame({"a": [1], "b": pd.to_datetime(["2001"])}) + expected = Series(pd.to_datetime("2001"), index=[0], name="b") + + result = df.groupby("a")["b"].cummax() + tm.assert_series_equal(expected, result) + + +def test_cummax_getattr_series(): + # GH 15635 + df = DataFrame({"a": [1, 2, 1], "b": [2, 1, 1]}) + result = df.groupby("a").b.cummax() + expected = Series([2, 1, 2], name="b") + tm.assert_series_equal(result, expected) + + +def test_cummax_i8_at_implementation_bound(): + # the minimum value used to be treated as NPY_NAT+1 instead of NPY_NAT + # for int64 dtype GH#46382 + ser = Series([pd.NaT._value + n for n in range(5)]) + df = DataFrame({"A": 1, "B": ser, "C": ser._values.view("M8[ns]")}) + gb = df.groupby("A") + + res = gb.cummax() + exp = df[["B", "C"]] + tm.assert_frame_equal(res, exp) + + +@pytest.mark.parametrize("method", ["cummin", "cummax"]) +@pytest.mark.parametrize("dtype", ["float", "Int64", "Float64"]) +@pytest.mark.parametrize( + "groups,expected_data", + [ + ([1, 1, 1], [1, None, None]), + ([1, 2, 3], [1, None, 2]), + ([1, 3, 3], [1, None, None]), + ], +) +def test_cummin_max_skipna(method, dtype, groups, expected_data): + # GH-34047 + df = DataFrame({"a": Series([1, None, 2], dtype=dtype)}) + orig = df.copy() + gb = df.groupby(groups)["a"] + + result = getattr(gb, method)(skipna=False) + expected = Series(expected_data, dtype=dtype, name="a") + + # check we didn't accidentally alter df + tm.assert_frame_equal(df, orig) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["cummin", "cummax"]) +def test_cummin_max_skipna_multiple_cols(method): + # Ensure missing value in "a" doesn't cause "b" to be nan-filled + df = DataFrame({"a": [np.nan, 2.0, 2.0], "b": [2.0, 2.0, 2.0]}) + gb = df.groupby([1, 1, 1])[["a", "b"]] + + result = getattr(gb, method)(skipna=False) + expected = DataFrame({"a": [np.nan, np.nan, np.nan], "b": [2.0, 2.0, 2.0]}) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["cumprod", "cumsum"]) +def test_numpy_compat(func): + # see gh-12811 + df = DataFrame({"A": [1, 2, 1], "B": [1, 2, 3]}) + g = df.groupby("A") + + msg = "numpy operations are not valid with groupby" + + with pytest.raises(UnsupportedFunctionCall, match=msg): + getattr(g, func)(1, 2, 3) + with pytest.raises(UnsupportedFunctionCall, match=msg): + getattr(g, func)(foo=1) + + +@td.skip_if_32bit +@pytest.mark.parametrize("method", ["cummin", "cummax"]) +@pytest.mark.parametrize( + "dtype,val", [("UInt64", np.iinfo("uint64").max), ("Int64", 2**53 + 1)] +) +def test_nullable_int_not_cast_as_float(method, dtype, val): + data = [val, pd.NA] + df = DataFrame({"grp": [1, 1], "b": data}, dtype=dtype) + grouped = df.groupby("grp") + + result = grouped.transform(method) + expected = DataFrame({"b": data}, dtype=dtype) + + tm.assert_frame_equal(result, expected) + + +def test_cython_api2(as_index): + # this takes the fast apply path + + # cumsum (GH5614) + # GH 5755 - cumsum is a transformer and should ignore as_index + df = DataFrame([[1, 2, np.nan], [1, np.nan, 9], [3, 4, 9]], columns=["A", "B", "C"]) + expected = DataFrame([[2, np.nan], [np.nan, 9], [4, 9]], columns=["B", "C"]) + result = df.groupby("A", as_index=as_index).cumsum() + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/groupby/test_filters.py b/pandas/tests/groupby/test_filters.py new file mode 100644 index 0000000000000000000000000000000000000000..c20fc9e3d62e77949383e9fd47080d0e23f94875 --- /dev/null +++ b/pandas/tests/groupby/test_filters.py @@ -0,0 +1,638 @@ +from string import ascii_lowercase + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Series, + Timestamp, +) +import pandas._testing as tm + + +def test_filter_series(): + s = Series([1, 3, 20, 5, 22, 24, 7]) + expected_odd = Series([1, 3, 5, 7], index=[0, 1, 3, 6]) + expected_even = Series([20, 22, 24], index=[2, 4, 5]) + grouper = s.apply(lambda x: x % 2) + grouped = s.groupby(grouper) + tm.assert_series_equal(grouped.filter(lambda x: x.mean() < 10), expected_odd) + tm.assert_series_equal(grouped.filter(lambda x: x.mean() > 10), expected_even) + # Test dropna=False. + tm.assert_series_equal( + grouped.filter(lambda x: x.mean() < 10, dropna=False), + expected_odd.reindex(s.index), + ) + tm.assert_series_equal( + grouped.filter(lambda x: x.mean() > 10, dropna=False), + expected_even.reindex(s.index), + ) + + +def test_filter_single_column_df(): + df = DataFrame([1, 3, 20, 5, 22, 24, 7]) + expected_odd = DataFrame([1, 3, 5, 7], index=[0, 1, 3, 6]) + expected_even = DataFrame([20, 22, 24], index=[2, 4, 5]) + grouper = df[0].apply(lambda x: x % 2) + grouped = df.groupby(grouper) + tm.assert_frame_equal(grouped.filter(lambda x: x.mean() < 10), expected_odd) + tm.assert_frame_equal(grouped.filter(lambda x: x.mean() > 10), expected_even) + # Test dropna=False. + tm.assert_frame_equal( + grouped.filter(lambda x: x.mean() < 10, dropna=False), + expected_odd.reindex(df.index), + ) + tm.assert_frame_equal( + grouped.filter(lambda x: x.mean() > 10, dropna=False), + expected_even.reindex(df.index), + ) + + +def test_filter_multi_column_df(): + df = DataFrame({"A": [1, 12, 12, 1], "B": [1, 1, 1, 1]}) + grouper = df["A"].apply(lambda x: x % 2) + grouped = df.groupby(grouper) + expected = DataFrame({"A": [12, 12], "B": [1, 1]}, index=[1, 2]) + tm.assert_frame_equal( + grouped.filter(lambda x: x["A"].sum() - x["B"].sum() > 10), expected + ) + + +def test_filter_mixed_df(): + df = DataFrame({"A": [1, 12, 12, 1], "B": "a b c d".split()}) + grouper = df["A"].apply(lambda x: x % 2) + grouped = df.groupby(grouper) + expected = DataFrame({"A": [12, 12], "B": ["b", "c"]}, index=[1, 2]) + tm.assert_frame_equal(grouped.filter(lambda x: x["A"].sum() > 10), expected) + + +def test_filter_out_all_groups(): + s = Series([1, 3, 20, 5, 22, 24, 7]) + grouper = s.apply(lambda x: x % 2) + grouped = s.groupby(grouper) + tm.assert_series_equal(grouped.filter(lambda x: x.mean() > 1000), s[[]]) + df = DataFrame({"A": [1, 12, 12, 1], "B": "a b c d".split()}) + grouper = df["A"].apply(lambda x: x % 2) + grouped = df.groupby(grouper) + tm.assert_frame_equal(grouped.filter(lambda x: x["A"].sum() > 1000), df.loc[[]]) + + +def test_filter_out_no_groups(): + s = Series([1, 3, 20, 5, 22, 24, 7]) + grouper = s.apply(lambda x: x % 2) + grouped = s.groupby(grouper) + filtered = grouped.filter(lambda x: x.mean() > 0) + tm.assert_series_equal(filtered, s) + + +def test_filter_out_no_groups_dataframe(): + df = DataFrame({"A": [1, 12, 12, 1], "B": "a b c d".split()}) + grouper = df["A"].apply(lambda x: x % 2) + grouped = df.groupby(grouper) + filtered = grouped.filter(lambda x: x["A"].mean() > 0) + tm.assert_frame_equal(filtered, df) + + +def test_filter_out_all_groups_in_df(): + # GH12768 + df = DataFrame({"a": [1, 1, 2], "b": [1, 2, 0]}) + res = df.groupby("a") + res = res.filter(lambda x: x["b"].sum() > 5, dropna=False) + expected = DataFrame({"a": [np.nan] * 3, "b": [np.nan] * 3}) + tm.assert_frame_equal(expected, res) + + +def test_filter_out_all_groups_in_df_dropna_true(): + # GH12768 + df = DataFrame({"a": [1, 1, 2], "b": [1, 2, 0]}) + res = df.groupby("a") + res = res.filter(lambda x: x["b"].sum() > 5, dropna=True) + expected = DataFrame({"a": [], "b": []}, dtype="int64") + tm.assert_frame_equal(expected, res) + + +def test_filter_condition_raises(): + def raise_if_sum_is_zero(x): + if x.sum() == 0: + raise ValueError + return x.sum() > 0 + + s = Series([-1, 0, 1, 2]) + grouper = s.apply(lambda x: x % 2) + grouped = s.groupby(grouper) + msg = "the filter must return a boolean result" + with pytest.raises(TypeError, match=msg): + grouped.filter(raise_if_sum_is_zero) + + +def test_filter_bad_shapes(): + df = DataFrame({"A": np.arange(8), "B": list("aabbbbcc"), "C": np.arange(8)}) + s = df["B"] + g_df = df.groupby("B") + g_s = s.groupby(s) + + f = lambda x: x + msg = "filter function returned a DataFrame, but expected a scalar bool" + with pytest.raises(TypeError, match=msg): + g_df.filter(f) + msg = "the filter must return a boolean result" + with pytest.raises(TypeError, match=msg): + g_s.filter(f) + + f = lambda x: x == 1 + msg = "filter function returned a DataFrame, but expected a scalar bool" + with pytest.raises(TypeError, match=msg): + g_df.filter(f) + msg = "the filter must return a boolean result" + with pytest.raises(TypeError, match=msg): + g_s.filter(f) + + f = lambda x: np.outer(x, x) + msg = "can't multiply sequence by non-int of type 'str'" + with pytest.raises(TypeError, match=msg): + g_df.filter(f) + msg = "the filter must return a boolean result" + with pytest.raises(TypeError, match=msg): + g_s.filter(f) + + +def test_filter_nan_is_false(): + df = DataFrame({"A": np.arange(8), "B": list("aabbbbcc"), "C": np.arange(8)}) + s = df["B"] + g_df = df.groupby(df["B"]) + g_s = s.groupby(s) + + f = lambda x: np.nan + tm.assert_frame_equal(g_df.filter(f), df.loc[[]]) + tm.assert_series_equal(g_s.filter(f), s[[]]) + + +def test_filter_pdna_is_false(): + # in particular, dont raise in filter trying to call bool(pd.NA) + df = DataFrame({"A": np.arange(8), "B": list("aabbbbcc"), "C": np.arange(8)}) + ser = df["B"] + g_df = df.groupby(df["B"]) + g_s = ser.groupby(ser) + + func = lambda x: pd.NA + res = g_df.filter(func) + tm.assert_frame_equal(res, df.loc[[]]) + res = g_s.filter(func) + tm.assert_series_equal(res, ser[[]]) + + +def test_filter_against_workaround_ints(): + # Series of ints + s = Series(np.random.default_rng(2).integers(0, 100, 10)) + grouper = s.apply(lambda x: np.round(x, -1)) + grouped = s.groupby(grouper) + f = lambda x: x.mean() > 10 + + old_way = s[grouped.transform(f).astype("bool")] + new_way = grouped.filter(f) + tm.assert_series_equal(new_way.sort_values(), old_way.sort_values()) + + +def test_filter_against_workaround_floats(): + # Series of floats + s = 100 * Series(np.random.default_rng(2).random(10)) + grouper = s.apply(lambda x: np.round(x, -1)) + grouped = s.groupby(grouper) + f = lambda x: x.mean() > 10 + old_way = s[grouped.transform(f).astype("bool")] + new_way = grouped.filter(f) + tm.assert_series_equal(new_way.sort_values(), old_way.sort_values()) + + +def test_filter_against_workaround_dataframe(): + # Set up DataFrame of ints, floats, strings. + letters = np.array(list(ascii_lowercase)) + N = 10 + random_letters = letters.take( + np.random.default_rng(2).integers(0, 26, N, dtype=int) + ) + df = DataFrame( + { + "ints": Series(np.random.default_rng(2).integers(0, 10, N)), + "floats": N / 10 * Series(np.random.default_rng(2).random(N)), + "letters": Series(random_letters), + } + ) + + # Group by ints; filter on floats. + grouped = df.groupby("ints") + old_way = df[grouped.floats.transform(lambda x: x.mean() > N / 2).astype("bool")] + new_way = grouped.filter(lambda x: x["floats"].mean() > N / 2) + tm.assert_frame_equal(new_way, old_way) + + # Group by floats (rounded); filter on strings. + grouper = df.floats.apply(lambda x: np.round(x, -1)) + grouped = df.groupby(grouper) + old_way = df[grouped.letters.transform(lambda x: len(x) < N / 2).astype("bool")] + new_way = grouped.filter(lambda x: len(x.letters) < N / 2) + tm.assert_frame_equal(new_way, old_way) + + # Group by strings; filter on ints. + grouped = df.groupby("letters") + old_way = df[grouped.ints.transform(lambda x: x.mean() > N / 2).astype("bool")] + new_way = grouped.filter(lambda x: x["ints"].mean() > N / 2) + tm.assert_frame_equal(new_way, old_way) + + +def test_filter_using_len(): + # GH 4447 + df = DataFrame({"A": np.arange(8), "B": list("aabbbbcc"), "C": np.arange(8)}) + grouped = df.groupby("B") + actual = grouped.filter(lambda x: len(x) > 2) + expected = DataFrame( + {"A": np.arange(2, 6), "B": list("bbbb"), "C": np.arange(2, 6)}, + index=range(2, 6), + ) + tm.assert_frame_equal(actual, expected) + + actual = grouped.filter(lambda x: len(x) > 4) + expected = df.loc[[]] + tm.assert_frame_equal(actual, expected) + + +def test_filter_using_len_series(): + # GH 4447 + s = Series(list("aabbbbcc"), name="B") + grouped = s.groupby(s) + actual = grouped.filter(lambda x: len(x) > 2) + expected = Series(4 * ["b"], index=range(2, 6), name="B") + tm.assert_series_equal(actual, expected) + + actual = grouped.filter(lambda x: len(x) > 4) + expected = s[[]] + tm.assert_series_equal(actual, expected) + + +@pytest.mark.parametrize( + "index", [range(8), range(7, -1, -1), [0, 2, 1, 3, 4, 6, 5, 7]] +) +def test_filter_maintains_ordering(index): + # GH 4621 + df = DataFrame( + {"pid": [1, 1, 1, 2, 2, 3, 3, 3], "tag": [23, 45, 62, 24, 45, 34, 25, 62]}, + index=index, + ) + s = df["pid"] + grouped = df.groupby("tag") + actual = grouped.filter(lambda x: len(x) > 1) + expected = df.iloc[[1, 2, 4, 7]] + tm.assert_frame_equal(actual, expected) + + grouped = s.groupby(df["tag"]) + actual = grouped.filter(lambda x: len(x) > 1) + expected = s.iloc[[1, 2, 4, 7]] + tm.assert_series_equal(actual, expected) + + +def test_filter_multiple_timestamp(): + # GH 10114 + df = DataFrame( + { + "A": np.arange(5, dtype="int64"), + "B": ["foo", "bar", "foo", "bar", "bar"], + "C": Timestamp("20130101"), + } + ) + + grouped = df.groupby(["B", "C"]) + + result = grouped["A"].filter(lambda x: True) + tm.assert_series_equal(df["A"], result) + + result = grouped["A"].transform(len) + expected = Series([2, 3, 2, 3, 3], name="A") + tm.assert_series_equal(result, expected) + + result = grouped.filter(lambda x: True) + tm.assert_frame_equal(df, result) + + result = grouped.transform("sum") + expected = DataFrame({"A": [2, 8, 2, 8, 8]}) + tm.assert_frame_equal(result, expected) + + result = grouped.transform(len) + expected = DataFrame({"A": [2, 3, 2, 3, 3]}) + tm.assert_frame_equal(result, expected) + + +def test_filter_and_transform_with_non_unique_int_index(): + # GH4620 + index = [1, 1, 1, 2, 1, 1, 0, 1] + df = DataFrame( + {"pid": [1, 1, 1, 2, 2, 3, 3, 3], "tag": [23, 45, 62, 24, 45, 34, 25, 62]}, + index=index, + ) + grouped_df = df.groupby("tag") + ser = df["pid"] + grouped_ser = ser.groupby(df["tag"]) + expected_indexes = [1, 2, 4, 7] + + # Filter DataFrame + actual = grouped_df.filter(lambda x: len(x) > 1) + expected = df.iloc[expected_indexes] + tm.assert_frame_equal(actual, expected) + + actual = grouped_df.filter(lambda x: len(x) > 1, dropna=False) + # Cast to avoid upcast when setting nan below + expected = df.copy().astype("float64") + expected.iloc[[0, 3, 5, 6]] = np.nan + tm.assert_frame_equal(actual, expected) + + # Filter Series + actual = grouped_ser.filter(lambda x: len(x) > 1) + expected = ser.take(expected_indexes) + tm.assert_series_equal(actual, expected) + + actual = grouped_ser.filter(lambda x: len(x) > 1, dropna=False) + expected = Series([np.nan, 1, 1, np.nan, 2, np.nan, np.nan, 3], index, name="pid") + # ^ made manually because this can get confusing! + tm.assert_series_equal(actual, expected) + + # Transform Series + actual = grouped_ser.transform(len) + expected = Series([1, 2, 2, 1, 2, 1, 1, 2], index, name="pid") + tm.assert_series_equal(actual, expected) + + # Transform (a column from) DataFrameGroupBy + actual = grouped_df.pid.transform(len) + tm.assert_series_equal(actual, expected) + + +def test_filter_and_transform_with_multiple_non_unique_int_index(): + # GH4620 + index = [1, 1, 1, 2, 0, 0, 0, 1] + df = DataFrame( + {"pid": [1, 1, 1, 2, 2, 3, 3, 3], "tag": [23, 45, 62, 24, 45, 34, 25, 62]}, + index=index, + ) + grouped_df = df.groupby("tag") + ser = df["pid"] + grouped_ser = ser.groupby(df["tag"]) + expected_indexes = [1, 2, 4, 7] + + # Filter DataFrame + actual = grouped_df.filter(lambda x: len(x) > 1) + expected = df.iloc[expected_indexes] + tm.assert_frame_equal(actual, expected) + + actual = grouped_df.filter(lambda x: len(x) > 1, dropna=False) + # Cast to avoid upcast when setting nan below + expected = df.copy().astype("float64") + expected.iloc[[0, 3, 5, 6]] = np.nan + tm.assert_frame_equal(actual, expected) + + # Filter Series + actual = grouped_ser.filter(lambda x: len(x) > 1) + expected = ser.take(expected_indexes) + tm.assert_series_equal(actual, expected) + + actual = grouped_ser.filter(lambda x: len(x) > 1, dropna=False) + expected = Series([np.nan, 1, 1, np.nan, 2, np.nan, np.nan, 3], index, name="pid") + # ^ made manually because this can get confusing! + tm.assert_series_equal(actual, expected) + + # Transform Series + actual = grouped_ser.transform(len) + expected = Series([1, 2, 2, 1, 2, 1, 1, 2], index, name="pid") + tm.assert_series_equal(actual, expected) + + # Transform (a column from) DataFrameGroupBy + actual = grouped_df.pid.transform(len) + tm.assert_series_equal(actual, expected) + + +def test_filter_and_transform_with_non_unique_float_index(): + # GH4620 + index = np.array([1, 1, 1, 2, 1, 1, 0, 1], dtype=float) + df = DataFrame( + {"pid": [1, 1, 1, 2, 2, 3, 3, 3], "tag": [23, 45, 62, 24, 45, 34, 25, 62]}, + index=index, + ) + grouped_df = df.groupby("tag") + ser = df["pid"] + grouped_ser = ser.groupby(df["tag"]) + expected_indexes = [1, 2, 4, 7] + + # Filter DataFrame + actual = grouped_df.filter(lambda x: len(x) > 1) + expected = df.iloc[expected_indexes] + tm.assert_frame_equal(actual, expected) + + actual = grouped_df.filter(lambda x: len(x) > 1, dropna=False) + # Cast to avoid upcast when setting nan below + expected = df.copy().astype("float64") + expected.iloc[[0, 3, 5, 6]] = np.nan + tm.assert_frame_equal(actual, expected) + + # Filter Series + actual = grouped_ser.filter(lambda x: len(x) > 1) + expected = ser.take(expected_indexes) + tm.assert_series_equal(actual, expected) + + actual = grouped_ser.filter(lambda x: len(x) > 1, dropna=False) + expected = Series([np.nan, 1, 1, np.nan, 2, np.nan, np.nan, 3], index, name="pid") + # ^ made manually because this can get confusing! + tm.assert_series_equal(actual, expected) + + # Transform Series + actual = grouped_ser.transform(len) + expected = Series([1, 2, 2, 1, 2, 1, 1, 2], index, name="pid") + tm.assert_series_equal(actual, expected) + + # Transform (a column from) DataFrameGroupBy + actual = grouped_df.pid.transform(len) + tm.assert_series_equal(actual, expected) + + +def test_filter_and_transform_with_non_unique_timestamp_index(): + # GH4620 + t0 = Timestamp("2013-09-30 00:05:00") + t1 = Timestamp("2013-10-30 00:05:00") + t2 = Timestamp("2013-11-30 00:05:00") + index = [t1, t1, t1, t2, t1, t1, t0, t1] + df = DataFrame( + {"pid": [1, 1, 1, 2, 2, 3, 3, 3], "tag": [23, 45, 62, 24, 45, 34, 25, 62]}, + index=index, + ) + grouped_df = df.groupby("tag") + ser = df["pid"] + grouped_ser = ser.groupby(df["tag"]) + expected_indexes = [1, 2, 4, 7] + + # Filter DataFrame + actual = grouped_df.filter(lambda x: len(x) > 1) + expected = df.iloc[expected_indexes] + tm.assert_frame_equal(actual, expected) + + actual = grouped_df.filter(lambda x: len(x) > 1, dropna=False) + # Cast to avoid upcast when setting nan below + expected = df.copy().astype("float64") + expected.iloc[[0, 3, 5, 6]] = np.nan + tm.assert_frame_equal(actual, expected) + + # Filter Series + actual = grouped_ser.filter(lambda x: len(x) > 1) + expected = ser.take(expected_indexes) + tm.assert_series_equal(actual, expected) + + actual = grouped_ser.filter(lambda x: len(x) > 1, dropna=False) + expected = Series([np.nan, 1, 1, np.nan, 2, np.nan, np.nan, 3], index, name="pid") + # ^ made manually because this can get confusing! + tm.assert_series_equal(actual, expected) + + # Transform Series + actual = grouped_ser.transform(len) + expected = Series([1, 2, 2, 1, 2, 1, 1, 2], index, name="pid") + tm.assert_series_equal(actual, expected) + + # Transform (a column from) DataFrameGroupBy + actual = grouped_df.pid.transform(len) + tm.assert_series_equal(actual, expected) + + +def test_filter_and_transform_with_non_unique_string_index(): + # GH4620 + index = list("bbbcbbab") + df = DataFrame( + {"pid": [1, 1, 1, 2, 2, 3, 3, 3], "tag": [23, 45, 62, 24, 45, 34, 25, 62]}, + index=index, + ) + grouped_df = df.groupby("tag") + ser = df["pid"] + grouped_ser = ser.groupby(df["tag"]) + expected_indexes = [1, 2, 4, 7] + + # Filter DataFrame + actual = grouped_df.filter(lambda x: len(x) > 1) + expected = df.iloc[expected_indexes] + tm.assert_frame_equal(actual, expected) + + actual = grouped_df.filter(lambda x: len(x) > 1, dropna=False) + # Cast to avoid upcast when setting nan below + expected = df.copy().astype("float64") + expected.iloc[[0, 3, 5, 6]] = np.nan + tm.assert_frame_equal(actual, expected) + + # Filter Series + actual = grouped_ser.filter(lambda x: len(x) > 1) + expected = ser.take(expected_indexes) + tm.assert_series_equal(actual, expected) + + actual = grouped_ser.filter(lambda x: len(x) > 1, dropna=False) + expected = Series([np.nan, 1, 1, np.nan, 2, np.nan, np.nan, 3], index, name="pid") + # ^ made manually because this can get confusing! + tm.assert_series_equal(actual, expected) + + # Transform Series + actual = grouped_ser.transform(len) + expected = Series([1, 2, 2, 1, 2, 1, 1, 2], index, name="pid") + tm.assert_series_equal(actual, expected) + + # Transform (a column from) DataFrameGroupBy + actual = grouped_df.pid.transform(len) + tm.assert_series_equal(actual, expected) + + +def test_filter_has_access_to_grouped_cols(): + df = DataFrame([[1, 2], [1, 3], [5, 6]], columns=["A", "B"]) + g = df.groupby("A") + # previously didn't have access to col A #???? + filt = g.filter(lambda x: x["A"].sum() == 2) + tm.assert_frame_equal(filt, df.iloc[[0, 1]]) + + +def test_filter_enforces_scalarness(): + df = DataFrame( + [ + ["best", "a", "x"], + ["worst", "b", "y"], + ["best", "c", "x"], + ["best", "d", "y"], + ["worst", "d", "y"], + ["worst", "d", "y"], + ["best", "d", "z"], + ], + columns=["a", "b", "c"], + ) + with pytest.raises(TypeError, match="filter function returned a.*"): + df.groupby("c").filter(lambda g: g["a"] == "best") + + +def test_filter_non_bool_raises(): + df = DataFrame( + [ + ["best", "a", 1], + ["worst", "b", 1], + ["best", "c", 1], + ["best", "d", 1], + ["worst", "d", 1], + ["worst", "d", 1], + ["best", "d", 1], + ], + columns=["a", "b", "c"], + ) + with pytest.raises(TypeError, match="filter function returned a.*"): + df.groupby("a").filter(lambda g: g.c.mean()) + + +def test_filter_dropna_with_empty_groups(): + # GH 10780 + data = Series(np.random.default_rng(2).random(9), index=np.repeat([1, 2, 3], 3)) + grouped = data.groupby(level=0) + result_false = grouped.filter(lambda x: x.mean() > 1, dropna=False) + expected_false = Series([np.nan] * 9, index=np.repeat([1, 2, 3], 3)) + tm.assert_series_equal(result_false, expected_false) + + result_true = grouped.filter(lambda x: x.mean() > 1, dropna=True) + expected_true = Series(index=pd.Index([], dtype=int), dtype=np.float64) + tm.assert_series_equal(result_true, expected_true) + + +def test_filter_consistent_result_before_after_agg_func(): + # GH 17091 + df = DataFrame({"data": range(6), "key": list("ABCABC")}) + grouper = df.groupby("key") + result = grouper.filter(lambda x: True) + expected = DataFrame({"data": range(6), "key": list("ABCABC")}) + tm.assert_frame_equal(result, expected) + + grouper.sum() + result = grouper.filter(lambda x: True) + tm.assert_frame_equal(result, expected) + + +def test_filter_with_non_values(): + # GH 62501 + df = DataFrame( + [ + [1], + [None], + ], + columns=["a"], + ) + + result = df.groupby("a", dropna=False).filter(lambda x: True) + tm.assert_frame_equal(result, df) + + +def test_filter_with_non_values_multi_index(): + # GH 62501 + df = DataFrame( + [ + [1, 2], + [3, None], + [None, 4], + [None, None], + ], + columns=["a", "b"], + ) + + result = df.groupby(["a", "b"], dropna=False).filter(lambda x: True) + tm.assert_frame_equal(result, df) diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py new file mode 100644 index 0000000000000000000000000000000000000000..54716bfff0fbafae2e354a2066ae6324d05f62bd --- /dev/null +++ b/pandas/tests/groupby/test_groupby.py @@ -0,0 +1,3004 @@ +from datetime import datetime +import decimal +from decimal import Decimal +import re + +import numpy as np +import pytest + +from pandas.errors import SpecificationError +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + Grouper, + Index, + Interval, + MultiIndex, + RangeIndex, + Series, + Timedelta, + Timestamp, + date_range, + to_datetime, +) +import pandas._testing as tm +from pandas.core.arrays import BooleanArray +import pandas.core.common as com + +pytestmark = pytest.mark.filterwarnings("ignore:Mean of empty slice:RuntimeWarning") + + +def test_repr(): + # GH18203 + result = repr(Grouper(key="A", level="B")) + expected = "Grouper(key='A', level='B', sort=False, dropna=True)" + assert result == expected + + +def test_groupby_nonobject_dtype(multiindex_dataframe_random_data): + key = multiindex_dataframe_random_data.index.codes[0] + grouped = multiindex_dataframe_random_data.groupby(key) + result = grouped.sum() + + expected = multiindex_dataframe_random_data.groupby(key.astype("O")).sum() + assert result.index.dtype == np.int8 + assert expected.index.dtype == np.int64 + tm.assert_frame_equal(result, expected, check_index_type=False) + + +def test_groupby_nonobject_dtype_mixed(): + # GH 3911, mixed frame non-conversion + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.array(np.random.default_rng(2).standard_normal(8), dtype="float32"), + } + ) + df["value"] = range(len(df)) + + def max_value(group): + return group.loc[group["value"].idxmax()] + + applied = df.groupby("A").apply(max_value) + result = applied.dtypes + expected = df.drop(columns="A").dtypes + tm.assert_series_equal(result, expected) + + +def test_pass_args_kwargs(ts): + def f(x, q=None, axis=0): + return np.percentile(x, q, axis=axis) + + g = lambda x: np.percentile(x, 80, axis=0) + + # Series + ts_grouped = ts.groupby(lambda x: x.month) + agg_result = ts_grouped.agg(np.percentile, 80, axis=0) + apply_result = ts_grouped.apply(np.percentile, 80, axis=0) + trans_result = ts_grouped.transform(np.percentile, 80, axis=0) + + agg_expected = ts_grouped.quantile(0.8) + trans_expected = ts_grouped.transform(g) + + tm.assert_series_equal(apply_result, agg_expected) + tm.assert_series_equal(agg_result, agg_expected) + tm.assert_series_equal(trans_result, trans_expected) + + agg_result = ts_grouped.agg(f, q=80) + apply_result = ts_grouped.apply(f, q=80) + trans_result = ts_grouped.transform(f, q=80) + tm.assert_series_equal(agg_result, agg_expected) + tm.assert_series_equal(apply_result, agg_expected) + tm.assert_series_equal(trans_result, trans_expected) + + +def test_pass_args_kwargs_dataframe(tsframe, as_index): + def f(x, q=None, axis=0): + return np.percentile(x, q, axis=axis) + + df_grouped = tsframe.groupby(lambda x: x.month, as_index=as_index) + agg_result = df_grouped.agg(np.percentile, 80, axis=0) + apply_result = df_grouped.apply(DataFrame.quantile, 0.8) + expected = df_grouped.quantile(0.8) + tm.assert_frame_equal(apply_result, expected, check_names=False) + tm.assert_frame_equal(agg_result, expected) + + apply_result = df_grouped.apply(DataFrame.quantile, [0.4, 0.8]) + expected_seq = df_grouped.quantile([0.4, 0.8]) + if not as_index: + # apply treats the op as a transform; .quantile knows it's a reduction + apply_result.index = range(4) + apply_result.insert(loc=0, column="level_0", value=[1, 1, 2, 2]) + apply_result.insert(loc=1, column="level_1", value=[0.4, 0.8, 0.4, 0.8]) + tm.assert_frame_equal(apply_result, expected_seq, check_names=False) + + agg_result = df_grouped.agg(f, q=80) + apply_result = df_grouped.apply(DataFrame.quantile, q=0.8) + tm.assert_frame_equal(agg_result, expected) + tm.assert_frame_equal(apply_result, expected, check_names=False) + + +def test_len(): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + grouped = df.groupby([lambda x: x.year, lambda x: x.month, lambda x: x.day]) + assert len(grouped) == len(df) + + grouped = df.groupby([lambda x: x.year, lambda x: x.month]) + expected = len({(x.year, x.month) for x in df.index}) + assert len(grouped) == expected + + +def test_len_nan_group(): + # issue 11016 + df = DataFrame({"a": [np.nan] * 3, "b": [1, 2, 3]}) + assert len(df.groupby("a")) == 0 + assert len(df.groupby("b")) == 3 + assert len(df.groupby(["a", "b"])) == 0 + + +def test_groupby_timedelta_median(): + # issue 57926 + expected = Series(data=Timedelta("1D"), index=["foo"], dtype="m8[us]") + df = DataFrame({"label": ["foo", "foo"], "timedelta": [pd.NaT, Timedelta("1D")]}) + gb = df.groupby("label")["timedelta"] + actual = gb.median() + tm.assert_series_equal(actual, expected, check_names=False) + + +@pytest.mark.parametrize("keys", [["a"], ["a", "b"]]) +def test_len_categorical(dropna, observed, keys): + # GH#57595 + df = DataFrame( + { + "a": Categorical([1, 1, 2, np.nan], categories=[1, 2, 3]), + "b": Categorical([1, 1, 2, np.nan], categories=[1, 2, 3]), + "c": 1, + } + ) + gb = df.groupby(keys, observed=observed, dropna=dropna) + result = len(gb) + if observed and dropna: + expected = 2 + elif observed and not dropna: + expected = 3 + elif len(keys) == 1: + expected = 3 if dropna else 4 + else: + expected = 9 if dropna else 16 + assert result == expected, f"{result} vs {expected}" + + +def test_basic_regression(): + # regression + result = Series([1.0 * x for x in list(range(1, 10)) * 10]) + + data = np.random.default_rng(2).random(1100) * 10.0 + groupings = Series(data) + + grouped = result.groupby(groupings) + grouped.mean() + + +def test_indices_concatenation_order(): + # GH 2808 + + def f1(x): + y = x[(x.b % 2) == 1] ** 2 + if y.empty: + multiindex = MultiIndex(levels=[[]] * 2, codes=[[]] * 2, names=["b", "c"]) + res = DataFrame(columns=["a"], index=multiindex) + return res + else: + y = y.set_index(["b", "c"]) + return y + + def f2(x): + y = x[(x.b % 2) == 1] ** 2 + if y.empty: + return DataFrame() + else: + y = y.set_index(["b", "c"]) + return y + + def f3(x): + y = x[(x.b % 2) == 1] ** 2 + if y.empty: + multiindex = MultiIndex( + levels=[[]] * 2, codes=[[]] * 2, names=["foo", "bar"] + ) + res = DataFrame(columns=["a", "b"], index=multiindex) + return res + else: + return y + + df = DataFrame({"a": [1, 2, 2, 2], "b": range(4), "c": range(5, 9)}) + + df2 = DataFrame({"a": [3, 2, 2, 2], "b": range(4), "c": range(5, 9)}) + + # correct result + result1 = df.groupby("a").apply(f1) + result2 = df2.groupby("a").apply(f1) + tm.assert_frame_equal(result1, result2) + + # should fail (not the same number of levels) + msg = "Cannot concat indices that do not have the same number of levels" + with pytest.raises(AssertionError, match=msg): + df.groupby("a").apply(f2) + with pytest.raises(AssertionError, match=msg): + df2.groupby("a").apply(f2) + + # should fail (incorrect shape) + with pytest.raises(AssertionError, match=msg): + df.groupby("a").apply(f3) + with pytest.raises(AssertionError, match=msg): + df2.groupby("a").apply(f3) + + +def test_attr_wrapper(ts): + grouped = ts.groupby(lambda x: x.weekday()) + + result = grouped.std() + expected = grouped.agg(lambda x: np.std(x, ddof=1)) + tm.assert_series_equal(result, expected) + + # this is pretty cool + result = grouped.describe() + expected = {name: gp.describe() for name, gp in grouped} + expected = DataFrame(expected).T + tm.assert_frame_equal(result, expected) + + # get attribute + result = grouped.dtype + expected = grouped.agg(lambda x: x.dtype) + tm.assert_series_equal(result, expected) + + # make sure raises error + msg = "'SeriesGroupBy' object has no attribute 'foo'" + with pytest.raises(AttributeError, match=msg): + grouped.foo + + +def test_frame_groupby(tsframe): + grouped = tsframe.groupby(lambda x: x.weekday()) + + # aggregate + aggregated = grouped.aggregate("mean") + assert len(aggregated) == 5 + assert len(aggregated.columns) == 4 + + # by string + tscopy = tsframe.copy() + tscopy["weekday"] = [x.weekday() for x in tscopy.index] + stragged = tscopy.groupby("weekday").aggregate("mean") + tm.assert_frame_equal(stragged, aggregated, check_names=False) + + # transform + grouped = tsframe.head(30).groupby(lambda x: x.weekday()) + transformed = grouped.transform(lambda x: x - x.mean()) + assert len(transformed) == 30 + assert len(transformed.columns) == 4 + + # transform propagate + transformed = grouped.transform(lambda x: x.mean()) + for name, group in grouped: + mean = group.mean() + for idx in group.index: + tm.assert_series_equal(transformed.xs(idx), mean, check_names=False) + + # iterate + for weekday, group in grouped: + assert group.index[0].weekday() == weekday + + # groups / group_indices + groups = grouped.groups + indices = grouped.indices + + for k, v in groups.items(): + samething = tsframe.index.take(indices[k]) + assert (samething == v).all() + + +def test_frame_set_name_single(df): + grouped = df.groupby("A") + + result = grouped.mean(numeric_only=True) + assert result.index.name == "A" + + result = df.groupby("A", as_index=False).mean(numeric_only=True) + assert result.index.name != "A" + + result = grouped[["C", "D"]].agg("mean") + assert result.index.name == "A" + + result = grouped.agg({"C": "mean", "D": "std"}) + assert result.index.name == "A" + + result = grouped["C"].mean() + assert result.index.name == "A" + result = grouped["C"].agg("mean") + assert result.index.name == "A" + result = grouped["C"].agg(["mean", "std"]) + assert result.index.name == "A" + + msg = r"nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + grouped["C"].agg({"foo": "mean", "bar": "std"}) + + +def test_multi_func(df): + col1 = df["A"] + col2 = df["B"] + + grouped = df.groupby([col1.get, col2.get]) + agged = grouped.mean(numeric_only=True) + expected = df.groupby(["A", "B"]).mean() + + # TODO groupby get drops names + tm.assert_frame_equal( + agged.loc[:, ["C", "D"]], expected.loc[:, ["C", "D"]], check_names=False + ) + + # some "groups" with no data + df = DataFrame( + { + "v1": np.random.default_rng(2).standard_normal(6), + "v2": np.random.default_rng(2).standard_normal(6), + "k1": np.array(["b", "b", "b", "a", "a", "a"]), + "k2": np.array(["1", "1", "1", "2", "2", "2"]), + }, + index=["one", "two", "three", "four", "five", "six"], + ) + # only verify that it works for now + grouped = df.groupby(["k1", "k2"]) + grouped.agg("sum") + + +def test_multi_key_multiple_functions(df): + grouped = df.groupby(["A", "B"])["C"] + + agged = grouped.agg(["mean", "std"]) + expected = DataFrame({"mean": grouped.agg("mean"), "std": grouped.agg("std")}) + tm.assert_frame_equal(agged, expected) + + +def test_frame_multi_key_function_list(): + data = DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + grouped = data.groupby(["A", "B"]) + funcs = ["mean", "std"] + agged = grouped.agg(funcs) + expected = pd.concat( + [grouped["D"].agg(funcs), grouped["E"].agg(funcs), grouped["F"].agg(funcs)], + keys=["D", "E", "F"], + axis=1, + ) + assert isinstance(agged.index, MultiIndex) + assert isinstance(expected.index, MultiIndex) + tm.assert_frame_equal(agged, expected) + + +def test_frame_multi_key_function_list_partial_failure(using_infer_string): + data = DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + grouped = data.groupby(["A", "B"]) + funcs = ["mean", "std"] + msg = re.escape("agg function failed [how->mean,dtype->") + if using_infer_string: + msg = "dtype 'str' does not support operation 'mean'" + with pytest.raises(TypeError, match=msg): + grouped.agg(funcs) + + +@pytest.mark.parametrize("op", [lambda x: x.sum(), lambda x: x.mean()]) +def test_groupby_multiple_columns(df, op): + data = df + grouped = data.groupby(["A", "B"]) + + result1 = op(grouped) + + keys = [] + values = [] + for n1, gp1 in data.groupby("A"): + for n2, gp2 in gp1.groupby("B"): + keys.append((n1, n2)) + values.append(op(gp2.loc[:, ["C", "D"]])) + + mi = MultiIndex.from_tuples(keys, names=["A", "B"]) + expected = pd.concat(values, axis=1).T + expected.index = mi + + # a little bit crude + for col in ["C", "D"]: + result_col = op(grouped[col]) + pivoted = result1[col] + exp = expected[col] + tm.assert_series_equal(result_col, exp) + tm.assert_series_equal(pivoted, exp) + + # test single series works the same + result = data["C"].groupby([data["A"], data["B"]]).mean() + expected = data.groupby(["A", "B"]).mean()["C"] + + tm.assert_series_equal(result, expected) + + +def test_as_index_select_column(): + # GH 5764 + df = DataFrame([[1, 2], [1, 4], [5, 6]], columns=["A", "B"]) + result = df.groupby("A", as_index=False)["B"].get_group(1) + expected = Series([2, 4], name="B") + tm.assert_series_equal(result, expected) + + result = df.groupby("A", as_index=False, group_keys=True)["B"].apply( + lambda x: x.cumsum() + ) + expected = Series([2, 6, 6], name="B", index=range(3)) + tm.assert_series_equal(result, expected) + + +def test_groupby_as_index_select_column_sum_empty_df(): + # GH 35246 + df = DataFrame(columns=Index(["A", "B", "C"], name="alpha")) + left = df.groupby(by="A", as_index=False)["B"].sum(numeric_only=False) + + expected = DataFrame(columns=df.columns[:2], index=range(0)) + # GH#50744 - Columns after selection shouldn't retain names + expected.columns.names = [None] + tm.assert_frame_equal(left, expected) + + +def test_ops_not_as_index(reduction_func): + # GH 10355, 21090 + # Using as_index=False should not modify grouped column + + if reduction_func in ("corrwith", "nth", "ngroup"): + pytest.skip(f"GH 5755: Test not applicable for {reduction_func}") + + df = DataFrame( + np.random.default_rng(2).integers(0, 5, size=(100, 2)), columns=["a", "b"] + ) + expected = getattr(df.groupby("a"), reduction_func)() + if reduction_func == "size": + expected = expected.rename("size") + expected = expected.reset_index() + + if reduction_func != "size": + # 32 bit compat -> groupby preserves dtype whereas reset_index casts to int64 + expected["a"] = expected["a"].astype(df["a"].dtype) + + g = df.groupby("a", as_index=False) + + result = getattr(g, reduction_func)() + tm.assert_frame_equal(result, expected) + + result = g.agg(reduction_func) + tm.assert_frame_equal(result, expected) + + result = getattr(g["b"], reduction_func)() + tm.assert_frame_equal(result, expected) + + result = g["b"].agg(reduction_func) + tm.assert_frame_equal(result, expected) + + +def test_as_index_series_return_frame(df): + grouped = df.groupby("A", as_index=False) + grouped2 = df.groupby(["A", "B"], as_index=False) + + result = grouped["C"].agg("sum") + expected = grouped.agg("sum").loc[:, ["A", "C"]] + assert isinstance(result, DataFrame) + tm.assert_frame_equal(result, expected) + + result2 = grouped2["C"].agg("sum") + expected2 = grouped2.agg("sum").loc[:, ["A", "B", "C"]] + assert isinstance(result2, DataFrame) + tm.assert_frame_equal(result2, expected2) + + result = grouped["C"].sum() + expected = grouped.sum().loc[:, ["A", "C"]] + assert isinstance(result, DataFrame) + tm.assert_frame_equal(result, expected) + + result2 = grouped2["C"].sum() + expected2 = grouped2.sum().loc[:, ["A", "B", "C"]] + assert isinstance(result2, DataFrame) + tm.assert_frame_equal(result2, expected2) + + +def test_as_index_series_column_slice_raises(df): + # GH15072 + grouped = df.groupby("A", as_index=False) + msg = r"Column\(s\) C already selected" + + with pytest.raises(IndexError, match=msg): + grouped["C"].__getitem__("D") + + +def test_groupby_as_index_cython(df): + data = df + + # single-key + grouped = data.groupby("A", as_index=False) + result = grouped.mean(numeric_only=True) + expected = data.groupby(["A"]).mean(numeric_only=True) + expected.insert(0, "A", expected.index) + expected.index = RangeIndex(len(expected)) + tm.assert_frame_equal(result, expected) + + # multi-key + grouped = data.groupby(["A", "B"], as_index=False) + result = grouped.mean() + expected = data.groupby(["A", "B"]).mean() + + arrays = list(zip(*expected.index.values, strict=True)) + expected.insert(0, "A", arrays[0]) + expected.insert(1, "B", arrays[1]) + expected.index = RangeIndex(len(expected)) + tm.assert_frame_equal(result, expected) + + +def test_groupby_as_index_series_scalar(df): + grouped = df.groupby(["A", "B"], as_index=False) + + # GH #421 + + result = grouped["C"].agg(len) + expected = grouped.agg(len).loc[:, ["A", "B", "C"]] + tm.assert_frame_equal(result, expected) + + +def test_groupby_multiple_key(): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + grouped = df.groupby([lambda x: x.year, lambda x: x.month, lambda x: x.day]) + agged = grouped.sum() + tm.assert_almost_equal(df.values, agged.values) + + +def test_groupby_multi_corner(df): + # test that having an all-NA column doesn't mess you up + df = df.copy() + df["bad"] = np.nan + agged = df.groupby(["A", "B"]).mean() + + expected = df.groupby(["A", "B"]).mean() + expected["bad"] = np.nan + + tm.assert_frame_equal(agged, expected) + + +def test_raises_on_nuisance(df, using_infer_string): + grouped = df.groupby("A") + msg = re.escape("agg function failed [how->mean,dtype->") + if using_infer_string: + msg = "dtype 'str' does not support operation 'mean'" + with pytest.raises(TypeError, match=msg): + grouped.agg("mean") + with pytest.raises(TypeError, match=msg): + grouped.mean() + + df = df.loc[:, ["A", "C", "D"]] + df["E"] = datetime.now() + grouped = df.groupby("A") + msg = "datetime64 type does not support operation 'sum'" + with pytest.raises(TypeError, match=msg): + grouped.agg("sum") + with pytest.raises(TypeError, match=msg): + grouped.sum() + + +@pytest.mark.parametrize( + "agg_function", + ["max", "min"], +) +def test_keep_nuisance_agg(df, agg_function): + # GH 38815 + grouped = df.groupby("A") + result = getattr(grouped, agg_function)() + expected = result.copy() + expected.loc["bar", "B"] = getattr(df.loc[df["A"] == "bar", "B"], agg_function)() + expected.loc["foo", "B"] = getattr(df.loc[df["A"] == "foo", "B"], agg_function)() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "agg_function", + ["sum", "mean", "prod", "std", "var", "sem", "median"], +) +@pytest.mark.parametrize("numeric_only", [True, False]) +def test_omit_nuisance_agg(df, agg_function, numeric_only, using_infer_string): + # GH 38774, GH 38815 + grouped = df.groupby("A") + + no_drop_nuisance = ("var", "std", "sem", "mean", "prod", "median") + if agg_function in no_drop_nuisance and not numeric_only: + # Added numeric_only as part of GH#46560; these do not drop nuisance + # columns when numeric_only is False + if using_infer_string: + msg = f"dtype 'str' does not support operation '{agg_function}'" + klass = TypeError + elif agg_function in ("std", "sem"): + klass = ValueError + msg = "could not convert string to float: 'one'" + else: + klass = TypeError + msg = re.escape(f"agg function failed [how->{agg_function},dtype->") + with pytest.raises(klass, match=msg): + getattr(grouped, agg_function)(numeric_only=numeric_only) + else: + result = getattr(grouped, agg_function)(numeric_only=numeric_only) + if not numeric_only and agg_function == "sum": + # sum is successful on column B + columns = ["A", "B", "C", "D"] + else: + columns = ["A", "C", "D"] + expected = getattr(df.loc[:, columns].groupby("A"), agg_function)( + numeric_only=numeric_only + ) + tm.assert_frame_equal(result, expected) + + +def test_raise_on_nuisance_python_single(df, using_infer_string): + # GH 38815 + grouped = df.groupby("A") + + err = ValueError + msg = "could not convert" + if using_infer_string: + err = TypeError + msg = "dtype 'str' does not support operation 'skew'" + with pytest.raises(err, match=msg): + grouped.skew() + + +def test_raise_on_nuisance_python_multiple(three_group, using_infer_string): + grouped = three_group.groupby(["A", "B"]) + msg = re.escape("agg function failed [how->mean,dtype->") + if using_infer_string: + msg = "dtype 'str' does not support operation 'mean'" + with pytest.raises(TypeError, match=msg): + grouped.agg("mean") + with pytest.raises(TypeError, match=msg): + grouped.mean() + + +def test_empty_groups_corner(multiindex_dataframe_random_data): + # handle empty groups + df = DataFrame( + { + "k1": np.array(["b", "b", "b", "a", "a", "a"]), + "k2": np.array(["1", "1", "1", "2", "2", "2"]), + "k3": ["foo", "bar"] * 3, + "v1": np.random.default_rng(2).standard_normal(6), + "v2": np.random.default_rng(2).standard_normal(6), + } + ) + + grouped = df.groupby(["k1", "k2"]) + result = grouped[["v1", "v2"]].agg("mean") + expected = grouped.mean(numeric_only=True) + tm.assert_frame_equal(result, expected) + + grouped = multiindex_dataframe_random_data[3:5].groupby(level=0) + agged = grouped.apply(lambda x: x.mean()) + agged_A = grouped["A"].apply("mean") + tm.assert_series_equal(agged["A"], agged_A) + assert agged.index.name == "first" + + +def test_nonsense_func(): + df = DataFrame([0]) + msg = r"unsupported operand type\(s\) for \+: 'int' and 'str'" + with pytest.raises(TypeError, match=msg): + df.groupby(lambda x: x + "foo") + + +def test_wrap_aggregated_output_multindex( + multiindex_dataframe_random_data, using_infer_string +): + df = multiindex_dataframe_random_data.T + df["baz", "two"] = "peekaboo" + + keys = [np.array([0, 0, 1]), np.array([0, 0, 1])] + msg = re.escape("agg function failed [how->mean,dtype->") + if using_infer_string: + msg = "dtype 'str' does not support operation 'mean'" + with pytest.raises(TypeError, match=msg): + df.groupby(keys).agg("mean") + agged = df.drop(columns=("baz", "two")).groupby(keys).agg("mean") + assert isinstance(agged.columns, MultiIndex) + + def aggfun(ser): + if ser.name == ("foo", "one"): + raise TypeError("Test error message") + return ser.sum() + + with pytest.raises(TypeError, match="Test error message"): + df.groupby(keys).aggregate(aggfun) + + +def test_groupby_level_apply(multiindex_dataframe_random_data): + result = multiindex_dataframe_random_data.groupby(level=0).count() + assert result.index.name == "first" + result = multiindex_dataframe_random_data.groupby(level=1).count() + assert result.index.name == "second" + + result = multiindex_dataframe_random_data["A"].groupby(level=0).count() + assert result.index.name == "first" + + +def test_groupby_level_mapper(multiindex_dataframe_random_data): + deleveled = multiindex_dataframe_random_data.reset_index() + + mapper0 = {"foo": 0, "bar": 0, "baz": 1, "qux": 1} + mapper1 = {"one": 0, "two": 0, "three": 1} + + result0 = multiindex_dataframe_random_data.groupby(mapper0, level=0).sum() + result1 = multiindex_dataframe_random_data.groupby(mapper1, level=1).sum() + + mapped_level0 = np.array( + [mapper0.get(x) for x in deleveled["first"]], dtype=np.int64 + ) + mapped_level1 = np.array( + [mapper1.get(x) for x in deleveled["second"]], dtype=np.int64 + ) + expected0 = multiindex_dataframe_random_data.groupby(mapped_level0).sum() + expected1 = multiindex_dataframe_random_data.groupby(mapped_level1).sum() + expected0.index.name, expected1.index.name = "first", "second" + + tm.assert_frame_equal(result0, expected0) + tm.assert_frame_equal(result1, expected1) + + +def test_groupby_level_nonmulti(): + # GH 1313, GH 13901 + s = Series([1, 2, 3, 10, 4, 5, 20, 6], Index([1, 2, 3, 1, 4, 5, 2, 6], name="foo")) + expected = Series([11, 22, 3, 4, 5, 6], Index(list(range(1, 7)), name="foo")) + + result = s.groupby(level=0).sum() + tm.assert_series_equal(result, expected) + result = s.groupby(level=[0]).sum() + tm.assert_series_equal(result, expected) + result = s.groupby(level=-1).sum() + tm.assert_series_equal(result, expected) + result = s.groupby(level=[-1]).sum() + tm.assert_series_equal(result, expected) + + msg = "level > 0 or level < -1 only valid with MultiIndex" + with pytest.raises(ValueError, match=msg): + s.groupby(level=1) + with pytest.raises(ValueError, match=msg): + s.groupby(level=-2) + msg = "No group keys passed!" + with pytest.raises(ValueError, match=msg): + s.groupby(level=[]) + msg = "multiple levels only valid with MultiIndex" + with pytest.raises(ValueError, match=msg): + s.groupby(level=[0, 0]) + with pytest.raises(ValueError, match=msg): + s.groupby(level=[0, 1]) + msg = "level > 0 or level < -1 only valid with MultiIndex" + with pytest.raises(ValueError, match=msg): + s.groupby(level=[1]) + + +def test_groupby_complex(): + # GH 12902 + a = Series(data=np.arange(4) * (1 + 2j), index=[0, 0, 1, 1]) + expected = Series((1 + 2j, 5 + 10j), index=Index([0, 1])) + + result = a.groupby(level=0).sum() + tm.assert_series_equal(result, expected) + + +def test_groupby_complex_mean(): + # GH 26475 + df = DataFrame( + [ + {"a": 2, "b": 1 + 2j}, + {"a": 1, "b": 1 + 1j}, + {"a": 1, "b": 1 + 2j}, + ] + ) + result = df.groupby("b").mean() + expected = DataFrame( + [[1.0], [1.5]], + index=Index([(1 + 1j), (1 + 2j)], name="b"), + columns=Index(["a"]), + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_complex_numbers(): + # GH 17927 + df = DataFrame( + [ + {"a": 1, "b": 1 + 1j}, + {"a": 1, "b": 1 + 2j}, + {"a": 4, "b": 1}, + ] + ) + expected = DataFrame( + np.array([1, 1, 1], dtype=np.int64), + index=Index([(1 + 1j), (1 + 2j), (1 + 0j)], name="b"), + columns=Index(["a"]), + ) + result = df.groupby("b", sort=False).count() + tm.assert_frame_equal(result, expected) + + # Sorted by the magnitude of the complex numbers + expected.index = Index([(1 + 0j), (1 + 1j), (1 + 2j)], name="b") + result = df.groupby("b", sort=True).count() + tm.assert_frame_equal(result, expected) + + +def test_groupby_series_indexed_differently(): + s1 = Series( + [5.0, -9.0, 4.0, 100.0, -5.0, 55.0, 6.7], + index=Index(["a", "b", "c", "d", "e", "f", "g"]), + ) + s2 = Series( + [1.0, 1.0, 4.0, 5.0, 5.0, 7.0], index=Index(["a", "b", "d", "f", "g", "h"]) + ) + + grouped = s1.groupby(s2) + agged = grouped.mean() + exp = s1.groupby(s2.reindex(s1.index).get).mean() + tm.assert_series_equal(agged, exp) + + +def test_groupby_with_hier_columns(): + tuples = list( + zip( + *[ + ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"], + ["one", "two", "one", "two", "one", "two", "one", "two"], + ], + strict=True, + ) + ) + index = MultiIndex.from_tuples(tuples) + columns = MultiIndex.from_tuples( + [("A", "cat"), ("B", "dog"), ("B", "cat"), ("A", "dog")] + ) + df = DataFrame( + np.random.default_rng(2).standard_normal((8, 4)), index=index, columns=columns + ) + + result = df.groupby(level=0).mean() + tm.assert_index_equal(result.columns, columns) + + result = df.groupby(level=0).agg("mean") + tm.assert_index_equal(result.columns, columns) + + result = df.groupby(level=0).apply(lambda x: x.mean()) + tm.assert_index_equal(result.columns, columns) + + # add a nuisance column + sorted_columns, _ = columns.sortlevel(0) + df["A", "foo"] = "bar" + result = df.groupby(level=0).mean(numeric_only=True) + tm.assert_index_equal(result.columns, df.columns[:-1]) + + +def test_grouping_ndarray(df): + grouped = df.groupby(df["A"].values) + grouped2 = df.groupby(df["A"].rename(None)) + + result = grouped.sum() + expected = grouped2.sum() + tm.assert_frame_equal(result, expected) + + +def test_groupby_wrong_multi_labels(): + index = Index([0, 1, 2, 3, 4], name="index") + data = DataFrame( + { + "foo": ["foo1", "foo1", "foo2", "foo1", "foo3"], + "bar": ["bar1", "bar2", "bar2", "bar1", "bar1"], + "baz": ["baz1", "baz1", "baz1", "baz2", "baz2"], + "spam": ["spam2", "spam3", "spam2", "spam1", "spam1"], + "data": [20, 30, 40, 50, 60], + }, + index=index, + ) + + grouped = data.groupby(["foo", "bar", "baz", "spam"]) + + result = grouped.agg("mean") + expected = grouped.mean() + tm.assert_frame_equal(result, expected) + + +def test_groupby_series_with_name(df): + result = df.groupby(df["A"]).mean(numeric_only=True) + result2 = df.groupby(df["A"], as_index=False).mean(numeric_only=True) + assert result.index.name == "A" + assert "A" in result2 + + result = df.groupby([df["A"], df["B"]]).mean() + result2 = df.groupby([df["A"], df["B"]], as_index=False).mean() + assert result.index.names == ("A", "B") + assert "A" in result2 + assert "B" in result2 + + +def test_seriesgroupby_name_attr(df): + # GH 6265 + result = df.groupby("A")["C"] + assert result.count().name == "C" + assert result.mean().name == "C" + + testFunc = lambda x: np.sum(x) * 2 + assert result.agg(testFunc).name == "C" + + +def test_consistency_name(): + # GH 12363 + + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "two", "two", "two", "one", "two"], + "C": np.random.default_rng(2).standard_normal(8) + 1.0, + "D": np.arange(8), + } + ) + + expected = df.groupby(["A"]).B.count() + result = df.B.groupby(df.A).count() + tm.assert_series_equal(result, expected) + + +def test_groupby_name_propagation(df): + # GH 6124 + def summarize(df, name=None): + return Series({"count": 1, "mean": 2, "omissions": 3}, name=name) + + def summarize_random_name(df): + # Provide a different name for each Series. In this case, groupby + # should not attempt to propagate the Series name since they are + # inconsistent. + return Series({"count": 1, "mean": 2, "omissions": 3}, name=df.iloc[0]["C"]) + + metrics = df.groupby("A").apply(summarize) + assert metrics.columns.name is None + metrics = df.groupby("A").apply(summarize, "metrics") + assert metrics.columns.name == "metrics" + metrics = df.groupby("A").apply(summarize_random_name) + assert metrics.columns.name is None + + +def test_groupby_nonstring_columns(): + df = DataFrame([np.arange(10) for x in range(10)]) + grouped = df.groupby(0) + result = grouped.mean() + expected = df.groupby(df[0]).mean() + tm.assert_frame_equal(result, expected) + + +def test_groupby_mixed_type_columns(): + # GH 13432, unorderable types in py3 + df = DataFrame([[0, 1, 2]], columns=["A", "B", 0]) + expected = DataFrame([[1, 2]], columns=["B", 0], index=Index([0], name="A")) + + result = df.groupby("A").first() + tm.assert_frame_equal(result, expected) + + result = df.groupby("A").sum() + tm.assert_frame_equal(result, expected) + + +def test_cython_grouper_series_bug_noncontig(): + arr = np.empty((100, 100)) + arr.fill(np.nan) + obj = Series(arr[:, 0]) + inds = np.tile(range(10), 10) + + result = obj.groupby(inds).agg(Series.median) + assert result.isna().all() + + +def test_series_grouper_noncontig_index(): + index = Index(["a" * 10] * 100) + + values = Series(np.random.default_rng(2).standard_normal(50), index=index[::2]) + labels = np.random.default_rng(2).integers(0, 5, 50) + + # it works! + grouped = values.groupby(labels) + + # accessing the index elements causes segfault + f = lambda x: len(set(map(id, x.index))) + grouped.agg(f) + + +def test_convert_objects_leave_decimal_alone(): + s = Series(range(5)) + labels = np.array(["a", "b", "c", "d", "e"], dtype="O") + + def convert_fast(x): + return Decimal(str(x.mean())) + + def convert_force_pure(x): + # base will be length 0 + assert len(x.values.base) > 0 + return Decimal(str(x.mean())) + + grouped = s.groupby(labels) + + result = grouped.agg(convert_fast) + assert result.dtype == np.object_ + assert isinstance(result.iloc[0], Decimal) + + result = grouped.agg(convert_force_pure) + assert result.dtype == np.object_ + assert isinstance(result.iloc[0], Decimal) + + +def test_groupby_dtype_inference_empty(): + # GH 6733 + df = DataFrame({"x": [], "range": np.arange(0, dtype="int64")}) + assert df["x"].dtype == np.float64 + + result = df.groupby("x").first() + exp_index = Index([], name="x", dtype=np.float64) + expected = DataFrame({"range": Series([], index=exp_index, dtype="int64")}) + tm.assert_frame_equal(result, expected, by_blocks=True) + + +def test_groupby_unit64_float_conversion(): + # GH: 30859 groupby converts unit64 to floats sometimes + df = DataFrame({"first": [1], "second": [1], "value": [16148277970000000000]}) + result = df.groupby(["first", "second"])["value"].max() + expected = Series( + [16148277970000000000], + MultiIndex.from_product([[1], [1]], names=["first", "second"]), + name="value", + ) + tm.assert_series_equal(result, expected) + + +def test_groupby_list_infer_array_like(df): + result = df.groupby(list(df["A"])).mean(numeric_only=True) + expected = df.groupby(df["A"]).mean(numeric_only=True) + tm.assert_frame_equal(result, expected, check_names=False) + + with pytest.raises(KeyError, match=r"^'foo'$"): + df.groupby(list(df["A"][:-1])) + + # pathological case of ambiguity + df = DataFrame( + { + "foo": [0, 1], + "bar": [3, 4], + "val": np.random.default_rng(2).standard_normal(2), + } + ) + + result = df.groupby(["foo", "bar"]).mean() + expected = df.groupby([df["foo"], df["bar"]]).mean()[["val"]] + + +def test_groupby_keys_same_size_as_index(): + # GH 11185 + freq = "s" + index = date_range( + start=Timestamp("2015-09-29T11:34:44-0700"), periods=2, freq=freq + ) + df = DataFrame([["A", 10], ["B", 15]], columns=["metric", "values"], index=index) + result = df.groupby([Grouper(level=0, freq=freq), "metric"]).mean() + expected = df.set_index([df.index, "metric"]).astype(float) + + tm.assert_frame_equal(result, expected) + + +def test_groupby_one_row(): + # GH 11741 + msg = r"^'Z'$" + df1 = DataFrame( + np.random.default_rng(2).standard_normal((1, 4)), columns=list("ABCD") + ) + with pytest.raises(KeyError, match=msg): + df1.groupby("Z") + df2 = DataFrame( + np.random.default_rng(2).standard_normal((2, 4)), columns=list("ABCD") + ) + with pytest.raises(KeyError, match=msg): + df2.groupby("Z") + + +def test_groupby_nat_exclude(): + # GH 6992 + df = DataFrame( + { + "values": np.random.default_rng(2).standard_normal(8), + "dt": [ + np.nan, + Timestamp("2013-01-01"), + np.nan, + Timestamp("2013-02-01"), + np.nan, + Timestamp("2013-02-01"), + np.nan, + Timestamp("2013-01-01"), + ], + "str": [np.nan, "a", np.nan, "a", np.nan, "a", np.nan, "b"], + } + ) + grouped = df.groupby("dt") + + expected = [ + RangeIndex(start=1, stop=13, step=6), + RangeIndex(start=3, stop=7, step=2), + ] + keys = sorted(grouped.groups.keys()) + assert len(keys) == 2 + for k, e in zip(keys, expected, strict=True): + # grouped.groups keys are np.datetime64 with system tz + # not to be affected by tz, only compare values + tm.assert_index_equal(grouped.groups[k], e) + + # confirm obj is not filtered + tm.assert_frame_equal(grouped._grouper.groupings[0].obj, df) + assert grouped.ngroups == 2 + + expected = { + Timestamp("2013-01-01 00:00:00"): np.array([1, 7], dtype=np.intp), + Timestamp("2013-02-01 00:00:00"): np.array([3, 5], dtype=np.intp), + } + + for k in grouped.indices: + tm.assert_numpy_array_equal(grouped.indices[k], expected[k]) + + tm.assert_frame_equal(grouped.get_group(Timestamp("2013-01-01")), df.iloc[[1, 7]]) + tm.assert_frame_equal(grouped.get_group(Timestamp("2013-02-01")), df.iloc[[3, 5]]) + + with pytest.raises(KeyError, match=r"^NaT$"): + grouped.get_group(pd.NaT) + + nan_df = DataFrame( + {"nan": [np.nan, np.nan, np.nan], "nat": [pd.NaT, pd.NaT, pd.NaT]} + ) + assert nan_df["nan"].dtype == "float64" + assert nan_df["nat"].dtype == "datetime64[s]" + + for key in ["nan", "nat"]: + grouped = nan_df.groupby(key) + assert grouped.groups == {} + assert grouped.ngroups == 0 + assert grouped.indices == {} + with pytest.raises(KeyError, match=r"^nan$"): + grouped.get_group(np.nan) + with pytest.raises(KeyError, match=r"^NaT$"): + grouped.get_group(pd.NaT) + + +def test_groupby_two_group_keys_all_nan(): + # GH #36842: Grouping over two group keys shouldn't raise an error + df = DataFrame({"a": [np.nan, np.nan], "b": [np.nan, np.nan], "c": [1, 2]}) + result = df.groupby(["a", "b"]).indices + assert result == {} + + +def test_groupby_2d_malformed(): + d = DataFrame(index=range(2)) + d["group"] = ["g1", "g2"] + d["zeros"] = [0, 0] + d["ones"] = [1, 1] + d["label"] = ["l1", "l2"] + tmp = d.groupby(["group"]).mean(numeric_only=True) + res_values = np.array([[0.0, 1.0], [0.0, 1.0]]) + tm.assert_index_equal(tmp.columns, Index(["zeros", "ones"])) + tm.assert_numpy_array_equal(tmp.values, res_values) + + +def test_int32_overflow(): + B = np.concatenate((np.arange(10000), np.arange(10000), np.arange(5000))) + A = np.arange(25000) + df = DataFrame( + { + "A": A, + "B": B, + "C": A, + "D": B, + "E": np.random.default_rng(2).standard_normal(25000), + } + ) + + left = df.groupby(["A", "B", "C", "D"]).sum() + right = df.groupby(["D", "C", "B", "A"]).sum() + assert len(left) == len(right) + + +def test_groupby_sort_multi(): + df = DataFrame( + { + "a": ["foo", "bar", "baz"], + "b": [3, 2, 1], + "c": [0, 1, 2], + "d": np.random.default_rng(2).standard_normal(3), + } + ) + + tups = [tuple(row) for row in df[["a", "b", "c"]].values] + tups = com.asarray_tuplesafe(tups) + result = df.groupby(["a", "b", "c"], sort=True).sum() + tm.assert_numpy_array_equal(result.index.values, tups[[1, 2, 0]]) + + tups = [tuple(row) for row in df[["c", "a", "b"]].values] + tups = com.asarray_tuplesafe(tups) + result = df.groupby(["c", "a", "b"], sort=True).sum() + tm.assert_numpy_array_equal(result.index.values, tups) + + tups = [tuple(x) for x in df[["b", "c", "a"]].values] + tups = com.asarray_tuplesafe(tups) + result = df.groupby(["b", "c", "a"], sort=True).sum() + tm.assert_numpy_array_equal(result.index.values, tups[[2, 1, 0]]) + + df = DataFrame( + { + "a": [0, 1, 2, 0, 1, 2], + "b": [0, 0, 0, 1, 1, 1], + "d": np.random.default_rng(2).standard_normal(6), + } + ) + grouped = df.groupby(["a", "b"])["d"] + result = grouped.sum() + + def _check_groupby(df, result, keys, field, f=lambda x: x.sum()): + tups = [tuple(row) for row in df[keys].values] + tups = com.asarray_tuplesafe(tups) + expected = f(df.groupby(tups)[field]) + for k, v in expected.items(): + assert result[k] == v + + _check_groupby(df, result, ["a", "b"], "d") + + +def test_dont_clobber_name_column(): + df = DataFrame( + {"key": ["a", "a", "a", "b", "b", "b"], "name": ["foo", "bar", "baz"] * 2} + ) + + result = df.groupby("key", group_keys=False).apply(lambda x: x) + tm.assert_frame_equal(result, df[["name"]]) + + +def test_skip_group_keys(): + tsf = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + + grouped = tsf.groupby(lambda x: x.month, group_keys=False) + result = grouped.apply(lambda x: x.sort_values(by="A")[:3]) + + pieces = [group.sort_values(by="A")[:3] for key, group in grouped] + + expected = pd.concat(pieces) + tm.assert_frame_equal(result, expected) + + grouped = tsf["A"].groupby(lambda x: x.month, group_keys=False) + result = grouped.apply(lambda x: x.sort_values()[:3]) + + pieces = [group.sort_values()[:3] for key, group in grouped] + + expected = pd.concat(pieces) + tm.assert_series_equal(result, expected) + + +def test_no_nonsense_name(float_frame): + # GH #995 + s = float_frame["C"].copy() + s.name = None + + result = s.groupby(float_frame["A"]).agg("sum") + assert result.name is None + + +def test_multifunc_sum_bug(): + # GH #1065 + x = DataFrame(np.arange(9).reshape(3, 3)) + x["test"] = 0 + x["fl"] = [1.3, 1.5, 1.6] + + grouped = x.groupby("test") + result = grouped.agg({"fl": "sum", 2: "size"}) + assert result["fl"].dtype == np.float64 + + +def test_handle_dict_return_value(df): + def f(group): + return {"max": group.max(), "min": group.min()} + + def g(group): + return Series({"max": group.max(), "min": group.min()}) + + result = df.groupby("A")["C"].apply(f) + expected = df.groupby("A")["C"].apply(g) + + assert isinstance(result, Series) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("grouper", ["A", ["A", "B"]]) +def test_set_group_name(df, grouper): + def f(group): + assert group.name is not None + return group + + def freduce(group): + assert group.name is not None + return group.sum() + + def freducex(x): + return freduce(x) + + grouped = df.groupby(grouper, group_keys=False) + + # make sure all these work + grouped.apply(f) + grouped.aggregate(freduce) + grouped.aggregate({"C": freduce, "D": freduce}) + grouped.transform(f) + + grouped["C"].apply(f) + grouped["C"].aggregate(freduce) + grouped["C"].aggregate([freduce, freducex]) + grouped["C"].transform(f) + + +def test_group_name_available_in_inference_pass(): + # gh-15062 + df = DataFrame({"a": [0, 0, 1, 1, 2, 2], "b": np.arange(6)}) + + names = [] + + def f(group): + names.append(group.name) + return group.copy() + + df.groupby("a", sort=False, group_keys=False).apply(f) + expected_names = [0, 1, 2] + assert names == expected_names + + +def test_no_dummy_key_names(df): + # see gh-1291 + result = df.groupby(df["A"].values).sum() + assert result.index.name is None + + result2 = df.groupby([df["A"].values, df["B"].values]).sum() + assert result2.index.names == (None, None) + + +def test_groupby_sort_multiindex_series(): + # series multiindex groupby sort argument was not being passed through + # _compress_group_index + # GH 9444 + index = MultiIndex( + levels=[[1, 2], [1, 2]], + codes=[[0, 0, 0, 0, 1, 1], [1, 1, 0, 0, 0, 0]], + names=["a", "b"], + ) + mseries = Series([0, 1, 2, 3, 4, 5], index=index) + index = MultiIndex( + levels=[[1, 2], [1, 2]], codes=[[0, 0, 1], [1, 0, 0]], names=["a", "b"] + ) + mseries_result = Series([0, 2, 4], index=index) + + result = mseries.groupby(level=["a", "b"], sort=False).first() + tm.assert_series_equal(result, mseries_result) + result = mseries.groupby(level=["a", "b"], sort=True).first() + tm.assert_series_equal(result, mseries_result.sort_index()) + + +def test_groupby_reindex_inside_function(): + periods = 1000 + ind = date_range(start="2012/1/1", freq="5min", periods=periods) + df = DataFrame({"high": np.arange(periods), "low": np.arange(periods)}, index=ind) + + def agg_before(func, fix=False): + """ + Run an aggregate func on the subset of data. + """ + + def _func(data): + d = data.loc[data.index.map(lambda x: x.hour < 11)].dropna() + if fix: + data[data.index[0]] + if len(d) == 0: + return None + return func(d) + + return _func + + grouped = df.groupby(lambda x: datetime(x.year, x.month, x.day)) + closure_bad = grouped.agg({"high": agg_before(np.max)}) + closure_good = grouped.agg({"high": agg_before(np.max, True)}) + + tm.assert_frame_equal(closure_bad, closure_good) + + +def test_groupby_multiindex_missing_pair(): + # GH9049 + df = DataFrame( + { + "group1": ["a", "a", "a", "b"], + "group2": ["c", "c", "d", "c"], + "value": [1, 1, 1, 5], + } + ) + df = df.set_index(["group1", "group2"]) + df_grouped = df.groupby(level=["group1", "group2"], sort=True) + + res = df_grouped.agg("sum") + idx = MultiIndex.from_tuples( + [("a", "c"), ("a", "d"), ("b", "c")], names=["group1", "group2"] + ) + exp = DataFrame([[2], [1], [5]], index=idx, columns=["value"]) + + tm.assert_frame_equal(res, exp) + + +def test_groupby_multiindex_not_lexsorted(performance_warning): + # GH 11640 + + # define the lexsorted version + lexsorted_mi = MultiIndex.from_tuples( + [("a", ""), ("b1", "c1"), ("b2", "c2")], names=["b", "c"] + ) + lexsorted_df = DataFrame([[1, 3, 4]], columns=lexsorted_mi) + assert lexsorted_df.columns._is_lexsorted() + + # define the non-lexsorted version + not_lexsorted_df = DataFrame( + columns=["a", "b", "c", "d"], data=[[1, "b1", "c1", 3], [1, "b2", "c2", 4]] + ) + not_lexsorted_df = not_lexsorted_df.pivot_table( + index="a", columns=["b", "c"], values="d" + ) + not_lexsorted_df = not_lexsorted_df.reset_index() + assert not not_lexsorted_df.columns._is_lexsorted() + + expected = lexsorted_df.groupby("a").mean() + with tm.assert_produces_warning(performance_warning): + result = not_lexsorted_df.groupby("a").mean() + tm.assert_frame_equal(expected, result) + + # a transforming function should work regardless of sort + # GH 14776 + df = DataFrame( + {"x": ["a", "a", "b", "a"], "y": [1, 1, 2, 2], "z": [1, 2, 3, 4]} + ).set_index(["x", "y"]) + assert not df.index._is_lexsorted() + + for level in [0, 1, [0, 1]]: + for sort in [False, True]: + result = df.groupby(level=level, sort=sort, group_keys=False).apply( + DataFrame.drop_duplicates + ) + expected = df + tm.assert_frame_equal(expected, result) + + result = ( + df.sort_index() + .groupby(level=level, sort=sort, group_keys=False) + .apply(DataFrame.drop_duplicates) + ) + expected = df.sort_index() + tm.assert_frame_equal(expected, result) + + +def test_index_label_overlaps_location(): + # checking we don't have any label/location confusion in the + # wake of GH5375 + df = DataFrame(list("ABCDE"), index=[2, 0, 2, 1, 1]) + g = df.groupby(list("ababb")) + actual = g.filter(lambda x: len(x) > 2) + expected = df.iloc[[1, 3, 4]] + tm.assert_frame_equal(actual, expected) + + ser = df[0] + g = ser.groupby(list("ababb")) + actual = g.filter(lambda x: len(x) > 2) + expected = ser.take([1, 3, 4]) + tm.assert_series_equal(actual, expected) + + # and again, with a generic Index of floats + df.index = df.index.astype(float) + g = df.groupby(list("ababb")) + actual = g.filter(lambda x: len(x) > 2) + expected = df.iloc[[1, 3, 4]] + tm.assert_frame_equal(actual, expected) + + ser = df[0] + g = ser.groupby(list("ababb")) + actual = g.filter(lambda x: len(x) > 2) + expected = ser.take([1, 3, 4]) + tm.assert_series_equal(actual, expected) + + +def test_transform_doesnt_clobber_ints(): + # GH 7972 + n = 6 + x = np.arange(n) + df = DataFrame({"a": x // 2, "b": 2.0 * x, "c": 3.0 * x}) + df2 = DataFrame({"a": x // 2 * 1.0, "b": 2.0 * x, "c": 3.0 * x}) + + gb = df.groupby("a") + result = gb.transform("mean") + + gb2 = df2.groupby("a") + expected = gb2.transform("mean") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "sort_column", + ["ints", "floats", "strings", ["ints", "floats"], ["ints", "strings"]], +) +@pytest.mark.parametrize( + "group_column", ["int_groups", "string_groups", ["int_groups", "string_groups"]] +) +def test_groupby_preserves_sort(sort_column, group_column): + # Test to ensure that groupby always preserves sort order of original + # object. Issue #8588 and #9651 + + df = DataFrame( + { + "int_groups": [3, 1, 0, 1, 0, 3, 3, 3], + "string_groups": ["z", "a", "z", "a", "a", "g", "g", "g"], + "ints": [8, 7, 4, 5, 2, 9, 1, 1], + "floats": [2.3, 5.3, 6.2, -2.4, 2.2, 1.1, 1.1, 5], + "strings": ["z", "d", "a", "e", "word", "word2", "42", "47"], + } + ) + + # Try sorting on different types and with different group types + + df = df.sort_values(by=sort_column) + g = df.groupby(group_column) + + def test_sort(x): + tm.assert_frame_equal(x, x.sort_values(by=sort_column)) + + g.apply(test_sort) + + +def test_pivot_table_values_key_error(): + # This test is designed to replicate the error in issue #14938 + df = DataFrame( + { + "eventDate": date_range(datetime.today(), periods=20, freq="ME").tolist(), + "thename": range(20), + } + ) + + df["year"] = df.set_index("eventDate").index.year + df["month"] = df.set_index("eventDate").index.month + + with pytest.raises(KeyError, match="'badname'"): + df.reset_index().pivot_table( + index="year", columns="month", values="badname", aggfunc="count" + ) + + +@pytest.mark.parametrize("columns", ["C", ["C"]]) +@pytest.mark.parametrize("keys", [["A"], ["A", "B"]]) +@pytest.mark.parametrize( + "values", + [ + [True], + [0], + [0.0], + ["a"], + Categorical([0]), + [to_datetime(0)], + date_range(0, 1, 1, tz="US/Eastern"), + pd.period_range("2016-01-01", periods=3, freq="D"), + pd.array([0], dtype="Int64"), + pd.array([0], dtype="Float64"), + pd.array([False], dtype="boolean"), + ], + ids=[ + "bool", + "int", + "float", + "str", + "cat", + "dt64", + "dt64tz", + "period", + "Int64", + "Float64", + "boolean", + ], +) +@pytest.mark.parametrize("method", ["attr", "agg", "apply"]) +@pytest.mark.parametrize( + "op", ["idxmax", "idxmin", "min", "max", "sum", "prod", "skew", "kurt"] +) +def test_empty_groupby(columns, keys, values, method, op, dropna, using_infer_string): + # GH8093 & GH26411 + override_dtype = None + + if isinstance(values, BooleanArray) and op in ["sum", "prod"]: + # We expect to get Int64 back for these + override_dtype = "Int64" + + if isinstance(values[0], bool) and op in ("prod", "sum"): + # sum/product of bools is an integer + override_dtype = "int64" + + df = DataFrame({"A": values, "B": values, "C": values}, columns=list("ABC")) + + if hasattr(values, "dtype"): + # check that we did the construction right + assert (df.dtypes == values.dtype).all() + + df = df.iloc[:0] + + gb = df.groupby(keys, group_keys=False, dropna=dropna, observed=False)[columns] + + def get_result(**kwargs): + if method == "attr": + return getattr(gb, op)(**kwargs) + else: + return getattr(gb, method)(op, **kwargs) + + def get_categorical_invalid_expected(): + # Categorical is special without 'observed=True', we get a NaN entry + # corresponding to the unobserved group. If we passed observed=True + # to groupby, expected would just be 'df.set_index(keys)[columns]' + # as below + lev = Categorical([0], dtype=values.dtype) + if len(keys) != 1: + idx = MultiIndex.from_product([lev, lev], names=keys) + else: + # all columns are dropped, but we end up with one row + # Categorical is special without 'observed=True' + idx = Index(lev, name=keys[0]) + + if using_infer_string: + columns = Index([], dtype="str") + else: + columns = [] + expected = DataFrame([], columns=columns, index=idx) + return expected + + is_per = isinstance(df.dtypes.iloc[0], pd.PeriodDtype) + is_dt64 = df.dtypes.iloc[0].kind == "M" + is_cat = isinstance(values, Categorical) + is_str = isinstance(df.dtypes.iloc[0], pd.StringDtype) + + if ( + isinstance(values, Categorical) + and not values.ordered + and op in ["min", "max", "idxmin", "idxmax"] + ): + if op in ["min", "max"]: + msg = f"Cannot perform {op} with non-ordered Categorical" + klass = TypeError + else: + msg = f"Can't get {op} of an empty group due to unobserved categories" + klass = ValueError + with pytest.raises(klass, match=msg): + get_result() + + if op in ["min", "max", "idxmin", "idxmax"] and isinstance(columns, list): + # i.e. DataframeGroupBy, not SeriesGroupBy + result = get_result(numeric_only=True) + expected = get_categorical_invalid_expected() + tm.assert_equal(result, expected) + return + + if op in ["prod", "sum", "skew", "kurt"]: + # ops that require more than just ordered-ness + if is_dt64 or is_cat or is_per or (is_str and op != "sum"): + # GH#41291 + # datetime64 -> prod and sum are invalid + if is_dt64: + msg = "datetime64 type does not support" + elif is_per: + msg = "Period type does not support" + elif is_str: + msg = f"dtype 'str' does not support operation '{op}'" + else: + msg = "category type does not support" + if op in ["skew", "kurt"]: + msg = "|".join([msg, f"does not support operation '{op}'"]) + with pytest.raises(TypeError, match=msg): + get_result() + + if not isinstance(columns, list): + # i.e. SeriesGroupBy + return + elif op in ["skew", "kurt"]: + # TODO: test the numeric_only=True case + return + else: + # i.e. op in ["prod", "sum"]: + # i.e. DataFrameGroupBy + # ops that require more than just ordered-ness + # GH#41291 + result = get_result(numeric_only=True) + + # with numeric_only=True, these are dropped, and we get + # an empty DataFrame back + expected = df.set_index(keys)[[]] + if is_cat: + expected = get_categorical_invalid_expected() + tm.assert_equal(result, expected) + return + + result = get_result() + expected = df.set_index(keys)[columns] + if op in ["idxmax", "idxmin"]: + expected = expected.astype(df.index.dtype) + if override_dtype is not None: + expected = expected.astype(override_dtype) + if len(keys) == 1: + expected.index.name = keys[0] + tm.assert_equal(result, expected) + + +def test_empty_groupby_apply_nonunique_columns(): + # GH#44417 + df = DataFrame(np.random.default_rng(2).standard_normal((0, 4))) + df[3] = df[3].astype(np.int64) + df.columns = [0, 1, 2, 0] + gb = df.groupby(df[1], group_keys=False) + res = gb.apply(lambda x: x) + assert (res.dtypes == df.drop(columns=1).dtypes).all() + + +def test_tuple_as_grouping(): + # https://github.com/pandas-dev/pandas/issues/18314 + df = DataFrame( + { + ("a", "b"): [1, 1, 1, 1], + "a": [2, 2, 2, 2], + "b": [2, 2, 2, 2], + "c": [1, 1, 1, 1], + } + ) + + with pytest.raises(KeyError, match=r"('a', 'b')"): + df[["a", "b", "c"]].groupby(("a", "b")) + + result = df.groupby(("a", "b"))["c"].sum() + expected = Series([4], name="c", index=Index([1], name=("a", "b"))) + tm.assert_series_equal(result, expected) + + +def test_tuple_correct_keyerror(): + # https://github.com/pandas-dev/pandas/issues/18798 + df = DataFrame(1, index=range(3), columns=MultiIndex.from_product([[1, 2], [3, 4]])) + with pytest.raises(KeyError, match=r"^\(7, 8\)$"): + df.groupby((7, 8)).mean() + + +def test_groupby_agg_ohlc_non_first(): + # GH 21716 + df = DataFrame( + [[1], [1]], + columns=Index(["foo"], name="mycols"), + index=date_range("2018-01-01", periods=2, freq="D", name="dti"), + ) + + expected = DataFrame( + [[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], + columns=MultiIndex.from_tuples( + ( + ("foo", "sum", "foo"), + ("foo", "ohlc", "open"), + ("foo", "ohlc", "high"), + ("foo", "ohlc", "low"), + ("foo", "ohlc", "close"), + ), + names=["mycols", None, None], + ), + index=date_range("2018-01-01", periods=2, freq="D", name="dti"), + ) + + result = df.groupby(Grouper(freq="D")).agg(["sum", "ohlc"]) + + tm.assert_frame_equal(result, expected) + + +def test_groupby_multiindex_nat(): + # GH 9236 + values = [ + (pd.NaT, "a"), + (datetime(2012, 1, 2), "a"), + (datetime(2012, 1, 2), "b"), + (datetime(2012, 1, 3), "a"), + ] + mi = MultiIndex.from_tuples(values, names=["date", None]) + ser = Series([3, 2, 2.5, 4], index=mi) + + result = ser.groupby(level=1).mean() + expected = Series([3.0, 2.5], index=["a", "b"]) + tm.assert_series_equal(result, expected) + + +def test_groupby_empty_list_raises(): + # GH 5289 + values = zip(range(10), range(10), strict=True) + df = DataFrame(values, columns=["apple", "b"]) + msg = "Grouper and axis must be same length" + with pytest.raises(ValueError, match=msg): + df.groupby([[]]) + + +def test_groupby_multiindex_series_keys_len_equal_group_axis(): + # GH 25704 + index_array = [["x", "x"], ["a", "b"], ["k", "k"]] + index_names = ["first", "second", "third"] + ri = MultiIndex.from_arrays(index_array, names=index_names) + s = Series(data=[1, 2], index=ri) + result = s.groupby(["first", "third"]).sum() + + index_array = [["x"], ["k"]] + index_names = ["first", "third"] + ei = MultiIndex.from_arrays(index_array, names=index_names) + expected = Series([3], index=ei) + + tm.assert_series_equal(result, expected) + + +def test_groupby_groups_in_BaseGrouper(): + # GH 26326 + # Test if DataFrame grouped with a pandas.Grouper has correct groups + mi = MultiIndex.from_product([["A", "B"], ["C", "D"]], names=["alpha", "beta"]) + df = DataFrame({"foo": [1, 2, 1, 2], "bar": [1, 2, 3, 4]}, index=mi) + result = df.groupby([Grouper(level="alpha"), "beta"]) + expected = df.groupby(["alpha", "beta"]) + assert result.groups == expected.groups + + result = df.groupby(["beta", Grouper(level="alpha")]) + expected = df.groupby(["beta", "alpha"]) + assert result.groups == expected.groups + + +def test_groups_sort_dropna(sort, dropna): + # GH#56966, GH#56851 + df = DataFrame([[2.0, 1.0], [np.nan, 4.0], [0.0, 3.0]]) + keys = [(2.0, 1.0), (np.nan, 4.0), (0.0, 3.0)] + values = [ + RangeIndex(0, 1), + RangeIndex(1, 2), + RangeIndex(2, 3), + ] + if sort: + taker = [2, 0] if dropna else [2, 0, 1] + else: + taker = [0, 2] if dropna else [0, 1, 2] + expected = {keys[idx]: values[idx] for idx in taker} + + gb = df.groupby([0, 1], sort=sort, dropna=dropna) + result = gb.groups + + for result_key, expected_key in zip(result.keys(), expected.keys(), strict=True): + # Compare as NumPy arrays to handle np.nan + result_key = np.array(result_key) + expected_key = np.array(expected_key) + tm.assert_numpy_array_equal(result_key, expected_key) + for result_value, expected_value in zip( + result.values(), expected.values(), strict=True + ): + tm.assert_index_equal(result_value, expected_value) + + +@pytest.mark.parametrize( + "op, expected", + [ + ( + "shift", + { + "time": [ + None, + None, + Timestamp("2019-01-01 12:00:00"), + Timestamp("2019-01-01 12:30:00"), + None, + None, + ] + }, + ), + ( + "bfill", + { + "time": [ + Timestamp("2019-01-01 12:00:00"), + Timestamp("2019-01-01 12:30:00"), + Timestamp("2019-01-01 14:00:00"), + Timestamp("2019-01-01 14:30:00"), + Timestamp("2019-01-01 14:00:00"), + Timestamp("2019-01-01 14:30:00"), + ] + }, + ), + ( + "ffill", + { + "time": [ + Timestamp("2019-01-01 12:00:00"), + Timestamp("2019-01-01 12:30:00"), + Timestamp("2019-01-01 12:00:00"), + Timestamp("2019-01-01 12:30:00"), + Timestamp("2019-01-01 14:00:00"), + Timestamp("2019-01-01 14:30:00"), + ] + }, + ), + ], +) +def test_shift_bfill_ffill_tz(tz_naive_fixture, op, expected): + # GH19995, GH27992: Check that timezone does not drop in shift, bfill, and ffill + tz = tz_naive_fixture + data = { + "id": ["A", "B", "A", "B", "A", "B"], + "time": [ + Timestamp("2019-01-01 12:00:00"), + Timestamp("2019-01-01 12:30:00"), + None, + None, + Timestamp("2019-01-01 14:00:00"), + Timestamp("2019-01-01 14:30:00"), + ], + } + df = DataFrame(data).assign(time=lambda x: x.time.dt.tz_localize(tz)) + + grouped = df.groupby("id") + result = getattr(grouped, op)() + expected = DataFrame(expected).assign(time=lambda x: x.time.dt.tz_localize(tz)) + tm.assert_frame_equal(result, expected) + + +def test_groupby_only_none_group(): + # see GH21624 + # this was crashing with "ValueError: Length of passed values is 1, index implies 0" + df = DataFrame({"g": [None], "x": 1}) + actual = df.groupby("g")["x"].transform("sum") + expected = Series([np.nan], name="x") + + tm.assert_series_equal(actual, expected) + + +def test_groupby_duplicate_index(): + # GH#29189 the groupby call here used to raise + ser = Series([2, 5, 6, 8], index=[2.0, 4.0, 4.0, 5.0]) + gb = ser.groupby(level=0) + + result = gb.mean() + expected = Series([2, 5.5, 8], index=[2.0, 4.0, 5.0]) + tm.assert_series_equal(result, expected) + + +def test_group_on_empty_multiindex(transformation_func, request): + # GH 47787 + # With one row, those are transforms so the schema should be the same + df = DataFrame( + data=[[1, Timestamp("today"), 3, 4]], + columns=["col_1", "col_2", "col_3", "col_4"], + ) + df["col_3"] = df["col_3"].astype(int) + df["col_4"] = df["col_4"].astype(int) + df = df.set_index(["col_1", "col_2"]) + result = df.iloc[:0].groupby(["col_1"]).transform(transformation_func) + expected = df.groupby(["col_1"]).transform(transformation_func).iloc[:0] + if transformation_func in ("diff", "shift"): + expected = expected.astype(int) + tm.assert_equal(result, expected) + + result = df["col_3"].iloc[:0].groupby(["col_1"]).transform(transformation_func) + expected = df["col_3"].groupby(["col_1"]).transform(transformation_func).iloc[:0] + if transformation_func in ("diff", "shift"): + expected = expected.astype(int) + tm.assert_equal(result, expected) + + +def test_groupby_crash_on_nunique(): + # Fix following 30253 + dti = date_range("2016-01-01", periods=2, name="foo") + df = DataFrame({("A", "B"): [1, 2], ("A", "C"): [1, 3], ("D", "B"): [0, 0]}) + df.columns.names = ("bar", "baz") + df.index = dti + + df = df.T + gb = df.groupby(level=0) + result = gb.nunique() + + expected = DataFrame({"A": [1, 2], "D": [1, 1]}, index=dti) + expected.columns.name = "bar" + expected = expected.T + + tm.assert_frame_equal(result, expected) + + # same thing, but empty columns + gb2 = df[[]].groupby(level=0) + exp = expected[[]] + + res = gb2.nunique() + tm.assert_frame_equal(res, exp) + + +def test_groupby_list_level(): + # GH 9790 + expected = DataFrame(np.arange(0, 9).reshape(3, 3), dtype=float) + result = expected.groupby(level=[0]).mean() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "max_seq_items, expected", + [ + (5, "{0: [0], 1: [1], 2: [2], 3: [3], 4: [4]}"), + (4, "{0: [0], 1: [1], 2: [2], 3: [3], ...}"), + (1, "{0: [0], ...}"), + ], +) +def test_groups_repr_truncates(max_seq_items, expected): + # GH 1135 + df = DataFrame(np.random.default_rng(2).standard_normal((5, 1))) + df["a"] = df.index + + with pd.option_context("display.max_seq_items", max_seq_items): + result = df.groupby("a").groups.__repr__() + assert result == expected + + result = df.groupby(np.array(df.a)).groups.__repr__() + assert result == expected + + +def test_group_on_two_row_multiindex_returns_one_tuple_key(): + # GH 18451 + df = DataFrame([{"a": 1, "b": 2, "c": 99}, {"a": 1, "b": 2, "c": 88}]) + df = df.set_index(["a", "b"]) + + grp = df.groupby(["a", "b"]) + result = grp.indices + expected = {(1, 2): np.array([0, 1], dtype=np.int64)} + + assert len(result) == 1 + key = (1, 2) + assert (result[key] == expected[key]).all() + + +@pytest.mark.parametrize( + "klass, attr, value", + [ + (DataFrame, "level", "a"), + (DataFrame, "as_index", False), + (DataFrame, "sort", False), + (DataFrame, "group_keys", False), + (DataFrame, "observed", True), + (DataFrame, "dropna", False), + (Series, "level", "a"), + (Series, "as_index", False), + (Series, "sort", False), + (Series, "group_keys", False), + (Series, "observed", True), + (Series, "dropna", False), + ], +) +def test_subsetting_columns_keeps_attrs(klass, attr, value): + # GH 9959 - When subsetting columns, don't drop attributes + df = DataFrame({"a": [1], "b": [2], "c": [3]}) + if attr != "axis": + df = df.set_index("a") + + expected = df.groupby("a", **{attr: value}) + result = expected[["b"]] if klass is DataFrame else expected["b"] + assert getattr(result, attr) == getattr(expected, attr) + + +@pytest.mark.parametrize("func", ["sum", "any", "shift"]) +def test_groupby_column_index_name_lost(func): + # GH: 29764 groupby loses index sometimes + expected = Index(["a"], name="idx") + df = DataFrame([[1]], columns=expected) + df_grouped = df.groupby([1]) + result = getattr(df_grouped, func)().columns + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize( + "infer_string", + [ + False, + pytest.param(True, marks=td.skip_if_no("pyarrow")), + ], +) +def test_groupby_duplicate_columns(infer_string): + # GH: 31735 + if infer_string: + pytest.importorskip("pyarrow") + df = DataFrame( + {"A": ["f", "e", "g", "h"], "B": ["a", "b", "c", "d"], "C": [1, 2, 3, 4]} + ).astype(object) + df.columns = ["A", "B", "B"] + with pd.option_context("future.infer_string", infer_string): + result = df.groupby([0, 0, 0, 0]).min() + expected = DataFrame( + [["e", "a", 1]], index=np.array([0]), columns=["A", "B", "B"], dtype=object + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_series_with_tuple_name(): + # GH 37755 + ser = Series([1, 2, 3, 4], index=[1, 1, 2, 2], name=("a", "a")) + ser.index.name = ("b", "b") + result = ser.groupby(level=0).last() + expected = Series([2, 4], index=[1, 2], name=("a", "a")) + expected.index.name = ("b", "b") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "func, values", [("sum", [97.0, 98.0]), ("mean", [24.25, 24.5])] +) +def test_groupby_numerical_stability_sum_mean(func, values): + # GH#38778 + data = [1e16, 1e16, 97, 98, -5e15, -5e15, -5e15, -5e15] + df = DataFrame({"group": [1, 2] * 4, "a": data, "b": data}) + result = getattr(df.groupby("group"), func)() + expected = DataFrame({"a": values, "b": values}, index=Index([1, 2], name="group")) + tm.assert_frame_equal(result, expected) + + +def test_groupby_numerical_stability_cumsum(): + # GH#38934 + data = [1e16, 1e16, 97, 98, -5e15, -5e15, -5e15, -5e15] + df = DataFrame({"group": [1, 2] * 4, "a": data, "b": data}) + result = df.groupby("group").cumsum() + exp_data = ( + [1e16] * 2 + [1e16 + 96, 1e16 + 98] + [5e15 + 97, 5e15 + 98] + [97.0, 98.0] + ) + expected = DataFrame({"a": exp_data, "b": exp_data}) + tm.assert_frame_equal(result, expected, check_exact=True) + + +def test_groupby_cumsum_skipna_false(): + # GH#46216 don't propagate np.nan above the diagonal + arr = np.random.default_rng(2).standard_normal((5, 5)) + df = DataFrame(arr) + for i in range(5): + df.iloc[i, i] = np.nan + + df["A"] = 1 + gb = df.groupby("A") + + res = gb.cumsum(skipna=False) + + expected = df[[0, 1, 2, 3, 4]].cumsum(skipna=False) + tm.assert_frame_equal(res, expected) + + +def test_groupby_cumsum_timedelta64(): + # GH#46216 don't ignore is_datetimelike in libgroupby.group_cumsum + dti = date_range("2016-01-01", periods=5, unit="ns") + ser = Series(dti) - dti[0] + ser[2] = pd.NaT + + df = DataFrame({"A": 1, "B": ser}) + gb = df.groupby("A") + + res = gb.cumsum(numeric_only=False, skipna=True) + exp = DataFrame({"B": [ser[0], ser[1], pd.NaT, ser[4], ser[4] * 2]}) + tm.assert_frame_equal(res, exp) + + res = gb.cumsum(numeric_only=False, skipna=False) + exp = DataFrame({"B": [ser[0], ser[1], pd.NaT, pd.NaT, pd.NaT]}) + tm.assert_frame_equal(res, exp) + + +def test_groupby_mean_duplicate_index(rand_series_with_duplicate_datetimeindex): + dups = rand_series_with_duplicate_datetimeindex + result = dups.groupby(level=0).mean() + expected = dups.groupby(dups.index).mean() + tm.assert_series_equal(result, expected) + + +def test_groupby_all_nan_groups_drop(): + # GH 15036 + s = Series([1, 2, 3], [np.nan, np.nan, np.nan]) + result = s.groupby(s.index).sum() + expected = Series([], index=Index([], dtype=np.float64), dtype=np.int64) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("numeric_only", [True, False]) +def test_groupby_empty_multi_column(as_index, numeric_only): + # GH 15106 & GH 41998 + df = DataFrame(data=[], columns=["A", "B", "C"]) + gb = df.groupby(["A", "B"], as_index=as_index) + result = gb.sum(numeric_only=numeric_only) + if as_index: + index = MultiIndex([[], []], [[], []], names=["A", "B"]) + columns = ["C"] if not numeric_only else Index([], dtype="str") + else: + index = RangeIndex(0) + columns = ["A", "B", "C"] if not numeric_only else ["A", "B"] + expected = DataFrame([], columns=columns, index=index) + tm.assert_frame_equal(result, expected) + + +def test_groupby_aggregation_non_numeric_dtype(): + # GH #43108 + df = DataFrame( + [["M", [1]], ["M", [1]], ["W", [10]], ["W", [20]]], columns=["MW", "v"] + ) + + expected = DataFrame( + { + "v": [[1, 1], [10, 20]], + }, + index=Index(["M", "W"], name="MW"), + ) + + gb = df.groupby(by=["MW"]) + result = gb.sum() + tm.assert_frame_equal(result, expected) + + +def test_groupby_aggregation_multi_non_numeric_dtype(): + # GH #42395 + df = DataFrame( + { + "x": [1, 0, 1, 1, 0], + "y": [Timedelta(i, "days") for i in range(1, 6)], + "z": [Timedelta(i * 10, "days") for i in range(1, 6)], + } + ) + + expected = DataFrame( + { + "y": [Timedelta(i, "days") for i in range(7, 9)], + "z": [Timedelta(i * 10, "days") for i in range(7, 9)], + }, + index=Index([0, 1], dtype="int64", name="x"), + ) + + gb = df.groupby(by=["x"]) + result = gb.sum() + tm.assert_frame_equal(result, expected) + + +def test_groupby_aggregation_numeric_with_non_numeric_dtype(): + # GH #43108 + df = DataFrame( + { + "x": [1, 0, 1, 1, 0], + "y": [Timedelta(i, "days") for i in range(1, 6)], + "z": list(range(1, 6)), + } + ) + + expected = DataFrame( + {"y": [Timedelta(7, "days"), Timedelta(8, "days")], "z": [7, 8]}, + index=Index([0, 1], dtype="int64", name="x"), + ) + + gb = df.groupby(by=["x"]) + result = gb.sum() + tm.assert_frame_equal(result, expected) + + +def test_groupby_filtered_df_std(): + # GH 16174 + dicts = [ + {"filter_col": False, "groupby_col": True, "bool_col": True, "float_col": 10.5}, + {"filter_col": True, "groupby_col": True, "bool_col": True, "float_col": 20.5}, + {"filter_col": True, "groupby_col": True, "bool_col": True, "float_col": 30.5}, + ] + df = DataFrame(dicts) + + df_filter = df[df["filter_col"] == True] # noqa: E712 + dfgb = df_filter.groupby("groupby_col") + result = dfgb.std() + expected = DataFrame( + [[0.0, 0.0, 7.071068]], + columns=["filter_col", "bool_col", "float_col"], + index=Index([True], name="groupby_col"), + ) + tm.assert_frame_equal(result, expected) + + +def test_datetime_categorical_multikey_groupby_indices(): + # GH 26859 + df = DataFrame( + { + "a": Series(list("abc")), + "b": Series( + to_datetime(["2018-01-01", "2018-02-01", "2018-03-01"]), + dtype="category", + ), + "c": Categorical.from_codes([-1, 0, 1], categories=[0, 1]), + } + ) + result = df.groupby(["a", "b"], observed=False).indices + expected = { + ("a", Timestamp("2018-01-01 00:00:00")): np.array([0]), + ("b", Timestamp("2018-02-01 00:00:00")): np.array([1]), + ("c", Timestamp("2018-03-01 00:00:00")): np.array([2]), + } + assert result == expected + + +def test_rolling_wrong_param_min_period(): + # GH34037 + name_l = ["Alice"] * 5 + ["Bob"] * 5 + val_l = [np.nan, np.nan, 1, 2, 3, np.nan, 1, 2, 3, 4] + test_df = DataFrame([name_l, val_l]).T + test_df.columns = ["name", "val"] + + result_error_msg = ( + r"^[a-zA-Z._]*\(\) got an unexpected keyword argument 'min_period'" + ) + with pytest.raises(TypeError, match=result_error_msg): + test_df.groupby("name")["val"].rolling(window=2, min_period=1).sum() + + +def test_by_column_values_with_same_starting_value(any_string_dtype): + # GH29635 + dtype = any_string_dtype + df = DataFrame( + { + "Name": ["Thomas", "Thomas", "Thomas John"], + "Credit": [1200, 1300, 900], + "Mood": Series(["sad", "happy", "happy"], dtype=dtype), + } + ) + aggregate_details = {"Mood": Series.mode, "Credit": "sum"} + + result = df.groupby(["Name"]).agg(aggregate_details) + expected = DataFrame( + { + "Mood": [["happy", "sad"], "happy"], + "Credit": [2500, 900], + "Name": ["Thomas", "Thomas John"], + }, + ).set_index("Name") + if getattr(dtype, "storage", None) == "pyarrow": + mood_values = pd.array(["happy", "sad"], dtype=dtype) + expected["Mood"] = [mood_values, "happy"] + tm.assert_frame_equal(result, expected) + + +def test_groupby_none_in_first_mi_level(): + # GH#47348 + arr = [[None, 1, 0, 1], [2, 3, 2, 3]] + ser = Series(1, index=MultiIndex.from_arrays(arr, names=["a", "b"])) + result = ser.groupby(level=[0, 1]).sum() + expected = Series( + [1, 2], MultiIndex.from_tuples([(0.0, 2), (1.0, 3)], names=["a", "b"]) + ) + tm.assert_series_equal(result, expected) + + +def test_groupby_none_column_name(using_infer_string): + # GH#47348 + df = DataFrame({None: [1, 1, 2, 2], "b": [1, 1, 2, 3], "c": [4, 5, 6, 7]}) + by = [np.nan] if using_infer_string else [None] + gb = df.groupby(by=by) + result = gb.sum() + expected = DataFrame({"b": [2, 5], "c": [9, 13]}, index=Index([1, 2], name=by[0])) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("selection", [None, "a", ["a"]]) +def test_single_element_list_grouping(selection): + # GH#42795, GH#53500 + df = DataFrame({"a": [1, 2], "b": [np.nan, 5], "c": [np.nan, 2]}, index=["x", "y"]) + grouped = df.groupby(["a"]) if selection is None else df.groupby(["a"])[selection] + result = [key for key, _ in grouped] + + expected = [(1,), (2,)] + assert result == expected + + +def test_groupby_string_dtype(): + # GH 40148 + df = DataFrame({"str_col": ["a", "b", "c", "a"], "num_col": [1, 2, 3, 2]}) + df["str_col"] = df["str_col"].astype("string") + expected = DataFrame( + { + "str_col": [ + "a", + "b", + "c", + ], + "num_col": [1.5, 2.0, 3.0], + } + ) + expected["str_col"] = expected["str_col"].astype("string") + grouped = df.groupby("str_col", as_index=False) + result = grouped.mean() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "level_arg, multiindex", [([0], False), ((0,), False), ([0], True), ((0,), True)] +) +def test_single_element_listlike_level_grouping(level_arg, multiindex): + # GH 51583 + df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"]) + if multiindex: + df = df.set_index(["a", "b"]) + result = [key for key, _ in df.groupby(level=level_arg)] + expected = [(1,), (2,)] if multiindex else [("x",), ("y",)] + assert result == expected + + +@pytest.mark.parametrize("func", ["sum", "cumsum", "cumprod", "prod"]) +def test_groupby_avoid_casting_to_float(func): + # GH#37493 + val = 922337203685477580 + df = DataFrame({"a": 1, "b": [val]}) + result = getattr(df.groupby("a"), func)() - val + expected = DataFrame({"b": [0]}, index=Index([1], name="a")) + if func in ["cumsum", "cumprod"]: + expected = expected.reset_index(drop=True) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func, val", [("sum", 3), ("prod", 2)]) +def test_groupby_sum_support_mask(any_numeric_ea_dtype, func, val): + # GH#37493 + df = DataFrame({"a": 1, "b": [1, 2, pd.NA]}, dtype=any_numeric_ea_dtype) + result = getattr(df.groupby("a"), func)() + expected = DataFrame( + {"b": [val]}, + index=Index([1], name="a", dtype=any_numeric_ea_dtype), + dtype=any_numeric_ea_dtype, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("val, dtype", [(111, "int"), (222, "uint")]) +def test_groupby_overflow(val, dtype): + # GH#37493 + df = DataFrame({"a": 1, "b": [val, val]}, dtype=f"{dtype}8") + result = df.groupby("a").sum() + expected = DataFrame( + {"b": [val * 2]}, + index=Index([1], name="a", dtype=f"{dtype}8"), + dtype=f"{dtype}64", + ) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a").cumsum() + expected = DataFrame({"b": [val, val * 2]}, dtype=f"{dtype}64") + tm.assert_frame_equal(result, expected) + + result = df.groupby("a").prod() + expected = DataFrame( + {"b": [val * val]}, + index=Index([1], name="a", dtype=f"{dtype}8"), + dtype=f"{dtype}64", + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("skipna, val", [(True, 3), (False, pd.NA)]) +def test_groupby_cumsum_mask(any_numeric_ea_dtype, skipna, val): + # GH#37493 + df = DataFrame({"a": 1, "b": [1, pd.NA, 2]}, dtype=any_numeric_ea_dtype) + result = df.groupby("a").cumsum(skipna=skipna) + expected = DataFrame( + {"b": [1, pd.NA, val]}, + dtype=any_numeric_ea_dtype, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "val_in, index, val_out", + [ + ( + [1.0, 2.0, 3.0, 4.0, 5.0], + ["foo", "foo", "bar", "baz", "blah"], + [3.0, 4.0, 5.0, 3.0], + ), + ( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + ["foo", "foo", "bar", "baz", "blah", "blah"], + [3.0, 4.0, 11.0, 3.0], + ), + ], +) +def test_groupby_index_name_in_index_content(val_in, index, val_out): + # GH 48567 + series = Series(data=val_in, name="values", index=Index(index, name="blah")) + result = series.groupby("blah").sum() + expected = Series( + data=val_out, + name="values", + index=Index(["bar", "baz", "blah", "foo"], name="blah"), + ) + tm.assert_series_equal(result, expected) + + result = series.to_frame().groupby("blah").sum() + expected = expected.to_frame() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("n", [1, 10, 32, 100, 1000]) +def test_sum_of_booleans(n): + # GH 50347 + df = DataFrame({"groupby_col": 1, "bool": [True] * n}) + df["bool"] = df["bool"].eq(True) + result = df.groupby("groupby_col").sum() + expected = DataFrame({"bool": [n]}, index=Index([1], name="groupby_col")) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:invalid value encountered in remainder:RuntimeWarning" +) +@pytest.mark.parametrize("method", ["head", "tail", "nth", "first", "last"]) +def test_groupby_method_drop_na(method): + # GH 21755 + df = DataFrame({"A": ["a", np.nan, "b", np.nan, "c"], "B": range(5)}) + + if method == "nth": + result = getattr(df.groupby("A"), method)(n=0) + else: + result = getattr(df.groupby("A"), method)() + + if method in ["first", "last"]: + expected = DataFrame({"B": [0, 2, 4]}).set_index( + Series(["a", "b", "c"], name="A") + ) + else: + expected = DataFrame( + {"A": ["a", "b", "c"], "B": [0, 2, 4]}, index=range(0, 6, 2) + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_reduce_period(): + # GH#51040 + pi = pd.period_range("2016-01-01", periods=100, freq="D") + grps = list(range(10)) * 10 + ser = pi.to_series() + gb = ser.groupby(grps) + + with pytest.raises(TypeError, match="Period type does not support sum operations"): + gb.sum() + with pytest.raises( + TypeError, match="Period type does not support cumsum operations" + ): + gb.cumsum() + with pytest.raises(TypeError, match="Period type does not support prod operations"): + gb.prod() + with pytest.raises( + TypeError, match="Period type does not support cumprod operations" + ): + gb.cumprod() + + res = gb.max() + expected = ser[-10:] + expected.index = Index(range(10), dtype=int) + tm.assert_series_equal(res, expected) + + res = gb.min() + expected = ser[:10] + expected.index = Index(range(10), dtype=int) + tm.assert_series_equal(res, expected) + + +def test_obj_with_exclusions_duplicate_columns(): + # GH#50806 + df = DataFrame([[0, 1, 2, 3]]) + df.columns = [0, 1, 2, 0] + gb = df.groupby(df[1]) + result = gb._obj_with_exclusions + expected = df.take([0, 2, 3], axis=1) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("numeric_only", [True, False]) +def test_groupby_numeric_only_std_no_result(numeric_only): + # GH 51080 + dicts_non_numeric = [{"a": "foo", "b": "bar"}, {"a": "car", "b": "dar"}] + df = DataFrame(dicts_non_numeric, dtype=object) + dfgb = df.groupby("a", as_index=False, sort=False) + + if numeric_only: + result = dfgb.std(numeric_only=True) + expected_df = DataFrame(["foo", "car"], columns=["a"]) + tm.assert_frame_equal(result, expected_df) + else: + with pytest.raises( + ValueError, match="could not convert string to float: 'bar'" + ): + dfgb.std(numeric_only=numeric_only) + + +def test_grouping_with_categorical_interval_columns(): + # GH#34164 + df = DataFrame({"x": [0.1, 0.2, 0.3, -0.4, 0.5], "w": ["a", "b", "a", "c", "a"]}) + qq = pd.qcut(df["x"], q=np.linspace(0, 1, 5)) + result = df.groupby([qq, "w"], observed=False)["x"].agg("mean") + categorical_index_level_1 = Categorical( + [ + Interval(-0.401, 0.1, closed="right"), + Interval(0.1, 0.2, closed="right"), + Interval(0.2, 0.3, closed="right"), + Interval(0.3, 0.5, closed="right"), + ], + ordered=True, + ) + index_level_2 = ["a", "b", "c"] + mi = MultiIndex.from_product( + [categorical_index_level_1, index_level_2], names=["x", "w"] + ) + expected = Series( + np.array( + [ + 0.1, + np.nan, + -0.4, + np.nan, + 0.2, + np.nan, + 0.3, + np.nan, + np.nan, + 0.5, + np.nan, + np.nan, + ] + ), + index=mi, + name="x", + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("bug_var", [1, "a"]) +def test_groupby_sum_on_nan_should_return_nan(bug_var): + # GH 24196 + df = DataFrame({"A": [bug_var, bug_var, bug_var, np.nan]}) + if isinstance(bug_var, str): + df = df.astype(object) + dfgb = df.groupby(lambda x: x) + result = dfgb.sum(min_count=1) + + expected_df = DataFrame( + [bug_var, bug_var, bug_var, None], columns=["A"], dtype=df["A"].dtype + ) + tm.assert_frame_equal(result, expected_df) + + +@pytest.mark.parametrize( + "method", + [ + "count", + "corr", + "cummax", + "cummin", + "cumprod", + "describe", + "rank", + "quantile", + "diff", + "shift", + "all", + "any", + "idxmin", + "idxmax", + "ffill", + "bfill", + "pct_change", + ], +) +def test_groupby_selection_with_methods(df, method): + # some methods which require DatetimeIndex + rng = date_range("2014", periods=len(df)) + df.index = rng + + g = df.groupby(["A"])[["C"]] + g_exp = df[["C"]].groupby(df["A"]) + # TODO check groupby with > 1 col ? + + res = getattr(g, method)() + exp = getattr(g_exp, method)() + + # should always be frames! + tm.assert_frame_equal(res, exp) + + +def test_groupby_selection_other_methods(df): + # some methods which require DatetimeIndex + rng = date_range("2014", periods=len(df)) + df.columns.name = "foo" + df.index = rng + + g = df.groupby(["A"])[["C"]] + g_exp = df[["C"]].groupby(df["A"]) + + # methods which aren't just .foo() + tm.assert_frame_equal(g.apply(lambda x: x.sum()), g_exp.apply(lambda x: x.sum())) + + tm.assert_frame_equal(g.resample("D").mean(), g_exp.resample("D").mean()) + tm.assert_frame_equal(g.resample("D").ohlc(), g_exp.resample("D").ohlc()) + + tm.assert_frame_equal( + g.filter(lambda x: len(x) == 3), g_exp.filter(lambda x: len(x) == 3) + ) + + +def test_groupby_with_Time_Grouper(unit): + idx2 = to_datetime( + [ + "2016-08-31 22:08:12.000", + "2016-08-31 22:09:12.200", + "2016-08-31 22:20:12.400", + ] + ).as_unit(unit) + + test_data = DataFrame( + {"quant": [1.0, 1.0, 3.0], "quant2": [1.0, 1.0, 3.0], "time2": idx2} + ) + + time2 = date_range("2016-08-31 22:08:00", periods=13, freq="1min", unit=unit) + expected_output = DataFrame( + { + "time2": time2, + "quant": [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + "quant2": [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + } + ) + + gb = test_data.groupby(Grouper(key="time2", freq="1min")) + result = gb.count().reset_index() + + tm.assert_frame_equal(result, expected_output) + + +def test_groupby_series_with_datetimeindex_month_name(): + # GH 48509 + s = Series([0, 1, 0], index=date_range("2022-01-01", periods=3), name="jan") + result = s.groupby(s).count() + expected = Series([2, 1], name="jan") + expected.index.name = "jan" + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("test_series", [True, False]) +@pytest.mark.parametrize( + "kwarg, value, name, warn", + [ + ("by", "a", 1, None), + ("by", ["a"], (1,), None), + ("level", 0, 1, None), + ("level", [0], (1,), None), + ], +) +def test_get_group_len_1_list_likes(test_series, kwarg, value, name, warn): + # GH#25971 + obj = DataFrame({"b": [3, 4, 5]}, index=Index([1, 1, 2], name="a")) + if test_series: + obj = obj["b"] + gb = obj.groupby(**{kwarg: value}) + result = gb.get_group(name) + if test_series: + expected = Series([3, 4], index=Index([1, 1], name="a"), name="b") + else: + expected = DataFrame({"b": [3, 4]}, index=Index([1, 1], name="a")) + tm.assert_equal(result, expected) + + +def test_groupby_ngroup_with_nan(): + # GH#50100 + df = DataFrame({"a": Categorical([np.nan]), "b": [1]}) + result = df.groupby(["a", "b"], dropna=False, observed=False).ngroup() + expected = Series([0]) + tm.assert_series_equal(result, expected) + + +def test_groupby_ffill_with_duplicated_index(): + # GH#43412 + df = DataFrame({"a": [1, 2, 3, 4, np.nan, np.nan]}, index=[0, 1, 2, 0, 1, 2]) + + result = df.groupby(level=0).ffill() + expected = DataFrame({"a": [1, 2, 3, 4, 2, 3]}, index=[0, 1, 2, 0, 1, 2]) + tm.assert_frame_equal(result, expected, check_dtype=False) + + +@pytest.mark.parametrize("test_series", [True, False]) +def test_decimal_na_sort(test_series): + # GH#54847 + # We catch both TypeError and decimal.InvalidOperation exceptions in safe_sort. + # If this next assert raises, we can just catch TypeError + assert not isinstance(decimal.InvalidOperation, TypeError) + df = DataFrame( + { + "key": [Decimal(1), Decimal(1), None, None], + "value": [Decimal(2), Decimal(3), Decimal(4), Decimal(5)], + } + ) + gb = df.groupby("key", dropna=False) + if test_series: + gb = gb["value"] + result = gb._grouper.result_index + expected = Index([Decimal(1), None], name="key") + tm.assert_index_equal(result, expected) + + +def test_groupby_dropna_with_nunique_unique(): + # GH#42016 + df = [[1, 1, 1, "A"], [1, None, 1, "A"], [1, None, 2, "A"], [1, None, 3, "A"]] + df_dropna = DataFrame(df, columns=["a", "b", "c", "partner"]) + result = df_dropna.groupby(["a", "b", "c"], dropna=False).agg( + {"partner": ["nunique", "unique"]} + ) + + index = MultiIndex.from_tuples( + [(1, 1.0, 1), (1, np.nan, 1), (1, np.nan, 2), (1, np.nan, 3)], + names=["a", "b", "c"], + ) + columns = MultiIndex.from_tuples([("partner", "nunique"), ("partner", "unique")]) + expected = DataFrame( + [(1, ["A"]), (1, ["A"]), (1, ["A"]), (1, ["A"])], index=index, columns=columns + ) + + tm.assert_frame_equal(result, expected) + + +def test_groupby_agg_namedagg_with_duplicate_columns(): + # GH#58446 + df = DataFrame( + { + "col1": [2, 1, 1, 0, 2, 0], + "col2": [4, 5, 36, 7, 4, 5], + "col3": [3.1, 8.0, 12, 10, 4, 1.1], + "col4": [17, 3, 16, 15, 5, 6], + "col5": [-1, 3, -1, 3, -2, -1], + } + ) + + result = df.groupby(by=["col1", "col1", "col2"], as_index=False).agg( + new_col=pd.NamedAgg(column="col1", aggfunc="min"), + new_col1=pd.NamedAgg(column="col1", aggfunc="max"), + new_col2=pd.NamedAgg(column="col2", aggfunc="count"), + ) + + expected = DataFrame( + { + "col1": [0, 0, 1, 1, 2], + "col2": [5, 7, 5, 36, 4], + "new_col": [0, 0, 1, 1, 2], + "new_col1": [0, 0, 1, 1, 2], + "new_col2": [1, 1, 1, 1, 2], + } + ) + + tm.assert_frame_equal(result, expected) + + +def test_groupby_multi_index_codes(): + # GH#54347 + df = DataFrame( + {"A": [1, 2, 3, 4], "B": [1, float("nan"), 2, float("nan")], "C": [2, 4, 6, 8]} + ) + df_grouped = df.groupby(["A", "B"], dropna=False).sum() + + index = df_grouped.index + tm.assert_index_equal(index, MultiIndex.from_frame(index.to_frame())) + + +def test_groupby_datetime_with_nat(): + # GH##35202 + df = DataFrame( + { + "a": [ + to_datetime("2019-02-12"), + to_datetime("2019-02-12"), + to_datetime("2019-02-13"), + pd.NaT, + ], + "b": [1, 2, 3, 4], + } + ) + grouped = df.groupby("a", dropna=False) + result = len(grouped) + assert result == 3 diff --git a/pandas/tests/groupby/test_groupby_dropna.py b/pandas/tests/groupby/test_groupby_dropna.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddeefedc217ff960b957d92a2f7ca169b2f9ba5 --- /dev/null +++ b/pandas/tests/groupby/test_groupby_dropna.py @@ -0,0 +1,692 @@ +import numpy as np +import pytest + +from pandas.errors import Pandas4Warning +import pandas.util._test_decorators as td + +from pandas.core.dtypes.missing import na_value_for_dtype + +import pandas as pd +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + + +@pytest.mark.parametrize( + "dropna, tuples, outputs", + [ + ( + True, + [["A", "B"], ["B", "A"]], + {"c": [13.0, 123.23], "d": [13.0, 123.0], "e": [13.0, 1.0]}, + ), + ( + False, + [["A", "B"], ["A", np.nan], ["B", "A"]], + { + "c": [13.0, 12.3, 123.23], + "d": [13.0, 233.0, 123.0], + "e": [13.0, 12.0, 1.0], + }, + ), + ], +) +def test_groupby_dropna_multi_index_dataframe_nan_in_one_group( + dropna, tuples, outputs, nulls_fixture +): + # GH 3729 this is to test that NA is in one group + df_list = [ + ["A", "B", 12, 12, 12], + ["A", nulls_fixture, 12.3, 233.0, 12], + ["B", "A", 123.23, 123, 1], + ["A", "B", 1, 1, 1.0], + ] + df = pd.DataFrame(df_list, columns=["a", "b", "c", "d", "e"]) + grouped = df.groupby(["a", "b"], dropna=dropna).sum() + + mi = pd.MultiIndex.from_tuples(tuples, names=list("ab")) + + # Since right now, by default MI will drop NA from levels when we create MI + # via `from_*`, so we need to add NA for level manually afterwards. + if not dropna: + mi = mi.set_levels(["A", "B", np.nan], level="b") + expected = pd.DataFrame(outputs, index=mi) + + tm.assert_frame_equal(grouped, expected) + + +@pytest.mark.parametrize( + "dropna, tuples, outputs", + [ + ( + True, + [["A", "B"], ["B", "A"]], + {"c": [12.0, 123.23], "d": [12.0, 123.0], "e": [12.0, 1.0]}, + ), + ( + False, + [["A", "B"], ["A", np.nan], ["B", "A"], [np.nan, "B"]], + { + "c": [12.0, 13.3, 123.23, 1.0], + "d": [12.0, 234.0, 123.0, 1.0], + "e": [12.0, 13.0, 1.0, 1.0], + }, + ), + ], +) +def test_groupby_dropna_multi_index_dataframe_nan_in_two_groups( + dropna, tuples, outputs, nulls_fixture, nulls_fixture2 +): + # GH 3729 this is to test that NA in different groups with different representations + df_list = [ + ["A", "B", 12, 12, 12], + ["A", nulls_fixture, 12.3, 233.0, 12], + ["B", "A", 123.23, 123, 1], + [nulls_fixture2, "B", 1, 1, 1.0], + ["A", nulls_fixture2, 1, 1, 1.0], + ] + df = pd.DataFrame(df_list, columns=["a", "b", "c", "d", "e"]) + grouped = df.groupby(["a", "b"], dropna=dropna).sum() + + mi = pd.MultiIndex.from_tuples(tuples, names=list("ab")) + + # Since right now, by default MI will drop NA from levels when we create MI + # via `from_*`, so we need to add NA for level manually afterwards. + if not dropna: + mi = mi.set_levels([["A", "B", np.nan], ["A", "B", np.nan]]) + expected = pd.DataFrame(outputs, index=mi) + + tm.assert_frame_equal(grouped, expected) + + +@pytest.mark.parametrize( + "dropna, idx, outputs", + [ + (True, ["A", "B"], {"b": [123.23, 13.0], "c": [123.0, 13.0], "d": [1.0, 13.0]}), + ( + False, + ["A", "B", np.nan], + { + "b": [123.23, 13.0, 12.3], + "c": [123.0, 13.0, 233.0], + "d": [1.0, 13.0, 12.0], + }, + ), + ], +) +def test_groupby_dropna_normal_index_dataframe(dropna, idx, outputs): + # GH 3729 + df_list = [ + ["B", 12, 12, 12], + [None, 12.3, 233.0, 12], + ["A", 123.23, 123, 1], + ["B", 1, 1, 1.0], + ] + df = pd.DataFrame(df_list, columns=["a", "b", "c", "d"]) + grouped = df.groupby("a", dropna=dropna).sum() + + expected = pd.DataFrame(outputs, index=pd.Index(idx, name="a")) + + tm.assert_frame_equal(grouped, expected) + + +@pytest.mark.parametrize( + "dropna, idx, expected", + [ + (True, ["a", "a", "b", np.nan], pd.Series([3, 3], index=["a", "b"])), + ( + False, + ["a", "a", "b", np.nan], + pd.Series([3, 3, 3], index=["a", "b", np.nan]), + ), + ], +) +def test_groupby_dropna_series_level(dropna, idx, expected): + ser = pd.Series([1, 2, 3, 3], index=idx) + + result = ser.groupby(level=0, dropna=dropna).sum() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dropna, expected", + [ + (True, pd.Series([210.0, 350.0], index=["a", "b"], name="Max Speed")), + ( + False, + pd.Series([210.0, 350.0, 20.0], index=["a", "b", np.nan], name="Max Speed"), + ), + ], +) +def test_groupby_dropna_series_by(dropna, expected): + ser = pd.Series( + [390.0, 350.0, 30.0, 20.0], + index=["Falcon", "Falcon", "Parrot", "Parrot"], + name="Max Speed", + ) + + result = ser.groupby(["a", "b", "a", np.nan], dropna=dropna).mean() + tm.assert_series_equal(result, expected) + + +def test_grouper_dropna_propagation(dropna): + # GH 36604 + df = pd.DataFrame({"A": [0, 0, 1, None], "B": [1, 2, 3, None]}) + gb = df.groupby("A", dropna=dropna) + assert gb._grouper.dropna == dropna + + +@pytest.mark.parametrize( + "index", + [ + pd.RangeIndex(0, 4), + list("abcd"), + pd.MultiIndex.from_product([(1, 2), ("R", "B")], names=["num", "col"]), + ], +) +def test_groupby_dataframe_slice_then_transform(dropna, index): + # GH35014 & GH35612 + expected_data = {"B": [2, 2, 1, np.nan if dropna else 1]} + + df = pd.DataFrame({"A": [0, 0, 1, None], "B": [1, 2, 3, None]}, index=index) + gb = df.groupby("A", dropna=dropna) + + result = gb.transform(len) + expected = pd.DataFrame(expected_data, index=index) + tm.assert_frame_equal(result, expected) + + result = gb[["B"]].transform(len) + expected = pd.DataFrame(expected_data, index=index) + tm.assert_frame_equal(result, expected) + + result = gb["B"].transform(len) + expected = pd.Series(expected_data["B"], index=index, name="B") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dropna, tuples, outputs", + [ + ( + True, + [["A", "B"], ["B", "A"]], + {"c": [13.0, 123.23], "d": [12.0, 123.0], "e": [1.0, 1.0]}, + ), + ( + False, + [["A", "B"], ["A", np.nan], ["B", "A"]], + { + "c": [13.0, 12.3, 123.23], + "d": [12.0, 233.0, 123.0], + "e": [1.0, 12.0, 1.0], + }, + ), + ], +) +def test_groupby_dropna_multi_index_dataframe_agg(dropna, tuples, outputs): + # GH 3729 + df_list = [ + ["A", "B", 12, 12, 12], + ["A", None, 12.3, 233.0, 12], + ["B", "A", 123.23, 123, 1], + ["A", "B", 1, 1, 1.0], + ] + df = pd.DataFrame(df_list, columns=["a", "b", "c", "d", "e"]) + agg_dict = {"c": "sum", "d": "max", "e": "min"} + grouped = df.groupby(["a", "b"], dropna=dropna).agg(agg_dict) + + mi = pd.MultiIndex.from_tuples(tuples, names=list("ab")) + + # Since right now, by default MI will drop NA from levels when we create MI + # via `from_*`, so we need to add NA for level manually afterwards. + if not dropna: + mi = mi.set_levels(["A", "B", np.nan], level="b") + expected = pd.DataFrame(outputs, index=mi) + + tm.assert_frame_equal(grouped, expected) + + +@pytest.mark.arm_slow +@pytest.mark.parametrize( + "datetime1, datetime2", + [ + (pd.Timestamp("2020-01-01"), pd.Timestamp("2020-02-01")), + (pd.Timedelta("-2 days"), pd.Timedelta("-1 days")), + (pd.Period("2020-01-01"), pd.Period("2020-02-01")), + ], +) +@pytest.mark.parametrize("dropna, values", [(True, [12, 3]), (False, [12, 3, 6])]) +def test_groupby_dropna_datetime_like_data( + dropna, values, datetime1, datetime2, unique_nulls_fixture, unique_nulls_fixture2 +): + # 3729 + df = pd.DataFrame( + { + "values": [1, 2, 3, 4, 5, 6], + "dt": [ + datetime1, + unique_nulls_fixture, + datetime2, + unique_nulls_fixture2, + datetime1, + datetime1, + ], + } + ) + + if dropna: + indexes = [datetime1, datetime2] + else: + indexes = [datetime1, datetime2, np.nan] + + grouped = df.groupby("dt", dropna=dropna).agg({"values": "sum"}) + expected = pd.DataFrame({"values": values}, index=pd.Index(indexes, name="dt")) + + tm.assert_frame_equal(grouped, expected) + + +@pytest.mark.parametrize( + "dropna, data, selected_data, levels", + [ + pytest.param( + False, + {"groups": ["a", "a", "b", np.nan], "values": [10, 10, 20, 30]}, + {"values": [0, 1, 0, 0]}, + ["a", "b", np.nan], + id="dropna_false_has_nan", + ), + pytest.param( + True, + {"groups": ["a", "a", "b", np.nan], "values": [10, 10, 20, 30]}, + {"values": [0, 1, 0]}, + None, + id="dropna_true_has_nan", + ), + pytest.param( + # no nan in "groups"; dropna=True|False should be same. + False, + {"groups": ["a", "a", "b", "c"], "values": [10, 10, 20, 30]}, + {"values": [0, 1, 0, 0]}, + None, + id="dropna_false_no_nan", + ), + pytest.param( + # no nan in "groups"; dropna=True|False should be same. + True, + {"groups": ["a", "a", "b", "c"], "values": [10, 10, 20, 30]}, + {"values": [0, 1, 0, 0]}, + None, + id="dropna_true_no_nan", + ), + ], +) +def test_groupby_apply_with_dropna_for_multi_index(dropna, data, selected_data, levels): + # GH 35889 + + df = pd.DataFrame(data) + gb = df.groupby("groups", dropna=dropna) + result = gb.apply(lambda grp: pd.DataFrame({"values": range(len(grp))})) + + mi_tuples = tuple(zip(data["groups"], selected_data["values"], strict=False)) + mi = pd.MultiIndex.from_tuples(mi_tuples, names=["groups", None]) + # Since right now, by default MI will drop NA from levels when we create MI + # via `from_*`, so we need to add NA for level manually afterwards. + if not dropna and levels: + mi = mi.set_levels(levels, level="groups") + + expected = pd.DataFrame(selected_data, index=mi) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("input_index", [None, ["a"], ["a", "b"]]) +@pytest.mark.parametrize("keys", [["a"], ["a", "b"]]) +@pytest.mark.parametrize("series", [True, False]) +def test_groupby_dropna_with_multiindex_input(input_index, keys, series): + # GH#46783 + obj = pd.DataFrame( + { + "a": [1, np.nan], + "b": [1, 1], + "c": [2, 3], + } + ) + + expected = obj.set_index(keys) + if series: + expected = expected["c"] + elif input_index == ["a", "b"] and keys == ["a"]: + # Column b should not be aggregated + expected = expected[["c"]] + + if input_index is not None: + obj = obj.set_index(input_index) + gb = obj.groupby(keys, dropna=False) + if series: + gb = gb["c"] + result = gb.sum() + + tm.assert_equal(result, expected) + + +def test_groupby_nan_included(): + # GH 35646 + data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]} + df = pd.DataFrame(data) + grouped = df.groupby("group", dropna=False) + result = grouped.indices + dtype = np.intp + expected = { + "g1": np.array([0, 2], dtype=dtype), + "g2": np.array([3], dtype=dtype), + np.nan: np.array([1, 4], dtype=dtype), + } + for result_values, expected_values in zip( + result.values(), expected.values(), strict=True + ): + tm.assert_numpy_array_equal(result_values, expected_values) + assert np.isnan(list(result.keys())[2]) + assert list(result.keys())[0:2] == ["g1", "g2"] + + +def test_groupby_drop_nan_with_multi_index(): + # GH 39895 + df = pd.DataFrame([[np.nan, 0, 1]], columns=["a", "b", "c"]) + df = df.set_index(["a", "b"]) + result = df.groupby(["a", "b"], dropna=False).first() + expected = df + tm.assert_frame_equal(result, expected) + + +# y >x and z is the missing value +@pytest.mark.parametrize( + "sequence", + [ + "xyzy", + "xxyz", + "yzxz", + "zzzz", + "zyzx", + "yyyy", + "zzxy", + "xyxy", + ], +) +@pytest.mark.parametrize( + "dtype", + [ + None, + "UInt8", + "Int8", + "UInt16", + "Int16", + "UInt32", + "Int32", + "UInt64", + "Int64", + "Float32", + "Float64", + "category", + "string", + pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")), + "datetime64[ns]", + "period[D]", + "Sparse[float]", + ], +) +@pytest.mark.parametrize("test_series", [True, False]) +def test_no_sort_keep_na(sequence, dtype, test_series, as_index): + # GH#46584, GH#48794 + + # Unique values to use for grouper, depends on dtype + if dtype in ("string", "string[pyarrow]"): + uniques = {"x": "x", "y": "y", "z": pd.NA} + elif dtype in ("datetime64[ns]", "period[D]"): + uniques = {"x": "2016-01-01", "y": "2017-01-01", "z": pd.NA} + elif dtype is not None and dtype.startswith(("I", "U", "F")): + uniques = {"x": 1, "y": 2, "z": pd.NA} + else: + uniques = {"x": 1, "y": 2, "z": np.nan} + + df = pd.DataFrame( + { + "key": pd.Series([uniques[label] for label in sequence], dtype=dtype), + "a": [0, 1, 2, 3], + } + ) + gb = df.groupby("key", dropna=False, sort=False, as_index=as_index, observed=False) + if test_series: + gb = gb["a"] + result = gb.sum() + + # Manually compute the groupby sum, use the labels "x", "y", and "z" to avoid + # issues with hashing np.nan + summed = {} + for idx, label in enumerate(sequence): + summed[label] = summed.get(label, 0) + idx + if dtype == "category": + index = pd.CategoricalIndex( + [uniques[e] for e in summed], + df["key"].cat.categories, + name="key", + ) + elif isinstance(dtype, str) and dtype.startswith("Sparse"): + index = pd.Index( + pd.array([uniques[label] for label in summed], dtype=dtype), name="key" + ) + else: + index = pd.Index([uniques[label] for label in summed], dtype=dtype, name="key") + expected = pd.Series(summed.values(), index=index, name="a", dtype=None) + if not test_series: + expected = expected.to_frame() + if not as_index: + expected = expected.reset_index() + if dtype is not None and dtype.startswith("Sparse"): + expected["key"] = expected["key"].astype(dtype) + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("test_series", [True, False]) +@pytest.mark.parametrize("dtype", [object, None]) +def test_null_is_null_for_dtype( + sort, dtype, nulls_fixture, nulls_fixture2, test_series +): + # GH#48506 - groups should always result in using the null for the dtype + df = pd.DataFrame({"a": [1, 2]}) + groups = pd.Series([nulls_fixture, nulls_fixture2], dtype=dtype) + obj = df["a"] if test_series else df + gb = obj.groupby(groups, dropna=False, sort=sort) + result = gb.sum() + index = pd.Index([na_value_for_dtype(groups.dtype)]) + expected = pd.DataFrame({"a": [3]}, index=index) + if test_series: + tm.assert_series_equal(result, expected["a"]) + else: + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("index_kind", ["range", "single", "multi"]) +def test_categorical_reducers(reduction_func, observed, sort, as_index, index_kind): + # Ensure there is at least one null value by appending to the end + values = np.append(np.random.default_rng(2).choice([1, 2, None], size=19), None) + df = pd.DataFrame( + {"x": pd.Categorical(values, categories=[1, 2, 3]), "y": range(20)} + ) + + # Strategy: Compare to dropna=True by filling null values with a new code + df_filled = df.copy() + df_filled["x"] = pd.Categorical(values, categories=[1, 2, 3, 4]).fillna(4) + + if index_kind == "range": + keys = ["x"] + elif index_kind == "single": + keys = ["x"] + df = df.set_index("x") + df_filled = df_filled.set_index("x") + else: + keys = ["x", "x2"] + df["x2"] = df["x"] + df = df.set_index(["x", "x2"]) + df_filled["x2"] = df_filled["x"] + df_filled = df_filled.set_index(["x", "x2"]) + args = get_groupby_method_args(reduction_func, df) + args_filled = get_groupby_method_args(reduction_func, df_filled) + if reduction_func == "corrwith" and index_kind == "range": + # Don't include the grouping columns so we can call reset_index + args = (args[0].drop(columns=keys),) + args_filled = (args_filled[0].drop(columns=keys),) + + gb_keepna = df.groupby( + keys, dropna=False, observed=observed, sort=sort, as_index=as_index + ) + + if not observed and reduction_func in ["idxmin", "idxmax"]: + with pytest.raises( + ValueError, match="empty group due to unobserved categories" + ): + getattr(gb_keepna, reduction_func)(*args) + return + + gb_filled = df_filled.groupby(keys, observed=observed, sort=sort, as_index=True) + if reduction_func == "corrwith": + warn = Pandas4Warning + msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn = None + msg = "" + with tm.assert_produces_warning(warn, match=msg): + expected = getattr(gb_filled, reduction_func)(*args_filled).reset_index() + expected["x"] = expected["x"].cat.remove_categories([4]) + if index_kind == "multi": + expected["x2"] = expected["x2"].cat.remove_categories([4]) + if as_index: + if index_kind == "multi": + expected = expected.set_index(["x", "x2"]) + else: + expected = expected.set_index("x") + if reduction_func in ("idxmax", "idxmin") and index_kind != "range": + # expected was computed with a RangeIndex; need to translate to index values + values = expected["y"].values.tolist() + if index_kind == "single": + values = [np.nan if e == 4 else e for e in values] + expected["y"] = pd.Categorical(values, categories=[1, 2, 3]) + else: + values = [(np.nan, np.nan) if e == (4, 4) else e for e in values] + expected["y"] = values + if reduction_func == "size": + # size, unlike other methods, has the desired behavior in GH#49519 + expected = expected.rename(columns={0: "size"}) + if as_index: + expected = expected["size"].rename(None) + + if reduction_func == "corrwith": + warn = Pandas4Warning + msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn = None + msg = "" + with tm.assert_produces_warning(warn, match=msg): + result = getattr(gb_keepna, reduction_func)(*args) + + # size will return a Series, others are DataFrame + tm.assert_equal(result, expected) + + +def test_categorical_transformers(transformation_func, observed, sort, as_index): + # GH#36327 + values = np.append(np.random.default_rng(2).choice([1, 2, None], size=19), None) + df = pd.DataFrame( + {"x": pd.Categorical(values, categories=[1, 2, 3]), "y": range(20)} + ) + args = get_groupby_method_args(transformation_func, df) + + # Compute result for null group + null_group_values = df[df["x"].isnull()]["y"] + if transformation_func == "cumcount": + null_group_data = list(range(len(null_group_values))) + elif transformation_func == "ngroup": + if sort: + if observed: + na_group = df["x"].nunique(dropna=False) - 1 + else: + # TODO: Should this be 3? + na_group = df["x"].nunique(dropna=False) - 1 + else: + na_group = df.iloc[: null_group_values.index[0]]["x"].nunique() + null_group_data = len(null_group_values) * [na_group] + else: + null_group_data = getattr(null_group_values, transformation_func)(*args) + null_group_result = pd.DataFrame({"y": null_group_data}) + + gb_keepna = df.groupby( + "x", dropna=False, observed=observed, sort=sort, as_index=as_index + ) + gb_dropna = df.groupby("x", dropna=True, observed=observed, sort=sort) + + result = getattr(gb_keepna, transformation_func)(*args) + expected = getattr(gb_dropna, transformation_func)(*args) + + for iloc, value in zip( + df[df["x"].isnull()].index.tolist(), + null_group_result.values.ravel(), + strict=True, + ): + if expected.ndim == 1: + expected.iloc[iloc] = value + else: + expected.iloc[iloc, 0] = value + if transformation_func == "ngroup": + expected[df["x"].notnull() & expected.ge(na_group)] += 1 + if transformation_func not in ("rank", "diff", "pct_change", "shift"): + expected = expected.astype("int64") + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("method", ["head", "tail"]) +def test_categorical_head_tail(method, observed, sort, as_index): + # GH#36327 + values = np.random.default_rng(2).choice([1, 2, None], 30) + df = pd.DataFrame( + {"x": pd.Categorical(values, categories=[1, 2, 3]), "y": range(len(values))} + ) + gb = df.groupby("x", dropna=False, observed=observed, sort=sort, as_index=as_index) + result = getattr(gb, method)() + + if method == "tail": + values = values[::-1] + # Take the top 5 values from each group + mask = ( + ((values == 1) & ((values == 1).cumsum() <= 5)) + | ((values == 2) & ((values == 2).cumsum() <= 5)) + # flake8 doesn't like the vectorized check for None, thinks we should use `is` + | ((values == None) & ((values == None).cumsum() <= 5)) # noqa: E711 + ) + if method == "tail": + mask = mask[::-1] + expected = df[mask] + + tm.assert_frame_equal(result, expected) + + +def test_categorical_agg(): + # GH#36327 + values = np.random.default_rng(2).choice([1, 2, None], 30) + df = pd.DataFrame( + {"x": pd.Categorical(values, categories=[1, 2, 3]), "y": range(len(values))} + ) + gb = df.groupby("x", dropna=False, observed=False) + result = gb.agg(lambda x: x.sum()) + expected = gb.sum() + tm.assert_frame_equal(result, expected) + + +def test_categorical_transform(): + # GH#36327 + values = np.random.default_rng(2).choice([1, 2, None], 30) + df = pd.DataFrame( + {"x": pd.Categorical(values, categories=[1, 2, 3]), "y": range(len(values))} + ) + gb = df.groupby("x", dropna=False, observed=False) + result = gb.transform(lambda x: x.sum()) + expected = gb.transform("sum") + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/groupby/test_groupby_subclass.py b/pandas/tests/groupby/test_groupby_subclass.py new file mode 100644 index 0000000000000000000000000000000000000000..e1dfb3aabdaf03d120058793e0c9866be2255db6 --- /dev/null +++ b/pandas/tests/groupby/test_groupby_subclass.py @@ -0,0 +1,152 @@ +from datetime import datetime + +import numpy as np +import pytest + +from pandas.errors import Pandas4Warning + +from pandas import ( + DataFrame, + Index, + Series, +) +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager|Passing a SingleBlockManager:DeprecationWarning" +) + + +@pytest.mark.parametrize( + "obj", + [ + tm.SubclassedDataFrame({"A": np.arange(0, 10)}), + tm.SubclassedSeries(np.arange(0, 10), name="A"), + ], +) +def test_groupby_preserves_subclass(obj, groupby_func): + # GH28330 -- preserve subclass through groupby operations + + if isinstance(obj, Series) and groupby_func in {"corrwith"}: + pytest.skip(f"Not applicable for Series and {groupby_func}") + + grouped = obj.groupby(np.arange(0, 10)) + + # Groups should preserve subclass type + assert isinstance(grouped.get_group(0), type(obj)) + + args = get_groupby_method_args(groupby_func, obj) + + warn = Pandas4Warning if groupby_func == "corrwith" else None + msg = f"{type(grouped).__name__}.corrwith is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result1 = getattr(grouped, groupby_func)(*args) + with tm.assert_produces_warning(warn, match=msg): + result2 = grouped.agg(groupby_func, *args) + + # Reduction or transformation kernels should preserve type + slices = {"ngroup", "cumcount", "size"} + if isinstance(obj, DataFrame) and groupby_func in slices: + assert isinstance(result1, tm.SubclassedSeries) + else: + assert isinstance(result1, type(obj)) + + # Confirm .agg() groupby operations return same results + if isinstance(result1, DataFrame): + tm.assert_frame_equal(result1, result2) + else: + tm.assert_series_equal(result1, result2) + + +def test_groupby_preserves_metadata(): + # GH-37343 + custom_df = tm.SubclassedDataFrame({"a": [1, 2, 3], "b": [1, 1, 2], "c": [7, 8, 9]}) + assert "testattr" in custom_df._metadata + custom_df.testattr = "hello" + for _, group_df in custom_df.groupby("c"): + assert group_df.testattr == "hello" + + # GH-45314 + def func(group): + assert isinstance(group, tm.SubclassedDataFrame) + assert hasattr(group, "testattr") + assert group.testattr == "hello" + return group.testattr + + result = custom_df.groupby("c").apply(func) + expected = tm.SubclassedSeries(["hello"] * 3, index=Index([7, 8, 9], name="c")) + tm.assert_series_equal(result, expected) + + result = custom_df.groupby("c").apply(func) + tm.assert_series_equal(result, expected) + + # https://github.com/pandas-dev/pandas/pull/56761 + result = custom_df.groupby("c")[["a", "b"]].apply(func) + tm.assert_series_equal(result, expected) + + def func2(group): + assert isinstance(group, tm.SubclassedSeries) + assert hasattr(group, "testattr") + return group.testattr + + custom_series = tm.SubclassedSeries([1, 2, 3]) + custom_series.testattr = "hello" + result = custom_series.groupby(custom_df["c"]).apply(func2) + tm.assert_series_equal(result, expected) + result = custom_series.groupby(custom_df["c"]).agg(func2) + tm.assert_series_equal(result, expected) + + +def test_groupby_apply_preserves_metadata(): + # GH#62134 - Test that apply() preserves metadata when returning DataFrames/Series + custom_df = tm.SubclassedDataFrame({"a": [1, 2, 3], "b": [1, 1, 2], "c": [7, 8, 9]}) + custom_df.testattr = "hello" + + def sum_func(group): + assert isinstance(group, tm.SubclassedDataFrame) + assert hasattr(group, "testattr") + assert group.testattr == "hello" + return group.sum() + + result = custom_df.groupby("c").apply(sum_func) + assert hasattr(result, "testattr"), "DataFrame apply() should preserve metadata" + assert result.testattr == "hello" + + custom_series = tm.SubclassedSeries([1, 2, 3]) + custom_series.testattr = "hello" + + def sum_series_func(group): + assert isinstance(group, tm.SubclassedSeries) + assert hasattr(group, "testattr") + assert group.testattr == "hello" + return group.sum() + + result = custom_series.groupby(custom_df["c"]).apply(sum_series_func) + assert hasattr(result, "testattr"), "Series apply() should preserve metadata" + assert result.testattr == "hello" + + +@pytest.mark.parametrize("obj", [DataFrame, tm.SubclassedDataFrame]) +def test_groupby_resample_preserves_subclass(obj): + # GH28330 -- preserve subclass through groupby.resample() + + df = obj( + { + "Buyer": Series("Carl Carl Carl Carl Joe Carl".split(), dtype=object), + "Quantity": [18, 3, 5, 1, 9, 3], + "Date": [ + datetime(2013, 9, 1, 13, 0), + datetime(2013, 9, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 3, 10, 0), + datetime(2013, 12, 2, 12, 0), + datetime(2013, 9, 2, 14, 0), + ], + } + ) + df = df.set_index("Date") + + # Confirm groupby.resample() preserves dataframe type + result = df.groupby("Buyer").resample("5D").sum() + assert isinstance(result, obj) diff --git a/pandas/tests/groupby/test_grouping.py b/pandas/tests/groupby/test_grouping.py new file mode 100644 index 0000000000000000000000000000000000000000..6450b1108d240b3eff5a26b6cb33fab8d0a07b90 --- /dev/null +++ b/pandas/tests/groupby/test_grouping.py @@ -0,0 +1,1216 @@ +""" +test where we are determining what we are grouping, or getting groups +""" + +from datetime import ( + date, + timedelta, +) + +import numpy as np +import pytest + +from pandas.errors import ( + Pandas4Warning, + SpecificationError, +) + +import pandas as pd +from pandas import ( + CategoricalIndex, + DataFrame, + Grouper, + Index, + MultiIndex, + Series, + Timestamp, + date_range, + period_range, +) +import pandas._testing as tm +from pandas.core.groupby.grouper import Grouping + +# selection +# -------------------------------- + + +class TestSelection: + def test_select_bad_cols(self): + df = DataFrame([[1, 2]], columns=["A", "B"]) + g = df.groupby("A") + with pytest.raises(KeyError, match="\"Columns not found: 'C'\""): + g[["C"]] + + with pytest.raises(KeyError, match="^[^A]+$"): + # A should not be referenced as a bad column... + # will have to rethink regex if you change message! + g[["A", "C"]] + + def test_groupby_duplicated_column_errormsg(self): + # GH7511 + df = DataFrame( + columns=["A", "B", "A", "C"], data=[range(4), range(2, 6), range(0, 8, 2)] + ) + + msg = "Grouper for 'A' not 1-dimensional" + with pytest.raises(ValueError, match=msg): + df.groupby("A") + with pytest.raises(ValueError, match=msg): + df.groupby(["A", "B"]) + + grouped = df.groupby("B") + c = grouped.count() + assert c.columns.nlevels == 1 + assert c.columns.size == 3 + + def test_column_select_via_attr(self, df): + result = df.groupby("A").C.sum() + expected = df.groupby("A")["C"].sum() + tm.assert_series_equal(result, expected) + + df["mean"] = 1.5 + result = df.groupby("A").mean(numeric_only=True) + expected = df.groupby("A")[["C", "D", "mean"]].agg("mean") + tm.assert_frame_equal(result, expected) + + def test_getitem_list_of_columns(self): + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.random.default_rng(2).standard_normal(8), + "E": np.random.default_rng(2).standard_normal(8), + } + ) + + result = df.groupby("A")[["C", "D"]].mean() + result2 = df.groupby("A")[df.columns[2:4]].mean() + + expected = df.loc[:, ["A", "C", "D"]].groupby("A").mean() + + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(result2, expected) + + def test_getitem_numeric_column_names(self): + # GH #13731 + df = DataFrame( + { + 0: list("abcd") * 2, + 2: np.random.default_rng(2).standard_normal(8), + 4: np.random.default_rng(2).standard_normal(8), + 6: np.random.default_rng(2).standard_normal(8), + } + ) + result = df.groupby(0)[df.columns[1:3]].mean() + result2 = df.groupby(0)[[2, 4]].mean() + + expected = df.loc[:, [0, 2, 4]].groupby(0).mean() + + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(result2, expected) + + # per GH 23566 enforced deprecation raises a ValueError + with pytest.raises(ValueError, match="Cannot subset columns with a tuple"): + df.groupby(0)[2, 4].mean() + + def test_getitem_single_tuple_of_columns_raises(self, df): + # per GH 23566 enforced deprecation raises a ValueError + with pytest.raises(ValueError, match="Cannot subset columns with a tuple"): + df.groupby("A")["C", "D"].mean() + + def test_getitem_single_column(self): + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": ["one", "one", "two", "three", "two", "two", "one", "three"], + "C": np.random.default_rng(2).standard_normal(8), + "D": np.random.default_rng(2).standard_normal(8), + "E": np.random.default_rng(2).standard_normal(8), + } + ) + + result = df.groupby("A")["C"].mean() + + as_frame = df.loc[:, ["A", "C"]].groupby("A").mean() + as_series = as_frame.iloc[:, 0] + expected = as_series + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "func", [lambda x: x.sum(), lambda x: x.agg(lambda y: y.sum())] + ) + def test_getitem_from_grouper(self, func): + # GH 50383 + df = DataFrame({"a": [1, 1, 2], "b": 3, "c": 4, "d": 5}) + gb = df.groupby(["a", "b"])[["a", "c"]] + + idx = MultiIndex.from_tuples([(1, 3), (2, 3)], names=["a", "b"]) + expected = DataFrame({"a": [2, 2], "c": [8, 4]}, index=idx) + result = func(gb) + + tm.assert_frame_equal(result, expected) + + def test_indices_grouped_by_tuple_with_lambda(self): + # GH 36158 + df = DataFrame( + { + "Tuples": ( + (x, y) + for x in [0, 1] + for y in np.random.default_rng(2).integers(3, 5, 5) + ) + } + ) + + gb = df.groupby("Tuples") + gb_lambda = df.groupby(lambda x: df.iloc[x, 0]) + + expected = gb.indices + result = gb_lambda.indices + + tm.assert_dict_equal(result, expected) + + +# grouping +# -------------------------------- + + +class TestGrouping: + @pytest.mark.parametrize( + "index", + [ + Index(list("abcde")), + Index(np.arange(5)), + Index(np.arange(5, dtype=float)), + date_range("2020-01-01", periods=5), + period_range("2020-01-01", periods=5), + ], + ) + def test_grouper_index_types(self, index): + # related GH5375 + # groupby misbehaving when using a Floatlike index + df = DataFrame(np.arange(10).reshape(5, 2), columns=list("AB"), index=index) + + df.groupby(list("abcde"), group_keys=False).apply(lambda x: x) + + df.index = df.index[::-1] + df.groupby(list("abcde"), group_keys=False).apply(lambda x: x) + + def test_grouper_multilevel_freq(self): + # GH 7885 + # with level and freq specified in a Grouper + d0 = date.today() - timedelta(days=14) + dates = date_range(d0, date.today()) + date_index = MultiIndex.from_product([dates, dates], names=["foo", "bar"]) + df = DataFrame(np.random.default_rng(2).integers(0, 100, 225), index=date_index) + + # Check string level + expected = ( + df.reset_index() + .groupby([Grouper(key="foo", freq="W"), Grouper(key="bar", freq="W")]) + .sum() + ) + # reset index changes columns dtype to object + expected.columns = Index([0], dtype="int64") + + result = df.groupby( + [Grouper(level="foo", freq="W"), Grouper(level="bar", freq="W")] + ).sum() + tm.assert_frame_equal(result, expected) + + # Check integer level + result = df.groupby( + [Grouper(level=0, freq="W"), Grouper(level=1, freq="W")] + ).sum() + tm.assert_frame_equal(result, expected) + + def test_grouper_creation_bug(self): + # GH 8795 + df = DataFrame({"A": [0, 0, 1, 1, 2, 2], "B": [1, 2, 3, 4, 5, 6]}) + g = df.groupby("A") + expected = g.sum() + + g = df.groupby(Grouper(key="A")) + result = g.sum() + tm.assert_frame_equal(result, expected) + + result = g.apply(lambda x: x.sum()) + tm.assert_frame_equal(result, expected) + + def test_grouper_creation_bug2(self): + # GH14334 + # Grouper(key=...) may be passed in a list + df = DataFrame( + {"A": [0, 0, 0, 1, 1, 1], "B": [1, 1, 2, 2, 3, 3], "C": [1, 2, 3, 4, 5, 6]} + ) + # Group by single column + expected = df.groupby("A").sum() + g = df.groupby([Grouper(key="A")]) + result = g.sum() + tm.assert_frame_equal(result, expected) + + # Group by two columns + # using a combination of strings and Grouper objects + expected = df.groupby(["A", "B"]).sum() + + # Group with two Grouper objects + g = df.groupby([Grouper(key="A"), Grouper(key="B")]) + result = g.sum() + tm.assert_frame_equal(result, expected) + + # Group with a string and a Grouper object + g = df.groupby(["A", Grouper(key="B")]) + result = g.sum() + tm.assert_frame_equal(result, expected) + + # Group with a Grouper object and a string + g = df.groupby([Grouper(key="A"), "B"]) + result = g.sum() + tm.assert_frame_equal(result, expected) + + def test_grouper_creation_bug3(self, unit): + # GH8866 + dti = date_range("20130101", periods=2, unit=unit) + mi = MultiIndex.from_product( + [list("ab"), range(2), dti], + names=["one", "two", "three"], + ) + ser = Series( + np.arange(8, dtype="int64"), + index=mi, + ) + result = ser.groupby(Grouper(level="three", freq="ME")).sum() + exp_dti = pd.DatetimeIndex( + [Timestamp("2013-01-31")], freq="ME", name="three" + ).as_unit(unit) + expected = Series( + [28], + index=exp_dti, + ) + tm.assert_series_equal(result, expected) + + # just specifying a level breaks + result = ser.groupby(Grouper(level="one")).sum() + expected = ser.groupby(level="one").sum() + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("func", [False, True]) + def test_grouper_returning_tuples(self, func): + # GH 22257 , both with dict and with callable + df = DataFrame({"X": ["A", "B", "A", "B"], "Y": [1, 4, 3, 2]}) + mapping = dict(zip(range(4), [("C", 5), ("D", 6)] * 2, strict=True)) + + if func: + gb = df.groupby(by=lambda idx: mapping[idx], sort=False) + else: + gb = df.groupby(by=mapping, sort=False) + + name, expected = next(iter(gb)) + assert name == ("C", 5) + result = gb.get_group(name) + + tm.assert_frame_equal(result, expected) + + def test_grouper_column_and_index(self): + # GH 14327 + + # Grouping a multi-index frame by a column and an index level should + # be equivalent to resetting the index and grouping by two columns + idx = MultiIndex.from_tuples( + [("a", 1), ("a", 2), ("a", 3), ("b", 1), ("b", 2), ("b", 3)] + ) + idx.names = ["outer", "inner"] + df_multi = DataFrame( + {"A": np.arange(6), "B": ["one", "one", "two", "two", "one", "one"]}, + index=idx, + ) + result = df_multi.groupby(["B", Grouper(level="inner")]).mean(numeric_only=True) + expected = ( + df_multi.reset_index().groupby(["B", "inner"]).mean(numeric_only=True) + ) + tm.assert_frame_equal(result, expected) + + # Test the reverse grouping order + result = df_multi.groupby([Grouper(level="inner"), "B"]).mean(numeric_only=True) + expected = ( + df_multi.reset_index().groupby(["inner", "B"]).mean(numeric_only=True) + ) + tm.assert_frame_equal(result, expected) + + # Grouping a single-index frame by a column and the index should + # be equivalent to resetting the index and grouping by two columns + df_single = df_multi.reset_index("outer") + result = df_single.groupby(["B", Grouper(level="inner")]).mean( + numeric_only=True + ) + expected = ( + df_single.reset_index().groupby(["B", "inner"]).mean(numeric_only=True) + ) + tm.assert_frame_equal(result, expected) + + # Test the reverse grouping order + result = df_single.groupby([Grouper(level="inner"), "B"]).mean( + numeric_only=True + ) + expected = ( + df_single.reset_index().groupby(["inner", "B"]).mean(numeric_only=True) + ) + tm.assert_frame_equal(result, expected) + + def test_groupby_levels_and_columns(self): + # GH9344, GH9049 + idx_names = ["x", "y"] + idx = MultiIndex.from_tuples([(1, 1), (1, 2), (3, 4), (5, 6)], names=idx_names) + df = DataFrame(np.arange(12).reshape(-1, 3), index=idx) + + by_levels = df.groupby(level=idx_names).mean() + # reset_index changes columns dtype to object + by_columns = df.reset_index().groupby(idx_names).mean() + + # without casting, by_columns.columns is object-dtype + by_columns.columns = by_columns.columns.astype(np.int64) + tm.assert_frame_equal(by_levels, by_columns) + + def test_groupby_categorical_index_and_columns(self, observed): + # GH18432, adapted for GH25871 + columns = ["A", "B", "A", "B"] + categories = ["B", "A"] + data = np.array( + [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 1, 2]], int + ) + cat_columns = CategoricalIndex(columns, categories=categories, ordered=True) + expected_data = np.array([[4, 2], [4, 2], [4, 2], [4, 2], [4, 2]], int) + expected_columns = CategoricalIndex( + categories, categories=categories, ordered=True + ) + + # test transposed version + df = DataFrame(data.T, index=cat_columns) + result = df.groupby(level=0, observed=observed).sum() + expected = DataFrame(data=expected_data.T, index=expected_columns) + tm.assert_frame_equal(result, expected) + + def test_grouper_getting_correct_binner(self): + # GH 10063 + # using a non-time-based grouper and a time-based grouper + # and specifying levels + df = DataFrame( + {"A": 1}, + index=MultiIndex.from_product( + [list("ab"), date_range("20130101", periods=80)], names=["one", "two"] + ), + ) + result = df.groupby( + [Grouper(level="one"), Grouper(level="two", freq="ME")] + ).sum() + expected = DataFrame( + {"A": [31, 28, 21, 31, 28, 21]}, + index=MultiIndex.from_product( + [list("ab"), date_range("20130101", freq="ME", periods=3)], + names=["one", "two"], + ), + ) + tm.assert_frame_equal(result, expected) + + def test_grouper_iter(self, df): + gb = df.groupby("A") + grouper = gb._grouper + result = sorted(grouper) + expected = ["bar", "foo"] + assert result == expected + + def test_empty_groups(self, df): + # see gh-1048 + with pytest.raises(ValueError, match="No group keys passed!"): + df.groupby([]) + + def test_groupby_grouper(self, df): + grouped = df.groupby("A") + grouper = grouped._grouper + result = df.groupby(grouper).mean(numeric_only=True) + expected = grouped.mean(numeric_only=True) + tm.assert_frame_equal(result, expected) + + def test_groupby_dict_mapping(self): + # GH #679 + s = Series({"T1": 5}) + result = s.groupby({"T1": "T2"}).agg("sum") + expected = s.groupby(["T2"]).agg("sum") + tm.assert_series_equal(result, expected) + + s = Series([1.0, 2.0, 3.0, 4.0], index=list("abcd")) + mapping = {"a": 0, "b": 0, "c": 1, "d": 1} + + result = s.groupby(mapping).mean() + result2 = s.groupby(mapping).agg("mean") + exp_key = np.array([0, 0, 1, 1], dtype=np.int64) + expected = s.groupby(exp_key).mean() + expected2 = s.groupby(exp_key).mean() + tm.assert_series_equal(result, expected) + tm.assert_series_equal(result, result2) + tm.assert_series_equal(result, expected2) + + @pytest.mark.parametrize( + "index", + [ + [0, 1, 2, 3], + ["a", "b", "c", "d"], + [Timestamp(2021, 7, 28 + i) for i in range(4)], + ], + ) + def test_groupby_series_named_with_tuple(self, frame_or_series, index): + # GH 42731 + obj = frame_or_series([1, 2, 3, 4], index=index) + groups = Series([1, 0, 1, 0], index=index, name=("a", "a")) + result = obj.groupby(groups).last() + expected = frame_or_series([4, 3]) + expected.index.name = ("a", "a") + tm.assert_equal(result, expected) + + def test_groupby_grouper_f_sanity_checked(self): + dates = date_range("01-Jan-2013", periods=12, freq="MS") + ts = Series(np.random.default_rng(2).standard_normal(12), index=dates) + + # GH51979 + # simple check that the passed function doesn't operates on the whole index + msg = "'Timestamp' object is not subscriptable" + with pytest.raises(TypeError, match=msg): + ts.groupby(lambda key: key[0:6]) + + result = ts.groupby(lambda x: x).sum() + expected = ts.groupby(ts.index).sum() + expected.index.freq = None + tm.assert_series_equal(result, expected) + + def test_groupby_with_datetime_key(self): + # GH 51158 + df = DataFrame( + { + "id": ["a", "b"] * 3, + "b": date_range("2000-01-01", "2000-01-03", freq="9h"), + } + ) + grouper = Grouper(key="b", freq="D") + gb = df.groupby([grouper, "id"]) + + # test number of groups + expected = { + (Timestamp("2000-01-01"), "a"): [0, 2], + (Timestamp("2000-01-01"), "b"): [1], + (Timestamp("2000-01-02"), "a"): [4], + (Timestamp("2000-01-02"), "b"): [3, 5], + } + tm.assert_dict_equal(gb.groups, expected) + + # test number of group keys + assert len(gb.groups.keys()) == 4 + + def test_grouping_error_on_multidim_input(self, df): + msg = "Grouper for '' not 1-dimensional" + with pytest.raises(ValueError, match=msg): + Grouping(df.index, df[["A", "A"]]) + + def test_multiindex_negative_level(self, multiindex_dataframe_random_data): + # GH 13901 + result = multiindex_dataframe_random_data.groupby(level=-1).sum() + expected = multiindex_dataframe_random_data.groupby(level="second").sum() + tm.assert_frame_equal(result, expected) + + result = multiindex_dataframe_random_data.groupby(level=-2).sum() + expected = multiindex_dataframe_random_data.groupby(level="first").sum() + tm.assert_frame_equal(result, expected) + + result = multiindex_dataframe_random_data.groupby(level=[-2, -1]).sum() + expected = multiindex_dataframe_random_data.sort_index() + tm.assert_frame_equal(result, expected) + + result = multiindex_dataframe_random_data.groupby(level=[-1, "first"]).sum() + expected = multiindex_dataframe_random_data.groupby( + level=["second", "first"] + ).sum() + tm.assert_frame_equal(result, expected) + + def test_agg_with_dict_raises(self, df): + df.columns = np.arange(len(df.columns)) + msg = "nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + df.groupby(1, as_index=False)[2].agg({"Q": np.mean}) + + def test_multiindex_columns_empty_level(self): + lst = [["count", "values"], ["to filter", ""]] + midx = MultiIndex.from_tuples(lst) + + df = DataFrame([[1, "A"]], columns=midx) + + msg = "In a future version, the keys" + grouped = df.groupby("to filter").groups + assert grouped["A"] == [0] + + with tm.assert_produces_warning(Pandas4Warning, match=msg): + grouped = df.groupby([("to filter", "")]).groups + assert grouped["A"] == [0] + + df = DataFrame([[1, "A"], [2, "B"]], columns=midx) + + expected = df.groupby("to filter").groups + with tm.assert_produces_warning(Pandas4Warning, match=msg): + result = df.groupby([("to filter", "")]).groups + assert result == expected + + df = DataFrame([[1, "A"], [2, "A"]], columns=midx) + + expected = df.groupby("to filter").groups + with tm.assert_produces_warning(Pandas4Warning, match=msg): + result = df.groupby([("to filter", "")]).groups + tm.assert_dict_equal(result, expected) + + def test_groupby_multiindex_tuple(self): + # GH 17979, GH#59179 + df = DataFrame( + [[1, 2, 3, 4], [3, 4, 5, 6], [1, 4, 2, 3]], + columns=MultiIndex.from_arrays([["a", "b", "b", "c"], [1, 1, 2, 2]]), + ) + + msg = "In a future version, the keys" + with tm.assert_produces_warning(Pandas4Warning, match=msg): + expected = df.groupby([("b", 1)]).groups + result = df.groupby(("b", 1)).groups + tm.assert_dict_equal(expected, result) + + df2 = DataFrame( + df.values, + columns=MultiIndex.from_arrays( + [["a", "b", "b", "c"], ["d", "d", "e", "e"]] + ), + ) + + with tm.assert_produces_warning(Pandas4Warning, match=msg): + expected = df2.groupby([("b", "d")]).groups + result = df.groupby(("b", 1)).groups + tm.assert_dict_equal(expected, result) + + df3 = DataFrame(df.values, columns=[("a", "d"), ("b", "d"), ("b", "e"), "c"]) + + with tm.assert_produces_warning(Pandas4Warning, match=msg): + expected = df3.groupby([("b", "d")]).groups + result = df.groupby(("b", 1)).groups + tm.assert_dict_equal(expected, result) + + def test_groupby_multiindex_partial_indexing_equivalence(self): + # GH 17977, GH#59179 + df = DataFrame( + [[1, 2, 3, 4], [3, 4, 5, 6], [1, 4, 2, 3]], + columns=MultiIndex.from_arrays([["a", "b", "b", "c"], [1, 1, 2, 2]]), + ) + + expected_mean = df.groupby([("a", 1)])[[("b", 1), ("b", 2)]].mean() + result_mean = df.groupby([("a", 1)])["b"].mean() + tm.assert_frame_equal(expected_mean, result_mean) + + expected_sum = df.groupby([("a", 1)])[[("b", 1), ("b", 2)]].sum() + result_sum = df.groupby([("a", 1)])["b"].sum() + tm.assert_frame_equal(expected_sum, result_sum) + + expected_count = df.groupby([("a", 1)])[[("b", 1), ("b", 2)]].count() + result_count = df.groupby([("a", 1)])["b"].count() + tm.assert_frame_equal(expected_count, result_count) + + expected_min = df.groupby([("a", 1)])[[("b", 1), ("b", 2)]].min() + result_min = df.groupby([("a", 1)])["b"].min() + tm.assert_frame_equal(expected_min, result_min) + + expected_max = df.groupby([("a", 1)])[[("b", 1), ("b", 2)]].max() + result_max = df.groupby([("a", 1)])["b"].max() + tm.assert_frame_equal(expected_max, result_max) + + msg = "In a future version, the keys" + with tm.assert_produces_warning(Pandas4Warning, match=msg): + expected_groups = df.groupby([("a", 1)])[[("b", 1), ("b", 2)]].groups + result_groups = df.groupby([("a", 1)])["b"].groups + tm.assert_dict_equal(expected_groups, result_groups) + + def test_groupby_level(self, sort, multiindex_dataframe_random_data, df): + # GH 17537 + frame = multiindex_dataframe_random_data + deleveled = frame.reset_index() + + result0 = frame.groupby(level=0, sort=sort).sum() + result1 = frame.groupby(level=1, sort=sort).sum() + + expected0 = frame.groupby(deleveled["first"].values, sort=sort).sum() + expected1 = frame.groupby(deleveled["second"].values, sort=sort).sum() + + expected0.index.name = "first" + expected1.index.name = "second" + + assert result0.index.name == "first" + assert result1.index.name == "second" + + tm.assert_frame_equal(result0, expected0) + tm.assert_frame_equal(result1, expected1) + assert result0.index.name == frame.index.names[0] + assert result1.index.name == frame.index.names[1] + + # groupby level name + result0 = frame.groupby(level="first", sort=sort).sum() + result1 = frame.groupby(level="second", sort=sort).sum() + tm.assert_frame_equal(result0, expected0) + tm.assert_frame_equal(result1, expected1) + + # raise exception for non-MultiIndex + msg = "level > 0 or level < -1 only valid with MultiIndex" + with pytest.raises(ValueError, match=msg): + df.groupby(level=1) + + def test_groupby_level_index_names(self): + # GH4014 this used to raise ValueError since 'exp'>1 (in py2) + df = DataFrame({"exp": ["A"] * 3 + ["B"] * 3, "var1": range(6)}).set_index( + "exp" + ) + df.groupby(level="exp") + msg = "level name foo is not the name of the index" + with pytest.raises(ValueError, match=msg): + df.groupby(level="foo") + + def test_groupby_level_with_nas(self, sort): + # GH 17537 + index = MultiIndex( + levels=[[1, 0], [0, 1, 2, 3]], + codes=[[1, 1, 1, 1, 0, 0, 0, 0], [0, 1, 2, 3, 0, 1, 2, 3]], + ) + + # factorizing doesn't confuse things + s = Series(np.arange(8.0), index=index) + result = s.groupby(level=0, sort=sort).sum() + expected = Series([6.0, 22.0], index=[0, 1]) + tm.assert_series_equal(result, expected) + + index = MultiIndex( + levels=[[1, 0], [0, 1, 2, 3]], + codes=[[1, 1, 1, 1, -1, 0, 0, 0], [0, 1, 2, 3, 0, 1, 2, 3]], + ) + + # factorizing doesn't confuse things + s = Series(np.arange(8.0), index=index) + result = s.groupby(level=0, sort=sort).sum() + expected = Series([6.0, 18.0], index=[0.0, 1.0]) + tm.assert_series_equal(result, expected) + + def test_groupby_args(self, multiindex_dataframe_random_data): + # PR8618 and issue 8015 + frame = multiindex_dataframe_random_data + + msg = "You have to supply one of 'by' and 'level'" + with pytest.raises(TypeError, match=msg): + frame.groupby() + + msg = "You have to supply one of 'by' and 'level'" + with pytest.raises(TypeError, match=msg): + frame.groupby(by=None, level=None) + + @pytest.mark.parametrize( + "sort,labels", + [ + [True, [2, 2, 2, 0, 0, 1, 1, 3, 3, 3]], + [False, [0, 0, 0, 1, 1, 2, 2, 3, 3, 3]], + ], + ) + def test_level_preserve_order(self, sort, labels, multiindex_dataframe_random_data): + # GH 17537 + grouped = multiindex_dataframe_random_data.groupby(level=0, sort=sort) + exp_labels = np.array(labels, np.intp) + tm.assert_almost_equal(grouped._grouper.ids, exp_labels) + + def test_grouping_labels(self, multiindex_dataframe_random_data): + grouped = multiindex_dataframe_random_data.groupby( + multiindex_dataframe_random_data.index.get_level_values(0) + ) + exp_labels = np.array([2, 2, 2, 0, 0, 1, 1, 3, 3, 3], dtype=np.intp) + tm.assert_almost_equal(grouped._grouper.codes[0], exp_labels) + + def test_list_grouper_with_nat(self): + # GH 14715, GH#59179 + df = DataFrame({"date": date_range("1/1/2011", periods=365, freq="D")}) + df.iloc[-1] = pd.NaT + grouper = Grouper(key="date", freq="YS") + msg = "In a future version, the keys" + + # Grouper in a list grouping + gb = df.groupby([grouper]) + expected = {Timestamp("2011-01-01"): Index(list(range(364)))} + with tm.assert_produces_warning(Pandas4Warning, match=msg): + result = gb.groups + tm.assert_dict_equal(result, expected) + + # Test case without a list + result = df.groupby(grouper) + expected = {Timestamp("2011-01-01"): 365} + tm.assert_dict_equal(result.groups, expected) + + @pytest.mark.parametrize( + "func,expected", + [ + ( + "transform", + Series(name=2, dtype=np.float64), + ), + ( + "agg", + Series( + name=2, dtype=np.float64, index=Index([], dtype=np.float64, name=1) + ), + ), + ( + "apply", + Series( + name=2, dtype=np.float64, index=Index([], dtype=np.float64, name=1) + ), + ), + ], + ) + def test_evaluate_with_empty_groups(self, func, expected): + # 26208 + # test transform'ing empty groups + # (not testing other agg fns, because they return + # different index objects. + df = DataFrame({1: [], 2: []}) + g = df.groupby(1, group_keys=True) + result = getattr(g[2], func)(lambda x: x) + tm.assert_series_equal(result, expected) + + def test_groupby_apply_empty_with_group_keys_false(self): + # 60471 + # test apply'ing empty groups with group_keys False + # (not testing other agg fns, because they return + # different index objects. + df = DataFrame({"A": [], "B": [], "C": []}) + g = df.groupby("A", group_keys=False) + result = g.apply(lambda x: x / x.sum()) + expected = DataFrame({"B": [], "C": []}, index=None) + tm.assert_frame_equal(result, expected) + + def test_groupby_empty(self): + # https://github.com/pandas-dev/pandas/issues/27190 + s = Series([], name="name", dtype="float64") + gr = s.groupby([]) + + result = gr.mean() + expected = s.set_axis(Index([], dtype=np.intp)) + tm.assert_series_equal(result, expected) + + # check group properties + assert len(gr._grouper.groupings) == 1 + tm.assert_numpy_array_equal( + gr._grouper.ids, np.array([], dtype=np.dtype(np.intp)) + ) + + assert gr._grouper.ngroups == 0 + + # check name + gb = s.groupby(s) + grouper = gb._grouper + result = grouper.names + expected = ["name"] + assert result == expected + + def test_groupby_level_index_value_all_na(self): + # issue 20519 + df = DataFrame( + [["x", np.nan, 10], [None, np.nan, 20]], columns=["A", "B", "C"] + ).set_index(["A", "B"]) + result = df.groupby(level=["A", "B"]).sum() + expected = DataFrame( + data=[], + index=MultiIndex( + levels=[Index(["x"], dtype="str"), Index([], dtype="float64")], + codes=[[], []], + names=["A", "B"], + ), + columns=["C"], + dtype="int64", + ) + tm.assert_frame_equal(result, expected) + + def test_groupby_multiindex_level_empty(self): + # https://github.com/pandas-dev/pandas/issues/31670 + df = DataFrame( + [[123, "a", 1.0], [123, "b", 2.0]], columns=["id", "category", "value"] + ) + df = df.set_index(["id", "category"]) + empty = df[df.value < 0] + result = empty.groupby("id").sum() + expected = DataFrame( + dtype="float64", + columns=["value"], + index=Index([], dtype=np.int64, name="id"), + ) + tm.assert_frame_equal(result, expected) + + def test_groupby_tuple_keys_handle_multiindex(self): + # https://github.com/pandas-dev/pandas/issues/21340 + df = DataFrame( + { + "num1": [0, 8, 9, 4, 3, 3, 5, 9, 3, 6], + "num2": [3, 8, 6, 4, 9, 2, 1, 7, 0, 9], + "num3": [6, 5, 7, 8, 5, 1, 1, 10, 7, 8], + "category_tuple": [ + (0, 1), + (0, 1), + (0, 1), + (0, 4), + (2, 3), + (2, 3), + (2, 3), + (2, 3), + (5,), + (6,), + ], + "category_string": list("aaabbbbcde"), + } + ) + expected = df.sort_values(by=["category_tuple", "num1"]) + result = df.groupby("category_tuple").apply(lambda x: x.sort_values(by="num1")) + expected = expected[result.columns] + tm.assert_frame_equal(result.reset_index(drop=True), expected) + + def test_groupby_grouper_immutable_list_item(self): + # GH 26564 - prevent 'ValueError: all keys need to be the same shape' + # when reusing a list of groupers + df1 = DataFrame([["05/29/2019"], ["05/28/2019"]], columns=["date"]).assign( + date=lambda df: pd.to_datetime(df["date"]) + ) + df2 = DataFrame(columns=["date"]).assign( + date=lambda df: pd.to_datetime(df["date"]) + ) + + groupers = [Grouper(key="date", freq="1D")] + + df1.groupby(groupers).head() + # no error + df2.groupby(groupers).head() + + +# get_group +# -------------------------------- + + +class TestGetGroup: + def test_get_group(self): + # GH 5267 + # be datelike friendly + df = DataFrame( + { + "DATE": pd.to_datetime( + [ + "10-Oct-2013", + "10-Oct-2013", + "10-Oct-2013", + "11-Oct-2013", + "11-Oct-2013", + "11-Oct-2013", + ] + ), + "label": ["foo", "foo", "bar", "foo", "foo", "bar"], + "VAL": [1, 2, 3, 4, 5, 6], + } + ) + + g = df.groupby("DATE") + key = next(iter(g.groups)) + result1 = g.get_group(key) + result2 = g.get_group(Timestamp(key).to_pydatetime()) + result3 = g.get_group(str(Timestamp(key))) + tm.assert_frame_equal(result1, result2) + tm.assert_frame_equal(result1, result3) + + g = df.groupby(["DATE", "label"]) + + key = next(iter(g.groups)) + result1 = g.get_group(key) + result2 = g.get_group((Timestamp(key[0]).to_pydatetime(), key[1])) + result3 = g.get_group((str(Timestamp(key[0])), key[1])) + tm.assert_frame_equal(result1, result2) + tm.assert_frame_equal(result1, result3) + + # must pass a same-length tuple with multiple keys + msg = "must supply a tuple to get_group with multiple grouping keys" + with pytest.raises(ValueError, match=msg): + g.get_group("foo") + with pytest.raises(ValueError, match=msg): + g.get_group("foo") + msg = "must supply a same-length tuple to get_group with multiple grouping keys" + with pytest.raises(ValueError, match=msg): + g.get_group(("foo", "bar", "baz")) + + def test_get_group_empty_bins(self, observed): + d = DataFrame([3, 1, 7, 6]) + bins = [0, 5, 10, 15] + g = d.groupby(pd.cut(d[0], bins), observed=observed) + + # TODO: should prob allow a str of Interval work as well + # IOW '(0, 5]' + result = g.get_group(pd.Interval(0, 5)) + expected = DataFrame([3, 1], index=[0, 1]) + tm.assert_frame_equal(result, expected) + + msg = r"Interval\(10, 15, closed='right'\)" + with pytest.raises(KeyError, match=msg): + g.get_group(pd.Interval(10, 15)) + + def test_get_group_grouped_by_tuple(self): + # GH 8121 + df = DataFrame([[(1,), (1, 2), (1,), (1, 2)]], index=["ids"]).T + gr = df.groupby("ids") + expected = DataFrame({"ids": [(1,), (1,)]}, index=[0, 2]) + result = gr.get_group((1,)) + tm.assert_frame_equal(result, expected) + + dt = pd.to_datetime(["2010-01-01", "2010-01-02", "2010-01-01", "2010-01-02"]) + df = DataFrame({"ids": [(x,) for x in dt]}) + gr = df.groupby("ids") + result = gr.get_group(("2010-01-01",)) + expected = DataFrame({"ids": [(dt[0],), (dt[0],)]}, index=[0, 2]) + tm.assert_frame_equal(result, expected) + + def test_get_group_grouped_by_tuple_with_lambda(self): + # GH 36158 + df = DataFrame( + { + "Tuples": ( + (x, y) + for x in [0, 1] + for y in np.random.default_rng(2).integers(3, 5, 5) + ) + } + ) + + gb = df.groupby("Tuples") + gb_lambda = df.groupby(lambda x: df.iloc[x, 0]) + + expected = gb.get_group(next(iter(gb.groups.keys()))) + result = gb_lambda.get_group(next(iter(gb_lambda.groups.keys()))) + + tm.assert_frame_equal(result, expected) + + def test_groupby_with_empty(self): + index = pd.DatetimeIndex(()) + data = () + series = Series(data, index, dtype=object) + grouper = Grouper(freq="D") + grouped = series.groupby(grouper) + assert next(iter(grouped), None) is None + + def test_groupby_with_single_column(self): + df = DataFrame({"a": list("abssbab")}) + tm.assert_frame_equal(df.groupby("a").get_group("a"), df.iloc[[0, 5]]) + # GH 13530 + exp = DataFrame( + index=Index(["a", "b", "s"], name="a"), columns=Index([], dtype="str") + ) + tm.assert_frame_equal(df.groupby("a").count(), exp) + tm.assert_frame_equal(df.groupby("a").sum(), exp) + + exp = df.iloc[[3, 4, 5]] + tm.assert_frame_equal(df.groupby("a").nth(1), exp) + + def test_gb_key_len_equal_axis_len(self): + # GH16843 + # test ensures that index and column keys are recognized correctly + # when number of keys equals axis length of groupby + df = DataFrame( + [["foo", "bar", "B", 1], ["foo", "bar", "B", 2], ["foo", "baz", "C", 3]], + columns=["first", "second", "third", "one"], + ) + df = df.set_index(["first", "second"]) + df = df.groupby(["first", "second", "third"]).size() + assert df.loc[("foo", "bar", "B")] == 2 + assert df.loc[("foo", "baz", "C")] == 1 + + +# groups & iteration +# -------------------------------- + + +class TestIteration: + def test_groups(self, df): + grouped = df.groupby(["A"]) + msg = "In a future version, the keys" + + with tm.assert_produces_warning(Pandas4Warning, match=msg): + groups = grouped.groups + assert groups is grouped.groups # caching works + + for k, v in groups.items(): + assert (df.loc[v]["A"] == k).all() + + grouped = df.groupby(["A", "B"]) + groups = grouped.groups + assert groups is grouped.groups # caching works + + for k, v in groups.items(): + assert (df.loc[v]["A"] == k[0]).all() + assert (df.loc[v]["B"] == k[1]).all() + + def test_grouping_is_iterable(self, tsframe): + # this code path isn't used anywhere else + # not sure it's useful + grouped = tsframe.groupby([lambda x: x.weekday(), lambda x: x.year]) + + # test it works + for g in grouped._grouper.groupings[0]: + pass + + def test_multi_iter(self): + s = Series(np.arange(6)) + k1 = np.array(["a", "a", "a", "b", "b", "b"]) + k2 = np.array(["1", "2", "1", "2", "1", "2"]) + + grouped = s.groupby([k1, k2]) + + iterated = list(grouped) + expected = [ + ("a", "1", s[[0, 2]]), + ("a", "2", s[[1]]), + ("b", "1", s[[4]]), + ("b", "2", s[[3, 5]]), + ] + for i, ((one, two), three) in enumerate(iterated): + e1, e2, e3 = expected[i] + assert e1 == one + assert e2 == two + tm.assert_series_equal(three, e3) + + def test_multi_iter_frame(self, three_group): + k1 = np.array(["b", "b", "b", "a", "a", "a"]) + k2 = np.array(["1", "2", "1", "2", "1", "2"]) + df = DataFrame( + { + "v1": np.random.default_rng(2).standard_normal(6), + "v2": np.random.default_rng(2).standard_normal(6), + "k1": k1, + "k2": k2, + }, + index=["one", "two", "three", "four", "five", "six"], + ) + + grouped = df.groupby(["k1", "k2"]) + + # things get sorted! + iterated = list(grouped) + idx = df.index + expected = [ + ("a", "1", df.loc[idx[[4]]]), + ("a", "2", df.loc[idx[[3, 5]]]), + ("b", "1", df.loc[idx[[0, 2]]]), + ("b", "2", df.loc[idx[[1]]]), + ] + for i, ((one, two), three) in enumerate(iterated): + e1, e2, e3 = expected[i] + assert e1 == one + assert e2 == two + tm.assert_frame_equal(three, e3) + + # don't iterate through groups with no data + df["k1"] = np.array(["b", "b", "b", "a", "a", "a"]) + df["k2"] = np.array(["1", "1", "1", "2", "2", "2"]) + grouped = df.groupby(["k1", "k2"]) + # calling `dict` on a DataFrameGroupBy leads to a TypeError, + # we need to use a dictionary comprehension here + groups = {key: gp for key, gp in grouped} # noqa: C416 + assert len(groups) == 2 + + def test_dictify(self, df): + dict(iter(df.groupby("A"))) + dict(iter(df.groupby(["A", "B"]))) + dict(iter(df["C"].groupby(df["A"]))) + dict(iter(df["C"].groupby([df["A"], df["B"]]))) + dict(iter(df.groupby("A")["C"])) + dict(iter(df.groupby(["A", "B"])["C"])) + + def test_groupby_with_small_elem(self): + # GH 8542 + # length=2 + df = DataFrame( + {"event": ["start", "start"], "change": [1234, 5678]}, + index=pd.DatetimeIndex(["2014-09-10", "2013-10-10"]), + ) + grouped = df.groupby([Grouper(freq="ME"), "event"]) + assert len(grouped.groups) == 2 + assert grouped.ngroups == 2 + assert (Timestamp("2014-09-30"), "start") in grouped.groups + assert (Timestamp("2013-10-31"), "start") in grouped.groups + + res = grouped.get_group((Timestamp("2014-09-30"), "start")) + tm.assert_frame_equal(res, df.iloc[[0], :]) + res = grouped.get_group((Timestamp("2013-10-31"), "start")) + tm.assert_frame_equal(res, df.iloc[[1], :]) + + df = DataFrame( + {"event": ["start", "start", "start"], "change": [1234, 5678, 9123]}, + index=pd.DatetimeIndex(["2014-09-10", "2013-10-10", "2014-09-15"]), + ) + grouped = df.groupby([Grouper(freq="ME"), "event"]) + assert len(grouped.groups) == 2 + assert grouped.ngroups == 2 + assert (Timestamp("2014-09-30"), "start") in grouped.groups + assert (Timestamp("2013-10-31"), "start") in grouped.groups + + res = grouped.get_group((Timestamp("2014-09-30"), "start")) + tm.assert_frame_equal(res, df.iloc[[0, 2], :]) + res = grouped.get_group((Timestamp("2013-10-31"), "start")) + tm.assert_frame_equal(res, df.iloc[[1], :]) + + # length=3 + df = DataFrame( + {"event": ["start", "start", "start"], "change": [1234, 5678, 9123]}, + index=pd.DatetimeIndex(["2014-09-10", "2013-10-10", "2014-08-05"]), + ) + grouped = df.groupby([Grouper(freq="ME"), "event"]) + assert len(grouped.groups) == 3 + assert grouped.ngroups == 3 + assert (Timestamp("2014-09-30"), "start") in grouped.groups + assert (Timestamp("2013-10-31"), "start") in grouped.groups + assert (Timestamp("2014-08-31"), "start") in grouped.groups + + res = grouped.get_group((Timestamp("2014-09-30"), "start")) + tm.assert_frame_equal(res, df.iloc[[0], :]) + res = grouped.get_group((Timestamp("2013-10-31"), "start")) + tm.assert_frame_equal(res, df.iloc[[1], :]) + res = grouped.get_group((Timestamp("2014-08-31"), "start")) + tm.assert_frame_equal(res, df.iloc[[2], :]) + + def test_grouping_string_repr(self): + # GH 13394 + mi = MultiIndex.from_arrays([list("AAB"), list("aba")]) + df = DataFrame([[1, 2, 3]], columns=mi) + gr = df.groupby(df[("A", "a")]) + + result = gr._grouper.groupings[0].__repr__() + expected = "Grouping(('A', 'a'))" + assert result == expected + + +def test_grouping_by_key_is_in_axis(): + # GH#50413 - Groupers specified by key are in-axis + df = DataFrame({"a": [1, 1, 2], "b": [1, 1, 2], "c": [3, 4, 5]}).set_index("a") + gb = df.groupby([Grouper(level="a"), Grouper(key="b")], as_index=False) + assert not gb._grouper.groupings[0].in_axis + assert gb._grouper.groupings[1].in_axis + + result = gb.sum() + expected = DataFrame({"a": [1, 2], "b": [1, 2], "c": [7, 5]}) + tm.assert_frame_equal(result, expected) + + +def test_groupby_any_with_timedelta(): + # GH#59712 + df = DataFrame({"value": [pd.Timedelta(1), pd.NaT]}) + + result = df.groupby(np.array([0, 1], dtype=np.int64))["value"].any() + + expected = Series({0: True, 1: False}, name="value", dtype=bool) + expected.index = expected.index.astype(np.int64) + + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/groupby/test_index_as_string.py b/pandas/tests/groupby/test_index_as_string.py new file mode 100644 index 0000000000000000000000000000000000000000..743db7e70b14b7f8c2d047e403884bdbf1b878a4 --- /dev/null +++ b/pandas/tests/groupby/test_index_as_string.py @@ -0,0 +1,72 @@ +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + + +@pytest.mark.parametrize( + "key_strs,groupers", + [ + ("inner", pd.Grouper(level="inner")), # Index name + (["inner"], [pd.Grouper(level="inner")]), # List of index name + (["B", "inner"], ["B", pd.Grouper(level="inner")]), # Column and index + (["inner", "B"], [pd.Grouper(level="inner"), "B"]), # Index and column + ], +) +@pytest.mark.parametrize("levels", [["inner"], ["inner", "outer"]]) +def test_grouper_index_level_as_string(levels, key_strs, groupers): + frame = pd.DataFrame( + { + "outer": ["a", "a", "a", "b", "b", "b"], + "inner": [1, 2, 3, 1, 2, 3], + "A": np.arange(6), + "B": ["one", "one", "two", "two", "one", "one"], + } + ) + frame = frame.set_index(levels) + if "B" not in key_strs or "outer" in frame.columns: + result = frame.groupby(key_strs).mean(numeric_only=True) + expected = frame.groupby(groupers).mean(numeric_only=True) + else: + result = frame.groupby(key_strs).mean() + expected = frame.groupby(groupers).mean() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "levels", + [ + "inner", + "outer", + "B", + ["inner"], + ["outer"], + ["B"], + ["inner", "outer"], + ["outer", "inner"], + ["inner", "outer", "B"], + ["B", "outer", "inner"], + ], +) +def test_grouper_index_level_as_string_series(levels): + # Compute expected result + df = pd.DataFrame( + { + "outer": ["a", "a", "a", "b", "b", "b"], + "inner": [1, 2, 3, 1, 2, 3], + "A": np.arange(6), + "B": ["one", "one", "two", "two", "one", "one"], + } + ) + series = df.set_index(["outer", "inner", "B"])["A"] + if isinstance(levels, list): + groupers = [pd.Grouper(level=lv) for lv in levels] + else: + groupers = pd.Grouper(level=levels) + + expected = series.groupby(groupers).mean() + + # Compute and check result + result = series.groupby(levels).mean() + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/groupby/test_indexing.py b/pandas/tests/groupby/test_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..a3d3f509e186ac8099dd5fd1c50e8265753fe16c --- /dev/null +++ b/pandas/tests/groupby/test_indexing.py @@ -0,0 +1,310 @@ +# Test GroupBy._positional_selector positional grouped indexing GH#42864 + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + + +@pytest.mark.parametrize( + "arg, expected_rows", + [ + [0, [0, 1, 4]], + [2, [5]], + [5, []], + [-1, [3, 4, 7]], + [-2, [1, 6]], + [-6, []], + ], +) +def test_int(slice_test_df, slice_test_grouped, arg, expected_rows): + # Test single integer + result = slice_test_grouped._positional_selector[arg] + expected = slice_test_df.iloc[expected_rows] + + tm.assert_frame_equal(result, expected) + + +def test_slice(slice_test_df, slice_test_grouped): + # Test single slice + result = slice_test_grouped._positional_selector[0:3:2] + expected = slice_test_df.iloc[[0, 1, 4, 5]] + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "arg, expected_rows", + [ + [[0, 2], [0, 1, 4, 5]], + [[0, 2, -1], [0, 1, 3, 4, 5, 7]], + [range(0, 3, 2), [0, 1, 4, 5]], + [{0, 2}, [0, 1, 4, 5]], + ], + ids=[ + "list", + "negative", + "range", + "set", + ], +) +def test_list(slice_test_df, slice_test_grouped, arg, expected_rows): + # Test lists of integers and integer valued iterables + result = slice_test_grouped._positional_selector[arg] + expected = slice_test_df.iloc[expected_rows] + + tm.assert_frame_equal(result, expected) + + +def test_ints(slice_test_df, slice_test_grouped): + # Test tuple of ints + result = slice_test_grouped._positional_selector[0, 2, -1] + expected = slice_test_df.iloc[[0, 1, 3, 4, 5, 7]] + + tm.assert_frame_equal(result, expected) + + +def test_slices(slice_test_df, slice_test_grouped): + # Test tuple of slices + result = slice_test_grouped._positional_selector[:2, -2:] + expected = slice_test_df.iloc[[0, 1, 2, 3, 4, 6, 7]] + + tm.assert_frame_equal(result, expected) + + +def test_mix(slice_test_df, slice_test_grouped): + # Test mixed tuple of ints and slices + result = slice_test_grouped._positional_selector[0, 1, -2:] + expected = slice_test_df.iloc[[0, 1, 2, 3, 4, 6, 7]] + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "arg, expected_rows", + [ + [0, [0, 1, 4]], + [[0, 2, -1], [0, 1, 3, 4, 5, 7]], + [(slice(None, 2), slice(-2, None)), [0, 1, 2, 3, 4, 6, 7]], + ], +) +def test_as_index(slice_test_df, arg, expected_rows): + # Test the default as_index behaviour + result = slice_test_df.groupby("Group", sort=False)._positional_selector[arg] + expected = slice_test_df.iloc[expected_rows] + + tm.assert_frame_equal(result, expected) + + +def test_doc_examples(): + # Test the examples in the documentation + df = pd.DataFrame( + [["a", 1], ["a", 2], ["a", 3], ["b", 4], ["b", 5]], columns=["A", "B"] + ) + + grouped = df.groupby("A", as_index=False) + + result = grouped._positional_selector[1:2] + expected = pd.DataFrame([["a", 2], ["b", 5]], columns=["A", "B"], index=[1, 4]) + + tm.assert_frame_equal(result, expected) + + result = grouped._positional_selector[1, -1] + expected = pd.DataFrame( + [["a", 2], ["a", 3], ["b", 5]], columns=["A", "B"], index=[1, 2, 4] + ) + + tm.assert_frame_equal(result, expected) + + +def test_multiindex(): + # Test the multiindex mentioned as the use-case in the documentation + + def _make_df_from_data(data): + rows = {} + for date in data: + for level in data[date]: + rows[(date, level[0])] = {"A": level[1], "B": level[2]} + + df = pd.DataFrame.from_dict(rows, orient="index") + df.index.names = ("Date", "Item") + return df + + rng = np.random.default_rng(2) + ndates = 100 + nitems = 20 + dates = pd.date_range("20130101", periods=ndates, freq="D") + items = [f"item {i}" for i in range(nitems)] + + multiindex_data = {} + for date in dates: + nitems_for_date = nitems - rng.integers(0, 12) + levels = [ + (item, rng.integers(0, 10000) / 100, rng.integers(0, 10000) / 100) + for item in items[:nitems_for_date] + ] + levels.sort(key=lambda x: x[1]) + multiindex_data[date] = levels + + df = _make_df_from_data(multiindex_data) + result = df.groupby("Date", as_index=False).nth(slice(3, -3)) + + sliced = {date: values[3:-3] for date, values in multiindex_data.items()} + expected = _make_df_from_data(sliced) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("arg", [1, 5, 30, 1000, -1, -5, -30, -1000]) +@pytest.mark.parametrize("method", ["head", "tail"]) +@pytest.mark.parametrize("simulated", [True, False]) +def test_against_head_and_tail(arg, method, simulated): + # Test gives the same results as grouped head and tail + n_groups = 100 + n_rows_per_group = 30 + + data = { + "group": [ + f"group {g}" for j in range(n_rows_per_group) for g in range(n_groups) + ], + "value": [ + f"group {g} row {j}" + for j in range(n_rows_per_group) + for g in range(n_groups) + ], + } + df = pd.DataFrame(data) + grouped = df.groupby("group", as_index=False) + size = arg if arg >= 0 else n_rows_per_group + arg + + if method == "head": + result = grouped._positional_selector[:arg] + + if simulated: + indices = [ + j * n_groups + i + for j in range(size) + for i in range(n_groups) + if j * n_groups + i < n_groups * n_rows_per_group + ] + expected = df.iloc[indices] + + else: + expected = grouped.head(arg) + + else: + result = grouped._positional_selector[-arg:] + + if simulated: + indices = [ + (n_rows_per_group + j - size) * n_groups + i + for j in range(size) + for i in range(n_groups) + if (n_rows_per_group + j - size) * n_groups + i >= 0 + ] + expected = df.iloc[indices] + + else: + expected = grouped.tail(arg) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("start", [None, 0, 1, 10, -1, -10]) +@pytest.mark.parametrize("stop", [None, 0, 1, 10, -1, -10]) +@pytest.mark.parametrize("step", [None, 1, 5]) +def test_against_df_iloc(start, stop, step): + # Test that a single group gives the same results as DataFrame.iloc + n_rows = 30 + + data = { + "group": ["group 0"] * n_rows, + "value": list(range(n_rows)), + } + df = pd.DataFrame(data) + grouped = df.groupby("group", as_index=False) + + result = grouped._positional_selector[start:stop:step] + expected = df.iloc[start:stop:step] + + tm.assert_frame_equal(result, expected) + + +def test_series(): + # Test grouped Series + ser = pd.Series([1, 2, 3, 4, 5], index=["a", "a", "a", "b", "b"]) + grouped = ser.groupby(level=0) + result = grouped._positional_selector[1:2] + expected = pd.Series([2, 5], index=["a", "b"]) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("step", [1, 2, 3, 4, 5]) +def test_step(step): + # Test slice with various step values + data = [["x", f"x{i}"] for i in range(5)] + data += [["y", f"y{i}"] for i in range(4)] + data += [["z", f"z{i}"] for i in range(3)] + df = pd.DataFrame(data, columns=["A", "B"]) + + grouped = df.groupby("A", as_index=False) + + result = grouped._positional_selector[::step] + + data = [["x", f"x{i}"] for i in range(0, 5, step)] + data += [["y", f"y{i}"] for i in range(0, 4, step)] + data += [["z", f"z{i}"] for i in range(0, 3, step)] + + index = [0 + i for i in range(0, 5, step)] + index += [5 + i for i in range(0, 4, step)] + index += [9 + i for i in range(0, 3, step)] + + expected = pd.DataFrame(data, columns=["A", "B"], index=index) + + tm.assert_frame_equal(result, expected) + + +def test_columns_on_iter(): + # GitHub issue #44821 + df = pd.DataFrame({k: range(10) for k in "ABC"}) + + # Group-by and select columns + cols = ["A", "B"] + for _, dg in df.groupby(df.A < 4)[cols]: + tm.assert_index_equal(dg.columns, pd.Index(cols)) + assert "C" not in dg.columns + + +@pytest.mark.parametrize("func", [list, pd.Index, pd.Series, np.array]) +def test_groupby_duplicated_columns(func): + # GH#44924 + df = pd.DataFrame( + { + "A": [1, 2], + "B": [3, 3], + "C": ["G", "G"], + } + ) + result = df.groupby("C")[func(["A", "B", "A"])].mean() + expected = pd.DataFrame( + [[1.5, 3.0, 1.5]], columns=["A", "B", "A"], index=pd.Index(["G"], name="C") + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_get_nonexisting_groups(): + # GH#32492 + df = pd.DataFrame( + data={ + "A": ["a1", "a2", None], + "B": ["b1", "b2", "b1"], + "val": [1, 2, 3], + } + ) + grps = df.groupby(by=["A", "B"]) + + msg = "('a2', 'b1')" + with pytest.raises(KeyError, match=msg): + grps.get_group(("a2", "b1")) diff --git a/pandas/tests/groupby/test_libgroupby.py b/pandas/tests/groupby/test_libgroupby.py new file mode 100644 index 0000000000000000000000000000000000000000..60095663b5a7d0c6fa2896f85b49373c1afff37a --- /dev/null +++ b/pandas/tests/groupby/test_libgroupby.py @@ -0,0 +1,344 @@ +import numpy as np +import pytest + +from pandas._libs import groupby as libgroupby +from pandas._libs.groupby import ( + group_cumprod, + group_cumsum, + group_mean, + group_sum, + group_var, +) + +from pandas.core.dtypes.common import ensure_platform_int + +from pandas import isna +import pandas._testing as tm + + +@pytest.mark.parametrize("dtype, rtol", [("float32", 1e-2), ("float64", 1e-5)]) +class TestGroupVar: + def test_group_var_generic_1d(self, dtype, rtol): + prng = np.random.default_rng(2) + + out = (np.nan * np.ones((5, 1))).astype(dtype) + counts = np.zeros(5, dtype="int64") + values = 10 * prng.random((15, 1)).astype(dtype) + labels = np.tile(np.arange(5), (3,)).astype("intp") + + expected_out = ( + np.squeeze(values).reshape((5, 3), order="F").std(axis=1, ddof=1) ** 2 + )[:, np.newaxis] + expected_counts = counts + 3 + + group_var(out, counts, values, labels) + assert np.allclose(out, expected_out, rtol) + tm.assert_numpy_array_equal(counts, expected_counts) + + def test_group_var_generic_1d_flat_labels(self, dtype, rtol): + prng = np.random.default_rng(2) + + out = (np.nan * np.ones((1, 1))).astype(dtype) + counts = np.zeros(1, dtype="int64") + values = 10 * prng.random((5, 1)).astype(dtype) + labels = np.zeros(5, dtype="intp") + + expected_out = np.array([[values.std(ddof=1) ** 2]]) + expected_counts = counts + 5 + + group_var(out, counts, values, labels) + + assert np.allclose(out, expected_out, rtol) + tm.assert_numpy_array_equal(counts, expected_counts) + + def test_group_var_generic_2d_all_finite(self, dtype, rtol): + prng = np.random.default_rng(2) + + out = (np.nan * np.ones((5, 2))).astype(dtype) + counts = np.zeros(5, dtype="int64") + values = 10 * prng.random((10, 2)).astype(dtype) + labels = np.tile(np.arange(5), (2,)).astype("intp") + + expected_out = np.std(values.reshape(2, 5, 2), ddof=1, axis=0) ** 2 + expected_counts = counts + 2 + + group_var(out, counts, values, labels) + assert np.allclose(out, expected_out, rtol) + tm.assert_numpy_array_equal(counts, expected_counts) + + def test_group_var_generic_2d_some_nan(self, dtype, rtol): + prng = np.random.default_rng(2) + + out = (np.nan * np.ones((5, 2))).astype(dtype) + counts = np.zeros(5, dtype="int64") + values = 10 * prng.random((10, 2)).astype(dtype) + values[:, 1] = np.nan + labels = np.tile(np.arange(5), (2,)).astype("intp") + + expected_out = np.vstack( + [ + values[:, 0].reshape(5, 2, order="F").std(ddof=1, axis=1) ** 2, + np.nan * np.ones(5), + ] + ).T.astype(dtype) + expected_counts = counts + 2 + + group_var(out, counts, values, labels) + tm.assert_almost_equal(out, expected_out, rtol=0.5e-06) + tm.assert_numpy_array_equal(counts, expected_counts) + + def test_group_var_constant(self, dtype, rtol): + # Regression test from GH 10448. + + out = np.array([[np.nan]], dtype=dtype) + counts = np.array([0], dtype="int64") + values = 0.832845131556193 * np.ones((3, 1), dtype=dtype) + labels = np.zeros(3, dtype="intp") + + group_var(out, counts, values, labels) + + assert counts[0] == 3 + assert out[0, 0] >= 0 + tm.assert_almost_equal(out[0, 0], 0.0) + + +def test_group_var_large_inputs(): + dtype = np.float64 + prng = np.random.default_rng(2) + + out = np.array([[np.nan]], dtype=dtype) + counts = np.array([0], dtype="int64") + values = (prng.random((10**6, 1)) + 10**12).astype(dtype) + labels = np.zeros(10**6, dtype="intp") + + group_var(out, counts, values, labels) + + assert counts[0] == 10**6 + tm.assert_almost_equal(out[0, 0], 1.0 / 12, rtol=0.5e-3) + + +@pytest.mark.parametrize("dtype", ["float32", "float64"]) +def test_group_ohlc(dtype): + obj = np.array(np.random.default_rng(2).standard_normal(20), dtype=dtype) + + bins = np.array([6, 12, 20]) + out = np.zeros((3, 4), dtype) + counts = np.zeros(len(out), dtype=np.int64) + labels = ensure_platform_int(np.repeat(np.arange(3), np.diff(np.r_[0, bins]))) + + func = libgroupby.group_ohlc + func(out, counts, obj[:, None], labels) + + def _ohlc(group): + if isna(group).all(): + return np.repeat(np.nan, 4) + return [group[0], group.max(), group.min(), group[-1]] + + expected = np.array([_ohlc(obj[:6]), _ohlc(obj[6:12]), _ohlc(obj[12:])]) + + tm.assert_almost_equal(out, expected) + tm.assert_numpy_array_equal(counts, np.array([6, 6, 8], dtype=np.int64)) + + obj[:6] = np.nan + func(out, counts, obj[:, None], labels) + expected[0] = np.nan + tm.assert_almost_equal(out, expected) + + +@pytest.mark.parametrize("dtype", [np.int64, np.uint64, np.float32, np.float64]) +@pytest.mark.parametrize( + "pd_op, np_op", + [ + (group_cumsum, np.cumsum), + (group_cumprod, np.cumprod), + ], +) +def test_cython_group_transform(dtype, pd_op, np_op): + # see gh-4095 + is_datetimelike = False + + data = np.array([[1], [2], [3], [4]], dtype=dtype) + answer = np.zeros_like(data) + + labels = np.array([0, 0, 0, 0], dtype=np.intp) + ngroups = 1 + pd_op(answer, data, labels, ngroups, is_datetimelike) + + tm.assert_numpy_array_equal(np_op(data), answer[:, 0], check_dtype=False) + + +def test_cython_group_transform_algos(): + # see gh-4095 + is_datetimelike = False + + # with nans + labels = np.array([0, 0, 0, 0, 0], dtype=np.intp) + ngroups = 1 + + data = np.array([[1], [2], [3], [np.nan], [4]], dtype="float64") + actual = np.zeros_like(data) + actual.fill(np.nan) + group_cumprod(actual, data, labels, ngroups, is_datetimelike) + expected = np.array([1, 2, 6, np.nan, 24], dtype="float64") + tm.assert_numpy_array_equal(actual[:, 0], expected) + + actual = np.zeros_like(data) + actual.fill(np.nan) + group_cumsum(actual, data, labels, ngroups, is_datetimelike) + expected = np.array([1, 3, 6, np.nan, 10], dtype="float64") + tm.assert_numpy_array_equal(actual[:, 0], expected) + + # timedelta + is_datetimelike = True + data = np.array([np.timedelta64(1, "ns")] * 5, dtype="m8[ns]")[:, None] + actual = np.zeros_like(data, dtype="int64") + group_cumsum(actual, data.view("int64"), labels, ngroups, is_datetimelike) + expected = np.array( + [ + np.timedelta64(1, "ns"), + np.timedelta64(2, "ns"), + np.timedelta64(3, "ns"), + np.timedelta64(4, "ns"), + np.timedelta64(5, "ns"), + ] + ) + tm.assert_numpy_array_equal(actual[:, 0].view("m8[ns]"), expected) + + +def test_cython_group_mean_datetimelike(): + actual = np.zeros(shape=(1, 1), dtype="float64") + counts = np.array([0], dtype="int64") + data = ( + np.array( + [np.timedelta64(2, "ns"), np.timedelta64(4, "ns"), np.timedelta64("NaT")], + dtype="m8[ns]", + )[:, None] + .view("int64") + .astype("float64") + ) + labels = np.zeros(len(data), dtype=np.intp) + + group_mean(actual, counts, data, labels, is_datetimelike=True) + + tm.assert_numpy_array_equal(actual[:, 0], np.array([3], dtype="float64")) + + +def test_cython_group_mean_wrong_min_count(): + actual = np.zeros(shape=(1, 1), dtype="float64") + counts = np.zeros(1, dtype="int64") + data = np.zeros(1, dtype="float64")[:, None] + labels = np.zeros(1, dtype=np.intp) + + with pytest.raises(AssertionError, match="min_count"): + group_mean(actual, counts, data, labels, is_datetimelike=True, min_count=0) + + +def test_cython_group_mean_not_datetimelike_but_has_NaT_values(): + actual = np.zeros(shape=(1, 1), dtype="float64") + counts = np.array([0], dtype="int64") + data = ( + np.array( + [np.timedelta64("NaT"), np.timedelta64("NaT")], + dtype="m8[ns]", + )[:, None] + .view("int64") + .astype("float64") + ) + labels = np.zeros(len(data), dtype=np.intp) + + group_mean(actual, counts, data, labels, is_datetimelike=False) + + tm.assert_numpy_array_equal( + actual[:, 0], np.array(np.divide(np.add(data[0], data[1]), 2), dtype="float64") + ) + + +def test_cython_group_mean_Inf_at_beginning_and_end(): + # GH 50367 + actual = np.array([[np.nan, np.nan], [np.nan, np.nan]], dtype="float64") + counts = np.array([0, 0], dtype="int64") + data = np.array( + [[np.inf, 1.0], [1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5, np.inf]], + dtype="float64", + ) + labels = np.array([0, 1, 0, 1, 0, 1], dtype=np.intp) + + group_mean(actual, counts, data, labels, is_datetimelike=False) + + expected = np.array([[np.inf, 3], [3, np.inf]], dtype="float64") + + tm.assert_numpy_array_equal( + actual, + expected, + ) + + +@pytest.mark.parametrize( + "values, out", + [ + ([[np.inf], [np.inf], [np.inf]], [[np.inf], [np.inf]]), + ([[np.inf], [np.inf], [-np.inf]], [[np.inf], [np.nan]]), + ([[np.inf], [-np.inf], [np.inf]], [[np.inf], [np.nan]]), + ([[np.inf], [-np.inf], [-np.inf]], [[np.inf], [-np.inf]]), + ], +) +def test_cython_group_sum_Inf_at_beginning_and_end(values, out): + # GH #53606 + actual = np.array([[np.nan], [np.nan]], dtype="float64") + counts = np.array([0, 0], dtype="int64") + data = np.array(values, dtype="float64") + labels = np.array([0, 1, 1], dtype=np.intp) + + group_sum(actual, counts, data, labels, None, is_datetimelike=False) + + expected = np.array(out, dtype="float64") + + tm.assert_numpy_array_equal( + actual, + expected, + ) + + +@pytest.mark.parametrize( + "values, expected_values", + [ + (np.finfo(np.float64).max, [[np.inf]]), + (np.finfo(np.float64).min, [[-np.inf]]), + ( + np.complex128(np.finfo(np.float64).min + np.finfo(np.float64).max * 1j), + [[complex(-np.inf, np.inf)]], + ), + ( + np.complex128(np.finfo(np.float64).max + np.finfo(np.float64).min * 1j), + [[complex(np.inf, -np.inf)]], + ), + ( + np.complex128(np.finfo(np.float64).max + np.finfo(np.float64).max * 1j), + [[complex(np.inf, np.inf)]], + ), + ( + np.complex128(np.finfo(np.float64).min + np.finfo(np.float64).min * 1j), + [[complex(-np.inf, -np.inf)]], + ), + ( + np.complex128(3.0 + np.finfo(np.float64).min * 1j), + [[complex(9.0, -np.inf)]], + ), + ( + np.complex128(np.finfo(np.float64).max + 3 * 1j), + [[complex(np.inf, 9.0)]], + ), + ], +) +def test_cython_group_sum_overflow(values, expected_values): + # GH-60303 + data = np.array([[values] for _ in range(3)]) + labels = np.array([0, 0, 0], dtype=np.intp) + counts = np.array([0], dtype="int64") + + expected = np.array(expected_values, dtype=values.dtype) + actual = np.zeros_like(expected) + + group_sum(actual, counts, data, labels, None, is_datetimelike=False) + + tm.assert_numpy_array_equal(actual, expected) diff --git a/pandas/tests/groupby/test_missing.py b/pandas/tests/groupby/test_missing.py new file mode 100644 index 0000000000000000000000000000000000000000..2b590c50371e9e6f22247c73e94785b16678a2d4 --- /dev/null +++ b/pandas/tests/groupby/test_missing.py @@ -0,0 +1,89 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, +) +import pandas._testing as tm + + +@pytest.mark.parametrize("func", ["ffill", "bfill"]) +def test_groupby_column_index_name_lost_fill_funcs(func): + # GH: 29764 groupby loses index sometimes + df = DataFrame( + [[1, 1.0, -1.0], [1, np.nan, np.nan], [1, 2.0, -2.0]], + columns=Index(["type", "a", "b"], name="idx"), + ) + df_grouped = df.groupby(["type"])[["a", "b"]] + result = getattr(df_grouped, func)().columns + expected = Index(["a", "b"], name="idx") + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("func", ["ffill", "bfill"]) +def test_groupby_fill_duplicate_column_names(func): + # GH: 25610 ValueError with duplicate column names + df1 = DataFrame({"field1": [1, 3, 4], "field2": [1, 3, 4]}) + df2 = DataFrame({"field1": [1, np.nan, 4]}) + df_grouped = pd.concat([df1, df2], axis=1).groupby(by=["field2"]) + expected = DataFrame( + [[1, 1.0], [3, np.nan], [4, 4.0]], columns=["field1", "field1"] + ) + result = getattr(df_grouped, func)() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("method", ["ffill", "bfill"]) +@pytest.mark.parametrize("has_nan_group", [True, False]) +def test_ffill_handles_nan_groups(dropna, method, has_nan_group): + # GH 34725 + + df_without_nan_rows = DataFrame([(1, 0.1), (2, 0.2)]) + + ridx = [-1, 0, -1, -1, 1, -1] + df = df_without_nan_rows.reindex(ridx).reset_index(drop=True) + + group_b = np.nan if has_nan_group else "b" + df["group_col"] = pd.Series(["a"] * 3 + [group_b] * 3) + + grouped = df.groupby(by="group_col", dropna=dropna) + result = getattr(grouped, method)(limit=None) + + expected_rows = { + ("ffill", True, True): [-1, 0, 0, -1, -1, -1], + ("ffill", True, False): [-1, 0, 0, -1, 1, 1], + ("ffill", False, True): [-1, 0, 0, -1, 1, 1], + ("ffill", False, False): [-1, 0, 0, -1, 1, 1], + ("bfill", True, True): [0, 0, -1, -1, -1, -1], + ("bfill", True, False): [0, 0, -1, 1, 1, -1], + ("bfill", False, True): [0, 0, -1, 1, 1, -1], + ("bfill", False, False): [0, 0, -1, 1, 1, -1], + } + + ridx = expected_rows.get((method, dropna, has_nan_group)) + expected = df_without_nan_rows.reindex(ridx).reset_index(drop=True) + # columns are a 'take' on df.columns, which are object dtype + expected.columns = expected.columns.astype(object) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("min_count, value", [(2, np.nan), (-1, 1.0)]) +@pytest.mark.parametrize("func", ["first", "last", "max", "min"]) +def test_min_count(func, min_count, value): + # GH#37821 + df = DataFrame({"a": [1] * 3, "b": [1, np.nan, np.nan], "c": [np.nan] * 3}) + result = getattr(df.groupby("a"), func)(min_count=min_count) + expected = DataFrame({"b": [value], "c": [np.nan]}, index=Index([1], name="a")) + tm.assert_frame_equal(result, expected) + + +def test_indices_with_missing(): + # GH 9304 + df = DataFrame({"a": [1, 1, np.nan], "b": [2, 3, 4], "c": [5, 6, 7]}) + g = df.groupby(["a", "b"]) + result = g.indices + expected = {(1.0, 2): np.array([0]), (1.0, 3): np.array([1])} + assert result == expected diff --git a/pandas/tests/groupby/test_numba.py b/pandas/tests/groupby/test_numba.py new file mode 100644 index 0000000000000000000000000000000000000000..082319d8479f02b9c61d08042269e77d56156898 --- /dev/null +++ b/pandas/tests/groupby/test_numba.py @@ -0,0 +1,82 @@ +import pytest + +from pandas.compat import is_platform_arm + +from pandas import ( + DataFrame, + Series, + option_context, +) +import pandas._testing as tm +from pandas.util.version import Version + +pytestmark = [pytest.mark.single_cpu] + +numba = pytest.importorskip("numba") +pytestmark.append( + pytest.mark.skipif( + Version(numba.__version__) == Version("0.61") and is_platform_arm(), + reason=f"Segfaults on ARM platforms with numba {numba.__version__}", + ) +) + + +@pytest.mark.filterwarnings("ignore") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +class TestEngine: + def test_cython_vs_numba_frame( + self, sort, nogil, parallel, nopython, numba_supported_reductions + ): + func, kwargs = numba_supported_reductions + df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)}) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + gb = df.groupby("a", sort=sort) + result = getattr(gb, func)( + engine="numba", engine_kwargs=engine_kwargs, **kwargs + ) + expected = getattr(gb, func)(**kwargs) + tm.assert_frame_equal(result, expected) + + def test_cython_vs_numba_getitem( + self, sort, nogil, parallel, nopython, numba_supported_reductions + ): + func, kwargs = numba_supported_reductions + df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)}) + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + gb = df.groupby("a", sort=sort)["c"] + result = getattr(gb, func)( + engine="numba", engine_kwargs=engine_kwargs, **kwargs + ) + expected = getattr(gb, func)(**kwargs) + tm.assert_series_equal(result, expected) + + def test_cython_vs_numba_series( + self, sort, nogil, parallel, nopython, numba_supported_reductions + ): + func, kwargs = numba_supported_reductions + ser = Series(range(3), index=[1, 2, 1], name="foo") + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + gb = ser.groupby(level=0, sort=sort) + result = getattr(gb, func)( + engine="numba", engine_kwargs=engine_kwargs, **kwargs + ) + expected = getattr(gb, func)(**kwargs) + tm.assert_series_equal(result, expected) + + def test_as_index_false_unsupported(self, numba_supported_reductions): + func, kwargs = numba_supported_reductions + df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)}) + gb = df.groupby("a", as_index=False) + with pytest.raises(NotImplementedError, match="as_index=False"): + getattr(gb, func)(engine="numba", **kwargs) + + def test_no_engine_doesnt_raise(self): + # GH55520 + df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)}) + gb = df.groupby("a") + # Make sure behavior of functions w/out engine argument don't raise + # when the global use_numba option is set + with option_context("compute.use_numba", True): + res = gb.agg({"b": "first"}) + expected = gb.agg({"b": "first"}) + tm.assert_frame_equal(res, expected) diff --git a/pandas/tests/groupby/test_numeric_only.py b/pandas/tests/groupby/test_numeric_only.py new file mode 100644 index 0000000000000000000000000000000000000000..b79ca8bf1ee3ae0e62b550d9c4322d7e10d7cf3a --- /dev/null +++ b/pandas/tests/groupby/test_numeric_only.py @@ -0,0 +1,445 @@ +import re + +import pytest + +from pandas._libs import lib +from pandas.errors import Pandas4Warning + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + Timestamp, + date_range, +) +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + + +class TestNumericOnly: + # make sure that we are passing thru kwargs to our agg functions + + @pytest.fixture + def df(self): + # GH3668 + # GH5724 + df = DataFrame( + { + "group": [1, 1, 2], + "int": [1, 2, 3], + "float": [4.0, 5.0, 6.0], + "string": Series(["a", "b", "c"], dtype="str"), + "object": Series(["a", "b", "c"], dtype=object), + "category_string": Series(list("abc")).astype("category"), + "category_int": [7, 8, 9], + "datetime": date_range("20130101", periods=3), + "datetimetz": date_range("20130101", periods=3, tz="US/Eastern"), + "timedelta": pd.timedelta_range("1 s", periods=3, freq="s"), + }, + columns=[ + "group", + "int", + "float", + "string", + "object", + "category_string", + "category_int", + "datetime", + "datetimetz", + "timedelta", + ], + ) + return df + + @pytest.mark.parametrize("method", ["mean", "median"]) + def test_averages(self, df, method): + # mean / median + expected_columns_numeric = Index(["int", "float", "category_int"]) + + gb = df.groupby("group") + expected = DataFrame( + { + "category_int": [7.5, 9], + "float": [4.5, 6.0], + "timedelta": [pd.Timedelta("1.5s"), pd.Timedelta("3s")], + "int": [1.5, 3], + "datetime": [ + Timestamp("2013-01-01 12:00:00"), + Timestamp("2013-01-03 00:00:00"), + ], + "datetimetz": [ + Timestamp("2013-01-01 12:00:00", tz="US/Eastern"), + Timestamp("2013-01-03 00:00:00", tz="US/Eastern"), + ], + }, + index=Index([1, 2], name="group"), + columns=[ + "int", + "float", + "category_int", + ], + ) + + result = getattr(gb, method)(numeric_only=True) + tm.assert_frame_equal(result.reindex_like(expected), expected) + + expected_columns = expected.columns + + self._check(df, method, expected_columns, expected_columns_numeric) + + @pytest.mark.parametrize("method", ["min", "max"]) + def test_extrema(self, df, method): + # TODO: min, max *should* handle + # categorical (ordered) dtype + + expected_columns = Index( + [ + "int", + "float", + "string", + "category_int", + "datetime", + "datetimetz", + "timedelta", + ] + ) + expected_columns_numeric = expected_columns + + self._check(df, method, expected_columns, expected_columns_numeric) + + @pytest.mark.parametrize("method", ["first", "last"]) + def test_first_last(self, df, method): + expected_columns = Index( + [ + "int", + "float", + "string", + "object", + "category_string", + "category_int", + "datetime", + "datetimetz", + "timedelta", + ] + ) + expected_columns_numeric = expected_columns + + self._check(df, method, expected_columns, expected_columns_numeric) + + @pytest.mark.parametrize("method", ["sum", "cumsum"]) + def test_sum_cumsum(self, df, method): + expected_columns_numeric = Index(["int", "float", "category_int"]) + expected_columns = Index( + ["int", "float", "string", "category_int", "timedelta"] + ) + if method == "cumsum": + # cumsum loses string + expected_columns = Index(["int", "float", "category_int", "timedelta"]) + + self._check(df, method, expected_columns, expected_columns_numeric) + + @pytest.mark.parametrize("method", ["prod", "cumprod"]) + def test_prod_cumprod(self, df, method): + expected_columns = Index(["int", "float", "category_int"]) + expected_columns_numeric = expected_columns + + self._check(df, method, expected_columns, expected_columns_numeric) + + @pytest.mark.parametrize("method", ["cummin", "cummax"]) + def test_cummin_cummax(self, df, method): + # like min, max, but don't include strings + expected_columns = Index( + ["int", "float", "category_int", "datetime", "datetimetz", "timedelta"] + ) + + # GH#15561: numeric_only=False set by default like min/max + expected_columns_numeric = expected_columns + + self._check(df, method, expected_columns, expected_columns_numeric) + + def _check(self, df, method, expected_columns, expected_columns_numeric): + gb = df.groupby("group") + + # object dtypes for transformations are not implemented in Cython and + # have no Python fallback + exception = ( + (NotImplementedError, TypeError) if method.startswith("cum") else TypeError + ) + + if method in ("min", "max", "cummin", "cummax", "cumsum", "cumprod"): + # The methods default to numeric_only=False and raise TypeError + msg = "|".join( + [ + "Categorical is not ordered", + f"Cannot perform {method} with non-ordered Categorical", + re.escape(f"agg function failed [how->{method},dtype->object]"), + # cumsum/cummin/cummax/cumprod + "function is not implemented for this dtype", + f"dtype 'str' does not support operation '{method}'", + ] + ) + with pytest.raises(exception, match=msg): + getattr(gb, method)() + elif method in ("sum", "mean", "median", "prod"): + msg = "|".join( + [ + "category type does not support sum operations", + re.escape(f"agg function failed [how->{method},dtype->object]"), + re.escape(f"agg function failed [how->{method},dtype->string]"), + f"dtype 'str' does not support operation '{method}'", + ] + ) + with pytest.raises(exception, match=msg): + getattr(gb, method)() + else: + result = getattr(gb, method)() + tm.assert_index_equal(result.columns, expected_columns_numeric) + + if method not in ("first", "last"): + msg = "|".join( + [ + "Categorical is not ordered", + "category type does not support", + "function is not implemented for this dtype", + f"Cannot perform {method} with non-ordered Categorical", + re.escape(f"agg function failed [how->{method},dtype->object]"), + re.escape(f"agg function failed [how->{method},dtype->string]"), + f"dtype 'str' does not support operation '{method}'", + ] + ) + with pytest.raises(exception, match=msg): + getattr(gb, method)(numeric_only=False) + else: + result = getattr(gb, method)(numeric_only=False) + tm.assert_index_equal(result.columns, expected_columns) + + +@pytest.mark.parametrize( + "kernel, has_arg", + [ + ("all", False), + ("any", False), + ("bfill", False), + ("corr", True), + ("corrwith", True), + ("cov", True), + ("cummax", True), + ("cummin", True), + ("cumprod", True), + ("cumsum", True), + ("diff", False), + ("ffill", False), + ("first", True), + ("idxmax", True), + ("idxmin", True), + ("last", True), + ("max", True), + ("mean", True), + ("median", True), + ("min", True), + ("nth", False), + ("nunique", False), + ("pct_change", False), + ("prod", True), + ("quantile", True), + ("sem", True), + ("skew", True), + ("kurt", True), + ("std", True), + ("sum", True), + ("var", True), + ], +) +@pytest.mark.parametrize("numeric_only", [True, False, lib.no_default]) +@pytest.mark.parametrize("keys", [["a1"], ["a1", "a2"]]) +def test_numeric_only(kernel, has_arg, numeric_only, keys): + # GH#46072 + # drops_nuisance: Whether the op drops nuisance columns even when numeric_only=False + # has_arg: Whether the op has a numeric_only arg + df = DataFrame({"a1": [1, 1], "a2": [2, 2], "a3": [5, 6], "b": 2 * [object]}) + + args = get_groupby_method_args(kernel, df) + kwargs = {} if numeric_only is lib.no_default else {"numeric_only": numeric_only} + + gb = df.groupby(keys) + method = getattr(gb, kernel) + if has_arg and numeric_only is True: + # Cases where b does not appear in the result + if kernel == "corrwith": + warn = Pandas4Warning + msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn = None + msg = "" + with tm.assert_produces_warning(warn, match=msg): + result = method(*args, **kwargs) + assert "b" not in result.columns + elif ( + # kernels that work on any dtype and have numeric_only arg + kernel in ("first", "last") + or ( + # kernels that work on any dtype and don't have numeric_only arg + kernel in ("any", "all", "bfill", "ffill", "nth", "nunique") + and numeric_only is lib.no_default + ) + ): + result = method(*args, **kwargs) + assert "b" in result.columns + elif has_arg: + assert numeric_only is not True + # kernels that are successful on any dtype were above; this will fail + + # object dtypes for transformations are not implemented in Cython and + # have no Python fallback + exception = NotImplementedError if kernel.startswith("cum") else TypeError + + msg = "|".join( + [ + "not allowed for this dtype", + "cannot be performed against 'object' dtypes", + "must be a string or a real number", + "unsupported operand type", + "function is not implemented for this dtype", + re.escape(f"agg function failed [how->{kernel},dtype->object]"), + ] + ) + if kernel == "quantile": + msg = "dtype 'object' does not support operation 'quantile'" + elif kernel == "idxmin": + msg = "'<' not supported between instances of 'type' and 'type'" + elif kernel == "idxmax": + msg = "'>' not supported between instances of 'type' and 'type'" + with pytest.raises(exception, match=msg): + if kernel == "corrwith": + warn = Pandas4Warning + msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn = None + msg = "" + with tm.assert_produces_warning(warn, match=msg): + method(*args, **kwargs) + elif not has_arg and numeric_only is not lib.no_default: + with pytest.raises( + TypeError, match="got an unexpected keyword argument 'numeric_only'" + ): + method(*args, **kwargs) + else: + assert kernel in ("diff", "pct_change") + assert numeric_only is lib.no_default + # Doesn't have numeric_only argument and fails on nuisance columns + with pytest.raises(TypeError, match=r"unsupported operand type"): + method(*args, **kwargs) + + +@pytest.mark.parametrize("dtype", [bool, int, float, object]) +def test_deprecate_numeric_only_series(dtype, groupby_func, request): + # GH#46560 + grouper = [0, 0, 1] + + ser = Series([1, 0, 0], dtype=dtype) + gb = ser.groupby(grouper) + + if groupby_func == "corrwith": + # corrwith is not implemented on SeriesGroupBy + assert not hasattr(gb, groupby_func) + return + + method = getattr(gb, groupby_func) + + expected_ser = Series([1, 0, 0]) + expected_gb = expected_ser.groupby(grouper) + expected_method = getattr(expected_gb, groupby_func) + + args = get_groupby_method_args(groupby_func, ser) + + fails_on_numeric_object = ( + "corr", + "cov", + "cummax", + "cummin", + "cumprod", + "cumsum", + "quantile", + ) + # ops that give an object result on object input + obj_result = ( + "first", + "last", + "nth", + "bfill", + "ffill", + "shift", + "sum", + "diff", + "pct_change", + "var", + "mean", + "median", + "min", + "max", + "prod", + "skew", + "kurt", + ) + + # Test default behavior; kernels that fail may be enabled in the future but kernels + # that succeed should not be allowed to fail (without deprecation, at least) + if groupby_func in fails_on_numeric_object and dtype is object: + if groupby_func == "quantile": + msg = "dtype 'object' does not support operation 'quantile'" + else: + msg = "is not supported for object dtype" + with pytest.raises(TypeError, match=msg): + method(*args) + elif dtype is object: + result = method(*args) + expected = expected_method(*args) + if groupby_func in obj_result: + expected = expected.astype(object) + tm.assert_series_equal(result, expected) + + has_numeric_only = ( + "first", + "last", + "max", + "mean", + "median", + "min", + "prod", + "quantile", + "sem", + "skew", + "kurt", + "std", + "sum", + "var", + "cummax", + "cummin", + "cumprod", + "cumsum", + ) + if groupby_func not in has_numeric_only: + msg = "got an unexpected keyword argument 'numeric_only'" + with pytest.raises(TypeError, match=msg): + method(*args, numeric_only=True) + elif dtype is object: + msg = "|".join( + [ + "SeriesGroupBy.sem called with numeric_only=True and dtype object", + "Series.skew does not allow numeric_only=True with non-numeric", + "cum(sum|prod|min|max) is not supported for object dtype", + r"Cannot use numeric_only=True with SeriesGroupBy\..* and non-numeric", + ] + ) + with pytest.raises(TypeError, match=msg): + method(*args, numeric_only=True) + elif dtype == bool and groupby_func == "quantile": + msg = "Cannot use quantile with bool dtype" + with pytest.raises(TypeError, match=msg): + # GH#51424 + method(*args, numeric_only=False) + else: + result = method(*args, numeric_only=True) + expected = method(*args, numeric_only=False) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/groupby/test_pipe.py b/pandas/tests/groupby/test_pipe.py new file mode 100644 index 0000000000000000000000000000000000000000..ee59a93695bcf84bcfcd8f1add8120e2c04004f5 --- /dev/null +++ b/pandas/tests/groupby/test_pipe.py @@ -0,0 +1,80 @@ +import numpy as np + +import pandas as pd +from pandas import ( + DataFrame, + Index, +) +import pandas._testing as tm + + +def test_pipe(): + # Test the pipe method of DataFrameGroupBy. + # Issue #17871 + + random_state = np.random.default_rng(2) + + df = DataFrame( + { + "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"], + "B": random_state.standard_normal(8), + "C": random_state.standard_normal(8), + } + ) + + def f(dfgb): + return dfgb.B.max() - dfgb.C.min().min() + + def square(srs): + return srs**2 + + # Note that the transformations are + # GroupBy -> Series + # Series -> Series + # This then chains the GroupBy.pipe and the + # NDFrame.pipe methods + result = df.groupby("A").pipe(f).pipe(square) + + index = Index(["bar", "foo"], name="A") + expected = pd.Series([3.749306591013693, 6.717707873081384], name="B", index=index) + + tm.assert_series_equal(expected, result) + + +def test_pipe_args(): + # Test passing args to the pipe method of DataFrameGroupBy. + # Issue #17871 + + df = DataFrame( + { + "group": ["A", "A", "B", "B", "C"], + "x": [1.0, 2.0, 3.0, 2.0, 5.0], + "y": [10.0, 100.0, 1000.0, -100.0, -1000.0], + } + ) + + def f(dfgb, arg1): + filtered = dfgb.filter(lambda grp: grp.y.mean() > arg1, dropna=False) + return filtered.groupby("group") + + def g(dfgb, arg2): + return dfgb.sum() / dfgb.sum().sum() + arg2 + + def h(df, arg3): + return df.x + df.y - arg3 + + result = df.groupby("group").pipe(f, 0).pipe(g, 10).pipe(h, 100) + + # Assert the results here + index = Index(["A", "B"], name="group") + expected = pd.Series([-79.5160891089, -78.4839108911], index=index) + + tm.assert_series_equal(result, expected) + + # test SeriesGroupby.pipe + ser = pd.Series([1, 1, 2, 2, 3, 3]) + result = ser.groupby(ser).pipe(lambda grp: grp.sum() * grp.count()) + + expected = pd.Series([4, 8, 12], index=Index([1, 2, 3], dtype=np.int64)) + + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/groupby/test_raises.py b/pandas/tests/groupby/test_raises.py new file mode 100644 index 0000000000000000000000000000000000000000..652c3808f5f1d13318a577533365ee552151f32d --- /dev/null +++ b/pandas/tests/groupby/test_raises.py @@ -0,0 +1,741 @@ +# Only tests that raise an error and have no better location should go here. +# Tests for specific groupby methods should go in their respective +# test file. + +import datetime +import re + +import numpy as np +import pytest + +from pandas.errors import Pandas4Warning + +from pandas import ( + Categorical, + DataFrame, + Grouper, + Series, +) +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args + + +@pytest.fixture( + params=[ + "a", + ["a"], + ["a", "b"], + Grouper(key="a"), + lambda x: x % 2, + [0, 0, 0, 1, 2, 2, 2, 3, 3], + np.array([0, 0, 0, 1, 2, 2, 2, 3, 3]), + dict(zip(range(9), [0, 0, 0, 1, 2, 2, 2, 3, 3], strict=True)), + Series([1, 1, 1, 1, 1, 2, 2, 2, 2]), + [Series([1, 1, 1, 1, 1, 2, 2, 2, 2]), Series([3, 3, 4, 4, 4, 4, 4, 3, 3])], + ] +) +def by(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def groupby_series(request): + return request.param + + +@pytest.fixture +def df_with_string_col(): + df = DataFrame( + { + "a": [1, 1, 1, 1, 1, 2, 2, 2, 2], + "b": [3, 3, 4, 4, 4, 4, 4, 3, 3], + "c": range(9), + "d": list("xyzwtyuio"), + } + ) + return df + + +@pytest.fixture +def df_with_datetime_col(): + df = DataFrame( + { + "a": [1, 1, 1, 1, 1, 2, 2, 2, 2], + "b": [3, 3, 4, 4, 4, 4, 4, 3, 3], + "c": range(9), + "d": datetime.datetime(2005, 1, 1, 10, 30, 23, 540000), + } + ) + return df + + +@pytest.fixture +def df_with_cat_col(): + df = DataFrame( + { + "a": [1, 1, 1, 1, 1, 2, 2, 2, 2], + "b": [3, 3, 4, 4, 4, 4, 4, 3, 3], + "c": range(9), + "d": Categorical( + ["a", "a", "a", "a", "b", "b", "b", "b", "c"], + categories=["a", "b", "c", "d"], + ordered=True, + ), + } + ) + return df + + +def _call_and_check( + klass, msg, how, gb, groupby_func, args, warn_category=None, warn_msg="" +): + with tm.assert_produces_warning( + warn_category, match=warn_msg, check_stacklevel=False + ): + if klass is None: + if how == "method": + getattr(gb, groupby_func)(*args) + elif how == "agg": + gb.agg(groupby_func, *args) + else: + gb.transform(groupby_func, *args) + else: + with pytest.raises(klass, match=msg): + if how == "method": + getattr(gb, groupby_func)(*args) + elif how == "agg": + gb.agg(groupby_func, *args) + else: + gb.transform(groupby_func, *args) + + +@pytest.mark.parametrize("how", ["method", "agg", "transform"]) +def test_groupby_raises_string( + how, by, groupby_series, groupby_func, df_with_string_col, using_infer_string +): + df = df_with_string_col + args = get_groupby_method_args(groupby_func, df) + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + if groupby_func == "corrwith": + assert not hasattr(gb, "corrwith") + return + + klass, msg = { + "all": (None, ""), + "any": (None, ""), + "bfill": (None, ""), + "corrwith": (TypeError, "Could not convert"), + "count": (None, ""), + "cumcount": (None, ""), + "cummax": ( + (NotImplementedError, TypeError), + "(function|cummax) is not (implemented|supported) for (this|object) dtype", + ), + "cummin": ( + (NotImplementedError, TypeError), + "(function|cummin) is not (implemented|supported) for (this|object) dtype", + ), + "cumprod": ( + (NotImplementedError, TypeError), + "(function|cumprod) is not (implemented|supported) for (this|object) dtype", + ), + "cumsum": ( + (NotImplementedError, TypeError), + "(function|cumsum) is not (implemented|supported) for (this|object) dtype", + ), + "diff": (TypeError, "unsupported operand type"), + "ffill": (None, ""), + "first": (None, ""), + "idxmax": (None, ""), + "idxmin": (None, ""), + "last": (None, ""), + "max": (None, ""), + "mean": ( + TypeError, + re.escape("agg function failed [how->mean,dtype->object]"), + ), + "median": ( + TypeError, + re.escape("agg function failed [how->median,dtype->object]"), + ), + "min": (None, ""), + "ngroup": (None, ""), + "nunique": (None, ""), + "pct_change": (TypeError, "unsupported operand type"), + "prod": ( + TypeError, + re.escape("agg function failed [how->prod,dtype->object]"), + ), + "quantile": (TypeError, "dtype 'object' does not support operation 'quantile'"), + "rank": (None, ""), + "sem": (ValueError, "could not convert string to float"), + "shift": (None, ""), + "size": (None, ""), + "skew": (ValueError, "could not convert string to float"), + "kurt": (ValueError, "could not convert string to float"), + "std": (ValueError, "could not convert string to float"), + "sum": (None, ""), + "var": ( + TypeError, + re.escape("agg function failed [how->var,dtype->"), + ), + }[groupby_func] + + if using_infer_string: + if groupby_func in [ + "prod", + "mean", + "median", + "cumsum", + "cumprod", + "std", + "sem", + "var", + "skew", + "kurt", + "quantile", + ]: + msg = f"dtype 'str' does not support operation '{groupby_func}'" + if groupby_func in ["sem", "std", "skew", "kurt"]: + # The object-dtype raises ValueError when trying to convert to numeric. + klass = TypeError + elif groupby_func == "pct_change" and df["d"].dtype.storage == "pyarrow": + # This doesn't go through EA._groupby_op so the message isn't controlled + # there. + msg = "operation 'truediv' not supported for dtype 'str' with dtype 'str'" + elif groupby_func == "diff" and df["d"].dtype.storage == "pyarrow": + # This doesn't go through EA._groupby_op so the message isn't controlled + # there. + msg = "operation 'sub' not supported for dtype 'str' with dtype 'str'" + + elif groupby_func in ["cummin", "cummax"]: + msg = msg.replace("object", "str") + elif groupby_func == "corrwith": + msg = "Cannot perform reduction 'mean' with string dtype" + + if groupby_func == "corrwith": + warn_category = Pandas4Warning + warn_msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn_category = None + warn_msg = "" + _call_and_check(klass, msg, how, gb, groupby_func, args, warn_category, warn_msg) + + +@pytest.mark.parametrize("how", ["agg", "transform"]) +def test_groupby_raises_string_udf(how, by, groupby_series, df_with_string_col): + df = df_with_string_col + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + def func(x): + raise TypeError("Test error message") + + with pytest.raises(TypeError, match="Test error message"): + getattr(gb, how)(func) + + +@pytest.mark.parametrize("how", ["agg", "transform"]) +@pytest.mark.parametrize("groupby_func_np", [np.sum, np.mean]) +def test_groupby_raises_string_np( + how, + by, + groupby_series, + groupby_func_np, + df_with_string_col, + using_infer_string, +): + # GH#50749 + df = df_with_string_col + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + klass, msg = { + np.sum: (None, ""), + np.mean: ( + TypeError, + "Could not convert string .* to numeric|" + "Cannot perform reduction 'mean' with string dtype", + ), + }[groupby_func_np] + + if using_infer_string: + if groupby_func_np is np.mean: + klass = TypeError + msg = f"Cannot perform reduction '{groupby_func_np.__name__}' with string dtype" + + _call_and_check(klass, msg, how, gb, groupby_func_np, ()) + + +@pytest.mark.parametrize("how", ["method", "agg", "transform"]) +def test_groupby_raises_datetime( + how, by, groupby_series, groupby_func, df_with_datetime_col +): + df = df_with_datetime_col + args = get_groupby_method_args(groupby_func, df) + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + if groupby_func == "corrwith": + assert not hasattr(gb, "corrwith") + return + + klass, msg = { + "all": (TypeError, "'all' with datetime64 dtypes is no longer supported"), + "any": (TypeError, "'any' with datetime64 dtypes is no longer supported"), + "bfill": (None, ""), + "corrwith": (TypeError, "cannot perform __mul__ with this index type"), + "count": (None, ""), + "cumcount": (None, ""), + "cummax": (None, ""), + "cummin": (None, ""), + "cumprod": (TypeError, "datetime64 type does not support operation 'cumprod'"), + "cumsum": (TypeError, "datetime64 type does not support operation 'cumsum'"), + "diff": (None, ""), + "ffill": (None, ""), + "first": (None, ""), + "idxmax": (None, ""), + "idxmin": (None, ""), + "last": (None, ""), + "max": (None, ""), + "mean": (None, ""), + "median": (None, ""), + "min": (None, ""), + "ngroup": (None, ""), + "nunique": (None, ""), + "pct_change": (TypeError, "cannot perform __truediv__ with this index type"), + "prod": (TypeError, "datetime64 type does not support operation 'prod'"), + "quantile": (None, ""), + "rank": (None, ""), + "sem": (None, ""), + "shift": (None, ""), + "size": (None, ""), + "skew": ( + TypeError, + "|".join( + [ + r"dtype datetime64\[ns\] does not support operation", + "datetime64 type does not support operation 'skew'", + ] + ), + ), + "kurt": ( + TypeError, + "|".join( + [ + r"dtype datetime64\[ns\] does not support operation", + "datetime64 type does not support operation 'kurt'", + ] + ), + ), + "std": (None, ""), + "sum": (TypeError, "datetime64 type does not support operation 'sum"), + "var": (TypeError, "datetime64 type does not support operation 'var'"), + }[groupby_func] + + if groupby_func == "corrwith": + warn_category = Pandas4Warning + warn_msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn_category = None + warn_msg = "" + _call_and_check(klass, msg, how, gb, groupby_func, args, warn_category, warn_msg) + + +@pytest.mark.parametrize("how", ["agg", "transform"]) +def test_groupby_raises_datetime_udf(how, by, groupby_series, df_with_datetime_col): + df = df_with_datetime_col + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + def func(x): + raise TypeError("Test error message") + + with pytest.raises(TypeError, match="Test error message"): + getattr(gb, how)(func) + + +@pytest.mark.parametrize("how", ["agg", "transform"]) +@pytest.mark.parametrize("groupby_func_np", [np.sum, np.mean]) +def test_groupby_raises_datetime_np( + how, by, groupby_series, groupby_func_np, df_with_datetime_col +): + # GH#50749 + df = df_with_datetime_col + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + klass, msg = { + np.sum: ( + TypeError, + re.escape("datetime64[us] does not support operation 'sum'"), + ), + np.mean: (None, ""), + }[groupby_func_np] + _call_and_check(klass, msg, how, gb, groupby_func_np, ()) + + +@pytest.mark.parametrize("func", ["prod", "cumprod", "skew", "kurt", "var"]) +def test_groupby_raises_timedelta(func): + df = DataFrame( + { + "a": [1, 1, 1, 1, 1, 2, 2, 2, 2], + "b": [3, 3, 4, 4, 4, 4, 4, 3, 3], + "c": range(9), + "d": datetime.timedelta(days=1), + } + ) + gb = df.groupby(by="a") + + _call_and_check( + TypeError, + "timedelta64 type does not support .* operations", + "method", + gb, + func, + [], + ) + + +@pytest.mark.parametrize("how", ["method", "agg", "transform"]) +def test_groupby_raises_category( + how, by, groupby_series, groupby_func, df_with_cat_col +): + # GH#50749 + df = df_with_cat_col + args = get_groupby_method_args(groupby_func, df) + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + if groupby_func == "corrwith": + assert not hasattr(gb, "corrwith") + return + + klass, msg = { + "all": (None, ""), + "any": (None, ""), + "bfill": (None, ""), + "corrwith": ( + TypeError, + r"unsupported operand type\(s\) for \*: 'Categorical' and 'int'", + ), + "count": (None, ""), + "cumcount": (None, ""), + "cummax": ( + (NotImplementedError, TypeError), + "(category type does not support cummax operations|" + "category dtype not supported|" + "cummax is not supported for category dtype)", + ), + "cummin": ( + (NotImplementedError, TypeError), + "(category type does not support cummin operations|" + "category dtype not supported|" + "cummin is not supported for category dtype)", + ), + "cumprod": ( + (NotImplementedError, TypeError), + "(category type does not support cumprod operations|" + "category dtype not supported|" + "cumprod is not supported for category dtype)", + ), + "cumsum": ( + (NotImplementedError, TypeError), + "(category type does not support cumsum operations|" + "category dtype not supported|" + "cumsum is not supported for category dtype)", + ), + "diff": ( + TypeError, + r"unsupported operand type\(s\) for -: 'Categorical' and 'Categorical'", + ), + "ffill": (None, ""), + "first": (None, ""), + "idxmax": (None, ""), + "idxmin": (None, ""), + "last": (None, ""), + "max": (None, ""), + "mean": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support operation 'mean'", + "category dtype does not support aggregation 'mean'", + ] + ), + ), + "median": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support operation 'median'", + "category dtype does not support aggregation 'median'", + ] + ), + ), + "min": (None, ""), + "ngroup": (None, ""), + "nunique": (None, ""), + "pct_change": ( + TypeError, + r"unsupported operand type\(s\) for /: 'Categorical' and 'Categorical'", + ), + "prod": (TypeError, "category type does not support prod operations"), + "quantile": (TypeError, "No matching signature found"), + "rank": (None, ""), + "sem": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support operation 'sem'", + "category dtype does not support aggregation 'sem'", + ] + ), + ), + "shift": (None, ""), + "size": (None, ""), + "skew": ( + TypeError, + "|".join( + [ + "dtype category does not support operation 'skew'", + "category type does not support skew operations", + ] + ), + ), + "kurt": ( + TypeError, + "|".join( + [ + "dtype category does not support operation 'kurt'", + "category type does not support kurt operations", + ] + ), + ), + "std": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support operation 'std'", + "category dtype does not support aggregation 'std'", + ] + ), + ), + "sum": (TypeError, "category type does not support sum operations"), + "var": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support operation 'var'", + "category dtype does not support aggregation 'var'", + ] + ), + ), + }[groupby_func] + + if groupby_func == "corrwith": + warn_category = Pandas4Warning + warn_msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn_category = None + warn_msg = "" + _call_and_check(klass, msg, how, gb, groupby_func, args, warn_category, warn_msg) + + +@pytest.mark.parametrize("how", ["agg", "transform"]) +def test_groupby_raises_category_udf(how, by, groupby_series, df_with_cat_col): + # GH#50749 + df = df_with_cat_col + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + def func(x): + raise TypeError("Test error message") + + with pytest.raises(TypeError, match="Test error message"): + getattr(gb, how)(func) + + +@pytest.mark.parametrize("how", ["agg", "transform"]) +@pytest.mark.parametrize("groupby_func_np", [np.sum, np.mean]) +def test_groupby_raises_category_np( + how, by, groupby_series, groupby_func_np, df_with_cat_col +): + # GH#50749 + df = df_with_cat_col + gb = df.groupby(by=by) + + if groupby_series: + gb = gb["d"] + + klass, msg = { + np.sum: (TypeError, "dtype category does not support operation 'sum'"), + np.mean: ( + TypeError, + "dtype category does not support operation 'mean'", + ), + }[groupby_func_np] + _call_and_check(klass, msg, how, gb, groupby_func_np, ()) + + +@pytest.mark.filterwarnings("ignore:In a future version, the keys") +@pytest.mark.parametrize("how", ["method", "agg", "transform"]) +def test_groupby_raises_category_on_category( + how, + by, + groupby_series, + groupby_func, + observed, + df_with_cat_col, +): + # GH#50749 + df = df_with_cat_col + df["a"] = Categorical( + ["a", "a", "a", "a", "b", "b", "b", "b", "c"], + categories=["a", "b", "c", "d"], + ordered=True, + ) + args = get_groupby_method_args(groupby_func, df) + gb = df.groupby(by=by, observed=observed) + + if groupby_series: + gb = gb["d"] + + if groupby_func == "corrwith": + assert not hasattr(gb, "corrwith") + return + + empty_groups = not observed and any(group.empty for group in gb.groups.values()) + if how == "transform": + # empty groups will be ignored + empty_groups = False + + klass, msg = { + "all": (None, ""), + "any": (None, ""), + "bfill": (None, ""), + "corrwith": ( + TypeError, + r"unsupported operand type\(s\) for \*: 'Categorical' and 'int'", + ), + "count": (None, ""), + "cumcount": (None, ""), + "cummax": ( + (NotImplementedError, TypeError), + "(cummax is not supported for category dtype|" + "category dtype not supported|" + "category type does not support cummax operations)", + ), + "cummin": ( + (NotImplementedError, TypeError), + "(cummin is not supported for category dtype|" + "category dtype not supported|" + "category type does not support cummin operations)", + ), + "cumprod": ( + (NotImplementedError, TypeError), + "(cumprod is not supported for category dtype|" + "category dtype not supported|" + "category type does not support cumprod operations)", + ), + "cumsum": ( + (NotImplementedError, TypeError), + "(cumsum is not supported for category dtype|" + "category dtype not supported|" + "category type does not support cumsum operations)", + ), + "diff": (TypeError, "unsupported operand type"), + "ffill": (None, ""), + "first": (None, ""), + "idxmax": (ValueError, "empty group due to unobserved categories") + if empty_groups + else (None, ""), + "idxmin": (ValueError, "empty group due to unobserved categories") + if empty_groups + else (None, ""), + "last": (None, ""), + "max": (None, ""), + "mean": (TypeError, "category dtype does not support aggregation 'mean'"), + "median": (TypeError, "category dtype does not support aggregation 'median'"), + "min": (None, ""), + "ngroup": (None, ""), + "nunique": (None, ""), + "pct_change": (TypeError, "unsupported operand type"), + "prod": (TypeError, "category type does not support prod operations"), + "quantile": (TypeError, "No matching signature found"), + "rank": (None, ""), + "sem": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support operation 'sem'", + "category dtype does not support aggregation 'sem'", + ] + ), + ), + "shift": (None, ""), + "size": (None, ""), + "skew": ( + TypeError, + "|".join( + [ + "category type does not support skew operations", + "dtype category does not support operation 'skew'", + ] + ), + ), + "kurt": ( + TypeError, + "|".join( + [ + "category type does not support kurt operations", + "dtype category does not support operation 'kurt'", + ] + ), + ), + "std": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support operation 'std'", + "category dtype does not support aggregation 'std'", + ] + ), + ), + "sum": (TypeError, "category type does not support sum operations"), + "var": ( + TypeError, + "|".join( + [ + "'Categorical' .* does not support operation 'var'", + "category dtype does not support aggregation 'var'", + ] + ), + ), + }[groupby_func] + + if groupby_func == "corrwith": + warn_category = Pandas4Warning + warn_msg = "DataFrameGroupBy.corrwith is deprecated" + else: + warn_category = None + warn_msg = "" + _call_and_check(klass, msg, how, gb, groupby_func, args, warn_category, warn_msg) diff --git a/pandas/tests/groupby/test_reductions.py b/pandas/tests/groupby/test_reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..890ce4e398f0104209799201a6c2a5ea4973edcb --- /dev/null +++ b/pandas/tests/groupby/test_reductions.py @@ -0,0 +1,1538 @@ +import builtins +import datetime as dt +from string import ascii_lowercase + +import numpy as np +import pytest + +from pandas._libs.tslibs import iNaT + +from pandas.core.dtypes.common import pandas_dtype +from pandas.core.dtypes.missing import na_value_for_dtype + +import pandas as pd +from pandas import ( + DataFrame, + MultiIndex, + Series, + Timestamp, + date_range, + isna, +) +import pandas._testing as tm +from pandas.tests.groupby import get_groupby_method_args +from pandas.util import _test_decorators as td + + +@pytest.mark.parametrize("dtype", ["int64", "int32", "float64", "float32"]) +def test_basic_aggregations(dtype): + data = Series(np.arange(9) // 3, index=np.arange(9), dtype=dtype) + + index = np.arange(9) + np.random.default_rng(2).shuffle(index) + data = data.reindex(index) + + grouped = data.groupby(lambda x: x // 3, group_keys=False) + + for k, v in grouped: + assert len(v) == 3 + + agged = grouped.aggregate(np.mean) + assert agged[1] == 1 + + expected = grouped.agg(np.mean) + tm.assert_series_equal(agged, expected) # shorthand + tm.assert_series_equal(agged, grouped.mean()) + result = grouped.sum() + expected = grouped.agg(np.sum) + if dtype == "int32": + # NumPy's sum returns int64 + expected = expected.astype("int32") + tm.assert_series_equal(result, expected) + + expected = grouped.apply(lambda x: x * x.sum()) + transformed = grouped.transform(lambda x: x * x.sum()) + assert transformed[7] == 12 + tm.assert_series_equal(transformed, expected) + + value_grouped = data.groupby(data) + result = value_grouped.aggregate(np.mean) + tm.assert_series_equal(result, agged, check_index_type=False) + + # complex agg + agged = grouped.aggregate([np.mean, np.std]) + + msg = r"nested renamer is not supported" + with pytest.raises(pd.errors.SpecificationError, match=msg): + grouped.aggregate({"one": np.mean, "two": np.std}) + + # corner cases + msg = "Must produce aggregated value" + # exception raised is type Exception + with pytest.raises(Exception, match=msg): + grouped.aggregate(lambda x: x * 2) + + +@pytest.mark.parametrize( + "vals", + [ + ["foo", "bar", "baz"], + ["foo", "", ""], + ["", "", ""], + [1, 2, 3], + [1, 0, 0], + [0, 0, 0], + [1.0, 2.0, 3.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [True, True, True], + [True, False, False], + [False, False, False], + [np.nan, np.nan, np.nan], + ], +) +def test_groupby_bool_aggs(skipna, all_boolean_reductions, vals): + df = DataFrame({"key": ["a"] * 3 + ["b"] * 3, "val": vals * 2}) + + # Figure out expectation using Python builtin + exp = getattr(builtins, all_boolean_reductions)(vals) + + # edge case for missing data with skipna and 'any' + if skipna and all(isna(vals)) and all_boolean_reductions == "any": + exp = False + + expected = DataFrame( + [exp] * 2, columns=["val"], index=pd.Index(["a", "b"], name="key") + ) + result = getattr(df.groupby("key"), all_boolean_reductions)(skipna=skipna) + tm.assert_frame_equal(result, expected) + + +def test_any(): + df = DataFrame( + [[1, 2, "foo"], [1, np.nan, "bar"], [3, np.nan, "baz"]], + columns=["A", "B", "C"], + ) + expected = DataFrame( + [[True, True], [False, True]], columns=["B", "C"], index=[1, 3] + ) + expected.index.name = "A" + result = df.groupby("A").any() + tm.assert_frame_equal(result, expected) + + +def test_bool_aggs_dup_column_labels(all_boolean_reductions): + # GH#21668 + df = DataFrame([[True, True]], columns=["a", "a"]) + grp_by = df.groupby([0]) + result = getattr(grp_by, all_boolean_reductions)() + + expected = df.set_axis(np.array([0])) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "data", + [ + [False, False, False], + [True, True, True], + [pd.NA, pd.NA, pd.NA], + [False, pd.NA, False], + [True, pd.NA, True], + [True, pd.NA, False], + ], +) +def test_masked_kleene_logic(all_boolean_reductions, skipna, data): + # GH#37506 + ser = Series(data, dtype="boolean") + + # The result should match aggregating on the whole series. Correctness + # there is verified in test_reductions.py::test_any_all_boolean_kleene_logic + expected_data = getattr(ser, all_boolean_reductions)(skipna=skipna) + expected = Series(expected_data, index=np.array([0]), dtype="boolean") + + result = ser.groupby([0, 0, 0]).agg(all_boolean_reductions, skipna=skipna) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype1,dtype2,exp_col1,exp_col2", + [ + ( + "float", + "Float64", + np.array([True], dtype=bool), + pd.array([pd.NA], dtype="boolean"), + ), + ( + "Int64", + "float", + pd.array([pd.NA], dtype="boolean"), + np.array([True], dtype=bool), + ), + ( + "Int64", + "Int64", + pd.array([pd.NA], dtype="boolean"), + pd.array([pd.NA], dtype="boolean"), + ), + ( + "Float64", + "boolean", + pd.array([pd.NA], dtype="boolean"), + pd.array([pd.NA], dtype="boolean"), + ), + ], +) +def test_masked_mixed_types(dtype1, dtype2, exp_col1, exp_col2): + # GH#37506 + data1 = [1.0, np.nan] if dtype1.startswith("f") else [1.0, pd.NA] + data2 = [1.0, np.nan] if dtype2.startswith("f") else [1.0, pd.NA] + df = DataFrame( + {"col1": pd.array(data1, dtype=dtype1), "col2": pd.array(data2, dtype=dtype2)} + ) + result = df.groupby([1, 1]).agg("all", skipna=False) + + expected = DataFrame({"col1": exp_col1, "col2": exp_col2}, index=np.array([1])) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"]) +def test_masked_bool_aggs_skipna( + all_boolean_reductions, dtype, skipna, frame_or_series +): + # GH#40585 + obj = frame_or_series([pd.NA, 1], dtype=dtype) + expected_res = True + if not skipna and all_boolean_reductions == "all": + expected_res = pd.NA + expected = frame_or_series([expected_res], index=np.array([1]), dtype="boolean") + + result = obj.groupby([1, 1]).agg(all_boolean_reductions, skipna=skipna) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "bool_agg_func,data,expected_res", + [ + ("any", [pd.NA, np.nan], False), + ("any", [pd.NA, 1, np.nan], True), + ("all", [pd.NA, pd.NaT], True), + ("all", [pd.NA, False, pd.NaT], False), + ], +) +def test_object_type_missing_vals(bool_agg_func, data, expected_res, frame_or_series): + # GH#37501 + obj = frame_or_series(data, dtype=object) + result = obj.groupby([1] * len(data)).agg(bool_agg_func) + expected = frame_or_series([expected_res], index=np.array([1]), dtype="bool") + tm.assert_equal(result, expected) + + +def test_object_NA_raises_with_skipna_false(all_boolean_reductions): + # GH#37501 + ser = Series([pd.NA], dtype=object) + with pytest.raises(TypeError, match="boolean value of NA is ambiguous"): + ser.groupby([1]).agg(all_boolean_reductions, skipna=False) + + +def test_empty(frame_or_series, all_boolean_reductions): + # GH 45231 + kwargs = {"columns": ["a"]} if frame_or_series is DataFrame else {"name": "a"} + obj = frame_or_series(**kwargs, dtype=object) + result = getattr(obj.groupby(obj.index), all_boolean_reductions)() + expected = frame_or_series(**kwargs, dtype=bool) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize("how", ["idxmin", "idxmax"]) +def test_idxmin_idxmax_extremes(how, any_real_numpy_dtype): + # GH#57040 + if any_real_numpy_dtype is int or any_real_numpy_dtype is float: + # No need to test + return + info = np.iinfo if "int" in any_real_numpy_dtype else np.finfo + min_value = info(any_real_numpy_dtype).min + max_value = info(any_real_numpy_dtype).max + df = DataFrame( + {"a": [2, 1, 1, 2], "b": [min_value, max_value, max_value, min_value]}, + dtype=any_real_numpy_dtype, + ) + gb = df.groupby("a") + result = getattr(gb, how)() + expected = DataFrame( + {"b": [1, 0]}, index=pd.Index([1, 2], name="a", dtype=any_real_numpy_dtype) + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("how", ["idxmin", "idxmax"]) +def test_idxmin_idxmax_extremes_skipna(skipna, how, float_numpy_dtype): + # GH#57040 + min_value = np.finfo(float_numpy_dtype).min + max_value = np.finfo(float_numpy_dtype).max + df = DataFrame( + { + "a": Series(np.repeat(range(1, 5), repeats=2), dtype="intp"), + "b": Series( + [ + np.nan, + min_value, + np.nan, + max_value, + min_value, + np.nan, + max_value, + np.nan, + ], + dtype=float_numpy_dtype, + ), + }, + ) + gb = df.groupby("a") + + if not skipna: + msg = f"{how} with skipna=False" + with pytest.raises(ValueError, match=msg): + getattr(gb, how)(skipna=skipna) + return + result = getattr(gb, how)(skipna=skipna) + expected = DataFrame( + {"b": [1, 3, 4, 6]}, index=pd.Index(range(1, 5), name="a", dtype="intp") + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "func, values", + [ + ("idxmin", {"c_int": [0, 2], "c_float": [1, 3], "c_date": [1, 2]}), + ("idxmax", {"c_int": [1, 3], "c_float": [0, 2], "c_date": [0, 3]}), + ], +) +@pytest.mark.parametrize("numeric_only", [True, False]) +def test_idxmin_idxmax_returns_int_types(func, values, numeric_only): + # GH 25444 + df = DataFrame( + { + "name": ["A", "A", "B", "B"], + "c_int": [1, 2, 3, 4], + "c_float": [4.02, 3.03, 2.04, 1.05], + "c_date": ["2019", "2018", "2016", "2017"], + } + ) + df["c_date"] = pd.to_datetime(df["c_date"]) + df["c_date_tz"] = df["c_date"].dt.tz_localize("US/Pacific") + df["c_timedelta"] = df["c_date"] - df["c_date"].iloc[0] + df["c_period"] = df["c_date"].dt.to_period("W") + df["c_Integer"] = df["c_int"].astype("Int64") + df["c_Floating"] = df["c_float"].astype("Float64") + + result = getattr(df.groupby("name"), func)(numeric_only=numeric_only) + + expected = DataFrame(values, index=pd.Index(["A", "B"], name="name")) + if numeric_only: + expected = expected.drop(columns=["c_date"]) + else: + expected["c_date_tz"] = expected["c_date"] + expected["c_timedelta"] = expected["c_date"] + expected["c_period"] = expected["c_date"] + expected["c_Integer"] = expected["c_int"] + expected["c_Floating"] = expected["c_float"] + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "data", + [ + ( + Timestamp("2011-01-15 12:50:28.502376"), + Timestamp("2011-01-20 12:50:28.593448"), + ), + (24650000000000001, 24650000000000002), + ], +) +@pytest.mark.parametrize("method", ["count", "min", "max", "first", "last"]) +def test_groupby_non_arithmetic_agg_int_like_precision(method, data): + # GH#6620, GH#9311 + df = DataFrame({"a": [1, 1], "b": data}) + + grouped = df.groupby("a") + result = getattr(grouped, method)() + if method == "count": + expected_value = 2 + elif method == "first": + expected_value = data[0] + elif method == "last": + expected_value = data[1] + else: + expected_value = getattr(df["b"], method)() + expected = DataFrame({"b": [expected_value]}, index=pd.Index([1], name="a")) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("how", ["first", "last"]) +def test_first_last_skipna(any_real_nullable_dtype, sort, skipna, how): + # GH#57019 + na_value = na_value_for_dtype(pandas_dtype(any_real_nullable_dtype)) + df = DataFrame( + { + "a": [2, 1, 1, 2, 3, 3], + # TODO: test that has mixed na_value and NaN either working for + # float or raising for int? + "b": [na_value, 3.0, na_value, 4.0, na_value, na_value], + "c": [na_value, 3.0, na_value, 4.0, na_value, na_value], + }, + dtype=any_real_nullable_dtype, + ) + gb = df.groupby("a", sort=sort) + method = getattr(gb, how) + result = method(skipna=skipna) + + ilocs = { + ("first", True): [3, 1, 4], + ("first", False): [0, 1, 4], + ("last", True): [3, 1, 5], + ("last", False): [3, 2, 5], + }[how, skipna] + expected = df.iloc[ilocs].set_index("a") + if sort: + expected = expected.sort_index() + tm.assert_frame_equal(result, expected) + + +def test_groupby_mean_no_overflow(): + # Regression test for (#22487) + df = DataFrame( + { + "user": ["A", "A", "A", "A", "A"], + "connections": [4970, 4749, 4719, 4704, 18446744073699999744], + } + ) + assert df.groupby("user")["connections"].mean()["A"] == 3689348814740003840 + + +def test_mean_on_timedelta(): + # GH 17382 + df = DataFrame({"time": pd.to_timedelta(range(10)), "cat": ["A", "B"] * 5}) + result = df.groupby("cat")["time"].mean() + expected = Series( + pd.to_timedelta([4, 5]), name="time", index=pd.Index(["A", "B"], name="cat") + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "values, dtype, result_dtype", + [ + ([0, 1, np.nan, 3, 4, 5, 6, 7, 8, 9], "float64", "float64"), + ([0, 1, pd.NA, 3, 4, 5, 6, 7, 8, 9], "Float64", "Float64"), + ([0, 1, pd.NA, 3, 4, 5, 6, 7, 8, 9], "Int64", "Float64"), + ([0, 1, np.nan, 3, 4, 5, 6, 7, 8, 9], "timedelta64[ns]", "timedelta64[ns]"), + ( + pd.to_datetime( + [ + "2019-05-09", + pd.NaT, + "2019-05-11", + "2019-05-12", + "2019-05-13", + "2019-05-14", + "2019-05-15", + "2019-05-16", + "2019-05-17", + "2019-05-18", + ] + ), + "datetime64[ns]", + "datetime64[ns]", + ), + ], +) +def test_mean_skipna(values, dtype, result_dtype, skipna): + # GH#15675 + df = DataFrame( + { + "val": values, + "cat": ["A", "B"] * 5, + } + ).astype({"val": dtype}) + # We need to recast the expected values to the result_dtype because + # Series.mean() changes the dtype to float64/object depending on the input dtype + expected = ( + df.groupby("cat")["val"] + .apply(lambda x: x.mean(skipna=skipna)) + .astype(result_dtype) + ) + result = df.groupby("cat")["val"].mean(skipna=skipna) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "values, dtype", + [ + ([0, 1, np.nan, 3, 4, 5, 6, 7, 8, 9], "float64"), + ([0, 1, pd.NA, 3, 4, 5, 6, 7, 8, 9], "Float64"), + ([0, 1, pd.NA, 3, 4, 5, 6, 7, 8, 9], "Int64"), + ([0, 1, np.nan, 3, 4, 5, 6, 7, 8, 9], "timedelta64[ns]"), + ], +) +def test_sum_skipna(values, dtype, skipna): + # GH#15675 + df = DataFrame( + { + "val": values, + "cat": ["A", "B"] * 5, + } + ).astype({"val": dtype}) + # We need to recast the expected values to the original dtype because + # Series.sum() changes the dtype + expected = ( + df.groupby("cat")["val"].apply(lambda x: x.sum(skipna=skipna)).astype(dtype) + ) + result = df.groupby("cat")["val"].sum(skipna=skipna) + tm.assert_series_equal(result, expected) + + +def test_sum_skipna_object(skipna): + # GH#15675 + df = DataFrame( + { + "val": ["a", "b", np.nan, "d", "e", "f", "g", "h", "i", "j"], + "cat": ["A", "B"] * 5, + } + ).astype({"val": object}) + if skipna: + expected = Series( + ["aegi", "bdfhj"], index=pd.Index(["A", "B"], name="cat"), name="val" + ).astype(object) + else: + expected = Series( + [np.nan, "bdfhj"], index=pd.Index(["A", "B"], name="cat"), name="val" + ).astype(object) + result = df.groupby("cat")["val"].sum(skipna=skipna) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "func, values, dtype, result_dtype", + [ + ("prod", [0, 1, 3, np.nan, 4, 5, 6, 7, -8, 9], "float64", "float64"), + ("prod", [0, -1, 3, 4, 5, pd.NA, 6, 7, 8, 9], "Float64", "Float64"), + ("prod", [0, 1, 3, -4, 5, 6, 7, -8, pd.NA, 9], "Int64", "Int64"), + ("prod", [np.nan] * 10, "float64", "float64"), + ("prod", [pd.NA] * 10, "Float64", "Float64"), + ("prod", [pd.NA] * 10, "Int64", "Int64"), + ("var", [0, -1, 3, 4, np.nan, 5, 6, 7, 8, 9], "float64", "float64"), + ("var", [0, 1, 3, -4, 5, 6, 7, -8, 9, pd.NA], "Float64", "Float64"), + ("var", [0, -1, 3, 4, 5, -6, 7, pd.NA, 8, 9], "Int64", "Float64"), + ("var", [np.nan] * 10, "float64", "float64"), + ("var", [pd.NA] * 10, "Float64", "Float64"), + ("var", [pd.NA] * 10, "Int64", "Float64"), + ("std", [0, 1, 3, -4, 5, 6, 7, -8, np.nan, 9], "float64", "float64"), + ("std", [0, -1, 3, 4, 5, -6, 7, pd.NA, 8, 9], "Float64", "Float64"), + ("std", [0, 1, 3, -4, 5, 6, 7, -8, 9, pd.NA], "Int64", "Float64"), + ("std", [np.nan] * 10, "float64", "float64"), + ("std", [pd.NA] * 10, "Float64", "Float64"), + ("std", [pd.NA] * 10, "Int64", "Float64"), + ("sem", [0, -1, 3, 4, 5, -6, 7, np.nan, 8, 9], "float64", "float64"), + ("sem", [0, 1, 3, -4, 5, 6, 7, -8, pd.NA, 9], "Float64", "Float64"), + ("sem", [0, -1, 3, 4, 5, -6, 7, 8, 9, pd.NA], "Int64", "Float64"), + ("sem", [np.nan] * 10, "float64", "float64"), + ("sem", [pd.NA] * 10, "Float64", "Float64"), + ("sem", [pd.NA] * 10, "Int64", "Float64"), + ("min", [0, -1, 3, 4, 5, -6, 7, np.nan, 8, 9], "float64", "float64"), + ("min", [0, 1, 3, -4, 5, 6, 7, -8, pd.NA, 9], "Float64", "Float64"), + ("min", [0, -1, 3, 4, 5, -6, 7, 8, 9, pd.NA], "Int64", "Int64"), + ( + "min", + [0, 1, np.nan, 3, 4, 5, 6, 7, 8, 9], + "timedelta64[ns]", + "timedelta64[ns]", + ), + ( + "min", + pd.to_datetime( + [ + "2019-05-09", + pd.NaT, + "2019-05-11", + "2019-05-12", + "2019-05-13", + "2019-05-14", + "2019-05-15", + "2019-05-16", + "2019-05-17", + "2019-05-18", + ] + ), + "datetime64[ns]", + "datetime64[ns]", + ), + ("min", [np.nan] * 10, "float64", "float64"), + ("min", [pd.NA] * 10, "Float64", "Float64"), + ("min", [pd.NA] * 10, "Int64", "Int64"), + ("max", [0, -1, 3, 4, 5, -6, 7, np.nan, 8, 9], "float64", "float64"), + ("max", [0, 1, 3, -4, 5, 6, 7, -8, pd.NA, 9], "Float64", "Float64"), + ("max", [0, -1, 3, 4, 5, -6, 7, 8, 9, pd.NA], "Int64", "Int64"), + ( + "max", + [0, 1, np.nan, 3, 4, 5, 6, 7, 8, 9], + "timedelta64[ns]", + "timedelta64[ns]", + ), + ( + "max", + pd.to_datetime( + [ + "2019-05-09", + pd.NaT, + "2019-05-11", + "2019-05-12", + "2019-05-13", + "2019-05-14", + "2019-05-15", + "2019-05-16", + "2019-05-17", + "2019-05-18", + ] + ), + "datetime64[ns]", + "datetime64[ns]", + ), + ("max", [np.nan] * 10, "float64", "float64"), + ("max", [pd.NA] * 10, "Float64", "Float64"), + ("max", [pd.NA] * 10, "Int64", "Int64"), + ("median", [0, -1, 3, 4, 5, -6, 7, np.nan, 8, 9], "float64", "float64"), + ("median", [0, 1, 3, -4, 5, 6, 7, -8, pd.NA, 9], "Float64", "Float64"), + ("median", [0, -1, 3, 4, 5, -6, 7, 8, 9, pd.NA], "Int64", "Float64"), + ( + "median", + [0, 1, np.nan, 3, 4, 5, 6, 7, 8, 9], + "timedelta64[ns]", + "timedelta64[ns]", + ), + ( + "median", + pd.to_datetime( + [ + "2019-05-09", + pd.NaT, + "2019-05-11", + "2019-05-12", + "2019-05-13", + "2019-05-14", + "2019-05-15", + "2019-05-16", + "2019-05-17", + "2019-05-18", + ] + ), + "datetime64[ns]", + "datetime64[ns]", + ), + ("median", [np.nan] * 10, "float64", "float64"), + ("median", [pd.NA] * 10, "Float64", "Float64"), + ("median", [pd.NA] * 10, "Int64", "Float64"), + ], +) +def test_multifunc_skipna(func, values, dtype, result_dtype, skipna): + # GH#15675 + df = DataFrame( + { + "val": values, + "cat": ["A", "B"] * 5, + } + ).astype({"val": dtype}) + # We need to recast the expected values to the result_dtype as some operations + # change the dtype + expected = ( + df.groupby("cat")["val"] + .apply(lambda x: getattr(x, func)(skipna=skipna)) + .astype(result_dtype) + ) + result = getattr(df.groupby("cat")["val"], func)(skipna=skipna) + tm.assert_series_equal(result, expected) + + +def test_cython_median(): + arr = np.random.default_rng(2).standard_normal(1000) + arr[::2] = np.nan + df = DataFrame(arr) + + labels = np.random.default_rng(2).integers(0, 50, size=1000).astype(float) + labels[::17] = np.nan + + result = df.groupby(labels).median() + exp = df.groupby(labels).agg(np.nanmedian) + tm.assert_frame_equal(result, exp) + + df = DataFrame(np.random.default_rng(2).standard_normal((1000, 5))) + rs = df.groupby(labels).agg(np.median) + xp = df.groupby(labels).median() + tm.assert_frame_equal(rs, xp) + + +def test_median_empty_bins(observed): + df = DataFrame(np.random.default_rng(2).integers(0, 44, 500)) + + grps = range(0, 55, 5) + bins = pd.cut(df[0], grps) + + result = df.groupby(bins, observed=observed).median() + expected = df.groupby(bins, observed=observed).agg(lambda x: x.median()) + tm.assert_frame_equal(result, expected) + + +def test_max_min_non_numeric(): + # #2700 + aa = DataFrame({"nn": [11, 11, 22, 22], "ii": [1, 2, 3, 4], "ss": 4 * ["mama"]}) + + result = aa.groupby("nn").max() + assert "ss" in result + + result = aa.groupby("nn").max(numeric_only=False) + assert "ss" in result + + result = aa.groupby("nn").min() + assert "ss" in result + + result = aa.groupby("nn").min(numeric_only=False) + assert "ss" in result + + +def test_max_min_object_multiple_columns(using_infer_string): + # GH#41111 case where the aggregation is valid for some columns but not + # others; we split object blocks column-wise, consistent with + # DataFrame._reduce + + df = DataFrame( + { + "A": [1, 1, 2, 2, 3], + "B": [1, "foo", 2, "bar", False], + "C": ["a", "b", "c", "d", "e"], + } + ) + df._consolidate_inplace() # should already be consolidate, but double-check + assert len(df._mgr.blocks) == 3 if using_infer_string else 2 + + gb = df.groupby("A") + + result = gb[["C"]].max() + # "max" is valid for column "C" but not for "B" + ei = pd.Index([1, 2, 3], name="A") + expected = DataFrame({"C": ["b", "d", "e"]}, index=ei) + tm.assert_frame_equal(result, expected) + + result = gb[["C"]].min() + # "min" is valid for column "C" but not for "B" + ei = pd.Index([1, 2, 3], name="A") + expected = DataFrame({"C": ["a", "c", "e"]}, index=ei) + tm.assert_frame_equal(result, expected) + + +def test_min_date_with_nans(): + # GH26321 + dates = pd.to_datetime( + Series(["2019-05-09", "2019-05-09", "2019-05-09"]), format="%Y-%m-%d" + ).dt.date + df = DataFrame({"a": [np.nan, "1", np.nan], "b": [0, 1, 1], "c": dates}) + + result = df.groupby("b", as_index=False)["c"].min()["c"] + expected = pd.to_datetime( + Series(["2019-05-09", "2019-05-09"], name="c"), format="%Y-%m-%d" + ).dt.date + tm.assert_series_equal(result, expected) + + result = df.groupby("b")["c"].min() + expected.index.name = "b" + tm.assert_series_equal(result, expected) + + +def test_max_inat(): + # GH#40767 dont interpret iNaT as NaN + ser = Series([1, iNaT]) + key = np.array([1, 1], dtype=np.int64) + gb = ser.groupby(key) + + result = gb.max(min_count=2) + expected = Series({1: 1}, dtype=np.int64) + tm.assert_series_equal(result, expected, check_exact=True) + + result = gb.min(min_count=2) + expected = Series({1: iNaT}, dtype=np.int64) + tm.assert_series_equal(result, expected, check_exact=True) + + # not enough entries -> gets masked to NaN + result = gb.min(min_count=3) + expected = Series({1: np.nan}) + tm.assert_series_equal(result, expected, check_exact=True) + + +def test_max_inat_not_all_na(): + # GH#40767 dont interpret iNaT as NaN + + # make sure we dont round iNaT+1 to iNaT + ser = Series([1, iNaT, 2, iNaT + 1]) + gb = ser.groupby([1, 2, 3, 3]) + result = gb.min(min_count=2) + + # Note: in converting to float64, the iNaT + 1 maps to iNaT, i.e. is lossy + expected = Series({1: np.nan, 2: np.nan, 3: iNaT + 1}) + expected.index = expected.index.astype(int) + tm.assert_series_equal(result, expected, check_exact=True) + + +@pytest.mark.parametrize("func", ["min", "max"]) +def test_groupby_aggregate_period_column(func): + # GH 31471 + groups = [1, 2] + periods = pd.period_range("2020", periods=2, freq="Y") + df = DataFrame({"a": groups, "b": periods}) + + result = getattr(df.groupby("a")["b"], func)() + idx = pd.Index([1, 2], name="a") + expected = Series(periods, index=idx, name="b") + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["min", "max"]) +def test_groupby_aggregate_period_frame(func): + # GH 31471 + groups = [1, 2] + periods = pd.period_range("2020", periods=2, freq="Y") + df = DataFrame({"a": groups, "b": periods}) + + result = getattr(df.groupby("a"), func)() + idx = pd.Index([1, 2], name="a") + expected = DataFrame({"b": periods}, index=idx) + + tm.assert_frame_equal(result, expected) + + +def test_aggregate_numeric_object_dtype(): + # https://github.com/pandas-dev/pandas/issues/39329 + # simplified case: multiple object columns where one is all-NaN + # -> gets split as the all-NaN is inferred as float + df = DataFrame( + {"key": ["A", "A", "B", "B"], "col1": list("abcd"), "col2": [np.nan] * 4}, + ).astype(object) + result = df.groupby("key").min() + expected = ( + DataFrame( + {"key": ["A", "B"], "col1": ["a", "c"], "col2": [np.nan, np.nan]}, + ) + .set_index("key") + .astype(object) + ) + tm.assert_frame_equal(result, expected) + + # same but with numbers + df = DataFrame( + {"key": ["A", "A", "B", "B"], "col1": list("abcd"), "col2": range(4)}, + ).astype(object) + result = df.groupby("key").min() + expected = ( + DataFrame({"key": ["A", "B"], "col1": ["a", "c"], "col2": [0, 2]}) + .set_index("key") + .astype(object) + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["min", "max"]) +def test_aggregate_categorical_lost_index(func: str): + # GH: 28641 groupby drops index, when grouping over categorical column with min/max + ds = Series(["b"], dtype="category").cat.as_ordered() + df = DataFrame({"A": [1997], "B": ds}) + result = df.groupby("A").agg({"B": func}) + expected = DataFrame({"B": ["b"]}, index=pd.Index([1997], name="A")) + + # ordered categorical dtype should be preserved + expected["B"] = expected["B"].astype(ds.dtype) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", ["Int64", "Int32", "Float64", "Float32", "boolean"]) +def test_groupby_min_max_nullable(dtype): + if dtype == "Int64": + # GH#41743 avoid precision loss + ts = 1618556707013635762 + elif dtype == "boolean": + ts = 0 + else: + ts = 4.0 + + df = DataFrame({"id": [2, 2], "ts": [ts, ts + 1]}) + df["ts"] = df["ts"].astype(dtype) + + gb = df.groupby("id") + + result = gb.min() + expected = df.iloc[:1].set_index("id") + tm.assert_frame_equal(result, expected) + + res_max = gb.max() + expected_max = df.iloc[1:].set_index("id") + tm.assert_frame_equal(res_max, expected_max) + + result2 = gb.min(min_count=3) + expected2 = DataFrame({"ts": [pd.NA]}, index=expected.index, dtype=dtype) + tm.assert_frame_equal(result2, expected2) + + res_max2 = gb.max(min_count=3) + tm.assert_frame_equal(res_max2, expected2) + + # Case with NA values + df2 = DataFrame({"id": [2, 2, 2], "ts": [ts, pd.NA, ts + 1]}) + df2["ts"] = df2["ts"].astype(dtype) + gb2 = df2.groupby("id") + + result3 = gb2.min() + tm.assert_frame_equal(result3, expected) + + res_max3 = gb2.max() + tm.assert_frame_equal(res_max3, expected_max) + + result4 = gb2.min(min_count=100) + tm.assert_frame_equal(result4, expected2) + + res_max4 = gb2.max(min_count=100) + tm.assert_frame_equal(res_max4, expected2) + + +def test_min_max_nullable_uint64_empty_group(): + # don't raise NotImplementedError from libgroupby + cat = pd.Categorical([0] * 10, categories=[0, 1]) + df = DataFrame({"A": cat, "B": pd.array(np.arange(10, dtype=np.uint64))}) + gb = df.groupby("A", observed=False) + + res = gb.min() + + idx = pd.CategoricalIndex([0, 1], dtype=cat.dtype, name="A") + expected = DataFrame({"B": pd.array([0, pd.NA], dtype="UInt64")}, index=idx) + tm.assert_frame_equal(res, expected) + + res = gb.max() + expected.iloc[0, 0] = 9 + tm.assert_frame_equal(res, expected) + + +@pytest.mark.parametrize("func", ["first", "last", "min", "max"]) +def test_groupby_min_max_categorical(func): + # GH: 52151 + df = DataFrame( + { + "col1": pd.Categorical(["A"], categories=list("AB"), ordered=True), + "col2": pd.Categorical([1], categories=[1, 2], ordered=True), + "value": 0.1, + } + ) + result = getattr(df.groupby("col1", observed=False), func)() + + idx = pd.CategoricalIndex(data=["A", "B"], name="col1", ordered=True) + expected = DataFrame( + { + "col2": pd.Categorical([1, None], categories=[1, 2], ordered=True), + "value": [0.1, None], + }, + index=idx, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("func", ["min", "max"]) +def test_min_empty_string_dtype(func, string_dtype_no_object): + # GH#55619 + dtype = string_dtype_no_object + df = DataFrame({"a": ["a"], "b": "a", "c": "a"}, dtype=dtype).iloc[:0] + result = getattr(df.groupby("a"), func)() + expected = DataFrame( + columns=["b", "c"], dtype=dtype, index=pd.Index([], dtype=dtype, name="a") + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("min_count", [0, 1]) +@pytest.mark.parametrize("test_series", [True, False]) +def test_string_dtype_all_na( + string_dtype_no_object, reduction_func, skipna, min_count, test_series +): + # https://github.com/pandas-dev/pandas/issues/60985 + if reduction_func == "corrwith": + # corrwith is deprecated. + return + + dtype = string_dtype_no_object + + if reduction_func in [ + "any", + "all", + "idxmin", + "idxmax", + "mean", + "median", + "std", + "var", + ]: + kwargs = {"skipna": skipna} + elif reduction_func in ["kurt"]: + kwargs = {"min_count": min_count} + elif reduction_func in ["count", "nunique", "quantile", "sem", "size"]: + kwargs = {} + else: + kwargs = {"skipna": skipna, "min_count": min_count} + + expected_dtype, expected_value = dtype, pd.NA + if reduction_func in ["all", "any"]: + expected_dtype = "bool" + # TODO: For skipna=False, bool(pd.NA) raises; should groupby? + expected_value = not skipna if reduction_func == "any" else True + elif reduction_func in ["count", "nunique", "size"]: + # TODO: Should be more consistent - return Int64 when dtype.na_value is pd.NA? + if ( + test_series + and reduction_func == "size" + and dtype.storage == "pyarrow" + and dtype.na_value is pd.NA + ): + expected_dtype = "Int64" + else: + expected_dtype = "int64" + expected_value = 1 if reduction_func == "size" else 0 + elif not skipna or min_count > 0: + expected_value = pd.NA + elif reduction_func == "sum": + # https://github.com/pandas-dev/pandas/pull/60936 + expected_value = "" + + df = DataFrame({"a": ["x"], "b": [pd.NA]}, dtype=dtype) + obj = df["b"] if test_series else df + args = get_groupby_method_args(reduction_func, obj) + gb = obj.groupby(df["a"]) + method = getattr(gb, reduction_func) + + if reduction_func in [ + "mean", + "median", + "kurt", + "prod", + "quantile", + "sem", + "skew", + "std", + "var", + ]: + msg = f"dtype '{dtype}' does not support operation '{reduction_func}'" + with pytest.raises(TypeError, match=msg): + method(*args, **kwargs) + return + elif reduction_func in ["idxmin", "idxmax"]: + if skipna: + msg = f"{reduction_func} with skipna=True encountered all NA values" + else: + msg = f"{reduction_func} with skipna=False encountered an NA value." + with pytest.raises(ValueError, match=msg): + method(*args, **kwargs) + return + + result = method(*args, **kwargs) + index = pd.Index(["x"], name="a", dtype=dtype) + if test_series or reduction_func == "size": + name = None if not test_series and reduction_func == "size" else "b" + expected = Series(expected_value, index=index, dtype=expected_dtype, name=name) + else: + expected = DataFrame({"b": expected_value}, index=index, dtype=expected_dtype) + tm.assert_equal(result, expected) + + +def test_max_nan_bug(): + df = DataFrame( + { + "Unnamed: 0": ["-04-23", "-05-06", "-05-07"], + "Date": [ + "2013-04-23 00:00:00", + "2013-05-06 00:00:00", + "2013-05-07 00:00:00", + ], + "app": Series([np.nan, np.nan, "OE"]), + "File": ["log080001.log", "log.log", "xlsx"], + } + ) + gb = df.groupby("Date") + r = gb[["File"]].max() + e = gb["File"].max().to_frame() + tm.assert_frame_equal(r, e) + assert not r["File"].isna().any() + + +@pytest.mark.slow +@pytest.mark.parametrize("with_nan", [True, False]) +@pytest.mark.parametrize("keys", [["joe"], ["joe", "jim"]]) +def test_series_groupby_nunique(sort, dropna, as_index, with_nan, keys): + n = 100 + m = 10 + days = date_range("2015-08-23", periods=10) + df = DataFrame( + { + "jim": np.random.default_rng(2).choice(list(ascii_lowercase), n), + "joe": np.random.default_rng(2).choice(days, n), + "julie": np.random.default_rng(2).integers(0, m, n), + } + ) + if with_nan: + df = df.astype({"julie": float}) # Explicit cast to avoid implicit cast below + df.loc[1::17, "jim"] = None + df.loc[3::37, "joe"] = None + df.loc[7::19, "julie"] = None + df.loc[8::19, "julie"] = None + df.loc[9::19, "julie"] = None + original_df = df.copy() + gr = df.groupby(keys, as_index=as_index, sort=sort) + left = gr["julie"].nunique(dropna=dropna) + + gr = df.groupby(keys, as_index=as_index, sort=sort) + right = gr["julie"].apply(Series.nunique, dropna=dropna) + if not as_index: + right = right.reset_index(drop=True) + + if as_index: + tm.assert_series_equal(left, right, check_names=False) + else: + tm.assert_frame_equal(left, right, check_names=False) + tm.assert_frame_equal(df, original_df) + + +def test_nunique(): + df = DataFrame({"A": list("abbacc"), "B": list("abxacc"), "C": list("abbacx")}) + + expected = DataFrame({"A": list("abc"), "B": [1, 2, 1], "C": [1, 1, 2]}) + result = df.groupby("A", as_index=False).nunique() + tm.assert_frame_equal(result, expected) + + # as_index + expected.index = list("abc") + expected.index.name = "A" + expected = expected.drop(columns="A") + result = df.groupby("A").nunique() + tm.assert_frame_equal(result, expected) + + # with na + result = df.replace({"x": None}).groupby("A").nunique(dropna=False) + tm.assert_frame_equal(result, expected) + + # dropna + expected = DataFrame({"B": [1] * 3, "C": [1] * 3}, index=list("abc")) + expected.index.name = "A" + result = df.replace({"x": None}).groupby("A").nunique() + tm.assert_frame_equal(result, expected) + + +def test_nunique_with_object(): + # GH 11077 + data = DataFrame( + [ + [100, 1, "Alice"], + [200, 2, "Bob"], + [300, 3, "Charlie"], + [-400, 4, "Dan"], + [500, 5, "Edith"], + ], + columns=["amount", "id", "name"], + ) + + result = data.groupby(["id", "amount"])["name"].nunique() + index = MultiIndex.from_arrays([data.id, data.amount]) + expected = Series([1] * 5, name="name", index=index) + tm.assert_series_equal(result, expected) + + +def test_nunique_with_empty_series(): + # GH 12553 + data = Series(name="name", dtype=object) + result = data.groupby(level=0).nunique() + expected = Series(name="name", dtype="int64") + tm.assert_series_equal(result, expected) + + +def test_nunique_with_timegrouper(): + # GH 13453 + test = DataFrame( + { + "time": [ + Timestamp("2016-06-28 09:35:35"), + Timestamp("2016-06-28 16:09:30"), + Timestamp("2016-06-28 16:46:28"), + ], + "data": ["1", "2", "3"], + } + ).set_index("time") + result = test.groupby(pd.Grouper(freq="h"))["data"].nunique() + expected = test.groupby(pd.Grouper(freq="h"))["data"].apply(Series.nunique) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "key, data, dropna, expected", + [ + ( + ["x", "x", "x"], + [Timestamp("2019-01-01"), pd.NaT, Timestamp("2019-01-01")], + True, + Series([1], index=pd.Index(["x"], name="key"), name="data"), + ), + ( + ["x", "x", "x"], + [dt.date(2019, 1, 1), pd.NaT, dt.date(2019, 1, 1)], + True, + Series([1], index=pd.Index(["x"], name="key"), name="data"), + ), + ( + ["x", "x", "x", "y", "y"], + [ + dt.date(2019, 1, 1), + pd.NaT, + dt.date(2019, 1, 1), + pd.NaT, + dt.date(2019, 1, 1), + ], + False, + Series([2, 2], index=pd.Index(["x", "y"], name="key"), name="data"), + ), + ( + ["x", "x", "x", "x", "y"], + [ + dt.date(2019, 1, 1), + pd.NaT, + dt.date(2019, 1, 1), + pd.NaT, + dt.date(2019, 1, 1), + ], + False, + Series([2, 1], index=pd.Index(["x", "y"], name="key"), name="data"), + ), + ], +) +def test_nunique_with_NaT(key, data, dropna, expected): + # GH 27951 + df = DataFrame({"key": key, "data": data}) + result = df.groupby(["key"])["data"].nunique(dropna=dropna) + tm.assert_series_equal(result, expected) + + +def test_nunique_preserves_column_level_names(): + # GH 23222 + test = DataFrame([1, 2, 2], columns=pd.Index(["A"], name="level_0")) + result = test.groupby([0, 0, 0]).nunique() + expected = DataFrame([2], index=np.array([0]), columns=test.columns) + tm.assert_frame_equal(result, expected) + + +def test_nunique_transform_with_datetime(): + # GH 35109 - transform with nunique on datetimes results in integers + df = DataFrame(date_range("2008-12-31", "2009-01-02"), columns=["date"]) + result = df.groupby([0, 0, 1])["date"].transform("nunique") + expected = Series([2, 2, 1], name="date") + tm.assert_series_equal(result, expected) + + +def test_empty_categorical(observed): + # GH#21334 + cat = Series([1]).astype("category") + ser = cat[:0] + gb = ser.groupby(ser, observed=observed) + result = gb.nunique() + if observed: + expected = Series([], index=cat[:0], dtype="int64") + else: + expected = Series([0], index=cat, dtype="int64") + tm.assert_series_equal(result, expected) + + +def test_intercept_builtin_sum(): + s = Series([1.0, 2.0, np.nan, 3.0]) + grouped = s.groupby([0, 1, 2, 2]) + + # GH#53425 + result = grouped.agg(builtins.sum) + # GH#53425 + result2 = grouped.apply(builtins.sum) + expected = Series([1.0, 2.0, np.nan], index=np.array([0, 1, 2])) + tm.assert_series_equal(result, expected) + tm.assert_series_equal(result2, expected) + + +@pytest.mark.parametrize("min_count", [0, 10]) +def test_groupby_sum_mincount_boolean(min_count): + b = True + a = False + na = np.nan + dfg = pd.array([b, b, na, na, a, a, b], dtype="boolean") + + df = DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": dfg}) + result = df.groupby("A").sum(min_count=min_count) + if min_count == 0: + expected = DataFrame( + {"B": pd.array([3, 0, 0], dtype="Int64")}, + index=pd.Index([1, 2, 3], name="A"), + ) + tm.assert_frame_equal(result, expected) + else: + expected = DataFrame( + {"B": pd.array([pd.NA] * 3, dtype="Int64")}, + index=pd.Index([1, 2, 3], name="A"), + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_sum_below_mincount_nullable_integer(): + # https://github.com/pandas-dev/pandas/issues/32861 + df = DataFrame({"a": [0, 1, 2], "b": [0, 1, 2], "c": [0, 1, 2]}, dtype="Int64") + grouped = df.groupby("a") + idx = pd.Index([0, 1, 2], name="a", dtype="Int64") + + result = grouped["b"].sum(min_count=2) + expected = Series([pd.NA] * 3, dtype="Int64", index=idx, name="b") + tm.assert_series_equal(result, expected) + + result = grouped.sum(min_count=2) + expected = DataFrame({"b": [pd.NA] * 3, "c": [pd.NA] * 3}, dtype="Int64", index=idx) + tm.assert_frame_equal(result, expected) + + +def test_groupby_sum_timedelta_with_nat(): + # GH#42659 + df = DataFrame( + { + "a": [1, 1, 2, 2], + "b": [pd.Timedelta("1D"), pd.Timedelta("2D"), pd.Timedelta("3D"), pd.NaT], + } + ) + td3 = pd.Timedelta(days=3).as_unit("us") + + gb = df.groupby("a") + + res = gb.sum() + expected = DataFrame({"b": [td3, td3]}, index=pd.Index([1, 2], name="a")) + tm.assert_frame_equal(res, expected) + + res = gb["b"].sum() + tm.assert_series_equal(res, expected["b"]) + + res = gb["b"].sum(min_count=2) + expected = Series([td3, pd.NaT], dtype="m8[us]", name="b", index=expected.index) + tm.assert_series_equal(res, expected) + + +@pytest.mark.parametrize( + "dtype", ["int8", "int16", "int32", "int64", "float32", "float64", "uint64"] +) +@pytest.mark.parametrize( + "method,data", + [ + ("first", {"df": [{"a": 1, "b": 1}, {"a": 2, "b": 3}]}), + ("last", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}]}), + ("min", {"df": [{"a": 1, "b": 1}, {"a": 2, "b": 3}]}), + ("max", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 4}]}), + ("count", {"df": [{"a": 1, "b": 2}, {"a": 2, "b": 2}], "out_type": "int64"}), + ], +) +def test_groupby_non_arithmetic_agg_types(dtype, method, data): + # GH9311, GH6620 + df = DataFrame( + [{"a": 1, "b": 1}, {"a": 1, "b": 2}, {"a": 2, "b": 3}, {"a": 2, "b": 4}] + ) + + df["b"] = df.b.astype(dtype) + + if "args" not in data: + data["args"] = [] + + if "out_type" in data: + out_type = data["out_type"] + else: + out_type = dtype + + exp = data["df"] + df_out = DataFrame(exp) + + df_out["b"] = df_out.b.astype(out_type) + df_out.set_index("a", inplace=True) + + grpd = df.groupby("a") + t = getattr(grpd, method)(*data["args"]) + tm.assert_frame_equal(t, df_out) + + +def scipy_sem(*args, **kwargs): + from scipy.stats import sem + + return sem(*args, ddof=1, **kwargs) + + +@pytest.mark.parametrize( + "op,targop", + [ + ("mean", np.mean), + ("median", np.median), + ("std", np.std), + ("var", np.var), + ("sum", np.sum), + ("prod", np.prod), + ("min", np.min), + ("max", np.max), + ("first", lambda x: x.iloc[0]), + ("last", lambda x: x.iloc[-1]), + ("count", np.size), + pytest.param("sem", scipy_sem, marks=td.skip_if_no("scipy")), + ], +) +def test_ops_general(op, targop): + df = DataFrame(np.random.default_rng(2).standard_normal(1000)) + labels = np.random.default_rng(2).integers(0, 50, size=1000).astype(float) + + result = getattr(df.groupby(labels), op)() + kwargs = {"ddof": 1, "axis": 0} if op in ["std", "var"] else {} + expected = df.groupby(labels).agg(targop, **kwargs) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "values", + [ + { + "a": [1, 1, 1, 2, 2, 2, 3, 3, 3], + "b": [1, pd.NA, 2, 1, pd.NA, 2, 1, pd.NA, 2], + }, + {"a": [1, 1, 2, 2, 3, 3], "b": [1, 2, 1, 2, 1, 2]}, + ], +) +@pytest.mark.parametrize("function", ["mean", "median", "var"]) +def test_apply_to_nullable_integer_returns_float(values, function): + # https://github.com/pandas-dev/pandas/issues/32219 + output = 0.5 if function == "var" else 1.5 + arr = np.array([output] * 3, dtype=float) + idx = pd.Index([1, 2, 3], name="a", dtype="Int64") + expected = DataFrame({"b": arr}, index=idx).astype("Float64") + + groups = DataFrame(values, dtype="Int64").groupby("a") + + result = getattr(groups, function)() + tm.assert_frame_equal(result, expected) + + result = groups.agg(function) + tm.assert_frame_equal(result, expected) + + result = groups.agg([function]) + expected.columns = MultiIndex.from_tuples([("b", function)]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "op", + [ + "sum", + "prod", + "min", + "max", + "median", + "mean", + "skew", + "kurt", + "std", + "var", + "sem", + ], +) +def test_regression_allowlist_methods(op, skipna, sort): + # GH6944 + # GH 17537 + # explicitly test the allowlist methods + frame = DataFrame([0]) + + grouped = frame.groupby(level=0, sort=sort) + + if op in ["skew", "kurt", "sum", "mean"]: + # skew, kurt, sum, mean have skipna + result = getattr(grouped, op)(skipna=skipna) + expected = frame.groupby(level=0).apply(lambda h: getattr(h, op)(skipna=skipna)) + if sort: + expected = expected.sort_index() + tm.assert_frame_equal(result, expected) + else: + result = getattr(grouped, op)() + expected = frame.groupby(level=0).apply(lambda h: getattr(h, op)()) + if sort: + expected = expected.sort_index() + tm.assert_frame_equal(result, expected) + + +def test_groupby_prod_with_int64_dtype(): + # GH#46573 + data = [ + [1, 11], + [1, 41], + [1, 17], + [1, 37], + [1, 7], + [1, 29], + [1, 31], + [1, 2], + [1, 3], + [1, 43], + [1, 5], + [1, 47], + [1, 19], + [1, 88], + ] + df = DataFrame(data, columns=["A", "B"], dtype="int64") + result = df.groupby(["A"]).prod().reset_index() + expected = DataFrame({"A": [1], "B": [180970905912331920]}, dtype="int64") + tm.assert_frame_equal(result, expected) + + +def test_groupby_std_datetimelike(): + # GH#48481 + tdi = pd.timedelta_range("1 Day", periods=10000, unit="ns") + ser = Series(tdi) + ser[::5] *= 2 # get different std for different groups + + df = ser.to_frame("A").copy() + + df["B"] = ser + Timestamp(0) + df["C"] = ser + Timestamp(0, tz="UTC") + df.iloc[-1] = pd.NaT # last group includes NaTs + + gb = df.groupby(list(range(5)) * 2000) + + result = gb.std() + + # Note: this does not _exactly_ match what we would get if we did + # [gb.get_group(i).std() for i in gb.groups] + # but it _does_ match the floating point error we get doing the + # same operation on int64 data xref GH#51332 + td1 = pd.Timedelta("2887 days 11:21:02.326710176") + td4 = pd.Timedelta("2886 days 00:42:34.664668096") + exp_ser = Series([td1 * 2, td1, td1, td1, td4], index=np.arange(5)) + expected = DataFrame({"A": exp_ser, "B": exp_ser, "C": exp_ser}) + tm.assert_frame_equal(result, expected) + + +def test_mean_numeric_only_validates_bool(): + # GH#62778 + + df = DataFrame({"A": range(5), "B": range(5)}) + + msg = "numeric_only accepts only Boolean values" + with pytest.raises(ValueError, match=msg): + df.groupby(["A"]).mean(["B"]) + + with pytest.raises(ValueError, match=msg): + df.groupby(["A"]).mean(numeric_only="True") + + with pytest.raises(ValueError, match=msg): + df.groupby(["A"]).mean(numeric_only=1) diff --git a/pandas/tests/groupby/test_timegrouper.py b/pandas/tests/groupby/test_timegrouper.py new file mode 100644 index 0000000000000000000000000000000000000000..b60947e61fb23acf3cd0c3314f0dbfe124c22188 --- /dev/null +++ b/pandas/tests/groupby/test_timegrouper.py @@ -0,0 +1,984 @@ +""" +test with the TimeGrouper / grouping with datetimes +""" + +from datetime import ( + datetime, + timedelta, + timezone, +) + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + DatetimeIndex, + Index, + MultiIndex, + Series, + Timestamp, + date_range, + offsets, +) +import pandas._testing as tm +from pandas.core.groupby.grouper import Grouper +from pandas.core.groupby.ops import BinGrouper + + +@pytest.fixture +def frame_for_truncated_bingrouper(): + """ + DataFrame used by groupby_with_truncated_bingrouper, made into + a separate fixture for easier reuse in + test_groupby_apply_timegrouper_with_nat_apply_squeeze + """ + df = DataFrame( + { + "Quantity": [18, 3, 5, 1, 9, 3], + "Date": [ + Timestamp(2013, 9, 1, 13, 0), + Timestamp(2013, 9, 1, 13, 5), + Timestamp(2013, 10, 1, 20, 0), + Timestamp(2013, 10, 3, 10, 0), + pd.NaT, + Timestamp(2013, 9, 2, 14, 0), + ], + } + ) + return df + + +@pytest.fixture +def groupby_with_truncated_bingrouper(frame_for_truncated_bingrouper): + """ + GroupBy object such that gb._grouper is a BinGrouper and + len(gb._grouper.result_index) < len(gb._grouper.group_keys_seq) + + Aggregations on this groupby should have + + dti = date_range("2013-09-01", "2013-10-01", freq="5D", name="Date") + + As either the index or an index level. + """ + df = frame_for_truncated_bingrouper + + tdg = Grouper(key="Date", freq="5D") + gb = df.groupby(tdg) + + # check we're testing the case we're interested in + assert len(gb._grouper.result_index) != len(gb._grouper.codes) + + return gb + + +class TestGroupBy: + def test_groupby_with_timegrouper(self, using_infer_string): + # GH 4161 + # TimeGrouper requires a sorted index + # also verifies that the resultant index has the correct name + df_original = DataFrame( + { + "Buyer": "Carl Carl Carl Carl Joe Carl".split(), + "Quantity": [18, 3, 5, 1, 9, 3], + "Date": [ + datetime(2013, 9, 1, 13, 0), + datetime(2013, 9, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 3, 10, 0), + datetime(2013, 12, 2, 12, 0), + datetime(2013, 9, 2, 14, 0), + ], + } + ) + + # GH 6908 change target column's order + df_reordered = df_original.sort_values(by="Quantity") + + for df in [df_original, df_reordered]: + df = df.set_index(["Date"]) + + exp_dti = date_range( + "20130901", + "20131205", + freq="5D", + name="Date", + inclusive="left", + unit=df.index.unit, + ) + expected = DataFrame( + {"Buyer": "" if using_infer_string else 0, "Quantity": 0}, + index=exp_dti, + ) + # Cast to object/str to avoid implicit cast when setting + # entry to "CarlCarlCarl" + expected = expected.astype({"Buyer": object}) + if using_infer_string: + expected = expected.astype({"Buyer": "str"}) + expected.iloc[0, 0] = "CarlCarlCarl" + expected.iloc[6, 0] = "CarlCarl" + expected.iloc[18, 0] = "Joe" + expected.iloc[[0, 6, 18], 1] = np.array([24, 6, 9], dtype="int64") + + result1 = df.resample("5D").sum() + tm.assert_frame_equal(result1, expected) + + df_sorted = df.sort_index() + result2 = df_sorted.groupby(Grouper(freq="5D")).sum() + tm.assert_frame_equal(result2, expected) + + result3 = df.groupby(Grouper(freq="5D")).sum() + tm.assert_frame_equal(result3, expected) + + @pytest.mark.parametrize("should_sort", [True, False]) + def test_groupby_with_timegrouper_methods(self, should_sort): + # GH 3881 + # make sure API of timegrouper conforms + + df = DataFrame( + { + "Branch": "A A A A A B".split(), + "Buyer": "Carl Mark Carl Joe Joe Carl".split(), + "Quantity": [1, 3, 5, 8, 9, 3], + "Date": [ + datetime(2013, 1, 1, 13, 0), + datetime(2013, 1, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 12, 2, 12, 0), + datetime(2013, 12, 2, 14, 0), + ], + } + ) + + if should_sort: + df = df.sort_values(by="Quantity", ascending=False) + + df = df.set_index("Date", drop=False) + g = df.groupby(Grouper(freq="6ME")) + assert g.group_keys + + assert isinstance(g._grouper, BinGrouper) + groups = g.groups + assert isinstance(groups, dict) + assert len(groups) == 3 + + def test_timegrouper_with_reg_groups(self): + # GH 3794 + # allow combination of timegrouper/reg groups + + df_original = DataFrame( + { + "Branch": "A A A A A A A B".split(), + "Buyer": "Carl Mark Carl Carl Joe Joe Joe Carl".split(), + "Quantity": [1, 3, 5, 1, 8, 1, 9, 3], + "Date": [ + datetime(2013, 1, 1, 13, 0), + datetime(2013, 1, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 12, 2, 12, 0), + datetime(2013, 12, 2, 14, 0), + ], + } + ).set_index("Date") + + df_sorted = df_original.sort_values(by="Quantity", ascending=False) + + for df in [df_original, df_sorted]: + expected = DataFrame( + { + "Buyer": "Carl Joe Mark".split(), + "Quantity": [10, 18, 3], + "Date": [ + datetime(2013, 12, 31, 0, 0), + datetime(2013, 12, 31, 0, 0), + datetime(2013, 12, 31, 0, 0), + ], + } + ).set_index(["Date", "Buyer"]) + + msg = "The default value of numeric_only" + result = df.groupby([Grouper(freq="YE"), "Buyer"]).sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + expected = DataFrame( + { + "Buyer": "Carl Mark Carl Joe".split(), + "Quantity": [1, 3, 9, 18], + "Date": [ + datetime(2013, 1, 1, 0, 0), + datetime(2013, 1, 1, 0, 0), + datetime(2013, 7, 1, 0, 0), + datetime(2013, 7, 1, 0, 0), + ], + } + ).set_index(["Date", "Buyer"]) + result = df.groupby([Grouper(freq="6MS"), "Buyer"]).sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + df_original = DataFrame( + { + "Branch": "A A A A A A A B".split(), + "Buyer": "Carl Mark Carl Carl Joe Joe Joe Carl".split(), + "Quantity": [1, 3, 5, 1, 8, 1, 9, 3], + "Date": [ + datetime(2013, 10, 1, 13, 0), + datetime(2013, 10, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 10, 2, 12, 0), + datetime(2013, 10, 2, 14, 0), + ], + } + ).set_index("Date") + + df_sorted = df_original.sort_values(by="Quantity", ascending=False) + for df in [df_original, df_sorted]: + expected = DataFrame( + { + "Buyer": "Carl Joe Mark Carl Joe".split(), + "Quantity": [6, 8, 3, 4, 10], + "Date": [ + datetime(2013, 10, 1, 0, 0), + datetime(2013, 10, 1, 0, 0), + datetime(2013, 10, 1, 0, 0), + datetime(2013, 10, 2, 0, 0), + datetime(2013, 10, 2, 0, 0), + ], + } + ).set_index(["Date", "Buyer"]) + + result = df.groupby([Grouper(freq="1D"), "Buyer"]).sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + result = df.groupby([Grouper(freq="1ME"), "Buyer"]).sum(numeric_only=True) + expected = DataFrame( + { + "Buyer": "Carl Joe Mark".split(), + "Quantity": [10, 18, 3], + "Date": [ + datetime(2013, 10, 31, 0, 0), + datetime(2013, 10, 31, 0, 0), + datetime(2013, 10, 31, 0, 0), + ], + } + ).set_index(["Date", "Buyer"]) + tm.assert_frame_equal(result, expected) + + # passing the name + df = df.reset_index() + result = df.groupby([Grouper(freq="1ME", key="Date"), "Buyer"]).sum( + numeric_only=True + ) + tm.assert_frame_equal(result, expected) + + with pytest.raises(KeyError, match="'The grouper name foo is not found'"): + df.groupby([Grouper(freq="1ME", key="foo"), "Buyer"]).sum() + + # passing the level + df = df.set_index("Date") + result = df.groupby([Grouper(freq="1ME", level="Date"), "Buyer"]).sum( + numeric_only=True + ) + tm.assert_frame_equal(result, expected) + result = df.groupby([Grouper(freq="1ME", level=0), "Buyer"]).sum( + numeric_only=True + ) + tm.assert_frame_equal(result, expected) + + with pytest.raises(ValueError, match="The level foo is not valid"): + df.groupby([Grouper(freq="1ME", level="foo"), "Buyer"]).sum() + + # multi names + df = df.copy() + df["Date"] = df.index + offsets.MonthEnd(2) + result = df.groupby([Grouper(freq="1ME", key="Date"), "Buyer"]).sum( + numeric_only=True + ) + expected = DataFrame( + { + "Buyer": "Carl Joe Mark".split(), + "Quantity": [10, 18, 3], + "Date": [ + datetime(2013, 11, 30, 0, 0), + datetime(2013, 11, 30, 0, 0), + datetime(2013, 11, 30, 0, 0), + ], + } + ).set_index(["Date", "Buyer"]) + tm.assert_frame_equal(result, expected) + + # error as we have both a level and a name! + msg = "The Grouper cannot specify both a key and a level!" + with pytest.raises(ValueError, match=msg): + df.groupby( + [Grouper(freq="1ME", key="Date", level="Date"), "Buyer"] + ).sum() + + # single groupers + expected = DataFrame( + [[31]], + columns=["Quantity"], + index=DatetimeIndex( + [datetime(2013, 10, 31, 0, 0)], freq=offsets.MonthEnd(), name="Date" + ), + ) + result = df.groupby(Grouper(freq="1ME")).sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + result = df.groupby([Grouper(freq="1ME")]).sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + expected.index = expected.index.shift(1) + assert expected.index.freq == offsets.MonthEnd() + result = df.groupby(Grouper(freq="1ME", key="Date")).sum(numeric_only=True) + tm.assert_frame_equal(result, expected) + + result = df.groupby([Grouper(freq="1ME", key="Date")]).sum( + numeric_only=True + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("freq", ["D", "ME", "YE", "QE-APR"]) + def test_timegrouper_with_reg_groups_freq(self, freq): + # GH 6764 multiple grouping with/without sort + df = DataFrame( + { + "date": pd.to_datetime( + [ + "20121002", + "20121007", + "20130130", + "20130202", + "20130305", + "20121002", + "20121207", + "20130130", + "20130202", + "20130305", + "20130202", + "20130305", + ] + ), + "user_id": [1, 1, 1, 1, 1, 3, 3, 3, 5, 5, 5, 5], + "whole_cost": [ + 1790, + 364, + 280, + 259, + 201, + 623, + 90, + 312, + 359, + 301, + 359, + 801, + ], + "cost1": [12, 15, 10, 24, 39, 1, 0, 90, 45, 34, 1, 12], + } + ).set_index("date") + + expected = ( + df.groupby("user_id")["whole_cost"] + .resample(freq) + .sum(min_count=1) # XXX + .dropna() + .reorder_levels(["date", "user_id"]) + .sort_index() + .astype("int64") + ) + expected.name = "whole_cost" + + result1 = ( + df.sort_index().groupby([Grouper(freq=freq), "user_id"])["whole_cost"].sum() + ) + tm.assert_series_equal(result1, expected) + + result2 = df.groupby([Grouper(freq=freq), "user_id"])["whole_cost"].sum() + tm.assert_series_equal(result2, expected) + + def test_timegrouper_get_group(self): + # GH 6914 + + df_original = DataFrame( + { + "Buyer": "Carl Joe Joe Carl Joe Carl".split(), + "Quantity": [18, 3, 5, 1, 9, 3], + "Date": [ + datetime(2013, 9, 1, 13, 0), + datetime(2013, 9, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 3, 10, 0), + datetime(2013, 12, 2, 12, 0), + datetime(2013, 9, 2, 14, 0), + ], + } + ) + df_reordered = df_original.sort_values(by="Quantity") + + # single grouping + expected_list = [ + df_original.iloc[[0, 1, 5]], + df_original.iloc[[2, 3]], + df_original.iloc[[4]], + ] + dt_list = ["2013-09-30", "2013-10-31", "2013-12-31"] + + for df in [df_original, df_reordered]: + grouped = df.groupby(Grouper(freq="ME", key="Date")) + for t, expected in zip(dt_list, expected_list, strict=True): + dt = Timestamp(t) + result = grouped.get_group(dt) + tm.assert_frame_equal(result, expected) + + # multiple grouping + expected_list = [ + df_original.iloc[[1]], + df_original.iloc[[3]], + df_original.iloc[[4]], + ] + g_list = [("Joe", "2013-09-30"), ("Carl", "2013-10-31"), ("Joe", "2013-12-31")] + + for df in [df_original, df_reordered]: + grouped = df.groupby(["Buyer", Grouper(freq="ME", key="Date")]) + for (b, t), expected in zip(g_list, expected_list, strict=True): + dt = Timestamp(t) + result = grouped.get_group((b, dt)) + tm.assert_frame_equal(result, expected) + + # with index + df_original = df_original.set_index("Date") + df_reordered = df_original.sort_values(by="Quantity") + + expected_list = [ + df_original.iloc[[0, 1, 5]], + df_original.iloc[[2, 3]], + df_original.iloc[[4]], + ] + + for df in [df_original, df_reordered]: + grouped = df.groupby(Grouper(freq="ME")) + for t, expected in zip(dt_list, expected_list, strict=True): + dt = Timestamp(t) + result = grouped.get_group(dt) + tm.assert_frame_equal(result, expected) + + def test_timegrouper_apply_return_type_series(self): + # Using `apply` with the `TimeGrouper` should give the + # same return type as an `apply` with a `Grouper`. + # Issue #11742 + df = DataFrame({"date": ["10/10/2000", "11/10/2000"], "value": [10, 13]}) + df_dt = df.copy() + df_dt["date"] = pd.to_datetime(df_dt["date"]) + + def sumfunc_series(x): + return Series([x["value"].sum()], ("sum",)) + + expected = df.groupby(Grouper(key="date")).apply(sumfunc_series) + result = df_dt.groupby(Grouper(freq="ME", key="date")).apply(sumfunc_series) + tm.assert_frame_equal( + result.reset_index(drop=True), expected.reset_index(drop=True) + ) + + def test_timegrouper_apply_return_type_value(self): + # Using `apply` with the `TimeGrouper` should give the + # same return type as an `apply` with a `Grouper`. + # Issue #11742 + df = DataFrame({"date": ["10/10/2000", "11/10/2000"], "value": [10, 13]}) + df_dt = df.copy() + df_dt["date"] = pd.to_datetime(df_dt["date"]) + + def sumfunc_value(x): + return x.value.sum() + + expected = df.groupby(Grouper(key="date")).apply(sumfunc_value) + result = df_dt.groupby(Grouper(freq="ME", key="date")).apply(sumfunc_value) + tm.assert_series_equal( + result.reset_index(drop=True), expected.reset_index(drop=True) + ) + + def test_groupby_groups_datetimeindex(self): + # GH#1430 + periods = 1000 + ind = date_range(start="2012/1/1", freq="5min", periods=periods) + df = DataFrame( + {"high": np.arange(periods), "low": np.arange(periods)}, index=ind + ) + grouped = df.groupby(lambda x: datetime(x.year, x.month, x.day)) + + # it works! + groups = grouped.groups + assert isinstance(next(iter(groups.keys())), datetime) + + def test_groupby_groups_datetimeindex2(self): + # GH#11442 + index = date_range("2015/01/01", periods=5, name="date") + df = DataFrame({"A": [5, 6, 7, 8, 9], "B": [1, 2, 3, 4, 5]}, index=index) + result = df.groupby(level="date").groups + dates = ["2015-01-05", "2015-01-04", "2015-01-03", "2015-01-02", "2015-01-01"] + expected = { + Timestamp(date): DatetimeIndex([date], name="date") for date in dates + } + tm.assert_dict_equal(result, expected) + + grouped = df.groupby(level="date") + for date in dates: + result = grouped.get_group(date) + data = [[df.loc[date, "A"], df.loc[date, "B"]]] + expected_index = DatetimeIndex( + [date], name="date", freq="D", dtype=index.dtype + ) + expected = DataFrame(data, columns=list("AB"), index=expected_index) + tm.assert_frame_equal(result, expected) + + def test_groupby_groups_datetimeindex_tz(self): + # GH 3950 + dates = [ + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + ] + df = DataFrame( + { + "label": ["a", "a", "a", "b", "b", "b"], + "datetime": dates, + "value1": np.arange(6, dtype="int64"), + "value2": [1, 2] * 3, + } + ) + df["datetime"] = df["datetime"].apply(lambda d: Timestamp(d, tz="US/Pacific")) + + exp_idx1 = DatetimeIndex( + [ + "2011-07-19 07:00:00", + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + "2011-07-19 09:00:00", + ], + tz="US/Pacific", + name="datetime", + ) + exp_idx2 = Index(["a", "b"] * 3, name="label") + exp_idx = MultiIndex.from_arrays([exp_idx1, exp_idx2]) + expected = DataFrame( + {"value1": [0, 3, 1, 4, 2, 5], "value2": [1, 2, 2, 1, 1, 2]}, + index=exp_idx, + columns=["value1", "value2"], + ) + + result = df.groupby(["datetime", "label"]).sum() + tm.assert_frame_equal(result, expected) + + # by level + didx = DatetimeIndex(dates, tz="Asia/Tokyo") + df = DataFrame( + {"value1": np.arange(6, dtype="int64"), "value2": [1, 2, 3, 1, 2, 3]}, + index=didx, + ) + + exp_idx = DatetimeIndex( + ["2011-07-19 07:00:00", "2011-07-19 08:00:00", "2011-07-19 09:00:00"], + tz="Asia/Tokyo", + ) + expected = DataFrame( + {"value1": [3, 5, 7], "value2": [2, 4, 6]}, + index=exp_idx, + columns=["value1", "value2"], + ) + + result = df.groupby(level=0).sum() + tm.assert_frame_equal(result, expected) + + def test_frame_datetime64_handling_groupby(self): + # it works! + df = DataFrame( + [(3, np.datetime64("2012-07-03")), (3, np.datetime64("2012-07-04"))], + columns=["a", "date"], + ) + result = df.groupby("a").first() + assert result["date"][3] == Timestamp("2012-07-03") + + def test_groupby_multi_timezone(self): + # combining multiple / different timezones yields UTC + df = DataFrame( + { + "value": range(5), + "date": [ + "2000-01-28 16:47:00", + "2000-01-29 16:48:00", + "2000-01-30 16:49:00", + "2000-01-31 16:50:00", + "2000-01-01 16:50:00", + ], + "tz": [ + "America/Chicago", + "America/Chicago", + "America/Los_Angeles", + "America/Chicago", + "America/New_York", + ], + } + ) + + result = df.groupby("tz", group_keys=False).date.apply( + lambda x: pd.to_datetime(x).dt.tz_localize(x.name) + ) + + expected = Series( + [ + Timestamp("2000-01-28 16:47:00-0600", tz="America/Chicago"), + Timestamp("2000-01-29 16:48:00-0600", tz="America/Chicago"), + Timestamp("2000-01-30 16:49:00-0800", tz="America/Los_Angeles"), + Timestamp("2000-01-31 16:50:00-0600", tz="America/Chicago"), + Timestamp("2000-01-01 16:50:00-0500", tz="America/New_York"), + ], + name="date", + dtype=object, + ) + tm.assert_series_equal(result, expected) + + tz = "America/Chicago" + res_values = df.groupby("tz").date.get_group(tz) + result = pd.to_datetime(res_values).dt.tz_localize(tz) + exp_values = Series( + ["2000-01-28 16:47:00", "2000-01-29 16:48:00", "2000-01-31 16:50:00"], + index=[0, 1, 3], + name="date", + ) + expected = pd.to_datetime(exp_values).dt.tz_localize(tz) + tm.assert_series_equal(result, expected) + + def test_groupby_groups_periods(self): + dates = [ + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + ] + df = DataFrame( + { + "label": ["a", "a", "a", "b", "b", "b"], + "period": [pd.Period(d, freq="h") for d in dates], + "value1": np.arange(6, dtype="int64"), + "value2": [1, 2] * 3, + } + ) + + exp_idx1 = pd.PeriodIndex( + [ + "2011-07-19 07:00:00", + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + "2011-07-19 09:00:00", + ], + freq="h", + name="period", + ) + exp_idx2 = Index(["a", "b"] * 3, name="label") + exp_idx = MultiIndex.from_arrays([exp_idx1, exp_idx2]) + expected = DataFrame( + {"value1": [0, 3, 1, 4, 2, 5], "value2": [1, 2, 2, 1, 1, 2]}, + index=exp_idx, + columns=["value1", "value2"], + ) + + result = df.groupby(["period", "label"]).sum() + tm.assert_frame_equal(result, expected) + + # by level + didx = pd.PeriodIndex(dates, freq="h") + df = DataFrame( + {"value1": np.arange(6, dtype="int64"), "value2": [1, 2, 3, 1, 2, 3]}, + index=didx, + ) + + exp_idx = pd.PeriodIndex( + ["2011-07-19 07:00:00", "2011-07-19 08:00:00", "2011-07-19 09:00:00"], + freq="h", + ) + expected = DataFrame( + {"value1": [3, 5, 7], "value2": [2, 4, 6]}, + index=exp_idx, + columns=["value1", "value2"], + ) + + result = df.groupby(level=0).sum() + tm.assert_frame_equal(result, expected) + + def test_groupby_first_datetime64(self): + df = DataFrame([(1, 1351036800000000000), (2, 1351036800000000000)]) + df[1] = df[1].astype("M8[ns]") + + assert issubclass(df[1].dtype.type, np.datetime64) + + result = df.groupby(level=0).first() + got_dt = result[1].dtype + assert issubclass(got_dt.type, np.datetime64) + + result = df[1].groupby(level=0).first() + got_dt = result.dtype + assert issubclass(got_dt.type, np.datetime64) + + def test_groupby_max_datetime64(self): + # GH 5869 + # datetimelike dtype conversion from int + df = DataFrame({"A": Timestamp("20130101").as_unit("s"), "B": np.arange(5)}) + # TODO: can we retain second reso in .apply here? + expected = df.groupby("A")["A"].apply(lambda x: x.max()).astype("M8[s]") + result = df.groupby("A")["A"].max() + tm.assert_series_equal(result, expected) + + def test_groupby_datetime64_32_bit(self): + # GH 6410 / numpy 4328 + # 32-bit under 1.9-dev indexing issue + + df = DataFrame({"A": range(2), "B": [Timestamp("2000-01-1")] * 2}) + result = df.groupby("A")["B"].transform("min") + expected = Series([Timestamp("2000-01-1")] * 2, name="B") + tm.assert_series_equal(result, expected) + + def test_groupby_with_timezone_selection(self): + # GH 11616 + # Test that column selection returns output in correct timezone. + + df = DataFrame( + { + "factor": np.random.default_rng(2).integers(0, 3, size=60), + "time": date_range("01/01/2000 00:00", periods=60, freq="s", tz="UTC"), + } + ) + df1 = df.groupby("factor").max()["time"] + df2 = df.groupby("factor")["time"].max() + tm.assert_series_equal(df1, df2) + + def test_timezone_info(self): + # see gh-11682: Timezone info lost when broadcasting + # scalar datetime to DataFrame + utc = timezone.utc + df = DataFrame({"a": [1], "b": [datetime.now(utc)]}) + assert df["b"][0].tzinfo == utc + df = DataFrame({"a": [1, 2, 3]}) + df["b"] = datetime.now(utc) + assert df["b"][0].tzinfo == utc + + def test_datetime_count(self): + df = DataFrame( + {"a": [1, 2, 3] * 2, "dates": date_range("now", periods=6, freq="min")} + ) + result = df.groupby("a").dates.count() + expected = Series([2, 2, 2], index=Index([1, 2, 3], name="a"), name="dates") + tm.assert_series_equal(result, expected) + + def test_first_last_max_min_on_time_data(self): + # GH 10295 + # Verify that NaT is not in the result of max, min, first and last on + # Dataframe with datetime or timedelta values. + df_test = DataFrame( + { + "dt": [ + np.nan, + "2015-07-24 10:10", + "2015-07-25 11:11", + "2015-07-23 12:12", + np.nan, + ], + "td": [ + np.nan, + timedelta(days=1), + timedelta(days=2), + timedelta(days=3), + np.nan, + ], + } + ) + df_test.dt = pd.to_datetime(df_test.dt) + df_test["group"] = "A" + df_ref = df_test[df_test.dt.notna()] + + grouped_test = df_test.groupby("group") + grouped_ref = df_ref.groupby("group") + + tm.assert_frame_equal(grouped_ref.max(), grouped_test.max()) + tm.assert_frame_equal(grouped_ref.min(), grouped_test.min()) + tm.assert_frame_equal(grouped_ref.first(), grouped_test.first()) + tm.assert_frame_equal(grouped_ref.last(), grouped_test.last()) + + def test_nunique_with_timegrouper_and_nat(self): + # GH 17575 + test = DataFrame( + { + "time": [ + Timestamp("2016-06-28 09:35:35"), + pd.NaT, + Timestamp("2016-06-28 16:46:28"), + ], + "data": ["1", "2", "3"], + } + ) + + grouper = Grouper(key="time", freq="h") + result = test.groupby(grouper)["data"].nunique() + expected = test[test.time.notnull()].groupby(grouper)["data"].nunique() + expected.index = expected.index._with_freq(None) + tm.assert_series_equal(result, expected) + + def test_scalar_call_versus_list_call(self): + # Issue: 17530 + data_frame = { + "location": ["shanghai", "beijing", "shanghai"], + "time": Series( + ["2017-08-09 13:32:23", "2017-08-11 23:23:15", "2017-08-11 22:23:15"], + dtype="datetime64[ns]", + ), + "value": [1, 2, 3], + } + data_frame = DataFrame(data_frame).set_index("time") + grouper = Grouper(freq="D") + + grouped = data_frame.groupby(grouper) + result = grouped.count() + grouped = data_frame.groupby([grouper]) + expected = grouped.count() + + tm.assert_frame_equal(result, expected) + + def test_grouper_period_index(self): + # GH 32108 + periods = 2 + index = pd.period_range( + start="2018-01", periods=periods, freq="M", name="Month" + ) + period_series = Series(range(periods), index=index) + result = period_series.groupby(period_series.index.month).sum() + + expected = Series( + range(periods), index=Index(range(1, periods + 1), name=index.name) + ) + tm.assert_series_equal(result, expected) + + def test_groupby_apply_timegrouper_with_nat_dict_returns( + self, groupby_with_truncated_bingrouper + ): + # GH#43500 case where gb._grouper.result_index and gb._grouper.group_keys_seq + # have different lengths that goes through the `isinstance(values[0], dict)` + # path + gb = groupby_with_truncated_bingrouper + + res = gb["Quantity"].apply(lambda x: {"foo": len(x)}) + + df = gb.obj + unit = df["Date"]._values.unit + dti = date_range("2013-09-01", "2013-10-01", freq="5D", name="Date", unit=unit) + mi = MultiIndex.from_arrays([dti, ["foo"] * len(dti)]) + expected = Series([3, 0, 0, 0, 0, 0, 2], index=mi, name="Quantity") + tm.assert_series_equal(res, expected) + + def test_groupby_apply_timegrouper_with_nat_scalar_returns( + self, groupby_with_truncated_bingrouper + ): + # GH#43500 Previously raised ValueError bc used index with incorrect + # length in wrap_applied_result + gb = groupby_with_truncated_bingrouper + + res = gb["Quantity"].apply(lambda x: x.iloc[0] if len(x) else np.nan) + + df = gb.obj + unit = df["Date"]._values.unit + dti = date_range("2013-09-01", "2013-10-01", freq="5D", name="Date", unit=unit) + expected = Series( + [18, np.nan, np.nan, np.nan, np.nan, np.nan, 5], + index=dti._with_freq(None), + name="Quantity", + ) + + tm.assert_series_equal(res, expected) + + def test_groupby_apply_timegrouper_with_nat_apply_squeeze( + self, frame_for_truncated_bingrouper + ): + df = frame_for_truncated_bingrouper + + # We need to create a GroupBy object with only one non-NaT group, + # so use a huge freq so that all non-NaT dates will be grouped together + tdg = Grouper(key="Date", freq="100YE") + gb = df.groupby(tdg) + + # check that we will go through the singular_series path + # in _wrap_applied_output_series + assert gb.ngroups == 1 + assert gb._selected_obj.index.nlevels == 1 + + # function that returns a Series + res = gb.apply(lambda x: x["Quantity"] * 2) + + dti = Index([Timestamp("2013-12-31")], dtype=df["Date"].dtype, name="Date") + expected = DataFrame( + [[36, 6, 6, 10, 2]], + index=dti, + columns=Index([0, 1, 5, 2, 3], name="Quantity"), + ) + tm.assert_frame_equal(res, expected) + + @pytest.mark.single_cpu + def test_groupby_agg_numba_timegrouper_with_nat( + self, groupby_with_truncated_bingrouper + ): + pytest.importorskip("numba") + + # See discussion in GH#43487 + gb = groupby_with_truncated_bingrouper + + result = gb["Quantity"].aggregate( + lambda values, index: np.nanmean(values), engine="numba" + ) + + expected = gb["Quantity"].aggregate("mean") + tm.assert_series_equal(result, expected) + + result_df = gb[["Quantity"]].aggregate( + lambda values, index: np.nanmean(values), engine="numba" + ) + expected_df = gb[["Quantity"]].aggregate("mean") + tm.assert_frame_equal(result_df, expected_df) + + @td.skip_if_no("pyarrow") + def test_pyarrow_index_retention(self): + # https://github.com/pandas-dev/pandas/issues/63518 + df = DataFrame( + { + "a": [1, 2, 3], + }, + index=Index( + [ + Timestamp("2013-01-01"), + Timestamp("2013-01-01"), + Timestamp("2013-01-02"), + ], + dtype="timestamp[ns, America/Denver][pyarrow]", + ), + ) + gb = df.groupby(Grouper(freq="D")) + result = gb._grouper.result_index + expected = Index( + [Timestamp("2013-01-01"), Timestamp("2013-01-02")], + dtype="timestamp[ns, America/Denver][pyarrow]", + ) + tm.assert_index_equal(result, expected) diff --git a/pandas/tests/interchange/__init__.py b/pandas/tests/interchange/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/interchange/test_impl.py b/pandas/tests/interchange/test_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..4e6fa4e2c83e1e539a638044684c0f9f4060d2b5 --- /dev/null +++ b/pandas/tests/interchange/test_impl.py @@ -0,0 +1,699 @@ +from datetime import ( + datetime, + timezone, +) + +import numpy as np +import pytest + +from pandas._libs.tslibs import iNaT +from pandas.compat import ( + is_ci_environment, + is_platform_windows, +) +from pandas.compat.pyarrow import pa_version_under22p0 + +import pandas as pd +import pandas._testing as tm +from pandas.core.interchange.column import PandasColumn +from pandas.core.interchange.dataframe_protocol import ( + ColumnNullType, + DtypeKind, +) +from pandas.core.interchange.from_dataframe import from_dataframe +from pandas.core.interchange.utils import ArrowCTypes + + +@pytest.mark.parametrize("data", [("ordered", True), ("unordered", False)]) +def test_categorical_dtype(data): + data_categorical = { + "ordered": pd.Categorical(list("testdata") * 30, ordered=True), + "unordered": pd.Categorical(list("testdata") * 30, ordered=False), + } + df = pd.DataFrame({"A": (data_categorical[data[0]])}) + + with tm.assert_produces_warning(match="Interchange"): + col = df.__dataframe__().get_column_by_name("A") + assert col.dtype[0] == DtypeKind.CATEGORICAL + assert col.null_count == 0 + assert col.describe_null == (ColumnNullType.USE_SENTINEL, -1) + assert col.num_chunks() == 1 + desc_cat = col.describe_categorical + assert desc_cat["is_ordered"] == data[1] + assert desc_cat["is_dictionary"] is True + assert isinstance(desc_cat["categories"], PandasColumn) + tm.assert_series_equal( + desc_cat["categories"]._col, pd.Series(["a", "d", "e", "s", "t"]) + ) + + with tm.assert_produces_warning(match="Interchange"): + tm.assert_frame_equal(df, from_dataframe(df.__dataframe__())) + + +def test_categorical_pyarrow(): + # GH 49889 + pa = pytest.importorskip("pyarrow", "11.0.0") + + arr = ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", "Sun"] + table = pa.table({"weekday": pa.array(arr).dictionary_encode()}) + exchange_df = table.__dataframe__() + with tm.assert_produces_warning(match="Interchange"): + result = from_dataframe(exchange_df) + weekday = pd.Categorical( + arr, categories=["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] + ) + expected = pd.DataFrame({"weekday": weekday}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.filterwarnings( + "ignore:Constructing a Categorical with a dtype and values containing" +) +def test_empty_categorical_pyarrow(): + # https://github.com/pandas-dev/pandas/issues/53077 + pa = pytest.importorskip("pyarrow", "11.0.0") + + arr = [None] + table = pa.table({"arr": pa.array(arr, "float64").dictionary_encode()}) + exchange_df = table.__dataframe__() + with tm.assert_produces_warning(match="Interchange"): + result = pd.api.interchange.from_dataframe(exchange_df) + expected = pd.DataFrame({"arr": pd.Categorical([np.nan])}) + tm.assert_frame_equal(result, expected) + + +def test_large_string_pyarrow(): + # GH 52795 + pa = pytest.importorskip("pyarrow", "11.0.0") + + arr = ["Mon", "Tue"] + table = pa.table({"weekday": pa.array(arr, "large_string")}) + exchange_df = table.__dataframe__() + with tm.assert_produces_warning(match="Interchange"): + result = from_dataframe(exchange_df) + expected = pd.DataFrame({"weekday": ["Mon", "Tue"]}) + tm.assert_frame_equal(result, expected) + + # check round-trip + # Don't check stacklevel as PyArrow calls the deprecated `__dataframe__` method. + with tm.assert_produces_warning(match="Interchange", check_stacklevel=False): + assert pa.Table.equals(pa.interchange.from_dataframe(result), table) + + +@pytest.mark.parametrize( + ("offset", "length", "expected_values"), + [ + (0, None, [3.3, float("nan"), 2.1]), + (1, None, [float("nan"), 2.1]), + (2, None, [2.1]), + (0, 2, [3.3, float("nan")]), + (0, 1, [3.3]), + (1, 1, [float("nan")]), + ], +) +def test_bitmasks_pyarrow(offset, length, expected_values): + # GH 52795 + pa = pytest.importorskip("pyarrow", "11.0.0") + + arr = [3.3, None, 2.1] + table = pa.table({"arr": arr}).slice(offset, length) + exchange_df = table.__dataframe__() + with tm.assert_produces_warning(match="Interchange"): + result = from_dataframe(exchange_df) + expected = pd.DataFrame({"arr": expected_values}) + tm.assert_frame_equal(result, expected) + + # check round-trip + # Don't check stacklevel as PyArrow calls the deprecated `__dataframe__` method. + with tm.assert_produces_warning(match="Interchange", check_stacklevel=False): + assert pa.Table.equals(pa.interchange.from_dataframe(result), table) + + +@pytest.mark.parametrize( + "data", + [ + lambda: np.random.default_rng(2).integers(-100, 100), + lambda: np.random.default_rng(2).integers(1, 100), + lambda: np.random.default_rng(2).random(), + lambda: np.random.default_rng(2).choice([True, False]), + lambda: datetime( + year=np.random.default_rng(2).integers(1900, 2100), + month=np.random.default_rng(2).integers(1, 12), + day=np.random.default_rng(2).integers(1, 20), + ), + ], +) +def test_dataframe(data): + NCOLS, NROWS = 10, 20 + data = { + f"col{int((i - NCOLS / 2) % NCOLS + 1)}": [data() for _ in range(NROWS)] + for i in range(NCOLS) + } + df = pd.DataFrame(data) + + with tm.assert_produces_warning(match="Interchange"): + df2 = df.__dataframe__() + + assert df2.num_columns() == NCOLS + assert df2.num_rows() == NROWS + + assert list(df2.column_names()) == list(data.keys()) + + indices = (0, 2) + names = tuple(list(data.keys())[idx] for idx in indices) + + with tm.assert_produces_warning(match="Interchange"): + result = from_dataframe(df2.select_columns(indices)) + expected = from_dataframe(df2.select_columns_by_name(names)) + tm.assert_frame_equal(result, expected) + + assert isinstance(result.attrs["_INTERCHANGE_PROTOCOL_BUFFERS"], list) + assert isinstance(expected.attrs["_INTERCHANGE_PROTOCOL_BUFFERS"], list) + + +def test_missing_from_masked(): + df = pd.DataFrame( + { + "x": np.array([1.0, 2.0, 3.0, 4.0, 0.0]), + "y": np.array([1.5, 2.5, 3.5, 4.5, 0]), + "z": np.array([1.0, 0.0, 1.0, 1.0, 1.0]), + } + ) + + rng = np.random.default_rng(2) + dict_null = {col: rng.integers(low=0, high=len(df)) for col in df.columns} + for col, num_nulls in dict_null.items(): + null_idx = df.index[ + rng.choice(np.arange(len(df)), size=num_nulls, replace=False) + ] + df.loc[null_idx, col] = None + + with tm.assert_produces_warning(match="Interchange"): + df2 = df.__dataframe__() + + assert df2.get_column_by_name("x").null_count == dict_null["x"] + assert df2.get_column_by_name("y").null_count == dict_null["y"] + assert df2.get_column_by_name("z").null_count == dict_null["z"] + + +@pytest.mark.parametrize( + "data", + [ + {"x": [1.5, 2.5, 3.5], "y": [9.2, 10.5, 11.8]}, + {"x": [1, 2, 0], "y": [9.2, 10.5, 11.8]}, + { + "x": np.array([True, True, False]), + "y": np.array([1, 2, 0]), + "z": np.array([9.2, 10.5, 11.8]), + }, + ], +) +def test_mixed_data(data): + df = pd.DataFrame(data) + with tm.assert_produces_warning(match="Interchange"): + df2 = df.__dataframe__() + + for col_name in df.columns: + assert df2.get_column_by_name(col_name).null_count == 0 + + +def test_mixed_missing(): + df = pd.DataFrame( + { + "x": np.array([True, None, False, None, True]), + "y": np.array([None, 2, None, 1, 2]), + "z": np.array([9.2, 10.5, None, 11.8, None]), + } + ) + + with tm.assert_produces_warning(match="Interchange"): + df2 = df.__dataframe__() + + for col_name in df.columns: + assert df2.get_column_by_name(col_name).null_count == 2 + + +def test_string(): + string_data = { + "separator data": [ + "abC|DeF,Hik", + "234,3245.67", + "gSaf,qWer|Gre", + "asd3,4sad|", + np.nan, + ] + } + test_str_data = string_data["separator data"] + [""] + df = pd.DataFrame({"A": test_str_data}) + with tm.assert_produces_warning(match="Interchange"): + col = df.__dataframe__().get_column_by_name("A") + + assert col.size() == 6 + assert col.null_count == 1 + assert col.dtype[0] == DtypeKind.STRING + assert col.describe_null == (ColumnNullType.USE_BYTEMASK, 0) + + df_sliced = df[1:] + with tm.assert_produces_warning(match="Interchange"): + col = df_sliced.__dataframe__().get_column_by_name("A") + assert col.size() == 5 + assert col.null_count == 1 + assert col.dtype[0] == DtypeKind.STRING + assert col.describe_null == (ColumnNullType.USE_BYTEMASK, 0) + + +def test_nonstring_object(): + df = pd.DataFrame({"A": ["a", 10, 1.0, ()]}) + with tm.assert_produces_warning(match="Interchange"): + col = df.__dataframe__().get_column_by_name("A") + with pytest.raises(NotImplementedError, match="not supported yet"): + col.dtype + + +def test_datetime(): + df = pd.DataFrame({"A": [pd.Timestamp("2022-01-01"), pd.NaT]}) + with tm.assert_produces_warning(match="Interchange"): + col = df.__dataframe__().get_column_by_name("A") + + assert col.size() == 2 + assert col.null_count == 1 + assert col.dtype[0] == DtypeKind.DATETIME + assert col.describe_null == (ColumnNullType.USE_SENTINEL, iNaT) + + with tm.assert_produces_warning(match="Interchange"): + tm.assert_frame_equal(df, from_dataframe(df.__dataframe__())) + + +def test_categorical_to_numpy_dlpack(): + # https://github.com/pandas-dev/pandas/issues/48393 + df = pd.DataFrame({"A": pd.Categorical(["a", "b", "a"])}) + with tm.assert_produces_warning(match="Interchange"): + col = df.__dataframe__().get_column_by_name("A") + result = np.from_dlpack(col.get_buffers()["data"][0]) + expected = np.array([0, 1, 0], dtype="int8") + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("data", [{}, {"a": []}]) +def test_empty_pyarrow(data): + # GH 53155 + pytest.importorskip("pyarrow", "14.0.0") + from pyarrow.interchange import from_dataframe as pa_from_dataframe + + expected = pd.DataFrame(data) + # Don't check stacklevel as PyArrow calls the deprecated `__dataframe__` method. + with tm.assert_produces_warning(match="Interchange", check_stacklevel=False): + arrow_df = pa_from_dataframe(expected) + result = from_dataframe(arrow_df) + tm.assert_frame_equal(result, expected, check_column_type=False) + + +def test_multi_chunk_pyarrow() -> None: + pa = pytest.importorskip("pyarrow", "14.0.0") + n_legs = pa.chunked_array([[2, 2, 4], [4, 5, 100]]) + names = ["n_legs"] + table = pa.table([n_legs], names=names) + with pytest.raises( + RuntimeError, + match="Cannot do zero copy conversion into multi-column DataFrame block", + ): + pd.api.interchange.from_dataframe(table, allow_copy=False) + + +def test_multi_chunk_column() -> None: + pytest.importorskip("pyarrow", "11.0.0") + ser = pd.Series([1, 2, None], dtype="Int64[pyarrow]") + df = pd.concat([ser, ser], ignore_index=True).to_frame("a") + df_orig = df.copy() + + with tm.assert_produces_warning(match="Interchange"): + with pytest.raises( + RuntimeError, + match="Found multi-chunk pyarrow array, but `allow_copy` is False", + ): + pd.api.interchange.from_dataframe(df.__dataframe__(allow_copy=False)) + with tm.assert_produces_warning(match="Interchange"): + result = pd.api.interchange.from_dataframe(df.__dataframe__(allow_copy=True)) + # Interchange protocol defaults to creating numpy-backed columns, so currently this + # is 'float64'. + expected = pd.DataFrame({"a": [1.0, 2.0, None, 1.0, 2.0, None]}, dtype="float64") + tm.assert_frame_equal(result, expected) + + # Check that the rechunking we did didn't modify the original DataFrame. + tm.assert_frame_equal(df, df_orig) + assert len(df["a"].array._pa_array.chunks) == 2 + assert len(df_orig["a"].array._pa_array.chunks) == 2 + + +def test_timestamp_ns_pyarrow(): + # GH 56712 + pytest.importorskip("pyarrow", "11.0.0") + timestamp_args = { + "year": 2000, + "month": 1, + "day": 1, + "hour": 1, + "minute": 1, + "second": 1, + } + df = pd.Series( + [datetime(**timestamp_args)], + dtype="timestamp[ns][pyarrow]", + name="col0", + ).to_frame() + + with tm.assert_produces_warning(match="Interchange"): + dfi = df.__dataframe__() + result = pd.api.interchange.from_dataframe(dfi)["col0"].item() + + expected = pd.Timestamp(**timestamp_args) + assert result == expected + + +@pytest.mark.parametrize("tz", ["UTC", "US/Pacific"]) +def test_datetimetzdtype(tz, unit): + # GH 54239 + tz_data = ( + pd.date_range("2018-01-01", periods=5, freq="D").tz_localize(tz).as_unit(unit) + ) + df = pd.DataFrame({"ts_tz": tz_data}) + with tm.assert_produces_warning(match="Interchange"): + tm.assert_frame_equal(df, from_dataframe(df.__dataframe__())) + + +def test_interchange_from_non_pandas_tz_aware(request): + # GH 54239, 54287 + pa = pytest.importorskip("pyarrow", "11.0.0") + import pyarrow.compute as pc + + if is_platform_windows() and is_ci_environment() and pa_version_under22p0: + mark = pytest.mark.xfail( + raises=pa.ArrowInvalid, + reason=( + "TODO: Set ARROW_TIMEZONE_DATABASE environment variable " + "on CI to path to the tzdata for pyarrow." + ), + ) + request.applymarker(mark) + + arr = pa.array([datetime(2020, 1, 1), None, datetime(2020, 1, 2)]) + arr = pc.assume_timezone(arr, "Asia/Kathmandu") + table = pa.table({"arr": arr}) + exchange_df = table.__dataframe__() + with tm.assert_produces_warning(match="Interchange"): + result = from_dataframe(exchange_df) + + expected = pd.DataFrame( + ["2020-01-01 00:00:00+05:45", "NaT", "2020-01-02 00:00:00+05:45"], + columns=["arr"], + dtype="datetime64[us, Asia/Kathmandu]", + ) + tm.assert_frame_equal(expected, result) + + +def test_interchange_from_corrected_buffer_dtypes(monkeypatch) -> None: + # https://github.com/pandas-dev/pandas/issues/54781 + with tm.assert_produces_warning(match="Interchange"): + df = pd.DataFrame({"a": ["foo", "bar"]}).__dataframe__() + interchange = df.__dataframe__() + column = interchange.get_column_by_name("a") + buffers = column.get_buffers() + buffers_data = buffers["data"] + buffer_dtype = buffers_data[1] + buffer_dtype = ( + DtypeKind.UINT, + 8, + ArrowCTypes.UINT8, + buffer_dtype[3], + ) + buffers["data"] = (buffers_data[0], buffer_dtype) + column.get_buffers = lambda: buffers + interchange.get_column_by_name = lambda _: column + monkeypatch.setattr(df, "__dataframe__", lambda allow_copy: interchange) + with tm.assert_produces_warning(match="Interchange"): + pd.api.interchange.from_dataframe(df) + + +def test_empty_string_column(): + # https://github.com/pandas-dev/pandas/issues/56703 + df = pd.DataFrame({"a": []}, dtype=str) + with tm.assert_produces_warning(match="Interchange"): + df2 = df.__dataframe__() + result = pd.api.interchange.from_dataframe(df2) + tm.assert_frame_equal(df, result) + + +def test_large_string(): + # GH#56702 + pytest.importorskip("pyarrow") + df = pd.DataFrame({"a": ["x"]}, dtype="large_string[pyarrow]") + # Don't check stacklevel as PyArrow calls the deprecated `__dataframe__` method. + with tm.assert_produces_warning(match="Interchange", check_stacklevel=False): + result = pd.api.interchange.from_dataframe(df.__dataframe__()) + expected = pd.DataFrame({"a": ["x"]}, dtype="str") + tm.assert_frame_equal(result, expected) + + +def test_non_str_names(): + # https://github.com/pandas-dev/pandas/issues/56701 + df = pd.Series([1, 2, 3], name=0).to_frame() + with tm.assert_produces_warning(match="Interchange"): + names = df.__dataframe__().column_names() + assert names == ["0"] + + +def test_non_str_names_w_duplicates(): + # https://github.com/pandas-dev/pandas/issues/56701 + df = pd.DataFrame({"0": [1, 2, 3], 0: [4, 5, 6]}) + with tm.assert_produces_warning(match="Interchange"): + dfi = df.__dataframe__() + with tm.assert_produces_warning(match="Interchange"): + with pytest.raises( + TypeError, + match=( + "Expected a Series, got a DataFrame. This likely happened because you " + "called __dataframe__ on a DataFrame which, after converting column " + r"names to string, resulted in duplicated names: Index\(\['0', '0'\], " + r"dtype='(str|object)'\). Please rename these columns before using the " + "interchange protocol." + ), + ): + pd.api.interchange.from_dataframe(dfi, allow_copy=False) + + +@pytest.mark.parametrize( + ("data", "dtype", "expected_dtype"), + [ + ([1, 2, None], "Int64", "int64"), + ([1, 2, None], "Int64[pyarrow]", "int64"), + ([1, 2, None], "Int8", "int8"), + ([1, 2, None], "Int8[pyarrow]", "int8"), + ( + [1, 2, None], + "UInt64", + "uint64", + ), + ( + [1, 2, None], + "UInt64[pyarrow]", + "uint64", + ), + ([1.0, 2.25, None], "Float32", "float32"), + ([1.0, 2.25, None], "Float32[pyarrow]", "float32"), + ([True, False, None], "boolean", "bool"), + ([True, False, None], "boolean[pyarrow]", "bool"), + (["much ado", "about", None], pd.StringDtype(na_value=np.nan), "large_string"), + (["much ado", "about", None], "string[pyarrow]", "large_string"), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2), None], + "timestamp[ns][pyarrow]", + "timestamp[ns]", + ), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2), None], + "timestamp[us][pyarrow]", + "timestamp[us]", + ), + ( + [ + datetime(2020, 1, 1, tzinfo=timezone.utc), + datetime(2020, 1, 2, tzinfo=timezone.utc), + None, + ], + "timestamp[us, Asia/Kathmandu][pyarrow]", + "timestamp[us, tz=Asia/Kathmandu]", + ), + ], +) +def test_pandas_nullable_with_missing_values( + data: list, dtype: str, expected_dtype: str +) -> None: + # https://github.com/pandas-dev/pandas/issues/57643 + # https://github.com/pandas-dev/pandas/issues/57664 + pa = pytest.importorskip("pyarrow", "14.0.0") + import pyarrow.interchange as pai + + if expected_dtype == "timestamp[us, tz=Asia/Kathmandu]": + expected_dtype = pa.timestamp("us", "Asia/Kathmandu") + + df = pd.DataFrame({"a": data}, dtype=dtype) + with tm.assert_produces_warning(match="Interchange"): + result = pai.from_dataframe(df.__dataframe__())["a"] + assert result.type == expected_dtype + assert result[0].as_py() == data[0] + assert result[1].as_py() == data[1] + assert result[2].as_py() is None + + +@pytest.mark.parametrize( + ("data", "dtype", "expected_dtype"), + [ + ([1, 2, 3], "Int64", "int64"), + ([1, 2, 3], "Int64[pyarrow]", "int64"), + ([1, 2, 3], "Int8", "int8"), + ([1, 2, 3], "Int8[pyarrow]", "int8"), + ( + [1, 2, 3], + "UInt64", + "uint64", + ), + ( + [1, 2, 3], + "UInt64[pyarrow]", + "uint64", + ), + ([1.0, 2.25, 5.0], "Float32", "float32"), + ([1.0, 2.25, 5.0], "Float32[pyarrow]", "float32"), + ([True, False, False], "boolean", "bool"), + ([True, False, False], "boolean[pyarrow]", "bool"), + ( + ["much ado", "about", "nothing"], + pd.StringDtype(na_value=np.nan), + "large_string", + ), + (["much ado", "about", "nothing"], "string[pyarrow]", "large_string"), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)], + "timestamp[ns][pyarrow]", + "timestamp[ns]", + ), + ( + [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 3)], + "timestamp[us][pyarrow]", + "timestamp[us]", + ), + ( + [ + datetime(2020, 1, 1, tzinfo=timezone.utc), + datetime(2020, 1, 2, tzinfo=timezone.utc), + datetime(2020, 1, 3, tzinfo=timezone.utc), + ], + "timestamp[us, Asia/Kathmandu][pyarrow]", + "timestamp[us, tz=Asia/Kathmandu]", + ), + ], +) +def test_pandas_nullable_without_missing_values( + data: list, dtype: str, expected_dtype: str +) -> None: + # https://github.com/pandas-dev/pandas/issues/57643 + pa = pytest.importorskip("pyarrow", "14.0.0") + import pyarrow.interchange as pai + + if expected_dtype == "timestamp[us, tz=Asia/Kathmandu]": + expected_dtype = pa.timestamp("us", "Asia/Kathmandu") + + df = pd.DataFrame({"a": data}, dtype=dtype) + with tm.assert_produces_warning(match="Interchange"): + result = pai.from_dataframe(df.__dataframe__())["a"] + assert result.type == expected_dtype + assert result[0].as_py() == data[0] + assert result[1].as_py() == data[1] + assert result[2].as_py() == data[2] + + +def test_string_validity_buffer() -> None: + # https://github.com/pandas-dev/pandas/issues/57761 + pytest.importorskip("pyarrow", "11.0.0") + df = pd.DataFrame({"a": ["x"]}, dtype="large_string[pyarrow]") + with tm.assert_produces_warning(match="Interchange"): + result = df.__dataframe__().get_column_by_name("a").get_buffers()["validity"] + assert result is None + + +def test_string_validity_buffer_no_missing() -> None: + # https://github.com/pandas-dev/pandas/issues/57762 + pytest.importorskip("pyarrow", "11.0.0") + df = pd.DataFrame({"a": ["x", None]}, dtype="large_string[pyarrow]") + with tm.assert_produces_warning(match="Interchange"): + validity = df.__dataframe__().get_column_by_name("a").get_buffers()["validity"] + assert validity is not None + result = validity[1] + expected = (DtypeKind.BOOL, 1, ArrowCTypes.BOOL, "=") + assert result == expected + + +def test_empty_dataframe(): + # https://github.com/pandas-dev/pandas/issues/56700 + df = pd.DataFrame({"a": []}, dtype="int8") + with tm.assert_produces_warning(match="Interchange"): + dfi = df.__dataframe__() + result = pd.api.interchange.from_dataframe(dfi, allow_copy=False) + expected = pd.DataFrame({"a": []}, dtype="int8") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("data", "expected_dtype", "expected_buffer_dtype"), + [ + ( + pd.Series(["a", "b", "a"], dtype="category"), + (DtypeKind.CATEGORICAL, 8, "c", "="), + (DtypeKind.INT, 8, "c", "|"), + ), + ( + pd.Series( + [datetime(2022, 1, 1), datetime(2022, 1, 2), datetime(2022, 1, 3)], + dtype="M8[ns]", + ), + (DtypeKind.DATETIME, 64, "tsn:", "="), + (DtypeKind.INT, 64, ArrowCTypes.INT64, "="), + ), + ( + pd.Series(["a", "bc", None]), + (DtypeKind.STRING, 8, ArrowCTypes.STRING, "="), + (DtypeKind.UINT, 8, ArrowCTypes.UINT8, "="), + ), + ( + pd.Series([1, 2, 3]), + (DtypeKind.INT, 64, ArrowCTypes.INT64, "="), + (DtypeKind.INT, 64, ArrowCTypes.INT64, "="), + ), + ( + pd.Series([1.5, 2, 3]), + (DtypeKind.FLOAT, 64, ArrowCTypes.FLOAT64, "="), + (DtypeKind.FLOAT, 64, ArrowCTypes.FLOAT64, "="), + ), + ], +) +def test_buffer_dtype_categorical( + data: pd.Series, + expected_dtype: tuple[DtypeKind, int, str, str], + expected_buffer_dtype: tuple[DtypeKind, int, str, str], +) -> None: + # https://github.com/pandas-dev/pandas/issues/54781 + df = pd.DataFrame({"data": data}) + with tm.assert_produces_warning(match="Interchange"): + dfi = df.__dataframe__() + col = dfi.get_column_by_name("data") + assert col.dtype == expected_dtype + assert col.get_buffers()["data"][1] == expected_buffer_dtype + + +def test_from_dataframe_list_dtype(): + pa = pytest.importorskip("pyarrow", "14.0.0") + data = {"a": [[1, 2], [4, 5, 6]]} + tbl = pa.table(data) + result = from_dataframe(tbl) + expected = pd.DataFrame(data) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/interchange/test_spec_conformance.py b/pandas/tests/interchange/test_spec_conformance.py new file mode 100644 index 0000000000000000000000000000000000000000..04e19b290f886a4cd37e8c41ef2f2e65d7e678e8 --- /dev/null +++ b/pandas/tests/interchange/test_spec_conformance.py @@ -0,0 +1,187 @@ +""" +A verbatim copy (vendored) of the spec tests. +Taken from https://github.com/data-apis/dataframe-api +""" + +import ctypes +import math + +import pytest + +import pandas as pd +import pandas._testing as tm + + +@pytest.fixture +def df_from_dict(): + def maker(dct, is_categorical=False): + df = pd.DataFrame(dct) + return df.astype("category") if is_categorical else df + + return maker + + +@pytest.mark.parametrize( + "test_data", + [ + {"a": ["foo", "bar"], "b": ["baz", "qux"]}, + {"a": [1.5, 2.5, 3.5], "b": [9.2, 10.5, 11.8]}, + {"A": [1, 2, 3, 4], "B": [1, 2, 3, 4]}, + ], + ids=["str_data", "float_data", "int_data"], +) +def test_only_one_dtype(test_data, df_from_dict): + columns = list(test_data.keys()) + df = df_from_dict(test_data) + with tm.assert_produces_warning(match="Interchange"): + dfX = df.__dataframe__() + + column_size = len(test_data[columns[0]]) + for column in columns: + null_count = dfX.get_column_by_name(column).null_count + assert null_count == 0 + assert isinstance(null_count, int) + assert dfX.get_column_by_name(column).size() == column_size + assert dfX.get_column_by_name(column).offset == 0 + + +def test_mixed_dtypes(df_from_dict): + df = df_from_dict( + { + "a": [1, 2, 3], # dtype kind INT = 0 + "b": [3, 4, 5], # dtype kind INT = 0 + "c": [1.5, 2.5, 3.5], # dtype kind FLOAT = 2 + "d": [9, 10, 11], # dtype kind INT = 0 + "e": [True, False, True], # dtype kind BOOLEAN = 20 + "f": ["a", "", "c"], # dtype kind STRING = 21 + } + ) + with tm.assert_produces_warning(match="Interchange"): + dfX = df.__dataframe__() + # for meanings of dtype[0] see the spec; we cannot import the spec here as this + # file is expected to be vendored *anywhere*; + # values for dtype[0] are explained above + columns = {"a": 0, "b": 0, "c": 2, "d": 0, "e": 20, "f": 21} + + for column, kind in columns.items(): + colX = dfX.get_column_by_name(column) + assert colX.null_count == 0 + assert isinstance(colX.null_count, int) + assert colX.size() == 3 + assert colX.offset == 0 + + assert colX.dtype[0] == kind + + assert dfX.get_column_by_name("c").dtype[1] == 64 + + +def test_na_float(df_from_dict): + df = df_from_dict({"a": [1.0, math.nan, 2.0]}) + with tm.assert_produces_warning(match="Interchange"): + dfX = df.__dataframe__() + colX = dfX.get_column_by_name("a") + assert colX.null_count == 1 + assert isinstance(colX.null_count, int) + + +def test_noncategorical(df_from_dict): + df = df_from_dict({"a": [1, 2, 3]}) + with tm.assert_produces_warning(match="Interchange"): + dfX = df.__dataframe__() + colX = dfX.get_column_by_name("a") + with pytest.raises(TypeError, match=".*categorical.*"): + colX.describe_categorical + + +def test_categorical(df_from_dict): + df = df_from_dict( + {"weekday": ["Mon", "Tue", "Mon", "Wed", "Mon", "Thu", "Fri", "Sat", "Sun"]}, + is_categorical=True, + ) + + with tm.assert_produces_warning(match="Interchange"): + colX = df.__dataframe__().get_column_by_name("weekday") + categorical = colX.describe_categorical + assert isinstance(categorical["is_ordered"], bool) + assert isinstance(categorical["is_dictionary"], bool) + + +def test_dataframe(df_from_dict): + df = df_from_dict( + {"x": [True, True, False], "y": [1, 2, 0], "z": [9.2, 10.5, 11.8]} + ) + with tm.assert_produces_warning(match="Interchange"): + dfX = df.__dataframe__() + + assert dfX.num_columns() == 3 + assert dfX.num_rows() == 3 + assert dfX.num_chunks() == 1 + assert list(dfX.column_names()) == ["x", "y", "z"] + assert list(dfX.select_columns((0, 2)).column_names()) == list( + dfX.select_columns_by_name(("x", "z")).column_names() + ) + + +@pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)]) +def test_df_get_chunks(size, n_chunks, df_from_dict): + df = df_from_dict({"x": list(range(size))}) + with tm.assert_produces_warning(match="Interchange"): + dfX = df.__dataframe__() + chunks = list(dfX.get_chunks(n_chunks)) + assert len(chunks) == n_chunks + assert sum(chunk.num_rows() for chunk in chunks) == size + + +@pytest.mark.parametrize(["size", "n_chunks"], [(10, 3), (12, 3), (12, 5)]) +def test_column_get_chunks(size, n_chunks, df_from_dict): + df = df_from_dict({"x": list(range(size))}) + with tm.assert_produces_warning(match="Interchange"): + dfX = df.__dataframe__() + chunks = list(dfX.get_column(0).get_chunks(n_chunks)) + assert len(chunks) == n_chunks + assert sum(chunk.size() for chunk in chunks) == size + + +def test_get_columns(df_from_dict): + df = df_from_dict({"a": [0, 1], "b": [2.5, 3.5]}) + with tm.assert_produces_warning(match="Interchange"): + dfX = df.__dataframe__() + for colX in dfX.get_columns(): + assert colX.size() == 2 + assert colX.num_chunks() == 1 + # for meanings of dtype[0] see the spec; we cannot import the spec here as this + # file is expected to be vendored *anywhere* + assert dfX.get_column(0).dtype[0] == 0 # INT + assert dfX.get_column(1).dtype[0] == 2 # FLOAT + + +def test_buffer(df_from_dict): + arr = [0, 1, -1] + df = df_from_dict({"a": arr}) + with tm.assert_produces_warning(match="Interchange"): + dfX = df.__dataframe__() + colX = dfX.get_column(0) + bufX = colX.get_buffers() + + dataBuf, dataDtype = bufX["data"] + + assert dataBuf.bufsize > 0 + assert dataBuf.ptr != 0 + device, _ = dataBuf.__dlpack_device__() + + # for meanings of dtype[0] see the spec; we cannot import the spec here as this + # file is expected to be vendored *anywhere* + assert dataDtype[0] == 0 # INT + + if device == 1: # CPU-only as we're going to directly read memory here + bitwidth = dataDtype[1] + ctype = { + 8: ctypes.c_int8, + 16: ctypes.c_int16, + 32: ctypes.c_int32, + 64: ctypes.c_int64, + }[bitwidth] + + for idx, truth in enumerate(arr): + val = ctype.from_address(dataBuf.ptr + idx * (bitwidth // 8)).value + assert val == truth, f"Buffer at index {idx} mismatch" diff --git a/pandas/tests/interchange/test_utils.py b/pandas/tests/interchange/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a47bc2752ff32f5eb7630a3960e7611242cb73e3 --- /dev/null +++ b/pandas/tests/interchange/test_utils.py @@ -0,0 +1,89 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas.core.interchange.utils import dtype_to_arrow_c_fmt + +# TODO: use ArrowSchema to get reference C-string. +# At the time, there is no way to access ArrowSchema holding a type format string +# from python. The only way to access it is to export the structure to a C-pointer, +# see DataType._export_to_c() method defined in +# https://github.com/apache/arrow/blob/master/python/pyarrow/types.pxi + + +@pytest.mark.parametrize( + "pandas_dtype, c_string", + [ + (np.dtype("bool"), "b"), + (np.dtype("int8"), "c"), + (np.dtype("uint8"), "C"), + (np.dtype("int16"), "s"), + (np.dtype("uint16"), "S"), + (np.dtype("int32"), "i"), + (np.dtype("uint32"), "I"), + (np.dtype("int64"), "l"), + (np.dtype("uint64"), "L"), + (np.dtype("float16"), "e"), + (np.dtype("float32"), "f"), + (np.dtype("float64"), "g"), + (pd.Series(["a"]).dtype, "u"), + ( + pd.Series([0]).astype("datetime64[ns]").dtype, + "tsn:", + ), + (pd.CategoricalDtype(["a"]), "l"), + (np.dtype("O"), "u"), + ], +) +def test_dtype_to_arrow_c_fmt(pandas_dtype, c_string): # PR01 + """Test ``dtype_to_arrow_c_fmt`` utility function.""" + assert dtype_to_arrow_c_fmt(pandas_dtype) == c_string + + +@pytest.mark.parametrize( + "pa_dtype, args_kwargs, c_string", + [ + ["null", {}, "n"], + ["bool_", {}, "b"], + ["uint8", {}, "C"], + ["uint16", {}, "S"], + ["uint32", {}, "I"], + ["uint64", {}, "L"], + ["int8", {}, "c"], + ["int16", {}, "S"], + ["int32", {}, "i"], + ["int64", {}, "l"], + ["float16", {}, "e"], + ["float32", {}, "f"], + ["float64", {}, "g"], + ["string", {}, "u"], + ["binary", {}, "z"], + ["time32", ("s",), "tts"], + ["time32", ("ms",), "ttm"], + ["time64", ("us",), "ttu"], + ["time64", ("ns",), "ttn"], + ["date32", {}, "tdD"], + ["date64", {}, "tdm"], + ["timestamp", {"unit": "s"}, "tss:"], + ["timestamp", {"unit": "ms"}, "tsm:"], + ["timestamp", {"unit": "us"}, "tsu:"], + ["timestamp", {"unit": "ns"}, "tsn:"], + ["timestamp", {"unit": "ns", "tz": "UTC"}, "tsn:UTC"], + ["duration", ("s",), "tDs"], + ["duration", ("ms",), "tDm"], + ["duration", ("us",), "tDu"], + ["duration", ("ns",), "tDn"], + ["decimal128", {"precision": 4, "scale": 2}, "d:4,2"], + ], +) +def test_dtype_to_arrow_c_fmt_arrowdtype(pa_dtype, args_kwargs, c_string): + # GH 52323 + pa = pytest.importorskip("pyarrow") + if not args_kwargs: + pa_type = getattr(pa, pa_dtype)() + elif isinstance(args_kwargs, tuple): + pa_type = getattr(pa, pa_dtype)(*args_kwargs) + else: + pa_type = getattr(pa, pa_dtype)(**args_kwargs) + arrow_type = pd.ArrowDtype(pa_type) + assert dtype_to_arrow_c_fmt(arrow_type) == c_string diff --git a/pandas/tests/internals/__init__.py b/pandas/tests/internals/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/internals/test_api.py b/pandas/tests/internals/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..9368105a1fa5b192e7b223529e527390b9d6d7bf --- /dev/null +++ b/pandas/tests/internals/test_api.py @@ -0,0 +1,178 @@ +""" +Tests for the pseudo-public API implemented in internals/api.py and exposed +in core.internals +""" + +import datetime + +import numpy as np +import pytest + +from pandas.errors import Pandas4Warning + +import pandas as pd +import pandas._testing as tm +from pandas.api.internals import create_dataframe_from_blocks +from pandas.core import internals +from pandas.core.internals import api + + +def test_internals_api(): + assert internals.make_block is api.make_block + + +def test_namespace(): + # SUBJECT TO CHANGE + + modules = [ + "blocks", + "concat", + "managers", + "construction", + "api", + "ops", + ] + expected = [ + "make_block", + "BlockManager", + "SingleBlockManager", + "concatenate_managers", + ] + + result = [x for x in dir(internals) if not x.startswith("__")] + assert set(result) == set(expected + modules), set(result) ^ set(expected + modules) + + +@pytest.mark.parametrize( + "name", + [ + "Block", + "ExtensionBlock", + ], +) +def test_deprecations(name): + # GH#55139 + msg = f"{name} is deprecated.* Use public APIs instead" + with tm.assert_produces_warning(Pandas4Warning, match=msg): + getattr(internals, name) + + +def test_make_block_2d_with_dti(): + # GH#41168 + dti = pd.date_range("2012", periods=3, tz="UTC") + + msg = "make_block is deprecated" + with tm.assert_produces_warning(Pandas4Warning, match=msg): + blk = api.make_block(dti, placement=[0]) + + assert blk.shape == (1, 3) + assert blk.values.shape == (1, 3) + + +def test_create_block_manager_from_blocks_deprecated(): + # GH#33892 + # If they must, downstream packages should get this from internals.api, + # not internals. + msg = ( + "create_block_manager_from_blocks is deprecated and will be " + "removed in a future version. Use public APIs instead" + ) + with tm.assert_produces_warning(Pandas4Warning, match=msg): + internals.create_block_manager_from_blocks + + +def test_maybe_infer_ndim_deprecated(): + # GH#40226 + msg = "maybe_infer_ndim is deprecated and will be removed in a future version." + arr = np.arange(5) + bp = pd._libs.internals.BlockPlacement([1]) + with tm.assert_produces_warning(DeprecationWarning, match=msg): + internals.api.maybe_infer_ndim(arr, bp, 1) + + +def test_create_dataframe_from_blocks(float_frame): + block = float_frame._mgr.blocks[0] + index = float_frame.index.copy() + columns = float_frame.columns.copy() + + result = create_dataframe_from_blocks( + [(block.values, block.mgr_locs.as_array)], index=index, columns=columns + ) + tm.assert_frame_equal(result, float_frame) + + +def test_create_dataframe_from_blocks_types(): + df = pd.DataFrame( + { + "int": list(range(1, 4)), + "uint": np.arange(3, 6).astype("uint8"), + "float": [2.0, np.nan, 3.0], + "bool": np.array([True, False, True]), + "boolean": pd.array([True, False, None], dtype="boolean"), + "string": list("abc"), + "datetime": pd.date_range("20130101", periods=3), + "datetimetz": pd.date_range("20130101", periods=3).tz_localize( + "Europe/Brussels" + ), + "timedelta": pd.timedelta_range("1 day", periods=3), + "period": pd.period_range("2012-01-01", periods=3, freq="D"), + "categorical": pd.Categorical(["a", "b", "a"]), + "interval": pd.IntervalIndex.from_tuples([(0, 1), (1, 2), (3, 4)]), + } + ) + + result = create_dataframe_from_blocks( + [(block.values, block.mgr_locs.as_array) for block in df._mgr.blocks], + index=df.index, + columns=df.columns, + ) + tm.assert_frame_equal(result, df) + + +def test_create_dataframe_from_blocks_datetimelike(): + # extension dtypes that have an exact matching numpy dtype can also be + # be passed as a numpy array + index, columns = pd.RangeIndex(3), pd.Index(["a", "b", "c", "d"]) + + block_array1 = np.arange( + datetime.datetime(2020, 1, 1), + datetime.datetime(2020, 1, 7), + step=datetime.timedelta(1), + ).reshape((2, 3)) + block_array2 = np.arange( + datetime.timedelta(1), datetime.timedelta(7), step=datetime.timedelta(1) + ).reshape((2, 3)) + result = create_dataframe_from_blocks( + [(block_array1, np.array([0, 2])), (block_array2, np.array([1, 3]))], + index=index, + columns=columns, + ) + expected = pd.DataFrame( + { + "a": pd.date_range("2020-01-01", periods=3, unit="us"), + "b": pd.timedelta_range("1 days", periods=3, unit="us"), + "c": pd.date_range("2020-01-04", periods=3, unit="us"), + "d": pd.timedelta_range("4 days", periods=3, unit="us"), + } + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "array", + [ + pd.date_range("2020-01-01", periods=3), + pd.date_range("2020-01-01", periods=3, tz="UTC"), + pd.period_range("2012-01-01", periods=3, freq="D"), + pd.timedelta_range("1 day", periods=3), + ], +) +def test_create_dataframe_from_blocks_1dEA(array): + # ExtensionArrays can be passed as 1D even if stored under the hood as 2D + df = pd.DataFrame({"a": array}) + + block = df._mgr.blocks[0] + result = create_dataframe_from_blocks( + [(block.values[0], block.mgr_locs.as_array)], index=df.index, columns=df.columns + ) + tm.assert_frame_equal(result, df) diff --git a/pandas/tests/internals/test_internals.py b/pandas/tests/internals/test_internals.py new file mode 100644 index 0000000000000000000000000000000000000000..8852ae81bddf2032f3cf03105917dde9fb6637bf --- /dev/null +++ b/pandas/tests/internals/test_internals.py @@ -0,0 +1,1421 @@ +from datetime import ( + date, + datetime, +) +import itertools +import re + +import numpy as np +import pytest + +from pandas._libs.internals import BlockPlacement +from pandas.compat import IS64 +from pandas.errors import Pandas4Warning + +from pandas.core.dtypes.common import is_scalar + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + DatetimeIndex, + Index, + IntervalIndex, + Series, + Timedelta, + Timestamp, + period_range, +) +import pandas._testing as tm +import pandas.core.algorithms as algos +from pandas.core.arrays import ( + DatetimeArray, + SparseArray, + TimedeltaArray, +) +from pandas.core.internals import ( + BlockManager, + SingleBlockManager, + make_block, +) +from pandas.core.internals.blocks import ( + ensure_block_shape, + maybe_coerce_values, + new_block, +) + + +@pytest.fixture(params=[new_block, make_block]) +def block_maker(request): + """ + Fixture to test both the internal new_block and pseudo-public make_block. + """ + return request.param + + +@pytest.fixture +def mgr(): + return create_mgr( + "a: f8; b: object; c: f8; d: object; e: f8;" + "f: bool; g: i8; h: complex; i: datetime-1; j: datetime-2;" + "k: M8[ns, US/Eastern]; l: M8[ns, CET];" + ) + + +def assert_block_equal(left, right): + tm.assert_numpy_array_equal(left.values, right.values) + assert left.dtype == right.dtype + assert isinstance(left.mgr_locs, BlockPlacement) + assert isinstance(right.mgr_locs, BlockPlacement) + tm.assert_numpy_array_equal(left.mgr_locs.as_array, right.mgr_locs.as_array) + + +def get_numeric_mat(shape): + arr = np.arange(shape[0]) + return np.lib.stride_tricks.as_strided( + x=arr, shape=shape, strides=(arr.itemsize,) + (0,) * (len(shape) - 1) + ).copy() + + +N = 10 + + +def create_block(typestr, placement, item_shape=None, num_offset=0, maker=new_block): + """ + Supported typestr: + + * float, f8, f4, f2 + * int, i8, i4, i2, i1 + * uint, u8, u4, u2, u1 + * complex, c16, c8 + * bool + * object, string, O + * datetime, dt, M8[ns], M8[ns, tz] + * timedelta, td, m8[ns] + * sparse (SparseArray with fill_value=0.0) + * sparse_na (SparseArray with fill_value=np.nan) + * category, category2 + + """ + placement = BlockPlacement(placement) + num_items = len(placement) + + if item_shape is None: + item_shape = (N,) + + shape = (num_items, *item_shape) + + mat = get_numeric_mat(shape) + + if typestr in ( + "float", + "f8", + "f4", + "f2", + "int", + "i8", + "i4", + "i2", + "i1", + "uint", + "u8", + "u4", + "u2", + "u1", + ): + values = mat.astype(typestr) + num_offset + elif typestr in ("complex", "c16", "c8"): + values = 1.0j * (mat.astype(typestr) + num_offset) + elif typestr in ("object", "string", "O"): + values = np.reshape([f"A{i:d}" for i in mat.ravel() + num_offset], shape) + elif typestr in ("b", "bool"): + values = np.ones(shape, dtype=np.bool_) + elif typestr in ("datetime", "dt", "M8[ns]"): + values = (mat * 1e9).astype("M8[ns]") + elif typestr.startswith("M8[ns"): + # datetime with tz + m = re.search(r"M8\[ns,\s*(\w+\/?\w*)\]", typestr) + assert m is not None, f"incompatible typestr -> {typestr}" + tz = m.groups()[0] + assert num_items == 1, "must have only 1 num items for a tz-aware" + values = DatetimeIndex(np.arange(N) * 10**9, tz=tz)._data + values = ensure_block_shape(values, ndim=len(shape)) + elif typestr in ("timedelta", "td", "m8[ns]"): + values = (mat * 1).astype("m8[ns]") + elif typestr in ("category",): + values = Categorical([1, 1, 2, 2, 3, 3, 3, 3, 4, 4]) + elif typestr in ("category2",): + values = Categorical(["a", "a", "a", "a", "b", "b", "c", "c", "c", "d"]) + elif typestr in ("sparse", "sparse_na"): + if shape[-1] != 10: + # We also are implicitly assuming this in the category cases above + raise NotImplementedError + + assert all(s == 1 for s in shape[:-1]) + if typestr.endswith("_na"): + fill_value = np.nan + else: + fill_value = 0.0 + values = SparseArray( + [fill_value, fill_value, 1, 2, 3, fill_value, 4, 5, fill_value, 6], + fill_value=fill_value, + ) + arr = values.sp_values.view() + arr += num_offset - 1 + else: + raise ValueError(f'Unsupported typestr: "{typestr}"') + + values = maybe_coerce_values(values) + return maker(values, placement=placement, ndim=len(shape)) + + +def create_single_mgr(typestr, num_rows=None): + if num_rows is None: + num_rows = N + + return SingleBlockManager( + create_block(typestr, placement=slice(0, num_rows), item_shape=()), + Index(np.arange(num_rows)), + ) + + +def create_mgr(descr, item_shape=None): + """ + Construct BlockManager from string description. + + String description syntax looks similar to np.matrix initializer. It looks + like this:: + + a,b,c: f8; d,e,f: i8 + + Rules are rather simple: + + * see list of supported datatypes in `create_block` method + * components are semicolon-separated + * each component is `NAME,NAME,NAME: DTYPE_ID` + * whitespace around colons & semicolons are removed + * components with same DTYPE_ID are combined into single block + * to force multiple blocks with same dtype, use '-SUFFIX':: + + "a:f8-1; b:f8-2; c:f8-foobar" + + """ + if item_shape is None: + item_shape = (N,) + + offset = 0 + mgr_items = [] + block_placements = {} + for d in descr.split(";"): + d = d.strip() + if not len(d): + continue + names, blockstr = d.partition(":")[::2] + blockstr = blockstr.strip() + names = names.strip().split(",") + + mgr_items.extend(names) + placement = list(np.arange(len(names)) + offset) + try: + block_placements[blockstr].extend(placement) + except KeyError: + block_placements[blockstr] = placement + offset += len(names) + + mgr_items = Index(mgr_items) + + blocks = [] + num_offset = 0 + for blockstr, placement in block_placements.items(): + typestr = blockstr.split("-")[0] + blocks.append( + create_block( + typestr, placement, item_shape=item_shape, num_offset=num_offset + ) + ) + num_offset += len(placement) + + sblocks = sorted(blocks, key=lambda b: b.mgr_locs[0]) + return BlockManager( + tuple(sblocks), + [mgr_items] + [Index(np.arange(n)) for n in item_shape], + ) + + +@pytest.fixture +def fblock(): + return create_block("float", [0, 2, 4]) + + +class TestBlock: + def test_constructor(self): + int32block = create_block("i4", [0]) + assert int32block.dtype == np.int32 + + @pytest.mark.parametrize( + "typ, data", + [ + ["float", [0, 2, 4]], + ["complex", [7]], + ["object", [1, 3]], + ["bool", [5]], + ], + ) + def test_pickle(self, typ, data, temp_file): + blk = create_block(typ, data) + assert_block_equal(tm.round_trip_pickle(blk, temp_file), blk) + + def test_mgr_locs(self, fblock): + assert isinstance(fblock.mgr_locs, BlockPlacement) + tm.assert_numpy_array_equal( + fblock.mgr_locs.as_array, np.array([0, 2, 4], dtype=np.intp) + ) + + def test_attrs(self, fblock): + assert fblock.shape == fblock.values.shape + assert fblock.dtype == fblock.values.dtype + assert len(fblock) == len(fblock.values) + + def test_copy(self, fblock): + cop = fblock.copy(deep=True) + assert cop is not fblock + assert_block_equal(fblock, cop) + + def test_delete(self, fblock): + newb = fblock.copy(deep=True) + locs = newb.mgr_locs + nb = newb.delete(0)[0] + assert newb.mgr_locs is locs + + assert nb is not newb + + tm.assert_numpy_array_equal( + nb.mgr_locs.as_array, np.array([2, 4], dtype=np.intp) + ) + assert not (newb.values[0] == 1).all() + assert (nb.values[0] == 1).all() + + newb = fblock.copy(deep=True) + locs = newb.mgr_locs + nb = newb.delete(1) + assert len(nb) == 2 + assert newb.mgr_locs is locs + + tm.assert_numpy_array_equal( + nb[0].mgr_locs.as_array, np.array([0], dtype=np.intp) + ) + tm.assert_numpy_array_equal( + nb[1].mgr_locs.as_array, np.array([4], dtype=np.intp) + ) + assert not (newb.values[1] == 2).all() + assert (nb[1].values[0] == 2).all() + + newb = fblock.copy(deep=True) + nb = newb.delete(2) + assert len(nb) == 1 + tm.assert_numpy_array_equal( + nb[0].mgr_locs.as_array, np.array([0, 2], dtype=np.intp) + ) + assert (nb[0].values[1] == 1).all() + + newb = fblock.copy(deep=True) + + with pytest.raises(IndexError, match=None): + newb.delete(3) + + def test_delete_datetimelike(self): + # dont use np.delete on values, as that will coerce from DTA/TDA to ndarray + arr = np.arange(20, dtype="i8").reshape(5, 4).view("m8[ns]") + df = DataFrame(arr) + blk = df._mgr.blocks[0] + assert isinstance(blk.values, TimedeltaArray) + + nb = blk.delete(1) + assert len(nb) == 2 + assert isinstance(nb[0].values, TimedeltaArray) + assert isinstance(nb[1].values, TimedeltaArray) + + df = DataFrame(arr.view("M8[ns]")) + blk = df._mgr.blocks[0] + assert isinstance(blk.values, DatetimeArray) + + nb = blk.delete([1, 3]) + assert len(nb) == 2 + assert isinstance(nb[0].values, DatetimeArray) + assert isinstance(nb[1].values, DatetimeArray) + + def test_split(self): + # GH#37799 + values = np.random.default_rng(2).standard_normal((3, 4)) + blk = new_block(values, placement=BlockPlacement([3, 1, 6]), ndim=2) + result = list(blk._split()) + + # check that we get views, not copies + values[:] = -9999 + assert (blk.values == -9999).all() + + assert len(result) == 3 + expected = [ + new_block(values[[0]], placement=BlockPlacement([3]), ndim=2), + new_block(values[[1]], placement=BlockPlacement([1]), ndim=2), + new_block(values[[2]], placement=BlockPlacement([6]), ndim=2), + ] + for res, exp in zip(result, expected): + assert_block_equal(res, exp) + + +class TestBlockManager: + def test_attrs(self): + mgr = create_mgr("a,b,c: f8-1; d,e,f: f8-2") + assert mgr.nblocks == 2 + assert len(mgr) == 6 + + def test_duplicate_ref_loc_failure(self): + tmp_mgr = create_mgr("a:bool; a: f8") + + axes, blocks = tmp_mgr.axes, tmp_mgr.blocks + + blocks[0].mgr_locs = BlockPlacement(np.array([0])) + blocks[1].mgr_locs = BlockPlacement(np.array([0])) + + # test trying to create block manager with overlapping ref locs + + msg = "Gaps in blk ref_locs" + + mgr = BlockManager(blocks, axes) + with pytest.raises(AssertionError, match=msg): + mgr._rebuild_blknos_and_blklocs() + + blocks[0].mgr_locs = BlockPlacement(np.array([0])) + blocks[1].mgr_locs = BlockPlacement(np.array([1])) + mgr = BlockManager(blocks, axes) + mgr.iget(1) + + def test_pickle(self, mgr, temp_file): + mgr2 = tm.round_trip_pickle(mgr, temp_file) + tm.assert_frame_equal( + DataFrame._from_mgr(mgr, axes=mgr.axes), + DataFrame._from_mgr(mgr2, axes=mgr2.axes), + ) + + # GH2431 + assert hasattr(mgr2, "_is_consolidated") + assert hasattr(mgr2, "_known_consolidated") + + # reset to False on load + assert not mgr2._is_consolidated + assert not mgr2._known_consolidated + + @pytest.mark.parametrize("mgr_string", ["a,a,a:f8", "a: f8; a: i8"]) + def test_non_unique_pickle(self, mgr_string, temp_file): + mgr = create_mgr(mgr_string) + mgr2 = tm.round_trip_pickle(mgr, temp_file) + tm.assert_frame_equal( + DataFrame._from_mgr(mgr, axes=mgr.axes), + DataFrame._from_mgr(mgr2, axes=mgr2.axes), + ) + + def test_categorical_block_pickle(self, temp_file): + mgr = create_mgr("a: category") + mgr2 = tm.round_trip_pickle(mgr, temp_file) + tm.assert_frame_equal( + DataFrame._from_mgr(mgr, axes=mgr.axes), + DataFrame._from_mgr(mgr2, axes=mgr2.axes), + ) + + smgr = create_single_mgr("category") + smgr2 = tm.round_trip_pickle(smgr, temp_file) + tm.assert_series_equal( + Series()._constructor_from_mgr(smgr, axes=smgr.axes), + Series()._constructor_from_mgr(smgr2, axes=smgr2.axes), + ) + + def test_iget(self): + cols = Index(list("abc")) + values = np.random.default_rng(2).random((3, 3)) + block = new_block( + values=values.copy(), + placement=BlockPlacement(np.arange(3, dtype=np.intp)), + ndim=values.ndim, + ) + mgr = BlockManager(blocks=(block,), axes=[cols, Index(np.arange(3))]) + + tm.assert_almost_equal(mgr.iget(0).internal_values(), values[0]) + tm.assert_almost_equal(mgr.iget(1).internal_values(), values[1]) + tm.assert_almost_equal(mgr.iget(2).internal_values(), values[2]) + + def test_set(self): + mgr = create_mgr("a,b,c: int", item_shape=(3,)) + + mgr.insert(len(mgr.items), "d", np.array(["foo"] * 3)) + mgr.iset(1, np.array(["bar"] * 3)) + tm.assert_numpy_array_equal(mgr.iget(0).internal_values(), np.array([0] * 3)) + tm.assert_numpy_array_equal( + mgr.iget(1).internal_values(), np.array(["bar"] * 3, dtype=np.object_) + ) + tm.assert_numpy_array_equal(mgr.iget(2).internal_values(), np.array([2] * 3)) + tm.assert_numpy_array_equal( + mgr.iget(3).internal_values(), np.array(["foo"] * 3, dtype=np.object_) + ) + + def test_set_change_dtype(self, mgr): + mgr.insert(len(mgr.items), "baz", np.zeros(N, dtype=bool)) + + mgr.iset(mgr.items.get_loc("baz"), np.repeat("foo", N)) + idx = mgr.items.get_loc("baz") + assert mgr.iget(idx).dtype == np.object_ + + mgr2 = mgr.consolidate() + mgr2.iset(mgr2.items.get_loc("baz"), np.repeat("foo", N)) + idx = mgr2.items.get_loc("baz") + assert mgr2.iget(idx).dtype == np.object_ + + mgr2.insert( + len(mgr2.items), + "quux", + np.random.default_rng(2).standard_normal(N).astype(int), + ) + idx = mgr2.items.get_loc("quux") + assert mgr2.iget(idx).dtype == np.dtype(int) + + mgr2.iset( + mgr2.items.get_loc("quux"), np.random.default_rng(2).standard_normal(N) + ) + assert mgr2.iget(idx).dtype == np.float64 + + def test_copy(self, mgr): + cp = mgr.copy(deep=False) + for blk, cp_blk in zip(mgr.blocks, cp.blocks): + # view assertion + tm.assert_equal(cp_blk.values, blk.values) + if isinstance(blk.values, np.ndarray): + assert cp_blk.values.base.base is blk.values.base + else: + # DatetimeTZBlock has DatetimeIndex values + assert cp_blk.values._ndarray.base is blk.values._ndarray.base + + # copy(deep=True) consolidates, so the block-wise assertions will + # fail is mgr is not consolidated + mgr._consolidate_inplace() + cp = mgr.copy(deep=True) + for blk, cp_blk in zip(mgr.blocks, cp.blocks): + bvals = blk.values + cpvals = cp_blk.values + + tm.assert_equal(cpvals, bvals) + + if isinstance(cpvals, np.ndarray): + lbase = cpvals.base + rbase = bvals.base + else: + lbase = cpvals._ndarray.base + rbase = bvals._ndarray.base + + # copy assertion we either have a None for a base or in case of + # some blocks it is an array (e.g. datetimetz), but was copied + if isinstance(cpvals, DatetimeArray): + assert (lbase is None and rbase is None) or (lbase is not rbase) + elif not isinstance(cpvals, np.ndarray): + assert lbase is not rbase + else: + assert lbase is None and rbase is None + + def test_sparse(self): + mgr = create_mgr("a: sparse-1; b: sparse-2") + assert mgr.as_array().dtype == np.float64 + + def test_sparse_mixed(self): + mgr = create_mgr("a: sparse-1; b: sparse-2; c: f8") + assert len(mgr.blocks) == 3 + assert isinstance(mgr, BlockManager) + + @pytest.mark.parametrize( + "mgr_string, dtype", + [("c: f4; d: f2", np.float32), ("c: f4; d: f2; e: f8", np.float64)], + ) + def test_as_array_float(self, mgr_string, dtype): + mgr = create_mgr(mgr_string) + assert mgr.as_array().dtype == dtype + + @pytest.mark.parametrize( + "mgr_string, dtype", + [ + ("a: bool-1; b: bool-2", np.bool_), + ("a: i8-1; b: i8-2; c: i4; d: i2; e: u1", np.int64), + ("c: i4; d: i2; e: u1", np.int32), + ], + ) + def test_as_array_int_bool(self, mgr_string, dtype): + mgr = create_mgr(mgr_string) + assert mgr.as_array().dtype == dtype + + def test_as_array_datetime(self): + mgr = create_mgr("h: datetime-1; g: datetime-2") + assert mgr.as_array().dtype == "M8[ns]" + + def test_as_array_datetime_tz(self): + mgr = create_mgr("h: M8[ns, US/Eastern]; g: M8[ns, CET]") + assert mgr.iget(0).dtype == "datetime64[ns, US/Eastern]" + assert mgr.iget(1).dtype == "datetime64[ns, CET]" + assert mgr.as_array().dtype == "object" + + @pytest.mark.parametrize("t", ["float16", "float32", "float64", "int32", "int64"]) + def test_astype(self, t): + # coerce all + mgr = create_mgr("c: f4; d: f2; e: f8") + + t = np.dtype(t) + tmgr = mgr.astype(t) + assert tmgr.iget(0).dtype.type == t + assert tmgr.iget(1).dtype.type == t + assert tmgr.iget(2).dtype.type == t + + # mixed + mgr = create_mgr("a,b: object; c: bool; d: datetime; e: f4; f: f2; g: f8") + + t = np.dtype(t) + tmgr = mgr.astype(t, errors="ignore") + assert tmgr.iget(2).dtype.type == t + assert tmgr.iget(4).dtype.type == t + assert tmgr.iget(5).dtype.type == t + assert tmgr.iget(6).dtype.type == t + + assert tmgr.iget(0).dtype.type == np.object_ + assert tmgr.iget(1).dtype.type == np.object_ + if t != np.int64: + assert tmgr.iget(3).dtype.type == np.datetime64 + else: + assert tmgr.iget(3).dtype.type == t + + def test_convert(self, using_infer_string): + def _compare(old_mgr, new_mgr): + """compare the blocks, numeric compare ==, object don't""" + old_blocks = set(old_mgr.blocks) + new_blocks = set(new_mgr.blocks) + assert len(old_blocks) == len(new_blocks) + + # compare non-numeric + for b in old_blocks: + found = False + for nb in new_blocks: + if (b.values == nb.values).all(): + found = True + break + assert found + + for b in new_blocks: + found = False + for ob in old_blocks: + if (b.values == ob.values).all(): + found = True + break + assert found + + # noops + mgr = create_mgr("f: i8; g: f8") + new_mgr = mgr.convert() + _compare(mgr, new_mgr) + + # convert + mgr = create_mgr("a,b,foo: object; f: i8; g: f8") + mgr.iset(0, np.array(["1"] * N, dtype=np.object_)) + mgr.iset(1, np.array(["2."] * N, dtype=np.object_)) + mgr.iset(2, np.array(["foo."] * N, dtype=np.object_)) + new_mgr = mgr.convert() + dtype = "str" if using_infer_string else np.object_ + assert new_mgr.iget(0).dtype == dtype + assert new_mgr.iget(1).dtype == dtype + assert new_mgr.iget(2).dtype == dtype + assert new_mgr.iget(3).dtype == np.int64 + assert new_mgr.iget(4).dtype == np.float64 + + mgr = create_mgr( + "a,b,foo: object; f: i4; bool: bool; dt: datetime; i: i8; g: f8; h: f2" + ) + mgr.iset(0, np.array(["1"] * N, dtype=np.object_)) + mgr.iset(1, np.array(["2."] * N, dtype=np.object_)) + mgr.iset(2, np.array(["foo."] * N, dtype=np.object_)) + new_mgr = mgr.convert() + assert new_mgr.iget(0).dtype == dtype + assert new_mgr.iget(1).dtype == dtype + assert new_mgr.iget(2).dtype == dtype + assert new_mgr.iget(3).dtype == np.int32 + assert new_mgr.iget(4).dtype == np.bool_ + assert new_mgr.iget(5).dtype.type, np.datetime64 + assert new_mgr.iget(6).dtype == np.int64 + assert new_mgr.iget(7).dtype == np.float64 + assert new_mgr.iget(8).dtype == np.float16 + + def test_interleave(self): + # self + for dtype in ["f8", "i8", "object", "bool", "complex", "M8[ns]", "m8[ns]"]: + mgr = create_mgr(f"a: {dtype}") + assert mgr.as_array().dtype == dtype + mgr = create_mgr(f"a: {dtype}; b: {dtype}") + assert mgr.as_array().dtype == dtype + + @pytest.mark.parametrize( + "mgr_string, dtype", + [ + ("a: category", "i8"), + ("a: category; b: category", "i8"), + ("a: category; b: category2", "object"), + ("a: category2", "object"), + ("a: category2; b: category2", "object"), + ("a: f8", "f8"), + ("a: f8; b: i8", "f8"), + ("a: f4; b: i8", "f8"), + ("a: f4; b: i8; d: object", "object"), + ("a: bool; b: i8", "object"), + ("a: complex", "complex"), + ("a: f8; b: category", "object"), + ("a: M8[ns]; b: category", "object"), + ("a: M8[ns]; b: bool", "object"), + ("a: M8[ns]; b: i8", "object"), + ("a: m8[ns]; b: bool", "object"), + ("a: m8[ns]; b: i8", "object"), + ("a: M8[ns]; b: m8[ns]", "object"), + ], + ) + def test_interleave_dtype(self, mgr_string, dtype): + # will be converted according the actual dtype of the underlying + mgr = create_mgr("a: category") + assert mgr.as_array().dtype == "i8" + mgr = create_mgr("a: category; b: category2") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: category2") + assert mgr.as_array().dtype == "object" + + # combinations + mgr = create_mgr("a: f8") + assert mgr.as_array().dtype == "f8" + mgr = create_mgr("a: f8; b: i8") + assert mgr.as_array().dtype == "f8" + mgr = create_mgr("a: f4; b: i8") + assert mgr.as_array().dtype == "f8" + mgr = create_mgr("a: f4; b: i8; d: object") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: bool; b: i8") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: complex") + assert mgr.as_array().dtype == "complex" + mgr = create_mgr("a: f8; b: category") + assert mgr.as_array().dtype == "f8" + mgr = create_mgr("a: M8[ns]; b: category") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: M8[ns]; b: bool") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: M8[ns]; b: i8") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: m8[ns]; b: bool") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: m8[ns]; b: i8") + assert mgr.as_array().dtype == "object" + mgr = create_mgr("a: M8[ns]; b: m8[ns]") + assert mgr.as_array().dtype == "object" + + def test_consolidate_ordering_issues(self, mgr): + mgr.iset(mgr.items.get_loc("f"), np.random.default_rng(2).standard_normal(N)) + mgr.iset(mgr.items.get_loc("d"), np.random.default_rng(2).standard_normal(N)) + mgr.iset(mgr.items.get_loc("b"), np.random.default_rng(2).standard_normal(N)) + mgr.iset(mgr.items.get_loc("g"), np.random.default_rng(2).standard_normal(N)) + mgr.iset(mgr.items.get_loc("h"), np.random.default_rng(2).standard_normal(N)) + + # we have datetime/tz blocks in mgr + cons = mgr.consolidate() + assert cons.nblocks == 4 + cons = mgr.consolidate().get_numeric_data() + assert cons.nblocks == 1 + assert isinstance(cons.blocks[0].mgr_locs, BlockPlacement) + tm.assert_numpy_array_equal( + cons.blocks[0].mgr_locs.as_array, np.arange(len(cons.items), dtype=np.intp) + ) + + def test_reindex_items(self): + # mgr is not consolidated, f8 & f8-2 blocks + mgr = create_mgr("a: f8; b: i8; c: f8; d: i8; e: f8; f: bool; g: f8-2") + + reindexed = mgr.reindex_axis(["g", "c", "a", "d"], axis=0) + assert not reindexed.is_consolidated() + + tm.assert_index_equal(reindexed.items, Index(["g", "c", "a", "d"])) + tm.assert_almost_equal( + mgr.iget(6).internal_values(), reindexed.iget(0).internal_values() + ) + tm.assert_almost_equal( + mgr.iget(2).internal_values(), reindexed.iget(1).internal_values() + ) + tm.assert_almost_equal( + mgr.iget(0).internal_values(), reindexed.iget(2).internal_values() + ) + tm.assert_almost_equal( + mgr.iget(3).internal_values(), reindexed.iget(3).internal_values() + ) + + def test_get_numeric_data(self): + mgr = create_mgr( + "int: int; float: float; complex: complex;" + "str: object; bool: bool; obj: object; dt: datetime", + item_shape=(3,), + ) + mgr.iset(5, np.array([1, 2, 3], dtype=np.object_)) + + numeric = mgr.get_numeric_data() + tm.assert_index_equal(numeric.items, Index(["int", "float", "complex", "bool"])) + tm.assert_almost_equal( + mgr.iget(mgr.items.get_loc("float")).internal_values(), + numeric.iget(numeric.items.get_loc("float")).internal_values(), + ) + + # Check sharing + numeric.iset( + numeric.items.get_loc("float"), + np.array([100.0, 200.0, 300.0]), + inplace=True, + ) + tm.assert_almost_equal( + mgr.iget(mgr.items.get_loc("float")).internal_values(), + np.array([1.0, 1.0, 1.0]), + ) + + def test_get_bool_data(self): + mgr = create_mgr( + "int: int; float: float; complex: complex;" + "str: object; bool: bool; obj: object; dt: datetime", + item_shape=(3,), + ) + mgr.iset(6, np.array([True, False, True], dtype=np.object_)) + + bools = mgr.get_bool_data() + tm.assert_index_equal(bools.items, Index(["bool"])) + tm.assert_almost_equal( + mgr.iget(mgr.items.get_loc("bool")).internal_values(), + bools.iget(bools.items.get_loc("bool")).internal_values(), + ) + + bools.iset(0, np.array([True, False, True]), inplace=True) + tm.assert_numpy_array_equal( + mgr.iget(mgr.items.get_loc("bool")).internal_values(), + np.array([True, True, True]), + ) + + def test_unicode_repr_doesnt_raise(self): + repr(create_mgr("b,\u05d0: object")) + + @pytest.mark.parametrize( + "mgr_string", ["a,b,c: i8-1; d,e,f: i8-2", "a,a,a: i8-1; b,b,b: i8-2"] + ) + def test_equals(self, mgr_string): + # unique items + bm1 = create_mgr(mgr_string) + bm2 = BlockManager(bm1.blocks[::-1], bm1.axes) + assert bm1.equals(bm2) + + @pytest.mark.parametrize( + "mgr_string", + [ + "a:i8;b:f8", # basic case + "a:i8;b:f8;c:c8;d:b", # many types + "a:i8;e:dt;f:td;g:string", # more types + "a:i8;b:category;c:category2", # categories + "c:sparse;d:sparse_na;b:f8", # sparse + ], + ) + def test_equals_block_order_different_dtypes(self, mgr_string): + # GH 9330 + bm = create_mgr(mgr_string) + block_perms = itertools.permutations(bm.blocks) + for bm_perm in block_perms: + bm_this = BlockManager(tuple(bm_perm), bm.axes) + assert bm.equals(bm_this) + assert bm_this.equals(bm) + + def test_single_mgr_ctor(self): + mgr = create_single_mgr("f8", num_rows=5) + assert mgr.external_values().tolist() == [0.0, 1.0, 2.0, 3.0, 4.0] + + @pytest.mark.parametrize("value", [1, "True", [1, 2, 3], 5.0]) + def test_validate_bool_args(self, value): + bm1 = create_mgr("a,b,c: i8-1; d,e,f: i8-2") + + msg = ( + 'For argument "inplace" expected type bool, ' + f"received type {type(value).__name__}." + ) + with pytest.raises(ValueError, match=msg): + bm1.replace_list([1], [2], inplace=value) + + def test_iset_split_block(self): + bm = create_mgr("a,b,c: i8; d: f8") + bm._iset_split_block(0, np.array([0])) + tm.assert_numpy_array_equal( + bm.blklocs, np.array([0, 0, 1, 0], dtype="int64" if IS64 else "int32") + ) + # First indexer currently does not have a block associated with it in case + tm.assert_numpy_array_equal( + bm.blknos, np.array([0, 0, 0, 1], dtype="int64" if IS64 else "int32") + ) + assert len(bm.blocks) == 2 + + def test_iset_split_block_values(self): + bm = create_mgr("a,b,c: i8; d: f8") + bm._iset_split_block(0, np.array([0]), np.array([list(range(10))])) + tm.assert_numpy_array_equal( + bm.blklocs, np.array([0, 0, 1, 0], dtype="int64" if IS64 else "int32") + ) + # First indexer currently does not have a block associated with it in case + tm.assert_numpy_array_equal( + bm.blknos, np.array([0, 2, 2, 1], dtype="int64" if IS64 else "int32") + ) + assert len(bm.blocks) == 3 + + +def _as_array(mgr): + if mgr.ndim == 1: + return mgr.external_values() + return mgr.as_array().T + + +class TestIndexing: + # Nosetests-style data-driven tests. + # + # This test applies different indexing routines to block managers and + # compares the outcome to the result of same operations on np.ndarray. + # + # NOTE: sparse (SparseBlock with fill_value != np.nan) fail a lot of tests + # and are disabled. + + MANAGERS = [ + create_single_mgr("f8", N), + create_single_mgr("i8", N), + # 2-dim + create_mgr("a,b,c,d,e,f: f8", item_shape=(N,)), + create_mgr("a,b,c,d,e,f: i8", item_shape=(N,)), + create_mgr("a,b: f8; c,d: i8; e,f: string", item_shape=(N,)), + create_mgr("a,b: f8; c,d: i8; e,f: f8", item_shape=(N,)), + ] + + @pytest.mark.parametrize("mgr", MANAGERS) + def test_get_slice(self, mgr): + def assert_slice_ok(mgr, axis, slobj): + mat = _as_array(mgr) + + # we maybe using an ndarray to test slicing and + # might not be the full length of the axis + if isinstance(slobj, np.ndarray): + ax = mgr.axes[axis] + if len(ax) and len(slobj) and len(slobj) != len(ax): + slobj = np.concatenate( + [slobj, np.zeros(len(ax) - len(slobj), dtype=bool)] + ) + + if isinstance(slobj, slice): + sliced = mgr.get_slice(slobj, axis=axis) + elif ( + mgr.ndim == 1 + and axis == 0 + and isinstance(slobj, np.ndarray) + and slobj.dtype == bool + ): + sliced = mgr.get_rows_with_mask(slobj) + else: + # BlockManager doesn't support non-slice, SingleBlockManager + # doesn't support axis > 0 + raise TypeError(slobj) + + mat_slobj = (slice(None),) * axis + (slobj,) + tm.assert_numpy_array_equal( + mat[mat_slobj], _as_array(sliced), check_dtype=False + ) + tm.assert_index_equal(mgr.axes[axis][slobj], sliced.axes[axis]) + + assert mgr.ndim <= 2, mgr.ndim + for ax in range(mgr.ndim): + # slice + assert_slice_ok(mgr, ax, slice(None)) + assert_slice_ok(mgr, ax, slice(3)) + assert_slice_ok(mgr, ax, slice(100)) + assert_slice_ok(mgr, ax, slice(1, 4)) + assert_slice_ok(mgr, ax, slice(3, 0, -2)) + + if mgr.ndim < 2: + # 2D only support slice objects + + # boolean mask + assert_slice_ok(mgr, ax, np.ones(mgr.shape[ax], dtype=np.bool_)) + assert_slice_ok(mgr, ax, np.zeros(mgr.shape[ax], dtype=np.bool_)) + + if mgr.shape[ax] >= 3: + assert_slice_ok(mgr, ax, np.arange(mgr.shape[ax]) % 3 == 0) + assert_slice_ok( + mgr, ax, np.array([True, True, False], dtype=np.bool_) + ) + + @pytest.mark.parametrize("mgr", MANAGERS) + def test_take(self, mgr): + def assert_take_ok(mgr, axis, indexer): + mat = _as_array(mgr) + taken = mgr.take(indexer, axis) + tm.assert_numpy_array_equal( + np.take(mat, indexer, axis), _as_array(taken), check_dtype=False + ) + tm.assert_index_equal(mgr.axes[axis].take(indexer), taken.axes[axis]) + + for ax in range(mgr.ndim): + # take/fancy indexer + assert_take_ok(mgr, ax, indexer=np.array([], dtype=np.intp)) + assert_take_ok(mgr, ax, indexer=np.array([0, 0, 0], dtype=np.intp)) + assert_take_ok( + mgr, ax, indexer=np.array(list(range(mgr.shape[ax])), dtype=np.intp) + ) + + if mgr.shape[ax] >= 3: + assert_take_ok(mgr, ax, indexer=np.array([0, 1, 2], dtype=np.intp)) + assert_take_ok(mgr, ax, indexer=np.array([-1, -2, -3], dtype=np.intp)) + + @pytest.mark.parametrize("mgr", MANAGERS) + @pytest.mark.parametrize("fill_value", [None, np.nan, 100.0]) + def test_reindex_axis(self, fill_value, mgr): + def assert_reindex_axis_is_ok(mgr, axis, new_labels, fill_value): + mat = _as_array(mgr) + indexer = mgr.axes[axis].get_indexer_for(new_labels) + + reindexed = mgr.reindex_axis(new_labels, axis, fill_value=fill_value) + tm.assert_numpy_array_equal( + algos.take_nd(mat, indexer, axis, fill_value=fill_value), + _as_array(reindexed), + check_dtype=False, + ) + tm.assert_index_equal(reindexed.axes[axis], new_labels) + + for ax in range(mgr.ndim): + assert_reindex_axis_is_ok(mgr, ax, Index([]), fill_value) + assert_reindex_axis_is_ok(mgr, ax, mgr.axes[ax], fill_value) + assert_reindex_axis_is_ok(mgr, ax, mgr.axes[ax][[0, 0, 0]], fill_value) + assert_reindex_axis_is_ok(mgr, ax, Index(["foo", "bar", "baz"]), fill_value) + assert_reindex_axis_is_ok( + mgr, ax, Index(["foo", mgr.axes[ax][0], "baz"]), fill_value + ) + + if mgr.shape[ax] >= 3: + assert_reindex_axis_is_ok(mgr, ax, mgr.axes[ax][:-3], fill_value) + assert_reindex_axis_is_ok(mgr, ax, mgr.axes[ax][-3::-1], fill_value) + assert_reindex_axis_is_ok( + mgr, ax, mgr.axes[ax][[0, 1, 2, 0, 1, 2]], fill_value + ) + + @pytest.mark.parametrize("mgr", MANAGERS) + @pytest.mark.parametrize("fill_value", [None, np.nan, 100.0]) + def test_reindex_indexer(self, fill_value, mgr): + def assert_reindex_indexer_is_ok(mgr, axis, new_labels, indexer, fill_value): + mat = _as_array(mgr) + reindexed_mat = algos.take_nd(mat, indexer, axis, fill_value=fill_value) + reindexed = mgr.reindex_indexer( + new_labels, indexer, axis, fill_value=fill_value + ) + tm.assert_numpy_array_equal( + reindexed_mat, _as_array(reindexed), check_dtype=False + ) + tm.assert_index_equal(reindexed.axes[axis], new_labels) + + for ax in range(mgr.ndim): + assert_reindex_indexer_is_ok( + mgr, ax, Index([]), np.array([], dtype=np.intp), fill_value + ) + assert_reindex_indexer_is_ok( + mgr, ax, mgr.axes[ax], np.arange(mgr.shape[ax]), fill_value + ) + assert_reindex_indexer_is_ok( + mgr, + ax, + Index(["foo"] * mgr.shape[ax]), + np.arange(mgr.shape[ax]), + fill_value, + ) + assert_reindex_indexer_is_ok( + mgr, ax, mgr.axes[ax][::-1], np.arange(mgr.shape[ax]), fill_value + ) + assert_reindex_indexer_is_ok( + mgr, ax, mgr.axes[ax], np.arange(mgr.shape[ax])[::-1], fill_value + ) + assert_reindex_indexer_is_ok( + mgr, ax, Index(["foo", "bar", "baz"]), np.array([0, 0, 0]), fill_value + ) + assert_reindex_indexer_is_ok( + mgr, ax, Index(["foo", "bar", "baz"]), np.array([-1, 0, -1]), fill_value + ) + assert_reindex_indexer_is_ok( + mgr, + ax, + Index(["foo", mgr.axes[ax][0], "baz"]), + np.array([-1, -1, -1]), + fill_value, + ) + + if mgr.shape[ax] >= 3: + assert_reindex_indexer_is_ok( + mgr, + ax, + Index(["foo", "bar", "baz"]), + np.array([0, 1, 2]), + fill_value, + ) + + +class TestBlockPlacement: + @pytest.mark.parametrize( + "slc, expected", + [ + (slice(0, 4), 4), + (slice(0, 4, 2), 2), + (slice(0, 3, 2), 2), + (slice(0, 1, 2), 1), + (slice(1, 0, -1), 1), + ], + ) + def test_slice_len(self, slc, expected): + assert len(BlockPlacement(slc)) == expected + + @pytest.mark.parametrize("slc", [slice(1, 1, 0), slice(1, 2, 0)]) + def test_zero_step_raises(self, slc): + msg = "slice step cannot be zero" + with pytest.raises(ValueError, match=msg): + BlockPlacement(slc) + + def test_slice_canonize_negative_stop(self): + # GH#37524 negative stop is OK with negative step and positive start + slc = slice(3, -1, -2) + + bp = BlockPlacement(slc) + assert bp.indexer == slice(3, None, -2) + + @pytest.mark.parametrize( + "slc", + [ + slice(None, None), + slice(10, None), + slice(None, None, -1), + slice(None, 10, -1), + # These are "unbounded" because negative index will + # change depending on container shape. + slice(-1, None), + slice(None, -1), + slice(-1, -1), + slice(-1, None, -1), + slice(None, -1, -1), + slice(-1, -1, -1), + ], + ) + def test_unbounded_slice_raises(self, slc): + msg = "unbounded slice" + with pytest.raises(ValueError, match=msg): + BlockPlacement(slc) + + @pytest.mark.parametrize( + "slc", + [ + slice(0, 0), + slice(100, 0), + slice(100, 100), + slice(100, 100, -1), + slice(0, 100, -1), + ], + ) + def test_not_slice_like_slices(self, slc): + assert not BlockPlacement(slc).is_slice_like + + @pytest.mark.parametrize( + "arr, slc", + [ + ([0], slice(0, 1, 1)), + ([100], slice(100, 101, 1)), + ([0, 1, 2], slice(0, 3, 1)), + ([0, 5, 10], slice(0, 15, 5)), + ([0, 100], slice(0, 200, 100)), + ([2, 1], slice(2, 0, -1)), + ], + ) + def test_array_to_slice_conversion(self, arr, slc): + assert BlockPlacement(arr).as_slice == slc + + @pytest.mark.parametrize( + "arr", + [ + [], + [-1], + [-1, -2, -3], + [-10], + [-1, 0, 1, 2], + [-2, 0, 2, 4], + [1, 0, -1], + [1, 1, 1], + ], + ) + def test_not_slice_like_arrays(self, arr): + assert not BlockPlacement(arr).is_slice_like + + @pytest.mark.parametrize( + "slc, expected", + [(slice(0, 3), [0, 1, 2]), (slice(0, 0), []), (slice(3, 0), [])], + ) + def test_slice_iter(self, slc, expected): + assert list(BlockPlacement(slc)) == expected + + @pytest.mark.parametrize( + "slc, arr", + [ + (slice(0, 3), [0, 1, 2]), + (slice(0, 0), []), + (slice(3, 0), []), + (slice(3, 0, -1), [3, 2, 1]), + ], + ) + def test_slice_to_array_conversion(self, slc, arr): + tm.assert_numpy_array_equal( + BlockPlacement(slc).as_array, np.asarray(arr, dtype=np.intp) + ) + + def test_blockplacement_add(self): + bpl = BlockPlacement(slice(0, 5)) + assert bpl.add(1).as_slice == slice(1, 6, 1) + assert bpl.add(np.arange(5)).as_slice == slice(0, 10, 2) + assert list(bpl.add(np.arange(5, 0, -1))) == [5, 5, 5, 5, 5] + + @pytest.mark.parametrize( + "val, inc, expected", + [ + (slice(0, 0), 0, []), + (slice(1, 4), 0, [1, 2, 3]), + (slice(3, 0, -1), 0, [3, 2, 1]), + ([1, 2, 4], 0, [1, 2, 4]), + (slice(0, 0), 10, []), + (slice(1, 4), 10, [11, 12, 13]), + (slice(3, 0, -1), 10, [13, 12, 11]), + ([1, 2, 4], 10, [11, 12, 14]), + (slice(0, 0), -1, []), + (slice(1, 4), -1, [0, 1, 2]), + ([1, 2, 4], -1, [0, 1, 3]), + ], + ) + def test_blockplacement_add_int(self, val, inc, expected): + assert list(BlockPlacement(val).add(inc)) == expected + + @pytest.mark.parametrize("val", [slice(1, 4), [1, 2, 4]]) + def test_blockplacement_add_int_raises(self, val): + msg = "iadd causes length change" + with pytest.raises(ValueError, match=msg): + BlockPlacement(val).add(-10) + + +class TestCanHoldElement: + @pytest.fixture( + params=[ + lambda x: x, + lambda x: x.to_series(), + lambda x: x._data, + lambda x: list(x), + lambda x: x.astype(object), + lambda x: np.asarray(x), + lambda x: x[0], + lambda x: x[:0], + ] + ) + def element(self, request): + """ + Functions that take an Index and return an element that should have + blk._can_hold_element(element) for a Block with this index's dtype. + """ + return request.param + + def test_datetime_block_can_hold_element(self): + block = create_block("datetime", [0]) + + assert block._can_hold_element([]) + + # We will check that block._can_hold_element iff arr.__setitem__ works + arr = pd.array(block.values.ravel()) + + # coerce None + assert block._can_hold_element(None) + arr[0] = None + assert arr[0] is pd.NaT + + # coerce different types of datetime objects + vals = [np.datetime64("2010-10-10"), datetime(2010, 10, 10)] + for val in vals: + assert block._can_hold_element(val) + arr[0] = val + + val = date(2010, 10, 10) + assert not block._can_hold_element(val) + + msg = ( + "value should be a 'Timestamp', 'NaT', " + "or array of those. Got 'date' instead." + ) + with pytest.raises(TypeError, match=msg): + arr[0] = val + + @pytest.mark.parametrize("dtype", [np.int64, np.uint64, np.float64]) + def test_interval_can_hold_element_emptylist(self, dtype, element): + arr = np.array([1, 3, 4], dtype=dtype) + ii = IntervalIndex.from_breaks(arr) + blk = new_block(ii._data, BlockPlacement([1]), ndim=2) + + assert blk._can_hold_element([]) + # TODO: check this holds for all blocks + + @pytest.mark.parametrize("dtype", [np.int64, np.uint64, np.float64]) + def test_interval_can_hold_element(self, dtype, element): + arr = np.array([1, 3, 4, 9], dtype=dtype) + ii = IntervalIndex.from_breaks(arr) + blk = new_block(ii._data, BlockPlacement([1]), ndim=2) + + elem = element(ii) + self.check_series_setitem(elem, ii, True) + assert blk._can_hold_element(elem) + + # Careful: to get the expected Series-inplace behavior we need + # `elem` to not have the same length as `arr` + ii2 = IntervalIndex.from_breaks(arr[:-1], closed="neither") + elem = element(ii2) + with pytest.raises(TypeError, match="Invalid value"): + self.check_series_setitem(elem, ii, False) + assert not blk._can_hold_element(elem) + + ii3 = IntervalIndex.from_breaks([Timestamp(1), Timestamp(3), Timestamp(4)]) + elem = element(ii3) + with pytest.raises(TypeError, match="Invalid value"): + self.check_series_setitem(elem, ii, False) + assert not blk._can_hold_element(elem) + + ii4 = IntervalIndex.from_breaks([Timedelta(1), Timedelta(3), Timedelta(4)]) + elem = element(ii4) + with pytest.raises(TypeError, match="Invalid value"): + self.check_series_setitem(elem, ii, False) + assert not blk._can_hold_element(elem) + + def test_period_can_hold_element_emptylist(self): + pi = period_range("2016", periods=3, freq="Y") + blk = new_block(pi._data.reshape(1, 3), BlockPlacement([1]), ndim=2) + + assert blk._can_hold_element([]) + + def test_period_can_hold_element(self, element): + pi = period_range("2016", periods=3, freq="Y") + + elem = element(pi) + self.check_series_setitem(elem, pi, True) + + # Careful: to get the expected Series-inplace behavior we need + # `elem` to not have the same length as `arr` + pi2 = pi.asfreq("D")[:-1] + elem = element(pi2) + with pytest.raises(TypeError, match="Invalid value"): + self.check_series_setitem(elem, pi, False) + + dti = pi.to_timestamp("s")[:-1] + elem = element(dti) + with pytest.raises(TypeError, match="Invalid value"): + self.check_series_setitem(elem, pi, False) + + def test_period_reindex_axis(self): + # GH#60273 Test reindexing of block with PeriodDtype + pi = period_range("2020", periods=5, freq="Y") + blk = new_block(pi._data.reshape(5, 1), BlockPlacement(slice(5)), ndim=2) + mgr = BlockManager(blocks=(blk,), axes=[Index(np.arange(5)), Index(["a"])]) + reindexed = mgr.reindex_axis(Index([0, 2, 4]), axis=0) + result = DataFrame._from_mgr(reindexed, axes=reindexed.axes) + expected = DataFrame([[pi[0], pi[2], pi[4]]], columns=[0, 2, 4], index=["a"]) + tm.assert_frame_equal(result, expected) + + def check_can_hold_element(self, obj, elem, inplace: bool): + blk = obj._mgr.blocks[0] + if inplace: + assert blk._can_hold_element(elem) + else: + assert not blk._can_hold_element(elem) + + def check_series_setitem(self, elem, index: Index, inplace: bool): + arr = index._data.copy() + ser = Series(arr, copy=False) + + self.check_can_hold_element(ser, elem, inplace) + + if is_scalar(elem): + ser[0] = elem + else: + ser[: len(elem)] = elem + + if inplace: + assert ser._values is arr # i.e. setting was done inplace + else: + assert ser.dtype == object + + +class TestShouldStore: + def test_should_store_categorical(self): + cat = Categorical(["A", "B", "C"]) + df = DataFrame(cat) + blk = df._mgr.blocks[0] + + # matching dtype + assert blk.should_store(cat) + assert blk.should_store(cat[:-1]) + + # different dtype + assert not blk.should_store(cat.as_ordered()) + + # ndarray instead of Categorical + assert not blk.should_store(np.asarray(cat)) + + +def test_validate_ndim(): + values = np.array([1.0, 2.0]) + placement = BlockPlacement(slice(2)) + msg = r"Wrong number of dimensions. values.ndim != ndim \[1 != 2\]" + + depr_msg = "make_block is deprecated" + with pytest.raises(ValueError, match=msg): + with tm.assert_produces_warning(Pandas4Warning, match=depr_msg): + make_block(values, placement, ndim=2) + + +def test_block_shape(): + idx = Index([0, 1, 2, 3, 4]) + a = Series([1, 2, 3]).reindex(idx) + b = Series(Categorical([1, 2, 3])).reindex(idx) + + assert a._mgr.blocks[0].mgr_locs.indexer == b._mgr.blocks[0].mgr_locs.indexer + + +def test_make_block_no_pandas_array(block_maker): + # https://github.com/pandas-dev/pandas/pull/24866 + arr = pd.arrays.NumpyExtensionArray(np.array([1, 2])) + + depr_msg = "make_block is deprecated" + warn = DeprecationWarning if block_maker is make_block else None + + # NumpyExtensionArray, no dtype + with tm.assert_produces_warning(warn, match=depr_msg): + result = block_maker(arr, BlockPlacement(slice(len(arr))), ndim=arr.ndim) + assert result.dtype.kind in ["i", "u"] + + if block_maker is make_block: + # new_block requires caller to unwrap NumpyExtensionArray + assert result.is_extension is False + + # NumpyExtensionArray, NumpyEADtype + with tm.assert_produces_warning(warn, match=depr_msg): + result = block_maker(arr, slice(len(arr)), dtype=arr.dtype, ndim=arr.ndim) + assert result.dtype.kind in ["i", "u"] + assert result.is_extension is False + + # new_block no longer accepts dtype keyword + # ndarray, NumpyEADtype + with tm.assert_produces_warning(warn, match=depr_msg): + result = block_maker( + arr.to_numpy(), slice(len(arr)), dtype=arr.dtype, ndim=arr.ndim + ) + assert result.dtype.kind in ["i", "u"] + assert result.is_extension is False diff --git a/pandas/tests/io/__init__.py b/pandas/tests/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/io/conftest.py b/pandas/tests/io/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce44e87570a77aa31b2cc534ced2c2797e66c74 --- /dev/null +++ b/pandas/tests/io/conftest.py @@ -0,0 +1,197 @@ +import uuid + +import pytest + +from pandas.compat import ( + is_ci_environment, + is_platform_arm, + is_platform_mac, + is_platform_windows, +) +import pandas.util._test_decorators as td + +import pandas.io.common as icom +from pandas.io.parsers import read_csv + + +@pytest.fixture +def compression_to_extension(): + return {value: key for key, value in icom.extension_to_compression.items()} + + +@pytest.fixture +def tips_file(datapath): + """Path to the tips dataset""" + return datapath("io", "data", "csv", "tips.csv") + + +@pytest.fixture +def jsonl_file(datapath): + """Path to a JSONL dataset""" + return datapath("io", "parser", "data", "items.jsonl") + + +@pytest.fixture +def salaries_table(datapath): + """DataFrame with the salaries dataset""" + return read_csv(datapath("io", "parser", "data", "salaries.csv"), sep="\t") + + +@pytest.fixture +def feather_file(datapath): + return datapath("io", "data", "feather", "feather-0_3_1.feather") + + +@pytest.fixture +def xml_file(datapath): + return datapath("io", "data", "xml", "books.xml") + + +@pytest.fixture(scope="session") +def aws_credentials(monkeysession): + """Mocked AWS Credentials for moto.""" + monkeysession.setenv("AWS_ACCESS_KEY_ID", "testing") + monkeysession.setenv("AWS_SECRET_ACCESS_KEY", "testing") + monkeysession.setenv("AWS_SECURITY_TOKEN", "testing") + monkeysession.setenv("AWS_SESSION_AWS_SESSION_TOKEN", "testing") + monkeysession.setenv("AWS_DEFAULT_REGION", "us-east-1") + + +@pytest.fixture(scope="session") +def moto_server(aws_credentials): + # use service container for Linux on GitHub Actions + if is_ci_environment() and not ( + is_platform_mac() or is_platform_arm() or is_platform_windows() + ): + yield "http://localhost:5000" + else: + moto_server = pytest.importorskip("moto.server") + server = moto_server.ThreadedMotoServer(port=0) + server.start() + host, port = server.get_host_and_port() + yield f"http://{host}:{port}" + server.stop() + + +@pytest.fixture +def moto_s3_resource(moto_server): + boto3 = pytest.importorskip("boto3") + s3 = boto3.resource("s3", endpoint_url=moto_server) + return s3 + + +@pytest.fixture(scope="session") +def s3so(moto_server): + return { + "client_kwargs": { + "endpoint_url": moto_server, + } + } + + +@pytest.fixture +def s3_bucket_public(moto_s3_resource): + """ + Create a public S3 bucket using moto. + """ + bucket_name = f"pandas-test-{uuid.uuid4()}" + bucket = moto_s3_resource.Bucket(bucket_name) + bucket.create(ACL="public-read") + yield bucket + bucket.objects.delete() + bucket.delete() + + +@pytest.fixture +def s3_bucket_private(moto_s3_resource): + """ + Create a private S3 bucket using moto. + """ + bucket_name = f"cant_get_it-{uuid.uuid4()}" + bucket = moto_s3_resource.Bucket(bucket_name) + bucket.create(ACL="private") + yield bucket + bucket.objects.delete() + bucket.delete() + + +@pytest.fixture +def s3_bucket_public_with_data( + s3_bucket_public, tips_file, jsonl_file, feather_file, xml_file +): + """ + The following datasets + are loaded. + + - tips.csv + - tips.csv.gz + - tips.csv.bz2 + - items.jsonl + """ + test_s3_files = [ + ("tips#1.csv", tips_file), + ("tips.csv", tips_file), + ("tips.csv.gz", tips_file + ".gz"), + ("tips.csv.bz2", tips_file + ".bz2"), + ("items.jsonl", jsonl_file), + ("simple_dataset.feather", feather_file), + ("books.xml", xml_file), + ] + for s3_key, file_name in test_s3_files: + with open(file_name, "rb") as f: + s3_bucket_public.put_object(Key=s3_key, Body=f) + return s3_bucket_public + + +@pytest.fixture +def s3_bucket_private_with_data( + s3_bucket_private, tips_file, jsonl_file, feather_file, xml_file +): + """ + The following datasets + are loaded. + + - tips.csv + - tips.csv.gz + - tips.csv.bz2 + - items.jsonl + """ + test_s3_files = [ + ("tips#1.csv", tips_file), + ("tips.csv", tips_file), + ("tips.csv.gz", tips_file + ".gz"), + ("tips.csv.bz2", tips_file + ".bz2"), + ("items.jsonl", jsonl_file), + ("simple_dataset.feather", feather_file), + ("books.xml", xml_file), + ] + for s3_key, file_name in test_s3_files: + with open(file_name, "rb") as f: + s3_bucket_private.put_object(Key=s3_key, Body=f) + return s3_bucket_private + + +_compression_formats_params = [ + (".no_compress", None), + ("", None), + (".gz", "gzip"), + (".GZ", "gzip"), + (".bz2", "bz2"), + (".BZ2", "bz2"), + (".zip", "zip"), + (".ZIP", "zip"), + (".xz", "xz"), + (".XZ", "xz"), + pytest.param((".zst", "zstd"), marks=td.skip_if_no("zstandard")), + pytest.param((".ZST", "zstd"), marks=td.skip_if_no("zstandard")), +] + + +@pytest.fixture(params=_compression_formats_params[1:]) +def compression_format(request): + return request.param + + +@pytest.fixture(params=_compression_formats_params) +def compression_ext(request): + return request.param[0] diff --git a/pandas/tests/io/generate_legacy_storage_files.py b/pandas/tests/io/generate_legacy_storage_files.py new file mode 100644 index 0000000000000000000000000000000000000000..04f176a550edf0b927727788e958421b422c26d0 --- /dev/null +++ b/pandas/tests/io/generate_legacy_storage_files.py @@ -0,0 +1,421 @@ +""" +self-contained to write legacy storage pickle files + +To use this script. Create an environment where you want +generate pickles, say its for 0.20.3, with your pandas clone +in ~/pandas + +. activate pandas_0.20.3 +cd ~/pandas/pandas + +$ python -m tests.io.generate_legacy_storage_files \ + tests/io/data/legacy_pickle/0.20.3/ pickle + +This script generates a storage file for the current arch, system, +and python version + pandas version: 0.20.3 + output dir : pandas/pandas/tests/io/data/legacy_pickle/0.20.3/ + storage format: pickle +created pickle file: 0.20.3_x86_64_darwin_3.5.2.pickle + +The idea here is you are using the *current* version of the +generate_legacy_storage_files with an *older* version of pandas to +generate a pickle file. We will then check this file into a current +branch, and test using test_pickle.py. This will load the *older* +pickles and test versus the current data that is generated +(with main). These are then compared. + +If we have cases where we changed the signature (e.g. we renamed +offset -> freq in Timestamp). Then we have to conditionally execute +in the generate_legacy_storage_files.py to make it +run under the older AND the newer version. + +""" + +from datetime import timedelta +import os +import pickle +import platform as pl +import sys + +# Remove script directory from path, otherwise Python will try to +# import the JSON test directory as the json module +sys.path.pop(0) + +import numpy as np + +import pandas +from pandas import ( + Categorical, + DataFrame, + Index, + MultiIndex, + NaT, + Period, + RangeIndex, + Series, + Timestamp, + bdate_range, + date_range, + interval_range, + period_range, + timedelta_range, +) +from pandas.arrays import SparseArray + +from pandas.tseries.offsets import ( + FY5253, + BusinessDay, + BusinessHour, + CustomBusinessDay, + DateOffset, + Day, + Easter, + Hour, + LastWeekOfMonth, + Minute, + MonthBegin, + MonthEnd, + QuarterBegin, + QuarterEnd, + SemiMonthBegin, + SemiMonthEnd, + Week, + WeekOfMonth, + YearBegin, + YearEnd, +) + + +def _create_sp_series(): + nan = np.nan + + # nan-based + arr = np.arange(15, dtype=np.float64) + arr[7:12] = nan + arr[-1:] = nan + + bseries = Series(SparseArray(arr, kind="block")) + bseries.name = "bseries" + return bseries + + +def _create_sp_tsseries(): + nan = np.nan + + # nan-based + arr = np.arange(15, dtype=np.float64) + arr[7:12] = nan + arr[-1:] = nan + + date_index = bdate_range("1/1/2011", periods=len(arr)) + bseries = Series(SparseArray(arr, kind="block"), index=date_index) + bseries.name = "btsseries" + return bseries + + +def _create_sp_frame(): + nan = np.nan + + data = { + "A": [nan, nan, nan, 0, 1, 2, 3, 4, 5, 6], + "B": [0, 1, 2, nan, nan, nan, 3, 4, 5, 6], + "C": np.arange(10).astype(np.int64), + "D": [0, 1, 2, 3, 4, 5, nan, nan, nan, nan], + } + + dates = bdate_range("1/1/2011", periods=10) + return DataFrame(data, index=dates).apply(SparseArray) + + +def create_pickle_data(test: bool = True): + """create the pickle data""" + data = { + "A": [0.0, 1.0, 2.0, 3.0, np.nan], + "B": [0, 1, 0, 1, 0], + "C": ["foo1", "foo2", "foo3", "foo4", "foo5"], + "D": date_range("1/1/2009", periods=5), + "E": [0.0, 1, Timestamp("20100101"), "foo", 2.0], + } + + scalars = {"timestamp": Timestamp("20130101"), "period": Period("2012", "M")} + + index = { + "int": Index(np.arange(10)), + "date": date_range("20130101", periods=10), + "period": period_range("2013-01-01", freq="M", periods=10), + "float": Index(np.arange(10, dtype=np.float64)), + "uint": Index(np.arange(10, dtype=np.uint64)), + "timedelta": timedelta_range("00:00:00", freq="30min", periods=10), + "string": Index(["foo", "bar", "baz", "qux", "quux"], dtype="string"), + } + + index["range"] = RangeIndex(10) + + index["interval"] = interval_range(0, periods=10) + + mi = { + "reg2": MultiIndex.from_tuples( + tuple( + zip( + *[ + ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"], + ["one", "two", "one", "two", "one", "two", "one", "two"], + ] + ) + ), + names=["first", "second"], + ) + } + + series = { + "float": Series(data["A"]), + "int": Series(data["B"]), + "mixed": Series(data["E"]), + "ts": Series( + np.arange(10).astype(np.int64), index=date_range("20130101", periods=10) + ), + "mi": Series( + np.arange(5).astype(np.float64), + index=MultiIndex.from_tuples( + tuple(zip(*[[1, 1, 2, 2, 2], [3, 4, 3, 4, 5]])), names=["one", "two"] + ), + ), + "dup": Series(np.arange(5).astype(np.float64), index=["A", "B", "C", "D", "A"]), + "cat": Series(Categorical(["foo", "bar", "baz"])), + "dt": Series(date_range("20130101", periods=5)), + "dt_tz": Series(date_range("20130101", periods=5, tz="US/Eastern")), + "period": Series([Period("2000Q1")] * 5), + "string": Series(["foo", "bar", "baz", "qux", "quux"], dtype="string"), + } + + mixed_dup_df = DataFrame(data) + mixed_dup_df.columns = list("ABCDA") + frame = { + "float": DataFrame({"A": series["float"], "B": series["float"] + 1}), + "int": DataFrame({"A": series["int"], "B": series["int"] + 1}), + "mixed": DataFrame({k: data[k] for k in ["A", "B", "C", "D"]}), + "mi": DataFrame( + {"A": np.arange(5).astype(np.float64), "B": np.arange(5).astype(np.int64)}, + index=MultiIndex.from_tuples( + tuple( + zip( + *[ + ["bar", "bar", "baz", "baz", "baz"], + ["one", "two", "one", "two", "three"], + ] + ) + ), + names=["first", "second"], + ), + ), + "dup": DataFrame( + np.arange(15).reshape(5, 3).astype(np.float64), columns=["A", "B", "A"] + ), + "cat_onecol": DataFrame({"A": Categorical(["foo", "bar"])}), + "cat_and_float": DataFrame( + { + "A": Categorical(["foo", "bar", "baz"]), + "B": np.arange(3).astype(np.int64), + } + ), + "mixed_dup": mixed_dup_df, + "dt_mixed_tzs": DataFrame( + { + "A": Timestamp("20130102", tz="US/Eastern"), + "B": Timestamp("20130603", tz="CET"), + }, + index=range(5), + ), + "dt_mixed2_tzs": DataFrame( + { + "A": Timestamp("20130102", tz="US/Eastern"), + "B": Timestamp("20130603", tz="CET"), + "C": Timestamp("20130603", tz="UTC"), + }, + index=range(5), + ), + "string": DataFrame( + { + "A": Series(["foo", "bar", "baz", "qux", "quux"], dtype="string"), + "B": Series(["one", "two", "one", "two", "three"], dtype="string"), + } + ), + } + + cat = { + "int8": Categorical(list("abcdefg")), + "int16": Categorical(np.arange(1000)), + "int32": Categorical(np.arange(10000)), + } + + timestamp = { + "normal": Timestamp("2011-01-01"), + "nat": NaT, + "tz": Timestamp("2011-01-01", tz="US/Eastern"), + } + if test: + # kept because those are present in the legacy pickles (<= 1.4) + timestamp["freq"] = Timestamp("2011-01-01") + timestamp["both"] = Timestamp("2011-01-01", tz="Asia/Tokyo") + + off = { + "DateOffset": DateOffset(years=1), + "DateOffset_h_ns": DateOffset(hour=6, nanoseconds=5824), + "BusinessDay": BusinessDay(offset=timedelta(seconds=9)), + "BusinessHour": BusinessHour(normalize=True, n=6, end="15:14"), + "CustomBusinessDay": CustomBusinessDay(weekmask="Mon Fri"), + "SemiMonthBegin": SemiMonthBegin(day_of_month=9), + "SemiMonthEnd": SemiMonthEnd(day_of_month=24), + "MonthBegin": MonthBegin(1), + "MonthEnd": MonthEnd(1), + "QuarterBegin": QuarterBegin(1), + "QuarterEnd": QuarterEnd(1), + "Day": Day(1), + "YearBegin": YearBegin(1), + "YearEnd": YearEnd(1), + "Week": Week(1), + "Week_Tues": Week(2, normalize=False, weekday=1), + "WeekOfMonth": WeekOfMonth(week=3, weekday=4), + "LastWeekOfMonth": LastWeekOfMonth(n=1, weekday=3), + "FY5253": FY5253(n=2, weekday=6, startingMonth=7, variation="last"), + "Easter": Easter(), + "Hour": Hour(1), + "Minute": Minute(1), + } + + return { + "series": series, + "frame": frame, + "index": index, + "scalars": scalars, + "mi": mi, + "sp_series": {"float": _create_sp_series(), "ts": _create_sp_tsseries()}, + "sp_frame": {"float": _create_sp_frame()}, + "cat": cat, + "timestamp": timestamp, + "offsets": off, + } + + +def create_dataframe_all_types(): + timestamps = Series( + [ + Timestamp("2013-01-01"), + NaT, + Timestamp("2013-01-03"), + Timestamp("2013-01-04"), + Timestamp("2013-01-05"), + ] + ) + timedeltas = timestamps - timestamps[0] + + data = { + # "string": Series( + # ["a", "b", "c", None, "e"], dtype=StringDtype(na_value=np.nan) + # ), + # "object": Series(["a", "b", "c", None, "e"], dtype=object), + # "object_nan": Series(["a", "b", "c", np.nan, "e"], dtype=object), + "int": list(range(1, 6)), + "uint64": np.arange(3, 8).astype("uint64"), + "float": [0.1, 0.2, 0.3, 0.4, np.nan], + "float32": Series([0.1, 0.2, 0.3, 0.4, np.nan], dtype="float32"), + "bool": [True, False, True, False, True], + "datetime_ns": timestamps.dt.as_unit("ns"), + "datetime_us": timestamps.dt.as_unit("us"), + "datetime_ms": timestamps.dt.as_unit("ms"), + "datetime_s": timestamps.dt.as_unit("s"), + "datetimetz_ns": timestamps.dt.tz_localize("US/Eastern").dt.as_unit("ns"), + "datetimetz_us": timestamps.dt.tz_localize("US/Eastern").dt.as_unit("us"), + "timedelta_ns": timedeltas.dt.as_unit("ns"), + "timedelta_us": timedeltas.dt.as_unit("us"), + "timedelta_ms": timedeltas.dt.as_unit("ms"), + "timedelta_s": timedeltas.dt.as_unit("s"), + # "categorical": Categorical( + # Series( + # ["foo", "bar", "baz",np.nan,"foo"],dtype=StringDtype(na_value=np.nan) + # ) + # ), + # "categorical_object": Categorical( + # Series(["foo", "bar", "baz", np.nan, "foo"], dtype=object) + # ), + "categorical_int": Categorical([1, 2, 3, np.nan, 1]), + } + return DataFrame(data) + + +def platform_name(): + return "_".join( + [ + str(pandas.__version__), + str(pl.machine()), + str(pl.system().lower()), + str(pl.python_version()), + ] + ) + + +def write_legacy_pickles(output_dir): + pth = f"{platform_name()}.pickle" + + with open(os.path.join(output_dir, pth), "wb") as fh: + pickle.dump(create_pickle_data(test=False), fh, pickle.DEFAULT_PROTOCOL) + + print(f"created pickle file: {pth}") + + +def write_legacy_hdf(output_dir, format): + import tables + + pth = f"{platform_name()}_pytables-{tables.__version__}_{format}.h5" + + df = create_dataframe_all_types() + if format == "fixed": + # df = df.drop(columns=["categorical", "categorical_object", "categorical_int"]) + df = df.drop(columns=["categorical_int"]) + complevel = 9 if format == "table" else None + df.to_hdf( + os.path.join(output_dir, pth), + key="df_alltypes", + format=format, + complevel=complevel, + ) + + print(f"created hdf file: {pth}") + + +def write_legacy_file(): + # force our cwd to be the first searched + sys.path.insert(0, "") + + if not 3 <= len(sys.argv) <= 4: + sys.exit( + "Specify output directory and storage type: generate_legacy_" + "storage_files.py " + ) + + output_dir = str(sys.argv[1]) + storage_type = str(sys.argv[2]) + + print( + "This script generates a storage file for the current arch, system, " + "and python version" + ) + print(f" pandas version: {pandas.__version__}") + print(f" output dir : {output_dir}") + print(f" storage format: {storage_type}") + + if not os.path.exists(output_dir): + os.mkdir(output_dir) + + if storage_type == "pickle": + write_legacy_pickles(output_dir=output_dir) + elif storage_type == "hdf": + write_legacy_hdf(output_dir=output_dir, format="fixed") + write_legacy_hdf(output_dir=output_dir, format="table") + else: + sys.exit("storage_type must be one of {'pickle', 'hdf'}") + + +if __name__ == "__main__": + write_legacy_file() diff --git a/pandas/tests/io/test_clipboard.py b/pandas/tests/io/test_clipboard.py new file mode 100644 index 0000000000000000000000000000000000000000..25834c47c09c67c3feadd7817a4902e6bdeff378 --- /dev/null +++ b/pandas/tests/io/test_clipboard.py @@ -0,0 +1,402 @@ +from textwrap import dedent + +import numpy as np +import pytest + +from pandas.errors import ( + PyperclipException, + PyperclipWindowsException, +) + +import pandas as pd +from pandas import ( + NA, + DataFrame, + Series, + get_option, + read_clipboard, +) +import pandas._testing as tm + +from pandas.io.clipboard import ( + CheckedCall, + _stringifyText, + init_qt_clipboard, +) + + +def build_kwargs(sep, excel): + kwargs = {} + if excel != "default": + kwargs["excel"] = excel + if sep != "default": + kwargs["sep"] = sep + return kwargs + + +@pytest.fixture( + params=[ + "delims", + "utf8", + "utf16", + "string", + "long", + "nonascii", + "colwidth", + "mixed", + "float", + "int", + ] +) +def df(request): + data_type = request.param + + if data_type == "delims": + return DataFrame({"a": ['"a,\t"b|c', "d\tef`"], "b": ["hi'j", "k''lm"]}) + elif data_type == "utf8": + return DataFrame({"a": ["µasd", "Ωœ∑`"], "b": ["øπ∆˚¬", "œ∑`®"]}) + elif data_type == "utf16": + return DataFrame( + {"a": ["\U0001f44d\U0001f44d", "\U0001f44d\U0001f44d"], "b": ["abc", "def"]} + ) + elif data_type == "string": + return DataFrame( + np.array([f"i-{i}" for i in range(15)]).reshape(5, 3), columns=list("abc") + ) + elif data_type == "long": + max_rows = get_option("display.max_rows") + return DataFrame( + np.random.default_rng(2).integers(0, 10, size=(max_rows + 1, 3)), + columns=list("abc"), + ) + elif data_type == "nonascii": + return DataFrame({"en": "in English".split(), "es": "en español".split()}) + elif data_type == "colwidth": + _cw = get_option("display.max_colwidth") + 1 + return DataFrame( + np.array(["x" * _cw for _ in range(15)]).reshape(5, 3), columns=list("abc") + ) + elif data_type == "mixed": + return DataFrame( + { + "a": np.arange(1.0, 6.0) + 0.01, + "b": np.arange(1, 6).astype(np.int64), + "c": list("abcde"), + } + ) + elif data_type == "float": + return DataFrame(np.random.default_rng(2).random((5, 3)), columns=list("abc")) + elif data_type == "int": + return DataFrame( + np.random.default_rng(2).integers(0, 10, (5, 3)), columns=list("abc") + ) + else: + raise ValueError + + +@pytest.fixture +def mock_ctypes(monkeypatch): + """ + Mocks WinError to help with testing the clipboard. + """ + + def _mock_win_error(): + return "Window Error" + + # Set raising to False because WinError won't exist on non-windows platforms + with monkeypatch.context() as m: + m.setattr("ctypes.WinError", _mock_win_error, raising=False) + yield + + +@pytest.mark.usefixtures("mock_ctypes") +def test_checked_call_with_bad_call(monkeypatch): + """ + Give CheckCall a function that returns a falsey value and + mock get_errno so it returns false so an exception is raised. + """ + + def _return_false(): + return False + + monkeypatch.setattr("pandas.io.clipboard.get_errno", lambda: True) + msg = f"Error calling {_return_false.__name__} \\(Window Error\\)" + + with pytest.raises(PyperclipWindowsException, match=msg): + CheckedCall(_return_false)() + + +@pytest.mark.usefixtures("mock_ctypes") +def test_checked_call_with_valid_call(monkeypatch): + """ + Give CheckCall a function that returns a truthy value and + mock get_errno so it returns true so an exception is not raised. + The function should return the results from _return_true. + """ + + def _return_true(): + return True + + monkeypatch.setattr("pandas.io.clipboard.get_errno", lambda: False) + + # Give CheckedCall a callable that returns a truthy value s + checked_call = CheckedCall(_return_true) + assert checked_call() is True + + +@pytest.mark.parametrize( + "text", + [ + "String_test", + True, + 1, + 1.0, + 1j, + ], +) +def test_stringify_text(text): + valid_types = (str, int, float, bool) + + if isinstance(text, valid_types): + result = _stringifyText(text) + assert result == str(text) + else: + msg = ( + "only str, int, float, and bool values " + f"can be copied to the clipboard, not {type(text).__name__}" + ) + with pytest.raises(PyperclipException, match=msg): + _stringifyText(text) + + +@pytest.fixture +def set_pyqt_clipboard(monkeypatch): + qt_cut, qt_paste = init_qt_clipboard() + with monkeypatch.context() as m: + m.setattr(pd.io.clipboard, "clipboard_set", qt_cut) + m.setattr(pd.io.clipboard, "clipboard_get", qt_paste) + yield + + +@pytest.fixture +def clipboard(qapp): + clip = qapp.clipboard() + yield clip + clip.clear() + + +@pytest.mark.single_cpu +@pytest.mark.clipboard +@pytest.mark.usefixtures("set_pyqt_clipboard") +@pytest.mark.usefixtures("clipboard") +class TestClipboard: + # Test that default arguments copy as tab delimited + # Test that explicit delimiters are respected + @pytest.mark.parametrize("sep", [None, "\t", ",", "|"]) + @pytest.mark.parametrize("encoding", [None, "UTF-8", "utf-8", "utf8"]) + def test_round_trip_frame_sep(self, df, sep, encoding): + df.to_clipboard(excel=None, sep=sep, encoding=encoding) + result = read_clipboard(sep=sep or "\t", index_col=0, encoding=encoding) + tm.assert_frame_equal(df, result) + + # Test white space separator + def test_round_trip_frame_string(self, df): + df.to_clipboard(excel=False, sep=None) + result = read_clipboard() + assert df.to_string() == result.to_string() + assert df.shape == result.shape + + # Two character separator is not supported in to_clipboard + # Test that multi-character separators are not silently passed + def test_excel_sep_warning(self, df): + with tm.assert_produces_warning( + UserWarning, + match="to_clipboard in excel mode requires a single character separator.", + check_stacklevel=False, + ): + df.to_clipboard(excel=True, sep=r"\t") + + # Separator is ignored when excel=False and should produce a warning + def test_copy_delim_warning(self, df): + with tm.assert_produces_warning(UserWarning, match="ignores the sep argument"): + df.to_clipboard(excel=False, sep="\t") + + # Tests that the default behavior of to_clipboard is tab + # delimited and excel="True" + @pytest.mark.parametrize("sep", ["\t", None, "default"]) + @pytest.mark.parametrize("excel", [True, None, "default"]) + def test_clipboard_copy_tabs_default(self, sep, excel, df, clipboard): + kwargs = build_kwargs(sep, excel) + df.to_clipboard(**kwargs) + assert clipboard.text() == df.to_csv(sep="\t") + + # Tests reading of white space separated tables + @pytest.mark.parametrize("sep", [None, "default"]) + def test_clipboard_copy_strings(self, sep, df): + kwargs = build_kwargs(sep, False) + df.to_clipboard(**kwargs) + result = read_clipboard(sep=r"\s+") + assert result.to_string() == df.to_string() + assert df.shape == result.shape + + def test_read_clipboard_infer_excel(self, clipboard): + # gh-19010: avoid warnings + clip_kwargs = {"engine": "python"} + + text = dedent( + """ + John James\tCharlie Mingus + 1\t2 + 4\tHarry Carney + """.strip() + ) + clipboard.setText(text) + df = read_clipboard(**clip_kwargs) + + # excel data is parsed correctly + assert df.iloc[1, 1] == "Harry Carney" + + # having diff tab counts doesn't trigger it + text = dedent( + """ + a\t b + 1 2 + 3 4 + """.strip() + ) + clipboard.setText(text) + res = read_clipboard(**clip_kwargs) + + text = dedent( + """ + a b + 1 2 + 3 4 + """.strip() + ) + clipboard.setText(text) + exp = read_clipboard(**clip_kwargs) + + tm.assert_frame_equal(res, exp) + + def test_infer_excel_with_nulls(self, clipboard): + # GH41108 + text = "col1\tcol2\n1\tred\n\tblue\n2\tgreen" + + clipboard.setText(text) + df = read_clipboard() + df_expected = DataFrame( + data={"col1": [1, None, 2], "col2": ["red", "blue", "green"]} + ) + + # excel data is parsed correctly + tm.assert_frame_equal(df, df_expected) + + @pytest.mark.parametrize( + "multiindex", + [ + ( # Can't use `dedent` here as it will remove the leading `\t` + "\n".join( + [ + "\t\t\tcol1\tcol2", + "A\t0\tTrue\t1\tred", + "A\t1\tTrue\t\tblue", + "B\t0\tFalse\t2\tgreen", + ] + ), + [["A", "A", "B"], [0, 1, 0], [True, True, False]], + ), + ( + "\n".join( + ["\t\tcol1\tcol2", "A\t0\t1\tred", "A\t1\t\tblue", "B\t0\t2\tgreen"] + ), + [["A", "A", "B"], [0, 1, 0]], + ), + ], + ) + def test_infer_excel_with_multiindex(self, clipboard, multiindex): + # GH41108 + + clipboard.setText(multiindex[0]) + df = read_clipboard() + df_expected = DataFrame( + data={"col1": [1, None, 2], "col2": ["red", "blue", "green"]}, + index=multiindex[1], + ) + + # excel data is parsed correctly + tm.assert_frame_equal(df, df_expected) + + def test_invalid_encoding(self, df): + msg = "clipboard only supports utf-8 encoding" + # test case for testing invalid encoding + with pytest.raises(ValueError, match=msg): + df.to_clipboard(encoding="ascii") + with pytest.raises(NotImplementedError, match=msg): + read_clipboard(encoding="ascii") + + @pytest.mark.parametrize("data", ["\U0001f44d...", "Ωœ∑`...", "abcd..."]) + def test_raw_roundtrip(self, data): + # PR #25040 wide unicode wasn't copied correctly on PY3 on windows + df = DataFrame({"data": [data]}) + df.to_clipboard() + result = read_clipboard() + tm.assert_frame_equal(df, result) + + @pytest.mark.parametrize("engine", ["c", "python"]) + def test_read_clipboard_dtype_backend( + self, clipboard, string_storage, dtype_backend, engine, using_infer_string + ): + # GH#50502 + if dtype_backend == "pyarrow": + pa = pytest.importorskip("pyarrow") + string_dtype = pd.ArrowDtype(pa.string()) + else: + string_dtype = pd.StringDtype(string_storage) + + text = """a,b,c,d,e,f,g,h,i +x,1,4.0,x,2,4.0,,True,False +y,2,5.0,,,,,False,""" + clipboard.setText(text) + + with pd.option_context("mode.string_storage", string_storage): + result = read_clipboard(sep=",", dtype_backend=dtype_backend, engine=engine) + + expected = DataFrame( + { + "a": Series(["x", "y"], dtype=string_dtype), + "b": Series([1, 2], dtype="Int64"), + "c": Series([4.0, 5.0], dtype="Float64"), + "d": Series(["x", None], dtype=string_dtype), + "e": Series([2, NA], dtype="Int64"), + "f": Series([4.0, NA], dtype="Float64"), + "g": Series([NA, NA], dtype="Int64"), + "h": Series([True, False], dtype="boolean"), + "i": Series([False, NA], dtype="boolean"), + } + ) + if dtype_backend == "pyarrow": + from pandas.arrays import ArrowExtensionArray + + expected = DataFrame( + { + col: ArrowExtensionArray(pa.array(expected[col], from_pandas=True)) + for col in expected.columns + } + ) + expected["g"] = ArrowExtensionArray(pa.array([None, None])) + + if using_infer_string: + expected.columns = expected.columns.astype( + pd.StringDtype(string_storage, na_value=np.nan) + ) + + tm.assert_frame_equal(result, expected) + + def test_invalid_dtype_backend(self): + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + with pytest.raises(ValueError, match=msg): + read_clipboard(dtype_backend="numpy") diff --git a/pandas/tests/io/test_common.py b/pandas/tests/io/test_common.py new file mode 100644 index 0000000000000000000000000000000000000000..a5081109d2299799f9c982d8117f65157049d3ad --- /dev/null +++ b/pandas/tests/io/test_common.py @@ -0,0 +1,688 @@ +""" +Tests for the pandas.io.common functionalities +""" + +import codecs +import errno +from functools import partial +from io import ( + BytesIO, + StringIO, + UnsupportedOperation, +) +import mmap +import os +from pathlib import Path +import pickle +import tempfile + +import numpy as np +import pytest + +from pandas.compat import ( + WASM, + is_platform_windows, +) +from pandas.compat.pyarrow import pa_version_under19p0 +import pandas.util._test_decorators as td + +import pandas as pd +import pandas._testing as tm + +import pandas.io.common as icom + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) + + +class CustomFSPath: + """For testing fspath on unknown objects""" + + def __init__(self, path) -> None: + self.path = path + + def __fspath__(self): + return self.path + + +HERE = os.path.abspath(os.path.dirname(__file__)) + + +# https://github.com/cython/cython/issues/1720 +class TestCommonIOCapabilities: + data1 = """index,A,B,C,D +foo,2,3,4,5 +bar,7,8,9,10 +baz,12,13,14,15 +qux,12,13,14,15 +foo2,12,13,14,15 +bar2,12,13,14,15 +""" + + def test_expand_user(self): + filename = "~/sometest" + expanded_name = icom._expand_user(filename) + + assert expanded_name != filename + assert os.path.isabs(expanded_name) + assert os.path.expanduser(filename) == expanded_name + + def test_expand_user_normal_path(self): + filename = "/somefolder/sometest" + expanded_name = icom._expand_user(filename) + + assert expanded_name == filename + assert os.path.expanduser(filename) == expanded_name + + def test_stringify_path_pathlib(self): + rel_path = icom.stringify_path(Path(".")) + assert rel_path == "." + redundant_path = icom.stringify_path(Path("foo//bar")) + assert redundant_path == os.path.join("foo", "bar") + + def test_stringify_path_fspath(self): + p = CustomFSPath("foo/bar.csv") + result = icom.stringify_path(p) + assert result == "foo/bar.csv" + + def test_stringify_file_and_path_like(self, temp_file): + # GH 38125: do not stringify file objects that are also path-like + fsspec = pytest.importorskip("fsspec") + with fsspec.open(f"file://{temp_file}", mode="wb") as fsspec_obj: + assert fsspec_obj == icom.stringify_path(fsspec_obj) + + @pytest.mark.parametrize("path_type", [str, CustomFSPath, Path]) + def test_infer_compression_from_path(self, compression_format, path_type): + extension, expected = compression_format + path = path_type("foo/bar.csv" + extension) + compression = icom.infer_compression(path, compression="infer") + assert compression == expected + + @pytest.mark.parametrize("path_type", [str, CustomFSPath, Path]) + def test_get_handle_with_path(self, path_type): + with tempfile.TemporaryDirectory(dir=Path.home()) as tmp: + filename = path_type("~/" + Path(tmp).name + "/sometest") + with icom.get_handle(filename, "w") as handles: + assert Path(handles.handle.name).is_absolute() + assert os.path.expanduser(filename) == handles.handle.name + + def test_get_handle_with_buffer(self): + with StringIO() as input_buffer: + with icom.get_handle(input_buffer, "r") as handles: + assert handles.handle == input_buffer + assert not input_buffer.closed + assert input_buffer.closed + + # Test that BytesIOWrapper(get_handle) returns correct amount of bytes every time + def test_bytesiowrapper_returns_correct_bytes(self): + # Test latin1, ucs-2, and ucs-4 chars + data = """a,b,c +1,2,3 +©,®,® +Look,a snake,🐍""" + with icom.get_handle(StringIO(data), "rb", is_text=False) as handles: + result = b"" + chunksize = 5 + while True: + chunk = handles.handle.read(chunksize) + # Make sure each chunk is correct amount of bytes + assert len(chunk) <= chunksize + if len(chunk) < chunksize: + # Can be less amount of bytes, but only at EOF + # which happens when read returns empty + assert len(handles.handle.read()) == 0 + result += chunk + break + result += chunk + assert result == data.encode("utf-8") + + # Test that pyarrow can handle a file opened with get_handle + def test_get_handle_pyarrow_compat(sel, using_infer_string): + pa_csv = pytest.importorskip("pyarrow.csv") + + # Test latin1, ucs-2, and ucs-4 chars + data = """a,b,c +1,2,3 +©,®,® +Look,a snake,🐍""" + expected = pd.DataFrame( + {"a": ["1", "©", "Look"], "b": ["2", "®", "a snake"], "c": ["3", "®", "🐍"]} + ) + s = StringIO(data) + with icom.get_handle(s, "rb", is_text=False) as handles: + df = pa_csv.read_csv(handles.handle).to_pandas() + if pa_version_under19p0: + expected = expected.astype("object") + elif not using_infer_string: + expected = expected.astype(pd.StringDtype(na_value=np.nan)) + tm.assert_frame_equal(df, expected) + assert not s.closed + + def test_iterator(self): + with pd.read_csv(StringIO(self.data1), chunksize=1) as reader: + result = pd.concat(reader, ignore_index=True) + expected = pd.read_csv(StringIO(self.data1)) + tm.assert_frame_equal(result, expected) + + # GH12153 + with pd.read_csv(StringIO(self.data1), chunksize=1) as it: + first = next(it) + tm.assert_frame_equal(first, expected.iloc[[0]]) + tm.assert_frame_equal(pd.concat(it), expected.iloc[1:]) + + @pytest.mark.skipif(WASM, reason="limited file system access on WASM") + @pytest.mark.parametrize( + "reader, module, error_class, fn_ext", + [ + (pd.read_csv, "os", FileNotFoundError, "csv"), + (pd.read_fwf, "os", FileNotFoundError, "txt"), + (pd.read_excel, "xlrd", FileNotFoundError, "xlsx"), + (pd.read_feather, "pyarrow", OSError, "feather"), + (pd.read_hdf, "tables", FileNotFoundError, "h5"), + (pd.read_stata, "os", FileNotFoundError, "dta"), + (pd.read_sas, "os", FileNotFoundError, "sas7bdat"), + (pd.read_json, "os", FileNotFoundError, "json"), + (pd.read_pickle, "os", FileNotFoundError, "pickle"), + ], + ) + def test_read_non_existent(self, reader, module, error_class, fn_ext): + pytest.importorskip(module) + + path = os.path.join(HERE, "data", "does_not_exist." + fn_ext) + msg1 = rf"File (b')?.+does_not_exist\.{fn_ext}'? does not exist" + msg2 = rf"\[Errno 2\] No such file or directory: '.+does_not_exist\.{fn_ext}'" + msg3 = "Expected object or value" + msg4 = "path_or_buf needs to be a string file path or file-like" + msg5 = ( + rf"\[Errno 2\] File .+does_not_exist\.{fn_ext} does not exist: " + rf"'.+does_not_exist\.{fn_ext}'" + ) + msg6 = rf"\[Errno 2\] 没有那个文件或目录: '.+does_not_exist\.{fn_ext}'" + msg7 = ( + rf"\[Errno 2\] File o directory non esistente: '.+does_not_exist\.{fn_ext}'" + ) + msg8 = rf"Failed to open local file.+does_not_exist\.{fn_ext}" + + with pytest.raises( + error_class, + match=rf"({msg1}|{msg2}|{msg3}|{msg4}|{msg5}|{msg6}|{msg7}|{msg8})", + ): + reader(path) + + @pytest.mark.parametrize( + "method, module, error_class, fn_ext", + [ + (pd.DataFrame.to_csv, "os", OSError, "csv"), + (pd.DataFrame.to_html, "os", OSError, "html"), + (pd.DataFrame.to_excel, "xlrd", OSError, "xlsx"), + (pd.DataFrame.to_feather, "pyarrow", OSError, "feather"), + (pd.DataFrame.to_parquet, "pyarrow", OSError, "parquet"), + (pd.DataFrame.to_stata, "os", OSError, "dta"), + (pd.DataFrame.to_json, "os", OSError, "json"), + (pd.DataFrame.to_pickle, "os", OSError, "pickle"), + ], + ) + # NOTE: Missing parent directory for pd.DataFrame.to_hdf is handled by PyTables + def test_write_missing_parent_directory(self, method, module, error_class, fn_ext): + pytest.importorskip(module) + + dummy_frame = pd.DataFrame({"a": [1, 2, 3], "b": [2, 3, 4], "c": [3, 4, 5]}) + + path = os.path.join(HERE, "data", "missing_folder", "does_not_exist." + fn_ext) + + with pytest.raises( + error_class, + match=r"Cannot save file into a non-existent directory: .*missing_folder", + ): + method(dummy_frame, path) + + @pytest.mark.skipif(WASM, reason="limited file system access on WASM") + @pytest.mark.parametrize( + "reader, module, error_class, fn_ext", + [ + (pd.read_csv, "os", FileNotFoundError, "csv"), + (pd.read_table, "os", FileNotFoundError, "csv"), + (pd.read_fwf, "os", FileNotFoundError, "txt"), + (pd.read_excel, "xlrd", FileNotFoundError, "xlsx"), + (pd.read_feather, "pyarrow", OSError, "feather"), + (pd.read_hdf, "tables", FileNotFoundError, "h5"), + (pd.read_stata, "os", FileNotFoundError, "dta"), + (pd.read_sas, "os", FileNotFoundError, "sas7bdat"), + (pd.read_json, "os", FileNotFoundError, "json"), + (pd.read_pickle, "os", FileNotFoundError, "pickle"), + ], + ) + def test_read_expands_user_home_dir( + self, reader, module, error_class, fn_ext, monkeypatch + ): + pytest.importorskip(module) + + path = os.path.join("~", "does_not_exist." + fn_ext) + monkeypatch.setattr(icom, "_expand_user", lambda x: os.path.join("foo", x)) + + msg1 = rf"File (b')?.+does_not_exist\.{fn_ext}'? does not exist" + msg2 = rf"\[Errno 2\] No such file or directory: '.+does_not_exist\.{fn_ext}'" + msg3 = "Unexpected character found when decoding 'false'" + msg4 = "path_or_buf needs to be a string file path or file-like" + msg5 = ( + rf"\[Errno 2\] File .+does_not_exist\.{fn_ext} does not exist: " + rf"'.+does_not_exist\.{fn_ext}'" + ) + msg6 = rf"\[Errno 2\] 没有那个文件或目录: '.+does_not_exist\.{fn_ext}'" + msg7 = ( + rf"\[Errno 2\] File o directory non esistente: '.+does_not_exist\.{fn_ext}'" + ) + msg8 = rf"Failed to open local file.+does_not_exist\.{fn_ext}" + + with pytest.raises( + error_class, + match=rf"({msg1}|{msg2}|{msg3}|{msg4}|{msg5}|{msg6}|{msg7}|{msg8})", + ): + reader(path) + + @pytest.mark.parametrize( + "reader, module, path", + [ + (pd.read_csv, "os", ("io", "data", "csv", "iris.csv")), + (pd.read_table, "os", ("io", "data", "csv", "iris.csv")), + ( + pd.read_fwf, + "os", + ("io", "data", "fixed_width", "fixed_width_format.txt"), + ), + (pd.read_excel, "xlrd", ("io", "data", "excel", "test1.xlsx")), + ( + pd.read_feather, + "pyarrow", + ("io", "data", "feather", "feather-0_3_1.feather"), + ), + ( + pd.read_hdf, + "tables", + ("io", "data", "legacy_hdf", "pytables_native2.h5"), + ), + (pd.read_stata, "os", ("io", "data", "stata", "stata10_115.dta")), + (pd.read_sas, "os", ("io", "sas", "data", "test1.sas7bdat")), + (pd.read_json, "os", ("io", "json", "data", "tsframe_v012.json")), + ( + pd.read_pickle, + "os", + ("io", "data", "pickle", "categorical.0.25.0.pickle"), + ), + ], + ) + def test_read_fspath_all(self, reader, module, path, datapath): + pytest.importorskip(module) + path = datapath(*path) + + mypath = CustomFSPath(path) + result = reader(mypath) + expected = reader(path) + + if path.endswith(".pickle"): + # categorical + tm.assert_categorical_equal(result, expected) + else: + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "writer_name, writer_kwargs, module", + [ + ("to_csv", {}, "os"), + ("to_excel", {"engine": "openpyxl"}, "openpyxl"), + ("to_feather", {}, "pyarrow"), + ("to_html", {}, "os"), + ("to_json", {}, "os"), + ("to_latex", {}, "os"), + ("to_pickle", {}, "os"), + ("to_stata", {"time_stamp": pd.to_datetime("2019-01-01 00:00")}, "os"), + ], + ) + def test_write_fspath_all(self, writer_name, writer_kwargs, module, tmp_path): + if writer_name in ["to_latex"]: # uses Styler implementation + pytest.importorskip("jinja2") + string = str(tmp_path / "string") + fspath = str(tmp_path / "fspath") + df = pd.DataFrame({"A": [1, 2]}) + + pytest.importorskip(module) + mypath = CustomFSPath(fspath) + writer = getattr(df, writer_name) + + writer(string, **writer_kwargs) + writer(mypath, **writer_kwargs) + with open(string, "rb") as f_str, open(fspath, "rb") as f_path: + if writer_name == "to_excel": + # binary representation of excel contains time creation + # data that causes flaky CI failures + result = pd.read_excel(f_str, **writer_kwargs) + expected = pd.read_excel(f_path, **writer_kwargs) + tm.assert_frame_equal(result, expected) + else: + result = f_str.read() + expected = f_path.read() + assert result == expected + + def test_write_fspath_hdf5(self, tmp_path): + # Same test as write_fspath_all, except HDF5 files aren't + # necessarily byte-for-byte identical for a given dataframe, so we'll + # have to read and compare equality + pytest.importorskip("tables") + + df = pd.DataFrame({"A": [1, 2]}) + string = str(tmp_path / "string") + fspath = str(tmp_path / "fspath") + + mypath = CustomFSPath(fspath) + df.to_hdf(mypath, key="bar") + df.to_hdf(string, key="bar") + + result = pd.read_hdf(fspath, key="bar") + expected = pd.read_hdf(string, key="bar") + + tm.assert_frame_equal(result, expected) + + +@pytest.fixture +def mmap_file(datapath): + return datapath("io", "data", "csv", "test_mmap.csv") + + +class TestMMapWrapper: + @pytest.mark.skipif(WASM, reason="limited file system access on WASM") + def test_constructor_bad_file(self, mmap_file): + non_file = StringIO("I am not a file") + non_file.fileno = lambda: -1 + + # the error raised is different on Windows + if is_platform_windows(): + msg = "The parameter is incorrect" + err = OSError + else: + msg = "[Errno 22]" + err = mmap.error + + with pytest.raises(err, match=msg): + icom._maybe_memory_map(non_file, True) + + with open(mmap_file, encoding="utf-8") as target: + pass + + msg = "I/O operation on closed file" + with pytest.raises(ValueError, match=msg): + icom._maybe_memory_map(target, True) + + @pytest.mark.skipif(WASM, reason="limited file system access on WASM") + def test_next(self, mmap_file): + with open(mmap_file, encoding="utf-8") as target: + lines = target.readlines() + + with icom.get_handle( + target, "r", is_text=True, memory_map=True + ) as wrappers: + wrapper = wrappers.handle + assert isinstance(wrapper.buffer.buffer, mmap.mmap) + + for line in lines: + next_line = next(wrapper) + assert next_line.strip() == line.strip() + + with pytest.raises(StopIteration, match=r"^$"): + next(wrapper) + + def test_unknown_engine(self, temp_file): + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + df.to_csv(temp_file) + with pytest.raises(ValueError, match="Unknown engine"): + pd.read_csv(temp_file, engine="pyt") + + def test_binary_mode(self, temp_file): + """ + 'encoding' shouldn't be passed to 'open' in binary mode. + + GH 35058 + """ + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + df.to_csv(temp_file, mode="w+b") + tm.assert_frame_equal(df, pd.read_csv(temp_file, index_col=0)) + + @pytest.mark.parametrize("encoding", ["utf-16", "utf-32"]) + @pytest.mark.parametrize("compression_", ["bz2", "xz"]) + def test_warning_missing_utf_bom(self, encoding, compression_, temp_file): + """ + bz2 and xz do not write the byte order mark (BOM) for utf-16/32. + + https://stackoverflow.com/questions/55171439 + + GH 35681 + """ + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + with tm.assert_produces_warning(UnicodeWarning, match="byte order mark"): + df.to_csv(temp_file, compression=compression_, encoding=encoding) + + # reading should fail (otherwise we wouldn't need the warning) + msg = ( + r"UTF-\d+ stream does not start with BOM|" + r"'utf-\d+' codec can't decode byte" + ) + with pytest.raises(UnicodeError, match=msg): + pd.read_csv(temp_file, compression=compression_, encoding=encoding) + + +def test_is_fsspec_url(): + assert icom.is_fsspec_url("gcs://pandas/somethingelse.com") + assert icom.is_fsspec_url("gs://pandas/somethingelse.com") + # the following is the only remote URL that is handled without fsspec + assert not icom.is_fsspec_url("http://pandas/somethingelse.com") + assert not icom.is_fsspec_url("random:pandas/somethingelse.com") + assert not icom.is_fsspec_url("/local/path") + assert not icom.is_fsspec_url("relative/local/path") + # fsspec URL in string should not be recognized + assert not icom.is_fsspec_url("this is not fsspec://url") + assert not icom.is_fsspec_url("{'url': 'gs://pandas/somethingelse.com'}") + # accept everything that conforms to RFC 3986 schema + assert icom.is_fsspec_url("RFC-3986+compliant.spec://something") + + +def test_is_fsspec_url_chained(): + # GH#48978 Support chained fsspec URLs + # See https://filesystem-spec.readthedocs.io/en/latest/features.html#url-chaining. + assert icom.is_fsspec_url("filecache::s3://pandas/test.csv") + assert icom.is_fsspec_url("zip://test.csv::filecache::gcs://bucket/file.zip") + assert icom.is_fsspec_url("filecache::zip://test.csv::gcs://bucket/file.zip") + assert icom.is_fsspec_url("filecache::dask::s3://pandas/test.csv") + assert not icom.is_fsspec_url("filecache:s3://pandas/test.csv") + assert not icom.is_fsspec_url("filecache:::s3://pandas/test.csv") + assert not icom.is_fsspec_url("filecache::://pandas/test.csv") + + +@pytest.mark.parametrize("format", ["csv", "json"]) +def test_codecs_encoding(format, temp_file): + # GH39247 + expected = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + with open(temp_file, mode="w", encoding="utf-8") as handle: + getattr(expected, f"to_{format}")(handle) + with open(temp_file, encoding="utf-8") as handle: + if format == "csv": + df = pd.read_csv(handle, index_col=0) + else: + df = pd.read_json(handle) + tm.assert_frame_equal(expected, df) + + +def test_codecs_get_writer_reader(temp_file): + # GH39247 + expected = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + with open(temp_file, "wb") as handle: + with codecs.getwriter("utf-8")(handle) as encoded: + expected.to_csv(encoded) + with open(temp_file, "rb") as handle: + with codecs.getreader("utf-8")(handle) as encoded: + df = pd.read_csv(encoded, index_col=0) + tm.assert_frame_equal(expected, df) + + +@pytest.mark.parametrize( + "io_class,mode,msg", + [ + (BytesIO, "t", "a bytes-like object is required, not 'str'"), + (StringIO, "b", "string argument expected, got 'bytes'"), + ], +) +def test_explicit_encoding(io_class, mode, msg): + # GH39247; this test makes sure that if a user provides mode="*t" or "*b", + # it is used. In the case of this test it leads to an error as intentionally the + # wrong mode is requested + expected = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + with io_class() as buffer: + with pytest.raises(TypeError, match=msg): + expected.to_csv(buffer, mode=f"w{mode}") + + +@pytest.mark.parametrize("encoding_errors", ["strict", "replace"]) +@pytest.mark.parametrize("format", ["csv", "json"]) +def test_encoding_errors(encoding_errors, format, temp_file): + # GH39450 + msg = "'utf-8' codec can't decode byte" + bad_encoding = b"\xe4" + + if format == "csv": + content = b"," + bad_encoding + b"\n" + bad_encoding * 2 + b"," + bad_encoding + reader = partial(pd.read_csv, index_col=0) + else: + content = ( + b'{"' + + bad_encoding * 2 + + b'": {"' + + bad_encoding + + b'":"' + + bad_encoding + + b'"}}' + ) + reader = partial(pd.read_json, orient="index") + file = temp_file + file.write_bytes(content) + + if encoding_errors != "replace": + with pytest.raises(UnicodeDecodeError, match=msg): + reader(temp_file, encoding_errors=encoding_errors) + else: + df = reader(temp_file, encoding_errors=encoding_errors) + decoded = bad_encoding.decode(errors=encoding_errors) + expected = pd.DataFrame({decoded: [decoded]}, index=[decoded * 2]) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.parametrize("encoding_errors", [0, None]) +def test_encoding_errors_badtype(encoding_errors): + # GH 59075 + content = StringIO("A,B\n1,2\n3,4\n") + reader = partial(pd.read_csv, encoding_errors=encoding_errors) + expected_error = "encoding_errors must be a string, got " + expected_error += f"{type(encoding_errors).__name__}" + with pytest.raises(ValueError, match=expected_error): + reader(content) + + +def test_bad_encdoing_errors(temp_file): + # GH 39777 + with pytest.raises(LookupError, match="unknown error handler name"): + icom.get_handle(temp_file, "w", errors="bad") + + +@pytest.mark.skipif(WASM, reason="limited file system access on WASM") +def test_errno_attribute(): + # GH 13872 + with pytest.raises(FileNotFoundError, match="\\[Errno 2\\]") as err: + pd.read_csv("doesnt_exist") + assert err.errno == errno.ENOENT + + +def test_fail_mmap(): + with pytest.raises(UnsupportedOperation, match="fileno"): + with BytesIO() as buffer: + icom.get_handle(buffer, "rb", memory_map=True) + + +def test_close_on_error(): + # GH 47136 + class TestError: + def close(self): + raise OSError("test") + + with pytest.raises(OSError, match="test"): + with BytesIO() as buffer: + with icom.get_handle(buffer, "rb") as handles: + handles.created_handles.append(TestError()) + + +@td.skip_if_no("fsspec") +@pytest.mark.parametrize("compression", [None, "infer"]) +def test_read_csv_chained_url_no_error(datapath, compression): + # GH 60100 + tar_file_path = datapath("io", "data", "tar", "test-csv.tar") + chained_file_url = f"tar://test.csv::file://{tar_file_path}" + + result = pd.read_csv(chained_file_url, compression=compression, sep=";") + expected = pd.DataFrame({"1": {0: 3}, "2": {0: 4}}) + + tm.assert_frame_equal(expected, result) + + +@pytest.mark.parametrize( + "reader", + [ + pd.read_csv, + pd.read_fwf, + pd.read_excel, + pd.read_feather, + pd.read_hdf, + pd.read_stata, + pd.read_sas, + pd.read_json, + pd.read_pickle, + ], +) +def test_pickle_reader(reader): + # GH 22265 + with BytesIO() as buffer: + pickle.dump(reader, buffer) + + +@td.skip_if_no("pyarrow") +def test_pyarrow_read_csv_datetime_dtype(): + # GH 59904 + data = '"date"\n"20/12/2025"\n""\n"31/12/2020"' + result = pd.read_csv( + StringIO(data), parse_dates=["date"], dayfirst=True, dtype_backend="pyarrow" + ) + + expect_data = pd.to_datetime(["20/12/2025", pd.NaT, "31/12/2020"], dayfirst=True) + expect = pd.DataFrame({"date": expect_data}) + + tm.assert_frame_equal(expect, result) diff --git a/pandas/tests/io/test_compression.py b/pandas/tests/io/test_compression.py new file mode 100644 index 0000000000000000000000000000000000000000..97b64a29a7f2cff7bfd5c1311fab9b40dcdf1a0a --- /dev/null +++ b/pandas/tests/io/test_compression.py @@ -0,0 +1,382 @@ +import gzip +import io +import os +import subprocess +import sys +import tarfile +import textwrap +import zipfile + +import numpy as np +import pytest + +from pandas.compat import is_platform_windows + +import pandas as pd +import pandas._testing as tm + +import pandas.io.common as icom + + +@pytest.mark.parametrize( + "obj", + [ + pd.DataFrame( + 100 * [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + columns=["X", "Y", "Z"], + ), + pd.Series(100 * [0.123456, 0.234567, 0.567567], name="X"), + ], +) +@pytest.mark.parametrize("method", ["to_pickle", "to_json", "to_csv"]) +def test_compression_size(obj, method, compression_only, temp_file): + if compression_only == "tar": + compression_only = {"method": "tar", "mode": "w:gz"} + + path = temp_file + getattr(obj, method)(path, compression=compression_only) + compressed_size = os.path.getsize(path) + getattr(obj, method)(path, compression=None) + uncompressed_size = os.path.getsize(path) + assert uncompressed_size > compressed_size + + +@pytest.mark.parametrize( + "obj", + [ + pd.DataFrame( + 100 * [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + columns=["X", "Y", "Z"], + ), + pd.Series(100 * [0.123456, 0.234567, 0.567567], name="X"), + ], +) +@pytest.mark.parametrize("method", ["to_csv", "to_json"]) +def test_compression_size_fh(obj, method, compression_only, temp_file): + path = temp_file + with icom.get_handle( + path, + "w:gz" if compression_only == "tar" else "w", + compression=compression_only, + ) as handles: + getattr(obj, method)(handles.handle) + assert not handles.handle.closed + compressed_size = os.path.getsize(path) + + # Create a new temporary file for uncompressed comparison + path2 = temp_file.parent / f"{temp_file.stem}_uncompressed{temp_file.suffix}" + path2.touch() + with icom.get_handle(path2, "w", compression=None) as handles: + getattr(obj, method)(handles.handle) + assert not handles.handle.closed + uncompressed_size = os.path.getsize(path2) + assert uncompressed_size > compressed_size + + +@pytest.mark.parametrize( + "write_method, write_kwargs, read_method", + [ + ("to_csv", {"index": False}, pd.read_csv), + ("to_json", {}, pd.read_json), + ("to_pickle", {}, pd.read_pickle), + ], +) +def test_dataframe_compression_defaults_to_infer( + write_method, + write_kwargs, + read_method, + compression_only, + compression_to_extension, + temp_file, +): + # GH22004 + input = pd.DataFrame([[1.0, 0, -4], [3.4, 5, 2]], columns=["X", "Y", "Z"]) + extension = compression_to_extension[compression_only] + path = temp_file.parent / f"compressed{extension}" + getattr(input, write_method)(path, **write_kwargs) + output = read_method(path, compression=compression_only) + tm.assert_frame_equal(output, input) + + +@pytest.mark.parametrize( + "write_method,write_kwargs,read_method,read_kwargs", + [ + ("to_csv", {"index": False, "header": True}, pd.read_csv, {"squeeze": True}), + ("to_json", {}, pd.read_json, {"typ": "series"}), + ("to_pickle", {}, pd.read_pickle, {}), + ], +) +def test_series_compression_defaults_to_infer( + write_method, + write_kwargs, + read_method, + read_kwargs, + compression_only, + compression_to_extension, + temp_file, +): + # GH22004 + input = pd.Series([0, 5, -2, 10], name="X") + extension = compression_to_extension[compression_only] + path = temp_file.parent / f"compressed{extension}" + getattr(input, write_method)(path, **write_kwargs) + if "squeeze" in read_kwargs: + kwargs = read_kwargs.copy() + del kwargs["squeeze"] + output = read_method(path, compression=compression_only, **kwargs).squeeze( + "columns" + ) + else: + output = read_method(path, compression=compression_only, **read_kwargs) + tm.assert_series_equal(output, input, check_names=False) + + +def test_compression_warning(compression_only, temp_file): + # Assert that passing a file object to to_csv while explicitly specifying a + # compression protocol triggers a RuntimeWarning, as per GH21227. + df = pd.DataFrame( + 100 * [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + columns=["X", "Y", "Z"], + ) + path = temp_file + with icom.get_handle(path, "w", compression=compression_only) as handles: + with tm.assert_produces_warning(RuntimeWarning, match="has no effect"): + df.to_csv(handles.handle, compression=compression_only) + + +def test_compression_binary(compression_only, temp_file): + """ + Binary file handles support compression. + + GH22555 + """ + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + + # with a file + path = temp_file + with open(path, mode="wb") as file: + df.to_csv(file, mode="wb", compression=compression_only) + file.seek(0) # file shouldn't be closed + tm.assert_frame_equal( + df, pd.read_csv(path, index_col=0, compression=compression_only) + ) + + # with BytesIO + file = io.BytesIO() + df.to_csv(file, mode="wb", compression=compression_only) + file.seek(0) # file shouldn't be closed + tm.assert_frame_equal( + df, pd.read_csv(file, index_col=0, compression=compression_only) + ) + + +def test_gzip_reproducibility_file_name(temp_file): + """ + Gzip should create reproducible archives with mtime. + + Note: Archives created with different filenames will still be different! + + GH 28103 + """ + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + compression_options = {"method": "gzip", "mtime": 1} + + # test for filename + path = temp_file + df.to_csv(path, compression=compression_options) + output = path.read_bytes() + df.to_csv(path, compression=compression_options) + assert output == path.read_bytes() + + +def test_gzip_reproducibility_file_object(): + """ + Gzip should create reproducible archives with mtime. + + GH 28103 + """ + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + compression_options = {"method": "gzip", "mtime": 1} + + # test for file object + buffer = io.BytesIO() + df.to_csv(buffer, compression=compression_options, mode="wb") + output = buffer.getvalue() + buffer = io.BytesIO() + df.to_csv(buffer, compression=compression_options, mode="wb") + assert output == buffer.getvalue() + + +@pytest.mark.single_cpu +def test_with_missing_lzma(): + """Tests if import pandas works when lzma is not present.""" + # https://github.com/pandas-dev/pandas/issues/27575 + code = textwrap.dedent( + """\ + import sys + sys.modules['lzma'] = None + import pandas + """ + ) + subprocess.check_output([sys.executable, "-c", code], stderr=subprocess.PIPE) + + +@pytest.mark.single_cpu +def test_with_missing_lzma_runtime(): + """Tests if ModuleNotFoundError is hit when calling lzma without + having the module available. + """ + code = textwrap.dedent( + """ + import sys + import pytest + sys.modules['lzma'] = None + import pandas as pd + df = pd.DataFrame() + with pytest.raises(ModuleNotFoundError, match='import of lzma'): + df.to_csv('foo.csv', compression='xz') + """ + ) + subprocess.check_output([sys.executable, "-c", code], stderr=subprocess.PIPE) + + +@pytest.mark.parametrize( + "obj", + [ + pd.DataFrame( + 100 * [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + columns=["X", "Y", "Z"], + ), + pd.Series(100 * [0.123456, 0.234567, 0.567567], name="X"), + ], +) +@pytest.mark.parametrize("method", ["to_pickle", "to_json", "to_csv"]) +def test_gzip_compression_level(obj, method, temp_file): + # GH33196 + path = temp_file + getattr(obj, method)(path, compression="gzip") + compressed_size_default = os.path.getsize(path) + getattr(obj, method)(path, compression={"method": "gzip", "compresslevel": 1}) + compressed_size_fast = os.path.getsize(path) + assert compressed_size_default < compressed_size_fast + + +@pytest.mark.parametrize( + "obj", + [ + pd.DataFrame( + 100 * [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + columns=["X", "Y", "Z"], + ), + pd.Series(100 * [0.123456, 0.234567, 0.567567], name="X"), + ], +) +@pytest.mark.parametrize("method", ["to_pickle", "to_json", "to_csv"]) +def test_xz_compression_level_read(obj, method, temp_file): + path = temp_file + getattr(obj, method)(path, compression="xz") + compressed_size_default = os.path.getsize(path) + getattr(obj, method)(path, compression={"method": "xz", "preset": 1}) + compressed_size_fast = os.path.getsize(path) + assert compressed_size_default < compressed_size_fast + if method == "to_csv": + pd.read_csv(path, compression="xz") + + +@pytest.mark.parametrize( + "obj", + [ + pd.DataFrame( + 100 * [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + columns=["X", "Y", "Z"], + ), + pd.Series(100 * [0.123456, 0.234567, 0.567567], name="X"), + ], +) +@pytest.mark.parametrize("method", ["to_pickle", "to_json", "to_csv"]) +def test_bzip_compression_level(obj, method, temp_file): + """GH33196 bzip needs file size > 100k to show a size difference between + compression levels, so here we just check if the call works when + compression is passed as a dict. + """ + path = temp_file + getattr(obj, method)(path, compression={"method": "bz2", "compresslevel": 1}) + + +@pytest.mark.parametrize( + "suffix,archive", + [ + (".zip", zipfile.ZipFile), + (".tar", tarfile.TarFile), + ], +) +def test_empty_archive_zip(suffix, archive, temp_file): + path = temp_file.parent / f"archive{suffix}" + with archive(path, "w"): + pass + with pytest.raises(ValueError, match="Zero files found"): + pd.read_csv(path) + + +def test_ambiguous_archive_zip(temp_file): + path = temp_file.parent / "archive.zip" + with zipfile.ZipFile(path, "w") as file: + file.writestr("a.csv", "foo,bar") + file.writestr("b.csv", "foo,bar") + with pytest.raises(ValueError, match="Multiple files found in ZIP file"): + pd.read_csv(path) + + +def test_ambiguous_archive_tar(tmp_path): + csvAPath = tmp_path / "a.csv" + with open(csvAPath, "w", encoding="utf-8") as a: + a.write("foo,bar\n") + csvBPath = tmp_path / "b.csv" + with open(csvBPath, "w", encoding="utf-8") as b: + b.write("foo,bar\n") + + tarpath = tmp_path / "archive.tar" + with tarfile.TarFile(tarpath, "w") as tar: + tar.add(csvAPath, "a.csv") + tar.add(csvBPath, "b.csv") + + with pytest.raises(ValueError, match="Multiple files found in TAR archive"): + pd.read_csv(tarpath) + + +def test_tar_gz_to_different_filename(temp_file): + file = temp_file.parent / "archive.foo" + pd.DataFrame( + [["1", "2"]], + columns=["foo", "bar"], + ).to_csv(file, compression={"method": "tar", "mode": "w:gz"}, index=False) + with gzip.open(file) as uncompressed: + with tarfile.TarFile(fileobj=uncompressed) as archive: + members = archive.getmembers() + assert len(members) == 1 + content = archive.extractfile(members[0]).read().decode("utf8") + + if is_platform_windows(): + expected = "foo,bar\r\n1,2\r\n" + else: + expected = "foo,bar\n1,2\n" + + assert content == expected + + +def test_tar_no_error_on_close(): + with io.BytesIO() as buffer: + with icom._BytesTarFile(fileobj=buffer, mode="w"): + pass diff --git a/pandas/tests/io/test_feather.py b/pandas/tests/io/test_feather.py new file mode 100644 index 0000000000000000000000000000000000000000..6351a9760b773e2ffdd1547061b7f1918ce325e4 --- /dev/null +++ b/pandas/tests/io/test_feather.py @@ -0,0 +1,291 @@ +"""test feather-format compat""" + +from datetime import datetime +import zoneinfo + +import numpy as np +import pytest + +from pandas.compat.pyarrow import ( + pa_version_under18p0, + pa_version_under19p0, +) + +import pandas as pd +import pandas._testing as tm + +from pandas.io.feather_format import read_feather, to_feather # isort:skip + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) + + +pa = pytest.importorskip("pyarrow") + + +@pytest.mark.single_cpu +class TestFeather: + def check_error_on_write(self, df, exc, err_msg, temp_file): + # check that we are raising the exception + # on writing + + with pytest.raises(exc, match=err_msg): + to_feather(df, temp_file) + + def check_external_error_on_write(self, df, temp_file): + # check that we are raising the exception + # on writing + + with tm.external_error_raised(Exception): + to_feather(df, temp_file) + + def check_round_trip( + self, df, temp_file, expected=None, write_kwargs=None, **read_kwargs + ): + if write_kwargs is None: + write_kwargs = {} + if expected is None: + expected = df.copy() + + to_feather(df, temp_file, **write_kwargs) + + result = read_feather(temp_file, **read_kwargs) + + tm.assert_frame_equal(result, expected) + + def test_error(self, temp_file): + msg = "feather only support IO with DataFrames" + for obj in [ + pd.Series([1, 2, 3]), + 1, + "foo", + pd.Timestamp("20130101"), + np.array([1, 2, 3]), + ]: + self.check_error_on_write(obj, ValueError, msg, temp_file) + + def test_basic(self, temp_file): + tz = zoneinfo.ZoneInfo("US/Eastern") + df = pd.DataFrame( + { + "string": list("abc"), + "int": list(range(1, 4)), + "uint": np.arange(3, 6).astype("u1"), + "float": np.arange(4.0, 7.0, dtype="float64"), + "float_with_null": [1.0, np.nan, 3], + "bool": [True, False, True], + "bool_with_null": [True, np.nan, False], + "cat": pd.Categorical(list("abc")), + "dt": pd.DatetimeIndex( + list(pd.date_range("20130101", periods=3)), freq=None + ), + "dttz": pd.DatetimeIndex( + list(pd.date_range("20130101", periods=3, tz=tz)), + freq=None, + ), + "dt_with_null": [ + pd.Timestamp("20130101"), + pd.NaT, + pd.Timestamp("20130103"), + ], + "dtns": pd.DatetimeIndex( + list(pd.date_range("20130101", periods=3, freq="ns")), freq=None + ), + } + ) + df["periods"] = pd.period_range("2013", freq="M", periods=3) + df["timedeltas"] = pd.timedelta_range("1 day", periods=3) + df["intervals"] = pd.interval_range(0, 3, 3) + + assert df.dttz.dtype.tz.key == "US/Eastern" + + expected = df.copy() + expected.loc[1, "bool_with_null"] = None + self.check_round_trip(df, temp_file, expected=expected) + + def test_duplicate_columns(self, temp_file): + # https://github.com/wesm/feather/issues/53 + # not currently able to handle duplicate columns + df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=list("aaa")).copy() + self.check_external_error_on_write(df, temp_file) + + def test_read_columns(self, temp_file): + # GH 24025 + df = pd.DataFrame( + { + "col1": list("abc"), + "col2": list(range(1, 4)), + "col3": list("xyz"), + "col4": list(range(4, 7)), + } + ) + columns = ["col1", "col3"] + self.check_round_trip(df, temp_file, expected=df[columns], columns=columns) + + def test_read_columns_different_order(self, temp_file): + # GH 33878 + df = pd.DataFrame({"A": [1, 2], "B": ["x", "y"], "C": [True, False]}) + expected = df[["B", "A"]] + self.check_round_trip(df, temp_file, expected, columns=["B", "A"]) + + def test_unsupported_other(self, temp_file): + # mixed python objects + df = pd.DataFrame({"a": ["a", 1, 2.0]}) + self.check_external_error_on_write(df, temp_file) + + def test_rw_use_threads(self, temp_file): + df = pd.DataFrame({"A": np.arange(100000)}) + self.check_round_trip(df, temp_file, use_threads=True) + self.check_round_trip(df, temp_file, use_threads=False) + + def test_path_pathlib(self, temp_file): + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ).reset_index() + result = tm.round_trip_pathlib(df.to_feather, read_feather, temp_file) + tm.assert_frame_equal(df, result) + + def test_passthrough_keywords(self, temp_file): + df = pd.DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ).reset_index() + self.check_round_trip(df, temp_file, write_kwargs={"version": 1}) + + @pytest.mark.network + @pytest.mark.single_cpu + def test_http_path(self, feather_file, httpserver): + # GH 29055 + expected = read_feather(feather_file) + with open(feather_file, "rb") as f: + httpserver.serve_content(content=f.read()) + res = read_feather(httpserver.url) + tm.assert_frame_equal(expected, res) + + def test_read_feather_dtype_backend( + self, string_storage, dtype_backend, using_infer_string, temp_file + ): + # GH#50765 + df = pd.DataFrame( + { + "a": pd.Series([1, pd.NA, 3], dtype="Int64"), + "b": pd.Series([1, 2, 3], dtype="Int64"), + "c": pd.Series([1.5, pd.NA, 2.5], dtype="Float64"), + "d": pd.Series([1.5, 2.0, 2.5], dtype="Float64"), + "e": [True, False, None], + "f": [True, False, True], + "g": ["a", "b", "c"], + "h": ["a", "b", None], + } + ) + + to_feather(df, temp_file) + with pd.option_context("mode.string_storage", string_storage): + result = read_feather(temp_file, dtype_backend=dtype_backend) + + if dtype_backend == "pyarrow": + pa = pytest.importorskip("pyarrow") + if using_infer_string: + string_dtype = pd.ArrowDtype(pa.large_string()) + else: + string_dtype = pd.ArrowDtype(pa.string()) + else: + string_dtype = pd.StringDtype(string_storage) + + expected = pd.DataFrame( + { + "a": pd.Series([1, pd.NA, 3], dtype="Int64"), + "b": pd.Series([1, 2, 3], dtype="Int64"), + "c": pd.Series([1.5, pd.NA, 2.5], dtype="Float64"), + "d": pd.Series([1.5, 2.0, 2.5], dtype="Float64"), + "e": pd.Series([True, False, pd.NA], dtype="boolean"), + "f": pd.Series([True, False, True], dtype="boolean"), + "g": pd.Series(["a", "b", "c"], dtype=string_dtype), + "h": pd.Series(["a", "b", None], dtype=string_dtype), + } + ) + + if dtype_backend == "pyarrow": + from pandas.arrays import ArrowExtensionArray + + expected = pd.DataFrame( + { + col: ArrowExtensionArray(pa.array(expected[col], from_pandas=True)) + for col in expected.columns + } + ) + + if using_infer_string: + expected.columns = expected.columns.astype( + pd.StringDtype(string_storage, na_value=np.nan) + ) + tm.assert_frame_equal(result, expected) + + def test_int_columns_and_index(self, temp_file): + df = pd.DataFrame({"a": [1, 2, 3]}, index=pd.Index([3, 4, 5], name="test")) + self.check_round_trip(df, temp_file) + + def test_invalid_dtype_backend(self, temp_file): + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + df = pd.DataFrame({"int": list(range(1, 4))}) + df.to_feather(temp_file) + with pytest.raises(ValueError, match=msg): + read_feather(temp_file, dtype_backend="numpy") + + def test_string_inference(self, temp_file, using_infer_string): + # GH#54431 + df = pd.DataFrame(data={"a": ["x", "y"]}) + df.to_feather(temp_file) + with pd.option_context("future.infer_string", True): + result = read_feather(temp_file) + dtype = pd.StringDtype(na_value=np.nan) + expected = pd.DataFrame( + data={"a": ["x", "y"]}, dtype=pd.StringDtype(na_value=np.nan) + ) + expected = pd.DataFrame( + data={"a": ["x", "y"]}, + dtype=dtype, + columns=pd.Index( + ["a"], + dtype=object + if pa_version_under19p0 and not using_infer_string + else dtype, + ), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.skipif(pa_version_under18p0, reason="not supported before 18.0") + def test_string_inference_string_view_type(self, temp_file): + # GH#54798 + import pyarrow as pa + from pyarrow import feather + + table = pa.table({"a": pa.array([None, "b", "c"], pa.string_view())}) + feather.write_feather(table, temp_file) + + with pd.option_context("future.infer_string", True): + result = read_feather(temp_file) + + expected = pd.DataFrame( + data={"a": [None, "b", "c"]}, dtype=pd.StringDtype(na_value=np.nan) + ) + tm.assert_frame_equal(result, expected) + + def test_out_of_bounds_datetime_to_feather(self, temp_file): + # GH#47832 + df = pd.DataFrame( + { + "date": [ + datetime.fromisoformat("1654-01-01"), + datetime.fromisoformat("1920-01-01"), + ], + } + ) + self.check_round_trip(df, temp_file) diff --git a/pandas/tests/io/test_fsspec.py b/pandas/tests/io/test_fsspec.py new file mode 100644 index 0000000000000000000000000000000000000000..5d76a622d29148caa47deb98c3e7687226ea5b1f --- /dev/null +++ b/pandas/tests/io/test_fsspec.py @@ -0,0 +1,348 @@ +import io + +import numpy as np +import pytest + +from pandas._config import using_string_dtype + +from pandas.compat import HAS_PYARROW +from pandas.compat.pyarrow import pa_version_under14p0 + +from pandas import ( + DataFrame, + date_range, + read_csv, + read_excel, + read_feather, + read_json, + read_parquet, + read_pickle, + read_stata, + read_table, +) +import pandas._testing as tm +from pandas.util import _test_decorators as td + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) + + +@pytest.fixture +def fsspectest(): + pytest.importorskip("fsspec") + from fsspec import register_implementation + from fsspec.implementations.memory import MemoryFileSystem + from fsspec.registry import _registry as registry + + class TestMemoryFS(MemoryFileSystem): + protocol = "testmem" + test = [None] + + def __init__(self, **kwargs) -> None: + self.test[0] = kwargs.pop("test", None) + super().__init__(**kwargs) + + register_implementation("testmem", TestMemoryFS, clobber=True) + yield TestMemoryFS() + registry.pop("testmem", None) + TestMemoryFS.test[0] = None + TestMemoryFS.store.clear() + + +@pytest.fixture +def df1(): + return DataFrame( + { + "int": [1, 3], + "float": [2.0, np.nan], + "str": ["t", "s"], + "dt": date_range("2018-06-18", periods=2), + } + ) + + +@pytest.fixture +def cleared_fs(): + fsspec = pytest.importorskip("fsspec") + + memfs = fsspec.filesystem("memory") + yield memfs + memfs.store.clear() + + +def test_read_csv(cleared_fs, df1): + text = str(df1.to_csv(index=False)).encode() + with cleared_fs.open("test/test.csv", "wb") as w: + w.write(text) + df2 = read_csv("memory://test/test.csv", parse_dates=["dt"]) + + expected = df1.copy() + expected["dt"] = expected["dt"].astype("M8[us]") + tm.assert_frame_equal(df2, expected) + + +def test_reasonable_error(monkeypatch, cleared_fs): + from fsspec.registry import known_implementations + + with pytest.raises(ValueError, match="nosuchprotocol"): + read_csv("nosuchprotocol://test/test.csv") + err_msg = "test error message" + monkeypatch.setitem( + known_implementations, + "couldexist", + {"class": "unimportable.CouldExist", "err": err_msg}, + ) + with pytest.raises(ImportError, match=err_msg): + read_csv("couldexist://test/test.csv") + + +def test_to_csv(cleared_fs, df1): + df1.to_csv("memory://test/test.csv", index=True) + + df2 = read_csv("memory://test/test.csv", parse_dates=["dt"], index_col=0) + + expected = df1.copy() + expected["dt"] = expected["dt"].astype("M8[us]") + tm.assert_frame_equal(df2, expected) + + +def test_to_excel(cleared_fs, df1): + pytest.importorskip("openpyxl") + ext = "xlsx" + path = f"memory://test/test.{ext}" + df1.to_excel(path, index=True) + + df2 = read_excel(path, parse_dates=["dt"], index_col=0) + + expected = df1.copy() + expected["dt"] = expected["dt"].astype("M8[us]") + tm.assert_frame_equal(df2, expected) + + +@pytest.mark.parametrize("binary_mode", [False, True]) +def test_to_csv_fsspec_object(cleared_fs, binary_mode, df1): + fsspec = pytest.importorskip("fsspec") + + path = "memory://test/test.csv" + mode = "wb" if binary_mode else "w" + with fsspec.open(path, mode=mode).open() as fsspec_object: + df1.to_csv(fsspec_object, index=True) + assert not fsspec_object.closed + + mode = mode.replace("w", "r") + with fsspec.open(path, mode=mode) as fsspec_object: + df2 = read_csv( + fsspec_object, + parse_dates=["dt"], + index_col=0, + ) + assert not fsspec_object.closed + + expected = df1.copy() + expected["dt"] = expected["dt"].astype("M8[us]") + tm.assert_frame_equal(df2, expected) + + +def test_csv_options(fsspectest): + df = DataFrame({"a": [0]}) + df.to_csv( + "testmem://test/test.csv", storage_options={"test": "csv_write"}, index=False + ) + assert fsspectest.test[0] == "csv_write" + read_csv("testmem://test/test.csv", storage_options={"test": "csv_read"}) + assert fsspectest.test[0] == "csv_read" + + +def test_read_table_options(fsspectest): + # GH #39167 + df = DataFrame({"a": [0]}) + df.to_csv( + "testmem://test/test.csv", storage_options={"test": "csv_write"}, index=False + ) + assert fsspectest.test[0] == "csv_write" + read_table("testmem://test/test.csv", storage_options={"test": "csv_read"}) + assert fsspectest.test[0] == "csv_read" + + +def test_excel_options(fsspectest): + pytest.importorskip("openpyxl") + extension = "xlsx" + + df = DataFrame({"a": [0]}) + + path = f"testmem://test/test.{extension}" + + df.to_excel(path, storage_options={"test": "write"}, index=False) + assert fsspectest.test[0] == "write" + read_excel(path, storage_options={"test": "read"}) + assert fsspectest.test[0] == "read" + + +@pytest.mark.xfail( + using_string_dtype() and HAS_PYARROW and not pa_version_under14p0, + reason="TODO(infer_string) fastparquet", +) +def test_to_parquet_new_file(cleared_fs, df1): + """Regression test for writing to a not-yet-existent GCS Parquet file.""" + pytest.importorskip("fastparquet") + + df1.to_parquet( + "memory://test/test.csv", index=True, engine="fastparquet", compression=None + ) + + +def test_arrowparquet_options(fsspectest): + """Regression test for writing to a not-yet-existent GCS Parquet file.""" + pytest.importorskip("pyarrow") + df = DataFrame({"a": [0]}) + df.to_parquet( + "testmem://test/test.csv", + engine="pyarrow", + compression=None, + storage_options={"test": "parquet_write"}, + ) + assert fsspectest.test[0] == "parquet_write" + read_parquet( + "testmem://test/test.csv", + engine="pyarrow", + storage_options={"test": "parquet_read"}, + ) + assert fsspectest.test[0] == "parquet_read" + + +def test_fastparquet_options(fsspectest): + """Regression test for writing to a not-yet-existent GCS Parquet file.""" + pytest.importorskip("fastparquet") + + df = DataFrame({"a": [0]}) + df.to_parquet( + "testmem://test/test.csv", + engine="fastparquet", + compression=None, + storage_options={"test": "parquet_write"}, + ) + assert fsspectest.test[0] == "parquet_write" + read_parquet( + "testmem://test/test.csv", + engine="fastparquet", + storage_options={"test": "parquet_read"}, + ) + assert fsspectest.test[0] == "parquet_read" + + +@pytest.mark.single_cpu +@pytest.mark.parametrize("compression_suffix", ["", ".gz", ".bz2"]) +def test_from_s3_csv(s3_bucket_public_with_data, s3so, tips_file, compression_suffix): + pytest.importorskip("s3fs") + df_from_s3 = read_csv( + f"s3://{s3_bucket_public_with_data.name}/tips.csv{compression_suffix}", + storage_options=s3so, + ) + df_from_local = read_csv(tips_file) + tm.assert_equal(df_from_s3, df_from_local) + + +@pytest.mark.single_cpu +@pytest.mark.parametrize("protocol", ["s3", "s3a", "s3n"]) +def test_s3_protocols(s3_bucket_public_with_data, s3so, tips_file, protocol): + pytest.importorskip("s3fs") + df_from_s3 = read_csv( + f"{protocol}://{s3_bucket_public_with_data.name}/tips.csv", + storage_options=s3so, + ) + df_from_local = read_csv(tips_file) + tm.assert_equal(df_from_s3, df_from_local) + + +@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string) fastparquet") +@pytest.mark.single_cpu +def test_s3_parquet(s3_bucket_public, s3so, df1): + pytest.importorskip("fastparquet") + pytest.importorskip("s3fs") + + fn = f"s3://{s3_bucket_public.name}/test.parquet" + df1.to_parquet( + fn, index=False, engine="fastparquet", compression=None, storage_options=s3so + ) + df2 = read_parquet(fn, engine="fastparquet", storage_options=s3so) + tm.assert_equal(df1, df2) + + +@td.skip_if_installed("fsspec") +def test_not_present_exception(): + msg = "`Import fsspec` failed. Use pip or conda to install the fsspec package." + with pytest.raises(ImportError, match=msg): + read_csv("memory://test/test.csv") + + +def test_feather_options(fsspectest): + pytest.importorskip("pyarrow") + df = DataFrame({"a": [0]}) + df.to_feather("testmem://mockfile", storage_options={"test": "feather_write"}) + assert fsspectest.test[0] == "feather_write" + out = read_feather("testmem://mockfile", storage_options={"test": "feather_read"}) + assert fsspectest.test[0] == "feather_read" + tm.assert_frame_equal(df, out) + + +def test_pickle_options(fsspectest): + df = DataFrame({"a": [0]}) + df.to_pickle("testmem://mockfile", storage_options={"test": "pickle_write"}) + assert fsspectest.test[0] == "pickle_write" + out = read_pickle("testmem://mockfile", storage_options={"test": "pickle_read"}) + assert fsspectest.test[0] == "pickle_read" + tm.assert_frame_equal(df, out) + + +def test_json_options(fsspectest, compression): + df = DataFrame({"a": [0]}) + df.to_json( + "testmem://mockfile", + compression=compression, + storage_options={"test": "json_write"}, + ) + assert fsspectest.test[0] == "json_write" + out = read_json( + "testmem://mockfile", + compression=compression, + storage_options={"test": "json_read"}, + ) + assert fsspectest.test[0] == "json_read" + tm.assert_frame_equal(df, out) + + +def test_stata_options(fsspectest): + df = DataFrame({"a": [0]}) + df.to_stata( + "testmem://mockfile", storage_options={"test": "stata_write"}, write_index=False + ) + assert fsspectest.test[0] == "stata_write" + out = read_stata("testmem://mockfile", storage_options={"test": "stata_read"}) + assert fsspectest.test[0] == "stata_read" + tm.assert_frame_equal(df, out.astype("int64")) + + +def test_markdown_options(fsspectest): + pytest.importorskip("tabulate") + df = DataFrame({"a": [0]}) + df.to_markdown("testmem://mockfile", storage_options={"test": "md_write"}) + assert fsspectest.test[0] == "md_write" + assert fsspectest.cat("testmem://mockfile") + + +def test_non_fsspec_options(): + pytest.importorskip("pyarrow") + with pytest.raises(ValueError, match="storage_options"): + read_csv("localfile", storage_options={"a": True}) + with pytest.raises(ValueError, match="storage_options"): + # separate test for parquet, which has a different code path + read_parquet("localfile", storage_options={"a": True}) + by = io.BytesIO() + + with pytest.raises(ValueError, match="storage_options"): + read_csv(by, storage_options={"a": True}) + + df = DataFrame({"a": [0]}) + with pytest.raises(ValueError, match="storage_options"): + df.to_parquet("nonfsspecpath", storage_options={"a": True}) diff --git a/pandas/tests/io/test_gcs.py b/pandas/tests/io/test_gcs.py new file mode 100644 index 0000000000000000000000000000000000000000..022fd89c1f555551a47d6361f0368dcbb3d28ab8 --- /dev/null +++ b/pandas/tests/io/test_gcs.py @@ -0,0 +1,233 @@ +from io import BytesIO +import os +import pathlib +import tarfile +import zipfile + +import numpy as np +import pytest + +from pandas.compat.pyarrow import pa_version_under17p0 + +from pandas import ( + DataFrame, + Index, + date_range, + read_csv, + read_excel, + read_json, + read_parquet, +) +import pandas._testing as tm +from pandas.util import _test_decorators as td + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) + + +@pytest.fixture +def gcs_buffer(): + """Emulate GCS using a binary buffer.""" + pytest.importorskip("gcsfs") + fsspec = pytest.importorskip("fsspec") + + gcs_buffer = BytesIO() + gcs_buffer.close = lambda: True + + class MockGCSFileSystem(fsspec.AbstractFileSystem): + @staticmethod + def open(*args, **kwargs): + gcs_buffer.seek(0) + return gcs_buffer + + def ls(self, path, **kwargs): + # needed for pyarrow + return [{"name": path, "type": "file"}] + + # Overwrites the default implementation from gcsfs to our mock class + fsspec.register_implementation("gs", MockGCSFileSystem, clobber=True) + + return gcs_buffer + + +# Patches pyarrow; other processes should not pick up change +@pytest.mark.single_cpu +@pytest.mark.parametrize("format", ["csv", "json", "parquet", "excel", "markdown"]) +def test_to_read_gcs(gcs_buffer, format, monkeypatch, capsys, request): + """ + Test that many to/read functions support GCS. + + GH 33987 + """ + + df1 = DataFrame( + { + "int": [1, 3], + "float": [2.0, np.nan], + "str": ["t", "s"], + "dt": date_range("2018-06-18", periods=2, unit="ns"), + } + ) + + path = f"gs://test/test.{format}" + + if format == "csv": + df1.to_csv(path, index=True) + df2 = read_csv(path, parse_dates=["dt"], index_col=0) + elif format == "excel": + path = "gs://test/test.xlsx" + df1.to_excel(path) + df2 = read_excel(path, parse_dates=["dt"], index_col=0) + elif format == "json": + df1.to_json(path, date_format="iso") + df2 = read_json(path, convert_dates=["dt"]) + elif format == "parquet": + pytest.importorskip("pyarrow") + pa_fs = pytest.importorskip("pyarrow.fs") + + class MockFileSystem(pa_fs.FileSystem): + @staticmethod + def from_uri(path): + print("Using pyarrow filesystem") + to_local = pathlib.Path(path.replace("gs://", "")).absolute().as_uri() + return pa_fs.LocalFileSystem(to_local) + + request.applymarker( + pytest.mark.xfail( + not pa_version_under17p0, + raises=TypeError, + reason="pyarrow 17 broke the mocked filesystem", + ) + ) + with monkeypatch.context() as m: + m.setattr(pa_fs, "FileSystem", MockFileSystem) + df1.to_parquet(path) + df2 = read_parquet(path) + captured = capsys.readouterr() + assert captured.out == "Using pyarrow filesystem\nUsing pyarrow filesystem\n" + elif format == "markdown": + pytest.importorskip("tabulate") + df1.to_markdown(path) + df2 = df1 + + expected = df1[:] + if format in ["csv", "excel", "json"]: + expected["dt"] = expected["dt"].dt.as_unit("us") + + tm.assert_frame_equal(df2, expected) + + +def assert_equal_zip_safe(result: bytes, expected: bytes, compression: str): + """ + For zip compression, only compare the CRC-32 checksum of the file contents + to avoid checking the time-dependent last-modified timestamp which + in some CI builds is off-by-one + + See https://en.wikipedia.org/wiki/ZIP_(file_format)#File_headers + """ + if compression == "zip": + # Only compare the CRC checksum of the file contents + with ( + zipfile.ZipFile(BytesIO(result)) as exp, + zipfile.ZipFile(BytesIO(expected)) as res, + ): + for res_info, exp_info in zip(res.infolist(), exp.infolist()): + assert res_info.CRC == exp_info.CRC + elif compression == "tar": + with ( + tarfile.open(fileobj=BytesIO(result)) as tar_exp, + tarfile.open(fileobj=BytesIO(expected)) as tar_res, + ): + for tar_res_info, tar_exp_info in zip( + tar_res.getmembers(), tar_exp.getmembers() + ): + actual_file = tar_res.extractfile(tar_res_info) + expected_file = tar_exp.extractfile(tar_exp_info) + assert (actual_file is None) == (expected_file is None) + if actual_file is not None and expected_file is not None: + assert actual_file.read() == expected_file.read() + else: + assert result == expected + + +@pytest.mark.parametrize("encoding", ["utf-8", "cp1251"]) +def test_to_csv_compression_encoding_gcs( + gcs_buffer, compression_only, encoding, compression_to_extension +): + """ + Compression and encoding should with GCS. + + GH 35677 (to_csv, compression), GH 26124 (to_csv, encoding), and + GH 32392 (read_csv, encoding) + """ + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD")), + index=Index([f"i-{i}" for i in range(30)]), + ) + + # reference of compressed and encoded file + compression = {"method": compression_only} + if compression_only == "gzip": + compression["mtime"] = 1 # be reproducible + buffer = BytesIO() + df.to_csv(buffer, compression=compression, encoding=encoding, mode="wb") + + # write compressed file with explicit compression + path_gcs = "gs://test/test.csv" + df.to_csv(path_gcs, compression=compression, encoding=encoding) + res = gcs_buffer.getvalue() + expected = buffer.getvalue() + assert_equal_zip_safe(res, expected, compression_only) + + read_df = read_csv( + path_gcs, index_col=0, compression=compression_only, encoding=encoding + ) + tm.assert_frame_equal(df, read_df) + + # write compressed file with implicit compression + file_ext = compression_to_extension[compression_only] + compression["method"] = "infer" + path_gcs += f".{file_ext}" + df.to_csv(path_gcs, compression=compression, encoding=encoding) + + res = gcs_buffer.getvalue() + expected = buffer.getvalue() + assert_equal_zip_safe(res, expected, compression_only) + + read_df = read_csv(path_gcs, index_col=0, compression="infer", encoding=encoding) + tm.assert_frame_equal(df, read_df) + + +def test_to_parquet_gcs_new_file(monkeypatch, tmpdir): + """Regression test for writing to a not-yet-existent GCS Parquet file.""" + pytest.importorskip("fastparquet") + pytest.importorskip("gcsfs") + + from fsspec import AbstractFileSystem + + df1 = DataFrame( + { + "int": [1, 3], + "float": [2.0, np.nan], + "dt": date_range("2018-06-18", periods=2), + } + ) + + class MockGCSFileSystem(AbstractFileSystem): + def open(self, path, mode="r", *args): + if "w" not in mode: + raise FileNotFoundError + return open(os.path.join(tmpdir, "test.parquet"), mode, encoding="utf-8") + + monkeypatch.setattr("gcsfs.GCSFileSystem", MockGCSFileSystem) + df1.to_parquet( + "gs://test/test.csv", index=True, engine="fastparquet", compression=None + ) + + +@td.skip_if_installed("gcsfs") +def test_gcs_not_present_exception(): + with tm.external_error_raised(ImportError): + read_csv("gs://test/test.csv") diff --git a/pandas/tests/io/test_html.py b/pandas/tests/io/test_html.py new file mode 100644 index 0000000000000000000000000000000000000000..abc0cfbb36332daeb3ea606450a1df75d455ed50 --- /dev/null +++ b/pandas/tests/io/test_html.py @@ -0,0 +1,1669 @@ +from collections.abc import Iterator +from functools import partial +from io import ( + BytesIO, + StringIO, +) +import os +from pathlib import Path +import re +import threading +from urllib.error import URLError + +import numpy as np +import pytest + +from pandas.compat import is_platform_windows +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + NA, + DataFrame, + MultiIndex, + Series, + Timestamp, + date_range, + read_csv, + read_html, + to_datetime, +) +import pandas._testing as tm + +from pandas.io.common import file_path_to_url + + +@pytest.fixture( + params=[ + "chinese_utf-16.html", + "chinese_utf-32.html", + "chinese_utf-8.html", + "letz_latin1.html", + ] +) +def html_encoding_file(request, datapath): + """Parametrized fixture for HTML encoding test filenames.""" + return datapath("io", "data", "html_encoding", request.param) + + +def assert_framelist_equal(list1, list2, *args, **kwargs): + assert len(list1) == len(list2), ( + "lists are not of equal size " + f"len(list1) == {len(list1)}, " + f"len(list2) == {len(list2)}" + ) + msg = "not all list elements are DataFrames" + both_frames = all( + map( + lambda x, y: isinstance(x, DataFrame) and isinstance(y, DataFrame), + list1, + list2, + ) + ) + assert both_frames, msg + for frame_i, frame_j in zip(list1, list2): + tm.assert_frame_equal(frame_i, frame_j, *args, **kwargs) + assert not frame_i.empty, "frames are both empty" + + +def test_bs4_version_fails(monkeypatch, datapath): + bs4 = pytest.importorskip("bs4") + pytest.importorskip("html5lib") + + monkeypatch.setattr(bs4, "__version__", "4.2") + with pytest.raises(ImportError, match="Pandas requires version"): + read_html(datapath("io", "data", "html", "spam.html"), flavor="bs4") + + +def test_invalid_flavor(): + url = "google.com" + flavor = "invalid flavor" + msg = r"\{" + flavor + r"\} is not a valid set of flavors" + + with pytest.raises(ValueError, match=msg): + read_html(StringIO(url), match="google", flavor=flavor) + + +def test_same_ordering(datapath): + pytest.importorskip("bs4") + pytest.importorskip("lxml") + pytest.importorskip("html5lib") + + filename = datapath("io", "data", "html", "valid_markup.html") + dfs_lxml = read_html(filename, index_col=0, flavor=["lxml"]) + dfs_bs4 = read_html(filename, index_col=0, flavor=["bs4"]) + assert_framelist_equal(dfs_lxml, dfs_bs4) + + +@pytest.fixture( + params=[ + pytest.param("bs4", marks=[td.skip_if_no("bs4"), td.skip_if_no("html5lib")]), + pytest.param("lxml", marks=td.skip_if_no("lxml")), + ], +) +def flavor_read_html(request): + return partial(read_html, flavor=request.param) + + +class TestReadHtml: + def test_literal_html_deprecation(self, flavor_read_html): + # GH 53785 + msg = r"\[Errno 2\] No such file or director" + + with pytest.raises(FileNotFoundError, match=msg): + flavor_read_html( + """ + + + + + + + + + + + + + + + + + + +
AB
12
34
""" + ) + + @pytest.fixture + def spam_data(self, datapath): + return datapath("io", "data", "html", "spam.html") + + @pytest.fixture + def banklist_data(self, datapath): + return datapath("io", "data", "html", "banklist.html") + + def test_to_html_compat(self, flavor_read_html): + df = ( + DataFrame( + np.random.default_rng(2).random((4, 3)), + columns=pd.Index(list("abc")), + ) + .map("{:.3f}".format) + .astype(float) + ) + out = df.to_html() + res = flavor_read_html( + StringIO(out), attrs={"class": "dataframe"}, index_col=0 + )[0] + tm.assert_frame_equal(res, df) + + def test_dtype_backend(self, string_storage, dtype_backend, flavor_read_html): + # GH#50286 + df = DataFrame( + { + "a": Series([1, NA, 3], dtype="Int64"), + "b": Series([1, 2, 3], dtype="Int64"), + "c": Series([1.5, NA, 2.5], dtype="Float64"), + "d": Series([1.5, 2.0, 2.5], dtype="Float64"), + "e": [True, False, None], + "f": [True, False, True], + "g": ["a", "b", "c"], + "h": ["a", "b", None], + } + ) + + out = df.to_html(index=False) + with pd.option_context("mode.string_storage", string_storage): + result = flavor_read_html(StringIO(out), dtype_backend=dtype_backend)[0] + + if dtype_backend == "pyarrow": + pa = pytest.importorskip("pyarrow") + string_dtype = pd.ArrowDtype(pa.string()) + else: + string_dtype = pd.StringDtype(string_storage) + + expected = DataFrame( + { + "a": Series([1, NA, 3], dtype="Int64"), + "b": Series([1, 2, 3], dtype="Int64"), + "c": Series([1.5, NA, 2.5], dtype="Float64"), + "d": Series([1.5, 2.0, 2.5], dtype="Float64"), + "e": Series([True, False, NA], dtype="boolean"), + "f": Series([True, False, True], dtype="boolean"), + "g": Series(["a", "b", "c"], dtype=string_dtype), + "h": Series(["a", "b", None], dtype=string_dtype), + } + ) + + if dtype_backend == "pyarrow": + import pyarrow as pa + + from pandas.arrays import ArrowExtensionArray + + expected = DataFrame( + { + col: ArrowExtensionArray(pa.array(expected[col], from_pandas=True)) + for col in expected.columns + } + ) + + # the storage of the str columns' Index is also affected by the + # string_storage setting -> ignore that for checking the result + tm.assert_frame_equal(result, expected, check_column_type=False) + + @pytest.mark.network + @pytest.mark.single_cpu + def test_banklist_url(self, httpserver, banklist_data, flavor_read_html): + with open(banklist_data, encoding="utf-8") as f: + httpserver.serve_content(content=f.read()) + df1 = flavor_read_html( + # lxml cannot find attrs leave out for now + httpserver.url, + match="First Federal Bank of Florida", # attrs={"class": "dataTable"} + ) + # lxml cannot find attrs leave out for now + df2 = flavor_read_html( + httpserver.url, + match="Metcalf Bank", + ) # attrs={"class": "dataTable"}) + + assert_framelist_equal(df1, df2) + + @pytest.mark.network + @pytest.mark.single_cpu + def test_spam_url(self, httpserver, spam_data, flavor_read_html): + with open(spam_data, encoding="utf-8") as f: + httpserver.serve_content(content=f.read()) + df1 = flavor_read_html(httpserver.url, match=".*Water.*") + df2 = flavor_read_html(httpserver.url, match="Unit") + + assert_framelist_equal(df1, df2) + + @pytest.mark.slow + def test_banklist(self, banklist_data, flavor_read_html): + df1 = flavor_read_html( + banklist_data, match=".*Florida.*", attrs={"id": "table"} + ) + df2 = flavor_read_html( + banklist_data, match="Metcalf Bank", attrs={"id": "table"} + ) + + assert_framelist_equal(df1, df2) + + def test_spam(self, spam_data, flavor_read_html): + df1 = flavor_read_html(spam_data, match=".*Water.*") + df2 = flavor_read_html(spam_data, match="Unit") + assert_framelist_equal(df1, df2) + + assert df1[0].iloc[0, 0] == "Proximates" + assert df1[0].columns[0] == "Nutrient" + + def test_spam_no_match(self, spam_data, flavor_read_html): + dfs = flavor_read_html(spam_data) + for df in dfs: + assert isinstance(df, DataFrame) + + def test_banklist_no_match(self, banklist_data, flavor_read_html): + dfs = flavor_read_html(banklist_data, attrs={"id": "table"}) + for df in dfs: + assert isinstance(df, DataFrame) + + def test_spam_header(self, spam_data, flavor_read_html): + df = flavor_read_html(spam_data, match=".*Water.*", header=2)[0] + assert df.columns[0] == "Proximates" + assert not df.empty + + def test_skiprows_int(self, spam_data, flavor_read_html): + df1 = flavor_read_html(spam_data, match=".*Water.*", skiprows=1) + df2 = flavor_read_html(spam_data, match="Unit", skiprows=1) + + assert_framelist_equal(df1, df2) + + def test_skiprows_range(self, spam_data, flavor_read_html): + df1 = flavor_read_html(spam_data, match=".*Water.*", skiprows=range(2)) + df2 = flavor_read_html(spam_data, match="Unit", skiprows=range(2)) + + assert_framelist_equal(df1, df2) + + def test_skiprows_list(self, spam_data, flavor_read_html): + df1 = flavor_read_html(spam_data, match=".*Water.*", skiprows=[1, 2]) + df2 = flavor_read_html(spam_data, match="Unit", skiprows=[2, 1]) + + assert_framelist_equal(df1, df2) + + def test_skiprows_set(self, spam_data, flavor_read_html): + df1 = flavor_read_html(spam_data, match=".*Water.*", skiprows={1, 2}) + df2 = flavor_read_html(spam_data, match="Unit", skiprows={2, 1}) + + assert_framelist_equal(df1, df2) + + def test_skiprows_slice(self, spam_data, flavor_read_html): + df1 = flavor_read_html(spam_data, match=".*Water.*", skiprows=1) + df2 = flavor_read_html(spam_data, match="Unit", skiprows=1) + + assert_framelist_equal(df1, df2) + + def test_skiprows_slice_short(self, spam_data, flavor_read_html): + df1 = flavor_read_html(spam_data, match=".*Water.*", skiprows=slice(2)) + df2 = flavor_read_html(spam_data, match="Unit", skiprows=slice(2)) + + assert_framelist_equal(df1, df2) + + def test_skiprows_slice_long(self, spam_data, flavor_read_html): + df1 = flavor_read_html(spam_data, match=".*Water.*", skiprows=slice(2, 5)) + df2 = flavor_read_html(spam_data, match="Unit", skiprows=slice(4, 1, -1)) + + assert_framelist_equal(df1, df2) + + def test_skiprows_ndarray(self, spam_data, flavor_read_html): + df1 = flavor_read_html(spam_data, match=".*Water.*", skiprows=np.arange(2)) + df2 = flavor_read_html(spam_data, match="Unit", skiprows=np.arange(2)) + + assert_framelist_equal(df1, df2) + + def test_skiprows_invalid(self, spam_data, flavor_read_html): + with pytest.raises(TypeError, match=("is not a valid type for skipping rows")): + flavor_read_html(spam_data, match=".*Water.*", skiprows="asdf") + + def test_index(self, spam_data, flavor_read_html): + df1 = flavor_read_html(spam_data, match=".*Water.*", index_col=0) + df2 = flavor_read_html(spam_data, match="Unit", index_col=0) + assert_framelist_equal(df1, df2) + + def test_header_and_index_no_types(self, spam_data, flavor_read_html): + df1 = flavor_read_html(spam_data, match=".*Water.*", header=1, index_col=0) + df2 = flavor_read_html(spam_data, match="Unit", header=1, index_col=0) + assert_framelist_equal(df1, df2) + + def test_header_and_index_with_types(self, spam_data, flavor_read_html): + df1 = flavor_read_html(spam_data, match=".*Water.*", header=1, index_col=0) + df2 = flavor_read_html(spam_data, match="Unit", header=1, index_col=0) + assert_framelist_equal(df1, df2) + + def test_infer_types(self, spam_data, flavor_read_html): + # 10892 infer_types removed + df1 = flavor_read_html(spam_data, match=".*Water.*", index_col=0) + df2 = flavor_read_html(spam_data, match="Unit", index_col=0) + assert_framelist_equal(df1, df2) + + def test_string_io(self, spam_data, flavor_read_html): + with open(spam_data, encoding="UTF-8") as f: + data1 = StringIO(f.read()) + + with open(spam_data, encoding="UTF-8") as f: + data2 = StringIO(f.read()) + + df1 = flavor_read_html(data1, match=".*Water.*") + df2 = flavor_read_html(data2, match="Unit") + assert_framelist_equal(df1, df2) + + def test_string(self, spam_data, flavor_read_html): + with open(spam_data, encoding="UTF-8") as f: + data = f.read() + + df1 = flavor_read_html(StringIO(data), match=".*Water.*") + df2 = flavor_read_html(StringIO(data), match="Unit") + + assert_framelist_equal(df1, df2) + + def test_file_like(self, spam_data, flavor_read_html): + with open(spam_data, encoding="UTF-8") as f: + df1 = flavor_read_html(f, match=".*Water.*") + + with open(spam_data, encoding="UTF-8") as f: + df2 = flavor_read_html(f, match="Unit") + + assert_framelist_equal(df1, df2) + + @pytest.mark.network + @pytest.mark.single_cpu + def test_bad_url_protocol(self, httpserver, flavor_read_html): + httpserver.serve_content("urlopen error unknown url type: git", code=404) + with pytest.raises(URLError, match="urlopen error unknown url type: git"): + flavor_read_html("git://github.com", match=".*Water.*") + + @pytest.mark.slow + @pytest.mark.network + @pytest.mark.single_cpu + def test_invalid_url(self, httpserver, flavor_read_html): + httpserver.serve_content("Name or service not known", code=404) + try: + with pytest.raises( + (URLError, ValueError), match="HTTP Error 404: NOT FOUND" + ) as err: + flavor_read_html(httpserver.url, match=".*Water.*") + finally: + if isinstance(err.value, URLError): + # Has a file-like handle that we can close + # https://docs.python.org/3/library/urllib.error.html#urllib.error.HTTPError + err.value.close() + + @pytest.mark.slow + def test_file_url(self, banklist_data, flavor_read_html): + url = banklist_data + dfs = flavor_read_html( + file_path_to_url(os.path.abspath(url)), match="First", attrs={"id": "table"} + ) + assert isinstance(dfs, list) + for df in dfs: + assert isinstance(df, DataFrame) + + @pytest.mark.slow + def test_invalid_table_attrs(self, banklist_data, flavor_read_html): + url = banklist_data + with pytest.raises(ValueError, match="No tables found"): + flavor_read_html( + url, match="First Federal Bank of Florida", attrs={"id": "tasdfable"} + ) + + @pytest.mark.slow + def test_multiindex_header(self, banklist_data, flavor_read_html): + df = flavor_read_html( + banklist_data, match="Metcalf", attrs={"id": "table"}, header=[0, 1] + )[0] + assert isinstance(df.columns, MultiIndex) + + @pytest.mark.slow + def test_multiindex_index(self, banklist_data, flavor_read_html): + df = flavor_read_html( + banklist_data, match="Metcalf", attrs={"id": "table"}, index_col=[0, 1] + )[0] + assert isinstance(df.index, MultiIndex) + + @pytest.mark.slow + def test_multiindex_header_index(self, banklist_data, flavor_read_html): + df = flavor_read_html( + banklist_data, + match="Metcalf", + attrs={"id": "table"}, + header=[0, 1], + index_col=[0, 1], + )[0] + assert isinstance(df.columns, MultiIndex) + assert isinstance(df.index, MultiIndex) + + @pytest.mark.slow + def test_multiindex_header_skiprows_tuples(self, banklist_data, flavor_read_html): + df = flavor_read_html( + banklist_data, + match="Metcalf", + attrs={"id": "table"}, + header=[0, 1], + skiprows=1, + )[0] + assert isinstance(df.columns, MultiIndex) + + @pytest.mark.slow + def test_multiindex_header_skiprows(self, banklist_data, flavor_read_html): + df = flavor_read_html( + banklist_data, + match="Metcalf", + attrs={"id": "table"}, + header=[0, 1], + skiprows=1, + )[0] + assert isinstance(df.columns, MultiIndex) + + @pytest.mark.slow + def test_multiindex_header_index_skiprows(self, banklist_data, flavor_read_html): + df = flavor_read_html( + banklist_data, + match="Metcalf", + attrs={"id": "table"}, + header=[0, 1], + index_col=[0, 1], + skiprows=1, + )[0] + assert isinstance(df.index, MultiIndex) + assert isinstance(df.columns, MultiIndex) + + @pytest.mark.slow + def test_regex_idempotency(self, banklist_data, flavor_read_html): + url = banklist_data + dfs = flavor_read_html( + file_path_to_url(os.path.abspath(url)), + match=re.compile(re.compile("Florida")), + attrs={"id": "table"}, + ) + assert isinstance(dfs, list) + for df in dfs: + assert isinstance(df, DataFrame) + + def test_negative_skiprows(self, spam_data, flavor_read_html): + msg = r"\(you passed a negative value\)" + with pytest.raises(ValueError, match=msg): + flavor_read_html(spam_data, match="Water", skiprows=-1) + + @pytest.fixture + def python_docs(self): + return """ + + +
+ + + + + + + + + + + + +
+ +

Indices and tables:

+ + +
+ + + + + + +
+ """ # noqa: E501 + + @pytest.mark.network + @pytest.mark.single_cpu + def test_multiple_matches(self, python_docs, httpserver, flavor_read_html): + httpserver.serve_content(content=python_docs) + dfs = flavor_read_html(httpserver.url, match="Python") + assert len(dfs) > 1 + + @pytest.mark.network + @pytest.mark.single_cpu + def test_python_docs_table(self, python_docs, httpserver, flavor_read_html): + httpserver.serve_content(content=python_docs) + dfs = flavor_read_html(httpserver.url, match="Python") + zz = [df.iloc[0, 0][0:4] for df in dfs] + assert sorted(zz) == ["Pyth", "What"] + + def test_empty_tables(self, flavor_read_html): + """ + Make sure that read_html ignores empty tables. + """ + html = """ + + + + + + + + + + + + + +
AB
12
+ + + +
+ """ + result = flavor_read_html(StringIO(html)) + assert len(result) == 1 + + def test_multiple_tbody(self, flavor_read_html): + # GH-20690 + # Read all tbody tags within a single table. + result = flavor_read_html( + StringIO( + """ + + + + + + + + + + + + + + + + + + +
AB
12
34
""" + ) + )[0] + + expected = DataFrame(data=[[1, 2], [3, 4]], columns=["A", "B"]) + + tm.assert_frame_equal(result, expected) + + def test_header_and_one_column(self, flavor_read_html): + """ + Don't fail with bs4 when there is a header and only one column + as described in issue #9178 + """ + result = flavor_read_html( + StringIO( + """ + + + + + + + + + + +
Header
first
""" + ) + )[0] + + expected = DataFrame(data={"Header": "first"}, index=[0]) + + tm.assert_frame_equal(result, expected) + + def test_thead_without_tr(self, flavor_read_html): + """ + Ensure parser adds
+ + + + + + + + + + + + + + +
CountryMunicipalityYear
UkraineOdessa1944
""" + ) + )[0] + + expected = DataFrame( + data=[["Ukraine", "Odessa", 1944]], + columns=["Country", "Municipality", "Year"], + ) + + tm.assert_frame_equal(result, expected) + + def test_tfoot_read(self, flavor_read_html): + """ + Make sure that read_html reads tfoot, containing td or th. + Ignores empty tfoot + """ + data_template = """ + + + + + + + + + + + + + + {footer} + +
AB
bodyAbodyB
""" + + expected1 = DataFrame(data=[["bodyA", "bodyB"]], columns=["A", "B"]) + + expected2 = DataFrame( + data=[["bodyA", "bodyB"], ["footA", "footB"]], columns=["A", "B"] + ) + + data1 = data_template.format(footer="") + data2 = data_template.format(footer="
footAfootB
+ + + + + + + + +
SI
text1944
+ """ + ), + header=0, + )[0] + + expected = DataFrame([["text", 1944]], columns=("S", "I")) + + tm.assert_frame_equal(result, expected) + + @pytest.mark.slow + def test_banklist_header(self, banklist_data, datapath, flavor_read_html): + from pandas.io.html import _remove_whitespace + + def try_remove_ws(x): + try: + return _remove_whitespace(x) + except AttributeError: + return x + + df = flavor_read_html(banklist_data, match="Metcalf", attrs={"id": "table"})[0] + ground_truth = read_csv( + datapath("io", "data", "csv", "banklist.csv"), + converters={"Updated Date": Timestamp, "Closing Date": Timestamp}, + ) + # html is a truncated version of banklist since bs4 is slow to parse it + assert df.shape == (len(df), ground_truth.shape[1]) + old = [ + "First Vietnamese American Bank In Vietnamese", + "Westernbank Puerto Rico En Espanol", + "R-G Premier Bank of Puerto Rico En Espanol", + "Eurobank En Espanol", + "Sanderson State Bank En Espanol", + "Washington Mutual Bank (Including its subsidiary Washington " + "Mutual Bank FSB)", + "Silver State Bank En Espanol", + "AmTrade International Bank En Espanol", + "Hamilton Bank, NA En Espanol", + "The Citizens Savings Bank Pioneer Community Bank, Inc.", + ] + new = [ + "First Vietnamese American Bank", + "Westernbank Puerto Rico", + "R-G Premier Bank of Puerto Rico", + "Eurobank", + "Sanderson State Bank", + "Washington Mutual Bank", + "Silver State Bank", + "AmTrade International Bank", + "Hamilton Bank, NA", + "The Citizens Savings Bank", + ] + dfnew = df.map(try_remove_ws).replace(old, new) + gtnew = ground_truth.map(try_remove_ws) + converted = dfnew + date_cols = ["Closing Date", "Updated Date"] + converted[date_cols] = converted[date_cols].apply(to_datetime) + gtnew = gtnew[gtnew["Bank Name"].isin(converted["Bank Name"])].reset_index( + drop=True + ) + tm.assert_frame_equal(converted, gtnew) + + @pytest.mark.slow + def test_heartland_bank(self, banklist_data, flavor_read_html): + gc = "Heartland Bank" + with open(banklist_data, encoding="utf-8") as f: + raw_text = f.read() + + assert gc in raw_text + df = flavor_read_html(banklist_data, match=gc, attrs={"id": "table"})[0] + assert gc in df.to_string() + + def test_different_number_of_cols(self, flavor_read_html): + expected = flavor_read_html( + StringIO( + """ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
C_l0_g0C_l0_g1C_l0_g2C_l0_g3C_l0_g4
R_l0_g0 0.763 0.233 nan nan nan
R_l0_g1 0.244 0.285 0.392 0.137 0.222
""" + ), + index_col=0, + )[0] + + result = flavor_read_html( + StringIO( + """ + + + + + + + + + + + + + + + + + + + + + + + + + +
C_l0_g0C_l0_g1C_l0_g2C_l0_g3C_l0_g4
R_l0_g0 0.763 0.233
R_l0_g1 0.244 0.285 0.392 0.137 0.222
""" + ), + index_col=0, + )[0] + + tm.assert_frame_equal(result, expected) + + def test_colspan_rowspan_1(self, flavor_read_html): + # GH17054 + result = flavor_read_html( + StringIO( + """ + + + + + + + + + + + +
ABC
abc
+ """ + ) + )[0] + + expected = DataFrame([["a", "b", "c"]], columns=["A", "B", "C"]) + + tm.assert_frame_equal(result, expected) + + def test_colspan_rowspan_copy_values(self, flavor_read_html): + # GH17054 + + # In ASCII, with lowercase letters being copies: + # + # X x Y Z W + # A B b z C + + result = flavor_read_html( + StringIO( + """ + + + + + + + + + + + + +
XYZW
ABC
+ """ + ), + header=0, + )[0] + + expected = DataFrame( + data=[["A", "B", "B", "Z", "C"]], columns=["X", "X.1", "Y", "Z", "W"] + ) + + tm.assert_frame_equal(result, expected) + + def test_colspan_rowspan_both_not_1(self, flavor_read_html): + # GH17054 + + # In ASCII, with lowercase letters being copies: + # + # A B b b C + # a b b b D + + result = flavor_read_html( + StringIO( + """ + + + + + + + + + +
ABC
D
+ """ + ), + header=0, + )[0] + + expected = DataFrame( + data=[["A", "B", "B", "B", "D"]], columns=["A", "B", "B.1", "B.2", "C"] + ) + + tm.assert_frame_equal(result, expected) + + def test_rowspan_at_end_of_row(self, flavor_read_html): + # GH17054 + + # In ASCII, with lowercase letters being copies: + # + # A B + # C b + + result = flavor_read_html( + StringIO( + """ + + + + + + + + +
AB
C
+ """ + ), + header=0, + )[0] + + expected = DataFrame(data=[["C", "B"]], columns=["A", "B"]) + + tm.assert_frame_equal(result, expected) + + def test_rowspan_only_rows(self, flavor_read_html): + # GH17054 + + result = flavor_read_html( + StringIO( + """ + + + + + +
AB
+ """ + ), + header=0, + )[0] + + expected = DataFrame(data=[["A", "B"], ["A", "B"]], columns=["A", "B"]) + + tm.assert_frame_equal(result, expected) + + def test_rowspan_in_header_overflowing_to_body(self, flavor_read_html): + # GH60210 + + result = flavor_read_html( + StringIO( + """ + + + + + + + + + + + + +
AB
1
C2
+ """ + ) + )[0] + + expected = DataFrame(data=[["A", 1], ["C", 2]], columns=["A", "B"]) + + tm.assert_frame_equal(result, expected) + + def test_header_inferred_from_rows_with_only_th(self, flavor_read_html): + # GH17054 + result = flavor_read_html( + StringIO( + """ + + + + + + + + + + + + + +
AB
ab
12
+ """ + ) + )[0] + + columns = MultiIndex(levels=[["A", "B"], ["a", "b"]], codes=[[0, 1], [0, 1]]) + expected = DataFrame(data=[[1, 2]], columns=columns) + + tm.assert_frame_equal(result, expected) + + def test_parse_dates_list(self, flavor_read_html): + df = DataFrame({"date": date_range("1/1/2001", periods=10)}) + + expected = df[:] + expected["date"] = expected["date"].dt.as_unit("us") + + str_df = df.to_html() + res = flavor_read_html(StringIO(str_df), parse_dates=[1], index_col=0) + tm.assert_frame_equal(expected, res[0]) + res = flavor_read_html(StringIO(str_df), parse_dates=["date"], index_col=0) + tm.assert_frame_equal(expected, res[0]) + + def test_wikipedia_states_table(self, datapath, flavor_read_html): + data = datapath("io", "data", "html", "wikipedia_states.html") + assert os.path.isfile(data), f"{data!r} is not a file" + assert os.path.getsize(data), f"{data!r} is an empty file" + result = flavor_read_html(data, match="Arizona", header=1)[0] + assert result.shape == (60, 12) + assert "Unnamed" in result.columns[-1] + assert result["sq mi"].dtype == np.dtype("float64") + assert np.allclose(result.loc[0, "sq mi"], 665384.04) + + def test_wikipedia_states_multiindex(self, datapath, flavor_read_html): + data = datapath("io", "data", "html", "wikipedia_states.html") + result = flavor_read_html(data, match="Arizona", index_col=0)[0] + assert result.shape == (60, 11) + assert "Unnamed" in result.columns[-1][1] + assert result.columns.nlevels == 2 + assert np.allclose(result.loc["Alaska", ("Total area[2]", "sq mi")], 665384.04) + + def test_parser_error_on_empty_header_row(self, flavor_read_html): + result = flavor_read_html( + StringIO( + """ + + + + + + + + +
AB
ab
+ """ + ), + header=[0, 1], + ) + expected = DataFrame( + [["a", "b"]], + columns=MultiIndex.from_tuples( + [("Unnamed: 0_level_0", "A"), ("Unnamed: 1_level_0", "B")] + ), + ) + tm.assert_frame_equal(result[0], expected) + + def test_decimal_rows(self, flavor_read_html): + # GH 12907 + result = flavor_read_html( + StringIO( + """ + + + + + + + + + + + + +
Header
1100#101
+ + """ + ), + decimal="#", + )[0] + + expected = DataFrame(data={"Header": 1100.101}, index=[0]) + + assert result["Header"].dtype == np.dtype("float64") + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("arg", [True, False]) + def test_bool_header_arg(self, spam_data, arg, flavor_read_html): + # GH 6114 + msg = re.escape( + "Passing a bool to header is invalid. Use header=None for no header or " + "header=int or list-like of ints to specify the row(s) making up the " + "column names" + ) + with pytest.raises(TypeError, match=msg): + flavor_read_html(spam_data, header=arg) + + def test_converters(self, flavor_read_html): + # GH 13461 + result = flavor_read_html( + StringIO( + """ + + + + + + + + + + + + + +
a
0.763
0.244
""" + ), + converters={"a": str}, + )[0] + + expected = DataFrame({"a": ["0.763", "0.244"]}) + + tm.assert_frame_equal(result, expected) + + def test_na_values(self, flavor_read_html): + # GH 13461 + result = flavor_read_html( + StringIO( + """ + + + + + + + + + + + + + +
a
0.763
0.244
""" + ), + na_values=[0.244], + )[0] + + expected = DataFrame({"a": [0.763, np.nan]}) + + tm.assert_frame_equal(result, expected) + + def test_keep_default_na(self, flavor_read_html): + html_data = """ + + + + + + + + + + + + + +
a
N/A
NA
""" + + expected_df = DataFrame({"a": ["N/A", "NA"]}) + html_df = flavor_read_html(StringIO(html_data), keep_default_na=False)[0] + tm.assert_frame_equal(expected_df, html_df) + + expected_df = DataFrame({"a": [np.nan, np.nan]}) + html_df = flavor_read_html(StringIO(html_data), keep_default_na=True)[0] + tm.assert_frame_equal(expected_df, html_df) + + def test_preserve_empty_rows(self, flavor_read_html): + result = flavor_read_html( + StringIO( + """ + + + + + + + + + + + + + +
AB
ab
+ """ + ) + )[0] + + expected = DataFrame(data=[["a", "b"], [np.nan, np.nan]], columns=["A", "B"]) + + tm.assert_frame_equal(result, expected) + + def test_ignore_empty_rows_when_inferring_header(self, flavor_read_html): + result = flavor_read_html( + StringIO( + """ + + + + + + + + + +
AB
ab
12
+ """ + ) + )[0] + + columns = MultiIndex(levels=[["A", "B"], ["a", "b"]], codes=[[0, 1], [0, 1]]) + expected = DataFrame(data=[[1, 2]], columns=columns) + + tm.assert_frame_equal(result, expected) + + def test_multiple_header_rows(self, flavor_read_html): + # Issue #13434 + expected_df = DataFrame( + data=[("Hillary", 68, "D"), ("Bernie", 74, "D"), ("Donald", 69, "R")] + ) + expected_df.columns = [ + ["Unnamed: 0_level_0", "Age", "Party"], + ["Name", "Unnamed: 1_level_1", "Unnamed: 2_level_1"], + ] + html = expected_df.to_html(index=False) + html_df = flavor_read_html(StringIO(html))[0] + tm.assert_frame_equal(expected_df, html_df) + + def test_works_on_valid_markup(self, datapath, flavor_read_html): + filename = datapath("io", "data", "html", "valid_markup.html") + dfs = flavor_read_html(filename, index_col=0) + assert isinstance(dfs, list) + assert isinstance(dfs[0], DataFrame) + + @pytest.mark.slow + def test_fallback_success(self, datapath, flavor_read_html): + banklist_data = datapath("io", "data", "html", "banklist.html") + + flavor_read_html(banklist_data, match=".*Water.*", flavor=["lxml", "html5lib"]) + + def test_to_html_timestamp(self): + rng = date_range("2000-01-01", periods=10) + df = DataFrame(np.random.default_rng(2).standard_normal((10, 4)), index=rng) + + result = df.to_html() + assert "2000-01-01" in result + + def test_to_html_borderless(self): + df = DataFrame([{"A": 1, "B": 2}]) + out_border_default = df.to_html() + out_border_true = df.to_html(border=True) + out_border_explicit_default = df.to_html(border=1) + out_border_nondefault = df.to_html(border=2) + out_border_zero = df.to_html(border=0) + + out_border_false = df.to_html(border=False) + + assert ' border="1"' in out_border_default + assert out_border_true == out_border_default + assert out_border_default == out_border_explicit_default + assert out_border_default != out_border_nondefault + assert ' border="2"' in out_border_nondefault + assert ' border="0"' not in out_border_zero + assert " border" not in out_border_false + assert out_border_zero == out_border_false + + @pytest.mark.parametrize( + "displayed_only,exp0,exp1", + [ + (True, ["foo"], None), + (False, ["foo bar baz qux"], DataFrame(["foo"])), + ], + ) + def test_displayed_only(self, displayed_only, exp0, exp1, flavor_read_html): + # GH 20027 + data = """ + + + + + +
+ foo + bar + baz + qux +
+ + + + +
foo
+ + """ + + exp0 = DataFrame(exp0) + dfs = flavor_read_html(StringIO(data), displayed_only=displayed_only) + tm.assert_frame_equal(dfs[0], exp0) + + if exp1 is not None: + tm.assert_frame_equal(dfs[1], exp1) + else: + assert len(dfs) == 1 # Should not parse hidden table + + @pytest.mark.parametrize("displayed_only", [True, False]) + def test_displayed_only_with_many_elements(self, displayed_only, flavor_read_html): + html_table = """ + + + + + + + + + + + + + +
AB
12
45
+ """ + result = flavor_read_html(StringIO(html_table), displayed_only=displayed_only)[ + 0 + ] + expected = DataFrame({"A": [1, 4], "B": [2, 5]}) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:You provided Unicode markup but also provided a value for " + "from_encoding.*:UserWarning" + ) + def test_encode(self, html_encoding_file, flavor_read_html): + base_path = os.path.basename(html_encoding_file) + root = os.path.splitext(base_path)[0] + _, encoding = root.split("_") + + try: + with open(html_encoding_file, "rb") as fobj: + from_string = flavor_read_html( + BytesIO(fobj.read()), encoding=encoding, index_col=0 + ).pop() + + with open(html_encoding_file, "rb") as fobj: + from_file_like = flavor_read_html( + BytesIO(fobj.read()), encoding=encoding, index_col=0 + ).pop() + + from_filename = flavor_read_html( + html_encoding_file, encoding=encoding, index_col=0 + ).pop() + tm.assert_frame_equal(from_string, from_file_like) + tm.assert_frame_equal(from_string, from_filename) + except Exception: + # seems utf-16/32 fail on windows + if is_platform_windows(): + if "16" in encoding or "32" in encoding: + pytest.skip() + raise + + def test_parse_failure_unseekable(self, flavor_read_html): + # Issue #17975 + + if flavor_read_html.keywords.get("flavor") == "lxml": + pytest.skip("Not applicable for lxml") + + class UnseekableStringIO(StringIO): + def seekable(self): + return False + + bad = UnseekableStringIO( + """ +
spameggs
""" + ) + + assert flavor_read_html(bad) + + with pytest.raises(ValueError, match="passed a non-rewindable file object"): + flavor_read_html(bad) + + def test_parse_failure_rewinds(self, flavor_read_html): + # Issue #17975 + + class MockFile: + def __init__(self, data) -> None: + self.data = data + self.at_end = False + + def read(self, size=None): + data = "" if self.at_end else self.data + self.at_end = True + return data + + def seek(self, offset): + self.at_end = False + + def seekable(self): + return True + + def __next__(self): ... + + def __iter__(self) -> Iterator: + # `is_file_like` depends on the presence of + # the __iter__ attribute. + return self + + good = MockFile("
spam
eggs
") + bad = MockFile("
spameggs
") + + assert flavor_read_html(good) + assert flavor_read_html(bad) + + @pytest.mark.slow + @pytest.mark.single_cpu + def test_importcheck_thread_safety(self, datapath, flavor_read_html): + # see gh-16928 + + class ErrorThread(threading.Thread): + def run(self): + try: + super().run() + except Exception as err: + self.err = err + else: + self.err = None + + filename = datapath("io", "data", "html", "valid_markup.html") + helper_thread1 = ErrorThread(target=flavor_read_html, args=(filename,)) + helper_thread2 = ErrorThread(target=flavor_read_html, args=(filename,)) + + helper_thread1.start() + helper_thread2.start() + + while helper_thread1.is_alive() or helper_thread2.is_alive(): + pass + assert None is helper_thread1.err is helper_thread2.err + + def test_parse_path_object(self, datapath, flavor_read_html): + # GH 37705 + file_path_string = datapath("io", "data", "html", "spam.html") + file_path = Path(file_path_string) + df1 = flavor_read_html(file_path_string)[0] + df2 = flavor_read_html(file_path)[0] + tm.assert_frame_equal(df1, df2) + + def test_parse_br_as_space(self, flavor_read_html): + # GH 29528: pd.read_html() convert
to space + result = flavor_read_html( + StringIO( + """ + + + + + + + +
A
word1
word2
+ """ + ) + )[0] + + expected = DataFrame(data=[["word1 word2"]], columns=["A"]) + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("arg", ["all", "body", "header", "footer"]) + def test_extract_links(self, arg, flavor_read_html): + gh_13141_data = """ + + + + + + + + + + + + + + + + + +
HTTPFTPLinkless
WikipediaSURROUNDING Debian TEXTLinkless
Footer + Multiple links: Only first captured. +
+ """ + + gh_13141_expected = { + "head_ignore": ["HTTP", "FTP", "Linkless"], + "head_extract": [ + ("HTTP", None), + ("FTP", None), + ("Linkless", "https://en.wiktionary.org/wiki/linkless"), + ], + "body_ignore": ["Wikipedia", "SURROUNDING Debian TEXT", "Linkless"], + "body_extract": [ + ("Wikipedia", "https://en.wikipedia.org/"), + ("SURROUNDING Debian TEXT", "ftp://ftp.us.debian.org/"), + ("Linkless", None), + ], + "footer_ignore": [ + "Footer", + "Multiple links: Only first captured.", + None, + ], + "footer_extract": [ + ("Footer", "https://en.wikipedia.org/wiki/Page_footer"), + ("Multiple links: Only first captured.", "1"), + None, + ], + } + + data_exp = gh_13141_expected["body_ignore"] + foot_exp = gh_13141_expected["footer_ignore"] + head_exp = gh_13141_expected["head_ignore"] + if arg == "all": + data_exp = gh_13141_expected["body_extract"] + foot_exp = gh_13141_expected["footer_extract"] + head_exp = gh_13141_expected["head_extract"] + elif arg == "body": + data_exp = gh_13141_expected["body_extract"] + elif arg == "footer": + foot_exp = gh_13141_expected["footer_extract"] + elif arg == "header": + head_exp = gh_13141_expected["head_extract"] + + result = flavor_read_html(StringIO(gh_13141_data), extract_links=arg)[0] + expected = DataFrame([data_exp, foot_exp], columns=head_exp) + expected = expected.fillna(np.nan) + tm.assert_frame_equal(result, expected) + + def test_extract_links_bad(self, spam_data): + msg = ( + "`extract_links` must be one of " + '{None, "header", "footer", "body", "all"}, got "incorrect"' + ) + with pytest.raises(ValueError, match=msg): + read_html(spam_data, extract_links="incorrect") + + def test_extract_links_all_no_header(self, flavor_read_html): + # GH 48316 + data = """ + + + + +
+ Google.com +
+ """ + result = flavor_read_html(StringIO(data), extract_links="all")[0] + expected = DataFrame([[("Google.com", "https://google.com")]]) + tm.assert_frame_equal(result, expected) + + def test_invalid_dtype_backend(self): + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + with pytest.raises(ValueError, match=msg): + read_html("test", dtype_backend="numpy") + + def test_style_tag(self, flavor_read_html): + # GH 48316 + data = """ + + + + + + + + + + + + + +
+ + A + B
A1B1
A2B2
+ """ + result = flavor_read_html(StringIO(data))[0] + expected = DataFrame(data=[["A1", "B1"], ["A2", "B2"]], columns=["A", "B"]) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/io/test_http_headers.py b/pandas/tests/io/test_http_headers.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9c8769ad9dc909f99d346395d5cd2113984992 --- /dev/null +++ b/pandas/tests/io/test_http_headers.py @@ -0,0 +1,174 @@ +""" +Tests for the pandas custom headers in http(s) requests +""" + +from functools import partial +import gzip +from io import BytesIO + +import pytest + +from pandas._config import using_string_dtype + +import pandas.util._test_decorators as td + +import pandas as pd +import pandas._testing as tm + +pytestmark = [ + pytest.mark.single_cpu, + pytest.mark.network, + pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" + ), +] + + +def gzip_bytes(response_bytes): + with BytesIO() as bio: + with gzip.GzipFile(fileobj=bio, mode="w") as zipper: + zipper.write(response_bytes) + return bio.getvalue() + + +def csv_responder(df): + return df.to_csv(index=False).encode("utf-8") + + +def gz_csv_responder(df): + return gzip_bytes(csv_responder(df)) + + +def json_responder(df): + return df.to_json().encode("utf-8") + + +def gz_json_responder(df): + return gzip_bytes(json_responder(df)) + + +def html_responder(df): + return df.to_html(index=False).encode("utf-8") + + +def parquetpyarrow_reponder(df): + return df.to_parquet(index=False, engine="pyarrow") + + +def parquetfastparquet_responder(df): + # the fastparquet engine doesn't like to write to a buffer + # it can do it via the open_with function being set appropriately + # however it automatically calls the close method and wipes the buffer + # so just overwrite that attribute on this instance to not do that + + # protected by an importorskip in the respective test + import fsspec + + df.to_parquet( + "memory://fastparquet_user_agent.parquet", + index=False, + engine="fastparquet", + compression=None, + ) + with fsspec.open("memory://fastparquet_user_agent.parquet", "rb") as f: + return f.read() + + +def pickle_respnder(df): + with BytesIO() as bio: + df.to_pickle(bio) + return bio.getvalue() + + +def stata_responder(df): + with BytesIO() as bio: + df.to_stata(bio, write_index=False) + return bio.getvalue() + + +@pytest.mark.parametrize( + "responder, read_method", + [ + (csv_responder, pd.read_csv), + (json_responder, pd.read_json), + ( + html_responder, + lambda *args, **kwargs: pd.read_html(*args, **kwargs)[0], + ), + pytest.param( + parquetpyarrow_reponder, + partial(pd.read_parquet, engine="pyarrow"), + marks=td.skip_if_no("pyarrow"), + ), + pytest.param( + parquetfastparquet_responder, + partial(pd.read_parquet, engine="fastparquet"), + marks=[ + td.skip_if_no("fastparquet"), + td.skip_if_no("fsspec"), + pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string"), + ], + ), + (pickle_respnder, pd.read_pickle), + (stata_responder, pd.read_stata), + (gz_csv_responder, pd.read_csv), + (gz_json_responder, pd.read_json), + ], +) +@pytest.mark.parametrize( + "storage_options", + [ + None, + {"User-Agent": "foo"}, + {"User-Agent": "foo", "Auth": "bar"}, + ], +) +def test_request_headers(responder, read_method, httpserver, storage_options): + expected = pd.DataFrame({"a": ["b"]}) + default_headers = ["Accept-Encoding", "Host", "Connection", "User-Agent"] + if "gz" in responder.__name__: + extra = {"Content-Encoding": "gzip"} + if storage_options is None: + storage_options = extra + else: + storage_options |= extra + else: + extra = None + expected_headers = set(default_headers).union( + storage_options.keys() if storage_options else [] + ) + httpserver.serve_content(content=responder(expected), headers=extra) + result = read_method(httpserver.url, storage_options=storage_options) + tm.assert_frame_equal(result, expected) + + request_headers = dict(httpserver.requests[0].headers) + for header in expected_headers: + exp = request_headers.pop(header) + if storage_options and header in storage_options: + assert exp == storage_options[header] + # No extra headers added + assert not request_headers + + +@pytest.mark.parametrize( + "engine", + [ + "pyarrow", + "fastparquet", + ], +) +def test_to_parquet_to_disk_with_storage_options(engine): + headers = { + "User-Agent": "custom", + "Auth": "other_custom", + } + + pytest.importorskip(engine) + + true_df = pd.DataFrame({"column_name": ["column_value"]}) + msg = ( + "storage_options passed with file object or non-fsspec file path|" + "storage_options passed with buffer, or non-supported URL" + ) + with pytest.raises(ValueError, match=msg): + true_df.to_parquet("/tmp/junk.parquet", storage_options=headers, engine=engine) diff --git a/pandas/tests/io/test_iceberg.py b/pandas/tests/io/test_iceberg.py new file mode 100644 index 0000000000000000000000000000000000000000..689eddb1985e6344d72c26fbb30a5237489cdadd --- /dev/null +++ b/pandas/tests/io/test_iceberg.py @@ -0,0 +1,222 @@ +""" +Tests for the Apache Iceberg format. + +Tests in this file use a simple Iceberg catalog based on SQLite, with the same +data used for Parquet tests (``pandas/tests/io/data/parquet/simple.parquet``). +""" + +import collections +import importlib +import pathlib + +import pytest + +import pandas as pd +import pandas._testing as tm + +from pandas.io.iceberg import read_iceberg + +pytestmark = pytest.mark.single_cpu + +pyiceberg = pytest.importorskip("pyiceberg") +pyiceberg_catalog = pytest.importorskip("pyiceberg.catalog") +pq = pytest.importorskip("pyarrow.parquet") + +Catalog = collections.namedtuple("Catalog", ["name", "uri", "warehouse"]) + + +@pytest.fixture +def catalog(request, tmp_path): + # the catalog stores the full path of data files, so the catalog needs to be + # created dynamically, and not saved in pandas/tests/io/data as other formats + uri = f"sqlite:///{tmp_path}/catalog.sqlite" + warehouse = f"file://{tmp_path}" + catalog_name = request.param if hasattr(request, "param") else None + catalog = pyiceberg_catalog.load_catalog( + catalog_name or "default", + type="sql", + uri=uri, + warehouse=warehouse, + ) + catalog.create_namespace("ns") + + df = pq.read_table( + pathlib.Path(__file__).parent / "data" / "parquet" / "simple.parquet" + ) + table = catalog.create_table("ns.my_table", schema=df.schema) + table.append(df) + + if catalog_name is not None: + config_path = pathlib.Path.home() / ".pyiceberg.yaml" + with open(config_path, "w", encoding="utf-8") as f: + f.write(f"""\ +catalog: + {catalog_name}: + type: sql + uri: {uri} + warehouse: {warehouse}""") + + importlib.reload(pyiceberg_catalog) # needed to reload the config file + + yield Catalog(name=catalog_name or "default", uri=uri, warehouse=warehouse) + + if catalog_name is not None: + config_path.unlink() + + +class TestIceberg: + def test_read(self, catalog): + expected = pd.DataFrame( + { + "A": [1, 2, 3], + "B": ["foo", "foo", "foo"], + } + ) + result = read_iceberg( + "ns.my_table", + catalog_properties={"uri": catalog.uri}, + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("catalog", ["default", "pandas_tests"], indirect=True) + def test_read_by_catalog_name(self, catalog): + expected = pd.DataFrame( + { + "A": [1, 2, 3], + "B": ["foo", "foo", "foo"], + } + ) + result = read_iceberg( + "ns.my_table", + catalog_name=catalog.name, + ) + tm.assert_frame_equal(result, expected) + + def test_read_with_row_filter(self, catalog): + expected = pd.DataFrame( + { + "A": [2, 3], + "B": ["foo", "foo"], + } + ) + result = read_iceberg( + "ns.my_table", + catalog_properties={"uri": catalog.uri}, + row_filter="A > 1", + ) + tm.assert_frame_equal(result, expected) + + def test_read_with_case_sensitive(self, catalog): + expected = pd.DataFrame( + { + "A": [1, 2, 3], + } + ) + result = read_iceberg( + "ns.my_table", + catalog_properties={"uri": catalog.uri}, + columns=["a"], + case_sensitive=False, + ) + tm.assert_frame_equal(result, expected) + + with pytest.raises(ValueError, match="^Could not find column"): + read_iceberg( + "ns.my_table", + catalog_properties={"uri": catalog.uri}, + columns=["a"], + case_sensitive=True, + ) + + def test_read_with_limit(self, catalog): + expected = pd.DataFrame( + { + "A": [1, 2], + "B": ["foo", "foo"], + } + ) + result = read_iceberg( + "ns.my_table", + catalog_properties={"uri": catalog.uri}, + limit=2, + ) + tm.assert_frame_equal(result, expected) + + def test_write(self, catalog): + df = pd.DataFrame( + { + "A": [1, 2, 3], + "B": ["foo", "foo", "foo"], + } + ) + df.to_iceberg( + "ns.new_table", + catalog_properties={"uri": catalog.uri}, + location=catalog.warehouse, + ) + result = read_iceberg( + "ns.new_table", + catalog_properties={"uri": catalog.uri}, + ) + tm.assert_frame_equal(result, df) + + @pytest.mark.parametrize("catalog", ["default", "pandas_tests"], indirect=True) + def test_write_by_catalog_name(self, catalog): + df = pd.DataFrame( + { + "A": [1, 2, 3], + "B": ["foo", "foo", "foo"], + } + ) + df.to_iceberg( + "ns.new_table", + catalog_name=catalog.name, + ) + result = read_iceberg( + "ns.new_table", + catalog_name=catalog.name, + ) + tm.assert_frame_equal(result, df) + + def test_write_existing_table_with_append_true(self, catalog): + original = read_iceberg( + "ns.my_table", + catalog_properties={"uri": catalog.uri}, + ) + new = pd.DataFrame( + { + "A": [1, 2, 3], + "B": ["foo", "foo", "foo"], + } + ) + expected = pd.concat([original, new], ignore_index=True) + new.to_iceberg( + "ns.my_table", + catalog_properties={"uri": catalog.uri}, + location=catalog.warehouse, + append=True, + ) + result = read_iceberg( + "ns.my_table", + catalog_properties={"uri": catalog.uri}, + ) + tm.assert_frame_equal(result, expected) + + def test_write_existing_table_with_append_false(self, catalog): + df = pd.DataFrame( + { + "A": [1, 2, 3], + "B": ["foo", "foo", "foo"], + } + ) + df.to_iceberg( + "ns.my_table", + catalog_properties={"uri": catalog.uri}, + location=catalog.warehouse, + append=False, + ) + result = read_iceberg( + "ns.my_table", + catalog_properties={"uri": catalog.uri}, + ) + tm.assert_frame_equal(result, df) diff --git a/pandas/tests/io/test_orc.py b/pandas/tests/io/test_orc.py new file mode 100644 index 0000000000000000000000000000000000000000..2e61494103355c5836f3ad0e9187d712093610dc --- /dev/null +++ b/pandas/tests/io/test_orc.py @@ -0,0 +1,432 @@ +"""test orc compat""" + +import datetime +from decimal import Decimal +from io import BytesIO +import os + +import numpy as np +import pytest + +import pandas as pd +from pandas import read_orc +import pandas._testing as tm + +pytest.importorskip("pyarrow.orc") + +import pyarrow as pa + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" +) + + +@pytest.fixture +def dirpath(datapath): + return datapath("io", "data", "orc") + + +def test_orc_reader_empty(dirpath, using_infer_string): + columns = [ + "boolean1", + "byte1", + "short1", + "int1", + "long1", + "float1", + "double1", + "bytes1", + "string1", + ] + dtypes = [ + "bool", + "int8", + "int16", + "int32", + "int64", + "float32", + "float64", + "object", + "str" if using_infer_string else "object", + ] + expected = pd.DataFrame(index=pd.RangeIndex(0)) + for colname, dtype in zip(columns, dtypes, strict=True): + expected[colname] = pd.Series(dtype=dtype) + expected.columns = expected.columns.astype("str") + + inputfile = os.path.join(dirpath, "TestOrcFile.emptyFile.orc") + got = read_orc(inputfile, columns=columns) + + tm.assert_equal(expected, got) + + +def test_orc_reader_basic(dirpath): + data = { + "boolean1": np.array([False, True], dtype="bool"), + "byte1": np.array([1, 100], dtype="int8"), + "short1": np.array([1024, 2048], dtype="int16"), + "int1": np.array([65536, 65536], dtype="int32"), + "long1": np.array([9223372036854775807, 9223372036854775807], dtype="int64"), + "float1": np.array([1.0, 2.0], dtype="float32"), + "double1": np.array([-15.0, -5.0], dtype="float64"), + "bytes1": np.array([b"\x00\x01\x02\x03\x04", b""], dtype="object"), + "string1": np.array(["hi", "bye"], dtype="object"), + } + expected = pd.DataFrame.from_dict(data) + + inputfile = os.path.join(dirpath, "TestOrcFile.test1.orc") + got = read_orc(inputfile, columns=data.keys()) + + tm.assert_equal(expected, got) + + +def test_orc_reader_decimal(dirpath): + # Only testing the first 10 rows of data + data = { + "_col0": np.array( + [ + Decimal("-1000.50000"), + Decimal("-999.60000"), + Decimal("-998.70000"), + Decimal("-997.80000"), + Decimal("-996.90000"), + Decimal("-995.10000"), + Decimal("-994.11000"), + Decimal("-993.12000"), + Decimal("-992.13000"), + Decimal("-991.14000"), + ], + dtype="object", + ) + } + expected = pd.DataFrame.from_dict(data) + + inputfile = os.path.join(dirpath, "TestOrcFile.decimal.orc") + got = read_orc(inputfile).iloc[:10] + + tm.assert_equal(expected, got) + + +def test_orc_reader_date_low(dirpath): + data = { + "time": np.array( + [ + "1900-05-05 12:34:56.100000", + "1900-05-05 12:34:56.100100", + "1900-05-05 12:34:56.100200", + "1900-05-05 12:34:56.100300", + "1900-05-05 12:34:56.100400", + "1900-05-05 12:34:56.100500", + "1900-05-05 12:34:56.100600", + "1900-05-05 12:34:56.100700", + "1900-05-05 12:34:56.100800", + "1900-05-05 12:34:56.100900", + ], + dtype="datetime64[ns]", + ), + "date": np.array( + [ + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + datetime.date(1900, 12, 25), + ], + dtype="object", + ), + } + expected = pd.DataFrame.from_dict(data) + + inputfile = os.path.join(dirpath, "TestOrcFile.testDate1900.orc") + got = read_orc(inputfile).iloc[:10] + + tm.assert_equal(expected, got) + + +def test_orc_reader_date_high(dirpath): + data = { + "time": np.array( + [ + "2038-05-05 12:34:56.100000", + "2038-05-05 12:34:56.100100", + "2038-05-05 12:34:56.100200", + "2038-05-05 12:34:56.100300", + "2038-05-05 12:34:56.100400", + "2038-05-05 12:34:56.100500", + "2038-05-05 12:34:56.100600", + "2038-05-05 12:34:56.100700", + "2038-05-05 12:34:56.100800", + "2038-05-05 12:34:56.100900", + ], + dtype="datetime64[ns]", + ), + "date": np.array( + [ + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + datetime.date(2038, 12, 25), + ], + dtype="object", + ), + } + expected = pd.DataFrame.from_dict(data) + + inputfile = os.path.join(dirpath, "TestOrcFile.testDate2038.orc") + got = read_orc(inputfile).iloc[:10] + + tm.assert_equal(expected, got) + + +def test_orc_reader_snappy_compressed(dirpath): + data = { + "int1": np.array( + [ + -1160101563, + 1181413113, + 2065821249, + -267157795, + 172111193, + 1752363137, + 1406072123, + 1911809390, + -1308542224, + -467100286, + ], + dtype="int32", + ), + "string1": np.array( + [ + "f50dcb8", + "382fdaaa", + "90758c6", + "9e8caf3f", + "ee97332b", + "d634da1", + "2bea4396", + "d67d89e8", + "ad71007e", + "e8c82066", + ], + dtype="object", + ), + } + expected = pd.DataFrame.from_dict(data) + + inputfile = os.path.join(dirpath, "TestOrcFile.testSnappy.orc") + got = read_orc(inputfile).iloc[:10] + + tm.assert_equal(expected, got) + + +def test_orc_roundtrip_file(dirpath, temp_file): + # GH44554 + # PyArrow gained ORC write support with the current argument order + pytest.importorskip("pyarrow") + + data = { + "boolean1": np.array([False, True], dtype="bool"), + "byte1": np.array([1, 100], dtype="int8"), + "short1": np.array([1024, 2048], dtype="int16"), + "int1": np.array([65536, 65536], dtype="int32"), + "long1": np.array([9223372036854775807, 9223372036854775807], dtype="int64"), + "float1": np.array([1.0, 2.0], dtype="float32"), + "double1": np.array([-15.0, -5.0], dtype="float64"), + "bytes1": np.array([b"\x00\x01\x02\x03\x04", b""], dtype="object"), + "string1": np.array(["hi", "bye"], dtype="object"), + } + expected = pd.DataFrame.from_dict(data) + + expected.to_orc(temp_file) + got = read_orc(temp_file) + + tm.assert_equal(expected, got) + + +def test_orc_roundtrip_bytesio(): + # GH44554 + # PyArrow gained ORC write support with the current argument order + pytest.importorskip("pyarrow") + + data = { + "boolean1": np.array([False, True], dtype="bool"), + "byte1": np.array([1, 100], dtype="int8"), + "short1": np.array([1024, 2048], dtype="int16"), + "int1": np.array([65536, 65536], dtype="int32"), + "long1": np.array([9223372036854775807, 9223372036854775807], dtype="int64"), + "float1": np.array([1.0, 2.0], dtype="float32"), + "double1": np.array([-15.0, -5.0], dtype="float64"), + "bytes1": np.array([b"\x00\x01\x02\x03\x04", b""], dtype="object"), + "string1": np.array(["hi", "bye"], dtype="object"), + } + expected = pd.DataFrame.from_dict(data) + + bytes = expected.to_orc() + got = read_orc(BytesIO(bytes)) + + tm.assert_equal(expected, got) + + +@pytest.mark.parametrize( + "orc_writer_dtypes_not_supported", + [ + np.array([1, 20], dtype="uint64"), + pd.Series(["a", "b", "a"], dtype="category"), + [pd.Interval(left=0, right=2), pd.Interval(left=0, right=5)], + [pd.Period("2022-01-03", freq="D"), pd.Period("2022-01-04", freq="D")], + ], +) +def test_orc_writer_dtypes_not_supported(orc_writer_dtypes_not_supported): + # GH44554 + # PyArrow gained ORC write support with the current argument order + pytest.importorskip("pyarrow") + + df = pd.DataFrame({"unimpl": orc_writer_dtypes_not_supported}) + msg = "The dtype of one or more columns is not supported yet." + with pytest.raises(NotImplementedError, match=msg): + df.to_orc() + + +def test_orc_dtype_backend_pyarrow(using_infer_string): + pytest.importorskip("pyarrow") + df = pd.DataFrame( + { + "string": list("abc"), + "string_with_nan": ["a", np.nan, "c"], + "string_with_none": ["a", None, "c"], + "bytes": [b"foo", b"bar", None], + "int": list(range(1, 4)), + "float": np.arange(4.0, 7.0, dtype="float64"), + "float_with_nan": [2.0, np.nan, 3.0], + "bool": [True, False, True], + "bool_with_na": [True, False, None], + "datetime": pd.date_range("20130101", periods=3, unit="ns"), + "datetime_with_nat": [ + pd.Timestamp("20130101"), + pd.NaT, + pd.Timestamp("20130103"), + ], + } + ) + # FIXME: without casting to ns we do not round-trip correctly + df["datetime_with_nat"] = df["datetime_with_nat"].astype("M8[ns]") + + bytes_data = df.copy().to_orc() + result = read_orc(BytesIO(bytes_data), dtype_backend="pyarrow") + + expected = pd.DataFrame( + { + col: pd.arrays.ArrowExtensionArray(pa.array(df[col], from_pandas=True)) + for col in df.columns + } + ) + if using_infer_string: + # ORC does not preserve distinction between string and large string + # -> the default large string comes back as string + string_dtype = pd.ArrowDtype(pa.string()) + expected["string"] = expected["string"].astype(string_dtype) + expected["string_with_nan"] = expected["string_with_nan"].astype(string_dtype) + expected["string_with_none"] = expected["string_with_none"].astype(string_dtype) + + tm.assert_frame_equal(result, expected) + + +def test_orc_dtype_backend_numpy_nullable(): + # GH#50503 + pytest.importorskip("pyarrow") + df = pd.DataFrame( + { + "string": list("abc"), + "string_with_nan": ["a", np.nan, "c"], + "string_with_none": ["a", None, "c"], + "int": list(range(1, 4)), + "int_with_nan": pd.Series([1, pd.NA, 3], dtype="Int64"), + "na_only": pd.Series([pd.NA, pd.NA, pd.NA], dtype="Int64"), + "float": np.arange(4.0, 7.0, dtype="float64"), + "float_with_nan": [2.0, np.nan, 3.0], + "bool": [True, False, True], + "bool_with_na": [True, False, None], + } + ) + + bytes_data = df.copy().to_orc() + result = read_orc(BytesIO(bytes_data), dtype_backend="numpy_nullable") + + expected = pd.DataFrame( + { + "string": pd.array(["a", "b", "c"], dtype=pd.StringDtype()), + "string_with_nan": pd.array(["a", pd.NA, "c"], dtype=pd.StringDtype()), + "string_with_none": pd.array(["a", pd.NA, "c"], dtype=pd.StringDtype()), + "int": pd.Series([1, 2, 3], dtype="Int64"), + "int_with_nan": pd.Series([1, pd.NA, 3], dtype="Int64"), + "na_only": pd.Series([pd.NA, pd.NA, pd.NA], dtype="Int64"), + "float": pd.Series([4.0, 5.0, 6.0], dtype="Float64"), + "float_with_nan": pd.Series([2.0, pd.NA, 3.0], dtype="Float64"), + "bool": pd.Series([True, False, True], dtype="boolean"), + "bool_with_na": pd.Series([True, False, pd.NA], dtype="boolean"), + } + ) + + tm.assert_frame_equal(result, expected) + + +def test_orc_uri_path(temp_file): + expected = pd.DataFrame({"int": list(range(1, 4))}) + expected.to_orc(temp_file) + uri = temp_file.as_uri() + result = read_orc(uri) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "index", + [ + pd.RangeIndex(start=2, stop=5, step=1), + pd.RangeIndex(start=0, stop=3, step=1, name="non-default"), + pd.Index([1, 2, 3]), + ], +) +def test_to_orc_non_default_index(index): + df = pd.DataFrame({"a": [1, 2, 3]}, index=index) + msg = ( + "orc does not support serializing a non-default index|" + "orc does not serialize index meta-data" + ) + with pytest.raises(ValueError, match=msg): + df.to_orc() + + +def test_invalid_dtype_backend(temp_file): + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + df = pd.DataFrame({"int": list(range(1, 4))}) + df.to_orc(temp_file) + with pytest.raises(ValueError, match=msg): + read_orc(temp_file, dtype_backend="numpy") + + +def test_string_inference(temp_file): + # GH#54431 + df = pd.DataFrame(data={"a": ["x", "y"]}) + df.to_orc(temp_file) + with pd.option_context("future.infer_string", True): + result = read_orc(temp_file) + expected = pd.DataFrame( + data={"a": ["x", "y"]}, + dtype=pd.StringDtype(na_value=np.nan), + columns=pd.Index(["a"], dtype=pd.StringDtype(na_value=np.nan)), + ) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/io/test_parquet.py b/pandas/tests/io/test_parquet.py new file mode 100644 index 0000000000000000000000000000000000000000..6b7f49e18f5490c0d74ba41e4581a8edc336ddc6 --- /dev/null +++ b/pandas/tests/io/test_parquet.py @@ -0,0 +1,1465 @@ +"""test parquet compat""" + +import datetime +from decimal import Decimal +from io import BytesIO +import os +import pathlib + +import numpy as np +import pytest + +from pandas._config import using_string_dtype + +from pandas.compat import is_platform_windows +from pandas.compat.pyarrow import ( + pa_version_under15p0, + pa_version_under17p0, + pa_version_under19p0, + pa_version_under20p0, +) + +import pandas as pd +import pandas._testing as tm +from pandas.util.version import Version + +from pandas.io.parquet import ( + FastParquetImpl, + PyArrowImpl, + get_engine, + read_parquet, + to_parquet, +) + +try: + import pyarrow + + _HAVE_PYARROW = True +except ImportError: + _HAVE_PYARROW = False + +try: + import fastparquet + + _HAVE_FASTPARQUET = True +except ImportError: + _HAVE_FASTPARQUET = False + + +pytestmark = [ + pytest.mark.filterwarnings("ignore:DataFrame._data is deprecated:FutureWarning"), + pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" + ), +] + + +# setup engines & skips +@pytest.fixture( + params=[ + pytest.param( + "fastparquet", + marks=[ + pytest.mark.skipif( + not _HAVE_FASTPARQUET, + reason="fastparquet is not installed", + ), + pytest.mark.xfail( + using_string_dtype(), + reason="TODO(infer_string) fastparquet", + strict=False, + ), + ], + ), + pytest.param( + "pyarrow", + marks=pytest.mark.skipif( + not _HAVE_PYARROW, reason="pyarrow is not installed" + ), + ), + ] +) +def engine(request): + return request.param + + +@pytest.fixture +def pa(): + if not _HAVE_PYARROW: + pytest.skip("pyarrow is not installed") + return "pyarrow" + + +@pytest.fixture +def fp(request): + if not _HAVE_FASTPARQUET: + pytest.skip("fastparquet is not installed") + if using_string_dtype(): + request.applymarker( + pytest.mark.xfail(reason="TODO(infer_string) fastparquet", strict=False) + ) + return "fastparquet" + + +@pytest.fixture +def df_compat(): + return pd.DataFrame({"A": [1, 2, 3], "B": "foo"}, columns=pd.Index(["A", "B"])) + + +@pytest.fixture +def df_cross_compat(): + df = pd.DataFrame( + { + "a": list("abc"), + "b": list(range(1, 4)), + # 'c': np.arange(3, 6).astype('u1'), + "d": np.arange(4.0, 7.0, dtype="float64"), + "e": [True, False, True], + "f": pd.date_range("20130101", periods=3), + # 'g': pd.date_range('20130101', periods=3, + # tz='US/Eastern'), + # 'h': pd.date_range('20130101', periods=3, freq='ns') + } + ) + return df + + +@pytest.fixture +def df_full(): + return pd.DataFrame( + { + "string": list("abc"), + "string_with_nan": ["a", np.nan, "c"], + "string_with_none": ["a", None, "c"], + "bytes": [b"foo", b"bar", b"baz"], + "unicode": ["foo", "bar", "baz"], + "int": list(range(1, 4)), + "uint": np.arange(3, 6).astype("u1"), + "float": np.arange(4.0, 7.0, dtype="float64"), + "float_with_nan": [2.0, np.nan, 3.0], + "bool": [True, False, True], + "datetime": pd.date_range("20130101", periods=3, unit="ns"), + "datetime_with_nat": [ + pd.Timestamp("20130101"), + pd.NaT, + pd.Timestamp("20130103"), + ], + } + ) + + +@pytest.fixture( + params=[ + datetime.datetime.now(datetime.UTC), + datetime.datetime.now(datetime.timezone.min), + datetime.datetime.now(datetime.timezone.max), + datetime.datetime.strptime("2019-01-04T16:41:24+0200", "%Y-%m-%dT%H:%M:%S%z"), + datetime.datetime.strptime("2019-01-04T16:41:24+0215", "%Y-%m-%dT%H:%M:%S%z"), + datetime.datetime.strptime("2019-01-04T16:41:24-0200", "%Y-%m-%dT%H:%M:%S%z"), + datetime.datetime.strptime("2019-01-04T16:41:24-0215", "%Y-%m-%dT%H:%M:%S%z"), + ] +) +def timezone_aware_date_list(request): + return request.param + + +def check_round_trip( + df, + temp_file, + engine=None, + path=None, + write_kwargs=None, + read_kwargs=None, + expected=None, + check_names=True, + check_like=False, + check_dtype=True, + repeat=2, +): + """Verify parquet serializer and deserializer produce the same results. + + Performs a pandas to disk and disk to pandas round trip, + then compares the 2 resulting DataFrames to verify equality. + + Parameters + ---------- + df: Dataframe + engine: str, optional + 'pyarrow' or 'fastparquet' + path: str, optional + write_kwargs: dict of str:str, optional + read_kwargs: dict of str:str, optional + expected: DataFrame, optional + Expected deserialization result, otherwise will be equal to `df` + check_names: list of str, optional + Closed set of column names to be compared + check_like: bool, optional + If True, ignore the order of index & columns. + repeat: int, optional + How many times to repeat the test + """ + if not isinstance(temp_file, pathlib.Path): + raise ValueError("temp_file must be a pathlib.Path") + write_kwargs = write_kwargs or {"compression": None} + read_kwargs = read_kwargs or {} + + if expected is None: + expected = df + + if engine: + write_kwargs["engine"] = engine + read_kwargs["engine"] = engine + + def compare(repeat): + for _ in range(repeat): + df.to_parquet(path, **write_kwargs) + actual = read_parquet(path, **read_kwargs) + + if "string_with_nan" in expected: + expected.loc[1, "string_with_nan"] = None + tm.assert_frame_equal( + expected, + actual, + check_names=check_names, + check_like=check_like, + check_dtype=check_dtype, + ) + + if path is None: + path = temp_file + compare(repeat) + else: + compare(repeat) + + +def check_partition_names(path, expected): + """Check partitions of a parquet file are as expected. + + Parameters + ---------- + path: str + Path of the dataset. + expected: iterable of str + Expected partition names. + """ + import pyarrow.dataset as ds + + dataset = ds.dataset(path, partitioning="hive") + assert dataset.partitioning.schema.names == expected + + +def test_invalid_engine(df_compat, temp_file): + msg = "engine must be one of 'pyarrow', 'fastparquet'" + with pytest.raises(ValueError, match=msg): + check_round_trip(df_compat, temp_file, "foo", "bar") + + +def test_options_py(df_compat, pa, using_infer_string, temp_file): + # use the set option + if using_infer_string and not pa_version_under19p0: + df_compat.columns = df_compat.columns.astype("str") + + with pd.option_context("io.parquet.engine", "pyarrow"): + check_round_trip(df_compat, temp_file) + + +def test_options_fp(df_compat, fp, temp_file): + # use the set option + + with pd.option_context("io.parquet.engine", "fastparquet"): + check_round_trip(df_compat, temp_file) + + +def test_options_auto(df_compat, fp, pa, temp_file): + # use the set option + + with pd.option_context("io.parquet.engine", "auto"): + check_round_trip(df_compat, temp_file) + + +def test_options_get_engine(fp, pa): + assert isinstance(get_engine("pyarrow"), PyArrowImpl) + assert isinstance(get_engine("fastparquet"), FastParquetImpl) + + with pd.option_context("io.parquet.engine", "pyarrow"): + assert isinstance(get_engine("auto"), PyArrowImpl) + assert isinstance(get_engine("pyarrow"), PyArrowImpl) + assert isinstance(get_engine("fastparquet"), FastParquetImpl) + + with pd.option_context("io.parquet.engine", "fastparquet"): + assert isinstance(get_engine("auto"), FastParquetImpl) + assert isinstance(get_engine("pyarrow"), PyArrowImpl) + assert isinstance(get_engine("fastparquet"), FastParquetImpl) + + with pd.option_context("io.parquet.engine", "auto"): + assert isinstance(get_engine("auto"), PyArrowImpl) + assert isinstance(get_engine("pyarrow"), PyArrowImpl) + assert isinstance(get_engine("fastparquet"), FastParquetImpl) + + +def test_get_engine_auto_error_message(): + # Expect different error messages from get_engine(engine="auto") + # if engines aren't installed vs. are installed but bad version + from pandas.compat._optional import VERSIONS + + # Do we have engines installed, but a bad version of them? + pa_min_ver = VERSIONS.get("pyarrow") + fp_min_ver = VERSIONS.get("fastparquet") + have_pa_bad_version = ( + False + if not _HAVE_PYARROW + else Version(pyarrow.__version__) < Version(pa_min_ver) + ) + have_fp_bad_version = ( + False + if not _HAVE_FASTPARQUET + else Version(fastparquet.__version__) < Version(fp_min_ver) + ) + # Do we have usable engines installed? + have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version + have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version + + if not have_usable_pa and not have_usable_fp: + # No usable engines found. + if have_pa_bad_version: + match = f"Pandas requires version .{pa_min_ver}. or newer of .pyarrow." + with pytest.raises(ImportError, match=match): + get_engine("auto") + else: + match = "Unable to find a usable engine; tried using: 'pyarrow'" + with pytest.raises(ImportError, match=match): + get_engine("auto") + + if have_fp_bad_version: + match = f"Pandas requires version .{fp_min_ver}. or newer of .fastparquet." + with pytest.raises(ImportError, match=match): + get_engine("auto") + else: + match = "Use pip or conda to install the fastparquet package" + with pytest.raises(ImportError, match=match): + get_engine("auto") + + +def test_cross_engine_pa_fp(df_cross_compat, pa, fp, temp_file): + # cross-compat with differing reading/writing engines + + df = df_cross_compat + df.to_parquet(temp_file, engine=pa, compression=None) + + result = read_parquet(temp_file, engine=fp) + tm.assert_frame_equal(result, df) + + result = read_parquet(temp_file, engine=fp, columns=["a", "d"]) + tm.assert_frame_equal(result, df[["a", "d"]]) + + +def test_cross_engine_fp_pa(df_cross_compat, pa, fp, temp_file): + # cross-compat with differing reading/writing engines + df = df_cross_compat + + df.to_parquet(temp_file, engine=fp, compression=None) + + result = read_parquet(temp_file, engine=pa) + tm.assert_frame_equal(result, df) + + result = read_parquet(temp_file, engine=pa, columns=["a", "d"]) + tm.assert_frame_equal(result, df[["a", "d"]]) + + +class Base: + def check_error_on_write(self, df, engine, exc, err_msg, temp_file_path): + # check that we are raising the exception on writing + with pytest.raises(exc, match=err_msg): + to_parquet(df, temp_file_path, engine, compression=None) + + def check_external_error_on_write(self, df, engine, exc, temp_file_path): + # check that an external library is raising the exception on writing + with tm.external_error_raised(exc): + to_parquet(df, temp_file_path, engine, compression=None) + + +class TestBasic(Base): + def test_error(self, engine, temp_file): + for obj in [ + pd.Series([1, 2, 3]), + 1, + "foo", + pd.Timestamp("20130101"), + np.array([1, 2, 3]), + ]: + msg = "to_parquet only supports IO with DataFrames" + self.check_error_on_write(obj, engine, ValueError, msg, temp_file) + + def test_columns_dtypes(self, engine, temp_file): + df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))}) + + # unicode + df.columns = ["foo", "bar"] + check_round_trip(df, temp_file, engine) + + @pytest.mark.parametrize("compression", [None, "gzip", "snappy", "brotli"]) + def test_compression(self, engine, compression, temp_file): + df = pd.DataFrame({"A": [1, 2, 3]}) + check_round_trip( + df, temp_file, engine, write_kwargs={"compression": compression} + ) + + def test_read_columns(self, engine, temp_file): + # GH18154 + df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))}) + + expected = pd.DataFrame({"string": list("abc")}) + check_round_trip( + df, + temp_file, + engine, + expected=expected, + read_kwargs={"columns": ["string"]}, + ) + + def test_read_filters(self, engine, tmp_path): + df = pd.DataFrame( + { + "int": list(range(4)), + "part": list("aabb"), + } + ) + + expected = pd.DataFrame({"int": [0, 1]}) + check_round_trip( + df, + tmp_path, + engine, + expected=expected, + write_kwargs={"partition_cols": ["part"]}, + read_kwargs={"filters": [("part", "==", "a")], "columns": ["int"]}, + repeat=1, + ) + + def test_write_index(self, temp_file): + pytest.importorskip("pyarrow") + df = pd.DataFrame({"A": [1, 2, 3]}) + check_round_trip(df, temp_file, "pyarrow") + + indexes = [ + [2, 3, 4], + pd.date_range("20130101", periods=3, unit="ns"), + list("abc"), + [1, 3, 4], + ] + # non-default index + for index in indexes: + df.index = index + if isinstance(index, pd.DatetimeIndex): + df.index = df.index._with_freq(None) # freq doesn't round-trip + check_round_trip(df, temp_file, "pyarrow") + + # index with meta-data + df.index = [0, 1, 2] + df.index.name = "foo" + check_round_trip(df, temp_file, "pyarrow") + + def test_write_multiindex(self, pa, temp_file): + # Not supported in fastparquet as of 0.1.3 or older pyarrow version + engine = pa + + df = pd.DataFrame({"A": [1, 2, 3]}) + index = pd.MultiIndex.from_tuples([("a", 1), ("a", 2), ("b", 1)]) + df.index = index + check_round_trip(df, temp_file, engine) + + def test_multiindex_with_columns(self, pa, temp_file): + engine = pa + dates = pd.date_range("01-Jan-2018", "01-Dec-2018", freq="MS", unit="ns") + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((2 * len(dates), 3)), + columns=list("ABC"), + ) + index1 = pd.MultiIndex.from_product( + [["Level1", "Level2"], dates], names=["level", "date"] + ) + index2 = index1.copy(names=None) + for index in [index1, index2]: + df.index = index + + check_round_trip(df, temp_file, engine) + check_round_trip( + df, + temp_file, + engine, + read_kwargs={"columns": ["A", "B"]}, + expected=df[["A", "B"]], + ) + + def test_write_ignoring_index(self, engine, temp_file): + # ENH 20768 + # Ensure index=False omits the index from the written Parquet file. + df = pd.DataFrame({"a": [1, 2, 3], "b": ["q", "r", "s"]}) + + write_kwargs = {"compression": None, "index": False} + + # Because we're dropping the index, we expect the loaded dataframe to + # have the default integer index. + expected = df.reset_index(drop=True) + + check_round_trip( + df, temp_file, engine, write_kwargs=write_kwargs, expected=expected + ) + + # Ignore custom index + df = pd.DataFrame( + {"a": [1, 2, 3], "b": ["q", "r", "s"]}, index=["zyx", "wvu", "tsr"] + ) + + check_round_trip( + df, temp_file, engine, write_kwargs=write_kwargs, expected=expected + ) + + # Ignore multi-indexes as well. + arrays = [ + ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"], + ["one", "two", "one", "two", "one", "two", "one", "two"], + ] + df = pd.DataFrame( + {"one": list(range(8)), "two": [-i for i in range(8)]}, index=arrays + ) + + expected = df.reset_index(drop=True) + check_round_trip( + df, temp_file, engine, write_kwargs=write_kwargs, expected=expected + ) + + def test_write_column_multiindex(self, engine, temp_file): + # Not able to write column multi-indexes with non-string column names. + mi_columns = pd.MultiIndex.from_tuples([("a", 1), ("a", 2), ("b", 1)]) + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((4, 3)), columns=mi_columns + ) + + if engine == "fastparquet": + self.check_error_on_write( + df, engine, TypeError, "Column name must be a string", temp_file + ) + elif engine == "pyarrow": + check_round_trip(df, temp_file, engine) + + def test_write_column_multiindex_nonstring(self, engine, temp_file): + # GH #34777 + + # Not able to write column multi-indexes with non-string column names + arrays = [ + ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"], + [1, 2, 1, 2, 1, 2, 1, 2], + ] + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((8, 8)), columns=arrays + ) + df.columns.names = ["Level1", "Level2"] + if engine == "fastparquet": + self.check_error_on_write(df, engine, ValueError, "Column name", temp_file) + elif engine == "pyarrow": + check_round_trip(df, temp_file, engine) + + def test_write_column_multiindex_string(self, pa, temp_file): + # GH #34777 + # Not supported in fastparquet as of 0.1.3 + engine = pa + + # Write column multi-indexes with string column names + arrays = [ + ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"], + ["one", "two", "one", "two", "one", "two", "one", "two"], + ] + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((8, 8)), columns=arrays + ) + df.columns.names = ["ColLevel1", "ColLevel2"] + + check_round_trip(df, temp_file, engine) + + def test_write_column_index_string(self, pa, temp_file): + # GH #34777 + # Not supported in fastparquet as of 0.1.3 + engine = pa + + # Write column indexes with string column names + arrays = ["bar", "baz", "foo", "qux"] + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((8, 4)), columns=arrays + ) + df.columns.name = "StringCol" + + check_round_trip(df, temp_file, engine) + + def test_write_column_index_nonstring(self, engine, temp_file): + # GH #34777 + + # Write column indexes with string column names + arrays = [1, 2, 3, 4] + df = pd.DataFrame( + np.random.default_rng(2).standard_normal((8, 4)), columns=arrays + ) + df.columns.name = "NonStringCol" + if engine == "fastparquet": + self.check_error_on_write( + df, engine, TypeError, "Column name must be a string", temp_file + ) + else: + check_round_trip(df, temp_file, engine) + + def test_dtype_backend(self, engine, request, temp_file): + pq = pytest.importorskip("pyarrow.parquet") + + if engine == "fastparquet": + # We are manually disabling fastparquet's + # nullable dtype support pending discussion + mark = pytest.mark.xfail( + reason="Fastparquet nullable dtype support is disabled" + ) + request.applymarker(mark) + + table = pyarrow.table( + { + "a": pyarrow.array([1, 2, 3, None], "int64"), + "b": pyarrow.array([1, 2, 3, None], "uint8"), + "c": pyarrow.array(["a", "b", "c", None]), + "d": pyarrow.array([True, False, True, None]), + # Test that nullable dtypes used even in absence of nulls + "e": pyarrow.array([1, 2, 3, 4], "int64"), + # GH 45694 + "f": pyarrow.array([1.0, 2.0, 3.0, None], "float32"), + "g": pyarrow.array([1.0, 2.0, 3.0, None], "float64"), + } + ) + # write manually with pyarrow to write integers + pq.write_table(table, temp_file) + result1 = read_parquet(temp_file, engine=engine) + result2 = read_parquet(temp_file, engine=engine, dtype_backend="numpy_nullable") + + assert result1["a"].dtype == np.dtype("float64") + expected = pd.DataFrame( + { + "a": pd.array([1, 2, 3, None], dtype="Int64"), + "b": pd.array([1, 2, 3, None], dtype="UInt8"), + "c": pd.array(["a", "b", "c", None], dtype="string"), + "d": pd.array([True, False, True, None], dtype="boolean"), + "e": pd.array([1, 2, 3, 4], dtype="Int64"), + "f": pd.array([1.0, 2.0, 3.0, None], dtype="Float32"), + "g": pd.array([1.0, 2.0, 3.0, None], dtype="Float64"), + } + ) + if engine == "fastparquet": + # Fastparquet doesn't support string columns yet + # Only int and boolean + result2 = result2.drop("c", axis=1) + expected = expected.drop("c", axis=1) + tm.assert_frame_equal(result2, expected) + + @pytest.mark.parametrize( + "dtype", + [ + "Int64", + "UInt8", + "boolean", + "object", + "datetime64[ns, UTC]", + "float", + "period[D]", + "Float64", + "string", + ], + ) + def test_read_empty_array(self, pa, dtype, temp_file): + # GH #41241 + df = pd.DataFrame( + { + "value": pd.array([], dtype=dtype), + } + ) + pytest.importorskip("pyarrow", "11.0.0") + # GH 45694 + expected = None + if dtype == "float": + expected = pd.DataFrame( + { + "value": pd.array([], dtype="Float64"), + } + ) + check_round_trip( + df, + temp_file, + pa, + read_kwargs={"dtype_backend": "numpy_nullable"}, + expected=expected, + ) + + @pytest.mark.network + @pytest.mark.single_cpu + def test_parquet_read_from_url(self, httpserver, datapath, df_compat, engine): + if engine != "auto": + pytest.importorskip(engine) + with open(datapath("io", "data", "parquet", "simple.parquet"), mode="rb") as f: + httpserver.serve_content(content=f.read()) + df = read_parquet(httpserver.url, engine=engine) + + expected = df_compat + if pa_version_under19p0: + expected.columns = expected.columns.astype(object) + tm.assert_frame_equal(df, expected) + + +class TestParquetPyArrow(Base): + def test_basic(self, pa, df_full, temp_file): + df = df_full + pytest.importorskip("pyarrow", "11.0.0") + + # additional supported types for pyarrow + dti = pd.date_range("20130101", periods=3, tz="Europe/Brussels") + dti = dti._with_freq(None) # freq doesn't round-trip + df["datetime_tz"] = dti + df["bool_with_none"] = [True, None, True] + + check_round_trip(df, temp_file, pa) + + def test_basic_subset_columns(self, pa, df_full, temp_file): + # GH18628 + + df = df_full + # additional supported types for pyarrow + df["datetime_tz"] = pd.date_range("20130101", periods=3, tz="Europe/Brussels") + + check_round_trip( + df, + temp_file, + pa, + expected=df[["string", "int"]], + read_kwargs={"columns": ["string", "int"]}, + ) + + def test_to_bytes_without_path_or_buf_provided(self, pa, df_full): + # GH 37105 + buf_bytes = df_full.to_parquet(engine=pa) + assert isinstance(buf_bytes, bytes) + + buf_stream = BytesIO(buf_bytes) + res = read_parquet(buf_stream) + + expected = df_full.copy() + expected.loc[1, "string_with_nan"] = None + expected["datetime_with_nat"] = expected["datetime_with_nat"].astype("M8[us]") + tm.assert_frame_equal(res, expected) + + def test_duplicate_columns(self, pa, temp_file): + # not currently able to handle duplicate columns + df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=list("aaa")).copy() + self.check_error_on_write( + df, pa, ValueError, "Duplicate column names found", temp_file + ) + + def test_timedelta(self, pa, temp_file): + df = pd.DataFrame({"a": pd.timedelta_range("1 day", periods=3)}) + check_round_trip(df, temp_file, pa) + + def test_unsupported(self, pa, temp_file): + # mixed python objects + df = pd.DataFrame({"a": ["a", 1, 2.0]}) + # pyarrow 0.11 raises ArrowTypeError + # older pyarrows raise ArrowInvalid + self.check_external_error_on_write(df, pa, pyarrow.ArrowException, temp_file) + + def test_unsupported_float16(self, pa, temp_file): + # #44847, #44914 + # Not able to write float 16 column using pyarrow. + data = np.arange(2, 10, dtype=np.float16) + df = pd.DataFrame(data=data, columns=["fp16"]) + if pa_version_under15p0: + self.check_external_error_on_write( + df, pa, pyarrow.ArrowException, temp_file + ) + else: + check_round_trip(df, temp_file, pa) + + @pytest.mark.xfail( + is_platform_windows(), + reason=( + "PyArrow does not cleanup of partial files dumps when unsupported " + "dtypes are passed to_parquet function in windows" + ), + ) + @pytest.mark.skipif(not pa_version_under15p0, reason="float16 works on 15") + @pytest.mark.parametrize("path_type", [str, pathlib.Path]) + def test_unsupported_float16_cleanup(self, pa, path_type, temp_file): + # #44847, #44914 + # Not able to write float 16 column using pyarrow. + # Tests cleanup by pyarrow in case of an error + data = np.arange(2, 10, dtype=np.float16) + df = pd.DataFrame(data=data, columns=["fp16"]) + + path = path_type(temp_file) + with tm.external_error_raised(pyarrow.ArrowException): + df.to_parquet(path=path, engine=pa) + assert not os.path.isfile(path) + + def test_categorical(self, pa, temp_file): + # supported in >= 0.7.0 + df = pd.DataFrame( + { + "a": pd.Categorical(list("abcdef")), + # test for null, out-of-order values, and unobserved category + "b": pd.Categorical( + ["bar", "foo", "foo", "bar", None, "bar"], + dtype=pd.CategoricalDtype(["foo", "bar", "baz"]), + ), + # test for ordered flag + "c": pd.Categorical( + [None, "b", "c", None, "c", "b"], + categories=["b", "c", "d"], + ordered=True, + ), + } + ) + + check_round_trip(df, temp_file, pa) + + @pytest.mark.single_cpu + def test_s3_roundtrip_explicit_fs( + self, df_compat, s3_bucket_public, s3so, pa, temp_file + ): + s3fs = pytest.importorskip("s3fs") + s3 = s3fs.S3FileSystem(**s3so) + kw = {"filesystem": s3} + check_round_trip( + df_compat, + temp_file, + pa, + path=f"{s3_bucket_public.name}/pyarrow.parquet", + read_kwargs=kw, + write_kwargs=kw, + ) + + @pytest.mark.single_cpu + def test_s3_roundtrip(self, df_compat, s3_bucket_public, s3so, pa, temp_file): + # GH #19134 + s3so = {"storage_options": s3so} + check_round_trip( + df_compat, + temp_file, + pa, + path=f"s3://{s3_bucket_public.name}/pyarrow.parquet", + read_kwargs=s3so, + write_kwargs=s3so, + ) + + @pytest.mark.single_cpu + @pytest.mark.parametrize("partition_col", [["A"], []]) + def test_s3_roundtrip_for_dir( + self, df_compat, s3_bucket_public, pa, partition_col, s3so, temp_file + ): + pytest.importorskip("s3fs") + # GH #26388 + expected_df = df_compat.copy() + + # GH #35791 + if partition_col: + expected_df = expected_df.astype(dict.fromkeys(partition_col, np.int32)) + partition_col_type = "category" + + expected_df[partition_col] = expected_df[partition_col].astype( + partition_col_type + ) + + check_round_trip( + df_compat, + temp_file, + pa, + expected=expected_df, + path=f"s3://{s3_bucket_public.name}/parquet_dir", + read_kwargs={"storage_options": s3so}, + write_kwargs={ + "partition_cols": partition_col, + "compression": None, + "storage_options": s3so, + }, + check_like=True, + repeat=1, + ) + + def test_read_file_like_obj_support(self, df_compat, using_infer_string): + pytest.importorskip("pyarrow") + buffer = BytesIO() + df_compat.to_parquet(buffer) + df_from_buf = read_parquet(buffer) + if using_infer_string and not pa_version_under19p0: + df_compat.columns = df_compat.columns.astype("str") + tm.assert_frame_equal(df_compat, df_from_buf) + + def test_expand_user(self, df_compat, monkeypatch): + pytest.importorskip("pyarrow") + monkeypatch.setenv("HOME", "TestingUser") + monkeypatch.setenv("USERPROFILE", "TestingUser") + with pytest.raises(OSError, match=r".*TestingUser.*"): + read_parquet("~/file.parquet") + with pytest.raises(OSError, match=r".*TestingUser.*"): + df_compat.to_parquet("~/file.parquet") + + def test_partition_cols_supported(self, tmp_path, pa, df_full): + # GH #23283 + partition_cols = ["bool", "int"] + df = df_full + df.to_parquet(tmp_path, partition_cols=partition_cols, compression=None) + check_partition_names(tmp_path, partition_cols) + assert read_parquet(tmp_path).shape == df.shape + + def test_partition_cols_string(self, tmp_path, pa, df_full): + # GH #27117 + partition_cols = "bool" + partition_cols_list = [partition_cols] + df = df_full + df.to_parquet(tmp_path, partition_cols=partition_cols, compression=None) + check_partition_names(tmp_path, partition_cols_list) + assert read_parquet(tmp_path).shape == df.shape + + @pytest.mark.parametrize( + "path_type", [str, lambda x: x], ids=["string", "pathlib.Path"] + ) + def test_partition_cols_pathlib(self, tmp_path, pa, df_compat, path_type): + # GH 35902 + + partition_cols = "B" + partition_cols_list = [partition_cols] + df = df_compat + + path = path_type(tmp_path) + df.to_parquet(path, partition_cols=partition_cols_list) + assert read_parquet(path).shape == df.shape + + def test_empty_dataframe(self, pa, temp_file): + # GH #27339 + df = pd.DataFrame(index=[], columns=[]) + check_round_trip(df, temp_file, pa) + + def test_write_with_schema(self, pa, temp_file): + import pyarrow + + df = pd.DataFrame({"x": [0, 1]}) + schema = pyarrow.schema([pyarrow.field("x", type=pyarrow.bool_())]) + out_df = df.astype(bool) + check_round_trip( + df, temp_file, pa, write_kwargs={"schema": schema}, expected=out_df + ) + + def test_additional_extension_arrays(self, pa, using_infer_string, temp_file): + # test additional ExtensionArrays that are supported through the + # __arrow_array__ protocol + pytest.importorskip("pyarrow") + df = pd.DataFrame( + { + "a": pd.Series([1, 2, 3], dtype="Int64"), + "b": pd.Series([1, 2, 3], dtype="UInt32"), + "c": pd.Series(["a", None, "c"], dtype="string"), + } + ) + if using_infer_string and pa_version_under19p0: + check_round_trip(df, temp_file, pa, expected=df.astype({"c": "str"})) + else: + check_round_trip(df, temp_file, pa) + + df = pd.DataFrame({"a": pd.Series([1, 2, 3, None], dtype="Int64")}) + check_round_trip(df, temp_file, pa) + + def test_pyarrow_backed_string_array( + self, pa, string_storage, using_infer_string, temp_file + ): + # test ArrowStringArray supported through the __arrow_array__ protocol + pytest.importorskip("pyarrow") + df = pd.DataFrame({"a": pd.Series(["a", None, "c"], dtype="string[pyarrow]")}) + with pd.option_context("string_storage", string_storage): + if using_infer_string: + if pa_version_under19p0: + expected = df.astype("str") + else: + expected = df.astype(f"string[{string_storage}]") + expected.columns = expected.columns.astype("str") + else: + expected = df.astype(f"string[{string_storage}]") + check_round_trip(df, temp_file, pa, expected=expected) + + def test_additional_extension_types(self, pa, temp_file): + # test additional ExtensionArrays that are supported through the + # __arrow_array__ protocol + by defining a custom ExtensionType + pytest.importorskip("pyarrow") + df = pd.DataFrame( + { + "c": pd.IntervalIndex.from_tuples([(0, 1), (1, 2), (3, 4)]), + "d": pd.period_range("2012-01-01", periods=3, freq="D"), + # GH-45881 issue with interval with datetime64[ns] subtype + "e": pd.IntervalIndex.from_breaks( + pd.date_range("2012-01-01", periods=4, freq="D") + ), + } + ) + check_round_trip(df, temp_file, pa) + + def test_timestamp_nanoseconds(self, pa, temp_file): + # with version 2.6, pyarrow defaults to writing the nanoseconds, so + # this should work without error, even for pyarrow < 13 + ver = "2.6" + df = pd.DataFrame({"a": pd.date_range("2017-01-01", freq="1ns", periods=10)}) + check_round_trip(df, temp_file, pa, write_kwargs={"version": ver}) + + def test_timezone_aware_index(self, pa, timezone_aware_date_list, temp_file): + idx = 5 * [timezone_aware_date_list] + df = pd.DataFrame(index=idx, data={"index_as_col": idx}) + + # see gh-36004 + # compare time(zone) values only, skip their class: + # pyarrow always creates fixed offset timezones using pytz.FixedOffset() + # even if it was datetime.timezone() originally + # + # technically they are the same: + # they both implement datetime.tzinfo + # they both wrap datetime.timedelta() + # this use-case sets the resolution to 1 minute + + expected = df[:] + if timezone_aware_date_list.tzinfo != datetime.UTC: + # pyarrow returns pytz.FixedOffset while pandas constructs datetime.timezone + # https://github.com/pandas-dev/pandas/issues/37286 + try: + import pytz + except ImportError: + pass + else: + offset = df.index.tz.utcoffset(timezone_aware_date_list) + tz = pytz.FixedOffset(offset.total_seconds() / 60) + expected.index = expected.index.tz_convert(tz) + expected["index_as_col"] = expected["index_as_col"].dt.tz_convert(tz) + check_round_trip(df, temp_file, pa, check_dtype=False, expected=expected) + + def test_filter_row_groups(self, pa, temp_file): + # https://github.com/pandas-dev/pandas/issues/26551 + pytest.importorskip("pyarrow") + df = pd.DataFrame({"a": list(range(3))}) + df.to_parquet(temp_file, engine=pa) + result = read_parquet(temp_file, pa, filters=[("a", "==", 0)]) + assert len(result) == 1 + + @pytest.mark.filterwarnings("ignore:make_block is deprecated:DeprecationWarning") + def test_read_dtype_backend_pyarrow_config(self, pa, df_full, temp_file): + import pyarrow + + df = df_full + + # additional supported types for pyarrow + dti = pd.date_range("20130101", periods=3, tz="Europe/Brussels", unit="ns") + dti = dti._with_freq(None) # freq doesn't round-trip + df["datetime_tz"] = dti + df["bool_with_none"] = [True, None, True] + + pa_table = pyarrow.Table.from_pandas(df) + expected = pa_table.to_pandas(types_mapper=pd.ArrowDtype) + expected["datetime_with_nat"] = expected["datetime_with_nat"].astype( + "timestamp[us][pyarrow]" + ) + + check_round_trip( + df, + temp_file, + engine=pa, + read_kwargs={"dtype_backend": "pyarrow"}, + expected=expected, + ) + + def test_read_dtype_backend_pyarrow_config_index(self, pa, temp_file): + df = pd.DataFrame( + {"a": [1, 2]}, index=pd.Index([3, 4], name="test"), dtype="int64[pyarrow]" + ) + expected = df.copy() + + expected.index = expected.index.astype("int64[pyarrow]") + check_round_trip( + df, + temp_file, + engine=pa, + read_kwargs={"dtype_backend": "pyarrow"}, + expected=expected, + ) + + @pytest.mark.parametrize( + "columns", + [ + [0, 1], + pytest.param( + [b"foo", b"bar"], + marks=pytest.mark.xfail( + pa_version_under20p0, + raises=NotImplementedError, + reason="https://github.com/apache/arrow/pull/44171", + ), + ), + pytest.param( + [ + datetime.datetime(2011, 1, 1, 0, 0), + datetime.datetime(2011, 1, 1, 1, 1), + ], + marks=pytest.mark.xfail( + pa_version_under17p0, + reason="pa.pandas_compat passes 'datetime64' to .astype", + ), + ), + ], + ) + def test_columns_dtypes_not_invalid(self, pa, columns, temp_file): + df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))}) + + df.columns = columns + check_round_trip(df, temp_file, pa) + + def test_empty_columns(self, pa, temp_file): + # GH 52034 + df = pd.DataFrame(index=pd.Index(["a", "b", "c"], name="custom name")) + check_round_trip(df, temp_file, pa) + + def test_df_attrs_persistence(self, temp_file, pa): + df = pd.DataFrame(data={1: [1]}) + df.attrs = {"test_attribute": 1} + df.to_parquet(temp_file, engine=pa) + new_df = read_parquet(temp_file, engine=pa) + assert new_df.attrs == df.attrs + + def test_string_inference(self, temp_file, pa, using_infer_string): + # GH#54431 + df = pd.DataFrame(data={"a": ["x", "y"]}, index=["a", "b"]) + df.to_parquet(temp_file, engine=pa) + with pd.option_context("future.infer_string", True): + result = read_parquet(temp_file, engine=pa) + dtype = pd.StringDtype(na_value=np.nan) + expected = pd.DataFrame( + data={"a": ["x", "y"]}, + dtype=dtype, + index=pd.Index(["a", "b"], dtype=dtype), + columns=pd.Index( + ["a"], + dtype=( + object if pa_version_under19p0 and not using_infer_string else dtype + ), + ), + ) + tm.assert_frame_equal(result, expected) + + def test_roundtrip_decimal(self, temp_file, pa): + # GH#54768 + import pyarrow as pa + + df = pd.DataFrame({"a": [Decimal("123.00")]}, dtype="string[pyarrow]") + df.to_parquet(temp_file, schema=pa.schema([("a", pa.decimal128(5))])) + result = read_parquet(temp_file) + if pa_version_under19p0: + expected = pd.DataFrame({"a": ["123"]}, dtype="string") + else: + expected = pd.DataFrame({"a": [Decimal("123.00")]}, dtype="object") + tm.assert_frame_equal(result, expected) + + def test_infer_string_large_string_type(self, temp_file, pa): + # GH#54798 + import pyarrow as pa + import pyarrow.parquet as pq + + table = pa.table({"a": pa.array([None, "b", "c"], pa.large_string())}) + pq.write_table(table, temp_file) + + with pd.option_context("future.infer_string", True): + result = read_parquet(temp_file) + expected = pd.DataFrame( + data={"a": [None, "b", "c"]}, + dtype=pd.StringDtype(na_value=np.nan), + columns=pd.Index(["a"], dtype=pd.StringDtype(na_value=np.nan)), + ) + tm.assert_frame_equal(result, expected) + + # NOTE: this test is not run by default, because it requires a lot of memory (>5GB) + # @pytest.mark.slow + # def test_string_column_above_2GB(self, tmp_path, pa): + # # https://github.com/pandas-dev/pandas/issues/55606 + # # above 2GB of string data + # v1 = b"x" * 100000000 + # v2 = b"x" * 147483646 + # df = pd.DataFrame({"strings": [v1] * 20 + [v2] + ["x"] * 20}, dtype="string") + # df.to_parquet(tmp_path / "test.parquet") + # result = read_parquet(tmp_path / "test.parquet") + # assert result["strings"].dtype == "string" + # FIXME: don't leave commented-out + + def test_non_nanosecond_timestamps(self, temp_file): + # GH#49236 + pa = pytest.importorskip("pyarrow", "13.0.0") + pq = pytest.importorskip("pyarrow.parquet") + + arr = pa.array([datetime.datetime(1600, 1, 1)], type=pa.timestamp("us")) + table = pa.table([arr], names=["timestamp"]) + pq.write_table(table, temp_file) + result = read_parquet(temp_file) + expected = pd.DataFrame( + data={"timestamp": [datetime.datetime(1600, 1, 1)]}, + dtype="datetime64[us]", + ) + tm.assert_frame_equal(result, expected) + + def test_maps_as_pydicts(self, pa, temp_file): + pyarrow = pytest.importorskip("pyarrow", "13.0.0") + + schema = pyarrow.schema( + [("foo", pyarrow.map_(pyarrow.string(), pyarrow.int64()))] + ) + df = pd.DataFrame([{"foo": {"A": 1}}, {"foo": {"B": 2}}]) + check_round_trip( + df, + temp_file, + pa, + write_kwargs={"schema": schema}, + read_kwargs={"to_pandas_kwargs": {"maps_as_pydicts": "strict"}}, + ) + + +class TestParquetFastParquet(Base): + def test_basic(self, fp, df_full, request, temp_file): + pytz = pytest.importorskip("pytz") + + tz = pytz.timezone("US/Eastern") + df = df_full + + dti = pd.date_range("20130101", periods=3, tz=tz) + dti = dti._with_freq(None) # freq doesn't round-trip + df["datetime_tz"] = dti + df["timedelta"] = pd.timedelta_range("1 day", periods=3) + check_round_trip(df, temp_file, fp) + + def test_columns_dtypes_invalid(self, fp, temp_file): + df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))}) + + err = TypeError + msg = "Column name must be a string" + + # numeric + df.columns = [0, 1] + self.check_error_on_write(df, fp, err, msg, temp_file) + + # bytes + df.columns = [b"foo", b"bar"] + self.check_error_on_write(df, fp, err, msg, temp_file) + + # python object + df.columns = [ + datetime.datetime(2011, 1, 1, 0, 0), + datetime.datetime(2011, 1, 1, 1, 1), + ] + self.check_error_on_write(df, fp, err, msg, temp_file) + + def test_duplicate_columns(self, fp, temp_file): + # not currently able to handle duplicate columns + df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=list("aaa")).copy() + msg = "Cannot create parquet dataset with duplicate column names" + self.check_error_on_write(df, fp, ValueError, msg, temp_file) + + def test_bool_with_none(self, fp, request, temp_file): + df = pd.DataFrame({"a": [True, None, False]}) + expected = pd.DataFrame({"a": [1.0, np.nan, 0.0]}, dtype="float16") + # Fastparquet bug in 0.7.1 makes it so that this dtype becomes + # float64 + check_round_trip(df, temp_file, fp, expected=expected, check_dtype=False) + + def test_unsupported(self, fp, temp_file): + # period + df = pd.DataFrame({"a": pd.period_range("2013", freq="M", periods=3)}) + # error from fastparquet -> don't check exact error message + self.check_error_on_write(df, fp, ValueError, None, temp_file) + + # mixed + df = pd.DataFrame({"a": ["a", 1, 2.0]}) + msg = "Can't infer object conversion type" + self.check_error_on_write(df, fp, ValueError, msg, temp_file) + + def test_categorical(self, fp, temp_file): + df = pd.DataFrame({"a": pd.Categorical(list("abc"))}) + check_round_trip(df, temp_file, fp) + + def test_filter_row_groups(self, fp, temp_file): + d = {"a": list(range(3))} + df = pd.DataFrame(d) + df.to_parquet(temp_file, engine=fp, compression=None, row_group_offsets=1) + result = read_parquet(temp_file, fp, filters=[("a", "==", 0)]) + assert len(result) == 1 + + @pytest.mark.single_cpu + def test_s3_roundtrip(self, df_compat, s3_bucket_public, s3so, fp, temp_file): + # GH #19134 + check_round_trip( + df_compat, + temp_file, + fp, + path=f"s3://{s3_bucket_public.name}/fastparquet.parquet", + read_kwargs={"storage_options": s3so}, + write_kwargs={"compression": None, "storage_options": s3so}, + ) + + def test_partition_cols_supported(self, tmp_path, fp, df_full): + # GH #23283 + partition_cols = ["bool", "int"] + df = df_full + df.to_parquet( + tmp_path, + engine="fastparquet", + partition_cols=partition_cols, + compression=None, + ) + assert os.path.exists(tmp_path) + import fastparquet + + actual_partition_cols = fastparquet.ParquetFile(str(tmp_path), False).cats + assert len(actual_partition_cols) == 2 + + def test_partition_cols_string(self, tmp_path, fp, df_full): + # GH #27117 + partition_cols = "bool" + df = df_full + df.to_parquet( + tmp_path, + engine="fastparquet", + partition_cols=partition_cols, + compression=None, + ) + assert os.path.exists(tmp_path) + import fastparquet + + actual_partition_cols = fastparquet.ParquetFile(str(tmp_path), False).cats + assert len(actual_partition_cols) == 1 + + def test_partition_on_supported(self, tmp_path, fp, df_full): + # GH #23283 + partition_cols = ["bool", "int"] + df = df_full + df.to_parquet( + tmp_path, + engine="fastparquet", + compression=None, + partition_on=partition_cols, + ) + assert os.path.exists(tmp_path) + import fastparquet + + actual_partition_cols = fastparquet.ParquetFile(str(tmp_path), False).cats + assert len(actual_partition_cols) == 2 + + def test_error_on_using_partition_cols_and_partition_on( + self, tmp_path, fp, df_full + ): + # GH #23283 + partition_cols = ["bool", "int"] + df = df_full + msg = ( + "Cannot use both partition_on and partition_cols. Use partition_cols for " + "partitioning data" + ) + with pytest.raises(ValueError, match=msg): + df.to_parquet( + tmp_path, + engine="fastparquet", + compression=None, + partition_on=partition_cols, + partition_cols=partition_cols, + ) + + def test_empty_dataframe(self, fp, temp_file): + # GH #27339 + df = pd.DataFrame() + expected = df.copy() + check_round_trip(df, temp_file, fp, expected=expected) + + def test_timezone_aware_index( + self, fp, timezone_aware_date_list, request, temp_file + ): + idx = 5 * [timezone_aware_date_list] + + df = pd.DataFrame(index=idx, data={"index_as_col": idx}) + + expected = df.copy() + expected.index.name = "index" + check_round_trip(df, temp_file, fp, expected=expected) + + def test_close_file_handle_on_read_error(self, temp_file): + pathlib.Path(temp_file).write_bytes(b"breakit") + with tm.external_error_raised(Exception): # Not important which exception + read_parquet(temp_file, engine="fastparquet") + # The next line raises an error on Windows if the file is still open + pathlib.Path(temp_file).unlink(missing_ok=False) + + def test_bytes_file_name(self, engine, temp_file): + # GH#48944 + df = pd.DataFrame(data={"A": [0, 1], "B": [1, 0]}) + with open(temp_file, "wb") as f: + df.to_parquet(f) + + result = read_parquet(temp_file, engine=engine) + tm.assert_frame_equal(result, df) + + def test_filesystem_notimplemented(self, temp_file): + pytest.importorskip("fastparquet") + df = pd.DataFrame(data={"A": [0, 1], "B": [1, 0]}) + with pytest.raises(NotImplementedError, match="filesystem is not implemented"): + df.to_parquet(temp_file, engine="fastparquet", filesystem="foo") + + pathlib.Path(temp_file).write_bytes(b"foo") + with pytest.raises(NotImplementedError, match="filesystem is not implemented"): + read_parquet(temp_file, engine="fastparquet", filesystem="foo") + + def test_invalid_filesystem(self, temp_file): + pytest.importorskip("pyarrow") + df = pd.DataFrame(data={"A": [0, 1], "B": [1, 0]}) + + with pytest.raises( + ValueError, match="filesystem must be a pyarrow or fsspec FileSystem" + ): + df.to_parquet(temp_file, engine="pyarrow", filesystem="foo") + + pathlib.Path(temp_file).write_bytes(b"foo") + with pytest.raises( + ValueError, match="filesystem must be a pyarrow or fsspec FileSystem" + ): + read_parquet(temp_file, engine="pyarrow", filesystem="foo") + + def test_unsupported_pa_filesystem_storage_options(self, temp_file): + pa_fs = pytest.importorskip("pyarrow.fs") + df = pd.DataFrame(data={"A": [0, 1], "B": [1, 0]}) + + with pytest.raises( + NotImplementedError, + match="storage_options not supported with a pyarrow FileSystem.", + ): + df.to_parquet( + temp_file, + engine="pyarrow", + filesystem=pa_fs.LocalFileSystem(), + storage_options={"foo": "bar"}, + ) + + pathlib.Path(temp_file).write_bytes(b"foo") + with pytest.raises( + NotImplementedError, + match="storage_options not supported with a pyarrow FileSystem.", + ): + read_parquet( + temp_file, + engine="pyarrow", + filesystem=pa_fs.LocalFileSystem(), + storage_options={"foo": "bar"}, + ) + + def test_invalid_dtype_backend(self, engine, temp_file): + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + df = pd.DataFrame({"int": list(range(1, 4))}) + df.to_parquet(temp_file) + with pytest.raises(ValueError, match=msg): + read_parquet(temp_file, dtype_backend="numpy") diff --git a/pandas/tests/io/test_pickle.py b/pandas/tests/io/test_pickle.py new file mode 100644 index 0000000000000000000000000000000000000000..7754c58a88ef137825906a543a3783f2e21c71b0 --- /dev/null +++ b/pandas/tests/io/test_pickle.py @@ -0,0 +1,590 @@ +""" +manage legacy pickle tests + +How to add pickle tests: + +1. Install pandas version intended to output the pickle. + +2. Execute "generate_legacy_storage_files.py" to create the pickle. +$ python generate_legacy_storage_files.py pickle + +3. Move the created pickle to "data/legacy_pickle/" directory. +""" + +from __future__ import annotations + +import bz2 +import datetime +import functools +from functools import partial +import gzip +import io +import os +from pathlib import Path +import pickle +import shutil +import tarfile +from typing import Any +import uuid +import zipfile + +import numpy as np +import pytest + +from pandas.compat import is_platform_little_endian +from pandas.compat._optional import import_optional_dependency + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + period_range, +) +import pandas._testing as tm +from pandas.tests.io.generate_legacy_storage_files import create_pickle_data +from pandas.util.version import Version + +import pandas.io.common as icom +from pandas.tseries.offsets import ( + Day, + MonthEnd, +) + + +# --------------------- +# comparison functions +# --------------------- +def compare_element(result, expected, typ): + if isinstance(expected, Index): + tm.assert_index_equal(result, expected) + return + + if typ.startswith("sp_"): + tm.assert_equal(result, expected) + elif typ == "timestamp": + if expected is pd.NaT: + assert result is pd.NaT + else: + assert result == expected + else: + comparator = getattr(tm, f"assert_{typ}_equal", tm.assert_almost_equal) + comparator(result, expected) + + +# --------------------- +# tests +# --------------------- + + +def test_pickles(datapath): + pytest.importorskip("pytz") + if not is_platform_little_endian(): + pytest.skip("known failure on non-little endian") + + current_data = create_pickle_data() + + # For loop for compat with --strict-data-files + for legacy_pickle in Path(__file__).parent.glob("data/legacy_pickle/*/*.p*kl*"): + legacy_version = Version(legacy_pickle.parent.name) + legacy_pickle = datapath(legacy_pickle) + + data = pd.read_pickle(legacy_pickle) + + for typ, dv in data.items(): + for dt, result in dv.items(): + expected = current_data[typ][dt] + + if ( + typ == "timestamp" + and dt in ("tz", "both") + and legacy_version < Version("1.3.0") + ): + # convert to wall time + # (bug since pandas 2.0 that tz gets dropped for older pickle files) + expected = expected.tz_convert(None) + + if legacy_version < Version("3.0.0.dev0"): + # before 3.0, we had: + # - object dtype instead of string + # - ns instead of us as the default unit + if typ in ("frame", "sp_frame"): + expected.columns = expected.columns.astype("object") + if dt in ("mixed", "mixed_dup"): + expected["C"] = expected["C"].astype(object) + expected["D"] = expected["D"].dt.as_unit("ns") + elif dt in ("cat_onecol", "cat_and_float"): + expected["A"] = expected["A"].astype( + pd.CategoricalDtype( + expected["A"].cat.categories.astype(object) + ) + ) + elif typ == "sp_frame" and dt == "float": + expected.index = expected.index.as_unit("ns") + elif dt == "mi": + expected.index = expected.index.set_levels( + [ + level.astype("object") + for level in expected.index.levels + ], + ) + elif typ in ("series", "sp_series"): + if dt == "ts": + expected.index = expected.index.as_unit("ns") + elif dt in ("dt", "dt_tz"): + expected = expected.dt.as_unit("ns") + elif dt == "cat": + expected = expected.astype( + pd.CategoricalDtype( + expected.cat.categories.astype(object) + ) + ) + elif dt == "dup": + expected.index = expected.index.astype(object) + elif typ == "index" and dt in ("date", "timedelta"): + expected = expected.as_unit("ns") + elif typ == "mi": + expected = expected.set_levels( + [level.astype("object") for level in expected.levels], + ) + if dt == "string": + # we switched from python to pyarrow as default storage in 3.0 + expected = expected.astype(pd.StringDtype("python")) + + if dt in ("dt_mixed_tzs", "dt_mixed2_tzs"): + if legacy_version < Version("2.1"): + # in pandas < 2.0, Timestamp() unit defaulted to 'ns' + expected_unit = "ns" + elif Version("2.1") <= legacy_version < Version("3.0.0.dev0"): + # in pandas 2.x, Timestamp() unit depended on input + expected_unit = "s" + else: + expected_unit = "us" + for col in expected.columns: + expected[col] = expected[col].dt.as_unit(expected_unit) + if typ == "index" and dt == "int" and "windows" in legacy_pickle: + expected = expected.astype(np.int32) + + if typ == "series" and dt == "ts": + # GH 7748 + tm.assert_series_equal(result, expected) + assert result.index.freq == expected.index.freq + assert not result.index.freq.normalize + tm.assert_series_equal(result > 0, expected > 0) + + # GH 9291 + freq = result.index.freq + assert freq + Day(1) == Day(2) + + res = freq + pd.Timedelta(hours=1) + assert isinstance(res, pd.Timedelta) + assert res == pd.Timedelta(days=1, hours=1) + + res = freq + pd.Timedelta(nanoseconds=1) + assert isinstance(res, pd.Timedelta) + assert res == pd.Timedelta(days=1, nanoseconds=1) + elif typ == "index" and dt == "period": + tm.assert_index_equal(result, expected) + assert isinstance(result.freq, MonthEnd) + assert result.freq == MonthEnd() + assert result.freqstr == "M" + tm.assert_index_equal(result.shift(2), expected.shift(2)) + elif typ == "series" and dt in ("dt_tz", "cat"): + tm.assert_series_equal(result, expected) + elif typ == "frame" and dt in ( + "dt_mixed_tzs", + "cat_onecol", + "cat_and_float", + ): + tm.assert_frame_equal(result, expected) + else: + compare_element(result, expected, typ) + + +def python_pickler(obj, path): + with open(path, "wb") as fh: + pickle.dump(obj, fh, protocol=-1) + + +def python_unpickler(path): + with open(path, "rb") as fh: + fh.seek(0) + return pickle.load(fh) + + +def flatten(data: dict) -> list[tuple[str, Any]]: + """Flatten create_pickle_data""" + return [ + (typ, example) + for typ, examples in data.items() + for example in examples.values() + ] + + +@pytest.mark.parametrize( + "pickle_writer", + [ + pytest.param(python_pickler, id="python"), + pytest.param(pd.to_pickle, id="pandas_proto_default"), + pytest.param( + functools.partial(pd.to_pickle, protocol=pickle.HIGHEST_PROTOCOL), + id="pandas_proto_highest", + ), + pytest.param(functools.partial(pd.to_pickle, protocol=4), id="pandas_proto_4"), + pytest.param( + functools.partial(pd.to_pickle, protocol=5), + id="pandas_proto_5", + ), + ], +) +@pytest.mark.parametrize("writer", [pd.to_pickle, python_pickler]) +@pytest.mark.parametrize("typ, expected", flatten(create_pickle_data())) +def test_round_trip_current(typ, expected, pickle_writer, writer, temp_file): + path = temp_file + # test writing with each pickler + pickle_writer(expected, path) + + # test reading with each unpickler + result = pd.read_pickle(path) + compare_element(result, expected, typ) + + result = python_unpickler(path) + compare_element(result, expected, typ) + + # and the same for file objects (GH 35679) + with open(path, mode="wb") as handle: + writer(expected, path) + handle.seek(0) # shouldn't close file handle + with open(path, mode="rb") as handle: + result = pd.read_pickle(handle) + handle.seek(0) # shouldn't close file handle + compare_element(result, expected, typ) + + +def test_pickle_path_pathlib(temp_file): + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + result = tm.round_trip_pathlib(df.to_pickle, pd.read_pickle, temp_file) + tm.assert_frame_equal(df, result) + + +# --------------------- +# test pickle compression +# --------------------- + + +@pytest.fixture +def get_random_path(): + return f"__{uuid.uuid4()}__.pickle" + + +class TestCompression: + _extension_to_compression = icom.extension_to_compression + + def compress_file(self, src_path, dest_path, compression): + if compression is None: + shutil.copyfile(src_path, dest_path) + return + + if compression == "gzip": + f = gzip.open(dest_path, "w") + elif compression == "bz2": + f = bz2.BZ2File(dest_path, "w") + elif compression == "zip": + with zipfile.ZipFile(dest_path, "w", compression=zipfile.ZIP_DEFLATED) as f: + f.write(src_path, os.path.basename(src_path)) + elif compression == "tar": + with open(src_path, "rb") as fh: + with tarfile.open(dest_path, mode="w") as tar: + tarinfo = tar.gettarinfo(src_path, os.path.basename(src_path)) + tar.addfile(tarinfo, fh) + elif compression == "xz": + import lzma + + f = lzma.LZMAFile(dest_path, "w") + elif compression == "zstd": + f = import_optional_dependency("zstandard").open(dest_path, "wb") + else: + msg = f"Unrecognized compression type: {compression}" + raise ValueError(msg) + + if compression not in ["zip", "tar"]: + with open(src_path, "rb") as fh: + with f: + f.write(fh.read()) + + def test_write_explicit(self, compression, get_random_path, temp_file): + p1 = temp_file.parent / f"{temp_file.stem}.compressed" + p2 = temp_file.parent / f"{temp_file.stem}.raw" + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + + # write to compressed file + df.to_pickle(p1, compression=compression) + + # decompress + with tm.decompress_file(p1, compression=compression) as f: + with open(p2, "wb") as fh: + fh.write(f.read()) + + # read decompressed file + df2 = pd.read_pickle(p2, compression=None) + + tm.assert_frame_equal(df, df2) + + @pytest.mark.parametrize("compression", ["", "None", "bad", "7z"]) + def test_write_explicit_bad(self, compression, get_random_path, temp_file): + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + path = temp_file + with pytest.raises(ValueError, match="Unrecognized compression type"): + df.to_pickle(path, compression=compression) + + def test_write_infer(self, compression_ext, get_random_path, temp_file): + p1 = temp_file.parent / f"{temp_file.stem}{compression_ext}" + p2 = temp_file.parent / f"{temp_file.stem}.raw" + compression = self._extension_to_compression.get(compression_ext.lower()) + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + + # write to compressed file by inferred compression method + df.to_pickle(p1) + + # decompress + with tm.decompress_file(p1, compression=compression) as f: + with open(p2, "wb") as fh: + fh.write(f.read()) + + # read decompressed file + df2 = pd.read_pickle(p2, compression=None) + + tm.assert_frame_equal(df, df2) + + def test_read_explicit(self, compression, get_random_path, temp_file): + p1 = temp_file.parent / f"{temp_file.stem}.raw" + p2 = temp_file.parent / f"{temp_file.stem}.compressed" + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + + # write to uncompressed file + df.to_pickle(p1, compression=None) + + # compress + self.compress_file(p1, p2, compression=compression) + + # read compressed file + df2 = pd.read_pickle(p2, compression=compression) + tm.assert_frame_equal(df, df2) + + def test_read_infer(self, compression_ext, get_random_path, temp_file): + p1 = temp_file.parent / f"{temp_file.stem}.raw" + p2 = temp_file.parent / f"{temp_file.stem}{compression_ext}" + compression = self._extension_to_compression.get(compression_ext.lower()) + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + + # write to uncompressed file + df.to_pickle(p1, compression=None) + + # compress + self.compress_file(p1, p2, compression=compression) + + # read compressed file by inferred compression method + df2 = pd.read_pickle(p2) + tm.assert_frame_equal(df, df2) + + +# --------------------- +# test pickle compression +# --------------------- + + +class TestProtocol: + @pytest.mark.parametrize("protocol", [-1, 0, 1, 2]) + def test_read(self, protocol, get_random_path, temp_file): + path = temp_file + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + df.to_pickle(path, protocol=protocol) + df2 = pd.read_pickle(path) + tm.assert_frame_equal(df, df2) + + +def test_pickle_buffer_roundtrip(temp_file): + path = temp_file + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + with open(path, "wb") as fh: + df.to_pickle(fh) + with open(path, "rb") as fh: + result = pd.read_pickle(fh) + tm.assert_frame_equal(df, result) + + +def test_pickle_fsspec_roundtrip(temp_file): + pytest.importorskip("fsspec") + # Using temp_file for context, but fsspec uses memory URL + mockurl = "memory://mockfile" + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + df.to_pickle(mockurl) + result = pd.read_pickle(mockurl) + tm.assert_frame_equal(df, result) + + +class MyTz(datetime.tzinfo): + def __init__(self) -> None: + pass + + +def test_read_pickle_with_subclass(temp_file): + # GH 12163 + expected = Series(dtype=object), MyTz() + result = tm.round_trip_pickle(expected, temp_file) + + tm.assert_series_equal(result[0], expected[0]) + assert isinstance(result[1], MyTz) + + +def test_pickle_binary_object_compression(compression, temp_file): + """ + Read/write from binary file-objects w/wo compression. + + GH 26237, GH 29054, and GH 29570 + """ + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(30)], dtype=object), + ) + + # reference for compression + path = temp_file + df.to_pickle(path, compression=compression) + reference = path.read_bytes() + + # write + buffer = io.BytesIO() + df.to_pickle(buffer, compression=compression) + buffer.seek(0) + + # gzip and zip safe the filename: cannot compare the compressed content + assert buffer.getvalue() == reference or compression in ("gzip", "zip", "tar") + + # read + read_df = pd.read_pickle(buffer, compression=compression) + buffer.seek(0) + tm.assert_frame_equal(df, read_df) + + +def test_pickle_dataframe_with_multilevel_index( + multiindex_year_month_day_dataframe_random_data, + multiindex_dataframe_random_data, + temp_file, +): + ymd = multiindex_year_month_day_dataframe_random_data + frame = multiindex_dataframe_random_data + + def _test_roundtrip(frame, temp_file): + unpickled = tm.round_trip_pickle(frame, temp_file) + tm.assert_frame_equal(frame, unpickled) + + _test_roundtrip(frame, temp_file) + _test_roundtrip(frame.T, temp_file) + _test_roundtrip(ymd, temp_file) + _test_roundtrip(ymd.T, temp_file) + + +def test_pickle_timeseries_periodindex(temp_file): + # GH#2891 + prng = period_range("1/1/2011", "1/1/2012", freq="M") + ts = Series(np.random.default_rng(2).standard_normal(len(prng)), prng) + new_ts = tm.round_trip_pickle(ts, temp_file) + assert new_ts.index.freqstr == "M" + + +@pytest.mark.parametrize( + "name", [777, 777.0, "name", datetime.datetime(2001, 11, 11), (1, 2)] +) +def test_pickle_preserve_name(name, temp_file): + unpickled = tm.round_trip_pickle( + Series(np.arange(10, dtype=np.float64), name=name), temp_file + ) + assert unpickled.name == name + + +def test_pickle_datetimes(datetime_series, temp_file): + unp_ts = tm.round_trip_pickle(datetime_series, temp_file) + tm.assert_series_equal(unp_ts, datetime_series) + + +def test_pickle_strings(string_series, temp_file): + unp_series = tm.round_trip_pickle(string_series, temp_file) + tm.assert_series_equal(unp_series, string_series) + + +def test_pickle_preserves_block_ndim(temp_file): + # GH#37631 + ser = Series(list("abc")).astype("category").iloc[[0]] + res = tm.round_trip_pickle(ser, temp_file) + + assert res._mgr.blocks[0].ndim == 1 + assert res._mgr.blocks[0].shape == (1,) + + # GH#37631 OP issue was about indexing, underlying problem was pickle + tm.assert_series_equal(res[[True]], ser) + + +@pytest.mark.parametrize("protocol", [pickle.DEFAULT_PROTOCOL, pickle.HIGHEST_PROTOCOL]) +def test_pickle_big_dataframe_compression(protocol, compression, temp_file): + # GH#39002 + df = DataFrame(range(100000)) + result = tm.round_trip_pathlib( + partial(df.to_pickle, protocol=protocol, compression=compression), + partial(pd.read_pickle, compression=compression), + temp_file, + ) + tm.assert_frame_equal(df, result) + + +def test_pickle_frame_v124_unpickle_130(datapath): + # GH#42345 DataFrame created in 1.2.x, unpickle in 1.3.x + path = datapath( + Path(__file__).parent, + "data", + "legacy_pickle", + "1.2.4", + "empty_frame_v1_2_4-GH#42345.pkl", + ) + with open(path, "rb") as fd: + df = pickle.load(fd) + + expected = DataFrame(index=[], columns=[]) + tm.assert_frame_equal(df, expected) diff --git a/pandas/tests/io/test_s3.py b/pandas/tests/io/test_s3.py new file mode 100644 index 0000000000000000000000000000000000000000..31d22223b0a33218db58120cea1db393827ff699 --- /dev/null +++ b/pandas/tests/io/test_s3.py @@ -0,0 +1,33 @@ +from io import BytesIO + +import pytest + +from pandas import read_csv + + +@pytest.mark.parametrize("data", [b"foo,bar,baz\n1,2,3\n4,5,6\n", b"just,the,header\n"]) +def test_streaming_s3_objects(data): + # GH 17135 + # botocore gained iteration support in 1.10.47, can now be used in read_* + pytest.importorskip("botocore", minversion="1.10.47") + from botocore.response import StreamingBody + + body = StreamingBody(BytesIO(data), content_length=len(data)) + read_csv(body) + + +@pytest.mark.single_cpu +@pytest.mark.parametrize("header", ["infer", None]) +def test_read_with_and_without_creds_from_pub_bucket( + s3_bucket_public_with_data, s3so, header +): + # GH 34626 + pytest.importorskip("s3fs") + nrows = 5 + df = read_csv( + f"s3://{s3_bucket_public_with_data.name}/tips.csv", + nrows=nrows, + header=header, + storage_options=s3so, + ) + assert len(df) == nrows diff --git a/pandas/tests/io/test_spss.py b/pandas/tests/io/test_spss.py new file mode 100644 index 0000000000000000000000000000000000000000..6210c0289a160e0ec853cb810dd80e404e11c37e --- /dev/null +++ b/pandas/tests/io/test_spss.py @@ -0,0 +1,169 @@ +import datetime +from pathlib import Path + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + +pyreadstat = pytest.importorskip("pyreadstat") + + +# TODO(CoW) - detection of chained assignment in cython +# https://github.com/pandas-dev/pandas/issues/51315 +@pytest.mark.filterwarnings("ignore::pandas.errors.ChainedAssignmentError") +@pytest.mark.parametrize("path_klass", [lambda p: p, Path]) +def test_spss_labelled_num(path_klass, datapath): + # test file from the Haven project (https://haven.tidyverse.org/) + # Licence at LICENSES/HAVEN_LICENSE, LICENSES/HAVEN_MIT + fname = path_klass(datapath("io", "data", "spss", "labelled-num.sav")) + + df = pd.read_spss(fname, convert_categoricals=True) + expected = pd.DataFrame({"VAR00002": "This is one"}, index=[0]) + expected["VAR00002"] = pd.Categorical(expected["VAR00002"]) + tm.assert_frame_equal(df, expected) + + df = pd.read_spss(fname, convert_categoricals=False) + expected = pd.DataFrame({"VAR00002": 1.0}, index=[0]) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.filterwarnings("ignore::pandas.errors.ChainedAssignmentError") +def test_spss_labelled_num_na(datapath): + # test file from the Haven project (https://haven.tidyverse.org/) + # Licence at LICENSES/HAVEN_LICENSE, LICENSES/HAVEN_MIT + fname = datapath("io", "data", "spss", "labelled-num-na.sav") + + df = pd.read_spss(fname, convert_categoricals=True) + expected = pd.DataFrame({"VAR00002": ["This is one", None]}) + expected["VAR00002"] = pd.Categorical(expected["VAR00002"]) + tm.assert_frame_equal(df, expected) + + df = pd.read_spss(fname, convert_categoricals=False) + expected = pd.DataFrame({"VAR00002": [1.0, np.nan]}) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.filterwarnings("ignore::pandas.errors.ChainedAssignmentError") +def test_spss_labelled_str(datapath): + # test file from the Haven project (https://haven.tidyverse.org/) + # Licence at LICENSES/HAVEN_LICENSE, LICENSES/HAVEN_MIT + fname = datapath("io", "data", "spss", "labelled-str.sav") + + df = pd.read_spss(fname, convert_categoricals=True) + expected = pd.DataFrame({"gender": ["Male", "Female"]}) + expected["gender"] = pd.Categorical(expected["gender"]) + tm.assert_frame_equal(df, expected) + + df = pd.read_spss(fname, convert_categoricals=False) + expected = pd.DataFrame({"gender": ["M", "F"]}) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.filterwarnings("ignore::pandas.errors.ChainedAssignmentError") +def test_spss_kwargs(datapath): + # test file from the Haven project (https://haven.tidyverse.org/) + # Licence at LICENSES/HAVEN_LICENSE, LICENSES/HAVEN_MIT + fname = datapath("io", "data", "spss", "labelled-str.sav") + + df = pd.read_spss(fname, convert_categoricals=True, row_limit=1) + expected = pd.DataFrame({"gender": ["Male"]}, dtype="category") + tm.assert_frame_equal(df, expected) + + df = pd.read_spss(fname, convert_categoricals=False, row_offset=1) + expected = pd.DataFrame({"gender": ["F"]}) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.filterwarnings("ignore::pandas.errors.ChainedAssignmentError") +def test_spss_umlauts(datapath): + # test file from the Haven project (https://haven.tidyverse.org/) + # Licence at LICENSES/HAVEN_LICENSE, LICENSES/HAVEN_MIT + fname = datapath("io", "data", "spss", "umlauts.sav") + + df = pd.read_spss(fname, convert_categoricals=True) + expected = pd.DataFrame( + {"var1": ["the ä umlaut", "the ü umlaut", "the ä umlaut", "the ö umlaut"]} + ) + expected["var1"] = pd.Categorical(expected["var1"]) + tm.assert_frame_equal(df, expected) + + df = pd.read_spss(fname, convert_categoricals=False) + expected = pd.DataFrame({"var1": [1.0, 2.0, 1.0, 3.0]}) + tm.assert_frame_equal(df, expected) + + +def test_spss_usecols(datapath): + # usecols must be list-like + fname = datapath("io", "data", "spss", "labelled-num.sav") + + with pytest.raises(TypeError, match="usecols must be list-like."): + pd.read_spss(fname, usecols="VAR00002") + + +def test_spss_umlauts_dtype_backend(datapath, dtype_backend): + # test file from the Haven project (https://haven.tidyverse.org/) + # Licence at LICENSES/HAVEN_LICENSE, LICENSES/HAVEN_MIT + fname = datapath("io", "data", "spss", "umlauts.sav") + + df = pd.read_spss(fname, convert_categoricals=False, dtype_backend=dtype_backend) + expected = pd.DataFrame({"var1": [1.0, 2.0, 1.0, 3.0]}, dtype="Int64") + + if dtype_backend == "pyarrow": + pa = pytest.importorskip("pyarrow") + + from pandas.arrays import ArrowExtensionArray + + expected = pd.DataFrame( + { + col: ArrowExtensionArray(pa.array(expected[col], from_pandas=True)) + for col in expected.columns + } + ) + + tm.assert_frame_equal(df, expected) + + +def test_invalid_dtype_backend(): + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + with pytest.raises(ValueError, match=msg): + pd.read_spss("test", dtype_backend="numpy") + + +@pytest.mark.filterwarnings("ignore::pandas.errors.ChainedAssignmentError") +def test_spss_metadata(datapath): + # GH 54264 + fname = datapath("io", "data", "spss", "labelled-num.sav") + + df = pd.read_spss(fname) + metadata = { + "column_names": ["VAR00002"], + "column_labels": [None], + "column_names_to_labels": {"VAR00002": None}, + "file_encoding": "UTF-8", + "number_columns": 1, + "number_rows": 1, + "variable_value_labels": {"VAR00002": {1.0: "This is one"}}, + "value_labels": {"labels0": {1.0: "This is one"}}, + "variable_to_label": {"VAR00002": "labels0"}, + "notes": [], + "original_variable_types": {"VAR00002": "F8.0"}, + "readstat_variable_types": {"VAR00002": "double"}, + "table_name": None, + "missing_ranges": {}, + "missing_user_values": {}, + "variable_storage_width": {"VAR00002": 8}, + "variable_display_width": {"VAR00002": 8}, + "variable_alignment": {"VAR00002": "unknown"}, + "variable_measure": {"VAR00002": "unknown"}, + "file_label": None, + "file_format": "sav/zsav", + "creation_time": datetime.datetime(2015, 2, 6, 14, 33, 36), + "modification_time": datetime.datetime(2015, 2, 6, 14, 33, 36), + "mr_sets": {}, + } + tm.assert_dict_equal(df.attrs, metadata) diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py new file mode 100644 index 0000000000000000000000000000000000000000..b0e06b204379986e0af0b5edcde4182a091e5fb9 --- /dev/null +++ b/pandas/tests/io/test_sql.py @@ -0,0 +1,4398 @@ +from __future__ import annotations + +import contextlib +import csv +from datetime import ( + date, + datetime, + time, + timedelta, +) +from decimal import Decimal +from io import StringIO +from pathlib import Path +import sqlite3 +from typing import TYPE_CHECKING +import uuid + +import numpy as np +import pytest + +from pandas._config import using_string_dtype + +from pandas._libs import lib +from pandas.compat import pa_version_under14p1 +from pandas.compat._optional import import_optional_dependency +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + Timestamp, + concat, + date_range, + isna, + to_datetime, + to_timedelta, +) +import pandas._testing as tm +from pandas.util.version import Version + +from pandas.io import sql +from pandas.io.sql import ( + SQLAlchemyEngine, + SQLDatabase, + SQLiteDatabase, + get_engine, + pandasSQL_builder, + read_sql_query, + read_sql_table, +) + +if TYPE_CHECKING: + import sqlalchemy + + +pytestmark = [ + pytest.mark.filterwarnings( + "ignore:Passing a BlockManager to DataFrame:DeprecationWarning" + ), + pytest.mark.single_cpu, +] + + +@pytest.fixture +def sql_strings(): + return { + "read_parameters": { + "sqlite": "SELECT * FROM iris WHERE Name=? AND SepalLength=?", + "mysql": "SELECT * FROM iris WHERE `Name`=%s AND `SepalLength`=%s", + "postgresql": 'SELECT * FROM iris WHERE "Name"=%s AND "SepalLength"=%s', + }, + "read_named_parameters": { + "sqlite": """ + SELECT * FROM iris WHERE Name=:name AND SepalLength=:length + """, + "mysql": """ + SELECT * FROM iris WHERE + `Name`=%(name)s AND `SepalLength`=%(length)s + """, + "postgresql": """ + SELECT * FROM iris WHERE + "Name"=%(name)s AND "SepalLength"=%(length)s + """, + }, + "read_no_parameters_with_percent": { + "sqlite": "SELECT * FROM iris WHERE Name LIKE '%'", + "mysql": "SELECT * FROM iris WHERE `Name` LIKE '%'", + "postgresql": "SELECT * FROM iris WHERE \"Name\" LIKE '%'", + }, + } + + +def iris_table_metadata(): + import sqlalchemy + from sqlalchemy import ( + Column, + Double, + Float, + MetaData, + String, + Table, + ) + + dtype = Double if Version(sqlalchemy.__version__) >= Version("2.0.0") else Float + metadata = MetaData() + iris = Table( + "iris", + metadata, + Column("SepalLength", dtype), + Column("SepalWidth", dtype), + Column("PetalLength", dtype), + Column("PetalWidth", dtype), + Column("Name", String(200)), + ) + return iris + + +def create_and_load_iris_sqlite3(conn, iris_file: Path): + stmt = """CREATE TABLE iris ( + "SepalLength" REAL, + "SepalWidth" REAL, + "PetalLength" REAL, + "PetalWidth" REAL, + "Name" TEXT + )""" + + cur = conn.cursor() + cur.execute(stmt) + with iris_file.open(newline=None, encoding="utf-8") as csvfile: + reader = csv.reader(csvfile) + next(reader) + stmt = "INSERT INTO iris VALUES(?, ?, ?, ?, ?)" + # ADBC requires explicit types - no implicit str -> float conversion + records = [] + records = [ + ( + float(row[0]), + float(row[1]), + float(row[2]), + float(row[3]), + row[4], + ) + for row in reader + ] + + cur.executemany(stmt, records) + cur.close() + + conn.commit() + + +def create_and_load_iris_postgresql(conn, iris_file: Path): + stmt = """CREATE TABLE iris ( + "SepalLength" DOUBLE PRECISION, + "SepalWidth" DOUBLE PRECISION, + "PetalLength" DOUBLE PRECISION, + "PetalWidth" DOUBLE PRECISION, + "Name" TEXT + )""" + with conn.cursor() as cur: + cur.execute(stmt) + with iris_file.open(newline=None, encoding="utf-8") as csvfile: + reader = csv.reader(csvfile) + next(reader) + stmt = "INSERT INTO iris VALUES($1, $2, $3, $4, $5)" + # ADBC requires explicit types - no implicit str -> float conversion + records = [ + ( + float(row[0]), + float(row[1]), + float(row[2]), + float(row[3]), + row[4], + ) + for row in reader + ] + + cur.executemany(stmt, records) + + conn.commit() + + +def create_and_load_iris(conn, iris_file: Path): + from sqlalchemy import insert + + iris = iris_table_metadata() + + with iris_file.open(newline=None, encoding="utf-8") as csvfile: + reader = csv.reader(csvfile) + header = next(reader) + params = [dict(zip(header, row)) for row in reader] + stmt = insert(iris).values(params) + with conn.begin() as con: + iris.drop(con, checkfirst=True) + iris.create(bind=con) + con.execute(stmt) + + +def create_and_load_iris_view(conn): + stmt = "CREATE VIEW iris_view AS SELECT * FROM iris" + if isinstance(conn, sqlite3.Connection): + cur = conn.cursor() + cur.execute(stmt) + else: + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + if adbc and isinstance(conn, adbc.Connection): + with conn.cursor() as cur: + cur.execute(stmt) + conn.commit() + else: + from sqlalchemy import text + + stmt = text(stmt) + with conn.begin() as con: + con.execute(stmt) + + +def types_table_metadata(dialect: str): + from sqlalchemy import ( + TEXT, + Boolean, + Column, + DateTime, + Float, + Integer, + MetaData, + Table, + ) + + date_type = TEXT if dialect == "sqlite" else DateTime + bool_type = Integer if dialect == "sqlite" else Boolean + metadata = MetaData() + types = Table( + "types", + metadata, + Column("TextCol", TEXT), + # error: Cannot infer type argument 1 of "Column" + Column("DateCol", date_type), # type: ignore[misc] + Column("IntDateCol", Integer), + Column("IntDateOnlyCol", Integer), + Column("FloatCol", Float), + Column("IntCol", Integer), + # error: Cannot infer type argument 1 of "Column" + Column("BoolCol", bool_type), # type: ignore[misc] + Column("IntColWithNull", Integer), + # error: Cannot infer type argument 1 of "Column" + Column("BoolColWithNull", bool_type), # type: ignore[misc] + ) + return types + + +def create_and_load_types_sqlite3(conn, types_data: list[dict]): + stmt = """CREATE TABLE types ( + "TextCol" TEXT, + "DateCol" TEXT, + "IntDateCol" INTEGER, + "IntDateOnlyCol" INTEGER, + "FloatCol" REAL, + "IntCol" INTEGER, + "BoolCol" INTEGER, + "IntColWithNull" INTEGER, + "BoolColWithNull" INTEGER + )""" + + ins_stmt = """ + INSERT INTO types + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?) + """ + + if isinstance(conn, sqlite3.Connection): + cur = conn.cursor() + cur.execute(stmt) + cur.executemany(ins_stmt, types_data) + else: + with conn.cursor() as cur: + cur.execute(stmt) + cur.executemany(ins_stmt, types_data) + + conn.commit() + + +def create_and_load_types_postgresql(conn, types_data: list[dict]): + with conn.cursor() as cur: + stmt = """CREATE TABLE types ( + "TextCol" TEXT, + "DateCol" TIMESTAMP, + "IntDateCol" INTEGER, + "IntDateOnlyCol" INTEGER, + "FloatCol" DOUBLE PRECISION, + "IntCol" INTEGER, + "BoolCol" BOOLEAN, + "IntColWithNull" INTEGER, + "BoolColWithNull" BOOLEAN + )""" + cur.execute(stmt) + + stmt = """ + INSERT INTO types + VALUES($1, $2::timestamp, $3, $4, $5, $6, $7, $8, $9) + """ + + cur.executemany(stmt, types_data) + + conn.commit() + + +def create_and_load_types(conn, types_data: list[dict], dialect: str): + from sqlalchemy import insert + from sqlalchemy.engine import Engine + + types = types_table_metadata(dialect) + + stmt = insert(types).values(types_data) + if isinstance(conn, Engine): + with conn.connect() as conn: + with conn.begin(): + types.drop(conn, checkfirst=True) + types.create(bind=conn) + conn.execute(stmt) + else: + with conn.begin(): + types.drop(conn, checkfirst=True) + types.create(bind=conn) + conn.execute(stmt) + + +def create_and_load_postgres_datetz(conn): + from sqlalchemy import ( + Column, + DateTime, + MetaData, + Table, + insert, + ) + from sqlalchemy.engine import Engine + + metadata = MetaData() + datetz = Table("datetz", metadata, Column("DateColWithTz", DateTime(timezone=True))) + datetz_data = [ + { + "DateColWithTz": "2000-01-01 00:00:00-08:00", + }, + { + "DateColWithTz": "2000-06-01 00:00:00-07:00", + }, + ] + stmt = insert(datetz).values(datetz_data) + if isinstance(conn, Engine): + with conn.connect() as conn: + with conn.begin(): + datetz.drop(conn, checkfirst=True) + datetz.create(bind=conn) + conn.execute(stmt) + else: + with conn.begin(): + datetz.drop(conn, checkfirst=True) + datetz.create(bind=conn) + conn.execute(stmt) + + # "2000-01-01 00:00:00-08:00" should convert to + # "2000-01-01 08:00:00" + # "2000-06-01 00:00:00-07:00" should convert to + # "2000-06-01 07:00:00" + # GH 6415 + expected_data = [ + Timestamp("2000-01-01 08:00:00", tz="UTC"), + Timestamp("2000-06-01 07:00:00", tz="UTC"), + ] + return Series(expected_data, name="DateColWithTz").astype("M8[us, UTC]") + + +def check_iris_frame(frame: DataFrame): + pytype = frame.dtypes.iloc[0].type + row = frame.iloc[0] + assert issubclass(pytype, np.floating) + tm.assert_series_equal( + row, Series([5.1, 3.5, 1.4, 0.2, "Iris-setosa"], index=frame.columns, name=0) + ) + assert frame.shape in ((150, 5), (8, 5)) + + +def count_rows(conn, table_name: str): + stmt = f"SELECT count(*) AS count_1 FROM {table_name}" + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + if isinstance(conn, sqlite3.Connection): + cur = conn.cursor() + return cur.execute(stmt).fetchone()[0] + elif adbc and isinstance(conn, adbc.Connection): + with conn.cursor() as cur: + cur.execute(stmt) + return cur.fetchone()[0] + else: + from sqlalchemy import create_engine + from sqlalchemy.engine import Engine + + if isinstance(conn, str): + try: + engine = create_engine(conn) + with engine.connect() as conn: + return conn.exec_driver_sql(stmt).scalar_one() + finally: + engine.dispose() + elif isinstance(conn, Engine): + with conn.connect() as conn: + return conn.exec_driver_sql(stmt).scalar_one() + else: + return conn.exec_driver_sql(stmt).scalar_one() + + +@pytest.fixture +def iris_path(datapath): + iris_path = datapath("io", "data", "csv", "iris.csv") + return Path(iris_path) + + +@pytest.fixture +def types_data(): + return [ + { + "TextCol": "first", + "DateCol": "2000-01-03 00:00:00", + "IntDateCol": 535852800, + "IntDateOnlyCol": 20101010, + "FloatCol": 10.10, + "IntCol": 1, + "BoolCol": False, + "IntColWithNull": 1, + "BoolColWithNull": False, + }, + { + "TextCol": "first", + "DateCol": "2000-01-04 00:00:00", + "IntDateCol": 1356998400, + "IntDateOnlyCol": 20101212, + "FloatCol": 10.10, + "IntCol": 1, + "BoolCol": False, + "IntColWithNull": None, + "BoolColWithNull": None, + }, + ] + + +@pytest.fixture +def types_data_frame(types_data): + dtypes = { + "TextCol": "str", + "DateCol": "str", + "IntDateCol": "int64", + "IntDateOnlyCol": "int64", + "FloatCol": "float", + "IntCol": "int64", + "BoolCol": "int64", + "IntColWithNull": "float", + "BoolColWithNull": "float", + } + df = DataFrame(types_data) + return df[dtypes.keys()].astype(dtypes) + + +@pytest.fixture +def test_frame1(): + columns = ["index", "A", "B", "C", "D"] + data = [ + ( + "2000-01-03 00:00:00", + 0.980268513777, + 3.68573087906, + -0.364216805298, + -1.15973806169, + ), + ( + "2000-01-04 00:00:00", + 1.04791624281, + -0.0412318367011, + -0.16181208307, + 0.212549316967, + ), + ( + "2000-01-05 00:00:00", + 0.498580885705, + 0.731167677815, + -0.537677223318, + 1.34627041952, + ), + ( + "2000-01-06 00:00:00", + 1.12020151869, + 1.56762092543, + 0.00364077397681, + 0.67525259227, + ), + ] + return DataFrame(data, columns=columns) + + +@pytest.fixture +def test_frame3(): + columns = ["index", "A", "B"] + data = [ + ("2000-01-03 00:00:00", 2**31 - 1, -1.987670), + ("2000-01-04 00:00:00", -29, -0.0412318367011), + ("2000-01-05 00:00:00", 20000, 0.731167677815), + ("2000-01-06 00:00:00", -290867, 1.56762092543), + ] + return DataFrame(data, columns=columns) + + +def get_all_views(conn): + if isinstance(conn, sqlite3.Connection): + c = conn.execute("SELECT name FROM sqlite_master WHERE type='view'") + return [view[0] for view in c.fetchall()] + else: + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + if adbc and isinstance(conn, adbc.Connection): + results = [] + info = conn.adbc_get_objects().read_all().to_pylist() + for catalog in info: + catalog["catalog_name"] + for schema in catalog["catalog_db_schemas"]: + schema["db_schema_name"] + for table in schema["db_schema_tables"]: + if table["table_type"] == "view": + view_name = table["table_name"] + results.append(view_name) + + return results + else: + from sqlalchemy import inspect + + return inspect(conn).get_view_names() + + +def get_all_tables(conn): + if isinstance(conn, sqlite3.Connection): + c = conn.execute("SELECT name FROM sqlite_master WHERE type='table'") + return [table[0] for table in c.fetchall()] + else: + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + + if adbc and isinstance(conn, adbc.Connection): + results = [] + info = conn.adbc_get_objects().read_all().to_pylist() + for catalog in info: + for schema in catalog["catalog_db_schemas"]: + for table in schema["db_schema_tables"]: + if table["table_type"] == "table": + table_name = table["table_name"] + results.append(table_name) + + return results + else: + from sqlalchemy import inspect + + return inspect(conn).get_table_names() + + +def drop_table( + table_name: str, + conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, +): + if isinstance(conn, sqlite3.Connection): + conn.execute(f"DROP TABLE IF EXISTS {sql._get_valid_sqlite_name(table_name)}") + conn.commit() + + else: + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + if adbc and isinstance(conn, adbc.Connection): + with conn.cursor() as cur: + cur.execute(f'DROP TABLE IF EXISTS "{table_name}"') + else: + with conn.begin() as con: + with sql.SQLDatabase(con) as db: + db.drop_table(table_name) + + +def drop_view( + view_name: str, + conn: sqlite3.Connection | sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, +): + import sqlalchemy + + if isinstance(conn, sqlite3.Connection): + conn.execute(f"DROP VIEW IF EXISTS {sql._get_valid_sqlite_name(view_name)}") + conn.commit() + else: + adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore") + if adbc and isinstance(conn, adbc.Connection): + with conn.cursor() as cur: + cur.execute(f'DROP VIEW IF EXISTS "{view_name}"') + else: + quoted_view = conn.engine.dialect.identifier_preparer.quote_identifier( + view_name + ) + stmt = sqlalchemy.text(f"DROP VIEW IF EXISTS {quoted_view}") + with conn.begin() as con: + con.execute(stmt) # type: ignore[union-attr] + + +@pytest.fixture +def mysql_pymysql_engine(): + sqlalchemy = pytest.importorskip("sqlalchemy") + pymysql = pytest.importorskip("pymysql") + engine = sqlalchemy.create_engine( + "mysql+pymysql://root@localhost:3306/pandas", + connect_args={"client_flag": pymysql.constants.CLIENT.MULTI_STATEMENTS}, + poolclass=sqlalchemy.pool.NullPool, + ) + yield engine + for view in get_all_views(engine): + drop_view(view, engine) + for tbl in get_all_tables(engine): + drop_table(tbl, engine) + engine.dispose() + + +@pytest.fixture +def mysql_pymysql_engine_iris(mysql_pymysql_engine, iris_path): + create_and_load_iris(mysql_pymysql_engine, iris_path) + create_and_load_iris_view(mysql_pymysql_engine) + return mysql_pymysql_engine + + +@pytest.fixture +def mysql_pymysql_engine_types(mysql_pymysql_engine, types_data): + create_and_load_types(mysql_pymysql_engine, types_data, "mysql") + return mysql_pymysql_engine + + +@pytest.fixture +def mysql_pymysql_conn(mysql_pymysql_engine): + with mysql_pymysql_engine.connect() as conn: + yield conn + + +@pytest.fixture +def mysql_pymysql_conn_iris(mysql_pymysql_engine_iris): + with mysql_pymysql_engine_iris.connect() as conn: + yield conn + + +@pytest.fixture +def mysql_pymysql_conn_types(mysql_pymysql_engine_types): + with mysql_pymysql_engine_types.connect() as conn: + yield conn + + +@pytest.fixture +def postgresql_psycopg2_engine(): + sqlalchemy = pytest.importorskip("sqlalchemy") + pytest.importorskip("psycopg2") + engine = sqlalchemy.create_engine( + "postgresql+psycopg2://postgres:postgres@localhost:5432/pandas", + poolclass=sqlalchemy.pool.NullPool, + ) + yield engine + for view in get_all_views(engine): + drop_view(view, engine) + for tbl in get_all_tables(engine): + drop_table(tbl, engine) + engine.dispose() + + +@pytest.fixture +def postgresql_psycopg2_engine_iris(postgresql_psycopg2_engine, iris_path): + create_and_load_iris(postgresql_psycopg2_engine, iris_path) + create_and_load_iris_view(postgresql_psycopg2_engine) + return postgresql_psycopg2_engine + + +@pytest.fixture +def postgresql_psycopg2_engine_types(postgresql_psycopg2_engine, types_data): + create_and_load_types(postgresql_psycopg2_engine, types_data, "postgres") + return postgresql_psycopg2_engine + + +@pytest.fixture +def postgresql_psycopg2_conn(postgresql_psycopg2_engine): + with postgresql_psycopg2_engine.connect() as conn: + yield conn + + +@pytest.fixture +def postgresql_adbc_conn(): + pytest.importorskip("pyarrow") + pytest.importorskip("adbc_driver_postgresql") + from adbc_driver_postgresql import dbapi + + uri = "postgresql://postgres:postgres@localhost:5432/pandas" + with dbapi.connect(uri) as conn: + yield conn + for view in get_all_views(conn): + drop_view(view, conn) + for tbl in get_all_tables(conn): + drop_table(tbl, conn) + conn.commit() + + +@pytest.fixture +def postgresql_adbc_iris(postgresql_adbc_conn, iris_path): + import adbc_driver_manager as mgr + + conn = postgresql_adbc_conn + + try: + conn.adbc_get_table_schema("iris") + except mgr.ProgrammingError: + conn.rollback() + create_and_load_iris_postgresql(conn, iris_path) + try: + conn.adbc_get_table_schema("iris_view") + except mgr.ProgrammingError: # note arrow-adbc issue 1022 + conn.rollback() + create_and_load_iris_view(conn) + return conn + + +@pytest.fixture +def postgresql_adbc_types(postgresql_adbc_conn, types_data): + import adbc_driver_manager as mgr + + conn = postgresql_adbc_conn + + try: + conn.adbc_get_table_schema("types") + except mgr.ProgrammingError: + conn.rollback() + new_data = [tuple(entry.values()) for entry in types_data] + + create_and_load_types_postgresql(conn, new_data) + + return conn + + +@pytest.fixture +def postgresql_psycopg2_conn_iris(postgresql_psycopg2_engine_iris): + with postgresql_psycopg2_engine_iris.connect() as conn: + yield conn + + +@pytest.fixture +def postgresql_psycopg2_conn_types(postgresql_psycopg2_engine_types): + with postgresql_psycopg2_engine_types.connect() as conn: + yield conn + + +@pytest.fixture +def sqlite_str(temp_file): + pytest.importorskip("sqlalchemy") + return f"sqlite:///{temp_file}" + + +@pytest.fixture +def sqlite_engine(sqlite_str): + sqlalchemy = pytest.importorskip("sqlalchemy") + engine = sqlalchemy.create_engine(sqlite_str, poolclass=sqlalchemy.pool.NullPool) + yield engine + for view in get_all_views(engine): + drop_view(view, engine) + for tbl in get_all_tables(engine): + drop_table(tbl, engine) + engine.dispose() + + +@pytest.fixture +def sqlite_conn(sqlite_engine): + with sqlite_engine.connect() as conn: + yield conn + + +@pytest.fixture +def sqlite_str_iris(sqlite_str, iris_path): + sqlalchemy = pytest.importorskip("sqlalchemy") + engine = sqlalchemy.create_engine(sqlite_str) + create_and_load_iris(engine, iris_path) + create_and_load_iris_view(engine) + engine.dispose() + return sqlite_str + + +@pytest.fixture +def sqlite_engine_iris(sqlite_engine, iris_path): + create_and_load_iris(sqlite_engine, iris_path) + create_and_load_iris_view(sqlite_engine) + return sqlite_engine + + +@pytest.fixture +def sqlite_conn_iris(sqlite_engine_iris): + with sqlite_engine_iris.connect() as conn: + yield conn + + +@pytest.fixture +def sqlite_str_types(sqlite_str, types_data): + sqlalchemy = pytest.importorskip("sqlalchemy") + engine = sqlalchemy.create_engine(sqlite_str) + create_and_load_types(engine, types_data, "sqlite") + engine.dispose() + return sqlite_str + + +@pytest.fixture +def sqlite_engine_types(sqlite_engine, types_data): + create_and_load_types(sqlite_engine, types_data, "sqlite") + return sqlite_engine + + +@pytest.fixture +def sqlite_conn_types(sqlite_engine_types): + with sqlite_engine_types.connect() as conn: + yield conn + + +@pytest.fixture +def sqlite_adbc_conn(temp_file): + pytest.importorskip("pyarrow") + pytest.importorskip("adbc_driver_sqlite") + from adbc_driver_sqlite import dbapi + + uri = f"file:{temp_file}" + with dbapi.connect(uri) as conn: + yield conn + for view in get_all_views(conn): + drop_view(view, conn) + for tbl in get_all_tables(conn): + drop_table(tbl, conn) + conn.commit() + + +@pytest.fixture +def sqlite_adbc_iris(sqlite_adbc_conn, iris_path): + import adbc_driver_manager as mgr + + conn = sqlite_adbc_conn + try: + conn.adbc_get_table_schema("iris") + except mgr.ProgrammingError: + conn.rollback() + create_and_load_iris_sqlite3(conn, iris_path) + try: + conn.adbc_get_table_schema("iris_view") + except mgr.ProgrammingError: + conn.rollback() + create_and_load_iris_view(conn) + return conn + + +@pytest.fixture +def sqlite_adbc_types(sqlite_adbc_conn, types_data): + import adbc_driver_manager as mgr + + conn = sqlite_adbc_conn + try: + conn.adbc_get_table_schema("types") + except mgr.ProgrammingError: + conn.rollback() + new_data = [] + for entry in types_data: + entry["BoolCol"] = int(entry["BoolCol"]) + if entry["BoolColWithNull"] is not None: + entry["BoolColWithNull"] = int(entry["BoolColWithNull"]) + new_data.append(tuple(entry.values())) + + create_and_load_types_sqlite3(conn, new_data) + conn.commit() + + return conn + + +@pytest.fixture +def sqlite_buildin(): + with contextlib.closing(sqlite3.connect(":memory:")) as closing_conn: + with closing_conn as conn: + yield conn + + +@pytest.fixture +def sqlite_buildin_iris(sqlite_buildin, iris_path): + create_and_load_iris_sqlite3(sqlite_buildin, iris_path) + create_and_load_iris_view(sqlite_buildin) + return sqlite_buildin + + +@pytest.fixture +def sqlite_buildin_types(sqlite_buildin, types_data): + types_data = [tuple(entry.values()) for entry in types_data] + create_and_load_types_sqlite3(sqlite_buildin, types_data) + return sqlite_buildin + + +mysql_connectable = [ + pytest.param("mysql_pymysql_engine", marks=pytest.mark.db), + pytest.param("mysql_pymysql_conn", marks=pytest.mark.db), +] + +mysql_connectable_iris = [ + pytest.param("mysql_pymysql_engine_iris", marks=pytest.mark.db), + pytest.param("mysql_pymysql_conn_iris", marks=pytest.mark.db), +] + +mysql_connectable_types = [ + pytest.param("mysql_pymysql_engine_types", marks=pytest.mark.db), + pytest.param("mysql_pymysql_conn_types", marks=pytest.mark.db), +] + +postgresql_connectable = [ + pytest.param("postgresql_psycopg2_engine", marks=pytest.mark.db), + pytest.param("postgresql_psycopg2_conn", marks=pytest.mark.db), +] + +postgresql_connectable_iris = [ + pytest.param("postgresql_psycopg2_engine_iris", marks=pytest.mark.db), + pytest.param("postgresql_psycopg2_conn_iris", marks=pytest.mark.db), +] + +postgresql_connectable_types = [ + pytest.param("postgresql_psycopg2_engine_types", marks=pytest.mark.db), + pytest.param("postgresql_psycopg2_conn_types", marks=pytest.mark.db), +] + +sqlite_connectable = [ + "sqlite_engine", + "sqlite_conn", + "sqlite_str", +] + +sqlite_connectable_iris = [ + "sqlite_engine_iris", + "sqlite_conn_iris", + "sqlite_str_iris", +] + +sqlite_connectable_types = [ + "sqlite_engine_types", + "sqlite_conn_types", + "sqlite_str_types", +] + +sqlalchemy_connectable = mysql_connectable + postgresql_connectable + sqlite_connectable + +sqlalchemy_connectable_iris = ( + mysql_connectable_iris + postgresql_connectable_iris + sqlite_connectable_iris +) + +sqlalchemy_connectable_types = ( + mysql_connectable_types + postgresql_connectable_types + sqlite_connectable_types +) + +adbc_connectable = [ + "sqlite_adbc_conn", + pytest.param("postgresql_adbc_conn", marks=pytest.mark.db), +] + +adbc_connectable_iris = [ + pytest.param("postgresql_adbc_iris", marks=pytest.mark.db), + "sqlite_adbc_iris", +] + +adbc_connectable_types = [ + pytest.param("postgresql_adbc_types", marks=pytest.mark.db), + "sqlite_adbc_types", +] + + +all_connectable = [*sqlalchemy_connectable, "sqlite_buildin", *adbc_connectable] + +all_connectable_iris = [ + *sqlalchemy_connectable_iris, + "sqlite_buildin_iris", + *adbc_connectable_iris, +] + +all_connectable_types = [ + *sqlalchemy_connectable_types, + "sqlite_buildin_types", + *adbc_connectable_types, +] + + +@pytest.mark.parametrize("conn", all_connectable) +def test_dataframe_to_sql(conn, test_frame1, request): + # GH 51086 if conn is sqlite_engine + conn = request.getfixturevalue(conn) + test_frame1.to_sql(name="test", con=conn, if_exists="append", index=False) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_dataframe_to_sql_empty(conn, test_frame1, request): + if conn == "postgresql_adbc_conn" and not using_string_dtype(): + request.node.add_marker( + pytest.mark.xfail( + reason="postgres ADBC driver < 1.2 cannot insert index with null type", + ) + ) + + # GH 51086 if conn is sqlite_engine + conn = request.getfixturevalue(conn) + empty_df = test_frame1.iloc[:0] + empty_df.to_sql(name="test", con=conn, if_exists="append", index=False) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_dataframe_to_sql_arrow_dtypes(conn, request): + # GH 52046 + pytest.importorskip("pyarrow") + df = DataFrame( + { + "int": pd.array([1], dtype="int8[pyarrow]"), + "datetime": pd.array( + [datetime(2023, 1, 1)], dtype="timestamp[ns][pyarrow]" + ), + "date": pd.array([date(2023, 1, 1)], dtype="date32[day][pyarrow]"), + "timedelta": pd.array([timedelta(1)], dtype="duration[ns][pyarrow]"), + "string": pd.array(["a"], dtype="string[pyarrow]"), + } + ) + + if "adbc" in conn: + if conn == "sqlite_adbc_conn": + df = df.drop(columns=["timedelta"]) + if pa_version_under14p1: + exp_warning = DeprecationWarning + msg = "is_sparse is deprecated" + else: + exp_warning = None + msg = "" + else: + exp_warning = UserWarning + msg = "the 'timedelta'" + + conn = request.getfixturevalue(conn) + with tm.assert_produces_warning(exp_warning, match=msg, check_stacklevel=False): + df.to_sql(name="test_arrow", con=conn, if_exists="replace", index=False) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_dataframe_to_sql_arrow_dtypes_missing(conn, request, nulls_fixture): + # GH 52046 + pytest.importorskip("pyarrow") + if isinstance(nulls_fixture, Decimal): + pytest.skip( + # GH#61773 + reason="Decimal('NaN') not supported in constructor for timestamp dtype" + ) + + df = DataFrame( + { + "datetime": pd.array( + [datetime(2023, 1, 1), nulls_fixture], dtype="timestamp[ns][pyarrow]" + ), + } + ) + conn = request.getfixturevalue(conn) + df.to_sql(name="test_arrow", con=conn, if_exists="replace", index=False) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize("method", [None, "multi"]) +def test_to_sql(conn, method, test_frame1, request): + if method == "multi" and "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'method' not implemented for ADBC drivers", strict=True + ) + ) + + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql(test_frame1, "test_frame", method=method) + assert pandasSQL.has_table("test_frame") + assert count_rows(conn, "test_frame") == len(test_frame1) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize( + "mode, num_row_coef", [("replace", 1), ("append", 2), ("delete_rows", 1)] +) +def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request): + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") + pandasSQL.to_sql(test_frame1, "test_frame", if_exists=mode) + assert pandasSQL.has_table("test_frame") + assert count_rows(conn, "test_frame") == num_row_coef * len(test_frame1) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_to_sql_exist_fail(conn, test_frame1, request): + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") + assert pandasSQL.has_table("test_frame") + + msg = "Table 'test_frame' already exists" + with pytest.raises(ValueError, match=msg): + pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_iris_query(conn, request): + conn = request.getfixturevalue(conn) + iris_frame = read_sql_query("SELECT * FROM iris", conn) + check_iris_frame(iris_frame) + iris_frame = pd.read_sql("SELECT * FROM iris", conn) + check_iris_frame(iris_frame) + iris_frame = pd.read_sql("SELECT * FROM iris where 0=1", conn) + assert iris_frame.shape == (0, 5) + assert "SepalWidth" in iris_frame.columns + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_iris_query_chunksize(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'chunksize' not implemented for ADBC drivers", + strict=True, + ) + ) + conn = request.getfixturevalue(conn) + iris_frame = concat(read_sql_query("SELECT * FROM iris", conn, chunksize=7)) + check_iris_frame(iris_frame) + iris_frame = concat(pd.read_sql("SELECT * FROM iris", conn, chunksize=7)) + check_iris_frame(iris_frame) + iris_frame = concat(pd.read_sql("SELECT * FROM iris where 0=1", conn, chunksize=7)) + assert iris_frame.shape == (0, 5) + assert "SepalWidth" in iris_frame.columns + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_read_iris_query_expression_with_parameter(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'chunksize' not implemented for ADBC drivers", + strict=True, + ) + ) + conn = request.getfixturevalue(conn) + from sqlalchemy import ( + MetaData, + Table, + create_engine, + select, + ) + + metadata = MetaData() + autoload_con = create_engine(conn) if isinstance(conn, str) else conn + iris = Table("iris", metadata, autoload_with=autoload_con) + iris_frame = read_sql_query( + select(iris), conn, params={"name": "Iris-setosa", "length": 5.1} + ) + check_iris_frame(iris_frame) + if isinstance(conn, str): + autoload_con.dispose() + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_iris_query_string_with_parameter(conn, request, sql_strings): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'chunksize' not implemented for ADBC drivers", + strict=True, + ) + ) + + for db, query in sql_strings["read_parameters"].items(): + if db in conn: + break + else: + raise KeyError(f"No part of {conn} found in sql_strings['read_parameters']") + conn = request.getfixturevalue(conn) + iris_frame = read_sql_query(query, conn, params=("Iris-setosa", 5.1)) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_read_iris_table(conn, request): + # GH 51015 if conn = sqlite_iris_str + conn = request.getfixturevalue(conn) + iris_frame = read_sql_table("iris", conn) + check_iris_frame(iris_frame) + iris_frame = pd.read_sql("iris", conn) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_read_iris_table_chunksize(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail(reason="chunksize argument NotImplemented with ADBC") + ) + conn = request.getfixturevalue(conn) + iris_frame = concat(read_sql_table("iris", conn, chunksize=7)) + check_iris_frame(iris_frame) + iris_frame = concat(pd.read_sql("iris", conn, chunksize=7)) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_to_sql_callable(conn, test_frame1, request): + conn = request.getfixturevalue(conn) + + check = [] # used to double check function below is really being used + + def sample(pd_table, conn, keys, data_iter): + check.append(1) + data = [dict(zip(keys, row)) for row in data_iter] + conn.execute(pd_table.table.insert(), data) + + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: + pandasSQL.to_sql(test_frame1, "test_frame", method=sample) + assert pandasSQL.has_table("test_frame") + assert check == [1] + assert count_rows(conn, "test_frame") == len(test_frame1) + + +@pytest.mark.parametrize("conn", all_connectable_types) +def test_default_type_conversion(conn, request): + conn_name = conn + if conn_name == "sqlite_buildin_types": + request.applymarker( + pytest.mark.xfail( + reason="sqlite_buildin connection does not implement read_sql_table" + ) + ) + + conn = request.getfixturevalue(conn) + df = sql.read_sql_table("types", conn) + + assert issubclass(df.FloatCol.dtype.type, np.floating) + assert issubclass(df.IntCol.dtype.type, np.integer) + + # MySQL/sqlite has no real BOOL type + if "postgresql" in conn_name: + assert issubclass(df.BoolCol.dtype.type, np.bool_) + else: + assert issubclass(df.BoolCol.dtype.type, np.integer) + + # Int column with NA values stays as float + assert issubclass(df.IntColWithNull.dtype.type, np.floating) + + # Bool column with NA = int column with NA values => becomes float + if "postgresql" in conn_name: + assert issubclass(df.BoolColWithNull.dtype.type, object) + else: + assert issubclass(df.BoolColWithNull.dtype.type, np.floating) + + +@pytest.mark.parametrize("conn", mysql_connectable) +def test_read_procedure(conn, request): + conn = request.getfixturevalue(conn) + + # GH 7324 + # Although it is more an api test, it is added to the + # mysql tests as sqlite does not have stored procedures + from sqlalchemy import text + from sqlalchemy.engine import Engine + + df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3]}) + df.to_sql(name="test_frame", con=conn, index=False) + + proc = """DROP PROCEDURE IF EXISTS get_testdb; + + CREATE PROCEDURE get_testdb () + + BEGIN + SELECT * FROM test_frame; + END""" + proc = text(proc) + if isinstance(conn, Engine): + with conn.connect() as engine_conn: + with engine_conn.begin(): + engine_conn.execute(proc) + else: + with conn.begin(): + conn.execute(proc) + + res1 = sql.read_sql_query("CALL get_testdb();", conn) + tm.assert_frame_equal(df, res1) + + # test delegation to read_sql_query + res2 = sql.read_sql("CALL get_testdb();", conn) + tm.assert_frame_equal(df, res2) + + +@pytest.mark.parametrize("conn", postgresql_connectable) +@pytest.mark.parametrize("expected_count", [2, "Success!"]) +def test_copy_from_callable_insertion_method(conn, expected_count, request): + # GH 8953 + # Example in io.rst found under _io.sql.method + # not available in sqlite, mysql + def psql_insert_copy(table, conn, keys, data_iter): + # gets a DBAPI connection that can provide a cursor + dbapi_conn = conn.connection + with dbapi_conn.cursor() as cur: + s_buf = StringIO() + writer = csv.writer(s_buf) + writer.writerows(data_iter) + s_buf.seek(0) + + columns = ", ".join([f'"{k}"' for k in keys]) + if table.schema: + table_name = f"{table.schema}.{table.name}" + else: + table_name = table.name + + sql_query = f"COPY {table_name} ({columns}) FROM STDIN WITH CSV" + cur.copy_expert(sql=sql_query, file=s_buf) + return expected_count + + conn = request.getfixturevalue(conn) + expected = DataFrame({"col1": [1, 2], "col2": [0.1, 0.2], "col3": ["a", "n"]}) + result_count = expected.to_sql( + name="test_frame", con=conn, index=False, method=psql_insert_copy + ) + # GH 46891 + if expected_count is None: + assert result_count is None + else: + assert result_count == expected_count + result = sql.read_sql_table("test_frame", conn) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", postgresql_connectable) +def test_insertion_method_on_conflict_do_nothing(conn, request): + # GH 15988: Example in to_sql docstring + conn = request.getfixturevalue(conn) + + from sqlalchemy.dialects.postgresql import insert + from sqlalchemy.engine import Engine + from sqlalchemy.sql import text + + def insert_on_conflict(table, conn, keys, data_iter): + data = [dict(zip(keys, row)) for row in data_iter] + stmt = ( + insert(table.table) + .values(data) + .on_conflict_do_nothing(index_elements=["a"]) + ) + result = conn.execute(stmt) + return result.rowcount + + create_sql = text( + """ + CREATE TABLE test_insert_conflict ( + a integer PRIMARY KEY, + b numeric, + c text + ); + """ + ) + if isinstance(conn, Engine): + with conn.connect() as con: + with con.begin(): + con.execute(create_sql) + else: + with conn.begin(): + conn.execute(create_sql) + + expected = DataFrame([[1, 2.1, "a"]], columns=list("abc")) + expected.to_sql( + name="test_insert_conflict", con=conn, if_exists="append", index=False + ) + + df_insert = DataFrame([[1, 3.2, "b"]], columns=list("abc")) + inserted = df_insert.to_sql( + name="test_insert_conflict", + con=conn, + index=False, + if_exists="append", + method=insert_on_conflict, + ) + result = sql.read_sql_table("test_insert_conflict", conn) + tm.assert_frame_equal(result, expected) + assert inserted == 0 + + # Cleanup + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_insert_conflict") + + +@pytest.mark.parametrize("conn", all_connectable) +def test_to_sql_on_public_schema(conn, request): + if "sqlite" in conn or "mysql" in conn: + request.applymarker( + pytest.mark.xfail( + reason="test for public schema only specific to postgresql" + ) + ) + + conn = request.getfixturevalue(conn) + + test_data = DataFrame([[1, 2.1, "a"], [2, 3.1, "b"]], columns=list("abc")) + test_data.to_sql( + name="test_public_schema", + con=conn, + if_exists="append", + index=False, + schema="public", + ) + + df_out = sql.read_sql_table("test_public_schema", conn, schema="public") + tm.assert_frame_equal(test_data, df_out) + + +@pytest.mark.parametrize("conn", mysql_connectable) +def test_insertion_method_on_conflict_update(conn, request): + # GH 14553: Example in to_sql docstring + conn = request.getfixturevalue(conn) + + from sqlalchemy.dialects.mysql import insert + from sqlalchemy.engine import Engine + from sqlalchemy.sql import text + + def insert_on_conflict(table, conn, keys, data_iter): + data = [dict(zip(keys, row)) for row in data_iter] + stmt = insert(table.table).values(data) + stmt = stmt.on_duplicate_key_update(b=stmt.inserted.b, c=stmt.inserted.c) + result = conn.execute(stmt) + return result.rowcount + + create_sql = text( + """ + CREATE TABLE test_insert_conflict ( + a INT PRIMARY KEY, + b FLOAT, + c VARCHAR(10) + ); + """ + ) + if isinstance(conn, Engine): + with conn.connect() as con: + with con.begin(): + con.execute(create_sql) + else: + with conn.begin(): + conn.execute(create_sql) + + df = DataFrame([[1, 2.1, "a"]], columns=list("abc")) + df.to_sql(name="test_insert_conflict", con=conn, if_exists="append", index=False) + + expected = DataFrame([[1, 3.2, "b"]], columns=list("abc")) + inserted = expected.to_sql( + name="test_insert_conflict", + con=conn, + index=False, + if_exists="append", + method=insert_on_conflict, + ) + result = sql.read_sql_table("test_insert_conflict", conn) + tm.assert_frame_equal(result, expected) + assert inserted == 2 + + # Cleanup + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_insert_conflict") + + +@pytest.mark.parametrize("conn", postgresql_connectable) +def test_read_view_postgres(conn, request): + # GH 52969 + conn = request.getfixturevalue(conn) + + from sqlalchemy.engine import Engine + from sqlalchemy.sql import text + + table_name = f"group_{uuid.uuid4().hex}" + view_name = f"group_view_{uuid.uuid4().hex}" + + sql_stmt = text( + f""" + CREATE TABLE {table_name} ( + group_id INTEGER, + name TEXT + ); + INSERT INTO {table_name} VALUES + (1, 'name'); + CREATE VIEW {view_name} + AS + SELECT * FROM {table_name}; + """ + ) + if isinstance(conn, Engine): + with conn.connect() as con: + with con.begin(): + con.execute(sql_stmt) + else: + with conn.begin(): + conn.execute(sql_stmt) + result = read_sql_table(view_name, conn) + expected = DataFrame({"group_id": [1], "name": "name"}) + tm.assert_frame_equal(result, expected) + + +def test_read_view_sqlite(sqlite_buildin): + # GH 52969 + create_table = """ +CREATE TABLE groups ( + group_id INTEGER, + name TEXT +); +""" + insert_into = """ +INSERT INTO groups VALUES + (1, 'name'); +""" + create_view = """ +CREATE VIEW group_view +AS +SELECT * FROM groups; +""" + sqlite_buildin.execute(create_table) + sqlite_buildin.execute(insert_into) + sqlite_buildin.execute(create_view) + result = pd.read_sql("SELECT * FROM group_view", sqlite_buildin) + expected = DataFrame({"group_id": [1], "name": "name"}) + tm.assert_frame_equal(result, expected) + + +def flavor(conn_name): + if "postgresql" in conn_name: + return "postgresql" + elif "sqlite" in conn_name: + return "sqlite" + elif "mysql" in conn_name: + return "mysql" + + raise ValueError(f"unsupported connection: {conn_name}") + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_sql_iris_parameter(conn, request, sql_strings): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'params' not implemented for ADBC drivers", + strict=True, + ) + ) + conn_name = conn + conn = request.getfixturevalue(conn) + query = sql_strings["read_parameters"][flavor(conn_name)] + params = ("Iris-setosa", 5.1) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + iris_frame = pandasSQL.read_query(query, params=params) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_sql_iris_named_parameter(conn, request, sql_strings): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'params' not implemented for ADBC drivers", + strict=True, + ) + ) + + conn_name = conn + conn = request.getfixturevalue(conn) + query = sql_strings["read_named_parameters"][flavor(conn_name)] + params = {"name": "Iris-setosa", "length": 5.1} + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + iris_frame = pandasSQL.read_query(query, params=params) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_sql_iris_no_parameter_with_percent(conn, request, sql_strings): + if "mysql" in conn or ("postgresql" in conn and "adbc" not in conn): + request.applymarker(pytest.mark.xfail(reason="broken test")) + + conn_name = conn + conn = request.getfixturevalue(conn) + + query = sql_strings["read_no_parameters_with_percent"][flavor(conn_name)] + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + iris_frame = pandasSQL.read_query(query, params=None) + check_iris_frame(iris_frame) + + +# ----------------------------------------------------------------------------- +# -- Testing the public API + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_api_read_sql_view(conn, request): + conn = request.getfixturevalue(conn) + iris_frame = sql.read_sql_query("SELECT * FROM iris_view", conn) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_api_read_sql_with_chunksize_no_result(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail(reason="chunksize argument NotImplemented with ADBC") + ) + conn = request.getfixturevalue(conn) + query = 'SELECT * FROM iris_view WHERE "SepalLength" < 0.0' + with_batch = sql.read_sql_query(query, conn, chunksize=5) + without_batch = sql.read_sql_query(query, conn) + tm.assert_frame_equal(concat(with_batch), without_batch) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql(conn, request, test_frame1): + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame1", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame1") + + sql.to_sql(test_frame1, "test_frame1", conn) + assert sql.has_table("test_frame1", conn) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql_fail(conn, request, test_frame1): + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame2", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame2") + + sql.to_sql(test_frame1, "test_frame2", conn, if_exists="fail") + assert sql.has_table("test_frame2", conn) + + msg = "Table 'test_frame2' already exists" + with pytest.raises(ValueError, match=msg): + sql.to_sql(test_frame1, "test_frame2", conn, if_exists="fail") + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql_replace(conn, request, test_frame1): + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame3", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame3") + + sql.to_sql(test_frame1, "test_frame3", conn, if_exists="fail") + # Add to table again + sql.to_sql(test_frame1, "test_frame3", conn, if_exists="replace") + assert sql.has_table("test_frame3", conn) + + num_entries = len(test_frame1) + num_rows = count_rows(conn, "test_frame3") + + assert num_rows == num_entries + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql_append(conn, request, test_frame1): + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame4", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame4") + + assert sql.to_sql(test_frame1, "test_frame4", conn, if_exists="fail") == 4 + + # Add to table again + assert sql.to_sql(test_frame1, "test_frame4", conn, if_exists="append") == 4 + assert sql.has_table("test_frame4", conn) + + num_entries = 2 * len(test_frame1) + num_rows = count_rows(conn, "test_frame4") + + assert num_rows == num_entries + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql_type_mapping(conn, request, test_frame3): + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame5", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame5") + + sql.to_sql(test_frame3, "test_frame5", conn, index=False) + result = sql.read_sql("SELECT * FROM test_frame5", conn) + + tm.assert_frame_equal(test_frame3, result) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql_series(conn, request): + conn = request.getfixturevalue(conn) + if sql.has_table("test_series", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_series") + + s = Series(np.arange(5, dtype="int64"), name="series") + sql.to_sql(s, "test_series", conn, index=False) + s2 = sql.read_sql_query("SELECT * FROM test_series", conn) + tm.assert_frame_equal(s.to_frame(), s2) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_roundtrip(conn, request, test_frame1): + conn_name = conn + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame_roundtrip", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame_roundtrip") + + sql.to_sql(test_frame1, "test_frame_roundtrip", con=conn) + result = sql.read_sql_query("SELECT * FROM test_frame_roundtrip", con=conn) + + # HACK! + if "adbc" in conn_name: + result = result.drop(columns="__index_level_0__") + else: + result = result.drop(columns="level_0") + tm.assert_frame_equal(result, test_frame1) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_roundtrip_chunksize(conn, request, test_frame1): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail(reason="chunksize argument NotImplemented with ADBC") + ) + conn = request.getfixturevalue(conn) + if sql.has_table("test_frame_roundtrip", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_frame_roundtrip") + + sql.to_sql( + test_frame1, + "test_frame_roundtrip", + con=conn, + index=False, + chunksize=2, + ) + result = sql.read_sql_query("SELECT * FROM test_frame_roundtrip", con=conn) + tm.assert_frame_equal(result, test_frame1) + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_api_execute_sql(conn, request): + # drop_sql = "DROP TABLE IF EXISTS test" # should already be done + conn = request.getfixturevalue(conn) + with sql.pandasSQL_builder(conn) as pandas_sql: + iris_results = pandas_sql.execute("SELECT * FROM iris") + row = iris_results.fetchone() + iris_results.close() + assert list(row) == [5.1, 3.5, 1.4, 0.2, "Iris-setosa"] + + +@pytest.mark.parametrize("conn", all_connectable_types) +def test_api_date_parsing(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + # Test date parsing in read_sql + # No Parsing + df = sql.read_sql_query("SELECT * FROM types", conn) + if not ("mysql" in conn_name or "postgres" in conn_name): + assert not issubclass(df.DateCol.dtype.type, np.datetime64) + + df = sql.read_sql_query("SELECT * FROM types", conn, parse_dates=["DateCol"]) + assert issubclass(df.DateCol.dtype.type, np.datetime64) + assert df.DateCol.tolist() == [ + Timestamp(2000, 1, 3, 0, 0, 0), + Timestamp(2000, 1, 4, 0, 0, 0), + ] + + df = sql.read_sql_query( + "SELECT * FROM types", + conn, + parse_dates={"DateCol": "%Y-%m-%d %H:%M:%S"}, + ) + assert issubclass(df.DateCol.dtype.type, np.datetime64) + assert df.DateCol.tolist() == [ + Timestamp(2000, 1, 3, 0, 0, 0), + Timestamp(2000, 1, 4, 0, 0, 0), + ] + + df = sql.read_sql_query("SELECT * FROM types", conn, parse_dates=["IntDateCol"]) + assert issubclass(df.IntDateCol.dtype.type, np.datetime64) + assert df.IntDateCol.tolist() == [ + Timestamp(1986, 12, 25, 0, 0, 0), + Timestamp(2013, 1, 1, 0, 0, 0), + ] + + df = sql.read_sql_query( + "SELECT * FROM types", conn, parse_dates={"IntDateCol": "s"} + ) + assert issubclass(df.IntDateCol.dtype.type, np.datetime64) + assert df.IntDateCol.tolist() == [ + Timestamp(1986, 12, 25, 0, 0, 0), + Timestamp(2013, 1, 1, 0, 0, 0), + ] + + df = sql.read_sql_query( + "SELECT * FROM types", + conn, + parse_dates={"IntDateOnlyCol": "%Y%m%d"}, + ) + assert issubclass(df.IntDateOnlyCol.dtype.type, np.datetime64) + assert df.IntDateOnlyCol.tolist() == [ + Timestamp("2010-10-10"), + Timestamp("2010-12-12"), + ] + + +@pytest.mark.parametrize("conn", all_connectable_types) +@pytest.mark.parametrize("error", ["raise", "coerce"]) +@pytest.mark.parametrize( + "read_sql, text, mode", + [ + (sql.read_sql, "SELECT * FROM types", ("sqlalchemy", "fallback")), + (sql.read_sql, "types", ("sqlalchemy")), + ( + sql.read_sql_query, + "SELECT * FROM types", + ("sqlalchemy", "fallback"), + ), + (sql.read_sql_table, "types", ("sqlalchemy")), + ], +) +def test_api_custom_dateparsing_error( + conn, request, read_sql, text, mode, error, types_data_frame +): + conn_name = conn + conn = request.getfixturevalue(conn) + if text == "types" and conn_name == "sqlite_buildin_types": + request.applymarker( + pytest.mark.xfail(reason="failing combination of arguments") + ) + + expected = types_data_frame.astype({"DateCol": "datetime64[us]"}) + + result = read_sql( + text, + con=conn, + parse_dates={ + "DateCol": {"errors": error}, + }, + ) + if "postgres" in conn_name: + # TODO: clean up types_data_frame fixture + result["BoolCol"] = result["BoolCol"].astype(int) + result["BoolColWithNull"] = result["BoolColWithNull"].astype(float) + + if conn_name == "postgresql_adbc_types": + expected = expected.astype( + { + "IntDateCol": "int32", + "IntDateOnlyCol": "int32", + "IntCol": "int32", + } + ) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", all_connectable_types) +def test_api_date_and_index(conn, request): + # Test case where same column appears in parse_date and index_col + conn = request.getfixturevalue(conn) + df = sql.read_sql_query( + "SELECT * FROM types", + conn, + index_col="DateCol", + parse_dates=["DateCol", "IntDateCol"], + ) + + assert issubclass(df.index.dtype.type, np.datetime64) + assert issubclass(df.IntDateCol.dtype.type, np.datetime64) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_timedelta(conn, request): + # see #6921 + conn_name = conn + conn = request.getfixturevalue(conn) + if sql.has_table("test_timedelta", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_timedelta") + + df = to_timedelta(Series(["00:00:01", "00:00:03"], name="foo")).to_frame() + + if conn_name == "sqlite_adbc_conn": + request.node.add_marker( + pytest.mark.xfail( + reason="sqlite ADBC driver doesn't implement timedelta", + ) + ) + + if "adbc" in conn_name: + if pa_version_under14p1: + exp_warning = DeprecationWarning + else: + exp_warning = None + else: + exp_warning = UserWarning + + with tm.assert_produces_warning(exp_warning, check_stacklevel=False): + result_count = df.to_sql(name="test_timedelta", con=conn) + assert result_count == 2 + result = sql.read_sql_query("SELECT * FROM test_timedelta", conn) + + if conn_name == "postgresql_adbc_conn": + # TODO: Postgres stores an INTERVAL, which ADBC reads as a Month-Day-Nano + # Interval; the default pandas type mapper maps this to a DateOffset + # but maybe we should try and restore the timedelta here? + expected = Series( + [ + pd.DateOffset(months=0, days=0, microseconds=1000000, nanoseconds=0), + pd.DateOffset(months=0, days=0, microseconds=3000000, nanoseconds=0), + ], + name="foo", + ) + else: + expected = df["foo"].astype("int64") + tm.assert_series_equal(result["foo"], expected) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_complex_raises(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + df = DataFrame({"a": [1 + 1j, 2j]}) + + if "adbc" in conn_name: + msg = "datatypes not supported" + else: + msg = "Complex datatypes not supported" + with pytest.raises(ValueError, match=msg): + assert df.to_sql("test_complex", con=conn) is None + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize( + "index_name,index_label,expected", + [ + # no index name, defaults to 'index' + (None, None, "index"), + # specifying index_label + (None, "other_label", "other_label"), + # using the index name + ("index_name", None, "index_name"), + # has index name, but specifying index_label + ("index_name", "other_label", "other_label"), + # index name is integer + (0, None, "0"), + # index name is None but index label is integer + (None, 0, "0"), + ], +) +def test_api_to_sql_index_label(conn, request, index_name, index_label, expected): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail(reason="index_label argument NotImplemented with ADBC") + ) + conn = request.getfixturevalue(conn) + if sql.has_table("test_index_label", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_index_label") + + temp_frame = DataFrame({"col1": range(4)}) + temp_frame.index.name = index_name + query = "SELECT * FROM test_index_label" + sql.to_sql(temp_frame, "test_index_label", conn, index_label=index_label) + frame = sql.read_sql_query(query, conn) + assert frame.columns[0] == expected + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_to_sql_index_label_multiindex(conn, request): + conn_name = conn + if "mysql" in conn_name: + request.applymarker( + pytest.mark.xfail( + reason="MySQL can fail using TEXT without length as key", strict=False + ) + ) + elif "adbc" in conn_name: + request.node.add_marker( + pytest.mark.xfail(reason="index_label argument NotImplemented with ADBC") + ) + + conn = request.getfixturevalue(conn) + if sql.has_table("test_index_label", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_index_label") + + expected_row_count = 4 + temp_frame = DataFrame( + {"col1": range(4)}, + index=MultiIndex.from_product([("A0", "A1"), ("B0", "B1")]), + ) + + # no index name, defaults to 'level_0' and 'level_1' + result = sql.to_sql(temp_frame, "test_index_label", conn) + assert result == expected_row_count + frame = sql.read_sql_query("SELECT * FROM test_index_label", conn) + assert frame.columns[0] == "level_0" + assert frame.columns[1] == "level_1" + + # specifying index_label + result = sql.to_sql( + temp_frame, + "test_index_label", + conn, + if_exists="replace", + index_label=["A", "B"], + ) + assert result == expected_row_count + frame = sql.read_sql_query("SELECT * FROM test_index_label", conn) + assert frame.columns[:2].tolist() == ["A", "B"] + + # using the index name + temp_frame.index.names = ["A", "B"] + result = sql.to_sql(temp_frame, "test_index_label", conn, if_exists="replace") + assert result == expected_row_count + frame = sql.read_sql_query("SELECT * FROM test_index_label", conn) + assert frame.columns[:2].tolist() == ["A", "B"] + + # has index name, but specifying index_label + result = sql.to_sql( + temp_frame, + "test_index_label", + conn, + if_exists="replace", + index_label=["C", "D"], + ) + assert result == expected_row_count + frame = sql.read_sql_query("SELECT * FROM test_index_label", conn) + assert frame.columns[:2].tolist() == ["C", "D"] + + msg = "Length of 'index_label' should match number of levels, which is 2" + with pytest.raises(ValueError, match=msg): + sql.to_sql( + temp_frame, + "test_index_label", + conn, + if_exists="replace", + index_label="C", + ) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_multiindex_roundtrip(conn, request): + conn = request.getfixturevalue(conn) + if sql.has_table("test_multiindex_roundtrip", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_multiindex_roundtrip") + + df = DataFrame.from_records( + [(1, 2.1, "line1"), (2, 1.5, "line2")], + columns=["A", "B", "C"], + index=["A", "B"], + ) + + df.to_sql(name="test_multiindex_roundtrip", con=conn) + result = sql.read_sql_query( + "SELECT * FROM test_multiindex_roundtrip", conn, index_col=["A", "B"] + ) + tm.assert_frame_equal(df, result, check_index_type=True) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize( + "dtype", + [ + None, + int, + float, + {"A": int, "B": float}, + ], +) +def test_api_dtype_argument(conn, request, dtype): + # GH10285 Add dtype argument to read_sql_query + conn_name = conn + conn = request.getfixturevalue(conn) + if sql.has_table("test_dtype_argument", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_dtype_argument") + + df = DataFrame([[1.2, 3.4], [5.6, 7.8]], columns=["A", "B"]) + assert df.to_sql(name="test_dtype_argument", con=conn) == 2 + + expected = df.astype(dtype) + + if "postgres" in conn_name: + query = 'SELECT "A", "B" FROM test_dtype_argument' + else: + query = "SELECT A, B FROM test_dtype_argument" + result = sql.read_sql_query(query, con=conn, dtype=dtype) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_integer_col_names(conn, request): + conn = request.getfixturevalue(conn) + df = DataFrame([[1, 2], [3, 4]], columns=[0, 1]) + sql.to_sql(df, "test_frame_integer_col_names", conn, if_exists="replace") + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_get_schema(conn, request, test_frame1): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'get_schema' not implemented for ADBC drivers", + strict=True, + ) + ) + conn = request.getfixturevalue(conn) + create_sql = sql.get_schema(test_frame1, "test", con=conn) + assert "CREATE" in create_sql + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_get_schema_with_schema(conn, request, test_frame1): + # GH28486 + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'get_schema' not implemented for ADBC drivers", + strict=True, + ) + ) + conn = request.getfixturevalue(conn) + create_sql = sql.get_schema(test_frame1, "test", con=conn, schema="pypi") + assert "CREATE TABLE pypi." in create_sql + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_get_schema_dtypes(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'get_schema' not implemented for ADBC drivers", + strict=True, + ) + ) + conn_name = conn + conn = request.getfixturevalue(conn) + float_frame = DataFrame({"a": [1.1, 1.2], "b": [2.1, 2.2]}) + + if conn_name == "sqlite_buildin": + dtype = "INTEGER" + else: + from sqlalchemy import Integer + + dtype = Integer + create_sql = sql.get_schema(float_frame, "test", con=conn, dtype={"b": dtype}) + assert "CREATE" in create_sql + assert "INTEGER" in create_sql + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_get_schema_keys(conn, request, test_frame1): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="'get_schema' not implemented for ADBC drivers", + strict=True, + ) + ) + conn_name = conn + conn = request.getfixturevalue(conn) + frame = DataFrame({"Col1": [1.1, 1.2], "Col2": [2.1, 2.2]}) + create_sql = sql.get_schema(frame, "test", con=conn, keys="Col1") + + if "mysql" in conn_name: + constraint_sentence = "CONSTRAINT test_pk PRIMARY KEY (`Col1`)" + else: + constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("Col1")' + assert constraint_sentence in create_sql + + # multiple columns as key (GH10385) + create_sql = sql.get_schema(test_frame1, "test", con=conn, keys=["A", "B"]) + if "mysql" in conn_name: + constraint_sentence = "CONSTRAINT test_pk PRIMARY KEY (`A`, `B`)" + else: + constraint_sentence = 'CONSTRAINT test_pk PRIMARY KEY ("A", "B")' + assert constraint_sentence in create_sql + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_chunksize_read(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail(reason="chunksize argument NotImplemented with ADBC") + ) + conn_name = conn + conn = request.getfixturevalue(conn) + if sql.has_table("test_chunksize", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_chunksize") + + df = DataFrame( + np.random.default_rng(2).standard_normal((22, 5)), columns=list("abcde") + ) + df.to_sql(name="test_chunksize", con=conn, index=False) + + # reading the query in one time + res1 = sql.read_sql_query("select * from test_chunksize", conn) + + # reading the query in chunks with read_sql_query + res2 = DataFrame() + i = 0 + sizes = [5, 5, 5, 5, 2] + + for chunk in sql.read_sql_query("select * from test_chunksize", conn, chunksize=5): + res2 = concat([res2, chunk], ignore_index=True) + assert len(chunk) == sizes[i] + i += 1 + + tm.assert_frame_equal(res1, res2) + + # reading the query in chunks with read_sql_query + if conn_name == "sqlite_buildin": + with pytest.raises(NotImplementedError, match="^$"): + sql.read_sql_table("test_chunksize", conn, chunksize=5) + else: + res3 = DataFrame() + i = 0 + sizes = [5, 5, 5, 5, 2] + + for chunk in sql.read_sql_table("test_chunksize", conn, chunksize=5): + res3 = concat([res3, chunk], ignore_index=True) + assert len(chunk) == sizes[i] + i += 1 + + tm.assert_frame_equal(res1, res3) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_categorical(conn, request): + if conn == "postgresql_adbc_conn": + adbc = import_optional_dependency("adbc_driver_postgresql", errors="ignore") + if adbc is not None and Version(adbc.__version__) < Version("0.9.0"): + request.node.add_marker( + pytest.mark.xfail( + reason="categorical dtype not implemented for ADBC postgres driver", + strict=True, + ) + ) + # GH8624 + # test that categorical gets written correctly as dense column + conn = request.getfixturevalue(conn) + if sql.has_table("test_categorical", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_categorical") + + df = DataFrame( + { + "person_id": [1, 2, 3], + "person_name": ["John P. Doe", "Jane Dove", "John P. Doe"], + } + ) + df2 = df.copy() + df2["person_name"] = df2["person_name"].astype("category") + + df2.to_sql(name="test_categorical", con=conn, index=False) + res = sql.read_sql_query("SELECT * FROM test_categorical", conn) + + tm.assert_frame_equal(res, df) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_unicode_column_name(conn, request): + # GH 11431 + conn = request.getfixturevalue(conn) + if sql.has_table("test_unicode", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_unicode") + + df = DataFrame([[1, 2], [3, 4]], columns=["\xe9", "b"]) + df.to_sql(name="test_unicode", con=conn, index=False) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_escaped_table_name(conn, request): + # GH 13206 + conn_name = conn + conn = request.getfixturevalue(conn) + if sql.has_table("d1187b08-4943-4c8d-a7f6", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("d1187b08-4943-4c8d-a7f6") + + df = DataFrame({"A": [0, 1, 2], "B": [0.2, np.nan, 5.6]}) + df.to_sql(name="d1187b08-4943-4c8d-a7f6", con=conn, index=False) + + if "postgres" in conn_name: + query = 'SELECT * FROM "d1187b08-4943-4c8d-a7f6"' + else: + query = "SELECT * FROM `d1187b08-4943-4c8d-a7f6`" + res = sql.read_sql_query(query, conn) + + tm.assert_frame_equal(res, df) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_api_read_sql_duplicate_columns(conn, request): + # GH#53117 + if "adbc" in conn: + pa = pytest.importorskip("pyarrow") + if not ( + Version(pa.__version__) >= Version("16.0") + and conn in ["sqlite_adbc_conn", "postgresql_adbc_conn"] + ): + request.node.add_marker( + pytest.mark.xfail( + reason="pyarrow->pandas throws ValueError", strict=True + ) + ) + conn = request.getfixturevalue(conn) + if sql.has_table("test_table", conn): + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("test_table") + + df = DataFrame({"a": [1, 2, 3], "b": [0.1, 0.2, 0.3], "c": 1}) + df.to_sql(name="test_table", con=conn, index=False) + + result = pd.read_sql("SELECT a, b, a +1 as a, c FROM test_table", conn) + expected = DataFrame( + [[1, 0.1, 2, 1], [2, 0.2, 3, 1], [3, 0.3, 4, 1]], + columns=["a", "b", "a", "c"], + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_read_table_columns(conn, request, test_frame1): + # test columns argument in read_table + conn_name = conn + if conn_name == "sqlite_buildin": + request.applymarker(pytest.mark.xfail(reason="Not Implemented")) + + conn = request.getfixturevalue(conn) + sql.to_sql(test_frame1, "test_frame", conn) + + cols = ["A", "B"] + + result = sql.read_sql_table("test_frame", conn, columns=cols) + assert result.columns.tolist() == cols + + +@pytest.mark.parametrize("conn", all_connectable) +def test_read_table_index_col(conn, request, test_frame1): + # test columns argument in read_table + conn_name = conn + if conn_name == "sqlite_buildin": + request.applymarker(pytest.mark.xfail(reason="Not Implemented")) + + conn = request.getfixturevalue(conn) + sql.to_sql(test_frame1, "test_frame", conn) + + result = sql.read_sql_table("test_frame", conn, index_col="index") + assert result.index.names == ["index"] + + result = sql.read_sql_table("test_frame", conn, index_col=["A", "B"]) + assert result.index.names == ["A", "B"] + + result = sql.read_sql_table( + "test_frame", conn, index_col=["A", "B"], columns=["C", "D"] + ) + assert result.index.names == ["A", "B"] + assert result.columns.tolist() == ["C", "D"] + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_sql_delegate(conn, request): + if conn == "sqlite_buildin_iris": + request.applymarker( + pytest.mark.xfail( + reason="sqlite_buildin connection does not implement read_sql_table" + ) + ) + + conn = request.getfixturevalue(conn) + iris_frame1 = sql.read_sql_query("SELECT * FROM iris", conn) + iris_frame2 = sql.read_sql("SELECT * FROM iris", conn) + tm.assert_frame_equal(iris_frame1, iris_frame2) + + iris_frame1 = sql.read_sql_table("iris", conn) + iris_frame2 = sql.read_sql("iris", conn) + tm.assert_frame_equal(iris_frame1, iris_frame2) + + +def test_not_reflect_all_tables(sqlite_conn): + conn = sqlite_conn + from sqlalchemy import text + from sqlalchemy.engine import Engine + + # create invalid table + query_list = [ + text("CREATE TABLE invalid (x INTEGER, y UNKNOWN);"), + text("CREATE TABLE other_table (x INTEGER, y INTEGER);"), + ] + + for query in query_list: + if isinstance(conn, Engine): + with conn.connect() as conn: + with conn.begin(): + conn.execute(query) + else: + with conn.begin(): + conn.execute(query) + + with tm.assert_produces_warning(None): + sql.read_sql_table("other_table", conn) + sql.read_sql_query("SELECT * FROM other_table", conn) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_warning_case_insensitive_table_name(conn, request, test_frame1): + conn_name = conn + if conn_name == "sqlite_buildin" or "adbc" in conn_name: + request.applymarker(pytest.mark.xfail(reason="Does not raise warning")) + + conn = request.getfixturevalue(conn) + # see gh-7815 + with tm.assert_produces_warning( + UserWarning, + match=( + r"The provided table name 'TABLE1' is not found exactly as such in " + r"the database after writing the table, possibly due to case " + r"sensitivity issues. Consider using lower case table names." + ), + ): + with sql.SQLDatabase(conn) as db: + db.check_case_sensitive("TABLE1", "") + + # Test that the warning is certainly NOT triggered in a normal case. + with tm.assert_produces_warning(None): + test_frame1.to_sql(name="CaseSensitive", con=conn) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_sqlalchemy_type_mapping(conn, request): + conn = request.getfixturevalue(conn) + from sqlalchemy import TIMESTAMP + + # Test Timestamp objects (no datetime64 because of timezone) (GH9085) + df = DataFrame( + {"time": to_datetime(["2014-12-12 01:54", "2014-12-11 02:54"], utc=True)} + ) + with sql.SQLDatabase(conn) as db: + table = sql.SQLTable("test_type", db, frame=df) + # GH 9086: TIMESTAMP is the suggested type for datetimes with timezones + assert isinstance(table.table.c["time"].type, TIMESTAMP) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +@pytest.mark.parametrize( + "integer, expected", + [ + ("int8", "SMALLINT"), + ("Int8", "SMALLINT"), + ("uint8", "SMALLINT"), + ("UInt8", "SMALLINT"), + ("int16", "SMALLINT"), + ("Int16", "SMALLINT"), + ("uint16", "INTEGER"), + ("UInt16", "INTEGER"), + ("int32", "INTEGER"), + ("Int32", "INTEGER"), + ("uint32", "BIGINT"), + ("UInt32", "BIGINT"), + ("int64", "BIGINT"), + ("Int64", "BIGINT"), + (int, "BIGINT" if np.dtype(int).name == "int64" else "INTEGER"), + ], +) +def test_sqlalchemy_integer_mapping(conn, request, integer, expected): + # GH35076 Map pandas integer to optimal SQLAlchemy integer type + conn = request.getfixturevalue(conn) + df = DataFrame([0, 1], columns=["a"], dtype=integer) + with sql.SQLDatabase(conn) as db: + table = sql.SQLTable("test_type", db, frame=df) + + result = str(table.table.c.a.type) + assert result == expected + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +@pytest.mark.parametrize("integer", ["uint64", "UInt64"]) +def test_sqlalchemy_integer_overload_mapping(conn, request, integer): + conn = request.getfixturevalue(conn) + # GH35076 Map pandas integer to optimal SQLAlchemy integer type + df = DataFrame([0, 1], columns=["a"], dtype=integer) + with sql.SQLDatabase(conn) as db: + with pytest.raises( + ValueError, match="Unsigned 64 bit integer datatype is not supported" + ): + sql.SQLTable("test_type", db, frame=df) + + +def test_database_uri_string(temp_file, request, test_frame1): + pytest.importorskip("sqlalchemy") + # Test read_sql and .to_sql method with a database URI (GH10654) + # db_uri = 'sqlite:///:memory:' # raises + # sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) near + # "iris": syntax error [SQL: 'iris'] + name = str(temp_file) + db_uri = "sqlite:///" + name + table = "iris" + test_frame1.to_sql(name=table, con=db_uri, if_exists="replace", index=False) + test_frame2 = sql.read_sql(table, db_uri) + test_frame3 = sql.read_sql_table(table, db_uri) + query = "SELECT * FROM iris" + test_frame4 = sql.read_sql_query(query, db_uri) + tm.assert_frame_equal(test_frame1, test_frame2) + tm.assert_frame_equal(test_frame1, test_frame3) + tm.assert_frame_equal(test_frame1, test_frame4) + + +@td.skip_if_installed("pg8000") +def test_pg8000_sqlalchemy_passthrough_error(request): + pytest.importorskip("sqlalchemy") + # using driver that will not be installed on CI to trigger error + # in sqlalchemy.create_engine -> test passing of this error to user + db_uri = "postgresql+pg8000://user:pass@host/dbname" + with pytest.raises(ImportError, match="pg8000"): + sql.read_sql("select * from table", db_uri) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_query_by_text_obj(conn, request): + # WIP : GH10846 + conn_name = conn + conn = request.getfixturevalue(conn) + from sqlalchemy import text + + if "postgres" in conn_name: + name_text = text('select * from iris where "Name"=:name') + else: + name_text = text("select * from iris where name=:name") + iris_df = sql.read_sql(name_text, conn, params={"name": "Iris-versicolor"}) + all_names = set(iris_df["Name"]) + assert all_names == {"Iris-versicolor"} + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_query_by_select_obj(conn, request): + conn = request.getfixturevalue(conn) + # WIP : GH10846 + from sqlalchemy import ( + bindparam, + select, + ) + + iris = iris_table_metadata() + name_select = select(iris).where(iris.c.Name == bindparam("name")) + iris_df = sql.read_sql(name_select, conn, params={"name": "Iris-setosa"}) + all_names = set(iris_df["Name"]) + assert all_names == {"Iris-setosa"} + + +@pytest.mark.parametrize("conn", all_connectable) +def test_column_with_percentage(conn, request): + # GH 37157 + conn_name = conn + if conn_name == "sqlite_buildin": + request.applymarker(pytest.mark.xfail(reason="Not Implemented")) + + conn = request.getfixturevalue(conn) + df = DataFrame({"A": [0, 1, 2], "%_variation": [3, 4, 5]}) + df.to_sql(name="test_column_percentage", con=conn, index=False) + + res = sql.read_sql_table("test_column_percentage", conn) + + tm.assert_frame_equal(res, df) + + +def test_sql_open_close(temp_file, test_frame3): + # Test if the IO in the database still work if the connection closed + # between the writing and reading (as in many real situations). + + with contextlib.closing(sqlite3.connect(temp_file)) as conn: + assert sql.to_sql(test_frame3, "test_frame3_legacy", conn, index=False) == 4 + + with contextlib.closing(sqlite3.connect(temp_file)) as conn: + result = sql.read_sql_query("SELECT * FROM test_frame3_legacy;", conn) + + tm.assert_frame_equal(test_frame3, result) + + +@td.skip_if_installed("sqlalchemy") +def test_con_string_import_error(): + conn = "mysql://root@localhost/pandas" + msg = "Using URI string without sqlalchemy installed" + with pytest.raises(ImportError, match=msg): + sql.read_sql("SELECT * FROM iris", conn) + + +@td.skip_if_installed("sqlalchemy") +def test_con_unknown_dbapi2_class_does_not_error_without_sql_alchemy_installed(): + class MockSqliteConnection: + def __init__(self, *args, **kwargs) -> None: + self.conn = sqlite3.Connection(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self.conn, name) + + def close(self): + self.conn.close() + + with contextlib.closing(MockSqliteConnection(":memory:")) as conn: + with tm.assert_produces_warning(UserWarning, match="only supports SQLAlchemy"): + sql.read_sql("SELECT 1", conn) + + +def test_sqlite_read_sql_delegate(sqlite_buildin_iris): + conn = sqlite_buildin_iris + iris_frame1 = sql.read_sql_query("SELECT * FROM iris", conn) + iris_frame2 = sql.read_sql("SELECT * FROM iris", conn) + tm.assert_frame_equal(iris_frame1, iris_frame2) + + msg = "Execution failed on sql 'iris': near \"iris\": syntax error" + with pytest.raises(sql.DatabaseError, match=msg): + sql.read_sql("iris", conn) + + +def test_get_schema2(test_frame1): + # without providing a connection object (available for backwards comp) + create_sql = sql.get_schema(test_frame1, "test") + assert "CREATE" in create_sql + + +def test_sqlite_type_mapping(sqlite_buildin): + # Test Timestamp objects (no datetime64 because of timezone) (GH9085) + conn = sqlite_buildin + df = DataFrame( + {"time": to_datetime(["2014-12-12 01:54", "2014-12-11 02:54"], utc=True)} + ) + db = sql.SQLiteDatabase(conn) + table = sql.SQLiteTable("test_type", db, frame=df) + schema = table.sql_schema() + for col in schema.split("\n"): + if col.split()[0].strip('"') == "time": + assert col.split()[1] == "TIMESTAMP" + + +# ----------------------------------------------------------------------------- +# -- Database flavor specific tests + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_create_table(conn, request): + if conn == "sqlite_str": + pytest.skip("sqlite_str has no inspection system") + + conn = request.getfixturevalue(conn) + + from sqlalchemy import inspect + + temp_frame = DataFrame({"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}) + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4 + + insp = inspect(conn) + assert insp.has_table("temp_frame") + + # Cleanup + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("temp_frame") + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_drop_table(conn, request): + if conn == "sqlite_str": + pytest.skip("sqlite_str has no inspection system") + + conn = request.getfixturevalue(conn) + + from sqlalchemy import inspect + + temp_frame = DataFrame({"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}) + with sql.SQLDatabase(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4 + + insp = inspect(conn) + assert insp.has_table("temp_frame") + + with pandasSQL.run_transaction(): + pandasSQL.drop_table("temp_frame") + try: + insp.clear_cache() # needed with SQLAlchemy 2.0, unavailable prior + except AttributeError: + pass + assert not insp.has_table("temp_frame") + + +@pytest.mark.parametrize("conn_name", all_connectable) +def test_delete_rows_success(conn_name, test_frame1, request): + table_name = "temp_delete_rows_frame" + conn = request.getfixturevalue(conn_name) + + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(test_frame1, table_name) == test_frame1.shape[0] + + with pandasSQL.run_transaction(): + assert pandasSQL.delete_rows(table_name) is None + + assert count_rows(conn, table_name) == 0 + assert pandasSQL.has_table(table_name) + + +@pytest.mark.parametrize("conn_name", all_connectable) +def test_delete_rows_is_atomic(conn_name, request): + sqlalchemy = pytest.importorskip("sqlalchemy") + + table_name = "temp_delete_rows_atomic_frame" + table_stmt = f"CREATE TABLE {table_name} (a INTEGER, b INTEGER UNIQUE NOT NULL)" + + if conn_name != "sqlite_buildin" and "adbc" not in conn_name: + table_stmt = sqlalchemy.text(table_stmt) + + # setting dtype is mandatory for adbc related tests + original_df = DataFrame({"a": [1, 2], "b": [3, 4]}, dtype="int32") + replacing_df = DataFrame({"a": [5, 6, 7], "b": [8, 8, 8]}, dtype="int32") + + conn = request.getfixturevalue(conn_name) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction() as cur: + cur.execute(table_stmt) + + with pandasSQL.run_transaction(): + pandasSQL.to_sql(original_df, table_name, if_exists="append", index=False) + + # inserting duplicated values in a UNIQUE constraint column + with pytest.raises(pd.errors.DatabaseError): + with pandasSQL.run_transaction(): + pandasSQL.to_sql( + replacing_df, table_name, if_exists="delete_rows", index=False + ) + + # failed "delete_rows" is rolled back preserving original data + with pandasSQL.run_transaction(): + result_df = pandasSQL.read_query( + f"SELECT * FROM {table_name}", dtype="int32" + ) + tm.assert_frame_equal(result_df, original_df) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_roundtrip(conn, request, test_frame1): + if conn == "sqlite_str": + pytest.skip("sqlite_str has no inspection system") + + conn_name = conn + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(test_frame1, "test_frame_roundtrip") == 4 + result = pandasSQL.read_query("SELECT * FROM test_frame_roundtrip") + + if "adbc" in conn_name: + result = result.rename(columns={"__index_level_0__": "level_0"}) + result.set_index("level_0", inplace=True) + # result.index.astype(int) + + result.index.name = None + + tm.assert_frame_equal(result, test_frame1) + + +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_execute_sql(conn, request): + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + iris_results = pandasSQL.execute("SELECT * FROM iris") + row = iris_results.fetchone() + iris_results.close() + assert list(row) == [5.1, 3.5, 1.4, 0.2, "Iris-setosa"] + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_sqlalchemy_read_table(conn, request): + conn = request.getfixturevalue(conn) + iris_frame = sql.read_sql_table("iris", con=conn) + check_iris_frame(iris_frame) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_sqlalchemy_read_table_columns(conn, request): + conn = request.getfixturevalue(conn) + iris_frame = sql.read_sql_table( + "iris", con=conn, columns=["SepalLength", "SepalLength"] + ) + tm.assert_index_equal(iris_frame.columns, Index(["SepalLength", "SepalLength__1"])) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_read_table_absent_raises(conn, request): + conn = request.getfixturevalue(conn) + msg = "Table this_doesnt_exist not found" + with pytest.raises(ValueError, match=msg): + sql.read_sql_table("this_doesnt_exist", con=conn) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_types) +def test_sqlalchemy_default_type_conversion(conn, request): + conn_name = conn + if conn_name == "sqlite_str": + pytest.skip("types tables not created in sqlite_str fixture") + elif "mysql" in conn_name or "sqlite" in conn_name: + request.applymarker( + pytest.mark.xfail(reason="boolean dtype not inferred properly") + ) + + conn = request.getfixturevalue(conn) + df = sql.read_sql_table("types", conn) + + assert issubclass(df.FloatCol.dtype.type, np.floating) + assert issubclass(df.IntCol.dtype.type, np.integer) + assert issubclass(df.BoolCol.dtype.type, np.bool_) + + # Int column with NA values stays as float + assert issubclass(df.IntColWithNull.dtype.type, np.floating) + # Bool column with NA values becomes object + assert issubclass(df.BoolColWithNull.dtype.type, object) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_bigint(conn, request): + # int64 should be converted to BigInteger, GH7433 + conn = request.getfixturevalue(conn) + df = DataFrame(data={"i64": [2**62]}) + assert df.to_sql(name="test_bigint", con=conn, index=False) == 1 + result = sql.read_sql_table("test_bigint", conn) + + tm.assert_frame_equal(df, result) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_types) +def test_default_date_load(conn, request): + conn_name = conn + if conn_name == "sqlite_str": + pytest.skip("types tables not created in sqlite_str fixture") + elif "sqlite" in conn_name: + request.applymarker( + pytest.mark.xfail(reason="sqlite does not read date properly") + ) + + conn = request.getfixturevalue(conn) + df = sql.read_sql_table("types", conn) + + assert issubclass(df.DateCol.dtype.type, np.datetime64) + + +@pytest.mark.parametrize("conn", postgresql_connectable) +@pytest.mark.parametrize("parse_dates", [None, ["DateColWithTz"]]) +def test_datetime_with_timezone_query(conn, request, parse_dates): + # edge case that converts postgresql datetime with time zone types + # to datetime64[ns,psycopg2.tz.FixedOffsetTimezone..], which is ok + # but should be more natural, so coerce to datetime64[ns] for now + conn = request.getfixturevalue(conn) + expected = create_and_load_postgres_datetz(conn) + + # GH11216 + df = read_sql_query("select * from datetz", conn, parse_dates=parse_dates) + col = df.DateColWithTz + tm.assert_series_equal(col, expected) + + +@pytest.mark.parametrize("conn", postgresql_connectable) +def test_datetime_with_timezone_query_chunksize(conn, request): + conn = request.getfixturevalue(conn) + expected = create_and_load_postgres_datetz(conn) + + df = concat( + list(read_sql_query("select * from datetz", conn, chunksize=1)), + ignore_index=True, + ) + col = df.DateColWithTz + tm.assert_series_equal(col, expected) + + +@pytest.mark.parametrize("conn", postgresql_connectable) +def test_datetime_with_timezone_table(conn, request): + conn = request.getfixturevalue(conn) + expected = create_and_load_postgres_datetz(conn) + result = sql.read_sql_table("datetz", conn) + + exp_frame = expected.to_frame() + tm.assert_frame_equal(result, exp_frame) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_datetime_with_timezone_roundtrip(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + # GH 9086 + # Write datetimetz data to a db and read it back + # For dbs that support timestamps with timezones, should get back UTC + # otherwise naive data should be returned + expected = DataFrame( + {"A": date_range("2013-01-01 09:00:00", periods=3, tz="US/Pacific", unit="us")} + ) + assert expected.to_sql(name="test_datetime_tz", con=conn, index=False) == 3 + + if "postgresql" in conn_name: + # SQLAlchemy "timezones" (i.e. offsets) are coerced to UTC + expected["A"] = expected["A"].dt.tz_convert("UTC") + else: + # Otherwise, timestamps are returned as local, naive + expected["A"] = expected["A"].dt.tz_localize(None) + + result = sql.read_sql_table("test_datetime_tz", conn) + tm.assert_frame_equal(result, expected) + + result = sql.read_sql_query("SELECT * FROM test_datetime_tz", conn) + if "sqlite" in conn_name: + # read_sql_query does not return datetime type like read_sql_table + assert isinstance(result.loc[0, "A"], str) + result["A"] = to_datetime(result["A"]).dt.as_unit("us") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_out_of_bounds_datetime(conn, request): + # GH 26761 + conn = request.getfixturevalue(conn) + data = DataFrame({"date": datetime(9999, 1, 1)}, index=[0]) + assert data.to_sql(name="test_datetime_obb", con=conn, index=False) == 1 + result = sql.read_sql_table("test_datetime_obb", conn) + expected = DataFrame( + np.array([datetime(9999, 1, 1)], dtype="M8[us]"), columns=["date"] + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_naive_datetimeindex_roundtrip(conn, request): + # GH 23510 + # Ensure that a naive DatetimeIndex isn't converted to UTC + conn = request.getfixturevalue(conn) + dates = date_range("2018-01-01", periods=5, freq="6h", unit="us")._with_freq(None) + expected = DataFrame({"nums": range(5)}, index=dates) + assert expected.to_sql(name="foo_table", con=conn, index_label="info_date") == 5 + result = sql.read_sql_table("foo_table", conn, index_col="info_date") + # result index with gain a name from a set_index operation; expected + tm.assert_frame_equal(result, expected, check_names=False) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable_types) +def test_date_parsing(conn, request): + # No Parsing + conn_name = conn + conn = request.getfixturevalue(conn) + df = sql.read_sql_table("types", conn) + expected_type = object if "sqlite" in conn_name else np.datetime64 + assert issubclass(df.DateCol.dtype.type, expected_type) + + df = sql.read_sql_table("types", conn, parse_dates=["DateCol"]) + assert issubclass(df.DateCol.dtype.type, np.datetime64) + + df = sql.read_sql_table("types", conn, parse_dates={"DateCol": "%Y-%m-%d %H:%M:%S"}) + assert issubclass(df.DateCol.dtype.type, np.datetime64) + + df = sql.read_sql_table( + "types", + conn, + parse_dates={"DateCol": {"format": "%Y-%m-%d %H:%M:%S"}}, + ) + assert issubclass(df.DateCol.dtype.type, np.datetime64) + + df = sql.read_sql_table("types", conn, parse_dates=["IntDateCol"]) + assert issubclass(df.IntDateCol.dtype.type, np.datetime64) + + df = sql.read_sql_table("types", conn, parse_dates={"IntDateCol": "s"}) + assert issubclass(df.IntDateCol.dtype.type, np.datetime64) + + df = sql.read_sql_table("types", conn, parse_dates={"IntDateCol": {"unit": "s"}}) + assert issubclass(df.IntDateCol.dtype.type, np.datetime64) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_datetime(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + df = DataFrame( + {"A": date_range("2013-01-01 09:00:00", periods=3), "B": np.arange(3.0)} + ) + assert df.to_sql(name="test_datetime", con=conn) == 3 + + # with read_table -> type information from schema used + result = sql.read_sql_table("test_datetime", conn) + result = result.drop("index", axis=1) + + expected = df[:] + expected["A"] = expected["A"].astype("M8[us]") + tm.assert_frame_equal(result, expected) + + # with read_sql -> no type information -> sqlite has no native + result = sql.read_sql_query("SELECT * FROM test_datetime", conn) + result = result.drop("index", axis=1) + if "sqlite" in conn_name: + assert isinstance(result.loc[0, "A"], str) + result["A"] = to_datetime(result["A"]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_datetime_NaT(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + df = DataFrame( + {"A": date_range("2013-01-01 09:00:00", periods=3), "B": np.arange(3.0)} + ) + df.loc[1, "A"] = np.nan + assert df.to_sql(name="test_datetime", con=conn, index=False) == 3 + + # with read_table -> type information from schema used + result = sql.read_sql_table("test_datetime", conn) + expected = df[:] + expected["A"] = expected["A"].astype("M8[us]") + tm.assert_frame_equal(result, expected) + + # with read_sql -> no type information -> sqlite has no native + result = sql.read_sql_query("SELECT * FROM test_datetime", conn) + if "sqlite" in conn_name: + assert isinstance(result.loc[0, "A"], str) + result["A"] = to_datetime(result["A"], errors="coerce") + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_datetime_date(conn, request): + # test support for datetime.date + conn = request.getfixturevalue(conn) + df = DataFrame([date(2014, 1, 1), date(2014, 1, 2)], columns=["a"]) + assert df.to_sql(name="test_date", con=conn, index=False) == 2 + res = read_sql_table("test_date", conn) + result = res["a"] + expected = to_datetime(df["a"]) + # comes back as datetime64 + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_datetime_time(conn, request, sqlite_buildin): + # test support for datetime.time + conn_name = conn + conn = request.getfixturevalue(conn) + df = DataFrame([time(9, 0, 0), time(9, 1, 30)], columns=["a"]) + assert df.to_sql(name="test_time", con=conn, index=False) == 2 + res = read_sql_table("test_time", conn) + tm.assert_frame_equal(res, df) + + # GH8341 + # first, use the fallback to have the sqlite adapter put in place + sqlite_conn = sqlite_buildin + assert sql.to_sql(df, "test_time2", sqlite_conn, index=False) == 2 + res = sql.read_sql_query("SELECT * FROM test_time2", sqlite_conn) + ref = df.map(lambda _: _.strftime("%H:%M:%S.%f")) + tm.assert_frame_equal(ref, res) # check if adapter is in place + # then test if sqlalchemy is unaffected by the sqlite adapter + assert sql.to_sql(df, "test_time3", conn, index=False) == 2 + if "sqlite" in conn_name: + res = sql.read_sql_query("SELECT * FROM test_time3", conn) + ref = df.map(lambda _: _.strftime("%H:%M:%S.%f")) + tm.assert_frame_equal(ref, res) + res = sql.read_sql_table("test_time3", conn) + tm.assert_frame_equal(df, res) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_mixed_dtype_insert(conn, request): + # see GH6509 + conn = request.getfixturevalue(conn) + s1 = Series(2**25 + 1, dtype=np.int32) + s2 = Series(0.0, dtype=np.float32) + df = DataFrame({"s1": s1, "s2": s2}) + + # write and read again + assert df.to_sql(name="test_read_write", con=conn, index=False) == 1 + df2 = sql.read_sql_table("test_read_write", conn) + + tm.assert_frame_equal(df, df2, check_dtype=False, check_exact=True) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_nan_numeric(conn, request): + # NaNs in numeric float column + conn = request.getfixturevalue(conn) + df = DataFrame({"A": [0, 1, 2], "B": [0.2, np.nan, 5.6]}) + assert df.to_sql(name="test_nan", con=conn, index=False) == 3 + + # with read_table + result = sql.read_sql_table("test_nan", conn) + tm.assert_frame_equal(result, df) + + # with read_sql + result = sql.read_sql_query("SELECT * FROM test_nan", conn) + tm.assert_frame_equal(result, df) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_nan_fullcolumn(conn, request): + # full NaN column (numeric float column) + conn = request.getfixturevalue(conn) + df = DataFrame({"A": [0, 1, 2], "B": [np.nan, np.nan, np.nan]}) + assert df.to_sql(name="test_nan", con=conn, index=False) == 3 + + # with read_table + result = sql.read_sql_table("test_nan", conn) + tm.assert_frame_equal(result, df) + + # with read_sql -> not type info from table -> stays None + df["B"] = df["B"].astype("object") + df["B"] = None + result = sql.read_sql_query("SELECT * FROM test_nan", conn) + tm.assert_frame_equal(result, df) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_nan_string(conn, request): + # NaNs in string column + conn = request.getfixturevalue(conn) + df = DataFrame({"A": [0, 1, 2], "B": ["a", "b", np.nan]}) + assert df.to_sql(name="test_nan", con=conn, index=False) == 3 + + # NaNs are coming back as None + df.loc[2, "B"] = None + + # with read_table + result = sql.read_sql_table("test_nan", conn) + tm.assert_frame_equal(result, df) + + # with read_sql + result = sql.read_sql_query("SELECT * FROM test_nan", conn) + tm.assert_frame_equal(result, df) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_to_sql_save_index(conn, request): + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail( + reason="ADBC implementation does not create index", strict=True + ) + ) + conn_name = conn + conn = request.getfixturevalue(conn) + df = DataFrame.from_records( + [(1, 2.1, "line1"), (2, 1.5, "line2")], columns=["A", "B", "C"], index=["A"] + ) + + tbl_name = "test_to_sql_saves_index" + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(df, tbl_name) == 2 + + if conn_name in {"sqlite_buildin", "sqlite_str"}: + ixs = sql.read_sql_query( + "SELECT * FROM sqlite_master WHERE type = 'index' " + f"AND tbl_name = '{tbl_name}'", + conn, + ) + ix_cols = [] + for ix_name in ixs.name: + ix_info = sql.read_sql_query(f"PRAGMA index_info({ix_name})", conn) + ix_cols.append(ix_info.name.tolist()) + else: + from sqlalchemy import inspect + + insp = inspect(conn) + + ixs = insp.get_indexes(tbl_name) + ix_cols = [i["column_names"] for i in ixs] + + assert ix_cols == [["A"]] + + +@pytest.mark.parametrize("conn", all_connectable) +def test_transactions(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + + stmt = "CREATE TABLE test_trans (A INT, B TEXT)" + if conn_name != "sqlite_buildin" and "adbc" not in conn_name: + from sqlalchemy import text + + stmt = text(stmt) + + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction() as trans: + trans.execute(stmt) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_transaction_rollback(conn, request): + conn_name = conn + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction() as trans: + stmt = "CREATE TABLE test_trans (A INT, B TEXT)" + if "adbc" in conn_name or isinstance(pandasSQL, SQLiteDatabase): + trans.execute(stmt) + else: + from sqlalchemy import text + + stmt = text(stmt) + trans.execute(stmt) + + class DummyException(Exception): + pass + + # Make sure when transaction is rolled back, no rows get inserted + ins_sql = "INSERT INTO test_trans (A,B) VALUES (1, 'blah')" + if isinstance(pandasSQL, SQLDatabase): + from sqlalchemy import text + + ins_sql = text(ins_sql) + try: + with pandasSQL.run_transaction() as trans: + trans.execute(ins_sql) + raise DummyException("error") + except DummyException: + # ignore raised exception + pass + with pandasSQL.run_transaction(): + res = pandasSQL.read_query("SELECT * FROM test_trans") + assert len(res) == 0 + + # Make sure when transaction is committed, rows do get inserted + with pandasSQL.run_transaction() as trans: + trans.execute(ins_sql) + res2 = pandasSQL.read_query("SELECT * FROM test_trans") + assert len(res2) == 1 + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_get_schema_create_table(conn, request, test_frame3): + # Use a dataframe without a bool column, since MySQL converts bool to + # TINYINT (which read_sql_table returns as an int and causes a dtype + # mismatch) + if conn == "sqlite_str": + request.applymarker( + pytest.mark.xfail(reason="test does not support sqlite_str fixture") + ) + + conn = request.getfixturevalue(conn) + + from sqlalchemy import text + from sqlalchemy.engine import Engine + + tbl = "test_get_schema_create_table" + create_sql = sql.get_schema(test_frame3, tbl, con=conn) + blank_test_df = test_frame3.iloc[:0] + + create_sql = text(create_sql) + if isinstance(conn, Engine): + with conn.connect() as newcon: + with newcon.begin(): + newcon.execute(create_sql) + else: + conn.execute(create_sql) + returned_df = sql.read_sql_table(tbl, conn) + tm.assert_frame_equal(returned_df, blank_test_df, check_index_type=False) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_dtype(conn, request): + if conn == "sqlite_str": + pytest.skip("sqlite_str has no inspection system") + + conn = request.getfixturevalue(conn) + + from sqlalchemy import ( + TEXT, + String, + ) + from sqlalchemy.schema import MetaData + + cols = ["A", "B"] + data = [(0.8, True), (0.9, None)] + df = DataFrame(data, columns=cols) + assert df.to_sql(name="dtype_test", con=conn) == 2 + assert df.to_sql(name="dtype_test2", con=conn, dtype={"B": TEXT}) == 2 + meta = MetaData() + meta.reflect(bind=conn) + sqltype = meta.tables["dtype_test2"].columns["B"].type + assert isinstance(sqltype, TEXT) + msg = "The type of B is not a SQLAlchemy type" + with pytest.raises(ValueError, match=msg): + df.to_sql(name="error", con=conn, dtype={"B": str}) + + # GH9083 + assert df.to_sql(name="dtype_test3", con=conn, dtype={"B": String(10)}) == 2 + meta.reflect(bind=conn) + sqltype = meta.tables["dtype_test3"].columns["B"].type + assert isinstance(sqltype, String) + assert sqltype.length == 10 + + # single dtype + assert df.to_sql(name="single_dtype_test", con=conn, dtype=TEXT) == 2 + meta.reflect(bind=conn) + sqltypea = meta.tables["single_dtype_test"].columns["A"].type + sqltypeb = meta.tables["single_dtype_test"].columns["B"].type + assert isinstance(sqltypea, TEXT) + assert isinstance(sqltypeb, TEXT) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_notna_dtype(conn, request): + if conn == "sqlite_str": + pytest.skip("sqlite_str has no inspection system") + + conn_name = conn + conn = request.getfixturevalue(conn) + + from sqlalchemy import ( + Boolean, + DateTime, + Float, + Integer, + ) + from sqlalchemy.schema import MetaData + + cols = { + "Bool": Series([True, None]), + "Date": Series([datetime(2012, 5, 1), None]), + "Int": Series([1, None], dtype="object"), + "Float": Series([1.1, None]), + } + df = DataFrame(cols) + + tbl = "notna_dtype_test" + assert df.to_sql(name=tbl, con=conn) == 2 + _ = sql.read_sql_table(tbl, conn) + meta = MetaData() + meta.reflect(bind=conn) + my_type = Integer if "mysql" in conn_name else Boolean + col_dict = meta.tables[tbl].columns + assert isinstance(col_dict["Bool"].type, my_type) + assert isinstance(col_dict["Date"].type, DateTime) + assert isinstance(col_dict["Int"].type, Integer) + assert isinstance(col_dict["Float"].type, Float) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_double_precision(conn, request): + if conn == "sqlite_str": + pytest.skip("sqlite_str has no inspection system") + + conn = request.getfixturevalue(conn) + + from sqlalchemy import ( + BigInteger, + Float, + Integer, + ) + from sqlalchemy.schema import MetaData + + V = 1.23456789101112131415 + + df = DataFrame( + { + "f32": Series([V], dtype="float32"), + "f64": Series([V], dtype="float64"), + "f64_as_f32": Series([V], dtype="float64"), + "i32": Series([5], dtype="int32"), + "i64": Series([5], dtype="int64"), + } + ) + + assert ( + df.to_sql( + name="test_dtypes", + con=conn, + index=False, + if_exists="replace", + dtype={"f64_as_f32": Float(precision=23)}, + ) + == 1 + ) + res = sql.read_sql_table("test_dtypes", conn) + + # check precision of float64 + assert np.round(df["f64"].iloc[0], 14) == np.round(res["f64"].iloc[0], 14) + + # check sql types + meta = MetaData() + meta.reflect(bind=conn) + col_dict = meta.tables["test_dtypes"].columns + assert str(col_dict["f32"].type) == str(col_dict["f64_as_f32"].type) + assert isinstance(col_dict["f32"].type, Float) + assert isinstance(col_dict["f64"].type, Float) + assert isinstance(col_dict["i32"].type, Integer) + assert isinstance(col_dict["i64"].type, BigInteger) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_connectable_issue_example(conn, request): + conn = request.getfixturevalue(conn) + + # This tests the example raised in issue + # https://github.com/pandas-dev/pandas/issues/10104 + from sqlalchemy.engine import Engine + + def test_select(connection): + query = "SELECT test_foo_data FROM test_foo_data" + return sql.read_sql_query(query, con=connection) + + def test_append(connection, data): + data.to_sql(name="test_foo_data", con=connection, if_exists="append") + + def test_connectable(conn): + # https://github.com/sqlalchemy/sqlalchemy/commit/ + # 00b5c10846e800304caa86549ab9da373b42fa5d#r48323973 + foo_data = test_select(conn) + test_append(conn, foo_data) + + def main(connectable): + if isinstance(connectable, Engine): + with connectable.connect() as conn: + with conn.begin(): + test_connectable(conn) + else: + test_connectable(connectable) + + assert ( + DataFrame({"test_foo_data": [0, 1, 2]}).to_sql(name="test_foo_data", con=conn) + == 3 + ) + main(conn) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +@pytest.mark.parametrize( + "input", + [{"foo": [np.inf]}, {"foo": [-np.inf]}, {"foo": [-np.inf], "infe0": ["bar"]}], +) +def test_to_sql_with_negative_npinf(conn, request, input): + # GH 34431 + + df = DataFrame(input) + conn_name = conn + conn = request.getfixturevalue(conn) + + if "mysql" in conn_name: + # GH 36465 + # The input {"foo": [-np.inf], "infe0": ["bar"]} does not raise any error + # for pymysql version >= 0.10 + msg = "Execution failed on sql" + with pytest.raises(pd.errors.DatabaseError, match=msg): + df.to_sql(name="foobar", con=conn, index=False) + else: + assert df.to_sql(name="foobar", con=conn, index=False) == 1 + res = sql.read_sql_table("foobar", conn) + tm.assert_equal(df, res) + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_temporary_table(conn, request): + if conn == "sqlite_str": + pytest.skip("test does not work with str connection") + + conn = request.getfixturevalue(conn) + + from sqlalchemy import ( + Column, + Integer, + Unicode, + select, + ) + from sqlalchemy.orm import ( + Session, + declarative_base, + ) + + test_data = "Hello, World!" + expected = DataFrame({"spam": [test_data]}) + Base = declarative_base() + + class Temporary(Base): + __tablename__ = "temp_test" + __table_args__ = {"prefixes": ["TEMPORARY"]} + id = Column(Integer, primary_key=True) + spam = Column(Unicode(30), nullable=False) + + with Session(conn) as session: + with session.begin(): + conn = session.connection() + Temporary.__table__.create(conn) + session.add(Temporary(spam=test_data)) + session.flush() + df = sql.read_sql_query(sql=select(Temporary.spam), con=conn) + tm.assert_frame_equal(df, expected) + + +@pytest.mark.parametrize("conn", all_connectable) +def test_invalid_engine(conn, request, test_frame1): + if conn == "sqlite_buildin" or "adbc" in conn: + request.applymarker( + pytest.mark.xfail( + reason="SQLiteDatabase/ADBCDatabase does not raise for bad engine" + ) + ) + + conn = request.getfixturevalue(conn) + msg = "engine must be one of 'auto', 'sqlalchemy'" + with pandasSQL_builder(conn) as pandasSQL: + with pytest.raises(ValueError, match=msg): + pandasSQL.to_sql(test_frame1, "test_frame1", engine="bad_engine") + + +@pytest.mark.parametrize("conn", all_connectable) +def test_to_sql_with_sql_engine(conn, request, test_frame1): + """`to_sql` with the `engine` param""" + # mostly copied from this class's `_to_sql()` method + conn = request.getfixturevalue(conn) + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(test_frame1, "test_frame1", engine="auto") == 4 + assert pandasSQL.has_table("test_frame1") + + num_entries = len(test_frame1) + num_rows = count_rows(conn, "test_frame1") + assert num_rows == num_entries + + +@pytest.mark.parametrize("conn", sqlalchemy_connectable) +def test_options_sqlalchemy(conn, request, test_frame1): + # use the set option + conn = request.getfixturevalue(conn) + with pd.option_context("io.sql.engine", "sqlalchemy"): + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(test_frame1, "test_frame1") == 4 + assert pandasSQL.has_table("test_frame1") + + num_entries = len(test_frame1) + num_rows = count_rows(conn, "test_frame1") + assert num_rows == num_entries + + +@pytest.mark.parametrize("conn", all_connectable) +def test_options_auto(conn, request, test_frame1): + # use the set option + conn = request.getfixturevalue(conn) + with pd.option_context("io.sql.engine", "auto"): + with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(test_frame1, "test_frame1") == 4 + assert pandasSQL.has_table("test_frame1") + + num_entries = len(test_frame1) + num_rows = count_rows(conn, "test_frame1") + assert num_rows == num_entries + + +def test_options_get_engine(): + pytest.importorskip("sqlalchemy") + assert isinstance(get_engine("sqlalchemy"), SQLAlchemyEngine) + + with pd.option_context("io.sql.engine", "sqlalchemy"): + assert isinstance(get_engine("auto"), SQLAlchemyEngine) + assert isinstance(get_engine("sqlalchemy"), SQLAlchemyEngine) + + with pd.option_context("io.sql.engine", "auto"): + assert isinstance(get_engine("auto"), SQLAlchemyEngine) + assert isinstance(get_engine("sqlalchemy"), SQLAlchemyEngine) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize("func", ["read_sql", "read_sql_query"]) +def test_read_sql_dtype_backend( + conn, + request, + string_storage, + func, + dtype_backend, + dtype_backend_data, + dtype_backend_expected, +): + # GH#50048 + conn_name = conn + conn = request.getfixturevalue(conn) + table = "test" + df = dtype_backend_data + df.to_sql(name=table, con=conn, index=False, if_exists="replace") + + with pd.option_context("mode.string_storage", string_storage): + result = getattr(pd, func)( + f"Select * from {table}", conn, dtype_backend=dtype_backend + ) + expected = dtype_backend_expected(string_storage, dtype_backend, conn_name) + + tm.assert_frame_equal(result, expected) + + if "adbc" in conn_name: + # adbc does not support chunksize argument + request.applymarker( + pytest.mark.xfail(reason="adbc does not support chunksize argument") + ) + + with pd.option_context("mode.string_storage", string_storage): + iterator = getattr(pd, func)( + f"Select * from {table}", + con=conn, + dtype_backend=dtype_backend, + chunksize=3, + ) + expected = dtype_backend_expected(string_storage, dtype_backend, conn_name) + for result in iterator: + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize("func", ["read_sql", "read_sql_table"]) +def test_read_sql_dtype_backend_table( + conn, + request, + string_storage, + func, + dtype_backend, + dtype_backend_data, + dtype_backend_expected, +): + if "sqlite" in conn and "adbc" not in conn: + request.applymarker( + pytest.mark.xfail( + reason=( + "SQLite actually returns proper boolean values via " + "read_sql_table, but before pytest refactor was skipped" + ) + ) + ) + # GH#50048 + conn_name = conn + conn = request.getfixturevalue(conn) + table = "test" + df = dtype_backend_data + df.to_sql(name=table, con=conn, index=False, if_exists="replace") + + with pd.option_context("mode.string_storage", string_storage): + result = getattr(pd, func)(table, conn, dtype_backend=dtype_backend) + expected = dtype_backend_expected(string_storage, dtype_backend, conn_name) + tm.assert_frame_equal(result, expected) + + if "adbc" in conn_name: + # adbc does not support chunksize argument + return + + with pd.option_context("mode.string_storage", string_storage): + iterator = getattr(pd, func)( + table, + conn, + dtype_backend=dtype_backend, + chunksize=3, + ) + expected = dtype_backend_expected(string_storage, dtype_backend, conn_name) + for result in iterator: + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize("func", ["read_sql", "read_sql_table", "read_sql_query"]) +def test_read_sql_invalid_dtype_backend_table(conn, request, func, dtype_backend_data): + conn = request.getfixturevalue(conn) + table = "test" + df = dtype_backend_data + df.to_sql(name=table, con=conn, index=False, if_exists="replace") + + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + with pytest.raises(ValueError, match=msg): + getattr(pd, func)(table, conn, dtype_backend="numpy") + + +@pytest.fixture +def dtype_backend_data() -> DataFrame: + return DataFrame( + { + "a": Series([1, pd.NA, 3], dtype="Int64"), + "b": Series([1, 2, 3], dtype="Int64"), + "c": Series([1.5, pd.NA, 2.5], dtype="Float64"), + "d": Series([1.5, 2.0, 2.5], dtype="Float64"), + "e": [True, False, None], + "f": [True, False, True], + "g": ["a", "b", "c"], + "h": ["a", "b", None], + } + ) + + +@pytest.fixture +def dtype_backend_expected(): + def func(string_storage, dtype_backend, conn_name) -> DataFrame: + string_dtype: pd.StringDtype | pd.ArrowDtype + if dtype_backend == "pyarrow": + pa = pytest.importorskip("pyarrow") + string_dtype = pd.ArrowDtype(pa.string()) + else: + string_dtype = pd.StringDtype(string_storage) + + df = DataFrame( + { + "a": Series([1, pd.NA, 3], dtype="Int64"), + "b": Series([1, 2, 3], dtype="Int64"), + "c": Series([1.5, pd.NA, 2.5], dtype="Float64"), + "d": Series([1.5, 2.0, 2.5], dtype="Float64"), + "e": Series([True, False, pd.NA], dtype="boolean"), + "f": Series([True, False, True], dtype="boolean"), + "g": Series(["a", "b", "c"], dtype=string_dtype), + "h": Series(["a", "b", None], dtype=string_dtype), + } + ) + if dtype_backend == "pyarrow": + pa = pytest.importorskip("pyarrow") + + from pandas.arrays import ArrowExtensionArray + + df = DataFrame( + { + col: ArrowExtensionArray(pa.array(df[col], from_pandas=True)) + for col in df.columns + } + ) + + if "mysql" in conn_name or "sqlite" in conn_name: + if dtype_backend == "numpy_nullable": + df = df.astype({"e": "Int64", "f": "Int64"}) + else: + df = df.astype({"e": "int64[pyarrow]", "f": "int64[pyarrow]"}) + + return df + + return func + + +@pytest.mark.parametrize("conn", all_connectable) +def test_chunksize_empty_dtypes(conn, request): + # GH#50245 + if "adbc" in conn: + request.node.add_marker( + pytest.mark.xfail(reason="chunksize argument NotImplemented with ADBC") + ) + conn = request.getfixturevalue(conn) + dtypes = {"a": "int64", "b": "object"} + df = DataFrame(columns=["a", "b"]).astype(dtypes) + expected = df.copy() + df.to_sql(name="test", con=conn, index=False, if_exists="replace") + + for result in read_sql_query( + "SELECT * FROM test", + conn, + dtype=dtypes, + chunksize=1, + ): + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("conn", all_connectable) +@pytest.mark.parametrize("dtype_backend", [lib.no_default, "numpy_nullable"]) +@pytest.mark.parametrize("func", ["read_sql", "read_sql_query"]) +def test_read_sql_dtype(conn, request, func, dtype_backend): + # GH#50797 + conn = request.getfixturevalue(conn) + table = "test" + df = DataFrame({"a": [1, 2, 3], "b": 5}) + df.to_sql(name=table, con=conn, index=False, if_exists="replace") + + result = getattr(pd, func)( + f"Select * from {table}", + conn, + dtype={"a": np.float64}, + dtype_backend=dtype_backend, + ) + expected = DataFrame( + { + "a": Series([1, 2, 3], dtype=np.float64), + "b": Series( + [5, 5, 5], + dtype="int64" if not dtype_backend == "numpy_nullable" else "Int64", + ), + } + ) + tm.assert_frame_equal(result, expected) + + +def test_bigint_warning(sqlite_engine): + conn = sqlite_engine + # test no warning for BIGINT (to support int64) is raised (GH7433) + df = DataFrame({"a": [1, 2]}, dtype="int64") + assert df.to_sql(name="test_bigintwarning", con=conn, index=False) == 2 + + with tm.assert_produces_warning(None): + sql.read_sql_table("test_bigintwarning", conn) + + +def test_valueerror_exception(sqlite_engine): + conn = sqlite_engine + df = DataFrame({"col1": [1, 2], "col2": [3, 4]}) + with pytest.raises(ValueError, match="Empty table name specified"): + df.to_sql(name="", con=conn, if_exists="replace", index=False) + + +def test_row_object_is_named_tuple(sqlite_engine): + conn = sqlite_engine + # GH 40682 + # Test for the is_named_tuple() function + # Placed here due to its usage of sqlalchemy + + from sqlalchemy import ( + Column, + Integer, + String, + ) + from sqlalchemy.orm import ( + declarative_base, + sessionmaker, + ) + + BaseModel = declarative_base() + + class Test(BaseModel): + __tablename__ = "test_frame" + id = Column(Integer, primary_key=True) + string_column = Column(String(50)) + + with conn.begin(): + BaseModel.metadata.create_all(conn) + Session = sessionmaker(bind=conn) + with Session() as session: + df = DataFrame({"id": [0, 1], "string_column": ["hello", "world"]}) + assert ( + df.to_sql(name="test_frame", con=conn, index=False, if_exists="replace") + == 2 + ) + session.commit() + test_query = session.query(Test.id, Test.string_column) + df = DataFrame(test_query) + + assert list(df.columns) == ["id", "string_column"] + + +def test_read_sql_string_inference(sqlite_engine): + conn = sqlite_engine + # GH#54430 + table = "test" + df = DataFrame({"a": ["x", "y"]}) + df.to_sql(table, con=conn, index=False, if_exists="replace") + + with pd.option_context("future.infer_string", True): + result = read_sql_table(table, conn) + + dtype = pd.StringDtype(na_value=np.nan) + expected = DataFrame( + {"a": ["x", "y"]}, dtype=dtype, columns=Index(["a"], dtype=dtype) + ) + + tm.assert_frame_equal(result, expected) + + +def test_roundtripping_datetimes(sqlite_engine): + conn = sqlite_engine + # GH#54877 + df = DataFrame({"t": [datetime(2020, 12, 31, 12)]}, dtype="datetime64[ns]") + df.to_sql("test", conn, if_exists="replace", index=False) + result = pd.read_sql("select * from test", conn).iloc[0, 0] + assert result == "2020-12-31 12:00:00.000000" + + +@pytest.fixture +def sqlite_builtin_detect_types(): + with contextlib.closing( + sqlite3.connect(":memory:", detect_types=sqlite3.PARSE_DECLTYPES) + ) as closing_conn: + with closing_conn as conn: + yield conn + + +def test_roundtripping_datetimes_detect_types(sqlite_builtin_detect_types): + # https://github.com/pandas-dev/pandas/issues/55554 + conn = sqlite_builtin_detect_types + df = DataFrame({"t": [datetime(2020, 12, 31, 12)]}, dtype="datetime64[ns]") + df.to_sql("test", conn, if_exists="replace", index=False) + result = pd.read_sql("select * from test", conn).iloc[0, 0] + assert result == Timestamp("2020-12-31 12:00:00.000000") + + +@pytest.mark.db +def test_psycopg2_schema_support(postgresql_psycopg2_engine): + conn = postgresql_psycopg2_engine + + # only test this for postgresql (schema's not supported in + # mysql/sqlite) + df = DataFrame({"col1": [1, 2], "col2": [0.1, 0.2], "col3": ["a", "n"]}) + + # create a schema + with conn.connect() as con: + with con.begin(): + con.exec_driver_sql("DROP SCHEMA IF EXISTS other CASCADE;") + con.exec_driver_sql("CREATE SCHEMA other;") + + # write dataframe to different schema's + assert df.to_sql(name="test_schema_public", con=conn, index=False) == 2 + assert ( + df.to_sql( + name="test_schema_public_explicit", + con=conn, + index=False, + schema="public", + ) + == 2 + ) + assert ( + df.to_sql(name="test_schema_other", con=conn, index=False, schema="other") == 2 + ) + + # read dataframes back in + res1 = sql.read_sql_table("test_schema_public", conn) + tm.assert_frame_equal(df, res1) + res2 = sql.read_sql_table("test_schema_public_explicit", conn) + tm.assert_frame_equal(df, res2) + res3 = sql.read_sql_table("test_schema_public_explicit", conn, schema="public") + tm.assert_frame_equal(df, res3) + res4 = sql.read_sql_table("test_schema_other", conn, schema="other") + tm.assert_frame_equal(df, res4) + msg = "Table test_schema_other not found" + with pytest.raises(ValueError, match=msg): + sql.read_sql_table("test_schema_other", conn, schema="public") + + # different if_exists options + + # create a schema + with conn.connect() as con: + with con.begin(): + con.exec_driver_sql("DROP SCHEMA IF EXISTS other CASCADE;") + con.exec_driver_sql("CREATE SCHEMA other;") + + # write dataframe with different if_exists options + assert ( + df.to_sql(name="test_schema_other", con=conn, schema="other", index=False) == 2 + ) + df.to_sql( + name="test_schema_other", + con=conn, + schema="other", + index=False, + if_exists="replace", + ) + assert ( + df.to_sql( + name="test_schema_other", + con=conn, + schema="other", + index=False, + if_exists="append", + ) + == 2 + ) + res = sql.read_sql_table("test_schema_other", conn, schema="other") + tm.assert_frame_equal(concat([df, df], ignore_index=True), res) + + +@pytest.mark.db +def test_self_join_date_columns(postgresql_psycopg2_engine): + # GH 44421 + conn = postgresql_psycopg2_engine + from sqlalchemy.sql import text + + create_table = text( + """ + CREATE TABLE person + ( + id serial constraint person_pkey primary key, + created_dt timestamp with time zone + ); + + INSERT INTO person + VALUES (1, '2021-01-01T00:00:00Z'); + """ + ) + with conn.connect() as con: + with con.begin(): + con.execute(create_table) + + sql_query = ( + 'SELECT * FROM "person" AS p1 INNER JOIN "person" AS p2 ON p1.id = p2.id;' + ) + result = pd.read_sql(sql_query, conn) + expected = DataFrame( + [[1, Timestamp("2021", tz="UTC")] * 2], columns=["id", "created_dt"] * 2 + ) + expected["created_dt"] = expected["created_dt"].astype("M8[us, UTC]") + tm.assert_frame_equal(result, expected) + + # Cleanup + with sql.SQLDatabase(conn, need_transaction=True) as pandasSQL: + pandasSQL.drop_table("person") + + +def test_create_and_drop_table(sqlite_engine): + conn = sqlite_engine + temp_frame = DataFrame({"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}) + with sql.SQLDatabase(conn) as pandasSQL: + with pandasSQL.run_transaction(): + assert pandasSQL.to_sql(temp_frame, "drop_test_frame") == 4 + + assert pandasSQL.has_table("drop_test_frame") + + with pandasSQL.run_transaction(): + pandasSQL.drop_table("drop_test_frame") + + assert not pandasSQL.has_table("drop_test_frame") + + +def test_sqlite_datetime_date(sqlite_buildin): + conn = sqlite_buildin + df = DataFrame([date(2014, 1, 1), date(2014, 1, 2)], columns=["a"]) + assert df.to_sql(name="test_date", con=conn, index=False) == 2 + res = read_sql_query("SELECT * FROM test_date", conn) + # comes back as strings + tm.assert_frame_equal(res, df.astype(str)) + + +@pytest.mark.parametrize("tz_aware", [False, True]) +def test_sqlite_datetime_time(tz_aware, sqlite_buildin): + conn = sqlite_buildin + # test support for datetime.time, GH #8341 + if not tz_aware: + tz_times = [time(9, 0, 0), time(9, 1, 30)] + else: + tz_dt = date_range("2013-01-01 09:00:00", periods=2, tz="US/Pacific") + tz_times = Series(tz_dt.to_pydatetime()).map(lambda dt: dt.timetz()) + + df = DataFrame(tz_times, columns=["a"]) + + assert df.to_sql(name="test_time", con=conn, index=False) == 2 + res = read_sql_query("SELECT * FROM test_time", conn) + # comes back as strings + expected = df.map(lambda _: _.strftime("%H:%M:%S.%f")) + tm.assert_frame_equal(res, expected) + + +def get_sqlite_column_type(conn, table, column): + recs = conn.execute(f"PRAGMA table_info({table})") + for cid, name, ctype, not_null, default, pk in recs: + if name == column: + return ctype + raise ValueError(f"Table {table}, column {column} not found") + + +def test_sqlite_test_dtype(sqlite_buildin): + conn = sqlite_buildin + cols = ["A", "B"] + data = [(0.8, True), (0.9, None)] + df = DataFrame(data, columns=cols) + assert df.to_sql(name="dtype_test", con=conn) == 2 + assert df.to_sql(name="dtype_test2", con=conn, dtype={"B": "STRING"}) == 2 + + # sqlite stores Boolean values as INTEGER + assert get_sqlite_column_type(conn, "dtype_test", "B") == "INTEGER" + + assert get_sqlite_column_type(conn, "dtype_test2", "B") == "STRING" + msg = r"B \(\) not a string" + with pytest.raises(ValueError, match=msg): + df.to_sql(name="error", con=conn, dtype={"B": bool}) + + # single dtype + assert df.to_sql(name="single_dtype_test", con=conn, dtype="STRING") == 2 + assert get_sqlite_column_type(conn, "single_dtype_test", "A") == "STRING" + assert get_sqlite_column_type(conn, "single_dtype_test", "B") == "STRING" + + +def test_sqlite_notna_dtype(sqlite_buildin): + conn = sqlite_buildin + cols = { + "Bool": Series([True, None]), + "Date": Series([datetime(2012, 5, 1), None]), + "Int": Series([1, None], dtype="object"), + "Float": Series([1.1, None]), + } + df = DataFrame(cols) + + tbl = "notna_dtype_test" + assert df.to_sql(name=tbl, con=conn) == 2 + + assert get_sqlite_column_type(conn, tbl, "Bool") == "INTEGER" + assert get_sqlite_column_type(conn, tbl, "Date") == "TIMESTAMP" + assert get_sqlite_column_type(conn, tbl, "Int") == "INTEGER" + assert get_sqlite_column_type(conn, tbl, "Float") == "REAL" + + +def test_sqlite_illegal_names(sqlite_buildin): + # For sqlite, these should work fine + conn = sqlite_buildin + df = DataFrame([[1, 2], [3, 4]], columns=["a", "b"]) + + msg = "Empty table or column name specified" + with pytest.raises(ValueError, match=msg): + df.to_sql(name="", con=conn) + + for ndx, weird_name in enumerate( + [ + "test_weird_name]", + "test_weird_name[", + "test_weird_name`", + 'test_weird_name"', + "test_weird_name'", + "_b.test_weird_name_01-30", + '"_b.test_weird_name_01-30"', + "99beginswithnumber", + "12345", + "\xe9", + ] + ): + assert df.to_sql(name=weird_name, con=conn) == 2 + sql.table_exists(weird_name, conn) + + df2 = DataFrame([[1, 2], [3, 4]], columns=["a", weird_name]) + c_tbl = f"test_weird_col_name{ndx:d}" + assert df2.to_sql(name=c_tbl, con=conn) == 2 + sql.table_exists(c_tbl, conn) + + +def format_query(sql, *args): + _formatters = { + datetime: "'{}'".format, + str: "'{}'".format, + np.str_: "'{}'".format, + bytes: "'{}'".format, + float: "{:.8f}".format, + int: "{:d}".format, + type(None): lambda x: "NULL", + np.float64: "{:.10f}".format, + bool: "'{!s}'".format, + } + processed_args = [] + for arg in args: + if isinstance(arg, float) and isna(arg): + arg = None + + formatter = _formatters[type(arg)] + processed_args.append(formatter(arg)) + + return sql % tuple(processed_args) + + +def tquery(query, con=None): + """Replace removed sql.tquery function""" + with sql.pandasSQL_builder(con) as pandas_sql: + res = pandas_sql.execute(query).fetchall() + return None if res is None else list(res) + + +def test_xsqlite_basic(sqlite_buildin): + frame = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD")), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + assert sql.to_sql(frame, name="test_table", con=sqlite_buildin, index=False) == 10 + result = sql.read_sql("select * from test_table", sqlite_buildin) + + # HACK! Change this once indexes are handled properly. + result.index = frame.index + + expected = frame + tm.assert_frame_equal(result, frame) + + frame["txt"] = ["a"] * len(frame) + frame2 = frame.copy() + new_idx = Index(np.arange(len(frame2)), dtype=np.int64) + 10 + frame2["Idx"] = new_idx.copy() + assert sql.to_sql(frame2, name="test_table2", con=sqlite_buildin, index=False) == 10 + result = sql.read_sql("select * from test_table2", sqlite_buildin, index_col="Idx") + expected = frame.copy() + expected.index = new_idx + expected.index.name = "Idx" + tm.assert_frame_equal(expected, result) + + +def test_xsqlite_write_row_by_row(sqlite_buildin): + frame = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD")), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + frame.iloc[0, 0] = np.nan + create_sql = sql.get_schema(frame, "test") + cur = sqlite_buildin.cursor() + cur.execute(create_sql) + + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + for _, row in frame.iterrows(): + fmt_sql = format_query(ins, *row) + tquery(fmt_sql, con=sqlite_buildin) + + sqlite_buildin.commit() + + result = sql.read_sql("select * from test", con=sqlite_buildin) + result.index = frame.index + tm.assert_frame_equal(result, frame, rtol=1e-3) + + +def test_xsqlite_execute(sqlite_buildin): + frame = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD")), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + create_sql = sql.get_schema(frame, "test") + cur = sqlite_buildin.cursor() + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (?, ?, ?, ?)" + + row = frame.iloc[0] + with sql.pandasSQL_builder(sqlite_buildin) as pandas_sql: + pandas_sql.execute(ins, tuple(row)) + sqlite_buildin.commit() + + result = sql.read_sql("select * from test", sqlite_buildin) + result.index = frame.index[:1] + tm.assert_frame_equal(result, frame[:1]) + + +def test_xsqlite_schema(sqlite_buildin): + frame = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD")), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + create_sql = sql.get_schema(frame, "test") + lines = create_sql.splitlines() + for line in lines: + tokens = line.split(" ") + if len(tokens) == 2 and tokens[0] == "A": + assert tokens[1] == "DATETIME" + + create_sql = sql.get_schema(frame, "test", keys=["A", "B"]) + lines = create_sql.splitlines() + assert 'PRIMARY KEY ("A", "B")' in create_sql + cur = sqlite_buildin.cursor() + cur.execute(create_sql) + + +def test_xsqlite_execute_fail(sqlite_buildin): + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + cur = sqlite_buildin.cursor() + cur.execute(create_sql) + + with sql.pandasSQL_builder(sqlite_buildin) as pandas_sql: + pandas_sql.execute("INSERT INTO test VALUES('foo', 'bar', 1.234)") + pandas_sql.execute("INSERT INTO test VALUES('foo', 'baz', 2.567)") + + with pytest.raises(sql.DatabaseError, match="Execution failed on sql"): + pandas_sql.execute("INSERT INTO test VALUES('foo', 'bar', 7)") + + +def test_xsqlite_execute_closed_connection(): + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + with contextlib.closing(sqlite3.connect(":memory:")) as conn: + cur = conn.cursor() + cur.execute(create_sql) + + with sql.pandasSQL_builder(conn) as pandas_sql: + pandas_sql.execute("INSERT INTO test VALUES('foo', 'bar', 1.234)") + + msg = "Cannot operate on a closed database." + with pytest.raises(sqlite3.ProgrammingError, match=msg): + tquery("select * from test", con=conn) + + +def test_xsqlite_keyword_as_column_names(sqlite_buildin): + df = DataFrame({"From": np.ones(5)}) + assert sql.to_sql(df, con=sqlite_buildin, name="testkeywords", index=False) == 5 + + +def test_xsqlite_onecolumn_of_integer(sqlite_buildin): + # GH 3628 + # a column_of_integers dataframe should transfer well to sql + + mono_df = DataFrame([1, 2], columns=["c0"]) + assert sql.to_sql(mono_df, con=sqlite_buildin, name="mono_df", index=False) == 2 + # computing the sum via sql + con_x = sqlite_buildin + the_sum = sum(my_c0[0] for my_c0 in con_x.execute("select * from mono_df")) + # it should not fail, and gives 3 ( Issue #3628 ) + assert the_sum == 3 + + result = sql.read_sql("select * from mono_df", con_x) + tm.assert_frame_equal(result, mono_df) + + +def test_xsqlite_if_exists(sqlite_buildin): + df_if_exists_1 = DataFrame({"col1": [1, 2], "col2": ["A", "B"]}) + df_if_exists_2 = DataFrame({"col1": [3, 4, 5], "col2": ["C", "D", "E"]}) + table_name = "table_if_exists" + sql_select = f"SELECT * FROM {table_name}" + + msg = "'notvalidvalue' is not valid for if_exists" + with pytest.raises(ValueError, match=msg): + sql.to_sql( + frame=df_if_exists_1, + con=sqlite_buildin, + name=table_name, + if_exists="notvalidvalue", + ) + drop_table(table_name, sqlite_buildin) + + # test if_exists='fail' + sql.to_sql( + frame=df_if_exists_1, con=sqlite_buildin, name=table_name, if_exists="fail" + ) + msg = "Table 'table_if_exists' already exists" + with pytest.raises(ValueError, match=msg): + sql.to_sql( + frame=df_if_exists_1, + con=sqlite_buildin, + name=table_name, + if_exists="fail", + ) + # test if_exists='replace' + sql.to_sql( + frame=df_if_exists_1, + con=sqlite_buildin, + name=table_name, + if_exists="replace", + index=False, + ) + assert tquery(sql_select, con=sqlite_buildin) == [(1, "A"), (2, "B")] + assert ( + sql.to_sql( + frame=df_if_exists_2, + con=sqlite_buildin, + name=table_name, + if_exists="replace", + index=False, + ) + == 3 + ) + assert tquery(sql_select, con=sqlite_buildin) == [(3, "C"), (4, "D"), (5, "E")] + drop_table(table_name, sqlite_buildin) + + # test if_exists='append' + assert ( + sql.to_sql( + frame=df_if_exists_1, + con=sqlite_buildin, + name=table_name, + if_exists="fail", + index=False, + ) + == 2 + ) + assert tquery(sql_select, con=sqlite_buildin) == [(1, "A"), (2, "B")] + assert ( + sql.to_sql( + frame=df_if_exists_2, + con=sqlite_buildin, + name=table_name, + if_exists="append", + index=False, + ) + == 3 + ) + assert tquery(sql_select, con=sqlite_buildin) == [ + (1, "A"), + (2, "B"), + (3, "C"), + (4, "D"), + (5, "E"), + ] + drop_table(table_name, sqlite_buildin) diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py new file mode 100644 index 0000000000000000000000000000000000000000..f69ec1f6105605f80eadccd0072c7c0683158b49 --- /dev/null +++ b/pandas/tests/io/test_stata.py @@ -0,0 +1,2624 @@ +import bz2 +import datetime as dt +from datetime import datetime +import gzip +import io +import itertools +import os +import string +import struct +import tarfile +import zipfile + +import numpy as np +import pytest + +from pandas.errors import Pandas4Warning +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import CategoricalDtype +import pandas._testing as tm +from pandas.core.frame import ( + DataFrame, + Series, +) + +from pandas.io.parsers import read_csv +from pandas.io.stata import ( + CategoricalConversionWarning, + InvalidColumnName, + PossiblePrecisionLoss, + StataMissingValue, + StataReader, + StataWriter, + StataWriterUTF8, + ValueLabelTypeMismatch, + read_stata, +) + + +@pytest.fixture +def mixed_frame(): + return DataFrame( + { + "a": [1, 2, 3, 4], + "b": [1.0, 3.0, 27.0, 81.0], + "c": ["Atlanta", "Birmingham", "Cincinnati", "Detroit"], + } + ) + + +@pytest.fixture +def parsed_114(datapath): + dta14_114 = datapath("io", "data", "stata", "stata5_114.dta") + parsed_114 = read_stata(dta14_114, convert_dates=True) + parsed_114.index.name = "index" + return parsed_114 + + +class TestStata: + def read_dta(self, file): + # Legacy default reader configuration + return read_stata(file, convert_dates=True) + + def read_csv(self, file): + return read_csv(file, parse_dates=True) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_read_empty_dta(self, version, temp_file): + empty_ds = DataFrame(columns=["unit"]) + # GH 7369, make sure can read a 0-obs dta file + path = temp_file + empty_ds.to_stata(path, write_index=False, version=version) + empty_ds2 = read_stata(path) + tm.assert_frame_equal(empty_ds, empty_ds2) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_read_empty_dta_with_dtypes(self, version, temp_file): + # GH 46240 + # Fixing above bug revealed that types are not correctly preserved when + # writing empty DataFrames + empty_df_typed = DataFrame( + { + "i8": np.array([0], dtype=np.int8), + "i16": np.array([0], dtype=np.int16), + "i32": np.array([0], dtype=np.int32), + "i64": np.array([0], dtype=np.int64), + "u8": np.array([0], dtype=np.uint8), + "u16": np.array([0], dtype=np.uint16), + "u32": np.array([0], dtype=np.uint32), + "u64": np.array([0], dtype=np.uint64), + "f32": np.array([0], dtype=np.float32), + "f64": np.array([0], dtype=np.float64), + } + ) + # GH 7369, make sure can read a 0-obs dta file + path = temp_file + empty_df_typed.to_stata(path, write_index=False, version=version) + empty_reread = read_stata(path) + + expected = empty_df_typed + # No uint# support. Downcast since values in range for int# + expected["u8"] = expected["u8"].astype(np.int8) + expected["u16"] = expected["u16"].astype(np.int16) + expected["u32"] = expected["u32"].astype(np.int32) + # No int64 supported at all. Downcast since values in range for int32 + expected["u64"] = expected["u64"].astype(np.int32) + expected["i64"] = expected["i64"].astype(np.int32) + + tm.assert_frame_equal(expected, empty_reread) + tm.assert_series_equal(expected.dtypes, empty_reread.dtypes) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_read_index_col_none(self, version, temp_file): + df = DataFrame({"a": range(5), "b": ["b1", "b2", "b3", "b4", "b5"]}) + # GH 7369, make sure can read a 0-obs dta file + path = temp_file + df.to_stata(path, write_index=False, version=version) + read_df = read_stata(path) + + assert isinstance(read_df.index, pd.RangeIndex) + expected = df + expected["a"] = expected["a"].astype(np.int32) + tm.assert_frame_equal(read_df, expected, check_index_type=True) + + @pytest.mark.parametrize( + "version", [102, 103, 104, 105, 108, 110, 111, 113, 114, 115, 117, 118, 119] + ) + def test_read_dta1(self, version, datapath): + file = datapath("io", "data", "stata", f"stata1_{version}.dta") + parsed = self.read_dta(file) + + # Pandas uses np.nan as missing value. + # Thus, all columns will be of type float, regardless of their name. + expected = DataFrame( + [(np.nan, np.nan, np.nan, np.nan, np.nan)], + columns=["float_miss", "double_miss", "byte_miss", "int_miss", "long_miss"], + ) + + # this is an oddity as really the nan should be float64, but + # the casting doesn't fail so need to match stata here + expected["float_miss"] = expected["float_miss"].astype(np.float32) + + # Column names too long for older Stata formats + if version <= 108: + expected = expected.rename( + columns={ + "float_miss": "f_miss", + "double_miss": "d_miss", + "byte_miss": "b_miss", + "int_miss": "i_miss", + "long_miss": "l_miss", + } + ) + + tm.assert_frame_equal(parsed, expected) + + def test_read_dta2(self, datapath): + expected = DataFrame.from_records( + [ + ( + datetime(2006, 11, 19, 23, 13, 20), + 1479596223000, + datetime(2010, 1, 20), + datetime(2010, 1, 8), + datetime(2010, 1, 1), + datetime(1974, 7, 1), + datetime(2010, 1, 1), + datetime(2010, 1, 1), + ), + ( + datetime(1959, 12, 31, 20, 3, 20), + -1479590, + datetime(1953, 10, 2), + datetime(1948, 6, 10), + datetime(1955, 1, 1), + datetime(1955, 7, 1), + datetime(1955, 1, 1), + datetime(2, 1, 1), + ), + (pd.NaT, pd.NaT, pd.NaT, pd.NaT, pd.NaT, pd.NaT, pd.NaT, pd.NaT), + ], + columns=[ + "datetime_c", + "datetime_big_c", + "date", + "weekly_date", + "monthly_date", + "quarterly_date", + "half_yearly_date", + "yearly_date", + ], + ) + # TODO(GH#55564): just pass M8[s] to the constructor + expected["datetime_c"] = expected["datetime_c"].astype("M8[ms]") + expected["date"] = expected["date"].astype("M8[s]") + expected["weekly_date"] = expected["weekly_date"].astype("M8[s]") + expected["monthly_date"] = expected["monthly_date"].astype("M8[s]") + expected["quarterly_date"] = expected["quarterly_date"].astype("M8[s]") + expected["half_yearly_date"] = expected["half_yearly_date"].astype("M8[s]") + expected["yearly_date"] = expected["yearly_date"].astype("M8[s]") + + path1 = datapath("io", "data", "stata", "stata2_114.dta") + path2 = datapath("io", "data", "stata", "stata2_115.dta") + path3 = datapath("io", "data", "stata", "stata2_117.dta") + + msg = "Leaving in Stata Internal Format" + with tm.assert_produces_warning(UserWarning, match=msg): + parsed_114 = self.read_dta(path1) + with tm.assert_produces_warning(UserWarning, match=msg): + parsed_115 = self.read_dta(path2) + with tm.assert_produces_warning(UserWarning, match=msg): + parsed_117 = self.read_dta(path3) + # FIXME: don't leave commented-out + # 113 is buggy due to limits of date format support in Stata + # parsed_113 = self.read_dta( + # datapath("io", "data", "stata", "stata2_113.dta") + # ) + + # FIXME: don't leave commented-out + # buggy test because of the NaT comparison on certain platforms + # Format 113 test fails since it does not support tc and tC formats + # tm.assert_frame_equal(parsed_113, expected) + tm.assert_frame_equal(parsed_114, expected) + tm.assert_frame_equal(parsed_115, expected) + tm.assert_frame_equal(parsed_117, expected) + + @pytest.mark.parametrize( + "file", ["stata3_113", "stata3_114", "stata3_115", "stata3_117"] + ) + def test_read_dta3(self, file, datapath): + file = datapath("io", "data", "stata", f"{file}.dta") + parsed = self.read_dta(file) + + # match stata here + expected = self.read_csv(datapath("io", "data", "stata", "stata3.csv")) + expected = expected.astype(np.float32) + expected["year"] = expected["year"].astype(np.int16) + expected["quarter"] = expected["quarter"].astype(np.int8) + + tm.assert_frame_equal(parsed, expected) + + @pytest.mark.parametrize("version", [110, 111, 113, 114, 115, 117]) + def test_read_dta4(self, version, datapath): + file = datapath("io", "data", "stata", f"stata4_{version}.dta") + parsed = self.read_dta(file) + + expected = DataFrame.from_records( + [ + ["one", "ten", "one", "one", "one"], + ["two", "nine", "two", "two", "two"], + ["three", "eight", "three", "three", "three"], + ["four", "seven", 4, "four", "four"], + ["five", "six", 5, np.nan, "five"], + ["six", "five", 6, np.nan, "six"], + ["seven", "four", 7, np.nan, "seven"], + ["eight", "three", 8, np.nan, "eight"], + ["nine", "two", 9, np.nan, "nine"], + ["ten", "one", "ten", np.nan, "ten"], + ], + columns=[ + "fully_labeled", + "fully_labeled2", + "incompletely_labeled", + "labeled_with_missings", + "float_labelled", + ], + ) + + # these are all categoricals + for col in expected: + orig = expected[col].copy() + + categories = np.asarray(expected["fully_labeled"][orig.notna()]) + if col == "incompletely_labeled": + categories = orig + + cat = orig.astype("category")._values + cat = cat.set_categories(categories, ordered=True) + cat.categories.rename(None, inplace=True) + + expected[col] = cat + + # stata doesn't save .category metadata + tm.assert_frame_equal(parsed, expected) + + @pytest.mark.parametrize("version", [102, 103, 104, 105, 108]) + def test_readold_dta4(self, version, datapath): + # This test is the same as test_read_dta4 above except that the columns + # had to be renamed to match the restrictions in older file format + file = datapath("io", "data", "stata", f"stata4_{version}.dta") + parsed = self.read_dta(file) + + expected = DataFrame.from_records( + [ + ["one", "ten", "one", "one", "one"], + ["two", "nine", "two", "two", "two"], + ["three", "eight", "three", "three", "three"], + ["four", "seven", 4, "four", "four"], + ["five", "six", 5, np.nan, "five"], + ["six", "five", 6, np.nan, "six"], + ["seven", "four", 7, np.nan, "seven"], + ["eight", "three", 8, np.nan, "eight"], + ["nine", "two", 9, np.nan, "nine"], + ["ten", "one", "ten", np.nan, "ten"], + ], + columns=[ + "fulllab", + "fulllab2", + "incmplab", + "misslab", + "floatlab", + ], + ) + + # these are all categoricals + for col in expected: + orig = expected[col].copy() + + categories = np.asarray(expected["fulllab"][orig.notna()]) + if col == "incmplab": + categories = orig + + cat = orig.astype("category")._values + cat = cat.set_categories(categories, ordered=True) + cat.categories.rename(None, inplace=True) + + expected[col] = cat + + # stata doesn't save .category metadata + tm.assert_frame_equal(parsed, expected) + + # File containing strls + @pytest.mark.parametrize( + "file", + [ + "stata12_117", + "stata12_be_117", + "stata12_118", + "stata12_be_118", + "stata12_119", + "stata12_be_119", + ], + ) + def test_read_dta_strl(self, file, datapath): + parsed = self.read_dta(datapath("io", "data", "stata", f"{file}.dta")) + expected = DataFrame.from_records( + [ + [1, "abc", "abcdefghi"], + [3, "cba", "qwertywertyqwerty"], + [93, "", "strl"], + ], + columns=["x", "y", "z"], + ) + + tm.assert_frame_equal(parsed, expected, check_dtype=False) + + # 117 is not included in this list as it uses ASCII strings + @pytest.mark.parametrize( + "file", + [ + "stata14_118", + "stata14_be_118", + "stata14_119", + "stata14_be_119", + ], + ) + def test_read_dta118_119(self, file, datapath): + parsed_118 = self.read_dta(datapath("io", "data", "stata", f"{file}.dta")) + parsed_118["Bytes"] = parsed_118["Bytes"].astype("O") + expected = DataFrame.from_records( + [ + ["Cat", "Bogota", "Bogotá", 1, 1.0, "option b Ünicode", 1.0], + ["Dog", "Boston", "Uzunköprü", np.nan, np.nan, np.nan, np.nan], + ["Plane", "Rome", "Tromsø", 0, 0.0, "option a", 0.0], + ["Potato", "Tokyo", "Elâzığ", -4, 4.0, 4, 4], # noqa: RUF001 + ["", "", "", 0, 0.3332999, "option a", 1 / 3.0], + ], + columns=[ + "Things", + "Cities", + "Unicode_Cities_Strl", + "Ints", + "Floats", + "Bytes", + "Longs", + ], + ) + expected["Floats"] = expected["Floats"].astype(np.float32) + for col in parsed_118.columns: + tm.assert_almost_equal(parsed_118[col], expected[col]) + + with StataReader(datapath("io", "data", "stata", f"{file}.dta")) as rdr: + vl = rdr.variable_labels() + vl_expected = { + "Unicode_Cities_Strl": "Here are some strls with Ünicode chars", + "Longs": "long data", + "Things": "Here are some things", + "Bytes": "byte data", + "Ints": "int data", + "Cities": "Here are some cities", + "Floats": "float data", + } + tm.assert_dict_equal(vl, vl_expected) + + assert rdr.data_label == "This is a Ünicode data label" + + def test_read_write_dta5(self, temp_file): + original = DataFrame( + [(np.nan, np.nan, np.nan, np.nan, np.nan)], + columns=["float_miss", "double_miss", "byte_miss", "int_miss", "long_miss"], + ) + original.index.name = "index" + + path = temp_file + original.to_stata(path, convert_dates=None) + written_and_read_again = self.read_dta(path) + + expected = original + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + def test_write_dta6(self, datapath, temp_file): + original = self.read_csv(datapath("io", "data", "stata", "stata3.csv")) + original.index.name = "index" + original.index = original.index.astype(np.int32) + original["year"] = original["year"].astype(np.int32) + original["quarter"] = original["quarter"].astype(np.int32) + + path = temp_file + original.to_stata(path, convert_dates=None) + written_and_read_again = self.read_dta(path) + tm.assert_frame_equal( + written_and_read_again.set_index("index"), + original, + check_index_type=False, + ) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_read_write_dta10(self, version, temp_file, using_infer_string): + original = DataFrame( + data=[["string", "object", 1, 1.1, np.datetime64("2003-12-25")]], + columns=["string", "object", "integer", "floating", "datetime"], + ) + original["object"] = Series(original["object"], dtype=object) + original.index.name = "index" + original.index = original.index.astype(np.int32) + original["integer"] = original["integer"].astype(np.int32) + + path = temp_file + original.to_stata(path, convert_dates={"datetime": "tc"}, version=version) + written_and_read_again = self.read_dta(path) + + expected = original.copy() + # "tc" convert_dates means we store in ms + expected["datetime"] = expected["datetime"].astype("M8[ms]") + if using_infer_string: + expected["object"] = expected["object"].astype("str") + + tm.assert_frame_equal( + written_and_read_again.set_index("index"), + expected, + ) + + def test_stata_doc_examples(self, temp_file): + path = temp_file + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 2)), columns=list("AB") + ) + df.to_stata(path) + + def test_write_preserves_original(self, temp_file): + # 9795 + + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 4)), columns=list("abcd") + ) + df.loc[2, "a":"c"] = np.nan + df_copy = df.copy() + path = temp_file + df.to_stata(path, write_index=False) + tm.assert_frame_equal(df, df_copy) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_encoding(self, version, datapath, temp_file): + # GH 4626, proper encoding handling + raw = read_stata(datapath("io", "data", "stata", "stata1_encoding.dta")) + encoded = read_stata(datapath("io", "data", "stata", "stata1_encoding.dta")) + result = encoded.kreis1849[0] + + expected = raw.kreis1849[0] + assert result == expected + assert isinstance(result, str) + + path = temp_file + encoded.to_stata(path, write_index=False, version=version) + reread_encoded = read_stata(path) + tm.assert_frame_equal(encoded, reread_encoded) + + def test_read_write_dta11(self, temp_file): + original = DataFrame( + [(1, 2, 3, 4)], + columns=[ + "good", + "b\u00e4d", + "8number", + "astringwithmorethan32characters______", + ], + ) + formatted = DataFrame( + [(1, 2, 3, 4)], + columns=["good", "b_d", "_8number", "astringwithmorethan32characters_"], + ) + formatted.index.name = "index" + formatted = formatted.astype(np.int32) + + path = temp_file + msg = "Not all pandas column names were valid Stata variable names" + with tm.assert_produces_warning(InvalidColumnName, match=msg): + original.to_stata(path, convert_dates=None) + + written_and_read_again = self.read_dta(path) + + expected = formatted + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_read_write_dta12(self, version, temp_file): + original = DataFrame( + [(1, 2, 3, 4, 5, 6)], + columns=[ + "astringwithmorethan32characters_1", + "astringwithmorethan32characters_2", + "+", + "-", + "short", + "delete", + ], + ) + formatted = DataFrame( + [(1, 2, 3, 4, 5, 6)], + columns=[ + "astringwithmorethan32characters_", + "_0astringwithmorethan32character", + "_", + "_1_", + "_short", + "_delete", + ], + ) + formatted.index.name = "index" + formatted = formatted.astype(np.int32) + + path = temp_file + msg = "Not all pandas column names were valid Stata variable names" + with tm.assert_produces_warning(InvalidColumnName, match=msg): + original.to_stata(path, convert_dates=None, version=version) + # should get a warning for that format. + + written_and_read_again = self.read_dta(path) + + expected = formatted + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + def test_read_write_dta13(self, temp_file): + s1 = Series(2**9, dtype=np.int16) + s2 = Series(2**17, dtype=np.int32) + s3 = Series(2**33, dtype=np.int64) + original = DataFrame({"int16": s1, "int32": s2, "int64": s3}) + original.index.name = "index" + + formatted = original + formatted["int64"] = formatted["int64"].astype(np.float64) + + path = temp_file + original.to_stata(path) + written_and_read_again = self.read_dta(path) + + expected = formatted + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + @pytest.mark.parametrize( + "file", ["stata5_113", "stata5_114", "stata5_115", "stata5_117"] + ) + def test_read_write_reread_dta14( + self, file, parsed_114, version, datapath, temp_file + ): + file = datapath("io", "data", "stata", f"{file}.dta") + parsed = self.read_dta(file) + parsed.index.name = "index" + + tm.assert_frame_equal(parsed_114, parsed) + + path = temp_file + parsed_114.to_stata(path, convert_dates={"date_td": "td"}, version=version) + written_and_read_again = self.read_dta(path) + + expected = parsed_114.copy() + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + @pytest.mark.parametrize( + "file", ["stata6_113", "stata6_114", "stata6_115", "stata6_117"] + ) + def test_read_write_reread_dta15(self, file, datapath): + expected = self.read_csv(datapath("io", "data", "stata", "stata6.csv")) + expected["byte_"] = expected["byte_"].astype(np.int8) + expected["int_"] = expected["int_"].astype(np.int16) + expected["long_"] = expected["long_"].astype(np.int32) + expected["float_"] = expected["float_"].astype(np.float32) + expected["double_"] = expected["double_"].astype(np.float64) + + # TODO(GH#55564): directly cast to M8[s] + arr = expected["date_td"].astype("Period[D]")._values.asfreq("s", how="S") + expected["date_td"] = arr.view("M8[s]") + + file = datapath("io", "data", "stata", f"{file}.dta") + parsed = self.read_dta(file) + + tm.assert_frame_equal(expected, parsed) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_timestamp_and_label(self, version, temp_file): + original = DataFrame([(1,)], columns=["variable"]) + time_stamp = datetime(2000, 2, 29, 14, 21) + data_label = "This is a data file." + path = temp_file + original.to_stata( + path, time_stamp=time_stamp, data_label=data_label, version=version + ) + + with StataReader(path) as reader: + assert reader.time_stamp == "29 Feb 2000 14:21" + assert reader.data_label == data_label + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_invalid_timestamp(self, version, temp_file): + original = DataFrame([(1,)], columns=["variable"]) + time_stamp = "01 Jan 2000, 00:00:00" + path = temp_file + msg = "time_stamp should be datetime type" + with pytest.raises(ValueError, match=msg): + original.to_stata(path, time_stamp=time_stamp, version=version) + assert not os.path.isfile(path) + + def test_numeric_column_names(self, temp_file): + original = DataFrame(np.reshape(np.arange(25.0), (5, 5))) + original.index.name = "index" + path = temp_file + # should get a warning for that format. + msg = "Not all pandas column names were valid Stata variable names" + with tm.assert_produces_warning(InvalidColumnName, match=msg): + original.to_stata(path) + + written_and_read_again = self.read_dta(path) + + written_and_read_again = written_and_read_again.set_index("index") + columns = list(written_and_read_again.columns) + convert_col_name = lambda x: int(x[1]) + written_and_read_again.columns = map(convert_col_name, columns) + + expected = original + tm.assert_frame_equal(expected, written_and_read_again) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_nan_to_missing_value(self, version, temp_file): + s1 = Series(np.arange(4.0), dtype=np.float32) + s2 = Series(np.arange(4.0), dtype=np.float64) + s1[::2] = np.nan + s2[1::2] = np.nan + original = DataFrame({"s1": s1, "s2": s2}) + original.index.name = "index" + + path = temp_file + original.to_stata(path, version=version) + written_and_read_again = self.read_dta(path) + + written_and_read_again = written_and_read_again.set_index("index") + expected = original + tm.assert_frame_equal(written_and_read_again, expected) + + def test_no_index(self, temp_file): + columns = ["x", "y"] + original = DataFrame(np.reshape(np.arange(10.0), (5, 2)), columns=columns) + original.index.name = "index_not_written" + path = temp_file + original.to_stata(path, write_index=False) + written_and_read_again = self.read_dta(path) + with pytest.raises(KeyError, match=original.index.name): + written_and_read_again["index_not_written"] + + def test_string_no_dates(self, temp_file): + s1 = Series(["a", "A longer string"]) + s2 = Series([1.0, 2.0], dtype=np.float64) + original = DataFrame({"s1": s1, "s2": s2}) + original.index.name = "index" + path = temp_file + original.to_stata(path) + written_and_read_again = self.read_dta(path) + + expected = original + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + def test_large_value_conversion(self, temp_file): + s0 = Series([1, 99], dtype=np.int8) + s1 = Series([1, 127], dtype=np.int8) + s2 = Series([1, 2**15 - 1], dtype=np.int16) + s3 = Series([1, 2**63 - 1], dtype=np.int64) + original = DataFrame({"s0": s0, "s1": s1, "s2": s2, "s3": s3}) + original.index.name = "index" + path = temp_file + with tm.assert_produces_warning(PossiblePrecisionLoss, match="from int64 to"): + original.to_stata(path) + + written_and_read_again = self.read_dta(path) + + modified = original + modified["s1"] = Series(modified["s1"], dtype=np.int16) + modified["s2"] = Series(modified["s2"], dtype=np.int32) + modified["s3"] = Series(modified["s3"], dtype=np.float64) + tm.assert_frame_equal(written_and_read_again.set_index("index"), modified) + + def test_dates_invalid_column(self, temp_file): + original = DataFrame([datetime(2006, 11, 19, 23, 13, 20)]) + original.index.name = "index" + path = temp_file + msg = "Not all pandas column names were valid Stata variable names" + with tm.assert_produces_warning(InvalidColumnName, match=msg): + original.to_stata(path, convert_dates={0: "tc"}) + + written_and_read_again = self.read_dta(path) + + expected = original.copy() + expected.columns = ["_0"] + expected.index = original.index.astype(np.int32) + expected["_0"] = expected["_0"].astype("M8[ms]") + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + def test_105(self, datapath): + # Data obtained from: + # http://go.worldbank.org/ZXY29PVJ21 + dpath = datapath("io", "data", "stata", "S4_EDUC1.dta") + df = read_stata(dpath) + df0 = [[1, 1, 3, -2], [2, 1, 2, -2], [4, 1, 1, -2]] + df0 = DataFrame(df0) + df0.columns = ["clustnum", "pri_schl", "psch_num", "psch_dis"] + df0["clustnum"] = df0["clustnum"].astype(np.int16) + df0["pri_schl"] = df0["pri_schl"].astype(np.int8) + df0["psch_num"] = df0["psch_num"].astype(np.int8) + df0["psch_dis"] = df0["psch_dis"].astype(np.float32) + tm.assert_frame_equal(df.head(3), df0) + + def test_value_labels_old_format(self, datapath): + # GH 19417 + # + # Test that value_labels() returns an empty dict if the file format + # predates supporting value labels. + dpath = datapath("io", "data", "stata", "S4_EDUC1.dta") + with StataReader(dpath) as reader: + assert reader.value_labels() == {} + + def test_date_export_formats(self, temp_file): + columns = ["tc", "td", "tw", "tm", "tq", "th", "ty"] + conversions = {c: c for c in columns} + data = [datetime(2006, 11, 20, 23, 13, 20)] * len(columns) + original = DataFrame([data], columns=columns) + original.index.name = "index" + expected_values = [ + datetime(2006, 11, 20, 23, 13, 20), # Time + datetime(2006, 11, 20), # Day + datetime(2006, 11, 19), # Week + datetime(2006, 11, 1), # Month + datetime(2006, 10, 1), # Quarter year + datetime(2006, 7, 1), # Half year + datetime(2006, 1, 1), + ] # Year + + expected = DataFrame( + [expected_values], + index=pd.Index([0], dtype=np.int32, name="index"), + columns=columns, + dtype="M8[s]", + ) + expected["tc"] = expected["tc"].astype("M8[ms]") + + path = temp_file + original.to_stata(path, convert_dates=conversions) + written_and_read_again = self.read_dta(path) + + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + def test_write_missing_strings(self, temp_file): + original = DataFrame([["1"], [None]], columns=["foo"]) + + expected = DataFrame( + [["1"], [""]], + index=pd.RangeIndex(2, name="index"), + columns=["foo"], + ) + + path = temp_file + original.to_stata(path) + written_and_read_again = self.read_dta(path) + + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + @pytest.mark.parametrize("byteorder", [">", "<"]) + def test_bool_uint(self, byteorder, version, temp_file): + s0 = Series([0, 1, True], dtype=np.bool_) + s1 = Series([0, 1, 100], dtype=np.uint8) + s2 = Series([0, 1, 255], dtype=np.uint8) + s3 = Series([0, 1, 2**15 - 100], dtype=np.uint16) + s4 = Series([0, 1, 2**16 - 1], dtype=np.uint16) + s5 = Series([0, 1, 2**31 - 100], dtype=np.uint32) + s6 = Series([0, 1, 2**32 - 1], dtype=np.uint32) + + original = DataFrame( + {"s0": s0, "s1": s1, "s2": s2, "s3": s3, "s4": s4, "s5": s5, "s6": s6} + ) + original.index.name = "index" + + path = temp_file + original.to_stata(path, byteorder=byteorder, version=version) + written_and_read_again = self.read_dta(path) + + written_and_read_again = written_and_read_again.set_index("index") + + expected = original + expected_types = ( + np.int8, + np.int8, + np.int16, + np.int16, + np.int32, + np.int32, + np.float64, + ) + for c, t in zip(expected.columns, expected_types): + expected[c] = expected[c].astype(t) + + tm.assert_frame_equal(written_and_read_again, expected) + + def test_variable_labels(self, datapath): + with StataReader(datapath("io", "data", "stata", "stata7_115.dta")) as rdr: + sr_115 = rdr.variable_labels() + with StataReader(datapath("io", "data", "stata", "stata7_117.dta")) as rdr: + sr_117 = rdr.variable_labels() + keys = ("var1", "var2", "var3") + labels = ("label1", "label2", "label3") + for k, v in sr_115.items(): + assert k in sr_117 + assert v == sr_117[k] + assert k in keys + assert v in labels + + def test_minimal_size_col(self, temp_file): + str_lens = (1, 100, 244) + s = {} + for str_len in str_lens: + s["s" + str(str_len)] = Series( + ["a" * str_len, "b" * str_len, "c" * str_len] + ) + original = DataFrame(s) + path = temp_file + original.to_stata(path, write_index=False) + + with StataReader(path) as sr: + sr._ensure_open() # The `_*list` variables are initialized here + for variable, fmt, typ in zip(sr._varlist, sr._fmtlist, sr._typlist): + assert int(variable[1:]) == int(fmt[1:-1]) + assert int(variable[1:]) == typ + + def test_excessively_long_string(self, temp_file): + str_lens = (1, 244, 500) + s = {} + for str_len in str_lens: + s["s" + str(str_len)] = Series( + ["a" * str_len, "b" * str_len, "c" * str_len] + ) + original = DataFrame(s) + msg = ( + r"Fixed width strings in Stata \.dta files are limited to 244 " + r"\(or fewer\)\ncharacters\. Column 's500' does not satisfy " + r"this restriction\. Use the\n'version=117' parameter to write " + r"the newer \(Stata 13 and later\) format\." + ) + with pytest.raises(ValueError, match=msg): + path = temp_file + original.to_stata(path) + + def test_missing_value_generator(self, temp_file): + types = ("b", "h", "l") + df = DataFrame([[0.0]], columns=["float_"]) + path = temp_file + df.to_stata(path) + with StataReader(path) as rdr: + valid_range = rdr.VALID_RANGE + expected_values = ["." + chr(97 + i) for i in range(26)] + expected_values.insert(0, ".") + for t in types: + offset = valid_range[t][1] + for i in range(27): + val = StataMissingValue(offset + 1 + i) + assert val.string == expected_values[i] + + # Test extremes for floats + val = StataMissingValue(struct.unpack(" DataFrame: + """ + Emulate the categorical casting behavior we expect from roundtripping. + """ + for col in from_frame: + ser = from_frame[col] + if isinstance(ser.dtype, CategoricalDtype): + cat = ser._values.remove_unused_categories() + if cat.categories.dtype == object: + categories = pd.Index._with_infer( + cat.categories._values, copy=False + ) + cat = cat.set_categories(categories) + elif cat.categories.dtype == "string" and len(cat.categories) == 0: + # if the read categories are empty, it comes back as object dtype + categories = cat.categories.astype(object) + cat = cat.set_categories(categories) + from_frame[col] = cat + return from_frame + + def test_iterator(self, datapath): + fname = datapath("io", "data", "stata", "stata12_117.dta") + + parsed = read_stata(fname) + expected = parsed.iloc[0:5, :] + + with read_stata(fname, iterator=True) as itr: + chunk = itr.read(5) + tm.assert_frame_equal(expected, chunk) + + with read_stata(fname, chunksize=5) as itr: + chunk = next(itr) + tm.assert_frame_equal(expected, chunk) + + with read_stata(fname, iterator=True) as itr: + chunk = itr.get_chunk(5) + tm.assert_frame_equal(expected, chunk) + + with read_stata(fname, chunksize=5) as itr: + chunk = itr.get_chunk() + tm.assert_frame_equal(expected, chunk) + + # GH12153 + with read_stata(fname, chunksize=4) as itr: + from_chunks = pd.concat(itr) + tm.assert_frame_equal(parsed, from_chunks) + + @pytest.mark.filterwarnings("ignore::UserWarning") + @pytest.mark.parametrize( + "file", + [ + "stata2_115", + "stata3_115", + "stata4_115", + "stata5_115", + "stata6_115", + "stata7_115", + "stata8_115", + "stata9_115", + "stata10_115", + "stata11_115", + ], + ) + @pytest.mark.parametrize("chunksize", [1, 2]) + @pytest.mark.parametrize("convert_categoricals", [False, True]) + @pytest.mark.parametrize("convert_dates", [False, True]) + def test_read_chunks_115( + self, file, chunksize, convert_categoricals, convert_dates, datapath + ): + fname = datapath("io", "data", "stata", f"{file}.dta") + + # Read the whole file + parsed = read_stata( + fname, + convert_categoricals=convert_categoricals, + convert_dates=convert_dates, + ) + + # Compare to what we get when reading by chunk + with read_stata( + fname, + iterator=True, + convert_dates=convert_dates, + convert_categoricals=convert_categoricals, + ) as itr: + pos = 0 + for j in range(5): + try: + chunk = itr.read(chunksize) + except StopIteration: + break + from_frame = parsed.iloc[pos : pos + chunksize, :].copy() + from_frame = self._convert_categorical(from_frame) + tm.assert_frame_equal( + from_frame, + chunk, + check_dtype=False, + ) + pos += chunksize + + def test_read_chunks_columns(self, datapath): + fname = datapath("io", "data", "stata", "stata3_117.dta") + columns = ["quarter", "cpi", "m1"] + chunksize = 2 + + parsed = read_stata(fname, columns=columns) + with read_stata(fname, iterator=True) as itr: + pos = 0 + for j in range(5): + chunk = itr.read(chunksize, columns=columns) + if chunk is None: + break + from_frame = parsed.iloc[pos : pos + chunksize, :] + tm.assert_frame_equal(from_frame, chunk, check_dtype=False) + pos += chunksize + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_write_variable_labels(self, version, mixed_frame, temp_file): + # GH 13631, add support for writing variable labels + mixed_frame.index.name = "index" + variable_labels = {"a": "City Rank", "b": "City Exponent", "c": "City"} + path = temp_file + mixed_frame.to_stata(path, variable_labels=variable_labels, version=version) + with StataReader(path) as sr: + read_labels = sr.variable_labels() + expected_labels = { + "index": "", + "a": "City Rank", + "b": "City Exponent", + "c": "City", + } + assert read_labels == expected_labels + + variable_labels["index"] = "The Index" + path = temp_file + mixed_frame.to_stata(path, variable_labels=variable_labels, version=version) + with StataReader(path) as sr: + read_labels = sr.variable_labels() + assert read_labels == variable_labels + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_invalid_variable_labels(self, version, mixed_frame, temp_file): + mixed_frame.index.name = "index" + variable_labels = {"a": "very long" * 10, "b": "City Exponent", "c": "City"} + path = temp_file + msg = "Variable labels must be 80 characters or fewer" + with pytest.raises(ValueError, match=msg): + mixed_frame.to_stata(path, variable_labels=variable_labels, version=version) + + @pytest.mark.parametrize("version", [114, 117]) + def test_invalid_variable_label_encoding(self, version, mixed_frame, temp_file): + mixed_frame.index.name = "index" + variable_labels = {"a": "very long" * 10, "b": "City Exponent", "c": "City"} + variable_labels["a"] = "invalid character Œ" + path = temp_file + with pytest.raises( + ValueError, match="Variable labels must contain only characters" + ): + mixed_frame.to_stata(path, variable_labels=variable_labels, version=version) + + def test_write_variable_label_errors(self, mixed_frame, temp_file): + values = ["\u03a1", "\u0391", "\u039d", "\u0394", "\u0391", "\u03a3"] + + variable_labels_utf8 = { + "a": "City Rank", + "b": "City Exponent", + "c": "".join(values), + } + + msg = ( + "Variable labels must contain only characters that can be " + "encoded in Latin-1" + ) + with pytest.raises(ValueError, match=msg): + path = temp_file + mixed_frame.to_stata(path, variable_labels=variable_labels_utf8) + + variable_labels_long = { + "a": "City Rank", + "b": "City Exponent", + "c": "A very, very, very long variable label " + "that is too long for Stata which means " + "that it has more than 80 characters", + } + + msg = "Variable labels must be 80 characters or fewer" + with pytest.raises(ValueError, match=msg): + path = temp_file + mixed_frame.to_stata(path, variable_labels=variable_labels_long) + + def test_default_date_conversion(self, temp_file): + # GH 12259 + dates = [ + dt.datetime(1999, 12, 31, 12, 12, 12, 12000), + dt.datetime(2012, 12, 21, 12, 21, 12, 21000), + dt.datetime(1776, 7, 4, 7, 4, 7, 4000), + ] + original = DataFrame( + { + "nums": [1.0, 2.0, 3.0], + "strs": ["apple", "banana", "cherry"], + "dates": dates, + } + ) + + expected = original[:] + # "tc" for convert_dates below stores with "ms" resolution + expected["dates"] = expected["dates"].astype("M8[ms]") + + path = temp_file + original.to_stata(path, write_index=False) + reread = read_stata(path, convert_dates=True) + tm.assert_frame_equal(expected, reread) + + original.to_stata(path, write_index=False, convert_dates={"dates": "tc"}) + direct = read_stata(path, convert_dates=True) + tm.assert_frame_equal(reread, direct) + + dates_idx = original.columns.tolist().index("dates") + original.to_stata(path, write_index=False, convert_dates={dates_idx: "tc"}) + direct = read_stata(path, convert_dates=True) + tm.assert_frame_equal(reread, direct) + + def test_unsupported_type(self, temp_file): + original = DataFrame({"a": [1 + 2j, 2 + 4j]}) + + msg = "Data type complex128 not supported" + with pytest.raises(NotImplementedError, match=msg): + path = temp_file + original.to_stata(path) + + def test_unsupported_datetype(self, temp_file): + dates = [ + dt.datetime(1999, 12, 31, 12, 12, 12, 12000), + dt.datetime(2012, 12, 21, 12, 21, 12, 21000), + dt.datetime(1776, 7, 4, 7, 4, 7, 4000), + ] + original = DataFrame( + { + "nums": [1.0, 2.0, 3.0], + "strs": ["apple", "banana", "cherry"], + "dates": dates, + } + ) + + msg = "Format %tC not implemented" + with pytest.raises(NotImplementedError, match=msg): + path = temp_file + original.to_stata(path, convert_dates={"dates": "tC"}) + + dates = pd.date_range("1-1-1990", periods=3, tz="Asia/Hong_Kong") + original = DataFrame( + { + "nums": [1.0, 2.0, 3.0], + "strs": ["apple", "banana", "cherry"], + "dates": dates, + } + ) + with pytest.raises(NotImplementedError, match="Data type datetime64"): + path = temp_file + original.to_stata(path) + + def test_repeated_column_labels(self, datapath): + # GH 13923, 25772 + msg = """ +Value labels for column ethnicsn are not unique. These cannot be converted to +pandas categoricals. + +Either read the file with `convert_categoricals` set to False or use the +low level interface in `StataReader` to separately read the values and the +value_labels. + +The repeated labels are:\n-+\nwolof +""" + with pytest.raises(ValueError, match=msg): + read_stata( + datapath("io", "data", "stata", "stata15.dta"), + convert_categoricals=True, + ) + + def test_stata_111(self, datapath): + # 111 is an old version but still used by current versions of + # SAS when exporting to Stata format. We do not know of any + # on-line documentation for this version. + df = read_stata(datapath("io", "data", "stata", "stata7_111.dta")) + original = DataFrame( + { + "y": [1, 1, 1, 1, 1, 0, 0, np.nan, 0, 0], + "x": [1, 2, 1, 3, np.nan, 4, 3, 5, 1, 6], + "w": [2, np.nan, 5, 2, 4, 4, 3, 1, 2, 3], + "z": ["a", "b", "c", "d", "e", "", "g", "h", "i", "j"], + } + ) + original = original[["y", "x", "w", "z"]] + tm.assert_frame_equal(original, df) + + def test_out_of_range_double(self, temp_file): + # GH 14618 + df = DataFrame( + { + "ColumnOk": [0.0, np.finfo(np.double).eps, 4.49423283715579e307], + "ColumnTooBig": [0.0, np.finfo(np.double).eps, np.finfo(np.double).max], + } + ) + msg = ( + r"Column ColumnTooBig has a maximum value \(.+\) outside the range " + r"supported by Stata \(.+\)" + ) + with pytest.raises(ValueError, match=msg): + path = temp_file + df.to_stata(path) + + def test_out_of_range_float(self, temp_file): + original = DataFrame( + { + "ColumnOk": [ + 0.0, + np.finfo(np.float32).eps, + np.finfo(np.float32).max / 10.0, + ], + "ColumnTooBig": [ + 0.0, + np.finfo(np.float32).eps, + np.finfo(np.float32).max, + ], + } + ) + original.index.name = "index" + for col in original: + original[col] = original[col].astype(np.float32) + + path = temp_file + original.to_stata(path) + reread = read_stata(path) + + original["ColumnTooBig"] = original["ColumnTooBig"].astype(np.float64) + expected = original + tm.assert_frame_equal(reread.set_index("index"), expected) + + @pytest.mark.parametrize("infval", [np.inf, -np.inf]) + def test_inf(self, infval, temp_file): + # GH 45350 + df = DataFrame({"WithoutInf": [0.0, 1.0], "WithInf": [2.0, infval]}) + msg = ( + "Column WithInf contains infinity or -infinity" + "which is outside the range supported by Stata." + ) + with pytest.raises(ValueError, match=msg): + path = temp_file + df.to_stata(path) + + def test_path_pathlib(self, temp_file): + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + df.index.name = "index" + reader = lambda x: read_stata(x).set_index("index") + result = tm.round_trip_pathlib(df.to_stata, reader, temp_file) + tm.assert_frame_equal(df, result) + + @pytest.mark.parametrize("write_index", [True, False]) + def test_value_labels_iterator(self, write_index, temp_file): + # GH 16923 + d = {"A": ["B", "E", "C", "A", "E"]} + df = DataFrame(data=d) + df["A"] = df["A"].astype("category") + path = temp_file + df.to_stata(path, write_index=write_index) + + with read_stata(path, iterator=True) as dta_iter: + value_labels = dta_iter.value_labels() + assert value_labels == {"A": {0: "A", 1: "B", 2: "C", 3: "E"}} + + def test_set_index(self, temp_file): + # GH 17328 + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + df.index.name = "index" + path = temp_file + df.to_stata(path) + reread = read_stata(path, index_col="index") + tm.assert_frame_equal(df, reread) + + @pytest.mark.parametrize( + "column", ["ms", "day", "week", "month", "qtr", "half", "yr"] + ) + def test_date_parsing_ignores_format_details(self, column, datapath): + # GH 17797 + # + # Test that display formats are ignored when determining if a numeric + # column is a date value. + # + # All date types are stored as numbers and format associated with the + # column denotes both the type of the date and the display format. + # + # STATA supports 9 date types which each have distinct units. We test 7 + # of the 9 types, ignoring %tC and %tb. %tC is a variant of %tc that + # accounts for leap seconds and %tb relies on STATAs business calendar. + df = read_stata(datapath("io", "data", "stata", "stata13_dates.dta")) + unformatted = df.loc[0, column] + formatted = df.loc[0, column + "_fmt"] + assert unformatted == formatted + + @pytest.mark.parametrize("byteorder", ["little", "big"]) + def test_writer_117(self, byteorder, temp_file, using_infer_string): + original = DataFrame( + data=[ + [ + "string", + "object", + 1, + 1, + 1, + 1.1, + 1.1, + np.datetime64("2003-12-25"), + "a", + "a" * 2045, + "a" * 5000, + "a", + ], + [ + "string-1", + "object-1", + 1, + 1, + 1, + 1.1, + 1.1, + np.datetime64("2003-12-26"), + "b", + "b" * 2045, + "", + "", + ], + ], + columns=[ + "string", + "object", + "int8", + "int16", + "int32", + "float32", + "float64", + "datetime", + "s1", + "s2045", + "srtl", + "forced_strl", + ], + ) + original["object"] = Series(original["object"], dtype=object) + original["int8"] = Series(original["int8"], dtype=np.int8) + original["int16"] = Series(original["int16"], dtype=np.int16) + original["int32"] = original["int32"].astype(np.int32) + original["float32"] = Series(original["float32"], dtype=np.float32) + original.index.name = "index" + copy = original.copy() + path = temp_file + original.to_stata( + path, + convert_dates={"datetime": "tc"}, + byteorder=byteorder, + convert_strl=["forced_strl"], + version=117, + ) + written_and_read_again = self.read_dta(path) + + expected = original[:] + # "tc" for convert_dates means we store with "ms" resolution + expected["datetime"] = expected["datetime"].astype("M8[ms]") + if using_infer_string: + # object dtype (with only strings/None) comes back as string dtype + expected["object"] = expected["object"].astype("str") + + tm.assert_frame_equal( + written_and_read_again.set_index("index"), + expected, + ) + tm.assert_frame_equal(original, copy) + + def test_convert_strl_name_swap(self, temp_file): + original = DataFrame( + [["a" * 3000, "A", "apple"], ["b" * 1000, "B", "banana"]], + columns=["long1" * 10, "long", 1], + ) + original.index.name = "index" + + msg = "Not all pandas column names were valid Stata variable names" + with tm.assert_produces_warning(InvalidColumnName, match=msg): + path = temp_file + original.to_stata(path, convert_strl=["long", 1], version=117) + reread = self.read_dta(path) + reread = reread.set_index("index") + reread.columns = original.columns + tm.assert_frame_equal(reread, original, check_index_type=False) + + def test_invalid_date_conversion(self, temp_file): + # GH 12259 + dates = [ + dt.datetime(1999, 12, 31, 12, 12, 12, 12000), + dt.datetime(2012, 12, 21, 12, 21, 12, 21000), + dt.datetime(1776, 7, 4, 7, 4, 7, 4000), + ] + original = DataFrame( + { + "nums": [1.0, 2.0, 3.0], + "strs": ["apple", "banana", "cherry"], + "dates": dates, + } + ) + + path = temp_file + msg = "convert_dates key must be a column or an integer" + with pytest.raises(ValueError, match=msg): + original.to_stata(path, convert_dates={"wrong_name": "tc"}) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_nonfile_writing(self, version, temp_file): + # GH 21041 + bio = io.BytesIO() + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + df.index.name = "index" + path = temp_file + df.to_stata(bio, version=version) + bio.seek(0) + with open(path, "wb") as dta: + dta.write(bio.read()) + reread = read_stata(path, index_col="index") + tm.assert_frame_equal(df, reread) + + def test_gzip_writing(self, temp_file): + # writing version 117 requires seek and cannot be used with gzip + df = DataFrame( + 1.1 * np.arange(120).reshape((30, 4)), + columns=pd.Index(list("ABCD")), + index=pd.Index([f"i-{i}" for i in range(30)]), + ) + df.index.name = "index" + path = temp_file + with gzip.GzipFile(path, "wb") as gz: + df.to_stata(gz, version=114) + with gzip.GzipFile(path, "rb") as gz: + reread = read_stata(gz, index_col="index") + tm.assert_frame_equal(df, reread) + + # 117 is not included in this list as it uses ASCII strings + @pytest.mark.parametrize( + "file", + [ + "stata16_118", + "stata16_be_118", + "stata16_119", + "stata16_be_119", + ], + ) + def test_unicode_dta_118_119(self, file, datapath): + unicode_df = self.read_dta(datapath("io", "data", "stata", f"{file}.dta")) + + columns = ["utf8", "latin1", "ascii", "utf8_strl", "ascii_strl"] + values = [ + ["ραηδας", "PÄNDÄS", "p", "ραηδας", "p"], + ["ƤĀńĐąŜ", "Ö", "a", "ƤĀńĐąŜ", "a"], + ["ᴘᴀᴎᴅᴀS", "Ü", "n", "ᴘᴀᴎᴅᴀS", "n"], + [" ", " ", "d", " ", "d"], + [" ", "", "a", " ", "a"], + ["", "", "s", "", "s"], + ["", "", " ", "", " "], + ] + expected = DataFrame(values, columns=columns) + + tm.assert_frame_equal(unicode_df, expected) + + def test_mixed_string_strl(self, temp_file, using_infer_string): + # GH 23633 + output = [{"mixed": "string" * 500, "number": 0}, {"mixed": None, "number": 1}] + output = DataFrame(output) + output.number = output.number.astype("int32") + + path = temp_file + output.to_stata(path, write_index=False, version=117) + reread = read_stata(path) + expected = output.fillna("") + tm.assert_frame_equal(reread, expected) + + # Check strl supports all None (null) + output["mixed"] = None + output.to_stata(path, write_index=False, convert_strl=["mixed"], version=117) + reread = read_stata(path) + expected = output.fillna("") + if using_infer_string: + expected["mixed"] = expected["mixed"].astype("str") + tm.assert_frame_equal(reread, expected) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_all_none_exception(self, version, temp_file): + output = [{"none": "none", "number": 0}, {"none": None, "number": 1}] + output = DataFrame(output) + output["none"] = None + with pytest.raises(ValueError, match="Column `none` cannot be exported"): + output.to_stata(temp_file, version=version) + + @pytest.mark.parametrize("version", [114, 117, 118, 119, None]) + def test_invalid_file_not_written(self, version, temp_file): + content = "Here is one __�__ Another one __·__ Another one __½__" + df = DataFrame([content], columns=["invalid"]) + msg1 = ( + r"'latin-1' codec can't encode character '\\ufffd' " + r"in position 14: ordinal not in range\(256\)" + ) + msg2 = ( + "'ascii' codec can't decode byte 0xef in position 14: " + r"ordinal not in range\(128\)" + ) + with pytest.raises(UnicodeEncodeError, match=f"{msg1}|{msg2}"): + df.to_stata(temp_file) + + def test_strl_latin1(self, temp_file): + # GH 23573, correct GSO data to reflect correct size + output = DataFrame( + [["pandas"] * 2, ["þâÑÐŧ"] * 2], columns=["var_str", "var_strl"] + ) + + output.to_stata(temp_file, version=117, convert_strl=["var_strl"]) + with open(temp_file, "rb") as reread: + content = reread.read() + expected = "þâÑÐŧ" + assert expected.encode("latin-1") in content + assert expected.encode("utf-8") in content + gsos = content.split(b"strls")[1][1:-2] + for gso in gsos.split(b"GSO")[1:]: + val = gso.split(b"\x00")[-2] + size = gso[gso.find(b"\x82") + 1] + assert len(val) == size - 1 + + def test_encoding_latin1_118(self, datapath): + # GH 25960 + msg = """ +One or more strings in the dta file could not be decoded using utf-8, and +so the fallback encoding of latin-1 is being used. This can happen when a file +has been incorrectly encoded by Stata or some other software. You should verify +the string values returned are correct.""" + # Move path outside of read_stata, or else assert_produces_warning + # will block pytests skip mechanism from triggering (failing the test) + # if the path is not present + path = datapath("io", "data", "stata", "stata1_encoding_118.dta") + with tm.assert_produces_warning(UnicodeWarning, filter_level="once") as w: + encoded = read_stata(path) + # with filter_level="always", produces 151 warnings which can be slow + assert len(w) == 1 + assert w[0].message.args[0] == msg + + expected = DataFrame([["Düsseldorf"]] * 151, columns=["kreis1849"]) + tm.assert_frame_equal(encoded, expected) + + @pytest.mark.slow + def test_stata_119(self, datapath): + # Gzipped since contains 32,999 variables and uncompressed is 20MiB + # Just validate that the reader reports correct number of variables + # to avoid high peak memory + with gzip.open( + datapath("io", "data", "stata", "stata1_119.dta.gz"), "rb" + ) as gz: + with StataReader(gz) as reader: + reader._ensure_open() + assert reader._nvar == 32999 + + @pytest.mark.parametrize("version", [118, 119, None]) + @pytest.mark.parametrize("byteorder", ["little", "big"]) + def test_utf8_writer(self, version, byteorder, temp_file): + cat = pd.Categorical(["a", "β", "ĉ"], ordered=True) + data = DataFrame( + [ + [1.0, 1, "ᴬ", "ᴀ relatively long ŝtring"], + [2.0, 2, "ᴮ", ""], + [3.0, 3, "ᴰ", None], + ], + columns=["Å", "β", "ĉ", "strls"], + ) + data["ᴐᴬᵀ"] = cat + variable_labels = { + "Å": "apple", + "β": "ᵈᵉᵊ", + "ĉ": "ᴎტჄႲႳႴႶႺ", + "strls": "Long Strings", + "ᴐᴬᵀ": "", + } + data_label = "ᴅaᵀa-label" + value_labels = {"β": {1: "label", 2: "æøå", 3: "ŋot valid latin-1"}} + data["β"] = data["β"].astype(np.int32) + writer = StataWriterUTF8( + temp_file, + data, + data_label=data_label, + convert_strl=["strls"], + variable_labels=variable_labels, + write_index=False, + byteorder=byteorder, + version=version, + value_labels=value_labels, + ) + writer.write_file() + reread_encoded = read_stata(temp_file) + # Missing is intentionally converted to empty strl + data["strls"] = data["strls"].fillna("") + # Variable with value labels is reread as categorical + data["β"] = ( + data["β"].replace(value_labels["β"]).astype("category").cat.as_ordered() + ) + tm.assert_frame_equal(data, reread_encoded) + with StataReader(temp_file) as reader: + assert reader.data_label == data_label + assert reader.variable_labels() == variable_labels + + data.to_stata(temp_file, version=version, write_index=False) + reread_to_stata = read_stata(temp_file) + tm.assert_frame_equal(data, reread_to_stata) + + def test_writer_118_exceptions(self, temp_file): + df = DataFrame(np.zeros((1, 33000), dtype=np.int8)) + with pytest.raises(ValueError, match="version must be either 118 or 119."): + StataWriterUTF8(temp_file, df, version=117) + with pytest.raises(ValueError, match="You must use version 119"): + StataWriterUTF8(temp_file, df, version=118) + + @pytest.mark.parametrize( + "dtype_backend", + ["numpy_nullable", pytest.param("pyarrow", marks=td.skip_if_no("pyarrow"))], + ) + def test_read_write_ea_dtypes(self, dtype_backend, temp_file, tmp_path): + dtype = "Int64" if dtype_backend == "numpy_nullable" else "int64[pyarrow]" + df = DataFrame( + { + "a": pd.array([1, 2, None], dtype=dtype), + "b": ["a", "b", "c"], + "c": [True, False, None], + "d": [1.5, 2.5, 3.5], + "e": pd.date_range("2020-12-31", periods=3, freq="D"), + }, + index=pd.Index([0, 1, 2], name="index"), + ) + df = df.convert_dtypes(dtype_backend=dtype_backend) + stata_path = tmp_path / "test_stata.dta" + df.to_stata(stata_path, version=118) + + df.to_stata(temp_file) + written_and_read_again = self.read_dta(temp_file) + + expected = DataFrame( + { + "a": [1, 2, np.nan], + "b": ["a", "b", "c"], + "c": [1.0, 0, np.nan], + "d": [1.5, 2.5, 3.5], + # stata stores with ms unit, so unit does not round-trip exactly + "e": pd.date_range("2020-12-31", periods=3, freq="D", unit="ms"), + }, + index=pd.RangeIndex(range(3), name="index"), + ) + + tm.assert_frame_equal(written_and_read_again.set_index("index"), expected) + + @pytest.mark.parametrize("version", [113, 114, 115, 117, 118, 119]) + def test_read_data_int_validranges(self, version, datapath): + expected = DataFrame( + { + "byte": np.array([-127, 100], dtype=np.int8), + "int": np.array([-32767, 32740], dtype=np.int16), + "long": np.array([-2147483647, 2147483620], dtype=np.int32), + } + ) + + parsed = read_stata( + datapath("io", "data", "stata", f"stata_int_validranges_{version}.dta") + ) + tm.assert_frame_equal(parsed, expected) + + @pytest.mark.parametrize("version", [104, 105, 108, 110, 111]) + def test_read_data_int_validranges_compat(self, version, datapath): + expected = DataFrame( + { + "byte": np.array([-128, 126], dtype=np.int8), + "int": np.array([-32768, 32766], dtype=np.int16), + "long": np.array([-2147483648, 2147483646], dtype=np.int32), + } + ) + + parsed = read_stata( + datapath("io", "data", "stata", f"stata_int_validranges_{version}.dta") + ) + tm.assert_frame_equal(parsed, expected) + + # The byte type was not supported prior to the 104 format + @pytest.mark.parametrize("version", [102, 103]) + def test_read_data_int_validranges_compat_nobyte(self, version, datapath): + expected = DataFrame( + { + "byte": np.array([-128, 126], dtype=np.int16), + "int": np.array([-32768, 32766], dtype=np.int16), + "long": np.array([-2147483648, 2147483646], dtype=np.int32), + } + ) + + parsed = read_stata( + datapath("io", "data", "stata", f"stata_int_validranges_{version}.dta") + ) + tm.assert_frame_equal(parsed, expected) + + +@pytest.mark.parametrize("version", [105, 108, 110, 111, 113, 114]) +def test_backward_compat(version, datapath): + data_base = datapath("io", "data", "stata") + ref = os.path.join(data_base, "stata-compat-118.dta") + old = os.path.join(data_base, f"stata-compat-{version}.dta") + expected = read_stata(ref) + old_dta = read_stata(old) + tm.assert_frame_equal(old_dta, expected, check_dtype=False) + + +@pytest.mark.parametrize("version", [103, 104]) +def test_backward_compat_nodateconversion(version, datapath): + # The Stata data format prior to 105 did not support a date format + # so read the raw values for comparison + data_base = datapath("io", "data", "stata") + ref = os.path.join(data_base, "stata-compat-118.dta") + old = os.path.join(data_base, f"stata-compat-{version}.dta") + expected = read_stata(ref, convert_dates=False) + old_dta = read_stata(old, convert_dates=False) + tm.assert_frame_equal(old_dta, expected, check_dtype=False) + + +@pytest.mark.parametrize("version", [102]) +def test_backward_compat_nostring(version, datapath): + # The Stata data format prior to 105 did not support a date format + # so read the raw values for comparison + ref = datapath("io", "data", "stata", "stata-compat-118.dta") + old = datapath("io", "data", "stata", f"stata-compat-{version}.dta") + expected = read_stata(ref, convert_dates=False) + # The Stata data format prior to 103 did not support string data + expected = expected.drop(columns=["s10"]) + old_dta = read_stata(old, convert_dates=False) + tm.assert_frame_equal(old_dta, expected, check_dtype=False) + + +@pytest.mark.parametrize("version", [105, 108, 110, 111, 113, 114, 118]) +def test_bigendian(version, datapath): + ref = datapath("io", "data", "stata", f"stata-compat-{version}.dta") + big = datapath("io", "data", "stata", f"stata-compat-be-{version}.dta") + expected = read_stata(ref) + big_dta = read_stata(big) + tm.assert_frame_equal(big_dta, expected) + + +# Note: 102 format does not support big-endian byte order +@pytest.mark.parametrize("version", [103, 104]) +def test_bigendian_nodateconversion(version, datapath): + # The Stata data format prior to 105 did not support a date format + # so read the raw values for comparison + ref = datapath("io", "data", "stata", f"stata-compat-{version}.dta") + big = datapath("io", "data", "stata", f"stata-compat-be-{version}.dta") + expected = read_stata(ref, convert_dates=False) + big_dta = read_stata(big, convert_dates=False) + tm.assert_frame_equal(big_dta, expected) + + +def test_direct_read(datapath, monkeypatch): + file_path = datapath("io", "data", "stata", "stata-compat-118.dta") + + # Test that opening a file path doesn't buffer the file. + with StataReader(file_path) as reader: + # Must not have been buffered to memory + assert not reader.read().empty + assert not isinstance(reader._path_or_buf, io.BytesIO) + + # Test that we use a given fp exactly, if possible. + with open(file_path, "rb") as fp: + with StataReader(fp) as reader: + assert not reader.read().empty + assert reader._path_or_buf is fp + + # Test that we use a given BytesIO exactly, if possible. + with open(file_path, "rb") as fp: + with io.BytesIO(fp.read()) as bio: + with StataReader(bio) as reader: + assert not reader.read().empty + assert reader._path_or_buf is bio + + +@pytest.mark.parametrize("version", [114, 117, 118, 119, None]) +@pytest.mark.parametrize("use_dict", [True, False]) +@pytest.mark.parametrize("infer", [True, False]) +def test_compression( + compression, version, use_dict, infer, compression_to_extension, tmp_path +): + file_name = "dta_inferred_compression.dta" + if compression: + if use_dict: + file_ext = compression + else: + file_ext = compression_to_extension[compression] + file_name += f".{file_ext}" + compression_arg = compression + if infer: + compression_arg = "infer" + if use_dict: + compression_arg = {"method": compression} + + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 2)), columns=list("AB") + ) + df.index.name = "index" + path = tmp_path / file_name + path.touch() + df.to_stata(path, version=version, compression=compression_arg) + if compression == "gzip": + with gzip.open(path, "rb") as comp: + fp = io.BytesIO(comp.read()) + elif compression == "zip": + with zipfile.ZipFile(path, "r") as comp: + fp = io.BytesIO(comp.read(comp.filelist[0])) + elif compression == "tar": + with tarfile.open(path) as tar: + fp = io.BytesIO(tar.extractfile(tar.getnames()[0]).read()) + elif compression == "bz2": + with bz2.open(path, "rb") as comp: + fp = io.BytesIO(comp.read()) + elif compression == "zstd": + zstd = pytest.importorskip("zstandard") + with zstd.open(path, "rb") as comp: + fp = io.BytesIO(comp.read()) + elif compression == "xz": + lzma = pytest.importorskip("lzma") + with lzma.open(path, "rb") as comp: + fp = io.BytesIO(comp.read()) + elif compression is None: + fp = path + reread = read_stata(fp, index_col="index") + + expected = df + tm.assert_frame_equal(reread, expected) + + +@pytest.mark.parametrize("method", ["zip", "infer"]) +@pytest.mark.parametrize("file_ext", [None, "dta", "zip"]) +def test_compression_dict(method, file_ext, tmp_path): + file_name = f"test.{file_ext}" + archive_name = "test.dta" + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 2)), columns=list("AB") + ) + df.index.name = "index" + compression = {"method": method, "archive_name": archive_name} + path = tmp_path / file_name + path.touch() + df.to_stata(path, compression=compression) + if method == "zip" or file_ext == "zip": + with zipfile.ZipFile(path, "r") as zp: + assert len(zp.filelist) == 1 + assert zp.filelist[0].filename == archive_name + fp = io.BytesIO(zp.read(zp.filelist[0])) + else: + fp = path + reread = read_stata(fp, index_col="index") + + expected = df + tm.assert_frame_equal(reread, expected) + + +@pytest.mark.parametrize("version", [114, 117, 118, 119, None]) +def test_chunked_categorical(version, temp_file): + df = DataFrame({"cats": Series(["a", "b", "a", "b", "c"], dtype="category")}) + df.index.name = "index" + + expected = df.copy() + + df.to_stata(temp_file, version=version) + with StataReader(temp_file, chunksize=2, order_categoricals=False) as reader: + for i, block in enumerate(reader): + block = block.set_index("index") + assert "cats" in block + tm.assert_series_equal( + block.cats, + expected.cats.iloc[2 * i : 2 * (i + 1)], + check_index_type=len(block) > 1, + ) + + +def test_chunked_categorical_partial(datapath): + dta_file = datapath("io", "data", "stata", "stata-dta-partially-labeled.dta") + values = ["a", "b", "a", "b", 3.0] + msg = "series with value labels are not fully labeled" + with StataReader(dta_file, chunksize=2) as reader: + with tm.assert_produces_warning(CategoricalConversionWarning, match=msg): + for i, block in enumerate(reader): + assert list(block.cats) == values[2 * i : 2 * (i + 1)] + if i < 2: + idx = pd.Index(["a", "b"]) + else: + idx = pd.Index([3.0], dtype="float64") + tm.assert_index_equal(block.cats.cat.categories, idx) + with tm.assert_produces_warning(CategoricalConversionWarning, match=msg): + with StataReader(dta_file, chunksize=5) as reader: + large_chunk = reader.__next__() + direct = read_stata(dta_file) + tm.assert_frame_equal(direct, large_chunk) + + +@pytest.mark.parametrize("chunksize", (-1, 0, "apple")) +def test_iterator_errors(datapath, chunksize): + dta_file = datapath("io", "data", "stata", "stata-dta-partially-labeled.dta") + with pytest.raises(ValueError, match="chunksize must be a positive"): + with StataReader(dta_file, chunksize=chunksize): + pass + + +def test_iterator_value_labels(temp_file): + # GH 31544 + values = ["c_label", "b_label"] + ["a_label"] * 500 + df = DataFrame({f"col{k}": pd.Categorical(values, ordered=True) for k in range(2)}) + df.to_stata(temp_file, write_index=False) + expected = pd.Index(["a_label", "b_label", "c_label"]) + with read_stata(temp_file, chunksize=100) as reader: + for j, chunk in enumerate(reader): + for i in range(2): + tm.assert_index_equal(chunk.dtypes.iloc[i].categories, expected) + tm.assert_frame_equal(chunk, df.iloc[j * 100 : (j + 1) * 100]) + + +def test_precision_loss(temp_file): + df = DataFrame( + [[sum(2**i for i in range(60)), sum(2**i for i in range(52))]], + columns=["big", "little"], + ) + with tm.assert_produces_warning( + PossiblePrecisionLoss, match="Column converted from int64 to float64" + ): + df.to_stata(temp_file, write_index=False) + reread = read_stata(temp_file) + expected_dt = Series([np.float64, np.float64], index=["big", "little"]) + tm.assert_series_equal(reread.dtypes, expected_dt) + assert reread.loc[0, "little"] == df.loc[0, "little"] + assert reread.loc[0, "big"] == float(df.loc[0, "big"]) + + +def test_compression_roundtrip(compression, temp_file): + df = DataFrame( + [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + index=["A", "B"], + columns=["X", "Y", "Z"], + ) + df.index.name = "index" + + df.to_stata(temp_file, compression=compression) + reread = read_stata(temp_file, compression=compression, index_col="index") + tm.assert_frame_equal(df, reread) + + # explicitly ensure file was compressed. + with tm.decompress_file(temp_file, compression) as fh: + contents = io.BytesIO(fh.read()) + reread = read_stata(contents, index_col="index") + tm.assert_frame_equal(df, reread) + + +@pytest.mark.parametrize("to_infer", [True, False]) +@pytest.mark.parametrize("read_infer", [True, False]) +def test_stata_compression( + compression_only, read_infer, to_infer, compression_to_extension, tmp_path +): + compression = compression_only + + ext = compression_to_extension[compression] + filename = f"test.{ext}" + + df = DataFrame( + [[0.123456, 0.234567, 0.567567], [12.32112, 123123.2, 321321.2]], + index=["A", "B"], + columns=["X", "Y", "Z"], + ) + df.index.name = "index" + + to_compression = "infer" if to_infer else compression + read_compression = "infer" if read_infer else compression + + path = tmp_path / filename + path.touch() + df.to_stata(path, compression=to_compression) + result = read_stata(path, compression=read_compression, index_col="index") + tm.assert_frame_equal(result, df) + + +def test_non_categorical_value_labels(temp_file): + data = DataFrame( + { + "fully_labelled": [1, 2, 3, 3, 1], + "partially_labelled": [1.0, 2.0, np.nan, 9.0, np.nan], + "Y": [7, 7, 9, 8, 10], + "Z": pd.Categorical(["j", "k", "l", "k", "j"]), + } + ) + + path = temp_file + value_labels = { + "fully_labelled": {1: "one", 2: "two", 3: "three"}, + "partially_labelled": {1.0: "one", 2.0: "two"}, + } + expected = {**value_labels, "Z": {0: "j", 1: "k", 2: "l"}} + + writer = StataWriter(path, data, value_labels=value_labels) + writer.write_file() + + with StataReader(path) as reader: + reader_value_labels = reader.value_labels() + assert reader_value_labels == expected + + msg = "Can't create value labels for notY, it wasn't found in the dataset." + value_labels = {"notY": {7: "label1", 8: "label2"}} + with pytest.raises(KeyError, match=msg): + StataWriter(path, data, value_labels=value_labels) + + msg = ( + "Can't create value labels for Z, value labels " + "can only be applied to numeric columns." + ) + value_labels = {"Z": {1: "a", 2: "k", 3: "j", 4: "i"}} + with pytest.raises(ValueError, match=msg): + StataWriter(path, data, value_labels=value_labels) + + +def test_non_categorical_value_label_name_conversion(temp_file): + # Check conversion of invalid variable names + data = DataFrame( + { + "invalid~!": [1, 1, 2, 3, 5, 8], # Only alphanumeric and _ + "6_invalid": [1, 1, 2, 3, 5, 8], # Must start with letter or _ + "invalid_name_longer_than_32_characters": [8, 8, 9, 9, 8, 8], # Too long + "aggregate": [2, 5, 5, 6, 6, 9], # Reserved words + (1, 2): [1, 2, 3, 4, 5, 6], # Hashable non-string + } + ) + + value_labels = { + "invalid~!": {1: "label1", 2: "label2"}, + "6_invalid": {1: "label1", 2: "label2"}, + "invalid_name_longer_than_32_characters": {8: "eight", 9: "nine"}, + "aggregate": {5: "five"}, + (1, 2): {3: "three"}, + } + + expected = { + "invalid__": {1: "label1", 2: "label2"}, + "_6_invalid": {1: "label1", 2: "label2"}, + "invalid_name_longer_than_32_char": {8: "eight", 9: "nine"}, + "_aggregate": {5: "five"}, + "_1__2_": {3: "three"}, + } + + msg = "Not all pandas column names were valid Stata variable names" + with tm.assert_produces_warning(InvalidColumnName, match=msg): + data.to_stata(temp_file, value_labels=value_labels) + + with StataReader(temp_file) as reader: + reader_value_labels = reader.value_labels() + assert reader_value_labels == expected + + +def test_non_categorical_value_label_convert_categoricals_error(temp_file): + # Mapping more than one value to the same label is valid for Stata + # labels, but can't be read with convert_categoricals=True + value_labels = { + "repeated_labels": {10: "Ten", 20: "More than ten", 40: "More than ten"} + } + + data = DataFrame( + { + "repeated_labels": [10, 10, 20, 20, 40, 40], + } + ) + + data.to_stata(temp_file, value_labels=value_labels) + + with StataReader(temp_file, convert_categoricals=False) as reader: + reader_value_labels = reader.value_labels() + assert reader_value_labels == value_labels + + col = "repeated_labels" + repeats = "-" * 80 + "\n" + "\n".join(["More than ten"]) + + msg = f""" +Value labels for column {col} are not unique. These cannot be converted to +pandas categoricals. + +Either read the file with `convert_categoricals` set to False or use the +low level interface in `StataReader` to separately read the values and the +value_labels. + +The repeated labels are: +{repeats} +""" + with pytest.raises(ValueError, match=msg): + read_stata(temp_file, convert_categoricals=True) + + +@pytest.mark.parametrize("version", [114, 117, 118, 119, None]) +@pytest.mark.parametrize( + "dtype", + [ + pd.BooleanDtype, + pd.Int8Dtype, + pd.Int16Dtype, + pd.Int32Dtype, + pd.Int64Dtype, + pd.UInt8Dtype, + pd.UInt16Dtype, + pd.UInt32Dtype, + pd.UInt64Dtype, + ], +) +def test_nullable_support(dtype, version, temp_file): + df = DataFrame( + { + "a": Series([1.0, 2.0, 3.0]), + "b": Series([1, pd.NA, pd.NA], dtype=dtype.name), + "c": Series(["a", "b", None]), + } + ) + dtype_name = df.b.dtype.numpy_dtype.name + # Only use supported names: no uint, bool or int64 + dtype_name = dtype_name.replace("u", "") + if dtype_name == "int64": + dtype_name = "int32" + elif dtype_name == "bool": + dtype_name = "int8" + value = StataMissingValue.BASE_MISSING_VALUES[dtype_name] + smv = StataMissingValue(value) + expected_b = Series([1, smv, smv], dtype=object, name="b") + expected_c = Series(["a", "b", ""], name="c") + df.to_stata(temp_file, write_index=False, version=version) + reread = read_stata(temp_file, convert_missing=True) + tm.assert_series_equal(df.a, reread.a) + tm.assert_series_equal(reread.b, expected_b) + tm.assert_series_equal(reread.c, expected_c) + + +def test_empty_frame(temp_file): + # GH 46240 + # create an empty DataFrame with int64 and float64 dtypes + df = DataFrame(data={"a": range(3), "b": [1.0, 2.0, 3.0]}).head(0) + path = temp_file + df.to_stata(path, write_index=False, version=117) + # Read entire dataframe + df2 = read_stata(path) + assert "b" in df2 + # Dtypes don't match since no support for int32 + dtypes = Series({"a": np.dtype("int32"), "b": np.dtype("float64")}) + tm.assert_series_equal(df2.dtypes, dtypes) + # read one column of empty .dta file + df3 = read_stata(path, columns=["a"]) + assert "b" not in df3 + tm.assert_series_equal(df3.dtypes, dtypes.loc[["a"]]) + + +@pytest.mark.parametrize("version", [114, 117, 118, 119, None]) +def test_many_strl(temp_file, version): + n = 65534 + df = DataFrame(np.arange(n), columns=["col"]) + lbls = ["".join(v) for v in itertools.product(*([string.ascii_letters] * 3))] + value_labels = {"col": {i: lbls[i] for i in range(n)}} + df.to_stata(temp_file, value_labels=value_labels, version=version) + + +@pytest.mark.parametrize("version", [117, 118, 119, None]) +def test_strl_missings(temp_file, version): + # GH 23633 + # Check that strl supports None and pd.NA + df = DataFrame( + [ + {"str1": "string" * 500, "number": 0}, + {"str1": None, "number": 1}, + {"str1": pd.NA, "number": 1}, + ] + ) + df.to_stata(temp_file, version=version) + + +@pytest.mark.parametrize("version", [117, 118, 119, None]) +def test_ascii_error(temp_file, version): + # GH #61583 + # Check that 2 byte long unicode characters doesn't cause export error + df = DataFrame({"doubleByteCol": ["§" * 1500]}) + df.to_stata(temp_file, write_index=0, version=version) + df_input = read_stata(temp_file) + tm.assert_frame_equal(df, df_input) diff --git a/pandas/tests/libs/__init__.py b/pandas/tests/libs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/libs/test_hashtable.py b/pandas/tests/libs/test_hashtable.py new file mode 100644 index 0000000000000000000000000000000000000000..1f24f87348595b7416e6717f929b92f0c9df22c4 --- /dev/null +++ b/pandas/tests/libs/test_hashtable.py @@ -0,0 +1,782 @@ +from collections import namedtuple +from collections.abc import Generator +from contextlib import contextmanager +import re +import struct +import tracemalloc + +import numpy as np +import pytest + +from pandas._libs import hashtable as ht + +import pandas as pd +import pandas._testing as tm +from pandas.core.algorithms import isin + + +@contextmanager +def activated_tracemalloc() -> Generator[None, None, None]: + tracemalloc.start() + try: + yield + finally: + tracemalloc.stop() + + +def get_allocated_khash_memory(): + snapshot = tracemalloc.take_snapshot() + snapshot = snapshot.filter_traces( + (tracemalloc.DomainFilter(True, ht.get_hashtable_trace_domain()),) + ) + return sum(x.size for x in snapshot.traces) + + +@pytest.mark.parametrize( + "table_type, dtype", + [ + (ht.PyObjectHashTable, np.object_), + (ht.Complex128HashTable, np.complex128), + (ht.Int64HashTable, np.int64), + (ht.UInt64HashTable, np.uint64), + (ht.Float64HashTable, np.float64), + (ht.Complex64HashTable, np.complex64), + (ht.Int32HashTable, np.int32), + (ht.UInt32HashTable, np.uint32), + (ht.Float32HashTable, np.float32), + (ht.Int16HashTable, np.int16), + (ht.UInt16HashTable, np.uint16), + (ht.Int8HashTable, np.int8), + (ht.UInt8HashTable, np.uint8), + (ht.IntpHashTable, np.intp), + ], +) +class TestHashTable: + def test_get_set_contains_len(self, table_type, dtype): + index = 5 + table = table_type(55) + assert len(table) == 0 + assert index not in table + + table.set_item(index, 42) + assert len(table) == 1 + assert index in table + assert table.get_item(index) == 42 + + table.set_item(index + 1, 41) + assert index in table + assert index + 1 in table + assert len(table) == 2 + assert table.get_item(index) == 42 + assert table.get_item(index + 1) == 41 + + table.set_item(index, 21) + assert index in table + assert index + 1 in table + assert len(table) == 2 + assert table.get_item(index) == 21 + assert table.get_item(index + 1) == 41 + assert index + 2 not in table + + table.set_item(index + 1, 21) + assert index in table + assert index + 1 in table + assert len(table) == 2 + assert table.get_item(index) == 21 + assert table.get_item(index + 1) == 21 + + with pytest.raises(KeyError, match=str(index + 2)): + table.get_item(index + 2) + + def test_get_set_contains_len_mask(self, table_type, dtype): + if table_type == ht.PyObjectHashTable: + pytest.skip("Mask not supported for object") + index = 5 + table = table_type(55, uses_mask=True) + assert len(table) == 0 + assert index not in table + + table.set_item(index, 42) + assert len(table) == 1 + assert index in table + assert table.get_item(index) == 42 + with pytest.raises(KeyError, match="NA"): + table.get_na() + + table.set_item(index + 1, 41) + table.set_na(41) + assert pd.NA in table + assert index in table + assert index + 1 in table + assert len(table) == 3 + assert table.get_item(index) == 42 + assert table.get_item(index + 1) == 41 + assert table.get_na() == 41 + + table.set_na(21) + assert index in table + assert index + 1 in table + assert len(table) == 3 + assert table.get_item(index + 1) == 41 + assert table.get_na() == 21 + assert index + 2 not in table + + with pytest.raises(KeyError, match=str(index + 2)): + table.get_item(index + 2) + + def test_map_keys_to_values(self, table_type, dtype, writable): + # only Int64HashTable has this method + if table_type == ht.Int64HashTable: + N = 77 + table = table_type() + keys = np.arange(N).astype(dtype) + vals = np.arange(N).astype(np.int64) + N + keys.flags.writeable = writable + vals.flags.writeable = writable + table.map_keys_to_values(keys, vals) + for i in range(N): + assert table.get_item(keys[i]) == i + N + + def test_map_locations(self, table_type, dtype, writable): + N = 8 + table = table_type() + keys = (np.arange(N) + N).astype(dtype) + keys.flags.writeable = writable + table.map_locations(keys) + for i in range(N): + assert table.get_item(keys[i]) == i + + def test_map_locations_mask(self, table_type, dtype, writable): + if table_type == ht.PyObjectHashTable: + pytest.skip("Mask not supported for object") + N = 129 # must be > 128 to test GH#58924 + table = table_type(uses_mask=True) + keys = (np.arange(N) + N).astype(dtype) + keys.flags.writeable = writable + mask = np.concatenate([np.repeat(False, N - 1), [True]], axis=0) + table.map_locations(keys, mask) + for i in range(N - 1): + assert table.get_item(keys[i]) == i + + with pytest.raises(KeyError, match=re.escape(str(keys[N - 1]))): + table.get_item(keys[N - 1]) + + assert table.get_na() == N - 1 + + def test_lookup(self, table_type, dtype, writable): + N = 3 + table = table_type() + keys = (np.arange(N) + N).astype(dtype) + keys.flags.writeable = writable + table.map_locations(keys) + result = table.lookup(keys) + expected = np.arange(N) + tm.assert_numpy_array_equal(result.astype(np.int64), expected.astype(np.int64)) + + def test_lookup_wrong(self, table_type, dtype): + if dtype in (np.int8, np.uint8): + N = 100 + else: + N = 512 + table = table_type() + keys = (np.arange(N) + N).astype(dtype) + table.map_locations(keys) + wrong_keys = np.arange(N).astype(dtype) + result = table.lookup(wrong_keys) + assert np.all(result == -1) + + def test_lookup_mask(self, table_type, dtype, writable): + if table_type == ht.PyObjectHashTable: + pytest.skip("Mask not supported for object") + N = 3 + table = table_type(uses_mask=True) + keys = (np.arange(N) + N).astype(dtype) + mask = np.array([False, True, False]) + keys.flags.writeable = writable + table.map_locations(keys, mask) + result = table.lookup(keys, mask) + expected = np.arange(N) + tm.assert_numpy_array_equal(result.astype(np.int64), expected.astype(np.int64)) + + result = table.lookup(np.array([1 + N]).astype(dtype), np.array([False])) + tm.assert_numpy_array_equal( + result.astype(np.int64), np.array([-1], dtype=np.int64) + ) + + def test_unique(self, table_type, dtype, writable): + if dtype in (np.int8, np.uint8): + N = 88 + else: + N = 1000 + table = table_type() + expected = (np.arange(N) + N).astype(dtype) + keys = np.repeat(expected, 5) + keys.flags.writeable = writable + unique = table.unique(keys) + tm.assert_numpy_array_equal(unique, expected) + + def test_tracemalloc_works(self, table_type, dtype): + if dtype in (np.int8, np.uint8): + N = 256 + else: + N = 30000 + keys = np.arange(N).astype(dtype) + with activated_tracemalloc(): + table = table_type() + table.map_locations(keys) + used = get_allocated_khash_memory() + my_size = table.sizeof() + assert used == my_size + del table + assert get_allocated_khash_memory() == 0 + + def test_tracemalloc_for_empty(self, table_type, dtype): + with activated_tracemalloc(): + table = table_type() + used = get_allocated_khash_memory() + my_size = table.sizeof() + assert used == my_size + del table + assert get_allocated_khash_memory() == 0 + + def test_get_state(self, table_type, dtype): + table = table_type(1000) + state = table.get_state() + assert state["size"] == 0 + assert state["n_occupied"] == 0 + assert "n_buckets" in state + assert "upper_bound" in state + + @pytest.mark.parametrize("N", range(1, 110, 4)) + def test_no_reallocation(self, table_type, dtype, N): + keys = np.arange(N).astype(dtype) + preallocated_table = table_type(N) + n_buckets_start = preallocated_table.get_state()["n_buckets"] + preallocated_table.map_locations(keys) + n_buckets_end = preallocated_table.get_state()["n_buckets"] + # original number of buckets was enough: + assert n_buckets_start == n_buckets_end + # check with clean table (not too much preallocated) + clean_table = table_type() + clean_table.map_locations(keys) + assert n_buckets_start == clean_table.get_state()["n_buckets"] + + +class TestHashTableUnsorted: + # TODO: moved from test_algos; may be redundancies with other tests + def test_string_hashtable_set_item_signature(self): + # GH#30419 fix typing in StringHashTable.set_item to prevent segfault + tbl = ht.StringHashTable() + + tbl.set_item("key", 1) + assert tbl.get_item("key") == 1 + + with pytest.raises(TypeError, match="'key' has incorrect type"): + # key arg typed as string, not object + tbl.set_item(4, 6) + with pytest.raises(TypeError, match="'val' has incorrect type"): + tbl.get_item(4) + + def test_lookup_nan(self, writable): + # GH#21688 ensure we can deal with readonly memory views + xs = np.array([2.718, 3.14, np.nan, -7, 5, 2, 3]) + xs.setflags(write=writable) + m = ht.Float64HashTable() + m.map_locations(xs) + tm.assert_numpy_array_equal(m.lookup(xs), np.arange(len(xs), dtype=np.intp)) + + def test_add_signed_zeros(self): + # GH#21866 inconsistent hash-function for float64 + # default hash-function would lead to different hash-buckets + # for 0.0 and -0.0 if there are more than 2^30 hash-buckets + # but this would mean 16GB + N = 4 # 12 * 10**8 would trigger the error, if you have enough memory + m = ht.Float64HashTable(N) + m.set_item(0.0, 0) + m.set_item(-0.0, 0) + assert len(m) == 1 # 0.0 and -0.0 are equivalent + + def test_add_different_nans(self): + # GH#21866 inconsistent hash-function for float64 + # create different nans from bit-patterns: + NAN1 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000000))[0] + NAN2 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000001))[0] + assert NAN1 != NAN1 + assert NAN2 != NAN2 + # default hash function would lead to different hash-buckets + # for NAN1 and NAN2 even if there are only 4 buckets: + m = ht.Float64HashTable() + m.set_item(NAN1, 0) + m.set_item(NAN2, 0) + assert len(m) == 1 # NAN1 and NAN2 are equivalent + + def test_lookup_overflow(self, writable): + xs = np.array([1, 2, 2**63], dtype=np.uint64) + # GH 21688 ensure we can deal with readonly memory views + xs.setflags(write=writable) + m = ht.UInt64HashTable() + m.map_locations(xs) + tm.assert_numpy_array_equal(m.lookup(xs), np.arange(len(xs), dtype=np.intp)) + + @pytest.mark.parametrize("nvals", [0, 10]) # resizing to 0 is special case + @pytest.mark.parametrize( + "htable, uniques, dtype, safely_resizes", + [ + (ht.PyObjectHashTable, ht.ObjectVector, "object", False), + (ht.StringHashTable, ht.ObjectVector, "object", True), + (ht.Float64HashTable, ht.Float64Vector, "float64", False), + (ht.Int64HashTable, ht.Int64Vector, "int64", False), + (ht.Int32HashTable, ht.Int32Vector, "int32", False), + (ht.UInt64HashTable, ht.UInt64Vector, "uint64", False), + ], + ) + def test_vector_resize( + self, writable, htable, uniques, dtype, safely_resizes, nvals + ): + # Test for memory errors after internal vector + # reallocations (GH 7157) + # Changed from using np.random.default_rng(2).rand to range + # which could cause flaky CI failures when safely_resizes=False + vals = np.array(range(1000), dtype=dtype) + + # GH 21688 ensures we can deal with read-only memory views + vals.setflags(write=writable) + + # initialise instances; cannot initialise in parametrization, + # as otherwise external views would be held on the array (which is + # one of the things this test is checking) + htable = htable() + uniques = uniques() + + # get_labels may append to uniques + htable.get_labels(vals[:nvals], uniques, 0, -1) + # to_array() sets an external_view_exists flag on uniques. + tmp = uniques.to_array() + oldshape = tmp.shape + + # subsequent get_labels() calls can no longer append to it + # (except for StringHashTables + ObjectVector) + if safely_resizes: + htable.get_labels(vals, uniques, 0, -1) + else: + with pytest.raises(ValueError, match="external reference.*"): + htable.get_labels(vals, uniques, 0, -1) + + uniques.to_array() # should not raise here + assert tmp.shape == oldshape + + @pytest.mark.parametrize( + "hashtable", + [ + ht.PyObjectHashTable, + ht.StringHashTable, + ht.Float64HashTable, + ht.Int64HashTable, + ht.Int32HashTable, + ht.UInt64HashTable, + ], + ) + def test_hashtable_large_sizehint(self, hashtable): + # GH#22729 smoketest for not raising when passing a large size_hint + size_hint = np.iinfo(np.uint32).max + 1 + hashtable(size_hint=size_hint) + + +class TestPyObjectHashTableWithNans: + def test_nan_float(self): + nan1 = float("nan") + nan2 = float("nan") + assert nan1 is not nan2 + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + + def test_nan_complex_both(self): + nan1 = complex(float("nan"), float("nan")) + nan2 = complex(float("nan"), float("nan")) + assert nan1 is not nan2 + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + + def test_nan_complex_real(self): + nan1 = complex(float("nan"), 1) + nan2 = complex(float("nan"), 1) + other = complex(float("nan"), 2) + assert nan1 is not nan2 + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + with pytest.raises(KeyError, match=re.escape(repr(other))): + table.get_item(other) + + def test_nan_complex_imag(self): + nan1 = complex(1, float("nan")) + nan2 = complex(1, float("nan")) + other = complex(2, float("nan")) + assert nan1 is not nan2 + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + with pytest.raises(KeyError, match=re.escape(repr(other))): + table.get_item(other) + + def test_nan_in_tuple(self): + nan1 = (float("nan"),) + nan2 = (float("nan"),) + assert nan1[0] is not nan2[0] + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + + def test_nan_in_nested_tuple(self): + nan1 = (1, (2, (float("nan"),))) + nan2 = (1, (2, (float("nan"),))) + other = (1, 2) + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + with pytest.raises(KeyError, match=re.escape(repr(other))): + table.get_item(other) + + def test_nan_in_namedtuple(self): + T = namedtuple("T", ["x"]) + nan1 = T(float("nan")) + nan2 = T(float("nan")) + assert nan1.x is not nan2.x + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + + def test_nan_in_nested_namedtuple(self): + T = namedtuple("T", ["x", "y"]) + nan1 = T(1, (2, (float("nan"),))) + nan2 = T(1, (2, (float("nan"),))) + other = T(1, 2) + table = ht.PyObjectHashTable() + table.set_item(nan1, 42) + assert table.get_item(nan2) == 42 + with pytest.raises(KeyError, match=re.escape(repr(other))): + table.get_item(other) + + +def test_hash_equal_tuple_with_nans(): + a = (float("nan"), (float("nan"), float("nan"))) + b = (float("nan"), (float("nan"), float("nan"))) + assert ht.object_hash(a) == ht.object_hash(b) + assert ht.objects_are_equal(a, b) + + +def test_hash_equal_namedtuple_with_nans(): + T = namedtuple("T", ["x", "y"]) + a = T(float("nan"), (float("nan"), float("nan"))) + b = T(float("nan"), (float("nan"), float("nan"))) + assert ht.object_hash(a) == ht.object_hash(b) + assert ht.objects_are_equal(a, b) + + +def test_hash_equal_namedtuple_and_tuple(): + T = namedtuple("T", ["x", "y"]) + a = T(1, (2, 3)) + b = (1, (2, 3)) + assert ht.object_hash(a) == ht.object_hash(b) + assert ht.objects_are_equal(a, b) + + +def test_get_labels_groupby_for_Int64(writable): + table = ht.Int64HashTable() + vals = np.array([1, 2, -1, 2, 1, -1], dtype=np.int64) + vals.flags.writeable = writable + arr, unique = table.get_labels_groupby(vals) + expected_arr = np.array([0, 1, -1, 1, 0, -1], dtype=np.intp) + expected_unique = np.array([1, 2], dtype=np.int64) + tm.assert_numpy_array_equal(arr, expected_arr) + tm.assert_numpy_array_equal(unique, expected_unique) + + +def test_tracemalloc_works_for_StringHashTable(): + N = 1000 + keys = np.arange(N).astype(np.str_).astype(np.object_) + with activated_tracemalloc(): + table = ht.StringHashTable() + table.map_locations(keys) + used = get_allocated_khash_memory() + my_size = table.sizeof() + assert used == my_size + del table + assert get_allocated_khash_memory() == 0 + + +def test_tracemalloc_for_empty_StringHashTable(): + with activated_tracemalloc(): + table = ht.StringHashTable() + used = get_allocated_khash_memory() + my_size = table.sizeof() + assert used == my_size + del table + assert get_allocated_khash_memory() == 0 + + +@pytest.mark.parametrize("N", range(1, 110, 4)) +def test_no_reallocation_StringHashTable(N): + keys = np.arange(N).astype(np.str_).astype(np.object_) + preallocated_table = ht.StringHashTable(N) + n_buckets_start = preallocated_table.get_state()["n_buckets"] + preallocated_table.map_locations(keys) + n_buckets_end = preallocated_table.get_state()["n_buckets"] + # original number of buckets was enough: + assert n_buckets_start == n_buckets_end + # check with clean table (not too much preallocated) + clean_table = ht.StringHashTable() + clean_table.map_locations(keys) + assert n_buckets_start == clean_table.get_state()["n_buckets"] + + +@pytest.mark.parametrize( + "table_type, dtype", + [ + (ht.Float64HashTable, np.float64), + (ht.Float32HashTable, np.float32), + (ht.Complex128HashTable, np.complex128), + (ht.Complex64HashTable, np.complex64), + ], +) +class TestHashTableWithNans: + def test_get_set_contains_len(self, table_type, dtype): + index = float("nan") + table = table_type() + assert index not in table + + table.set_item(index, 42) + assert len(table) == 1 + assert index in table + assert table.get_item(index) == 42 + + table.set_item(index, 41) + assert len(table) == 1 + assert index in table + assert table.get_item(index) == 41 + + def test_map_locations(self, table_type, dtype): + N = 10 + table = table_type() + keys = np.full(N, np.nan, dtype=dtype) + table.map_locations(keys) + assert len(table) == 1 + assert table.get_item(np.nan) == N - 1 + + def test_unique(self, table_type, dtype): + N = 1020 + table = table_type() + keys = np.full(N, np.nan, dtype=dtype) + unique = table.unique(keys) + assert np.all(np.isnan(unique)) and len(unique) == 1 + + +def test_unique_for_nan_objects_floats(): + table = ht.PyObjectHashTable() + keys = np.array([float("nan") for i in range(50)], dtype=np.object_) + unique = table.unique(keys) + assert len(unique) == 1 + + +def test_unique_for_nan_objects_complex(): + table = ht.PyObjectHashTable() + keys = np.array([complex(float("nan"), 1.0) for i in range(50)], dtype=np.object_) + unique = table.unique(keys) + assert len(unique) == 1 + + +def test_unique_for_nan_objects_tuple(): + table = ht.PyObjectHashTable() + keys = np.array( + [1] + [(1.0, (float("nan"), 1.0)) for i in range(50)], dtype=np.object_ + ) + unique = table.unique(keys) + assert len(unique) == 2 + + +@pytest.mark.parametrize( + "dtype", + [ + np.object_, + np.complex128, + np.int64, + np.uint64, + np.float64, + np.complex64, + np.int32, + np.uint32, + np.float32, + np.int16, + np.uint16, + np.int8, + np.uint8, + np.intp, + ], +) +class TestHelpFunctions: + def test_value_count(self, dtype, writable): + N = 43 + expected = (np.arange(N) + N).astype(dtype) + values = np.repeat(expected, 5) + values.flags.writeable = writable + keys, counts, _ = ht.value_count(values, False) + tm.assert_numpy_array_equal(np.sort(keys), expected) + assert np.all(counts == 5) + + def test_value_count_mask(self, dtype): + if dtype == np.object_: + pytest.skip("mask not implemented for object dtype") + values = np.array([1] * 5, dtype=dtype) + mask = np.zeros((5,), dtype=np.bool_) + mask[1] = True + mask[4] = True + keys, counts, na_counter = ht.value_count(values, False, mask=mask) + assert len(keys) == 2 + assert na_counter == 2 + + def test_value_count_stable(self, dtype, writable): + # GH12679 + values = np.array([2, 1, 5, 22, 3, -1, 8]).astype(dtype) + values.flags.writeable = writable + keys, counts, _ = ht.value_count(values, False) + tm.assert_numpy_array_equal(keys, values) + assert np.all(counts == 1) + + def test_duplicated_first(self, dtype, writable): + N = 100 + values = np.repeat(np.arange(N).astype(dtype), 5) + values.flags.writeable = writable + result = ht.duplicated(values) + expected = np.ones_like(values, dtype=np.bool_) + expected[::5] = False + tm.assert_numpy_array_equal(result, expected) + + def test_ismember_yes(self, dtype, writable): + N = 127 + arr = np.arange(N).astype(dtype) + values = np.arange(N).astype(dtype) + arr.flags.writeable = writable + values.flags.writeable = writable + result = ht.ismember(arr, values) + expected = np.ones_like(values, dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) + + def test_ismember_no(self, dtype): + N = 17 + arr = np.arange(N).astype(dtype) + values = (np.arange(N) + N).astype(dtype) + result = ht.ismember(arr, values) + expected = np.zeros_like(values, dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) + + def test_mode(self, dtype, writable): + if dtype in (np.int8, np.uint8): + N = 53 + else: + N = 11111 + values = np.repeat(np.arange(N).astype(dtype), 5) + values[0] = 42 + values.flags.writeable = writable + result = ht.mode(values, False)[0] + assert result == 42 + + def test_mode_stable(self, dtype, writable): + values = np.array([2, 1, 5, 22, 3, -1, 8]).astype(dtype) + values.flags.writeable = writable + keys = ht.mode(values, False)[0] + tm.assert_numpy_array_equal(keys, values) + + +def test_modes_with_nans(): + # GH42688, nans aren't mangled + nulls = [pd.NA, np.nan, pd.NaT, None] + values = np.array([True] + nulls * 2, dtype=np.object_) + modes = ht.mode(values, False)[0] + assert modes.size == len(nulls) + + +def test_unique_label_indices_intp(writable): + keys = np.array([1, 2, 2, 2, 1, 3], dtype=np.intp) + keys.flags.writeable = writable + result = ht.unique_label_indices(keys) + expected = np.array([0, 1, 5], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + +def test_unique_label_indices(): + a = np.random.default_rng(2).integers(1, 1 << 10, 1 << 15).astype(np.intp) + + left = ht.unique_label_indices(a) + right = np.unique(a, return_index=True)[1] + + tm.assert_numpy_array_equal(left, right, check_dtype=False) + + a[np.random.default_rng(2).choice(len(a), 10)] = -1 + left = ht.unique_label_indices(a) + right = np.unique(a, return_index=True)[1][1:] + tm.assert_numpy_array_equal(left, right, check_dtype=False) + + +@pytest.mark.parametrize( + "dtype", + [ + np.float64, + np.float32, + np.complex128, + np.complex64, + ], +) +class TestHelpFunctionsWithNans: + def test_value_count(self, dtype): + values = np.array([np.nan, np.nan, np.nan], dtype=dtype) + keys, counts, _ = ht.value_count(values, True) + assert len(keys) == 0 + keys, counts, _ = ht.value_count(values, False) + assert len(keys) == 1 and np.all(np.isnan(keys)) + assert counts[0] == 3 + + def test_duplicated_first(self, dtype): + values = np.array([np.nan, np.nan, np.nan], dtype=dtype) + result = ht.duplicated(values) + expected = np.array([False, True, True]) + tm.assert_numpy_array_equal(result, expected) + + def test_ismember_yes(self, dtype): + arr = np.array([np.nan, np.nan, np.nan], dtype=dtype) + values = np.array([np.nan, np.nan], dtype=dtype) + result = ht.ismember(arr, values) + expected = np.array([True, True, True], dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) + + def test_ismember_no(self, dtype): + arr = np.array([np.nan, np.nan, np.nan], dtype=dtype) + values = np.array([1], dtype=dtype) + result = ht.ismember(arr, values) + expected = np.array([False, False, False], dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) + + def test_mode(self, dtype): + values = np.array([42, np.nan, np.nan, np.nan], dtype=dtype) + assert ht.mode(values, True)[0] == 42 + assert np.isnan(ht.mode(values, False)[0]) + + +def test_ismember_tuple_with_nans(): + # GH-41836 + values = np.empty(2, dtype=object) + values[:] = [("a", float("nan")), ("b", 1)] + comps = [("a", float("nan"))] + + result = isin(values, comps) + expected = np.array([True, False], dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) + + +def test_float_complex_int_are_equal_as_objects(): + values = ["a", 5, 5.0, 5.0 + 0j] + comps = list(range(129)) + result = isin(np.array(values, dtype=object), np.asarray(comps)) + expected = np.array([False, True, True, True], dtype=np.bool_) + tm.assert_numpy_array_equal(result, expected) diff --git a/pandas/tests/libs/test_join.py b/pandas/tests/libs/test_join.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8b4fabc54cbab338e9b4a3a061eb09b19714da --- /dev/null +++ b/pandas/tests/libs/test_join.py @@ -0,0 +1,388 @@ +import numpy as np +import pytest + +from pandas._libs import join as libjoin +from pandas._libs.join import ( + inner_join, + left_outer_join, +) + +import pandas._testing as tm + + +class TestIndexer: + @pytest.mark.parametrize( + "dtype", ["int32", "int64", "float32", "float64", "object"] + ) + def test_outer_join_indexer(self, dtype): + indexer = libjoin.outer_join_indexer + + left = np.arange(3, dtype=dtype) + right = np.arange(2, 5, dtype=dtype) + empty = np.array([], dtype=dtype) + + result, lindexer, rindexer = indexer(left, right) + assert isinstance(result, np.ndarray) + assert isinstance(lindexer, np.ndarray) + assert isinstance(rindexer, np.ndarray) + tm.assert_numpy_array_equal(result, np.arange(5, dtype=dtype)) + exp = np.array([0, 1, 2, -1, -1], dtype=np.intp) + tm.assert_numpy_array_equal(lindexer, exp) + exp = np.array([-1, -1, 0, 1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(rindexer, exp) + + result, lindexer, rindexer = indexer(empty, right) + tm.assert_numpy_array_equal(result, right) + exp = np.array([-1, -1, -1], dtype=np.intp) + tm.assert_numpy_array_equal(lindexer, exp) + exp = np.array([0, 1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(rindexer, exp) + + result, lindexer, rindexer = indexer(left, empty) + tm.assert_numpy_array_equal(result, left) + exp = np.array([0, 1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(lindexer, exp) + exp = np.array([-1, -1, -1], dtype=np.intp) + tm.assert_numpy_array_equal(rindexer, exp) + + def test_cython_left_outer_join(self): + left = np.array([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype=np.intp) + right = np.array([1, 1, 0, 4, 2, 2, 1], dtype=np.intp) + max_group = 5 + + ls, rs = left_outer_join(left, right, max_group) + + exp_ls = left.argsort(kind="mergesort") + exp_rs = right.argsort(kind="mergesort") + + exp_li = np.array([0, 1, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10]) + exp_ri = np.array( + [0, 0, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 4, 5, 4, 5, -1, -1] + ) + + exp_ls = exp_ls.take(exp_li) + exp_ls[exp_li == -1] = -1 + + exp_rs = exp_rs.take(exp_ri) + exp_rs[exp_ri == -1] = -1 + + tm.assert_numpy_array_equal(ls, exp_ls, check_dtype=False) + tm.assert_numpy_array_equal(rs, exp_rs, check_dtype=False) + + def test_cython_right_outer_join(self): + left = np.array([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype=np.intp) + right = np.array([1, 1, 0, 4, 2, 2, 1], dtype=np.intp) + max_group = 5 + + rs, ls = left_outer_join(right, left, max_group) + + exp_ls = left.argsort(kind="mergesort") + exp_rs = right.argsort(kind="mergesort") + + # 0 1 1 1 + exp_li = np.array( + [ + 0, + 1, + 2, + 3, + 4, + 5, + 3, + 4, + 5, + 3, + 4, + 5, + # 2 2 4 + 6, + 7, + 8, + 6, + 7, + 8, + -1, + ] + ) + exp_ri = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6]) + + exp_ls = exp_ls.take(exp_li) + exp_ls[exp_li == -1] = -1 + + exp_rs = exp_rs.take(exp_ri) + exp_rs[exp_ri == -1] = -1 + + tm.assert_numpy_array_equal(ls, exp_ls) + tm.assert_numpy_array_equal(rs, exp_rs) + + def test_cython_inner_join(self): + left = np.array([0, 1, 2, 1, 2, 0, 0, 1, 2, 3, 3], dtype=np.intp) + right = np.array([1, 1, 0, 4, 2, 2, 1, 4], dtype=np.intp) + max_group = 5 + + ls, rs = inner_join(left, right, max_group) + + exp_ls = left.argsort(kind="mergesort") + exp_rs = right.argsort(kind="mergesort") + + exp_li = np.array([0, 1, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 7, 7, 8, 8]) + exp_ri = np.array([0, 0, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 4, 5, 4, 5]) + + exp_ls = exp_ls.take(exp_li) + exp_ls[exp_li == -1] = -1 + + exp_rs = exp_rs.take(exp_ri) + exp_rs[exp_ri == -1] = -1 + + tm.assert_numpy_array_equal(ls, exp_ls) + tm.assert_numpy_array_equal(rs, exp_rs) + + +def test_left_join_indexer_unique(writable): + a = np.array([1, 2, 3, 4, 5], dtype=np.int64) + b = np.array([2, 2, 3, 4, 4], dtype=np.int64) + # GH#37312, GH#37264 + a.setflags(write=writable) + b.setflags(write=writable) + + result = libjoin.left_join_indexer_unique(b, a) + expected = np.array([1, 1, 2, 3, 3], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + +def test_left_outer_join_bug(): + left = np.array( + [ + 0, + 1, + 0, + 1, + 1, + 2, + 3, + 1, + 0, + 2, + 1, + 2, + 0, + 1, + 1, + 2, + 3, + 2, + 3, + 2, + 1, + 1, + 3, + 0, + 3, + 2, + 3, + 0, + 0, + 2, + 3, + 2, + 0, + 3, + 1, + 3, + 0, + 1, + 3, + 0, + 0, + 1, + 0, + 3, + 1, + 0, + 1, + 0, + 1, + 1, + 0, + 2, + 2, + 2, + 2, + 2, + 0, + 3, + 1, + 2, + 0, + 0, + 3, + 1, + 3, + 2, + 2, + 0, + 1, + 3, + 0, + 2, + 3, + 2, + 3, + 3, + 2, + 3, + 3, + 1, + 3, + 2, + 0, + 0, + 3, + 1, + 1, + 1, + 0, + 2, + 3, + 3, + 1, + 2, + 0, + 3, + 1, + 2, + 0, + 2, + ], + dtype=np.intp, + ) + + right = np.array([3, 1], dtype=np.intp) + max_groups = 4 + + lidx, ridx = libjoin.left_outer_join(left, right, max_groups, sort=False) + + exp_lidx = np.arange(len(left), dtype=np.intp) + exp_ridx = -np.ones(len(left), dtype=np.intp) + + exp_ridx[left == 1] = 1 + exp_ridx[left == 3] = 0 + + tm.assert_numpy_array_equal(lidx, exp_lidx) + tm.assert_numpy_array_equal(ridx, exp_ridx) + + +def test_inner_join_indexer(): + a = np.array([1, 2, 3, 4, 5], dtype=np.int64) + b = np.array([0, 3, 5, 7, 9], dtype=np.int64) + + index, ares, bres = libjoin.inner_join_indexer(a, b) + + index_exp = np.array([3, 5], dtype=np.int64) + tm.assert_almost_equal(index, index_exp) + + aexp = np.array([2, 4], dtype=np.intp) + bexp = np.array([1, 2], dtype=np.intp) + tm.assert_almost_equal(ares, aexp) + tm.assert_almost_equal(bres, bexp) + + a = np.array([5], dtype=np.int64) + b = np.array([5], dtype=np.int64) + + index, ares, bres = libjoin.inner_join_indexer(a, b) + tm.assert_numpy_array_equal(index, np.array([5], dtype=np.int64)) + tm.assert_numpy_array_equal(ares, np.array([0], dtype=np.intp)) + tm.assert_numpy_array_equal(bres, np.array([0], dtype=np.intp)) + + +def test_outer_join_indexer(): + a = np.array([1, 2, 3, 4, 5], dtype=np.int64) + b = np.array([0, 3, 5, 7, 9], dtype=np.int64) + + index, ares, bres = libjoin.outer_join_indexer(a, b) + + index_exp = np.array([0, 1, 2, 3, 4, 5, 7, 9], dtype=np.int64) + tm.assert_almost_equal(index, index_exp) + + aexp = np.array([-1, 0, 1, 2, 3, 4, -1, -1], dtype=np.intp) + bexp = np.array([0, -1, -1, 1, -1, 2, 3, 4], dtype=np.intp) + tm.assert_almost_equal(ares, aexp) + tm.assert_almost_equal(bres, bexp) + + a = np.array([5], dtype=np.int64) + b = np.array([5], dtype=np.int64) + + index, ares, bres = libjoin.outer_join_indexer(a, b) + tm.assert_numpy_array_equal(index, np.array([5], dtype=np.int64)) + tm.assert_numpy_array_equal(ares, np.array([0], dtype=np.intp)) + tm.assert_numpy_array_equal(bres, np.array([0], dtype=np.intp)) + + +def test_left_join_indexer(): + a = np.array([1, 2, 3, 4, 5], dtype=np.int64) + b = np.array([0, 3, 5, 7, 9], dtype=np.int64) + + index, ares, bres = libjoin.left_join_indexer(a, b) + + tm.assert_almost_equal(index, a) + + aexp = np.array([0, 1, 2, 3, 4], dtype=np.intp) + bexp = np.array([-1, -1, 1, -1, 2], dtype=np.intp) + tm.assert_almost_equal(ares, aexp) + tm.assert_almost_equal(bres, bexp) + + a = np.array([5], dtype=np.int64) + b = np.array([5], dtype=np.int64) + + index, ares, bres = libjoin.left_join_indexer(a, b) + tm.assert_numpy_array_equal(index, np.array([5], dtype=np.int64)) + tm.assert_numpy_array_equal(ares, np.array([0], dtype=np.intp)) + tm.assert_numpy_array_equal(bres, np.array([0], dtype=np.intp)) + + +def test_left_join_indexer2(): + idx = np.array([1, 1, 2, 5], dtype=np.int64) + idx2 = np.array([1, 2, 5, 7, 9], dtype=np.int64) + + res, lidx, ridx = libjoin.left_join_indexer(idx2, idx) + + exp_res = np.array([1, 1, 2, 5, 7, 9], dtype=np.int64) + tm.assert_almost_equal(res, exp_res) + + exp_lidx = np.array([0, 0, 1, 2, 3, 4], dtype=np.intp) + tm.assert_almost_equal(lidx, exp_lidx) + + exp_ridx = np.array([0, 1, 2, 3, -1, -1], dtype=np.intp) + tm.assert_almost_equal(ridx, exp_ridx) + + +def test_outer_join_indexer2(): + idx = np.array([1, 1, 2, 5], dtype=np.int64) + idx2 = np.array([1, 2, 5, 7, 9], dtype=np.int64) + + res, lidx, ridx = libjoin.outer_join_indexer(idx2, idx) + + exp_res = np.array([1, 1, 2, 5, 7, 9], dtype=np.int64) + tm.assert_almost_equal(res, exp_res) + + exp_lidx = np.array([0, 0, 1, 2, 3, 4], dtype=np.intp) + tm.assert_almost_equal(lidx, exp_lidx) + + exp_ridx = np.array([0, 1, 2, 3, -1, -1], dtype=np.intp) + tm.assert_almost_equal(ridx, exp_ridx) + + +def test_inner_join_indexer2(): + idx = np.array([1, 1, 2, 5], dtype=np.int64) + idx2 = np.array([1, 2, 5, 7, 9], dtype=np.int64) + + res, lidx, ridx = libjoin.inner_join_indexer(idx2, idx) + + exp_res = np.array([1, 1, 2, 5], dtype=np.int64) + tm.assert_almost_equal(res, exp_res) + + exp_lidx = np.array([0, 0, 1, 2], dtype=np.intp) + tm.assert_almost_equal(lidx, exp_lidx) + + exp_ridx = np.array([0, 1, 2, 3], dtype=np.intp) + tm.assert_almost_equal(ridx, exp_ridx) diff --git a/pandas/tests/libs/test_lib.py b/pandas/tests/libs/test_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..792b2ef121cf2f0bd8266050217031b1c04f06c2 --- /dev/null +++ b/pandas/tests/libs/test_lib.py @@ -0,0 +1,358 @@ +import pickle + +import numpy as np +import pytest + +from pandas._libs import ( + Timedelta, + lib, + writers as libwriters, +) +from pandas.compat import IS64 + +from pandas import Index +import pandas._testing as tm + + +class TestMisc: + def test_max_len_string_array(self): + arr = a = np.array(["foo", "b", np.nan], dtype="object") + assert libwriters.max_len_string_array(arr) == 3 + + # unicode + arr = a.astype("U").astype(object) + assert libwriters.max_len_string_array(arr) == 3 + + # bytes for python3 + arr = a.astype("S").astype(object) + assert libwriters.max_len_string_array(arr) == 3 + + # raises + msg = "No matching signature found" + with pytest.raises(TypeError, match=msg): + libwriters.max_len_string_array(arr.astype("U")) + + def test_fast_unique_multiple_list_gen_sort(self): + keys = [["p", "a"], ["n", "d"], ["a", "s"]] + + gen = (key for key in keys) + expected = np.array(["a", "d", "n", "p", "s"]) + out = lib.fast_unique_multiple_list_gen(gen, sort=True) + tm.assert_numpy_array_equal(np.array(out), expected) + + gen = (key for key in keys) + expected = np.array(["p", "a", "n", "d", "s"]) + out = lib.fast_unique_multiple_list_gen(gen, sort=False) + tm.assert_numpy_array_equal(np.array(out), expected) + + def test_fast_multiget_timedelta_resos(self): + # This will become relevant for test_constructor_dict_timedelta64_index + # once Timedelta constructor preserves reso when passed a + # np.timedelta64 object + td = Timedelta(days=1) + + mapping1 = {td: 1} + mapping2 = {td.as_unit("s"): 1} + + oindex = Index([td * n for n in range(3)])._values.astype(object) + + expected = lib.fast_multiget(mapping1, oindex) + result = lib.fast_multiget(mapping2, oindex) + tm.assert_numpy_array_equal(result, expected) + + # case that can't be cast to td64ns + td = Timedelta(np.timedelta64(146000, "D")) + assert hash(td) == hash(td.as_unit("ms")) + assert hash(td) == hash(td.as_unit("us")) + mapping1 = {td: 1} + mapping2 = {td.as_unit("ms"): 1} + + oindex = Index([td * n for n in range(3)])._values.astype(object) + + expected = lib.fast_multiget(mapping1, oindex) + result = lib.fast_multiget(mapping2, oindex) + tm.assert_numpy_array_equal(result, expected) + + +class TestIndexing: + def test_maybe_indices_to_slice_left_edge(self): + target = np.arange(100) + + # slice + indices = np.array([], dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize("end", [1, 2, 5, 20, 99]) + @pytest.mark.parametrize("step", [1, 2, 4]) + def test_maybe_indices_to_slice_left_edge_not_slice_end_steps(self, end, step): + target = np.arange(100) + indices = np.arange(0, end, step, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + # reverse + indices = indices[::-1] + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize( + "case", [[2, 1, 2, 0], [2, 2, 1, 0], [0, 1, 2, 1], [-2, 0, 2], [2, 0, -2]] + ) + def test_maybe_indices_to_slice_left_edge_not_slice(self, case): + # not slice + target = np.arange(100) + indices = np.array(case, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert not isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(maybe_slice, indices) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize("start", [0, 2, 5, 20, 97, 98]) + @pytest.mark.parametrize("step", [1, 2, 4]) + def test_maybe_indices_to_slice_right_edge(self, start, step): + target = np.arange(100) + + # slice + indices = np.arange(start, 99, step, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + # reverse + indices = indices[::-1] + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + def test_maybe_indices_to_slice_right_edge_not_slice(self): + # not slice + target = np.arange(100) + indices = np.array([97, 98, 99, 100], dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert not isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(maybe_slice, indices) + + msg = "index 100 is out of bounds for axis (0|1) with size 100" + + with pytest.raises(IndexError, match=msg): + target[indices] + with pytest.raises(IndexError, match=msg): + target[maybe_slice] + + indices = np.array([100, 99, 98, 97], dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert not isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(maybe_slice, indices) + + with pytest.raises(IndexError, match=msg): + target[indices] + with pytest.raises(IndexError, match=msg): + target[maybe_slice] + + @pytest.mark.parametrize( + "case", [[99, 97, 99, 96], [99, 99, 98, 97], [98, 98, 97, 96]] + ) + def test_maybe_indices_to_slice_right_edge_cases(self, case): + target = np.arange(100) + indices = np.array(case, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert not isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(maybe_slice, indices) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize("step", [1, 2, 4, 5, 8, 9]) + def test_maybe_indices_to_slice_both_edges(self, step): + target = np.arange(10) + + # slice + indices = np.arange(0, 9, step, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + # reverse + indices = indices[::-1] + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize("case", [[4, 2, 0, -2], [2, 2, 1, 0], [0, 1, 2, 1]]) + def test_maybe_indices_to_slice_both_edges_not_slice(self, case): + # not slice + target = np.arange(10) + indices = np.array(case, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + assert not isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(maybe_slice, indices) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize("start, end", [(2, 10), (5, 25), (65, 97)]) + @pytest.mark.parametrize("step", [1, 2, 4, 20]) + def test_maybe_indices_to_slice_middle(self, start, end, step): + target = np.arange(100) + + # slice + indices = np.arange(start, end, step, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + # reverse + indices = indices[::-1] + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + @pytest.mark.parametrize( + "case", [[14, 12, 10, 12], [12, 12, 11, 10], [10, 11, 12, 11]] + ) + def test_maybe_indices_to_slice_middle_not_slice(self, case): + # not slice + target = np.arange(100) + indices = np.array(case, dtype=np.intp) + maybe_slice = lib.maybe_indices_to_slice(indices, len(target)) + + assert not isinstance(maybe_slice, slice) + tm.assert_numpy_array_equal(maybe_slice, indices) + tm.assert_numpy_array_equal(target[indices], target[maybe_slice]) + + def test_maybe_booleans_to_slice(self): + arr = np.array([0, 0, 1, 1, 1, 0, 1], dtype=np.uint8) + result = lib.maybe_booleans_to_slice(arr) + assert result.dtype == np.bool_ + + result = lib.maybe_booleans_to_slice(arr[:0]) + assert result == slice(0, 0) + + def test_get_reverse_indexer(self): + indexer = np.array([-1, -1, 1, 2, 0, -1, 3, 4], dtype=np.intp) + result = lib.get_reverse_indexer(indexer, 5) + expected = np.array([4, 2, 3, 6, 7], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["int64", "int32"]) + def test_is_range_indexer(self, dtype): + # GH#50592 + left = np.arange(0, 100, dtype=dtype) + assert lib.is_range_indexer(left, 100) + + @pytest.mark.skipif( + not IS64, + reason="2**31 is too big for Py_ssize_t on 32-bit. " + "It doesn't matter though since you cannot create an array that long on 32-bit", + ) + @pytest.mark.parametrize("dtype", ["int64", "int32"]) + def test_is_range_indexer_big_n(self, dtype): + # GH53616 + left = np.arange(0, 100, dtype=dtype) + + assert not lib.is_range_indexer(left, 2**31) + + @pytest.mark.parametrize("dtype", ["int64", "int32"]) + def test_is_range_indexer_not_equal(self, dtype): + # GH#50592 + left = np.array([1, 2], dtype=dtype) + assert not lib.is_range_indexer(left, 2) + + @pytest.mark.parametrize("dtype", ["int64", "int32"]) + def test_is_range_indexer_not_equal_shape(self, dtype): + # GH#50592 + left = np.array([0, 1, 2], dtype=dtype) + assert not lib.is_range_indexer(left, 2) + + +def test_cache_readonly_preserve_docstrings(): + # GH18197 + assert Index.hasnans.__doc__ is not None + + +def test_no_default_pickle(temp_file): + # GH#40397 + obj = tm.round_trip_pickle(lib.no_default, temp_file) + assert obj is lib.no_default + + +def test_ensure_string_array_copy(): + # ensure the original array is not modified in case of copy=False with + # pickle-roundtripped object dtype array + # https://github.com/pandas-dev/pandas/issues/54654 + arr = np.array(["a", None], dtype=object) + arr = pickle.loads(pickle.dumps(arr)) + result = lib.ensure_string_array(arr, copy=False) + assert not np.shares_memory(arr, result) + assert arr[1] is None + assert result[1] is np.nan + + +def test_ensure_string_array_list_of_lists(): + # GH#61155: ensure list of lists doesn't get converted to string + arr = [list("test"), list("word")] + result = lib.ensure_string_array(arr) + + # Each item in result should still be a list, not a stringified version + expected = np.array(["['t', 'e', 's', 't']", "['w', 'o', 'r', 'd']"], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + +def test_item_from_zerodim_for_subclasses(): + # GH#62981 Ensure item_from_zerodim preserves subclasses of ndarray + # Define a custom ndarray subclass + class TestArray(np.ndarray): + def __new__(cls, input_array): + return np.asarray(input_array).view(cls) + + def __array_finalize__(self, obj) -> None: + self._is_test_array = True + + # Define test data + val_0_dim = 1 + val_1_dim = [1, 2, 3] + + # 0-dim and 1-dim numpy arrays + arr_0_dim = np.array(val_0_dim) + arr_1_dim = np.array(val_1_dim) + + # 0-dim and 1-dim TestArray arrays + test_arr_0_dim = TestArray(val_0_dim) + test_arr_1_dim = TestArray(val_1_dim) + + # Check that behavior did not change for regular numpy arrays + # Test with regular numpy 0-dim array + result = lib.item_from_zerodim(arr_0_dim) + expected = val_0_dim + assert result == expected + assert np.isscalar(result) + + # Test with regular numpy 1-dim array + result = lib.item_from_zerodim(arr_1_dim) + expected = arr_1_dim + tm.assert_numpy_array_equal(result, expected) + assert isinstance(result, np.ndarray) + + # Check that behaviour for subclasses now is as expected + # Test with TestArray 0-dim array + result = lib.item_from_zerodim(test_arr_0_dim) + expected = test_arr_0_dim + assert result == expected + assert isinstance(result, TestArray) + + # Test with TestArray 1-dim array + result = lib.item_from_zerodim(test_arr_1_dim) + expected = test_arr_1_dim + assert np.all(result == expected) + assert isinstance(result, TestArray) diff --git a/pandas/tests/libs/test_libalgos.py b/pandas/tests/libs/test_libalgos.py new file mode 100644 index 0000000000000000000000000000000000000000..42d09c72aab2baa9636093d172d864cbe0e41b12 --- /dev/null +++ b/pandas/tests/libs/test_libalgos.py @@ -0,0 +1,162 @@ +from datetime import datetime +from itertools import permutations + +import numpy as np + +from pandas._libs import algos as libalgos + +import pandas._testing as tm + + +def test_ensure_platform_int(): + arr = np.arange(100, dtype=np.intp) + + result = libalgos.ensure_platform_int(arr) + assert result is arr + + +def test_is_lexsorted(): + failure = [ + np.array( + ([3] * 32) + ([2] * 32) + ([1] * 32) + ([0] * 32), + dtype="int64", + ), + np.array( + list(range(31))[::-1] * 4, + dtype="int64", + ), + ] + + assert not libalgos.is_lexsorted(failure) + + +def test_groupsort_indexer(): + a = np.random.default_rng(2).integers(0, 1000, 100).astype(np.intp) + b = np.random.default_rng(2).integers(0, 1000, 100).astype(np.intp) + + result = libalgos.groupsort_indexer(a, 1000)[0] + + # need to use a stable sort + # np.argsort returns int, groupsort_indexer + # always returns intp + expected = np.argsort(a, kind="mergesort") + expected = expected.astype(np.intp) + + tm.assert_numpy_array_equal(result, expected) + + # compare with lexsort + # np.lexsort returns int, groupsort_indexer + # always returns intp + key = a * 1000 + b + result = libalgos.groupsort_indexer(key, 1000000)[0] + expected = np.lexsort((b, a)) + expected = expected.astype(np.intp) + + tm.assert_numpy_array_equal(result, expected) + + +class TestPadBackfill: + def test_backfill(self): + old = np.array([1, 5, 10], dtype=np.int64) + new = np.array(list(range(12)), dtype=np.int64) + + filler = libalgos.backfill["int64_t"](old, new) + + expect_filler = np.array([0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, -1], dtype=np.intp) + tm.assert_numpy_array_equal(filler, expect_filler) + + # corner case + old = np.array([1, 4], dtype=np.int64) + new = np.array(list(range(5, 10)), dtype=np.int64) + filler = libalgos.backfill["int64_t"](old, new) + + expect_filler = np.array([-1, -1, -1, -1, -1], dtype=np.intp) + tm.assert_numpy_array_equal(filler, expect_filler) + + def test_pad(self): + old = np.array([1, 5, 10], dtype=np.int64) + new = np.array(list(range(12)), dtype=np.int64) + + filler = libalgos.pad["int64_t"](old, new) + + expect_filler = np.array([-1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2], dtype=np.intp) + tm.assert_numpy_array_equal(filler, expect_filler) + + # corner case + old = np.array([5, 10], dtype=np.int64) + new = np.arange(5, dtype=np.int64) + filler = libalgos.pad["int64_t"](old, new) + expect_filler = np.array([-1, -1, -1, -1, -1], dtype=np.intp) + tm.assert_numpy_array_equal(filler, expect_filler) + + def test_pad_backfill_object_segfault(self): + old = np.array([], dtype="O") + new = np.array([datetime(2010, 12, 31)], dtype="O") + + result = libalgos.pad["object"](old, new) + expected = np.array([-1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + result = libalgos.pad["object"](new, old) + expected = np.array([], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + result = libalgos.backfill["object"](old, new) + expected = np.array([-1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + result = libalgos.backfill["object"](new, old) + expected = np.array([], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + + +class TestInfinity: + def test_infinity_sort(self): + # GH#13445 + # numpy's argsort can be unhappy if something is less than + # itself. Instead, let's give our infinities a self-consistent + # ordering, but outside the float extended real line. + + Inf = libalgos.Infinity() + NegInf = libalgos.NegInfinity() + + ref_nums = [NegInf, float("-inf"), -1e100, 0, 1e100, float("inf"), Inf] + + assert all(Inf >= x for x in ref_nums) + assert all(Inf > x or x is Inf for x in ref_nums) + assert Inf >= Inf and Inf == Inf + assert not Inf < Inf and not Inf > Inf + assert libalgos.Infinity() == libalgos.Infinity() + assert not libalgos.Infinity() != libalgos.Infinity() + + assert all(NegInf <= x for x in ref_nums) + assert all(NegInf < x or x is NegInf for x in ref_nums) + assert NegInf <= NegInf and NegInf == NegInf + assert not NegInf < NegInf and not NegInf > NegInf + assert libalgos.NegInfinity() == libalgos.NegInfinity() + assert not libalgos.NegInfinity() != libalgos.NegInfinity() + + for perm in permutations(ref_nums): + assert sorted(perm) == ref_nums + + # smoke tests + np.array([libalgos.Infinity()] * 32).argsort() + np.array([libalgos.NegInfinity()] * 32).argsort() + + def test_infinity_against_nan(self): + Inf = libalgos.Infinity() + NegInf = libalgos.NegInfinity() + + assert not Inf > np.nan + assert not Inf >= np.nan + assert not Inf < np.nan + assert not Inf <= np.nan + assert not Inf == np.nan + assert Inf != np.nan + + assert not NegInf > np.nan + assert not NegInf >= np.nan + assert not NegInf < np.nan + assert not NegInf <= np.nan + assert not NegInf == np.nan + assert NegInf != np.nan diff --git a/pandas/tests/plotting/__init__.py b/pandas/tests/plotting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/plotting/common.py b/pandas/tests/plotting/common.py new file mode 100644 index 0000000000000000000000000000000000000000..588bbf88e856243ce539a5378b22468228bcf280 --- /dev/null +++ b/pandas/tests/plotting/common.py @@ -0,0 +1,579 @@ +""" +Module consolidating common testing functions for checking plotting. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from pandas.core.dtypes.api import is_list_like + +import pandas as pd +from pandas import Series +import pandas._testing as tm + +if TYPE_CHECKING: + from collections.abc import Sequence + + from matplotlib.axes import Axes + + +def _check_legend_labels(axes, labels=None, visible=True): + """ + Check each axes has expected legend labels + + Parameters + ---------- + axes : matplotlib Axes object, or its list-like + labels : list-like + expected legend labels + visible : bool + expected legend visibility. labels are checked only when visible is + True + """ + if visible and (labels is None): + raise ValueError("labels must be specified when visible is True") + axes = _flatten_visible(axes) + for ax in axes: + if visible: + assert ax.get_legend() is not None + _check_text_labels(ax.get_legend().get_texts(), labels) + else: + assert ax.get_legend() is None + + +def _check_legend_marker(ax, expected_markers=None, visible=True): + """ + Check ax has expected legend markers + + Parameters + ---------- + ax : matplotlib Axes object + expected_markers : list-like + expected legend markers + visible : bool + expected legend visibility. labels are checked only when visible is + True + """ + if visible and (expected_markers is None): + raise ValueError("Markers must be specified when visible is True") + if visible: + handles, _ = ax.get_legend_handles_labels() + markers = [handle.get_marker() for handle in handles] + assert markers == expected_markers + else: + assert ax.get_legend() is None + + +def _check_data(xp, rs): + """ + Check each axes has identical lines + + Parameters + ---------- + xp : matplotlib Axes object + rs : matplotlib Axes object + """ + xp_lines = xp.get_lines() + rs_lines = rs.get_lines() + + assert len(xp_lines) == len(rs_lines) + for xpl, rsl in zip(xp_lines, rs_lines, strict=True): + xpdata = xpl.get_xydata() + rsdata = rsl.get_xydata() + tm.assert_almost_equal(xpdata, rsdata) + + +def _check_visible(collections, visible=True): + """ + Check each artist is visible or not + + Parameters + ---------- + collections : matplotlib Artist or its list-like + target Artist or its list or collection + visible : bool + expected visibility + """ + from matplotlib.collections import Collection + + if not isinstance(collections, Collection) and not is_list_like(collections): + collections = [collections] + + for patch in collections: + assert patch.get_visible() == visible + + +def _check_patches_all_filled(axes: Axes | Sequence[Axes], filled: bool = True) -> None: + """ + Check for each artist whether it is filled or not + + Parameters + ---------- + axes : matplotlib Axes object, or its list-like + filled : bool + expected filling + """ + + axes = _flatten_visible(axes) + for ax in axes: + for patch in ax.patches: + assert patch.fill == filled + + +def _get_colors_mapped(series, colors): + unique = series.unique() + # unique and colors length can be differed + # depending on slice value + mapped = dict(zip(unique, colors)) + return [mapped[v] for v in series.values] + + +def _check_colors(collections, linecolors=None, facecolors=None, mapping=None): + """ + Check each artist has expected line colors and face colors + + Parameters + ---------- + collections : list-like + list or collection of target artist + linecolors : list-like which has the same length as collections + list of expected line colors + facecolors : list-like which has the same length as collections + list of expected face colors + mapping : Series + Series used for color grouping key + used for andrew_curves, parallel_coordinates, radviz test + """ + from matplotlib import colors + from matplotlib.collections import ( + Collection, + LineCollection, + PolyCollection, + ) + from matplotlib.lines import Line2D + + conv = colors.ColorConverter + if linecolors is not None: + if mapping is not None: + linecolors = _get_colors_mapped(mapping, linecolors) + linecolors = linecolors[: len(collections)] + + assert len(collections) == len(linecolors) + for patch, color in zip(collections, linecolors, strict=True): + if isinstance(patch, Line2D): + result = patch.get_color() + # Line2D may contains string color expression + result = conv.to_rgba(result) + elif isinstance(patch, (PolyCollection, LineCollection)): + result = tuple(patch.get_edgecolor()[0]) + else: + result = patch.get_edgecolor() + + expected = conv.to_rgba(color) + assert result == expected + + if facecolors is not None: + if mapping is not None: + facecolors = _get_colors_mapped(mapping, facecolors) + facecolors = facecolors[: len(collections)] + + assert len(collections) == len(facecolors) + for patch, color in zip(collections, facecolors, strict=True): + if isinstance(patch, Collection): + # returned as list of np.array + result = patch.get_facecolor()[0] + else: + result = patch.get_facecolor() + + if isinstance(result, np.ndarray): + result = tuple(result) + + expected = conv.to_rgba(color) + assert result == expected + + +def _check_text_labels(texts, expected): + """ + Check each text has expected labels + + Parameters + ---------- + texts : matplotlib Text object, or its list-like + target text, or its list + expected : str or list-like which has the same length as texts + expected text label, or its list + """ + if not is_list_like(texts): + assert texts.get_text() == expected + else: + labels = [t.get_text() for t in texts] + assert len(labels) == len(expected) + for label, e in zip(labels, expected, strict=True): + assert label == e + + +def _check_ticks_props(axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None): + """ + Check each axes has expected tick properties + + Parameters + ---------- + axes : matplotlib Axes object, or its list-like + xlabelsize : number + expected xticks font size + xrot : number + expected xticks rotation + ylabelsize : number + expected yticks font size + yrot : number + expected yticks rotation + """ + from matplotlib.ticker import NullFormatter + + axes = _flatten_visible(axes) + for ax in axes: + if xlabelsize is not None or xrot is not None: + if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter): + # If minor ticks has NullFormatter, rot / fontsize are not + # retained + labels = ax.get_xticklabels() + else: + labels = ax.get_xticklabels() + ax.get_xticklabels(minor=True) + + for label in labels: + if xlabelsize is not None: + tm.assert_almost_equal(label.get_fontsize(), xlabelsize) + if xrot is not None: + tm.assert_almost_equal(label.get_rotation(), xrot) + + if ylabelsize is not None or yrot is not None: + if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter): + labels = ax.get_yticklabels() + else: + labels = ax.get_yticklabels() + ax.get_yticklabels(minor=True) + + for label in labels: + if ylabelsize is not None: + tm.assert_almost_equal(label.get_fontsize(), ylabelsize) + if yrot is not None: + tm.assert_almost_equal(label.get_rotation(), yrot) + + +def _check_ax_scales(axes, xaxis="linear", yaxis="linear"): + """ + Check each axes has expected scales + + Parameters + ---------- + axes : matplotlib Axes object, or its list-like + xaxis : {'linear', 'log'} + expected xaxis scale + yaxis : {'linear', 'log'} + expected yaxis scale + """ + axes = _flatten_visible(axes) + for ax in axes: + assert ax.xaxis.get_scale() == xaxis + assert ax.yaxis.get_scale() == yaxis + + +def _check_axes_shape(axes, axes_num=None, layout=None, figsize=None): + """ + Check expected number of axes is drawn in expected layout + + Parameters + ---------- + axes : matplotlib Axes object, or its list-like + axes_num : number + expected number of axes. Unnecessary axes should be set to + invisible. + layout : tuple + expected layout, (expected number of rows , columns) + figsize : tuple + expected figsize. default is matplotlib default + """ + from pandas.plotting._matplotlib.tools import flatten_axes + + if figsize is None: + figsize = (6.4, 4.8) + visible_axes = _flatten_visible(axes) + + if axes_num is not None: + assert len(visible_axes) == axes_num + for ax in visible_axes: + # check something drawn on visible axes + assert len(ax.get_children()) > 0 + + if layout is not None: + x_set = set() + y_set = set() + for ax in flatten_axes(axes): + # check axes coordinates to estimate layout + points = ax.get_position().get_points() + x_set.add(points[0][0]) + y_set.add(points[0][1]) + result = (len(y_set), len(x_set)) + assert result == layout + + tm.assert_numpy_array_equal( + visible_axes[0].figure.get_size_inches(), + np.array(figsize, dtype=np.float64), + ) + + +def _flatten_visible(axes: Axes | Sequence[Axes]) -> Sequence[Axes]: + """ + Flatten axes, and filter only visible + + Parameters + ---------- + axes : matplotlib Axes object, or its list-like + + """ + from pandas.plotting._matplotlib.tools import flatten_axes + + axes_ndarray = flatten_axes(axes) + axes = [ax for ax in axes_ndarray if ax.get_visible()] + return axes + + +def _check_has_errorbars(axes, xerr=0, yerr=0): + """ + Check axes has expected number of errorbars + + Parameters + ---------- + axes : matplotlib Axes object, or its list-like + xerr : number + expected number of x errorbar + yerr : number + expected number of y errorbar + """ + axes = _flatten_visible(axes) + for ax in axes: + containers = ax.containers + xerr_count = 0 + yerr_count = 0 + for c in containers: + has_xerr = getattr(c, "has_xerr", False) + has_yerr = getattr(c, "has_yerr", False) + if has_xerr: + xerr_count += 1 + if has_yerr: + yerr_count += 1 + assert xerr == xerr_count + assert yerr == yerr_count + + +def _check_box_return_type( + returned, return_type, expected_keys=None, check_ax_title=True +): + """ + Check box returned type is correct + + Parameters + ---------- + returned : object to be tested, returned from boxplot + return_type : str + return_type passed to boxplot + expected_keys : list-like, optional + group labels in subplot case. If not passed, + the function checks assuming boxplot uses single ax + check_ax_title : bool + Whether to check the ax.title is the same as expected_key + Intended to be checked by calling from ``boxplot``. + Normal ``plot`` doesn't attach ``ax.title``, it must be disabled. + """ + from matplotlib.axes import Axes + + types = {"dict": dict, "axes": Axes, "both": tuple} + if expected_keys is None: + # should be fixed when the returning default is changed + if return_type is None: + return_type = "dict" + + assert isinstance(returned, types[return_type]) + if return_type == "both": + assert isinstance(returned.ax, Axes) + assert isinstance(returned.lines, dict) + else: + # should be fixed when the returning default is changed + if return_type is None: + for r in _flatten_visible(returned): + assert isinstance(r, Axes) + return + + assert isinstance(returned, Series) + + assert sorted(returned.keys()) == sorted(expected_keys) + for key, value in returned.items(): + assert isinstance(value, types[return_type]) + # check returned dict has correct mapping + if return_type == "axes": + if check_ax_title: + assert value.get_title() == key + elif return_type == "both": + if check_ax_title: + assert value.ax.get_title() == key + assert isinstance(value.ax, Axes) + assert isinstance(value.lines, dict) + elif return_type == "dict": + line = value["medians"][0] + axes = line.axes + if check_ax_title: + assert axes.get_title() == key + else: + raise AssertionError + + +def _check_grid_settings(obj, kinds, kws=None): + # Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792 + + import matplotlib as mpl + + def is_grid_on(): + xticks = mpl.pyplot.gca().xaxis.get_major_ticks() + yticks = mpl.pyplot.gca().yaxis.get_major_ticks() + xoff = all(not g.gridline.get_visible() for g in xticks) + yoff = all(not g.gridline.get_visible() for g in yticks) + + return not (xoff and yoff) + + if kws is None: + kws = {} + spndx = 1 + for kind in kinds: + mpl.pyplot.subplot(1, 4 * len(kinds), spndx) + spndx += 1 + mpl.rc("axes", grid=False) + obj.plot(kind=kind, **kws) + assert not is_grid_on() + mpl.pyplot.clf() + + mpl.pyplot.subplot(1, 4 * len(kinds), spndx) + spndx += 1 + mpl.rc("axes", grid=True) + obj.plot(kind=kind, grid=False, **kws) + assert not is_grid_on() + mpl.pyplot.clf() + + if kind not in ["pie", "hexbin", "scatter"]: + mpl.pyplot.subplot(1, 4 * len(kinds), spndx) + spndx += 1 + mpl.rc("axes", grid=True) + obj.plot(kind=kind, **kws) + assert is_grid_on() + mpl.pyplot.clf() + + mpl.pyplot.subplot(1, 4 * len(kinds), spndx) + spndx += 1 + mpl.rc("axes", grid=False) + obj.plot(kind=kind, grid=True, **kws) + assert is_grid_on() + mpl.pyplot.clf() + + +def _unpack_cycler(rcParams, field="color"): + """ + Auxiliary function for correctly unpacking cycler after MPL >= 1.5 + """ + return [v[field] for v in rcParams["axes.prop_cycle"]] + + +def get_x_axis(ax): + return ax._shared_axes["x"] + + +def get_y_axis(ax): + return ax._shared_axes["y"] + + +def assert_is_valid_plot_return_object(objs) -> None: + from matplotlib.artist import Artist + from matplotlib.axes import Axes + + if isinstance(objs, (Series, np.ndarray)): + if isinstance(objs, Series): + objs = objs._values + for el in objs.reshape(-1): + msg = ( + "one of 'objs' is not a matplotlib Axes instance, " + f"type encountered {type(el).__name__!r}" + ) + assert isinstance(el, (Axes, dict)), msg + else: + msg = ( + "objs is neither an ndarray of Artist instances nor a single " + "ArtistArtist instance, tuple, or dict, 'objs' is a " + f"{type(objs).__name__!r}" + ) + assert isinstance(objs, (Artist, tuple, dict)), msg + + +def _check_plot_works(f, default_axes=False, **kwargs): + """ + Create plot and ensure that plot return object is valid. + + Parameters + ---------- + f : func + Plotting function. + default_axes : bool, optional + If False (default): + - If `ax` not in `kwargs`, then create subplot(211) and plot there + - Create new subplot(212) and plot there as well + - Mind special corner case for bootstrap_plot (see `_gen_two_subplots`) + If True: + - Simply run plotting function with kwargs provided + - All required axes instances will be created automatically + - It is recommended to use it when the plotting function + creates multiple axes itself. It helps avoid warnings like + 'UserWarning: To output multiple subplots, + the figure containing the passed axes is being cleared' + **kwargs + Keyword arguments passed to the plotting function. + + Returns + ------- + Plot object returned by the last plotting. + """ + import matplotlib.pyplot as plt + + if default_axes: + gen_plots = _gen_default_plot + else: + gen_plots = _gen_two_subplots + + ret = None + fig = kwargs.get("figure", plt.gcf()) + fig.clf() + + for ret in gen_plots(f, fig, **kwargs): + assert_is_valid_plot_return_object(ret) + + return ret + + +def _gen_default_plot(f, fig, **kwargs): + """ + Create plot in a default way. + """ + yield f(**kwargs) + + +def _gen_two_subplots(f, fig, **kwargs): + """ + Create plot on two subplots forcefully created. + """ + if "ax" not in kwargs: + fig.add_subplot(211) + yield f(**kwargs) + + if f is pd.plotting.bootstrap_plot: + assert "ax" not in kwargs + else: + kwargs["ax"] = fig.add_subplot(212) + yield f(**kwargs) diff --git a/pandas/tests/plotting/conftest.py b/pandas/tests/plotting/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..eb5a1f1f6382e3fe48fd0fa4050a673160807a9a --- /dev/null +++ b/pandas/tests/plotting/conftest.py @@ -0,0 +1,39 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + to_datetime, +) + + +@pytest.fixture(autouse=True) +def autouse_mpl_cleanup(mpl_cleanup): + pass + + +@pytest.fixture +def hist_df(): + n = 50 + rng = np.random.default_rng(10) + gender = rng.choice(["Male", "Female"], size=n) + classroom = rng.choice(["A", "B", "C"], size=n) + + hist_df = DataFrame( + { + "gender": gender, + "classroom": classroom, + "height": rng.normal(66, 4, size=n), + "weight": rng.normal(161, 32, size=n), + "category": rng.integers(4, size=n), + "datetime": to_datetime( + rng.integers( + 812419200000000000, + 819331200000000000, + size=n, + dtype=np.int64, + ) + ), + } + ) + return hist_df diff --git a/pandas/tests/plotting/test_backend.py b/pandas/tests/plotting/test_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..683bfcfe54f3ed45567c77b636119973ac910f38 --- /dev/null +++ b/pandas/tests/plotting/test_backend.py @@ -0,0 +1,100 @@ +import sys +import types + +import pytest + +import pandas.util._test_decorators as td + +import pandas + +pytestmark = pytest.mark.single_cpu + + +@pytest.fixture +def dummy_backend(): + db = types.ModuleType("pandas_dummy_backend") + setattr(db, "plot", lambda *args, **kwargs: "used_dummy") + return db + + +@pytest.fixture +def restore_backend(): + """Restore the plotting backend to matplotlib""" + with pandas.option_context("plotting.backend", "matplotlib"): + yield + + +def test_backend_is_not_module(): + msg = "Could not find plotting backend 'not_an_existing_module'." + with pytest.raises(ValueError, match=msg): + pandas.set_option("plotting.backend", "not_an_existing_module") + + assert pandas.options.plotting.backend == "matplotlib" + + +def test_backend_is_correct(monkeypatch, restore_backend, dummy_backend): + monkeypatch.setitem(sys.modules, "pandas_dummy_backend", dummy_backend) + + pandas.set_option("plotting.backend", "pandas_dummy_backend") + assert pandas.get_option("plotting.backend") == "pandas_dummy_backend" + assert ( + pandas.plotting._core._get_plot_backend("pandas_dummy_backend") is dummy_backend + ) + + +def test_backend_can_be_set_in_plot_call(monkeypatch, restore_backend, dummy_backend): + monkeypatch.setitem(sys.modules, "pandas_dummy_backend", dummy_backend) + df = pandas.DataFrame([1, 2, 3]) + + assert pandas.get_option("plotting.backend") == "matplotlib" + assert df.plot(backend="pandas_dummy_backend") == "used_dummy" + + +def test_register_entrypoint(restore_backend, tmp_path, monkeypatch, dummy_backend): + monkeypatch.syspath_prepend(tmp_path) + monkeypatch.setitem(sys.modules, "pandas_dummy_backend", dummy_backend) + + dist_info = tmp_path / "my_backend-0.0.0.dist-info" + dist_info.mkdir() + # entry_point name should not match module name - otherwise pandas will + # fall back to backend lookup by module name + (dist_info / "entry_points.txt").write_bytes( + b"[pandas_plotting_backends]\nmy_ep_backend = pandas_dummy_backend\n" + ) + + assert pandas.plotting._core._get_plot_backend("my_ep_backend") is dummy_backend + + with pandas.option_context("plotting.backend", "my_ep_backend"): + assert pandas.plotting._core._get_plot_backend() is dummy_backend + + +def test_setting_backend_without_plot_raises(monkeypatch): + # GH-28163 + module = types.ModuleType("pandas_plot_backend") + monkeypatch.setitem(sys.modules, "pandas_plot_backend", module) + + assert pandas.options.plotting.backend == "matplotlib" + with pytest.raises( + ValueError, match="Could not find plotting backend 'pandas_plot_backend'." + ): + pandas.set_option("plotting.backend", "pandas_plot_backend") + + assert pandas.options.plotting.backend == "matplotlib" + + +@td.skip_if_installed("matplotlib") +def test_no_matplotlib_ok(): + msg = ( + 'matplotlib is required for plotting when the default backend "matplotlib" is ' + "selected." + ) + with pytest.raises(ImportError, match=msg): + pandas.plotting._core._get_plot_backend("matplotlib") + + +def test_extra_kinds_ok(monkeypatch, restore_backend, dummy_backend): + # https://github.com/pandas-dev/pandas/pull/28647 + monkeypatch.setitem(sys.modules, "pandas_dummy_backend", dummy_backend) + pandas.set_option("plotting.backend", "pandas_dummy_backend") + df = pandas.DataFrame({"A": [1, 2, 3]}) + df.plot(kind="not a real kind") diff --git a/pandas/tests/plotting/test_boxplot_method.py b/pandas/tests/plotting/test_boxplot_method.py new file mode 100644 index 0000000000000000000000000000000000000000..3554f1549e4889660340eb0ea2c41149e55df9a0 --- /dev/null +++ b/pandas/tests/plotting/test_boxplot_method.py @@ -0,0 +1,774 @@ +"""Test cases for .boxplot method""" + +from __future__ import annotations + +import itertools +import string + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + MultiIndex, + Series, + date_range, + plotting, + timedelta_range, +) +import pandas._testing as tm +from pandas.tests.plotting.common import ( + _check_axes_shape, + _check_box_return_type, + _check_plot_works, + _check_ticks_props, + _check_visible, +) +from pandas.util.version import Version + +from pandas.io.formats.printing import pprint_thing + +mpl = pytest.importorskip("matplotlib") +plt = pytest.importorskip("matplotlib.pyplot") + + +def _check_ax_limits(col, ax): + y_min, y_max = ax.get_ylim() + assert y_min <= col.min() + assert y_max >= col.max() + + +if Version(mpl.__version__) < Version("3.10"): + verts: list[dict[str, bool | str]] = [{"vert": False}, {"vert": True}] +else: + verts = [{"orientation": "horizontal"}, {"orientation": "vertical"}] + + +@pytest.fixture(params=verts) +def vert(request): + return request.param + + +class TestDataFramePlots: + def test_stacked_boxplot_set_axis(self): + # GH2980 + n = 30 + df = DataFrame( + { + "Clinical": np.random.default_rng(2).choice([0, 1, 2, 3], n), + "Confirmed": np.random.default_rng(2).choice([0, 1, 2, 3], n), + "Discarded": np.random.default_rng(2).choice([0, 1, 2, 3], n), + }, + index=np.arange(0, n), + ) + ax = df.plot(kind="bar", stacked=True) + assert [int(x.get_text()) for x in ax.get_xticklabels()] == df.index.to_list() + ax.set_xticks(np.arange(0, n, 10)) + plt.draw() # Update changes + assert [int(x.get_text()) for x in ax.get_xticklabels()] == list( + np.arange(0, n, 10) + ) + + @pytest.mark.slow + @pytest.mark.parametrize( + "kwargs, warn", + [ + [{"return_type": "dict"}, None], + [{"column": ["one", "two"]}, None], + [{"column": ["one", "two"], "by": "indic"}, UserWarning], + [{"column": ["one"], "by": ["indic", "indic2"]}, None], + [{"by": "indic"}, UserWarning], + [{"by": ["indic", "indic2"]}, UserWarning], + [{"notch": 1}, None], + [{"by": "indic", "notch": 1}, UserWarning], + ], + ) + def test_boxplot_legacy1(self, kwargs, warn): + df = DataFrame( + np.random.default_rng(2).standard_normal((6, 4)), + index=list(string.ascii_letters[:6]), + columns=["one", "two", "three", "four"], + ) + df["indic"] = ["foo", "bar"] * 3 + df["indic2"] = ["foo", "bar", "foo"] * 2 + + # _check_plot_works can add an ax so catch warning. see GH #13188 + with tm.assert_produces_warning(warn, check_stacklevel=False): + _check_plot_works(df.boxplot, **kwargs) + + def test_boxplot_legacy1_series(self): + ser = Series(np.random.default_rng(2).standard_normal(6)) + _check_plot_works(plotting._core.boxplot, data=ser, return_type="dict") + + def test_boxplot_legacy2(self): + df = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=["Col1", "Col2"] + ) + df["X"] = Series(["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) + df["Y"] = Series(["A"] * 10) + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + _check_plot_works(df.boxplot, by="X") + + def test_boxplot_legacy2_with_ax(self): + df = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=["Col1", "Col2"] + ) + df["X"] = Series(["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) + df["Y"] = Series(["A"] * 10) + # When ax is supplied and required number of axes is 1, + # passed ax should be used: + _, ax = mpl.pyplot.subplots() + axes = df.boxplot("Col1", by="X", ax=ax) + ax_axes = ax.axes + assert ax_axes is axes + + def test_boxplot_legacy2_with_ax_return_type(self): + df = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=["Col1", "Col2"] + ) + df["X"] = Series(["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) + df["Y"] = Series(["A"] * 10) + fig, ax = mpl.pyplot.subplots() + axes = df.groupby("Y").boxplot(ax=ax, return_type="axes") + ax_axes = ax.axes + assert ax_axes is axes["A"] + + def test_boxplot_legacy2_with_multi_col(self): + df = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=["Col1", "Col2"] + ) + df["X"] = Series(["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) + df["Y"] = Series(["A"] * 10) + # Multiple columns with an ax argument should use same figure + fig, ax = mpl.pyplot.subplots() + msg = "the figure containing the passed axes is being cleared" + with tm.assert_produces_warning(UserWarning, match=msg): + axes = df.boxplot( + column=["Col1", "Col2"], by="X", ax=ax, return_type="axes" + ) + assert axes["Col1"].get_figure() is fig + + def test_boxplot_legacy2_by_none(self): + df = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=["Col1", "Col2"] + ) + df["X"] = Series(["A", "A", "A", "A", "A", "B", "B", "B", "B", "B"]) + df["Y"] = Series(["A"] * 10) + # When by is None, check that all relevant lines are present in the + # dict + _, ax = mpl.pyplot.subplots() + d = df.boxplot(ax=ax, return_type="dict") + lines = list(itertools.chain.from_iterable(d.values())) + assert len(ax.get_lines()) == len(lines) + + def test_boxplot_return_type_none(self, hist_df): + # GH 12216; return_type=None & by=None -> axes + result = hist_df.boxplot() + assert isinstance(result, mpl.pyplot.Axes) + + def test_boxplot_return_type_legacy(self): + # API change in https://github.com/pandas-dev/pandas/pull/7096 + + df = DataFrame( + np.random.default_rng(2).standard_normal((6, 4)), + index=list(string.ascii_letters[:6]), + columns=["one", "two", "three", "four"], + ) + msg = "return_type must be {'axes', 'dict', 'both'}" + with pytest.raises(ValueError, match=msg): + df.boxplot(return_type="NOT_A_TYPE") + + result = df.boxplot() + _check_box_return_type(result, "axes") + + @pytest.mark.parametrize("return_type", ["dict", "axes", "both"]) + def test_boxplot_return_type_legacy_return_type(self, return_type): + # API change in https://github.com/pandas-dev/pandas/pull/7096 + + df = DataFrame( + np.random.default_rng(2).standard_normal((6, 4)), + index=list(string.ascii_letters[:6]), + columns=["one", "two", "three", "four"], + ) + with tm.assert_produces_warning(False): + result = df.boxplot(return_type=return_type) + _check_box_return_type(result, return_type) + + def test_boxplot_axis_limits(self, hist_df): + df = hist_df.copy() + df["age"] = np.random.default_rng(2).integers(1, 20, df.shape[0]) + # One full row + height_ax, weight_ax = df.boxplot(["height", "weight"], by="category") + _check_ax_limits(df["height"], height_ax) + _check_ax_limits(df["weight"], weight_ax) + assert weight_ax._sharey == height_ax + + def test_boxplot_axis_limits_two_rows(self, hist_df): + df = hist_df.copy() + df["age"] = np.random.default_rng(2).integers(1, 20, df.shape[0]) + # Two rows, one partial + p = df.boxplot(["height", "weight", "age"], by="category") + height_ax, weight_ax, age_ax = p[0, 0], p[0, 1], p[1, 0] + dummy_ax = p[1, 1] + + _check_ax_limits(df["height"], height_ax) + _check_ax_limits(df["weight"], weight_ax) + _check_ax_limits(df["age"], age_ax) + assert weight_ax._sharey == height_ax + assert age_ax._sharey == height_ax + assert dummy_ax._sharey is None + + def test_boxplot_empty_column(self): + df = DataFrame(np.random.default_rng(2).standard_normal((20, 4))) + df.loc[:, 0] = np.nan + _check_plot_works(df.boxplot, return_type="axes") + + def test_figsize(self): + df = DataFrame( + np.random.default_rng(2).random((10, 5)), columns=["A", "B", "C", "D", "E"] + ) + result = df.boxplot(return_type="axes", figsize=(12, 8)) + assert result.figure.bbox_inches.width == 12 + assert result.figure.bbox_inches.height == 8 + + def test_fontsize(self): + df = DataFrame({"a": [1, 2, 3, 4, 5, 6]}) + _check_ticks_props(df.boxplot("a", fontsize=16), xlabelsize=16, ylabelsize=16) + + def test_boxplot_numeric_data(self): + # GH 22799 + df = DataFrame( + { + "a": date_range("2012-01-01", periods=10), + "b": np.random.default_rng(2).standard_normal(10), + "c": np.random.default_rng(2).standard_normal(10) + 2, + "d": date_range("2012-01-01", periods=10).astype(str), + "e": date_range("2012-01-01", periods=10, tz="UTC"), + "f": timedelta_range("1 days", periods=10), + } + ) + ax = df.plot(kind="box") + assert [x.get_text() for x in ax.get_xticklabels()] == ["b", "c"] + + @pytest.mark.parametrize( + "colors_kwd, expected", + [ + ( + {"boxes": "r", "whiskers": "b", "medians": "g", "caps": "c"}, + {"boxes": "r", "whiskers": "b", "medians": "g", "caps": "c"}, + ), + ({"boxes": "r"}, {"boxes": "r"}), + ("r", {"boxes": "r", "whiskers": "r", "medians": "r", "caps": "r"}), + ], + ) + def test_color_kwd(self, colors_kwd, expected): + # GH: 26214 + df = DataFrame(np.random.default_rng(2).random((10, 2))) + result = df.boxplot(color=colors_kwd, return_type="dict") + for k, v in expected.items(): + assert result[k][0].get_color() == v + + @pytest.mark.parametrize( + "scheme,expected", + [ + ( + "dark_background", + { + "boxes": "#8dd3c7", + "whiskers": "#8dd3c7", + "medians": "#bfbbd9", + "caps": "#8dd3c7", + }, + ), + ( + "default", + { + "boxes": "#1f77b4", + "whiskers": "#1f77b4", + "medians": "#2ca02c", + "caps": "#1f77b4", + }, + ), + ], + ) + def test_colors_in_theme(self, scheme, expected): + # GH: 40769 + df = DataFrame(np.random.default_rng(2).random((10, 2))) + plt.style.use(scheme) + result = df.plot.box(return_type="dict") + for k, v in expected.items(): + assert result[k][0].get_color() == v + + @pytest.mark.parametrize( + "dict_colors, msg", + [({"boxes": "r", "invalid_key": "r"}, "invalid key 'invalid_key'")], + ) + def test_color_kwd_errors(self, dict_colors, msg): + # GH: 26214 + df = DataFrame(np.random.default_rng(2).random((10, 2))) + with pytest.raises(ValueError, match=msg): + df.boxplot(color=dict_colors, return_type="dict") + + @pytest.mark.parametrize( + "props, expected", + [ + ("boxprops", "boxes"), + ("whiskerprops", "whiskers"), + ("capprops", "caps"), + ("medianprops", "medians"), + ], + ) + def test_specified_props_kwd(self, props, expected): + # GH 30346 + df = DataFrame({k: np.random.default_rng(2).random(10) for k in "ABC"}) + kwd = {props: {"color": "C1"}} + result = df.boxplot(return_type="dict", **kwd) + + assert result[expected][0].get_color() == "C1" + + @pytest.mark.filterwarnings("ignore:set_ticklabels:UserWarning") + def test_plot_xlabel_ylabel(self, vert): + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(10), + "b": np.random.default_rng(2).standard_normal(10), + "group": np.random.default_rng(2).choice(["group1", "group2"], 10), + } + ) + xlabel, ylabel = "x", "y" + ax = df.plot(kind="box", xlabel=xlabel, ylabel=ylabel, **vert) + assert ax.get_xlabel() == xlabel + assert ax.get_ylabel() == ylabel + + @pytest.mark.filterwarnings("ignore:set_ticklabels:UserWarning") + def test_plot_box(self, vert): + # GH 54941 + rng = np.random.default_rng(2) + df1 = DataFrame(rng.integers(0, 100, size=(10, 4)), columns=list("ABCD")) + df2 = DataFrame(rng.integers(0, 100, size=(10, 4)), columns=list("ABCD")) + + xlabel, ylabel = "x", "y" + _, axs = plt.subplots(ncols=2, figsize=(10, 7), sharey=True) + df1.plot.box(ax=axs[0], xlabel=xlabel, ylabel=ylabel, **vert) + df2.plot.box(ax=axs[1], xlabel=xlabel, ylabel=ylabel, **vert) + for ax in axs: + assert ax.get_xlabel() == xlabel + assert ax.get_ylabel() == ylabel + + @pytest.mark.filterwarnings("ignore:set_ticklabels:UserWarning") + def test_boxplot_xlabel_ylabel(self, vert): + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(10), + "b": np.random.default_rng(2).standard_normal(10), + "group": np.random.default_rng(2).choice(["group1", "group2"], 10), + } + ) + xlabel, ylabel = "x", "y" + ax = df.boxplot(xlabel=xlabel, ylabel=ylabel, **vert) + assert ax.get_xlabel() == xlabel + assert ax.get_ylabel() == ylabel + + @pytest.mark.filterwarnings("ignore:set_ticklabels:UserWarning") + def test_boxplot_group_xlabel_ylabel(self, vert): + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(10), + "b": np.random.default_rng(2).standard_normal(10), + "group": np.random.default_rng(2).choice(["group1", "group2"], 10), + } + ) + xlabel, ylabel = "x", "y" + ax = df.boxplot(by="group", xlabel=xlabel, ylabel=ylabel, **vert) + for subplot in ax: + assert subplot.get_xlabel() == xlabel + assert subplot.get_ylabel() == ylabel + + @pytest.mark.filterwarnings("ignore:set_ticklabels:UserWarning") + def test_boxplot_group_no_xlabel_ylabel(self, vert, request): + if Version(mpl.__version__) >= Version("3.10") and vert == { + "orientation": "horizontal" + }: + request.applymarker( + pytest.mark.xfail(reason=f"{vert} fails starting with matplotlib 3.10") + ) + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(10), + "b": np.random.default_rng(2).standard_normal(10), + "group": np.random.default_rng(2).choice(["group1", "group2"], 10), + } + ) + ax = df.boxplot(by="group", **vert) + for subplot in ax: + target_label = ( + subplot.get_xlabel() + if vert in ({"vert": True}, {"orientation": "vertical"}) + else subplot.get_ylabel() + ) + assert target_label == pprint_thing(["group"]) + + +class TestDataFrameGroupByPlots: + def test_boxplot_legacy1(self, hist_df): + grouped = hist_df.groupby(by="gender") + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + axes = _check_plot_works(grouped.boxplot, return_type="axes") + _check_axes_shape(list(axes.values), axes_num=2, layout=(1, 2)) + + def test_boxplot_legacy1_return_type(self, hist_df): + grouped = hist_df.groupby(by="gender") + axes = _check_plot_works(grouped.boxplot, subplots=False, return_type="axes") + _check_axes_shape(axes, axes_num=1, layout=(1, 1)) + + @pytest.mark.slow + def test_boxplot_legacy2(self): + tuples = zip(string.ascii_letters[:10], range(10), strict=True) + df = DataFrame( + np.random.default_rng(2).random((10, 3)), + index=MultiIndex.from_tuples(tuples), + ) + grouped = df.groupby(level=1) + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + axes = _check_plot_works(grouped.boxplot, return_type="axes") + _check_axes_shape(list(axes.values), axes_num=10, layout=(4, 3)) + + @pytest.mark.slow + def test_boxplot_legacy2_return_type(self): + tuples = zip(string.ascii_letters[:10], range(10), strict=True) + df = DataFrame( + np.random.default_rng(2).random((10, 3)), + index=MultiIndex.from_tuples(tuples), + ) + grouped = df.groupby(level=1) + axes = _check_plot_works(grouped.boxplot, subplots=False, return_type="axes") + _check_axes_shape(axes, axes_num=1, layout=(1, 1)) + + def test_grouped_plot_fignums(self): + n = 10 + weight = Series(np.random.default_rng(2).normal(166, 20, size=n)) + height = Series(np.random.default_rng(2).normal(60, 10, size=n)) + gender = np.random.default_rng(2).choice(["male", "female"], size=n) + df = DataFrame({"height": height, "weight": weight, "gender": gender}) + gb = df.groupby("gender") + + res = gb.plot() + assert len(mpl.pyplot.get_fignums()) == 2 + assert len(res) == 2 + plt.close("all") + + res = gb.boxplot(return_type="axes") + assert len(mpl.pyplot.get_fignums()) == 1 + assert len(res) == 2 + + def test_grouped_plot_fignums_excluded_col(self): + n = 10 + weight = Series(np.random.default_rng(2).normal(166, 20, size=n)) + height = Series(np.random.default_rng(2).normal(60, 10, size=n)) + gender = np.random.default_rng(2).choice(["male", "female"], size=n) + df = DataFrame({"height": height, "weight": weight, "gender": gender}) + # now works with GH 5610 as gender is excluded + df.groupby("gender").hist() + + @pytest.mark.slow + def test_grouped_box_return_type(self, hist_df): + df = hist_df + + # old style: return_type=None + result = df.boxplot(by="gender") + assert isinstance(result, np.ndarray) + _check_box_return_type( + result, None, expected_keys=["height", "weight", "category"] + ) + + @pytest.mark.slow + def test_grouped_box_return_type_groupby(self, hist_df): + df = hist_df + # now for groupby + result = df.groupby("gender").boxplot(return_type="dict") + _check_box_return_type(result, "dict", expected_keys=["Male", "Female"]) + + @pytest.mark.slow + @pytest.mark.parametrize("return_type", ["dict", "axes", "both"]) + def test_grouped_box_return_type_arg(self, hist_df, return_type): + df = hist_df + + returned = df.groupby("classroom").boxplot(return_type=return_type) + _check_box_return_type(returned, return_type, expected_keys=["A", "B", "C"]) + + returned = df.boxplot(by="classroom", return_type=return_type) + _check_box_return_type( + returned, return_type, expected_keys=["height", "weight", "category"] + ) + + @pytest.mark.slow + @pytest.mark.parametrize("return_type", ["dict", "axes", "both"]) + def test_grouped_box_return_type_arg_duplcate_cats(self, return_type): + columns2 = "X B C D A".split() + df2 = DataFrame( + np.random.default_rng(2).standard_normal((6, 5)), columns=columns2 + ) + categories2 = "A B".split() + df2["category"] = categories2 * 3 + + returned = df2.groupby("category").boxplot(return_type=return_type) + _check_box_return_type(returned, return_type, expected_keys=categories2) + + returned = df2.boxplot(by="category", return_type=return_type) + _check_box_return_type(returned, return_type, expected_keys=columns2) + + @pytest.mark.slow + def test_grouped_box_layout_too_small(self, hist_df): + df = hist_df + + msg = "Layout of 1x1 must be larger than required size 2" + with pytest.raises(ValueError, match=msg): + df.boxplot(column=["weight", "height"], by=df.gender, layout=(1, 1)) + + @pytest.mark.slow + def test_grouped_box_layout_needs_by(self, hist_df): + df = hist_df + msg = "The 'layout' keyword is not supported when 'by' is None" + with pytest.raises(ValueError, match=msg): + df.boxplot( + column=["height", "weight", "category"], + layout=(2, 1), + return_type="dict", + ) + + @pytest.mark.slow + def test_grouped_box_layout_positive_layout(self, hist_df): + df = hist_df + msg = "At least one dimension of layout must be positive" + with pytest.raises(ValueError, match=msg): + df.boxplot(column=["weight", "height"], by=df.gender, layout=(-1, -1)) + + @pytest.mark.slow + @pytest.mark.parametrize( + "gb_key, axes_num, rows", + [["gender", 2, 1], ["category", 4, 2], ["classroom", 3, 2]], + ) + def test_grouped_box_layout_positive_layout_axes( + self, hist_df, gb_key, axes_num, rows + ): + df = hist_df + # _check_plot_works adds an ax so catch warning. see GH #13188 GH 6769 + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + _check_plot_works( + df.groupby(gb_key).boxplot, column="height", return_type="dict" + ) + _check_axes_shape(mpl.pyplot.gcf().axes, axes_num=axes_num, layout=(rows, 2)) + + @pytest.mark.slow + @pytest.mark.parametrize( + "col, visible", [["height", False], ["weight", True], ["category", True]] + ) + def test_grouped_box_layout_visible(self, hist_df, col, visible): + df = hist_df + # GH 5897 + axes = df.boxplot( + column=["height", "weight", "category"], by="gender", return_type="axes" + ) + _check_axes_shape(mpl.pyplot.gcf().axes, axes_num=3, layout=(2, 2)) + ax = axes[col] + _check_visible(ax.get_xticklabels(), visible=visible) + _check_visible([ax.xaxis.get_label()], visible=visible) + + @pytest.mark.slow + def test_grouped_box_layout_shape(self, hist_df): + df = hist_df + df.groupby("classroom").boxplot( + column=["height", "weight", "category"], return_type="dict" + ) + _check_axes_shape(mpl.pyplot.gcf().axes, axes_num=3, layout=(2, 2)) + + @pytest.mark.slow + @pytest.mark.parametrize("cols", [2, -1]) + def test_grouped_box_layout_works(self, hist_df, cols): + df = hist_df + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + _check_plot_works( + df.groupby("category").boxplot, + column="height", + layout=(3, cols), + return_type="dict", + ) + _check_axes_shape(mpl.pyplot.gcf().axes, axes_num=4, layout=(3, 2)) + + @pytest.mark.slow + @pytest.mark.parametrize("rows, res", [[4, 4], [-1, 3]]) + def test_grouped_box_layout_axes_shape_rows(self, hist_df, rows, res): + df = hist_df + df.boxplot( + column=["height", "weight", "category"], by="gender", layout=(rows, 1) + ) + _check_axes_shape(mpl.pyplot.gcf().axes, axes_num=3, layout=(res, 1)) + + @pytest.mark.slow + @pytest.mark.parametrize("cols, res", [[4, 4], [-1, 3]]) + def test_grouped_box_layout_axes_shape_cols_groupby(self, hist_df, cols, res): + df = hist_df + df.groupby("classroom").boxplot( + column=["height", "weight", "category"], + layout=(1, cols), + return_type="dict", + ) + _check_axes_shape(mpl.pyplot.gcf().axes, axes_num=3, layout=(1, res)) + + @pytest.mark.slow + def test_grouped_box_multiple_axes(self, hist_df): + # GH 6970, GH 7069 + df = hist_df + + # check warning to ignore sharex / sharey + # this check should be done in the first function which + # passes multiple axes to plot, hist or boxplot + # location should be changed if other test is added + # which has earlier alphabetical order + with tm.assert_produces_warning(UserWarning, match="sharex and sharey"): + _, axes = mpl.pyplot.subplots(2, 2) + df.groupby("category").boxplot(column="height", return_type="axes", ax=axes) + _check_axes_shape(mpl.pyplot.gcf().axes, axes_num=4, layout=(2, 2)) + + @pytest.mark.slow + def test_grouped_box_multiple_axes_on_fig(self, hist_df): + # GH 6970, GH 7069 + df = hist_df + fig, axes = mpl.pyplot.subplots(2, 3) + with tm.assert_produces_warning(UserWarning, match="sharex and sharey"): + returned = df.boxplot( + column=["height", "weight", "category"], + by="gender", + return_type="axes", + ax=axes[0], + ) + returned = np.array(list(returned.values)) + _check_axes_shape(returned, axes_num=3, layout=(1, 3)) + tm.assert_numpy_array_equal(returned, axes[0]) + assert returned[0].figure is fig + + # draw on second row + with tm.assert_produces_warning(UserWarning, match="sharex and sharey"): + returned = df.groupby("classroom").boxplot( + column=["height", "weight", "category"], return_type="axes", ax=axes[1] + ) + returned = np.array(list(returned.values)) + _check_axes_shape(returned, axes_num=3, layout=(1, 3)) + tm.assert_numpy_array_equal(returned, axes[1]) + assert returned[0].figure is fig + + @pytest.mark.slow + def test_grouped_box_multiple_axes_ax_error(self, hist_df): + # GH 6970, GH 7069 + df = hist_df + msg = "The number of passed axes must be 3, the same as the output plot" + _, axes = mpl.pyplot.subplots(2, 3) + with pytest.raises(ValueError, match=msg): + # pass different number of axes from required + with tm.assert_produces_warning(UserWarning, match="sharex and sharey"): + axes = df.groupby("classroom").boxplot(ax=axes) + + def test_fontsize(self): + df = DataFrame({"a": [1, 2, 3, 4, 5, 6], "b": [0, 0, 0, 1, 1, 1]}) + _check_ticks_props( + df.boxplot("a", by="b", fontsize=16), xlabelsize=16, ylabelsize=16 + ) + + @pytest.mark.parametrize( + "col, expected_xticklabel", + [ + ("v", ["(a, v)", "(b, v)", "(c, v)", "(d, v)", "(e, v)"]), + (["v"], ["(a, v)", "(b, v)", "(c, v)", "(d, v)", "(e, v)"]), + ("v1", ["(a, v1)", "(b, v1)", "(c, v1)", "(d, v1)", "(e, v1)"]), + ( + ["v", "v1"], + [ + "(a, v)", + "(a, v1)", + "(b, v)", + "(b, v1)", + "(c, v)", + "(c, v1)", + "(d, v)", + "(d, v1)", + "(e, v)", + "(e, v1)", + ], + ), + ( + None, + [ + "(a, v)", + "(a, v1)", + "(b, v)", + "(b, v1)", + "(c, v)", + "(c, v1)", + "(d, v)", + "(d, v1)", + "(e, v)", + "(e, v1)", + ], + ), + ], + ) + def test_groupby_boxplot_subplots_false(self, col, expected_xticklabel): + # GH 16748 + df = DataFrame( + { + "cat": np.random.default_rng(2).choice(list("abcde"), 100), + "v": np.random.default_rng(2).random(100), + "v1": np.random.default_rng(2).random(100), + } + ) + grouped = df.groupby("cat") + + axes = _check_plot_works( + grouped.boxplot, subplots=False, column=col, return_type="axes" + ) + + result_xticklabel = [x.get_text() for x in axes.get_xticklabels()] + assert expected_xticklabel == result_xticklabel + + def test_groupby_boxplot_object(self, hist_df): + # GH 43480 + df = hist_df.astype("object") + grouped = df.groupby("gender") + msg = "boxplot method requires numerical columns, nothing to plot" + with pytest.raises(ValueError, match=msg): + _check_plot_works(grouped.boxplot, subplots=False) + + def test_boxplot_multiindex_column(self): + # GH 16748 + arrays = [ + ["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"], + ["one", "two", "one", "two", "one", "two", "one", "two"], + ] + tuples = list(zip(*arrays, strict=True)) + index = MultiIndex.from_tuples(tuples, names=["first", "second"]) + df = DataFrame( + np.random.default_rng(2).standard_normal((3, 8)), + index=["A", "B", "C"], + columns=index, + ) + + col = [("bar", "one"), ("bar", "two")] + axes = _check_plot_works(df.boxplot, column=col, return_type="axes") + + expected_xticklabel = ["(bar, one)", "(bar, two)"] + result_xticklabel = [x.get_text() for x in axes.get_xticklabels()] + assert expected_xticklabel == result_xticklabel + + @pytest.mark.parametrize("group", ["X", ["X", "Y"]]) + def test_boxplot_multi_groupby_groups(self, group): + # GH 14701 + rows = 20 + df = DataFrame( + np.random.default_rng(12).normal(size=(rows, 2)), columns=["Col1", "Col2"] + ) + df["X"] = Series(np.repeat(["A", "B"], int(rows / 2))) + df["Y"] = Series(np.tile(["C", "D"], int(rows / 2))) + grouped = df.groupby(group) + _check_plot_works(df.boxplot, by=group, default_axes=True) + _check_plot_works(df.plot.box, by=group, default_axes=True) + _check_plot_works(grouped.boxplot, default_axes=True) diff --git a/pandas/tests/plotting/test_common.py b/pandas/tests/plotting/test_common.py new file mode 100644 index 0000000000000000000000000000000000000000..20daf5935624843af3224f991497f84fa6639a0d --- /dev/null +++ b/pandas/tests/plotting/test_common.py @@ -0,0 +1,60 @@ +import pytest + +from pandas import DataFrame +from pandas.tests.plotting.common import ( + _check_plot_works, + _check_ticks_props, + _gen_two_subplots, +) + +plt = pytest.importorskip("matplotlib.pyplot") + + +class TestCommon: + def test__check_ticks_props(self): + # GH 34768 + df = DataFrame({"b": [0, 1, 0], "a": [1, 2, 3]}) + ax = _check_plot_works(df.plot, rot=30) + ax.yaxis.set_tick_params(rotation=30) + msg = "expected 0.00000 but got " + with pytest.raises(AssertionError, match=msg): + _check_ticks_props(ax, xrot=0) + with pytest.raises(AssertionError, match=msg): + _check_ticks_props(ax, xlabelsize=0) + with pytest.raises(AssertionError, match=msg): + _check_ticks_props(ax, yrot=0) + with pytest.raises(AssertionError, match=msg): + _check_ticks_props(ax, ylabelsize=0) + + def test__gen_two_subplots_with_ax(self): + fig = plt.gcf() + gen = _gen_two_subplots(f=lambda **kwargs: None, fig=fig, ax="test") + # On the first yield, no subplot should be added since ax was passed + next(gen) + assert fig.get_axes() == [] + # On the second, the one axis should match fig.subplot(2, 1, 2) + next(gen) + axes = fig.get_axes() + assert len(axes) == 1 + subplot_geometry = list(axes[0].get_subplotspec().get_geometry()[:-1]) + subplot_geometry[-1] += 1 + assert subplot_geometry == [2, 1, 2] + + def test_colorbar_layout(self): + fig = plt.figure() + + axes = fig.subplot_mosaic( + """ + AB + CC + """ + ) + + x = [1, 2, 3] + y = [1, 2, 3] + + cs0 = axes["A"].scatter(x, y) + axes["B"].scatter(x, y) + + fig.colorbar(cs0, ax=[axes["A"], axes["B"]], location="right") + DataFrame(x).plot(ax=axes["C"]) diff --git a/pandas/tests/plotting/test_converter.py b/pandas/tests/plotting/test_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..e33e91ccf6c6e2c508d4b528dbf1c194cd7918fa --- /dev/null +++ b/pandas/tests/plotting/test_converter.py @@ -0,0 +1,391 @@ +from datetime import ( + date, + datetime, +) +import subprocess +import sys + +import numpy as np +import pytest + +import pandas._config.config as cf + +from pandas._libs.tslibs import to_offset + +from pandas import ( + Index, + Period, + PeriodIndex, + Series, + Timestamp, + arrays, + date_range, +) +import pandas._testing as tm + +from pandas.plotting import ( + deregister_matplotlib_converters, + register_matplotlib_converters, +) +from pandas.tseries.offsets import ( + Day, + Micro, + Milli, + Second, +) + +plt = pytest.importorskip("matplotlib.pyplot") +dates = pytest.importorskip("matplotlib.dates") +units = pytest.importorskip("matplotlib.units") + +from pandas.plotting._matplotlib import converter + + +@pytest.mark.single_cpu +def test_registry_mpl_resets(): + # Check that Matplotlib converters are properly reset (see issue #27481) + code = ( + "import matplotlib.units as units; " + "import matplotlib.dates as mdates; " + "n_conv = len(units.registry); " + "import pandas as pd; " + "pd.plotting.register_matplotlib_converters(); " + "pd.plotting.deregister_matplotlib_converters(); " + "assert len(units.registry) == n_conv" + ) + call = [sys.executable, "-c", code] + subprocess.check_output(call) + + +def test_timtetonum_accepts_unicode(): + assert converter.time2num("00:01") == converter.time2num("00:01") + + +class TestRegistration: + @pytest.mark.single_cpu + def test_dont_register_by_default(self): + # Run in subprocess to ensure a clean state + code = ( + "import matplotlib.units; " + "import pandas as pd; " + "units = dict(matplotlib.units.registry); " + "assert pd.Timestamp not in units" + ) + call = [sys.executable, "-c", code] + assert subprocess.check_call(call) == 0 + + def test_registering_no_warning(self): + s = Series(range(12), index=date_range("2017", periods=12)) + _, ax = plt.subplots() + + # Set to the "warn" state, in case this isn't the first test run + register_matplotlib_converters() + ax.plot(s.index, s.values) + + def test_pandas_plots_register(self): + s = Series(range(12), index=date_range("2017", periods=12)) + # Set to the "warn" state, in case this isn't the first test run + with tm.assert_produces_warning(None) as w: + s.plot() + + assert len(w) == 0 + + def test_matplotlib_formatters(self): + # Can't make any assertion about the start state. + # We we check that toggling converters off removes it, and toggling it + # on restores it. + + with cf.option_context("plotting.matplotlib.register_converters", True): + with cf.option_context("plotting.matplotlib.register_converters", False): + assert Timestamp not in units.registry + assert Timestamp in units.registry + + def test_option_no_warning(self): + s = Series(range(12), index=date_range("2017", periods=12)) + _, ax = plt.subplots() + + # Test without registering first, no warning + with cf.option_context("plotting.matplotlib.register_converters", False): + ax.plot(s.index, s.values) + + # Now test with registering + register_matplotlib_converters() + with cf.option_context("plotting.matplotlib.register_converters", False): + ax.plot(s.index, s.values) + + def test_registry_resets(self): + # make a copy, to reset to + original = dict(units.registry) + + try: + # get to a known state + units.registry.clear() + date_converter = dates.DateConverter() + units.registry[datetime] = date_converter + units.registry[date] = date_converter + + register_matplotlib_converters() + assert units.registry[date] is not date_converter + deregister_matplotlib_converters() + assert units.registry[date] is date_converter + + finally: + # restore original stater + units.registry.clear() + for k, v in original.items(): + units.registry[k] = v + + +class TestDateTimeConverter: + @pytest.fixture + def dtc(self): + return converter.DatetimeConverter() + + def test_convert_accepts_unicode(self, dtc): + r1 = dtc.convert("2000-01-01 12:22", None, None) + r2 = dtc.convert("2000-01-01 12:22", None, None) + assert r1 == r2, "DatetimeConverter.convert should accept unicode" + + def test_conversion(self, dtc): + rs = dtc.convert(["2012-1-1"], None, None)[0] + xp = dates.date2num(datetime(2012, 1, 1)) + assert rs == xp + + rs = dtc.convert("2012-1-1", None, None) + assert rs == xp + + rs = dtc.convert(date(2012, 1, 1), None, None) + assert rs == xp + + rs = dtc.convert("2012-1-1", None, None) + assert rs == xp + + rs = dtc.convert(Timestamp("2012-1-1"), None, None) + assert rs == xp + + # also testing datetime64 dtype (GH8614) + rs = dtc.convert("2012-01-01", None, None) + assert rs == xp + + rs = dtc.convert("2012-01-01 00:00:00+0000", None, None) + assert rs == xp + + rs = dtc.convert( + np.array(["2012-01-01 00:00:00+0000", "2012-01-02 00:00:00+0000"]), + None, + None, + ) + assert rs[0] == xp + + # we have a tz-aware date (constructed to that when we turn to utc it + # is the same as our sample) + ts = Timestamp("2012-01-01").tz_localize("UTC").tz_convert("US/Eastern") + rs = dtc.convert(ts, None, None) + assert rs == xp + + rs = dtc.convert(ts.to_pydatetime(), None, None) + assert rs == xp + + rs = dtc.convert(Index([ts - Day(1), ts]), None, None) + assert rs[1] == xp + + rs = dtc.convert(Index([ts - Day(1), ts]).to_pydatetime(), None, None) + assert rs[1] == xp + + def test_conversion_float(self, dtc): + rtol = 0.5 * 10**-9 + + rs = dtc.convert(Timestamp("2012-1-1 01:02:03", tz="UTC"), None, None) + xp = dates.date2num(Timestamp("2012-1-1 01:02:03", tz="UTC")) + tm.assert_almost_equal(rs, xp, rtol=rtol) + + rs = dtc.convert( + Timestamp("2012-1-1 09:02:03", tz="Asia/Hong_Kong"), None, None + ) + tm.assert_almost_equal(rs, xp, rtol=rtol) + + rs = dtc.convert(datetime(2012, 1, 1, 1, 2, 3), None, None) + tm.assert_almost_equal(rs, xp, rtol=rtol) + + @pytest.mark.parametrize( + "values", + [ + [date(1677, 1, 1), date(1677, 1, 2)], + [datetime(1677, 1, 1, 12), datetime(1677, 1, 2, 12)], + ], + ) + def test_conversion_outofbounds_datetime(self, dtc, values): + # 2579 + rs = dtc.convert(values, None, None) + xp = dates.date2num(values) + tm.assert_numpy_array_equal(rs, xp) + rs = dtc.convert(values[0], None, None) + xp = dates.date2num(values[0]) + assert rs == xp + + @pytest.mark.parametrize( + "time,format_expected", + [ + (0, "00:00"), # time2num(datetime.time.min) + (86399.999999, "23:59:59.999999"), # time2num(datetime.time.max) + (90000, "01:00"), + (3723, "01:02:03"), + (39723.2, "11:02:03.200"), + ], + ) + def test_time_formatter(self, time, format_expected): + # issue 18478 + result = converter.TimeFormatter(None)(time) + assert result == format_expected + + @pytest.mark.parametrize("freq", ("B", "ms", "s")) + def test_dateindex_conversion(self, freq, dtc): + rtol = 10**-9 + dateindex = date_range("2020-01-01", periods=10, freq=freq) + rs = dtc.convert(dateindex, None, None) + xp = dates.date2num(dateindex._mpl_repr()) + tm.assert_almost_equal(rs, xp, rtol=rtol) + + @pytest.mark.parametrize("offset", [Second(), Milli(), Micro(50)]) + def test_resolution(self, offset, dtc): + # Matplotlib's time representation using floats cannot distinguish + # intervals smaller than ~10 microsecond in the common range of years. + ts1 = Timestamp("2012-1-1") + ts2 = ts1 + offset + val1 = dtc.convert(ts1, None, None) + val2 = dtc.convert(ts2, None, None) + if not val1 < val2: + raise AssertionError(f"{val1} is not less than {val2}.") + + def test_convert_nested(self, dtc): + inner = [Timestamp("2017-01-01"), Timestamp("2017-01-02")] + data = [inner, inner] + result = dtc.convert(data, None, None) + expected = [dtc.convert(x, None, None) for x in data] + assert (np.array(result) == expected).all() + + +class TestPeriodConverter: + @pytest.fixture + def pc(self): + return converter.PeriodConverter() + + @pytest.fixture + def axis(self): + class Axis: + pass + + axis = Axis() + axis.freq = "D" + return axis + + def test_convert_accepts_unicode(self, pc, axis): + r1 = pc.convert("2012-1-1", None, axis) + r2 = pc.convert("2012-1-1", None, axis) + assert r1 == r2 + + def test_conversion(self, pc, axis): + rs = pc.convert(["2012-1-1"], None, axis)[0] + xp = Period("2012-1-1").ordinal + assert rs == xp + + rs = pc.convert("2012-1-1", None, axis) + assert rs == xp + + rs = pc.convert([date(2012, 1, 1)], None, axis)[0] + assert rs == xp + + rs = pc.convert(date(2012, 1, 1), None, axis) + assert rs == xp + + rs = pc.convert([Timestamp("2012-1-1")], None, axis)[0] + assert rs == xp + + rs = pc.convert(Timestamp("2012-1-1"), None, axis) + assert rs == xp + + rs = pc.convert("2012-01-01", None, axis) + assert rs == xp + + rs = pc.convert("2012-01-01 00:00:00+0000", None, axis) + assert rs == xp + + rs = pc.convert( + np.array( + ["2012-01-01 00:00:00", "2012-01-02 00:00:00"], + dtype="datetime64[ns]", + ), + None, + axis, + ) + assert rs[0] == xp + + def test_integer_passthrough(self, pc, axis): + # GH9012 + rs = pc.convert([0, 1], None, axis) + xp = [0, 1] + assert rs == xp + + def test_convert_nested(self, pc, axis): + data = ["2012-1-1", "2012-1-2"] + r1 = pc.convert([data, data], None, axis) + r2 = [pc.convert(data, None, axis) for _ in range(2)] + assert r1 == r2 + + +class TestTimeDeltaConverter: + """Test timedelta converter""" + + @pytest.mark.parametrize( + "x, decimal, format_expected", + [ + (0.0, 0, "00:00:00"), + (3972320000000, 1, "01:06:12.3"), + (713233432000000, 2, "8 days 06:07:13.43"), + (32423432000000, 4, "09:00:23.4320"), + ], + ) + def test_format_timedelta_ticks(self, x, decimal, format_expected): + tdc = converter.TimeSeries_TimedeltaFormatter + result = tdc.format_timedelta_ticks(x, pos=None, n_decimals=decimal, exp=9) + assert result == format_expected + + @pytest.mark.parametrize("view_interval", [(1, 2), (2, 1)]) + def test_call_w_different_view_intervals(self, view_interval, monkeypatch): + # previously broke on reversed xlmits; see GH37454 + class mock_axis: + def get_view_interval(self): + return view_interval + + tdc = converter.TimeSeries_TimedeltaFormatter() + monkeypatch.setattr(tdc, "axis", mock_axis()) + tdc(0.0, 0) + + +@pytest.mark.parametrize("year_span", [11.25, 30, 80, 150, 400, 800, 1500, 2500, 3500]) +# The range is limited to 11.25 at the bottom by if statements in +# the _quarterly_finder() function +def test_quarterly_finder(year_span): + vmin = -1000 + vmax = vmin + year_span * 4 + span = vmax - vmin + 1 + if span < 45: + pytest.skip("the quarterly finder is only invoked if the span is >= 45") + nyears = span / 4 + (min_anndef, maj_anndef) = converter._get_default_annual_spacing(nyears) + result = converter._quarterly_finder(vmin, vmax, to_offset("QE")) + quarters = PeriodIndex( + arrays.PeriodArray(np.array([x[0] for x in result]), dtype="period[Q]") + ) + majors = np.array([x[1] for x in result]) + minors = np.array([x[2] for x in result]) + major_quarters = quarters[majors] + minor_quarters = quarters[minors] + check_major_years = major_quarters.year % maj_anndef == 0 + check_minor_years = minor_quarters.year % min_anndef == 0 + check_major_quarters = major_quarters.quarter == 1 + check_minor_quarters = minor_quarters.quarter == 1 + assert np.all(check_major_years) + assert np.all(check_minor_years) + assert np.all(check_major_quarters) + assert np.all(check_minor_quarters) diff --git a/pandas/tests/plotting/test_datetimelike.py b/pandas/tests/plotting/test_datetimelike.py new file mode 100644 index 0000000000000000000000000000000000000000..fb845c6e6d71d0895c4424b2b54eb333c7b320ce --- /dev/null +++ b/pandas/tests/plotting/test_datetimelike.py @@ -0,0 +1,1721 @@ +"""Test cases for time series specific (freq conversion, etc)""" + +from datetime import ( + date, + datetime, + time, + timedelta, +) +import pickle + +import numpy as np +import pytest + +from pandas._libs.tslibs import ( + BaseOffset, + to_offset, +) + +from pandas.core.dtypes.dtypes import PeriodDtype + +from pandas import ( + DataFrame, + Index, + NaT, + Series, + concat, + isna, + to_datetime, +) +import pandas._testing as tm +from pandas.core.indexes.datetimes import ( + DatetimeIndex, + bdate_range, + date_range, +) +from pandas.core.indexes.period import ( + Period, + PeriodIndex, + period_range, +) +from pandas.core.indexes.timedeltas import timedelta_range +from pandas.tests.plotting.common import _check_ticks_props + +from pandas.tseries.offsets import WeekOfMonth + +mpl = pytest.importorskip("matplotlib") +plt = pytest.importorskip("matplotlib.pyplot") + +import pandas.plotting._matplotlib.converter as conv + + +class TestTSPlot: + @pytest.mark.filterwarnings("ignore::UserWarning") + def test_ts_plot_with_tz(self, tz_aware_fixture): + # GH2877, GH17173, GH31205, GH31580 + tz = tz_aware_fixture + index = date_range("1/1/2011", periods=2, freq="h", tz=tz) + ts = Series([188.5, 328.25], index=index) + _check_plot_works(ts.plot) + ax = ts.plot() + xdata = next(iter(ax.get_lines())).get_xdata() + # Check first and last points' labels are correct + assert (xdata[0].hour, xdata[0].minute) == (0, 0) + assert (xdata[-1].hour, xdata[-1].minute) == (1, 0) + + def test_fontsize_set_correctly(self): + # For issue #8765 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 9)), index=range(10) + ) + _, ax = mpl.pyplot.subplots() + df.plot(fontsize=2, ax=ax) + for label in ax.get_xticklabels() + ax.get_yticklabels(): + assert label.get_fontsize() == 2 + + def test_frame_inferred(self): + # inferred freq + idx = date_range("1/1/1987", freq="MS", periods=10) + idx = DatetimeIndex(idx.values, freq=None) + + df = DataFrame( + np.random.default_rng(2).standard_normal((len(idx), 3)), index=idx + ) + _check_plot_works(df.plot) + + # axes freq + idx = idx[0:4].union(idx[6:]) + df2 = DataFrame( + np.random.default_rng(2).standard_normal((len(idx), 3)), index=idx + ) + _check_plot_works(df2.plot) + + def test_frame_inferred_n_gt_1(self): + # N > 1 + idx = date_range("2008-1-1 00:15:00", freq="15min", periods=10) + idx = DatetimeIndex(idx.values, freq=None) + df = DataFrame( + np.random.default_rng(2).standard_normal((len(idx), 3)), index=idx + ) + _check_plot_works(df.plot) + + def test_is_error_nozeroindex(self): + # GH11858 + i = np.array([1, 2, 3]) + a = DataFrame(i, index=i) + _check_plot_works(a.plot, xerr=a) + _check_plot_works(a.plot, yerr=a) + + def test_nonnumeric_exclude(self): + idx = date_range("1/1/1987", freq="YE", periods=3) + df = DataFrame({"A": ["x", "y", "z"], "B": [1, 2, 3]}, idx) + + fig, ax = mpl.pyplot.subplots() + df.plot(ax=ax) # it works + assert len(ax.get_lines()) == 1 # B was plotted + + def test_nonnumeric_exclude_error(self): + idx = date_range("1/1/1987", freq="YE", periods=3) + df = DataFrame({"A": ["x", "y", "z"], "B": [1, 2, 3]}, idx) + msg = "no numeric data to plot" + with pytest.raises(TypeError, match=msg): + df["A"].plot() + + @pytest.mark.parametrize("freq", ["s", "min", "h", "D", "W", "M", "Q", "Y"]) + def test_tsplot_period(self, freq): + idx = period_range("12/31/1999", freq=freq, periods=10) + ser = Series(np.random.default_rng(2).standard_normal(len(idx)), idx) + _, ax = mpl.pyplot.subplots() + _check_plot_works(ser.plot, ax=ax) + + @pytest.mark.parametrize( + "freq", ["s", "min", "h", "D", "W", "ME", "QE-DEC", "YE", "1B30Min"] + ) + def test_tsplot_datetime(self, freq): + idx = date_range("12/31/1999", freq=freq, periods=10) + ser = Series(np.random.default_rng(2).standard_normal(len(idx)), idx) + _, ax = mpl.pyplot.subplots() + _check_plot_works(ser.plot, ax=ax) + + def test_tsplot(self): + ts = Series( + np.arange(10, dtype=np.float64), index=date_range("2020-01-01", periods=10) + ) + _, ax = mpl.pyplot.subplots() + ts.plot(style="k", ax=ax) + color = (0.0, 0.0, 0.0, 1) + assert color == ax.get_lines()[0].get_color() + + @pytest.mark.parametrize("index", [None, date_range("2020-01-01", periods=10)]) + def test_both_style_and_color(self, index): + ts = Series(np.arange(10, dtype=np.float64), index=index) + msg = ( + "Cannot pass 'style' string with a color symbol and 'color' " + "keyword argument. Please use one or the other or pass 'style' " + "without a color symbol" + ) + with pytest.raises(ValueError, match=msg): + ts.plot(style="b-", color="#000099") + + @pytest.mark.parametrize("freq", ["ms", "us"]) + def test_high_freq(self, freq): + _, ax = mpl.pyplot.subplots() + rng = date_range("1/1/2012", periods=10, freq=freq) + ser = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + _check_plot_works(ser.plot, ax=ax) + + def test_get_datevalue(self): + assert conv._get_datevalue(None, "D") is None + assert conv._get_datevalue(1987, "Y") == 1987 + assert ( + conv._get_datevalue(Period(1987, "Y"), "M") + == Period("1987-12", "M").ordinal + ) + assert conv._get_datevalue("1/1/1987", "D") == Period("1987-1-1", "D").ordinal + + @pytest.mark.parametrize( + "freq, expected_string", + [["YE-DEC", "t = 2014 y = 1.000000"], ["D", "t = 2014-01-01 y = 1.000000"]], + ) + def test_ts_plot_format_coord(self, freq, expected_string): + ser = Series(1, index=date_range("2014-01-01", periods=3, freq=freq)) + _, ax = mpl.pyplot.subplots() + ser.plot(ax=ax) + first_line = ax.get_lines()[0] + first_x = first_line.get_xdata()[0].ordinal + first_y = first_line.get_ydata()[0] + assert expected_string == ax.format_coord(first_x, first_y) + + @pytest.mark.parametrize("freq", ["s", "min", "h", "D", "W", "M", "Q", "Y"]) + def test_line_plot_period_series(self, freq): + idx = period_range("12/31/1999", freq=freq, periods=10) + ser = Series(np.random.default_rng(2).standard_normal(len(idx)), idx) + _check_plot_works(ser.plot, ser.index.freq) + + @pytest.mark.parametrize( + "frqncy", ["1s", "3s", "5min", "7h", "4D", "8W", "11M", "3Y"] + ) + def test_line_plot_period_mlt_series(self, frqncy): + # test period index line plot for series with multiples (`mlt`) of the + # frequency (`frqncy`) rule code. tests resolution of issue #14763 + idx = period_range("12/31/1999", freq=frqncy, periods=10) + s = Series(np.random.default_rng(2).standard_normal(len(idx)), idx) + _check_plot_works(s.plot, s.index.freq.rule_code) + + @pytest.mark.parametrize( + "freq", ["s", "min", "h", "D", "W", "ME", "QE-DEC", "YE", "1B30Min"] + ) + def test_line_plot_datetime_series(self, freq): + idx = date_range("12/31/1999", freq=freq, periods=10) + ser = Series(np.random.default_rng(2).standard_normal(len(idx)), idx) + _check_plot_works(ser.plot, ser.index.freq.rule_code) + + @pytest.mark.parametrize("freq", ["s", "min", "h", "D", "W", "ME", "QE", "YE"]) + def test_line_plot_period_frame(self, freq): + idx = date_range("12/31/1999", freq=freq, periods=10) + df = DataFrame( + np.random.default_rng(2).standard_normal((len(idx), 3)), + index=idx, + columns=["A", "B", "C"], + ) + _check_plot_works(df.plot, df.index.freq) + + @pytest.mark.parametrize( + "frqncy", ["1s", "3s", "5min", "7h", "4D", "8W", "11M", "3Y"] + ) + def test_line_plot_period_mlt_frame(self, frqncy): + # test period index line plot for DataFrames with multiples (`mlt`) + # of the frequency (`frqncy`) rule code. tests resolution of issue + # #14763 + idx = period_range("12/31/1999", freq=frqncy, periods=10) + df = DataFrame( + np.random.default_rng(2).standard_normal((len(idx), 3)), + index=idx, + columns=["A", "B", "C"], + ) + freq = df.index.freq.rule_code + _check_plot_works(df.plot, freq) + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + @pytest.mark.parametrize( + "freq", ["s", "min", "h", "D", "W", "ME", "QE-DEC", "YE", "1B30Min"] + ) + def test_line_plot_datetime_frame(self, freq): + idx = date_range("12/31/1999", freq=freq, periods=10) + df = DataFrame( + np.random.default_rng(2).standard_normal((len(idx), 3)), + index=idx, + columns=["A", "B", "C"], + ) + freq = PeriodDtype(df.index.freq)._freqstr + freq = df.index.to_period(freq).freq + _check_plot_works(df.plot, freq) + + @pytest.mark.parametrize( + "freq", ["s", "min", "h", "D", "W", "ME", "QE-DEC", "YE", "1B30Min"] + ) + def test_line_plot_inferred_freq(self, freq): + idx = date_range("12/31/1999", freq=freq, periods=10) + ser = Series(np.random.default_rng(2).standard_normal(len(idx)), idx) + ser = Series(ser.values, Index(np.asarray(ser.index))) + _check_plot_works(ser.plot, ser.index.inferred_freq) + + ser = ser.iloc[[0, 3, 5, 6]] + _check_plot_works(ser.plot) + + def test_fake_inferred_business(self): + _, ax = mpl.pyplot.subplots() + rng = date_range("2001-1-1", "2001-1-10") + ts = Series(range(len(rng)), index=rng) + ts = concat([ts[:3], ts[5:]]) + ts.plot(ax=ax) + assert not hasattr(ax, "freq") + + def test_plot_offset_freq(self): + ser = Series( + np.arange(10, dtype=np.float64), index=date_range("2020-01-01", periods=10) + ) + _check_plot_works(ser.plot) + + def test_plot_offset_freq_business(self): + dr = date_range("2023-01-01", freq="BQS", periods=10) + ser = Series(np.random.default_rng(2).standard_normal(len(dr)), index=dr) + _check_plot_works(ser.plot) + + def test_plot_multiple_inferred_freq(self): + dr = Index([datetime(2000, 1, 1), datetime(2000, 1, 6), datetime(2000, 1, 11)]) + ser = Series(np.random.default_rng(2).standard_normal(len(dr)), index=dr) + _check_plot_works(ser.plot) + + def test_irreg_hf(self): + idx = date_range("2012-6-22 21:59:51", freq="s", periods=10) + df = DataFrame( + np.random.default_rng(2).standard_normal((len(idx), 2)), index=idx + ) + + irreg = df.iloc[[0, 1, 3, 4]] + _, ax = mpl.pyplot.subplots() + irreg.plot(ax=ax) + diffs = Series(ax.get_lines()[0].get_xydata()[:, 0]).diff() + + sec = 1.0 / 24 / 60 / 60 + assert (np.fabs(diffs[1:] - [sec, sec * 2, sec]) < 1e-8).all() + + def test_irreg_hf_object(self): + idx = date_range("2012-6-22 21:59:51", freq="s", periods=10) + df2 = DataFrame( + np.random.default_rng(2).standard_normal((len(idx), 2)), index=idx + ) + _, ax = mpl.pyplot.subplots() + df2.index = df2.index.astype(object) + df2.plot(ax=ax) + diffs = Series(ax.get_lines()[0].get_xydata()[:, 0]).diff() + sec = 1.0 / 24 / 60 / 60 + assert (np.fabs(diffs[1:] - sec) < 1e-8).all() + + def test_irregular_datetime64_repr_bug(self): + ser = Series( + np.arange(10, dtype=np.float64), index=date_range("2020-01-01", periods=10) + ) + ser = ser.iloc[[0, 1, 2, 7]] + + _, ax = mpl.pyplot.subplots() + + ret = ser.plot(ax=ax) + assert ret is not None + + for rs, xp in zip(ax.get_lines()[0].get_xdata(), ser.index, strict=True): + assert rs == xp + + def test_business_freq(self): + bts = Series(range(5), period_range("2020-01-01", periods=5)) + msg = r"PeriodDtype\[B\] is deprecated" + dt = bts.index[0].to_timestamp() + with tm.assert_produces_warning(FutureWarning, match=msg): + bts.index = period_range(start=dt, periods=len(bts), freq="B") + _, ax = mpl.pyplot.subplots() + bts.plot(ax=ax) + assert ax.get_lines()[0].get_xydata()[0, 0] == bts.index[0].ordinal + idx = ax.get_lines()[0].get_xdata() + with tm.assert_produces_warning(FutureWarning, match=msg): + assert PeriodIndex(data=idx).freqstr == "B" + + def test_business_freq_convert(self): + bts = Series( + np.arange(50, dtype=np.float64), + index=date_range("2020-01-01", periods=50, freq="B"), + ).asfreq("BME") + ts = bts.to_period("M") + _, ax = mpl.pyplot.subplots() + bts.plot(ax=ax) + assert ax.get_lines()[0].get_xydata()[0, 0] == ts.index[0].ordinal + idx = ax.get_lines()[0].get_xdata() + assert PeriodIndex(data=idx).freqstr == "M" + + def test_freq_with_no_period_alias(self): + # GH34487 + freq = WeekOfMonth() + bts = Series( + np.arange(10, dtype=np.float64), index=date_range("2020-01-01", periods=10) + ).asfreq(freq) + _, ax = mpl.pyplot.subplots() + bts.plot(ax=ax) + + idx = ax.get_lines()[0].get_xdata() + msg = "freq not specified and cannot be inferred" + with pytest.raises(ValueError, match=msg): + PeriodIndex(data=idx) + + def test_nonzero_base(self): + # GH2571 + idx = date_range("2012-12-20", periods=24, freq="h") + timedelta(minutes=30) + df = DataFrame(np.arange(24), index=idx) + _, ax = mpl.pyplot.subplots() + df.plot(ax=ax) + rs = ax.get_lines()[0].get_xdata() + assert not Index(rs).is_normalized + + def test_dataframe(self): + bts = DataFrame( + { + "a": Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + ) + } + ) + _, ax = mpl.pyplot.subplots() + bts.plot(ax=ax) + idx = ax.get_lines()[0].get_xdata() + tm.assert_index_equal(bts.index.to_period(), PeriodIndex(idx)) + + @pytest.mark.filterwarnings( + "ignore:Period with BDay freq is deprecated:FutureWarning" + ) + @pytest.mark.parametrize( + "obj", + [ + Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + ), + DataFrame( + { + "a": Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + ), + "b": Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + ) + + 1, + } + ), + ], + ) + def test_axis_limits(self, obj): + _, ax = mpl.pyplot.subplots() + obj.plot(ax=ax) + xlim = ax.get_xlim() + ax.set_xlim(xlim[0] - 5, xlim[1] + 10) + result = ax.get_xlim() + assert result[0] == xlim[0] - 5 + assert result[1] == xlim[1] + 10 + + # string + expected = (Period("1/1/2000", ax.freq), Period("4/1/2000", ax.freq)) + ax.set_xlim("1/1/2000", "4/1/2000") + result = ax.get_xlim() + assert int(result[0]) == expected[0].ordinal + assert int(result[1]) == expected[1].ordinal + + # datetime + expected = (Period("1/1/2000", ax.freq), Period("4/1/2000", ax.freq)) + ax.set_xlim(datetime(2000, 1, 1), datetime(2000, 4, 1)) + result = ax.get_xlim() + assert int(result[0]) == expected[0].ordinal + assert int(result[1]) == expected[1].ordinal + + def test_get_finder(self): + assert conv.get_finder(to_offset("B")) == conv._daily_finder + assert conv.get_finder(to_offset("D")) == conv._daily_finder + assert conv.get_finder(to_offset("ME")) == conv._monthly_finder + assert conv.get_finder(to_offset("QE")) == conv._quarterly_finder + assert conv.get_finder(to_offset("YE")) == conv._annual_finder + assert conv.get_finder(to_offset("W")) == conv._daily_finder + + def test_finder_daily(self): + day_lst = [10, 40, 252, 400, 950, 2750, 10000] + + msg = "Period with BDay freq is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + xpl1 = xpl2 = [Period("1999-1-1", freq="B").ordinal] * len(day_lst) + rs1 = [] + rs2 = [] + for n in day_lst: + rng = bdate_range("1999-1-1", periods=n) + ser = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + _, ax = mpl.pyplot.subplots() + ser.plot(ax=ax) + xaxis = ax.get_xaxis() + rs1.append(xaxis.get_majorticklocs()[0]) + + vmin, vmax = ax.get_xlim() + ax.set_xlim(vmin + 0.9, vmax) + rs2.append(xaxis.get_majorticklocs()[0]) + mpl.pyplot.close(ax.get_figure()) + + assert rs1 == xpl1 + assert rs2 == xpl2 + + def test_finder_quarterly(self): + yrs = [3.5, 11] + + xpl1 = xpl2 = [Period("1988Q1").ordinal] * len(yrs) + rs1 = [] + rs2 = [] + for n in yrs: + rng = period_range("1987Q2", periods=int(n * 4), freq="Q") + ser = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + _, ax = mpl.pyplot.subplots() + ser.plot(ax=ax) + xaxis = ax.get_xaxis() + rs1.append(xaxis.get_majorticklocs()[0]) + + (vmin, vmax) = ax.get_xlim() + ax.set_xlim(vmin + 0.9, vmax) + rs2.append(xaxis.get_majorticklocs()[0]) + mpl.pyplot.close(ax.get_figure()) + + assert rs1 == xpl1 + assert rs2 == xpl2 + + def test_finder_monthly(self): + yrs = [1.15, 2.5, 4, 11] + + xpl1 = xpl2 = [Period("Jan 1988").ordinal] * len(yrs) + rs1 = [] + rs2 = [] + for n in yrs: + rng = period_range("1987Q2", periods=int(n * 12), freq="M") + ser = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + _, ax = mpl.pyplot.subplots() + ser.plot(ax=ax) + xaxis = ax.get_xaxis() + rs1.append(xaxis.get_majorticklocs()[0]) + + vmin, vmax = ax.get_xlim() + ax.set_xlim(vmin + 0.9, vmax) + rs2.append(xaxis.get_majorticklocs()[0]) + mpl.pyplot.close(ax.get_figure()) + + assert rs1 == xpl1 + assert rs2 == xpl2 + + def test_finder_monthly_long(self): + rng = period_range("1988Q1", periods=24 * 12, freq="M") + ser = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + _, ax = mpl.pyplot.subplots() + ser.plot(ax=ax) + xaxis = ax.get_xaxis() + rs = xaxis.get_majorticklocs()[0] + xp = Period("1989Q1", "M").ordinal + assert rs == xp + + def test_finder_annual(self): + xp = [1987, 1988, 1990, 1990, 1995, 2020, 2070, 2170] + xp = [Period(x, freq="Y").ordinal for x in xp] + rs = [] + for nyears in [5, 10, 19, 49, 99, 199, 599, 1001]: + rng = period_range("1987", periods=nyears, freq="Y") + ser = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + _, ax = mpl.pyplot.subplots() + ser.plot(ax=ax) + xaxis = ax.get_xaxis() + rs.append(xaxis.get_majorticklocs()[0]) + mpl.pyplot.close(ax.get_figure()) + + assert rs == xp + + @pytest.mark.slow + def test_finder_minutely(self): + nminutes = 1 * 24 * 60 + rng = date_range("1/1/1999", freq="Min", periods=nminutes) + ser = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + _, ax = mpl.pyplot.subplots() + ser.plot(ax=ax) + xaxis = ax.get_xaxis() + rs = xaxis.get_majorticklocs()[0] + xp = Period("1/1/1999", freq="Min").ordinal + + assert rs == xp + + def test_finder_hourly(self): + nhours = 23 + rng = date_range("1/1/1999", freq="h", periods=nhours) + ser = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + _, ax = mpl.pyplot.subplots() + ser.plot(ax=ax) + xaxis = ax.get_xaxis() + rs = xaxis.get_majorticklocs()[0] + xp = Period("1/1/1999", freq="h").ordinal + + assert rs == xp + + def test_gaps(self): + ts = Series( + np.arange(10, dtype=np.float64), index=date_range("2020-01-01", periods=10) + ) + ts.iloc[5:7] = np.nan + _, ax = mpl.pyplot.subplots() + ts.plot(ax=ax) + lines = ax.get_lines() + assert len(lines) == 1 + line = lines[0] + data = line.get_xydata() + + data = np.ma.MaskedArray(data, mask=isna(data), fill_value=np.nan) + + assert isinstance(data, np.ma.core.MaskedArray) + mask = data.mask + assert mask[5:7, 1].all() + + def test_gaps_irregular(self): + # irregular + ts = Series( + np.arange(30, dtype=np.float64), index=date_range("2020-01-01", periods=30) + ) + ts = ts.iloc[[0, 1, 2, 5, 7, 9, 12, 15, 20]] + ts.iloc[2:5] = np.nan + _, ax = mpl.pyplot.subplots() + ax = ts.plot(ax=ax) + lines = ax.get_lines() + assert len(lines) == 1 + line = lines[0] + data = line.get_xydata() + + data = np.ma.MaskedArray(data, mask=isna(data), fill_value=np.nan) + + assert isinstance(data, np.ma.core.MaskedArray) + mask = data.mask + assert mask[2:5, 1].all() + + def test_gaps_non_ts(self): + # non-ts + idx = [0, 1, 2, 5, 7, 9, 12, 15, 20] + ser = Series(np.random.default_rng(2).standard_normal(len(idx)), idx) + ser.iloc[2:5] = np.nan + _, ax = mpl.pyplot.subplots() + ser.plot(ax=ax) + lines = ax.get_lines() + assert len(lines) == 1 + line = lines[0] + data = line.get_xydata() + data = np.ma.MaskedArray(data, mask=isna(data), fill_value=np.nan) + + assert isinstance(data, np.ma.core.MaskedArray) + mask = data.mask + assert mask[2:5, 1].all() + + def test_gap_upsample(self): + low = Series( + np.arange(10, dtype=np.float64), index=date_range("2020-01-01", periods=10) + ) + low.iloc[5:7] = np.nan + _, ax = mpl.pyplot.subplots() + low.plot(ax=ax) + + idxh = date_range(low.index[0], low.index[-1], freq="12h") + s = Series(np.random.default_rng(2).standard_normal(len(idxh)), idxh) + s.plot(secondary_y=True) + lines = ax.get_lines() + assert len(lines) == 1 + assert len(ax.right_ax.get_lines()) == 1 + + line = lines[0] + data = line.get_xydata() + data = np.ma.MaskedArray(data, mask=isna(data), fill_value=np.nan) + + assert isinstance(data, np.ma.core.MaskedArray) + mask = data.mask + assert mask[5:7, 1].all() + + def test_secondary_y(self): + ser = Series(np.random.default_rng(2).standard_normal(10)) + fig, _ = mpl.pyplot.subplots() + ax = ser.plot(secondary_y=True) + assert hasattr(ax, "left_ax") + assert not hasattr(ax, "right_ax") + axes = fig.get_axes() + line = ax.get_lines()[0] + xp = Series(line.get_ydata(), line.get_xdata()) + tm.assert_series_equal(ser, xp) + assert ax.get_yaxis().get_ticks_position() == "right" + assert not axes[0].get_yaxis().get_visible() + + def test_secondary_y_yaxis(self): + Series(np.random.default_rng(2).standard_normal(10)) + ser2 = Series(np.random.default_rng(2).standard_normal(10)) + _, ax2 = mpl.pyplot.subplots() + ser2.plot(ax=ax2) + assert ax2.get_yaxis().get_ticks_position() == "left" + + def test_secondary_both(self): + ser = Series(np.random.default_rng(2).standard_normal(10)) + ser2 = Series(np.random.default_rng(2).standard_normal(10)) + ax = ser2.plot() + ax2 = ser.plot(secondary_y=True) + assert ax.get_yaxis().get_visible() + assert not hasattr(ax, "left_ax") + assert hasattr(ax, "right_ax") + assert hasattr(ax2, "left_ax") + assert not hasattr(ax2, "right_ax") + + def test_secondary_y_ts(self): + idx = date_range("1/1/2000", periods=10, unit="ns") + ser = Series(np.random.default_rng(2).standard_normal(10), idx) + fig, _ = mpl.pyplot.subplots() + ax = ser.plot(secondary_y=True) + assert hasattr(ax, "left_ax") + assert not hasattr(ax, "right_ax") + axes = fig.get_axes() + line = ax.get_lines()[0] + xp = Series(line.get_ydata(), line.get_xdata()).to_timestamp() + xp.index = xp.index.as_unit("ns") + tm.assert_series_equal(ser, xp) + assert ax.get_yaxis().get_ticks_position() == "right" + assert not axes[0].get_yaxis().get_visible() + + def test_secondary_y_ts_yaxis(self): + idx = date_range("1/1/2000", periods=10) + ser2 = Series(np.random.default_rng(2).standard_normal(10), idx) + _, ax2 = mpl.pyplot.subplots() + ser2.plot(ax=ax2) + assert ax2.get_yaxis().get_ticks_position() == "left" + + def test_secondary_y_ts_visible(self): + idx = date_range("1/1/2000", periods=10) + ser2 = Series(np.random.default_rng(2).standard_normal(10), idx) + ax = ser2.plot() + assert ax.get_yaxis().get_visible() + + def test_secondary_kde(self): + pytest.importorskip("scipy") + ser = Series(np.random.default_rng(2).standard_normal(10)) + fig, ax = mpl.pyplot.subplots() + ax = ser.plot(secondary_y=True, kind="density", ax=ax) + assert hasattr(ax, "left_ax") + assert not hasattr(ax, "right_ax") + axes = fig.get_axes() + assert axes[1].get_yaxis().get_ticks_position() == "right" + + def test_secondary_bar(self): + ser = Series(np.random.default_rng(2).standard_normal(10)) + fig, ax = mpl.pyplot.subplots() + ser.plot(secondary_y=True, kind="bar", ax=ax) + axes = fig.get_axes() + assert axes[1].get_yaxis().get_ticks_position() == "right" + + def test_secondary_frame(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), columns=["a", "b", "c"] + ) + axes = df.plot(secondary_y=["a", "c"], subplots=True) + assert axes[0].get_yaxis().get_ticks_position() == "right" + assert axes[1].get_yaxis().get_ticks_position() == "left" + assert axes[2].get_yaxis().get_ticks_position() == "right" + + def test_secondary_bar_frame(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 3)), columns=["a", "b", "c"] + ) + axes = df.plot(kind="bar", secondary_y=["a", "c"], subplots=True) + assert axes[0].get_yaxis().get_ticks_position() == "right" + assert axes[1].get_yaxis().get_ticks_position() == "left" + assert axes[2].get_yaxis().get_ticks_position() == "right" + + def test_mixed_freq_regular_first(self): + # TODO + s1 = Series( + np.arange(20, dtype=np.float64), + index=date_range("2020-01-01", periods=20, freq="B"), + ) + s2 = s1.iloc[[0, 5, 10, 11, 12, 13, 14, 15]] + + # it works! + _, ax = mpl.pyplot.subplots() + s1.plot(ax=ax) + + ax2 = s2.plot(style="g", ax=ax) + lines = ax2.get_lines() + msg = r"PeriodDtype\[B\] is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + idx1 = PeriodIndex(lines[0].get_xdata()) + idx2 = PeriodIndex(lines[1].get_xdata()) + + tm.assert_index_equal(idx1, s1.index.to_period("B")) + tm.assert_index_equal(idx2, s2.index.to_period("B")) + + left, right = ax2.get_xlim() + pidx = s1.index.to_period() + assert left <= pidx[0].ordinal + assert right >= pidx[-1].ordinal + + def test_mixed_freq_irregular_first(self): + s1 = Series( + np.arange(20, dtype=np.float64), index=date_range("2020-01-01", periods=20) + ) + s2 = s1.iloc[[0, 5, 10, 11, 12, 13, 14, 15]] + _, ax = mpl.pyplot.subplots() + s2.plot(style="g", ax=ax) + s1.plot(ax=ax) + assert not hasattr(ax, "freq") + lines = ax.get_lines() + x1 = lines[0].get_xdata() + tm.assert_numpy_array_equal(x1, s2.index.astype(object).values) + x2 = lines[1].get_xdata() + tm.assert_numpy_array_equal(x2, s1.index.astype(object).values) + + def test_mixed_freq_regular_first_df(self): + # GH 9852 + s1 = Series( + np.arange(20, dtype=np.float64), + index=date_range("2020-01-01", periods=20, freq="B"), + ).to_frame() + s2 = s1.iloc[[0, 5, 10, 11, 12, 13, 14, 15], :] + _, ax = mpl.pyplot.subplots() + s1.plot(ax=ax) + ax2 = s2.plot(style="g", ax=ax) + lines = ax2.get_lines() + msg = r"PeriodDtype\[B\] is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + idx1 = PeriodIndex(lines[0].get_xdata()) + idx2 = PeriodIndex(lines[1].get_xdata()) + assert idx1.equals(s1.index.to_period("B")) + assert idx2.equals(s2.index.to_period("B")) + left, right = ax2.get_xlim() + pidx = s1.index.to_period() + assert left <= pidx[0].ordinal + assert right >= pidx[-1].ordinal + + def test_mixed_freq_irregular_first_df(self): + # GH 9852 + s1 = Series( + np.arange(20, dtype=np.float64), index=date_range("2020-01-01", periods=20) + ).to_frame() + s2 = s1.iloc[[0, 5, 10, 11, 12, 13, 14, 15], :] + _, ax = mpl.pyplot.subplots() + s2.plot(style="g", ax=ax) + s1.plot(ax=ax) + assert not hasattr(ax, "freq") + lines = ax.get_lines() + x1 = lines[0].get_xdata() + tm.assert_numpy_array_equal(x1, s2.index.astype(object).values) + x2 = lines[1].get_xdata() + tm.assert_numpy_array_equal(x2, s1.index.astype(object).values) + + def test_mixed_freq_hf_first(self): + idxh = date_range("1/1/1999", periods=365, freq="D") + idxl = date_range("1/1/1999", periods=12, freq="ME") + high = Series(np.random.default_rng(2).standard_normal(len(idxh)), idxh) + low = Series(np.random.default_rng(2).standard_normal(len(idxl)), idxl) + _, ax = mpl.pyplot.subplots() + high.plot(ax=ax) + low.plot(ax=ax) + for line in ax.get_lines(): + assert PeriodIndex(data=line.get_xdata()).freq == "D" + + def test_mixed_freq_alignment(self): + ts_ind = date_range("2012-01-01 13:00", "2012-01-02", freq="h") + ts_data = np.random.default_rng(2).standard_normal(12) + + ts = Series(ts_data, index=ts_ind) + ts2 = ts.asfreq("min").interpolate() + + _, ax = mpl.pyplot.subplots() + ax = ts.plot(ax=ax) + ts2.plot(style="r", ax=ax) + + assert ax.lines[0].get_xdata()[0] == ax.lines[1].get_xdata()[0] + + def test_mixed_freq_lf_first(self): + idxh = date_range("1/1/1999", periods=365, freq="D") + idxl = date_range("1/1/1999", periods=12, freq="ME") + high = Series(np.random.default_rng(2).standard_normal(len(idxh)), idxh) + low = Series(np.random.default_rng(2).standard_normal(len(idxl)), idxl) + _, ax = mpl.pyplot.subplots() + low.plot(legend=True, ax=ax) + high.plot(legend=True, ax=ax) + for line in ax.get_lines(): + assert PeriodIndex(data=line.get_xdata()).freq == "D" + leg = ax.get_legend() + assert len(leg.texts) == 2 + mpl.pyplot.close(ax.get_figure()) + + def test_mixed_freq_lf_first_hourly(self): + idxh = date_range("1/1/1999", periods=240, freq="min") + idxl = date_range("1/1/1999", periods=4, freq="h") + high = Series(np.random.default_rng(2).standard_normal(len(idxh)), idxh) + low = Series(np.random.default_rng(2).standard_normal(len(idxl)), idxl) + _, ax = mpl.pyplot.subplots() + low.plot(ax=ax) + high.plot(ax=ax) + for line in ax.get_lines(): + assert PeriodIndex(data=line.get_xdata()).freq == "min" + + @pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning") + def test_mixed_freq_irreg_period(self): + ts = Series( + np.arange(30, dtype=np.float64), index=date_range("2020-01-01", periods=30) + ) + irreg = ts.iloc[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 16, 17, 18, 29]] + msg = r"PeriodDtype\[B\] is deprecated" + with tm.assert_produces_warning(FutureWarning, match=msg): + rng = period_range("1/3/2000", periods=30, freq="B") + ps = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + _, ax = mpl.pyplot.subplots() + irreg.plot(ax=ax) + ps.plot(ax=ax) + + def test_mixed_freq_shared_ax(self): + # GH13341, using sharex=True + idx1 = date_range("2015-01-01", periods=3, freq="ME") + idx2 = idx1[:1].union(idx1[2:]) + s1 = Series(range(len(idx1)), idx1) + s2 = Series(range(len(idx2)), idx2) + + _, (ax1, ax2) = mpl.pyplot.subplots(nrows=2, sharex=True) + s1.plot(ax=ax1) + s2.plot(ax=ax2) + + assert ax1.freq == "M" + assert ax2.freq == "M" + assert ax1.lines[0].get_xydata()[0, 0] == ax2.lines[0].get_xydata()[0, 0] + + def test_mixed_freq_shared_ax_twin_x(self): + # GH13341, using sharex=True + idx1 = date_range("2015-01-01", periods=3, freq="ME") + idx2 = idx1[:1].union(idx1[2:]) + s1 = Series(range(len(idx1)), idx1) + s2 = Series(range(len(idx2)), idx2) + # using twinx + _, ax1 = mpl.pyplot.subplots() + ax2 = ax1.twinx() + s1.plot(ax=ax1) + s2.plot(ax=ax2) + + assert ax1.lines[0].get_xydata()[0, 0] == ax2.lines[0].get_xydata()[0, 0] + + @pytest.mark.xfail(reason="TODO (GH14330, GH14322)") + def test_mixed_freq_shared_ax_twin_x_irregular_first(self): + # GH13341, using sharex=True + idx1 = date_range("2015-01-01", periods=3, freq="ME") + idx2 = idx1[:1].union(idx1[2:]) + s1 = Series(range(len(idx1)), idx1) + s2 = Series(range(len(idx2)), idx2) + _, ax1 = mpl.pyplot.subplots() + ax2 = ax1.twinx() + s2.plot(ax=ax1) + s1.plot(ax=ax2) + assert ax1.lines[0].get_xydata()[0, 0] == ax2.lines[0].get_xydata()[0, 0] + + def test_nat_handling(self): + _, ax = mpl.pyplot.subplots() + + dti = DatetimeIndex(["2015-01-01", NaT, "2015-01-03"]) + s = Series(range(len(dti)), dti) + s.plot(ax=ax) + xdata = ax.get_lines()[0].get_xdata() + # plot x data is bounded by index values + assert s.index.min() <= Series(xdata).min() + assert Series(xdata).max() <= s.index.max() + + def test_to_weekly_resampling_disallow_how_kwd(self): + idxh = date_range("1/1/1999", periods=52, freq="W") + idxl = date_range("1/1/1999", periods=12, freq="ME") + high = Series(np.random.default_rng(2).standard_normal(len(idxh)), idxh) + low = Series(np.random.default_rng(2).standard_normal(len(idxl)), idxl) + _, ax = mpl.pyplot.subplots() + high.plot(ax=ax) + + msg = ( + "'how' is not a valid keyword for plotting functions. If plotting " + "multiple objects on shared axes, resample manually first." + ) + with pytest.raises(ValueError, match=msg): + low.plot(ax=ax, how="foo") + + def test_to_weekly_resampling(self): + idxh = date_range("1/1/1999", periods=52, freq="W") + idxl = date_range("1/1/1999", periods=12, freq="ME") + high = Series(np.random.default_rng(2).standard_normal(len(idxh)), idxh) + low = Series(np.random.default_rng(2).standard_normal(len(idxl)), idxl) + _, ax = mpl.pyplot.subplots() + high.plot(ax=ax) + low.plot(ax=ax) + for line in ax.get_lines(): + assert PeriodIndex(data=line.get_xdata()).freq == idxh.freq + + def test_from_weekly_resampling(self): + idxh = date_range("1/1/1999", periods=52, freq="W") + idxl = date_range("1/1/1999", periods=12, freq="ME") + high = Series(np.random.default_rng(2).standard_normal(len(idxh)), idxh) + low = Series(np.random.default_rng(2).standard_normal(len(idxl)), idxl) + _, ax = mpl.pyplot.subplots() + low.plot(ax=ax) + high.plot(ax=ax) + + expected_h = idxh.to_period().asi8.astype(np.float64) + expected_l = np.array( + [1514, 1519, 1523, 1527, 1531, 1536, 1540, 1544, 1549, 1553, 1558, 1562], + dtype=np.float64, + ) + for line in ax.get_lines(): + assert PeriodIndex(data=line.get_xdata()).freq == idxh.freq + xdata = line.get_xdata(orig=False) + if len(xdata) == 12: # idxl lines + tm.assert_numpy_array_equal(xdata, expected_l) + else: + tm.assert_numpy_array_equal(xdata, expected_h) + + @pytest.mark.parametrize("kind1, kind2", [("line", "area"), ("area", "line")]) + def test_from_resampling_area_line_mixed(self, kind1, kind2): + idxh = date_range("1/1/1999", periods=52, freq="W") + idxl = date_range("1/1/1999", periods=12, freq="ME") + high = DataFrame( + np.random.default_rng(2).random((len(idxh), 3)), + index=idxh, + columns=[0, 1, 2], + ) + low = DataFrame( + np.random.default_rng(2).random((len(idxl), 3)), + index=idxl, + columns=[0, 1, 2], + ) + + _, ax = mpl.pyplot.subplots() + low.plot(kind=kind1, stacked=True, ax=ax) + high.plot(kind=kind2, stacked=True, ax=ax) + + # check low dataframe result + expected_x = np.array( + [ + 1514, + 1519, + 1523, + 1527, + 1531, + 1536, + 1540, + 1544, + 1549, + 1553, + 1558, + 1562, + ], + dtype=np.float64, + ) + expected_y = np.zeros(len(expected_x), dtype=np.float64) + for i in range(3): + line = ax.lines[i] + assert PeriodIndex(line.get_xdata()).freq == idxh.freq + tm.assert_numpy_array_equal(line.get_xdata(orig=False), expected_x) + # check stacked values are correct + expected_y += low[i].values + tm.assert_numpy_array_equal(line.get_ydata(orig=False), expected_y) + + # check high dataframe result + expected_x = idxh.to_period().asi8.astype(np.float64) + expected_y = np.zeros(len(expected_x), dtype=np.float64) + for i in range(3): + line = ax.lines[3 + i] + assert PeriodIndex(data=line.get_xdata()).freq == idxh.freq + tm.assert_numpy_array_equal(line.get_xdata(orig=False), expected_x) + expected_y += high[i].values + tm.assert_numpy_array_equal(line.get_ydata(orig=False), expected_y) + + @pytest.mark.parametrize("kind1, kind2", [("line", "area"), ("area", "line")]) + def test_from_resampling_area_line_mixed_high_to_low(self, kind1, kind2): + idxh = date_range("1/1/1999", periods=52, freq="W") + idxl = date_range("1/1/1999", periods=12, freq="ME") + high = DataFrame( + np.random.default_rng(2).random((len(idxh), 3)), + index=idxh, + columns=[0, 1, 2], + ) + low = DataFrame( + np.random.default_rng(2).random((len(idxl), 3)), + index=idxl, + columns=[0, 1, 2], + ) + _, ax = mpl.pyplot.subplots() + high.plot(kind=kind1, stacked=True, ax=ax) + low.plot(kind=kind2, stacked=True, ax=ax) + + # check high dataframe result + expected_x = idxh.to_period().asi8.astype(np.float64) + expected_y = np.zeros(len(expected_x), dtype=np.float64) + for i in range(3): + line = ax.lines[i] + assert PeriodIndex(data=line.get_xdata()).freq == idxh.freq + tm.assert_numpy_array_equal(line.get_xdata(orig=False), expected_x) + expected_y += high[i].values + tm.assert_numpy_array_equal(line.get_ydata(orig=False), expected_y) + + # check low dataframe result + expected_x = np.array( + [ + 1514, + 1519, + 1523, + 1527, + 1531, + 1536, + 1540, + 1544, + 1549, + 1553, + 1558, + 1562, + ], + dtype=np.float64, + ) + expected_y = np.zeros(len(expected_x), dtype=np.float64) + for i in range(3): + lines = ax.lines[3 + i] + assert PeriodIndex(data=lines.get_xdata()).freq == idxh.freq + tm.assert_numpy_array_equal(lines.get_xdata(orig=False), expected_x) + expected_y += low[i].values + tm.assert_numpy_array_equal(lines.get_ydata(orig=False), expected_y) + + def test_mixed_freq_second_millisecond(self): + # GH 7772, GH 7760 + idxh = date_range("2014-07-01 09:00", freq="s", periods=5) + idxl = date_range("2014-07-01 09:00", freq="100ms", periods=50) + high = Series(np.random.default_rng(2).standard_normal(len(idxh)), idxh) + low = Series(np.random.default_rng(2).standard_normal(len(idxl)), idxl) + # high to low + _, ax = mpl.pyplot.subplots() + high.plot(ax=ax) + low.plot(ax=ax) + assert len(ax.get_lines()) == 2 + for line in ax.get_lines(): + assert PeriodIndex(data=line.get_xdata()).freq == "ms" + + def test_mixed_freq_second_millisecond_low_to_high(self): + # GH 7772, GH 7760 + idxh = date_range("2014-07-01 09:00", freq="s", periods=5) + idxl = date_range("2014-07-01 09:00", freq="100ms", periods=50) + high = Series(np.random.default_rng(2).standard_normal(len(idxh)), idxh) + low = Series(np.random.default_rng(2).standard_normal(len(idxl)), idxl) + # low to high + _, ax = mpl.pyplot.subplots() + low.plot(ax=ax) + high.plot(ax=ax) + assert len(ax.get_lines()) == 2 + for line in ax.get_lines(): + assert PeriodIndex(data=line.get_xdata()).freq == "ms" + + def test_irreg_dtypes(self): + # date + idx = [date(2000, 1, 1), date(2000, 1, 5), date(2000, 1, 20)] + df = DataFrame( + np.random.default_rng(2).standard_normal((len(idx), 3)), + Index(idx, dtype=object), + ) + _check_plot_works(df.plot) + + def test_irreg_dtypes_dt64(self): + # np.datetime64 + idx = date_range("1/1/2000", periods=10) + idx = idx[[0, 2, 5, 9]].astype(object) + df = DataFrame(np.random.default_rng(2).standard_normal((len(idx), 3)), idx) + _, ax = mpl.pyplot.subplots() + _check_plot_works(df.plot, ax=ax) + + def test_time(self): + t = datetime(1, 1, 1, 3, 30, 0) + deltas = np.random.default_rng(2).integers(1, 20, 3).cumsum() + ts = np.array([(t + timedelta(minutes=int(x))).time() for x in deltas]) + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(len(ts)), + "b": np.random.default_rng(2).standard_normal(len(ts)), + }, + index=ts, + ) + _, ax = mpl.pyplot.subplots() + df.plot(ax=ax) + + # verify tick labels + ticks = ax.get_xticks() + labels = ax.get_xticklabels() + for _tick, _label in zip(ticks, labels, strict=True): + m, s = divmod(int(_tick), 60) + h, m = divmod(m, 60) + rs = _label.get_text() + if len(rs) > 0: + if s != 0: + xp = time(h, m, s).strftime("%H:%M:%S") + else: + xp = time(h, m, s).strftime("%H:%M") + assert xp == rs + + def test_time_change_xlim(self): + t = datetime(1, 1, 1, 3, 30, 0) + deltas = np.random.default_rng(2).integers(1, 20, 3).cumsum() + ts = np.array([(t + timedelta(minutes=int(x))).time() for x in deltas]) + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(len(ts)), + "b": np.random.default_rng(2).standard_normal(len(ts)), + }, + index=ts, + ) + _, ax = mpl.pyplot.subplots() + df.plot(ax=ax) + + # verify tick labels + ticks = ax.get_xticks() + labels = ax.get_xticklabels() + for _tick, _label in zip(ticks, labels, strict=True): + m, s = divmod(int(_tick), 60) + h, m = divmod(m, 60) + rs = _label.get_text() + if len(rs) > 0: + if s != 0: + xp = time(h, m, s).strftime("%H:%M:%S") + else: + xp = time(h, m, s).strftime("%H:%M") + assert xp == rs + + # change xlim + ax.set_xlim("1:30", "5:00") + + # check tick labels again + ticks = ax.get_xticks() + labels = ax.get_xticklabels() + for _tick, _label in zip(ticks, labels, strict=True): + m, s = divmod(int(_tick), 60) + h, m = divmod(m, 60) + rs = _label.get_text() + if len(rs) > 0: + if s != 0: + xp = time(h, m, s).strftime("%H:%M:%S") + else: + xp = time(h, m, s).strftime("%H:%M") + assert xp == rs + + def test_time_musec(self): + t = datetime(1, 1, 1, 3, 30, 0) + deltas = np.random.default_rng(2).integers(1, 20, 3).cumsum() + ts = np.array([(t + timedelta(microseconds=int(x))).time() for x in deltas]) + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(len(ts)), + "b": np.random.default_rng(2).standard_normal(len(ts)), + }, + index=ts, + ) + _, ax = mpl.pyplot.subplots() + ax = df.plot(ax=ax) + + # verify tick labels + ticks = ax.get_xticks() + labels = ax.get_xticklabels() + for _tick, _label in zip(ticks, labels, strict=True): + m, s = divmod(int(_tick), 60) + + us = round((_tick - int(_tick)) * 1e6) + + h, m = divmod(m, 60) + rs = _label.get_text() + if len(rs) > 0: + if (us % 1000) != 0: + xp = time(h, m, s, us).strftime("%H:%M:%S.%f") + elif (us // 1000) != 0: + xp = time(h, m, s, us).strftime("%H:%M:%S.%f")[:-3] + elif s != 0: + xp = time(h, m, s, us).strftime("%H:%M:%S") + else: + xp = time(h, m, s, us).strftime("%H:%M") + assert xp == rs + + def test_secondary_upsample(self): + idxh = date_range("1/1/1999", periods=365, freq="D") + idxl = date_range("1/1/1999", periods=12, freq="ME") + high = Series(np.random.default_rng(2).standard_normal(len(idxh)), idxh) + low = Series(np.random.default_rng(2).standard_normal(len(idxl)), idxl) + _, ax = mpl.pyplot.subplots() + low.plot(ax=ax) + ax = high.plot(secondary_y=True, ax=ax) + for line in ax.get_lines(): + assert PeriodIndex(line.get_xdata()).freq == "D" + assert hasattr(ax, "left_ax") + assert not hasattr(ax, "right_ax") + for line in ax.left_ax.get_lines(): + assert PeriodIndex(line.get_xdata()).freq == "D" + + def test_secondary_legend(self): + fig = mpl.pyplot.figure() + ax = fig.add_subplot(211) + + # ts + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + df.plot(secondary_y=["A", "B"], ax=ax) + leg = ax.get_legend() + assert len(leg.get_lines()) == 4 + assert leg.get_texts()[0].get_text() == "A (right)" + assert leg.get_texts()[1].get_text() == "B (right)" + assert leg.get_texts()[2].get_text() == "C" + assert leg.get_texts()[3].get_text() == "D" + assert ax.right_ax.get_legend() is None + colors = set() + for line in leg.get_lines(): + colors.add(line.get_color()) + + # TODO: color cycle problems + assert len(colors) == 4 + + def test_secondary_legend_right(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + fig = mpl.pyplot.figure() + ax = fig.add_subplot(211) + df.plot(secondary_y=["A", "C"], mark_right=False, ax=ax) + leg = ax.get_legend() + assert len(leg.get_lines()) == 4 + assert leg.get_texts()[0].get_text() == "A" + assert leg.get_texts()[1].get_text() == "B" + assert leg.get_texts()[2].get_text() == "C" + assert leg.get_texts()[3].get_text() == "D" + + def test_secondary_legend_bar(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + fig, ax = mpl.pyplot.subplots() + df.plot(kind="bar", secondary_y=["A"], ax=ax) + leg = ax.get_legend() + assert leg.get_texts()[0].get_text() == "A (right)" + assert leg.get_texts()[1].get_text() == "B" + + def test_secondary_legend_bar_right(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + fig, ax = mpl.pyplot.subplots() + df.plot(kind="bar", secondary_y=["A"], mark_right=False, ax=ax) + leg = ax.get_legend() + assert leg.get_texts()[0].get_text() == "A" + assert leg.get_texts()[1].get_text() == "B" + + def test_secondary_legend_multi_col(self): + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + fig = mpl.pyplot.figure() + ax = fig.add_subplot(211) + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + ax = df.plot(secondary_y=["C", "D"], ax=ax) + leg = ax.get_legend() + assert len(leg.get_lines()) == 4 + assert ax.right_ax.get_legend() is None + colors = set() + for line in leg.get_lines(): + colors.add(line.get_color()) + + # TODO: color cycle problems + assert len(colors) == 4 + + def test_secondary_legend_nonts(self): + # non-ts + df = DataFrame( + 1.1 * np.arange(40).reshape((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(10)], dtype=object), + ) + fig = mpl.pyplot.figure() + ax = fig.add_subplot(211) + ax = df.plot(secondary_y=["A", "B"], ax=ax) + leg = ax.get_legend() + assert len(leg.get_lines()) == 4 + assert ax.right_ax.get_legend() is None + colors = set() + for line in leg.get_lines(): + colors.add(line.get_color()) + + # TODO: color cycle problems + assert len(colors) == 4 + + def test_secondary_legend_nonts_multi_col(self): + # non-ts + df = DataFrame( + 1.1 * np.arange(40).reshape((10, 4)), + columns=Index(list("ABCD"), dtype=object), + index=Index([f"i-{i}" for i in range(10)], dtype=object), + ) + fig = mpl.pyplot.figure() + ax = fig.add_subplot(211) + ax = df.plot(secondary_y=["C", "D"], ax=ax) + leg = ax.get_legend() + assert len(leg.get_lines()) == 4 + assert ax.right_ax.get_legend() is None + colors = set() + for line in leg.get_lines(): + colors.add(line.get_color()) + + # TODO: color cycle problems + assert len(colors) == 4 + + @pytest.mark.xfail(reason="Api changed in 3.6.0") + def test_format_date_axis(self): + rng = date_range("1/1/2012", periods=12, freq="ME") + df = DataFrame(np.random.default_rng(2).standard_normal((len(rng), 3)), rng) + _, ax = mpl.pyplot.subplots() + ax = df.plot(ax=ax) + xaxis = ax.get_xaxis() + for line in xaxis.get_ticklabels(): + if len(line.get_text()) > 0: + assert line.get_rotation() == 30 + + def test_ax_plot(self): + x = date_range(start="2012-01-02", periods=10, freq="D") + y = list(range(len(x))) + _, ax = mpl.pyplot.subplots() + lines = ax.plot(x, y, label="Y") + tm.assert_index_equal(DatetimeIndex(lines[0].get_xdata()), x) + + def test_mpl_nopandas(self): + dates = [date(2008, 12, 31), date(2009, 1, 31)] + values1 = np.arange(10.0, 11.0, 0.5) + values2 = np.arange(11.0, 12.0, 0.5) + + _, ax = mpl.pyplot.subplots() + ( + line1, + line2, + ) = ax.plot( + [x.toordinal() for x in dates], + values1, + "-", + [x.toordinal() for x in dates], + values2, + "-", + linewidth=4, + ) + + exp = np.array([x.toordinal() for x in dates], dtype=np.float64) + tm.assert_numpy_array_equal(line1.get_xydata()[:, 0], exp) + tm.assert_numpy_array_equal(line2.get_xydata()[:, 0], exp) + + def test_irregular_ts_shared_ax_xlim(self): + # GH 2960 + ts = Series( + np.arange(20, dtype=np.float64), index=date_range("2020-01-01", periods=20) + ) + ts_irregular = ts.iloc[[1, 4, 5, 6, 8, 9, 10, 12, 13, 14, 15, 17, 18]] + + # plot the left section of the irregular series, then the right section + _, ax = mpl.pyplot.subplots() + ts_irregular[:5].plot(ax=ax) + ts_irregular[5:].plot(ax=ax) + + # check that axis limits are correct + left, right = ax.get_xlim() + assert left <= conv.DatetimeConverter.convert(ts_irregular.index.min(), "", ax) + assert right >= conv.DatetimeConverter.convert(ts_irregular.index.max(), "", ax) + + def test_secondary_y_non_ts_xlim(self): + # GH 3490 - non-timeseries with secondary y + index_1 = [1, 2, 3, 4] + index_2 = [5, 6, 7, 8] + s1 = Series(1, index=index_1) + s2 = Series(2, index=index_2) + + _, ax = mpl.pyplot.subplots() + s1.plot(ax=ax) + left_before, right_before = ax.get_xlim() + s2.plot(secondary_y=True, ax=ax) + left_after, right_after = ax.get_xlim() + + assert left_before >= left_after + assert right_before < right_after + + def test_secondary_y_regular_ts_xlim(self): + # GH 3490 - regular-timeseries with secondary y + index_1 = date_range(start="2000-01-01", periods=4, freq="D") + index_2 = date_range(start="2000-01-05", periods=4, freq="D") + s1 = Series(1, index=index_1) + s2 = Series(2, index=index_2) + + _, ax = mpl.pyplot.subplots() + s1.plot(ax=ax) + left_before, right_before = ax.get_xlim() + s2.plot(secondary_y=True, ax=ax) + left_after, right_after = ax.get_xlim() + + assert left_before >= left_after + assert right_before < right_after + + def test_secondary_y_mixed_freq_ts_xlim(self): + # GH 3490 - mixed frequency timeseries with secondary y + rng = date_range("2000-01-01", periods=10, freq="min") + ts = Series(1, index=rng) + + _, ax = mpl.pyplot.subplots() + ts.plot(ax=ax) + left_before, right_before = ax.get_xlim() + ts.resample("D").mean().plot(secondary_y=True, ax=ax) + left_after, right_after = ax.get_xlim() + + # a downsample should not have changed either limit + assert left_before == left_after + assert right_before == right_after + + def test_secondary_y_irregular_ts_xlim(self): + # GH 3490 - irregular-timeseries with secondary y + ts = Series( + np.arange(20, dtype=np.float64), index=date_range("2020-01-01", periods=20) + ) + ts_irregular = ts.iloc[[1, 4, 5, 6, 8, 9, 10, 12, 13, 14, 15, 17, 18]] + + _, ax = mpl.pyplot.subplots() + ts_irregular[:5].plot(ax=ax) + # plot higher-x values on secondary axis + ts_irregular[5:].plot(secondary_y=True, ax=ax) + # ensure secondary limits aren't overwritten by plot on primary + ts_irregular[:5].plot(ax=ax) + + left, right = ax.get_xlim() + assert left <= conv.DatetimeConverter.convert(ts_irregular.index.min(), "", ax) + assert right >= conv.DatetimeConverter.convert(ts_irregular.index.max(), "", ax) + + def test_plot_outofbounds_datetime(self): + # 2579 - checking this does not raise + values = [date(1677, 1, 1), date(1677, 1, 2)] + _, ax = mpl.pyplot.subplots() + ax.plot(values) + + values = [datetime(1677, 1, 1, 12), datetime(1677, 1, 2, 12)] + ax.plot(values) + + def test_format_timedelta_ticks_narrow(self): + expected_labels = [f"00:00:00.0000000{i:0>2d}" for i in np.arange(10)] + + rng = timedelta_range("0", periods=10, freq="ns") + df = DataFrame(np.random.default_rng(2).standard_normal((len(rng), 3)), rng) + _, ax = mpl.pyplot.subplots() + df.plot(fontsize=2, ax=ax) + mpl.pyplot.draw() + labels = ax.get_xticklabels() + + result_labels = [x.get_text() for x in labels] + assert len(result_labels) == len(expected_labels) + assert result_labels == expected_labels + + def test_format_timedelta_ticks_wide(self, unit): + expected_labels = [ + "00:00:00", + "1 days 03:46:40", + "2 days 07:33:20", + "3 days 11:20:00", + "4 days 15:06:40", + "5 days 18:53:20", + "6 days 22:40:00", + "8 days 02:26:40", + "9 days 06:13:20", + ] + + rng = timedelta_range("0", periods=10, freq="1 D", unit=unit) + df = DataFrame(np.random.default_rng(2).standard_normal((len(rng), 3)), rng) + _, ax = mpl.pyplot.subplots() + ax = df.plot(fontsize=2, ax=ax) + mpl.pyplot.draw() + labels = ax.get_xticklabels() + + result_labels = [x.get_text() for x in labels] + assert len(result_labels) == len(expected_labels) + assert result_labels == expected_labels + + def test_timedelta_plot(self): + # test issue #8711 + s = Series(range(5), timedelta_range("1day", periods=5)) + _, ax = mpl.pyplot.subplots() + _check_plot_works(s.plot, ax=ax) + + def test_timedelta_long_period(self): + # test long period + index = timedelta_range("1 day 2 hr 30 min 10 s", periods=10, freq="1 D") + s = Series(np.random.default_rng(2).standard_normal(len(index)), index) + _, ax = mpl.pyplot.subplots() + _check_plot_works(s.plot, ax=ax) + + def test_timedelta_short_period(self): + # test short period + index = timedelta_range("1 day 2 hr 30 min 10 s", periods=10, freq="1 ns") + s = Series(np.random.default_rng(2).standard_normal(len(index)), index) + _, ax = mpl.pyplot.subplots() + _check_plot_works(s.plot, ax=ax) + + def test_hist(self): + # https://github.com/matplotlib/matplotlib/issues/8459 + rng = date_range("1/1/2011", periods=10, freq="h") + x = rng + w1 = np.arange(0, 1, 0.1) + w2 = np.arange(0, 1, 0.1)[::-1] + _, ax = mpl.pyplot.subplots() + ax.hist([x, x], weights=[w1, w2]) + + def test_overlapping_datetime(self): + # GB 6608 + s1 = Series( + [1, 2, 3], + index=[ + datetime(1995, 12, 31), + datetime(2000, 12, 31), + datetime(2005, 12, 31), + ], + ) + s2 = Series( + [1, 2, 3], + index=[ + datetime(1997, 12, 31), + datetime(2003, 12, 31), + datetime(2008, 12, 31), + ], + ) + + # plot first series, then add the second series to those axes, + # then try adding the first series again + _, ax = mpl.pyplot.subplots() + s1.plot(ax=ax) + s2.plot(ax=ax) + s1.plot(ax=ax) + + @pytest.mark.xfail(reason="GH9053 matplotlib does not use ax.xaxis.converter") + def test_add_matplotlib_datetime64(self): + # GH9053 - ensure that a plot with PeriodConverter still understands + # datetime64 data. This still fails because matplotlib overrides the + # ax.xaxis.converter with a DatetimeConverter + s = Series( + np.random.default_rng(2).standard_normal(10), + index=date_range("1970-01-02", periods=10), + ) + ax = s.plot() + with tm.assert_produces_warning(DeprecationWarning): + # multi-dimensional indexing + ax.plot(s.index, s.values, color="g") + l1, l2 = ax.lines + tm.assert_numpy_array_equal(l1.get_xydata(), l2.get_xydata()) + + def test_matplotlib_scatter_datetime64(self): + # https://github.com/matplotlib/matplotlib/issues/11391 + df = DataFrame(np.random.default_rng(2).random((10, 2)), columns=["x", "y"]) + df["time"] = date_range("2018-01-01", periods=10, freq="D") + _, ax = mpl.pyplot.subplots() + ax.scatter(x="time", y="y", data=df) + mpl.pyplot.draw() + label = ax.get_xticklabels()[0] + expected = "2018-01-01" + assert label.get_text() == expected + + def test_check_xticks_rot(self): + # https://github.com/pandas-dev/pandas/issues/29460 + # regular time series + x = to_datetime(["2020-05-01", "2020-05-02", "2020-05-03"]) + df = DataFrame({"x": x, "y": [1, 2, 3]}) + axes = df.plot(x="x", y="y") + _check_ticks_props(axes, xrot=0) + + def test_check_xticks_rot_irregular(self): + # irregular time series + x = to_datetime(["2020-05-01", "2020-05-02", "2020-05-04"]) + df = DataFrame({"x": x, "y": [1, 2, 3]}) + axes = df.plot(x="x", y="y") + _check_ticks_props(axes, xrot=30) + + def test_check_xticks_rot_use_idx(self): + # irregular time series + x = to_datetime(["2020-05-01", "2020-05-02", "2020-05-04"]) + df = DataFrame({"x": x, "y": [1, 2, 3]}) + # use timeseries index or not + axes = df.set_index("x").plot(y="y", use_index=True) + _check_ticks_props(axes, xrot=30) + axes = df.set_index("x").plot(y="y", use_index=False) + _check_ticks_props(axes, xrot=0) + + def test_check_xticks_rot_sharex(self): + # irregular time series + x = to_datetime(["2020-05-01", "2020-05-02", "2020-05-04"]) + df = DataFrame({"x": x, "y": [1, 2, 3]}) + # separate subplots + axes = df.plot(x="x", y="y", subplots=True, sharex=True) + _check_ticks_props(axes, xrot=30) + axes = df.plot(x="x", y="y", subplots=True, sharex=False) + _check_ticks_props(axes, xrot=0) + + @pytest.mark.parametrize( + "idx", + [ + date_range("2020-01-01", periods=5), + date_range("2020-01-01", periods=5, tz="UTC"), + timedelta_range("1 day", periods=5, freq="D"), + period_range("2020-01-01", periods=5, freq="D"), + Index([date(2000, 1, i) for i in [1, 3, 6, 20, 22]], dtype=object), + range(5), + ], + ) + def test_pickle_fig(self, temp_file, frame_or_series, idx): + # GH18439, GH#24088, statsmodels#4772 + df = frame_or_series(range(5), index=idx) + fig, ax = plt.subplots(1, 1) + df.plot(ax=ax) + with temp_file.open(mode="wb") as path: + pickle.dump(fig, path) + + +def _check_plot_works(f, freq=None, series=None, *args, **kwargs): + fig = plt.gcf() + + fig.clf() + ax = fig.add_subplot(211) + orig_ax = kwargs.pop("ax", plt.gca()) + orig_axfreq = getattr(orig_ax, "freq", None) + + ret = f(*args, **kwargs) + assert ret is not None # do something more intelligent + + ax = kwargs.pop("ax", plt.gca()) + if series is not None: + dfreq = series.index.freq + if isinstance(dfreq, BaseOffset): + dfreq = dfreq.rule_code + if orig_axfreq is None: + assert ax.freq == dfreq + + if freq is not None and orig_axfreq is None: + assert to_offset(ax.freq, is_period=True) == freq + + ax = fig.add_subplot(212) + kwargs["ax"] = ax + ret = f(*args, **kwargs) + assert ret is not None # TODO: do something more intelligent diff --git a/pandas/tests/plotting/test_groupby.py b/pandas/tests/plotting/test_groupby.py new file mode 100644 index 0000000000000000000000000000000000000000..e86c4e9838d2459f2ad87fcf97a4c46bad2e820d --- /dev/null +++ b/pandas/tests/plotting/test_groupby.py @@ -0,0 +1,154 @@ +"""Test cases for GroupBy.plot""" + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + Series, +) +from pandas.tests.plotting.common import ( + _check_axes_shape, + _check_legend_labels, +) + +pytest.importorskip("matplotlib") + + +class TestDataFrameGroupByPlots: + def test_series_groupby_plotting_nominally_works(self): + n = 10 + weight = Series(np.random.default_rng(2).normal(166, 20, size=n)) + gender = np.random.default_rng(2).choice(["male", "female"], size=n) + + weight.groupby(gender).plot() + + def test_series_groupby_plotting_nominally_works_hist(self): + n = 10 + height = Series(np.random.default_rng(2).normal(60, 10, size=n)) + gender = np.random.default_rng(2).choice(["male", "female"], size=n) + height.groupby(gender).hist() + + def test_series_groupby_plotting_nominally_works_alpha(self): + n = 10 + height = Series(np.random.default_rng(2).normal(60, 10, size=n)) + gender = np.random.default_rng(2).choice(["male", "female"], size=n) + # Regression test for GH8733 + height.groupby(gender).plot(alpha=0.5) + + def test_plotting_with_float_index_works(self): + # GH 7025 + df = DataFrame( + { + "def": [1, 1, 1, 2, 2, 2, 3, 3, 3], + "val": np.random.default_rng(2).standard_normal(9), + }, + index=[1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], + ) + + df.groupby("def")["val"].plot() + + def test_plotting_with_float_index_works_apply(self): + # GH 7025 + df = DataFrame( + { + "def": [1, 1, 1, 2, 2, 2, 3, 3, 3], + "val": np.random.default_rng(2).standard_normal(9), + }, + index=[1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], + ) + df.groupby("def")["val"].apply(lambda x: x.plot()) + + def test_hist_single_row(self): + # GH10214 + bins = np.arange(80, 100 + 2, 1) + df = DataFrame({"Name": ["AAA", "BBB"], "ByCol": [1, 2], "Mark": [85, 89]}) + df["Mark"].hist(by=df["ByCol"], bins=bins) + + def test_hist_single_row_single_bycol(self): + # GH10214 + bins = np.arange(80, 100 + 2, 1) + df = DataFrame({"Name": ["AAA"], "ByCol": [1], "Mark": [85]}) + df["Mark"].hist(by=df["ByCol"], bins=bins) + + def test_plot_submethod_works(self): + df = DataFrame({"x": [1, 2, 3, 4, 5], "y": [1, 2, 3, 2, 1], "z": list("ababa")}) + df.groupby("z").plot.scatter("x", "y") + + def test_plot_submethod_works_line(self): + df = DataFrame({"x": [1, 2, 3, 4, 5], "y": [1, 2, 3, 2, 1], "z": list("ababa")}) + df.groupby("z")["x"].plot.line() + + def test_plot_kwargs(self): + df = DataFrame({"x": [1, 2, 3, 4, 5], "y": [1, 2, 3, 2, 1], "z": list("ababa")}) + + res = df.groupby("z").plot(kind="scatter", x="x", y="y") + # check that a scatter plot is effectively plotted: the axes should + # contain a PathCollection from the scatter plot (GH11805) + assert len(res["a"].collections) == 1 + + def test_plot_kwargs_scatter(self): + df = DataFrame({"x": [1, 2, 3, 4, 5], "y": [1, 2, 3, 2, 1], "z": list("ababa")}) + res = df.groupby("z").plot.scatter(x="x", y="y") + assert len(res["a"].collections) == 1 + + @pytest.mark.parametrize("column, expected_axes_num", [(None, 2), ("b", 1)]) + def test_groupby_hist_frame_with_legend(self, column, expected_axes_num): + # GH 6279 - DataFrameGroupBy histogram can have a legend + expected_layout = (1, expected_axes_num) + expected_labels = column or [["a"], ["b"]] + + index = Index(15 * ["1"] + 15 * ["2"], name="c") + df = DataFrame( + np.random.default_rng(2).standard_normal((30, 2)), + index=index, + columns=["a", "b"], + ) + g = df.groupby("c") + + for axes in g.hist(legend=True, column=column): + _check_axes_shape(axes, axes_num=expected_axes_num, layout=expected_layout) + for ax, expected_label in zip(axes[0], expected_labels, strict=True): + _check_legend_labels(ax, expected_label) + + @pytest.mark.parametrize("column", [None, "b"]) + def test_groupby_hist_frame_with_legend_raises(self, column): + # GH 6279 - DataFrameGroupBy histogram with legend and label raises + index = Index(15 * ["1"] + 15 * ["2"], name="c") + df = DataFrame( + np.random.default_rng(2).standard_normal((30, 2)), + index=index, + columns=["a", "b"], + ) + g = df.groupby("c") + + with pytest.raises(ValueError, match="Cannot use both legend and label"): + g.hist(legend=True, column=column, label="d") + + def test_groupby_hist_series_with_legend(self): + # GH 6279 - SeriesGroupBy histogram can have a legend + index = Index(15 * ["1"] + 15 * ["2"], name="c") + df = DataFrame( + np.random.default_rng(2).standard_normal((30, 2)), + index=index, + columns=["a", "b"], + ) + g = df.groupby("c") + + for ax in g["a"].hist(legend=True): + _check_axes_shape(ax, axes_num=1, layout=(1, 1)) + _check_legend_labels(ax, ["1", "2"]) + + def test_groupby_hist_series_with_legend_raises(self): + # GH 6279 - SeriesGroupBy histogram with legend and label raises + index = Index(15 * ["1"] + 15 * ["2"], name="c") + df = DataFrame( + np.random.default_rng(2).standard_normal((30, 2)), + index=index, + columns=["a", "b"], + ) + g = df.groupby("c") + + with pytest.raises(ValueError, match="Cannot use both legend and label"): + g.hist(legend=True, label="d") diff --git a/pandas/tests/plotting/test_hist_method.py b/pandas/tests/plotting/test_hist_method.py new file mode 100644 index 0000000000000000000000000000000000000000..e71d4ce5475a89dd686f831afcd1e4c4d726851d --- /dev/null +++ b/pandas/tests/plotting/test_hist_method.py @@ -0,0 +1,957 @@ +"""Test cases for .hist method""" + +import re + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + Series, + date_range, + to_datetime, +) +import pandas._testing as tm +from pandas.tests.plotting.common import ( + _check_ax_scales, + _check_axes_shape, + _check_colors, + _check_legend_labels, + _check_patches_all_filled, + _check_plot_works, + _check_text_labels, + _check_ticks_props, + get_x_axis, + get_y_axis, +) + +mpl = pytest.importorskip("matplotlib") +plt = pytest.importorskip("matplotlib.pyplot") + +from pandas.plotting._matplotlib.hist import _grouped_hist + + +@pytest.fixture +def ts(): + return Series( + np.arange(30, dtype=np.float64), + index=date_range("2020-01-01", periods=30, freq="B"), + name="ts", + ) + + +class TestSeriesPlots: + @pytest.mark.parametrize("kwargs", [{}, {"grid": False}, {"figsize": (8, 10)}]) + def test_hist_legacy_kwargs(self, ts, kwargs): + _check_plot_works(ts.hist, **kwargs) + + @pytest.mark.parametrize("kwargs", [{}, {"bins": 5}]) + def test_hist_legacy_kwargs_warning(self, ts, kwargs): + # _check_plot_works adds an ax so catch warning. see GH #13188 + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + _check_plot_works(ts.hist, by=ts.index.month, **kwargs) + + def test_hist_legacy_ax(self, ts): + fig, ax = mpl.pyplot.subplots(1, 1) + _check_plot_works(ts.hist, ax=ax, default_axes=True) + + def test_hist_legacy_ax_and_fig(self, ts): + fig, ax = mpl.pyplot.subplots(1, 1) + _check_plot_works(ts.hist, ax=ax, figure=fig, default_axes=True) + + def test_hist_legacy_fig(self, ts): + fig, _ = mpl.pyplot.subplots(1, 1) + _check_plot_works(ts.hist, figure=fig, default_axes=True) + + def test_hist_legacy_multi_ax(self, ts): + fig, (ax1, ax2) = mpl.pyplot.subplots(1, 2) + _check_plot_works(ts.hist, figure=fig, ax=ax1, default_axes=True) + _check_plot_works(ts.hist, figure=fig, ax=ax2, default_axes=True) + + def test_hist_legacy_by_fig_error(self, ts): + fig, _ = mpl.pyplot.subplots(1, 1) + msg = ( + "Cannot pass 'figure' when using the 'by' argument, since a new 'Figure' " + "instance will be created" + ) + with pytest.raises(ValueError, match=msg): + ts.hist(by=ts.index, figure=fig) + + def test_hist_bins_legacy(self): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2))) + ax = df.hist(bins=2)[0][0] + assert len(ax.patches) == 2 + + def test_hist_layout(self, hist_df): + df = hist_df + msg = "The 'layout' keyword is not supported when 'by' is None" + with pytest.raises(ValueError, match=msg): + df.height.hist(layout=(1, 1)) + + with pytest.raises(ValueError, match=msg): + df.height.hist(layout=[1, 1]) + + @pytest.mark.slow + @pytest.mark.parametrize( + "by, layout, axes_num, res_layout", + [ + ["gender", (2, 1), 2, (2, 1)], + ["gender", (3, -1), 2, (3, 1)], + ["category", (4, 1), 4, (4, 1)], + ["category", (2, -1), 4, (2, 2)], + ["category", (3, -1), 4, (3, 2)], + ["category", (-1, 4), 4, (1, 4)], + ["classroom", (2, 2), 3, (2, 2)], + ], + ) + def test_hist_layout_with_by(self, hist_df, by, layout, axes_num, res_layout): + df = hist_df + + # _check_plot_works adds an `ax` kwarg to the method call + # so we get a warning about an axis being cleared, even + # though we don't explicitly pass one, see GH #13188 + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + axes = _check_plot_works(df.height.hist, by=getattr(df, by), layout=layout) + _check_axes_shape(axes, axes_num=axes_num, layout=res_layout) + + def test_hist_layout_with_by_shape(self, hist_df): + df = hist_df + + axes = df.height.hist(by=df.category, layout=(4, 2), figsize=(12, 7)) + _check_axes_shape(axes, axes_num=4, layout=(4, 2), figsize=(12, 7)) + + def test_hist_no_overlap(self): + x = Series(np.random.default_rng(2).standard_normal(2)) + y = Series(np.random.default_rng(2).standard_normal(2)) + plt.subplot(121) + x.hist() + plt.subplot(122) + y.hist() + fig = plt.gcf() + axes = fig.axes + assert len(axes) == 2 + + def test_hist_by_no_extra_plots(self, hist_df): + df = hist_df + df.height.hist(by=df.gender) + assert len(mpl.pyplot.get_fignums()) == 1 + + def test_plot_fails_when_ax_differs_from_figure(self, ts): + fig1 = plt.figure(1) + fig2 = plt.figure(2) + ax1 = fig1.add_subplot(111) + msg = "passed axis not bound to passed figure" + with pytest.raises(AssertionError, match=msg): + ts.hist(ax=ax1, figure=fig2) + + @pytest.mark.parametrize( + "histtype, expected", + [ + ("bar", True), + ("barstacked", True), + ("step", False), + ("stepfilled", True), + ], + ) + def test_histtype_argument(self, histtype, expected): + # GH23992 Verify functioning of histtype argument + ser = Series(np.random.default_rng(2).integers(1, 10)) + ax = ser.hist(histtype=histtype) + _check_patches_all_filled(ax, filled=expected) + + @pytest.mark.parametrize( + "by, expected_axes_num, expected_layout", [(None, 1, (1, 1)), ("b", 2, (1, 2))] + ) + def test_hist_with_legend(self, by, expected_axes_num, expected_layout): + # GH 6279 - Series histogram can have a legend + index = 5 * ["1"] + 5 * ["2"] + s = Series(np.random.default_rng(2).standard_normal(10), index=index, name="a") + s.index.name = "b" + + # Use default_axes=True when plotting method generate subplots itself + axes = _check_plot_works(s.hist, default_axes=True, legend=True, by=by) + _check_axes_shape(axes, axes_num=expected_axes_num, layout=expected_layout) + _check_legend_labels(axes, "a") + + @pytest.mark.parametrize("by", [None, "b"]) + def test_hist_with_legend_raises(self, by): + # GH 6279 - Series histogram with legend and label raises + index = 5 * ["1"] + 5 * ["2"] + s = Series(np.random.default_rng(2).standard_normal(10), index=index, name="a") + s.index.name = "b" + + with pytest.raises(ValueError, match="Cannot use both legend and label"): + s.hist(legend=True, by=by, label="c") + + def test_hist_kwargs(self, ts): + _, ax = mpl.pyplot.subplots() + ax = ts.plot.hist(bins=5, ax=ax) + assert len(ax.patches) == 5 + _check_text_labels(ax.yaxis.get_label(), "Frequency") + + def test_hist_kwargs_horizontal(self, ts): + _, ax = mpl.pyplot.subplots() + ax = ts.plot.hist(bins=5, ax=ax) + ax = ts.plot.hist(orientation="horizontal", ax=ax) + _check_text_labels(ax.xaxis.get_label(), "Frequency") + + def test_hist_kwargs_align(self, ts): + _, ax = mpl.pyplot.subplots() + ax = ts.plot.hist(bins=5, ax=ax) + ax = ts.plot.hist(align="left", stacked=True, ax=ax) + + @pytest.mark.xfail(reason="Api changed in 3.6.0") + def test_hist_kde(self, ts): + pytest.importorskip("scipy") + _, ax = mpl.pyplot.subplots() + ax = ts.plot.hist(logy=True, ax=ax) + _check_ax_scales(ax, yaxis="log") + xlabels = ax.get_xticklabels() + # ticks are values, thus ticklabels are blank + _check_text_labels(xlabels, [""] * len(xlabels)) + ylabels = ax.get_yticklabels() + _check_text_labels(ylabels, [""] * len(ylabels)) + + def test_hist_kde_plot_works(self, ts): + pytest.importorskip("scipy") + _check_plot_works(ts.plot.kde) + + def test_hist_kde_density_works(self, ts): + pytest.importorskip("scipy") + _check_plot_works(ts.plot.density) + + @pytest.mark.xfail(reason="Api changed in 3.6.0") + def test_hist_kde_logy(self, ts): + pytest.importorskip("scipy") + _, ax = mpl.pyplot.subplots() + ax = ts.plot.kde(logy=True, ax=ax) + _check_ax_scales(ax, yaxis="log") + xlabels = ax.get_xticklabels() + _check_text_labels(xlabels, [""] * len(xlabels)) + ylabels = ax.get_yticklabels() + _check_text_labels(ylabels, [""] * len(ylabels)) + + def test_hist_kde_color_bins(self, ts): + pytest.importorskip("scipy") + _, ax = mpl.pyplot.subplots() + ax = ts.plot.hist(logy=True, bins=10, color="b", ax=ax) + _check_ax_scales(ax, yaxis="log") + assert len(ax.patches) == 10 + _check_colors(ax.patches, facecolors=["b"] * 10) + + def test_hist_kde_color(self, ts): + pytest.importorskip("scipy") + _, ax = mpl.pyplot.subplots() + ax = ts.plot.kde(logy=True, color="r", ax=ax) + _check_ax_scales(ax, yaxis="log") + lines = ax.get_lines() + assert len(lines) == 1 + _check_colors(lines, ["r"]) + + +class TestDataFramePlots: + @pytest.mark.slow + def test_hist_df_legacy(self, hist_df): + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + _check_plot_works(hist_df.hist) + + @pytest.mark.slow + def test_hist_df_legacy_layout(self): + # make sure layout is handled + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2))) + df[2] = to_datetime( + np.random.default_rng(2).integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + axes = _check_plot_works(df.hist, grid=False) + _check_axes_shape(axes, axes_num=3, layout=(2, 2)) + assert not axes[1, 1].get_visible() + + _check_plot_works(df[[2]].hist) + + @pytest.mark.slow + def test_hist_df_legacy_layout2(self): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 1))) + _check_plot_works(df.hist) + + @pytest.mark.slow + def test_hist_df_legacy_layout3(self): + # make sure layout is handled + df = DataFrame(np.random.default_rng(2).standard_normal((10, 5))) + df[5] = to_datetime( + np.random.default_rng(2).integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + axes = _check_plot_works(df.hist, layout=(4, 2)) + _check_axes_shape(axes, axes_num=6, layout=(4, 2)) + + @pytest.mark.slow + @pytest.mark.parametrize( + "kwargs", [{"sharex": True, "sharey": True}, {"figsize": (8, 10)}, {"bins": 5}] + ) + def test_hist_df_legacy_layout_kwargs(self, kwargs): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 5))) + df[5] = to_datetime( + np.random.default_rng(2).integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + # make sure sharex, sharey is handled + # handle figsize arg + # check bins argument + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + _check_plot_works(df.hist, **kwargs) + + @pytest.mark.slow + def test_hist_df_legacy_layout_labelsize_rot(self, frame_or_series): + # make sure xlabelsize and xrot are handled + obj = frame_or_series(range(10)) + xf, yf = 20, 18 + xrot, yrot = 30, 40 + axes = obj.hist(xlabelsize=xf, xrot=xrot, ylabelsize=yf, yrot=yrot) + _check_ticks_props(axes, xlabelsize=xf, xrot=xrot, ylabelsize=yf, yrot=yrot) + + @pytest.mark.slow + def test_hist_df_legacy_rectangles(self): + ser = Series(range(10)) + ax = ser.hist(cumulative=True, bins=4, density=True) + # height of last bin (index 5) must be 1.0 + rects = [x for x in ax.get_children() if isinstance(x, mpl.patches.Rectangle)] + tm.assert_almost_equal(rects[-1].get_height(), 1.0) + + @pytest.mark.slow + def test_hist_df_legacy_scale(self): + ser = Series(range(10)) + ax = ser.hist(log=True) + # scale of y must be 'log' + _check_ax_scales(ax, yaxis="log") + + @pytest.mark.slow + def test_hist_df_legacy_external_error(self): + ser = Series(range(10)) + # propagate attr exception from matplotlib.Axes.hist + with tm.external_error_raised(AttributeError): + ser.hist(foo="bar") + + def test_hist_non_numerical_or_datetime_raises(self): + # gh-10444, GH32590 + df = DataFrame( + { + "a": np.random.default_rng(2).random(10), + "b": np.random.default_rng(2).integers(0, 10, 10), + "c": to_datetime( + np.random.default_rng(2).integers( + 1582800000000000000, 1583500000000000000, 10, dtype=np.int64 + ) + ), + "d": to_datetime( + np.random.default_rng(2).integers( + 1582800000000000000, 1583500000000000000, 10, dtype=np.int64 + ), + utc=True, + ), + } + ) + df_o = df.astype(object) + + msg = "hist method requires numerical or datetime columns, nothing to plot." + with pytest.raises(ValueError, match=msg): + df_o.hist() + + @pytest.mark.parametrize( + "layout_test", + ( + {"layout": None, "expected_size": (2, 2)}, # default is 2x2 + {"layout": (2, 2), "expected_size": (2, 2)}, + {"layout": (4, 1), "expected_size": (4, 1)}, + {"layout": (1, 4), "expected_size": (1, 4)}, + {"layout": (3, 3), "expected_size": (3, 3)}, + {"layout": (-1, 4), "expected_size": (1, 4)}, + {"layout": (4, -1), "expected_size": (4, 1)}, + {"layout": (-1, 2), "expected_size": (2, 2)}, + {"layout": (2, -1), "expected_size": (2, 2)}, + ), + ) + def test_hist_layout(self, layout_test): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2))) + df[2] = to_datetime( + np.random.default_rng(2).integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + axes = df.hist(layout=layout_test["layout"]) + expected = layout_test["expected_size"] + _check_axes_shape(axes, axes_num=3, layout=expected) + + def test_hist_layout_error(self): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2))) + df[2] = to_datetime( + np.random.default_rng(2).integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + # layout too small for all 4 plots + msg = "Layout of 1x1 must be larger than required size 3" + with pytest.raises(ValueError, match=msg): + df.hist(layout=(1, 1)) + + # invalid format for layout + msg = re.escape("Layout must be a tuple of (rows, columns)") + with pytest.raises(ValueError, match=msg): + df.hist(layout=(1,)) + msg = "At least one dimension of layout must be positive" + with pytest.raises(ValueError, match=msg): + df.hist(layout=(-1, -1)) + + # GH 9351 + def test_tight_layout(self): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2))) + df[2] = to_datetime( + np.random.default_rng(2).integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + # Use default_axes=True when plotting method generate subplots itself + _check_plot_works(df.hist, default_axes=True) + mpl.pyplot.tight_layout() + + def test_hist_subplot_xrot(self): + # GH 30288 + df = DataFrame( + { + "length": [1.5, 0.5, 1.2, 0.9, 3], + "animal": ["pig", "rabbit", "pig", "pig", "rabbit"], + } + ) + # Use default_axes=True when plotting method generate subplots itself + axes = _check_plot_works( + df.hist, + default_axes=True, + column="length", + by="animal", + bins=5, + xrot=0, + ) + _check_ticks_props(axes, xrot=0) + + @pytest.mark.parametrize( + "column, expected", + [ + (None, ["width", "length", "height"]), + (["length", "width", "height"], ["length", "width", "height"]), + ], + ) + def test_hist_column_order_unchanged(self, column, expected): + # GH29235 + + df = DataFrame( + { + "width": [0.7, 0.2, 0.15, 0.2, 1.1], + "length": [1.5, 0.5, 1.2, 0.9, 3], + "height": [3, 0.5, 3.4, 2, 1], + }, + index=["pig", "rabbit", "duck", "chicken", "horse"], + ) + + # Use default_axes=True when plotting method generate subplots itself + axes = _check_plot_works( + df.hist, + default_axes=True, + column=column, + layout=(1, 3), + ) + result = [axes[0, i].get_title() for i in range(3)] + assert result == expected + + @pytest.mark.parametrize( + "histtype, expected", + [ + ("bar", True), + ("barstacked", True), + ("step", False), + ("stepfilled", True), + ], + ) + def test_histtype_argument(self, histtype, expected): + # GH23992 Verify functioning of histtype argument + df = DataFrame( + np.random.default_rng(2).integers(1, 10, size=(10, 2)), columns=["a", "b"] + ) + ax = df.hist(histtype=histtype) + _check_patches_all_filled(ax, filled=expected) + + @pytest.mark.parametrize("by", [None, "c"]) + @pytest.mark.parametrize("column", [None, "b"]) + def test_hist_with_legend(self, by, column): + # GH 6279 - DataFrame histogram can have a legend + expected_axes_num = 1 if by is None and column is not None else 2 + expected_layout = (1, expected_axes_num) + expected_labels = column or ["a", "b"] + if by is not None: + expected_labels = [expected_labels] * 2 + + index = Index(5 * ["1"] + 5 * ["2"], name="c") + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 2)), + index=index, + columns=["a", "b"], + ) + + # Use default_axes=True when plotting method generate subplots itself + axes = _check_plot_works( + df.hist, + default_axes=True, + legend=True, + by=by, + column=column, + ) + + _check_axes_shape(axes, axes_num=expected_axes_num, layout=expected_layout) + if by is None and column is None: + axes = axes[0] + for expected_label, ax in zip(expected_labels, axes, strict=True): + _check_legend_labels(ax, expected_label) + + @pytest.mark.parametrize("by", [None, "c"]) + @pytest.mark.parametrize("column", [None, "b"]) + def test_hist_with_legend_raises(self, by, column): + # GH 6279 - DataFrame histogram with legend and label raises + index = Index(5 * ["1"] + 5 * ["2"], name="c") + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 2)), + index=index, + columns=["a", "b"], + ) + + with pytest.raises(ValueError, match="Cannot use both legend and label"): + df.hist(legend=True, by=by, column=column, label="d") + + def test_hist_df_kwargs(self): + df = DataFrame(np.random.default_rng(2).standard_normal((10, 2))) + _, ax = mpl.pyplot.subplots() + ax = df.plot.hist(bins=5, ax=ax) + assert len(ax.patches) == 10 + + def test_hist_df_with_nonnumerics(self): + # GH 9853 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=["A", "B", "C", "D"], + ) + df["E"] = ["x", "y"] * 5 + _, ax = mpl.pyplot.subplots() + ax = df.plot.hist(bins=5, ax=ax) + assert len(ax.patches) == 20 + + def test_hist_df_with_nonnumerics_no_bins(self): + # GH 9853 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=["A", "B", "C", "D"], + ) + df["E"] = ["x", "y"] * 5 + _, ax = mpl.pyplot.subplots() + ax = df.plot.hist(ax=ax) # bins=10 + assert len(ax.patches) == 40 + + def test_hist_secondary_legend(self): + # GH 9610 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), columns=list("abcd") + ) + + # primary -> secondary + _, ax = mpl.pyplot.subplots() + ax = df["a"].plot.hist(legend=True, ax=ax) + df["b"].plot.hist(ax=ax, legend=True, secondary_y=True) + # both legends are drawn on left ax + # left and right axis must be visible + _check_legend_labels(ax, labels=["a", "b (right)"]) + assert ax.get_yaxis().get_visible() + assert ax.right_ax.get_yaxis().get_visible() + + def test_hist_secondary_secondary(self): + # GH 9610 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), columns=list("abcd") + ) + # secondary -> secondary + _, ax = mpl.pyplot.subplots() + ax = df["a"].plot.hist(legend=True, secondary_y=True, ax=ax) + df["b"].plot.hist(ax=ax, legend=True, secondary_y=True) + # both legends are draw on left ax + # left axis must be invisible, right axis must be visible + _check_legend_labels(ax.left_ax, labels=["a (right)", "b (right)"]) + assert not ax.left_ax.get_yaxis().get_visible() + assert ax.get_yaxis().get_visible() + + def test_hist_secondary_primary(self): + # GH 9610 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), columns=list("abcd") + ) + # secondary -> primary + _, ax = mpl.pyplot.subplots() + ax = df["a"].plot.hist(legend=True, secondary_y=True, ax=ax) + # right axes is returned + df["b"].plot.hist(ax=ax, legend=True) + # both legends are draw on left ax + # left and right axis must be visible + _check_legend_labels(ax.left_ax, labels=["a (right)", "b"]) + assert ax.left_ax.get_yaxis().get_visible() + assert ax.get_yaxis().get_visible() + + def test_hist_with_nans_and_weights(self): + # GH 48884 + df = DataFrame( + [[np.nan, 0.2, 0.3], [0.4, np.nan, np.nan], [0.7, 0.8, 0.9]], + columns=list("abc"), + ) + weights = np.array([0.25, 0.3, 0.45]) + no_nan_df = DataFrame([[0.4, 0.2, 0.3], [0.7, 0.8, 0.9]], columns=list("abc")) + no_nan_weights = np.array([[0.3, 0.25, 0.25], [0.45, 0.45, 0.45]]) + + _, ax0 = mpl.pyplot.subplots() + df.plot.hist(ax=ax0, weights=weights) + rects = [x for x in ax0.get_children() if isinstance(x, mpl.patches.Rectangle)] + heights = [rect.get_height() for rect in rects] + _, ax1 = mpl.pyplot.subplots() + no_nan_df.plot.hist(ax=ax1, weights=no_nan_weights) + no_nan_rects = [ + x for x in ax1.get_children() if isinstance(x, mpl.patches.Rectangle) + ] + no_nan_heights = [rect.get_height() for rect in no_nan_rects] + assert all(h0 == h1 for h0, h1 in zip(heights, no_nan_heights, strict=True)) + + idxerror_weights = np.array([[0.3, 0.25], [0.45, 0.45]]) + + msg = "weights must have the same shape as data, or be a single column" + _, ax2 = mpl.pyplot.subplots() + with pytest.raises(ValueError, match=msg): + no_nan_df.plot.hist(ax=ax2, weights=idxerror_weights) + + +class TestDataFrameGroupByPlots: + def test_grouped_hist_legacy(self): + rs = np.random.default_rng(10) + df = DataFrame(rs.standard_normal((10, 1)), columns=["A"]) + df["B"] = to_datetime( + rs.integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + df["C"] = rs.integers(0, 4, 10) + df["D"] = ["X"] * 10 + + axes = _grouped_hist(df.A, by=df.C) + _check_axes_shape(axes, axes_num=4, layout=(2, 2)) + + def test_grouped_hist_legacy_axes_shape_no_col(self): + rs = np.random.default_rng(10) + df = DataFrame(rs.standard_normal((10, 1)), columns=["A"]) + df["B"] = to_datetime( + rs.integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + df["C"] = rs.integers(0, 4, 10) + df["D"] = ["X"] * 10 + axes = df.hist(by=df.C) + _check_axes_shape(axes, axes_num=4, layout=(2, 2)) + + def test_grouped_hist_legacy_single_key(self): + rs = np.random.default_rng(2) + df = DataFrame(rs.standard_normal((10, 1)), columns=["A"]) + df["B"] = to_datetime( + rs.integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + df["C"] = rs.integers(0, 4, 10) + df["D"] = ["X"] * 10 + # group by a key with single value + axes = df.hist(by="D", rot=30) + _check_axes_shape(axes, axes_num=1, layout=(1, 1)) + _check_ticks_props(axes, xrot=30) + + def test_grouped_hist_legacy_grouped_hist_kwargs(self): + rs = np.random.default_rng(2) + df = DataFrame(rs.standard_normal((10, 1)), columns=["A"]) + df["B"] = to_datetime( + rs.integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + df["C"] = rs.integers(0, 4, 10) + # make sure kwargs to hist are handled + xf, yf = 20, 18 + xrot, yrot = 30, 40 + + axes = _grouped_hist( + df.A, + by=df.C, + cumulative=True, + bins=4, + xlabelsize=xf, + xrot=xrot, + ylabelsize=yf, + yrot=yrot, + density=True, + ) + # height of last bin (index 5) must be 1.0 + for ax in axes.ravel(): + rects = [ + x for x in ax.get_children() if isinstance(x, mpl.patches.Rectangle) + ] + height = rects[-1].get_height() + tm.assert_almost_equal(height, 1.0) + _check_ticks_props(axes, xlabelsize=xf, xrot=xrot, ylabelsize=yf, yrot=yrot) + + def test_grouped_hist_legacy_grouped_hist(self): + rs = np.random.default_rng(2) + df = DataFrame(rs.standard_normal((10, 1)), columns=["A"]) + df["B"] = to_datetime( + rs.integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + df["C"] = rs.integers(0, 4, 10) + df["D"] = ["X"] * 10 + axes = _grouped_hist(df.A, by=df.C, log=True) + # scale of y must be 'log' + _check_ax_scales(axes, yaxis="log") + + def test_grouped_hist_legacy_external_err(self): + rs = np.random.default_rng(2) + df = DataFrame(rs.standard_normal((10, 1)), columns=["A"]) + df["B"] = to_datetime( + rs.integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + df["C"] = rs.integers(0, 4, 10) + df["D"] = ["X"] * 10 + # propagate attr exception from matplotlib.Axes.hist + with tm.external_error_raised(AttributeError): + _grouped_hist(df.A, by=df.C, foo="bar") + + def test_grouped_hist_legacy_figsize_err(self): + rs = np.random.default_rng(2) + df = DataFrame(rs.standard_normal((10, 1)), columns=["A"]) + df["B"] = to_datetime( + rs.integers( + 812419200000000000, + 819331200000000000, + size=10, + dtype=np.int64, + ) + ) + df["C"] = rs.integers(0, 4, 10) + df["D"] = ["X"] * 10 + msg = "Specify figure size by tuple instead" + with pytest.raises(ValueError, match=msg): + df.hist(by="C", figsize="default") + + def test_grouped_hist_legacy2(self): + n = 10 + weight = Series(np.random.default_rng(2).normal(166, 20, size=n)) + height = Series(np.random.default_rng(2).normal(60, 10, size=n)) + gender_int = np.random.default_rng(2).choice([0, 1], size=n) + df_int = DataFrame({"height": height, "weight": weight, "gender": gender_int}) + gb = df_int.groupby("gender") + axes = gb.hist() + assert len(axes) == 2 + assert len(mpl.pyplot.get_fignums()) == 2 + + @pytest.mark.slow + @pytest.mark.parametrize( + "msg, plot_col, by_col, layout", + [ + [ + "Layout of 1x1 must be larger than required size 2", + "weight", + "gender", + (1, 1), + ], + [ + "Layout of 1x3 must be larger than required size 4", + "height", + "category", + (1, 3), + ], + [ + "At least one dimension of layout must be positive", + "height", + "category", + (-1, -1), + ], + ], + ) + def test_grouped_hist_layout_error(self, hist_df, msg, plot_col, by_col, layout): + df = hist_df + with pytest.raises(ValueError, match=msg): + df.hist(column=plot_col, by=getattr(df, by_col), layout=layout) + + @pytest.mark.slow + def test_grouped_hist_layout_warning(self, hist_df): + df = hist_df + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + axes = _check_plot_works( + df.hist, column="height", by=df.gender, layout=(2, 1) + ) + _check_axes_shape(axes, axes_num=2, layout=(2, 1)) + + @pytest.mark.slow + @pytest.mark.parametrize( + "layout, check_layout, figsize", + [[(4, 1), (4, 1), None], [(-1, 1), (4, 1), None], [(4, 2), (4, 2), (12, 8)]], + ) + def test_grouped_hist_layout_figsize(self, hist_df, layout, check_layout, figsize): + df = hist_df + axes = df.hist(column="height", by=df.category, layout=layout, figsize=figsize) + _check_axes_shape(axes, axes_num=4, layout=check_layout, figsize=figsize) + + @pytest.mark.slow + @pytest.mark.parametrize("kwargs", [{}, {"column": "height", "layout": (2, 2)}]) + def test_grouped_hist_layout_by_warning(self, hist_df, kwargs): + df = hist_df + # GH 6769 + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + axes = _check_plot_works(df.hist, by="classroom", **kwargs) + _check_axes_shape(axes, axes_num=3, layout=(2, 2)) + + @pytest.mark.slow + @pytest.mark.parametrize( + "kwargs, axes_num, layout", + [ + [{"by": "gender", "layout": (3, 5)}, 2, (3, 5)], + [{"column": ["height", "weight", "category"]}, 3, (2, 2)], + ], + ) + def test_grouped_hist_layout_axes(self, hist_df, kwargs, axes_num, layout): + df = hist_df + axes = df.hist(**kwargs) + _check_axes_shape(axes, axes_num=axes_num, layout=layout) + + def test_grouped_hist_multiple_axes(self, hist_df): + # GH 6970, GH 7069 + df = hist_df + + fig, axes = mpl.pyplot.subplots(2, 3) + returned = df.hist(column=["height", "weight", "category"], ax=axes[0]) + _check_axes_shape(returned, axes_num=3, layout=(1, 3)) + tm.assert_numpy_array_equal(returned, axes[0]) + assert returned[0].figure is fig + + def test_grouped_hist_multiple_axes_no_cols(self, hist_df): + # GH 6970, GH 7069 + df = hist_df + + fig, axes = mpl.pyplot.subplots(2, 3) + returned = df.hist(by="classroom", ax=axes[1]) + _check_axes_shape(returned, axes_num=3, layout=(1, 3)) + tm.assert_numpy_array_equal(returned, axes[1]) + assert returned[0].figure is fig + + def test_grouped_hist_multiple_axes_error(self, hist_df): + # GH 6970, GH 7069 + df = hist_df + fig, axes = mpl.pyplot.subplots(2, 3) + # pass different number of axes from required + msg = "The number of passed axes must be 1, the same as the output plot" + with pytest.raises(ValueError, match=msg): + axes = df.hist(column="height", ax=axes) + + def test_axis_share_x(self, hist_df): + df = hist_df + # GH4089 + ax1, ax2 = df.hist(column="height", by=df.gender, sharex=True) + + # share x + assert get_x_axis(ax1).joined(ax1, ax2) + assert get_x_axis(ax2).joined(ax1, ax2) + + # don't share y + assert not get_y_axis(ax1).joined(ax1, ax2) + assert not get_y_axis(ax2).joined(ax1, ax2) + + def test_axis_share_y(self, hist_df): + df = hist_df + ax1, ax2 = df.hist(column="height", by=df.gender, sharey=True) + + # share y + assert get_y_axis(ax1).joined(ax1, ax2) + assert get_y_axis(ax2).joined(ax1, ax2) + + # don't share x + assert not get_x_axis(ax1).joined(ax1, ax2) + assert not get_x_axis(ax2).joined(ax1, ax2) + + def test_axis_share_xy(self, hist_df): + df = hist_df + ax1, ax2 = df.hist(column="height", by=df.gender, sharex=True, sharey=True) + + # share both x and y + assert get_x_axis(ax1).joined(ax1, ax2) + assert get_x_axis(ax2).joined(ax1, ax2) + + assert get_y_axis(ax1).joined(ax1, ax2) + assert get_y_axis(ax2).joined(ax1, ax2) + + @pytest.mark.parametrize( + "histtype, expected", + [ + ("bar", True), + ("barstacked", True), + ("step", False), + ("stepfilled", True), + ], + ) + def test_histtype_argument(self, histtype, expected): + # GH23992 Verify functioning of histtype argument + df = DataFrame( + np.random.default_rng(2).integers(1, 10, size=(10, 2)), columns=["a", "b"] + ) + ax = df.hist(by="a", histtype=histtype) + _check_patches_all_filled(ax, filled=expected) diff --git a/pandas/tests/plotting/test_misc.py b/pandas/tests/plotting/test_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..8b6d2499787c02347b0f152fe0d4e0231568e544 --- /dev/null +++ b/pandas/tests/plotting/test_misc.py @@ -0,0 +1,866 @@ +"""Test cases for misc plot functions""" + +import os + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas import ( + DataFrame, + Index, + Series, + Timestamp, + date_range, + interval_range, + period_range, + plotting, + read_csv, +) +import pandas._testing as tm +from pandas.tests.plotting.common import ( + _check_colors, + _check_legend_labels, + _check_plot_works, + _check_text_labels, + _check_ticks_props, +) + +mpl = pytest.importorskip("matplotlib") +plt = pytest.importorskip("matplotlib.pyplot") +cm = pytest.importorskip("matplotlib.cm") + +import re + +from pandas.plotting._matplotlib.style import get_standard_colors + + +@pytest.fixture +def iris(datapath) -> DataFrame: + """ + The iris dataset as a DataFrame. + """ + return read_csv(datapath("io", "data", "csv", "iris.csv")) + + +@td.skip_if_installed("matplotlib") +def test_import_error_message(): + # GH-19810 + df = DataFrame({"A": [1, 2]}) + + with pytest.raises(ImportError, match="matplotlib is required for plotting"): + df.plot() + + +def test_get_accessor_args(): + func = plotting._core.PlotAccessor._get_call_args + + msg = "Called plot accessor for type list, expected Series or DataFrame" + with pytest.raises(TypeError, match=msg): + func(backend_name="", data=[], args=[], kwargs={}) + + msg = "should not be called with positional arguments" + with pytest.raises(TypeError, match=msg): + func(backend_name="", data=Series(dtype=object), args=["line", None], kwargs={}) + + x, y, kind, kwargs = func( + backend_name="", + data=DataFrame(), + args=["x"], + kwargs={"y": "y", "kind": "bar", "grid": False}, + ) + assert x == "x" + assert y == "y" + assert kind == "bar" + assert kwargs == {"grid": False} + + x, y, kind, kwargs = func( + backend_name="pandas.plotting._matplotlib", + data=Series(dtype=object), + args=[], + kwargs={}, + ) + assert x is None + assert y is None + assert kind == "line" + assert len(kwargs) == 24 + + +@pytest.mark.parametrize("kind", plotting.PlotAccessor._all_kinds) +@pytest.mark.parametrize( + "data", [DataFrame(np.arange(15).reshape(5, 3)), Series(range(5))] +) +@pytest.mark.parametrize( + "index", + [ + Index(range(5)), + date_range("2020-01-01", periods=5), + period_range("2020-01-01", periods=5), + ], +) +def test_savefig(kind, data, index): + fig, ax = plt.subplots() + data.index = index + kwargs = {} + if kind in ["hexbin", "scatter", "pie"]: + if isinstance(data, Series): + pytest.skip(f"{kind} not supported with Series") + kwargs = {"x": 0, "y": 1} + data.plot(kind=kind, ax=ax, **kwargs) + fig.savefig(os.devnull) + + +class TestSeriesPlots: + def test_autocorrelation_plot(self): + ser = Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + # Ensure no UserWarning when making plot + with tm.assert_produces_warning(None): + _check_plot_works(plotting.autocorrelation_plot, series=ser) + _check_plot_works(plotting.autocorrelation_plot, series=ser.values) + + ax = plotting.autocorrelation_plot(ser, label="Test") + _check_legend_labels(ax, labels=["Test"]) + + @pytest.mark.parametrize("kwargs", [{}, {"lag": 5}]) + def test_lag_plot(self, kwargs): + ser = Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + _check_plot_works(plotting.lag_plot, series=ser, **kwargs) + + def test_bootstrap_plot(self): + ser = Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + _check_plot_works(plotting.bootstrap_plot, series=ser, size=10) + + +class TestDataFramePlots: + @pytest.mark.parametrize("pass_axis", [False, True]) + def test_scatter_matrix_axis(self, pass_axis): + pytest.importorskip("scipy") + scatter_matrix = plotting.scatter_matrix + + ax = None + if pass_axis: + _, ax = mpl.pyplot.subplots(3, 3) + + df = DataFrame(np.random.default_rng(2).standard_normal((10, 3))) + + # we are plotting multiples on a sub-plot + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + axes = _check_plot_works( + scatter_matrix, + frame=df, + range_padding=0.1, + ax=ax, + ) + axes0_labels = axes[0][0].yaxis.get_majorticklabels() + # GH 5662 + expected = ["-2", "-1", "0"] + _check_text_labels(axes0_labels, expected) + _check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0) + + @pytest.mark.parametrize("pass_axis", [False, True]) + def test_scatter_matrix_axis_smaller(self, pass_axis): + pytest.importorskip("scipy") + scatter_matrix = plotting.scatter_matrix + + ax = None + if pass_axis: + _, ax = mpl.pyplot.subplots(3, 3) + + df = DataFrame(np.random.default_rng(11).standard_normal((10, 3))) + df[0] = (df[0] - 2) / 3 + + # we are plotting multiples on a sub-plot + with tm.assert_produces_warning(UserWarning, check_stacklevel=False): + axes = _check_plot_works( + scatter_matrix, + frame=df, + range_padding=0.1, + ax=ax, + ) + axes0_labels = axes[0][0].yaxis.get_majorticklabels() + expected = ["-1.25", "-1.0", "-0.75", "-0.5"] + _check_text_labels(axes0_labels, expected) + _check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0) + + @pytest.mark.slow + def test_andrews_curves_no_warning(self, iris): + # Ensure no UserWarning when making plot + with tm.assert_produces_warning(None): + _check_plot_works(plotting.andrews_curves, frame=iris, class_column="Name") + + @pytest.mark.slow + @pytest.mark.parametrize( + "linecolors", + [ + ("#556270", "#4ECDC4", "#C7F464"), + ["dodgerblue", "aquamarine", "seagreen"], + ], + ) + @pytest.mark.parametrize( + "df", + [ + "iris", + DataFrame( + { + "A": np.random.default_rng(2).standard_normal(10), + "B": np.random.default_rng(2).standard_normal(10), + "C": np.random.default_rng(2).standard_normal(10), + "Name": ["A"] * 10, + } + ), + ], + ) + def test_andrews_curves_linecolors(self, request, df, linecolors): + if isinstance(df, str): + df = request.getfixturevalue(df) + ax = _check_plot_works( + plotting.andrews_curves, frame=df, class_column="Name", color=linecolors + ) + _check_colors( + ax.get_lines()[:10], linecolors=linecolors, mapping=df["Name"][:10] + ) + + @pytest.mark.slow + @pytest.mark.parametrize( + "df", + [ + "iris", + DataFrame( + { + "A": np.random.default_rng(2).standard_normal(10), + "B": np.random.default_rng(2).standard_normal(10), + "C": np.random.default_rng(2).standard_normal(10), + "Name": ["A"] * 10, + } + ), + ], + ) + def test_andrews_curves_cmap(self, request, df): + if isinstance(df, str): + df = request.getfixturevalue(df) + cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())] + ax = _check_plot_works( + plotting.andrews_curves, frame=df, class_column="Name", color=cmaps + ) + _check_colors(ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10]) + + @pytest.mark.slow + def test_andrews_curves_handle(self): + colors = ["b", "g", "r"] + df = DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "C": [1, 2, 3], "Name": colors}) + ax = plotting.andrews_curves(df, "Name", color=colors) + handles, _ = ax.get_legend_handles_labels() + _check_colors(handles, linecolors=colors) + + @pytest.mark.slow + @pytest.mark.parametrize( + "color", + [("#556270", "#4ECDC4", "#C7F464"), ["dodgerblue", "aquamarine", "seagreen"]], + ) + def test_parallel_coordinates_colors(self, iris, color): + df = iris + + ax = _check_plot_works( + plotting.parallel_coordinates, frame=df, class_column="Name", color=color + ) + _check_colors(ax.get_lines()[:10], linecolors=color, mapping=df["Name"][:10]) + + @pytest.mark.slow + def test_parallel_coordinates_cmap(self, iris): + df = iris + + ax = _check_plot_works( + plotting.parallel_coordinates, + frame=df, + class_column="Name", + colormap=cm.jet, + ) + cmaps = [mpl.cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())] + _check_colors(ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10]) + + @pytest.mark.slow + def test_parallel_coordinates_line_diff(self, iris): + df = iris + + ax = _check_plot_works( + plotting.parallel_coordinates, frame=df, class_column="Name" + ) + nlines = len(ax.get_lines()) + nxticks = len(ax.xaxis.get_ticklabels()) + + ax = _check_plot_works( + plotting.parallel_coordinates, frame=df, class_column="Name", axvlines=False + ) + assert len(ax.get_lines()) == (nlines - nxticks) + + @pytest.mark.slow + def test_parallel_coordinates_handles(self, iris): + df = iris + colors = ["b", "g", "r"] + df = DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "C": [1, 2, 3], "Name": colors}) + ax = plotting.parallel_coordinates(df, "Name", color=colors) + handles, _ = ax.get_legend_handles_labels() + _check_colors(handles, linecolors=colors) + + # not sure if this is indicative of a problem + @pytest.mark.filterwarnings("ignore:Attempting to set:UserWarning") + def test_parallel_coordinates_with_sorted_labels(self): + # GH 15908 + df = DataFrame( + { + "feat": list(range(30)), + "class": [2 for _ in range(10)] + + [3 for _ in range(10)] + + [1 for _ in range(10)], + } + ) + ax = plotting.parallel_coordinates(df, "class", sort_labels=True) + polylines, labels = ax.get_legend_handles_labels() + color_label_tuples = zip( + [polyline.get_color() for polyline in polylines], labels, strict=True + ) + ordered_color_label_tuples = sorted(color_label_tuples, key=lambda x: x[1]) + prev_next_tupels = zip( + list(ordered_color_label_tuples[0:-1]), + list(ordered_color_label_tuples[1:]), + strict=True, + ) + for prev, nxt in prev_next_tupels: + # labels and colors are ordered strictly increasing + assert prev[1] < nxt[1] and prev[0] < nxt[0] + + def test_radviz_no_warning(self, iris): + # Ensure no UserWarning when making plot + with tm.assert_produces_warning(None): + _check_plot_works(plotting.radviz, frame=iris, class_column="Name") + + @pytest.mark.parametrize( + "color", + [("#556270", "#4ECDC4", "#C7F464"), ["dodgerblue", "aquamarine", "seagreen"]], + ) + def test_radviz_color(self, iris, color): + df = iris + ax = _check_plot_works( + plotting.radviz, frame=df, class_column="Name", color=color + ) + # skip Circle drawn as ticks + patches = [p for p in ax.patches[:20] if p.get_label() != ""] + _check_colors(patches[:10], facecolors=color, mapping=df["Name"][:10]) + + def test_radviz_color_cmap(self, iris): + df = iris + ax = _check_plot_works( + plotting.radviz, frame=df, class_column="Name", colormap=cm.jet + ) + cmaps = [mpl.cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())] + patches = [p for p in ax.patches[:20] if p.get_label() != ""] + _check_colors(patches, facecolors=cmaps, mapping=df["Name"][:10]) + + def test_radviz_colors_handles(self): + colors = [[0.0, 0.0, 1.0, 1.0], [0.0, 0.5, 1.0, 1.0], [1.0, 0.0, 0.0, 1.0]] + df = DataFrame( + {"A": [1, 2, 3], "B": [2, 1, 3], "C": [3, 2, 1], "Name": ["b", "g", "r"]} + ) + ax = plotting.radviz(df, "Name", color=colors) + handles, _ = ax.get_legend_handles_labels() + _check_colors(handles, facecolors=colors) + + def test_subplot_titles(self, iris): + df = iris.drop("Name", axis=1).head() + # Use the column names as the subplot titles + title = list(df.columns) + + # Case len(title) == len(df) + plot = df.plot(subplots=True, title=title) + assert [p.get_title() for p in plot] == title + + def test_subplot_titles_too_much(self, iris): + df = iris.drop("Name", axis=1).head() + # Use the column names as the subplot titles + title = list(df.columns) + # Case len(title) > len(df) + msg = ( + "The length of `title` must equal the number of columns if " + "using `title` of type `list` and `subplots=True`" + ) + with pytest.raises(ValueError, match=msg): + df.plot(subplots=True, title=[*title, "kittens > puppies"]) + + def test_subplot_titles_too_little(self, iris): + df = iris.drop("Name", axis=1).head() + # Use the column names as the subplot titles + title = list(df.columns) + msg = ( + "The length of `title` must equal the number of columns if " + "using `title` of type `list` and `subplots=True`" + ) + # Case len(title) < len(df) + with pytest.raises(ValueError, match=msg): + df.plot(subplots=True, title=title[:2]) + + def test_subplot_titles_subplots_false(self, iris): + df = iris.drop("Name", axis=1).head() + # Use the column names as the subplot titles + title = list(df.columns) + # Case subplots=False and title is of type list + msg = ( + "Using `title` of type `list` is not supported unless " + "`subplots=True` is passed" + ) + with pytest.raises(ValueError, match=msg): + df.plot(subplots=False, title=title) + + def test_subplot_titles_numeric_square_layout(self, iris): + df = iris.drop("Name", axis=1).head() + # Use the column names as the subplot titles + title = list(df.columns) + # Case df with 3 numeric columns but layout of (2,2) + plot = df.drop("SepalWidth", axis=1).plot( + subplots=True, layout=(2, 2), title=title[:-1] + ) + title_list = [ax.get_title() for sublist in plot for ax in sublist] + assert title_list == [*title[:3], ""] + + def test_get_standard_colors_random_seed(self): + # GH17525 + df = DataFrame(np.zeros((10, 10))) + + # Make sure that the random seed isn't reset by get_standard_colors + plotting.parallel_coordinates(df, 0) + rand1 = np.random.default_rng(None).random() + plotting.parallel_coordinates(df, 0) + rand2 = np.random.default_rng(None).random() + assert rand1 != rand2 + + def test_get_standard_colors_consistency(self): + # GH17525 + # Make sure it produces the same colors every time it's called + color1 = get_standard_colors(1, color_type="random") + color2 = get_standard_colors(1, color_type="random") + assert color1 == color2 + + def test_get_standard_colors_default_num_colors(self): + # Make sure the default color_types returns the specified amount + color1 = get_standard_colors(1, color_type="default") + color2 = get_standard_colors(9, color_type="default") + color3 = get_standard_colors(20, color_type="default") + assert len(color1) == 1 + assert len(color2) == 9 + assert len(color3) == 20 + + def test_plot_single_color(self): + # Example from #20585. All 3 bars should have the same color + df = DataFrame( + { + "account-start": ["2017-02-03", "2017-03-03", "2017-01-01"], + "client": ["Alice Anders", "Bob Baker", "Charlie Chaplin"], + "balance": [-1432.32, 10.43, 30000.00], + "db-id": [1234, 2424, 251], + "proxy-id": [525, 1525, 2542], + "rank": [52, 525, 32], + } + ) + ax = df.client.value_counts().plot.bar() + colors = [rect.get_facecolor() for rect in ax.get_children()[0:3]] + assert all(color == colors[0] for color in colors) + + def test_get_standard_colors_no_appending(self): + # GH20726 + + # Make sure not to add more colors so that matplotlib can cycle + # correctly. + color_before = mpl.cm.gnuplot(range(5)) + color_after = get_standard_colors(1, color=color_before) + assert len(color_after) == len(color_before) + + df = DataFrame( + np.random.default_rng(2).standard_normal((48, 4)), columns=list("ABCD") + ) + + color_list = mpl.cm.gnuplot(np.linspace(0, 1, 16)) + p = df.A.plot.bar(figsize=(16, 7), color=color_list) + assert p.patches[1].get_facecolor() == p.patches[17].get_facecolor() + + @pytest.mark.parametrize("kind", ["bar", "line"]) + def test_dictionary_color(self, kind): + # issue-8193 + # Test plot color dictionary format + data_files = ["a", "b"] + + expected = [(0.5, 0.24, 0.6), (0.3, 0.7, 0.7)] + + df1 = DataFrame(np.random.default_rng(2).random((2, 2)), columns=data_files) + dic_color = {"b": (0.3, 0.7, 0.7), "a": (0.5, 0.24, 0.6)} + + ax = df1.plot(kind=kind, color=dic_color) + if kind == "bar": + colors = [rect.get_facecolor()[0:-1] for rect in ax.get_children()[0:3:2]] + else: + colors = [rect.get_color() for rect in ax.get_lines()[0:2]] + assert all(color == expected[index] for index, color in enumerate(colors)) + + def test_bar_plot(self): + # GH38947 + # Test bar plot with string and int index + expected = [mpl.text.Text(0, 0, "0"), mpl.text.Text(1, 0, "Total")] + + df = DataFrame( + { + "a": [1, 2], + }, + index=Index([0, "Total"]), + ) + plot_bar = df.plot.bar() + assert all( + (a.get_text() == b.get_text()) + for a, b in zip(plot_bar.get_xticklabels(), expected, strict=True) + ) + + def test_barh_plot_labels_mixed_integer_string(self): + # GH39126 + # Test barh plot with string and integer at the same column + df = DataFrame([{"word": 1, "value": 0}, {"word": "knowledge", "value": 2}]) + plot_barh = df.plot.barh(x="word", legend=None) + expected_yticklabels = [ + mpl.text.Text(0, 0, "1"), + mpl.text.Text(0, 1, "knowledge"), + ] + assert all( + actual.get_text() == expected.get_text() + for actual, expected in zip( + plot_barh.get_yticklabels(), expected_yticklabels, strict=True + ) + ) + + def test_has_externally_shared_axis_x_axis(self): + # GH33819 + # Test _has_externally_shared_axis() works for x-axis + func = plotting._matplotlib.tools._has_externally_shared_axis + + fig = mpl.pyplot.figure() + plots = fig.subplots(2, 4) + + # Create *externally* shared axes for first and third columns + plots[0][0] = fig.add_subplot(231, sharex=plots[1][0]) + plots[0][2] = fig.add_subplot(233, sharex=plots[1][2]) + + # Create *internally* shared axes for second and third columns + plots[0][1].twinx() + plots[0][2].twinx() + + # First column is only externally shared + # Second column is only internally shared + # Third column is both + # Fourth column is neither + assert func(plots[0][0], "x") + assert not func(plots[0][1], "x") + assert func(plots[0][2], "x") + assert not func(plots[0][3], "x") + + def test_has_externally_shared_axis_y_axis(self): + # GH33819 + # Test _has_externally_shared_axis() works for y-axis + func = plotting._matplotlib.tools._has_externally_shared_axis + + fig = mpl.pyplot.figure() + plots = fig.subplots(4, 2) + + # Create *externally* shared axes for first and third rows + plots[0][0] = fig.add_subplot(321, sharey=plots[0][1]) + plots[2][0] = fig.add_subplot(325, sharey=plots[2][1]) + + # Create *internally* shared axes for second and third rows + plots[1][0].twiny() + plots[2][0].twiny() + + # First row is only externally shared + # Second row is only internally shared + # Third row is both + # Fourth row is neither + assert func(plots[0][0], "y") + assert not func(plots[1][0], "y") + assert func(plots[2][0], "y") + assert not func(plots[3][0], "y") + + def test_has_externally_shared_axis_invalid_compare_axis(self): + # GH33819 + # Test _has_externally_shared_axis() raises an exception when + # passed an invalid value as compare_axis parameter + func = plotting._matplotlib.tools._has_externally_shared_axis + + fig = mpl.pyplot.figure() + plots = fig.subplots(4, 2) + + # Create arbitrary axes + plots[0][0] = fig.add_subplot(321, sharey=plots[0][1]) + + # Check that an invalid compare_axis value triggers the expected exception + msg = "needs 'x' or 'y' as a second parameter" + with pytest.raises(ValueError, match=msg): + func(plots[0][0], "z") + + def test_externally_shared_axes(self): + # Example from GH33819 + # Create data + df = DataFrame( + { + "a": np.random.default_rng(2).standard_normal(10), + "b": np.random.default_rng(2).standard_normal(10), + } + ) + + # Create figure + fig = mpl.pyplot.figure() + plots = fig.subplots(2, 3) + + # Create *externally* shared axes + plots[0][0] = fig.add_subplot(231, sharex=plots[1][0]) + # note: no plots[0][1] that's the twin only case + plots[0][2] = fig.add_subplot(233, sharex=plots[1][2]) + + # Create *internally* shared axes + # note: no plots[0][0] that's the external only case + twin_ax1 = plots[0][1].twinx() + twin_ax2 = plots[0][2].twinx() + + # Plot data to primary axes + df["a"].plot(ax=plots[0][0], title="External share only").set_xlabel( + "this label should never be visible" + ) + df["a"].plot(ax=plots[1][0]) + + df["a"].plot(ax=plots[0][1], title="Internal share (twin) only").set_xlabel( + "this label should always be visible" + ) + df["a"].plot(ax=plots[1][1]) + + df["a"].plot(ax=plots[0][2], title="Both").set_xlabel( + "this label should never be visible" + ) + df["a"].plot(ax=plots[1][2]) + + # Plot data to twinned axes + df["b"].plot(ax=twin_ax1, color="green") + df["b"].plot(ax=twin_ax2, color="yellow") + + assert not plots[0][0].xaxis.get_label().get_visible() + assert plots[0][1].xaxis.get_label().get_visible() + assert not plots[0][2].xaxis.get_label().get_visible() + + def test_plot_bar_axis_units_timestamp_conversion(self): + # GH 38736 + # Ensure string x-axis from the second plot will not be converted to datetime + # due to axis data from first plot + df = DataFrame( + [1.0], + index=[Timestamp("2022-02-22 22:22:22")], + ) + _check_plot_works(df.plot) + s = Series({"A": 1.0}) + _check_plot_works(s.plot.bar) + + def test_bar_plt_xaxis_intervalrange(self): + # GH 38969 + # Ensure IntervalIndex x-axis produces a bar plot as expected + expected = [mpl.text.Text(0, 0, "([0, 1],)"), mpl.text.Text(1, 0, "([1, 2],)")] + s = Series( + [1, 2], + index=[interval_range(0, 2, closed="both")], + ) + _check_plot_works(s.plot.bar) + assert all( + (a.get_text() == b.get_text()) + for a, b in zip(s.plot.bar().get_xticklabels(), expected, strict=True) + ) + + +@pytest.fixture +def df_bar_data(): + return np.random.default_rng(3).integers(0, 100, 5) + + +@pytest.fixture +def df_bar_df(df_bar_data) -> DataFrame: + df_bar_df = DataFrame( + { + "A": df_bar_data, + "B": df_bar_data[::-1], + "C": df_bar_data[0], + "D": df_bar_data[-1], + } + ) + return df_bar_df + + +def _df_bar_xyheight_from_ax_helper(df_bar_data, ax, subplot_division): + subplot_data_df_list = [] + + # get xy and height of squares representing data, separated by subplots + for i in range(len(subplot_division)): + subplot_data = np.array( + [ + (x.get_x(), x.get_y(), x.get_height()) + for x in ax[i].findobj(plt.Rectangle) + if x.get_height() in df_bar_data + ] + ) + subplot_data_df_list.append( + DataFrame(data=subplot_data, columns=["x_coord", "y_coord", "height"]) + ) + + return subplot_data_df_list + + +def _df_bar_subplot_checker(df_bar_data, df_bar_df, subplot_data_df, subplot_columns): + subplot_sliced_by_source = [ + subplot_data_df.iloc[ + len(df_bar_data) * i : len(df_bar_data) * (i + 1) + ].reset_index() + for i in range(len(subplot_columns)) + ] + + if len(subplot_columns) == 1: + expected_total_height = df_bar_df.loc[:, subplot_columns[0]] + else: + expected_total_height = df_bar_df.loc[:, subplot_columns].sum(axis=1) + + for i in range(len(subplot_columns)): + sliced_df = subplot_sliced_by_source[i] + if i == 0: + # Checks that the bar chart starts y=0 + assert (sliced_df["y_coord"] == 0).all() + height_iter = sliced_df["y_coord"].add(sliced_df["height"]) + else: + height_iter = height_iter + sliced_df["height"] + + if i + 1 == len(subplot_columns): + # Checks final height matches what is expected + tm.assert_series_equal( + height_iter, expected_total_height, check_names=False, check_dtype=False + ) + else: + # Checks each preceding bar ends where the next one starts + next_start_coord = subplot_sliced_by_source[i + 1]["y_coord"] + tm.assert_series_equal( + height_iter, next_start_coord, check_names=False, check_dtype=False + ) + + +# GH Issue 61018 +@pytest.mark.parametrize("columns_used", [["A", "B"], ["C", "D"], ["D", "A"]]) +def test_bar_1_subplot_1_double_stacked(df_bar_data, df_bar_df, columns_used): + df_bar_df_trimmed = df_bar_df[columns_used] + subplot_division = [columns_used] + ax = df_bar_df_trimmed.plot(subplots=subplot_division, kind="bar", stacked=True) + subplot_data_df_list = _df_bar_xyheight_from_ax_helper( + df_bar_data, ax, subplot_division + ) + for i in range(len(subplot_data_df_list)): + _df_bar_subplot_checker( + df_bar_data, df_bar_df_trimmed, subplot_data_df_list[i], subplot_division[i] + ) + + +@pytest.mark.parametrize( + "columns_used", [["A", "B", "C"], ["A", "C", "B"], ["D", "A", "C"]] +) +def test_bar_2_subplot_1_double_stacked(df_bar_data, df_bar_df, columns_used): + df_bar_df_trimmed = df_bar_df[columns_used] + subplot_division = [(columns_used[0], columns_used[1]), (columns_used[2],)] + ax = df_bar_df_trimmed.plot(subplots=subplot_division, kind="bar", stacked=True) + subplot_data_df_list = _df_bar_xyheight_from_ax_helper( + df_bar_data, ax, subplot_division + ) + for i in range(len(subplot_data_df_list)): + _df_bar_subplot_checker( + df_bar_data, df_bar_df_trimmed, subplot_data_df_list[i], subplot_division[i] + ) + + +@pytest.mark.parametrize( + "subplot_division", + [ + [("A", "B"), ("C", "D")], + [("A", "D"), ("C", "B")], + [("B", "C"), ("D", "A")], + [("B", "D"), ("C", "A")], + ], +) +def test_bar_2_subplot_2_double_stacked(df_bar_data, df_bar_df, subplot_division): + ax = df_bar_df.plot(subplots=subplot_division, kind="bar", stacked=True) + subplot_data_df_list = _df_bar_xyheight_from_ax_helper( + df_bar_data, ax, subplot_division + ) + for i in range(len(subplot_data_df_list)): + _df_bar_subplot_checker( + df_bar_data, df_bar_df, subplot_data_df_list[i], subplot_division[i] + ) + + +@pytest.mark.parametrize( + "subplot_division", + [[("A", "B", "C")], [("A", "D", "B")], [("C", "A", "D")], [("D", "C", "A")]], +) +def test_bar_2_subplots_1_triple_stacked(df_bar_data, df_bar_df, subplot_division): + ax = df_bar_df.plot(subplots=subplot_division, kind="bar", stacked=True) + subplot_data_df_list = _df_bar_xyheight_from_ax_helper( + df_bar_data, ax, subplot_division + ) + for i in range(len(subplot_data_df_list)): + _df_bar_subplot_checker( + df_bar_data, df_bar_df, subplot_data_df_list[i], subplot_division[i] + ) + + +def test_bar_subplots_stacking_bool(df_bar_data, df_bar_df): + subplot_division = [("A"), ("B"), ("C"), ("D")] + ax = df_bar_df.plot(subplots=True, kind="bar", stacked=True) + subplot_data_df_list = _df_bar_xyheight_from_ax_helper( + df_bar_data, ax, subplot_division + ) + for i in range(len(subplot_data_df_list)): + _df_bar_subplot_checker( + df_bar_data, df_bar_df, subplot_data_df_list[i], subplot_division[i] + ) + + +def test_plot_bar_label_count_default(): + df = DataFrame( + [(30, 10, 10, 10), (20, 20, 20, 20), (10, 30, 30, 10)], columns=list("ABCD") + ) + df.plot(subplots=True, kind="bar", title=["A", "B", "C", "D"]) + + +def test_plot_bar_label_count_expected_fail(): + df = DataFrame( + [(30, 10, 10, 10), (20, 20, 20, 20), (10, 30, 30, 10)], columns=list("ABCD") + ) + error_regex = re.escape( + "The number of titles (4) must equal the number of subplots (3)." + ) + with pytest.raises(ValueError, match=error_regex): + df.plot( + subplots=[("A", "B")], + kind="bar", + title=["A&B", "C", "D", "Extra Title"], + ) + + +def test_plot_bar_label_count_expected_success(): + df = DataFrame( + [(30, 10, 10, 10), (20, 20, 20, 20), (10, 30, 30, 10)], columns=list("ABCD") + ) + df.plot(subplots=[("A", "B", "D")], kind="bar", title=["A&B&D", "C"]) diff --git a/pandas/tests/plotting/test_series.py b/pandas/tests/plotting/test_series.py new file mode 100644 index 0000000000000000000000000000000000000000..3ef1e660236c899ef19e9419e9c64402d81beace --- /dev/null +++ b/pandas/tests/plotting/test_series.py @@ -0,0 +1,1005 @@ +"""Test cases for Series.plot""" + +from datetime import datetime +from itertools import chain + +import numpy as np +import pytest + +from pandas.compat import is_platform_linux +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Series, + date_range, + period_range, + plotting, +) +import pandas._testing as tm +from pandas.tests.plotting.common import ( + _check_ax_scales, + _check_axes_shape, + _check_colors, + _check_grid_settings, + _check_has_errorbars, + _check_legend_labels, + _check_plot_works, + _check_text_labels, + _check_ticks_props, + _unpack_cycler, + get_y_axis, +) + +from pandas.tseries.offsets import CustomBusinessDay + +mpl = pytest.importorskip("matplotlib") +plt = pytest.importorskip("matplotlib.pyplot") + +from pandas.plotting._matplotlib.converter import DatetimeConverter +from pandas.plotting._matplotlib.style import get_standard_colors + +pytestmark = [ + pytest.mark.filterwarnings( + "ignore:divide by zero encountered in scalar divide:RuntimeWarning" + ), + pytest.mark.filterwarnings( + "ignore:invalid value encountered in scalar multiply:RuntimeWarning" + ), +] + + +@pytest.fixture +def ts(): + return Series( + np.arange(10, dtype=np.float64), + index=date_range("2020-01-01", periods=10), + name="ts", + ) + + +@pytest.fixture +def series(): + return Series( + range(10), dtype=np.float64, name="series", index=[f"i_{i}" for i in range(10)] + ) + + +class TestSeriesPlots: + @pytest.mark.slow + @pytest.mark.parametrize("kwargs", [{"label": "foo"}, {"use_index": False}]) + def test_plot(self, ts, kwargs): + _check_plot_works(ts.plot, **kwargs) + + @pytest.mark.slow + def test_plot_tick_props(self, ts): + axes = _check_plot_works(ts.plot, rot=0) + _check_ticks_props(axes, xrot=0) + + @pytest.mark.slow + @pytest.mark.parametrize( + "scale, exp_scale", + [ + [{"logy": True}, {"yaxis": "log"}], + [{"logx": True}, {"xaxis": "log"}], + [{"loglog": True}, {"xaxis": "log", "yaxis": "log"}], + ], + ) + def test_plot_scales(self, ts, scale, exp_scale): + ax = _check_plot_works(ts.plot, style=".", **scale) + _check_ax_scales(ax, **exp_scale) + + @pytest.mark.slow + def test_plot_ts_bar(self, ts): + _check_plot_works(ts[:10].plot.bar) + + @pytest.mark.slow + def test_plot_ts_area_stacked(self, ts): + _check_plot_works(ts.plot.area, stacked=False) + + def test_plot_iseries(self): + ser = Series(range(5), period_range("2020-01-01", periods=5)) + _check_plot_works(ser.plot) + + @pytest.mark.parametrize( + "kind", + [ + "line", + "bar", + "barh", + pytest.param("kde", marks=td.skip_if_no("scipy")), + "hist", + "box", + ], + ) + def test_plot_series_kinds(self, series, kind): + _check_plot_works(series[:5].plot, kind=kind) + + def test_plot_series_barh(self, series): + _check_plot_works(series[:10].plot.barh) + + def test_plot_series_bar_ax(self): + ax = _check_plot_works( + Series(np.random.default_rng(2).standard_normal(10)).plot.bar, color="black" + ) + _check_colors([ax.patches[0]], facecolors=["black"]) + + @pytest.mark.parametrize("kwargs", [{}, {"layout": (-1, 1)}, {"layout": (1, -1)}]) + def test_plot_6951(self, ts, kwargs): + # GH 6951 + ax = _check_plot_works(ts.plot, subplots=True, **kwargs) + _check_axes_shape(ax, axes_num=1, layout=(1, 1)) + + def test_plot_figsize_and_title(self, series): + # figsize and title + _, ax = mpl.pyplot.subplots() + ax = series.plot(title="Test", figsize=(16, 8), ax=ax) + _check_text_labels(ax.title, "Test") + _check_axes_shape(ax, axes_num=1, layout=(1, 1), figsize=(16, 8)) + + def test_dont_modify_rcParams(self): + # GH 8242 + key = "axes.prop_cycle" + colors = mpl.pyplot.rcParams[key] + _, ax = mpl.pyplot.subplots() + Series([1, 2, 3]).plot(ax=ax) + assert colors == mpl.pyplot.rcParams[key] + + @pytest.mark.parametrize("kwargs", [{}, {"secondary_y": True}]) + def test_ts_line_lim(self, ts, kwargs): + _, ax = mpl.pyplot.subplots() + ax = ts.plot(ax=ax, **kwargs) + xmin, xmax = ax.get_xlim() + lines = ax.get_lines() + assert xmin <= lines[0].get_data(orig=False)[0][0] + assert xmax >= lines[0].get_data(orig=False)[0][-1] + + def test_ts_area_lim(self, ts): + _, ax = mpl.pyplot.subplots() + ax = ts.plot.area(stacked=False, ax=ax) + xmin, xmax = ax.get_xlim() + line = ax.get_lines()[0].get_data(orig=False)[0] + assert xmin <= line[0] + assert xmax >= line[-1] + _check_ticks_props(ax, xrot=0) + + def test_ts_area_lim_xcompat(self, ts): + # GH 7471 + _, ax = mpl.pyplot.subplots() + ax = ts.plot.area(stacked=False, x_compat=True, ax=ax) + xmin, xmax = ax.get_xlim() + line = ax.get_lines()[0].get_data(orig=False)[0] + assert xmin <= line[0] + assert xmax >= line[-1] + _check_ticks_props(ax, xrot=30) + + def test_ts_tz_area_lim_xcompat(self, ts): + tz_ts = ts.copy() + tz_ts.index = tz_ts.tz_localize("GMT").tz_convert("CET") + _, ax = mpl.pyplot.subplots() + ax = tz_ts.plot.area(stacked=False, x_compat=True, ax=ax) + xmin, xmax = ax.get_xlim() + line = ax.get_lines()[0].get_data(orig=False)[0] + assert xmin <= line[0] + assert xmax >= line[-1] + _check_ticks_props(ax, xrot=0) + + def test_ts_tz_area_lim_xcompat_secondary_y(self, ts): + tz_ts = ts.copy() + tz_ts.index = tz_ts.tz_localize("GMT").tz_convert("CET") + _, ax = mpl.pyplot.subplots() + ax = tz_ts.plot.area(stacked=False, secondary_y=True, ax=ax) + xmin, xmax = ax.get_xlim() + line = ax.get_lines()[0].get_data(orig=False)[0] + assert xmin <= line[0] + assert xmax >= line[-1] + _check_ticks_props(ax, xrot=0) + + def test_area_sharey_dont_overwrite(self, ts): + # GH37942 + fig, (ax1, ax2) = mpl.pyplot.subplots(1, 2, sharey=True) + + abs(ts).plot(ax=ax1, kind="area") + abs(ts).plot(ax=ax2, kind="area") + + assert get_y_axis(ax1).joined(ax1, ax2) + assert get_y_axis(ax2).joined(ax1, ax2) + + def test_label(self): + s = Series([1, 2]) + _, ax = mpl.pyplot.subplots() + ax = s.plot(label="LABEL", legend=True, ax=ax) + _check_legend_labels(ax, labels=["LABEL"]) + + def test_label_none(self): + s = Series([1, 2]) + _, ax = mpl.pyplot.subplots() + ax = s.plot(legend=True, ax=ax) + _check_legend_labels(ax, labels=[""]) + + def test_label_ser_name(self): + s = Series([1, 2], name="NAME") + _, ax = mpl.pyplot.subplots() + ax = s.plot(legend=True, ax=ax) + _check_legend_labels(ax, labels=["NAME"]) + + def test_label_ser_name_override(self): + s = Series([1, 2], name="NAME") + # override the default + _, ax = mpl.pyplot.subplots() + ax = s.plot(legend=True, label="LABEL", ax=ax) + _check_legend_labels(ax, labels=["LABEL"]) + + def test_label_ser_name_override_dont_draw(self): + s = Series([1, 2], name="NAME") + # Add lebel info, but don't draw + _, ax = mpl.pyplot.subplots() + ax = s.plot(legend=False, label="LABEL", ax=ax) + assert ax.get_legend() is None # Hasn't been drawn + ax.legend() # draw it + _check_legend_labels(ax, labels=["LABEL"]) + + def test_boolean(self): + # GH 23719 + s = Series([False, False, True]) + _check_plot_works(s.plot, include_bool=True) + + msg = "no numeric data to plot" + with pytest.raises(TypeError, match=msg): + _check_plot_works(s.plot) + + @pytest.mark.parametrize("index", [None, date_range("2020-01-01", periods=4)]) + def test_line_area_nan_series(self, index): + values = [1, 2, np.nan, 3] + d = Series(values, index=index) + ax = _check_plot_works(d.plot) + masked = ax.lines[0].get_ydata() + # remove nan for comparison purpose + exp = np.array([1, 2, 3], dtype=np.float64) + tm.assert_numpy_array_equal(np.delete(masked.data, 2), exp) + tm.assert_numpy_array_equal(masked.mask, np.array([False, False, True, False])) + + expected = np.array([1, 2, 0, 3], dtype=np.float64) + ax = _check_plot_works(d.plot, stacked=True) + tm.assert_numpy_array_equal(ax.lines[0].get_ydata(), expected) + ax = _check_plot_works(d.plot.area) + tm.assert_numpy_array_equal(ax.lines[0].get_ydata(), expected) + ax = _check_plot_works(d.plot.area, stacked=False) + tm.assert_numpy_array_equal(ax.lines[0].get_ydata(), expected) + + def test_line_use_index_false(self): + s = Series([1, 2, 3], index=["a", "b", "c"]) + s.index.name = "The Index" + _, ax = mpl.pyplot.subplots() + ax = s.plot(use_index=False, ax=ax) + label = ax.get_xlabel() + assert label == "" + + def test_line_use_index_false_diff_var(self): + s = Series([1, 2, 3], index=["a", "b", "c"]) + s.index.name = "The Index" + _, ax = mpl.pyplot.subplots() + ax2 = s.plot.bar(use_index=False, ax=ax) + label2 = ax2.get_xlabel() + assert label2 == "" + + @pytest.mark.xfail( + is_platform_linux(), + reason="Weird rounding problems", + strict=False, + ) + @pytest.mark.parametrize("axis, meth", [("yaxis", "bar"), ("xaxis", "barh")]) + def test_bar_log(self, axis, meth): + expected = np.array([1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]) + + _, ax = mpl.pyplot.subplots() + ax = getattr(Series([200, 500]).plot, meth)(log=True, ax=ax) + tm.assert_numpy_array_equal(getattr(ax, axis).get_ticklocs(), expected) + + @pytest.mark.xfail( + is_platform_linux(), + reason="Weird rounding problems", + strict=False, + ) + @pytest.mark.parametrize( + "axis, kind, res_meth", + [["yaxis", "bar", "get_ylim"], ["xaxis", "barh", "get_xlim"]], + ) + def test_bar_log_kind_bar(self, axis, kind, res_meth): + # GH 9905 + expected = np.array([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1]) + + _, ax = mpl.pyplot.subplots() + ax = Series([0.1, 0.01, 0.001]).plot(log=True, kind=kind, ax=ax) + ymin = 0.0007943282347242822 + ymax = 0.12589254117941673 + res = getattr(ax, res_meth)() + tm.assert_almost_equal(res[0], ymin) + tm.assert_almost_equal(res[1], ymax) + tm.assert_numpy_array_equal(getattr(ax, axis).get_ticklocs(), expected) + + def test_bar_ignore_index(self): + df = Series([1, 2, 3, 4], index=["a", "b", "c", "d"]) + _, ax = mpl.pyplot.subplots() + ax = df.plot.bar(use_index=False, ax=ax) + _check_text_labels(ax.get_xticklabels(), ["0", "1", "2", "3"]) + + def test_bar_user_colors(self): + s = Series([1, 2, 3, 4]) + ax = s.plot.bar(color=["red", "blue", "blue", "red"]) + result = [p.get_facecolor() for p in ax.patches] + expected = [ + (1.0, 0.0, 0.0, 1.0), + (0.0, 0.0, 1.0, 1.0), + (0.0, 0.0, 1.0, 1.0), + (1.0, 0.0, 0.0, 1.0), + ] + assert result == expected + + def test_rotation_default(self): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 5))) + # Default rot 0 + _, ax = mpl.pyplot.subplots() + axes = df.plot(ax=ax) + _check_ticks_props(axes, xrot=0) + + def test_rotation_30(self): + df = DataFrame(np.random.default_rng(2).standard_normal((5, 5))) + _, ax = mpl.pyplot.subplots() + axes = df.plot(rot=30, ax=ax) + _check_ticks_props(axes, xrot=30) + + def test_irregular_datetime(self): + rng = date_range("1/1/2000", "1/15/2000") + rng = rng[[0, 1, 2, 3, 5, 9, 10, 11, 12]] + ser = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + _, ax = mpl.pyplot.subplots() + ax = ser.plot(ax=ax) + xp = DatetimeConverter.convert(datetime(1999, 1, 1), "", ax) + ax.set_xlim("1/1/1999", "1/1/2001") + assert xp == ax.get_xlim()[0] + _check_ticks_props(ax, xrot=30) + + def test_unsorted_index_xlim(self): + ser = Series( + [0.0, 1.0, np.nan, 3.0, 4.0, 5.0, 6.0], + index=[1.0, 0.0, 3.0, 2.0, np.nan, 3.0, 2.0], + ) + _, ax = mpl.pyplot.subplots() + ax = ser.plot(ax=ax) + xmin, xmax = ax.get_xlim() + lines = ax.get_lines() + assert xmin <= np.nanmin(lines[0].get_data(orig=False)[0]) + assert xmax >= np.nanmax(lines[0].get_data(orig=False)[0]) + + def test_pie_series(self): + # if sum of values is less than 1.0, pie handle them as rate and draw + # semicircle. + series = Series( + np.random.default_rng(2).integers(1, 5), + index=["a", "b", "c", "d", "e"], + name="YLABEL", + ) + ax = _check_plot_works(series.plot.pie) + _check_text_labels(ax.texts, series.index) + assert ax.get_ylabel() == "" + + def test_pie_arrow_type(self): + # GH 59192 + pytest.importorskip("pyarrow") + ser = Series([1, 2, 3, 4], dtype="int32[pyarrow]") + _check_plot_works(ser.plot.pie) + + def test_pie_series_no_label(self): + series = Series( + np.random.default_rng(2).integers(1, 5), + index=["a", "b", "c", "d", "e"], + name="YLABEL", + ) + ax = _check_plot_works(series.plot.pie, labels=None) + _check_text_labels(ax.texts, [""] * 5) + + def test_pie_series_less_colors_than_elements(self): + series = Series( + np.random.default_rng(2).integers(1, 5), + index=["a", "b", "c", "d", "e"], + name="YLABEL", + ) + color_args = ["r", "g", "b"] + ax = _check_plot_works(series.plot.pie, colors=color_args) + + color_expected = ["r", "g", "b", "r", "g"] + _check_colors(ax.patches, facecolors=color_expected) + + def test_pie_series_labels_and_colors(self): + series = Series( + np.random.default_rng(2).integers(1, 5), + index=["a", "b", "c", "d", "e"], + name="YLABEL", + ) + # with labels and colors + labels = ["A", "B", "C", "D", "E"] + color_args = ["r", "g", "b", "c", "m"] + ax = _check_plot_works(series.plot.pie, labels=labels, colors=color_args) + _check_text_labels(ax.texts, labels) + _check_colors(ax.patches, facecolors=color_args) + + def test_pie_series_autopct_and_fontsize(self): + series = Series( + np.random.default_rng(2).integers(1, 5), + index=["a", "b", "c", "d", "e"], + name="YLABEL", + ) + color_args = ["r", "g", "b", "c", "m"] + ax = _check_plot_works( + series.plot.pie, colors=color_args, autopct="%.2f", fontsize=7 + ) + pcts = [f"{s * 100:.2f}" for s in series.values / series.sum()] + expected_texts = list(chain.from_iterable(zip(series.index, pcts, strict=True))) + _check_text_labels(ax.texts, expected_texts) + for t in ax.texts: + assert t.get_fontsize() == 7 + + def test_pie_series_negative_raises(self): + # includes negative value + series = Series([1, 2, 0, 4, -1], index=["a", "b", "c", "d", "e"]) + with pytest.raises(ValueError, match="pie plot doesn't allow negative values"): + series.plot.pie() + + def test_pie_series_nan(self): + # includes nan + series = Series([1, 2, np.nan, 4], index=["a", "b", "c", "d"], name="YLABEL") + ax = _check_plot_works(series.plot.pie) + _check_text_labels(ax.texts, ["a", "b", "", "d"]) + + def test_pie_nan(self): + s = Series([1, np.nan, 1, 1]) + _, ax = mpl.pyplot.subplots() + ax = s.plot.pie(legend=True, ax=ax) + expected = ["0", "", "2", "3"] + result = [x.get_text() for x in ax.texts] + assert result == expected + + def test_df_series_secondary_legend(self): + # GH 9779 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 3)), columns=list("abc") + ) + s = Series(np.random.default_rng(2).standard_normal(10), name="x") + + # primary -> secondary (without passing ax) + _, ax = mpl.pyplot.subplots() + ax = df.plot(ax=ax) + s.plot(legend=True, secondary_y=True, ax=ax) + # both legends are drawn on left ax + # left and right axis must be visible + _check_legend_labels(ax, labels=["a", "b", "c", "x (right)"]) + assert ax.get_yaxis().get_visible() + assert ax.right_ax.get_yaxis().get_visible() + + def test_df_series_secondary_legend_both(self): + # GH 9779 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 3)), columns=list("abc") + ) + s = Series(np.random.default_rng(2).standard_normal(10), name="x") + # secondary -> secondary (without passing ax) + _, ax = mpl.pyplot.subplots() + ax = df.plot(secondary_y=True, ax=ax) + s.plot(legend=True, secondary_y=True, ax=ax) + # both legends are drawn on left ax + # left axis must be invisible and right axis must be visible + expected = ["a (right)", "b (right)", "c (right)", "x (right)"] + _check_legend_labels(ax.left_ax, labels=expected) + assert not ax.left_ax.get_yaxis().get_visible() + assert ax.get_yaxis().get_visible() + + def test_df_series_secondary_legend_both_with_axis_2(self): + # GH 9779 + df = DataFrame( + np.random.default_rng(2).standard_normal((10, 3)), columns=list("abc") + ) + s = Series(np.random.default_rng(2).standard_normal(10), name="x") + # secondary -> secondary (with passing ax) + _, ax = mpl.pyplot.subplots() + ax = df.plot(secondary_y=True, mark_right=False, ax=ax) + s.plot(ax=ax, legend=True, secondary_y=True) + # both legends are drawn on left ax + # left axis must be invisible and right axis must be visible + expected = ["a", "b", "c", "x (right)"] + _check_legend_labels(ax.left_ax, expected) + assert not ax.left_ax.get_yaxis().get_visible() + assert ax.get_yaxis().get_visible() + + @pytest.mark.parametrize( + "input_logy, expected_scale", [(True, "log"), ("sym", "symlog")] + ) + @pytest.mark.parametrize("secondary_kwarg", [{}, {"secondary_y": True}]) + def test_secondary_logy(self, input_logy, expected_scale, secondary_kwarg): + # GH 25545, GH 24980 + s1 = Series(np.random.default_rng(2).standard_normal(10)) + ax1 = s1.plot(logy=input_logy, **secondary_kwarg) + assert ax1.get_yscale() == expected_scale + + def test_plot_fails_with_dupe_color_and_style(self): + x = Series(np.random.default_rng(2).standard_normal(2)) + _, ax = mpl.pyplot.subplots() + msg = ( + "Cannot pass 'style' string with a color symbol and 'color' keyword " + "argument. Please use one or the other or pass 'style' without a color " + "symbol" + ) + with pytest.raises(ValueError, match=msg): + x.plot(style="k--", color="k", ax=ax) + + @pytest.mark.parametrize( + "bw_method, ind", + [ + ["scott", 20], + [None, 20], + [None, np.int_(20)], + [0.5, np.linspace(-100, 100, 20)], + ], + ) + def test_kde_kwargs(self, ts, bw_method, ind): + pytest.importorskip("scipy") + _check_plot_works(ts.plot.kde, bw_method=bw_method, ind=ind) + + @pytest.mark.parametrize( + "bw_method, ind, weights", + [ + ["scott", 20, None], + [None, 20, None], + [None, np.int_(20), None], + [0.5, np.linspace(-100, 100, 20), None], + ["scott", 40, np.linspace(0.0, 2.0, 50)], + ], + ) + def test_kde_kwargs_weights(self, bw_method, ind, weights): + # GH59337 + pytest.importorskip("scipy") + s = Series(np.random.default_rng(2).uniform(size=50)) + _check_plot_works(s.plot.kde, bw_method=bw_method, ind=ind, weights=weights) + + def test_density_kwargs(self, ts): + pytest.importorskip("scipy") + sample_points = np.linspace(-100, 100, 20) + _check_plot_works(ts.plot.density, bw_method=0.5, ind=sample_points) + + def test_kde_kwargs_check_axes(self, ts): + pytest.importorskip("scipy") + _, ax = mpl.pyplot.subplots() + sample_points = np.linspace(-100, 100, 20) + ax = ts.plot.kde(logy=True, bw_method=0.5, ind=sample_points, ax=ax) + _check_ax_scales(ax, yaxis="log") + _check_text_labels(ax.yaxis.get_label(), "Density") + + def test_kde_missing_vals(self): + pytest.importorskip("scipy") + s = Series(np.random.default_rng(2).uniform(size=50)) + s[0] = np.nan + axes = _check_plot_works(s.plot.kde) + + # gh-14821: check if the values have any missing values + assert any(~np.isnan(axes.lines[0].get_xdata())) + + @pytest.mark.xfail(reason="Api changed in 3.6.0") + def test_boxplot_series(self, ts): + _, ax = mpl.pyplot.subplots() + ax = ts.plot.box(logy=True, ax=ax) + _check_ax_scales(ax, yaxis="log") + xlabels = ax.get_xticklabels() + _check_text_labels(xlabels, [ts.name]) + ylabels = ax.get_yticklabels() + _check_text_labels(ylabels, [""] * len(ylabels)) + + @pytest.mark.parametrize( + "kind", + plotting.PlotAccessor._common_kinds + plotting.PlotAccessor._series_kinds, + ) + def test_kind_kwarg(self, kind): + pytest.importorskip("scipy") + s = Series(range(3)) + _, ax = mpl.pyplot.subplots() + s.plot(kind=kind, ax=ax) + mpl.pyplot.close() + + @pytest.mark.parametrize( + "kind", + plotting.PlotAccessor._common_kinds + plotting.PlotAccessor._series_kinds, + ) + def test_kind_attr(self, kind): + pytest.importorskip("scipy") + s = Series(range(3)) + _, ax = mpl.pyplot.subplots() + getattr(s.plot, kind)() + mpl.pyplot.close() + + @pytest.mark.parametrize("kind", plotting.PlotAccessor._common_kinds) + def test_invalid_plot_data(self, kind): + s = Series(list("abcd")) + _, ax = mpl.pyplot.subplots() + msg = "no numeric data to plot" + with pytest.raises(TypeError, match=msg): + s.plot(kind=kind, ax=ax) + + @pytest.mark.parametrize("kind", plotting.PlotAccessor._common_kinds) + def test_valid_object_plot(self, kind): + pytest.importorskip("scipy") + s = Series(range(10), dtype=object) + _check_plot_works(s.plot, kind=kind) + + @pytest.mark.parametrize("kind", plotting.PlotAccessor._common_kinds) + def test_partially_invalid_plot_data(self, kind): + s = Series(["a", "b", 1.0, 2]) + _, ax = mpl.pyplot.subplots() + msg = "no numeric data to plot" + with pytest.raises(TypeError, match=msg): + s.plot(kind=kind, ax=ax) + + def test_invalid_kind(self): + s = Series([1, 2]) + with pytest.raises(ValueError, match="invalid_kind is not a valid plot kind"): + s.plot(kind="invalid_kind") + + def test_dup_datetime_index_plot(self): + dr1 = date_range("1/1/2009", periods=4) + dr2 = date_range("1/2/2009", periods=4) + index = dr1.append(dr2) + values = np.random.default_rng(2).standard_normal(index.size) + s = Series(values, index=index) + _check_plot_works(s.plot) + + def test_errorbar_asymmetrical(self): + # GH9536 + s = Series(np.arange(10), name="x") + err = np.random.default_rng(2).random((2, 10)) + + ax = s.plot(yerr=err, xerr=err) + + result = np.vstack([i.vertices[:, 1] for i in ax.collections[1].get_paths()]) + expected = (err.T * np.array([-1, 1])) + s.to_numpy().reshape(-1, 1) + tm.assert_numpy_array_equal(result, expected) + + def test_errorbar_asymmetrical_error(self): + # GH9536 + s = Series(np.arange(10), name="x") + msg = ( + "Asymmetrical error bars should be provided " + f"with the shape \\(2, {len(s)}\\)" + ) + with pytest.raises(ValueError, match=msg): + s.plot(yerr=np.random.default_rng(2).random((2, 11))) + + @pytest.mark.slow + @pytest.mark.parametrize("kind", ["line", "bar"]) + @pytest.mark.parametrize( + "yerr", + [ + Series(np.abs(np.random.default_rng(2).standard_normal(10))), + np.abs(np.random.default_rng(2).standard_normal(10)), + list(np.abs(np.random.default_rng(2).standard_normal(10))), + DataFrame( + np.abs(np.random.default_rng(2).standard_normal((10, 2))), + columns=["x", "y"], + ), + ], + ) + def test_errorbar_plot(self, kind, yerr): + s = Series(np.arange(10), name="x") + ax = _check_plot_works(s.plot, yerr=yerr, kind=kind) + _check_has_errorbars(ax, xerr=0, yerr=1) + + @pytest.mark.slow + def test_errorbar_plot_yerr_0(self): + s = Series(np.arange(10), name="x") + s_err = np.abs(np.random.default_rng(2).standard_normal(10)) + ax = _check_plot_works(s.plot, xerr=s_err) + _check_has_errorbars(ax, xerr=1, yerr=0) + + @pytest.mark.slow + @pytest.mark.parametrize( + "yerr", + [ + Series(np.abs(np.random.default_rng(2).standard_normal(12))), + DataFrame( + np.abs(np.random.default_rng(2).standard_normal((12, 2))), + columns=["x", "y"], + ), + ], + ) + def test_errorbar_plot_ts(self, yerr): + # test time series plotting + ix = date_range("1/1/2000", "1/1/2001", freq="ME") + ts = Series(np.arange(12), index=ix, name="x") + yerr.index = ix + + ax = _check_plot_works(ts.plot, yerr=yerr) + _check_has_errorbars(ax, xerr=0, yerr=1) + + @pytest.mark.slow + def test_errorbar_plot_invalid_yerr_shape(self): + s = Series(np.arange(10), name="x") + # check incorrect lengths and types + with tm.external_error_raised(ValueError): + s.plot(yerr=np.arange(11)) + + @pytest.mark.slow + def test_errorbar_plot_invalid_yerr(self): + s = Series(np.arange(10), name="x") + s_err = ["zzz"] * 10 + with tm.external_error_raised(TypeError): + s.plot(yerr=s_err) + + @pytest.mark.slow + def test_table_true(self, series): + _check_plot_works(series.plot, table=True) + + @pytest.mark.slow + def test_table_self(self, series): + _check_plot_works(series.plot, table=series) + + @pytest.mark.slow + def test_series_grid_settings(self): + # Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792 + pytest.importorskip("scipy") + _check_grid_settings( + Series([1, 2, 3]), + plotting.PlotAccessor._series_kinds + plotting.PlotAccessor._common_kinds, + ) + + @pytest.mark.parametrize("c", ["r", "red", "green", "#FF0000"]) + def test_standard_colors(self, c): + result = get_standard_colors(1, color=c) + assert result == [c] + + result = get_standard_colors(1, color=[c]) + assert result == [c] + + result = get_standard_colors(3, color=c) + assert result == [c] * 3 + + result = get_standard_colors(3, color=[c]) + assert result == [c] * 3 + + def test_standard_colors_all(self): + # multiple colors like mediumaquamarine + for c in mpl.colors.cnames: + result = get_standard_colors(num_colors=1, color=c) + assert result == [c] + + result = get_standard_colors(num_colors=1, color=[c]) + assert result == [c] + + result = get_standard_colors(num_colors=3, color=c) + assert result == [c] * 3 + + result = get_standard_colors(num_colors=3, color=[c]) + assert result == [c] * 3 + + # single letter colors like k + for c in mpl.colors.ColorConverter.colors: + result = get_standard_colors(num_colors=1, color=c) + assert result == [c] + + result = get_standard_colors(num_colors=1, color=[c]) + assert result == [c] + + result = get_standard_colors(num_colors=3, color=c) + assert result == [c] * 3 + + result = get_standard_colors(num_colors=3, color=[c]) + assert result == [c] * 3 + + def test_series_plot_color_kwargs(self): + # GH1890 + _, ax = mpl.pyplot.subplots() + ax = Series(np.arange(12) + 1).plot(color="green", ax=ax) + _check_colors(ax.get_lines(), linecolors=["green"]) + + def test_time_series_plot_color_kwargs(self): + # #1890 + _, ax = mpl.pyplot.subplots() + ax = Series(np.arange(12) + 1, index=date_range("1/1/2000", periods=12)).plot( + color="green", ax=ax + ) + _check_colors(ax.get_lines(), linecolors=["green"]) + + def test_time_series_plot_color_with_empty_kwargs(self): + def_colors = _unpack_cycler(mpl.rcParams) + index = date_range("1/1/2000", periods=12) + s = Series(np.arange(1, 13), index=index) + + ncolors = 3 + + _, ax = mpl.pyplot.subplots() + for i in range(ncolors): + ax = s.plot(ax=ax) + _check_colors(ax.get_lines(), linecolors=def_colors[:ncolors]) + + def test_xticklabels(self): + # GH11529 + s = Series(np.arange(10), index=[f"P{i:02d}" for i in range(10)]) + _, ax = mpl.pyplot.subplots() + ax = s.plot(xticks=[0, 3, 5, 9], ax=ax) + exp = [f"P{i:02d}" for i in [0, 3, 5, 9]] + _check_text_labels(ax.get_xticklabels(), exp) + + def test_xtick_barPlot(self): + # GH28172 + s = Series(range(10), index=[f"P{i:02d}" for i in range(10)]) + ax = s.plot.bar(xticks=range(0, 11, 2)) + exp = np.array(list(range(0, 11, 2))) + tm.assert_numpy_array_equal(exp, ax.get_xticks()) + + def test_custom_business_day_freq(self): + # GH7222 + s = Series( + range(100, 121), + index=pd.bdate_range( + start="2014-05-01", + end="2014-06-01", + freq=CustomBusinessDay(holidays=["2014-05-26"]), + ), + ) + + _check_plot_works(s.plot) + + @pytest.mark.xfail( + reason="GH#24426, see also " + "github.com/pandas-dev/pandas/commit/" + "ef1bd69fa42bbed5d09dd17f08c44fc8bfc2b685#r61470674" + ) + def test_plot_accessor_updates_on_inplace(self): + ser = Series([1, 2, 3, 4]) + _, ax = mpl.pyplot.subplots() + ax = ser.plot(ax=ax) + before = ax.xaxis.get_ticklocs() + + ser.drop([0, 1], inplace=True) + _, ax = mpl.pyplot.subplots() + after = ax.xaxis.get_ticklocs() + tm.assert_numpy_array_equal(before, after) + + @pytest.mark.parametrize("kind", ["line", "area"]) + def test_plot_xlim_for_series(self, kind): + # test if xlim is also correctly plotted in Series for line and area + # GH 27686 + s = Series([2, 3]) + _, ax = mpl.pyplot.subplots() + s.plot(kind=kind, ax=ax) + xlims = ax.get_xlim() + + assert xlims[0] < 0 + assert xlims[1] > 1 + + def test_plot_no_rows(self): + # GH 27758 + df = Series(dtype=int) + assert df.empty + ax = df.plot() + assert len(ax.get_lines()) == 1 + line = ax.get_lines()[0] + assert len(line.get_xdata()) == 0 + assert len(line.get_ydata()) == 0 + + def test_plot_no_numeric_data(self): + df = Series(["a", "b", "c"]) + with pytest.raises(TypeError, match="no numeric data to plot"): + df.plot() + + @pytest.mark.parametrize( + "data, index", + [ + ([1, 2, 3, 4], [3, 2, 1, 0]), + ([10, 50, 20, 30], [1910, 1920, 1980, 1950]), + ], + ) + def test_plot_order(self, data, index): + # GH38865 Verify plot order of a Series + ser = Series(data=data, index=index) + ax = ser.plot(kind="bar") + + expected = ser.tolist() + result = [ + patch.get_bbox().ymax + for patch in sorted(ax.patches, key=lambda patch: patch.get_bbox().xmax) + ] + assert expected == result + + def test_style_single_ok(self): + s = Series([1, 2]) + ax = s.plot(style="s", color="C3") + assert ax.lines[0].get_color() == "C3" + + @pytest.mark.parametrize( + "index_name, old_label, new_label", + [(None, "", "new"), ("old", "old", "new"), (None, "", "")], + ) + @pytest.mark.parametrize("kind", ["line", "area", "bar", "barh", "hist"]) + def test_xlabel_ylabel_series(self, kind, index_name, old_label, new_label): + # GH 9093 + ser = Series([1, 2, 3, 4]) + ser.index.name = index_name + + # default is the ylabel is not shown and xlabel is index name (reverse for barh) + ax = ser.plot(kind=kind) + if kind == "barh": + assert ax.get_xlabel() == "" + assert ax.get_ylabel() == old_label + elif kind == "hist": + assert ax.get_xlabel() == "" + assert ax.get_ylabel() == "Frequency" + else: + assert ax.get_ylabel() == "" + assert ax.get_xlabel() == old_label + + # old xlabel will be overridden and assigned ylabel will be used as ylabel + ax = ser.plot(kind=kind, ylabel=new_label, xlabel=new_label) + assert ax.get_ylabel() == new_label + assert ax.get_xlabel() == new_label + + @pytest.mark.parametrize( + "index", + [ + pd.timedelta_range(start=0, periods=2, freq="D"), + [pd.Timedelta(days=1), pd.Timedelta(days=2)], + ], + ) + def test_timedelta_index(self, index): + # GH37454 + xlims = (3, 1) + ax = Series([1, 2], index=index).plot(xlim=(xlims)) + assert ax.get_xlim() == (3, 1) + + def test_series_none_color(self): + # GH51953 + series = Series([1, 2, 3]) + ax = series.plot(color=None) + expected = _unpack_cycler(mpl.pyplot.rcParams)[:1] + _check_colors(ax.get_lines(), linecolors=expected) + + @pytest.mark.slow + def test_plot_no_warning(self, ts): + # GH 55138 + # TODO(3.0): this can be removed once Period[B] deprecation is enforced + with tm.assert_produces_warning(False): + _ = ts.plot() + + def test_secondary_y_subplot_axis_labels(self): + # GH#14102 + s1 = Series([5, 7, 6, 8, 7], index=[1, 2, 3, 4, 5]) + s2 = Series([6, 4, 5, 3, 4], index=[1, 2, 3, 4, 5]) + + ax = plt.subplot(2, 1, 1) + s1.plot(ax=ax) + s2.plot(ax=ax, secondary_y=True) + ax2 = plt.subplot(2, 1, 2) + s1.plot(ax=ax2) + assert len(ax.xaxis.get_minor_ticks()) == 0 + assert len(ax.get_xticklabels()) > 0 + + def test_bar_line_plot(self): + """ + Test that bar and line plots with the same x values are superposed + and that the x limits are set such that the plots are visible. + """ + # GH61161 + index = period_range("2023", periods=3, freq="Y") + years = set(index.year.astype(str)) + s = Series([1, 2, 3], index=index) + ax = plt.subplot() + s.plot(kind="bar", ax=ax) + bar_xticks = [ + label for label in ax.get_xticklabels() if label.get_text() in years + ] + s.plot(kind="line", ax=ax, color="r") + line_xticks = [ + label for label in ax.get_xticklabels() if label.get_text() in years + ] + assert len(bar_xticks) == len(index) + assert bar_xticks == line_xticks + x_limits = ax.get_xlim() + assert x_limits[0] <= bar_xticks[0].get_position()[0] + assert x_limits[1] >= bar_xticks[-1].get_position()[0] diff --git a/pandas/tests/plotting/test_style.py b/pandas/tests/plotting/test_style.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c89e0a7893f501b2f8f7f0916ca1a008201b7d --- /dev/null +++ b/pandas/tests/plotting/test_style.py @@ -0,0 +1,149 @@ +import pytest + +from pandas import Series + +mpl = pytest.importorskip("matplotlib") +plt = pytest.importorskip("matplotlib.pyplot") +from pandas.plotting._matplotlib.style import get_standard_colors + + +class TestGetStandardColors: + @pytest.mark.parametrize( + "num_colors, expected", + [ + (3, ["red", "green", "blue"]), + (5, ["red", "green", "blue", "red", "green"]), + (7, ["red", "green", "blue", "red", "green", "blue", "red"]), + (2, ["red", "green"]), + (1, ["red"]), + ], + ) + def test_default_colors_named_from_prop_cycle(self, num_colors, expected): + mpl_params = { + "axes.prop_cycle": plt.cycler(color=["red", "green", "blue"]), + } + with mpl.rc_context(rc=mpl_params): + result = get_standard_colors(num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize( + "num_colors, expected", + [ + (1, ["b"]), + (3, ["b", "g", "r"]), + (4, ["b", "g", "r", "y"]), + (5, ["b", "g", "r", "y", "b"]), + (7, ["b", "g", "r", "y", "b", "g", "r"]), + ], + ) + def test_default_colors_named_from_prop_cycle_string(self, num_colors, expected): + mpl_params = { + "axes.prop_cycle": plt.cycler(color="bgry"), + } + with mpl.rc_context(rc=mpl_params): + result = get_standard_colors(num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize( + "num_colors, expected_name", + [ + (1, ["C0"]), + (3, ["C0", "C1", "C2"]), + ( + 12, + [ + "C0", + "C1", + "C2", + "C3", + "C4", + "C5", + "C6", + "C7", + "C8", + "C9", + "C0", + "C1", + ], + ), + ], + ) + def test_default_colors_named_undefined_prop_cycle(self, num_colors, expected_name): + with mpl.rc_context(rc={}): + expected = [mpl.colors.to_hex(x) for x in expected_name] + result = get_standard_colors(num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize( + "num_colors, expected", + [ + (1, ["red", "green", (0.1, 0.2, 0.3)]), + (2, ["red", "green", (0.1, 0.2, 0.3)]), + (3, ["red", "green", (0.1, 0.2, 0.3)]), + (4, ["red", "green", (0.1, 0.2, 0.3), "red"]), + ], + ) + def test_user_input_color_sequence(self, num_colors, expected): + color = ["red", "green", (0.1, 0.2, 0.3)] + result = get_standard_colors(color=color, num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize( + "num_colors, expected", + [ + (1, ["r", "g", "b", "k"]), + (2, ["r", "g", "b", "k"]), + (3, ["r", "g", "b", "k"]), + (4, ["r", "g", "b", "k"]), + (5, ["r", "g", "b", "k", "r"]), + (6, ["r", "g", "b", "k", "r", "g"]), + ], + ) + def test_user_input_color_string(self, num_colors, expected): + color = "rgbk" + result = get_standard_colors(color=color, num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize( + "num_colors, expected", + [ + (1, [(0.1, 0.2, 0.3)]), + (2, [(0.1, 0.2, 0.3), (0.1, 0.2, 0.3)]), + (3, [(0.1, 0.2, 0.3), (0.1, 0.2, 0.3), (0.1, 0.2, 0.3)]), + ], + ) + def test_user_input_color_floats(self, num_colors, expected): + color = (0.1, 0.2, 0.3) + result = get_standard_colors(color=color, num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize( + "color, num_colors, expected", + [ + ("Crimson", 1, ["Crimson"]), + ("DodgerBlue", 2, ["DodgerBlue", "DodgerBlue"]), + ("firebrick", 3, ["firebrick", "firebrick", "firebrick"]), + ], + ) + def test_user_input_named_color_string(self, color, num_colors, expected): + result = get_standard_colors(color=color, num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize("color", ["", [], (), Series([], dtype="object")]) + def test_empty_color_raises(self, color): + with pytest.raises(ValueError, match="Invalid color argument"): + get_standard_colors(color=color, num_colors=1) + + @pytest.mark.parametrize( + "color", + [ + "bad_color", + ("red", "green", "bad_color"), + (0.1,), + (0.1, 0.2), + (0.1, 0.2, 0.3, 0.4, 0.5), # must be either 3 or 4 floats + ], + ) + def test_bad_color_raises(self, color): + with pytest.raises(ValueError, match="Invalid color"): + get_standard_colors(color=color, num_colors=5) diff --git a/pandas/tests/resample/__init__.py b/pandas/tests/resample/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/resample/conftest.py b/pandas/tests/resample/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..6c45ece5d8fb963d32945877b4af6521ea7b9450 --- /dev/null +++ b/pandas/tests/resample/conftest.py @@ -0,0 +1,33 @@ +import pytest + +# The various methods we support +downsample_methods = [ + "min", + "max", + "first", + "last", + "sum", + "mean", + "sem", + "median", + "prod", + "var", + "std", + "ohlc", + "quantile", +] +upsample_methods = ["count", "size"] +series_methods = ["nunique"] +resample_methods = downsample_methods + upsample_methods + series_methods + + +@pytest.fixture(params=downsample_methods) +def downsample_method(request): + """Fixture for parametrization of Grouper downsample methods.""" + return request.param + + +@pytest.fixture(params=resample_methods) +def resample_method(request): + """Fixture for parametrization of Grouper resample methods.""" + return request.param diff --git a/pandas/tests/resample/test_base.py b/pandas/tests/resample/test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..359ad72bd67f32d472d502550dfb28285b33a4af --- /dev/null +++ b/pandas/tests/resample/test_base.py @@ -0,0 +1,554 @@ +from datetime import datetime + +import numpy as np +import pytest + +from pandas.errors import Pandas4Warning + +from pandas.core.dtypes.common import is_extension_array_dtype + +import pandas as pd +from pandas import ( + DataFrame, + DatetimeIndex, + Index, + MultiIndex, + NaT, + PeriodIndex, + Series, + TimedeltaIndex, +) +import pandas._testing as tm +from pandas.core.groupby.groupby import DataError +from pandas.core.groupby.grouper import Grouper +from pandas.core.indexes.datetimes import date_range +from pandas.core.indexes.period import period_range +from pandas.core.indexes.timedeltas import timedelta_range +from pandas.core.resample import _asfreq_compat + + +@pytest.fixture( + params=[ + "linear", + "time", + "index", + "values", + "nearest", + "zero", + "slinear", + "quadratic", + "cubic", + "barycentric", + "krogh", + "from_derivatives", + "piecewise_polynomial", + "pchip", + "akima", + ], +) +def all_1d_no_arg_interpolation_methods(request): + return request.param + + +@pytest.mark.parametrize("freq", ["2D", "1h"]) +@pytest.mark.parametrize( + "index", + [ + timedelta_range("1 day", "10 day", freq="D"), + date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D"), + ], +) +def test_asfreq(frame_or_series, index, freq): + obj = frame_or_series(range(len(index)), index=index) + idx_range = date_range if isinstance(index, DatetimeIndex) else timedelta_range + + result = obj.resample(freq).asfreq() + new_index = idx_range(obj.index[0], obj.index[-1], freq=freq) + expected = obj.reindex(new_index) + tm.assert_almost_equal(result, expected) + + +@pytest.mark.parametrize( + "index", + [ + timedelta_range("1 day", "10 day", freq="D"), + date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D"), + ], +) +def test_asfreq_fill_value(index): + # test for fill value during resampling, issue 3715 + + ser = Series(range(len(index)), index=index, name="a") + idx_range = date_range if isinstance(index, DatetimeIndex) else timedelta_range + + result = ser.resample("1h").asfreq() + new_index = idx_range(ser.index[0], ser.index[-1], freq="1h") + expected = ser.reindex(new_index) + tm.assert_series_equal(result, expected) + + # Explicit cast to float to avoid implicit cast when setting None + frame = ser.astype("float").to_frame("value") + frame.iloc[1] = None + result = frame.resample("1h").asfreq(fill_value=4.0) + new_index = idx_range(frame.index[0], frame.index[-1], freq="1h") + expected = frame.reindex(new_index, fill_value=4.0) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "index", + [ + timedelta_range("1 day", "3 day", freq="D"), + date_range(datetime(2005, 1, 1), datetime(2005, 1, 3), freq="D"), + period_range(datetime(2005, 1, 1), datetime(2005, 1, 3), freq="D"), + ], +) +def test_resample_interpolate(index): + # GH#12925 + df = DataFrame(range(len(index)), index=index) + result = df.resample("1min").asfreq().interpolate() + expected = df.resample("1min").interpolate() + tm.assert_frame_equal(result, expected) + + +def test_resample_interpolate_inplace_deprecated(): + # GH#58690 + dti = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + + df = DataFrame(range(len(dti)), index=dti) + rs = df.resample("1min") + msg = "The 'inplace' keyword in DatetimeIndexResampler.interpolate" + with tm.assert_produces_warning(Pandas4Warning, match=msg): + rs.interpolate(inplace=False) + + msg2 = "Cannot interpolate inplace on a resampled object" + with pytest.raises(ValueError, match=msg2): + with tm.assert_produces_warning(Pandas4Warning, match=msg): + rs.interpolate(inplace=True) + + +def test_resample_interpolate_regular_sampling_off_grid( + all_1d_no_arg_interpolation_methods, +): + pytest.importorskip("scipy") + # GH#21351 + index = date_range("2000-01-01 00:01:00", periods=5, freq="2h") + ser = Series(np.arange(5.0), index) + + method = all_1d_no_arg_interpolation_methods + result = ser.resample("1h").interpolate(method) + + if method == "linear": + values = np.repeat(np.arange(0.0, 4.0), 2) + np.tile([1 / 3, 2 / 3], 4) + elif method == "nearest": + values = np.repeat(np.arange(0.0, 5.0), 2)[1:-1] + elif method == "zero": + values = np.repeat(np.arange(0.0, 4.0), 2) + else: + values = 0.491667 + np.arange(0.0, 4.0, 0.5) + values = np.insert(values, 0, np.nan) + index = date_range("2000-01-01 00:00:00", periods=9, freq="1h") + expected = Series(values, index=index) + tm.assert_series_equal(result, expected) + + +def test_resample_interpolate_irregular_sampling(all_1d_no_arg_interpolation_methods): + pytest.importorskip("scipy") + # GH#21351 + ser = Series( + np.linspace(0.0, 1.0, 5), + index=DatetimeIndex( + [ + "2000-01-01 00:00:03", + "2000-01-01 00:00:22", + "2000-01-01 00:00:24", + "2000-01-01 00:00:31", + "2000-01-01 00:00:39", + ] + ), + ) + + # Resample to 5 second sampling and interpolate with the given method + ser_resampled = ser.resample("5s").interpolate(all_1d_no_arg_interpolation_methods) + + # Check that none of the resampled values are NaN, except the first one + # which lies 3 seconds before the first actual data point + assert np.isnan(ser_resampled.iloc[0]) + assert not ser_resampled.iloc[1:].isna().any() + + +def test_raises_on_non_datetimelike_index(): + # this is a non datetimelike index + xp = DataFrame() + msg = ( + "Only valid with DatetimeIndex, TimedeltaIndex or PeriodIndex, " + "but got an instance of 'RangeIndex'" + ) + with pytest.raises(TypeError, match=msg): + xp.resample("YE") + + +@pytest.mark.parametrize( + "index", + [ + PeriodIndex([], freq="D", name="a"), + DatetimeIndex([], name="a"), + TimedeltaIndex([], name="a"), + ], +) +@pytest.mark.parametrize("freq", ["ME", "D", "h"]) +def test_resample_empty_series(freq, index, resample_method): + # GH12771 & GH12868 + + ser = Series(index=index, dtype=float) + if freq == "ME" and isinstance(ser.index, TimedeltaIndex): + msg = ( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + "e.g. '24h' or '3D', not " + ) + with pytest.raises(ValueError, match=msg): + ser.resample(freq) + return + elif freq == "ME" and isinstance(ser.index, PeriodIndex): + # index is PeriodIndex, so convert to corresponding Period freq + freq = "M" + rs = ser.resample(freq) + result = getattr(rs, resample_method)() + + if resample_method == "ohlc": + expected = DataFrame( + [], index=ser.index[:0], columns=["open", "high", "low", "close"] + ) + expected.index = _asfreq_compat(ser.index, freq) + tm.assert_frame_equal(result, expected, check_dtype=False) + else: + expected = ser.copy() + expected.index = _asfreq_compat(ser.index, freq) + tm.assert_series_equal(result, expected, check_dtype=False) + + tm.assert_index_equal(result.index, expected.index) + assert result.index.freq == expected.index.freq + + +@pytest.mark.parametrize("min_count", [0, 1]) +def test_resample_empty_sum_string(string_dtype_no_object, min_count): + # https://github.com/pandas-dev/pandas/issues/60229 + dtype = string_dtype_no_object + ser = Series( + pd.NA, + index=DatetimeIndex( + [ + "2000-01-01 00:00:00", + "2000-01-01 00:00:10", + "2000-01-01 00:00:20", + "2000-01-01 00:00:30", + ] + ), + dtype=dtype, + ) + rs = ser.resample("20s") + result = rs.sum(min_count=min_count) + + value = "" if min_count == 0 else pd.NA + index = date_range(start="2000-01-01", freq="20s", periods=2, unit="us") + expected = Series(value, index=index, dtype=dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "freq", + [ + pytest.param("ME", marks=pytest.mark.xfail(reason="Don't know why this fails")), + "D", + "h", + ], +) +def test_resample_nat_index_series(freq, resample_method): + # GH39227 + + ser = Series(range(5), index=PeriodIndex([NaT] * 5, freq=freq)) + + rs = ser.resample(freq) + result = getattr(rs, resample_method)() + + if resample_method == "ohlc": + expected = DataFrame( + [], index=ser.index[:0], columns=["open", "high", "low", "close"] + ) + tm.assert_frame_equal(result, expected, check_dtype=False) + else: + expected = ser[:0].copy() + tm.assert_series_equal(result, expected, check_dtype=False) + tm.assert_index_equal(result.index, expected.index) + assert result.index.freq == expected.index.freq + + +@pytest.mark.parametrize( + "index", + [ + PeriodIndex([], freq="D", name="a"), + DatetimeIndex([], name="a"), + TimedeltaIndex([], name="a"), + ], +) +@pytest.mark.parametrize("freq", ["ME", "D", "h"]) +@pytest.mark.parametrize("resample_method", ["count", "size"]) +def test_resample_count_empty_series(freq, index, resample_method): + # GH28427 + ser = Series(index=index) + if freq == "ME" and isinstance(ser.index, TimedeltaIndex): + msg = ( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + "e.g. '24h' or '3D', not " + ) + with pytest.raises(ValueError, match=msg): + ser.resample(freq) + return + elif freq == "ME" and isinstance(ser.index, PeriodIndex): + # index is PeriodIndex, so convert to corresponding Period freq + freq = "M" + rs = ser.resample(freq) + + result = getattr(rs, resample_method)() + + index = _asfreq_compat(ser.index, freq) + + expected = Series([], dtype="int64", index=index, name=ser.name) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "index", [DatetimeIndex([]), TimedeltaIndex([]), PeriodIndex([], freq="D")] +) +@pytest.mark.parametrize("freq", ["ME", "D", "h"]) +def test_resample_empty_dataframe(index, freq, resample_method): + # GH13212 + df = DataFrame(index=index) + # count retains dimensions too + if freq == "ME" and isinstance(df.index, TimedeltaIndex): + msg = ( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + "e.g. '24h' or '3D', not " + ) + with pytest.raises(ValueError, match=msg): + df.resample(freq, group_keys=False) + return + elif freq == "ME" and isinstance(df.index, PeriodIndex): + # index is PeriodIndex, so convert to corresponding Period freq + freq = "M" + rs = df.resample(freq, group_keys=False) + result = getattr(rs, resample_method)() + if resample_method == "ohlc": + # TODO: no tests with len(df.columns) > 0 + mi = MultiIndex.from_product([df.columns, ["open", "high", "low", "close"]]) + expected = DataFrame([], index=df.index[:0], columns=mi, dtype=np.float64) + expected.index = _asfreq_compat(df.index, freq) + + elif resample_method != "size": + expected = df.copy() + else: + # GH14962 + expected = Series([], dtype=np.int64) + + expected.index = _asfreq_compat(df.index, freq) + + tm.assert_index_equal(result.index, expected.index) + assert result.index.freq == expected.index.freq + tm.assert_almost_equal(result, expected) + + # test size for GH13212 (currently stays as df) + + +@pytest.mark.parametrize( + "index", [DatetimeIndex([]), TimedeltaIndex([]), PeriodIndex([], freq="D")] +) +@pytest.mark.parametrize("freq", ["ME", "D", "h"]) +def test_resample_count_empty_dataframe(freq, index): + # GH28427 + empty_frame_dti = DataFrame(index=index, columns=Index(["a"], dtype=object)) + + if freq == "ME" and isinstance(empty_frame_dti.index, TimedeltaIndex): + msg = ( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + "e.g. '24h' or '3D', not " + ) + with pytest.raises(ValueError, match=msg): + empty_frame_dti.resample(freq) + return + elif freq == "ME" and isinstance(empty_frame_dti.index, PeriodIndex): + # index is PeriodIndex, so convert to corresponding Period freq + freq = "M" + result = empty_frame_dti.resample(freq).count() + + index = _asfreq_compat(empty_frame_dti.index, freq) + + expected = DataFrame(dtype="int64", index=index, columns=Index(["a"], dtype=object)) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "index", [DatetimeIndex([]), TimedeltaIndex([]), PeriodIndex([], freq="D")] +) +@pytest.mark.parametrize("freq", ["ME", "D", "h"]) +def test_resample_size_empty_dataframe(freq, index): + # GH28427 + + empty_frame_dti = DataFrame(index=index, columns=Index(["a"], dtype=object)) + + if freq == "ME" and isinstance(empty_frame_dti.index, TimedeltaIndex): + msg = ( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + "e.g. '24h' or '3D', not " + ) + with pytest.raises(ValueError, match=msg): + empty_frame_dti.resample(freq) + return + elif freq == "ME" and isinstance(empty_frame_dti.index, PeriodIndex): + # index is PeriodIndex, so convert to corresponding Period freq + freq = "M" + result = empty_frame_dti.resample(freq).size() + + index = _asfreq_compat(empty_frame_dti.index, freq) + + expected = Series([], dtype="int64", index=index) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("index", [DatetimeIndex([]), TimedeltaIndex([])]) +@pytest.mark.parametrize("freq", ["D", "h"]) +@pytest.mark.parametrize( + "method", ["ffill", "bfill", "nearest", "asfreq", "interpolate", "mean"] +) +def test_resample_apply_empty_dataframe(index, freq, method): + # GH#55572 + empty_frame_dti = DataFrame(index=index) + + rs = empty_frame_dti.resample(freq) + result = rs.apply(getattr(rs, method)) + + expected_index = _asfreq_compat(empty_frame_dti.index, freq) + expected = DataFrame([], index=expected_index) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "index", + [ + PeriodIndex([], freq="M", name="a"), + DatetimeIndex([], name="a"), + TimedeltaIndex([], name="a"), + ], +) +@pytest.mark.parametrize("dtype", [float, int, object, "datetime64[ns]"]) +def test_resample_empty_dtypes(index, dtype, resample_method): + # Empty series were sometimes causing a segfault (for the functions + # with Cython bounds-checking disabled) or an IndexError. We just run + # them to ensure they no longer do. (GH #10228) + empty_series_dti = Series([], index, dtype) + rs = empty_series_dti.resample("D", group_keys=False) + try: + getattr(rs, resample_method)() + except DataError: + # Ignore these since some combinations are invalid + # (ex: doing mean with dtype of np.object_) + pass + + +@pytest.mark.parametrize( + "index", + [ + PeriodIndex([], freq="D", name="a"), + DatetimeIndex([], name="a"), + TimedeltaIndex([], name="a"), + ], +) +@pytest.mark.parametrize("freq", ["ME", "D", "h"]) +def test_apply_to_empty_series(index, freq): + # GH 14313 + ser = Series(index=index) + + if freq == "ME" and isinstance(ser.index, TimedeltaIndex): + msg = ( + "Resampling on a TimedeltaIndex requires fixed-duration `freq`, " + "e.g. '24h' or '3D', not " + ) + with pytest.raises(ValueError, match=msg): + ser.resample(freq) + return + elif freq == "ME" and isinstance(ser.index, PeriodIndex): + # index is PeriodIndex, so convert to corresponding Period freq + freq = "M" + result = ser.resample(freq, group_keys=False).apply(lambda x: 1) + expected = ser.resample(freq).apply("sum") + + tm.assert_series_equal(result, expected, check_dtype=False) + + +@pytest.mark.parametrize( + "index", + [ + timedelta_range("1 day", "10 day", freq="D"), + date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D"), + period_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D"), + ], +) +def test_resampler_is_iterable(index): + # GH 15314 + series = Series(range(len(index)), index=index) + freq = "h" + tg = Grouper(freq=freq, convention="start") + grouped = series.groupby(tg) + resampled = series.resample(freq) + for (rk, rv), (gk, gv) in zip(resampled, grouped): + assert rk == gk + tm.assert_series_equal(rv, gv) + + +@pytest.mark.parametrize( + "index", + [ + timedelta_range("1 day", "10 day", freq="D"), + date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D"), + period_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D"), + ], +) +def test_resample_quantile(index): + # GH 15023 + ser = Series(range(len(index)), index=index) + q = 0.75 + freq = "h" + + result = ser.resample(freq).quantile(q) + expected = ser.resample(freq).agg(lambda x: x.quantile(q)).rename(ser.name) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("how", ["first", "last"]) +def test_first_last_skipna(any_real_nullable_dtype, skipna, how): + # GH#57019 + if is_extension_array_dtype(any_real_nullable_dtype): + na_value = Series(dtype=any_real_nullable_dtype).dtype.na_value + else: + na_value = np.nan + df = DataFrame( + { + "a": [2, 1, 1, 2], + "b": [na_value, 3.0, na_value, 4.0], + "c": [na_value, 3.0, na_value, 4.0], + }, + index=date_range("2020-01-01", periods=4, freq="D", unit="ns"), + dtype=any_real_nullable_dtype, + ) + rs = df.resample("ME") + method = getattr(rs, how) + result = method(skipna=skipna) + + ts = pd.to_datetime("2020-01-31").as_unit("ns") + gb = df.groupby(df.shape[0] * [ts]) + expected = getattr(gb, how)(skipna=skipna) + expected.index.freq = "ME" + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/resample/test_datetime_index.py b/pandas/tests/resample/test_datetime_index.py new file mode 100644 index 0000000000000000000000000000000000000000..6867b6cf9927142888a066717a15278fbfee9013 --- /dev/null +++ b/pandas/tests/resample/test_datetime_index.py @@ -0,0 +1,2190 @@ +from datetime import datetime +from functools import partial +import zoneinfo + +import numpy as np +import pytest + +from pandas._libs import lib +from pandas._libs.tslibs import Day +from pandas._typing import DatetimeNaTType +from pandas.compat import is_platform_windows +from pandas.compat.pyarrow import pa_version_under22p0 +from pandas.errors import Pandas4Warning +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + Timedelta, + Timestamp, + isna, + notna, +) +import pandas._testing as tm +from pandas.core.groupby.grouper import Grouper +from pandas.core.indexes.datetimes import date_range +from pandas.core.indexes.period import ( + Period, + period_range, +) +from pandas.core.resample import ( + DatetimeIndex, + _get_timestamp_range_edges, +) + +from pandas.tseries import offsets +from pandas.tseries.frequencies import to_offset +from pandas.tseries.offsets import Minute + + +@pytest.fixture +def simple_date_range_series(): + """ + Series with date range index and random data for test purposes. + """ + + def _simple_date_range_series(start, end, freq="D"): + rng = date_range(start, end, freq=freq) + return Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + return _simple_date_range_series + + +def test_custom_grouper(unit): + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="Min") + dti = index.as_unit(unit) + s = Series(np.array([1] * len(dti)), index=dti, dtype="int64") + + b = Grouper(freq=Minute(5)) + g = s.groupby(b) + + # check all cython functions work + g.ohlc() # doesn't use _cython_agg_general + funcs = ["sum", "mean", "prod", "min", "max", "var"] + for f in funcs: + g._cython_agg_general(f, alt=None, numeric_only=True) + + b = Grouper(freq=Minute(5), closed="right", label="right") + g = s.groupby(b) + # check all cython functions work + g.ohlc() # doesn't use _cython_agg_general + funcs = ["sum", "mean", "prod", "min", "max", "var"] + for f in funcs: + g._cython_agg_general(f, alt=None, numeric_only=True) + + assert g.ngroups == 2593 + assert notna(g.mean()).all() + + # construct expected val + arr = [1] + [5] * 2592 + idx = dti[0:-1:5] + idx = idx.append(dti[-1:]) + idx = DatetimeIndex(idx, freq="5min").as_unit(unit) + expect = Series(arr, index=idx) + + # GH2763 - return input dtype if we can + result = g.agg("sum") + tm.assert_series_equal(result, expect) + + +def test_custom_grouper_df(unit): + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + b = Grouper(freq=Minute(5), closed="right", label="right") + dti = index.as_unit(unit) + df = DataFrame( + np.random.default_rng(2).random((len(dti), 10)), index=dti, dtype="float64" + ) + r = df.groupby(b).agg("sum") + + assert len(r.columns) == 10 + assert len(r.index) == 2593 + + +@pytest.mark.parametrize( + "closed, expected", + [ + ( + "right", + lambda s: Series( + [s.iloc[0], s[1:6].mean(), s[6:11].mean(), s[11:].mean()], + index=date_range("1/1/2000", periods=4, freq="5min", name="index"), + ), + ), + ( + "left", + lambda s: Series( + [s[:5].mean(), s[5:10].mean(), s[10:].mean()], + index=date_range( + "1/1/2000 00:05", periods=3, freq="5min", name="index" + ), + ), + ), + ], +) +def test_resample_basic(closed, expected, unit): + index = date_range("1/1/2000 00:00:00", "1/1/2000 00:13:00", freq="Min") + s = Series(range(len(index)), index=index) + s.index.name = "index" + s.index = s.index.as_unit(unit) + expected = expected(s) + expected.index = expected.index.as_unit(unit) + result = s.resample("5min", closed=closed, label="right").mean() + tm.assert_series_equal(result, expected) + + +def test_resample_integerarray(unit): + # GH 25580, resample on IntegerArray + ts = Series( + range(9), + index=date_range("1/1/2000", periods=9, freq="min").as_unit(unit), + dtype="Int64", + ) + result = ts.resample("3min").sum() + expected = Series( + [3, 12, 21], + index=date_range("1/1/2000", periods=3, freq="3min").as_unit(unit), + dtype="Int64", + ) + tm.assert_series_equal(result, expected) + + result = ts.resample("3min").mean() + expected = Series( + [1, 4, 7], + index=date_range("1/1/2000", periods=3, freq="3min").as_unit(unit), + dtype="Float64", + ) + tm.assert_series_equal(result, expected) + + +def test_resample_basic_grouper(unit): + index = date_range("1/1/2000 00:00:00", "1/1/2000 00:13:00", freq="Min") + s = Series(range(len(index)), index=index) + s.index.name = "index" + s.index = s.index.as_unit(unit) + result = s.resample("5Min").last() + grouper = Grouper(freq=Minute(5), closed="left", label="left") + expected = s.groupby(grouper).agg(lambda x: x.iloc[-1]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "keyword,value", + [("label", "righttt"), ("closed", "righttt"), ("convention", "starttt")], +) +def test_resample_string_kwargs(keyword, value, unit): + # see gh-19303 + # Check that wrong keyword argument strings raise an error + index = date_range("1/1/2000 00:00:00", "1/1/2000 00:13:00", freq="Min") + series = Series(range(len(index)), index=index) + series.index.name = "index" + series.index = series.index.as_unit(unit) + msg = f"Unsupported value {value} for `{keyword}`" + with pytest.raises(ValueError, match=msg): + series.resample("5min", **({keyword: value})) + + +def test_resample_how(downsample_method, unit): + if downsample_method == "ohlc": + pytest.skip("covered by test_resample_how_ohlc") + index = date_range("1/1/2000 00:00:00", "1/1/2000 00:13:00", freq="Min") + s = Series(range(len(index)), index=index) + s.index.name = "index" + s.index = s.index.as_unit(unit) + grouplist = np.ones_like(s) + grouplist[0] = 0 + grouplist[1:6] = 1 + grouplist[6:11] = 2 + grouplist[11:] = 3 + expected = s.groupby(grouplist).agg(downsample_method) + expected.index = date_range( + "1/1/2000", periods=4, freq="5min", name="index" + ).as_unit(unit) + + result = getattr( + s.resample("5min", closed="right", label="right"), downsample_method + )() + tm.assert_series_equal(result, expected) + + +def test_resample_how_ohlc(unit): + index = date_range("1/1/2000 00:00:00", "1/1/2000 00:13:00", freq="Min") + s = Series(range(len(index)), index=index) + s.index.name = "index" + s.index = s.index.as_unit(unit) + grouplist = np.ones_like(s) + grouplist[0] = 0 + grouplist[1:6] = 1 + grouplist[6:11] = 2 + grouplist[11:] = 3 + + def _ohlc(group): + if isna(group).all(): + return np.repeat(np.nan, 4) + return [group.iloc[0], group.max(), group.min(), group.iloc[-1]] + + expected = DataFrame( + s.groupby(grouplist).agg(_ohlc).values.tolist(), + index=date_range("1/1/2000", periods=4, freq="5min", name="index").as_unit( + unit + ), + columns=["open", "high", "low", "close"], + ) + + result = s.resample("5min", closed="right", label="right").ohlc() + tm.assert_frame_equal(result, expected) + + +def test_resample_how_callables(unit): + # GH#7929 + data = np.arange(5, dtype=np.int64) + msg = "'d' is deprecated and will be removed in a future version." + with tm.assert_produces_warning(Pandas4Warning, match=msg): + ind = date_range(start="2014-01-01", periods=len(data), freq="d").as_unit(unit) + df = DataFrame({"A": data, "B": data}, index=ind) + + def fn(x, a=1): + return str(type(x)) + + class FnClass: + def __call__(self, x): + return str(type(x)) + + df_standard = df.resample("ME").apply(fn) + df_lambda = df.resample("ME").apply(lambda x: str(type(x))) + df_partial = df.resample("ME").apply(partial(fn)) + df_partial2 = df.resample("ME").apply(partial(fn, a=2)) + df_class = df.resample("ME").apply(FnClass()) + + tm.assert_frame_equal(df_standard, df_lambda) + tm.assert_frame_equal(df_standard, df_partial) + tm.assert_frame_equal(df_standard, df_partial2) + tm.assert_frame_equal(df_standard, df_class) + + +def test_resample_rounding(unit): + # GH 8371 + # odd results when rounding is needed + + ts = [ + "2014-11-08 00:00:01", + "2014-11-08 00:00:02", + "2014-11-08 00:00:02", + "2014-11-08 00:00:03", + "2014-11-08 00:00:07", + "2014-11-08 00:00:07", + "2014-11-08 00:00:08", + "2014-11-08 00:00:08", + "2014-11-08 00:00:08", + "2014-11-08 00:00:09", + "2014-11-08 00:00:10", + "2014-11-08 00:00:11", + "2014-11-08 00:00:11", + "2014-11-08 00:00:13", + "2014-11-08 00:00:14", + "2014-11-08 00:00:15", + "2014-11-08 00:00:17", + "2014-11-08 00:00:20", + "2014-11-08 00:00:21", + ] + df = DataFrame({"value": [1] * 19}, index=pd.to_datetime(ts)) + df.index = df.index.as_unit(unit) + + result = df.resample("6s").sum() + expected = DataFrame( + {"value": [4, 9, 4, 2]}, + index=date_range("2014-11-08", freq="6s", periods=4).as_unit(unit), + ) + tm.assert_frame_equal(result, expected) + + result = df.resample("7s").sum() + expected = DataFrame( + {"value": [4, 10, 4, 1]}, + index=date_range("2014-11-08", freq="7s", periods=4).as_unit(unit), + ) + tm.assert_frame_equal(result, expected) + + result = df.resample("11s").sum() + expected = DataFrame( + {"value": [11, 8]}, + index=date_range("2014-11-08", freq="11s", periods=2).as_unit(unit), + ) + tm.assert_frame_equal(result, expected) + + result = df.resample("13s").sum() + expected = DataFrame( + {"value": [13, 6]}, + index=date_range("2014-11-08", freq="13s", periods=2).as_unit(unit), + ) + tm.assert_frame_equal(result, expected) + + result = df.resample("17s").sum() + expected = DataFrame( + {"value": [16, 3]}, + index=date_range("2014-11-08", freq="17s", periods=2).as_unit(unit), + ) + tm.assert_frame_equal(result, expected) + + +def test_resample_basic_from_daily(unit): + # from daily + dti = date_range( + start=datetime(2005, 1, 1), end=datetime(2005, 1, 10), freq="D", name="index" + ).as_unit(unit) + + s = Series(np.random.default_rng(2).random(len(dti)), dti) + + # to weekly + msg = "'w-sun' is deprecated and will be removed in a future version." + with tm.assert_produces_warning(Pandas4Warning, match=msg): + result = s.resample("w-sun").last() + + assert len(result) == 3 + assert (result.index.dayofweek == [6, 6, 6]).all() + assert result.iloc[0] == s["1/2/2005"] + assert result.iloc[1] == s["1/9/2005"] + assert result.iloc[2] == s.iloc[-1] + + result = s.resample("W-MON").last() + assert len(result) == 2 + assert (result.index.dayofweek == [0, 0]).all() + assert result.iloc[0] == s["1/3/2005"] + assert result.iloc[1] == s["1/10/2005"] + + result = s.resample("W-TUE").last() + assert len(result) == 2 + assert (result.index.dayofweek == [1, 1]).all() + assert result.iloc[0] == s["1/4/2005"] + assert result.iloc[1] == s["1/10/2005"] + + result = s.resample("W-WED").last() + assert len(result) == 2 + assert (result.index.dayofweek == [2, 2]).all() + assert result.iloc[0] == s["1/5/2005"] + assert result.iloc[1] == s["1/10/2005"] + + result = s.resample("W-THU").last() + assert len(result) == 2 + assert (result.index.dayofweek == [3, 3]).all() + assert result.iloc[0] == s["1/6/2005"] + assert result.iloc[1] == s["1/10/2005"] + + result = s.resample("W-FRI").last() + assert len(result) == 2 + assert (result.index.dayofweek == [4, 4]).all() + assert result.iloc[0] == s["1/7/2005"] + assert result.iloc[1] == s["1/10/2005"] + + # to biz day + result = s.resample("B").last() + assert len(result) == 7 + assert (result.index.dayofweek == [4, 0, 1, 2, 3, 4, 0]).all() + + assert result.iloc[0] == s["1/2/2005"] + assert result.iloc[1] == s["1/3/2005"] + assert result.iloc[5] == s["1/9/2005"] + assert result.index.name == "index" + + +def test_resample_upsampling_picked_but_not_correct(unit): + # Test for issue #3020 + dates = date_range("01-Jan-2014", "05-Jan-2014", freq="D").as_unit(unit) + series = Series(1, index=dates) + + result = series.resample("D").mean() + assert result.index[0] == dates[0] + + # GH 5955 + # incorrect deciding to upsample when the axis frequency matches the + # resample frequency + + s = Series( + np.arange(1.0, 6), index=[datetime(1975, 1, i, 12, 0) for i in range(1, 6)] + ) + s.index = s.index.as_unit(unit) + expected = Series( + np.arange(1.0, 6), + index=date_range("19750101", periods=5, freq="D").as_unit(unit), + ) + + result = s.resample("D").count() + tm.assert_series_equal(result, Series(1, index=expected.index)) + + result1 = s.resample("D").sum() + result2 = s.resample("D").mean() + tm.assert_series_equal(result1, expected) + tm.assert_series_equal(result2, expected) + + +@pytest.mark.parametrize("f", ["sum", "mean", "prod", "min", "max", "var"]) +def test_resample_frame_basic_cy_funcs(f, unit): + df = DataFrame( + np.random.default_rng(2).standard_normal((50, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=50, freq="B"), + ) + df.index = df.index.as_unit(unit) + + b = Grouper(freq="ME") + g = df.groupby(b) + + # check all cython functions work + g._cython_agg_general(f, alt=None, numeric_only=True) + + +@pytest.mark.parametrize("freq", ["YE", "ME"]) +def test_resample_frame_basic_M_A(freq, unit): + df = DataFrame( + np.random.default_rng(2).standard_normal((50, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=50, freq="B"), + ) + df.index = df.index.as_unit(unit) + result = df.resample(freq).mean() + tm.assert_series_equal(result["A"], df["A"].resample(freq).mean()) + + +def test_resample_upsample(unit): + # from daily + dti = date_range( + start=datetime(2005, 1, 1), end=datetime(2005, 1, 10), freq="D", name="index" + ).as_unit(unit) + + s = Series(np.random.default_rng(2).random(len(dti)), dti) + + # to minutely, by padding + result = s.resample("Min").ffill() + assert len(result) == 12961 + assert result.iloc[0] == s.iloc[0] + assert result.iloc[-1] == s.iloc[-1] + + assert result.index.name == "index" + + +def test_resample_how_method(unit): + # GH9915 + s = Series( + [11, 22], + index=[ + Timestamp("2015-03-31 21:48:52.672000"), + Timestamp("2015-03-31 21:49:52.739000"), + ], + ) + s.index = s.index.as_unit(unit) + expected = Series( + [11, np.nan, np.nan, np.nan, np.nan, np.nan, 22], + index=DatetimeIndex( + [ + Timestamp("2015-03-31 21:48:50"), + Timestamp("2015-03-31 21:49:00"), + Timestamp("2015-03-31 21:49:10"), + Timestamp("2015-03-31 21:49:20"), + Timestamp("2015-03-31 21:49:30"), + Timestamp("2015-03-31 21:49:40"), + Timestamp("2015-03-31 21:49:50"), + ], + freq="10s", + ), + ) + expected.index = expected.index.as_unit(unit) + tm.assert_series_equal(s.resample("10s").mean(), expected) + + +def test_resample_extra_index_point(unit): + # GH#9756 + index = date_range(start="20150101", end="20150331", freq="BME").as_unit(unit) + expected = DataFrame({"A": Series([21, 41, 63], index=index)}) + + index = date_range(start="20150101", end="20150331", freq="B").as_unit(unit) + df = DataFrame({"A": Series(range(len(index)), index=index)}, dtype="int64") + result = df.resample("BME").last() + tm.assert_frame_equal(result, expected) + + +def test_upsample_with_limit(unit): + rng = date_range("1/1/2000", periods=3, freq="5min").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + + result = ts.resample("min").ffill(limit=2) + expected = ts.reindex(result.index, method="ffill", limit=2) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("freq", ["1D", "10h", "5Min", "10s"]) +@pytest.mark.parametrize("rule", ["YE", "3ME", "15D", "30h", "15Min", "30s"]) +def test_nearest_upsample_with_limit(tz_aware_fixture, freq, rule, unit): + # GH 33939 + rng = date_range("1/1/2000", periods=3, freq=freq, tz=tz_aware_fixture).as_unit( + unit + ) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + + result = ts.resample(rule).nearest(limit=2) + expected = ts.reindex(result.index, method="nearest", limit=2) + tm.assert_series_equal(result, expected) + + +def test_resample_ohlc(unit): + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 2), freq="Min") + s = Series(range(len(index)), index=index) + s.index.name = "index" + s.index = s.index.as_unit(unit) + + grouper = Grouper(freq=Minute(5)) + expect = s.groupby(grouper).agg(lambda x: x.iloc[-1]) + result = s.resample("5Min").ohlc() + + assert len(result) == len(expect) + assert len(result.columns) == 4 + + xs = result.iloc[-2] + assert xs["open"] == s.iloc[-6] + assert xs["high"] == s[-6:-1].max() + assert xs["low"] == s[-6:-1].min() + assert xs["close"] == s.iloc[-2] + + xs = result.iloc[0] + assert xs["open"] == s.iloc[0] + assert xs["high"] == s[:5].max() + assert xs["low"] == s[:5].min() + assert xs["close"] == s.iloc[4] + + +def test_resample_ohlc_result(unit): + # GH 12332 + index = date_range("1-1-2000", "2-15-2000", freq="h").as_unit(unit) + index = index.union(date_range("4-15-2000", "5-15-2000", freq="h").as_unit(unit)) + s = Series(range(len(index)), index=index) + + a = s.loc[:"4-15-2000"].resample("30min").ohlc() + assert isinstance(a, DataFrame) + + b = s.loc[:"4-14-2000"].resample("30min").ohlc() + assert isinstance(b, DataFrame) + + +def test_resample_ohlc_result_odd_period(unit): + # GH12348 + # raising on odd period + rng = date_range("2013-12-30", "2014-01-07").as_unit(unit) + index = rng.drop( + [ + Timestamp("2014-01-01"), + Timestamp("2013-12-31"), + Timestamp("2014-01-04"), + Timestamp("2014-01-05"), + ] + ) + df = DataFrame(data=np.arange(len(index)), index=index) + result = df.resample("B").mean() + expected = df.reindex(index=date_range(rng[0], rng[-1], freq="B").as_unit(unit)) + tm.assert_frame_equal(result, expected) + + +def test_resample_ohlc_dataframe(unit): + df = ( + DataFrame( + { + "PRICE": { + Timestamp("2011-01-06 10:59:05", tz=None): 24990, + Timestamp("2011-01-06 12:43:33", tz=None): 25499, + Timestamp("2011-01-06 12:54:09", tz=None): 25499, + }, + "VOLUME": { + Timestamp("2011-01-06 10:59:05", tz=None): 1500000000, + Timestamp("2011-01-06 12:43:33", tz=None): 5000000000, + Timestamp("2011-01-06 12:54:09", tz=None): 100000000, + }, + } + ) + ).reindex(["VOLUME", "PRICE"], axis=1) + df.index = df.index.as_unit(unit) + df.columns.name = "Cols" + res = df.resample("h").ohlc() + exp = pd.concat( + [df["VOLUME"].resample("h").ohlc(), df["PRICE"].resample("h").ohlc()], + axis=1, + keys=df.columns, + ) + assert exp.columns.names[0] == "Cols" + tm.assert_frame_equal(exp, res) + + df.columns = [["a", "b"], ["c", "d"]] + res = df.resample("h").ohlc() + exp.columns = pd.MultiIndex.from_tuples( + [ + ("a", "c", "open"), + ("a", "c", "high"), + ("a", "c", "low"), + ("a", "c", "close"), + ("b", "d", "open"), + ("b", "d", "high"), + ("b", "d", "low"), + ("b", "d", "close"), + ] + ) + tm.assert_frame_equal(exp, res) + + # dupe columns fail atm + # df.columns = ['PRICE', 'PRICE'] + + +def test_resample_reresample(unit): + dti = date_range( + start=datetime(2005, 1, 1), end=datetime(2005, 1, 10), freq="D" + ).as_unit(unit) + s = Series(np.random.default_rng(2).random(len(dti)), dti) + bs = s.resample("B", closed="right", label="right").mean() + result = bs.resample("8h").mean() + assert len(result) == 25 + assert isinstance(result.index.freq, offsets.DateOffset) + assert result.index.freq == offsets.Hour(8) + + +@pytest.mark.parametrize( + "freq, expected_kwargs", + [ + ["YE-DEC", {"start": "1990", "end": "2000", "freq": "Y-DEC"}], + ["YE-JUN", {"start": "1990", "end": "2000", "freq": "Y-JUN"}], + ["ME", {"start": "1990-01", "end": "2000-01", "freq": "M"}], + ], +) +def test_resample_timestamp_to_period( + simple_date_range_series, freq, expected_kwargs, unit +): + ts = simple_date_range_series("1/1/1990", "1/1/2000") + ts.index = ts.index.as_unit(unit) + + result = ts.resample(freq).mean().to_period() + expected = ts.resample(freq).mean() + expected.index = period_range(**expected_kwargs) + tm.assert_series_equal(result, expected) + + +def test_ohlc_5min(unit): + def _ohlc(group): + if isna(group).all(): + return np.repeat(np.nan, 4) + return [group.iloc[0], group.max(), group.min(), group.iloc[-1]] + + rng = date_range("1/1/2000 00:00:00", "1/1/2000 5:59:50", freq="10s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + resampled = ts.resample("5min", closed="right", label="right").ohlc() + + assert (resampled.loc["1/1/2000 00:00"] == ts.iloc[0]).all() + + exp = _ohlc(ts[1:31]) + assert (resampled.loc["1/1/2000 00:05"] == exp).all() + + exp = _ohlc(ts["1/1/2000 5:55:01":]) + assert (resampled.loc["1/1/2000 6:00:00"] == exp).all() + + +def test_downsample_non_unique(unit): + rng = date_range("1/1/2000", "2/29/2000").as_unit(unit) + rng2 = rng.repeat(5).values + ts = Series(np.random.default_rng(2).standard_normal(len(rng2)), index=rng2) + + result = ts.resample("ME").mean() + + expected = ts.groupby(lambda x: x.month).mean() + assert len(result) == 2 + tm.assert_almost_equal(result.iloc[0], expected[1]) + tm.assert_almost_equal(result.iloc[1], expected[2]) + + +def test_asfreq_non_unique(unit): + # GH #1077 + rng = date_range("1/1/2000", "2/29/2000").as_unit(unit) + rng2 = rng.repeat(2).values + ts = Series(np.random.default_rng(2).standard_normal(len(rng2)), index=rng2) + + msg = "cannot reindex on an axis with duplicate labels" + with pytest.raises(ValueError, match=msg): + ts.asfreq("B") + + +@pytest.mark.parametrize("freq", ["min", "5min", "15min", "30min", "4h", "12h"]) +def test_resample_anchored_ticks(freq, unit): + # If a fixed delta (5 minute, 4 hour) evenly divides a day, we should + # "anchor" the origin at midnight so we get regular intervals rather + # than starting from the first timestamp which might start in the + # middle of a desired interval + + rng = date_range("1/1/2000 04:00:00", periods=86400, freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + ts[:2] = np.nan # so results are the same + result = ts[2:].resample(freq, closed="left", label="left").mean() + expected = ts.resample(freq, closed="left", label="left").mean() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("end", [1, 2]) +def test_resample_single_group(end, unit): + mysum = lambda x: x.sum() + + rng = date_range("2000-1-1", f"2000-{end}-10", freq="D").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + tm.assert_series_equal(ts.resample("ME").sum(), ts.resample("ME").apply(mysum)) + + +def test_resample_single_group_std(unit): + # GH 3849 + s = Series( + [30.1, 31.6], + index=[Timestamp("20070915 15:30:00"), Timestamp("20070915 15:40:00")], + ) + s.index = s.index.as_unit(unit) + expected = Series( + [0.75], index=DatetimeIndex([Timestamp("20070915")], freq="D").as_unit(unit) + ) + result = s.resample("D").apply(lambda x: np.std(x)) + tm.assert_series_equal(result, expected) + + +def test_resample_offset(unit): + # GH 31809 + + rng = date_range("1/1/2000 00:00:00", "1/1/2000 02:00", freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + resampled = ts.resample("5min", offset="2min").mean() + exp_rng = date_range("12/31/1999 23:57:00", "1/1/2000 01:57", freq="5min").as_unit( + unit + ) + tm.assert_index_equal(resampled.index, exp_rng) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"origin": "1999-12-31 23:57:00"}, + {"origin": Timestamp("1970-01-01 00:02:00")}, + {"origin": "epoch", "offset": "2m"}, + # origin of '1999-31-12 12:02:00' should be equivalent for this case + {"origin": "1999-12-31 12:02:00"}, + {"offset": "-3m"}, + ], +) +def test_resample_origin(kwargs, unit): + # GH 31809 + rng = date_range("2000-01-01 00:00:00", "2000-01-01 02:00", freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + exp_rng = date_range( + "1999-12-31 23:57:00", "2000-01-01 01:57", freq="5min" + ).as_unit(unit) + + resampled = ts.resample("5min", **kwargs).mean() + tm.assert_index_equal(resampled.index, exp_rng) + + +@pytest.mark.parametrize( + "origin", ["invalid_value", "epch", "startday", "startt", "2000-30-30", object()] +) +def test_resample_bad_origin(origin, unit): + rng = date_range("2000-01-01 00:00:00", "2000-01-01 02:00", freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + msg = ( + "'origin' should be equal to 'epoch', 'start', 'start_day', " + "'end', 'end_day' or should be a Timestamp convertible type. Got " + f"'{origin}' instead." + ) + with pytest.raises(ValueError, match=msg): + ts.resample("5min", origin=origin) + + +@pytest.mark.parametrize("offset", ["invalid_value", "12dayys", "2000-30-30", object()]) +def test_resample_bad_offset(offset, unit): + rng = date_range("2000-01-01 00:00:00", "2000-01-01 02:00", freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + msg = f"'offset' should be a Timedelta convertible type. Got '{offset}' instead." + with pytest.raises(ValueError, match=msg): + ts.resample("5min", offset=offset) + + +def test_resample_origin_prime_freq(unit): + # GH 31809 + start, end = "2000-10-01 23:30:00", "2000-10-02 00:30:00" + rng = date_range(start, end, freq="7min").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + exp_rng = date_range( + "2000-10-01 23:14:00", "2000-10-02 00:22:00", freq="17min" + ).as_unit(unit) + resampled = ts.resample("17min").mean() + tm.assert_index_equal(resampled.index, exp_rng) + resampled = ts.resample("17min", origin="start_day").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + exp_rng = date_range( + "2000-10-01 23:30:00", "2000-10-02 00:21:00", freq="17min" + ).as_unit(unit) + resampled = ts.resample("17min", origin="start").mean() + tm.assert_index_equal(resampled.index, exp_rng) + resampled = ts.resample("17min", offset="23h30min").mean() + tm.assert_index_equal(resampled.index, exp_rng) + resampled = ts.resample("17min", origin="start_day", offset="23h30min").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + exp_rng = date_range( + "2000-10-01 23:18:00", "2000-10-02 00:26:00", freq="17min" + ).as_unit(unit) + resampled = ts.resample("17min", origin="epoch").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + exp_rng = date_range( + "2000-10-01 23:24:00", "2000-10-02 00:15:00", freq="17min" + ).as_unit(unit) + resampled = ts.resample("17min", origin="2000-01-01").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + +def test_resample_origin_with_tz(unit): + # GH 31809 + msg = "The origin must have the same timezone as the index." + + tz = "Europe/Paris" + rng = date_range( + "2000-01-01 00:00:00", "2000-01-01 02:00", freq="s", tz=tz + ).as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + exp_rng = date_range( + "1999-12-31 23:57:00", "2000-01-01 01:57", freq="5min", tz=tz + ).as_unit(unit) + resampled = ts.resample("5min", origin="1999-12-31 23:57:00+00:00").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + # origin of '1999-31-12 12:02:00+03:00' should be equivalent for this case + resampled = ts.resample("5min", origin="1999-12-31 12:02:00+03:00").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + resampled = ts.resample("5min", origin="epoch", offset="2m").mean() + tm.assert_index_equal(resampled.index, exp_rng) + + with pytest.raises(ValueError, match=msg): + ts.resample("5min", origin="12/31/1999 23:57:00").mean() + + # if the series is not tz aware, origin should not be tz aware + rng = date_range("2000-01-01 00:00:00", "2000-01-01 02:00", freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + with pytest.raises(ValueError, match=msg): + ts.resample("5min", origin="12/31/1999 23:57:00+03:00").mean() + + +def test_resample_origin_epoch_with_tz_day_vs_24h(unit): + # GH 34474 + start, end = "2000-10-01 23:30:00+0500", "2000-12-02 00:30:00+0500" + rng = date_range(start, end, freq="7min").as_unit(unit) + random_values = np.random.default_rng(2).standard_normal(len(rng)) + ts_1 = Series(random_values, index=rng) + + result_1 = ts_1.resample("D").mean() + result_2 = ts_1.resample("24h", origin="epoch").mean() + tm.assert_series_equal(result_1, result_2, check_freq=False) + # GH#41943 check_freq=False bc Day and Hour(24) no longer compare as equal + + # check that we have the same behavior with epoch even if we are not timezone aware + ts_no_tz = ts_1.tz_localize(None) + result_3 = ts_no_tz.resample("D").mean() + result_4 = ts_no_tz.resample("24h", origin="epoch").mean() + tm.assert_series_equal(result_1, result_3.tz_localize(rng.tz), check_freq=False) + tm.assert_series_equal(result_1, result_4.tz_localize(rng.tz), check_freq=False) + + # check that we have the similar results with two different timezones (+2H and +5H) + start, end = "2000-10-01 23:30:00+0200", "2000-12-02 00:30:00+0200" + rng = date_range(start, end, freq="7min").as_unit(unit) + ts_2 = Series(random_values, index=rng) + result_5 = ts_2.resample("D").mean() + result_6 = ts_2.resample("24h", origin="epoch").mean() + tm.assert_series_equal(result_1.tz_localize(None), result_5.tz_localize(None)) + tm.assert_series_equal(result_1.tz_localize(None), result_6.tz_localize(None)) + + +def test_resample_origin_with_day_freq_on_dst(unit): + # GH 31809 + tz = "America/Chicago" + msg = "The '(origin|offset)' keyword does not take effect" + + def _create_series(values, timestamps, freq="D"): + return Series( + values, + index=DatetimeIndex( + [Timestamp(t, tz=tz) for t in timestamps], freq=freq, ambiguous=True + ).as_unit(unit), + ) + + # test classical behavior of origin in a DST context + start = Timestamp("2013-11-02", tz=tz) + end = Timestamp("2013-11-03 23:59", tz=tz) + rng = date_range(start, end, freq="1h").as_unit(unit) + ts = Series(np.ones(len(rng)), index=rng) + + expected = _create_series([24.0, 25.0], ["2013-11-02", "2013-11-03"]) + for origin in ["epoch", "start", "start_day", start, None]: + warn = RuntimeWarning if origin != "start_day" else None + with tm.assert_produces_warning(warn, match=msg): + result = ts.resample("D", origin=origin).sum() + tm.assert_series_equal(result, expected) + + # test complex behavior of origin/offset in a DST context + start = Timestamp("2013-11-03", tz=tz) + end = Timestamp("2013-11-03 23:59", tz=tz) + rng = date_range(start, end, freq="1h").as_unit(unit) + ts = Series(np.ones(len(rng)), index=rng) + + # GH#61985 changed this to behave like "B" rather than "24h" + expected_ts = ["2013-11-03 00:00-05:00"] + expected = _create_series([25.0], expected_ts) + with tm.assert_produces_warning(RuntimeWarning, match=msg): + result = ts.resample("D", origin="start", offset="-2h").sum() + tm.assert_series_equal(result, expected) + + expected_ts = ["2013-11-02 22:00-05:00", "2013-11-03 21:00-06:00"] + expected = _create_series([22.0, 3.0], expected_ts, freq="24h") + result = ts.resample("24h", origin="start", offset="-2h").sum() + tm.assert_series_equal(result, expected) + + # GH#61985 changed this to behave like "B" rather than "24h" + expected_ts = ["2013-11-03 00:00-05:00"] + expected = _create_series([25.0], expected_ts) + with tm.assert_produces_warning(RuntimeWarning, match=msg): + result = ts.resample("D", origin="start", offset="2h").sum() + tm.assert_series_equal(result, expected) + + expected_ts = ["2013-11-03 00:00-05:00"] + expected = _create_series([25.0], expected_ts) + with tm.assert_produces_warning(RuntimeWarning, match=msg): + result = ts.resample("D", origin="start", offset="-1h").sum() + tm.assert_series_equal(result, expected) + + expected_ts = ["2013-11-03 00:00-05:00"] + expected = _create_series([25.0], expected_ts) + with tm.assert_produces_warning(RuntimeWarning, match=msg): + result = ts.resample("D", origin="start", offset="1h").sum() + tm.assert_series_equal(result, expected) + + +def test_resample_dst_midnight_last_nonexistent(): + # GH 58380 + ts = Series( + 1, + date_range("2024-04-19", "2024-04-20", tz="Africa/Cairo", freq="15min"), + ) + + expected = Series([len(ts)], index=DatetimeIndex([ts.index[0]], freq="7D")) + + result = ts.resample("7D").sum() + tm.assert_series_equal(result, expected) + + +def test_resample_daily_anchored(unit): + rng = date_range("1/1/2000 0:00:00", periods=10000, freq="min").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + ts[:2] = np.nan # so results are the same + + result = ts[2:].resample("D", closed="left", label="left").mean() + expected = ts.resample("D", closed="left", label="left").mean() + tm.assert_series_equal(result, expected) + + +def test_resample_to_period_monthly_buglet(unit): + # GH #1259 + + rng = date_range("1/1/2000", "12/31/2000").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + result = ts.resample("ME").mean().to_period() + exp_index = period_range("Jan-2000", "Dec-2000", freq="M") + tm.assert_index_equal(result.index, exp_index) + + +def test_period_with_agg(): + # aggregate a period resampler with a lambda + s2 = Series( + np.random.default_rng(2).integers(0, 5, 50), + index=period_range("2012-01-01", freq="h", periods=50), + dtype="float64", + ) + + expected = s2.to_timestamp().resample("D").mean().to_period() + result = s2.resample("D").agg(lambda x: x.mean()) + tm.assert_series_equal(result, expected) + + +def test_resample_segfault(unit): + # GH 8573 + # segfaulting in older versions + all_wins_and_wagers = [ + (1, datetime(2013, 10, 1, 16, 20), 1, 0), + (2, datetime(2013, 10, 1, 16, 10), 1, 0), + (2, datetime(2013, 10, 1, 18, 15), 1, 0), + (2, datetime(2013, 10, 1, 16, 10, 31), 1, 0), + ] + + df = DataFrame.from_records( + all_wins_and_wagers, columns=("ID", "timestamp", "A", "B") + ).set_index("timestamp") + df.index = df.index.as_unit(unit) + result = df.groupby("ID").resample("5min").sum() + expected = df.groupby("ID").apply(lambda x: x.resample("5min").sum()) + tm.assert_frame_equal(result, expected) + + +def test_resample_dtype_preservation(unit): + # GH 12202 + # validation tests for dtype preservation + + df = DataFrame( + { + "date": date_range(start="2016-01-01", periods=4, freq="W").as_unit(unit), + "group": [1, 1, 2, 2], + "val": Series([5, 6, 7, 8], dtype="int32"), + } + ).set_index("date") + + result = df.resample("1D").ffill() + assert result.val.dtype == np.int32 + + result = df.groupby("group").resample("1D").ffill() + assert result.val.dtype == np.int32 + + +def test_resample_dtype_coercion(unit): + pytest.importorskip("scipy.interpolate") + + # GH 16361 + df = {"a": [1, 3, 1, 4]} + df = DataFrame(df, index=date_range("2017-01-01", "2017-01-04").as_unit(unit)) + + expected = df.astype("float64").resample("h").mean()["a"].interpolate("cubic") + + result = df.resample("h")["a"].mean().interpolate("cubic") + tm.assert_series_equal(result, expected) + + result = df.resample("h").mean()["a"].interpolate("cubic") + tm.assert_series_equal(result, expected) + + +def test_weekly_resample_buglet(unit): + # #1327 + rng = date_range("1/1/2000", freq="B", periods=20).as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + resampled = ts.resample("W").mean() + expected = ts.resample("W-SUN").mean() + tm.assert_series_equal(resampled, expected) + + +def test_monthly_resample_error(unit): + # #1451 + dates = date_range("4/16/2012 20:00", periods=5000, freq="h").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(dates)), index=dates) + # it works! + ts.resample("ME") + + +def test_nanosecond_resample_error(): + # GH 12307 - Values falls after last bin when + # Resampling using pd.tseries.offsets.Nano as period + start = 1443707890427 + exp_start = 1443707890400 + indx = date_range(start=pd.to_datetime(start), periods=10, freq="100ns") + ts = Series(range(len(indx)), index=indx) + r = ts.resample(pd.tseries.offsets.Nano(100)) + result = r.agg("mean") + + exp_indx = date_range(start=pd.to_datetime(exp_start), periods=10, freq="100ns") + exp = Series(range(len(exp_indx)), index=exp_indx, dtype=float) + + tm.assert_series_equal(result, exp) + + +def test_resample_anchored_intraday(unit): + # #1471, #1458 + + rng = date_range("1/1/2012", "4/1/2012", freq="100min").as_unit(unit) + df = DataFrame(rng.month, index=rng) + + result = df.resample("ME").mean() + expected = df.resample("ME").mean().to_period() + expected = expected.to_timestamp(how="end") + expected.index += Timedelta(1, unit="us") - Timedelta(1, unit="D") + expected.index = expected.index.as_unit(unit)._with_freq("infer") + assert expected.index.freq == "ME" + tm.assert_frame_equal(result, expected) + + result = df.resample("ME", closed="left").mean() + exp = df.shift(1, freq="D").resample("ME").mean().to_period() + exp = exp.to_timestamp(how="end") + + exp.index = exp.index + Timedelta(1, unit="us") - Timedelta(1, unit="D") + exp.index = exp.index.as_unit(unit)._with_freq("infer") + assert exp.index.freq == "ME" + tm.assert_frame_equal(result, exp) + + +def test_resample_anchored_intraday2(unit): + rng = date_range("1/1/2012", "4/1/2012", freq="100min").as_unit(unit) + df = DataFrame(rng.month, index=rng) + + result = df.resample("QE").mean() + expected = df.resample("QE").mean().to_period() + expected = expected.to_timestamp(how="end") + expected.index += Timedelta(1, unit="us") - Timedelta(1, unit="D") + expected.index._data.freq = "QE" + expected.index._freq = lib.no_default + expected.index = expected.index.as_unit(unit) + tm.assert_frame_equal(result, expected) + + result = df.resample("QE", closed="left").mean() + expected = df.shift(1, freq="D").resample("QE").mean() + expected = expected.to_period() + expected = expected.to_timestamp(how="end") + expected.index += Timedelta(1, unit="us") - Timedelta(1, unit="D") + expected.index._data.freq = "QE" + expected.index._freq = lib.no_default + expected.index = expected.index.as_unit(unit) + tm.assert_frame_equal(result, expected) + + +def test_resample_anchored_intraday3(simple_date_range_series, unit): + ts = simple_date_range_series("2012-04-29 23:00", "2012-04-30 5:00", freq="h") + ts.index = ts.index.as_unit(unit) + resampled = ts.resample("ME").mean() + assert len(resampled) == 1 + + +@pytest.mark.parametrize("freq", ["MS", "BMS", "QS-MAR", "YS-DEC", "YS-JUN"]) +def test_resample_anchored_monthstart(simple_date_range_series, freq, unit): + ts = simple_date_range_series("1/1/2000", "12/31/2002") + ts.index = ts.index.as_unit(unit) + ts.resample(freq).mean() + + +@pytest.mark.parametrize("label, sec", [[None, 2.0], ["right", "4.2"]]) +def test_resample_anchored_multiday(label, sec): + # When resampling a range spanning multiple days, ensure that the + # start date gets used to determine the offset. Fixes issue where + # a one day period is not a multiple of the frequency. + # + # See: https://github.com/pandas-dev/pandas/issues/8683 + + index1 = date_range("2014-10-14 23:06:23.206", periods=3, freq="400ms") + index2 = date_range("2014-10-15 23:00:00", periods=2, freq="2200ms") + index = index1.union(index2) + + s = Series(np.random.default_rng(2).standard_normal(5), index=index) + + # Ensure left closing works + result = s.resample("2200ms", label=label).mean() + assert result.index[-1] == Timestamp(f"2014-10-15 23:00:{sec}00") + + +def test_corner_cases(unit): + # miscellaneous test coverage + + rng = date_range("1/1/2000", periods=12, freq="min").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + result = ts.resample("5min", closed="right", label="left").mean() + ex_index = date_range("1999-12-31 23:55", periods=4, freq="5min").as_unit(unit) + tm.assert_index_equal(result.index, ex_index) + + +def test_corner_cases_date(simple_date_range_series, unit): + # resample to periods + ts = simple_date_range_series("2000-04-28", "2000-04-30 11:00", freq="h") + ts.index = ts.index.as_unit(unit) + result = ts.resample("ME").mean().to_period() + assert len(result) == 1 + assert result.index[0] == Period("2000-04", freq="M") + + +def test_anchored_lowercase_buglet(unit): + dates = date_range("4/16/2012 20:00", periods=50000, freq="s").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(dates)), index=dates) + # it works! + msg = "'d' is deprecated and will be removed in a future version." + with tm.assert_produces_warning(Pandas4Warning, match=msg): + ts.resample("d").mean() + + +def test_upsample_apply_functions(unit): + # #1596 + rng = date_range("2012-06-12", periods=4, freq="h").as_unit(unit) + + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + result = ts.resample("20min").aggregate(["mean", "sum"]) + assert isinstance(result, DataFrame) + + +def test_resample_not_monotonic(unit): + rng = date_range("2012-06-12", periods=200, freq="h").as_unit(unit) + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + ts = ts.take(np.random.default_rng(2).permutation(len(ts))) + + result = ts.resample("D").sum() + exp = ts.sort_index().resample("D").sum() + tm.assert_series_equal(result, exp) + + +@pytest.mark.parametrize( + "dtype", + [ + "int64", + "int32", + "float64", + "float32", + ], +) +def test_resample_median_bug_1688(dtype, unit): + # GH#55958 + dti = DatetimeIndex( + [datetime(2012, 1, 1, 0, 0, 0), datetime(2012, 1, 1, 0, 5, 0)] + ).as_unit(unit) + df = DataFrame( + [1, 2], + index=dti, + dtype=dtype, + ) + + result = df.resample("min").apply(lambda x: x.mean()) + exp = df.asfreq("min") + tm.assert_frame_equal(result, exp) + + result = df.resample("min").median() + exp = df.asfreq("min") + tm.assert_frame_equal(result, exp) + + +def test_how_lambda_functions(simple_date_range_series, unit): + ts = simple_date_range_series("1/1/2000", "4/1/2000") + ts.index = ts.index.as_unit(unit) + + result = ts.resample("ME").apply(lambda x: x.mean()) + exp = ts.resample("ME").mean() + tm.assert_series_equal(result, exp) + + foo_exp = ts.resample("ME").mean() + foo_exp.name = "foo" + bar_exp = ts.resample("ME").std() + bar_exp.name = "bar" + + result = ts.resample("ME").apply([lambda x: x.mean(), lambda x: x.std(ddof=1)]) + result.columns = ["foo", "bar"] + tm.assert_series_equal(result["foo"], foo_exp) + tm.assert_series_equal(result["bar"], bar_exp) + + # this is a MI Series, so comparing the names of the results + # doesn't make sense + result = ts.resample("ME").aggregate( + {"foo": lambda x: x.mean(), "bar": lambda x: x.std(ddof=1)} + ) + tm.assert_series_equal(result["foo"], foo_exp, check_names=False) + tm.assert_series_equal(result["bar"], bar_exp, check_names=False) + + +def test_resample_unequal_times(unit): + # #1772 + start = datetime(1999, 3, 1, 5) + # end hour is less than start + end = datetime(2012, 7, 31, 4) + bad_ind = date_range(start, end, freq="30min").as_unit(unit) + df = DataFrame({"close": 1}, index=bad_ind) + + # it works! + df.resample("YS").sum() + + +def test_resample_consistency(unit): + # GH 6418 + # resample with bfill / limit / reindex consistency + + i30 = date_range("2002-02-02", periods=4, freq="30min").as_unit(unit) + s = Series(np.arange(4.0), index=i30) + s.iloc[2] = np.nan + + # Upsample by factor 3 with reindex() and resample() methods: + i10 = date_range(i30[0], i30[-1], freq="10min").as_unit(unit) + + s10 = s.reindex(index=i10, method="bfill") + s10_2 = s.reindex(index=i10, method="bfill", limit=2) + with tm.assert_produces_warning(Pandas4Warning): + rl = s.reindex_like(s10, method="bfill", limit=2) + r10_2 = s.resample("10Min").bfill(limit=2) + r10 = s.resample("10Min").bfill() + + # s10_2, r10, r10_2, rl should all be equal + tm.assert_series_equal(s10_2, r10) + tm.assert_series_equal(s10_2, r10_2) + tm.assert_series_equal(s10_2, rl) + + +dates1: list[DatetimeNaTType] = [ + datetime(2014, 10, 1), + datetime(2014, 9, 3), + datetime(2014, 11, 5), + datetime(2014, 9, 5), + datetime(2014, 10, 8), + datetime(2014, 7, 15), +] + +dates2: list[DatetimeNaTType] = [*dates1[:2], pd.NaT, *dates1[2:4], pd.NaT, *dates1[4:]] +dates3 = [pd.NaT, *dates1, pd.NaT] + + +@pytest.mark.parametrize("dates", [dates1, dates2, dates3]) +def test_resample_timegrouper(dates, unit): + # GH 7227 + dates = DatetimeIndex(dates).as_unit(unit) + df = DataFrame({"A": dates, "B": np.arange(len(dates))}) + result = df.set_index("A").resample("ME").count() + exp_idx = DatetimeIndex( + ["2014-07-31", "2014-08-31", "2014-09-30", "2014-10-31", "2014-11-30"], + freq="ME", + name="A", + ).as_unit(unit) + expected = DataFrame({"B": [1, 0, 2, 2, 1]}, index=exp_idx) + if df["A"].isna().any(): + expected.index = expected.index._with_freq(None) + tm.assert_frame_equal(result, expected) + + result = df.groupby(Grouper(freq="ME", key="A")).count() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dates", [dates1, dates2, dates3]) +def test_resample_timegrouper2(dates, unit): + dates = DatetimeIndex(dates).as_unit(unit) + + df = DataFrame({"A": dates, "B": np.arange(len(dates)), "C": np.arange(len(dates))}) + result = df.set_index("A").resample("ME").count() + + exp_idx = DatetimeIndex( + ["2014-07-31", "2014-08-31", "2014-09-30", "2014-10-31", "2014-11-30"], + freq="ME", + name="A", + ).as_unit(unit) + expected = DataFrame( + {"B": [1, 0, 2, 2, 1], "C": [1, 0, 2, 2, 1]}, + index=exp_idx, + columns=["B", "C"], + ) + if df["A"].isna().any(): + expected.index = expected.index._with_freq(None) + tm.assert_frame_equal(result, expected) + + result = df.groupby(Grouper(freq="ME", key="A")).count() + tm.assert_frame_equal(result, expected) + + +def test_resample_nunique(unit): + # GH 12352 + df = DataFrame( + { + "ID": { + Timestamp("2015-06-05 00:00:00"): "0010100903", + Timestamp("2015-06-08 00:00:00"): "0010150847", + }, + "DATE": { + Timestamp("2015-06-05 00:00:00"): "2015-06-05", + Timestamp("2015-06-08 00:00:00"): "2015-06-08", + }, + } + ) + df.index = df.index.as_unit(unit) + r = df.resample("D") + g = df.groupby(Grouper(freq="D")) + expected = df.groupby(Grouper(freq="D")).ID.apply(lambda x: x.nunique()) + assert expected.name == "ID" + + for t in [r, g]: + result = t.ID.nunique() + tm.assert_series_equal(result, expected) + + result = df.ID.resample("D").nunique() + tm.assert_series_equal(result, expected) + + result = df.ID.groupby(Grouper(freq="D")).nunique() + tm.assert_series_equal(result, expected) + + +def test_resample_nunique_preserves_column_level_names(unit): + # see gh-23222 + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 4)), + columns=Index(list("ABCD"), dtype=object), + index=date_range("2000-01-01", periods=5, freq="D"), + ).abs() + df.index = df.index.as_unit(unit) + df.columns = pd.MultiIndex.from_arrays( + [df.columns.tolist()] * 2, names=["lev0", "lev1"] + ) + result = df.resample("1h").nunique() + tm.assert_index_equal(df.columns, result.columns) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: x.nunique(), + lambda x: x.agg(Series.nunique), + lambda x: x.agg("nunique"), + ], + ids=["nunique", "series_nunique", "nunique_str"], +) +def test_resample_nunique_with_date_gap(func, unit): + # GH 13453 + # Since all elements are unique, these should all be the same + index = date_range("1-1-2000", "2-15-2000", freq="h").as_unit(unit) + index2 = date_range("4-15-2000", "5-15-2000", freq="h").as_unit(unit) + index3 = index.append(index2) + s = Series(range(len(index3)), index=index3, dtype="int64") + r = s.resample("ME") + result = r.count() + expected = func(r) + tm.assert_series_equal(result, expected) + + +def test_resample_group_info(unit): + # GH10914 + + # use a fixed seed to always have the same uniques + n = 100 + k = 10 + prng = np.random.default_rng(2) + + dr = date_range(start="2015-08-27", periods=n // 10, freq="min").as_unit(unit) + ts = Series(prng.integers(0, n // k, n).astype("int64"), index=prng.choice(dr, n)) + + left = ts.resample("30min").nunique() + ix = date_range(start=ts.index.min(), end=ts.index.max(), freq="30min").as_unit( + unit + ) + + vals = ts.values + bins = np.searchsorted(ix.values, ts.index, side="right") + + sorter = np.lexsort((vals, bins)) + vals, bins = vals[sorter], bins[sorter] + + mask = np.r_[True, vals[1:] != vals[:-1]] + mask |= np.r_[True, bins[1:] != bins[:-1]] + + arr = np.bincount(bins[mask] - 1, minlength=len(ix)).astype("int64", copy=False) + right = Series(arr, index=ix) + + tm.assert_series_equal(left, right) + + +def test_resample_size(unit): + n = 10000 + dr = date_range("2015-09-19", periods=n, freq="min").as_unit(unit) + ts = Series( + np.random.default_rng(2).standard_normal(n), + index=np.random.default_rng(2).choice(dr, n), + ) + + left = ts.resample("7min").size() + ix = date_range(start=left.index.min(), end=ts.index.max(), freq="7min").as_unit( + unit + ) + + bins = np.searchsorted(ix.values, ts.index.values, side="right") + val = np.bincount(bins, minlength=len(ix) + 1)[1:].astype("int64", copy=False) + + right = Series(val, index=ix) + tm.assert_series_equal(left, right) + + +def test_resample_across_dst(): + # The test resamples a DatetimeIndex with values before and after a + # DST change + # Issue: 14682 + + # The DatetimeIndex we will start with + # (note that DST happens at 03:00+02:00 -> 02:00+01:00) + # 2016-10-30 02:23:00+02:00, 2016-10-30 02:23:00+01:00 + df1 = DataFrame([1477786980, 1477790580], columns=["ts"]) + dti1 = DatetimeIndex( + pd.to_datetime(df1.ts, unit="s") + .dt.tz_localize("UTC") + .dt.tz_convert("Europe/Madrid") + ) + + # The expected DatetimeIndex after resampling. + # 2016-10-30 02:00:00+02:00, 2016-10-30 02:00:00+01:00 + df2 = DataFrame([1477785600, 1477789200], columns=["ts"]) + dti2 = DatetimeIndex( + pd.to_datetime(df2.ts, unit="s") + .dt.tz_localize("UTC") + .dt.tz_convert("Europe/Madrid"), + freq="h", + ) + df = DataFrame([5, 5], index=dti1) + + result = df.resample(rule="h").sum() + expected = DataFrame([5, 5], index=dti2) + + tm.assert_frame_equal(result, expected) + + +def test_groupby_with_dst_time_change(unit): + # GH 24972 + index = ( + DatetimeIndex([1478064900001000000, 1480037118776792000], tz="UTC") + .tz_convert("America/Chicago") + .as_unit(unit) + ) + + df = DataFrame([1, 2], index=index) + result = df.groupby(Grouper(freq="1D")).last() + expected_index_values = date_range( + "2016-11-02", "2016-11-24", freq="D", tz="America/Chicago" + ).as_unit(unit) + + index = DatetimeIndex(expected_index_values) + expected = DataFrame([1.0] + ([np.nan] * 21) + [2.0], index=index) + tm.assert_frame_equal(result, expected) + + +def test_resample_dst_anchor(unit): + # 5172 + dti = DatetimeIndex([datetime(2012, 11, 4, 23)], tz="US/Eastern").as_unit(unit) + df = DataFrame([5], index=dti) + + dti = DatetimeIndex(df.index.normalize(), freq="D").as_unit(unit) + expected = DataFrame([5], index=dti) + tm.assert_frame_equal(df.resample(rule="D").sum(), expected) + df.resample(rule="MS").sum() + tm.assert_frame_equal( + df.resample(rule="MS").sum(), + DataFrame( + [5], + index=DatetimeIndex( + [datetime(2012, 11, 1)], tz="US/Eastern", freq="MS" + ).as_unit(unit), + ), + ) + + +def test_resample_dst_anchor2(unit): + dti = date_range( + "2013-09-30", "2013-11-02", freq="30Min", tz="Europe/Paris" + ).as_unit(unit) + values = range(dti.size) + df = DataFrame({"a": values, "b": values, "c": values}, index=dti, dtype="int64") + how = {"a": "min", "b": "max", "c": "count"} + + rs = df.resample("W-MON") + result = rs.agg(how)[["a", "b", "c"]] + expected = DataFrame( + { + "a": [0, 48, 384, 720, 1056, 1394], + "b": [47, 383, 719, 1055, 1393, 1586], + "c": [48, 336, 336, 336, 338, 193], + }, + index=date_range( + "9/30/2013", "11/4/2013", freq="W-MON", tz="Europe/Paris" + ).as_unit(unit), + ) + tm.assert_frame_equal( + result, + expected, + "W-MON Frequency", + ) + + rs2 = df.resample("2W-MON") + result2 = rs2.agg(how)[["a", "b", "c"]] + expected2 = DataFrame( + { + "a": [0, 48, 720, 1394], + "b": [47, 719, 1393, 1586], + "c": [48, 672, 674, 193], + }, + index=date_range( + "9/30/2013", "11/11/2013", freq="2W-MON", tz="Europe/Paris" + ).as_unit(unit), + ) + tm.assert_frame_equal( + result2, + expected2, + "2W-MON Frequency", + ) + + rs3 = df.resample("MS") + result3 = rs3.agg(how)[["a", "b", "c"]] + expected3 = DataFrame( + {"a": [0, 48, 1538], "b": [47, 1537, 1586], "c": [48, 1490, 49]}, + index=date_range("9/1/2013", "11/1/2013", freq="MS", tz="Europe/Paris").as_unit( + unit + ), + ) + tm.assert_frame_equal( + result3, + expected3, + "MS Frequency", + ) + + rs4 = df.resample("2MS") + result4 = rs4.agg(how)[["a", "b", "c"]] + expected4 = DataFrame( + {"a": [0, 1538], "b": [1537, 1586], "c": [1538, 49]}, + index=date_range( + "9/1/2013", "11/1/2013", freq="2MS", tz="Europe/Paris" + ).as_unit(unit), + ) + tm.assert_frame_equal( + result4, + expected4, + "2MS Frequency", + ) + + df_daily = df["10/26/2013":"10/29/2013"] + rs_d = df_daily.resample("D") + result_d = rs_d.agg({"a": "min", "b": "max", "c": "count"})[["a", "b", "c"]] + expected_d = DataFrame( + { + "a": [1248, 1296, 1346, 1394], + "b": [1295, 1345, 1393, 1441], + "c": [48, 50, 48, 48], + }, + index=date_range( + "10/26/2013", "10/29/2013", freq="D", tz="Europe/Paris" + ).as_unit(unit), + ) + tm.assert_frame_equal( + result_d, + expected_d, + "D Frequency", + ) + + +def test_downsample_across_dst(unit): + # GH 8531 + tz = zoneinfo.ZoneInfo("Europe/Berlin") + dt = datetime(2014, 10, 26) + dates = date_range(dt.astimezone(tz), periods=4, freq="2h").as_unit(unit) + result = Series(5, index=dates).resample("h").mean() + expected = Series( + [5.0, np.nan] * 3 + [5.0], + index=date_range(dt.astimezone(tz), periods=7, freq="h").as_unit(unit), + ) + tm.assert_series_equal(result, expected) + + +def test_downsample_across_dst_weekly(unit): + # GH 9119, GH 21459 + df = DataFrame( + index=DatetimeIndex( + ["2017-03-25", "2017-03-26", "2017-03-27", "2017-03-28", "2017-03-29"], + tz="Europe/Amsterdam", + ).as_unit(unit), + data=[11, 12, 13, 14, 15], + ) + result = df.resample("1W").sum() + expected = DataFrame( + [23, 42], + index=DatetimeIndex( + ["2017-03-26", "2017-04-02"], tz="Europe/Amsterdam", freq="W" + ).as_unit(unit), + ) + tm.assert_frame_equal(result, expected) + + +def test_downsample_across_dst_weekly_2(unit): + # GH 9119, GH 21459 + idx = date_range("2013-04-01", "2013-05-01", tz="Europe/London", freq="h").as_unit( + unit + ) + s = Series(index=idx, dtype=np.float64) + result = s.resample("W").mean() + expected = Series( + index=date_range("2013-04-07", freq="W", periods=5, tz="Europe/London").as_unit( + unit + ), + dtype=np.float64, + ) + tm.assert_series_equal(result, expected) + + +def test_downsample_dst_at_midnight(unit): + # GH 25758 + start = datetime(2018, 11, 3, 12) + end = datetime(2018, 11, 5, 12) + index = date_range(start, end, freq="1h").as_unit(unit) + index = index.tz_localize("UTC").tz_convert("America/Havana") + data = list(range(len(index))) + dataframe = DataFrame(data, index=index) + result = dataframe.groupby(Grouper(freq="1D")).mean() + + dti = date_range("2018-11-03", periods=3).tz_localize( + "America/Havana", ambiguous=True + ) + dti = DatetimeIndex(dti, freq="D").as_unit(unit) + expected = DataFrame([7.5, 28.0, 44.5], index=dti) + tm.assert_frame_equal(result, expected) + + +def test_resample_with_nat(unit): + # GH 13020 + index = DatetimeIndex( + [ + pd.NaT, + "1970-01-01 00:00:00", + pd.NaT, + "1970-01-01 00:00:01", + "1970-01-01 00:00:02", + ] + ).as_unit(unit) + frame = DataFrame([2, 3, 5, 7, 11], index=index) + + index_1s = DatetimeIndex( + ["1970-01-01 00:00:00", "1970-01-01 00:00:01", "1970-01-01 00:00:02"] + ).as_unit(unit) + frame_1s = DataFrame([3.0, 7.0, 11.0], index=index_1s) + tm.assert_frame_equal(frame.resample("1s").mean(), frame_1s) + + index_2s = DatetimeIndex(["1970-01-01 00:00:00", "1970-01-01 00:00:02"]).as_unit( + unit + ) + frame_2s = DataFrame([5.0, 11.0], index=index_2s) + tm.assert_frame_equal(frame.resample("2s").mean(), frame_2s) + + index_3s = DatetimeIndex(["1970-01-01 00:00:00"]).as_unit(unit) + frame_3s = DataFrame([7.0], index=index_3s) + tm.assert_frame_equal(frame.resample("3s").mean(), frame_3s) + + tm.assert_frame_equal(frame.resample("60s").mean(), frame_3s) + + +def test_resample_datetime_values(unit): + # GH 13119 + # check that datetime dtype is preserved when NaT values are + # introduced by the resampling + + dates = [datetime(2016, 1, 15), datetime(2016, 1, 19)] + df = DataFrame({"timestamp": dates}, index=dates) + df.index = df.index.as_unit(unit) + + exp = Series( + [datetime(2016, 1, 15), pd.NaT, datetime(2016, 1, 19)], + index=date_range("2016-01-15", periods=3, freq="2D").as_unit(unit), + name="timestamp", + ) + + res = df.resample("2D").first()["timestamp"] + tm.assert_series_equal(res, exp) + res = df["timestamp"].resample("2D").first() + tm.assert_series_equal(res, exp) + + +def test_resample_apply_with_additional_args(unit): + # GH 14615 + index = date_range("1/1/2000 00:00:00", "1/1/2000 00:13:00", freq="Min") + series = Series(range(len(index)), index=index) + series.index.name = "index" + + def f(data, add_arg): + return np.mean(data) * add_arg + + series.index = series.index.as_unit(unit) + + multiplier = 10 + result = series.resample("D").apply(f, multiplier) + expected = series.resample("D").mean().multiply(multiplier) + tm.assert_series_equal(result, expected) + + # Testing as kwarg + result = series.resample("D").apply(f, add_arg=multiplier) + expected = series.resample("D").mean().multiply(multiplier) + tm.assert_series_equal(result, expected) + + +def test_resample_apply_with_additional_args2(): + # Testing dataframe + def f(data, add_arg): + return np.mean(data) * add_arg + + multiplier = 10 + + df = DataFrame({"A": 1, "B": 2}, index=date_range("2017", periods=10)) + result = df.groupby("A").resample("D").agg(f, multiplier).astype(float) + expected = df.groupby("A").resample("D").mean().multiply(multiplier) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("k", [1, 2, 3]) +@pytest.mark.parametrize( + "n1, freq1, n2, freq2", + [ + (30, "s", 0.5, "Min"), + (60, "s", 1, "Min"), + (3600, "s", 1, "h"), + (60, "Min", 1, "h"), + (21600, "s", 0.25, "D"), + (86400, "s", 1, "D"), + (43200, "s", 0.5, "D"), + (1440, "Min", 1, "D"), + (12, "h", 0.5, "D"), + (24, "h", 1, "D"), + ], +) +def test_resample_equivalent_offsets(n1, freq1, n2, freq2, k, unit): + # GH 24127 + n1_ = n1 * k + n2_ = n2 * k + dti = date_range("1991-09-05", "1991-09-06", freq=freq1).as_unit(unit) + ser = Series(range(len(dti)), index=dti) + + result1 = ser.resample(str(n1_) + freq1).mean() + result2 = ser.resample(str(n2_) + freq2).mean() + if freq2 == "D" and isinstance(result2.index.freq, Day): + # GH#55502 Day is no longer a Tick so no longer compares as equivalent, + # but the actual values we expect should still match + result2.index.freq = to_offset(Timedelta(days=result2.index.freq.n)) + tm.assert_series_equal(result1, result2) + + +@pytest.mark.parametrize( + "first,last,freq,exp_first,exp_last", + [ + ("19910905", "19920406", "D", "19910905", "19920407"), + ("19910905 00:00", "19920406 06:00", "D", "19910905", "19920407"), + ("19910905 06:00", "19920406 06:00", "h", "19910905 06:00", "19920406 07:00"), + ("19910906", "19920406", "ME", "19910831", "19920430"), + ("19910831", "19920430", "ME", "19910831", "19920531"), + ("1991-08", "1992-04", "ME", "19910831", "19920531"), + ], +) +def test_get_timestamp_range_edges(first, last, freq, exp_first, exp_last, unit): + first = Period(first) + first = first.to_timestamp(first.freq).as_unit(unit) + last = Period(last) + last = last.to_timestamp(last.freq).as_unit(unit) + + exp_first = Timestamp(exp_first) + exp_last = Timestamp(exp_last) + + freq = pd.tseries.frequencies.to_offset(freq) + result = _get_timestamp_range_edges(first, last, freq, unit="ns") + expected = (exp_first, exp_last) + assert result == expected + + +@pytest.mark.parametrize("duplicates", [True, False]) +def test_resample_apply_product(duplicates, unit): + # GH 5586 + index = date_range(start="2012-01-31", freq="ME", periods=12).as_unit(unit) + + ts = Series(range(12), index=index) + df = DataFrame({"A": ts, "B": ts + 2}) + if duplicates: + df.columns = ["A", "A"] + + result = df.resample("QE").apply(np.prod) + expected = DataFrame( + np.array([[0, 24], [60, 210], [336, 720], [990, 1716]], dtype=np.int64), + index=DatetimeIndex( + ["2012-03-31", "2012-06-30", "2012-09-30", "2012-12-31"], freq="QE-DEC" + ).as_unit(unit), + columns=df.columns, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "first,last,freq_in,freq_out,exp_last", + [ + ( + "2020-03-28", + "2020-03-31", + "D", + "24h", + "2020-03-30 01:00", + ), # includes transition into DST + ( + "2020-03-28", + "2020-10-27", + "D", + "24h", + "2020-10-27 00:00", + ), # includes transition into and out of DST + ( + "2020-10-25", + "2020-10-27", + "D", + "24h", + "2020-10-26 23:00", + ), # includes transition out of DST + ( + "2020-03-28", + "2020-03-31", + "24h", + "D", + "2020-03-30 00:00", + ), # same as above, but from 24H to D + ("2020-03-28", "2020-10-27", "24h", "D", "2020-10-27 00:00"), + ("2020-10-25", "2020-10-27", "24h", "D", "2020-10-26 00:00"), + ], +) +def test_resample_calendar_day_with_dst( + first: str, last: str, freq_in: str, freq_out: str, exp_last: str, unit +): + # GH 35219 + ts = Series( + 1.0, date_range(first, last, freq=freq_in, tz="Europe/Amsterdam").as_unit(unit) + ) + result = ts.resample(freq_out).ffill() + expected = Series( + 1.0, + date_range(first, exp_last, freq=freq_out, tz="Europe/Amsterdam").as_unit(unit), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["min", "max", "first", "last"]) +def test_resample_aggregate_functions_min_count(func, unit): + # GH#37768 + index = date_range(start="2020", freq="ME", periods=3).as_unit(unit) + ser = Series([1, np.nan, np.nan], index) + result = getattr(ser.resample("QE"), func)(min_count=2) + expected = Series( + [np.nan], + index=DatetimeIndex(["2020-03-31"], freq="QE-DEC").as_unit(unit), + ) + tm.assert_series_equal(result, expected) + + +def test_resample_unsigned_int(any_unsigned_int_numpy_dtype, unit): + # gh-43329 + df = DataFrame( + index=date_range(start="2000-01-01", end="2000-01-03 23", freq="12h").as_unit( + unit + ), + columns=["x"], + data=[0, 1, 0] * 2, + dtype=any_unsigned_int_numpy_dtype, + ) + df = df.loc[(df.index < "2000-01-02") | (df.index > "2000-01-03"), :] + + result = df.resample("D").max() + + expected = DataFrame( + [1, np.nan, 0], + columns=["x"], + index=date_range(start="2000-01-01", end="2000-01-03 23", freq="D").as_unit( + unit + ), + ) + tm.assert_frame_equal(result, expected) + + +def test_long_rule_non_nano(): + # https://github.com/pandas-dev/pandas/issues/51024 + idx = date_range("0300-01-01", "2000-01-01", unit="s", freq="100YE") + ser = Series([1, 4, 2, 8, 5, 7, 1, 4, 2, 8, 5, 7, 1, 4, 2, 8, 5], index=idx) + result = ser.resample("200YE").mean() + expected_idx = DatetimeIndex( + np.array( + [ + "0300-12-31", + "0500-12-31", + "0700-12-31", + "0900-12-31", + "1100-12-31", + "1300-12-31", + "1500-12-31", + "1700-12-31", + "1900-12-31", + ] + ).astype("datetime64[s]"), + freq="200YE-DEC", + ) + expected = Series([1.0, 3.0, 6.5, 4.0, 3.0, 6.5, 4.0, 3.0, 6.5], index=expected_idx) + tm.assert_series_equal(result, expected) + + +def test_resample_empty_series_with_tz(): + # GH#53664 + df = DataFrame({"ts": [], "values": []}).astype( + {"ts": "datetime64[ns, Atlantic/Faroe]"} + ) + rs = df.resample("2MS", on="ts", closed="left", label="left") + result = rs["values"].sum() + + expected_idx = DatetimeIndex( + [], freq="2MS", name="ts", dtype="datetime64[ns, Atlantic/Faroe]" + ) + expected = Series([], index=expected_idx, name="values", dtype="float64") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("freq", ["2M", "2m", "2Q", "2Q-SEP", "2q-sep", "1Y", "2Y-MAR"]) +def test_resample_M_Q_Y_raises(freq): + msg = f"Invalid frequency: {freq}" + + s = Series(range(10), index=date_range("20130101", freq="D", periods=10)) + with pytest.raises(ValueError, match=msg): + s.resample(freq).mean() + + +@pytest.mark.parametrize("freq", ["2BM", "1bm", "1BQ", "2BQ-MAR", "2bq=-mar"]) +def test_resample_BM_BQ_raises(freq): + msg = f"Invalid frequency: {freq}" + + s = Series(range(10), index=date_range("20130101", freq="D", periods=10)) + with pytest.raises(ValueError, match=msg): + s.resample(freq).mean() + + +@pytest.mark.parametrize( + "freq,freq_depr,data", + [ + ("1W-SUN", "1w-sun", ["2013-01-06"]), + ("1D", "1d", ["2013-01-01"]), + ("1B", "1b", ["2013-01-01"]), + ("1C", "1c", ["2013-01-01"]), + ], +) +def test_resample_depr_lowercase_frequency(freq, freq_depr, data): + msg = f"'{freq_depr[1:]}' is deprecated and will be removed in a future version." + + s = Series(range(5), index=date_range("20130101", freq="h", periods=5, unit="ns")) + with tm.assert_produces_warning(Pandas4Warning, match=msg): + result = s.resample(freq_depr).mean() + + exp_dti = DatetimeIndex(data=data, dtype="datetime64[ns]", freq=freq) + expected = Series(2.0, index=exp_dti) + tm.assert_series_equal(result, expected, check_freq=False) + # GH#41943 check_freq=False bc 24H and D no longer compare as equal + + +def test_resample_ms_closed_right(unit): + # https://github.com/pandas-dev/pandas/issues/55271 + dti = date_range(start="2020-01-31", freq="1min", periods=6000, unit=unit) + df = DataFrame({"ts": dti}, index=dti) + grouped = df.resample("MS", closed="right") + result = grouped.last() + exp_dti = DatetimeIndex( + [datetime(2020, 1, 1), datetime(2020, 2, 1)], freq="MS" + ).as_unit(unit) + expected = DataFrame( + {"ts": [datetime(2020, 2, 1), datetime(2020, 2, 4, 3, 59)]}, + index=exp_dti, + ).astype(f"M8[{unit}]") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("freq", ["B", "C"]) +def test_resample_c_b_closed_right(freq: str, unit): + # https://github.com/pandas-dev/pandas/issues/55281 + dti = date_range(start="2020-01-31", freq="1min", periods=6000, unit=unit) + df = DataFrame({"ts": dti}, index=dti) + grouped = df.resample(freq, closed="right") + result = grouped.last() + + exp_dti = DatetimeIndex( + [ + datetime(2020, 1, 30), + datetime(2020, 1, 31), + datetime(2020, 2, 3), + datetime(2020, 2, 4), + ], + freq=freq, + ).as_unit(unit) + expected = DataFrame( + { + "ts": [ + datetime(2020, 1, 31), + datetime(2020, 2, 3), + datetime(2020, 2, 4), + datetime(2020, 2, 4, 3, 59), + ] + }, + index=exp_dti, + ).astype(f"M8[{unit}]") + tm.assert_frame_equal(result, expected) + + +def test_resample_b_55282(unit): + # https://github.com/pandas-dev/pandas/issues/55282 + dti = date_range("2023-09-26", periods=6, freq="12h", unit=unit) + ser = Series([1, 2, 3, 4, 5, 6], index=dti) + result = ser.resample("B", closed="right", label="right").mean() + + exp_dti = DatetimeIndex( + [ + datetime(2023, 9, 26), + datetime(2023, 9, 27), + datetime(2023, 9, 28), + datetime(2023, 9, 29), + ], + freq="B", + ).as_unit(unit) + expected = Series( + [1.0, 2.5, 4.5, 6.0], + index=exp_dti, + ) + tm.assert_series_equal(result, expected) + + +@td.skip_if_no("pyarrow") +@pytest.mark.parametrize( + "tz", + [ + None, + pytest.param( + "UTC", + marks=pytest.mark.xfail( + condition=is_platform_windows() and pa_version_under22p0, + reason="TODO: Set ARROW_TIMEZONE_DATABASE env var in CI", + ), + ), + ], +) +def test_arrow_timestamp_resample(tz): + # GH 56371 + idx = Series(date_range("2020-01-01", periods=5), dtype="timestamp[ns][pyarrow]") + if tz is not None: + idx = idx.dt.tz_localize(tz) + expected = Series(np.arange(5, dtype=np.float64), index=idx) + result = expected.resample("1D").mean() + tm.assert_series_equal(result, expected) + + +@td.skip_if_no("pyarrow") +def test_arrow_timestamp_resample_keep_index_name(): + # https://github.com/pandas-dev/pandas/issues/61222 + idx = Series(date_range("2020-01-01", periods=5), dtype="timestamp[ns][pyarrow]") + expected = Series(np.arange(5, dtype=np.float64), index=idx) + expected.index.name = "index_name" + result = expected.resample("1D").mean() + tm.assert_series_equal(result, expected) + + +def test_resample_unit_second_large_years(): + # GH#57427 + index = DatetimeIndex( + date_range(start=Timestamp("1950-01-01"), periods=10, freq="1000YS", unit="s") + ) + ser = Series(1, index=index) + result = ser.resample("2000YS").sum() + expected = Series(2, index=index[::2]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("freq", ["1A", "2A-MAR"]) +def test_resample_A_raises(freq): + msg = f"Invalid frequency: {freq[1:]}" + + s = Series(range(10), index=date_range("20130101", freq="D", periods=10)) + with pytest.raises(ValueError, match=msg): + s.resample(freq).mean() diff --git a/pandas/tests/resample/test_period_index.py b/pandas/tests/resample/test_period_index.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e2613b823f75f83bd1fdb840d2e25df26c8123 --- /dev/null +++ b/pandas/tests/resample/test_period_index.py @@ -0,0 +1,1032 @@ +from datetime import ( + datetime, + timezone, +) +import re +import warnings +import zoneinfo + +import dateutil +import numpy as np +import pytest + +from pandas._libs.tslibs.ccalendar import ( + DAYS, + MONTHS, +) +from pandas._libs.tslibs.period import IncompatibleFrequency +from pandas.errors import InvalidIndexError + +import pandas as pd +from pandas import ( + DataFrame, + Series, + Timestamp, +) +import pandas._testing as tm +from pandas.core.indexes.datetimes import date_range +from pandas.core.indexes.period import ( + Period, + PeriodIndex, + period_range, +) +from pandas.core.resample import _get_period_range_edges + +from pandas.tseries import offsets + + +@pytest.fixture +def simple_period_range_series(): + """ + Series with period range index and random data for test purposes. + """ + + def _simple_period_range_series(start, end, freq="D"): + with warnings.catch_warnings(): + # suppress Period[B] deprecation warning + msg = "|".join(["Period with BDay freq", r"PeriodDtype\[B\] is deprecated"]) + warnings.filterwarnings( + "ignore", + msg, + category=FutureWarning, + ) + rng = period_range(start, end, freq=freq) + return Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + return _simple_period_range_series + + +class TestPeriodIndex: + @pytest.mark.parametrize("freq", ["2D", "1h", "2h"]) + def test_asfreq(self, frame_or_series, freq): + # GH 12884, 15944 + + obj = frame_or_series(range(5), index=period_range("2020-01-01", periods=5)) + + expected = obj.to_timestamp().resample(freq).asfreq() + result = obj.to_timestamp().resample(freq).asfreq() + tm.assert_almost_equal(result, expected) + + start = obj.index[0].to_timestamp(how="start") + end = (obj.index[-1] + obj.index.freq).to_timestamp(how="start") + new_index = date_range(start=start, end=end, freq=freq, inclusive="left") + expected = obj.to_timestamp().reindex(new_index).to_period(freq) + + result = obj.resample(freq).asfreq() + tm.assert_almost_equal(result, expected) + + result = obj.resample(freq).asfreq().to_timestamp().to_period() + tm.assert_almost_equal(result, expected) + + def test_asfreq_fill_value(self): + # test for fill value during resampling, issue 3715 + + index = period_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + s = Series(range(len(index)), index=index) + new_index = date_range( + s.index[0].to_timestamp(how="start"), + (s.index[-1]).to_timestamp(how="start"), + freq="1h", + ) + expected = s.to_timestamp().reindex(new_index, fill_value=4.0) + result = s.to_timestamp().resample("1h").asfreq(fill_value=4.0) + tm.assert_series_equal(result, expected) + + frame = s.to_frame("value") + new_index = date_range( + frame.index[0].to_timestamp(how="start"), + (frame.index[-1]).to_timestamp(how="start"), + freq="1h", + ) + expected = frame.to_timestamp().reindex(new_index, fill_value=3.0) + result = frame.to_timestamp().resample("1h").asfreq(fill_value=3.0) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("freq", ["h", "12h", "2D", "W"]) + @pytest.mark.parametrize("kwargs", [{"on": "date"}, {"level": "d"}]) + def test_selection(self, freq, kwargs): + # This is a bug, these should be implemented + # GH 14008 + index = period_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + rng = np.arange(len(index), dtype=np.int64) + df = DataFrame( + {"date": index, "a": rng}, + index=pd.MultiIndex.from_arrays([rng, index], names=["v", "d"]), + ) + msg = ( + "Resampling from level= or on= selection with a PeriodIndex is " + r"not currently supported, use \.set_index\(\.\.\.\) to " + "explicitly set index" + ) + with pytest.raises(NotImplementedError, match=msg): + df.resample(freq, **kwargs) + + @pytest.mark.parametrize("month", MONTHS) + @pytest.mark.parametrize("meth", ["ffill", "bfill"]) + @pytest.mark.parametrize("conv", ["start", "end"]) + @pytest.mark.parametrize( + ("offset", "period"), [("D", "D"), ("B", "B"), ("ME", "M"), ("QE", "Q")] + ) + def test_annual_upsample_cases( + self, offset, period, conv, meth, month, simple_period_range_series + ): + ts = simple_period_range_series("1/1/1990", "12/31/1990", freq=f"Y-{month}") + warn = FutureWarning if period == "B" else None + msg = r"PeriodDtype\[B\] is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result = getattr(ts.resample(period, convention=conv), meth)() + expected = result.to_timestamp(period, how=conv) + expected = expected.asfreq(offset, meth).to_period() + tm.assert_series_equal(result, expected) + + def test_basic_downsample(self, simple_period_range_series): + ts = simple_period_range_series("1/1/1990", "6/30/1995", freq="M") + result = ts.resample("Y-DEC").mean() + + expected = ts.groupby(ts.index.year).mean() + expected.index = period_range("1/1/1990", "6/30/1995", freq="Y-DEC") + tm.assert_series_equal(result, expected) + + # this is ok + tm.assert_series_equal(ts.resample("Y-DEC").mean(), result) + tm.assert_series_equal(ts.resample("Y").mean(), result) + + @pytest.mark.parametrize( + "rule,expected_error_msg", + [ + ("Y-DEC", ""), + ("Q-MAR", ""), + ("M", ""), + ("W-THU", ""), + ], + ) + def test_not_subperiod(self, simple_period_range_series, rule, expected_error_msg): + # These are incompatible period rules for resampling + ts = simple_period_range_series("1/1/1990", "6/30/1995", freq="W-WED") + msg = ( + "Frequency cannot be resampled to " + f"{expected_error_msg}, as they are not sub or super periods" + ) + with pytest.raises(IncompatibleFrequency, match=msg): + ts.resample(rule).mean() + + @pytest.mark.parametrize("freq", ["D", "2D"]) + def test_basic_upsample(self, freq, simple_period_range_series): + ts = simple_period_range_series("1/1/1990", "6/30/1995", freq="M") + result = ts.resample("Y-DEC").mean() + + resampled = result.resample(freq, convention="end").ffill() + expected = result.to_timestamp(freq, how="end") + expected = expected.asfreq(freq, "ffill").to_period(freq) + tm.assert_series_equal(resampled, expected) + + def test_upsample_with_limit(self): + rng = period_range("1/1/2000", periods=5, freq="Y") + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), rng) + + result = ts.resample("M", convention="end").ffill(limit=2) + expected = ts.asfreq("M").reindex(result.index, method="ffill", limit=2) + tm.assert_series_equal(result, expected) + + def test_annual_upsample(self, simple_period_range_series): + ts = simple_period_range_series("1/1/1990", "12/31/1995", freq="Y-DEC") + df = DataFrame({"a": ts}) + rdf = df.resample("D").ffill() + exp = df["a"].resample("D").ffill() + tm.assert_series_equal(rdf["a"], exp) + + def test_annual_upsample2(self): + rng = period_range("2000", "2003", freq="Y-DEC") + ts = Series([1, 2, 3, 4], index=rng) + + result = ts.resample("M").ffill() + ex_index = period_range("2000-01", "2003-12", freq="M") + + expected = ts.asfreq("M", how="start").reindex(ex_index, method="ffill") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("month", MONTHS) + @pytest.mark.parametrize("convention", ["start", "end"]) + @pytest.mark.parametrize( + ("offset", "period"), [("D", "D"), ("B", "B"), ("ME", "M")] + ) + def test_quarterly_upsample( + self, month, offset, period, convention, simple_period_range_series + ): + freq = f"Q-{month}" + ts = simple_period_range_series("1/1/1990", "12/31/1991", freq=freq) + warn = FutureWarning if period == "B" else None + msg = r"PeriodDtype\[B\] is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result = ts.resample(period, convention=convention).ffill() + expected = result.to_timestamp(period, how=convention) + expected = expected.asfreq(offset, "ffill").to_period() + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("target", ["D", "B"]) + @pytest.mark.parametrize("convention", ["start", "end"]) + def test_monthly_upsample(self, target, convention, simple_period_range_series): + ts = simple_period_range_series("1/1/1990", "12/31/1995", freq="M") + + warn = None if target == "D" else FutureWarning + msg = r"PeriodDtype\[B\] is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result = ts.resample(target, convention=convention).ffill() + expected = result.to_timestamp(target, how=convention) + expected = expected.asfreq(target, "ffill").to_period() + tm.assert_series_equal(result, expected) + + def test_resample_basic(self): + # GH3609 + s = Series( + range(100), + index=date_range("20130101", freq="s", periods=100, name="idx"), + dtype="float", + ) + s[10:30] = np.nan + index = PeriodIndex( + [Period("2013-01-01 00:00", "min"), Period("2013-01-01 00:01", "min")], + name="idx", + ) + expected = Series([34.5, 79.5], index=index) + result = s.to_period().resample("min").mean() + tm.assert_series_equal(result, expected) + result2 = s.resample("min").mean().to_period() + tm.assert_series_equal(result2, expected) + + @pytest.mark.parametrize( + "freq,expected_vals", [("M", [31, 29, 31, 9]), ("2M", [31 + 29, 31 + 9])] + ) + def test_resample_count(self, freq, expected_vals): + # GH12774 + series = Series(1, index=period_range(start="2000", periods=100)) + result = series.resample(freq).count() + expected_index = period_range( + start="2000", freq=freq, periods=len(expected_vals) + ) + expected = Series(expected_vals, index=expected_index) + tm.assert_series_equal(result, expected) + + def test_resample_same_freq(self, resample_method): + # GH12770 + series = Series(range(3), index=period_range(start="2000", periods=3, freq="M")) + expected = series + + result = getattr(series.resample("M"), resample_method)() + tm.assert_series_equal(result, expected) + + def test_resample_incompat_freq(self): + msg = ( + "Frequency cannot be resampled to , " + "as they are not sub or super periods" + ) + pi = period_range(start="2000", periods=3, freq="M") + ser = Series(range(3), index=pi) + rs = ser.resample("W") + with pytest.raises(IncompatibleFrequency, match=msg): + # TODO: should this raise at the resample call instead of at the mean call? + rs.mean() + + @pytest.mark.parametrize( + "tz", + [ + zoneinfo.ZoneInfo("America/Los_Angeles"), + dateutil.tz.gettz("America/Los_Angeles"), + ], + ) + def test_with_local_timezone(self, tz): + # see gh-5430 + local_timezone = tz + + start = datetime( + year=2013, month=11, day=1, hour=0, minute=0, tzinfo=timezone.utc + ) + # 1 day later + end = datetime( + year=2013, month=11, day=2, hour=0, minute=0, tzinfo=timezone.utc + ) + + index = date_range(start, end, freq="h", name="idx") + + series = Series(1, index=index) + series = series.tz_convert(local_timezone) + msg = "Converting to PeriodArray/Index representation will drop timezone" + with tm.assert_produces_warning(UserWarning, match=msg): + result = series.resample("D").mean().to_period() + + # Create the expected series + # Index is moved back a day with the timezone conversion from UTC to + # Pacific + expected_index = ( + period_range(start=start, end=end, freq="D", name="idx") - offsets.Day() + ) + expected = Series(1.0, index=expected_index) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "tz", + [ + zoneinfo.ZoneInfo("America/Los_Angeles"), + dateutil.tz.gettz("America/Los_Angeles"), + ], + ) + def test_resample_with_tz(self, tz, unit): + # GH 13238 + dti = date_range("2017-01-01", periods=48, freq="h", tz=tz, unit=unit) + ser = Series(2, index=dti) + result = ser.resample("D").mean() + exp_dti = pd.DatetimeIndex( + ["2017-01-01", "2017-01-02"], tz=tz, freq="D" + ).as_unit(unit) + expected = Series( + 2.0, + index=exp_dti, + ) + tm.assert_series_equal(result, expected) + + def test_resample_nonexistent_time_bin_edge(self): + # GH 19375 + index = date_range("2017-03-12", "2017-03-12 1:45:00", freq="15min") + s = Series(np.zeros(len(index)), index=index) + expected = s.tz_localize("US/Pacific") + expected.index = pd.DatetimeIndex(expected.index, freq="900s") + result = expected.resample("900s").mean() + tm.assert_series_equal(result, expected) + + def test_resample_nonexistent_time_bin_edge2(self): + # GH 23742 + index = date_range(start="2017-10-10", end="2017-10-20", freq="1h") + index = index.tz_localize("UTC").tz_convert("America/Sao_Paulo") + df = DataFrame(data=list(range(len(index))), index=index) + result = df.groupby(pd.Grouper(freq="1D")).count() + expected = date_range( + start="2017-10-09", + end="2017-10-20", + freq="D", + tz="America/Sao_Paulo", + nonexistent="shift_forward", + inclusive="left", + ) + tm.assert_index_equal(result.index, expected) + + def test_resample_ambiguous_time_bin_edge(self): + # GH 10117 + idx = date_range( + "2014-10-25 22:00:00", + "2014-10-26 00:30:00", + freq="30min", + tz="Europe/London", + ) + expected = Series(np.zeros(len(idx)), index=idx) + result = expected.resample("30min").mean() + tm.assert_series_equal(result, expected) + + def test_fill_method_and_how_upsample(self): + # GH2073 + s = Series( + np.arange(9, dtype="int64"), + index=date_range("2010-01-01", periods=9, freq="QE"), + ) + last = s.resample("ME").ffill() + both = s.resample("ME").ffill().resample("ME").last().astype("int64") + tm.assert_series_equal(last, both) + + @pytest.mark.parametrize("day", DAYS) + @pytest.mark.parametrize("target", ["D", "B"]) + @pytest.mark.parametrize("convention", ["start", "end"]) + def test_weekly_upsample(self, day, target, convention, simple_period_range_series): + freq = f"W-{day}" + ts = simple_period_range_series("1/1/1990", "07/31/1990", freq=freq) + warn = None if target == "D" else FutureWarning + msg = r"PeriodDtype\[B\] is deprecated" + with tm.assert_produces_warning(warn, match=msg): + result = ts.resample(target, convention=convention).ffill() + expected = result.to_timestamp(target, how=convention) + expected = expected.asfreq(target, "ffill").to_period() + tm.assert_series_equal(result, expected) + + def test_resample_to_timestamps(self, simple_period_range_series): + ts = simple_period_range_series("1/1/1990", "12/31/1995", freq="M") + + result = ts.resample("Y-DEC").mean().to_timestamp() + expected = ts.resample("Y-DEC").mean().to_timestamp(how="start") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("month", MONTHS) + def test_resample_to_quarterly(self, simple_period_range_series, month): + ts = simple_period_range_series("1990", "1992", freq=f"Y-{month}") + quar_ts = ts.resample(f"Q-{month}").ffill() + + stamps = ts.to_timestamp("D", how="start") + qdates = period_range( + ts.index[0].asfreq("D", "start"), + ts.index[-1].asfreq("D", "end"), + freq=f"Q-{month}", + ) + + expected = stamps.reindex(qdates.to_timestamp("D", "s"), method="ffill") + expected.index = qdates + + tm.assert_series_equal(quar_ts, expected) + + @pytest.mark.parametrize("how", ["start", "end"]) + def test_resample_to_quarterly_start_end(self, simple_period_range_series, how): + # conforms, but different month + ts = simple_period_range_series("1990", "1992", freq="Y-JUN") + result = ts.resample("Q-MAR", convention=how).ffill() + expected = ts.asfreq("Q-MAR", how=how) + expected = expected.reindex(result.index, method="ffill") + + # FIXME: don't leave commented-out + # .to_timestamp('D') + # expected = expected.resample('Q-MAR').ffill() + + tm.assert_series_equal(result, expected) + + def test_resample_fill_missing(self): + rng = PeriodIndex([2000, 2005, 2007, 2009], freq="Y") + + s = Series(np.random.default_rng(2).standard_normal(4), index=rng) + + stamps = s.to_timestamp() + filled = s.resample("Y").ffill() + expected = stamps.resample("YE").ffill().to_period("Y") + tm.assert_series_equal(filled, expected) + + def test_cant_fill_missing_dups(self): + rng = PeriodIndex([2000, 2005, 2005, 2007, 2007], freq="Y") + s = Series(np.random.default_rng(2).standard_normal(5), index=rng) + msg = "Reindexing only valid with uniquely valued Index objects" + with pytest.raises(InvalidIndexError, match=msg): + s.resample("Y").ffill() + + def test_resample_5minute(self): + rng = period_range("1/1/2000", "1/5/2000", freq="min") + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + expected = ts.to_timestamp().resample("5min").mean() + result = ts.resample("5min").mean().to_timestamp() + tm.assert_series_equal(result, expected) + + expected = expected.to_period("5min") + result = ts.resample("5min").mean() + tm.assert_series_equal(result, expected) + result = ts.resample("5min").mean().to_timestamp().to_period() + tm.assert_series_equal(result, expected) + + def test_upsample_daily_business_daily(self, simple_period_range_series): + ts = simple_period_range_series("1/1/2000", "2/1/2000", freq="B") + + result = ts.resample("D").asfreq() + expected = ts.asfreq("D").reindex(period_range("1/3/2000", "2/1/2000")) + tm.assert_series_equal(result, expected) + + ts = simple_period_range_series("1/1/2000", "2/1/2000") + result = ts.resample("h", convention="s").asfreq() + exp_rng = period_range("1/1/2000", "2/1/2000 23:00", freq="h") + expected = ts.asfreq("h", how="s").reindex(exp_rng) + tm.assert_series_equal(result, expected) + + def test_resample_irregular_sparse(self): + dr = date_range(start="1/1/2012", freq="5min", periods=1000) + s = Series(np.array(100), index=dr) + # subset the data. + subset = s[:"2012-01-04 06:55"] + + result = subset.resample("10min").apply(len) + expected = s.resample("10min").apply(len).loc[result.index] + tm.assert_series_equal(result, expected) + + def test_resample_weekly_all_na(self): + rng = date_range("1/1/2000", periods=10, freq="W-WED") + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + result = ts.resample("W-THU").asfreq() + + assert result.isna().all() + + result = ts.resample("W-THU").asfreq().ffill()[:-1] + expected = ts.asfreq("W-THU").ffill() + tm.assert_series_equal(result, expected) + + def test_resample_tz_localized(self, unit): + dr = date_range(start="2012-4-13", end="2012-5-1", unit=unit) + ts = Series(range(len(dr)), index=dr) + + ts_utc = ts.tz_localize("UTC") + ts_local = ts_utc.tz_convert("America/Los_Angeles") + + result = ts_local.resample("W").mean() + + ts_local_naive = ts_local.copy() + ts_local_naive.index = ts_local_naive.index.tz_localize(None) + + exp = ts_local_naive.resample("W").mean().tz_localize("America/Los_Angeles") + exp.index = pd.DatetimeIndex(exp.index, freq="W") + + tm.assert_series_equal(result, exp) + + # it works + result = ts_local.resample("D").mean() + + def test_resample_tz_localized2(self): + # #2245 + idx = date_range( + "2001-09-20 15:59", "2001-09-20 16:00", freq="min", tz="Australia/Sydney" + ) + s = Series([1, 2], index=idx) + + # GH#61985 changed this to behave like "B" rather than "24h" + result = s.resample("D", closed="right", label="right").mean() + ex_index = date_range("2001-09-20", periods=2, freq="D", tz="Australia/Sydney") + expected = Series([np.nan, 1.5], index=ex_index) + + tm.assert_series_equal(result, expected) + + # for good measure + msg = "Converting to PeriodArray/Index representation will drop timezone " + with tm.assert_produces_warning(UserWarning, match=msg): + result = s.resample("D").mean().to_period() + ex_index = period_range("2001-09-20", periods=1, freq="D") + expected = Series([1.5], index=ex_index) + tm.assert_series_equal(result, expected) + + def test_resample_tz_localized3(self): + # GH 6397 + # comparing an offset that doesn't propagate tz's + rng = date_range("1/1/2011", periods=20000, freq="h") + rng = rng.tz_localize("EST") + ts = DataFrame(index=rng) + ts["first"] = np.random.default_rng(2).standard_normal(len(rng)) + ts["second"] = np.cumsum(np.random.default_rng(2).standard_normal(len(rng))) + expected = DataFrame( + { + "first": ts.resample("YE").sum()["first"], + "second": ts.resample("YE").mean()["second"], + }, + columns=["first", "second"], + ) + result = ( + ts.resample("YE") + .agg({"first": "sum", "second": "mean"}) + .reindex(columns=["first", "second"]) + ) + tm.assert_frame_equal(result, expected) + + def test_closed_left_corner(self): + # #1465 + s = Series( + np.random.default_rng(2).standard_normal(21), + index=date_range(start="1/1/2012 9:30", freq="1min", periods=21), + ) + s.iloc[0] = np.nan + + result = s.resample("10min", closed="left", label="right").mean() + exp = s[1:].resample("10min", closed="left", label="right").mean() + tm.assert_series_equal(result, exp) + + result = s.resample("10min", closed="left", label="left").mean() + exp = s[1:].resample("10min", closed="left", label="left").mean() + + ex_index = date_range(start="1/1/2012 9:30", freq="10min", periods=3) + + tm.assert_index_equal(result.index, ex_index) + tm.assert_series_equal(result, exp) + + def test_quarterly_resampling(self): + rng = period_range("2000Q1", periods=10, freq="Q-DEC") + ts = Series(np.arange(10), index=rng) + + result = ts.resample("Y").mean() + exp = ts.to_timestamp().resample("YE").mean().to_period() + tm.assert_series_equal(result, exp) + + def test_resample_weekly_bug_1726(self): + # 8/6/12 is a Monday + ind = date_range(start="8/6/2012", end="8/26/2012", freq="D") + n = len(ind) + data = [[x] * 5 for x in range(n)] + df = DataFrame(data, columns=["open", "high", "low", "close", "vol"], index=ind) + + # it works! + df.resample("W-MON", closed="left", label="left").first() + + def test_resample_with_dst_time_change(self): + # GH 15549 + index = ( + pd.DatetimeIndex([1457537600000000000, 1458059600000000000]) + .tz_localize("UTC") + .tz_convert("America/Chicago") + ) + df = DataFrame([1, 2], index=index) + result = df.resample("12h", closed="right", label="right").last().ffill() + + expected_index_values = [ + "2016-03-09 12:00:00-06:00", + "2016-03-10 00:00:00-06:00", + "2016-03-10 12:00:00-06:00", + "2016-03-11 00:00:00-06:00", + "2016-03-11 12:00:00-06:00", + "2016-03-12 00:00:00-06:00", + "2016-03-12 12:00:00-06:00", + "2016-03-13 00:00:00-06:00", + "2016-03-13 13:00:00-05:00", + "2016-03-14 01:00:00-05:00", + "2016-03-14 13:00:00-05:00", + "2016-03-15 01:00:00-05:00", + "2016-03-15 13:00:00-05:00", + ] + index = ( + pd.to_datetime(expected_index_values, utc=True) + .tz_convert("America/Chicago") + .as_unit(index.unit) + ) + index = pd.DatetimeIndex(index, freq="12h") + expected = DataFrame( + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0], + index=index, + ) + tm.assert_frame_equal(result, expected) + + def test_resample_bms_2752(self): + # GH2753 + timeseries = Series( + index=pd.bdate_range("20000101", "20000201"), dtype=np.float64 + ) + res1 = timeseries.resample("BMS").mean() + res2 = timeseries.resample("BMS").mean().resample("B").mean() + assert res1.index[0] == Timestamp("20000103") + assert res1.index[0] == res2.index[0] + + @pytest.mark.xfail(reason="Commented out for more than 3 years. Should this work?") + def test_monthly_convention_span(self): + rng = period_range("2000-01", periods=3, freq="ME") + ts = Series(np.arange(3), index=rng) + + # hacky way to get same thing + exp_index = period_range("2000-01-01", "2000-03-31", freq="D") + expected = ts.asfreq("D", how="end").reindex(exp_index) + expected = expected.fillna(method="bfill") + + result = ts.resample("D").mean() + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "from_freq, to_freq", [("D", "ME"), ("QE", "YE"), ("ME", "QE"), ("D", "W")] + ) + def test_default_right_closed_label(self, from_freq, to_freq): + idx = date_range(start="8/15/2012", periods=100, freq=from_freq) + df = DataFrame(np.random.default_rng(2).standard_normal((len(idx), 2)), idx) + + resampled = df.resample(to_freq).mean() + tm.assert_frame_equal( + resampled, df.resample(to_freq, closed="right", label="right").mean() + ) + + @pytest.mark.parametrize( + "from_freq, to_freq", + [("D", "MS"), ("QE", "YS"), ("ME", "QS"), ("h", "D"), ("min", "h")], + ) + def test_default_left_closed_label(self, from_freq, to_freq): + idx = date_range(start="8/15/2012", periods=100, freq=from_freq) + df = DataFrame(np.random.default_rng(2).standard_normal((len(idx), 2)), idx) + + resampled = df.resample(to_freq).mean() + tm.assert_frame_equal( + resampled, df.resample(to_freq, closed="left", label="left").mean() + ) + + def test_all_values_single_bin(self): + # GH#2070 + index = period_range(start="2012-01-01", end="2012-12-31", freq="M") + ser = Series(np.random.default_rng(2).standard_normal(len(index)), index=index) + + result = ser.resample("Y").mean() + tm.assert_almost_equal(result.iloc[0], ser.mean()) + + def test_evenly_divisible_with_no_extra_bins(self): + # GH#4076 + # when the frequency is evenly divisible, sometimes extra bins + + df = DataFrame( + np.random.default_rng(2).standard_normal((9, 3)), + index=date_range("2000-1-1", periods=9, unit="ns"), + ) + result = df.resample("5D").mean() + expected = pd.concat([df.iloc[0:5].mean(), df.iloc[5:].mean()], axis=1).T + expected.index = pd.DatetimeIndex( + [Timestamp("2000-1-1"), Timestamp("2000-1-6")], dtype="M8[ns]", freq="5D" + ) + tm.assert_frame_equal(result, expected) + + def test_evenly_divisible_with_no_extra_bins2(self): + index = date_range(start="2001-5-4", periods=28) + df = DataFrame( + [ + { + "REST_KEY": 1, + "DLY_TRN_QT": 80, + "DLY_SLS_AMT": 90, + "COOP_DLY_TRN_QT": 30, + "COOP_DLY_SLS_AMT": 20, + } + ] + * 28 + + [ + { + "REST_KEY": 2, + "DLY_TRN_QT": 70, + "DLY_SLS_AMT": 10, + "COOP_DLY_TRN_QT": 50, + "COOP_DLY_SLS_AMT": 20, + } + ] + * 28, + index=index.append(index), + ).sort_index() + + index = date_range("2001-5-4", periods=4, freq="7D") + expected = DataFrame( + [ + { + "REST_KEY": 14, + "DLY_TRN_QT": 14, + "DLY_SLS_AMT": 14, + "COOP_DLY_TRN_QT": 14, + "COOP_DLY_SLS_AMT": 14, + } + ] + * 4, + index=index, + ) + result = df.resample("7D").count() + tm.assert_frame_equal(result, expected) + + expected = DataFrame( + [ + { + "REST_KEY": 21, + "DLY_TRN_QT": 1050, + "DLY_SLS_AMT": 700, + "COOP_DLY_TRN_QT": 560, + "COOP_DLY_SLS_AMT": 280, + } + ] + * 4, + index=index, + ) + result = df.resample("7D").sum() + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("freq, period_mult", [("h", 24), ("12h", 2)]) + def test_upsampling_ohlc(self, freq, period_mult): + # GH 13083 + pi = period_range(start="2000", freq="D", periods=10) + s = Series(range(len(pi)), index=pi) + expected = s.to_timestamp().resample(freq).ohlc().to_period(freq) + + # timestamp-based resampling doesn't include all sub-periods + # of the last original period, so extend accordingly: + new_index = period_range(start="2000", freq=freq, periods=period_mult * len(pi)) + expected = expected.reindex(new_index) + result = s.resample(freq).ohlc() + tm.assert_frame_equal(result, expected) + + result = s.resample(freq).ohlc().to_timestamp().to_period() + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "periods, values", + [ + ( + [ + pd.NaT, + "1970-01-01 00:00:00", + pd.NaT, + "1970-01-01 00:00:02", + "1970-01-01 00:00:03", + ], + [2, 3, 5, 7, 11], + ), + ( + [ + pd.NaT, + pd.NaT, + "1970-01-01 00:00:00", + pd.NaT, + pd.NaT, + pd.NaT, + "1970-01-01 00:00:02", + "1970-01-01 00:00:03", + pd.NaT, + pd.NaT, + ], + [1, 2, 3, 5, 6, 8, 7, 11, 12, 13], + ), + ], + ) + @pytest.mark.parametrize( + "freq, expected_values", + [ + ("1s", [3, np.nan, 7, 11]), + ("2s", [3, (7 + 11) / 2]), + ("3s", [(3 + 7) / 2, 11]), + ], + ) + def test_resample_with_nat(self, periods, values, freq, expected_values): + # GH 13224 + index = PeriodIndex(periods, freq="s") + frame = DataFrame(values, index=index) + + expected_index = period_range( + "1970-01-01 00:00:00", periods=len(expected_values), freq=freq + ) + expected = DataFrame(expected_values, index=expected_index) + result = frame.resample(freq).mean() + tm.assert_frame_equal(result, expected) + + def test_resample_with_only_nat(self): + # GH 13224 + pi = PeriodIndex([pd.NaT] * 3, freq="s") + frame = DataFrame([2, 3, 5], index=pi, columns=["a"]) + expected_index = PeriodIndex(data=[], freq=pi.freq) + expected = DataFrame(index=expected_index, columns=["a"], dtype="float64") + result = frame.resample("1s").mean() + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "start,end,start_freq,end_freq,offset", + [ + ("19910905", "19910909 03:00", "h", "24h", "10h"), + ("19910905", "19910909 12:00", "h", "24h", "10h"), + ("19910905", "19910909 23:00", "h", "24h", "10h"), + ("19910905 10:00", "19910909", "h", "24h", "10h"), + ("19910905 10:00", "19910909 10:00", "h", "24h", "10h"), + ("19910905", "19910909 10:00", "h", "24h", "10h"), + ("19910905 12:00", "19910909", "h", "24h", "10h"), + ("19910905 12:00", "19910909 03:00", "h", "24h", "10h"), + ("19910905 12:00", "19910909 12:00", "h", "24h", "10h"), + ("19910905 12:00", "19910909 12:00", "h", "24h", "34h"), + ("19910905 12:00", "19910909 12:00", "h", "17h", "10h"), + ("19910905 12:00", "19910909 12:00", "h", "17h", "3h"), + ("19910905", "19910913 06:00", "2h", "24h", "10h"), + ("19910905", "19910905 01:39", "Min", "5Min", "3Min"), + ("19910905", "19910905 03:18", "2Min", "5Min", "3Min"), + ], + ) + def test_resample_with_offset(self, start, end, start_freq, end_freq, offset): + # GH 23882 & 31809 + pi = period_range(start, end, freq=start_freq) + ser = Series(np.arange(len(pi)), index=pi) + result = ser.resample(end_freq, offset=offset).mean() + result = result.to_timestamp(end_freq) + + expected = ser.to_timestamp().resample(end_freq, offset=offset).mean() + tm.assert_series_equal(result, expected) + + def test_resample_with_offset_month(self): + # GH 23882 & 31809 + pi = period_range("19910905 12:00", "19910909 1:00", freq="h") + ser = Series(np.arange(len(pi)), index=pi) + result = ser.resample("M").mean() + result = result.to_timestamp("M") + expected = ser.to_timestamp().resample("ME").mean() + # TODO: is non-tick the relevant characteristic? (GH 33815) + expected.index = expected.index._with_freq(None) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "first,last,freq,freq_to_offset,exp_first,exp_last", + [ + ("19910905", "19920406", "D", "D", "19910905", "19920406"), + ("19910905 00:00", "19920406 06:00", "D", "D", "19910905", "19920406"), + ( + "19910905 06:00", + "19920406 06:00", + "h", + "h", + "19910905 06:00", + "19920406 06:00", + ), + ("19910906", "19920406", "M", "ME", "1991-09", "1992-04"), + ("19910831", "19920430", "M", "ME", "1991-08", "1992-04"), + ("1991-08", "1992-04", "M", "ME", "1991-08", "1992-04"), + ], + ) + def test_get_period_range_edges( + self, first, last, freq, freq_to_offset, exp_first, exp_last + ): + first = Period(first) + last = Period(last) + + exp_first = Period(exp_first, freq=freq) + exp_last = Period(exp_last, freq=freq) + + freq = pd.tseries.frequencies.to_offset(freq_to_offset) + result = _get_period_range_edges(first, last, freq) + expected = (exp_first, exp_last) + assert result == expected + + def test_sum_min_count(self): + # GH 19974 + index = date_range(start="2018", freq="ME", periods=6) + data = np.ones(6) + data[3:6] = np.nan + s = Series(data, index).to_period() + result = s.resample("Q").sum(min_count=1) + expected = Series( + [3.0, np.nan], index=PeriodIndex(["2018Q1", "2018Q2"], freq="Q-DEC") + ) + tm.assert_series_equal(result, expected) + + def test_resample_t_l_deprecated(self): + # GH#52536 + msg_t = "Invalid frequency: T" + msg_l = "Invalid frequency: L" + + with pytest.raises(ValueError, match=msg_l): + period_range( + "2020-01-01 00:00:00 00:00", "2020-01-01 00:00:00 00:01", freq="L" + ) + rng_l = period_range( + "2020-01-01 00:00:00 00:00", "2020-01-01 00:00:00 00:01", freq="ms" + ) + ser = Series(np.arange(len(rng_l)), index=rng_l) + + with pytest.raises(ValueError, match=msg_t): + ser.resample("T").mean() + + @pytest.mark.parametrize( + "freq, freq_depr, freq_depr_res", + [ + ("2Q", "2q", "2y"), + ("2M", "2m", "2q"), + ], + ) + def test_resample_lowercase_frequency_raises(self, freq, freq_depr, freq_depr_res): + msg = f"Invalid frequency: {freq_depr}" + with pytest.raises(ValueError, match=msg): + period_range("2020-01-01", "2020-08-01", freq=freq_depr) + + msg = f"Invalid frequency: {freq_depr_res}" + rng = period_range("2020-01-01", "2020-08-01", freq=freq) + ser = Series(np.arange(len(rng)), index=rng) + with pytest.raises(ValueError, match=msg): + ser.resample(freq_depr_res).mean() + + @pytest.mark.parametrize( + "offset", + [ + offsets.MonthBegin(), + offsets.BYearBegin(2), + offsets.BusinessHour(2), + ], + ) + def test_asfreq_invalid_period_offset(self, offset, frame_or_series): + # GH#55785 + msg = re.escape(f"{offset} is not supported as period frequency") + + obj = frame_or_series(range(5), index=period_range("2020-01-01", periods=5)) + with pytest.raises(ValueError, match=msg): + obj.asfreq(freq=offset) + + +@pytest.mark.parametrize( + "freq", + [ + ("2ME"), + ("2QE"), + ("2QE-FEB"), + ("2YE"), + ("2YE-MAR"), + ("2me"), + ("2qe"), + ("2ye-mar"), + ], +) +def test_resample_frequency_ME_QE_YE_raises(frame_or_series, freq): + # GH#9586 + msg = f"{freq[1:]} is not supported as period frequency" + + obj = frame_or_series(range(5), index=period_range("2020-01-01", periods=5)) + msg = f"Invalid frequency: {freq}" + with pytest.raises(ValueError, match=msg): + obj.resample(freq) + + +def test_corner_cases_period(simple_period_range_series): + # miscellaneous test coverage + len0pts = simple_period_range_series("2007-01", "2010-05", freq="M")[:0] + # it works + result = len0pts.resample("Y-DEC").mean() + assert len(result) == 0 + + +@pytest.mark.parametrize("freq", ["2BME", "2CBME", "2SME", "2BQE-FEB", "2BYE-MAR"]) +def test_resample_frequency_invalid_freq(frame_or_series, freq): + # GH#9586 + msg = f"Invalid frequency: {freq}" + + obj = frame_or_series(range(5), index=period_range("2020-01-01", periods=5)) + with pytest.raises(ValueError, match=msg): + obj.resample(freq) diff --git a/pandas/tests/resample/test_resample_api.py b/pandas/tests/resample/test_resample_api.py new file mode 100644 index 0000000000000000000000000000000000000000..36ef01178b3bc905bc4099fe6cc3e271f3872b5f --- /dev/null +++ b/pandas/tests/resample/test_resample_api.py @@ -0,0 +1,1018 @@ +from datetime import datetime +import re + +import numpy as np +import pytest + +from pandas._libs import lib +from pandas._libs.tslibs import Day + +import pandas as pd +from pandas import ( + DataFrame, + NamedAgg, + Series, +) +import pandas._testing as tm +from pandas.core.indexes.datetimes import date_range + + +@pytest.fixture +def dti(): + return date_range(start=datetime(2005, 1, 1), end=datetime(2005, 1, 10), freq="Min") + + +@pytest.fixture +def _test_series(dti): + return Series(np.random.default_rng(2).random(len(dti)), dti) + + +@pytest.fixture +def test_frame(dti, _test_series): + return DataFrame({"A": _test_series, "B": _test_series, "C": np.arange(len(dti))}) + + +def test_str(_test_series): + r = _test_series.resample("h") + assert ( + "DatetimeIndexResampler [freq=, closed=left, " + "label=left, convention=start, origin=start_day]" in str(r) + ) + + r = _test_series.resample("h", origin="2000-01-01") + assert ( + "DatetimeIndexResampler [freq=, closed=left, " + "label=left, convention=start, origin=2000-01-01 00:00:00]" in str(r) + ) + + +def test_api(_test_series): + r = _test_series.resample("h") + result = r.mean() + assert isinstance(result, Series) + assert len(result) == 217 + + r = _test_series.to_frame().resample("h") + result = r.mean() + assert isinstance(result, DataFrame) + assert len(result) == 217 + + +def test_groupby_resample_api(): + # GH 12448 + # .groupby(...).resample(...) hitting warnings + # when appropriate + df = DataFrame( + { + "date": date_range(start="2016-01-01", periods=4, freq="W"), + "group": [1, 1, 2, 2], + "val": [5, 6, 7, 8], + } + ).set_index("date") + + # replication step + i = ( + date_range("2016-01-03", periods=8).tolist() + + date_range("2016-01-17", periods=8).tolist() + ) + index = pd.MultiIndex.from_arrays([[1] * 8 + [2] * 8, i], names=["group", "date"]) + expected = DataFrame({"val": [5] * 7 + [6] + [7] * 7 + [8]}, index=index) + result = df.groupby("group").apply(lambda x: x.resample("1D").ffill())[["val"]] + tm.assert_frame_equal(result, expected) + + +def test_groupby_resample_on_api(): + # GH 15021 + # .groupby(...).resample(on=...) results in an unexpected + # keyword warning. + df = DataFrame( + { + "key": ["A", "B"] * 5, + "dates": date_range("2016-01-01", periods=10), + "values": np.random.default_rng(2).standard_normal(10), + } + ) + + expected = df.set_index("dates").groupby("key").resample("D").mean() + result = df.groupby("key").resample("D", on="dates").mean() + tm.assert_frame_equal(result, expected) + + +def test_resample_group_keys(): + df = DataFrame({"A": 1, "B": 2}, index=date_range("2000", periods=10, unit="ns")) + expected = df.copy() + + # group_keys=False + g = df.resample("5D", group_keys=False) + result = g.apply(lambda x: x) + tm.assert_frame_equal(result, expected) + + # group_keys defaults to False + g = df.resample("5D") + result = g.apply(lambda x: x) + tm.assert_frame_equal(result, expected) + + # group_keys=True + expected.index = pd.MultiIndex.from_arrays( + [ + pd.to_datetime(["2000-01-01", "2000-01-06"]).as_unit("ns").repeat(5), + expected.index, + ] + ) + g = df.resample("5D", group_keys=True) + result = g.apply(lambda x: x) + tm.assert_frame_equal(result, expected) + + +def test_pipe(test_frame, _test_series): + # GH17905 + + # series + r = _test_series.resample("h") + expected = r.max() - r.mean() + result = r.pipe(lambda x: x.max() - x.mean()) + tm.assert_series_equal(result, expected) + + # dataframe + r = test_frame.resample("h") + expected = r.max() - r.mean() + result = r.pipe(lambda x: x.max() - x.mean()) + tm.assert_frame_equal(result, expected) + + +def test_getitem(test_frame): + r = test_frame.resample("h") + tm.assert_index_equal(r._selected_obj.columns, test_frame.columns) + + r = test_frame.resample("h")["B"] + assert r._selected_obj.name == test_frame.columns[1] + + # technically this is allowed + r = test_frame.resample("h")["A", "B"] + tm.assert_index_equal(r._selected_obj.columns, test_frame.columns[[0, 1]]) + + r = test_frame.resample("h")["A", "B"] + tm.assert_index_equal(r._selected_obj.columns, test_frame.columns[[0, 1]]) + + +@pytest.mark.parametrize("key", [["D"], ["A", "D"]]) +def test_select_bad_cols(key, test_frame): + g = test_frame.resample("h") + # 'A' should not be referenced as a bad column... + # will have to rethink regex if you change message! + msg = r"^\"Columns not found: 'D'\"$" + with pytest.raises(KeyError, match=msg): + g[key] + + +def test_attribute_access(test_frame): + r = test_frame.resample("h") + tm.assert_series_equal(r.A.sum(), r["A"].sum()) + + +@pytest.mark.parametrize("attr", ["groups", "ngroups", "indices"]) +def test_api_compat_before_use(attr): + # make sure that we are setting the binner + # on these attributes + rng = date_range("1/1/2012", periods=100, freq="s") + ts = Series(np.arange(len(rng)), index=rng) + rs = ts.resample("30s") + + # before use + getattr(rs, attr) + + # after grouper is initialized is ok + rs.mean() + getattr(rs, attr) + + +def tests_raises_on_nuisance(test_frame, using_infer_string): + df = test_frame + df["D"] = "foo" + r = df.resample("h") + result = r[["A", "B"]].mean() + expected = pd.concat([r.A.mean(), r.B.mean()], axis=1) + tm.assert_frame_equal(result, expected) + + expected = r[["A", "B", "C"]].mean() + msg = re.escape("agg function failed [how->mean,dtype->") + if using_infer_string: + msg = "dtype 'str' does not support operation 'mean'" + with pytest.raises(TypeError, match=msg): + r.mean() + result = r.mean(numeric_only=True) + tm.assert_frame_equal(result, expected) + + +def test_downsample_but_actually_upsampling(): + # this is reindex / asfreq + rng = date_range("1/1/2012", periods=100, freq="s") + ts = Series(np.arange(len(rng), dtype="int64"), index=rng) + result = ts.resample("20s").asfreq() + expected = Series( + [0, 20, 40, 60, 80], + index=date_range("2012-01-01 00:00:00", freq="20s", periods=5), + ) + tm.assert_series_equal(result, expected) + + +def test_combined_up_downsampling_of_irregular(): + # since we are really doing an operation like this + # ts2.resample('2s').mean().ffill() + # preserve these semantics + + rng = date_range("1/1/2012", periods=100, freq="s", unit="ns") + ts = Series(np.arange(len(rng)), index=rng) + ts2 = ts.iloc[[0, 1, 2, 3, 5, 7, 11, 15, 16, 25, 30]] + + result = ts2.resample("2s").mean().ffill() + expected = Series( + [ + 0.5, + 2.5, + 5.0, + 7.0, + 7.0, + 11.0, + 11.0, + 15.0, + 16.0, + 16.0, + 16.0, + 16.0, + 25.0, + 25.0, + 25.0, + 30.0, + ], + index=pd.DatetimeIndex( + [ + "2012-01-01 00:00:00", + "2012-01-01 00:00:02", + "2012-01-01 00:00:04", + "2012-01-01 00:00:06", + "2012-01-01 00:00:08", + "2012-01-01 00:00:10", + "2012-01-01 00:00:12", + "2012-01-01 00:00:14", + "2012-01-01 00:00:16", + "2012-01-01 00:00:18", + "2012-01-01 00:00:20", + "2012-01-01 00:00:22", + "2012-01-01 00:00:24", + "2012-01-01 00:00:26", + "2012-01-01 00:00:28", + "2012-01-01 00:00:30", + ], + dtype="datetime64[ns]", + freq="2s", + ), + ) + tm.assert_series_equal(result, expected) + + +def test_transform_series(_test_series): + r = _test_series.resample("20min") + expected = _test_series.groupby(pd.Grouper(freq="20min")).transform("mean") + result = r.transform("mean") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("on", [None, "date"]) +def test_transform_frame(on): + # GH#47079 + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + index.name = "date" + df = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=list("AB"), index=index + ) + expected = df.groupby(pd.Grouper(freq="20min")).transform("mean") + if on == "date": + # Move date to being a column; result will then have a RangeIndex + expected = expected.reset_index(drop=True) + df = df.reset_index() + + r = df.resample("20min", on=on) + result = r.transform("mean") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "func", + [ + lambda x: x.resample("20min", group_keys=False), + lambda x: x.groupby(pd.Grouper(freq="20min"), group_keys=False), + ], + ids=["resample", "groupby"], +) +def test_apply_without_aggregation(func, _test_series): + # both resample and groupby should work w/o aggregation + t = func(_test_series) + result = t.apply(lambda x: x) + tm.assert_series_equal(result, _test_series) + + +def test_apply_without_aggregation2(_test_series): + grouped = _test_series.to_frame(name="foo").resample("20min", group_keys=False) + result = grouped["foo"].apply(lambda x: x) + tm.assert_series_equal(result, _test_series.rename("foo")) + + +def test_agg_consistency(): + # make sure that we are consistent across + # similar aggregations with and w/o selection list + df = DataFrame( + np.random.default_rng(2).standard_normal((1000, 3)), + index=date_range("1/1/2012", freq="s", periods=1000), + columns=["A", "B", "C"], + ) + + r = df.resample("3min") + + msg = r"Label\(s\) \['r1', 'r2'\] do not exist" + with pytest.raises(KeyError, match=msg): + r.agg({"r1": "mean", "r2": "sum"}) + + +def test_agg_consistency_int_str_column_mix(): + # GH#39025 + df = DataFrame( + np.random.default_rng(2).standard_normal((1000, 2)), + index=date_range("1/1/2012", freq="s", periods=1000), + columns=[1, "a"], + ) + + r = df.resample("3min") + + msg = r"Label\(s\) \[2, 'b'\] do not exist" + with pytest.raises(KeyError, match=msg): + r.agg({2: "mean", "b": "sum"}) + + +# TODO(GH#14008): once GH 14008 is fixed, move these tests into +# `Base` test class + + +@pytest.fixture +def index(): + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D", unit="ns") + index.name = "date" + return index + + +@pytest.fixture +def df(index): + frame = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=list("AB"), index=index + ) + return frame + + +@pytest.fixture +def df_col(df): + return df.reset_index() + + +@pytest.fixture +def df_mult(df_col, index): + df_mult = df_col.copy() + df_mult.index = pd.MultiIndex.from_arrays( + [range(10), index], names=["index", "date"] + ) + return df_mult + + +@pytest.fixture +def a_mean(df): + return df.resample("2D")["A"].mean() + + +@pytest.fixture +def a_std(df): + return df.resample("2D")["A"].std() + + +@pytest.fixture +def a_sum(df): + return df.resample("2D")["A"].sum() + + +@pytest.fixture +def b_mean(df): + return df.resample("2D")["B"].mean() + + +@pytest.fixture +def b_std(df): + return df.resample("2D")["B"].std() + + +@pytest.fixture +def b_sum(df): + return df.resample("2D")["B"].sum() + + +@pytest.fixture +def df_resample(df): + return df.resample("2D") + + +@pytest.fixture +def df_col_resample(df_col): + return df_col.resample("2D", on="date") + + +@pytest.fixture +def df_mult_resample(df_mult): + return df_mult.resample("2D", level="date") + + +@pytest.fixture +def df_grouper_resample(df): + return df.groupby(pd.Grouper(freq="2D")) + + +@pytest.fixture( + params=["df_resample", "df_col_resample", "df_mult_resample", "df_grouper_resample"] +) +def cases(request): + return request.getfixturevalue(request.param) + + +def test_agg_mixed_column_aggregation(cases, a_mean, a_std, b_mean, b_std, request): + expected = pd.concat([a_mean, a_std, b_mean, b_std], axis=1) + expected.columns = pd.MultiIndex.from_product([["A", "B"], ["mean", ""]]) + # "date" is an index and a column, so get included in the agg + if "df_mult" in request.node.callspec.id: + date_mean = cases["date"].mean() + date_std = cases["date"].std() + expected = pd.concat([date_mean, date_std, expected], axis=1) + expected.columns = pd.MultiIndex.from_product( + [["date", "A", "B"], ["mean", ""]] + ) + result = cases.aggregate([np.mean, lambda x: np.std(x, ddof=1)]) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "agg", + [ + {"func": {"A": np.mean, "B": lambda x: np.std(x, ddof=1)}}, + {"A": ("A", np.mean), "B": ("B", lambda x: np.std(x, ddof=1))}, + {"A": NamedAgg("A", np.mean), "B": NamedAgg("B", lambda x: np.std(x, ddof=1))}, + ], +) +def test_agg_both_mean_std_named_result(cases, a_mean, b_std, agg): + expected = pd.concat([a_mean, b_std], axis=1) + result = cases.aggregate(**agg) + tm.assert_frame_equal(result, expected, check_like=True) + + +def test_agg_both_mean_std_dict_of_list(cases, a_mean, a_std): + expected = pd.concat([a_mean, a_std], axis=1) + expected.columns = pd.MultiIndex.from_tuples([("A", "mean"), ("A", "std")]) + result = cases.aggregate({"A": ["mean", "std"]}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "agg", [{"func": ["mean", "sum"]}, {"mean": "mean", "sum": "sum"}] +) +def test_agg_both_mean_sum(cases, a_mean, a_sum, agg): + expected = pd.concat([a_mean, a_sum], axis=1) + expected.columns = ["mean", "sum"] + result = cases["A"].aggregate(**agg) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "agg", + [ + {"A": {"mean": "mean", "sum": "sum"}}, + { + "A": {"mean": "mean", "sum": "sum"}, + "B": {"mean2": "mean", "sum2": "sum"}, + }, + ], +) +def test_agg_dict_of_dict_specificationerror(cases, agg): + msg = "nested renamer is not supported" + with pytest.raises(pd.errors.SpecificationError, match=msg): + cases.aggregate(agg) + + +def test_agg_dict_of_lists(cases, a_mean, a_std, b_mean, b_std): + expected = pd.concat([a_mean, a_std, b_mean, b_std], axis=1) + expected.columns = pd.MultiIndex.from_tuples( + [("A", "mean"), ("A", "std"), ("B", "mean"), ("B", "std")] + ) + result = cases.aggregate({"A": ["mean", "std"], "B": ["mean", "std"]}) + tm.assert_frame_equal(result, expected, check_like=True) + + +@pytest.mark.parametrize( + "agg", + [ + {"func": {"A": np.sum, "B": lambda x: np.std(x, ddof=1)}}, + {"A": ("A", np.sum), "B": ("B", lambda x: np.std(x, ddof=1))}, + {"A": NamedAgg("A", np.sum), "B": NamedAgg("B", lambda x: np.std(x, ddof=1))}, + ], +) +def test_agg_with_lambda(cases, agg): + # passed lambda + rcustom = cases["B"].apply(lambda x: np.std(x, ddof=1)) + expected = pd.concat([cases["A"].sum(), rcustom], axis=1) + result = cases.agg(**agg) + tm.assert_frame_equal(result, expected, check_like=True) + + +@pytest.mark.parametrize( + "agg", + [ + {"func": {"result1": np.sum, "result2": np.mean}}, + {"A": ("result1", np.sum), "B": ("result2", np.mean)}, + {"A": NamedAgg("result1", np.sum), "B": NamedAgg("result2", np.mean)}, + ], +) +def test_agg_no_column(cases, agg): + msg = r"Label\(s\) \['result1', 'result2'\] do not exist" + with pytest.raises(KeyError, match=msg): + cases[["A", "B"]].agg(**agg) + + +@pytest.mark.parametrize( + "cols, agg", + [ + [None, {"A": ["sum", "std"], "B": ["mean", "std"]}], + [ + [ + "A", + "B", + ], + {"A": ["sum", "std"], "B": ["mean", "std"]}, + ], + ], +) +def test_agg_specificationerror_nested(cases, cols, agg, a_sum, a_std, b_mean, b_std): + # agg with different hows + # equivalent of using a selection list / or not + expected = pd.concat([a_sum, a_std, b_mean, b_std], axis=1) + expected.columns = pd.MultiIndex.from_tuples( + [("A", "sum"), ("A", "std"), ("B", "mean"), ("B", "std")] + ) + if cols is not None: + obj = cases[cols] + else: + obj = cases + + result = obj.agg(agg) + tm.assert_frame_equal(result, expected, check_like=True) + + +@pytest.mark.parametrize( + "agg", [{"A": ["sum", "std"]}, {"A": ["sum", "std"], "B": ["mean", "std"]}] +) +def test_agg_specificationerror_series(cases, agg): + msg = "nested renamer is not supported" + + # series like aggs + with pytest.raises(pd.errors.SpecificationError, match=msg): + cases["A"].agg(agg) + + +def test_agg_specificationerror_invalid_names(cases): + # errors + # invalid names in the agg specification + msg = r"Label\(s\) \['B'\] do not exist" + with pytest.raises(KeyError, match=msg): + cases[["A"]].agg({"A": ["sum", "std"], "B": ["mean", "std"]}) + + +def test_agg_nested_dicts(): + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + index.name = "date" + df = DataFrame( + np.random.default_rng(2).random((10, 2)), columns=list("AB"), index=index + ) + df_col = df.reset_index() + df_mult = df_col.copy() + df_mult.index = pd.MultiIndex.from_arrays( + [range(10), df.index], names=["index", "date"] + ) + r = df.resample("2D") + cases = [ + r, + df_col.resample("2D", on="date"), + df_mult.resample("2D", level="date"), + df.groupby(pd.Grouper(freq="2D")), + ] + + msg = "nested renamer is not supported" + for t in cases: + with pytest.raises(pd.errors.SpecificationError, match=msg): + t.aggregate({"r1": {"A": ["mean", "sum"]}, "r2": {"B": ["mean", "sum"]}}) + + for t in cases: + with pytest.raises(pd.errors.SpecificationError, match=msg): + t[["A", "B"]].agg( + {"A": {"ra": ["mean", "std"]}, "B": {"rb": ["mean", "std"]}} + ) + + with pytest.raises(pd.errors.SpecificationError, match=msg): + t.agg({"A": {"ra": ["mean", "std"]}, "B": {"rb": ["mean", "std"]}}) + + +def test_try_aggregate_non_existing_column(): + # GH 16766 + data = [ + {"dt": datetime(2017, 6, 1, 0), "x": 1.0, "y": 2.0}, + {"dt": datetime(2017, 6, 1, 1), "x": 2.0, "y": 2.0}, + {"dt": datetime(2017, 6, 1, 2), "x": 3.0, "y": 1.5}, + ] + df = DataFrame(data).set_index("dt") + + # Error as we don't have 'z' column + msg = r"Label\(s\) \['z'\] do not exist" + with pytest.raises(KeyError, match=msg): + df.resample("30min").agg({"x": ["mean"], "y": ["median"], "z": ["sum"]}) + + +def test_agg_list_like_func_with_args(): + # 50624 + df = DataFrame( + {"x": [1, 2, 3]}, index=date_range("2020-01-01", periods=3, freq="D") + ) + + def foo1(x, a=1, c=0): + return x + a + c + + def foo2(x, b=2, c=0): + return x + b + c + + msg = r"foo1\(\) got an unexpected keyword argument 'b'" + with pytest.raises(TypeError, match=msg): + df.resample("D").agg([foo1, foo2], 3, b=3, c=4) + + result = df.resample("D").agg([foo1, foo2], 3, c=4) + expected = DataFrame( + [[8, 8], [9, 9], [10, 10]], + index=date_range("2020-01-01", periods=3, freq="D"), + columns=pd.MultiIndex.from_tuples([("x", "foo1"), ("x", "foo2")]), + ) + tm.assert_frame_equal(result, expected) + + +def test_selection_api_validation(): + # GH 13500 + index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D") + + rng = np.arange(len(index), dtype=np.int64) + df = DataFrame( + {"date": index, "a": rng}, + index=pd.MultiIndex.from_arrays([rng, index], names=["v", "d"]), + ) + df_exp = DataFrame({"a": rng}, index=index) + + # non DatetimeIndex + msg = ( + "Only valid with DatetimeIndex, TimedeltaIndex or PeriodIndex, " + "but got an instance of 'Index'" + ) + with pytest.raises(TypeError, match=msg): + df.resample("2D", level="v") + + msg = "The Grouper cannot specify both a key and a level!" + with pytest.raises(ValueError, match=msg): + df.resample("2D", on="date", level="d") + + msg = "unhashable type: 'list'" + with pytest.raises(TypeError, match=msg): + df.resample("2D", on=["a", "date"]) + + msg = r"\"Level \['a', 'date'\] not found\"" + with pytest.raises(KeyError, match=msg): + df.resample("2D", level=["a", "date"]) + + # upsampling not allowed + msg = ( + "Upsampling from level= or on= selection is not supported, use " + r"\.set_index\(\.\.\.\) to explicitly set index to datetime-like" + ) + with pytest.raises(ValueError, match=msg): + df.resample("2D", level="d").asfreq() + with pytest.raises(ValueError, match=msg): + df.resample("2D", on="date").asfreq() + + exp = df_exp.resample("2D").sum() + exp.index.name = "date" + result = df.resample("2D", on="date").sum() + tm.assert_frame_equal(exp, result) + + exp.index.name = "d" + with pytest.raises( + TypeError, match="datetime64 type does not support operation 'sum'" + ): + df.resample("2D", level="d").sum() + result = df.resample("2D", level="d").sum(numeric_only=True) + tm.assert_frame_equal(exp, result) + + +@pytest.mark.parametrize( + "col_name", ["t2", "t2x", "t2q", "T_2M", "t2p", "t2m", "t2m1", "T2M"] +) +def test_agg_with_datetime_index_list_agg_func(col_name): + # GH 22660 + # The parametrized column names would get converted to dates by our + # date parser. Some would result in OutOfBoundsError (ValueError) while + # others would result in OverflowError when passed into Timestamp. + # We catch these errors and move on to the correct branch. + df = DataFrame( + list(range(200)), + index=date_range( + start="2017-01-01", freq="15min", periods=200, tz="Europe/Berlin" + ), + columns=[col_name], + ) + result = df.resample("1D").aggregate(["mean"]) + expected = DataFrame( + [47.5, 143.5, 195.5], + index=date_range(start="2017-01-01", freq="D", periods=3, tz="Europe/Berlin"), + columns=pd.MultiIndex(levels=[[col_name], ["mean"]], codes=[[0], [0]]), + ) + tm.assert_frame_equal(result, expected) + + +def test_resample_agg_readonly(): + # GH#31710 cython needs to allow readonly data + index = date_range("2020-01-01", "2020-01-02", freq="1h", unit="ns") + arr = np.zeros_like(index) + arr.setflags(write=False) + + ser = Series(arr, index=index) + rs = ser.resample("1D") + + expected = Series([pd.Timestamp(0), pd.Timestamp(0)], index=index[::24]) + expected.index.freq = Day(1) # GH#41943 no longer equivalent to 24h + + result = rs.agg("last") + tm.assert_series_equal(result, expected) + + result = rs.agg("first") + tm.assert_series_equal(result, expected) + + result = rs.agg("max") + tm.assert_series_equal(result, expected) + + result = rs.agg("min") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "start,end,freq,data,resample_freq,origin,closed,exp_data,exp_end,exp_periods", + [ + ( + "2000-10-01 23:30:00", + "2000-10-02 00:26:00", + "7min", + [0, 3, 6, 9, 12, 15, 18, 21, 24], + "17min", + "end", + None, + [0, 18, 27, 63], + "20001002 00:26:00", + 4, + ), + ( + "20200101 8:26:35", + "20200101 9:31:58", + "77s", + [1] * 51, + "7min", + "end", + "right", + [1, 6, 5, 6, 5, 6, 5, 6, 5, 6], + "2020-01-01 09:30:45", + 10, + ), + ( + "2000-10-01 23:30:00", + "2000-10-02 00:26:00", + "7min", + [0, 3, 6, 9, 12, 15, 18, 21, 24], + "17min", + "end", + "left", + [0, 18, 27, 39, 24], + "20001002 00:43:00", + 5, + ), + ( + "2000-10-01 23:30:00", + "2000-10-02 00:26:00", + "7min", + [0, 3, 6, 9, 12, 15, 18, 21, 24], + "17min", + "end_day", + None, + [3, 15, 45, 45], + "2000-10-02 00:29:00", + 4, + ), + ], +) +def test_end_and_end_day_origin( + start, + end, + freq, + data, + resample_freq, + origin, + closed, + exp_data, + exp_end, + exp_periods, +): + rng = date_range(start, end, freq=freq) + ts = Series(data, index=rng) + + res = ts.resample(resample_freq, origin=origin, closed=closed).sum() + expected = Series( + exp_data, + index=date_range(end=exp_end, freq=resample_freq, periods=exp_periods), + ) + + tm.assert_series_equal(res, expected) + + +@pytest.mark.parametrize( + # expected_data is a string when op raises a ValueError + "method, numeric_only, expected_data", + [ + ("sum", True, {"num": [25]}), + ("sum", False, {"cat": ["cat_1cat_2"], "num": [25]}), + ("sum", lib.no_default, {"cat": ["cat_1cat_2"], "num": [25]}), + ("prod", True, {"num": [100]}), + ("prod", False, "can't multiply sequence"), + ("prod", lib.no_default, "can't multiply sequence"), + ("min", True, {"num": [5]}), + ("min", False, {"cat": ["cat_1"], "num": [5]}), + ("min", lib.no_default, {"cat": ["cat_1"], "num": [5]}), + ("max", True, {"num": [20]}), + ("max", False, {"cat": ["cat_2"], "num": [20]}), + ("max", lib.no_default, {"cat": ["cat_2"], "num": [20]}), + ("first", True, {"num": [5]}), + ("first", False, {"cat": ["cat_1"], "num": [5]}), + ("first", lib.no_default, {"cat": ["cat_1"], "num": [5]}), + ("last", True, {"num": [20]}), + ("last", False, {"cat": ["cat_2"], "num": [20]}), + ("last", lib.no_default, {"cat": ["cat_2"], "num": [20]}), + ("mean", True, {"num": [12.5]}), + ("mean", False, "Could not convert"), + ("mean", lib.no_default, "Could not convert"), + ("median", True, {"num": [12.5]}), + ("median", False, r"Cannot convert \['cat_1' 'cat_2'\] to numeric"), + ("median", lib.no_default, r"Cannot convert \['cat_1' 'cat_2'\] to numeric"), + ("std", True, {"num": [10.606601717798213]}), + ("std", False, "could not convert string to float"), + ("std", lib.no_default, "could not convert string to float"), + ("var", True, {"num": [112.5]}), + ("var", False, "could not convert string to float"), + ("var", lib.no_default, "could not convert string to float"), + ("sem", True, {"num": [7.5]}), + ("sem", False, "could not convert string to float"), + ("sem", lib.no_default, "could not convert string to float"), + ], +) +def test_frame_downsample_method( + method, numeric_only, expected_data, using_infer_string +): + # GH#46442 test if `numeric_only` behave as expected for DataFrameGroupBy + + index = date_range("2018-01-01", periods=2, freq="D") + expected_index = date_range("2018-12-31", periods=1, freq="YE") + df = DataFrame({"cat": ["cat_1", "cat_2"], "num": [5, 20]}, index=index) + resampled = df.resample("YE") + if numeric_only is lib.no_default: + kwargs = {} + else: + kwargs = {"numeric_only": numeric_only} + + func = getattr(resampled, method) + if isinstance(expected_data, str): + if method in ("var", "mean", "median", "prod"): + klass = TypeError + msg = re.escape(f"agg function failed [how->{method},dtype->") + if using_infer_string: + msg = f"dtype 'str' does not support operation '{method}'" + elif method in ["sum", "std", "sem"] and using_infer_string: + klass = TypeError + msg = f"dtype 'str' does not support operation '{method}'" + else: + klass = ValueError + msg = expected_data + with pytest.raises(klass, match=msg): + _ = func(**kwargs) + else: + result = func(**kwargs) + expected = DataFrame(expected_data, index=expected_index) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "method, numeric_only, expected_data", + [ + ("sum", True, ()), + ("sum", False, ["cat_1cat_2"]), + ("sum", lib.no_default, ["cat_1cat_2"]), + ("prod", True, ()), + ("prod", False, ()), + ("prod", lib.no_default, ()), + ("min", True, ()), + ("min", False, ["cat_1"]), + ("min", lib.no_default, ["cat_1"]), + ("max", True, ()), + ("max", False, ["cat_2"]), + ("max", lib.no_default, ["cat_2"]), + ("first", True, ()), + ("first", False, ["cat_1"]), + ("first", lib.no_default, ["cat_1"]), + ("last", True, ()), + ("last", False, ["cat_2"]), + ("last", lib.no_default, ["cat_2"]), + ], +) +def test_series_downsample_method( + method, numeric_only, expected_data, using_infer_string +): + # GH#46442 test if `numeric_only` behave as expected for SeriesGroupBy + + index = date_range("2018-01-01", periods=2, freq="D") + expected_index = date_range("2018-12-31", periods=1, freq="YE") + df = Series(["cat_1", "cat_2"], index=index) + resampled = df.resample("YE") + kwargs = {} if numeric_only is lib.no_default else {"numeric_only": numeric_only} + + func = getattr(resampled, method) + if numeric_only and numeric_only is not lib.no_default: + msg = rf"Cannot use numeric_only=True with SeriesGroupBy\.{method}" + with pytest.raises(TypeError, match=msg): + func(**kwargs) + elif method == "prod": + msg = re.escape("agg function failed [how->prod,dtype->") + if using_infer_string: + msg = "dtype 'str' does not support operation 'prod'" + with pytest.raises(TypeError, match=msg): + func(**kwargs) + + else: + result = func(**kwargs) + expected = Series(expected_data, index=expected_index) + tm.assert_series_equal(result, expected) + + +def test_resample_empty(): + # GH#52484 + df = DataFrame( + index=pd.to_datetime( + ["2018-01-01 00:00:00", "2018-01-01 12:00:00", "2018-01-02 00:00:00"] + ) + ) + expected = DataFrame( + index=pd.to_datetime( + [ + "2018-01-01 00:00:00", + "2018-01-01 08:00:00", + "2018-01-01 16:00:00", + "2018-01-02 00:00:00", + ] + ) + ) + result = df.resample("8h").mean() + tm.assert_frame_equal(result, expected) + + +def test_asfreq_respects_origin_with_fixed_freq_all_seconds_equal(): + # GH#62725: Ensure Resampler.asfreq respects origin="start_day" + # when all datetimes share identical seconds values. + idx = [ + datetime(2025, 10, 17, 17, 15, 10), + datetime(2025, 10, 17, 17, 16, 10), + datetime(2025, 10, 17, 17, 17, 10), + ] + df = DataFrame({"value": [0, 1, 2]}, index=idx) + + result = df.resample("1min", origin="start_day").asfreq() + + # Expected index: list of Timestamps, matching dtype + exp_idx = pd.DatetimeIndex( + [ + pd.Timestamp("2025-10-17 17:15:00"), + pd.Timestamp("2025-10-17 17:16:00"), + pd.Timestamp("2025-10-17 17:17:00"), + ], + dtype=result.index.dtype, + freq="min", + ) + + exp = DataFrame({"value": [np.nan, np.nan, np.nan]}, index=exp_idx) + tm.assert_frame_equal(result, exp) diff --git a/pandas/tests/resample/test_resampler_grouper.py b/pandas/tests/resample/test_resampler_grouper.py new file mode 100644 index 0000000000000000000000000000000000000000..862578decb782af4e056f0f076b1f4cba894c200 --- /dev/null +++ b/pandas/tests/resample/test_resampler_grouper.py @@ -0,0 +1,671 @@ +from textwrap import dedent + +import numpy as np +import pytest + +from pandas.compat import is_platform_windows + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + TimedeltaIndex, + Timestamp, +) +import pandas._testing as tm +from pandas.core.indexes.datetimes import date_range + + +@pytest.fixture +def test_frame(): + return DataFrame( + {"A": [1] * 20 + [2] * 12 + [3] * 8, "B": np.arange(40)}, + index=date_range("1/1/2000", freq="s", periods=40, unit="ns"), + ) + + +def test_tab_complete_ipython6_warning(ip): + from IPython.core.completer import provisionalcompleter + + code = dedent( + """\ + import numpy as np + from pandas import Series, date_range + data = np.arange(10, dtype=np.float64) + index = date_range("2020-01-01", periods=len(data)) + s = Series(data, index=index) + rs = s.resample("D") + """ + ) + ip.run_cell(code) + + # GH 31324 newer jedi version raises Deprecation warning; + # appears resolved 2021-02-02 + with tm.assert_produces_warning(None, raise_on_extra_warnings=False): + with provisionalcompleter("ignore"): + list(ip.Completer.completions("rs.", 1)) + + +def test_deferred_with_groupby(): + # GH 12486 + # support deferred resample ops with groupby + data = [ + ["2010-01-01", "A", 2], + ["2010-01-02", "A", 3], + ["2010-01-05", "A", 8], + ["2010-01-10", "A", 7], + ["2010-01-13", "A", 3], + ["2010-01-01", "B", 5], + ["2010-01-03", "B", 2], + ["2010-01-04", "B", 1], + ["2010-01-11", "B", 7], + ["2010-01-14", "B", 3], + ] + + df = DataFrame(data, columns=["date", "id", "score"]) + df.date = pd.to_datetime(df.date) + + def f_0(x): + return x.set_index("date").resample("D").asfreq() + + expected = df.groupby("id").apply(f_0) + result = df.set_index("date").groupby("id").resample("D").asfreq() + tm.assert_frame_equal(result, expected) + + df = DataFrame( + { + "date": date_range(start="2016-01-01", periods=4, freq="W"), + "group": [1, 1, 2, 2], + "val": [5, 6, 7, 8], + } + ).set_index("date") + + def f_1(x): + return x.resample("1D").ffill() + + expected = df.groupby("group").apply(f_1) + result = df.groupby("group").resample("1D").ffill() + tm.assert_frame_equal(result, expected) + + +def test_getitem(test_frame): + g = test_frame.groupby("A") + + expected = g.B.apply(lambda x: x.resample("2s").mean()) + + result = g.resample("2s").B.mean() + tm.assert_series_equal(result, expected) + + result = g.B.resample("2s").mean() + tm.assert_series_equal(result, expected) + + result = g.resample("2s").mean().B + tm.assert_series_equal(result, expected) + + +def test_getitem_multiple(): + # GH 13174 + # multiple calls after selection causing an issue with aliasing + data = [{"id": 1, "buyer": "A"}, {"id": 2, "buyer": "B"}] + df = DataFrame(data, index=date_range("2016-01-01", periods=2)) + r = df.groupby("id").resample("1D") + result = r["buyer"].count() + + exp_mi = pd.MultiIndex.from_arrays([[1, 2], df.index], names=("id", None)) + expected = Series( + [1, 1], + index=exp_mi, + name="buyer", + ) + tm.assert_series_equal(result, expected) + + result = r["buyer"].count() + tm.assert_series_equal(result, expected) + + +def test_groupby_resample_on_api_with_getitem(): + # GH 17813 + df = DataFrame( + {"id": list("aabbb"), "date": date_range("1-1-2016", periods=5), "data": 1} + ) + exp = df.set_index("date").groupby("id").resample("2D")["data"].sum() + result = df.groupby("id").resample("2D", on="date")["data"].sum() + tm.assert_series_equal(result, exp) + + +def test_groupby_with_origin(): + # GH 31809 + + freq = "1399min" # prime number that is smaller than 24h + start, end = "1/1/2000 00:00:00", "1/31/2000 00:00" + middle = "1/15/2000 00:00:00" + + rng = date_range(start, end, freq="1231min") # prime number + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + ts2 = ts[middle:end] + + # proves that grouper without a fixed origin does not work + # when dealing with unusual frequencies + simple_grouper = pd.Grouper(freq=freq) + count_ts = ts.groupby(simple_grouper).agg("count") + count_ts = count_ts[middle:end] + count_ts2 = ts2.groupby(simple_grouper).agg("count") + with pytest.raises(AssertionError, match="Index are different"): + tm.assert_index_equal(count_ts.index, count_ts2.index) + + # test origin on 1970-01-01 00:00:00 + origin = Timestamp(0) + adjusted_grouper = pd.Grouper(freq=freq, origin=origin) + adjusted_count_ts = ts.groupby(adjusted_grouper).agg("count") + adjusted_count_ts = adjusted_count_ts[middle:end] + adjusted_count_ts2 = ts2.groupby(adjusted_grouper).agg("count") + tm.assert_series_equal(adjusted_count_ts, adjusted_count_ts2) + + # test origin on 2049-10-18 20:00:00 + origin_future = Timestamp(0) + pd.Timedelta("1399min") * 30_000 + adjusted_grouper2 = pd.Grouper(freq=freq, origin=origin_future) + adjusted2_count_ts = ts.groupby(adjusted_grouper2).agg("count") + adjusted2_count_ts = adjusted2_count_ts[middle:end] + adjusted2_count_ts2 = ts2.groupby(adjusted_grouper2).agg("count") + tm.assert_series_equal(adjusted2_count_ts, adjusted2_count_ts2) + + # both grouper use an adjusted timestamp that is a multiple of 1399 min + # they should be equals even if the adjusted_timestamp is in the future + tm.assert_series_equal(adjusted_count_ts, adjusted2_count_ts2) + + +def test_nearest(): + # GH 17496 + # Resample nearest + index = date_range("1/1/2000", periods=3, freq="min", unit="ns") + result = Series(range(3), index=index).resample("20s").nearest() + + expected = Series( + [0, 0, 1, 1, 1, 2, 2], + index=pd.DatetimeIndex( + [ + "2000-01-01 00:00:00", + "2000-01-01 00:00:20", + "2000-01-01 00:00:40", + "2000-01-01 00:01:00", + "2000-01-01 00:01:20", + "2000-01-01 00:01:40", + "2000-01-01 00:02:00", + ], + dtype="datetime64[ns]", + freq="20s", + ), + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "f", + [ + "first", + "last", + "median", + "sem", + "sum", + "mean", + "min", + "max", + "size", + "count", + "nearest", + "bfill", + "ffill", + "asfreq", + "ohlc", + ], +) +def test_methods(f, test_frame): + g = test_frame.groupby("A") + r = g.resample("2s") + + result = getattr(r, f)() + expected = g.apply(lambda x: getattr(x.resample("2s"), f)()) + tm.assert_equal(result, expected) + + +def test_methods_nunique(test_frame): + # series only + g = test_frame.groupby("A") + r = g.resample("2s") + result = r.B.nunique() + expected = g.B.apply(lambda x: x.resample("2s").nunique()) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("f", ["std", "var"]) +def test_methods_std_var(f, test_frame): + g = test_frame.groupby("A") + r = g.resample("2s") + result = getattr(r, f)(ddof=1) + expected = g.apply(lambda x: getattr(x.resample("2s"), f)(ddof=1)) + tm.assert_frame_equal(result, expected) + + +def test_apply(test_frame): + g = test_frame.groupby("A") + r = g.resample("2s") + + # reduction + expected = g.resample("2s").sum() + + def f_0(x): + return x.resample("2s").sum() + + result = r.apply(f_0) + tm.assert_frame_equal(result, expected) + + def f_1(x): + return x.resample("2s").apply(lambda y: y.sum()) + + result = g.apply(f_1) + tm.assert_frame_equal(result, expected) + + +def test_apply_with_mutated_index(): + # GH 15169 + index = date_range("1-1-2015", "12-31-15", freq="D") + df = DataFrame( + data={"col1": np.random.default_rng(2).random(len(index))}, index=index + ) + + def f(x): + s = Series([1, 2], index=["a", "b"]) + return s + + expected = df.groupby(pd.Grouper(freq="ME")).apply(f) + + result = df.resample("ME").apply(f) + tm.assert_frame_equal(result, expected) + + # A case for series + expected = df["col1"].groupby(pd.Grouper(freq="ME"), group_keys=False).apply(f) + result = df["col1"].resample("ME").apply(f) + tm.assert_series_equal(result, expected) + + +def test_apply_columns_multilevel(): + # GH 16231 + cols = pd.MultiIndex.from_tuples([("A", "a", "", "one"), ("B", "b", "i", "two")]) + ind = date_range(start="2017-01-01", freq="15Min", periods=8) + df = DataFrame( + np.array([0] * 16, dtype=np.int64).reshape(8, 2), index=ind, columns=cols + ) + agg_dict = {col: (np.sum if col[3] == "one" else np.mean) for col in df.columns} + result = df.resample("h").apply(lambda x: agg_dict[x.name](x)) + expected = DataFrame( + 2 * [[0, 0.0]], + index=date_range(start="2017-01-01", freq="1h", periods=2), + columns=pd.MultiIndex.from_tuples( + [("A", "a", "", "one"), ("B", "b", "i", "two")] + ), + ) + tm.assert_frame_equal(result, expected) + + +def test_apply_non_naive_index(): + def weighted_quantile(series, weights, q): + series = series.sort_values() + cumsum = weights.reindex(series.index).fillna(0).cumsum() + cutoff = cumsum.iloc[-1] * q + return series[cumsum >= cutoff].iloc[0] + + times = date_range("2017-6-23 18:00", periods=8, freq="15min", tz="UTC") + data = Series([1.0, 1, 1, 1, 1, 2, 2, 0], index=times) + weights = Series([160.0, 91, 65, 43, 24, 10, 1, 0], index=times) + + result = data.resample("D").apply(weighted_quantile, weights=weights, q=0.5) + ind = date_range( + "2017-06-23 00:00:00+00:00", "2017-06-23 00:00:00+00:00", freq="D", tz="UTC" + ) + expected = Series([1.0], index=ind) + tm.assert_series_equal(result, expected) + + +def test_resample_groupby_with_label(unit): + # GH 13235 + index = date_range("2000-01-01", freq="2D", periods=5, unit=unit) + df = DataFrame(index=index, data={"col0": [0, 0, 1, 1, 2], "col1": [1, 1, 1, 1, 1]}) + result = df.groupby("col0").resample("1W", label="left").sum() + + mi = [ + np.array([0, 0, 1, 2], dtype=np.int64), + np.array( + ["1999-12-26", "2000-01-02", "2000-01-02", "2000-01-02"], + dtype=f"M8[{unit}]", + ), + ] + mindex = pd.MultiIndex.from_arrays(mi, names=["col0", None]) + expected = DataFrame(data={"col1": [1, 1, 2, 1]}, index=mindex) + + tm.assert_frame_equal(result, expected) + + +def test_consistency_with_window(test_frame): + # consistent return values with window + df = test_frame + expected = Index([1, 2, 3], name="A") + result = df.groupby("A").resample("2s").mean() + assert result.index.nlevels == 2 + tm.assert_index_equal(result.index.levels[0], expected) + + result = df.groupby("A").rolling(20).mean() + assert result.index.nlevels == 2 + tm.assert_index_equal(result.index.levels[0], expected) + + +def test_median_duplicate_columns(): + # GH 14233 + + df = DataFrame( + np.random.default_rng(2).standard_normal((20, 3)), + columns=list("aaa"), + index=date_range("2012-01-01", periods=20, freq="s"), + ) + result = df.resample("5s").median() + df.columns = ["a", "b", "c"] + expected = df.resample("5s").median() + expected.columns = result.columns + tm.assert_frame_equal(result, expected) + + +def test_apply_to_one_column_of_df(): + # GH: 36951 + df = DataFrame( + {"col": range(10), "col1": range(10, 20)}, + index=date_range("2012-01-01", periods=10, freq="20min"), + ) + + # access "col" via getattr -> make sure we handle AttributeError + result = df.resample("h").apply(lambda group: group.col.sum()) + expected = Series( + [3, 12, 21, 9], index=date_range("2012-01-01", periods=4, freq="h") + ) + tm.assert_series_equal(result, expected) + + # access "col" via _getitem__ -> make sure we handle KeyErrpr + result = df.resample("h").apply(lambda group: group["col"].sum()) + tm.assert_series_equal(result, expected) + + +def test_resample_groupby_agg(): + # GH: 33548 + df = DataFrame( + { + "cat": [ + "cat_1", + "cat_1", + "cat_2", + "cat_1", + "cat_2", + "cat_1", + "cat_2", + "cat_1", + ], + "num": [5, 20, 22, 3, 4, 30, 10, 50], + "date": [ + "2019-2-1", + "2018-02-03", + "2020-3-11", + "2019-2-2", + "2019-2-2", + "2018-12-4", + "2020-3-11", + "2020-12-12", + ], + } + ) + df["date"] = pd.to_datetime(df["date"]) + + resampled = df.groupby("cat").resample("YE", on="date") + expected = resampled[["num"]].sum() + result = resampled.agg({"num": "sum"}) + + tm.assert_frame_equal(result, expected) + + +def test_resample_groupby_agg_listlike(): + # GH 42905 + ts = Timestamp("2021-02-28 00:00:00") + df = DataFrame({"class": ["beta"], "value": [69]}, index=Index([ts], name="date")) + resampled = df.groupby("class").resample("ME")["value"] + result = resampled.agg(["sum", "size"]) + expected = DataFrame( + [[69, 1]], + index=pd.MultiIndex.from_tuples([("beta", ts)], names=["class", "date"]), + columns=["sum", "size"], + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("keys", [["a"], ["a", "b"]]) +def test_empty(keys): + # GH 26411 + df = DataFrame([], columns=["a", "b"], index=TimedeltaIndex([])) + result = df.groupby(keys).resample(rule=pd.to_timedelta("00:00:01")).mean() + expected_columns = ["b"] if keys == ["a"] else [] + expected = ( + DataFrame(columns=["a", "b"]) + .set_index(keys, drop=False) + .set_index(TimedeltaIndex([]), append=True)[expected_columns] + ) + if len(keys) == 1: + expected.index.name = keys[0] + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("consolidate", [True, False]) +def test_resample_groupby_agg_object_dtype_all_nan(consolidate): + # https://github.com/pandas-dev/pandas/issues/39329 + + dates = date_range("2020-01-01", periods=15, freq="D", unit="ns") + df1 = DataFrame({"key": "A", "date": dates, "col1": range(15), "col_object": "val"}) + df2 = DataFrame({"key": "B", "date": dates, "col1": range(15)}) + df = pd.concat([df1, df2], ignore_index=True) + if consolidate: + df = df._consolidate() + + result = df.groupby(["key"]).resample("W", on="date").min() + idx = pd.MultiIndex.from_arrays( + [ + ["A"] * 3 + ["B"] * 3, + pd.to_datetime(["2020-01-05", "2020-01-12", "2020-01-19"] * 2).as_unit( + "ns" + ), + ], + names=["key", "date"], + ) + expected = DataFrame( + { + "col1": [0, 5, 12] * 2, + "col_object": ["val"] * 3 + [np.nan] * 3, + }, + index=idx, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("min_count", [0, 1]) +def test_groupby_resample_empty_sum_string( + string_dtype_no_object, test_frame, min_count +): + # https://github.com/pandas-dev/pandas/issues/60229 + dtype = string_dtype_no_object + test_frame = test_frame.assign(B=pd.array([pd.NA] * len(test_frame), dtype=dtype)) + gbrs = test_frame.groupby("A").resample("40s") + result = gbrs.sum(min_count=min_count) + + index = pd.MultiIndex( + levels=[[1, 2, 3], [pd.to_datetime("2000-01-01", unit="ns").as_unit("ns")]], + codes=[[0, 1, 2], [0, 0, 0]], + names=["A", None], + ) + value = "" if min_count == 0 else pd.NA + expected = DataFrame({"B": value}, index=index, dtype=dtype) + tm.assert_frame_equal(result, expected) + + +def test_groupby_resample_with_list_of_keys(): + # GH 47362 + df = DataFrame( + data={ + "date": date_range(start="2016-01-01", periods=8), + "group": [0, 0, 0, 0, 1, 1, 1, 1], + "val": [1, 7, 5, 2, 3, 10, 5, 1], + } + ) + result = df.groupby("group").resample("2D", on="date")[["val"]].mean() + + mi_exp = pd.MultiIndex.from_arrays( + [[0, 0, 1, 1], df["date"]._values[::2]], names=["group", "date"] + ) + expected = DataFrame( + data={ + "val": [4.0, 3.5, 6.5, 3.0], + }, + index=mi_exp, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("keys", [["a"], ["a", "b"]]) +def test_resample_no_index(keys): + # GH 47705 + df = DataFrame([], columns=["a", "b", "date"]) + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index("date") + result = df.groupby(keys).resample(rule=pd.to_timedelta("00:00:01")).mean() + expected_columns = ["b"] if keys == ["a"] else [] + expected = DataFrame(columns=["a", "b", "date"]).set_index(keys, drop=False) + expected["date"] = pd.to_datetime(expected["date"]) + expected = expected.set_index("date", append=True, drop=True)[expected_columns] + if len(keys) == 1: + expected.index.name = keys[0] + + tm.assert_frame_equal(result, expected) + + +def test_resample_no_columns(): + # GH#52484 + df = DataFrame( + index=Index( + pd.to_datetime( + ["2018-01-01 00:00:00", "2018-01-01 12:00:00", "2018-01-02 00:00:00"] + ), + name="date", + ) + ) + result = df.groupby([0, 0, 1]).resample(rule=pd.to_timedelta("06:00:00")).mean() + index = pd.to_datetime( + [ + "2018-01-01 00:00:00", + "2018-01-01 06:00:00", + "2018-01-01 12:00:00", + "2018-01-02 00:00:00", + ] + ) + expected = DataFrame( + index=pd.MultiIndex( + levels=[np.array([0, 1], dtype=np.intp), index], + codes=[[0, 0, 0, 1], [0, 1, 2, 3]], + names=[None, "date"], + ) + ) + + # GH#52710 - Index comes out as 32-bit on 64-bit Windows + tm.assert_frame_equal(result, expected, check_index_type=not is_platform_windows()) + + +def test_groupby_resample_size_all_index_same(): + # GH 46826 + df = DataFrame( + {"A": [1] * 3 + [2] * 3 + [1] * 3 + [2] * 3, "B": np.arange(12)}, + index=date_range("31/12/2000 18:00", freq="h", periods=12, unit="ns"), + ) + result = df.groupby("A").resample("D").size() + + mi_exp = pd.MultiIndex.from_arrays( + [ + [1, 1, 2, 2], + pd.DatetimeIndex(["2000-12-31", "2001-01-01"] * 2, dtype="M8[ns]"), + ], + names=["A", None], + ) + expected = Series( + 3, + index=mi_exp, + ) + tm.assert_series_equal(result, expected) + + +def test_groupby_resample_on_index_with_list_of_keys(): + # GH 50840 + df = DataFrame( + data={ + "group": [0, 0, 0, 0, 1, 1, 1, 1], + "val": [3, 1, 4, 1, 5, 9, 2, 6], + }, + index=date_range(start="2016-01-01", periods=8, name="date"), + ) + result = df.groupby("group").resample("2D")[["val"]].mean() + + mi_exp = pd.MultiIndex.from_arrays( + [[0, 0, 1, 1], df.index[::2]], names=["group", "date"] + ) + expected = DataFrame( + data={ + "val": [2.0, 2.5, 7.0, 4.0], + }, + index=mi_exp, + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_resample_on_index_with_list_of_keys_multi_columns(): + # GH 50876 + df = DataFrame( + data={ + "group": [0, 0, 0, 0, 1, 1, 1, 1], + "first_val": [3, 1, 4, 1, 5, 9, 2, 6], + "second_val": [2, 7, 1, 8, 2, 8, 1, 8], + "third_val": [1, 4, 1, 4, 2, 1, 3, 5], + }, + index=date_range(start="2016-01-01", periods=8, name="date"), + ) + result = df.groupby("group").resample("2D")[["first_val", "second_val"]].mean() + + mi_exp = pd.MultiIndex.from_arrays( + [[0, 0, 1, 1], df.index[::2]], names=["group", "date"] + ) + expected = DataFrame( + data={ + "first_val": [2.0, 2.5, 7.0, 4.0], + "second_val": [4.5, 4.5, 5.0, 4.5], + }, + index=mi_exp, + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_resample_on_index_with_list_of_keys_missing_column(): + # GH 50876 + df = DataFrame( + data={ + "group": [0, 0, 0, 0, 1, 1, 1, 1], + "val": [3, 1, 4, 1, 5, 9, 2, 6], + }, + index=Series( + date_range(start="2016-01-01", periods=8), + name="date", + ), + ) + gb = df.groupby("group") + rs = gb.resample("2D") + with pytest.raises(KeyError, match="Columns not found"): + rs[["val_not_in_dataframe"]] diff --git a/pandas/tests/resample/test_time_grouper.py b/pandas/tests/resample/test_time_grouper.py new file mode 100644 index 0000000000000000000000000000000000000000..e214a9f17824dbfc4dc00896b651d9113fdac472 --- /dev/null +++ b/pandas/tests/resample/test_time_grouper.py @@ -0,0 +1,439 @@ +from datetime import datetime +from operator import methodcaller + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + Timestamp, +) +import pandas._testing as tm +from pandas.core.groupby.grouper import Grouper +from pandas.core.indexes.datetimes import date_range + + +@pytest.fixture +def test_series(): + return Series( + np.random.default_rng(2).standard_normal(1000), + index=date_range("1/1/2000", periods=1000), + ) + + +def test_apply(test_series): + grouper = Grouper(freq="YE", label="right", closed="right") + + grouped = test_series.groupby(grouper) + + def f(x): + return x.sort_values()[-3:] + + applied = grouped.apply(f) + expected = test_series.groupby(lambda x: x.year).apply(f) + + applied.index = applied.index.droplevel(0) + expected.index = expected.index.droplevel(0) + tm.assert_series_equal(applied, expected) + + +def test_count(test_series): + test_series[::3] = np.nan + + expected = test_series.groupby(lambda x: x.year).count() + + grouper = Grouper(freq="YE", label="right", closed="right") + result = test_series.groupby(grouper).count() + expected.index = result.index + tm.assert_series_equal(result, expected) + + result = test_series.resample("YE").count() + expected.index = result.index + tm.assert_series_equal(result, expected) + + +def test_numpy_reduction(test_series): + result = test_series.resample("YE", closed="right").prod() + expected = test_series.groupby(lambda x: x.year).agg(np.prod) + expected.index = result.index + tm.assert_series_equal(result, expected) + + +def test_apply_iteration(): + # #2300 + N = 1000 + ind = date_range(start="2000-01-01", freq="D", periods=N) + df = DataFrame({"open": 1, "close": 2}, index=ind) + tg = Grouper(freq="ME") + + grouper, _ = tg._get_grouper(df) + + # Errors + grouped = df.groupby(grouper, group_keys=False) + + def f(df): + return df["close"] / df["open"] + + # it works! + result = grouped.apply(f) + tm.assert_index_equal(result.index, df.index) + + +@pytest.mark.parametrize( + "index", + [ + Index([1, 2]), + Index(["a", "b"]), + Index([1.1, 2.2]), + pd.MultiIndex.from_arrays([[1, 2], ["a", "b"]]), + ], +) +def test_fails_on_no_datetime_index(index): + name = type(index).__name__ + df = DataFrame({"a": range(len(index))}, index=index) + + msg = ( + "Only valid with DatetimeIndex, TimedeltaIndex " + f"or PeriodIndex, but got an instance of '{name}'" + ) + with pytest.raises(TypeError, match=msg): + df.groupby(Grouper(freq="D")) + + +def test_aaa_group_order(): + # GH 12840 + # check TimeGrouper perform stable sorts + n = 20 + data = np.random.default_rng(2).standard_normal((n, 4)) + df = DataFrame(data, columns=["A", "B", "C", "D"]) + df["key"] = [ + datetime(2013, 1, 1), + datetime(2013, 1, 2), + datetime(2013, 1, 3), + datetime(2013, 1, 4), + datetime(2013, 1, 5), + ] * 4 + grouped = df.groupby(Grouper(key="key", freq="D")) + + tm.assert_frame_equal(grouped.get_group(datetime(2013, 1, 1)), df[::5]) + tm.assert_frame_equal(grouped.get_group(datetime(2013, 1, 2)), df[1::5]) + tm.assert_frame_equal(grouped.get_group(datetime(2013, 1, 3)), df[2::5]) + tm.assert_frame_equal(grouped.get_group(datetime(2013, 1, 4)), df[3::5]) + tm.assert_frame_equal(grouped.get_group(datetime(2013, 1, 5)), df[4::5]) + + +def test_aggregate_normal(resample_method): + """Check TimeGrouper's aggregation is identical as normal groupby.""" + + data = np.random.default_rng(2).standard_normal((20, 4)) + normal_df = DataFrame(data, columns=["A", "B", "C", "D"]) + normal_df["key"] = [1, 2, 3, 4, 5] * 4 + + dt_df = DataFrame(data, columns=["A", "B", "C", "D"]) + dt_df["key"] = Index( + [ + datetime(2013, 1, 1), + datetime(2013, 1, 2), + datetime(2013, 1, 3), + datetime(2013, 1, 4), + datetime(2013, 1, 5), + ] + * 4, + dtype="M8[ns]", + ) + + normal_grouped = normal_df.groupby("key") + dt_grouped = dt_df.groupby(Grouper(key="key", freq="D")) + + expected = getattr(normal_grouped, resample_method)() + dt_result = getattr(dt_grouped, resample_method)() + expected.index = date_range( + start="2013-01-01", freq="D", periods=5, unit="ns", name="key" + ) + tm.assert_equal(expected, dt_result) + + +@pytest.mark.xfail(reason="if TimeGrouper is used included, 'nth' doesn't work yet") +def test_aggregate_nth(): + """Check TimeGrouper's aggregation is identical as normal groupby.""" + + data = np.random.default_rng(2).standard_normal((20, 4)) + normal_df = DataFrame(data, columns=["A", "B", "C", "D"]) + normal_df["key"] = [1, 2, 3, 4, 5] * 4 + + dt_df = DataFrame(data, columns=["A", "B", "C", "D"]) + dt_df["key"] = [ + datetime(2013, 1, 1), + datetime(2013, 1, 2), + datetime(2013, 1, 3), + datetime(2013, 1, 4), + datetime(2013, 1, 5), + ] * 4 + + normal_grouped = normal_df.groupby("key") + dt_grouped = dt_df.groupby(Grouper(key="key", freq="D")) + + expected = normal_grouped.nth(3) + expected.index = date_range(start="2013-01-01", freq="D", periods=5, name="key") + dt_result = dt_grouped.nth(3) + tm.assert_frame_equal(expected, dt_result) + + +@pytest.mark.parametrize( + "method, method_args, unit", + [ + ("sum", {}, 0), + ("sum", {"min_count": 0}, 0), + ("sum", {"min_count": 1}, np.nan), + ("prod", {}, 1), + ("prod", {"min_count": 0}, 1), + ("prod", {"min_count": 1}, np.nan), + ], +) +def test_resample_entirely_nat_window(method, method_args, unit): + ser = Series([0] * 2 + [np.nan] * 2, index=date_range("2017", periods=4, unit="ns")) + result = methodcaller(method, **method_args)(ser.resample("2D")) + + exp_dti = pd.DatetimeIndex(["2017-01-01", "2017-01-03"], dtype="M8[ns]", freq="2D") + expected = Series([0.0, unit], index=exp_dti) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "func, fill_value", + [("min", np.nan), ("max", np.nan), ("sum", 0), ("prod", 1), ("count", 0)], +) +def test_aggregate_with_nat(func, fill_value): + # check TimeGrouper's aggregation is identical as normal groupby + # if NaT is included, 'var', 'std', 'mean', 'first','last' + # and 'nth' doesn't work yet + + n = 20 + data = np.random.default_rng(2).standard_normal((n, 4)).astype("int64") + normal_df = DataFrame(data, columns=["A", "B", "C", "D"]) + normal_df["key"] = [1, 2, np.nan, 4, 5] * 4 + + dt_df = DataFrame(data, columns=["A", "B", "C", "D"]) + dt_df["key"] = Index( + [ + datetime(2013, 1, 1), + datetime(2013, 1, 2), + pd.NaT, + datetime(2013, 1, 4), + datetime(2013, 1, 5), + ] + * 4, + dtype="M8[ns]", + ) + + normal_grouped = normal_df.groupby("key") + dt_grouped = dt_df.groupby(Grouper(key="key", freq="D")) + + normal_result = getattr(normal_grouped, func)() + dt_result = getattr(dt_grouped, func)() + + pad = DataFrame([[fill_value] * 4], index=[3], columns=["A", "B", "C", "D"]) + expected = pd.concat([normal_result, pad]) + expected = expected.sort_index() + dti = date_range( + start="2013-01-01", + freq="D", + periods=5, + name="key", + unit=dt_df["key"]._values.unit, + ) + expected.index = dti._with_freq(None) # TODO: is this desired? + tm.assert_frame_equal(expected, dt_result) + assert dt_result.index.name == "key" + + +def test_aggregate_with_nat_size(): + # GH 9925 + n = 20 + data = np.random.default_rng(2).standard_normal((n, 4)).astype("int64") + normal_df = DataFrame(data, columns=["A", "B", "C", "D"]) + normal_df["key"] = [1, 2, np.nan, 4, 5] * 4 + + dt_df = DataFrame(data, columns=["A", "B", "C", "D"]) + dt_df["key"] = Index( + [ + datetime(2013, 1, 1), + datetime(2013, 1, 2), + pd.NaT, + datetime(2013, 1, 4), + datetime(2013, 1, 5), + ] + * 4, + dtype="M8[ns]", + ) + + normal_grouped = normal_df.groupby("key") + dt_grouped = dt_df.groupby(Grouper(key="key", freq="D")) + + normal_result = normal_grouped.size() + dt_result = dt_grouped.size() + + pad = Series([0], index=[3]) + expected = pd.concat([normal_result, pad]) + expected = expected.sort_index() + expected.index = date_range( + start="2013-01-01", + freq="D", + periods=5, + name="key", + unit=dt_df["key"]._values.unit, + )._with_freq(None) + tm.assert_series_equal(expected, dt_result) + assert dt_result.index.name == "key" + + +def test_repr(): + # GH18203 + result = repr(Grouper(key="A", freq="h")) + expected = ( + "TimeGrouper(key='A', freq=, sort=True, dropna=True, " + "closed='left', label='left', how='mean', " + "convention='e', origin='start_day')" + ) + assert result == expected + + result = repr(Grouper(key="A", freq="h", origin="2000-01-01")) + expected = ( + "TimeGrouper(key='A', freq=, sort=True, dropna=True, " + "closed='left', label='left', how='mean', " + "convention='e', origin=Timestamp('2000-01-01 00:00:00'))" + ) + assert result == expected + + +@pytest.mark.parametrize( + "method, method_args, expected_values", + [ + ("sum", {}, [1, 0, 1]), + ("sum", {"min_count": 0}, [1, 0, 1]), + ("sum", {"min_count": 1}, [1, np.nan, 1]), + ("sum", {"min_count": 2}, [np.nan, np.nan, np.nan]), + ("prod", {}, [1, 1, 1]), + ("prod", {"min_count": 0}, [1, 1, 1]), + ("prod", {"min_count": 1}, [1, np.nan, 1]), + ("prod", {"min_count": 2}, [np.nan, np.nan, np.nan]), + ], +) +def test_upsample_sum(method, method_args, expected_values): + ser = Series(1, index=date_range("2017", periods=2, freq="h", unit="ns")) + resampled = ser.resample("30min") + index = pd.DatetimeIndex( + ["2017-01-01T00:00:00", "2017-01-01T00:30:00", "2017-01-01T01:00:00"], + dtype="M8[ns]", + freq="30min", + ) + result = methodcaller(method, **method_args)(resampled) + expected = Series(expected_values, index=index) + tm.assert_series_equal(result, expected) + + +@pytest.fixture +def groupy_test_df(): + return DataFrame( + {"price": [10, 11, 9], "volume": [50, 60, 50]}, + index=date_range("01/01/2018", periods=3, freq="W", unit="ns"), + ) + + +def test_groupby_resample_interpolate_raises(groupy_test_df): + # GH 35325 + + # Make a copy of the test data frame that has index.name=None + groupy_test_df_without_index_name = groupy_test_df.copy() + groupy_test_df_without_index_name.index.name = None + + dfs = [groupy_test_df, groupy_test_df_without_index_name] + + for df in dfs: + with pytest.raises( + NotImplementedError, + match="Direct interpolation of MultiIndex data frames is not supported", + ): + df.groupby("volume").resample("1D").interpolate(method="linear") + + +def test_groupby_resample_interpolate_with_apply_syntax(groupy_test_df): + # GH 35325 + + # Make a copy of the test data frame that has index.name=None + groupy_test_df_without_index_name = groupy_test_df.copy() + groupy_test_df_without_index_name.index.name = None + + dfs = [groupy_test_df, groupy_test_df_without_index_name] + + for df in dfs: + result = df.groupby("volume").apply( + lambda x: x.resample("1D").interpolate(method="linear"), + ) + + volume = [50] * 15 + [60] + week_starting = [ + *list(date_range("2018-01-07", "2018-01-21", unit="ns")), + Timestamp("2018-01-14"), + ] + expected_ind = pd.MultiIndex.from_arrays( + [volume, week_starting], + names=["volume", df.index.name], + ) + + expected = DataFrame( + data={ + "price": [ + 10.0, + 9.928571428571429, + 9.857142857142858, + 9.785714285714286, + 9.714285714285714, + 9.642857142857142, + 9.571428571428571, + 9.5, + 9.428571428571429, + 9.357142857142858, + 9.285714285714286, + 9.214285714285714, + 9.142857142857142, + 9.071428571428571, + 9.0, + 11.0, + ] + }, + index=expected_ind, + ) + tm.assert_frame_equal(result, expected) + + +def test_groupby_resample_interpolate_with_apply_syntax_off_grid(groupy_test_df): + """Similar test as test_groupby_resample_interpolate_with_apply_syntax but + with resampling that results in missing anchor points when interpolating. + See GH#21351.""" + # GH#21351 + result = groupy_test_df.groupby("volume").apply( + lambda x: x.resample("265h").interpolate(method="linear") + ) + + volume = [50, 50, 60] + week_starting = pd.DatetimeIndex( + [ + Timestamp("2018-01-07"), + Timestamp("2018-01-18 01:00:00"), + Timestamp("2018-01-14"), + ] + ).as_unit("ns") + expected_ind = pd.MultiIndex.from_arrays( + [volume, week_starting], + names=["volume", "week_starting"], + ) + + expected = DataFrame( + data={"price": [10.0, 9.5, 11.0]}, + index=expected_ind, + ) + tm.assert_frame_equal(result, expected, check_names=False) diff --git a/pandas/tests/resample/test_timedelta.py b/pandas/tests/resample/test_timedelta.py new file mode 100644 index 0000000000000000000000000000000000000000..3bec66e3a1aa2c06bbcdac1be6a982a89f25d814 --- /dev/null +++ b/pandas/tests/resample/test_timedelta.py @@ -0,0 +1,218 @@ +from datetime import timedelta + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + DataFrame, + Series, +) +import pandas._testing as tm +from pandas.core.indexes.timedeltas import timedelta_range + + +def test_asfreq_bug(): + df = DataFrame(data=[1, 3], index=[timedelta(), timedelta(minutes=3)]) + result = df.resample("1min").asfreq() + expected = DataFrame( + data=[1, np.nan, np.nan, 3], + index=timedelta_range("0 day", periods=4, freq="1min", unit="us"), + ) + tm.assert_frame_equal(result, expected) + + +def test_resample_with_nat(): + # GH 13223 + index = pd.to_timedelta(["0s", pd.NaT, "2s"]) + result = DataFrame({"value": [2, 3, 5]}, index).resample("1s").mean() + expected = DataFrame( + {"value": [2.5, np.nan, 5.0]}, + index=timedelta_range("0 day", periods=3, freq="1s"), + ) + tm.assert_frame_equal(result, expected) + + +def test_resample_as_freq_with_subperiod(): + # GH 13022 + index = timedelta_range("00:00:00", "00:10:00", freq="5min") + df = DataFrame(data={"value": [1, 5, 10]}, index=index) + result = df.resample("2min").asfreq() + expected_data = {"value": [1, np.nan, np.nan, np.nan, np.nan, 10]} + expected = DataFrame( + data=expected_data, index=timedelta_range("00:00:00", "00:10:00", freq="2min") + ) + tm.assert_frame_equal(result, expected) + + +def test_resample_with_timedeltas(): + expected = DataFrame({"A": np.arange(1480)}) + expected = expected.groupby(expected.index // 30).sum() + expected.index = timedelta_range("0 days", freq="30min", periods=50) + + df = DataFrame( + {"A": np.arange(1480)}, + index=pd.to_timedelta(np.arange(1480), unit="min").as_unit("us"), + ) + result = df.resample("30min").sum() + + tm.assert_frame_equal(result, expected) + + s = df["A"] + result = s.resample("30min").sum() + tm.assert_series_equal(result, expected["A"]) + + +def test_resample_single_period_timedelta(): + s = Series(list(range(5)), index=timedelta_range("1 day", freq="s", periods=5)) + result = s.resample("2s").sum() + expected = Series([1, 5, 4], index=timedelta_range("1 day", freq="2s", periods=3)) + tm.assert_series_equal(result, expected) + + +def test_resample_timedelta_idempotency(): + # GH 12072 + index = timedelta_range("0", periods=9, freq="10ms") + series = Series(range(9), index=index) + result = series.resample("10ms").mean() + expected = series.astype(float) + tm.assert_series_equal(result, expected) + + +def test_resample_offset_with_timedeltaindex(): + # GH 10530 & 31809 + rng = timedelta_range(start="0s", periods=25, freq="s") + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + with_base = ts.resample("2s", offset="5s").mean() + without_base = ts.resample("2s").mean() + + exp_without_base = timedelta_range(start="0s", end="25s", freq="2s") + exp_with_base = timedelta_range(start="5s", end="29s", freq="2s") + + tm.assert_index_equal(without_base.index, exp_without_base) + tm.assert_index_equal(with_base.index, exp_with_base) + + +def test_resample_categorical_data_with_timedeltaindex(): + # GH #12169 + df = DataFrame({"Group_obj": "A"}, index=pd.to_timedelta(list(range(20)), unit="s")) + df["Group"] = df["Group_obj"].astype("category") + result = df.resample("10s").agg(lambda x: (x.value_counts().index[0])) + exp_tdi = pd.TimedeltaIndex(np.array([0, 10], dtype="m8[s]"), freq="10s") + expected = DataFrame( + {"Group_obj": ["A", "A"], "Group": ["A", "A"]}, + index=exp_tdi, + ) + expected = expected.reindex(["Group_obj", "Group"], axis=1) + expected["Group"] = expected["Group_obj"].astype("category") + tm.assert_frame_equal(result, expected) + + +def test_resample_timedelta_values(): + # GH 13119 + # check that timedelta dtype is preserved when NaT values are + # introduced by the resampling + + times = timedelta_range("1 day", "6 day", freq="4D") + df = DataFrame({"time": times}, index=times) + + times2 = timedelta_range("1 day", "6 day", freq="2D") + exp = Series(times2, index=times2, name="time") + exp.iloc[1] = pd.NaT + + res = df.resample("2D").first()["time"] + tm.assert_series_equal(res, exp) + res = df["time"].resample("2D").first() + tm.assert_series_equal(res, exp) + + +@pytest.mark.parametrize( + "start, end, freq, resample_freq", + [ + ("8h", "21h59min50s", "10s", "3h"), # GH 30353 example + ("3h", "22h", "1h", "5h"), + ("527D", "5006D", "3D", "10D"), + ("1D", "10D", "1D", "2D"), # GH 13022 example + # tests that worked before GH 33498: + ("8h", "21h59min50s", "10s", "2h"), + ("0h", "21h59min50s", "10s", "3h"), + ("10D", "85D", "D", "2D"), + ], +) +def test_resample_timedelta_edge_case(start, end, freq, resample_freq): + # GH 33498 + # check that the timedelta bins does not contains an extra bin + idx = timedelta_range(start=start, end=end, freq=freq) + s = Series(np.arange(len(idx)), index=idx) + result = s.resample(resample_freq).min() + expected_index = timedelta_range(freq=resample_freq, start=start, end=end) + tm.assert_index_equal(result.index, expected_index) + assert result.index.freq == expected_index.freq + assert not np.isnan(result.iloc[-1]) + + +@pytest.mark.parametrize("duplicates", [True, False]) +def test_resample_with_timedelta_yields_no_empty_groups(duplicates): + # GH 10603 + df = DataFrame( + np.random.default_rng(2).normal(size=(10000, 4)), + index=timedelta_range(start="0s", periods=10000, freq="3906250ns"), + ) + if duplicates: + # case with non-unique columns + df.columns = ["A", "B", "A", "C"] + + result = df.loc["1s":, :].resample("3s").apply(lambda x: len(x)) + + expected = DataFrame( + [[768] * 4] * 12 + [[528] * 4], + index=timedelta_range(start="1s", periods=13, freq="3s", unit="ns"), + ) + expected.columns = df.columns + tm.assert_frame_equal(result, expected) + + +def test_resample_quantile_timedelta(unit): + # GH: 29485 + dtype = np.dtype(f"m8[{unit}]") + df = DataFrame( + {"value": pd.to_timedelta(np.arange(4), unit="s").astype(dtype)}, + index=pd.date_range("20200101", periods=4, tz="UTC"), + ) + result = df.resample("2D").quantile(0.99) + expected = DataFrame( + { + "value": [ + pd.Timedelta("0 days 00:00:00.990000"), + pd.Timedelta("0 days 00:00:02.990000"), + ] + }, + index=pd.date_range("20200101", periods=2, tz="UTC", freq="2D"), + ).astype(dtype) + tm.assert_frame_equal(result, expected) + + +def test_resample_closed_right(): + # GH#45414 + idx = pd.Index([pd.Timedelta(seconds=120 + i * 30) for i in range(10)]) + ser = Series(range(10), index=idx) + result = ser.resample("min", closed="right", label="right").sum() + expected = Series( + [0, 3, 7, 11, 15, 9], + index=pd.TimedeltaIndex( + [pd.Timedelta(seconds=120 + i * 60) for i in range(6)], freq="min" + ), + ) + tm.assert_series_equal(result, expected) + + +@td.skip_if_no("pyarrow") +def test_arrow_duration_resample(): + # GH 56371 + idx = pd.Index(timedelta_range("1 day", periods=5), dtype="duration[ns][pyarrow]") + expected = Series(np.arange(5, dtype=np.float64), index=idx) + result = expected.resample("1D").mean() + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/reshape/__init__.py b/pandas/tests/reshape/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/reshape/test_crosstab.py b/pandas/tests/reshape/test_crosstab.py new file mode 100644 index 0000000000000000000000000000000000000000..1482da8a074eb41b64d276683ffc7258b4e9d0bb --- /dev/null +++ b/pandas/tests/reshape/test_crosstab.py @@ -0,0 +1,879 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + CategoricalDtype, + CategoricalIndex, + DataFrame, + Index, + MultiIndex, + Series, + crosstab, +) +import pandas._testing as tm + + +@pytest.fixture +def df(): + df = DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + return pd.concat([df, df], ignore_index=True) + + +class TestCrosstab: + def test_crosstab_single(self, df): + result = crosstab(df["A"], df["C"]) + expected = df.groupby(["A", "C"]).size().unstack() + tm.assert_frame_equal(result, expected.fillna(0).astype(np.int64)) + + def test_crosstab_multiple(self, df): + result = crosstab(df["A"], [df["B"], df["C"]]) + expected = df.groupby(["A", "B", "C"]).size() + expected = expected.unstack("B").unstack("C").fillna(0).astype(np.int64) + tm.assert_frame_equal(result, expected) + + result = crosstab([df["B"], df["C"]], df["A"]) + expected = df.groupby(["B", "C", "A"]).size() + expected = expected.unstack("A").fillna(0).astype(np.int64) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("box", [np.array, list, tuple]) + def test_crosstab_ndarray(self, box): + # GH 44076 + a = box(np.random.default_rng(2).integers(0, 5, size=100)) + b = box(np.random.default_rng(2).integers(0, 3, size=100)) + c = box(np.random.default_rng(2).integers(0, 10, size=100)) + + df = DataFrame({"a": a, "b": b, "c": c}) + + result = crosstab(a, [b, c], rownames=["a"], colnames=("b", "c")) + expected = crosstab(df["a"], [df["b"], df["c"]]) + tm.assert_frame_equal(result, expected) + + result = crosstab([b, c], a, colnames=["a"], rownames=("b", "c")) + expected = crosstab([df["b"], df["c"]], df["a"]) + tm.assert_frame_equal(result, expected) + + # assign arbitrary names + result = crosstab(a, c) + expected = crosstab(df["a"], df["c"]) + expected.index.names = ["row_0"] + expected.columns.names = ["col_0"] + tm.assert_frame_equal(result, expected) + + def test_crosstab_non_aligned(self): + # GH 17005 + a = Series([0, 1, 1], index=["a", "b", "c"]) + b = Series([3, 4, 3, 4, 3], index=["a", "b", "c", "d", "f"]) + c = np.array([3, 4, 3], dtype=np.int64) + + expected = DataFrame( + [[1, 0], [1, 1]], + index=Index([0, 1], name="row_0"), + columns=Index([3, 4], name="col_0"), + ) + + result = crosstab(a, b) + tm.assert_frame_equal(result, expected) + + result = crosstab(a, c) + tm.assert_frame_equal(result, expected) + + def test_crosstab_margins(self): + a = np.random.default_rng(2).integers(0, 7, size=100) + b = np.random.default_rng(2).integers(0, 3, size=100) + c = np.random.default_rng(2).integers(0, 5, size=100) + + df = DataFrame({"a": a, "b": b, "c": c}) + + result = crosstab(a, [b, c], rownames=["a"], colnames=("b", "c"), margins=True) + + assert result.index.names == ("a",) + assert result.columns.names == ["b", "c"] + + all_cols = result["All", ""] + exp_cols = df.groupby(["a"]).size().astype("i8") + # to keep index.name + exp_margin = Series([len(df)], index=Index(["All"], name="a")) + exp_cols = pd.concat([exp_cols, exp_margin]) + exp_cols.name = ("All", "") + + tm.assert_series_equal(all_cols, exp_cols) + + all_rows = result.loc["All"] + exp_rows = df.groupby(["b", "c"]).size().astype("i8") + exp_rows = pd.concat([exp_rows, Series([len(df)], index=[("All", "")])]) + exp_rows.name = "All" + + exp_rows = exp_rows.reindex(all_rows.index) + exp_rows = exp_rows.fillna(0).astype(np.int64) + tm.assert_series_equal(all_rows, exp_rows) + + def test_crosstab_margins_set_margin_name(self): + # GH 15972 + a = np.random.default_rng(2).integers(0, 7, size=100) + b = np.random.default_rng(2).integers(0, 3, size=100) + c = np.random.default_rng(2).integers(0, 5, size=100) + + df = DataFrame({"a": a, "b": b, "c": c}) + + result = crosstab( + a, + [b, c], + rownames=["a"], + colnames=("b", "c"), + margins=True, + margins_name="TOTAL", + ) + + assert result.index.names == ("a",) + assert result.columns.names == ["b", "c"] + + all_cols = result["TOTAL", ""] + exp_cols = df.groupby(["a"]).size().astype("i8") + # to keep index.name + exp_margin = Series([len(df)], index=Index(["TOTAL"], name="a")) + exp_cols = pd.concat([exp_cols, exp_margin]) + exp_cols.name = ("TOTAL", "") + + tm.assert_series_equal(all_cols, exp_cols) + + all_rows = result.loc["TOTAL"] + exp_rows = df.groupby(["b", "c"]).size().astype("i8") + exp_rows = pd.concat([exp_rows, Series([len(df)], index=[("TOTAL", "")])]) + exp_rows.name = "TOTAL" + + exp_rows = exp_rows.reindex(all_rows.index) + exp_rows = exp_rows.fillna(0).astype(np.int64) + tm.assert_series_equal(all_rows, exp_rows) + + msg = "margins_name argument must be a string" + for margins_name in [666, None, ["a", "b"]]: + with pytest.raises(ValueError, match=msg): + crosstab( + a, + [b, c], + rownames=["a"], + colnames=("b", "c"), + margins=True, + margins_name=margins_name, + ) + + def test_crosstab_pass_values(self): + a = np.random.default_rng(2).integers(0, 7, size=100) + b = np.random.default_rng(2).integers(0, 3, size=100) + c = np.random.default_rng(2).integers(0, 5, size=100) + values = np.random.default_rng(2).standard_normal(100) + + table = crosstab( + [a, b], c, values, aggfunc="sum", rownames=["foo", "bar"], colnames=["baz"] + ) + + df = DataFrame({"foo": a, "bar": b, "baz": c, "values": values}) + + expected = df.pivot_table( + "values", index=["foo", "bar"], columns="baz", aggfunc="sum" + ) + tm.assert_frame_equal(table, expected) + + def test_crosstab_dropna(self): + # GH 3820 + a = np.array(["foo", "foo", "foo", "bar", "bar", "foo", "foo"], dtype=object) + b = np.array(["one", "one", "two", "one", "two", "two", "two"], dtype=object) + c = np.array( + ["dull", "dull", "dull", "dull", "dull", "shiny", "shiny"], dtype=object + ) + res = crosstab(a, [b, c], rownames=["a"], colnames=["b", "c"], dropna=False) + m = MultiIndex.from_tuples( + [("one", "dull"), ("one", "shiny"), ("two", "dull"), ("two", "shiny")], + names=["b", "c"], + ) + tm.assert_index_equal(res.columns, m) + + def test_crosstab_no_overlap(self): + # GS 10291 + + s1 = Series([1, 2, 3], index=[1, 2, 3]) + s2 = Series([4, 5, 6], index=[4, 5, 6]) + + actual = crosstab(s1, s2) + expected = DataFrame( + index=Index([], dtype="int64", name="row_0"), + columns=Index([], dtype="int64", name="col_0"), + ) + + tm.assert_frame_equal(actual, expected) + + def test_margin_dropna(self): + # GH 12577 + # pivot_table counts null into margin ('All') + # when margins=true and dropna=true + + df = DataFrame({"a": [1, 2, 2, 2, 2, np.nan], "b": [3, 3, 4, 4, 4, 4]}) + actual = crosstab(df.a, df.b, margins=True, dropna=True) + expected = DataFrame([[1, 0, 1], [1, 3, 4], [2, 3, 5]]) + expected.index = Index([1.0, 2.0, "All"], name="a") + expected.columns = Index([3, 4, "All"], name="b") + tm.assert_frame_equal(actual, expected) + + def test_margin_dropna2(self): + df = DataFrame( + {"a": [1, np.nan, np.nan, np.nan, 2, np.nan], "b": [3, np.nan, 4, 4, 4, 4]} + ) + actual = crosstab(df.a, df.b, margins=True, dropna=True) + expected = DataFrame([[1, 0, 1], [0, 1, 1], [1, 1, 2]]) + expected.index = Index([1.0, 2.0, "All"], name="a") + expected.columns = Index([3.0, 4.0, "All"], name="b") + tm.assert_frame_equal(actual, expected) + + def test_margin_dropna3(self): + df = DataFrame( + {"a": [1, np.nan, np.nan, np.nan, np.nan, 2], "b": [3, 3, 4, 4, 4, 4]} + ) + actual = crosstab(df.a, df.b, margins=True, dropna=True) + expected = DataFrame([[1, 0, 1], [0, 1, 1], [1, 1, 2]]) + expected.index = Index([1.0, 2.0, "All"], name="a") + expected.columns = Index([3, 4, "All"], name="b") + tm.assert_frame_equal(actual, expected) + + def test_margin_dropna4(self): + # GH 12642 + # _add_margins raises KeyError: Level None not found + # when margins=True and dropna=False + # GH: 10772: Keep np.nan in result with dropna=False + df = DataFrame({"a": [1, 2, 2, 2, 2, np.nan], "b": [3, 3, 4, 4, 4, 4]}) + actual = crosstab(df.a, df.b, margins=True, dropna=False) + expected = DataFrame([[1, 0, 1], [1, 3, 4], [0, 1, 1], [2, 4, 6]]) + expected.index = Index([1.0, 2.0, np.nan, "All"], name="a") + expected.columns = Index([3, 4, "All"], name="b") + tm.assert_frame_equal(actual, expected) + + def test_margin_dropna5(self): + # GH: 10772: Keep np.nan in result with dropna=False + df = DataFrame( + {"a": [1, np.nan, np.nan, np.nan, 2, np.nan], "b": [3, np.nan, 4, 4, 4, 4]} + ) + actual = crosstab(df.a, df.b, margins=True, dropna=False) + expected = DataFrame( + [[1, 0, 0, 1.0], [0, 1, 0, 1.0], [0, 3, 1, 4.0], [1, 4, 1, 6.0]] + ) + expected.index = Index([1.0, 2.0, np.nan, "All"], name="a") + expected.columns = Index([3.0, 4.0, np.nan, "All"], name="b") + tm.assert_frame_equal(actual, expected, check_dtype=False) + + def test_margin_dropna6(self): + # GH: 10772: Keep np.nan in result with dropna=False + a = np.array(["foo", "foo", "foo", "bar", "bar", "foo", "foo"], dtype=object) + b = np.array(["one", "one", "two", "one", "two", np.nan, "two"], dtype=object) + c = np.array( + ["dull", "dull", "dull", "dull", "dull", "shiny", "shiny"], dtype=object + ) + + actual = crosstab( + a, [b, c], rownames=["a"], colnames=["b", "c"], margins=True, dropna=False + ) + m = MultiIndex.from_arrays( + [ + ["one", "one", "two", "two", np.nan, np.nan, "All"], + ["dull", "shiny", "dull", "shiny", "dull", "shiny", ""], + ], + names=["b", "c"], + ) + expected = DataFrame( + [[1, 0, 1, 0, 0, 0, 2], [2, 0, 1, 1, 0, 1, 5], [3, 0, 2, 1, 0, 1, 7]], + columns=m, + ) + expected.index = Index(["bar", "foo", "All"], name="a") + tm.assert_frame_equal(actual, expected) + + actual = crosstab( + [a, b], c, rownames=["a", "b"], colnames=["c"], margins=True, dropna=False + ) + m = MultiIndex.from_arrays( + [ + ["bar", "bar", "bar", "foo", "foo", "foo", "All"], + ["one", "two", np.nan, "one", "two", np.nan, ""], + ], + names=["a", "b"], + ) + expected = DataFrame( + [ + [1, 0, 1.0], + [1, 0, 1.0], + [0, 0, np.nan], + [2, 0, 2.0], + [1, 1, 2.0], + [0, 1, 1.0], + [5, 2, 7.0], + ], + index=m, + ) + expected.columns = Index(["dull", "shiny", "All"], name="c") + tm.assert_frame_equal(actual, expected) + + actual = crosstab( + [a, b], c, rownames=["a", "b"], colnames=["c"], margins=True, dropna=True + ) + m = MultiIndex.from_arrays( + [["bar", "bar", "foo", "foo", "All"], ["one", "two", "one", "two", ""]], + names=["a", "b"], + ) + expected = DataFrame( + [[1, 0, 1], [1, 0, 1], [2, 0, 2], [1, 1, 2], [5, 1, 6]], index=m + ) + expected.columns = Index(["dull", "shiny", "All"], name="c") + tm.assert_frame_equal(actual, expected) + + def test_crosstab_normalize(self): + # Issue 12578 + df = DataFrame( + {"a": [1, 2, 2, 2, 2], "b": [3, 3, 4, 4, 4], "c": [1, 1, np.nan, 1, 1]} + ) + + rindex = Index([1, 2], name="a") + cindex = Index([3, 4], name="b") + full_normal = DataFrame([[0.2, 0], [0.2, 0.6]], index=rindex, columns=cindex) + row_normal = DataFrame([[1.0, 0], [0.25, 0.75]], index=rindex, columns=cindex) + col_normal = DataFrame([[0.5, 0], [0.5, 1.0]], index=rindex, columns=cindex) + + # Check all normalize args + tm.assert_frame_equal(crosstab(df.a, df.b, normalize="all"), full_normal) + tm.assert_frame_equal(crosstab(df.a, df.b, normalize=True), full_normal) + tm.assert_frame_equal(crosstab(df.a, df.b, normalize="index"), row_normal) + tm.assert_frame_equal(crosstab(df.a, df.b, normalize="columns"), col_normal) + tm.assert_frame_equal( + crosstab(df.a, df.b, normalize=1), + crosstab(df.a, df.b, normalize="columns"), + ) + tm.assert_frame_equal( + crosstab(df.a, df.b, normalize=0), crosstab(df.a, df.b, normalize="index") + ) + + row_normal_margins = DataFrame( + [[1.0, 0], [0.25, 0.75], [0.4, 0.6]], + index=Index([1, 2, "All"], name="a", dtype="object"), + columns=Index([3, 4], name="b", dtype="object"), + ) + col_normal_margins = DataFrame( + [[0.5, 0, 0.2], [0.5, 1.0, 0.8]], + index=Index([1, 2], name="a", dtype="object"), + columns=Index([3, 4, "All"], name="b", dtype="object"), + ) + + all_normal_margins = DataFrame( + [[0.2, 0, 0.2], [0.2, 0.6, 0.8], [0.4, 0.6, 1]], + index=Index([1, 2, "All"], name="a", dtype="object"), + columns=Index([3, 4, "All"], name="b", dtype="object"), + ) + tm.assert_frame_equal( + crosstab(df.a, df.b, normalize="index", margins=True), row_normal_margins + ) + tm.assert_frame_equal( + crosstab(df.a, df.b, normalize="columns", margins=True), col_normal_margins + ) + tm.assert_frame_equal( + crosstab(df.a, df.b, normalize=True, margins=True), all_normal_margins + ) + + def test_crosstab_normalize_arrays(self): + # GH#12578 + df = DataFrame( + {"a": [1, 2, 2, 2, 2], "b": [3, 3, 4, 4, 4], "c": [1, 1, np.nan, 1, 1]} + ) + + # Test arrays + crosstab( + [np.array([1, 1, 2, 2]), np.array([1, 2, 1, 2])], np.array([1, 2, 1, 2]) + ) + + # Test with aggfunc + norm_counts = DataFrame( + [[0.25, 0, 0.25], [0.25, 0.5, 0.75], [0.5, 0.5, 1]], + index=Index([1, 2, "All"], name="a", dtype="object"), + columns=Index([3, 4, "All"], name="b"), + ) + test_case = crosstab( + df.a, df.b, df.c, aggfunc="count", normalize="all", margins=True + ) + tm.assert_frame_equal(test_case, norm_counts) + + df = DataFrame( + {"a": [1, 2, 2, 2, 2], "b": [3, 3, 4, 4, 4], "c": [0, 4, np.nan, 3, 3]} + ) + + norm_sum = DataFrame( + [[0, 0, 0.0], [0.4, 0.6, 1], [0.4, 0.6, 1]], + index=Index([1, 2, "All"], name="a", dtype="object"), + columns=Index([3, 4, "All"], name="b", dtype="object"), + ) + test_case = crosstab( + df.a, df.b, df.c, aggfunc=np.sum, normalize="all", margins=True + ) + tm.assert_frame_equal(test_case, norm_sum) + + def test_crosstab_with_empties(self): + # Check handling of empties + df = DataFrame( + { + "a": [1, 2, 2, 2, 2], + "b": [3, 3, 4, 4, 4], + "c": [np.nan, np.nan, np.nan, np.nan, np.nan], + } + ) + + empty = DataFrame( + [[0.0, 0.0], [0.0, 0.0]], + index=Index([1, 2], name="a", dtype="int64"), + columns=Index([3, 4], name="b"), + ) + + for i in [True, "index", "columns"]: + calculated = crosstab(df.a, df.b, values=df.c, aggfunc="count", normalize=i) + tm.assert_frame_equal(empty, calculated) + + nans = DataFrame( + [[0.0, np.nan], [0.0, 0.0]], + index=Index([1, 2], name="a", dtype="int64"), + columns=Index([3, 4], name="b"), + ) + + calculated = crosstab(df.a, df.b, values=df.c, aggfunc="count", normalize=False) + tm.assert_frame_equal(nans, calculated) + + def test_crosstab_errors(self): + # Issue 12578 + + df = DataFrame( + {"a": [1, 2, 2, 2, 2], "b": [3, 3, 4, 4, 4], "c": [1, 1, np.nan, 1, 1]} + ) + + error = "values cannot be used without an aggfunc." + with pytest.raises(ValueError, match=error): + crosstab(df.a, df.b, values=df.c) + + error = "aggfunc cannot be used without values" + with pytest.raises(ValueError, match=error): + crosstab(df.a, df.b, aggfunc=np.mean) + + error = "Not a valid normalize argument" + with pytest.raises(ValueError, match=error): + crosstab(df.a, df.b, normalize="42") + + with pytest.raises(ValueError, match=error): + crosstab(df.a, df.b, normalize=42) + + error = "Not a valid margins argument" + with pytest.raises(ValueError, match=error): + crosstab(df.a, df.b, normalize="all", margins=42) + + def test_crosstab_with_categorial_columns(self): + # GH 8860 + df = DataFrame( + { + "MAKE": ["Honda", "Acura", "Tesla", "Honda", "Honda", "Acura"], + "MODEL": ["Sedan", "Sedan", "Electric", "Pickup", "Sedan", "Sedan"], + } + ) + categories = ["Sedan", "Electric", "Pickup"] + df["MODEL"] = df["MODEL"].astype("category").cat.set_categories(categories) + result = crosstab(df["MAKE"], df["MODEL"]) + + expected_index = Index(["Acura", "Honda", "Tesla"], name="MAKE") + expected_columns = CategoricalIndex( + categories, categories=categories, ordered=False, name="MODEL" + ) + expected_data = [[2, 0, 0], [2, 0, 1], [0, 1, 0]] + expected = DataFrame( + expected_data, index=expected_index, columns=expected_columns + ) + tm.assert_frame_equal(result, expected) + + def test_crosstab_with_numpy_size(self): + # GH 4003 + df = DataFrame( + { + "A": ["one", "one", "two", "three"] * 6, + "B": ["A", "B", "C"] * 8, + "C": ["foo", "foo", "foo", "bar", "bar", "bar"] * 4, + "D": np.random.default_rng(2).standard_normal(24), + "E": np.random.default_rng(2).standard_normal(24), + } + ) + result = crosstab( + index=[df["A"], df["B"]], + columns=[df["C"]], + margins=True, + aggfunc=np.size, + values=df["D"], + ) + expected_index = MultiIndex( + levels=[["All", "one", "three", "two"], ["", "A", "B", "C"]], + codes=[[1, 1, 1, 2, 2, 2, 3, 3, 3, 0], [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]], + names=["A", "B"], + ) + expected_column = Index(["bar", "foo", "All"], name="C") + expected_data = np.array( + [ + [2.0, 2.0, 4.0], + [2.0, 2.0, 4.0], + [2.0, 2.0, 4.0], + [2.0, np.nan, 2.0], + [np.nan, 2.0, 2.0], + [2.0, np.nan, 2.0], + [np.nan, 2.0, 2.0], + [2.0, np.nan, 2.0], + [np.nan, 2.0, 2.0], + [12.0, 12.0, 24.0], + ] + ) + expected = DataFrame( + expected_data, index=expected_index, columns=expected_column + ) + # aggfunc is np.size, resulting in integers + expected["All"] = expected["All"].astype("int64") + tm.assert_frame_equal(result, expected) + + def test_crosstab_duplicate_names(self): + # GH 13279 / 22529 + + s1 = Series(range(3), name="foo") + s2_foo = Series(range(1, 4), name="foo") + s2_bar = Series(range(1, 4), name="bar") + s3 = Series(range(3), name="waldo") + + # check result computed with duplicate labels against + # result computed with unique labels, then relabelled + mapper = {"bar": "foo"} + + # duplicate row, column labels + result = crosstab(s1, s2_foo) + expected = crosstab(s1, s2_bar).rename_axis(columns=mapper, axis=1) + tm.assert_frame_equal(result, expected) + + # duplicate row, unique column labels + result = crosstab([s1, s2_foo], s3) + expected = crosstab([s1, s2_bar], s3).rename_axis(index=mapper, axis=0) + tm.assert_frame_equal(result, expected) + + # unique row, duplicate column labels + result = crosstab(s3, [s1, s2_foo]) + expected = crosstab(s3, [s1, s2_bar]).rename_axis(columns=mapper, axis=1) + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("names", [["a", ("b", "c")], [("a", "b"), "c"]]) + def test_crosstab_tuple_name(self, names): + s1 = Series(range(3), name=names[0]) + s2 = Series(range(1, 4), name=names[1]) + + mi = MultiIndex.from_arrays([range(3), range(1, 4)], names=names) + expected = Series(1, index=mi).unstack(1, fill_value=0) + + result = crosstab(s1, s2) + tm.assert_frame_equal(result, expected) + + def test_crosstab_both_tuple_names(self): + # GH 18321 + s1 = Series(range(3), name=("a", "b")) + s2 = Series(range(3), name=("c", "d")) + + expected = DataFrame( + np.eye(3, dtype="int64"), + index=Index(range(3), name=("a", "b")), + columns=Index(range(3), name=("c", "d")), + ) + result = crosstab(s1, s2) + tm.assert_frame_equal(result, expected) + + def test_crosstab_unsorted_order(self): + df = DataFrame({"b": [3, 1, 2], "a": [5, 4, 6]}, index=["C", "A", "B"]) + result = crosstab(df.index, [df.b, df.a]) + e_idx = Index(["A", "B", "C"], name="row_0") + e_columns = MultiIndex.from_tuples([(1, 4), (2, 6), (3, 5)], names=["b", "a"]) + expected = DataFrame( + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], index=e_idx, columns=e_columns + ) + tm.assert_frame_equal(result, expected) + + def test_crosstab_normalize_multiple_columns(self): + # GH 15150 + df = DataFrame( + { + "A": ["one", "one", "two", "three"] * 6, + "B": ["A", "B", "C"] * 8, + "C": ["foo", "foo", "foo", "bar", "bar", "bar"] * 4, + "D": [0] * 24, + "E": [0] * 24, + } + ) + + result = crosstab( + [df.A, df.B], + df.C, + values=df.D, + aggfunc=np.sum, + normalize=True, + margins=True, + ) + expected = DataFrame( + np.array([0] * 29 + [1], dtype=float).reshape(10, 3), + columns=Index(["bar", "foo", "All"], name="C"), + index=MultiIndex.from_tuples( + [ + ("one", "A"), + ("one", "B"), + ("one", "C"), + ("three", "A"), + ("three", "B"), + ("three", "C"), + ("two", "A"), + ("two", "B"), + ("two", "C"), + ("All", ""), + ], + names=["A", "B"], + ), + ) + tm.assert_frame_equal(result, expected) + + def test_margin_normalize(self): + # GH 27500 + df = DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "C": [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + "E": [2, 4, 5, 5, 6, 6, 8, 9, 9], + } + ) + # normalize on index + result = crosstab( + [df.A, df.B], df.C, margins=True, margins_name="Sub-Total", normalize=0 + ) + expected = DataFrame( + [[0.5, 0.5], [0.5, 0.5], [0.666667, 0.333333], [0, 1], [0.444444, 0.555556]] + ) + expected.index = MultiIndex( + levels=[["Sub-Total", "bar", "foo"], ["", "one", "two"]], + codes=[[1, 1, 2, 2, 0], [1, 2, 1, 2, 0]], + names=["A", "B"], + ) + expected.columns = Index(["large", "small"], name="C") + tm.assert_frame_equal(result, expected) + + # normalize on columns + result = crosstab( + [df.A, df.B], df.C, margins=True, margins_name="Sub-Total", normalize=1 + ) + expected = DataFrame( + [ + [0.25, 0.2, 0.222222], + [0.25, 0.2, 0.222222], + [0.5, 0.2, 0.333333], + [0, 0.4, 0.222222], + ] + ) + expected.columns = Index(["large", "small", "Sub-Total"], name="C") + expected.index = MultiIndex( + levels=[["bar", "foo"], ["one", "two"]], + codes=[[0, 0, 1, 1], [0, 1, 0, 1]], + names=["A", "B"], + ) + tm.assert_frame_equal(result, expected) + + # normalize on both index and column + result = crosstab( + [df.A, df.B], df.C, margins=True, margins_name="Sub-Total", normalize=True + ) + expected = DataFrame( + [ + [0.111111, 0.111111, 0.222222], + [0.111111, 0.111111, 0.222222], + [0.222222, 0.111111, 0.333333], + [0.000000, 0.222222, 0.222222], + [0.444444, 0.555555, 1], + ] + ) + expected.columns = Index(["large", "small", "Sub-Total"], name="C") + expected.index = MultiIndex( + levels=[["Sub-Total", "bar", "foo"], ["", "one", "two"]], + codes=[[1, 1, 2, 2, 0], [1, 2, 1, 2, 0]], + names=["A", "B"], + ) + tm.assert_frame_equal(result, expected) + + def test_margin_normalize_multiple_columns(self): + # GH 35144 + # use multiple columns with margins and normalization + df = DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "C": [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + "E": [2, 4, 5, 5, 6, 6, 8, 9, 9], + } + ) + result = crosstab( + index=df.C, + columns=[df.A, df.B], + margins=True, + margins_name="margin", + normalize=True, + ) + expected = DataFrame( + [ + [0.111111, 0.111111, 0.222222, 0.000000, 0.444444], + [0.111111, 0.111111, 0.111111, 0.222222, 0.555556], + [0.222222, 0.222222, 0.333333, 0.222222, 1.0], + ], + index=["large", "small", "margin"], + ) + expected.columns = MultiIndex( + levels=[["bar", "foo", "margin"], ["", "one", "two"]], + codes=[[0, 0, 1, 1, 2], [1, 2, 1, 2, 0]], + names=["A", "B"], + ) + expected.index.name = "C" + tm.assert_frame_equal(result, expected) + + def test_margin_support_Float(self): + # GH 50313 + # use Float64 formats and function aggfunc with margins + df = DataFrame( + {"A": [1, 2, 2, 1], "B": [3, 3, 4, 5], "C": [-1.0, 10.0, 1.0, 10.0]}, + dtype="Float64", + ) + result = crosstab( + df["A"], + df["B"], + values=df["C"], + aggfunc="sum", + margins=True, + ) + expected = DataFrame( + [ + [-1.0, pd.NA, 10.0, 9.0], + [10.0, 1.0, pd.NA, 11.0], + [9.0, 1.0, 10.0, 20.0], + ], + index=Index([1.0, 2.0, "All"], dtype="object", name="A"), + columns=Index([3.0, 4.0, 5.0, "All"], dtype="object", name="B"), + dtype="Float64", + ) + tm.assert_frame_equal(result, expected) + + def test_margin_with_ordered_categorical_column(self): + # GH 25278 + df = DataFrame( + { + "First": ["B", "B", "C", "A", "B", "C"], + "Second": ["C", "B", "B", "B", "C", "A"], + } + ) + df["First"] = df["First"].astype(CategoricalDtype(ordered=True)) + customized_categories_order = ["C", "A", "B"] + df["First"] = df["First"].cat.reorder_categories(customized_categories_order) + result = crosstab(df["First"], df["Second"], margins=True) + + expected_index = Index(["C", "A", "B", "All"], name="First") + expected_columns = Index(["A", "B", "C", "All"], name="Second") + expected_data = [[1, 1, 0, 2], [0, 1, 0, 1], [0, 1, 2, 3], [1, 3, 2, 6]] + expected = DataFrame( + expected_data, index=expected_index, columns=expected_columns + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("a_dtype", ["category", "int64"]) +@pytest.mark.parametrize("b_dtype", ["category", "int64"]) +def test_categoricals(a_dtype, b_dtype): + # https://github.com/pandas-dev/pandas/issues/37465 + g = np.random.default_rng(2) + a = Series(g.integers(0, 3, size=100)).astype(a_dtype) + b = Series(g.integers(0, 2, size=100)).astype(b_dtype) + result = crosstab(a, b, margins=True, dropna=False) + columns = Index([0, 1, "All"], dtype="object", name="col_0") + index = Index([0, 1, 2, "All"], dtype="object", name="row_0") + values = [[10, 18, 28], [23, 16, 39], [17, 16, 33], [50, 50, 100]] + expected = DataFrame(values, index, columns) + tm.assert_frame_equal(result, expected) + + # Verify when categorical does not have all values present + a.loc[a == 1] = 2 + a_is_cat = isinstance(a.dtype, CategoricalDtype) + assert not a_is_cat or a.value_counts().loc[1] == 0 + result = crosstab(a, b, margins=True, dropna=False) + values = [[10, 18, 28], [0, 0, 0], [40, 32, 72], [50, 50, 100]] + expected = DataFrame(values, index, columns) + if not a_is_cat: + expected = expected.loc[[0, 2, "All"]] + expected["All"] = expected["All"].astype("int64") + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/reshape/test_cut.py b/pandas/tests/reshape/test_cut.py new file mode 100644 index 0000000000000000000000000000000000000000..909c10d3f73b20bf55ceba1663e89fd0c692cb30 --- /dev/null +++ b/pandas/tests/reshape/test_cut.py @@ -0,0 +1,828 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + DatetimeIndex, + Index, + Interval, + IntervalIndex, + Series, + TimedeltaIndex, + Timestamp, + cut, + date_range, + interval_range, + isna, + qcut, + timedelta_range, + to_datetime, +) +import pandas._testing as tm +from pandas.api.types import CategoricalDtype +import pandas.core.reshape.tile as tmod + + +def test_simple(): + data = np.ones(5, dtype="int64") + result = cut(data, 4, labels=False) + + expected = np.array([1, 1, 1, 1, 1]) + tm.assert_numpy_array_equal(result, expected, check_dtype=False) + + +@pytest.mark.parametrize("func", [list, np.array]) +def test_bins(func): + data = func([0.2, 1.4, 2.5, 6.2, 9.7, 2.1]) + result, bins = cut(data, 3, retbins=True) + + intervals = IntervalIndex.from_breaks(bins.round(3)) + intervals = intervals.take([0, 0, 0, 1, 2, 0]) + expected = Categorical(intervals, ordered=True) + + tm.assert_categorical_equal(result, expected) + tm.assert_almost_equal(bins, np.array([0.1905, 3.36666667, 6.53333333, 9.7])) + + +def test_right(): + data = np.array([0.2, 1.4, 2.5, 6.2, 9.7, 2.1, 2.575]) + result, bins = cut(data, 4, right=True, retbins=True) + + intervals = IntervalIndex.from_breaks(bins.round(3)) + expected = Categorical(intervals, ordered=True) + expected = expected.take([0, 0, 0, 2, 3, 0, 0]) + + tm.assert_categorical_equal(result, expected) + tm.assert_almost_equal(bins, np.array([0.1905, 2.575, 4.95, 7.325, 9.7])) + + +def test_no_right(): + data = np.array([0.2, 1.4, 2.5, 6.2, 9.7, 2.1, 2.575]) + result, bins = cut(data, 4, right=False, retbins=True) + + intervals = IntervalIndex.from_breaks(bins.round(3), closed="left") + intervals = intervals.take([0, 0, 0, 2, 3, 0, 1]) + expected = Categorical(intervals, ordered=True) + + tm.assert_categorical_equal(result, expected) + tm.assert_almost_equal(bins, np.array([0.2, 2.575, 4.95, 7.325, 9.7095])) + + +def test_bins_from_interval_index(): + c = cut(range(5), 3) + expected = c + result = cut(range(5), bins=expected.categories) + tm.assert_categorical_equal(result, expected) + + expected = Categorical.from_codes( + np.append(c.codes, -1), categories=c.categories, ordered=True + ) + result = cut(range(6), bins=expected.categories) + tm.assert_categorical_equal(result, expected) + + +def test_bins_from_interval_index_doc_example(): + # Make sure we preserve the bins. + ages = np.array([10, 15, 13, 12, 23, 25, 28, 59, 60]) + c = cut(ages, bins=[0, 18, 35, 70]) + expected = IntervalIndex.from_tuples([(0, 18), (18, 35), (35, 70)]) + tm.assert_index_equal(c.categories, expected) + + result = cut([25, 20, 50], bins=c.categories) + tm.assert_index_equal(result.categories, expected) + tm.assert_numpy_array_equal(result.codes, np.array([1, 1, 2], dtype="int8")) + + +def test_bins_not_overlapping_from_interval_index(): + # see gh-23980 + msg = "Overlapping IntervalIndex is not accepted" + ii = IntervalIndex.from_tuples([(0, 10), (2, 12), (4, 14)]) + + with pytest.raises(ValueError, match=msg): + cut([5, 6], bins=ii) + + +def test_bins_not_monotonic(): + msg = "bins must increase monotonically" + data = [0.2, 1.4, 2.5, 6.2, 9.7, 2.1] + + with pytest.raises(ValueError, match=msg): + cut(data, [0.1, 1.5, 1, 10]) + + +@pytest.mark.parametrize( + "x, bins, expected", + [ + ( + date_range("2017-12-31", periods=3), + [Timestamp.min, Timestamp("2018-01-01"), Timestamp.max], + IntervalIndex.from_tuples( + [ + (Timestamp.min, Timestamp("2018-01-01")), + (Timestamp("2018-01-01"), Timestamp.max), + ] + ), + ), + ( + [-1, 0, 1], + np.array( + [np.iinfo(np.int64).min, 0, np.iinfo(np.int64).max], dtype="int64" + ), + IntervalIndex.from_tuples( + [(np.iinfo(np.int64).min, 0), (0, np.iinfo(np.int64).max)] + ), + ), + ( + [ + np.timedelta64(-1, "ns"), + np.timedelta64(0, "ns"), + np.timedelta64(1, "ns"), + ], + np.array( + [ + np.timedelta64(-np.iinfo(np.int64).max, "ns"), + np.timedelta64(0, "ns"), + np.timedelta64(np.iinfo(np.int64).max, "ns"), + ] + ), + IntervalIndex.from_tuples( + [ + ( + np.timedelta64(-np.iinfo(np.int64).max, "ns"), + np.timedelta64(0, "ns"), + ), + ( + np.timedelta64(0, "ns"), + np.timedelta64(np.iinfo(np.int64).max, "ns"), + ), + ] + ), + ), + ], +) +def test_bins_monotonic_not_overflowing(x, bins, expected): + # GH 26045 + result = cut(x, bins) + tm.assert_index_equal(result.categories, expected) + + +def test_wrong_num_labels(): + msg = "Bin labels must be one fewer than the number of bin edges" + data = [0.2, 1.4, 2.5, 6.2, 9.7, 2.1] + + with pytest.raises(ValueError, match=msg): + cut(data, [0, 1, 10], labels=["foo", "bar", "baz"]) + + +@pytest.mark.parametrize( + "x,bins,msg", + [ + ([], 2, "Cannot cut empty array"), + ([1, 2, 3], 0.5, "`bins` should be a positive integer"), + ], +) +def test_cut_corner(x, bins, msg): + with pytest.raises(ValueError, match=msg): + cut(x, bins) + + +@pytest.mark.parametrize("arg", [2, np.eye(2), DataFrame(np.eye(2))]) +@pytest.mark.parametrize("cut_func", [cut, qcut]) +def test_cut_not_1d_arg(arg, cut_func): + msg = "Input array must be 1 dimensional" + with pytest.raises(ValueError, match=msg): + cut_func(arg, 2) + + +@pytest.mark.parametrize( + "data", + [ + [0, 1, 2, 3, 4, np.inf], + [-np.inf, 0, 1, 2, 3, 4], + [-np.inf, 0, 1, 2, 3, 4, np.inf], + ], +) +def test_int_bins_with_inf(data): + # GH 24314 + msg = "cannot specify integer `bins` when input data contains infinity" + with pytest.raises(ValueError, match=msg): + cut(data, bins=3) + + +def test_cut_out_of_range_more(): + # see gh-1511 + name = "x" + + ser = Series([0, -1, 0, 1, -3], name=name) + ind = cut(ser, [0, 1], labels=False) + + exp = Series([np.nan, np.nan, np.nan, 0, np.nan], name=name) + tm.assert_series_equal(ind, exp) + + +@pytest.mark.parametrize( + "right,breaks,closed", + [ + (True, [-1e-3, 0.25, 0.5, 0.75, 1], "right"), + (False, [0, 0.25, 0.5, 0.75, 1 + 1e-3], "left"), + ], +) +def test_labels(right, breaks, closed): + arr = np.tile(np.arange(0, 1.01, 0.1), 4) + + result, bins = cut(arr, 4, retbins=True, right=right) + ex_levels = IntervalIndex.from_breaks(breaks, closed=closed) + tm.assert_index_equal(result.categories, ex_levels) + + +def test_cut_pass_series_name_to_factor(): + name = "foo" + ser = Series(np.random.default_rng(2).standard_normal(100), name=name) + + factor = cut(ser, 4) + assert factor.name == name + + +def test_label_precision(): + arr = np.arange(0, 0.73, 0.01) + result = cut(arr, 4, precision=2) + + ex_levels = IntervalIndex.from_breaks([-0.00072, 0.18, 0.36, 0.54, 0.72]) + tm.assert_index_equal(result.categories, ex_levels) + + +@pytest.mark.parametrize("labels", [None, False]) +def test_na_handling(labels): + arr = np.arange(0, 0.75, 0.01) + arr[::3] = np.nan + + result = cut(arr, 4, labels=labels) + result = np.asarray(result) + + expected = np.where(isna(arr), np.nan, result) + tm.assert_almost_equal(result, expected) + + +def test_inf_handling(): + data = np.arange(6) + data_ser = Series(data, dtype="int64") + + bins = [-np.inf, 2, 4, np.inf] + result = cut(data, bins) + result_ser = cut(data_ser, bins) + + ex_uniques = IntervalIndex.from_breaks(bins) + tm.assert_index_equal(result.categories, ex_uniques) + + assert result[5] == Interval(4, np.inf) + assert result[0] == Interval(-np.inf, 2) + assert result_ser[5] == Interval(4, np.inf) + assert result_ser[0] == Interval(-np.inf, 2) + + +def test_cut_out_of_bounds(): + arr = np.random.default_rng(2).standard_normal(100) + result = cut(arr, [-1, 0, 1]) + + mask = isna(result) + ex_mask = (arr < -1) | (arr > 1) + tm.assert_numpy_array_equal(mask, ex_mask) + + +@pytest.mark.parametrize( + "get_labels,get_expected", + [ + ( + lambda labels: labels, + lambda labels: Categorical( + ["Medium"] + 4 * ["Small"] + ["Medium", "Large"], + categories=labels, + ordered=True, + ), + ), + ( + lambda labels: Categorical.from_codes([0, 1, 2], labels), + lambda labels: Categorical.from_codes([1] + 4 * [0] + [1, 2], labels), + ), + ], +) +def test_cut_pass_labels(get_labels, get_expected): + bins = [0, 25, 50, 100] + arr = [50, 5, 10, 15, 20, 30, 70] + labels = ["Small", "Medium", "Large"] + + result = cut(arr, bins, labels=get_labels(labels)) + tm.assert_categorical_equal(result, get_expected(labels)) + + +def test_cut_pass_labels_compat(): + # see gh-16459 + arr = [50, 5, 10, 15, 20, 30, 70] + labels = ["Good", "Medium", "Bad"] + + result = cut(arr, 3, labels=labels) + exp = cut(arr, 3, labels=Categorical(labels, categories=labels, ordered=True)) + tm.assert_categorical_equal(result, exp) + + +@pytest.mark.parametrize("x", [np.arange(11.0), np.arange(11.0) / 1e10]) +def test_round_frac_just_works(x): + # It works. + cut(x, 2) + + +@pytest.mark.parametrize( + "val,precision,expected", + [ + (-117.9998, 3, -118), + (117.9998, 3, 118), + (117.9998, 2, 118), + (0.000123456, 2, 0.00012), + ], +) +def test_round_frac(val, precision, expected): + # see gh-1979 + result = tmod._round_frac(val, precision=precision) + assert result == expected + + +def test_cut_return_intervals(): + ser = Series([0, 1, 2, 3, 4, 5, 6, 7, 8]) + result = cut(ser, 3) + + exp_bins = np.linspace(0, 8, num=4).round(3) + exp_bins[0] -= 0.008 + + expected = Series( + IntervalIndex.from_breaks(exp_bins, closed="right").take( + [0, 0, 0, 1, 1, 1, 2, 2, 2] + ) + ).astype(CategoricalDtype(ordered=True)) + tm.assert_series_equal(result, expected) + + +def test_series_ret_bins(): + # see gh-8589 + ser = Series(np.arange(4)) + result, bins = cut(ser, 2, retbins=True) + + expected = Series( + IntervalIndex.from_breaks([-0.003, 1.5, 3], closed="right").repeat(2) + ).astype(CategoricalDtype(ordered=True)) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "kwargs,msg", + [ + ({"duplicates": "drop"}, None), + ({}, "Bin edges must be unique"), + ({"duplicates": "raise"}, "Bin edges must be unique"), + ({"duplicates": "foo"}, "invalid value for 'duplicates' parameter"), + ], +) +def test_cut_duplicates_bin(kwargs, msg): + # see gh-20947 + bins = [0, 2, 4, 6, 10, 10] + values = Series(np.array([1, 3, 5, 7, 9]), index=["a", "b", "c", "d", "e"]) + + if msg is not None: + with pytest.raises(ValueError, match=msg): + cut(values, bins, **kwargs) + else: + result = cut(values, bins, **kwargs) + expected = cut(values, pd.unique(np.asarray(bins))) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("data", [9.0, -9.0, 0.0]) +@pytest.mark.parametrize("length", [1, 2]) +def test_single_bin(data, length): + # see gh-14652, gh-15428 + ser = Series([data] * length) + result = cut(ser, 1, labels=False) + + expected = Series([0] * length, dtype=np.intp) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "values,threshold", + [ + ([0.1, 0.1, 0.1], 0.001), # small positive values + ([-0.1, -0.1, -0.1], 0.001), # negative values + ([0.01, 0.01, 0.01], 0.0001), # very small values + ], +) +def test_single_bin_edge_adjustment(values, threshold): + # gh-58517 - edge adjustment mutation when all values are same + result, bins = cut(values, 3, retbins=True) + + bin_range = bins[-1] - bins[0] + assert bin_range < threshold + + +@pytest.mark.parametrize( + "array_1_writeable,array_2_writeable", [(True, True), (True, False), (False, False)] +) +def test_cut_read_only(array_1_writeable, array_2_writeable): + # issue 18773 + array_1 = np.arange(0, 100, 10) + array_1.flags.writeable = array_1_writeable + + array_2 = np.arange(0, 100, 10) + array_2.flags.writeable = array_2_writeable + + hundred_elements = np.arange(100) + tm.assert_categorical_equal( + cut(hundred_elements, array_1), cut(hundred_elements, array_2) + ) + + +@pytest.mark.parametrize( + "conv", + [ + lambda v: Timestamp(v), + lambda v: to_datetime(v), + lambda v: np.datetime64(v), + lambda v: Timestamp(v).to_pydatetime(), + ], +) +def test_datetime_bin(conv): + data = [np.datetime64("2012-12-13"), np.datetime64("2012-12-15")] + bin_data = ["2012-12-12", "2012-12-14", "2012-12-16"] + + expected = Series( + IntervalIndex( + [ + Interval(Timestamp(bin_data[0]), Timestamp(bin_data[1])), + Interval(Timestamp(bin_data[1]), Timestamp(bin_data[2])), + ] + ) + ) + + bins = [conv(v) for v in bin_data] + result = Series(cut(data, bins=bins)) + + if type(bins[0]) is np.datetime64: + # The bins have microsecond dtype -> so does result + expected = expected.astype("interval[datetime64[s]]") + + expected = expected.astype(CategoricalDtype(ordered=True)) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("box", [Series, Index, np.array, list]) +def test_datetime_cut(unit, box): + # see gh-14714 + # + # Testing time data when it comes in various collection types. + data = to_datetime(["2013-01-01", "2013-01-02", "2013-01-03"]).astype(f"M8[{unit}]") + data = box(data) + result, _ = cut(data, 3, retbins=True) + + if unit == "s": + # See https://github.com/pandas-dev/pandas/pull/56101#discussion_r1405325425 + # for why we round to 8 seconds instead of 7 + left = DatetimeIndex( + ["2012-12-31 23:57:08", "2013-01-01 16:00:00", "2013-01-02 08:00:00"], + dtype=f"M8[{unit}]", + ) + else: + left = DatetimeIndex( + [ + "2012-12-31 23:57:07.200000", + "2013-01-01 16:00:00", + "2013-01-02 08:00:00", + ], + dtype=f"M8[{unit}]", + ) + right = DatetimeIndex( + ["2013-01-01 16:00:00", "2013-01-02 08:00:00", "2013-01-03 00:00:00"], + dtype=f"M8[{unit}]", + ) + + exp_intervals = IntervalIndex.from_arrays(left, right) + expected = Series(exp_intervals).astype(CategoricalDtype(ordered=True)) + tm.assert_series_equal(Series(result), expected) + + +@pytest.mark.parametrize("box", [list, np.array, Index, Series]) +def test_datetime_tz_cut_mismatched_tzawareness(box): + # GH#54964 + bins = box( + [ + Timestamp("2013-01-01 04:57:07.200000"), + Timestamp("2013-01-01 21:00:00"), + Timestamp("2013-01-02 13:00:00"), + Timestamp("2013-01-03 05:00:00"), + ] + ) + ser = Series(date_range("20130101", periods=3, tz="US/Eastern")) + + msg = "Cannot use timezone-naive bins with timezone-aware values" + with pytest.raises(ValueError, match=msg): + cut(ser, bins) + + +@pytest.mark.parametrize( + "bins", + [ + 3, + [ + Timestamp("2013-01-01 04:57:07.200000", tz="UTC").tz_convert("US/Eastern"), + Timestamp("2013-01-01 21:00:00", tz="UTC").tz_convert("US/Eastern"), + Timestamp("2013-01-02 13:00:00", tz="UTC").tz_convert("US/Eastern"), + Timestamp("2013-01-03 05:00:00", tz="UTC").tz_convert("US/Eastern"), + ], + ], +) +@pytest.mark.parametrize("box", [list, np.array, Index, Series]) +def test_datetime_tz_cut(bins, box): + # see gh-19872 + tz = "US/Eastern" + ser = Series(date_range("20130101", periods=3, tz=tz, unit="ns")) + + if not isinstance(bins, int): + bins = box(bins) + + result = cut(ser, bins) + ii = IntervalIndex( + [ + Interval( + Timestamp("2012-12-31 23:57:07.200000", tz=tz), + Timestamp("2013-01-01 16:00:00", tz=tz), + ), + Interval( + Timestamp("2013-01-01 16:00:00", tz=tz), + Timestamp("2013-01-02 08:00:00", tz=tz), + ), + Interval( + Timestamp("2013-01-02 08:00:00", tz=tz), + Timestamp("2013-01-03 00:00:00", tz=tz), + ), + ] + ) + if isinstance(bins, int): + # the dtype is inferred from ser, which has nanosecond unit + ii = ii.astype("interval[datetime64[ns, US/Eastern]]") + expected = Series(ii).astype(CategoricalDtype(ordered=True)) + tm.assert_series_equal(result, expected) + + +def test_datetime_nan_error(): + msg = "bins must be of datetime64 dtype" + + with pytest.raises(ValueError, match=msg): + cut(date_range("20130101", periods=3), bins=[0, 2, 4]) + + +def test_datetime_nan_mask(): + result = cut( + date_range("20130102", periods=5), bins=date_range("20130101", periods=2) + ) + + mask = result.categories.isna() + tm.assert_numpy_array_equal(mask, np.array([False])) + + mask = result.isna() + tm.assert_numpy_array_equal(mask, np.array([False, True, True, True, True])) + + +@pytest.mark.parametrize("tz", [None, "UTC", "US/Pacific"]) +def test_datetime_cut_roundtrip(tz, unit): + # see gh-19891 + ser = Series(date_range("20180101", periods=3, tz=tz, unit=unit)) + result, result_bins = cut(ser, 2, retbins=True) + + expected = cut(ser, result_bins) + tm.assert_series_equal(result, expected) + + if unit == "s": + # TODO: constructing DatetimeIndex with dtype="M8[s]" without truncating + # the first entry here raises in array_to_datetime. Should truncate + # instead of raising? + # See https://github.com/pandas-dev/pandas/pull/56101#discussion_r1405325425 + # for why we round to 8 seconds instead of 7 + expected_bins = DatetimeIndex( + ["2017-12-31 23:57:08", "2018-01-02 00:00:00", "2018-01-03 00:00:00"], + dtype=f"M8[{unit}]", + ) + else: + expected_bins = DatetimeIndex( + [ + "2017-12-31 23:57:07.200000", + "2018-01-02 00:00:00", + "2018-01-03 00:00:00", + ], + dtype=f"M8[{unit}]", + ) + expected_bins = expected_bins.tz_localize(tz) + tm.assert_index_equal(result_bins, expected_bins) + + +def test_timedelta_cut_roundtrip(): + # see gh-19891 + ser = Series(timedelta_range("1day", periods=3)) + result, result_bins = cut(ser, 2, retbins=True) + + expected = cut(ser, result_bins) + tm.assert_series_equal(result, expected) + + expected_bins = TimedeltaIndex( + ["0 days 23:57:07.200000", "2 days 00:00:00", "3 days 00:00:00"] + ) + tm.assert_index_equal(result_bins, expected_bins) + + +@pytest.mark.parametrize("bins", [6, 7]) +@pytest.mark.parametrize( + "box, compare", + [ + (Series, tm.assert_series_equal), + (np.array, tm.assert_categorical_equal), + (list, tm.assert_equal), + ], +) +def test_cut_bool_coercion_to_int(bins, box, compare): + # issue 20303 + data_expected = box([0, 1, 1, 0, 1] * 10) + data_result = box([False, True, True, False, True] * 10) + expected = cut(data_expected, bins, duplicates="drop") + result = cut(data_result, bins, duplicates="drop") + compare(result, expected) + + +@pytest.mark.parametrize("labels", ["foo", 1, True]) +def test_cut_incorrect_labels(labels): + # GH 13318 + values = range(5) + msg = "Bin labels must either be False, None or passed in as a list-like argument" + with pytest.raises(ValueError, match=msg): + cut(values, 4, labels=labels) + + +@pytest.mark.parametrize("bins", [3, [0, 5, 15]]) +@pytest.mark.parametrize("right", [True, False]) +@pytest.mark.parametrize("include_lowest", [True, False]) +def test_cut_nullable_integer(bins, right, include_lowest): + a = np.random.default_rng(2).integers(0, 10, size=50).astype(float) + a[::2] = np.nan + b = a.astype(object) + b[::2] = pd.NA + result = cut( + pd.array(b, dtype="Int64"), bins, right=right, include_lowest=include_lowest + ) + expected = cut(a, bins, right=right, include_lowest=include_lowest) + tm.assert_categorical_equal(result, expected) + + +@pytest.mark.parametrize( + "data, bins, labels, expected_codes, expected_labels", + [ + ([15, 17, 19], [14, 16, 18, 20], ["A", "B", "A"], [0, 1, 0], ["A", "B"]), + ([1, 3, 5], [0, 2, 4, 6, 8], [2, 0, 1, 2], [2, 0, 1], [0, 1, 2]), + ], +) +def test_cut_non_unique_labels(data, bins, labels, expected_codes, expected_labels): + # GH 33141 + result = cut(data, bins=bins, labels=labels, ordered=False) + expected = Categorical.from_codes( + expected_codes, categories=expected_labels, ordered=False + ) + tm.assert_categorical_equal(result, expected) + + +@pytest.mark.parametrize( + "data, bins, labels, expected_codes, expected_labels", + [ + ([15, 17, 19], [14, 16, 18, 20], ["C", "B", "A"], [0, 1, 2], ["C", "B", "A"]), + ([1, 3, 5], [0, 2, 4, 6, 8], [3, 0, 1, 2], [0, 1, 2], [3, 0, 1, 2]), + ], +) +def test_cut_unordered_labels(data, bins, labels, expected_codes, expected_labels): + # GH 33141 + result = cut(data, bins=bins, labels=labels, ordered=False) + expected = Categorical.from_codes( + expected_codes, categories=expected_labels, ordered=False + ) + tm.assert_categorical_equal(result, expected) + + +def test_cut_unordered_with_missing_labels_raises_error(): + # GH 33141 + msg = "'labels' must be provided if 'ordered = False'" + with pytest.raises(ValueError, match=msg): + cut([0.5, 3], bins=[0, 1, 2], ordered=False) + + +def test_cut_unordered_with_series_labels(): + # https://github.com/pandas-dev/pandas/issues/36603 + ser = Series([1, 2, 3, 4, 5]) + bins = Series([0, 2, 4, 6]) + labels = Series(["a", "b", "c"]) + result = cut(ser, bins=bins, labels=labels, ordered=False) + expected = Series(["a", "a", "b", "b", "c"], dtype="category") + tm.assert_series_equal(result, expected) + + +def test_cut_no_warnings(): + df = DataFrame({"value": np.random.default_rng(2).integers(0, 100, 20)}) + labels = [f"{i} - {i + 9}" for i in range(0, 100, 10)] + with tm.assert_produces_warning(False): + df["group"] = cut(df.value, range(0, 105, 10), right=False, labels=labels) + + +def test_cut_with_duplicated_index_lowest_included(): + # GH 42185 + expected = Series( + [Interval(-0.001, 2, closed="right")] * 3 + + [Interval(2, 4, closed="right"), Interval(-0.001, 2, closed="right")], + index=[0, 1, 2, 3, 0], + dtype="category", + ).cat.as_ordered() + + ser = Series([0, 1, 2, 3, 0], index=[0, 1, 2, 3, 0]) + result = cut(ser, bins=[0, 2, 4], include_lowest=True) + tm.assert_series_equal(result, expected) + + +@pytest.mark.filterwarnings("ignore:invalid value encountered in cast:RuntimeWarning") +def test_cut_with_nonexact_categorical_indices(): + # GH 42424 + + ser = Series(range(100)) + ser1 = cut(ser, 10).value_counts().head(5) + ser2 = cut(ser, 10).value_counts().tail(5) + result = DataFrame({"1": ser1, "2": ser2}) + + index = pd.CategoricalIndex( + [ + Interval(-0.099, 9.9, closed="right"), + Interval(9.9, 19.8, closed="right"), + Interval(19.8, 29.7, closed="right"), + Interval(29.7, 39.6, closed="right"), + Interval(39.6, 49.5, closed="right"), + Interval(49.5, 59.4, closed="right"), + Interval(59.4, 69.3, closed="right"), + Interval(69.3, 79.2, closed="right"), + Interval(79.2, 89.1, closed="right"), + Interval(89.1, 99, closed="right"), + ], + ordered=True, + ) + + expected = DataFrame( + {"1": [10] * 5 + [np.nan] * 5, "2": [np.nan] * 5 + [10] * 5}, index=index + ) + + tm.assert_frame_equal(expected, result) + + +def test_cut_with_timestamp_tuple_labels(): + # GH 40661 + labels = [(Timestamp(10),), (Timestamp(20),), (Timestamp(30),)] + result = cut([2, 4, 6], bins=[1, 3, 5, 7], labels=labels) + + expected = Categorical.from_codes([0, 1, 2], labels, ordered=True) + tm.assert_categorical_equal(result, expected) + + +def test_cut_bins_datetime_intervalindex(): + # https://github.com/pandas-dev/pandas/issues/46218 + bins = interval_range(Timestamp("2022-02-25"), Timestamp("2022-02-27"), freq="1D") + # passing Series instead of list is important to trigger bug + result = cut(Series([Timestamp("2022-02-26")]), bins=bins) + expected = Categorical.from_codes([0], bins, ordered=True) + tm.assert_categorical_equal(result.array, expected) + + +def test_cut_with_nullable_int64(): + # GH 30787 + series = Series([0, 1, 2, 3, 4, pd.NA, 6, 7], dtype="Int64") + bins = [0, 2, 4, 6, 8] + intervals = IntervalIndex.from_breaks(bins) + + expected = Series( + Categorical.from_codes([-1, 0, 0, 1, 1, -1, 2, 3], intervals, ordered=True) + ) + + result = cut(series, bins=bins) + + tm.assert_series_equal(result, expected) + + +def test_cut_datetime_array_no_attributeerror(): + # GH 55431 + ser = Series(to_datetime(["2023-10-06 12:00:00+0000", "2023-10-07 12:00:00+0000"])) + + result = cut(ser.array, bins=2) + + categories = result.categories + expected = Categorical.from_codes([0, 1], categories=categories, ordered=True) + + tm.assert_categorical_equal( + result, expected, check_dtype=True, check_category_order=True + ) diff --git a/pandas/tests/reshape/test_from_dummies.py b/pandas/tests/reshape/test_from_dummies.py new file mode 100644 index 0000000000000000000000000000000000000000..0997baf7c3f74bf6b73e62e02283cea9eb1bd696 --- /dev/null +++ b/pandas/tests/reshape/test_from_dummies.py @@ -0,0 +1,477 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, + from_dummies, + get_dummies, +) +import pandas._testing as tm + + +@pytest.fixture +def dummies_basic(): + return DataFrame( + { + "col1_a": [1, 0, 1], + "col1_b": [0, 1, 0], + "col2_a": [0, 1, 0], + "col2_b": [1, 0, 0], + "col2_c": [0, 0, 1], + }, + ) + + +@pytest.fixture +def dummies_with_unassigned(): + return DataFrame( + { + "col1_a": [1, 0, 0], + "col1_b": [0, 1, 0], + "col2_a": [0, 1, 0], + "col2_b": [0, 0, 0], + "col2_c": [0, 0, 1], + }, + ) + + +def test_error_wrong_data_type(): + dummies = [0, 1, 0] + with pytest.raises( + TypeError, + match=r"Expected 'data' to be a 'DataFrame'; Received 'data' of type: list", + ): + from_dummies(dummies) + + +def test_error_no_prefix_contains_unassigned(): + dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, 0]}) + with pytest.raises( + ValueError, + match=( + r"Dummy DataFrame contains unassigned value\(s\); " + r"First instance in row: 2" + ), + ): + from_dummies(dummies) + + +def test_error_no_prefix_wrong_default_category_type(): + dummies = DataFrame({"a": [1, 0, 1], "b": [0, 1, 1]}) + with pytest.raises( + TypeError, + match=( + r"Expected 'default_category' to be of type 'None', 'Hashable', or 'dict'; " + r"Received 'default_category' of type: list" + ), + ): + from_dummies(dummies, default_category=["c", "d"]) + + +def test_error_no_prefix_multi_assignment(): + dummies = DataFrame({"a": [1, 0, 1], "b": [0, 1, 1]}) + with pytest.raises( + ValueError, + match=( + r"Dummy DataFrame contains multi-assignment\(s\); " + r"First instance in row: 2" + ), + ): + from_dummies(dummies) + + +def test_error_no_prefix_contains_nan(): + dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, np.nan]}) + with pytest.raises( + ValueError, match=r"Dummy DataFrame contains NA value in column: 'b'" + ): + from_dummies(dummies) + + +def test_error_contains_non_dummies(): + dummies = DataFrame( + {"a": [1, 6, 3, 1], "b": [0, 1, 0, 2], "c": ["c1", "c2", "c3", "c4"]} + ) + with pytest.raises( + TypeError, + match=r"Passed DataFrame contains non-dummy data", + ): + from_dummies(dummies) + + +def test_error_with_prefix_multiple_separators(): + dummies = DataFrame( + { + "col1_a": [1, 0, 1], + "col1_b": [0, 1, 0], + "col2-a": [0, 1, 0], + "col2-b": [1, 0, 1], + }, + ) + with pytest.raises( + ValueError, + match=(r"Separator not specified for column: col2-a"), + ): + from_dummies(dummies, sep="_") + + +def test_error_with_prefix_sep_wrong_type(dummies_basic): + with pytest.raises( + TypeError, + match=( + r"Expected 'sep' to be of type 'str' or 'None'; " + r"Received 'sep' of type: list" + ), + ): + from_dummies(dummies_basic, sep=["_"]) + + +def test_error_with_prefix_contains_unassigned(dummies_with_unassigned): + with pytest.raises( + ValueError, + match=( + r"Dummy DataFrame contains unassigned value\(s\); " + r"First instance in row: 2" + ), + ): + from_dummies(dummies_with_unassigned, sep="_") + + +def test_error_with_prefix_default_category_wrong_type(dummies_with_unassigned): + with pytest.raises( + TypeError, + match=( + r"Expected 'default_category' to be of type 'None', 'Hashable', or 'dict'; " + r"Received 'default_category' of type: list" + ), + ): + from_dummies(dummies_with_unassigned, sep="_", default_category=["x", "y"]) + + +def test_error_with_prefix_default_category_dict_not_complete( + dummies_with_unassigned, +): + with pytest.raises( + ValueError, + match=( + r"Length of 'default_category' \(1\) did not match " + r"the length of the columns being encoded \(2\)" + ), + ): + from_dummies(dummies_with_unassigned, sep="_", default_category={"col1": "x"}) + + +def test_error_with_prefix_contains_nan(dummies_basic): + # Set float64 dtype to avoid upcast when setting np.nan + dummies_basic["col2_c"] = dummies_basic["col2_c"].astype("float64") + dummies_basic.loc[2, "col2_c"] = np.nan + with pytest.raises( + ValueError, match=r"Dummy DataFrame contains NA value in column: 'col2_c'" + ): + from_dummies(dummies_basic, sep="_") + + +def test_error_with_prefix_contains_non_dummies(dummies_basic): + # Set object dtype to avoid upcast when setting "str" + dummies_basic["col2_c"] = dummies_basic["col2_c"].astype(object) + dummies_basic.loc[2, "col2_c"] = "str" + with pytest.raises(TypeError, match=r"Passed DataFrame contains non-dummy data"): + from_dummies(dummies_basic, sep="_") + + +def test_error_with_prefix_double_assignment(): + dummies = DataFrame( + { + "col1_a": [1, 0, 1], + "col1_b": [1, 1, 0], + "col2_a": [0, 1, 0], + "col2_b": [1, 0, 0], + "col2_c": [0, 0, 1], + }, + ) + with pytest.raises( + ValueError, + match=( + r"Dummy DataFrame contains multi-assignment\(s\); " + r"First instance in row: 0" + ), + ): + from_dummies(dummies, sep="_") + + +def test_roundtrip_series_to_dataframe(): + categories = Series(["a", "b", "c", "a"]) + dummies = get_dummies(categories) + result = from_dummies(dummies) + expected = DataFrame({"": ["a", "b", "c", "a"]}) + tm.assert_frame_equal(result, expected) + + +def test_roundtrip_single_column_dataframe(): + categories = DataFrame({"": ["a", "b", "c", "a"]}) + dummies = get_dummies(categories) + result = from_dummies(dummies, sep="_") + expected = categories + tm.assert_frame_equal(result, expected) + + +def test_roundtrip_with_prefixes(): + categories = DataFrame({"col1": ["a", "b", "a"], "col2": ["b", "a", "c"]}) + dummies = get_dummies(categories) + result = from_dummies(dummies, sep="_") + expected = categories + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_string_cats_basic(): + dummies = DataFrame({"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}) + expected = DataFrame({"": ["a", "b", "c", "a"]}) + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_string_cats_basic_bool_values(): + dummies = DataFrame( + { + "a": [True, False, False, True], + "b": [False, True, False, False], + "c": [False, False, True, False], + } + ) + expected = DataFrame({"": ["a", "b", "c", "a"]}) + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_string_cats_basic_mixed_bool_values(): + dummies = DataFrame( + {"a": [1, 0, 0, 1], "b": [False, True, False, False], "c": [0, 0, 1, 0]} + ) + expected = DataFrame({"": ["a", "b", "c", "a"]}) + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_int_cats_basic(): + dummies = DataFrame( + {1: [1, 0, 0, 0], 25: [0, 1, 0, 0], 2: [0, 0, 1, 0], 5: [0, 0, 0, 1]} + ) + expected = DataFrame({"": [1, 25, 2, 5]}) + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_float_cats_basic(): + dummies = DataFrame( + {1.0: [1, 0, 0, 0], 25.0: [0, 1, 0, 0], 2.5: [0, 0, 1, 0], 5.84: [0, 0, 0, 1]} + ) + expected = DataFrame({"": [1.0, 25.0, 2.5, 5.84]}) + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_mixed_cats_basic(): + dummies = DataFrame( + { + 1.23: [1, 0, 0, 0, 0], + "c": [0, 1, 0, 0, 0], + 2: [0, 0, 1, 0, 0], + False: [0, 0, 0, 1, 0], + None: [0, 0, 0, 0, 1], + } + ) + expected = DataFrame({"": [1.23, "c", 2, False, None]}, dtype="object") + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +def test_no_prefix_string_cats_contains_get_dummies_NaN_column(): + dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, 0], "NaN": [0, 0, 1]}) + expected = DataFrame({"": ["a", "b", "NaN"]}) + result = from_dummies(dummies) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "default_category, expected", + [ + pytest.param( + "c", + {"": ["a", "b", "c"]}, + id="default_category is a str", + ), + pytest.param( + 1, + {"": ["a", "b", 1]}, + id="default_category is an int", + ), + pytest.param( + 1.25, + {"": ["a", "b", 1.25]}, + id="default_category is a float", + ), + pytest.param( + 0, + {"": ["a", "b", 0]}, + id="default_category is a 0", + ), + pytest.param( + False, + {"": ["a", "b", False]}, + id="default_category is a bool", + ), + pytest.param( + (1, 2), + {"": ["a", "b", (1, 2)]}, + id="default_category is a tuple", + ), + ], +) +def test_no_prefix_string_cats_default_category( + default_category, expected, using_infer_string +): + dummies = DataFrame({"a": [1, 0, 0], "b": [0, 1, 0]}) + result = from_dummies(dummies, default_category=default_category) + expected = DataFrame(expected, dtype=dummies.columns.dtype) + tm.assert_frame_equal(result, expected) + + +def test_with_prefix_basic(dummies_basic): + expected = DataFrame({"col1": ["a", "b", "a"], "col2": ["b", "a", "c"]}) + result = from_dummies(dummies_basic, sep="_") + tm.assert_frame_equal(result, expected) + + +def test_with_prefix_contains_get_dummies_NaN_column(): + dummies = DataFrame( + { + "col1_a": [1, 0, 0], + "col1_b": [0, 1, 0], + "col1_NaN": [0, 0, 1], + "col2_a": [0, 1, 0], + "col2_b": [0, 0, 0], + "col2_c": [0, 0, 1], + "col2_NaN": [1, 0, 0], + }, + ) + expected = DataFrame({"col1": ["a", "b", "NaN"], "col2": ["NaN", "a", "c"]}) + result = from_dummies(dummies, sep="_") + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "default_category, expected", + [ + pytest.param( + "x", + {"col1": ["a", "b", "x"], "col2": ["x", "a", "c"]}, + id="default_category is a str", + ), + pytest.param( + 0, + {"col1": ["a", "b", 0], "col2": [0, "a", "c"]}, + id="default_category is a 0", + ), + pytest.param( + False, + {"col1": ["a", "b", False], "col2": [False, "a", "c"]}, + id="default_category is a False", + ), + pytest.param( + {"col2": 1, "col1": 2.5}, + {"col1": ["a", "b", 2.5], "col2": [1, "a", "c"]}, + id="default_category is a dict with int and float values", + ), + pytest.param( + {"col2": None, "col1": False}, + {"col1": ["a", "b", False], "col2": [None, "a", "c"]}, + id="default_category is a dict with bool and None values", + ), + pytest.param( + {"col2": (1, 2), "col1": [1.25, False]}, + {"col1": ["a", "b", [1.25, False]], "col2": [(1, 2), "a", "c"]}, + id="default_category is a dict with list and tuple values", + ), + ], +) +def test_with_prefix_default_category( + dummies_with_unassigned, default_category, expected, using_infer_string +): + result = from_dummies( + dummies_with_unassigned, sep="_", default_category=default_category + ) + expected = DataFrame(expected) + if using_infer_string: + expected = expected.astype("str") + tm.assert_frame_equal(result, expected) + + +def test_ea_categories(): + # GH 54300 + df = DataFrame({"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}) + df.columns = df.columns.astype("string[python]") + result = from_dummies(df) + expected = DataFrame({"": Series(list("abca"), dtype="string[python]")}) + tm.assert_frame_equal(result, expected) + + +def test_ea_categories_with_sep(): + # GH 54300 + df = DataFrame( + { + "col1_a": [1, 0, 1], + "col1_b": [0, 1, 0], + "col2_a": [0, 1, 0], + "col2_b": [1, 0, 0], + "col2_c": [0, 0, 1], + } + ) + df.columns = df.columns.astype("string[python]") + result = from_dummies(df, sep="_") + expected = DataFrame( + { + "col1": Series(list("aba"), dtype="string[python]"), + "col2": Series(list("bac"), dtype="string[python]"), + } + ) + expected.columns = expected.columns.astype("string[python]") + tm.assert_frame_equal(result, expected) + + +def test_maintain_original_index(): + # GH 54300 + df = DataFrame( + {"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}, index=list("abcd") + ) + result = from_dummies(df) + expected = DataFrame({"": list("abca")}, index=list("abcd")) + tm.assert_frame_equal(result, expected) + + +def test_int_columns_with_float_default(): + # https://github.com/pandas-dev/pandas/pull/60694 + df = DataFrame( + { + 3: [1, 0, 0], + 4: [0, 1, 0], + }, + ) + with pytest.raises(ValueError, match="Trying to coerce float values to integers"): + from_dummies(df, default_category=0.5) + + +def test_object_dtype_preserved(): + # https://github.com/pandas-dev/pandas/pull/60694 + # When the input has object dtype, the result should as + # well even when infer_string is True. + df = DataFrame( + { + "x": [1, 0, 0], + "y": [0, 1, 0], + }, + ) + df.columns = df.columns.astype("object") + result = from_dummies(df, default_category="z") + expected = DataFrame({"": ["x", "y", "z"]}, dtype="object") + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/reshape/test_get_dummies.py b/pandas/tests/reshape/test_get_dummies.py new file mode 100644 index 0000000000000000000000000000000000000000..c776e7b2e3d9a4f6ab99042cf95a6996d1ba5b60 --- /dev/null +++ b/pandas/tests/reshape/test_get_dummies.py @@ -0,0 +1,741 @@ +import re +import unicodedata + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import is_integer_dtype + +import pandas as pd +from pandas import ( + ArrowDtype, + Categorical, + CategoricalDtype, + CategoricalIndex, + DataFrame, + Index, + RangeIndex, + Series, + SparseDtype, + get_dummies, +) +import pandas._testing as tm +from pandas.core.arrays.sparse import SparseArray + +try: + import pyarrow as pa +except ImportError: + pa = None + + +class TestGetDummies: + @pytest.fixture + def df(self): + return DataFrame({"A": ["a", "b", "a"], "B": ["b", "b", "c"], "C": [1, 2, 3]}) + + @pytest.fixture(params=["uint8", "i8", np.float64, bool, None]) + def dtype(self, request): + return np.dtype(request.param) + + @pytest.fixture(params=["dense", "sparse"]) + def sparse(self, request): + # params are strings to simplify reading test results, + # e.g. TestGetDummies::test_basic[uint8-sparse] instead of [uint8-True] + return request.param == "sparse" + + def effective_dtype(self, dtype): + if dtype is None: + return np.uint8 + return dtype + + def test_get_dummies_raises_on_dtype_object(self, df): + msg = "dtype=object is not a valid dtype for get_dummies" + with pytest.raises(ValueError, match=msg): + get_dummies(df, dtype="object") + + def test_get_dummies_basic(self, sparse, dtype): + s_list = list("abc") + s_series = Series(s_list) + s_series_index = Series(s_list, list("ABC")) + + expected = DataFrame( + {"a": [1, 0, 0], "b": [0, 1, 0], "c": [0, 0, 1]}, + dtype=self.effective_dtype(dtype), + ) + if sparse: + if dtype.kind == "b": + expected = expected.apply(SparseArray, fill_value=False) + else: + expected = expected.apply(SparseArray, fill_value=0.0) + result = get_dummies(s_list, sparse=sparse, dtype=dtype) + tm.assert_frame_equal(result, expected) + + result = get_dummies(s_series, sparse=sparse, dtype=dtype) + tm.assert_frame_equal(result, expected) + + expected.index = list("ABC") + result = get_dummies(s_series_index, sparse=sparse, dtype=dtype) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_basic_types(self, sparse, dtype, using_infer_string): + # GH 10531 + s_list = list("abc") + s_series = Series(s_list) + s_df = DataFrame( + {"a": [0, 1, 0, 1, 2], "b": ["A", "A", "B", "C", "C"], "c": [2, 3, 3, 3, 2]} + ) + + expected = DataFrame( + {"a": [1, 0, 0], "b": [0, 1, 0], "c": [0, 0, 1]}, + dtype=self.effective_dtype(dtype), + columns=list("abc"), + ) + if sparse: + if is_integer_dtype(dtype): + fill_value = 0 + elif dtype == bool: + fill_value = False + else: + fill_value = 0.0 + + expected = expected.apply(SparseArray, fill_value=fill_value) + result = get_dummies(s_list, sparse=sparse, dtype=dtype) + tm.assert_frame_equal(result, expected) + + result = get_dummies(s_series, sparse=sparse, dtype=dtype) + tm.assert_frame_equal(result, expected) + + result = get_dummies(s_df, columns=s_df.columns, sparse=sparse, dtype=dtype) + if sparse: + dtype_name = f"Sparse[{self.effective_dtype(dtype).name}, {fill_value}]" + else: + dtype_name = self.effective_dtype(dtype).name + + expected = Series({dtype_name: 8}, name="count") + result = result.dtypes.value_counts() + result.index = [str(i) for i in result.index] + tm.assert_series_equal(result, expected) + + result = get_dummies(s_df, columns=["a"], sparse=sparse, dtype=dtype) + + key = "str" if using_infer_string else "object" + expected_counts = {"int64": 1, key: 1} + expected_counts[dtype_name] = 3 + expected_counts.get(dtype_name, 0) + + expected = Series(expected_counts, name="count").sort_index() + result = result.dtypes.value_counts() + result.index = [str(i) for i in result.index] + result = result.sort_index() + tm.assert_series_equal(result, expected) + + def test_get_dummies_just_na(self, sparse): + just_na_list = [np.nan] + just_na_series = Series(just_na_list) + just_na_series_index = Series(just_na_list, index=["A"]) + + res_list = get_dummies(just_na_list, sparse=sparse) + res_series = get_dummies(just_na_series, sparse=sparse) + res_series_index = get_dummies(just_na_series_index, sparse=sparse) + + assert res_list.empty + assert res_series.empty + assert res_series_index.empty + + assert res_list.index.tolist() == [0] + assert res_series.index.tolist() == [0] + assert res_series_index.index.tolist() == ["A"] + + def test_get_dummies_include_na(self, sparse, dtype): + s = ["a", "b", np.nan] + res = get_dummies(s, sparse=sparse, dtype=dtype) + exp = DataFrame( + {"a": [1, 0, 0], "b": [0, 1, 0]}, dtype=self.effective_dtype(dtype) + ) + if sparse: + if dtype.kind == "b": + exp = exp.apply(SparseArray, fill_value=False) + else: + exp = exp.apply(SparseArray, fill_value=0.0) + tm.assert_frame_equal(res, exp) + + # Sparse dataframes do not allow nan labelled columns, see #GH8822 + res_na = get_dummies(s, dummy_na=True, sparse=sparse, dtype=dtype) + exp_na = DataFrame( + {np.nan: [0, 0, 1], "a": [1, 0, 0], "b": [0, 1, 0]}, + dtype=self.effective_dtype(dtype), + ) + exp_na = exp_na.reindex(["a", "b", np.nan], axis=1) + # hack (NaN handling in assert_index_equal) + exp_na.columns = res_na.columns + if sparse: + if dtype.kind == "b": + exp_na = exp_na.apply(SparseArray, fill_value=False) + else: + exp_na = exp_na.apply(SparseArray, fill_value=0.0) + tm.assert_frame_equal(res_na, exp_na) + + res_just_na = get_dummies([np.nan], dummy_na=True, sparse=sparse, dtype=dtype) + exp_just_na = DataFrame( + Series(1, index=[0]), columns=[np.nan], dtype=self.effective_dtype(dtype) + ) + tm.assert_numpy_array_equal(res_just_na.values, exp_just_na.values) + + def test_get_dummies_unicode(self, sparse): + # See GH 6885 - get_dummies chokes on unicode values + e = "e" + eacute = unicodedata.lookup("LATIN SMALL LETTER E WITH ACUTE") + s = [e, eacute, eacute] + res = get_dummies(s, prefix="letter", sparse=sparse) + exp = DataFrame( + {"letter_e": [True, False, False], f"letter_{eacute}": [False, True, True]} + ) + if sparse: + exp = exp.apply(SparseArray, fill_value=False) + tm.assert_frame_equal(res, exp) + + def test_dataframe_dummies_all_obj(self, df, sparse): + df = df[["A", "B"]] + result = get_dummies(df, sparse=sparse) + expected = DataFrame( + {"A_a": [1, 0, 1], "A_b": [0, 1, 0], "B_b": [1, 1, 0], "B_c": [0, 0, 1]}, + dtype=bool, + ) + if sparse: + expected = DataFrame( + { + "A_a": SparseArray([1, 0, 1], dtype="bool"), + "A_b": SparseArray([0, 1, 0], dtype="bool"), + "B_b": SparseArray([1, 1, 0], dtype="bool"), + "B_c": SparseArray([0, 0, 1], dtype="bool"), + } + ) + + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_string_dtype(self, df, any_string_dtype): + # GH44965 + df = df[["A", "B"]] + df = df.astype({"A": "str", "B": any_string_dtype}) + result = get_dummies(df) + expected = DataFrame( + { + "A_a": [1, 0, 1], + "A_b": [0, 1, 0], + "B_b": [1, 1, 0], + "B_c": [0, 0, 1], + }, + dtype=bool, + ) + if any_string_dtype == "string" and any_string_dtype.na_value is pd.NA: + expected[["B_b", "B_c"]] = expected[["B_b", "B_c"]].astype("boolean") + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_mix_default(self, df, sparse, dtype): + result = get_dummies(df, sparse=sparse, dtype=dtype) + if sparse: + arr = SparseArray + if dtype.kind == "b": + typ = SparseDtype(dtype, False) + else: + typ = SparseDtype(dtype, 0) + else: + arr = np.array + typ = dtype + expected = DataFrame( + { + "C": [1, 2, 3], + "A_a": arr([1, 0, 1], dtype=typ), + "A_b": arr([0, 1, 0], dtype=typ), + "B_b": arr([1, 1, 0], dtype=typ), + "B_c": arr([0, 0, 1], dtype=typ), + } + ) + expected = expected[["C", "A_a", "A_b", "B_b", "B_c"]] + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_prefix_list(self, df, sparse): + prefixes = ["from_A", "from_B"] + result = get_dummies(df, prefix=prefixes, sparse=sparse) + expected = DataFrame( + { + "C": [1, 2, 3], + "from_A_a": [True, False, True], + "from_A_b": [False, True, False], + "from_B_b": [True, True, False], + "from_B_c": [False, False, True], + }, + ) + expected[["C"]] = df[["C"]] + cols = ["from_A_a", "from_A_b", "from_B_b", "from_B_c"] + expected = expected[["C", *cols]] + + typ = SparseArray if sparse else Series + expected[cols] = expected[cols].apply(lambda x: typ(x)) + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_prefix_str(self, df, sparse): + # not that you should do this... + result = get_dummies(df, prefix="bad", sparse=sparse) + bad_columns = ["bad_a", "bad_b", "bad_b", "bad_c"] + expected = DataFrame( + [ + [1, True, False, True, False], + [2, False, True, True, False], + [3, True, False, False, True], + ], + columns=["C", *bad_columns], + ) + expected = expected.astype({"C": np.int64}) + if sparse: + # work around astyping & assigning with duplicate columns + # https://github.com/pandas-dev/pandas/issues/14427 + expected = pd.concat( + [ + Series([1, 2, 3], name="C"), + Series([True, False, True], name="bad_a", dtype="Sparse[bool]"), + Series([False, True, False], name="bad_b", dtype="Sparse[bool]"), + Series([True, True, False], name="bad_b", dtype="Sparse[bool]"), + Series([False, False, True], name="bad_c", dtype="Sparse[bool]"), + ], + axis=1, + ) + + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_subset(self, df, sparse): + result = get_dummies(df, prefix=["from_A"], columns=["A"], sparse=sparse) + expected = DataFrame( + { + "B": ["b", "b", "c"], + "C": [1, 2, 3], + "from_A_a": [1, 0, 1], + "from_A_b": [0, 1, 0], + }, + ) + cols = expected.columns + expected[cols[1:]] = expected[cols[1:]].astype(bool) + expected[["C"]] = df[["C"]] + if sparse: + cols = ["from_A_a", "from_A_b"] + expected[cols] = expected[cols].astype(SparseDtype("bool", False)) + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_prefix_sep(self, df, sparse): + result = get_dummies(df, prefix_sep="..", sparse=sparse) + expected = DataFrame( + { + "C": [1, 2, 3], + "A..a": [True, False, True], + "A..b": [False, True, False], + "B..b": [True, True, False], + "B..c": [False, False, True], + }, + ) + expected[["C"]] = df[["C"]] + expected = expected[["C", "A..a", "A..b", "B..b", "B..c"]] + if sparse: + cols = ["A..a", "A..b", "B..b", "B..c"] + expected[cols] = expected[cols].astype(SparseDtype("bool", False)) + + tm.assert_frame_equal(result, expected) + + result = get_dummies(df, prefix_sep=["..", "__"], sparse=sparse) + expected = expected.rename(columns={"B..b": "B__b", "B..c": "B__c"}) + tm.assert_frame_equal(result, expected) + + result = get_dummies(df, prefix_sep={"A": "..", "B": "__"}, sparse=sparse) + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_prefix_bad_length(self, df, sparse): + msg = re.escape( + "Length of 'prefix' (1) did not match the length of the columns being " + "encoded (2)" + ) + with pytest.raises(ValueError, match=msg): + get_dummies(df, prefix=["too few"], sparse=sparse) + + def test_dataframe_dummies_prefix_sep_bad_length(self, df, sparse): + msg = re.escape( + "Length of 'prefix_sep' (1) did not match the length of the columns being " + "encoded (2)" + ) + with pytest.raises(ValueError, match=msg): + get_dummies(df, prefix_sep=["bad"], sparse=sparse) + + def test_dataframe_dummies_prefix_dict(self, sparse): + prefixes = {"A": "from_A", "B": "from_B"} + df = DataFrame({"C": [1, 2, 3], "A": ["a", "b", "a"], "B": ["b", "b", "c"]}) + result = get_dummies(df, prefix=prefixes, sparse=sparse) + + expected = DataFrame( + { + "C": [1, 2, 3], + "from_A_a": [1, 0, 1], + "from_A_b": [0, 1, 0], + "from_B_b": [1, 1, 0], + "from_B_c": [0, 0, 1], + } + ) + + columns = ["from_A_a", "from_A_b", "from_B_b", "from_B_c"] + expected[columns] = expected[columns].astype(bool) + if sparse: + expected[columns] = expected[columns].astype(SparseDtype("bool", False)) + + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_with_na(self, df, sparse, dtype): + df.loc[3, :] = [np.nan, np.nan, np.nan] + result = get_dummies(df, dummy_na=True, sparse=sparse, dtype=dtype).sort_index( + axis=1 + ) + + if sparse: + arr = SparseArray + if dtype.kind == "b": + typ = SparseDtype(dtype, False) + else: + typ = SparseDtype(dtype, 0) + else: + arr = np.array + typ = dtype + + expected = DataFrame( + { + "C": [1, 2, 3, np.nan], + "A_a": arr([1, 0, 1, 0], dtype=typ), + "A_b": arr([0, 1, 0, 0], dtype=typ), + "A_nan": arr([0, 0, 0, 1], dtype=typ), + "B_b": arr([1, 1, 0, 0], dtype=typ), + "B_c": arr([0, 0, 1, 0], dtype=typ), + "B_nan": arr([0, 0, 0, 1], dtype=typ), + } + ).sort_index(axis=1) + + tm.assert_frame_equal(result, expected) + + result = get_dummies(df, dummy_na=False, sparse=sparse, dtype=dtype) + expected = expected[["C", "A_a", "A_b", "B_b", "B_c"]] + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_with_categorical(self, df, sparse, dtype): + df["cat"] = Categorical(["x", "y", "y"]) + result = get_dummies(df, sparse=sparse, dtype=dtype).sort_index(axis=1) + if sparse: + arr = SparseArray + if dtype.kind == "b": + typ = SparseDtype(dtype, False) + else: + typ = SparseDtype(dtype, 0) + else: + arr = np.array + typ = dtype + + expected = DataFrame( + { + "C": [1, 2, 3], + "A_a": arr([1, 0, 1], dtype=typ), + "A_b": arr([0, 1, 0], dtype=typ), + "B_b": arr([1, 1, 0], dtype=typ), + "B_c": arr([0, 0, 1], dtype=typ), + "cat_x": arr([1, 0, 0], dtype=typ), + "cat_y": arr([0, 1, 1], dtype=typ), + } + ).sort_index(axis=1) + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "get_dummies_kwargs,expected", + [ + ( + {"data": DataFrame({"ä": ["a"]})}, + "ä_a", + ), + ( + {"data": DataFrame({"x": ["ä"]})}, + "x_ä", + ), + ( + {"data": DataFrame({"x": ["a"]}), "prefix": "ä"}, + "ä_a", + ), + ( + {"data": DataFrame({"x": ["a"]}), "prefix_sep": "ä"}, + "xäa", + ), + ], + ) + def test_dataframe_dummies_unicode(self, get_dummies_kwargs, expected): + # GH22084 get_dummies incorrectly encodes unicode characters + # in dataframe column names + result = get_dummies(**get_dummies_kwargs) + expected = DataFrame({expected: [True]}) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_basic_drop_first(self, sparse): + # GH12402 Add a new parameter `drop_first` to avoid collinearity + # Basic case + s_list = list("abc") + s_series = Series(s_list) + s_series_index = Series(s_list, list("ABC")) + + expected = DataFrame({"b": [0, 1, 0], "c": [0, 0, 1]}, dtype=bool) + + result = get_dummies(s_list, drop_first=True, sparse=sparse) + if sparse: + expected = expected.apply(SparseArray, fill_value=False) + tm.assert_frame_equal(result, expected) + + result = get_dummies(s_series, drop_first=True, sparse=sparse) + tm.assert_frame_equal(result, expected) + + expected.index = list("ABC") + result = get_dummies(s_series_index, drop_first=True, sparse=sparse) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_basic_drop_first_one_level(self, sparse): + # Test the case that categorical variable only has one level. + s_list = list("aaa") + s_series = Series(s_list) + s_series_index = Series(s_list, list("ABC")) + + expected = DataFrame(index=RangeIndex(3)) + + result = get_dummies(s_list, drop_first=True, sparse=sparse) + tm.assert_frame_equal(result, expected) + + result = get_dummies(s_series, drop_first=True, sparse=sparse) + tm.assert_frame_equal(result, expected) + + expected = DataFrame(index=list("ABC")) + result = get_dummies(s_series_index, drop_first=True, sparse=sparse) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_basic_drop_first_NA(self, sparse): + # Test NA handling together with drop_first + s_NA = ["a", "b", np.nan] + res = get_dummies(s_NA, drop_first=True, sparse=sparse) + exp = DataFrame({"b": [0, 1, 0]}, dtype=bool) + if sparse: + exp = exp.apply(SparseArray, fill_value=False) + + tm.assert_frame_equal(res, exp) + + res_na = get_dummies(s_NA, dummy_na=True, drop_first=True, sparse=sparse) + exp_na = DataFrame({"b": [0, 1, 0], np.nan: [0, 0, 1]}, dtype=bool).reindex( + ["b", np.nan], axis=1 + ) + if sparse: + exp_na = exp_na.apply(SparseArray, fill_value=False) + tm.assert_frame_equal(res_na, exp_na) + + res_just_na = get_dummies( + [np.nan], dummy_na=True, drop_first=True, sparse=sparse + ) + exp_just_na = DataFrame(index=RangeIndex(1)) + tm.assert_frame_equal(res_just_na, exp_just_na) + + def test_dataframe_dummies_drop_first(self, df, sparse): + df = df[["A", "B"]] + result = get_dummies(df, drop_first=True, sparse=sparse) + expected = DataFrame({"A_b": [0, 1, 0], "B_c": [0, 0, 1]}, dtype=bool) + if sparse: + expected = expected.apply(SparseArray, fill_value=False) + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_drop_first_with_categorical(self, df, sparse, dtype): + df["cat"] = Categorical(["x", "y", "y"]) + result = get_dummies(df, drop_first=True, sparse=sparse) + expected = DataFrame( + {"C": [1, 2, 3], "A_b": [0, 1, 0], "B_c": [0, 0, 1], "cat_y": [0, 1, 1]} + ) + cols = ["A_b", "B_c", "cat_y"] + expected[cols] = expected[cols].astype(bool) + expected = expected[["C", "A_b", "B_c", "cat_y"]] + if sparse: + for col in cols: + expected[col] = SparseArray(expected[col]) + tm.assert_frame_equal(result, expected) + + def test_dataframe_dummies_drop_first_with_na(self, df, sparse): + df.loc[3, :] = [np.nan, np.nan, np.nan] + result = get_dummies( + df, dummy_na=True, drop_first=True, sparse=sparse + ).sort_index(axis=1) + expected = DataFrame( + { + "C": [1, 2, 3, np.nan], + "A_b": [0, 1, 0, 0], + "A_nan": [0, 0, 0, 1], + "B_c": [0, 0, 1, 0], + "B_nan": [0, 0, 0, 1], + } + ) + cols = ["A_b", "A_nan", "B_c", "B_nan"] + expected[cols] = expected[cols].astype(bool) + expected = expected.sort_index(axis=1) + if sparse: + for col in cols: + expected[col] = SparseArray(expected[col]) + + tm.assert_frame_equal(result, expected) + + result = get_dummies(df, dummy_na=False, drop_first=True, sparse=sparse) + expected = expected[["C", "A_b", "B_c"]] + tm.assert_frame_equal(result, expected) + + def test_get_dummies_int_int(self): + data = Series([1, 2, 1]) + result = get_dummies(data) + expected = DataFrame([[1, 0], [0, 1], [1, 0]], columns=[1, 2], dtype=bool) + tm.assert_frame_equal(result, expected) + + data = Series(Categorical(["a", "b", "a"])) + result = get_dummies(data) + expected = DataFrame( + [[1, 0], [0, 1], [1, 0]], columns=Categorical(["a", "b"]), dtype=bool + ) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_int_df(self, dtype): + data = DataFrame( + { + "A": [1, 2, 1], + "B": Categorical(["a", "b", "a"]), + "C": [1, 2, 1], + "D": [1.0, 2.0, 1.0], + } + ) + columns = ["C", "D", "A_1", "A_2", "B_a", "B_b"] + expected = DataFrame( + [[1, 1.0, 1, 0, 1, 0], [2, 2.0, 0, 1, 0, 1], [1, 1.0, 1, 0, 1, 0]], + columns=columns, + ) + expected[columns[2:]] = expected[columns[2:]].astype(dtype) + result = get_dummies(data, columns=["A", "B"], dtype=dtype) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("ordered", [True, False]) + def test_dataframe_dummies_preserve_categorical_dtype(self, dtype, ordered): + # GH13854 + cat = Categorical(list("xy"), categories=list("xyz"), ordered=ordered) + result = get_dummies(cat, dtype=dtype) + + data = np.array([[1, 0, 0], [0, 1, 0]], dtype=self.effective_dtype(dtype)) + cols = CategoricalIndex( + cat.categories, categories=cat.categories, ordered=ordered + ) + expected = DataFrame(data, columns=cols, dtype=self.effective_dtype(dtype)) + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("sparse", [True, False]) + def test_get_dummies_dont_sparsify_all_columns(self, sparse): + # GH18914 + df = DataFrame.from_dict({"GDP": [1, 2], "Nation": ["AB", "CD"]}) + df = get_dummies(df, columns=["Nation"], sparse=sparse) + df2 = df.reindex(columns=["GDP"]) + + tm.assert_frame_equal(df[["GDP"]], df2) + + def test_get_dummies_duplicate_columns(self, df): + # GH20839 + df.columns = ["A", "A", "A"] + result = get_dummies(df).sort_index(axis=1) + + expected = DataFrame( + [ + [1, True, False, True, False], + [2, False, True, True, False], + [3, True, False, False, True], + ], + columns=["A", "A_a", "A_b", "A_b", "A_c"], + ).sort_index(axis=1) + + expected = expected.astype({"A": np.int64}) + + tm.assert_frame_equal(result, expected) + + def test_get_dummies_all_sparse(self): + df = DataFrame({"A": [1, 2]}) + result = get_dummies(df, columns=["A"], sparse=True) + dtype = SparseDtype("bool", False) + expected = DataFrame( + { + "A_1": SparseArray([1, 0], dtype=dtype), + "A_2": SparseArray([0, 1], dtype=dtype), + } + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("values", ["baz"]) + def test_get_dummies_with_string_values(self, values): + # issue #28383 + df = DataFrame( + { + "bar": [1, 2, 3, 4, 5, 6], + "foo": ["one", "one", "one", "two", "two", "two"], + "baz": ["A", "B", "C", "A", "B", "C"], + "zoo": ["x", "y", "z", "q", "w", "t"], + } + ) + + msg = "Input must be a list-like for parameter `columns`" + + with pytest.raises(TypeError, match=msg): + get_dummies(df, columns=values) + + def test_get_dummies_ea_dtype_series(self, any_numeric_ea_and_arrow_dtype): + # GH#32430 + ser = Series(list("abca")) + result = get_dummies(ser, dtype=any_numeric_ea_and_arrow_dtype) + expected = DataFrame( + {"a": [1, 0, 0, 1], "b": [0, 1, 0, 0], "c": [0, 0, 1, 0]}, + dtype=any_numeric_ea_and_arrow_dtype, + ) + tm.assert_frame_equal(result, expected) + + def test_get_dummies_ea_dtype_dataframe(self, any_numeric_ea_and_arrow_dtype): + # GH#32430 + df = DataFrame({"x": list("abca")}) + result = get_dummies(df, dtype=any_numeric_ea_and_arrow_dtype) + expected = DataFrame( + {"x_a": [1, 0, 0, 1], "x_b": [0, 1, 0, 0], "x_c": [0, 0, 1, 0]}, + dtype=any_numeric_ea_and_arrow_dtype, + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dtype_type", ["string", "category"]) + def test_get_dummies_ea_dtype(self, dtype_type, string_dtype_no_object): + # GH#56273 + dtype = string_dtype_no_object + exp_dtype = "boolean" if dtype.na_value is pd.NA else "bool" + if dtype_type == "category": + dtype = CategoricalDtype(Index(["a"], dtype)) + df = DataFrame({"name": Series(["a"], dtype=dtype), "x": 1}) + result = get_dummies(df) + expected = DataFrame({"x": 1, "name_a": Series([True], dtype=exp_dtype)}) + tm.assert_frame_equal(result, expected) + + @td.skip_if_no("pyarrow") + def test_get_dummies_arrow_dtype(self): + # GH#56273 + df = DataFrame({"name": Series(["a"], dtype=ArrowDtype(pa.string())), "x": 1}) + result = get_dummies(df) + expected = DataFrame({"x": 1, "name_a": Series([True], dtype="bool[pyarrow]")}) + tm.assert_frame_equal(result, expected) + + df = DataFrame( + { + "name": Series( + ["a"], + dtype=CategoricalDtype(Index(["a"], dtype=ArrowDtype(pa.string()))), + ), + "x": 1, + } + ) + result = get_dummies(df) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/reshape/test_melt.py b/pandas/tests/reshape/test_melt.py new file mode 100644 index 0000000000000000000000000000000000000000..fba9c28282e9491ce5f237e4d92068a4320953ca --- /dev/null +++ b/pandas/tests/reshape/test_melt.py @@ -0,0 +1,1280 @@ +import re + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + date_range, + lreshape, + melt, + wide_to_long, +) +import pandas._testing as tm + + +@pytest.fixture +def df(): + res = DataFrame( + np.random.default_rng(2).standard_normal((10, 4)), + columns=Index(list("ABCD")), + index=date_range("2000-01-01", periods=10, freq="B"), + ) + res["id1"] = (res["A"] > 0).astype(np.int64) + res["id2"] = (res["B"] > 0).astype(np.int64) + return res + + +@pytest.fixture +def df1(): + res = DataFrame( + [ + [1.067683, -1.110463, 0.20867], + [-1.321405, 0.368915, -1.055342], + [-0.807333, 0.08298, -0.873361], + ] + ) + res.columns = [list("ABC"), list("abc")] + res.columns.names = ["CAP", "low"] + return res + + +@pytest.fixture +def var_name(): + return "var" + + +@pytest.fixture +def value_name(): + return "val" + + +class TestMelt: + def test_top_level_method(self, df): + result = melt(df) + assert result.columns.tolist() == ["variable", "value"] + + def test_method_signatures(self, df, df1, var_name, value_name): + tm.assert_frame_equal(df.melt(), melt(df)) + + tm.assert_frame_equal( + df.melt(id_vars=["id1", "id2"], value_vars=["A", "B"]), + melt(df, id_vars=["id1", "id2"], value_vars=["A", "B"]), + ) + + tm.assert_frame_equal( + df.melt(var_name=var_name, value_name=value_name), + melt(df, var_name=var_name, value_name=value_name), + ) + + tm.assert_frame_equal(df1.melt(col_level=0), melt(df1, col_level=0)) + + def test_default_col_names(self, df): + result = df.melt() + assert result.columns.tolist() == ["variable", "value"] + + result1 = df.melt(id_vars=["id1"]) + assert result1.columns.tolist() == ["id1", "variable", "value"] + + result2 = df.melt(id_vars=["id1", "id2"]) + assert result2.columns.tolist() == ["id1", "id2", "variable", "value"] + + def test_value_vars(self, df): + result3 = df.melt(id_vars=["id1", "id2"], value_vars="A") + assert len(result3) == 10 + + result4 = df.melt(id_vars=["id1", "id2"], value_vars=["A", "B"]) + expected4 = DataFrame( + { + "id1": df["id1"].tolist() * 2, + "id2": df["id2"].tolist() * 2, + "variable": ["A"] * 10 + ["B"] * 10, + "value": (df["A"].tolist() + df["B"].tolist()), + }, + columns=["id1", "id2", "variable", "value"], + ) + tm.assert_frame_equal(result4, expected4) + + @pytest.mark.parametrize("type_", (tuple, list, np.array)) + def test_value_vars_types(self, type_, df): + # GH 15348 + expected = DataFrame( + { + "id1": df["id1"].tolist() * 2, + "id2": df["id2"].tolist() * 2, + "variable": ["A"] * 10 + ["B"] * 10, + "value": (df["A"].tolist() + df["B"].tolist()), + }, + columns=["id1", "id2", "variable", "value"], + ) + result = df.melt(id_vars=["id1", "id2"], value_vars=type_(("A", "B"))) + tm.assert_frame_equal(result, expected) + + def test_vars_work_with_multiindex(self, df1): + expected = DataFrame( + { + ("A", "a"): df1[("A", "a")], + "CAP": ["B"] * len(df1), + "low": ["b"] * len(df1), + "value": df1[("B", "b")], + }, + columns=[("A", "a"), "CAP", "low", "value"], + ) + + result = df1.melt(id_vars=[("A", "a")], value_vars=[("B", "b")]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "id_vars, value_vars, col_level, expected", + [ + ( + ["A"], + ["B"], + 0, + { + "A": {0: 1.067683, 1: -1.321405, 2: -0.807333}, + "CAP": {0: "B", 1: "B", 2: "B"}, + "value": {0: -1.110463, 1: 0.368915, 2: 0.08298}, + }, + ), + ( + ["a"], + ["b"], + 1, + { + "a": {0: 1.067683, 1: -1.321405, 2: -0.807333}, + "low": {0: "b", 1: "b", 2: "b"}, + "value": {0: -1.110463, 1: 0.368915, 2: 0.08298}, + }, + ), + ], + ) + def test_single_vars_work_with_multiindex( + self, id_vars, value_vars, col_level, expected, df1 + ): + result = df1.melt(id_vars, value_vars, col_level=col_level) + expected = DataFrame(expected) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "id_vars, value_vars", + [ + [("A", "a"), [("B", "b")]], + [[("A", "a")], ("B", "b")], + [("A", "a"), ("B", "b")], + ], + ) + def test_tuple_vars_fail_with_multiindex(self, id_vars, value_vars, df1): + # melt should fail with an informative error message if + # the columns have a MultiIndex and a tuple is passed + # for id_vars or value_vars. + msg = r"(id|value)_vars must be a list of tuples when columns are a MultiIndex" + with pytest.raises(ValueError, match=msg): + df1.melt(id_vars=id_vars, value_vars=value_vars) + + def test_custom_var_name(self, df, var_name): + result5 = df.melt(var_name=var_name) + assert result5.columns.tolist() == ["var", "value"] + + result6 = df.melt(id_vars=["id1"], var_name=var_name) + assert result6.columns.tolist() == ["id1", "var", "value"] + + result7 = df.melt(id_vars=["id1", "id2"], var_name=var_name) + assert result7.columns.tolist() == ["id1", "id2", "var", "value"] + + result8 = df.melt(id_vars=["id1", "id2"], value_vars="A", var_name=var_name) + assert result8.columns.tolist() == ["id1", "id2", "var", "value"] + + result9 = df.melt( + id_vars=["id1", "id2"], value_vars=["A", "B"], var_name=var_name + ) + expected9 = DataFrame( + { + "id1": df["id1"].tolist() * 2, + "id2": df["id2"].tolist() * 2, + var_name: ["A"] * 10 + ["B"] * 10, + "value": (df["A"].tolist() + df["B"].tolist()), + }, + columns=["id1", "id2", var_name, "value"], + ) + tm.assert_frame_equal(result9, expected9) + + def test_custom_value_name(self, df, value_name): + result10 = df.melt(value_name=value_name) + assert result10.columns.tolist() == ["variable", "val"] + + result11 = df.melt(id_vars=["id1"], value_name=value_name) + assert result11.columns.tolist() == ["id1", "variable", "val"] + + result12 = df.melt(id_vars=["id1", "id2"], value_name=value_name) + assert result12.columns.tolist() == ["id1", "id2", "variable", "val"] + + result13 = df.melt( + id_vars=["id1", "id2"], value_vars="A", value_name=value_name + ) + assert result13.columns.tolist() == ["id1", "id2", "variable", "val"] + + result14 = df.melt( + id_vars=["id1", "id2"], value_vars=["A", "B"], value_name=value_name + ) + expected14 = DataFrame( + { + "id1": df["id1"].tolist() * 2, + "id2": df["id2"].tolist() * 2, + "variable": ["A"] * 10 + ["B"] * 10, + value_name: (df["A"].tolist() + df["B"].tolist()), + }, + columns=["id1", "id2", "variable", value_name], + ) + tm.assert_frame_equal(result14, expected14) + + def test_custom_var_and_value_name(self, df, value_name, var_name): + result15 = df.melt(var_name=var_name, value_name=value_name) + assert result15.columns.tolist() == ["var", "val"] + + result16 = df.melt(id_vars=["id1"], var_name=var_name, value_name=value_name) + assert result16.columns.tolist() == ["id1", "var", "val"] + + result17 = df.melt( + id_vars=["id1", "id2"], var_name=var_name, value_name=value_name + ) + assert result17.columns.tolist() == ["id1", "id2", "var", "val"] + + result18 = df.melt( + id_vars=["id1", "id2"], + value_vars="A", + var_name=var_name, + value_name=value_name, + ) + assert result18.columns.tolist() == ["id1", "id2", "var", "val"] + + result19 = df.melt( + id_vars=["id1", "id2"], + value_vars=["A", "B"], + var_name=var_name, + value_name=value_name, + ) + expected19 = DataFrame( + { + "id1": df["id1"].tolist() * 2, + "id2": df["id2"].tolist() * 2, + var_name: ["A"] * 10 + ["B"] * 10, + value_name: (df["A"].tolist() + df["B"].tolist()), + }, + columns=["id1", "id2", var_name, value_name], + ) + tm.assert_frame_equal(result19, expected19) + + df20 = df.copy() + df20.columns.name = "foo" + result20 = df20.melt() + assert result20.columns.tolist() == ["foo", "value"] + + @pytest.mark.parametrize("col_level", [0, "CAP"]) + def test_col_level(self, col_level, df1): + res = df1.melt(col_level=col_level) + assert res.columns.tolist() == ["CAP", "value"] + + def test_multiindex(self, df1): + res = df1.melt() + assert res.columns.tolist() == ["CAP", "low", "value"] + + @pytest.mark.parametrize( + "col", + [ + date_range("2010", periods=5, tz="US/Pacific"), + pd.Categorical(["a", "b", "c", "a", "d"]), + [0, 1, 0, 0, 0], + ], + ) + def test_pandas_dtypes(self, col): + # GH 15785 + col = pd.Series(col) + df = DataFrame( + {"klass": range(5), "col": col, "attr1": [1, 0, 0, 0, 0], "attr2": col} + ) + expected_value = pd.concat([pd.Series([1, 0, 0, 0, 0]), col], ignore_index=True) + result = melt( + df, id_vars=["klass", "col"], var_name="attribute", value_name="value" + ) + expected = DataFrame( + { + 0: list(range(5)) * 2, + 1: pd.concat([col] * 2, ignore_index=True), + 2: ["attr1"] * 5 + ["attr2"] * 5, + 3: expected_value, + } + ) + expected.columns = ["klass", "col", "attribute", "value"] + tm.assert_frame_equal(result, expected) + + def test_preserve_category(self): + # GH 15853 + data = DataFrame({"A": [1, 2], "B": pd.Categorical(["X", "Y"])}) + result = melt(data, ["B"], ["A"]) + expected = DataFrame( + {"B": pd.Categorical(["X", "Y"]), "variable": ["A", "A"], "value": [1, 2]} + ) + + tm.assert_frame_equal(result, expected) + + def test_melt_missing_columns_raises(self): + # GH-23575 + # This test is to ensure that pandas raises an error if melting is + # attempted with column names absent from the dataframe + + # Generate data + df = DataFrame( + np.random.default_rng(2).standard_normal((5, 4)), columns=list("abcd") + ) + + # Try to melt with missing `value_vars` column name + msg = "The following id_vars or value_vars are not present in the DataFrame:" + with pytest.raises(KeyError, match=msg): + df.melt(["a", "b"], ["C", "d"]) + + # Try to melt with missing `id_vars` column name + with pytest.raises(KeyError, match=msg): + df.melt(["A", "b"], ["c", "d"]) + + # Multiple missing + with pytest.raises( + KeyError, + match=msg, + ): + df.melt(["a", "b", "not_here", "or_there"], ["c", "d"]) + + # Multiindex melt fails if column is missing from multilevel melt + df.columns = [list("ABCD"), list("abcd")] + with pytest.raises(KeyError, match=msg): + df.melt([("E", "a")], [("B", "b")]) + # Multiindex fails if column is missing from single level melt + with pytest.raises(KeyError, match=msg): + df.melt(["A"], ["F"], col_level=0) + + def test_melt_mixed_int_str_id_vars(self): + # GH 29718 + df = DataFrame({0: ["foo"], "a": ["bar"], "b": [1], "d": [2]}) + result = melt(df, id_vars=[0, "a"], value_vars=["b", "d"]) + expected = DataFrame( + {0: ["foo"] * 2, "a": ["bar"] * 2, "variable": list("bd"), "value": [1, 2]} + ) + # the df's columns are mixed type and thus object -> preserves object dtype + expected["variable"] = expected["variable"].astype(object) + tm.assert_frame_equal(result, expected) + + def test_melt_mixed_int_str_value_vars(self): + # GH 29718 + df = DataFrame({0: ["foo"], "a": ["bar"]}) + result = melt(df, value_vars=[0, "a"]) + expected = DataFrame({"variable": [0, "a"], "value": ["foo", "bar"]}) + tm.assert_frame_equal(result, expected) + + def test_ignore_index(self): + # GH 17440 + df = DataFrame({"foo": [0], "bar": [1]}, index=["first"]) + result = melt(df, ignore_index=False) + expected = DataFrame( + {"variable": ["foo", "bar"], "value": [0, 1]}, index=["first", "first"] + ) + tm.assert_frame_equal(result, expected) + + def test_ignore_multiindex(self): + # GH 17440 + index = pd.MultiIndex.from_tuples( + [("first", "second"), ("first", "third")], names=["baz", "foobar"] + ) + df = DataFrame({"foo": [0, 1], "bar": [2, 3]}, index=index) + result = melt(df, ignore_index=False) + + expected_index = pd.MultiIndex.from_tuples( + [("first", "second"), ("first", "third")] * 2, names=["baz", "foobar"] + ) + expected = DataFrame( + {"variable": ["foo"] * 2 + ["bar"] * 2, "value": [0, 1, 2, 3]}, + index=expected_index, + ) + + tm.assert_frame_equal(result, expected) + + def test_ignore_index_name_and_type(self): + # GH 17440 + index = Index(["foo", "bar"], dtype="category", name="baz") + df = DataFrame({"x": [0, 1], "y": [2, 3]}, index=index) + result = melt(df, ignore_index=False) + + expected_index = Index(["foo", "bar"] * 2, dtype="category", name="baz") + expected = DataFrame( + {"variable": ["x", "x", "y", "y"], "value": [0, 1, 2, 3]}, + index=expected_index, + ) + + tm.assert_frame_equal(result, expected) + + def test_melt_with_duplicate_columns(self): + # GH#41951 + df = DataFrame([["id", 2, 3]], columns=["a", "b", "b"]) + result = df.melt(id_vars=["a"], value_vars=["b"]) + expected = DataFrame( + [["id", "b", 2], ["id", "b", 3]], columns=["a", "variable", "value"] + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["Int8", "Int64"]) + def test_melt_ea_dtype(self, dtype): + # GH#41570 + df = DataFrame( + { + "a": pd.Series([1, 2], dtype="Int8"), + "b": pd.Series([3, 4], dtype=dtype), + } + ) + result = df.melt() + expected = DataFrame( + { + "variable": ["a", "a", "b", "b"], + "value": pd.Series([1, 2, 3, 4], dtype=dtype), + } + ) + tm.assert_frame_equal(result, expected) + + def test_melt_ea_columns(self): + # GH 54297 + df = DataFrame( + { + "A": {0: "a", 1: "b", 2: "c"}, + "B": {0: 1, 1: 3, 2: 5}, + "C": {0: 2, 1: 4, 2: 6}, + } + ) + df.columns = df.columns.astype("string[python]") + result = df.melt(id_vars=["A"], value_vars=["B"]) + expected = DataFrame( + { + "A": list("abc"), + "variable": pd.Series(["B"] * 3, dtype="string[python]"), + "value": [1, 3, 5], + } + ) + tm.assert_frame_equal(result, expected) + + def test_melt_preserves_datetime(self): + df = DataFrame( + data=[ + { + "type": "A0", + "start_date": pd.Timestamp("2023/03/01", tz="Asia/Tokyo"), + "end_date": pd.Timestamp("2023/03/10", tz="Asia/Tokyo"), + }, + { + "type": "A1", + "start_date": pd.Timestamp("2023/03/01", tz="Asia/Tokyo"), + "end_date": pd.Timestamp("2023/03/11", tz="Asia/Tokyo"), + }, + ], + index=["aaaa", "bbbb"], + ) + result = df.melt( + id_vars=["type"], + value_vars=["start_date", "end_date"], + var_name="start/end", + value_name="date", + ) + expected = DataFrame( + { + "type": {0: "A0", 1: "A1", 2: "A0", 3: "A1"}, + "start/end": { + 0: "start_date", + 1: "start_date", + 2: "end_date", + 3: "end_date", + }, + "date": { + 0: pd.Timestamp("2023-03-01 00:00:00+0900", tz="Asia/Tokyo"), + 1: pd.Timestamp("2023-03-01 00:00:00+0900", tz="Asia/Tokyo"), + 2: pd.Timestamp("2023-03-10 00:00:00+0900", tz="Asia/Tokyo"), + 3: pd.Timestamp("2023-03-11 00:00:00+0900", tz="Asia/Tokyo"), + }, + } + ) + tm.assert_frame_equal(result, expected) + + def test_melt_allows_non_scalar_id_vars(self): + df = DataFrame( + data={"a": [1, 2, 3], "b": [4, 5, 6]}, + index=["11", "22", "33"], + ) + result = df.melt( + id_vars="a", + var_name=0, + value_name=1, + ) + expected = DataFrame({"a": [1, 2, 3], 0: ["b"] * 3, 1: [4, 5, 6]}) + tm.assert_frame_equal(result, expected) + + def test_melt_allows_non_string_var_name(self): + df = DataFrame( + data={"a": [1, 2, 3], "b": [4, 5, 6]}, + index=["11", "22", "33"], + ) + result = df.melt( + id_vars=["a"], + var_name=0, + value_name=1, + ) + expected = DataFrame({"a": [1, 2, 3], 0: ["b"] * 3, 1: [4, 5, 6]}) + tm.assert_frame_equal(result, expected) + + def test_melt_non_scalar_var_name_raises(self): + df = DataFrame( + data={"a": [1, 2, 3], "b": [4, 5, 6]}, + index=["11", "22", "33"], + ) + with pytest.raises(ValueError, match=r".* must be a scalar."): + df.melt(id_vars=["a"], var_name=[1, 2]) + + def test_melt_multiindex_columns_var_name(self): + # GH 58033 + df = DataFrame({("A", "a"): [1], ("A", "b"): [2]}) + + expected = DataFrame( + [("A", "a", 1), ("A", "b", 2)], columns=["first", "second", "value"] + ) + + tm.assert_frame_equal(df.melt(var_name=["first", "second"]), expected) + tm.assert_frame_equal(df.melt(var_name=["first"]), expected[["first", "value"]]) + + def test_melt_multiindex_columns_var_name_too_many(self): + # GH 58033 + df = DataFrame({("A", "a"): [1], ("A", "b"): [2]}) + + with pytest.raises( + ValueError, match="but the dataframe columns only have 2 levels" + ): + df.melt(var_name=["first", "second", "third"]) + + def test_melt_duplicate_column_header_raises(self): + # GH61475 + df = DataFrame([[1, 2, 3], [3, 4, 5]], columns=["A", "A", "B"]) + msg = "id_vars cannot contain duplicate columns." + + with pytest.raises(ValueError, match=msg): + df.melt(id_vars=["A"], value_vars=["B"]) + + +class TestLreshape: + def test_pairs(self): + data = { + "birthdt": [ + "08jan2009", + "20dec2008", + "30dec2008", + "21dec2008", + "11jan2009", + ], + "birthwt": [1766, 3301, 1454, 3139, 4133], + "id": [101, 102, 103, 104, 105], + "sex": ["Male", "Female", "Female", "Female", "Female"], + "visitdt1": [ + "11jan2009", + "22dec2008", + "04jan2009", + "29dec2008", + "20jan2009", + ], + "visitdt2": ["21jan2009", np.nan, "22jan2009", "31dec2008", "03feb2009"], + "visitdt3": ["05feb2009", np.nan, np.nan, "02jan2009", "15feb2009"], + "wt1": [1823, 3338, 1549, 3298, 4306], + "wt2": [2011.0, np.nan, 1892.0, 3338.0, 4575.0], + "wt3": [2293.0, np.nan, np.nan, 3377.0, 4805.0], + } + + df = DataFrame(data) + + spec = { + "visitdt": [f"visitdt{i:d}" for i in range(1, 4)], + "wt": [f"wt{i:d}" for i in range(1, 4)], + } + result = lreshape(df, spec) + + exp_data = { + "birthdt": [ + "08jan2009", + "20dec2008", + "30dec2008", + "21dec2008", + "11jan2009", + "08jan2009", + "30dec2008", + "21dec2008", + "11jan2009", + "08jan2009", + "21dec2008", + "11jan2009", + ], + "birthwt": [ + 1766, + 3301, + 1454, + 3139, + 4133, + 1766, + 1454, + 3139, + 4133, + 1766, + 3139, + 4133, + ], + "id": [101, 102, 103, 104, 105, 101, 103, 104, 105, 101, 104, 105], + "sex": [ + "Male", + "Female", + "Female", + "Female", + "Female", + "Male", + "Female", + "Female", + "Female", + "Male", + "Female", + "Female", + ], + "visitdt": [ + "11jan2009", + "22dec2008", + "04jan2009", + "29dec2008", + "20jan2009", + "21jan2009", + "22jan2009", + "31dec2008", + "03feb2009", + "05feb2009", + "02jan2009", + "15feb2009", + ], + "wt": [ + 1823.0, + 3338.0, + 1549.0, + 3298.0, + 4306.0, + 2011.0, + 1892.0, + 3338.0, + 4575.0, + 2293.0, + 3377.0, + 4805.0, + ], + } + exp = DataFrame(exp_data, columns=result.columns) + tm.assert_frame_equal(result, exp) + + result = lreshape(df, spec, dropna=False) + exp_data = { + "birthdt": [ + "08jan2009", + "20dec2008", + "30dec2008", + "21dec2008", + "11jan2009", + "08jan2009", + "20dec2008", + "30dec2008", + "21dec2008", + "11jan2009", + "08jan2009", + "20dec2008", + "30dec2008", + "21dec2008", + "11jan2009", + ], + "birthwt": [ + 1766, + 3301, + 1454, + 3139, + 4133, + 1766, + 3301, + 1454, + 3139, + 4133, + 1766, + 3301, + 1454, + 3139, + 4133, + ], + "id": [ + 101, + 102, + 103, + 104, + 105, + 101, + 102, + 103, + 104, + 105, + 101, + 102, + 103, + 104, + 105, + ], + "sex": [ + "Male", + "Female", + "Female", + "Female", + "Female", + "Male", + "Female", + "Female", + "Female", + "Female", + "Male", + "Female", + "Female", + "Female", + "Female", + ], + "visitdt": [ + "11jan2009", + "22dec2008", + "04jan2009", + "29dec2008", + "20jan2009", + "21jan2009", + np.nan, + "22jan2009", + "31dec2008", + "03feb2009", + "05feb2009", + np.nan, + np.nan, + "02jan2009", + "15feb2009", + ], + "wt": [ + 1823.0, + 3338.0, + 1549.0, + 3298.0, + 4306.0, + 2011.0, + np.nan, + 1892.0, + 3338.0, + 4575.0, + 2293.0, + np.nan, + np.nan, + 3377.0, + 4805.0, + ], + } + exp = DataFrame(exp_data, columns=result.columns) + tm.assert_frame_equal(result, exp) + + spec = { + "visitdt": [f"visitdt{i:d}" for i in range(1, 3)], + "wt": [f"wt{i:d}" for i in range(1, 4)], + } + msg = "All column lists must be same length" + with pytest.raises(ValueError, match=msg): + lreshape(df, spec) + + +class TestWideToLong: + def test_simple(self): + x = np.random.default_rng(2).standard_normal(3) + df = DataFrame( + { + "A1970": {0: "a", 1: "b", 2: "c"}, + "A1980": {0: "d", 1: "e", 2: "f"}, + "B1970": {0: 2.5, 1: 1.2, 2: 0.7}, + "B1980": {0: 3.2, 1: 1.3, 2: 0.1}, + "X": dict(zip(range(3), x)), + } + ) + df["id"] = df.index + exp_data = { + "X": x.tolist() + x.tolist(), + "A": ["a", "b", "c", "d", "e", "f"], + "B": [2.5, 1.2, 0.7, 3.2, 1.3, 0.1], + "year": [1970, 1970, 1970, 1980, 1980, 1980], + "id": [0, 1, 2, 0, 1, 2], + } + expected = DataFrame(exp_data) + expected = expected.set_index(["id", "year"])[["X", "A", "B"]] + result = wide_to_long(df, ["A", "B"], i="id", j="year") + tm.assert_frame_equal(result, expected) + + def test_stubs(self): + # GH9204 wide_to_long call should not modify 'stubs' list + df = DataFrame([[0, 1, 2, 3, 8], [4, 5, 6, 7, 9]]) + df.columns = ["id", "inc1", "inc2", "edu1", "edu2"] + stubs = ["inc", "edu"] + + wide_to_long(df, stubs, i="id", j="age") + + assert stubs == ["inc", "edu"] + + def test_separating_character(self): + # GH14779 + + x = np.random.default_rng(2).standard_normal(3) + df = DataFrame( + { + "A.1970": {0: "a", 1: "b", 2: "c"}, + "A.1980": {0: "d", 1: "e", 2: "f"}, + "B.1970": {0: 2.5, 1: 1.2, 2: 0.7}, + "B.1980": {0: 3.2, 1: 1.3, 2: 0.1}, + "X": dict(zip(range(3), x)), + } + ) + df["id"] = df.index + exp_data = { + "X": x.tolist() + x.tolist(), + "A": ["a", "b", "c", "d", "e", "f"], + "B": [2.5, 1.2, 0.7, 3.2, 1.3, 0.1], + "year": [1970, 1970, 1970, 1980, 1980, 1980], + "id": [0, 1, 2, 0, 1, 2], + } + expected = DataFrame(exp_data) + expected = expected.set_index(["id", "year"])[["X", "A", "B"]] + result = wide_to_long(df, ["A", "B"], i="id", j="year", sep=".") + tm.assert_frame_equal(result, expected) + + def test_escapable_characters(self): + x = np.random.default_rng(2).standard_normal(3) + df = DataFrame( + { + "A(quarterly)1970": {0: "a", 1: "b", 2: "c"}, + "A(quarterly)1980": {0: "d", 1: "e", 2: "f"}, + "B(quarterly)1970": {0: 2.5, 1: 1.2, 2: 0.7}, + "B(quarterly)1980": {0: 3.2, 1: 1.3, 2: 0.1}, + "X": dict(zip(range(3), x)), + } + ) + df["id"] = df.index + exp_data = { + "X": x.tolist() + x.tolist(), + "A(quarterly)": ["a", "b", "c", "d", "e", "f"], + "B(quarterly)": [2.5, 1.2, 0.7, 3.2, 1.3, 0.1], + "year": [1970, 1970, 1970, 1980, 1980, 1980], + "id": [0, 1, 2, 0, 1, 2], + } + expected = DataFrame(exp_data) + expected = expected.set_index(["id", "year"])[ + ["X", "A(quarterly)", "B(quarterly)"] + ] + result = wide_to_long(df, ["A(quarterly)", "B(quarterly)"], i="id", j="year") + tm.assert_frame_equal(result, expected) + + def test_unbalanced(self): + # test that we can have a varying amount of time variables + df = DataFrame( + { + "A2010": [1.0, 2.0], + "A2011": [3.0, 4.0], + "B2010": [5.0, 6.0], + "X": ["X1", "X2"], + } + ) + df["id"] = df.index + exp_data = { + "X": ["X1", "X2", "X1", "X2"], + "A": [1.0, 2.0, 3.0, 4.0], + "B": [5.0, 6.0, np.nan, np.nan], + "id": [0, 1, 0, 1], + "year": [2010, 2010, 2011, 2011], + } + expected = DataFrame(exp_data) + expected = expected.set_index(["id", "year"])[["X", "A", "B"]] + result = wide_to_long(df, ["A", "B"], i="id", j="year") + tm.assert_frame_equal(result, expected) + + def test_character_overlap(self): + # Test we handle overlapping characters in both id_vars and value_vars + df = DataFrame( + { + "A11": ["a11", "a22", "a33"], + "A12": ["a21", "a22", "a23"], + "B11": ["b11", "b12", "b13"], + "B12": ["b21", "b22", "b23"], + "BB11": [1, 2, 3], + "BB12": [4, 5, 6], + "BBBX": [91, 92, 93], + "BBBZ": [91, 92, 93], + } + ) + df["id"] = df.index + expected = DataFrame( + { + "BBBX": [91, 92, 93, 91, 92, 93], + "BBBZ": [91, 92, 93, 91, 92, 93], + "A": ["a11", "a22", "a33", "a21", "a22", "a23"], + "B": ["b11", "b12", "b13", "b21", "b22", "b23"], + "BB": [1, 2, 3, 4, 5, 6], + "id": [0, 1, 2, 0, 1, 2], + "year": [11, 11, 11, 12, 12, 12], + } + ) + expected = expected.set_index(["id", "year"])[["BBBX", "BBBZ", "A", "B", "BB"]] + result = wide_to_long(df, ["A", "B", "BB"], i="id", j="year") + tm.assert_frame_equal(result.sort_index(axis=1), expected.sort_index(axis=1)) + + def test_invalid_separator(self): + # if an invalid separator is supplied an empty data frame is returned + sep = "nope!" + df = DataFrame( + { + "A2010": [1.0, 2.0], + "A2011": [3.0, 4.0], + "B2010": [5.0, 6.0], + "X": ["X1", "X2"], + } + ) + df["id"] = df.index + exp_data = { + "X": "", + "A2010": [], + "A2011": [], + "B2010": [], + "id": [], + "year": [], + "A": [], + "B": [], + } + expected = DataFrame(exp_data).astype({"year": np.int64}) + expected = expected.set_index(["id", "year"])[ + ["X", "A2010", "A2011", "B2010", "A", "B"] + ] + expected.index = expected.index.set_levels([0, 1], level=0) + result = wide_to_long(df, ["A", "B"], i="id", j="year", sep=sep) + tm.assert_frame_equal(result.sort_index(axis=1), expected.sort_index(axis=1)) + + def test_num_string_disambiguation(self): + # Test that we can disambiguate number value_vars from + # string value_vars + df = DataFrame( + { + "A11": ["a11", "a22", "a33"], + "A12": ["a21", "a22", "a23"], + "B11": ["b11", "b12", "b13"], + "B12": ["b21", "b22", "b23"], + "BB11": [1, 2, 3], + "BB12": [4, 5, 6], + "Arating": [91, 92, 93], + "Arating_old": [91, 92, 93], + } + ) + df["id"] = df.index + expected = DataFrame( + { + "Arating": [91, 92, 93, 91, 92, 93], + "Arating_old": [91, 92, 93, 91, 92, 93], + "A": ["a11", "a22", "a33", "a21", "a22", "a23"], + "B": ["b11", "b12", "b13", "b21", "b22", "b23"], + "BB": [1, 2, 3, 4, 5, 6], + "id": [0, 1, 2, 0, 1, 2], + "year": [11, 11, 11, 12, 12, 12], + } + ) + expected = expected.set_index(["id", "year"])[ + ["Arating", "Arating_old", "A", "B", "BB"] + ] + result = wide_to_long(df, ["A", "B", "BB"], i="id", j="year") + tm.assert_frame_equal(result.sort_index(axis=1), expected.sort_index(axis=1)) + + def test_invalid_suffixtype(self): + # If all stubs names end with a string, but a numeric suffix is + # assumed, an empty data frame is returned + df = DataFrame( + { + "Aone": [1.0, 2.0], + "Atwo": [3.0, 4.0], + "Bone": [5.0, 6.0], + "X": ["X1", "X2"], + } + ) + df["id"] = df.index + exp_data = { + "X": "", + "Aone": [], + "Atwo": [], + "Bone": [], + "id": [], + "year": [], + "A": [], + "B": [], + } + expected = DataFrame(exp_data).astype({"year": np.int64}) + + expected = expected.set_index(["id", "year"]) + expected.index = expected.index.set_levels([0, 1], level=0) + result = wide_to_long(df, ["A", "B"], i="id", j="year") + tm.assert_frame_equal(result.sort_index(axis=1), expected.sort_index(axis=1)) + + def test_multiple_id_columns(self): + # Taken from http://www.ats.ucla.edu/stat/stata/modules/reshapel.htm + df = DataFrame( + { + "famid": [1, 1, 1, 2, 2, 2, 3, 3, 3], + "birth": [1, 2, 3, 1, 2, 3, 1, 2, 3], + "ht1": [2.8, 2.9, 2.2, 2, 1.8, 1.9, 2.2, 2.3, 2.1], + "ht2": [3.4, 3.8, 2.9, 3.2, 2.8, 2.4, 3.3, 3.4, 2.9], + } + ) + expected = DataFrame( + { + "ht": [ + 2.8, + 3.4, + 2.9, + 3.8, + 2.2, + 2.9, + 2.0, + 3.2, + 1.8, + 2.8, + 1.9, + 2.4, + 2.2, + 3.3, + 2.3, + 3.4, + 2.1, + 2.9, + ], + "famid": [1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3], + "birth": [1, 1, 2, 2, 3, 3, 1, 1, 2, 2, 3, 3, 1, 1, 2, 2, 3, 3], + "age": [1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + } + ) + expected = expected.set_index(["famid", "birth", "age"])[["ht"]] + result = wide_to_long(df, "ht", i=["famid", "birth"], j="age") + tm.assert_frame_equal(result, expected) + + def test_non_unique_idvars(self): + # GH16382 + # Raise an error message if non unique id vars (i) are passed + df = DataFrame( + {"A_A1": [1, 2, 3, 4, 5], "B_B1": [1, 2, 3, 4, 5], "x": [1, 1, 1, 1, 1]} + ) + msg = "the id variables need to uniquely identify each row" + with pytest.raises(ValueError, match=msg): + wide_to_long(df, ["A_A", "B_B"], i="x", j="colname") + + def test_cast_j_int(self): + df = DataFrame( + { + "actor_1": ["CCH Pounder", "Johnny Depp", "Christoph Waltz"], + "actor_2": ["Joel David Moore", "Orlando Bloom", "Rory Kinnear"], + "actor_fb_likes_1": [1000.0, 40000.0, 11000.0], + "actor_fb_likes_2": [936.0, 5000.0, 393.0], + "title": ["Avatar", "Pirates of the Caribbean", "Spectre"], + } + ) + + expected = DataFrame( + { + "actor": [ + "CCH Pounder", + "Johnny Depp", + "Christoph Waltz", + "Joel David Moore", + "Orlando Bloom", + "Rory Kinnear", + ], + "actor_fb_likes": [1000.0, 40000.0, 11000.0, 936.0, 5000.0, 393.0], + "num": [1, 1, 1, 2, 2, 2], + "title": [ + "Avatar", + "Pirates of the Caribbean", + "Spectre", + "Avatar", + "Pirates of the Caribbean", + "Spectre", + ], + } + ).set_index(["title", "num"]) + result = wide_to_long( + df, ["actor", "actor_fb_likes"], i="title", j="num", sep="_" + ) + + tm.assert_frame_equal(result, expected) + + def test_identical_stubnames(self): + df = DataFrame( + { + "A2010": [1.0, 2.0], + "A2011": [3.0, 4.0], + "B2010": [5.0, 6.0], + "A": ["X1", "X2"], + } + ) + msg = "stubname can't be identical to a column name" + with pytest.raises(ValueError, match=msg): + wide_to_long(df, ["A", "B"], i="A", j="colname") + + def test_nonnumeric_suffix(self): + df = DataFrame( + { + "treatment_placebo": [1.0, 2.0], + "treatment_test": [3.0, 4.0], + "result_placebo": [5.0, 6.0], + "A": ["X1", "X2"], + } + ) + expected = DataFrame( + { + "A": ["X1", "X2", "X1", "X2"], + "colname": ["placebo", "placebo", "test", "test"], + "result": [5.0, 6.0, np.nan, np.nan], + "treatment": [1.0, 2.0, 3.0, 4.0], + } + ) + expected = expected.set_index(["A", "colname"]) + result = wide_to_long( + df, ["result", "treatment"], i="A", j="colname", suffix="[a-z]+", sep="_" + ) + tm.assert_frame_equal(result, expected) + + def test_mixed_type_suffix(self): + df = DataFrame( + { + "A": ["X1", "X2"], + "result_1": [0, 9], + "result_foo": [5.0, 6.0], + "treatment_1": [1.0, 2.0], + "treatment_foo": [3.0, 4.0], + } + ) + expected = DataFrame( + { + "A": ["X1", "X2", "X1", "X2"], + "colname": ["1", "1", "foo", "foo"], + "result": [0.0, 9.0, 5.0, 6.0], + "treatment": [1.0, 2.0, 3.0, 4.0], + } + ).set_index(["A", "colname"]) + result = wide_to_long( + df, ["result", "treatment"], i="A", j="colname", suffix=".+", sep="_" + ) + tm.assert_frame_equal(result, expected) + + def test_float_suffix(self): + df = DataFrame( + { + "treatment_1.1": [1.0, 2.0], + "treatment_2.1": [3.0, 4.0], + "result_1.2": [5.0, 6.0], + "result_1": [0, 9], + "A": ["X1", "X2"], + } + ) + expected = DataFrame( + { + "A": ["X1", "X2", "X1", "X2", "X1", "X2", "X1", "X2"], + "colname": [1.2, 1.2, 1.0, 1.0, 1.1, 1.1, 2.1, 2.1], + "result": [5.0, 6.0, 0.0, 9.0, np.nan, np.nan, np.nan, np.nan], + "treatment": [np.nan, np.nan, np.nan, np.nan, 1.0, 2.0, 3.0, 4.0], + } + ) + expected = expected.set_index(["A", "colname"]) + result = wide_to_long( + df, ["result", "treatment"], i="A", j="colname", suffix="[0-9.]+", sep="_" + ) + tm.assert_frame_equal(result, expected) + + def test_col_substring_of_stubname(self): + # GH22468 + # Don't raise ValueError when a column name is a substring + # of a stubname that's been passed as a string + wide_data = { + "node_id": {0: 0, 1: 1, 2: 2, 3: 3, 4: 4}, + "A": {0: 0.80, 1: 0.0, 2: 0.25, 3: 1.0, 4: 0.81}, + "PA0": {0: 0.74, 1: 0.56, 2: 0.56, 3: 0.98, 4: 0.6}, + "PA1": {0: 0.77, 1: 0.64, 2: 0.52, 3: 0.98, 4: 0.67}, + "PA3": {0: 0.34, 1: 0.70, 2: 0.52, 3: 0.98, 4: 0.67}, + } + wide_df = DataFrame.from_dict(wide_data) + expected = wide_to_long(wide_df, stubnames=["PA"], i=["node_id", "A"], j="time") + result = wide_to_long(wide_df, stubnames="PA", i=["node_id", "A"], j="time") + tm.assert_frame_equal(result, expected) + + def test_raise_of_column_name_value(self): + # GH34731, enforced in 2.0 + # raise a ValueError if the resultant value column name matches + # a name in the dataframe already (default name is "value") + df = DataFrame({"col": list("ABC"), "value": range(10, 16, 2)}) + + with pytest.raises( + ValueError, match=re.escape("value_name (value) cannot match") + ): + df.melt(id_vars="value", value_name="value") + + def test_missing_stubname(self, any_string_dtype): + # GH46044 + df = DataFrame({"id": ["1", "2"], "a-1": [100, 200], "a-2": [300, 400]}) + df = df.astype({"id": any_string_dtype}) + result = wide_to_long( + df, + stubnames=["a", "b"], + i="id", + j="num", + sep="-", + ) + index = Index( + [("1", 1), ("2", 1), ("1", 2), ("2", 2)], + name=("id", "num"), + ) + expected = DataFrame( + {"a": [100, 200, 300, 400], "b": [np.nan] * 4}, + index=index, + ) + new_level = expected.index.levels[0].astype(any_string_dtype) + if any_string_dtype == "object": + new_level = expected.index.levels[0].astype("str") + expected.index = expected.index.set_levels(new_level, level=0) + tm.assert_frame_equal(result, expected) + + +def test_wide_to_long_string_columns(string_storage): + # GH 57066 + string_dtype = pd.StringDtype(string_storage, na_value=np.nan) + df = DataFrame( + { + "ID": {0: 1}, + "R_test1": {0: 1}, + "R_test2": {0: 1}, + "R_test3": {0: 2}, + "D": {0: 1}, + } + ) + df.columns = df.columns.astype(string_dtype) + result = wide_to_long( + df, stubnames="R", i="ID", j="UNPIVOTED", sep="_", suffix=".*" + ) + expected = DataFrame( + [[1, 1], [1, 1], [1, 2]], + columns=Index(["D", "R"]), + index=pd.MultiIndex.from_arrays( + [ + [1, 1, 1], + Index(["test1", "test2", "test3"], dtype=string_dtype), + ], + names=["ID", "UNPIVOTED"], + ), + ) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/reshape/test_pivot.py b/pandas/tests/reshape/test_pivot.py new file mode 100644 index 0000000000000000000000000000000000000000..6745ba0bac765a17b583724cc57dc207bb2e81a6 --- /dev/null +++ b/pandas/tests/reshape/test_pivot.py @@ -0,0 +1,2961 @@ +from datetime import ( + date, + datetime, + timedelta, +) +from itertools import product +import re + +import numpy as np +import pytest + +from pandas._config import using_string_dtype + +import pandas as pd +from pandas import ( + ArrowDtype, + Categorical, + DataFrame, + Grouper, + Index, + MultiIndex, + Series, + concat, + date_range, +) +import pandas._testing as tm +from pandas.api.types import CategoricalDtype +from pandas.core.reshape import reshape as reshape_lib +from pandas.core.reshape.pivot import pivot_table + + +class TestPivotTable: + @pytest.fixture + def data(self): + return DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": [ + "dull", + "dull", + "shiny", + "dull", + "dull", + "shiny", + "shiny", + "dull", + "shiny", + "shiny", + "shiny", + ], + "D": np.random.default_rng(2).standard_normal(11), + "E": np.random.default_rng(2).standard_normal(11), + "F": np.random.default_rng(2).standard_normal(11), + } + ) + + def test_pivot_table(self, observed, data): + index = ["A", "B"] + columns = "C" + table = pivot_table( + data, values="D", index=index, columns=columns, observed=observed + ) + + table2 = data.pivot_table( + values="D", index=index, columns=columns, observed=observed + ) + tm.assert_frame_equal(table, table2) + + # this works + pivot_table(data, values="D", index=index, observed=observed) + + if len(index) > 1: + assert table.index.names == tuple(index) + else: + assert table.index.name == index[0] + + if len(columns) > 1: + assert table.columns.names == columns + else: + assert table.columns.name == columns[0] + + expected = data.groupby([*index, columns])["D"].agg("mean").unstack() + tm.assert_frame_equal(table, expected) + + def test_pivot_table_categorical_observed_equal(self, observed): + # issue #24923 + df = DataFrame( + {"col1": list("abcde"), "col2": list("fghij"), "col3": [1, 2, 3, 4, 5]} + ) + + expected = df.pivot_table( + index="col1", values="col3", columns="col2", aggfunc="sum", fill_value=0 + ) + + expected.index = expected.index.astype("category") + expected.columns = expected.columns.astype("category") + + df.col1 = df.col1.astype("category") + df.col2 = df.col2.astype("category") + + result = df.pivot_table( + index="col1", + values="col3", + columns="col2", + aggfunc="sum", + fill_value=0, + observed=observed, + ) + + tm.assert_frame_equal(result, expected) + + def test_pivot_table_nocols(self): + df = DataFrame( + {"rows": ["a", "b", "c"], "cols": ["x", "y", "z"], "values": [1, 2, 3]} + ) + rs = df.pivot_table(columns="cols", aggfunc="sum") + xp = df.pivot_table(index="cols", aggfunc="sum").T + tm.assert_frame_equal(rs, xp) + + rs = df.pivot_table(columns="cols", aggfunc={"values": "mean"}) + xp = df.pivot_table(index="cols", aggfunc={"values": "mean"}).T + tm.assert_frame_equal(rs, xp) + + def test_pivot_table_dropna(self): + df = DataFrame( + { + "amount": {0: 60000, 1: 100000, 2: 50000, 3: 30000}, + "customer": {0: "A", 1: "A", 2: "B", 3: "C"}, + "month": {0: 201307, 1: 201309, 2: 201308, 3: 201310}, + "product": {0: "a", 1: "b", 2: "c", 3: "d"}, + "quantity": {0: 2000000, 1: 500000, 2: 1000000, 3: 1000000}, + } + ) + pv_col = df.pivot_table( + "quantity", "month", ["customer", "product"], dropna=False + ) + pv_ind = df.pivot_table( + "quantity", ["customer", "product"], "month", dropna=False + ) + + m = MultiIndex.from_tuples( + [ + ("A", "a"), + ("A", "b"), + ("A", "c"), + ("A", "d"), + ("B", "a"), + ("B", "b"), + ("B", "c"), + ("B", "d"), + ("C", "a"), + ("C", "b"), + ("C", "c"), + ("C", "d"), + ], + names=["customer", "product"], + ) + tm.assert_index_equal(pv_col.columns, m) + tm.assert_index_equal(pv_ind.index, m) + + def test_pivot_table_categorical(self): + cat1 = Categorical( + ["a", "a", "b", "b"], categories=["a", "b", "z"], ordered=True + ) + cat2 = Categorical( + ["c", "d", "c", "d"], categories=["c", "d", "y"], ordered=True + ) + df = DataFrame({"A": cat1, "B": cat2, "values": [1, 2, 3, 4]}) + result = pivot_table( + df, values="values", index=["A", "B"], dropna=True, observed=False + ) + + exp_index = MultiIndex.from_arrays([cat1, cat2], names=["A", "B"]) + expected = DataFrame({"values": [1.0, 2.0, 3.0, 4.0]}, index=exp_index) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_dropna_categoricals(self, dropna): + # GH 15193 + categories = ["a", "b", "c", "d"] + + df = DataFrame( + { + "A": ["a", "a", "a", "b", "b", "b", "c", "c", "c"], + "B": [1, 2, 3, 1, 2, 3, 1, 2, 3], + "C": range(9), + } + ) + + df["A"] = df["A"].astype(CategoricalDtype(categories, ordered=False)) + result = df.pivot_table( + index="B", columns="A", values="C", dropna=dropna, observed=False + ) + expected_columns = Series(["a", "b", "c"], name="A") + expected_columns = expected_columns.astype( + CategoricalDtype(categories, ordered=False) + ) + expected_index = Series([1, 2, 3], name="B") + expected = DataFrame( + [[0.0, 3.0, 6.0], [1.0, 4.0, 7.0], [2.0, 5.0, 8.0]], + index=expected_index, + columns=expected_columns, + ) + if not dropna: + # add back the non observed to compare + expected = expected.reindex(columns=Categorical(categories)).astype("float") + + tm.assert_frame_equal(result, expected) + + def test_pivot_with_non_observable_dropna(self, dropna): + # gh-21133 + df = DataFrame( + { + "A": Categorical( + [np.nan, "low", "high", "low", "high"], + categories=["low", "high"], + ordered=True, + ), + "B": [0.0, 1.0, 2.0, 3.0, 4.0], + } + ) + + result = df.pivot_table(index="A", values="B", dropna=dropna, observed=False) + if dropna: + values = [2.0, 3.0] + codes = [0, 1] + else: + # GH: 10772 + values = [2.0, 3.0, 0.0] + codes = [0, 1, -1] + expected = DataFrame( + {"B": values}, + index=Index( + Categorical.from_codes(codes, categories=["low", "high"], ordered=True), + name="A", + ), + ) + + tm.assert_frame_equal(result, expected) + + def test_pivot_with_non_observable_dropna_multi_cat(self, dropna): + # gh-21378 + df = DataFrame( + { + "A": Categorical( + ["left", "low", "high", "low", "high"], + categories=["low", "high", "left"], + ordered=True, + ), + "B": range(5), + } + ) + + result = df.pivot_table(index="A", values="B", dropna=dropna, observed=False) + expected = DataFrame( + {"B": [2.0, 3.0, 0.0]}, + index=Index( + Categorical.from_codes( + [0, 1, 2], categories=["low", "high", "left"], ordered=True + ), + name="A", + ), + ) + if not dropna: + expected["B"] = expected["B"].astype(float) + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "left_right", [([0] * 4, [1] * 4), (range(3), range(1, 4))] + ) + def test_pivot_with_interval_index(self, left_right, dropna, closed): + # GH 25814 + left, right = left_right + interval_values = Categorical(pd.IntervalIndex.from_arrays(left, right, closed)) + df = DataFrame({"A": interval_values, "B": 1}) + + result = df.pivot_table(index="A", values="B", dropna=dropna, observed=False) + expected = DataFrame( + {"B": 1.0}, index=Index(interval_values.unique(), name="A") + ) + if not dropna: + expected = expected.astype(float) + tm.assert_frame_equal(result, expected) + + def test_pivot_with_interval_index_margins(self): + # GH 25815 + ordered_cat = pd.IntervalIndex.from_arrays([0, 0, 1, 1], [1, 1, 2, 2]) + df = DataFrame( + { + "A": np.arange(4, 0, -1, dtype=np.intp), + "B": ["a", "b", "a", "b"], + "C": Categorical(ordered_cat, ordered=True).sort_values( + ascending=False + ), + } + ) + + pivot_tab = pivot_table( + df, + index="C", + columns="B", + values="A", + aggfunc="sum", + margins=True, + observed=False, + ) + + result = pivot_tab["All"] + expected = Series( + [3, 7, 10], + index=Index([pd.Interval(0, 1), pd.Interval(1, 2), "All"], name="C"), + name="All", + dtype=np.intp, + ) + tm.assert_series_equal(result, expected) + + def test_pass_array(self, data): + result = data.pivot_table("D", index=data.A, columns=data.C) + expected = data.pivot_table("D", index="A", columns="C") + tm.assert_frame_equal(result, expected) + + def test_pass_function(self, data): + result = data.pivot_table("D", index=lambda x: x // 5, columns=data.C) + expected = data.pivot_table("D", index=data.index // 5, columns="C") + tm.assert_frame_equal(result, expected) + + def test_pivot_table_multiple(self, data): + index = ["A", "B"] + columns = "C" + table = pivot_table(data, index=index, columns=columns) + expected = data.groupby([*index, columns]).agg("mean").unstack() + tm.assert_frame_equal(table, expected) + + def test_pivot_dtypes(self): + # can convert dtypes + f = DataFrame( + { + "a": ["cat", "bat", "cat", "bat"], + "v": [1, 2, 3, 4], + "i": ["a", "b", "a", "b"], + } + ) + assert f.dtypes["v"] == "int64" + + z = pivot_table( + f, values="v", index=["a"], columns=["i"], fill_value=0, aggfunc="sum" + ) + result = z.dtypes + expected = Series([np.dtype("int64")] * 2, index=Index(list("ab"), name="i")) + tm.assert_series_equal(result, expected) + + # cannot convert dtypes + f = DataFrame( + { + "a": ["cat", "bat", "cat", "bat"], + "v": [1.5, 2.5, 3.5, 4.5], + "i": ["a", "b", "a", "b"], + } + ) + assert f.dtypes["v"] == "float64" + + z = pivot_table( + f, values="v", index=["a"], columns=["i"], fill_value=0, aggfunc="mean" + ) + result = z.dtypes + expected = Series([np.dtype("float64")] * 2, index=Index(list("ab"), name="i")) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "columns,values", + [ + ("bool1", ["float1", "float2"]), + ("bool1", ["float1", "float2", "bool1"]), + ("bool2", ["float1", "float2", "bool1"]), + ], + ) + def test_pivot_preserve_dtypes(self, columns, values): + # GH 7142 regression test + v = np.arange(5, dtype=np.float64) + df = DataFrame( + {"float1": v, "float2": v + 2.0, "bool1": v <= 2, "bool2": v <= 3} + ) + + df_res = df.reset_index().pivot_table( + index="index", columns=columns, values=values + ) + + result = dict(df_res.dtypes) + expected = {col: np.dtype("float64") for col in df_res} + assert result == expected + + def test_pivot_no_values(self): + # GH 14380 + idx = pd.DatetimeIndex( + ["2011-01-01", "2011-02-01", "2011-01-02", "2011-01-01", "2011-01-02"] + ) + df = DataFrame({"A": [1, 2, 3, 4, 5]}, index=idx) + res = df.pivot_table(index=df.index.month, columns=df.index.day) + + exp_columns = MultiIndex.from_tuples([("A", 1), ("A", 2)]) + exp_columns = exp_columns.set_levels( + exp_columns.levels[1].astype(np.int32), level=1 + ) + exp = DataFrame( + [[2.5, 4.0], [2.0, np.nan]], + index=Index([1, 2], dtype=np.int32), + columns=exp_columns, + ) + tm.assert_frame_equal(res, exp) + + df = DataFrame( + { + "A": [1, 2, 3, 4, 5], + "dt": date_range("2011-01-01", freq="D", periods=5, unit="ns"), + }, + index=idx, + ) + res = df.pivot_table(index=df.index.month, columns=Grouper(key="dt", freq="ME")) + exp_columns = MultiIndex.from_arrays( + [["A"], pd.DatetimeIndex(["2011-01-31"], dtype="M8[ns]")], + names=[None, "dt"], + ) + exp = DataFrame( + [3.25, 2.0], index=Index([1, 2], dtype=np.int32), columns=exp_columns + ) + tm.assert_frame_equal(res, exp) + + res = df.pivot_table( + index=Grouper(freq="YE"), columns=Grouper(key="dt", freq="ME") + ) + exp = DataFrame( + [3.0], + index=pd.DatetimeIndex(["2011-12-31"], freq="YE"), + columns=exp_columns, + ) + tm.assert_frame_equal(res, exp) + + def test_pivot_multi_values(self, data): + result = pivot_table( + data, values=["D", "E"], index="A", columns=["B", "C"], fill_value=0 + ) + expected = pivot_table( + data.drop(["F"], axis=1), index="A", columns=["B", "C"], fill_value=0 + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_multi_functions(self, data): + f = lambda func: pivot_table( + data, values=["D", "E"], index=["A", "B"], columns="C", aggfunc=func + ) + result = f(["mean", "std"]) + means = f("mean") + stds = f("std") + expected = concat([means, stds], keys=["mean", "std"], axis=1) + tm.assert_frame_equal(result, expected) + + # margins not supported?? + f = lambda func: pivot_table( + data, + values=["D", "E"], + index=["A", "B"], + columns="C", + aggfunc=func, + margins=True, + ) + result = f(["mean", "std"]) + means = f("mean") + stds = f("std") + expected = concat([means, stds], keys=["mean", "std"], axis=1) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_index_with_nan(self, method): + # GH 3588 + nan = np.nan + df = DataFrame( + { + "a": ["R1", "R2", nan, "R4"], + "b": ["C1", "C2", "C3", "C4"], + "c": [10, 15, 17, 20], + } + ) + if method: + result = df.pivot(index="a", columns="b", values="c") + else: + result = pd.pivot(df, index="a", columns="b", values="c") + expected = DataFrame( + [ + [nan, nan, 17, nan], + [10, nan, nan, nan], + [nan, 15, nan, nan], + [nan, nan, nan, 20], + ], + index=Index([nan, "R1", "R2", "R4"], name="a"), + columns=Index(["C1", "C2", "C3", "C4"], name="b"), + ) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(df.pivot(index="b", columns="a", values="c"), expected.T) + + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_index_with_nan_dates(self, method): + # GH9491 + df = DataFrame( + { + "a": date_range("2014-02-01", periods=6, freq="D"), + "c": 100 + np.arange(6), + } + ) + df["b"] = df["a"] - pd.Timestamp("2014-02-02") + df.loc[1, "a"] = df.loc[3, "a"] = np.nan + df.loc[1, "b"] = df.loc[4, "b"] = np.nan + + if method: + pv = df.pivot(index="a", columns="b", values="c") + else: + pv = pd.pivot(df, index="a", columns="b", values="c") + assert pv.notna().values.sum() == len(df) + + for _, row in df.iterrows(): + assert pv.loc[row["a"], row["b"]] == row["c"] + + if method: + result = df.pivot(index="b", columns="a", values="c") + else: + result = pd.pivot(df, index="b", columns="a", values="c") + tm.assert_frame_equal(result, pv.T) + + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_with_tz(self, method, unit): + # GH 5878 + df = DataFrame( + { + "dt1": pd.DatetimeIndex( + [ + datetime(2013, 1, 1, 9, 0), + datetime(2013, 1, 2, 9, 0), + datetime(2013, 1, 1, 9, 0), + datetime(2013, 1, 2, 9, 0), + ], + dtype=f"M8[{unit}, US/Pacific]", + ), + "dt2": pd.DatetimeIndex( + [ + datetime(2014, 1, 1, 9, 0), + datetime(2014, 1, 1, 9, 0), + datetime(2014, 1, 2, 9, 0), + datetime(2014, 1, 2, 9, 0), + ], + dtype=f"M8[{unit}, Asia/Tokyo]", + ), + "data1": np.arange(4, dtype="int64"), + "data2": np.arange(4, dtype="int64"), + } + ) + + exp_col1 = Index(["data1", "data1", "data2", "data2"]) + exp_col2 = pd.DatetimeIndex( + ["2014/01/01 09:00", "2014/01/02 09:00"] * 2, + name="dt2", + dtype=f"M8[{unit}, Asia/Tokyo]", + ) + exp_col = MultiIndex.from_arrays([exp_col1, exp_col2]) + exp_idx = pd.DatetimeIndex( + ["2013/01/01 09:00", "2013/01/02 09:00"], + name="dt1", + dtype=f"M8[{unit}, US/Pacific]", + ) + expected = DataFrame( + [[0, 2, 0, 2], [1, 3, 1, 3]], + index=exp_idx, + columns=exp_col, + ) + + if method: + pv = df.pivot(index="dt1", columns="dt2") + else: + pv = pd.pivot(df, index="dt1", columns="dt2") + tm.assert_frame_equal(pv, expected) + + expected = DataFrame( + [[0, 2], [1, 3]], + index=exp_idx, + columns=exp_col2[:2], + ) + + if method: + pv = df.pivot(index="dt1", columns="dt2", values="data1") + else: + pv = pd.pivot(df, index="dt1", columns="dt2", values="data1") + tm.assert_frame_equal(pv, expected) + + def test_pivot_tz_in_values(self): + # GH 14948 + df = DataFrame( + [ + { + "uid": "aa", + "ts": pd.Timestamp("2016-08-12 13:00:00-0700", tz="US/Pacific"), + }, + { + "uid": "aa", + "ts": pd.Timestamp("2016-08-12 08:00:00-0700", tz="US/Pacific"), + }, + { + "uid": "aa", + "ts": pd.Timestamp("2016-08-12 14:00:00-0700", tz="US/Pacific"), + }, + { + "uid": "aa", + "ts": pd.Timestamp("2016-08-25 11:00:00-0700", tz="US/Pacific"), + }, + { + "uid": "aa", + "ts": pd.Timestamp("2016-08-25 13:00:00-0700", tz="US/Pacific"), + }, + ] + ) + + df = df.set_index("ts").reset_index() + mins = df.ts.map(lambda x: x.replace(hour=0, minute=0, second=0)) + + result = pivot_table( + df.set_index("ts").reset_index(), + values="ts", + index=["uid"], + columns=[mins], + aggfunc="min", + ) + expected = DataFrame( + [ + [ + pd.Timestamp("2016-08-12 08:00:00-0700", tz="US/Pacific"), + pd.Timestamp("2016-08-25 11:00:00-0700", tz="US/Pacific"), + ] + ], + index=Index(["aa"], name="uid"), + columns=pd.DatetimeIndex( + [ + pd.Timestamp("2016-08-12 00:00:00", tz="US/Pacific"), + pd.Timestamp("2016-08-25 00:00:00", tz="US/Pacific"), + ], + name="ts", + ), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_periods(self, method): + df = DataFrame( + { + "p1": [ + pd.Period("2013-01-01", "D"), + pd.Period("2013-01-02", "D"), + pd.Period("2013-01-01", "D"), + pd.Period("2013-01-02", "D"), + ], + "p2": [ + pd.Period("2013-01", "M"), + pd.Period("2013-01", "M"), + pd.Period("2013-02", "M"), + pd.Period("2013-02", "M"), + ], + "data1": np.arange(4, dtype="int64"), + "data2": np.arange(4, dtype="int64"), + } + ) + + exp_col1 = Index(["data1", "data1", "data2", "data2"]) + exp_col2 = pd.PeriodIndex(["2013-01", "2013-02"] * 2, name="p2", freq="M") + exp_col = MultiIndex.from_arrays([exp_col1, exp_col2]) + expected = DataFrame( + [[0, 2, 0, 2], [1, 3, 1, 3]], + index=pd.PeriodIndex(["2013-01-01", "2013-01-02"], name="p1", freq="D"), + columns=exp_col, + ) + if method: + pv = df.pivot(index="p1", columns="p2") + else: + pv = pd.pivot(df, index="p1", columns="p2") + tm.assert_frame_equal(pv, expected) + + expected = DataFrame( + [[0, 2], [1, 3]], + index=pd.PeriodIndex(["2013-01-01", "2013-01-02"], name="p1", freq="D"), + columns=pd.PeriodIndex(["2013-01", "2013-02"], name="p2", freq="M"), + ) + if method: + pv = df.pivot(index="p1", columns="p2", values="data1") + else: + pv = pd.pivot(df, index="p1", columns="p2", values="data1") + tm.assert_frame_equal(pv, expected) + + def test_pivot_periods_with_margins(self): + # GH 28323 + df = DataFrame( + { + "a": [1, 1, 2, 2], + "b": [ + pd.Period("2019Q1"), + pd.Period("2019Q2"), + pd.Period("2019Q1"), + pd.Period("2019Q2"), + ], + "x": 1.0, + } + ) + + expected = DataFrame( + data=1.0, + index=Index([1, 2, "All"], name="a"), + columns=Index([pd.Period("2019Q1"), pd.Period("2019Q2"), "All"], name="b"), + ) + + result = df.pivot_table(index="a", columns="b", values="x", margins=True) + tm.assert_frame_equal(expected, result) + + @pytest.mark.parametrize("box", [list, np.array, Series, Index]) + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_with_list_like_values(self, box, method): + # issue #17160 + values = box(["baz", "zoo"]) + df = DataFrame( + { + "foo": ["one", "one", "one", "two", "two", "two"], + "bar": ["A", "B", "C", "A", "B", "C"], + "baz": [1, 2, 3, 4, 5, 6], + "zoo": ["x", "y", "z", "q", "w", "t"], + } + ) + + if method: + result = df.pivot(index="foo", columns="bar", values=values) + else: + result = pd.pivot(df, index="foo", columns="bar", values=values) + + data = [[1, 2, 3, "x", "y", "z"], [4, 5, 6, "q", "w", "t"]] + index = Index(data=["one", "two"], name="foo") + columns = MultiIndex( + levels=[["baz", "zoo"], ["A", "B", "C"]], + codes=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + names=[None, "bar"], + ) + expected = DataFrame(data=data, index=index, columns=columns) + expected["baz"] = expected["baz"].astype(object) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "values", + [ + ["bar", "baz"], + np.array(["bar", "baz"]), + Series(["bar", "baz"]), + Index(["bar", "baz"]), + ], + ) + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_with_list_like_values_nans(self, values, method): + # issue #17160 + df = DataFrame( + { + "foo": ["one", "one", "one", "two", "two", "two"], + "bar": ["A", "B", "C", "A", "B", "C"], + "baz": [1, 2, 3, 4, 5, 6], + "zoo": ["x", "y", "z", "q", "w", "t"], + } + ) + + if method: + result = df.pivot(index="zoo", columns="foo", values=values) + else: + result = pd.pivot(df, index="zoo", columns="foo", values=values) + + data = [ + [np.nan, "A", np.nan, 4], + [np.nan, "C", np.nan, 6], + [np.nan, "B", np.nan, 5], + ["A", np.nan, 1, np.nan], + ["B", np.nan, 2, np.nan], + ["C", np.nan, 3, np.nan], + ] + index = Index(data=["q", "t", "w", "x", "y", "z"], name="zoo") + columns = MultiIndex( + levels=[["bar", "baz"], ["one", "two"]], + codes=[[0, 0, 1, 1], [0, 1, 0, 1]], + names=[None, "foo"], + ) + expected = DataFrame(data=data, index=index, columns=columns) + expected["baz"] = expected["baz"].astype(object) + tm.assert_frame_equal(result, expected) + + def test_pivot_columns_none_raise_error(self): + # GH 30924 + df = DataFrame({"col1": ["a", "b", "c"], "col2": [1, 2, 3], "col3": [1, 2, 3]}) + msg = r"pivot\(\) missing 1 required keyword-only argument: 'columns'" + with pytest.raises(TypeError, match=msg): + df.pivot(index="col1", values="col3") + + @pytest.mark.xfail( + reason="MultiIndexed unstack with tuple names fails with KeyError GH#19966" + ) + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_with_multiindex(self, method): + # issue #17160 + index = Index(data=[0, 1, 2, 3, 4, 5]) + data = [ + ["one", "A", 1, "x"], + ["one", "B", 2, "y"], + ["one", "C", 3, "z"], + ["two", "A", 4, "q"], + ["two", "B", 5, "w"], + ["two", "C", 6, "t"], + ] + columns = MultiIndex( + levels=[["bar", "baz"], ["first", "second"]], + codes=[[0, 0, 1, 1], [0, 1, 0, 1]], + ) + df = DataFrame(data=data, index=index, columns=columns, dtype="object") + if method: + result = df.pivot( + index=("bar", "first"), + columns=("bar", "second"), + values=("baz", "first"), + ) + else: + result = pd.pivot( + df, + index=("bar", "first"), + columns=("bar", "second"), + values=("baz", "first"), + ) + + data = { + "A": Series([1, 4], index=["one", "two"]), + "B": Series([2, 5], index=["one", "two"]), + "C": Series([3, 6], index=["one", "two"]), + } + expected = DataFrame(data) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("method", [True, False]) + def test_pivot_with_tuple_of_values(self, method): + # issue #17160 + df = DataFrame( + { + "foo": ["one", "one", "one", "two", "two", "two"], + "bar": ["A", "B", "C", "A", "B", "C"], + "baz": [1, 2, 3, 4, 5, 6], + "zoo": ["x", "y", "z", "q", "w", "t"], + } + ) + with pytest.raises(KeyError, match=r"^\('bar', 'baz'\)$"): + # tuple is seen as a single column name + if method: + df.pivot(index="zoo", columns="foo", values=("bar", "baz")) + else: + pd.pivot(df, index="zoo", columns="foo", values=("bar", "baz")) + + def _check_output( + self, + result, + values_col, + data, + index=None, + columns=None, + margins_col="All", + ): + if index is None: + index = ["A", "B"] + if columns is None: + columns = ["C"] + col_margins = result.loc[result.index[:-1], margins_col] + expected_col_margins = data.groupby(index)[values_col].mean() + tm.assert_series_equal(col_margins, expected_col_margins, check_names=False) + assert col_margins.name == margins_col + + result = result.sort_index() + index_margins = result.loc[(margins_col, "")].iloc[:-1] + + expected_ix_margins = data.groupby(columns)[values_col].mean() + tm.assert_series_equal(index_margins, expected_ix_margins, check_names=False) + assert index_margins.name == (margins_col, "") + + grand_total_margins = result.loc[(margins_col, ""), margins_col] + expected_total_margins = data[values_col].mean() + assert grand_total_margins == expected_total_margins + + def test_margins(self, data): + # column specified + result = data.pivot_table( + values="D", index=["A", "B"], columns="C", margins=True, aggfunc="mean" + ) + self._check_output(result, "D", data) + + # Set a different margins_name (not 'All') + result = data.pivot_table( + values="D", + index=["A", "B"], + columns="C", + margins=True, + aggfunc="mean", + margins_name="Totals", + ) + self._check_output(result, "D", data, margins_col="Totals") + + # no column specified + table = data.pivot_table( + index=["A", "B"], columns="C", margins=True, aggfunc="mean" + ) + for value_col in table.columns.levels[0]: + self._check_output(table[value_col], value_col, data) + + def test_no_col(self, data, using_infer_string): + # no col + + # to help with a buglet + data.columns = [k * 2 for k in data.columns] + msg = re.escape("agg function failed [how->mean,dtype->") + if using_infer_string: + msg = "dtype 'str' does not support operation 'mean'" + with pytest.raises(TypeError, match=msg): + data.pivot_table(index=["AA", "BB"], margins=True, aggfunc="mean") + table = data.drop(columns="CC").pivot_table( + index=["AA", "BB"], margins=True, aggfunc="mean" + ) + for value_col in table.columns: + totals = table.loc[("All", ""), value_col] + assert totals == data[value_col].mean() + + with pytest.raises(TypeError, match=msg): + data.pivot_table(index=["AA", "BB"], margins=True, aggfunc="mean") + table = data.drop(columns="CC").pivot_table( + index=["AA", "BB"], margins=True, aggfunc="mean" + ) + for item in ["DD", "EE", "FF"]: + totals = table.loc[("All", ""), item] + assert totals == data[item].mean() + + @pytest.mark.parametrize( + "columns, aggfunc, values, expected_columns", + [ + ( + "A", + "mean", + [[5.5, 5.5, 2.2, 2.2], [8.0, 8.0, 4.4, 4.4]], + Index(["bar", "All", "foo", "All"], name="A"), + ), + ( + ["A", "B"], + "sum", + [ + [9, 13, 22, 5, 6, 11], + [14, 18, 32, 11, 11, 22], + ], + MultiIndex.from_tuples( + [ + ("bar", "one"), + ("bar", "two"), + ("bar", "All"), + ("foo", "one"), + ("foo", "two"), + ("foo", "All"), + ], + names=["A", "B"], + ), + ), + ], + ) + def test_margin_with_only_columns_defined( + self, columns, aggfunc, values, expected_columns, using_infer_string + ): + # GH 31016 + df = DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "C": [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + "E": [2, 4, 5, 5, 6, 6, 8, 9, 9], + } + ) + if aggfunc != "sum": + msg = re.escape("agg function failed [how->mean,dtype->") + if using_infer_string: + msg = "dtype 'str' does not support operation 'mean'" + with pytest.raises(TypeError, match=msg): + df.pivot_table(columns=columns, margins=True, aggfunc=aggfunc) + if "B" not in columns: + df = df.drop(columns="B") + result = df.drop(columns="C").pivot_table( + columns=columns, margins=True, aggfunc=aggfunc + ) + expected = DataFrame(values, index=Index(["D", "E"]), columns=expected_columns) + + tm.assert_frame_equal(result, expected) + + def test_margins_dtype(self, data): + # GH 17013 + + df = data.copy() + df[["D", "E", "F"]] = np.arange(len(df) * 3).reshape(len(df), 3).astype("i8") + + mi_val = [*list(product(["bar", "foo"], ["one", "two"])), ("All", "")] + mi = MultiIndex.from_tuples(mi_val, names=("A", "B")) + expected = DataFrame( + {"dull": [12, 21, 3, 9, 45], "shiny": [33, 0, 36, 51, 120]}, index=mi + ).rename_axis("C", axis=1) + expected["All"] = expected["dull"] + expected["shiny"] + + result = df.pivot_table( + values="D", + index=["A", "B"], + columns="C", + margins=True, + aggfunc="sum", + fill_value=0, + ) + + tm.assert_frame_equal(expected, result) + + def test_margins_dtype_len(self, data): + mi_val = [*list(product(["bar", "foo"], ["one", "two"])), ("All", "")] + mi = MultiIndex.from_tuples(mi_val, names=("A", "B")) + expected = DataFrame( + {"dull": [1, 1, 2, 1, 5], "shiny": [2, 0, 2, 2, 6]}, index=mi + ).rename_axis("C", axis=1) + expected["All"] = expected["dull"] + expected["shiny"] + + result = data.pivot_table( + values="D", + index=["A", "B"], + columns="C", + margins=True, + aggfunc=len, + fill_value=0, + ) + + tm.assert_frame_equal(expected, result) + + @pytest.mark.parametrize("cols", [(1, 2), ("a", "b"), (1, "b"), ("a", 1)]) + def test_pivot_table_multiindex_only(self, cols): + # GH 17038 + df2 = DataFrame({cols[0]: [1, 2, 3], cols[1]: [1, 2, 3], "v": [4, 5, 6]}) + + result = df2.pivot_table(values="v", columns=cols) + expected = DataFrame( + [[4.0, 5.0, 6.0]], + columns=MultiIndex.from_tuples([(1, 1), (2, 2), (3, 3)], names=cols), + index=Index(["v"], dtype="str" if cols == ("a", "b") else "object"), + ) + + tm.assert_frame_equal(result, expected) + + def test_pivot_table_retains_tz(self): + dti = date_range("2016-01-01", periods=3, tz="Europe/Amsterdam") + df = DataFrame( + { + "A": np.random.default_rng(2).standard_normal(3), + "B": np.random.default_rng(2).standard_normal(3), + "C": dti, + } + ) + result = df.pivot_table(index=["B", "C"], dropna=False) + + # check tz retention + assert result.index.levels[1].equals(dti) + + def test_pivot_integer_columns(self): + # caused by upstream bug in unstack + + d = date.min + data = list( + product( + ["foo", "bar"], + ["A", "B", "C"], + ["x1", "x2"], + [d + timedelta(i) for i in range(20)], + [1.0], + ) + ) + df = DataFrame(data) + table = df.pivot_table(values=4, index=[0, 1, 3], columns=[2]) + + df2 = df.rename(columns=str) + table2 = df2.pivot_table(values="4", index=["0", "1", "3"], columns=["2"]) + + tm.assert_frame_equal(table, table2, check_names=False) + + def test_pivot_no_level_overlap(self): + # GH #1181 + + data = DataFrame( + { + "a": ["a", "a", "a", "a", "b", "b", "b", "b"] * 2, + "b": [0, 0, 0, 0, 1, 1, 1, 1] * 2, + "c": (["foo"] * 4 + ["bar"] * 4) * 2, + "value": np.random.default_rng(2).standard_normal(16), + } + ) + + table = data.pivot_table("value", index="a", columns=["b", "c"]) + + grouped = data.groupby(["a", "b", "c"])["value"].mean() + expected = grouped.unstack("b").unstack("c").dropna(axis=1, how="all") + tm.assert_frame_equal(table, expected) + + def test_pivot_columns_lexsorted(self): + n = 10000 + + dtype = np.dtype( + [ + ("Index", object), + ("Symbol", object), + ("Year", int), + ("Month", int), + ("Day", int), + ("Quantity", int), + ("Price", float), + ] + ) + + products = np.array( + [ + ("SP500", "ADBE"), + ("SP500", "NVDA"), + ("SP500", "ORCL"), + ("NDQ100", "AAPL"), + ("NDQ100", "MSFT"), + ("NDQ100", "GOOG"), + ("FTSE", "DGE.L"), + ("FTSE", "TSCO.L"), + ("FTSE", "GSK.L"), + ], + dtype=[("Index", object), ("Symbol", object)], + ) + items = np.empty(n, dtype=dtype) + iproduct = np.random.default_rng(2).integers(0, len(products), n) + items["Index"] = products["Index"][iproduct] + items["Symbol"] = products["Symbol"][iproduct] + dr = date_range(date(2000, 1, 1), date(2010, 12, 31)) + dates = dr[np.random.default_rng(2).integers(0, len(dr), n)] + items["Year"] = dates.year + items["Month"] = dates.month + items["Day"] = dates.day + items["Price"] = np.random.default_rng(2).lognormal(4.0, 2.0, n) + + df = DataFrame(items) + + pivoted = df.pivot_table( + "Price", + index=["Month", "Day"], + columns=["Index", "Symbol", "Year"], + aggfunc="mean", + ) + + assert pivoted.columns.is_monotonic_increasing + + def test_pivot_complex_aggfunc(self, data): + f = {"D": ["std"], "E": ["sum"]} + expected = data.groupby(["A", "B"]).agg(f).unstack("B") + result = data.pivot_table(index="A", columns="B", aggfunc=f) + + tm.assert_frame_equal(result, expected) + + def test_margins_no_values_no_cols(self, data): + # Regression test on pivot table: no values or cols passed. + result = data[["A", "B"]].pivot_table( + index=["A", "B"], aggfunc=len, margins=True + ) + result_list = result.tolist() + assert sum(result_list[:-1]) == result_list[-1] + + def test_margins_no_values_two_rows(self, data): + # Regression test on pivot table: no values passed but rows are a + # multi-index + result = data[["A", "B", "C"]].pivot_table( + index=["A", "B"], columns="C", aggfunc=len, margins=True + ) + assert result.All.tolist() == [3.0, 1.0, 4.0, 3.0, 11.0] + + def test_margins_no_values_one_row_one_col(self, data): + # Regression test on pivot table: no values passed but row and col + # defined + result = data[["A", "B"]].pivot_table( + index="A", columns="B", aggfunc=len, margins=True + ) + assert result.All.tolist() == [4.0, 7.0, 11.0] + + def test_margins_no_values_two_row_two_cols(self, data): + # Regression test on pivot table: no values passed but rows and cols + # are multi-indexed + data["D"] = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"] + result = data[["A", "B", "C", "D"]].pivot_table( + index=["A", "B"], columns=["C", "D"], aggfunc=len, margins=True + ) + assert result.All.tolist() == [3.0, 1.0, 4.0, 3.0, 11.0] + + @pytest.mark.parametrize("margin_name", ["foo", "one", 666, None, ["a", "b"]]) + def test_pivot_table_with_margins_set_margin_name(self, margin_name, data): + # see gh-3335 + msg = ( + f'Conflicting name "{margin_name}" in margins|' + "margins_name argument must be a string" + ) + with pytest.raises(ValueError, match=msg): + # multi-index index + pivot_table( + data, + values="D", + index=["A", "B"], + columns=["C"], + margins=True, + margins_name=margin_name, + ) + with pytest.raises(ValueError, match=msg): + # multi-index column + pivot_table( + data, + values="D", + index=["C"], + columns=["A", "B"], + margins=True, + margins_name=margin_name, + ) + with pytest.raises(ValueError, match=msg): + # non-multi-index index/column + pivot_table( + data, + values="D", + index=["A"], + columns=["B"], + margins=True, + margins_name=margin_name, + ) + + def test_pivot_timegrouper(self): + df = DataFrame( + { + "Branch": "A A A A A A A B".split(), + "Buyer": "Carl Mark Carl Carl Joe Joe Joe Carl".split(), + "Quantity": [1, 3, 5, 1, 8, 1, 9, 3], + "Date": [ + datetime(2013, 1, 1), + datetime(2013, 1, 1), + datetime(2013, 10, 1), + datetime(2013, 10, 2), + datetime(2013, 10, 1), + datetime(2013, 10, 2), + datetime(2013, 12, 2), + datetime(2013, 12, 2), + ], + } + ).set_index("Date") + + expected = DataFrame( + np.array([10, 18, 3], dtype="int64").reshape(1, 3), + index=pd.DatetimeIndex([datetime(2013, 12, 31)], freq="YE"), + columns="Carl Joe Mark".split(), + ) + expected.index.name = "Date" + expected.columns.name = "Buyer" + + result = pivot_table( + df, + index=Grouper(freq="YE"), + columns="Buyer", + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index="Buyer", + columns=Grouper(freq="YE"), + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected.T) + + expected = DataFrame( + np.array([1, np.nan, 3, 9, 18, np.nan]).reshape(2, 3), + index=pd.DatetimeIndex( + [datetime(2013, 1, 1), datetime(2013, 7, 1)], freq="6MS" + ), + columns="Carl Joe Mark".split(), + ) + expected.index.name = "Date" + expected.columns.name = "Buyer" + + result = pivot_table( + df, + index=Grouper(freq="6MS"), + columns="Buyer", + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index="Buyer", + columns=Grouper(freq="6MS"), + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected.T) + + # passing the name + df = df.reset_index() + result = pivot_table( + df, + index=Grouper(freq="6MS", key="Date"), + columns="Buyer", + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index="Buyer", + columns=Grouper(freq="6MS", key="Date"), + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected.T) + + msg = "'The grouper name foo is not found'" + with pytest.raises(KeyError, match=msg): + pivot_table( + df, + index=Grouper(freq="6MS", key="foo"), + columns="Buyer", + values="Quantity", + aggfunc="sum", + ) + with pytest.raises(KeyError, match=msg): + pivot_table( + df, + index="Buyer", + columns=Grouper(freq="6MS", key="foo"), + values="Quantity", + aggfunc="sum", + ) + + # passing the level + df = df.set_index("Date") + result = pivot_table( + df, + index=Grouper(freq="6MS", level="Date"), + columns="Buyer", + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index="Buyer", + columns=Grouper(freq="6MS", level="Date"), + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected.T) + + msg = "The level foo is not valid" + with pytest.raises(ValueError, match=msg): + pivot_table( + df, + index=Grouper(freq="6MS", level="foo"), + columns="Buyer", + values="Quantity", + aggfunc="sum", + ) + with pytest.raises(ValueError, match=msg): + pivot_table( + df, + index="Buyer", + columns=Grouper(freq="6MS", level="foo"), + values="Quantity", + aggfunc="sum", + ) + + def test_pivot_timegrouper_double(self): + # double grouper + df = DataFrame( + { + "Branch": "A A A A A A A B".split(), + "Buyer": "Carl Mark Carl Carl Joe Joe Joe Carl".split(), + "Quantity": [1, 3, 5, 1, 8, 1, 9, 3], + "Date": [ + datetime(2013, 11, 1, 13, 0), + datetime(2013, 9, 1, 13, 5), + datetime(2013, 10, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 11, 1, 20, 0), + datetime(2013, 10, 2, 10, 0), + datetime(2013, 10, 2, 12, 0), + datetime(2013, 12, 5, 14, 0), + ], + "PayDay": [ + datetime(2013, 10, 4, 0, 0), + datetime(2013, 10, 15, 13, 5), + datetime(2013, 9, 5, 20, 0), + datetime(2013, 11, 2, 10, 0), + datetime(2013, 10, 7, 20, 0), + datetime(2013, 9, 5, 10, 0), + datetime(2013, 12, 30, 12, 0), + datetime(2013, 11, 20, 14, 0), + ], + } + ) + + result = pivot_table( + df, + index=Grouper(freq="ME", key="Date"), + columns=Grouper(freq="ME", key="PayDay"), + values="Quantity", + aggfunc="sum", + ) + expected = DataFrame( + np.array( + [ + np.nan, + 3, + np.nan, + np.nan, + 6, + np.nan, + 1, + 9, + np.nan, + 9, + np.nan, + np.nan, + np.nan, + np.nan, + 3, + np.nan, + ] + ).reshape(4, 4), + index=pd.DatetimeIndex( + [ + datetime(2013, 9, 30), + datetime(2013, 10, 31), + datetime(2013, 11, 30), + datetime(2013, 12, 31), + ], + freq="ME", + ), + columns=pd.DatetimeIndex( + [ + datetime(2013, 9, 30), + datetime(2013, 10, 31), + datetime(2013, 11, 30), + datetime(2013, 12, 31), + ], + freq="ME", + ), + ) + expected.index.name = "Date" + expected.columns.name = "PayDay" + + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index=Grouper(freq="ME", key="PayDay"), + columns=Grouper(freq="ME", key="Date"), + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected.T) + + tuples = [ + (datetime(2013, 9, 30), datetime(2013, 10, 31)), + (datetime(2013, 10, 31), datetime(2013, 9, 30)), + (datetime(2013, 10, 31), datetime(2013, 11, 30)), + (datetime(2013, 10, 31), datetime(2013, 12, 31)), + (datetime(2013, 11, 30), datetime(2013, 10, 31)), + (datetime(2013, 12, 31), datetime(2013, 11, 30)), + ] + idx = MultiIndex.from_tuples(tuples, names=["Date", "PayDay"]) + expected = DataFrame( + np.array( + [3, np.nan, 6, np.nan, 1, np.nan, 9, np.nan, 9, np.nan, np.nan, 3] + ).reshape(6, 2), + index=idx, + columns=["A", "B"], + ) + expected.columns.name = "Branch" + + result = pivot_table( + df, + index=[Grouper(freq="ME", key="Date"), Grouper(freq="ME", key="PayDay")], + columns=["Branch"], + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index=["Branch"], + columns=[Grouper(freq="ME", key="Date"), Grouper(freq="ME", key="PayDay")], + values="Quantity", + aggfunc="sum", + ) + tm.assert_frame_equal(result, expected.T) + + def test_pivot_datetime_tz(self): + dates1 = pd.DatetimeIndex( + [ + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + ], + dtype="M8[ns, US/Pacific]", + name="dt1", + ) + dates2 = pd.DatetimeIndex( + [ + "2013-01-01 15:00:00", + "2013-01-01 15:00:00", + "2013-01-01 15:00:00", + "2013-02-01 15:00:00", + "2013-02-01 15:00:00", + "2013-02-01 15:00:00", + ], + dtype="M8[ns, Asia/Tokyo]", + ) + df = DataFrame( + { + "label": ["a", "a", "a", "b", "b", "b"], + "dt1": dates1, + "dt2": dates2, + "value1": np.arange(6, dtype="int64"), + "value2": [1, 2] * 3, + } + ) + + exp_idx = dates1[:3] + exp_col1 = Index(["value1", "value1"]) + exp_col2 = Index(["a", "b"], name="label") + exp_col = MultiIndex.from_arrays([exp_col1, exp_col2]) + expected = DataFrame( + [[0.0, 3.0], [1.0, 4.0], [2.0, 5.0]], index=exp_idx, columns=exp_col + ) + result = pivot_table(df, index=["dt1"], columns=["label"], values=["value1"]) + tm.assert_frame_equal(result, expected) + + exp_col1 = Index(["sum", "sum", "sum", "sum", "mean", "mean", "mean", "mean"]) + exp_col2 = Index(["value1", "value1", "value2", "value2"] * 2) + exp_col3 = pd.DatetimeIndex( + ["2013-01-01 15:00:00", "2013-02-01 15:00:00"] * 4, + dtype="M8[ns, Asia/Tokyo]", + name="dt2", + ) + exp_col = MultiIndex.from_arrays([exp_col1, exp_col2, exp_col3]) + expected1 = DataFrame( + np.array( + [ + [ + 0, + 3, + 1, + 2, + ], + [1, 4, 2, 1], + [2, 5, 1, 2], + ], + dtype="int64", + ), + index=exp_idx, + columns=exp_col[:4], + ) + expected2 = DataFrame( + np.array( + [ + [0.0, 3.0, 1.0, 2.0], + [1.0, 4.0, 2.0, 1.0], + [2.0, 5.0, 1.0, 2.0], + ], + ), + index=exp_idx, + columns=exp_col[4:], + ) + expected = concat([expected1, expected2], axis=1) + + result = pivot_table( + df, + index=["dt1"], + columns=["dt2"], + values=["value1", "value2"], + aggfunc=["sum", "mean"], + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_dtaccessor(self): + # GH 8103 + dates1 = pd.DatetimeIndex( + [ + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + "2011-07-19 07:00:00", + "2011-07-19 08:00:00", + "2011-07-19 09:00:00", + ] + ) + dates2 = pd.DatetimeIndex( + [ + "2013-01-01 15:00:00", + "2013-01-01 15:00:00", + "2013-01-01 15:00:00", + "2013-02-01 15:00:00", + "2013-02-01 15:00:00", + "2013-02-01 15:00:00", + ] + ) + df = DataFrame( + { + "label": ["a", "a", "a", "b", "b", "b"], + "dt1": dates1, + "dt2": dates2, + "value1": np.arange(6, dtype="int64"), + "value2": [1, 2] * 3, + } + ) + + result = pivot_table( + df, index="label", columns=df["dt1"].dt.hour, values="value1" + ) + + exp_idx = Index(["a", "b"], name="label") + expected = DataFrame( + {7: [0.0, 3.0], 8: [1.0, 4.0], 9: [2.0, 5.0]}, + index=exp_idx, + columns=Index([7, 8, 9], dtype=np.int32, name="dt1"), + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, index=df["dt2"].dt.month, columns=df["dt1"].dt.hour, values="value1" + ) + + expected = DataFrame( + {7: [0.0, 3.0], 8: [1.0, 4.0], 9: [2.0, 5.0]}, + index=Index([1, 2], dtype=np.int32, name="dt2"), + columns=Index([7, 8, 9], dtype=np.int32, name="dt1"), + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index=df["dt2"].dt.year.values, + columns=[df["dt1"].dt.hour, df["dt2"].dt.month], + values="value1", + ) + + exp_col = MultiIndex.from_arrays( + [ + np.array([7, 7, 8, 8, 9, 9], dtype=np.int32), + np.array([1, 2] * 3, dtype=np.int32), + ], + names=["dt1", "dt2"], + ) + expected = DataFrame( + np.array([[0.0, 3.0, 1.0, 4.0, 2.0, 5.0]]), + index=Index([2013], dtype=np.int32), + columns=exp_col, + ) + tm.assert_frame_equal(result, expected) + + result = pivot_table( + df, + index=np.array(["X", "X", "X", "X", "Y", "Y"]), + columns=[df["dt1"].dt.hour, df["dt2"].dt.month], + values="value1", + ) + expected = DataFrame( + np.array( + [[0, 3, 1, np.nan, 2, np.nan], [np.nan, np.nan, np.nan, 4, np.nan, 5]] + ), + index=["X", "Y"], + columns=exp_col, + ) + tm.assert_frame_equal(result, expected) + + def test_daily(self): + rng = date_range("1/1/2000", "12/31/2004", freq="D") + ts = Series(np.arange(len(rng)), index=rng) + + result = pivot_table( + DataFrame(ts), index=ts.index.year, columns=ts.index.dayofyear + ) + result.columns = result.columns.droplevel(0) + + doy = np.asarray(ts.index.dayofyear) + + expected = {} + for y in ts.index.year.unique().values: + mask = ts.index.year == y + expected[y] = Series(ts.values[mask], index=doy[mask]) + expected = DataFrame(expected, dtype=float).T + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(result, expected) + + def test_monthly(self): + rng = date_range("1/1/2000", "12/31/2004", freq="ME") + ts = Series(np.arange(len(rng)), index=rng) + + result = pivot_table(DataFrame(ts), index=ts.index.year, columns=ts.index.month) + result.columns = result.columns.droplevel(0) + + month = np.asarray(ts.index.month) + expected = {} + for y in ts.index.year.unique().values: + mask = ts.index.year == y + expected[y] = Series(ts.values[mask], index=month[mask]) + expected = DataFrame(expected, dtype=float).T + expected.index = expected.index.astype(np.int32) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_with_iterator_values(self, data): + # GH 12017 + aggs = {"D": "sum", "E": "mean"} + + pivot_values_list = pivot_table( + data, index=["A"], values=list(aggs.keys()), aggfunc=aggs + ) + + pivot_values_keys = pivot_table( + data, index=["A"], values=aggs.keys(), aggfunc=aggs + ) + tm.assert_frame_equal(pivot_values_keys, pivot_values_list) + + agg_values_gen = (value for value in aggs) + pivot_values_gen = pivot_table( + data, index=["A"], values=agg_values_gen, aggfunc=aggs + ) + tm.assert_frame_equal(pivot_values_gen, pivot_values_list) + + def test_pivot_table_margins_name_with_aggfunc_list(self): + # GH 13354 + margins_name = "Weekly" + costs = DataFrame( + { + "item": ["bacon", "cheese", "bacon", "cheese"], + "cost": [2.5, 4.5, 3.2, 3.3], + "day": ["ME", "ME", "T", "T"], + } + ) + table = costs.pivot_table( + index="item", + columns="day", + margins=True, + margins_name=margins_name, + aggfunc=["mean", "max"], + ) + ix = Index(["bacon", "cheese", margins_name], name="item") + tups = [ + ("mean", "cost", "ME"), + ("mean", "cost", "T"), + ("mean", "cost", margins_name), + ("max", "cost", "ME"), + ("max", "cost", "T"), + ("max", "cost", margins_name), + ] + cols = MultiIndex.from_tuples(tups, names=[None, None, "day"]) + expected = DataFrame(table.values, index=ix, columns=cols) + tm.assert_frame_equal(table, expected) + + def test_categorical_margins(self, observed): + # GH 10989 + df = DataFrame( + {"x": np.arange(8), "y": np.arange(8) // 4, "z": np.arange(8) % 2} + ) + + expected = DataFrame([[1.0, 2.0, 1.5], [5, 6, 5.5], [3, 4, 3.5]]) + expected.index = Index([0, 1, "All"], name="y") + expected.columns = Index([0, 1, "All"], name="z") + + table = df.pivot_table("x", "y", "z", dropna=observed, margins=True) + tm.assert_frame_equal(table, expected) + + def test_categorical_margins_category(self, observed): + df = DataFrame( + {"x": np.arange(8), "y": np.arange(8) // 4, "z": np.arange(8) % 2} + ) + + expected = DataFrame([[1.0, 2.0, 1.5], [5, 6, 5.5], [3, 4, 3.5]]) + expected.index = Index([0, 1, "All"], name="y") + expected.columns = Index([0, 1, "All"], name="z") + + df.y = df.y.astype("category") + df.z = df.z.astype("category") + table = df.pivot_table( + "x", "y", "z", dropna=observed, margins=True, observed=False + ) + tm.assert_frame_equal(table, expected) + + def test_margins_casted_to_float(self): + # GH 24893 + df = DataFrame( + { + "A": [2, 4, 6, 8], + "B": [1, 4, 5, 8], + "C": [1, 3, 4, 6], + "D": ["X", "X", "Y", "Y"], + } + ) + + result = pivot_table(df, index="D", margins=True) + expected = DataFrame( + {"A": [3.0, 7.0, 5], "B": [2.5, 6.5, 4.5], "C": [2.0, 5.0, 3.5]}, + index=Index(["X", "Y", "All"], name="D"), + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_with_categorical(self, observed, ordered): + # gh-21370 + idx = [np.nan, "low", "high", "low", np.nan] + col = [np.nan, "A", "B", np.nan, "A"] + df = DataFrame( + { + "In": Categorical(idx, categories=["low", "high"], ordered=ordered), + "Col": Categorical(col, categories=["A", "B"], ordered=ordered), + "Val": range(1, 6), + } + ) + # case with index/columns/value + result = df.pivot_table( + index="In", columns="Col", values="Val", observed=observed + ) + + expected_cols = pd.CategoricalIndex(["A", "B"], ordered=ordered, name="Col") + + expected = DataFrame(data=[[2.0, np.nan], [np.nan, 3.0]], columns=expected_cols) + expected.index = Index( + Categorical(["low", "high"], categories=["low", "high"], ordered=ordered), + name="In", + ) + + tm.assert_frame_equal(result, expected) + + # case with columns/value + result = df.pivot_table(columns="Col", values="Val", observed=observed) + + expected = DataFrame( + data=[[3.5, 3.0]], columns=expected_cols, index=Index(["Val"]) + ) + + tm.assert_frame_equal(result, expected) + + def test_categorical_aggfunc(self, observed): + # GH 9534 + df = DataFrame( + {"C1": ["A", "B", "C", "C"], "C2": ["a", "a", "b", "b"], "V": [1, 2, 3, 4]} + ) + df["C1"] = df["C1"].astype("category") + result = df.pivot_table( + "V", + index="C1", + columns="C2", + dropna=observed, + aggfunc="count", + observed=False, + ) + + expected_index = pd.CategoricalIndex( + ["A", "B", "C"], categories=["A", "B", "C"], ordered=False, name="C1" + ) + expected_columns = Index(["a", "b"], name="C2") + expected_data = np.array([[1, 0], [1, 0], [0, 2]], dtype=np.int64) + expected = DataFrame( + expected_data, index=expected_index, columns=expected_columns + ) + tm.assert_frame_equal(result, expected) + + def test_categorical_pivot_index_ordering(self, observed): + # GH 8731 + df = DataFrame( + { + "Sales": [100, 120, 220], + "Month": ["January", "January", "January"], + "Year": [2013, 2014, 2013], + } + ) + months = [ + "January", + "February", + "March", + "April", + "May", + "June", + "July", + "August", + "September", + "October", + "November", + "December", + ] + df["Month"] = df["Month"].astype("category").cat.set_categories(months) + result = df.pivot_table( + values="Sales", + index="Month", + columns="Year", + observed=observed, + aggfunc="sum", + ) + expected_columns = Index([2013, 2014], name="Year", dtype="int64") + expected_index = pd.CategoricalIndex( + months, categories=months, ordered=False, name="Month" + ) + expected_data = [[320, 120]] + [[0, 0]] * 11 + expected = DataFrame( + expected_data, index=expected_index, columns=expected_columns + ) + if observed: + expected = expected.loc[["January"]] + + tm.assert_frame_equal(result, expected) + + def test_pivot_table_not_series(self): + # GH 4386 + # pivot_table always returns a DataFrame + # when values is not list like and columns is None + # and aggfunc is not instance of list + df = DataFrame({"col1": [3, 4, 5], "col2": ["C", "D", "E"], "col3": [1, 3, 9]}) + + result = df.pivot_table("col1", index=["col3", "col2"], aggfunc="sum") + m = MultiIndex.from_arrays([[1, 3, 9], ["C", "D", "E"]], names=["col3", "col2"]) + expected = DataFrame([3, 4, 5], index=m, columns=["col1"]) + + tm.assert_frame_equal(result, expected) + + result = df.pivot_table("col1", index="col3", columns="col2", aggfunc="sum") + expected = DataFrame( + [[3, np.nan, np.nan], [np.nan, 4, np.nan], [np.nan, np.nan, 5]], + index=Index([1, 3, 9], name="col3"), + columns=Index(["C", "D", "E"], name="col2"), + ) + + tm.assert_frame_equal(result, expected) + + result = df.pivot_table("col1", index="col3", aggfunc=["sum"]) + m = MultiIndex.from_arrays([["sum"], ["col1"]]) + expected = DataFrame([3, 4, 5], index=Index([1, 3, 9], name="col3"), columns=m) + + tm.assert_frame_equal(result, expected) + + def test_pivot_margins_name_unicode(self): + # issue #13292 + greek = "\u0394\u03bf\u03ba\u03b9\u03bc\u03ae" + frame = DataFrame({"foo": [1, 2, 3]}, columns=Index(["foo"], dtype=object)) + table = pivot_table( + frame, index=["foo"], aggfunc=len, margins=True, margins_name=greek + ) + index = Index([1, 2, 3, greek], dtype="object", name="foo") + expected = DataFrame(index=index, columns=[]) + tm.assert_frame_equal(table, expected) + + def test_pivot_string_as_func(self): + # GH #18713 + # for correctness purposes + data = DataFrame( + { + "A": [ + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + "foo", + "foo", + "foo", + ], + "B": [ + "one", + "one", + "one", + "two", + "one", + "one", + "one", + "two", + "two", + "two", + "one", + ], + "C": range(11), + } + ) + + result = pivot_table(data, index="A", columns="B", aggfunc="sum") + mi = MultiIndex( + levels=[["C"], ["one", "two"]], codes=[[0, 0], [0, 1]], names=[None, "B"] + ) + expected = DataFrame( + {("C", "one"): {"bar": 15, "foo": 13}, ("C", "two"): {"bar": 7, "foo": 20}}, + columns=mi, + ).rename_axis("A") + tm.assert_frame_equal(result, expected) + + result = pivot_table(data, index="A", columns="B", aggfunc=["sum", "mean"]) + mi = MultiIndex( + levels=[["sum", "mean"], ["C"], ["one", "two"]], + codes=[[0, 0, 1, 1], [0, 0, 0, 0], [0, 1, 0, 1]], + names=[None, None, "B"], + ) + expected = DataFrame( + { + ("mean", "C", "one"): {"bar": 5.0, "foo": 3.25}, + ("mean", "C", "two"): {"bar": 7.0, "foo": 6.666666666666667}, + ("sum", "C", "one"): {"bar": 15, "foo": 13}, + ("sum", "C", "two"): {"bar": 7, "foo": 20}, + }, + columns=mi, + ).rename_axis("A") + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("kwargs", [{"a": 2}, {"a": 2, "b": 3}, {"b": 3, "a": 2}]) + def test_pivot_table_kwargs(self, kwargs): + # GH#57884 + def f(x, a, b=3): + return x.sum() * a + b + + def g(x): + return f(x, **kwargs) + + df = DataFrame( + { + "A": ["good", "bad", "good", "bad", "good"], + "B": ["one", "two", "one", "three", "two"], + "X": [2, 5, 4, 20, 10], + } + ) + result = pivot_table( + df, index="A", columns="B", values="X", aggfunc=f, **kwargs + ) + expected = pivot_table(df, index="A", columns="B", values="X", aggfunc=g) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "kwargs", [{}, {"b": 10}, {"a": 3}, {"a": 3, "b": 10}, {"b": 10, "a": 3}] + ) + def test_pivot_table_kwargs_margin(self, data, kwargs): + # GH#57884 + def f(x, a=5, b=7): + return (x.sum() + b) * a + + def g(x): + return f(x, **kwargs) + + result = data.pivot_table( + values="D", + index=["A", "B"], + columns="C", + aggfunc=f, + margins=True, + fill_value=0, + **kwargs, + ) + + expected = data.pivot_table( + values="D", + index=["A", "B"], + columns="C", + aggfunc=g, + margins=True, + fill_value=0, + ) + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "f, f_numpy", + [ + ("sum", np.sum), + ("mean", np.mean), + ("min", np.min), + (["sum", "mean"], [np.sum, np.mean]), + (["sum", "min"], [np.sum, np.min]), + (["max", "mean"], [np.max, np.mean]), + ], + ) + def test_pivot_string_func_vs_func(self, f, f_numpy, data): + # GH #18713 + # for consistency purposes + data = data.drop(columns="C") + result = pivot_table(data, index="A", columns="B", aggfunc=f) + expected = pivot_table(data, index="A", columns="B", aggfunc=f_numpy) + tm.assert_frame_equal(result, expected) + + @pytest.mark.slow + def test_pivot_number_of_levels_larger_than_int32_warns( + self, performance_warning, monkeypatch + ): + # GH 20601 + # GH 26314: Change ValueError to PerformanceWarning + class MockUnstacker(reshape_lib._Unstacker): + def __init__(self, *args, **kwargs) -> None: + # __init__ will raise the warning + super().__init__(*args, **kwargs) + raise Exception("Don't compute final result.") + + def _make_selectors(self) -> None: + pass + + with monkeypatch.context() as m: + m.setattr(reshape_lib, "_Unstacker", MockUnstacker) + df = DataFrame( + {"ind1": np.arange(2**16), "ind2": np.arange(2**16), "count": 0} + ) + + msg = "The following operation may generate" + with tm.assert_produces_warning(performance_warning, match=msg): + with pytest.raises(Exception, match="Don't compute final result."): + df.pivot_table( + index="ind1", columns="ind2", values="count", aggfunc="count" + ) + + def test_pivot_table_aggfunc_dropna(self, dropna): + # GH 22159 + df = DataFrame( + { + "fruit": ["apple", "peach", "apple"], + "size": [1, 1, 2], + "taste": [7, 6, 6], + } + ) + + def ret_one(x): + return 1 + + def ret_sum(x): + return sum(x) + + def ret_none(x): + return np.nan + + result = pivot_table( + df, columns="fruit", aggfunc=[ret_sum, ret_none, ret_one], dropna=dropna + ) + + data = [[3, 1, np.nan, np.nan, 1, 1], [13, 6, np.nan, np.nan, 1, 1]] + col = MultiIndex.from_product( + [["ret_sum", "ret_none", "ret_one"], ["apple", "peach"]], + names=[None, "fruit"], + ) + expected = DataFrame(data, index=["size", "taste"], columns=col) + + if dropna: + expected = expected.dropna(axis="columns") + + tm.assert_frame_equal(result, expected) + + def test_pivot_table_aggfunc_scalar_dropna(self, dropna): + # GH 22159 + df = DataFrame( + {"A": ["one", "two", "one"], "x": [3, np.nan, 2], "y": [1, np.nan, np.nan]} + ) + + result = pivot_table(df, columns="A", aggfunc="mean", dropna=dropna) + + data = [[2.5, np.nan], [1, np.nan]] + col = Index(["one", "two"], name="A") + expected = DataFrame(data, index=["x", "y"], columns=col) + + if dropna: + expected = expected.dropna(axis="columns") + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("margins", [True, False]) + def test_pivot_table_empty_aggfunc(self, margins): + # GH 9186 & GH 13483 & GH 49240 + df = DataFrame( + { + "A": [2, 2, 3, 3, 2], + "id": [5, 6, 7, 8, 9], + "C": ["p", "q", "q", "p", "q"], + "D": [None, None, None, None, None], + } + ) + result = df.pivot_table( + index="A", columns="D", values="id", aggfunc=np.size, margins=margins + ) + exp_cols = Index([], name="D") + expected = DataFrame(index=Index([], dtype="int64", name="A"), columns=exp_cols) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_no_column_raises(self): + # GH 10326 + def agg(arr): + return np.mean(arr) + + df = DataFrame({"X": [0, 0, 1, 1], "Y": [0, 1, 0, 1], "Z": [10, 20, 30, 40]}) + with pytest.raises(KeyError, match="notpresent"): + df.pivot_table("notpresent", "X", "Y", aggfunc=agg) + + def test_pivot_table_multiindex_columns_doctest_case(self): + # The relevant characteristic is that the call + # to maybe_downcast_to_dtype(agged[v], data[v].dtype) in + # __internal_pivot_table has `agged[v]` a DataFrame instead of Series, + # In this case this is because agged.columns is a MultiIndex and 'v' + # is only indexing on its first level. + df = DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "C": [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + "E": [2, 4, 5, 5, 6, 6, 8, 9, 9], + } + ) + + table = pivot_table( + df, + values=["D", "E"], + index=["A", "C"], + aggfunc={"D": "mean", "E": ["min", "max", "mean"]}, + ) + cols = MultiIndex.from_tuples( + [("D", "mean"), ("E", "max"), ("E", "mean"), ("E", "min")] + ) + index = MultiIndex.from_tuples( + [("bar", "large"), ("bar", "small"), ("foo", "large"), ("foo", "small")], + names=["A", "C"], + ) + vals = np.array( + [ + [5.5, 9.0, 7.5, 6.0], + [5.5, 9.0, 8.5, 8.0], + [2.0, 5.0, 4.5, 4.0], + [2.33333333, 6.0, 4.33333333, 2.0], + ] + ) + expected = DataFrame(vals, columns=cols, index=index) + expected[("E", "min")] = expected[("E", "min")].astype(np.int64) + expected[("E", "max")] = expected[("E", "max")].astype(np.int64) + tm.assert_frame_equal(table, expected) + + def test_pivot_table_sort_false(self): + # GH#39143 + df = DataFrame( + { + "a": ["d1", "d4", "d3"], + "col": ["a", "b", "c"], + "num": [23, 21, 34], + "year": ["2018", "2018", "2019"], + } + ) + result = df.pivot_table( + index=["a", "col"], columns="year", values="num", aggfunc="sum", sort=False + ) + expected = DataFrame( + [[23, np.nan], [21, np.nan], [np.nan, 34]], + columns=Index(["2018", "2019"], name="year"), + index=MultiIndex.from_arrays( + [["d1", "d4", "d3"], ["a", "b", "c"]], names=["a", "col"] + ), + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_nullable_margins(self): + # GH#48681 + df = DataFrame( + {"a": "A", "b": [1, 2], "sales": Series([10, 11], dtype="Int64")} + ) + + result = df.pivot_table(index="b", columns="a", margins=True, aggfunc="sum") + expected = DataFrame( + [[10, 10], [11, 11], [21, 21]], + index=Index([1, 2, "All"], name="b"), + columns=MultiIndex.from_tuples( + [("sales", "A"), ("sales", "All")], names=[None, "a"] + ), + dtype="Int64", + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_sort_false_with_multiple_values(self): + df = DataFrame( + { + "firstname": ["John", "Michael"], + "lastname": ["Foo", "Bar"], + "height": [173, 182], + "age": [47, 33], + } + ) + result = df.pivot_table( + index=["lastname", "firstname"], values=["height", "age"], sort=False + ) + expected = DataFrame( + [[173.0, 47.0], [182.0, 33.0]], + columns=["height", "age"], + index=MultiIndex.from_tuples( + [("Foo", "John"), ("Bar", "Michael")], + names=["lastname", "firstname"], + ), + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_with_margins_and_numeric_columns(self): + # GH 26568 + df = DataFrame([["a", "x", 1], ["a", "y", 2], ["b", "y", 3], ["b", "z", 4]]) + df.columns = [10, 20, 30] + + result = df.pivot_table( + index=10, columns=20, values=30, aggfunc="sum", fill_value=0, margins=True + ) + + expected = DataFrame([[1, 2, 0, 3], [0, 3, 4, 7], [1, 5, 4, 10]]) + expected.columns = ["x", "y", "z", "All"] + expected.index = ["a", "b", "All"] + expected.columns.name = 20 + expected.index.name = 10 + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "dtype,expected_dtype", [("Int64", "Float64"), ("int64", "float64")] + ) + def test_pivot_ea_dtype_dropna(self, dropna, dtype, expected_dtype): + # GH#47477 + # GH#47971 + df = DataFrame({"x": "a", "y": "b", "age": Series([20, 40], dtype=dtype)}) + result = df.pivot_table( + index="x", columns="y", values="age", aggfunc="mean", dropna=dropna + ) + expected = DataFrame( + [[30]], + index=Index(["a"], name="x"), + columns=Index(["b"], name="y"), + dtype=expected_dtype, + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_datetime_warning(self): + # GH#48683 + df = DataFrame( + { + "a": "A", + "b": [1, 2], + "date": pd.Timestamp("2019-12-31"), + "sales": [10.0, 11], + } + ) + with tm.assert_produces_warning(None): + result = df.pivot_table( + index=["b", "date"], columns="a", margins=True, aggfunc="sum" + ) + expected = DataFrame( + [[10.0, 10.0], [11.0, 11.0], [21.0, 21.0]], + index=MultiIndex.from_arrays( + [ + Index([1, 2, "All"], name="b"), + Index( + [pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31"), ""], + dtype=object, + name="date", + ), + ] + ), + columns=MultiIndex.from_tuples( + [("sales", "A"), ("sales", "All")], names=[None, "a"] + ), + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_with_mixed_nested_tuples(self): + # GH 50342 + df = DataFrame( + { + "A": ["foo", "foo", "foo", "foo", "foo", "bar", "bar", "bar", "bar"], + "B": ["one", "one", "one", "two", "two", "one", "one", "two", "two"], + "C": [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + "D": [1, 2, 2, 3, 3, 4, 5, 6, 7], + "E": [2, 4, 5, 5, 6, 6, 8, 9, 9], + ("col5",): [ + "foo", + "foo", + "foo", + "foo", + "foo", + "bar", + "bar", + "bar", + "bar", + ], + ("col6", 6): [ + "one", + "one", + "one", + "two", + "two", + "one", + "one", + "two", + "two", + ], + (7, "seven"): [ + "small", + "large", + "large", + "small", + "small", + "large", + "small", + "small", + "large", + ], + } + ) + result = pivot_table( + df, values="D", index=["A", "B"], columns=[(7, "seven")], aggfunc="sum" + ) + expected = DataFrame( + [[4.0, 5.0], [7.0, 6.0], [4.0, 1.0], [np.nan, 6.0]], + columns=Index(["large", "small"], name=(7, "seven")), + index=MultiIndex.from_arrays( + [["bar", "bar", "foo", "foo"], ["one", "two"] * 2], names=["A", "B"] + ), + ) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_aggfunc_nunique_with_different_values(self): + test = DataFrame( + { + "a": range(10), + "b": range(10), + "c": range(10), + "d": range(10), + } + ) + + columnval = MultiIndex.from_arrays( + [ + ["nunique" for i in range(10)], + ["c" for i in range(10)], + range(10), + ], + names=(None, None, "b"), + ) + nparr = np.full((10, 10), np.nan) + np.fill_diagonal(nparr, 1.0) + + expected = DataFrame(nparr, index=Index(range(10), name="a"), columns=columnval) + result = test.pivot_table( + index=[ + "a", + ], + columns=[ + "b", + ], + values=[ + "c", + ], + aggfunc=["nunique"], + ) + + tm.assert_frame_equal(result, expected) + + def test_pivot_table_index_and_column_keys_with_nan(self, dropna): + # GH#61113 + data = {"row": [None, *range(4)], "col": [*range(4), None], "val": range(5)} + df = DataFrame(data) + result = df.pivot_table(values="val", index="row", columns="col", dropna=dropna) + e_axis = [*range(4), None] + nan = np.nan + e_data = [ + [nan, 1.0, nan, nan, nan], + [nan, nan, 2.0, nan, nan], + [nan, nan, nan, 3.0, nan], + [nan, nan, nan, nan, 4.0], + [0.0, nan, nan, nan, nan], + ] + expected = DataFrame( + data=e_data, + index=Index(data=e_axis, name="row"), + columns=Index(data=e_axis, name="col"), + ) + if dropna: + expected = expected.loc[[0, 1, 2], [1, 2, 3]] + + tm.assert_frame_equal(left=result, right=expected) + + @pytest.mark.parametrize( + "index, columns, e_data, e_index, e_cols", + [ + ( + "Category", + "Value", + [ + [1.0, np.nan, 1.0, np.nan], + [np.nan, 1.0, np.nan, 1.0], + ], + Index(data=["A", "B"], name="Category"), + Index(data=[10, 20, 40, 50], name="Value"), + ), + ( + "Value", + "Category", + [ + [1.0, np.nan], + [np.nan, 1.0], + [1.0, np.nan], + [np.nan, 1.0], + ], + Index(data=[10, 20, 40, 50], name="Value"), + Index(data=["A", "B"], name="Category"), + ), + ], + ids=["values-and-columns", "values-and-index"], + ) + def test_pivot_table_values_as_two_params( + self, index, columns, e_data, e_index, e_cols + ): + # GH#57876 + data = {"Category": ["A", "B", "A", "B"], "Value": [10, 20, 40, 50]} + df = DataFrame(data) + result = df.pivot_table( + index=index, columns=columns, values="Value", aggfunc="count" + ) + expected = DataFrame(data=e_data, index=e_index, columns=e_cols) + tm.assert_frame_equal(result, expected) + + def test_pivot_table_margins_include_nan_groups(self): + # GH#61509 + df = DataFrame( + { + "i": [1, 2, 3], + "g1": ["a", "b", "b"], + "g2": ["x", None, None], + } + ) + + result = df.pivot_table( + index="g1", + columns="g2", + values="i", + aggfunc="count", + dropna=False, + margins=True, + ) + + expected = DataFrame( + { + "x": {"a": 1.0, "b": np.nan, "All": 1.0}, + np.nan: {"a": np.nan, "b": 2.0, "All": 2.0}, + "All": {"a": 1.0, "b": 2.0, "All": 3.0}, + } + ) + expected.index.name = "g1" + expected.columns.name = "g2" + tm.assert_frame_equal(result, expected, check_dtype=False) + + +class TestPivot: + def test_pivot(self): + data = { + "index": ["A", "B", "C", "C", "B", "A"], + "columns": ["One", "One", "One", "Two", "Two", "Two"], + "values": [1.0, 2.0, 3.0, 3.0, 2.0, 1.0], + } + + frame = DataFrame(data) + pivoted = frame.pivot(index="index", columns="columns", values="values") + + expected = DataFrame( + { + "One": {"A": 1.0, "B": 2.0, "C": 3.0}, + "Two": {"A": 1.0, "B": 2.0, "C": 3.0}, + } + ) + + expected.index.name, expected.columns.name = "index", "columns" + tm.assert_frame_equal(pivoted, expected) + + # name tracking + assert pivoted.index.name == "index" + assert pivoted.columns.name == "columns" + + # don't specify values + pivoted = frame.pivot(index="index", columns="columns") + assert pivoted.index.name == "index" + assert pivoted.columns.names == (None, "columns") + + def test_pivot_duplicates(self): + data = DataFrame( + { + "a": ["bar", "bar", "foo", "foo", "foo"], + "b": ["one", "two", "one", "one", "two"], + "c": [1.0, 2.0, 3.0, 3.0, 4.0], + } + ) + with pytest.raises(ValueError, match="duplicate entries"): + data.pivot(index="a", columns="b", values="c") + + def test_pivot_empty(self): + df = DataFrame(columns=["a", "b", "c"]) + result = df.pivot(index="a", columns="b", values="c") + expected = DataFrame(index=[], columns=[]) + tm.assert_frame_equal(result, expected, check_names=False) + + def test_pivot_integer_bug(self, any_string_dtype): + df = DataFrame( + data=[("A", "1", "A1"), ("B", "2", "B2")], dtype=any_string_dtype + ) + + result = df.pivot(index=1, columns=0, values=2) + expected_columns = Index(["A", "B"], name=0, dtype=any_string_dtype) + tm.assert_index_equal(result.columns, expected_columns) + + def test_pivot_index_none(self): + # GH#3962 + data = { + "index": ["A", "B", "C", "C", "B", "A"], + "columns": ["One", "One", "One", "Two", "Two", "Two"], + "values": [1.0, 2.0, 3.0, 3.0, 2.0, 1.0], + } + + frame = DataFrame(data).set_index("index") + result = frame.pivot(columns="columns", values="values") + expected = DataFrame( + { + "One": {"A": 1.0, "B": 2.0, "C": 3.0}, + "Two": {"A": 1.0, "B": 2.0, "C": 3.0}, + } + ) + + expected.index.name, expected.columns.name = "index", "columns" + tm.assert_frame_equal(result, expected) + + # omit values + result = frame.pivot(columns="columns") + + expected.columns = MultiIndex.from_tuples( + [("values", "One"), ("values", "Two")], names=[None, "columns"] + ) + expected.index.name = "index" + tm.assert_frame_equal(result, expected, check_names=False) + assert result.index.name == "index" + assert result.columns.names == (None, "columns") + expected.columns = expected.columns.droplevel(0) + result = frame.pivot(columns="columns", values="values") + + expected.columns.name = "columns" + tm.assert_frame_equal(result, expected) + + def test_pivot_index_list_values_none_immutable_args(self): + # GH37635 + df = DataFrame( + { + "lev1": [1, 1, 1, 2, 2, 2], + "lev2": [1, 1, 2, 1, 1, 2], + "lev3": [1, 2, 1, 2, 1, 2], + "lev4": [1, 2, 3, 4, 5, 6], + "values": [0, 1, 2, 3, 4, 5], + } + ) + index = ["lev1", "lev2"] + columns = ["lev3"] + result = df.pivot(index=index, columns=columns) + + expected = DataFrame( + np.array( + [ + [1.0, 2.0, 0.0, 1.0], + [3.0, np.nan, 2.0, np.nan], + [5.0, 4.0, 4.0, 3.0], + [np.nan, 6.0, np.nan, 5.0], + ] + ), + index=MultiIndex.from_arrays( + [(1, 1, 2, 2), (1, 2, 1, 2)], names=["lev1", "lev2"] + ), + columns=MultiIndex.from_arrays( + [("lev4", "lev4", "values", "values"), (1, 2, 1, 2)], + names=[None, "lev3"], + ), + ) + + tm.assert_frame_equal(result, expected) + + assert index == ["lev1", "lev2"] + assert columns == ["lev3"] + + def test_pivot_columns_not_given(self): + # GH#48293 + df = DataFrame({"a": [1], "b": 1}) + with pytest.raises(TypeError, match="missing 1 required keyword-only argument"): + df.pivot() + + # this still fails because columns=None gets passed down to unstack as level=None + # while at that point None was converted to NaN + @pytest.mark.xfail( + using_string_dtype(), reason="TODO(infer_string) None is cast to NaN" + ) + def test_pivot_columns_is_none(self): + # GH#48293 + df = DataFrame({None: [1], "b": 2, "c": 3}) + result = df.pivot(columns=None) + expected = DataFrame({("b", 1): [2], ("c", 1): 3}) + tm.assert_frame_equal(result, expected) + + result = df.pivot(columns=None, index="b") + expected = DataFrame({("c", 1): 3}, index=Index([2], name="b")) + tm.assert_frame_equal(result, expected) + + result = df.pivot(columns=None, index="b", values="c") + expected = DataFrame({1: 3}, index=Index([2], name="b")) + tm.assert_frame_equal(result, expected) + + def test_pivot_index_is_none(self, using_infer_string): + # GH#48293 + df = DataFrame({None: [1], "b": 2, "c": 3}) + + result = df.pivot(columns="b", index=None) + expected = DataFrame({("c", 2): 3}, index=[1]) + expected.columns.names = [None, "b"] + tm.assert_frame_equal(result, expected) + + result = df.pivot(columns="b", index=None, values="c") + expected = DataFrame(3, index=[1], columns=Index([2], name="b")) + if using_infer_string: + expected.index.name = np.nan + tm.assert_frame_equal(result, expected) + + def test_pivot_values_is_none(self): + # GH#48293 + df = DataFrame({None: [1], "b": 2, "c": 3}) + + result = df.pivot(columns="b", index="c", values=None) + expected = DataFrame( + 1, index=Index([3], name="c"), columns=Index([2], name="b") + ) + tm.assert_frame_equal(result, expected) + + result = df.pivot(columns="b", values=None) + expected = DataFrame(1, index=[0], columns=Index([2], name="b")) + tm.assert_frame_equal(result, expected) + + def test_pivot_not_changing_index_name(self): + # GH#52692 + df = DataFrame({"one": ["a"], "two": 0, "three": 1}) + expected = df.copy(deep=True) + df.pivot(index="one", columns="two", values="three") + tm.assert_frame_equal(df, expected) + + def test_pivot_table_empty_dataframe_correct_index(self): + # GH 21932 + df = DataFrame([], columns=["a", "b", "value"]) + pivot = df.pivot_table(index="a", columns="b", values="value", aggfunc="count") + + expected = Index([], dtype="object", name="b") + tm.assert_index_equal(pivot.columns, expected) + + def test_pivot_table_handles_explicit_datetime_types(self): + # GH#43574 + df = DataFrame( + [ + {"a": "x", "date_str": "2023-01-01", "amount": 1}, + {"a": "y", "date_str": "2023-01-02", "amount": 2}, + {"a": "z", "date_str": "2023-01-03", "amount": 3}, + ] + ) + df["date"] = pd.to_datetime(df["date_str"]) + + with tm.assert_produces_warning(False): + pivot = df.pivot_table( + index=["a", "date"], values=["amount"], aggfunc="sum", margins=True + ) + + expected = MultiIndex.from_tuples( + [ + ("x", datetime.strptime("2023-01-01 00:00:00", "%Y-%m-%d %H:%M:%S")), + ("y", datetime.strptime("2023-01-02 00:00:00", "%Y-%m-%d %H:%M:%S")), + ("z", datetime.strptime("2023-01-03 00:00:00", "%Y-%m-%d %H:%M:%S")), + ("All", ""), + ], + names=["a", "date"], + ) + tm.assert_index_equal(pivot.index, expected) + + def test_pivot_table_with_margins_and_numeric_column_names(self): + # GH#26568 + df = DataFrame([["a", "x", 1], ["a", "y", 2], ["b", "y", 3], ["b", "z", 4]]) + + result = df.pivot_table( + index=0, columns=1, values=2, aggfunc="sum", fill_value=0, margins=True + ) + + expected = DataFrame( + [[1, 2, 0, 3], [0, 3, 4, 7], [1, 5, 4, 10]], + columns=Index(["x", "y", "z", "All"], name=1), + index=Index(["a", "b", "All"], name=0), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("m", [1, 10]) + def test_unstack_copy(self, m): + # GH#56633 + levels = np.arange(m) + index = MultiIndex.from_product([levels] * 2) + values = np.arange(m * m * 100).reshape(m * m, 100) + df = DataFrame(values, index, np.arange(100)) + df_orig = df.copy() + result = df.unstack(sort=False) + result.iloc[0, 0] = -1 + tm.assert_frame_equal(df, df_orig) + + def test_pivot_empty_with_datetime(self): + # GH#59126 + df = DataFrame( + { + "timestamp": Series([], dtype=pd.DatetimeTZDtype(tz="UTC")), + "category": Series([], dtype=str), + "value": Series([], dtype=str), + } + ) + df_pivoted = df.pivot_table( + index="category", columns="value", values="timestamp" + ) + assert df_pivoted.empty + + def test_pivot_margins_with_none_index(self): + # GH#58722 + df = DataFrame( + { + "x": [1, 1, 2], + "y": [3, 3, 4], + "z": [5, 5, 6], + "w": [7, 8, 9], + } + ) + result = df.pivot_table( + index=None, + columns=["y", "z"], + values="w", + margins=True, + aggfunc="count", + ) + expected = DataFrame( + [[2, 2, 1, 1]], + index=["w"], + columns=MultiIndex( + levels=[[3, 4], [5, 6, "All"]], + codes=[[0, 0, 1, 1], [0, 2, 1, 2]], + names=["y", "z"], + ), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning") + def test_pivot_with_pyarrow_categorical(self): + # GH#53051 + pa = pytest.importorskip("pyarrow") + + df = DataFrame( + {"string_column": ["A", "B", "C"], "number_column": [1, 2, 3]} + ).astype( + { + "string_column": ArrowDtype(pa.dictionary(pa.int32(), pa.string())), + "number_column": "float[pyarrow]", + } + ) + + df = df.pivot(columns=["string_column"], values=["number_column"]) + + multi_index = MultiIndex.from_arrays( + [["number_column", "number_column", "number_column"], ["A", "B", "C"]], + names=(None, "string_column"), + ) + df_expected = DataFrame( + [[1.0, np.nan, np.nan], [np.nan, 2.0, np.nan], [np.nan, np.nan, 3.0]], + columns=multi_index, + ) + tm.assert_frame_equal( + df, df_expected, check_dtype=False, check_column_type=False + ) + + @pytest.mark.parametrize("freq", ["D", "M", "Q", "Y"]) + def test_pivot_empty_dataframe_period_dtype(self, freq): + # GH#62705 + + dtype = pd.PeriodDtype(freq=freq) + df = DataFrame({"index": [], "columns": [], "values": []}) + df = df.astype({"values": dtype}) + result = df.pivot(index="index", columns="columns", values="values") + + expected_index = Index([], name="index", dtype="float64") + expected_columns = Index([], name="columns", dtype="float64") + expected = DataFrame( + index=expected_index, columns=expected_columns, dtype=dtype + ) + + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/reshape/test_pivot_multilevel.py b/pandas/tests/reshape/test_pivot_multilevel.py new file mode 100644 index 0000000000000000000000000000000000000000..af70210b37f3c0396b1dd972f64822001c04dc50 --- /dev/null +++ b/pandas/tests/reshape/test_pivot_multilevel.py @@ -0,0 +1,301 @@ +import numpy as np +import pytest + +from pandas._libs import lib + +import pandas as pd +from pandas import ( + Index, + MultiIndex, +) +import pandas._testing as tm + + +@pytest.mark.parametrize( + "input_index, input_columns, input_values, " + "expected_values, expected_columns, expected_index", + [ + ( + ["lev4"], + "lev3", + "values", + [ + [0.0, np.nan], + [np.nan, 1.0], + [2.0, np.nan], + [np.nan, 3.0], + [4.0, np.nan], + [np.nan, 5.0], + [6.0, np.nan], + [np.nan, 7.0], + ], + Index([1, 2], name="lev3"), + Index([1, 2, 3, 4, 5, 6, 7, 8], name="lev4"), + ), + ( + ["lev4"], + "lev3", + lib.no_default, + [ + [1.0, np.nan, 1.0, np.nan, 0.0, np.nan], + [np.nan, 1.0, np.nan, 1.0, np.nan, 1.0], + [1.0, np.nan, 2.0, np.nan, 2.0, np.nan], + [np.nan, 1.0, np.nan, 2.0, np.nan, 3.0], + [2.0, np.nan, 1.0, np.nan, 4.0, np.nan], + [np.nan, 2.0, np.nan, 1.0, np.nan, 5.0], + [2.0, np.nan, 2.0, np.nan, 6.0, np.nan], + [np.nan, 2.0, np.nan, 2.0, np.nan, 7.0], + ], + MultiIndex.from_tuples( + [ + ("lev1", 1), + ("lev1", 2), + ("lev2", 1), + ("lev2", 2), + ("values", 1), + ("values", 2), + ], + names=[None, "lev3"], + ), + Index([1, 2, 3, 4, 5, 6, 7, 8], name="lev4"), + ), + ( + ["lev1", "lev2"], + "lev3", + "values", + [[0, 1], [2, 3], [4, 5], [6, 7]], + Index([1, 2], name="lev3"), + MultiIndex.from_tuples( + [(1, 1), (1, 2), (2, 1), (2, 2)], names=["lev1", "lev2"] + ), + ), + ( + ["lev1", "lev2"], + "lev3", + lib.no_default, + [[1, 2, 0, 1], [3, 4, 2, 3], [5, 6, 4, 5], [7, 8, 6, 7]], + MultiIndex.from_tuples( + [("lev4", 1), ("lev4", 2), ("values", 1), ("values", 2)], + names=[None, "lev3"], + ), + MultiIndex.from_tuples( + [(1, 1), (1, 2), (2, 1), (2, 2)], names=["lev1", "lev2"] + ), + ), + ], +) +def test_pivot_list_like_index( + input_index, + input_columns, + input_values, + expected_values, + expected_columns, + expected_index, +): + # GH 21425, test when index is given a list + df = pd.DataFrame( + { + "lev1": [1, 1, 1, 1, 2, 2, 2, 2], + "lev2": [1, 1, 2, 2, 1, 1, 2, 2], + "lev3": [1, 2, 1, 2, 1, 2, 1, 2], + "lev4": [1, 2, 3, 4, 5, 6, 7, 8], + "values": [0, 1, 2, 3, 4, 5, 6, 7], + } + ) + + result = df.pivot(index=input_index, columns=input_columns, values=input_values) + expected = pd.DataFrame( + expected_values, columns=expected_columns, index=expected_index + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "input_index, input_columns, input_values, " + "expected_values, expected_columns, expected_index", + [ + ( + "lev4", + ["lev3"], + "values", + [ + [0.0, np.nan], + [np.nan, 1.0], + [2.0, np.nan], + [np.nan, 3.0], + [4.0, np.nan], + [np.nan, 5.0], + [6.0, np.nan], + [np.nan, 7.0], + ], + Index([1, 2], name="lev3"), + Index([1, 2, 3, 4, 5, 6, 7, 8], name="lev4"), + ), + ( + ["lev1", "lev2"], + ["lev3"], + "values", + [[0, 1], [2, 3], [4, 5], [6, 7]], + Index([1, 2], name="lev3"), + MultiIndex.from_tuples( + [(1, 1), (1, 2), (2, 1), (2, 2)], names=["lev1", "lev2"] + ), + ), + ( + ["lev1"], + ["lev2", "lev3"], + "values", + [[0, 1, 2, 3], [4, 5, 6, 7]], + MultiIndex.from_tuples( + [(1, 1), (1, 2), (2, 1), (2, 2)], names=["lev2", "lev3"] + ), + Index([1, 2], name="lev1"), + ), + ( + ["lev1", "lev2"], + ["lev3", "lev4"], + "values", + [ + [0.0, 1.0, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, 2.0, 3.0, np.nan, np.nan, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, 4.0, 5.0, np.nan, np.nan], + [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, 6.0, 7.0], + ], + MultiIndex.from_tuples( + [(1, 1), (2, 2), (1, 3), (2, 4), (1, 5), (2, 6), (1, 7), (2, 8)], + names=["lev3", "lev4"], + ), + MultiIndex.from_tuples( + [(1, 1), (1, 2), (2, 1), (2, 2)], names=["lev1", "lev2"] + ), + ), + ], +) +def test_pivot_list_like_columns( + input_index, + input_columns, + input_values, + expected_values, + expected_columns, + expected_index, +): + # GH 21425, test when columns is given a list + df = pd.DataFrame( + { + "lev1": [1, 1, 1, 1, 2, 2, 2, 2], + "lev2": [1, 1, 2, 2, 1, 1, 2, 2], + "lev3": [1, 2, 1, 2, 1, 2, 1, 2], + "lev4": [1, 2, 3, 4, 5, 6, 7, 8], + "values": [0, 1, 2, 3, 4, 5, 6, 7], + } + ) + + result = df.pivot(index=input_index, columns=input_columns, values=input_values) + expected = pd.DataFrame( + expected_values, columns=expected_columns, index=expected_index + ) + tm.assert_frame_equal(result, expected) + + +def test_pivot_multiindexed_rows_and_cols(): + # GH 36360 + + df = pd.DataFrame( + data=np.arange(12).reshape(4, 3), + columns=MultiIndex.from_tuples( + [(0, 0), (0, 1), (0, 2)], names=["col_L0", "col_L1"] + ), + index=MultiIndex.from_tuples( + [(0, 0, 0), (0, 0, 1), (1, 1, 1), (1, 0, 0)], + names=["idx_L0", "idx_L1", "idx_L2"], + ), + ) + + res = df.pivot_table( + index=["idx_L0"], + columns=["idx_L1"], + values=[(0, 1)], + aggfunc=lambda col: col.values.sum(), + ) + + expected = pd.DataFrame( + data=[[5, np.nan], [10, 7.0]], + columns=MultiIndex.from_tuples( + [(0, 1, 0), (0, 1, 1)], names=["col_L0", "col_L1", "idx_L1"] + ), + index=Index([0, 1], dtype="int64", name="idx_L0"), + ) + expected = expected.astype("float64") + + tm.assert_frame_equal(res, expected) + + +def test_pivot_df_multiindex_index_none(): + # GH 23955 + df = pd.DataFrame( + [ + ["A", "A1", "label1", 1], + ["A", "A2", "label2", 2], + ["B", "A1", "label1", 3], + ["B", "A2", "label2", 4], + ], + columns=["index_1", "index_2", "label", "value"], + ) + df = df.set_index(["index_1", "index_2"]) + + result = df.pivot(columns="label", values="value") + expected = pd.DataFrame( + [[1.0, np.nan], [np.nan, 2.0], [3.0, np.nan], [np.nan, 4.0]], + index=df.index, + columns=Index(["label1", "label2"], name="label"), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "index, columns, e_data, e_index, e_cols", + [ + ( + "index", + ["col", "value"], + [ + [50.0, np.nan, 100.0, np.nan], + [np.nan, 100.0, np.nan, 200.0], + ], + Index(data=["A", "B"], name="index"), + MultiIndex.from_arrays( + arrays=[[1, 1, 2, 2], [50, 100, 100, 200]], names=["col", "value"] + ), + ), + ( + ["index", "value"], + "col", + [ + [50.0, np.nan], + [np.nan, 100.0], + [100.0, np.nan], + [np.nan, 200.0], + ], + MultiIndex.from_arrays( + arrays=[["A", "A", "B", "B"], [50, 100, 100, 200]], + names=["index", "value"], + ), + Index(data=[1, 2], name="col"), + ), + ], + ids=["values-and-columns", "values-and-index"], +) +def test_pivot_table_multiindex_values_as_two_params( + index, columns, e_data, e_index, e_cols +): + # GH#61292 + data = [ + ["A", 1, 50, -1], + ["B", 1, 100, -2], + ["A", 2, 100, -2], + ["B", 2, 200, -4], + ] + df = pd.DataFrame(data=data, columns=["index", "col", "value", "extra"]) + result = df.pivot_table(values="value", index=index, columns=columns) + expected = pd.DataFrame(data=e_data, index=e_index, columns=e_cols) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/reshape/test_qcut.py b/pandas/tests/reshape/test_qcut.py new file mode 100644 index 0000000000000000000000000000000000000000..51617bc3536807fca79f06c8476dba0694f5d445 --- /dev/null +++ b/pandas/tests/reshape/test_qcut.py @@ -0,0 +1,308 @@ +import os + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + Categorical, + DatetimeIndex, + Interval, + IntervalIndex, + NaT, + Series, + Timedelta, + TimedeltaIndex, + Timestamp, + cut, + date_range, + isna, + qcut, + timedelta_range, +) +import pandas._testing as tm +from pandas.api.types import CategoricalDtype + +from pandas.tseries.offsets import Day + + +def test_qcut(): + arr = np.random.default_rng(2).standard_normal(1000) + + # We store the bins as Index that have been + # rounded to comparisons are a bit tricky. + labels, _ = qcut(arr, 4, retbins=True) + ex_bins = np.quantile(arr, [0, 0.25, 0.5, 0.75, 1.0]) + + result = labels.categories.left.values + assert np.allclose(result, ex_bins[:-1], atol=1e-2) + + result = labels.categories.right.values + assert np.allclose(result, ex_bins[1:], atol=1e-2) + + ex_levels = cut(arr, ex_bins, include_lowest=True) + tm.assert_categorical_equal(labels, ex_levels) + + +def test_qcut_bounds(): + arr = np.random.default_rng(2).standard_normal(1000) + + factor = qcut(arr, 10, labels=False) + assert len(np.unique(factor)) == 10 + + +def test_qcut_specify_quantiles(): + arr = np.random.default_rng(2).standard_normal(100) + factor = qcut(arr, [0, 0.25, 0.5, 0.75, 1.0]) + + expected = qcut(arr, 4) + tm.assert_categorical_equal(factor, expected) + + +def test_qcut_all_bins_same(): + with pytest.raises(ValueError, match="edges.*unique"): + qcut([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 3) + + +def test_qcut_include_lowest(): + values = np.arange(10) + ii = qcut(values, 4) + + ex_levels = IntervalIndex( + [ + Interval(-0.001, 2.25), + Interval(2.25, 4.5), + Interval(4.5, 6.75), + Interval(6.75, 9), + ] + ) + tm.assert_index_equal(ii.categories, ex_levels) + + +def test_qcut_nas(): + arr = np.random.default_rng(2).standard_normal(100) + arr[:20] = np.nan + + result = qcut(arr, 4) + assert isna(result[:20]).all() + + +def test_qcut_index(): + result = qcut([0, 2], 2) + intervals = [Interval(-0.001, 1), Interval(1, 2)] + + expected = Categorical(intervals, ordered=True) + tm.assert_categorical_equal(result, expected) + + +def test_qcut_binning_issues(datapath): + # see gh-1978, gh-1979 + cut_file = datapath(os.path.join("reshape", "data", "cut_data.csv")) + arr = np.loadtxt(cut_file) + result = qcut(arr, 20) + + starts = result.categories.left + ends = result.categories.right + assert (starts < ends).all() + assert (starts[1:] <= ends[:-1]).all() + + +def test_qcut_return_intervals(): + ser = Series([0, 1, 2, 3, 4, 5, 6, 7, 8]) + res = qcut(ser, [0, 0.333, 0.666, 1]) + + exp_levels = np.array( + [Interval(-0.001, 2.664), Interval(2.664, 5.328), Interval(5.328, 8)] + ) + exp = Series(exp_levels.take([0, 0, 0, 1, 1, 1, 2, 2, 2])).astype( + CategoricalDtype(ordered=True) + ) + tm.assert_series_equal(res, exp) + + +@pytest.mark.parametrize("labels", ["foo", 1, True]) +def test_qcut_incorrect_labels(labels): + # GH 13318 + values = range(5) + msg = "Bin labels must either be False, None or passed in as a list-like argument" + with pytest.raises(ValueError, match=msg): + qcut(values, 4, labels=labels) + + +@pytest.mark.parametrize("labels", [["a", "b", "c"], list(range(3))]) +def test_qcut_wrong_length_labels(labels): + # GH 13318 + values = range(10) + msg = "Bin labels must be one fewer than the number of bin edges" + with pytest.raises(ValueError, match=msg): + qcut(values, 4, labels=labels) + + +@pytest.mark.parametrize( + "labels, expected", + [ + (["a", "b", "c"], ["a", "b", "c"]), + (list(range(3)), [0, 1, 2]), + ], +) +def test_qcut_list_like_labels(labels, expected): + # GH 13318 + values = range(3) + result = qcut(values, 3, labels=labels) + expected = Categorical(expected, ordered=True) + tm.assert_categorical_equal(result, expected) + + +@pytest.mark.parametrize( + "kwargs,msg", + [ + ({"duplicates": "drop"}, None), + ({}, "Bin edges must be unique"), + ({"duplicates": "raise"}, "Bin edges must be unique"), + ({"duplicates": "foo"}, "invalid value for 'duplicates' parameter"), + ], +) +def test_qcut_duplicates_bin(kwargs, msg): + # see gh-7751 + values = [0, 0, 0, 0, 1, 2, 3] + + if msg is not None: + with pytest.raises(ValueError, match=msg): + qcut(values, 3, **kwargs) + else: + result = qcut(values, 3, **kwargs) + expected = IntervalIndex([Interval(-0.001, 1), Interval(1, 3)]) + tm.assert_index_equal(result.categories, expected) + + +@pytest.mark.parametrize( + "data,start,end", [(9.0, 8.999, 9.0), (0.0, -0.001, 0.0), (-9.0, -9.001, -9.0)] +) +@pytest.mark.parametrize("length", [1, 2]) +@pytest.mark.parametrize("labels", [None, False]) +def test_single_quantile(data, start, end, length, labels): + # see gh-15431 + ser = Series([data] * length) + result = qcut(ser, 1, labels=labels) + + if labels is None: + intervals = IntervalIndex([Interval(start, end)] * length, closed="right") + expected = Series(intervals).astype(CategoricalDtype(ordered=True)) + else: + expected = Series([0] * length, dtype=np.intp) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "ser", + [ + DatetimeIndex(["20180101", NaT, "20180103"]), + TimedeltaIndex(["0 days", NaT, "2 days"]), + ], + ids=lambda x: str(x.dtype), +) +def test_qcut_nat(ser, unit): + # see gh-19768 + ser = Series(ser) + ser = ser.dt.as_unit(unit) + td = Timedelta(1, unit=unit).as_unit(unit) + + left = Series([ser[0] - td, np.nan, ser[2] - Day()], dtype=ser.dtype) + right = Series([ser[2] - Day(), np.nan, ser[2]], dtype=ser.dtype) + intervals = IntervalIndex.from_arrays(left, right) + expected = Series(Categorical(intervals, ordered=True)) + + result = qcut(ser, 2) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("bins", [3, np.linspace(0, 1, 4)]) +def test_datetime_tz_qcut(bins): + # see gh-19872 + tz = "US/Eastern" + ser = Series(date_range("20130101", periods=3, tz=tz, unit="ns")) + + result = qcut(ser, bins) + expected = Series( + IntervalIndex( + [ + Interval( + Timestamp("2012-12-31 23:59:59.999999999", tz=tz), + Timestamp("2013-01-01 16:00:00", tz=tz), + ), + Interval( + Timestamp("2013-01-01 16:00:00", tz=tz), + Timestamp("2013-01-02 08:00:00", tz=tz), + ), + Interval( + Timestamp("2013-01-02 08:00:00", tz=tz), + Timestamp("2013-01-03 00:00:00", tz=tz), + ), + ] + ) + ).astype(CategoricalDtype(ordered=True)) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "arg,expected_bins", + [ + [ + timedelta_range("1day", periods=3), + TimedeltaIndex(["1 days", "2 days", "3 days"]), + ], + [ + date_range("20180101", periods=3), + DatetimeIndex(["2018-01-01", "2018-01-02", "2018-01-03"]), + ], + ], +) +def test_date_like_qcut_bins(arg, expected_bins, unit): + # see gh-19891 + arg = arg.as_unit(unit) + expected_bins = expected_bins.as_unit(unit) + ser = Series(arg) + result, result_bins = qcut(ser, 2, retbins=True) + tm.assert_index_equal(result_bins, expected_bins) + + +@pytest.mark.parametrize("bins", [6, 7]) +@pytest.mark.parametrize( + "box, compare", + [ + (Series, tm.assert_series_equal), + (np.array, tm.assert_categorical_equal), + (list, tm.assert_equal), + ], +) +def test_qcut_bool_coercion_to_int(bins, box, compare): + # issue 20303 + data_expected = box([0, 1, 1, 0, 1] * 10) + data_result = box([False, True, True, False, True] * 10) + expected = qcut(data_expected, bins, duplicates="drop") + result = qcut(data_result, bins, duplicates="drop") + compare(result, expected) + + +@pytest.mark.parametrize("q", [2, 5, 10]) +def test_qcut_nullable_integer(q, any_numeric_ea_dtype): + arr = pd.array(np.arange(100), dtype=any_numeric_ea_dtype) + arr[::2] = pd.NA + + result = qcut(arr, q) + expected = qcut(arr.astype(float), q) + + tm.assert_categorical_equal(result, expected) + + +@pytest.mark.parametrize("scale", [1.0, 1 / 3, 17.0]) +@pytest.mark.parametrize("q", [3, 7, 9]) +@pytest.mark.parametrize("precision", [1, 3, 16]) +def test_qcut_contains(scale, q, precision): + # GH-59355 + arr = (scale * np.arange(q + 1)).round(precision) + result = qcut(arr, q, precision=precision) + + for value, bucket in zip(arr, result): + assert value in bucket diff --git a/pandas/tests/reshape/test_union_categoricals.py b/pandas/tests/reshape/test_union_categoricals.py new file mode 100644 index 0000000000000000000000000000000000000000..081feae6fc43fef590d546ad040d4dc0f412028d --- /dev/null +++ b/pandas/tests/reshape/test_union_categoricals.py @@ -0,0 +1,369 @@ +import numpy as np +import pytest + +from pandas.core.dtypes.concat import union_categoricals + +import pandas as pd +from pandas import ( + Categorical, + CategoricalIndex, + Series, +) +import pandas._testing as tm + + +class TestUnionCategoricals: + @pytest.mark.parametrize( + "a, b, combined", + [ + (list("abc"), list("abd"), list("abcabd")), + ([0, 1, 2], [2, 3, 4], [0, 1, 2, 2, 3, 4]), + ([0, 1.2, 2], [2, 3.4, 4], [0, 1.2, 2, 2, 3.4, 4]), + ( + ["b", "b", np.nan, "a"], + ["a", np.nan, "c"], + ["b", "b", np.nan, "a", "a", np.nan, "c"], + ), + ( + pd.date_range("2014-01-01", "2014-01-05"), + pd.date_range("2014-01-06", "2014-01-07"), + pd.date_range("2014-01-01", "2014-01-07"), + ), + ( + pd.date_range("2014-01-01", "2014-01-05", tz="US/Central"), + pd.date_range("2014-01-06", "2014-01-07", tz="US/Central"), + pd.date_range("2014-01-01", "2014-01-07", tz="US/Central"), + ), + ( + pd.period_range("2014-01-01", "2014-01-05"), + pd.period_range("2014-01-06", "2014-01-07"), + pd.period_range("2014-01-01", "2014-01-07"), + ), + ], + ) + @pytest.mark.parametrize("box", [Categorical, CategoricalIndex, Series]) + def test_union_categorical(self, a, b, combined, box): + # GH 13361 + result = union_categoricals([box(Categorical(a)), box(Categorical(b))]) + expected = Categorical(combined) + tm.assert_categorical_equal(result, expected) + + def test_union_categorical_ordered_appearance(self): + # new categories ordered by appearance + s = Categorical(["x", "y", "z"]) + s2 = Categorical(["a", "b", "c"]) + result = union_categoricals([s, s2]) + expected = Categorical( + ["x", "y", "z", "a", "b", "c"], categories=["x", "y", "z", "a", "b", "c"] + ) + tm.assert_categorical_equal(result, expected) + + def test_union_categorical_ordered_true(self): + s = Categorical([0, 1.2, 2], ordered=True) + s2 = Categorical([0, 1.2, 2], ordered=True) + result = union_categoricals([s, s2]) + expected = Categorical([0, 1.2, 2, 0, 1.2, 2], ordered=True) + tm.assert_categorical_equal(result, expected) + + def test_union_categorical_match_types(self): + # must exactly match types + s = Categorical([0, 1.2, 2]) + s2 = Categorical([2, 3, 4]) + msg = "dtype of categories must be the same" + with pytest.raises(TypeError, match=msg): + union_categoricals([s, s2]) + + def test_union_categorical_empty(self): + msg = "No Categoricals to union" + with pytest.raises(ValueError, match=msg): + union_categoricals([]) + + def test_union_categoricals_nan(self): + # GH 13759 + res = union_categoricals( + [Categorical([1, 2, np.nan]), Categorical([3, 2, np.nan])] + ) + exp = Categorical([1, 2, np.nan, 3, 2, np.nan]) + tm.assert_categorical_equal(res, exp) + + res = union_categoricals( + [Categorical(["A", "B"]), Categorical(["B", "B", np.nan])] + ) + exp = Categorical(["A", "B", "B", "B", np.nan]) + tm.assert_categorical_equal(res, exp) + + val1 = [pd.Timestamp("2011-01-01"), pd.Timestamp("2011-03-01"), pd.NaT] + val2 = [pd.NaT, pd.Timestamp("2011-01-01"), pd.Timestamp("2011-02-01")] + + res = union_categoricals([Categorical(val1), Categorical(val2)]) + exp = Categorical( + val1 + val2, + categories=[ + pd.Timestamp("2011-01-01"), + pd.Timestamp("2011-03-01"), + pd.Timestamp("2011-02-01"), + ], + ) + tm.assert_categorical_equal(res, exp) + + # all NaN + res = union_categoricals( + [ + Categorical(np.array([np.nan, np.nan], dtype=object)), + Categorical(["X"], categories=pd.Index(["X"], dtype=object)), + ] + ) + exp = Categorical([np.nan, np.nan, "X"]) + tm.assert_categorical_equal(res, exp) + + res = union_categoricals( + [Categorical([np.nan, np.nan]), Categorical([np.nan, np.nan])] + ) + exp = Categorical([np.nan, np.nan, np.nan, np.nan]) + tm.assert_categorical_equal(res, exp) + + @pytest.mark.parametrize("val", [[], ["1"]]) + def test_union_categoricals_empty(self, val, request, using_infer_string): + # GH 13759 + if using_infer_string and val == ["1"]: + request.applymarker( + pytest.mark.xfail( + reason="TDOD(infer_string) object and strings dont match" + ) + ) + res = union_categoricals([Categorical([]), Categorical(val)]) + exp = Categorical(val) + tm.assert_categorical_equal(res, exp) + + def test_union_categorical_same_category(self): + # check fastpath + c1 = Categorical([1, 2, 3, 4], categories=[1, 2, 3, 4]) + c2 = Categorical([3, 2, 1, np.nan], categories=[1, 2, 3, 4]) + res = union_categoricals([c1, c2]) + exp = Categorical([1, 2, 3, 4, 3, 2, 1, np.nan], categories=[1, 2, 3, 4]) + tm.assert_categorical_equal(res, exp) + + def test_union_categorical_same_category_str(self): + c1 = Categorical(["z", "z", "z"], categories=["x", "y", "z"]) + c2 = Categorical(["x", "x", "x"], categories=["x", "y", "z"]) + res = union_categoricals([c1, c2]) + exp = Categorical(["z", "z", "z", "x", "x", "x"], categories=["x", "y", "z"]) + tm.assert_categorical_equal(res, exp) + + def test_union_categorical_same_categories_different_order(self): + # https://github.com/pandas-dev/pandas/issues/19096 + c1 = Categorical(["a", "b", "c"], categories=["a", "b", "c"]) + c2 = Categorical(["a", "b", "c"], categories=["b", "a", "c"]) + result = union_categoricals([c1, c2]) + expected = Categorical( + ["a", "b", "c", "a", "b", "c"], categories=["a", "b", "c"] + ) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_ordered(self): + c1 = Categorical([1, 2, 3], ordered=True) + c2 = Categorical([1, 2, 3], ordered=False) + + msg = "Categorical.ordered must be the same" + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, c2]) + + res = union_categoricals([c1, c1]) + exp = Categorical([1, 2, 3, 1, 2, 3], ordered=True) + tm.assert_categorical_equal(res, exp) + + c1 = Categorical([1, 2, 3, np.nan], ordered=True) + c2 = Categorical([3, 2], categories=[1, 2, 3], ordered=True) + + res = union_categoricals([c1, c2]) + exp = Categorical([1, 2, 3, np.nan, 3, 2], ordered=True) + tm.assert_categorical_equal(res, exp) + + c1 = Categorical([1, 2, 3], ordered=True) + c2 = Categorical([1, 2, 3], categories=[3, 2, 1], ordered=True) + + msg = "to union ordered Categoricals, all categories must be the same" + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, c2]) + + def test_union_categoricals_ignore_order(self): + # GH 15219 + c1 = Categorical([1, 2, 3], ordered=True) + c2 = Categorical([1, 2, 3], ordered=False) + + res = union_categoricals([c1, c2], ignore_order=True) + exp = Categorical([1, 2, 3, 1, 2, 3]) + tm.assert_categorical_equal(res, exp) + + msg = "Categorical.ordered must be the same" + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, c2], ignore_order=False) + + res = union_categoricals([c1, c1], ignore_order=True) + exp = Categorical([1, 2, 3, 1, 2, 3]) + tm.assert_categorical_equal(res, exp) + + res = union_categoricals([c1, c1], ignore_order=False) + exp = Categorical([1, 2, 3, 1, 2, 3], categories=[1, 2, 3], ordered=True) + tm.assert_categorical_equal(res, exp) + + c1 = Categorical([1, 2, 3, np.nan], ordered=True) + c2 = Categorical([3, 2], categories=[1, 2, 3], ordered=True) + + res = union_categoricals([c1, c2], ignore_order=True) + exp = Categorical([1, 2, 3, np.nan, 3, 2]) + tm.assert_categorical_equal(res, exp) + + c1 = Categorical([1, 2, 3], ordered=True) + c2 = Categorical([1, 2, 3], categories=[3, 2, 1], ordered=True) + + res = union_categoricals([c1, c2], ignore_order=True) + exp = Categorical([1, 2, 3, 1, 2, 3]) + tm.assert_categorical_equal(res, exp) + + res = union_categoricals([c2, c1], ignore_order=True, sort_categories=True) + exp = Categorical([1, 2, 3, 1, 2, 3], categories=[1, 2, 3]) + tm.assert_categorical_equal(res, exp) + + c1 = Categorical([1, 2, 3], ordered=True) + c2 = Categorical([4, 5, 6], ordered=True) + result = union_categoricals([c1, c2], ignore_order=True) + expected = Categorical([1, 2, 3, 4, 5, 6]) + tm.assert_categorical_equal(result, expected) + + msg = "to union ordered Categoricals, all categories must be the same" + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, c2], ignore_order=False) + + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, c2]) + + def test_union_categoricals_sort(self): + # GH 13846 + c1 = Categorical(["x", "y", "z"]) + c2 = Categorical(["a", "b", "c"]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical( + ["x", "y", "z", "a", "b", "c"], categories=["a", "b", "c", "x", "y", "z"] + ) + tm.assert_categorical_equal(result, expected) + + # fastpath + c1 = Categorical(["a", "b"], categories=["b", "a", "c"]) + c2 = Categorical(["b", "c"], categories=["b", "a", "c"]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical(["a", "b", "b", "c"], categories=["a", "b", "c"]) + tm.assert_categorical_equal(result, expected) + + c1 = Categorical(["a", "b"], categories=["c", "a", "b"]) + c2 = Categorical(["b", "c"], categories=["c", "a", "b"]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical(["a", "b", "b", "c"], categories=["a", "b", "c"]) + tm.assert_categorical_equal(result, expected) + + # fastpath - skip resort + c1 = Categorical(["a", "b"], categories=["a", "b", "c"]) + c2 = Categorical(["b", "c"], categories=["a", "b", "c"]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical(["a", "b", "b", "c"], categories=["a", "b", "c"]) + tm.assert_categorical_equal(result, expected) + + c1 = Categorical(["x", np.nan]) + c2 = Categorical([np.nan, "b"]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical(["x", np.nan, np.nan, "b"], categories=["b", "x"]) + tm.assert_categorical_equal(result, expected) + + c1 = Categorical([np.nan]) + c2 = Categorical([np.nan]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical([np.nan, np.nan]) + tm.assert_categorical_equal(result, expected) + + c1 = Categorical([]) + c2 = Categorical([]) + result = union_categoricals([c1, c2], sort_categories=True) + expected = Categorical([]) + tm.assert_categorical_equal(result, expected) + + c1 = Categorical(["b", "a"], categories=["b", "a", "c"], ordered=True) + c2 = Categorical(["a", "c"], categories=["b", "a", "c"], ordered=True) + msg = "Cannot use sort_categories=True with ordered Categoricals" + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, c2], sort_categories=True) + + def test_union_categoricals_sort_false(self): + # GH 13846 + c1 = Categorical(["x", "y", "z"]) + c2 = Categorical(["a", "b", "c"]) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical( + ["x", "y", "z", "a", "b", "c"], categories=["x", "y", "z", "a", "b", "c"] + ) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_sort_false_fastpath(self): + # fastpath + c1 = Categorical(["a", "b"], categories=["b", "a", "c"]) + c2 = Categorical(["b", "c"], categories=["b", "a", "c"]) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical(["a", "b", "b", "c"], categories=["b", "a", "c"]) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_sort_false_skipresort(self): + # fastpath - skip resort + c1 = Categorical(["a", "b"], categories=["a", "b", "c"]) + c2 = Categorical(["b", "c"], categories=["a", "b", "c"]) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical(["a", "b", "b", "c"], categories=["a", "b", "c"]) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_sort_false_one_nan(self): + c1 = Categorical(["x", np.nan]) + c2 = Categorical([np.nan, "b"]) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical(["x", np.nan, np.nan, "b"], categories=["x", "b"]) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_sort_false_only_nan(self): + c1 = Categorical([np.nan]) + c2 = Categorical([np.nan]) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical([np.nan, np.nan]) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_sort_false_empty(self): + c1 = Categorical([]) + c2 = Categorical([]) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical([]) + tm.assert_categorical_equal(result, expected) + + def test_union_categoricals_sort_false_ordered_true(self): + c1 = Categorical(["b", "a"], categories=["b", "a", "c"], ordered=True) + c2 = Categorical(["a", "c"], categories=["b", "a", "c"], ordered=True) + result = union_categoricals([c1, c2], sort_categories=False) + expected = Categorical( + ["b", "a", "a", "c"], categories=["b", "a", "c"], ordered=True + ) + tm.assert_categorical_equal(result, expected) + + def test_union_categorical_unwrap(self): + # GH 14173 + c1 = Categorical(["a", "b"]) + c2 = Series(["b", "c"], dtype="category") + result = union_categoricals([c1, c2]) + expected = Categorical(["a", "b", "b", "c"]) + tm.assert_categorical_equal(result, expected) + + c2 = CategoricalIndex(c2) + result = union_categoricals([c1, c2]) + tm.assert_categorical_equal(result, expected) + + c1 = Series(c1) + result = union_categoricals([c1, c2]) + tm.assert_categorical_equal(result, expected) + + msg = "all components to combine must be Categorical" + with pytest.raises(TypeError, match=msg): + union_categoricals([c1, ["a", "b", "c"]]) diff --git a/pandas/tests/series/__init__.py b/pandas/tests/series/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/series/test_api.py b/pandas/tests/series/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..4b369bb0bc86935ac94113a858cfb5d99082e34f --- /dev/null +++ b/pandas/tests/series/test_api.py @@ -0,0 +1,278 @@ +import inspect +import pydoc + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + Series, + date_range, + period_range, + timedelta_range, +) +import pandas._testing as tm + + +class TestSeriesMisc: + def test_tab_completion(self): + # GH 9910 + s = Series(list("abcd")) + # Series of str values should have .str but not .dt/.cat in __dir__ + assert "str" in dir(s) + assert "dt" not in dir(s) + assert "cat" not in dir(s) + + def test_tab_completion_dt(self): + # similarly for .dt + s = Series(date_range("1/1/2015", periods=5)) + assert "dt" in dir(s) + assert "str" not in dir(s) + assert "cat" not in dir(s) + + def test_tab_completion_cat(self): + # Similarly for .cat, but with the twist that str and dt should be + # there if the categories are of that type first cat and str. + s = Series(list("abbcd"), dtype="category") + assert "cat" in dir(s) + assert "str" in dir(s) # as it is a string categorical + assert "dt" not in dir(s) + + def test_tab_completion_cat_str(self): + # similar to cat and str + s = Series(date_range("1/1/2015", periods=5)).astype("category") + assert "cat" in dir(s) + assert "str" not in dir(s) + assert "dt" in dir(s) # as it is a datetime categorical + + def test_tab_completion_with_categorical(self): + # test the tab completion display + ok_for_cat = [ + "categories", + "codes", + "ordered", + "set_categories", + "add_categories", + "remove_categories", + "rename_categories", + "reorder_categories", + "remove_unused_categories", + "as_ordered", + "as_unordered", + ] + + s = Series(list("aabbcde")).astype("category") + results = sorted({r for r in s.cat.__dir__() if not r.startswith("_")}) + tm.assert_almost_equal(results, sorted(set(ok_for_cat))) + + @pytest.mark.parametrize( + "index", + [ + Index(list("ab") * 5, dtype="category"), + Index([str(i) for i in range(10)]), + Index(["foo", "bar", "baz"] * 2), + date_range("2020-01-01", periods=10), + period_range("2020-01-01", periods=10, freq="D"), + timedelta_range("1 day", periods=10), + Index(np.arange(10), dtype=np.uint64), + Index(np.arange(10), dtype=np.int64), + Index(np.arange(10), dtype=np.float64), + Index([True, False]), + Index([f"a{i}" for i in range(101)]), + pd.MultiIndex.from_tuples(zip("ABCD", "EFGH")), + pd.MultiIndex.from_tuples(zip([0, 1, 2, 3], "EFGH")), + ], + ) + def test_index_tab_completion(self, index): + # dir contains string-like values of the Index. + s = Series(index=index, dtype=object) + dir_s = dir(s) + for i, x in enumerate(s.index.unique(level=0)): + if i < 100: + assert not isinstance(x, str) or not x.isidentifier() or x in dir_s + else: + assert x not in dir_s + + @pytest.mark.parametrize("ser", [Series(dtype=object), Series([1])]) + def test_not_hashable(self, ser): + msg = "unhashable type: 'Series'" + with pytest.raises(TypeError, match=msg): + hash(ser) + + def test_contains(self, datetime_series): + tm.assert_contains_all(datetime_series.index, datetime_series) + + def test_axis_alias(self): + s = Series([1, 2, np.nan]) + tm.assert_series_equal(s.dropna(axis="rows"), s.dropna(axis="index")) + assert s.dropna().sum(axis="rows") == 3 + assert s._get_axis_number("rows") == 0 + assert s._get_axis_name("rows") == "index" + + def test_class_axis(self): + # https://github.com/pandas-dev/pandas/issues/18147 + # no exception and no empty docstring + assert pydoc.getdoc(Series.index) + + def test_ndarray_compat(self): + # test numpy compat with Series as sub-class of NDFrame + tsdf = DataFrame( + np.random.default_rng(2).standard_normal((1000, 3)), + columns=["A", "B", "C"], + index=date_range("1/1/2000", periods=1000), + ) + + def f(x): + return x[x.idxmax()] + + result = tsdf.apply(f) + expected = tsdf.max() + tm.assert_series_equal(result, expected) + + def test_ndarray_compat_like_func(self): + # using an ndarray like function + s = Series(np.random.default_rng(2).standard_normal(10)) + result = Series(np.ones_like(s)) + expected = Series(1, index=range(10), dtype="float64") + tm.assert_series_equal(result, expected) + + def test_empty_method(self): + s_empty = Series(dtype=object) + assert s_empty.empty + + @pytest.mark.parametrize("dtype", ["int64", object]) + def test_empty_method_full_series(self, dtype): + full_series = Series(index=[1], dtype=dtype) + assert not full_series.empty + + @pytest.mark.parametrize("dtype", [None, "Int64"]) + def test_integer_series_size(self, dtype): + # GH 25580 + s = Series(range(9), dtype=dtype) + assert s.size == 9 + + def test_attrs(self): + s = Series([0, 1], name="abc") + assert s.attrs == {} + s.attrs["version"] = 1 + result = s + 1 + assert result.attrs == {"version": 1} + + def test_inspect_getmembers(self): + # GH38782 + ser = Series(dtype=object) + inspect.getmembers(ser) + + def test_unknown_attribute(self): + # GH#9680 + tdi = timedelta_range(start=0, periods=10, freq="1s") + ser = Series(np.random.default_rng(2).normal(size=10), index=tdi) + assert "foo" not in ser.__dict__ + msg = "'Series' object has no attribute 'foo'" + with pytest.raises(AttributeError, match=msg): + ser.foo + + @pytest.mark.parametrize("op", ["year", "day", "second", "weekday"]) + def test_datetime_series_no_datelike_attrs(self, op, datetime_series): + # GH#7206 + msg = f"'Series' object has no attribute '{op}'" + with pytest.raises(AttributeError, match=msg): + getattr(datetime_series, op) + + def test_series_datetimelike_attribute_access(self): + # attribute access should still work! + ser = Series({"year": 2000, "month": 1, "day": 10}) + assert ser.year == 2000 + assert ser.month == 1 + assert ser.day == 10 + + def test_series_datetimelike_attribute_access_invalid(self): + ser = Series({"year": 2000, "month": 1, "day": 10}) + msg = "'Series' object has no attribute 'weekday'" + with pytest.raises(AttributeError, match=msg): + ser.weekday + + @pytest.mark.parametrize( + "kernel, has_numeric_only", + [ + ("skew", True), + ("var", True), + ("all", False), + ("prod", True), + ("any", False), + ("idxmin", False), + ("quantile", False), + ("idxmax", False), + ("min", True), + ("sem", True), + ("mean", True), + ("nunique", False), + ("max", True), + ("sum", True), + ("count", False), + ("median", True), + ("std", True), + ("rank", True), + ("pct_change", False), + ("cummax", False), + ("shift", False), + ("diff", False), + ("cumsum", False), + ("cummin", False), + ("cumprod", False), + ("fillna", False), + ("ffill", False), + ("bfill", False), + ("sample", False), + ("tail", False), + ("take", False), + ("head", False), + ("cov", False), + ("corr", False), + ], + ) + @pytest.mark.parametrize("dtype", [bool, int, float, object]) + def test_numeric_only(self, kernel, has_numeric_only, dtype): + # GH#47500 + ser = Series([0, 1, 1], dtype=dtype) + if kernel == "corrwith": + args = (ser,) + elif kernel == "corr": + args = (ser,) + elif kernel == "cov": + args = (ser,) + elif kernel == "nth": + args = (0,) + elif kernel == "fillna": + args = (True,) + elif kernel == "fillna": + args = ("ffill",) + elif kernel == "take": + args = ([0],) + elif kernel == "quantile": + args = (0.5,) + else: + args = () + method = getattr(ser, kernel) + if not has_numeric_only: + msg = ( + "(got an unexpected keyword argument 'numeric_only'" + "|too many arguments passed in)" + ) + with pytest.raises(TypeError, match=msg): + method(*args, numeric_only=True) + elif dtype is object: + msg = f"Series.{kernel} does not allow numeric_only=True with non-numeric" + with pytest.raises(TypeError, match=msg): + method(*args, numeric_only=True) + else: + result = method(*args, numeric_only=True) + expected = method(*args, numeric_only=False) + if isinstance(expected, Series): + # transformer + tm.assert_series_equal(result, expected) + else: + # reducer + assert result == expected diff --git a/pandas/tests/series/test_arithmetic.py b/pandas/tests/series/test_arithmetic.py new file mode 100644 index 0000000000000000000000000000000000000000..4c2bba16b86cd64353203dcdf473bbe155c13b13 --- /dev/null +++ b/pandas/tests/series/test_arithmetic.py @@ -0,0 +1,1085 @@ +from datetime import ( + date, + timedelta, + timezone, +) +from decimal import Decimal +from enum import ( + Enum, + auto, +) +import operator + +import numpy as np +import pytest + +from pandas._libs import lib + +import pandas as pd +from pandas import ( + Categorical, + DatetimeTZDtype, + Index, + Series, + Timedelta, + bdate_range, + date_range, + isna, +) +import pandas._testing as tm +from pandas.core import ops +from pandas.core.computation import expressions as expr + + +@pytest.fixture(autouse=True, params=[0, 1000000], ids=["numexpr", "python"]) +def switch_numexpr_min_elements(request, monkeypatch): + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", request.param) + yield + + +def _permute(obj): + return obj.take(np.random.default_rng(2).permutation(len(obj))) + + +class TestSeriesFlexArithmetic: + @pytest.mark.parametrize( + "ts", + [ + (lambda x: x, lambda x: x * 2, False), + (lambda x: x, lambda x: x[::2], False), + (lambda x: x, lambda x: 5, True), + ( + lambda x: Series(range(10), dtype=np.float64), + lambda x: Series(range(10), dtype=np.float64), + True, + ), + ], + ) + @pytest.mark.parametrize( + "opname", ["add", "sub", "mul", "floordiv", "truediv", "pow"] + ) + def test_flex_method_equivalence(self, opname, ts): + # check that Series.{opname} behaves like Series.__{opname}__, + tser = Series( + np.arange(20, dtype=np.float64), + index=date_range("2020-01-01", periods=20), + name="ts", + ) + + series = ts[0](tser) + other = ts[1](tser) + check_reverse = ts[2] + + op = getattr(Series, opname) + alt = getattr(operator, opname) + + result = op(series, other) + expected = alt(series, other) + tm.assert_almost_equal(result, expected) + if check_reverse: + rop = getattr(Series, "r" + opname) + result = rop(series, other) + expected = alt(other, series) + tm.assert_almost_equal(result, expected) + + def test_flex_method_subclass_metadata_preservation(self, all_arithmetic_operators): + # GH 13208 + class MySeries(Series): + _metadata = ["x"] + + @property + def _constructor(self): + return MySeries + + opname = all_arithmetic_operators + op = getattr(Series, opname) + m = MySeries([1, 2, 3], name="test") + m.x = 42 + result = op(m, 1) + assert result.x == 42 + + def test_flex_add_scalar_fill_value(self): + # GH12723 + ser = Series([0, 1, np.nan, 3, 4, 5]) + + exp = ser.fillna(0).add(2) + res = ser.add(2, fill_value=0) + tm.assert_series_equal(res, exp) + + pairings = [(Series.div, operator.truediv, 1), (Series.rdiv, ops.rtruediv, 1)] + for op in ["add", "sub", "mul", "pow", "truediv", "floordiv"]: + fv = 0 + lop = getattr(Series, op) + lequiv = getattr(operator, op) + rop = getattr(Series, "r" + op) + # bind op at definition time... + requiv = lambda x, y, op=op: getattr(operator, op)(y, x) + pairings.append((lop, lequiv, fv)) + pairings.append((rop, requiv, fv)) + + @pytest.mark.parametrize("op, equiv_op, fv", pairings) + def test_operators_combine(self, op, equiv_op, fv): + def _check_fill(meth, op, a, b, fill_value=0): + exp_index = a.index.union(b.index) + a = a.reindex(exp_index) + b = b.reindex(exp_index) + + amask = isna(a) + bmask = isna(b) + + exp_values = [] + for i in range(len(exp_index)): + with np.errstate(all="ignore"): + if amask[i]: + if bmask[i]: + exp_values.append(np.nan) + continue + exp_values.append(op(fill_value, b[i])) + elif bmask[i]: + if amask[i]: + exp_values.append(np.nan) + continue + exp_values.append(op(a[i], fill_value)) + else: + exp_values.append(op(a[i], b[i])) + + result = meth(a, b, fill_value=fill_value) + expected = Series(exp_values, exp_index) + tm.assert_series_equal(result, expected) + + a = Series([np.nan, 1.0, 2.0, 3.0, np.nan], index=np.arange(5)) + b = Series([np.nan, 1, np.nan, 3, np.nan, 4.0], index=np.arange(6)) + + result = op(a, b) + exp = equiv_op(a, b) + tm.assert_series_equal(result, exp) + _check_fill(op, equiv_op, a, b, fill_value=fv) + # should accept axis=0 or axis='rows' + op(a, b, axis=0) + + @pytest.mark.parametrize("kind", ["datetime", "timedelta"]) + def test_rhs_extension_array_sub_with_fill_value(self, kind): + # GH:62467 + if kind == "datetime": + left = Series( + [pd.Timestamp("2025-08-20"), pd.Timestamp("2025-08-21")], + dtype=np.dtype("datetime64[ns]"), + ) + else: + left = Series( + [Timedelta(days=1), Timedelta(days=2)], + dtype=np.dtype("timedelta64[ns]"), + ) + + right = ( + left._values + ) # DatetimeArray or TimedeltaArray which is an ExtensionArray + + result = left.sub(right, fill_value=left.iloc[0]) + expected = Series(np.zeros(len(left), dtype=np.dtype("timedelta64[ns]"))) + tm.assert_series_equal(result, expected) + + def test_flex_disallows_dataframe(self): + # GH#46179 + df = pd.DataFrame( + {2010: [1], 2020: [3]}, + index=pd.MultiIndex.from_product([["a"], ["b"]], names=["scen", "mod"]), + ) + + ser = Series( + [10.0, 20.0, 30.0], + index=pd.MultiIndex.from_product( + [["a"], ["b"], [0, 1, 2]], names=["scen", "mod", "id"] + ), + ) + + msg = "Series.add does not support a DataFrame `other`" + with pytest.raises(TypeError, match=msg): + ser.add(df, axis=0) + + +class TestSeriesArithmetic: + # Some of these may end up in tests/arithmetic, but are not yet sorted + + def test_add_series_with_period_index(self): + rng = pd.period_range("1/1/2000", "1/1/2010", freq="Y") + ts = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + result = ts + ts[::2] + expected = ts + ts + expected.iloc[1::2] = np.nan + tm.assert_series_equal(result, expected) + + result = ts + _permute(ts[::2]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "target_add,input_value,expected_value", + [ + ("!", ["hello", "world"], ["hello!", "world!"]), + ("m", ["hello", "world"], ["hellom", "worldm"]), + ], + ) + def test_string_addition(self, target_add, input_value, expected_value): + # GH28658 - ensure adding 'm' does not raise an error + a = Series(input_value) + + result = a + target_add + expected = Series(expected_value) + tm.assert_series_equal(result, expected) + + def test_divmod(self): + # GH#25557 + a = Series([1, 1, 1, np.nan], index=["a", "b", "c", "d"]) + b = Series([2, np.nan, 1, np.nan], index=["a", "b", "d", "e"]) + + result = a.divmod(b) + expected = divmod(a, b) + tm.assert_series_equal(result[0], expected[0]) + tm.assert_series_equal(result[1], expected[1]) + + result = a.rdivmod(b) + expected = divmod(b, a) + tm.assert_series_equal(result[0], expected[0]) + tm.assert_series_equal(result[1], expected[1]) + + @pytest.mark.parametrize("index", [None, range(9)]) + def test_series_integer_mod(self, index): + # GH#24396 + s1 = Series(range(1, 10)) + s2 = Series("foo", index=index) + + msg = "not all arguments converted during string formatting|'mod' not supported" + + with pytest.raises(TypeError, match=msg): + s2 % s1 + + def test_add_with_duplicate_index(self): + # GH14227 + s1 = Series([1, 2], index=[1, 1]) + s2 = Series([10, 10], index=[1, 2]) + result = s1 + s2 + expected = Series([11, 12, np.nan], index=[1, 1, 2]) + tm.assert_series_equal(result, expected) + + def test_add_na_handling(self): + ser = Series( + [Decimal("1.3"), Decimal("2.3")], index=[date(2012, 1, 1), date(2012, 1, 2)] + ) + + result = ser + ser.shift(1) + result2 = ser.shift(1) + ser + assert isna(result.iloc[0]) + assert isna(result2.iloc[0]) + + def test_add_corner_cases(self, datetime_series): + empty = Series([], index=Index([]), dtype=np.float64) + + result = datetime_series + empty + assert np.isnan(result).all() + + result = empty + empty + assert len(result) == 0 + + def test_add_float_plus_int(self, datetime_series): + # float + int + int_ts = datetime_series.astype(int)[:-5] + added = datetime_series + int_ts + expected = Series( + datetime_series.values[:-5] + int_ts.values, + index=datetime_series.index[:-5], + name="ts", + ) + tm.assert_series_equal(added[:-5], expected) + + def test_mul_empty_int_corner_case(self): + s1 = Series([], [], dtype=np.int32) + s2 = Series({"x": 0.0}) + tm.assert_series_equal(s1 * s2, Series([np.nan], index=["x"])) + + def test_sub_datetimelike_align(self): + # GH#7500 + # datetimelike ops need to align + dt = Series(date_range("2012-1-1", periods=3, freq="D", unit="ns")) + dt.iloc[2] = np.nan + dt2 = dt[::-1] + + expected = Series([timedelta(0), timedelta(0), pd.NaT], dtype="m8[ns]") + # name is reset + result = dt2 - dt + tm.assert_series_equal(result, expected) + + expected = Series(expected, name=0) + result = (dt2.to_frame() - dt.to_frame())[0] + tm.assert_series_equal(result, expected) + + def test_alignment_doesnt_change_tz(self): + # GH#33671 + dti = date_range("2016-01-01", periods=10, tz="CET") + dti_utc = dti.tz_convert("UTC") + ser = Series(10, index=dti) + ser_utc = Series(10, index=dti_utc) + + # we don't care about the result, just that original indexes are unchanged + ser * ser_utc + + assert ser.index is dti + assert ser_utc.index is dti_utc + + def test_alignment_categorical(self): + # GH13365 + cat = Categorical(["3z53", "3z53", "LoJG", "LoJG", "LoJG", "N503"]) + ser1 = Series(2, index=cat) + ser2 = Series(2, index=cat[:-1]) + result = ser1 * ser2 + + exp_index = ["3z53"] * 4 + ["LoJG"] * 9 + ["N503"] + exp_index = pd.CategoricalIndex(exp_index, categories=cat.categories) + exp_values = [4.0] * 13 + [np.nan] + expected = Series(exp_values, exp_index) + + tm.assert_series_equal(result, expected) + + def test_arithmetic_with_duplicate_index(self): + # GH#8363 + # integer ops with a non-unique index + index = [2, 2, 3, 3, 4] + ser = Series(np.arange(1, 6, dtype="int64"), index=index) + other = Series(np.arange(5, dtype="int64"), index=index) + result = ser - other + expected = Series(1, index=[2, 2, 3, 3, 4]) + tm.assert_series_equal(result, expected) + + # GH#8363 + # datetime ops with a non-unique index + ser = Series(date_range("20130101 09:00:00", periods=5, unit="ns"), index=index) + other = Series(date_range("20130101", periods=5, unit="ns"), index=index) + result = ser - other + expected = Series(Timedelta("9 hours"), index=[2, 2, 3, 3, 4], dtype="m8[ns]") + tm.assert_series_equal(result, expected) + + def test_masked_and_non_masked_propagate_na(self): + # GH#45810 + ser1 = Series([0, np.nan], dtype="float") + ser2 = Series([0, 1], dtype="Int64") + result = ser1 * ser2 + expected = Series([0, pd.NA], dtype="Float64") + tm.assert_series_equal(result, expected) + + def test_mask_div_propagate_na_for_non_na_dtype(self): + # GH#42630 + ser1 = Series([15, pd.NA, 5, 4], dtype="Int64") + ser2 = Series([15, 5, np.nan, 4]) + result = ser1 / ser2 + expected = Series([1.0, pd.NA, pd.NA, 1.0], dtype="Float64") + tm.assert_series_equal(result, expected) + + result = ser2 / ser1 + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("val", [3, 3.5]) + def test_add_list_to_masked_array(self, val): + # GH#22962, behavior changed by GH#62552 + ser = Series([1, None, 3], dtype="Int64") + result = ser + [1, None, val] # noqa: RUF005 + expected = Series([2, pd.NA, 3 + val], dtype="Float64") + tm.assert_series_equal(result, expected) + + result = [1, None, val] + ser # noqa: RUF005 + tm.assert_series_equal(result, expected) + + def test_add_list_to_masked_array_boolean(self): + # GH#22962 + ser = Series([True, None, False], dtype="boolean") + result = ser + [True, None, True] # noqa: RUF005 + expected = Series([2, pd.NA, 1], dtype=object) + tm.assert_series_equal(result, expected) + + result = [True, None, True] + ser # noqa: RUF005 + tm.assert_series_equal(result, expected) + + +# ------------------------------------------------------------------ +# Comparisons + + +class TestSeriesFlexComparison: + @pytest.mark.parametrize("axis", [0, None, "index"]) + def test_comparison_flex_basic(self, axis, comparison_op): + left = Series(np.random.default_rng(2).standard_normal(10)) + right = Series(np.random.default_rng(2).standard_normal(10)) + result = getattr(left, comparison_op.__name__)(right, axis=axis) + expected = comparison_op(left, right) + tm.assert_series_equal(result, expected) + + def test_comparison_bad_axis(self, comparison_op): + left = Series(np.random.default_rng(2).standard_normal(10)) + right = Series(np.random.default_rng(2).standard_normal(10)) + + msg = "No axis named 1 for object type" + with pytest.raises(ValueError, match=msg): + getattr(left, comparison_op.__name__)(right, axis=1) + + @pytest.mark.parametrize( + "values, op", + [ + ([False, False, True, False], "eq"), + ([True, True, False, True], "ne"), + ([False, False, True, False], "le"), + ([False, False, False, False], "lt"), + ([False, True, True, False], "ge"), + ([False, True, False, False], "gt"), + ], + ) + def test_comparison_flex_alignment(self, values, op): + left = Series([1, 3, 2], index=list("abc")) + right = Series([2, 2, 2], index=list("bcd")) + result = getattr(left, op)(right) + expected = Series(values, index=list("abcd")) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "left", + [ + Series(Categorical(["a", "b", "a"])), + Series(pd.period_range("2020Q1", periods=3, freq="Q")), + ], + ids=["categorical", "period"], + ) + def test_rhs_extension_array_eq_with_fill_value(self, left): + # GH:#62467 + right = left._values # this is an ExtensionArray + + result = left.eq(right, fill_value=left.iloc[0]) + expected = Series([True, True, True]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "values, op, fill_value", + [ + ([False, False, True, True], "eq", 2), + ([True, True, False, False], "ne", 2), + ([False, False, True, True], "le", 0), + ([False, False, False, True], "lt", 0), + ([True, True, True, False], "ge", 0), + ([True, True, False, False], "gt", 0), + ], + ) + def test_comparison_flex_alignment_fill(self, values, op, fill_value): + left = Series([1, 3, 2], index=list("abc")) + right = Series([2, 2, 2], index=list("bcd")) + result = getattr(left, op)(right, fill_value=fill_value) + expected = Series(values, index=list("abcd")) + tm.assert_series_equal(result, expected) + + def test_eq_objects(self) -> None: + # GH#62191 Test eq with Enum and List elements + + class Thing(Enum): + FIRST = auto() + SECOND = auto() + + left = Series([Thing.FIRST, Thing.SECOND]) + py_l = [Thing.FIRST, Thing.SECOND] + + result = left.eq(Thing.FIRST) + expected = Series([True, False]) + tm.assert_series_equal(result, expected) + + result = left.eq(py_l) + expected = Series([True, True]) + tm.assert_series_equal(result, expected) + + result = left.eq(np.asarray(py_l)) + expected = Series([True, True]) + tm.assert_series_equal(result, expected) + + result = left.eq(Series(py_l)) + expected = Series([True, True]) + tm.assert_series_equal(result, expected) + + result = Series([[1, 2], [3, 4]]).eq([1, 2]) + expected = Series([True, False]) + with pytest.raises(AssertionError): + tm.assert_series_equal(result, expected) + expected = Series([False, False]) + tm.assert_series_equal(result, expected) + + def test_eq_with_index(self) -> None: + # GH#62191 Test eq with non-trivial indices + left = Series([1, 2], index=[1, 0]) + py_l = [1, 2] + + # assuming Python list has the same index as the Series + result = left.eq(py_l) + expected = Series([True, True], index=[1, 0]) + tm.assert_series_equal(result, expected) + + # assuming np.ndarray has the same index as the Series + result = left.eq(np.asarray(py_l)) + expected = Series([True, True], index=[1, 0]) + tm.assert_series_equal(result, expected) + + result = left.eq(Series(py_l)) + expected = Series([False, False]) + tm.assert_series_equal(result, expected) + + result = left.eq(Series([2, 1])) + expected = Series([True, True]) + tm.assert_series_equal(result, expected) + + +class TestSeriesComparison: + def test_comparison_different_length(self): + a = Series(["a", "b", "c"]) + b = Series(["b", "a"]) + msg = "only compare identically-labeled Series" + with pytest.raises(ValueError, match=msg): + a < b + + a = Series([1, 2]) + b = Series([2, 3, 4]) + with pytest.raises(ValueError, match=msg): + a == b + + @pytest.mark.parametrize("opname", ["eq", "ne", "gt", "lt", "ge", "le"]) + def test_ser_flex_cmp_return_dtypes(self, opname): + # GH#15115 + ser = Series([1, 3, 2], index=range(3)) + const = 2 + result = getattr(ser, opname)(const).dtypes + expected = np.dtype("bool") + assert result == expected + + @pytest.mark.parametrize("opname", ["eq", "ne", "gt", "lt", "ge", "le"]) + def test_ser_flex_cmp_return_dtypes_empty(self, opname): + # GH#15115 empty Series case + ser = Series([1, 3, 2], index=range(3)) + empty = ser.iloc[:0] + const = 2 + result = getattr(empty, opname)(const).dtypes + expected = np.dtype("bool") + assert result == expected + + @pytest.mark.parametrize( + "names", [(None, None, None), ("foo", "bar", None), ("baz", "baz", "baz")] + ) + def test_ser_cmp_result_names(self, names, comparison_op): + # datetime64 dtype + op = comparison_op + dti = date_range("1949-06-07 03:00:00", freq="h", periods=5, name=names[0]) + ser = Series(dti).rename(names[1]) + result = op(ser, dti) + assert result.name == names[2] + + # datetime64tz dtype + dti = dti.tz_localize("US/Central") + dti = pd.DatetimeIndex(dti, freq="infer") # freq not preserved by tz_localize + ser = Series(dti).rename(names[1]) + result = op(ser, dti) + assert result.name == names[2] + + # timedelta64 dtype + tdi = dti - dti.shift(1) + ser = Series(tdi).rename(names[1]) + result = op(ser, tdi) + assert result.name == names[2] + + # interval dtype + if op in [operator.eq, operator.ne]: + # interval dtype comparisons not yet implemented + ii = pd.interval_range(start=0, periods=5, name=names[0]) + ser = Series(ii).rename(names[1]) + result = op(ser, ii) + assert result.name == names[2] + + # categorical + if op in [operator.eq, operator.ne]: + # categorical dtype comparisons raise for inequalities + cidx = tdi.astype("category") + ser = Series(cidx).rename(names[1]) + result = op(ser, cidx) + assert result.name == names[2] + + def test_comparisons(self): + s = Series(["a", "b", "c"]) + s2 = Series([False, True, False]) + + # it works! + exp = Series([False, False, False]) + tm.assert_series_equal(s == s2, exp) + tm.assert_series_equal(s2 == s, exp) + + # ----------------------------------------------------------------- + # Categorical Dtype Comparisons + + def test_categorical_comparisons(self): + # GH#8938 + # allow equality comparisons + a = Series(list("abc"), dtype="category") + b = Series(list("abc"), dtype="object") + c = Series(["a", "b", "cc"], dtype="object") + d = Series(list("acb"), dtype="object") + e = Categorical(list("abc")) + f = Categorical(list("acb")) + + # vs scalar + assert not (a == "a").all() + assert ((a != "a") == ~(a == "a")).all() + + assert not ("a" == a).all() + assert (a == "a")[0] + assert ("a" == a)[0] + assert not ("a" != a)[0] + + # vs list-like + assert (a == a).all() + assert not (a != a).all() + + assert (a == list(a)).all() + assert (a == b).all() + assert (b == a).all() + assert ((~(a == b)) == (a != b)).all() + assert ((~(b == a)) == (b != a)).all() + + assert not (a == c).all() + assert not (c == a).all() + assert not (a == d).all() + assert not (d == a).all() + + # vs a cat-like + assert (a == e).all() + assert (e == a).all() + assert not (a == f).all() + assert not (f == a).all() + + assert (~(a == e) == (a != e)).all() + assert (~(e == a) == (e != a)).all() + assert (~(a == f) == (a != f)).all() + assert (~(f == a) == (f != a)).all() + + # non-equality is not comparable + msg = "can only compare equality or not" + with pytest.raises(TypeError, match=msg): + a < b + with pytest.raises(TypeError, match=msg): + b < a + with pytest.raises(TypeError, match=msg): + a > b + with pytest.raises(TypeError, match=msg): + b > a + + def test_unequal_categorical_comparison_raises_type_error(self): + # unequal comparison should raise for unordered cats + cat = Series(Categorical(list("abc"))) + msg = "can only compare equality or not" + with pytest.raises(TypeError, match=msg): + cat > "b" + + cat = Series(Categorical(list("abc"), ordered=False)) + with pytest.raises(TypeError, match=msg): + cat > "b" + + # https://github.com/pandas-dev/pandas/issues/9836#issuecomment-92123057 + # and following comparisons with scalars not in categories should raise + # for unequal comps, but not for equal/not equal + cat = Series(Categorical(list("abc"), ordered=True)) + + msg = "Invalid comparison between dtype=category and str" + with pytest.raises(TypeError, match=msg): + cat < "d" + with pytest.raises(TypeError, match=msg): + cat > "d" + with pytest.raises(TypeError, match=msg): + "d" < cat + with pytest.raises(TypeError, match=msg): + "d" > cat + + tm.assert_series_equal(cat == "d", Series([False, False, False])) + tm.assert_series_equal(cat != "d", Series([True, True, True])) + + # ----------------------------------------------------------------- + + def test_comparison_tuples(self): + # GH#11339 + # comparisons vs tuple + s = Series([(1, 1), (1, 2)]) + + result = s == (1, 2) + expected = Series([False, True]) + tm.assert_series_equal(result, expected) + + result = s != (1, 2) + expected = Series([True, False]) + tm.assert_series_equal(result, expected) + + result = s == (0, 0) + expected = Series([False, False]) + tm.assert_series_equal(result, expected) + + result = s != (0, 0) + expected = Series([True, True]) + tm.assert_series_equal(result, expected) + + s = Series([(1, 1), (1, 1)]) + + result = s == (1, 1) + expected = Series([True, True]) + tm.assert_series_equal(result, expected) + + result = s != (1, 1) + expected = Series([False, False]) + tm.assert_series_equal(result, expected) + + def test_comparison_frozenset(self): + ser = Series([frozenset([1]), frozenset([1, 2])]) + + result = ser == frozenset([1]) + expected = Series([True, False]) + tm.assert_series_equal(result, expected) + + def test_comparison_operators_with_nas(self, comparison_op): + ser = Series(bdate_range("1/1/2000", periods=10), dtype=object) + ser[::2] = np.nan + + # test that comparisons work + val = ser[5] + + result = comparison_op(ser, val) + expected = comparison_op(ser.dropna(), val).reindex(ser.index) + + if comparison_op is operator.ne: + expected = expected.fillna(True).astype(bool) + else: + expected = expected.fillna(False).astype(bool) + + tm.assert_series_equal(result, expected) + + def test_ne(self): + ts = Series([3, 4, 5, 6, 7], [3, 4, 5, 6, 7], dtype=float) + expected = np.array([True, True, False, True, True]) + tm.assert_numpy_array_equal(ts.index != 5, expected) + tm.assert_numpy_array_equal(~(ts.index == 5), expected) + + @pytest.mark.parametrize("right_data", [[2, 2, 2], [2, 2, 2, 2]]) + def test_comp_ops_df_compat(self, right_data, frame_or_series): + # GH 1134 + # GH 50083 to clarify that index and columns must be identically labeled + left = Series([1, 2, 3], index=list("ABC"), name="x") + right = Series(right_data, index=list("ABDC")[: len(right_data)], name="x") + if frame_or_series is not Series: + msg = ( + rf"Can only compare identically-labeled \(both index and columns\) " + f"{frame_or_series.__name__} objects" + ) + left = left.to_frame() + right = right.to_frame() + else: + msg = ( + f"Can only compare identically-labeled {frame_or_series.__name__} " + f"objects" + ) + + with pytest.raises(ValueError, match=msg): + left == right + with pytest.raises(ValueError, match=msg): + right == left + + with pytest.raises(ValueError, match=msg): + left != right + with pytest.raises(ValueError, match=msg): + right != left + + with pytest.raises(ValueError, match=msg): + left < right + with pytest.raises(ValueError, match=msg): + right < left + + def test_compare_series_interval_keyword(self): + # GH#25338 + ser = Series(["IntervalA", "IntervalB", "IntervalC"]) + result = ser == "IntervalA" + expected = Series([True, False, False]) + tm.assert_series_equal(result, expected) + + +# ------------------------------------------------------------------ +# Unsorted +# These arithmetic tests were previously in other files, eventually +# should be parametrized and put into tests.arithmetic + + +class TestTimeSeriesArithmetic: + def test_series_add_tz_mismatch_converts_to_utc(self): + rng = date_range("1/1/2011", periods=100, freq="h", tz="utc") + + perm = np.random.default_rng(2).permutation(100)[:90] + ser1 = Series( + np.random.default_rng(2).standard_normal(90), + index=rng.take(perm).tz_convert("US/Eastern"), + ) + + perm = np.random.default_rng(2).permutation(100)[:90] + ser2 = Series( + np.random.default_rng(2).standard_normal(90), + index=rng.take(perm).tz_convert("Europe/Berlin"), + ) + + result = ser1 + ser2 + + uts1 = ser1.tz_convert("utc") + uts2 = ser2.tz_convert("utc") + expected = uts1 + uts2 + + # sort since input indexes are not equal + expected = expected.sort_index() + + assert result.index.tz is timezone.utc + tm.assert_series_equal(result, expected) + + def test_series_add_aware_naive_raises(self): + rng = date_range("1/1/2011", periods=10, freq="h") + ser = Series(np.random.default_rng(2).standard_normal(len(rng)), index=rng) + + ser_utc = ser.tz_localize("utc") + + msg = "Cannot join tz-naive with tz-aware DatetimeIndex" + with pytest.raises(Exception, match=msg): + ser + ser_utc + + with pytest.raises(Exception, match=msg): + ser_utc + ser + + # TODO: belongs in tests/arithmetic? + def test_datetime_understood(self, unit): + # Ensures it doesn't fail to create the right series + # reported in issue#16726 + series = Series(date_range("2012-01-01", periods=3, unit=unit)) + offset = pd.offsets.DateOffset(days=6) + result = series - offset + exp_dti = pd.to_datetime(["2011-12-26", "2011-12-27", "2011-12-28"]).as_unit( + unit + ) + expected = Series(exp_dti) + tm.assert_series_equal(result, expected) + + def test_align_date_objects_with_datetimeindex(self): + rng = date_range("1/1/2000", periods=20) + ts = Series(np.random.default_rng(2).standard_normal(20), index=rng) + + ts_slice = ts[5:] + ts2 = ts_slice.copy() + ts2.index = [x.date() for x in ts2.index] + + result = ts + ts2 + result2 = ts2 + ts + expected = ts + ts[5:] + expected.index = expected.index._with_freq(None) + tm.assert_series_equal(result, expected) + tm.assert_series_equal(result2, expected) + + +class TestNamePreservation: + @pytest.mark.parametrize("box", [list, tuple, np.array, Index, Series, pd.array]) + @pytest.mark.parametrize("flex", [True, False]) + def test_series_ops_name_retention(self, flex, box, names, all_binary_operators): + # GH#33930 consistent name-retention + op = all_binary_operators + + left = Series(range(10), name=names[0]) + right = Series(range(10), name=names[1]) + + name = op.__name__.strip("_") + is_logical = name in ["and", "rand", "xor", "rxor", "or", "ror"] + + msg = ( + r"Logical ops \(and, or, xor\) between Pandas objects and " + "dtype-less sequences" + ) + + right = box(right) + if flex: + if is_logical: + # Series doesn't have these as flex methods + return + result = getattr(left, name)(right) + else: + if is_logical and box in [list, tuple]: + with pytest.raises(TypeError, match=msg): + # GH#52264 logical ops with dtype-less sequences deprecated + op(left, right) + return + result = op(left, right) + + assert isinstance(result, Series) + if box in [Index, Series]: + assert result.name is names[2] or result.name == names[2] + else: + assert result.name is names[0] or result.name == names[0] + + def test_binop_maybe_preserve_name(self, datetime_series): + # names match, preserve + result = datetime_series * datetime_series + assert result.name == datetime_series.name + result = datetime_series.mul(datetime_series) + assert result.name == datetime_series.name + + result = datetime_series * datetime_series[:-2] + assert result.name == datetime_series.name + + # names don't match, don't preserve + cp = datetime_series.copy() + cp.name = "something else" + result = datetime_series + cp + assert result.name is None + result = datetime_series.add(cp) + assert result.name is None + + ops = ["add", "sub", "mul", "div", "truediv", "floordiv", "mod", "pow"] + ops = ops + ["r" + op for op in ops] + for op in ops: + # names match, preserve + ser = datetime_series.copy() + result = getattr(ser, op)(ser) + assert result.name == datetime_series.name + + # names don't match, don't preserve + cp = datetime_series.copy() + cp.name = "changed" + result = getattr(ser, op)(cp) + assert result.name is None + + def test_scalarop_preserve_name(self, datetime_series): + result = datetime_series * 2 + assert result.name == datetime_series.name + + +class TestInplaceOperations: + @pytest.mark.parametrize( + "dtype1, dtype2, dtype_expected, dtype_mul", + ( + ("Int64", "Int64", "Int64", "Int64"), + ("float", "float", "float", "float"), + ("Int64", "float", "Float64", "Float64"), + ("Int64", "Float64", "Float64", "Float64"), + ), + ) + def test_series_inplace_ops(self, dtype1, dtype2, dtype_expected, dtype_mul): + # GH 37910 + + ser1 = Series([1], dtype=dtype1) + ser2 = Series([2], dtype=dtype2) + ser1 += ser2 + expected = Series([3], dtype=dtype_expected) + tm.assert_series_equal(ser1, expected) + + ser1 -= ser2 + expected = Series([1], dtype=dtype_expected) + tm.assert_series_equal(ser1, expected) + + ser1 *= ser2 + expected = Series([2], dtype=dtype_mul) + tm.assert_series_equal(ser1, expected) + + +def test_none_comparison(request, series_with_simple_index): + series = series_with_simple_index + + if len(series) < 1: + request.applymarker( + pytest.mark.xfail(reason="Test doesn't make sense on empty data") + ) + + # bug brought up by #1079 + # changed from TypeError in 0.17.0 + series.iloc[0] = np.nan + + # noinspection PyComparisonWithNone + result = series == None # noqa: E711 + assert not result.iat[0] + assert not result.iat[1] + + # noinspection PyComparisonWithNone + result = series != None # noqa: E711 + assert result.iat[0] + assert result.iat[1] + + result = None == series # noqa: E711 + assert not result.iat[0] + assert not result.iat[1] + + result = None != series # noqa: E711 + assert result.iat[0] + assert result.iat[1] + + if lib.is_np_dtype(series.dtype, "M") or isinstance(series.dtype, DatetimeTZDtype): + # Following DatetimeIndex (and Timestamp) convention, + # inequality comparisons with Series[datetime64] raise + msg = "Invalid comparison" + with pytest.raises(TypeError, match=msg): + None > series + with pytest.raises(TypeError, match=msg): + series > None + else: + result = None > series + assert not result.iat[0] + assert not result.iat[1] + + result = series < None + assert not result.iat[0] + assert not result.iat[1] + + +def test_series_varied_multiindex_alignment(): + # GH 20414 + s1 = Series( + range(8), + index=pd.MultiIndex.from_product( + [list("ab"), list("xy"), [1, 2]], names=["ab", "xy", "num"] + ), + ) + s2 = Series( + [1000 * i for i in range(1, 5)], + index=pd.MultiIndex.from_product([list("xy"), [1, 2]], names=["xy", "num"]), + ) + result = s1.loc[pd.IndexSlice[["a"], :, :]] + s2 + expected = Series( + [1000, 2001, 3002, 4003], + index=pd.MultiIndex.from_tuples( + [("a", "x", 1), ("a", "x", 2), ("a", "y", 1), ("a", "y", 2)], + names=["ab", "xy", "num"], + ), + ) + tm.assert_series_equal(result, expected) + + +def test_rmod_consistent_large_series(): + # GH 29602 + result = Series([2] * 10001).rmod(-1) + expected = Series([1] * 10001) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "index", + [ + date_range("2016-01-01", periods=3), + date_range("2016-01-01", tz="US/Pacific", periods=3), + pd.timedelta_range("1 Day", periods=3), + ], +) +def test_comparison_mismatched_datetime_units(index): + # GH#63459 + + ser = Series(1, index=index) + ser2 = Series(1, index=index.as_unit("ns")) + + result = ser == ser2 + expected = Series([True, True, True], index=ser.index) + tm.assert_series_equal(result, expected) + + result2 = ser2 < ser + expected2 = Series([False, False, False], index=ser2.index) + tm.assert_series_equal(result2, expected2) diff --git a/pandas/tests/series/test_arrow_interface.py b/pandas/tests/series/test_arrow_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..3b4340064ea1bc4827c858d4deff602538240eb2 --- /dev/null +++ b/pandas/tests/series/test_arrow_interface.py @@ -0,0 +1,117 @@ +import ctypes + +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd +import pandas._testing as tm + +pa = pytest.importorskip("pyarrow", minversion="16.0") + + +def test_series_arrow_interface(): + s = pd.Series([1, 4, 2]) + + capsule = s.__arrow_c_stream__() + assert ( + ctypes.pythonapi.PyCapsule_IsValid( + ctypes.py_object(capsule), b"arrow_array_stream" + ) + == 1 + ) + + ca = pa.chunked_array(s) + expected = pa.chunked_array([[1, 4, 2]]) + assert ca.equals(expected) + ca = pa.chunked_array(s, type=pa.int32()) + expected = pa.chunked_array([[1, 4, 2]], type=pa.int32()) + assert ca.equals(expected) + + +def test_series_arrow_interface_arrow_dtypes(): + s = pd.Series([1, 4, 2], dtype="Int64[pyarrow]") + + capsule = s.__arrow_c_stream__() + assert ( + ctypes.pythonapi.PyCapsule_IsValid( + ctypes.py_object(capsule), b"arrow_array_stream" + ) + == 1 + ) + + ca = pa.chunked_array(s) + expected = pa.chunked_array([[1, 4, 2]]) + assert ca.equals(expected) + ca = pa.chunked_array(s, type=pa.int32()) + expected = pa.chunked_array([[1, 4, 2]], type=pa.int32()) + assert ca.equals(expected) + + +def test_series_arrow_interface_stringdtype(): + s = pd.Series(["foo", "bar"], dtype="string[pyarrow]") + + capsule = s.__arrow_c_stream__() + assert ( + ctypes.pythonapi.PyCapsule_IsValid( + ctypes.py_object(capsule), b"arrow_array_stream" + ) + == 1 + ) + + ca = pa.chunked_array(s) + expected = pa.chunked_array([["foo", "bar"]], type=pa.large_string()) + assert ca.equals(expected) + + +class ArrowArrayWrapper: + def __init__(self, array): + self.array = array + + def __arrow_c_array__(self, requested_schema=None): + return self.array.__arrow_c_array__(requested_schema) + + +class ArrowStreamWrapper: + def __init__(self, chunked_array): + self.stream = chunked_array + + def __arrow_c_stream__(self, requested_schema=None): + return self.stream.__arrow_c_stream__(requested_schema) + + +@td.skip_if_no("pyarrow", min_version="14.0") +def test_dataframe_from_arrow(): + # objects with __arrow_c_stream__ + arr = pa.chunked_array([[1, 2, 3], [4, 5]]) + + result = pd.Series.from_arrow(arr) + expected = pd.Series([1, 2, 3, 4, 5]) + tm.assert_series_equal(result, expected) + + # not only pyarrow object are supported + result = pd.Series.from_arrow(ArrowStreamWrapper(arr)) + tm.assert_series_equal(result, expected) + + # table works as well, but will be seen as a StructArray + table = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + + result = pd.Series.from_arrow(table) + expected = pd.Series([{"a": 1, "b": "a"}, {"a": 2, "b": "b"}, {"a": 3, "b": "c"}]) + tm.assert_series_equal(result, expected) + + # objects with __arrow_c_array__ + arr = pa.array([1, 2, 3]) + + expected = pd.Series([1, 2, 3]) + result = pd.Series.from_arrow(arr) + tm.assert_series_equal(result, expected) + + result = pd.Series.from_arrow(ArrowArrayWrapper(arr)) + tm.assert_series_equal(result, expected) + + # only accept actual Arrow objects + with pytest.raises( + TypeError, match="Expected an Arrow-compatible array-like object" + ): + pd.Series.from_arrow([1, 2, 3]) diff --git a/pandas/tests/series/test_constructors.py b/pandas/tests/series/test_constructors.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0dcb9ec5facdbd08ba5608c64c4a454b6419db --- /dev/null +++ b/pandas/tests/series/test_constructors.py @@ -0,0 +1,2294 @@ +from collections import OrderedDict +from collections.abc import Iterator +from datetime import ( + datetime, + timedelta, +) + +from dateutil.tz import tzoffset +import numpy as np +from numpy import ma +import pytest + +from pandas._libs import ( + iNaT, + lib, +) +from pandas.compat import HAS_PYARROW +from pandas.compat.numpy import np_version_gt2 +from pandas.errors import ( + IntCastingNaNError, + Pandas4Warning, +) + +from pandas.core.dtypes.dtypes import CategoricalDtype + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + DatetimeIndex, + DatetimeTZDtype, + Index, + Interval, + IntervalIndex, + MultiIndex, + NaT, + Period, + RangeIndex, + Series, + Timestamp, + date_range, + isna, + period_range, + timedelta_range, +) +import pandas._testing as tm +from pandas.core.arrays import ( + IntegerArray, + IntervalArray, + period_array, +) +from pandas.core.internals.blocks import NumpyBlock + + +class TestSeriesConstructors: + def test_from_ints_with_non_nano_dt64_dtype(self, index_or_series): + values = np.arange(10) + + res = index_or_series(values, dtype="M8[s]") + expected = index_or_series(values.astype("M8[s]")) + tm.assert_equal(res, expected) + + res = index_or_series(list(values), dtype="M8[s]") + tm.assert_equal(res, expected) + + def test_from_na_value_and_interval_of_datetime_dtype(self): + # GH#41805 + ser = Series([None], dtype="interval[datetime64[ns]]") + assert ser.isna().all() + assert ser.dtype == "interval[datetime64[ns], right]" + + def test_infer_with_date_and_datetime(self): + # GH#49341 pre-2.0 we inferred datetime-and-date to datetime64, which + # was inconsistent with Index behavior + ts = Timestamp(2016, 1, 1) + vals = [ts.to_pydatetime(), ts.date()] + + ser = Series(vals) + expected = Series(vals, dtype=object) + tm.assert_series_equal(ser, expected) + + idx = Index(vals) + expected = Index(vals, dtype=object) + tm.assert_index_equal(idx, expected) + + def test_unparsable_strings_with_dt64_dtype(self): + # pre-2.0 these would be silently ignored and come back with object dtype + vals = ["aa"] + msg = "^Unknown datetime string format, unable to parse: aa$" + with pytest.raises(ValueError, match=msg): + Series(vals, dtype="datetime64[ns]") + + with pytest.raises(ValueError, match=msg): + Series(np.array(vals, dtype=object), dtype="datetime64[ns]") + + def test_invalid_dtype_conversion_datetime_to_timedelta(self): + # GH#60728 + vals = Series([NaT, Timestamp(2025, 1, 1)], dtype="datetime64[ns]") + msg = r"^Cannot cast DatetimeArray to dtype timedelta64\[ns\]$" + with pytest.raises(TypeError, match=msg): + Series(vals, dtype="timedelta64[ns]") + + @pytest.mark.parametrize( + "constructor", + [ + # NOTE: some overlap with test_constructor_empty but that test does not + # test for None or an empty generator. + # test_constructor_pass_none tests None but only with the index also + # passed. + (lambda idx: Series(index=idx)), + (lambda idx: Series(None, index=idx)), + (lambda idx: Series({}, index=idx)), + (lambda idx: Series((), index=idx)), + (lambda idx: Series([], index=idx)), + (lambda idx: Series((_ for _ in []), index=idx)), + (lambda idx: Series(data=None, index=idx)), + (lambda idx: Series(data={}, index=idx)), + (lambda idx: Series(data=(), index=idx)), + (lambda idx: Series(data=[], index=idx)), + (lambda idx: Series(data=(_ for _ in []), index=idx)), + ], + ) + @pytest.mark.parametrize("empty_index", [None, []]) + def test_empty_constructor(self, constructor, empty_index): + # GH 49573 (addition of empty_index parameter) + expected = Series(index=empty_index) + result = constructor(empty_index) + + assert result.dtype == object + assert len(result.index) == 0 + tm.assert_series_equal(result, expected, check_index_type=True) + + def test_invalid_dtype(self): + # GH15520 + msg = "not understood" + invalid_list = [Timestamp, "Timestamp", list] + for dtype in invalid_list: + with pytest.raises(TypeError, match=msg): + Series([], name="time", dtype=dtype) + + def test_invalid_compound_dtype(self): + # GH#13296 + c_dtype = np.dtype([("a", "i8"), ("b", "f4")]) + cdt_arr = np.array([(1, 0.4), (256, -13)], dtype=c_dtype) + + with pytest.raises(ValueError, match="Use DataFrame instead"): + Series(cdt_arr, index=["A", "B"]) + + def test_scalar_conversion(self): + # Pass in scalar is disabled + scalar = Series(0.5) + assert not isinstance(scalar, float) + + def test_scalar_extension_dtype(self, ea_scalar_and_dtype): + # GH 28401 + + ea_scalar, ea_dtype = ea_scalar_and_dtype + + ser = Series(ea_scalar, index=range(3)) + expected = Series([ea_scalar] * 3, dtype=ea_dtype) + + assert ser.dtype == ea_dtype + tm.assert_series_equal(ser, expected) + + def test_constructor(self, datetime_series, using_infer_string): + empty_series = Series() + assert datetime_series.index._is_all_dates + + # Pass in Series + derived = Series(datetime_series) + assert derived.index._is_all_dates + + tm.assert_index_equal(derived.index, datetime_series.index) + # Ensure new index is not created + assert id(datetime_series.index) == id(derived.index) + + # Mixed type Series + mixed = Series(["hello", np.nan], index=[0, 1]) + assert mixed.dtype == np.object_ if not using_infer_string else "str" + assert np.isnan(mixed[1]) + + assert not empty_series.index._is_all_dates + assert not Series().index._is_all_dates + + # exception raised is of type ValueError GH35744 + with pytest.raises( + ValueError, + match=r"Data must be 1-dimensional, got ndarray of shape \(3, 3\) instead", + ): + Series(np.random.default_rng(2).standard_normal((3, 3)), index=np.arange(3)) + + mixed.name = "Series" + rs = Series(mixed).name + xp = "Series" + assert rs == xp + + # raise on MultiIndex GH4187 + m = MultiIndex.from_arrays([[1, 2], [3, 4]]) + msg = "initializing a Series from a MultiIndex is not supported" + with pytest.raises(NotImplementedError, match=msg): + Series(m) + + def test_constructor_index_ndim_gt_1_raises(self): + # GH#18579 + df = DataFrame([[1, 2], [3, 4], [5, 6]], index=[3, 6, 9]) + with pytest.raises(ValueError, match="Index data must be 1-dimensional"): + Series([1, 3, 2], index=df) + + @pytest.mark.parametrize("input_class", [list, dict, OrderedDict]) + def test_constructor_empty(self, input_class, using_infer_string): + empty = Series() + empty2 = Series(input_class()) + + # these are Index() and RangeIndex() which don't compare type equal + # but are just .equals + tm.assert_series_equal(empty, empty2, check_index_type=False) + + # With explicit dtype: + empty = Series(dtype="float64") + empty2 = Series(input_class(), dtype="float64") + tm.assert_series_equal(empty, empty2, check_index_type=False) + + # GH 18515 : with dtype=category: + empty = Series(dtype="category") + empty2 = Series(input_class(), dtype="category") + tm.assert_series_equal(empty, empty2, check_index_type=False) + + if input_class is not list: + # With index: + empty = Series(index=range(10)) + empty2 = Series(input_class(), index=range(10)) + tm.assert_series_equal(empty, empty2) + + # With index and dtype float64: + empty = Series(np.nan, index=range(10)) + empty2 = Series(input_class(), index=range(10), dtype="float64") + tm.assert_series_equal(empty, empty2) + + # GH 19853 : with empty string, index and dtype str + empty = Series("", dtype=str, index=range(3)) + if using_infer_string: + empty2 = Series("", index=range(3), dtype="str") + else: + empty2 = Series("", index=range(3)) + tm.assert_series_equal(empty, empty2) + + @pytest.mark.parametrize("input_arg", [np.nan, float("nan")]) + def test_constructor_nan(self, input_arg): + empty = Series(dtype="float64", index=range(10)) + empty2 = Series(input_arg, index=range(10)) + + tm.assert_series_equal(empty, empty2, check_index_type=False) + + @pytest.mark.parametrize( + "dtype", + ["f8", "i8", "M8[ns]", "m8[ns]", "category", "object", "datetime64[ns, UTC]"], + ) + @pytest.mark.parametrize("index", [None, Index([])]) + def test_constructor_dtype_only(self, dtype, index): + # GH-20865 + result = Series(dtype=dtype, index=index) + assert result.dtype == dtype + assert len(result) == 0 + + def test_constructor_no_data_index_order(self): + result = Series(index=["b", "a", "c"]) + assert result.index.tolist() == ["b", "a", "c"] + + def test_constructor_no_data_string_type(self): + # GH 22477 + result = Series(index=[1], dtype=str) + assert np.isnan(result.iloc[0]) + + @pytest.mark.parametrize("item", ["entry", "ѐ", 13]) + def test_constructor_string_element_string_type(self, item): + # GH 22477 + result = Series(item, index=[1], dtype=str) + assert result.iloc[0] == str(item) + + def test_constructor_dtype_str_na_values(self, string_dtype): + # https://github.com/pandas-dev/pandas/issues/21083 + ser = Series(["x", None], dtype=string_dtype) + result = ser.isna() + expected = Series([False, True]) + tm.assert_series_equal(result, expected) + assert ser.iloc[1] is None + + ser = Series(["x", np.nan], dtype=string_dtype) + assert np.isnan(ser.iloc[1]) + + def test_constructor_series(self): + index1 = ["d", "b", "a", "c"] + index2 = sorted(index1) + s1 = Series([4, 7, -5, 3], index=index1) + s2 = Series(s1, index=index2) + + tm.assert_series_equal(s2, s1.sort_index()) + + def test_constructor_iterable(self): + # GH 21987 + class Iter: + def __iter__(self) -> Iterator: + yield from range(10) + + expected = Series(list(range(10)), dtype="int64") + result = Series(Iter(), dtype="int64") + tm.assert_series_equal(result, expected) + + def test_constructor_sequence(self): + # GH 21987 + expected = Series(list(range(10)), dtype="int64") + result = Series(range(10), dtype="int64") + tm.assert_series_equal(result, expected) + + def test_constructor_single_str(self): + # GH 21987 + expected = Series(["abc"]) + result = Series("abc") + tm.assert_series_equal(result, expected) + + def test_constructor_list_like(self): + # make sure that we are coercing different + # list-likes to standard dtypes and not + # platform specific + expected = Series([1, 2, 3], dtype="int64") + for obj in [[1, 2, 3], (1, 2, 3), np.array([1, 2, 3], dtype="int64")]: + result = Series(obj, index=[0, 1, 2]) + tm.assert_series_equal(result, expected) + + def test_constructor_boolean_index(self): + # GH#18579 + s1 = Series([1, 2, 3], index=[4, 5, 6]) + + index = s1 == 2 + result = Series([1, 3, 2], index=index) + expected = Series([1, 3, 2], index=[False, True, False]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["bool", "int32", "int64", "float64"]) + def test_constructor_index_dtype(self, dtype): + # GH 17088 + + s = Series(Index([0, 2, 4]), dtype=dtype) + assert s.dtype == dtype + + @pytest.mark.parametrize( + "input_vals", + [ + [1, 2], + ["1", "2"], + list(date_range("1/1/2011", periods=2, freq="h")), + list(date_range("1/1/2011", periods=2, freq="h", tz="US/Eastern")), + [Interval(left=0, right=5)], + ], + ) + def test_constructor_list_str(self, input_vals, string_dtype): + # GH 16605 + # Ensure that data elements from a list are converted to strings + # when dtype is str, 'str', or 'U' + result = Series(input_vals, dtype=string_dtype) + expected = Series(input_vals).astype(string_dtype) + tm.assert_series_equal(result, expected) + + def test_constructor_list_str_na(self, string_dtype): + result = Series([1.0, 2.0, np.nan], dtype=string_dtype) + expected = Series(["1.0", "2.0", np.nan], dtype=object) + tm.assert_series_equal(result, expected) + assert np.isnan(result[2]) + + def test_constructor_generator(self): + gen = (i for i in range(10)) + + result = Series(gen) + exp = Series(range(10)) + tm.assert_series_equal(result, exp) + + # same but with non-default index + gen = (i for i in range(10)) + result = Series(gen, index=range(10, 20)) + exp.index = range(10, 20) + tm.assert_series_equal(result, exp) + + def test_constructor_map(self): + # GH8909 + m = (x for x in range(10)) + + result = Series(m) + exp = Series(range(10)) + tm.assert_series_equal(result, exp) + + # same but with non-default index + m = (x for x in range(10)) + result = Series(m, index=range(10, 20)) + exp.index = range(10, 20) + tm.assert_series_equal(result, exp) + + def test_constructor_categorical(self): + msg = "Constructing a Categorical with a dtype and values containing" + with tm.assert_produces_warning(Pandas4Warning, match=msg): + cat = Categorical([0, 1, 2, 0, 1, 2], ["a", "b", "c"]) + res = Series(cat) + tm.assert_categorical_equal(res.values, cat) + + # can cast to a new dtype + result = Series(Categorical([1, 2, 3]), dtype="int64") + expected = Series([1, 2, 3], dtype="int64") + tm.assert_series_equal(result, expected) + + def test_construct_from_categorical_with_dtype(self): + # GH12574 + ser = Series(Categorical([1, 2, 3]), dtype="category") + assert isinstance(ser.dtype, CategoricalDtype) + + def test_construct_intlist_values_category_dtype(self): + ser = Series([1, 2, 3], dtype="category") + assert isinstance(ser.dtype, CategoricalDtype) + + def test_constructor_categorical_with_coercion(self): + factor = Categorical(["a", "b", "b", "a", "a", "c", "c", "c"]) + # test basic creation / coercion of categoricals + s = Series(factor, name="A") + assert s.dtype == "category" + assert len(s) == len(factor) + + # in a frame + df = DataFrame({"A": factor}) + result = df["A"] + tm.assert_series_equal(result, s) + result = df.iloc[:, 0] + tm.assert_series_equal(result, s) + assert len(df) == len(factor) + + df = DataFrame({"A": s}) + result = df["A"] + tm.assert_series_equal(result, s) + assert len(df) == len(factor) + + # multiples + df = DataFrame({"A": s, "B": s, "C": 1}) + result1 = df["A"] + result2 = df["B"] + tm.assert_series_equal(result1, s) + tm.assert_series_equal(result2, s, check_names=False) + assert result2.name == "B" + assert len(df) == len(factor) + + def test_constructor_categorical_with_coercion2(self): + # GH8623 + x = DataFrame( + [[1, "John P. Doe"], [2, "Jane Dove"], [1, "John P. Doe"]], + columns=["person_id", "person_name"], + ) + x["person_name"] = Categorical(x.person_name) # doing this breaks transform + + expected = x.iloc[0].person_name + result = x.person_name.iloc[0] + assert result == expected + + result = x.person_name[0] + assert result == expected + + result = x.person_name.loc[0] + assert result == expected + + def test_constructor_series_to_categorical(self): + # see GH#16524: test conversion of Series to Categorical + series = Series(["a", "b", "c"]) + + result = Series(series, dtype="category") + expected = Series(["a", "b", "c"], dtype="category") + + tm.assert_series_equal(result, expected) + + def test_constructor_categorical_dtype(self): + result = Series( + ["a", "b"], dtype=CategoricalDtype(["a", "b", "c"], ordered=True) + ) + assert isinstance(result.dtype, CategoricalDtype) + tm.assert_index_equal(result.cat.categories, Index(["a", "b", "c"])) + assert result.cat.ordered + + result = Series(["a", "b"], dtype=CategoricalDtype(["b", "a"])) + assert isinstance(result.dtype, CategoricalDtype) + tm.assert_index_equal(result.cat.categories, Index(["b", "a"])) + assert result.cat.ordered is False + + # GH 19565 - Check broadcasting of scalar with Categorical dtype + result = Series( + "a", index=[0, 1], dtype=CategoricalDtype(["a", "b"], ordered=True) + ) + expected = Series( + ["a", "a"], index=[0, 1], dtype=CategoricalDtype(["a", "b"], ordered=True) + ) + tm.assert_series_equal(result, expected) + + def test_constructor_categorical_string(self): + # GH 26336: the string 'category' maintains existing CategoricalDtype + cdt = CategoricalDtype(categories=list("dabc"), ordered=True) + expected = Series(list("abcabc"), dtype=cdt) + + # Series(Categorical, dtype='category') keeps existing dtype + cat = Categorical(list("abcabc"), dtype=cdt) + result = Series(cat, dtype="category") + tm.assert_series_equal(result, expected) + + # Series(Series[Categorical], dtype='category') keeps existing dtype + result = Series(result, dtype="category") + tm.assert_series_equal(result, expected) + + def test_categorical_sideeffects_free(self): + # Passing a categorical to a Series and then changing values in either + # the series or the categorical should not change the values in the + # other one, IF you specify copy! + cat = Categorical(["a", "b", "c", "a"]) + s = Series(cat, copy=True) + assert s.cat is not cat + s = s.cat.rename_categories([1, 2, 3]) + exp_s = np.array([1, 2, 3, 1], dtype=np.int64) + exp_cat = np.array(["a", "b", "c", "a"], dtype=np.object_) + tm.assert_numpy_array_equal(s.__array__(), exp_s) + tm.assert_numpy_array_equal(cat.__array__(), exp_cat) + + # setting + s[0] = 2 + exp_s2 = np.array([2, 2, 3, 1], dtype=np.int64) + tm.assert_numpy_array_equal(s.__array__(), exp_s2) + tm.assert_numpy_array_equal(cat.__array__(), exp_cat) + + # however, copy is False by default + # so this WILL change values + cat = Categorical(["a", "b", "c", "a"]) + s = Series(cat, copy=False) + assert s._values is cat + s = s.cat.rename_categories([1, 2, 3]) + assert s._values is not cat + exp_s = np.array([1, 2, 3, 1], dtype=np.int64) + tm.assert_numpy_array_equal(s.__array__(), exp_s) + + s[0] = 2 + exp_s2 = np.array([2, 2, 3, 1], dtype=np.int64) + tm.assert_numpy_array_equal(s.__array__(), exp_s2) + + def test_unordered_compare_equal(self): + left = Series(["a", "b", None], dtype=CategoricalDtype(["a", "b"])) + right = Series(Categorical(["a", "b", np.nan], categories=["a", "b"])) + tm.assert_series_equal(left, right) + + def test_constructor_maskedarray(self): + data = ma.masked_all((3,), dtype=float) + result = Series(data) + expected = Series([np.nan, np.nan, np.nan]) + tm.assert_series_equal(result, expected) + + data[0] = 0.0 + data[2] = 2.0 + index = ["a", "b", "c"] + result = Series(data, index=index) + expected = Series([0.0, np.nan, 2.0], index=index) + tm.assert_series_equal(result, expected) + + data[1] = 1.0 + result = Series(data, index=index) + expected = Series([0.0, 1.0, 2.0], index=index) + tm.assert_series_equal(result, expected) + + data = ma.masked_all((3,), dtype=int) + result = Series(data) + expected = Series([np.nan, np.nan, np.nan], dtype=float) + tm.assert_series_equal(result, expected) + + data[0] = 0 + data[2] = 2 + index = ["a", "b", "c"] + result = Series(data, index=index) + expected = Series([0, np.nan, 2], index=index, dtype=float) + tm.assert_series_equal(result, expected) + + data[1] = 1 + result = Series(data, index=index) + expected = Series([0, 1, 2], index=index, dtype=int) + with pytest.raises(AssertionError, match="Series classes are different"): + # TODO should this be raising at all? + # https://github.com/pandas-dev/pandas/issues/56131 + tm.assert_series_equal(result, expected) + + data = ma.masked_all((3,), dtype=bool) + result = Series(data) + expected = Series([np.nan, np.nan, np.nan], dtype=object) + tm.assert_series_equal(result, expected) + + data[0] = True + data[2] = False + index = ["a", "b", "c"] + result = Series(data, index=index) + expected = Series([True, np.nan, False], index=index, dtype=object) + tm.assert_series_equal(result, expected) + + data[1] = True + result = Series(data, index=index) + expected = Series([True, True, False], index=index, dtype=bool) + with pytest.raises(AssertionError, match="Series classes are different"): + # TODO should this be raising at all? + # https://github.com/pandas-dev/pandas/issues/56131 + tm.assert_series_equal(result, expected) + + data = ma.masked_all((3,), dtype="M8[ns]") + result = Series(data) + expected = Series([iNaT, iNaT, iNaT], dtype="M8[ns]") + tm.assert_series_equal(result, expected) + + data[0] = datetime(2001, 1, 1) + data[2] = datetime(2001, 1, 3) + index = ["a", "b", "c"] + result = Series(data, index=index) + expected = Series( + [datetime(2001, 1, 1), iNaT, datetime(2001, 1, 3)], + index=index, + dtype="M8[ns]", + ) + tm.assert_series_equal(result, expected) + + data[1] = datetime(2001, 1, 2) + result = Series(data, index=index) + expected = Series( + [datetime(2001, 1, 1), datetime(2001, 1, 2), datetime(2001, 1, 3)], + index=index, + dtype="M8[ns]", + ) + tm.assert_series_equal(result, expected) + + def test_constructor_maskedarray_hardened(self): + # Check numpy masked arrays with hard masks -- from GH24574 + data = ma.masked_all((3,), dtype=float).harden_mask() + result = Series(data) + expected = Series([np.nan, np.nan, np.nan]) + tm.assert_series_equal(result, expected) + + def test_series_ctor_plus_datetimeindex(self): + rng = date_range("20090415", "20090519", freq="B") + data = dict.fromkeys(rng, 1) + + result = Series(data, index=rng) + assert result.index.is_(rng) + + def test_constructor_default_index(self): + s = Series([0, 1, 2]) + tm.assert_index_equal(s.index, Index(range(3)), exact=True) + + @pytest.mark.parametrize( + "input", + [ + [1, 2, 3], + (1, 2, 3), + list(range(3)), + Categorical(["a", "b", "a"]), + (i for i in range(3)), + (x for x in range(3)), + ], + ) + def test_constructor_index_mismatch(self, input): + # GH 19342 + # test that construction of a Series with an index of different length + # raises an error + msg = r"Length of values \(3\) does not match length of index \(4\)" + with pytest.raises(ValueError, match=msg): + Series(input, index=np.arange(4)) + + def test_constructor_numpy_scalar(self): + # GH 19342 + # construction with a numpy scalar + # should not raise + result = Series(np.array(100), index=np.arange(4), dtype="int64") + expected = Series(100, index=np.arange(4), dtype="int64") + tm.assert_series_equal(result, expected) + + def test_constructor_broadcast_list(self): + # GH 19342 + # construction with single-element container and index + # should raise + msg = r"Length of values \(1\) does not match length of index \(3\)" + with pytest.raises(ValueError, match=msg): + Series(["foo"], index=["a", "b", "c"]) + + def test_constructor_corner(self): + df = DataFrame(range(5), index=date_range("2020-01-01", periods=5)) + objs = [df, df] + s = Series(objs, index=[0, 1]) + assert isinstance(s, Series) + + def test_constructor_sanitize(self): + s = Series(np.array([1.0, 1.0, 8.0]), dtype="i8") + assert s.dtype == np.dtype("i8") + + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + with pytest.raises(IntCastingNaNError, match=msg): + Series(np.array([1.0, 1.0, np.nan]), copy=True, dtype="i8") + + def test_constructor_copy(self): + # GH15125 + # test dtype parameter has no side effects on copy=True + for data in [[1.0], np.array([1.0])]: + x = Series(data) + y = Series(x, copy=True, dtype=float) + + # copy=True maintains original data in Series + tm.assert_series_equal(x, y) + + # changes to origin of copy does not affect the copy + x[0] = 2.0 + assert not x.equals(y) + assert x[0] == 2.0 + assert y[0] == 1.0 + + @pytest.mark.parametrize( + "index", + [ + date_range("20170101", periods=3, tz="US/Eastern"), + date_range("20170101", periods=3), + timedelta_range("1 day", periods=3), + period_range("2012Q1", periods=3, freq="Q"), + Index(list("abc")), + Index([1, 2, 3]), + RangeIndex(0, 3), + ], + ids=lambda x: type(x).__name__, + ) + def test_constructor_limit_copies(self, index): + # GH 17449 + # limit copies of input + s = Series(index) + + # we make 1 copy; this is just a smoke test here + assert s._mgr.blocks[0].values is not index + + def test_constructor_shallow_copy(self): + # constructing a Series from Series with copy=False should still + # give a "shallow" copy (share data, not attributes) + # https://github.com/pandas-dev/pandas/issues/49523 + s = Series([1, 2, 3]) + s_orig = s.copy() + s2 = Series(s) + assert s2._mgr is not s._mgr + # Overwriting index of s2 doesn't change s + s2.index = ["a", "b", "c"] + tm.assert_series_equal(s, s_orig) + + def test_constructor_pass_none(self): + s = Series(None, index=range(5)) + assert s.dtype == np.float64 + + s = Series(None, index=range(5), dtype=object) + assert s.dtype == np.object_ + + # GH 7431 + # inference on the index + s = Series(index=np.array([None])) + expected = Series(index=Index([None])) + tm.assert_series_equal(s, expected) + + def test_constructor_pass_nan_nat(self): + # GH 13467 + exp = Series([np.nan, np.nan], dtype=np.float64) + assert exp.dtype == np.float64 + tm.assert_series_equal(Series([np.nan, np.nan]), exp) + tm.assert_series_equal(Series(np.array([np.nan, np.nan])), exp) + + exp = Series([NaT, NaT]) + assert exp.dtype == "datetime64[s]" + tm.assert_series_equal(Series([NaT, NaT]), exp) + tm.assert_series_equal(Series(np.array([NaT, NaT])), exp) + + tm.assert_series_equal(Series([NaT, np.nan]), exp) + tm.assert_series_equal(Series(np.array([NaT, np.nan])), exp) + + tm.assert_series_equal(Series([np.nan, NaT]), exp) + tm.assert_series_equal(Series(np.array([np.nan, NaT])), exp) + + def test_constructor_cast(self): + msg = "could not convert string to float" + with pytest.raises(ValueError, match=msg): + Series(["a", "b", "c"], dtype=float) + + def test_constructor_signed_int_overflow_raises(self): + # GH#41734 disallow silent overflow, enforced in 2.0 + if np_version_gt2: + msg = "The elements provided in the data cannot all be casted to the dtype" + err = OverflowError + else: + msg = "Values are too large to be losslessly converted" + err = ValueError + with pytest.raises(err, match=msg): + Series([1, 200, 923442], dtype="int8") + + with pytest.raises(err, match=msg): + Series([1, 200, 923442], dtype="uint8") + + @pytest.mark.parametrize( + "values", + [ + np.array([1], dtype=np.uint16), + np.array([1], dtype=np.uint32), + np.array([1], dtype=np.uint64), + [np.uint16(1)], + [np.uint32(1)], + [np.uint64(1)], + ], + ) + def test_constructor_numpy_uints(self, values): + # GH#47294 + value = values[0] + result = Series(values) + + assert result[0].dtype == value.dtype + assert result[0] == value + + def test_constructor_unsigned_dtype_overflow(self, any_unsigned_int_numpy_dtype): + # see gh-15832 + if np_version_gt2: + msg = ( + f"The elements provided in the data cannot " + f"all be casted to the dtype {any_unsigned_int_numpy_dtype}" + ) + else: + msg = "Trying to coerce negative values to unsigned integers" + with pytest.raises(OverflowError, match=msg): + Series([-1], dtype=any_unsigned_int_numpy_dtype) + + def test_constructor_floating_data_int_dtype(self, frame_or_series): + # GH#40110 + arr = np.random.default_rng(2).standard_normal(2) + + # Long-standing behavior (for Series, new in 2.0 for DataFrame) + # has been to ignore the dtype on these; + # not clear if this is what we want long-term + # expected = frame_or_series(arr) + + # GH#49599 as of 2.0 we raise instead of silently retaining float dtype + msg = "Trying to coerce float values to integer" + with pytest.raises(ValueError, match=msg): + frame_or_series(arr, dtype="i8") + + with pytest.raises(ValueError, match=msg): + frame_or_series(list(arr), dtype="i8") + + # pre-2.0, when we had NaNs, we silently ignored the integer dtype + arr[0] = np.nan + # expected = frame_or_series(arr) + + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + with pytest.raises(IntCastingNaNError, match=msg): + frame_or_series(arr, dtype="i8") + + exc = IntCastingNaNError + if frame_or_series is Series: + # TODO: try to align these + exc = ValueError + msg = "cannot convert float NaN to integer" + with pytest.raises(exc, match=msg): + # same behavior if we pass list instead of the ndarray + frame_or_series(list(arr), dtype="i8") + + # float array that can be losslessly cast to integers + arr = np.array([1.0, 2.0], dtype="float64") + expected = frame_or_series(arr.astype("i8")) + + obj = frame_or_series(arr, dtype="i8") + tm.assert_equal(obj, expected) + + obj = frame_or_series(list(arr), dtype="i8") + tm.assert_equal(obj, expected) + + def test_constructor_coerce_float_fail(self, any_int_numpy_dtype): + # see gh-15832 + # Updated: make sure we treat this list the same as we would treat + # the equivalent ndarray + # GH#49599 pre-2.0 we silently retained float dtype, in 2.0 we raise + vals = [1, 2, 3.5] + + msg = "Trying to coerce float values to integer" + with pytest.raises(ValueError, match=msg): + Series(vals, dtype=any_int_numpy_dtype) + with pytest.raises(ValueError, match=msg): + Series(np.array(vals), dtype=any_int_numpy_dtype) + + def test_constructor_coerce_float_valid(self, float_numpy_dtype): + s = Series([1, 2, 3.5], dtype=float_numpy_dtype) + expected = Series([1, 2, 3.5]).astype(float_numpy_dtype) + tm.assert_series_equal(s, expected) + + def test_constructor_invalid_coerce_ints_with_float_nan(self, any_int_numpy_dtype): + # GH 22585 + # Updated: make sure we treat this list the same as we would treat the + # equivalent ndarray + vals = [1, 2, np.nan] + # pre-2.0 this would return with a float dtype, in 2.0 we raise + + msg = "cannot convert float NaN to integer" + with pytest.raises(ValueError, match=msg): + Series(vals, dtype=any_int_numpy_dtype) + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + with pytest.raises(IntCastingNaNError, match=msg): + Series(np.array(vals), dtype=any_int_numpy_dtype) + + def test_constructor_dtype_no_cast(self): + # see gh-1572 + s = Series([1, 2, 3]) + s2 = Series(s, dtype=np.int64) + + s2[1] = 5 + assert s[1] == 2 + + def test_constructor_datelike_coercion(self): + # GH 9477 + # incorrectly inferring on dateimelike looking when object dtype is + # specified + s = Series([Timestamp("20130101"), "NOV"], dtype=object) + assert s.iloc[0] == Timestamp("20130101") + assert s.iloc[1] == "NOV" + assert s.dtype == object + + def test_constructor_datelike_coercion2(self): + # the dtype was being reset on the slicing and re-inferred to datetime + # even thought the blocks are mixed + belly = "216 3T19".split() + wing1 = "2T15 4H19".split() + wing2 = "416 4T20".split() + mat = pd.to_datetime("2016-01-22 2019-09-07".split()) + df = DataFrame({"wing1": wing1, "wing2": wing2, "mat": mat}, index=belly) + + result = df.loc["3T19"] + assert result.dtype == object + result = df.loc["216"] + assert result.dtype == object + + def test_constructor_mixed_int_and_timestamp(self, frame_or_series): + # specifically Timestamp with nanos, not datetimes + objs = [Timestamp(9), 10, NaT._value] + result = frame_or_series(objs, dtype="M8[ns]") + + expected = frame_or_series([Timestamp(9), Timestamp(10), NaT]) + tm.assert_equal(result, expected) + + def test_constructor_datetimes_with_nulls(self): + # gh-15869 + for arr in [ + np.array([None, None, None, None, datetime.now(), None]), + np.array([None, None, datetime.now(), None]), + ]: + result = Series(arr) + assert result.dtype == "M8[us]" + + def test_constructor_dtype_datetime64(self): + s = Series(iNaT, dtype="M8[ns]", index=range(5)) + assert isna(s).all() + + # in theory this should be all nulls, but since + # we are not specifying a dtype is ambiguous + s = Series(iNaT, index=range(5)) + assert not isna(s).all() + + s = Series(np.nan, dtype="M8[ns]", index=range(5)) + assert isna(s).all() + + s = Series([datetime(2001, 1, 2, 0, 0), iNaT], dtype="M8[ns]") + assert isna(s[1]) + assert s.dtype == "M8[ns]" + + s = Series([datetime(2001, 1, 2, 0, 0), np.nan], dtype="M8[ns]") + assert isna(s[1]) + assert s.dtype == "M8[ns]" + + def test_constructor_dtype_datetime64_10(self): + # GH3416 + pydates = [datetime(2013, 1, 1), datetime(2013, 1, 2), datetime(2013, 1, 3)] + dates = [np.datetime64(x) for x in pydates] + + ser = Series(dates) + assert ser.dtype == "M8[us]" + + ser.iloc[0] = np.nan + assert ser.dtype == "M8[us]" + + # GH3414 related + expected = Series(pydates, dtype="datetime64[ms]") + + result = Series(Series(dates).astype(np.int64) / 1000, dtype="M8[ms]") + tm.assert_series_equal(result, expected) + + result = Series(dates, dtype="datetime64[ms]") + tm.assert_series_equal(result, expected) + + expected = Series( + [NaT, datetime(2013, 1, 2), datetime(2013, 1, 3)], dtype="datetime64[ns]" + ) + result = Series([np.nan, *dates[1:]], dtype="datetime64[ns]") + tm.assert_series_equal(result, expected) + + def test_constructor_dtype_datetime64_11(self): + pydates = [datetime(2013, 1, 1), datetime(2013, 1, 2), datetime(2013, 1, 3)] + dates = [np.datetime64(x) for x in pydates] + + dts = Series(dates, dtype="datetime64[ns]") + + # valid astype + dts.astype("int64") + + # invalid casting + msg = r"Converting from datetime64\[ns\] to int32 is not supported" + with pytest.raises(TypeError, match=msg): + dts.astype("int32") + + # ints are ok + # we test with np.int64 to get similar results on + # windows / 32-bit platforms + result = Series(dts, dtype=np.int64) + expected = Series(dts.astype(np.int64)) + tm.assert_series_equal(result, expected) + + def test_constructor_dtype_datetime64_9(self): + # invalid dates can be help as object + result = Series([datetime(2, 1, 1)]) + assert result[0] == datetime(2, 1, 1, 0, 0) + + result = Series([datetime(3000, 1, 1)]) + assert result[0] == datetime(3000, 1, 1, 0, 0) + + def test_constructor_dtype_datetime64_8(self): + # don't mix types + result = Series([Timestamp("20130101"), 1], index=["a", "b"]) + assert result["a"] == Timestamp("20130101") + assert result["b"] == 1 + + def test_constructor_dtype_datetime64_7(self): + # GH6529 + # coerce datetime64 non-ns properly + dates = date_range("01-Jan-2015", "01-Dec-2015", freq="ME") + values2 = dates.view(np.ndarray).astype("datetime64[ns]") + expected = Series(values2, index=dates) + + for unit in ["s", "D", "ms", "us", "ns"]: + dtype = np.dtype(f"M8[{unit}]") + values1 = dates.view(np.ndarray).astype(dtype) + result = Series(values1, dates) + if unit == "D": + # for unit="D" we cast to nearest-supported reso, i.e. "s" + dtype = np.dtype("M8[s]") + assert result.dtype == dtype + tm.assert_series_equal(result, expected.astype(dtype)) + + # GH 13876 + # coerce to non-ns to object properly + expected = Series(values2, index=dates, dtype=object) + for dtype in ["s", "D", "ms", "us", "ns"]: + values1 = dates.view(np.ndarray).astype(f"M8[{dtype}]") + result = Series(values1, index=dates, dtype=object) + tm.assert_series_equal(result, expected) + + # leave datetime.date alone + dates2 = np.array([d.date() for d in dates.to_pydatetime()], dtype=object) + series1 = Series(dates2, dates) + tm.assert_numpy_array_equal(series1.values, dates2) + assert series1.dtype == object + + def test_constructor_dtype_datetime64_6(self): + # as of 2.0, these no longer infer datetime64 based on the strings, + # matching the Index behavior + + ser = Series([None, NaT, "2013-08-05 15:30:00.000001"]) + assert ser.dtype == object + + ser = Series([np.nan, NaT, "2013-08-05 15:30:00.000001"]) + assert ser.dtype == object + + ser = Series([NaT, None, "2013-08-05 15:30:00.000001"]) + assert ser.dtype == object + + ser = Series([NaT, np.nan, "2013-08-05 15:30:00.000001"]) + assert ser.dtype == object + + def test_constructor_dtype_datetime64_5(self): + # tz-aware (UTC and other tz's) + # GH 8411 + dr = date_range("20130101", periods=3) + assert Series(dr).iloc[0].tz is None + dr = date_range("20130101", periods=3, tz="UTC") + assert str(Series(dr).iloc[0].tz) == "UTC" + dr = date_range("20130101", periods=3, tz="US/Eastern") + assert str(Series(dr).iloc[0].tz) == "US/Eastern" + + def test_constructor_dtype_datetime64_4(self): + # non-convertible + ser = Series([1479596223000, -1479590, NaT]) + assert ser.dtype == "object" + assert ser[2] is NaT + assert "NaT" in str(ser) + + def test_constructor_dtype_datetime64_3(self): + # if we passed a NaT it remains + ser = Series([datetime(2010, 1, 1), datetime(2, 1, 1), NaT]) + assert ser.dtype == "M8[us]" + assert ser[2] is NaT + assert "NaT" in str(ser) + + def test_constructor_dtype_datetime64_2(self): + # if we passed a nan it remains + ser = Series([datetime(2010, 1, 1), datetime(2, 1, 1), np.nan]) + assert ser.dtype == "M8[us]" + assert ser[2] is NaT + assert "NaT" in str(ser) + + def test_constructor_with_datetime_tz(self): + # 8260 + # support datetime64 with tz + + dr = date_range("20130101", periods=3, tz="US/Eastern", unit="ns") + s = Series(dr) + assert s.dtype.name == "datetime64[ns, US/Eastern]" + assert s.dtype == "datetime64[ns, US/Eastern]" + assert isinstance(s.dtype, DatetimeTZDtype) + assert "datetime64[ns, US/Eastern]" in str(s) + + # export + result = s.values + assert isinstance(result, np.ndarray) + assert result.dtype == "datetime64[ns]" + + exp = DatetimeIndex(result) + exp = exp.tz_localize("UTC").tz_convert(tz=s.dt.tz) + tm.assert_index_equal(dr, exp) + + # indexing + result = s.iloc[0] + assert result == Timestamp("2013-01-01 00:00:00-0500", tz="US/Eastern") + result = s[0] + assert result == Timestamp("2013-01-01 00:00:00-0500", tz="US/Eastern") + + result = s[Series([True, True, False], index=s.index)] + tm.assert_series_equal(result, s[0:2]) + + result = s.iloc[0:1] + tm.assert_series_equal(result, Series(dr[0:1])) + + # concat + result = pd.concat([s.iloc[0:1], s.iloc[1:]]) + tm.assert_series_equal(result, s) + + # short str + assert "datetime64[ns, US/Eastern]" in str(s) + + # formatting with NaT + result = s.shift() + assert "datetime64[ns, US/Eastern]" in str(result) + assert "NaT" in str(result) + + result = DatetimeIndex(s, freq="infer") + tm.assert_index_equal(result, dr) + + def test_constructor_with_datetime_tz5(self): + # long str + ser = Series(date_range("20130101", periods=1000, tz="US/Eastern", unit="ns")) + assert "datetime64[ns, US/Eastern]" in str(ser) + + def test_constructor_with_datetime_tz4(self): + # inference + ser = Series( + [ + Timestamp("2013-01-01 13:00:00-0800", tz="US/Pacific").as_unit("s"), + Timestamp("2013-01-02 14:00:00-0800", tz="US/Pacific").as_unit("s"), + ] + ) + assert ser.dtype == "datetime64[s, US/Pacific]" + assert lib.infer_dtype(ser, skipna=True) == "datetime64" + + def test_constructor_with_datetime_tz3(self): + ser = Series( + [ + Timestamp("2013-01-01 13:00:00-0800", tz="US/Pacific"), + Timestamp("2013-01-02 14:00:00-0800", tz="US/Eastern"), + ] + ) + assert ser.dtype == "object" + assert lib.infer_dtype(ser, skipna=True) == "datetime" + + def test_constructor_with_datetime_tz2(self): + # with all NaT + ser = Series(NaT, index=[0, 1], dtype="datetime64[ns, US/Eastern]") + dti = DatetimeIndex(["NaT", "NaT"], tz="US/Eastern").as_unit("ns") + expected = Series(dti) + tm.assert_series_equal(ser, expected) + + def test_constructor_no_partial_datetime_casting(self): + # GH#40111 + vals = [ + "nan", + Timestamp("1990-01-01"), + "2015-03-14T16:15:14.123-08:00", + "2019-03-04T21:56:32.620-07:00", + None, + ] + ser = Series(vals) + assert all(ser[i] is vals[i] for i in range(len(vals))) + + @pytest.mark.parametrize("arr_dtype", [np.int64, np.float64]) + @pytest.mark.parametrize("kind", ["M", "m"]) + @pytest.mark.parametrize("unit", ["ns", "us", "ms", "s", "h", "m", "D"]) + def test_construction_to_datetimelike_unit(self, arr_dtype, kind, unit): + # tests all units + # gh-19223 + # TODO: GH#19223 was about .astype, doesn't belong here + dtype = f"{kind}8[{unit}]" + arr = np.array([1, 2, 3], dtype=arr_dtype) + ser = Series(arr) + result = ser.astype(dtype) + + expected = Series(arr.astype(dtype)) + + if unit in ["ns", "us", "ms", "s"]: + assert result.dtype == dtype + assert expected.dtype == dtype + else: + # Otherwise we cast to nearest-supported unit, i.e. seconds + assert result.dtype == f"{kind}8[s]" + assert expected.dtype == f"{kind}8[s]" + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("arg", ["2013-01-01 00:00:00", NaT, np.nan, None]) + def test_constructor_with_naive_string_and_datetimetz_dtype(self, arg): + # GH 17415: With naive string + result = Series([arg], dtype="datetime64[ns, CET]") + expected = Series([Timestamp(arg)], dtype="M8[ns]").dt.tz_localize("CET") + tm.assert_series_equal(result, expected) + + def test_constructor_datetime64_bigendian(self): + # GH#30976 + ms = np.datetime64(1, "ms") + arr = np.array([np.datetime64(1, "ms")], dtype=">M8[ms]") + + result = Series(arr) + expected = Series([Timestamp(ms)]).astype("M8[ms]") + assert expected.dtype == "M8[ms]" + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("interval_constructor", [IntervalIndex, IntervalArray]) + def test_construction_interval(self, interval_constructor): + # construction from interval & array of intervals + intervals = interval_constructor.from_breaks(np.arange(3), closed="right") + result = Series(intervals) + assert result.dtype == "interval[int64, right]" + tm.assert_index_equal(Index(result.values), Index(intervals)) + + @pytest.mark.parametrize( + "data_constructor", [list, np.array], ids=["list", "ndarray[object]"] + ) + def test_constructor_infer_interval(self, data_constructor): + # GH 23563: consistent closed results in interval dtype + data = [Interval(0, 1), Interval(0, 2), None] + result = Series(data_constructor(data)) + expected = Series(IntervalArray(data)) + assert result.dtype == "interval[float64, right]" + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "data_constructor", [list, np.array], ids=["list", "ndarray[object]"] + ) + def test_constructor_interval_mixed_closed(self, data_constructor): + # GH 23563: mixed closed results in object dtype (not interval dtype) + data = [Interval(0, 1, closed="both"), Interval(0, 2, closed="neither")] + result = Series(data_constructor(data)) + assert result.dtype == object + assert result.tolist() == data + + def test_construction_consistency(self): + # make sure that we are not re-localizing upon construction + # GH 14928 + ser = Series(date_range("20130101", periods=3, tz="US/Eastern")) + + result = Series(ser, dtype=ser.dtype) + tm.assert_series_equal(result, ser) + + result = Series(ser.dt.tz_convert("UTC"), dtype=ser.dtype) + tm.assert_series_equal(result, ser) + + # Pre-2.0 dt64 values were treated as utc, which was inconsistent + # with DatetimeIndex, which treats them as wall times, see GH#33401 + result = Series(ser.values, dtype=ser.dtype) + expected = Series(ser.values).dt.tz_localize(ser.dtype.tz) + tm.assert_series_equal(result, expected) + + with tm.assert_produces_warning(None): + # one suggested alternative to the deprecated (changed in 2.0) usage + middle = Series(ser.values).dt.tz_localize("UTC") + result = middle.dt.tz_convert(ser.dtype.tz) + tm.assert_series_equal(result, ser) + + with tm.assert_produces_warning(None): + # the other suggested alternative to the deprecated usage + result = Series(ser.values.view("int64"), dtype=ser.dtype) + tm.assert_series_equal(result, ser) + + @pytest.mark.parametrize( + "data_constructor", [list, np.array], ids=["list", "ndarray[object]"] + ) + def test_constructor_infer_period(self, data_constructor): + data = [Period("2000", "D"), Period("2001", "D"), None] + result = Series(data_constructor(data)) + expected = Series(period_array(data)) + tm.assert_series_equal(result, expected) + assert result.dtype == "Period[D]" + + @pytest.mark.xfail(reason="PeriodDtype Series not supported yet") + def test_construct_from_ints_including_iNaT_scalar_period_dtype(self): + series = Series([0, 1000, 2000, pd._libs.iNaT], dtype="period[D]") + + val = series[3] + assert isna(val) + + series[2] = val + assert isna(series[2]) + + def test_constructor_period_incompatible_frequency(self): + data = [Period("2000", "D"), Period("2001", "Y")] + result = Series(data) + assert result.dtype == object + assert result.tolist() == data + + def test_constructor_periodindex(self): + # GH7932 + # converting a PeriodIndex when put in a Series + + pi = period_range("20130101", periods=5, freq="D") + s = Series(pi) + assert s.dtype == "Period[D]" + expected = Series(pi.astype(object)) + assert expected.dtype == object + + def test_constructor_dict(self): + d = {"a": 0.0, "b": 1.0, "c": 2.0} + + result = Series(d) + expected = Series(d, index=sorted(d.keys())) + tm.assert_series_equal(result, expected) + + result = Series(d, index=["b", "c", "d", "a"]) + expected = Series([1, 2, np.nan, 0], index=["b", "c", "d", "a"]) + tm.assert_series_equal(result, expected) + + pidx = period_range("2020-01-01", periods=10, freq="D") + d = {pidx[0]: 0, pidx[1]: 1} + result = Series(d, index=pidx) + expected = Series(np.nan, pidx, dtype=np.float64) + expected.iloc[0] = 0 + expected.iloc[1] = 1 + tm.assert_series_equal(result, expected) + + def test_constructor_dict_list_value_explicit_dtype(self): + # GH 18625 + d = {"a": [[2], [3], [4]]} + result = Series(d, index=["a"], dtype="object") + expected = Series(d, index=["a"]) + tm.assert_series_equal(result, expected) + + def test_constructor_dict_order(self): + # GH19018 + # initialization ordering: by insertion order + d = {"b": 1, "a": 0, "c": 2} + result = Series(d) + expected = Series([1, 0, 2], index=list("bac")) + tm.assert_series_equal(result, expected) + + def test_constructor_dict_extension(self, ea_scalar_and_dtype): + ea_scalar, ea_dtype = ea_scalar_and_dtype + d = {"a": ea_scalar} + result = Series(d, index=["a"]) + expected = Series(ea_scalar, index=["a"], dtype=ea_dtype) + + assert result.dtype == ea_dtype + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("value", [2, np.nan, None, float("nan")]) + def test_constructor_dict_nan_key(self, value): + # GH 18480 + d = {1: "a", value: "b", float("nan"): "c", 4: "d"} + result = Series(d).sort_values() + expected = Series(["a", "b", "c", "d"], index=[1, value, np.nan, 4]) + tm.assert_series_equal(result, expected) + + # MultiIndex: + d = {(1, 1): "a", (2, np.nan): "b", (3, value): "c"} + result = Series(d).sort_values() + expected = Series( + ["a", "b", "c"], index=Index([(1, 1), (2, np.nan), (3, value)]) + ) + tm.assert_series_equal(result, expected) + + def test_constructor_dict_datetime64_index(self): + # GH 9456 + + dates_as_str = ["1984-02-19", "1988-11-06", "1989-12-03", "1990-03-15"] + values = [42544017.198965244, 1234565, 40512335.181958228, -1] + + def create_data(constructor): + return dict(zip((constructor(x) for x in dates_as_str), values)) + + data_datetime64 = create_data(np.datetime64) + data_datetime = create_data(lambda x: datetime.strptime(x, "%Y-%m-%d")) + data_Timestamp = create_data(Timestamp) + + expected = Series(values, (Timestamp(x) for x in dates_as_str)) + + result_datetime64 = Series(data_datetime64) + result_datetime = Series(data_datetime) + result_Timestamp = Series(data_Timestamp) + + tm.assert_series_equal( + result_datetime64, expected.set_axis(expected.index.as_unit("s")) + ) + tm.assert_series_equal(result_datetime, expected) + tm.assert_series_equal(result_Timestamp, expected) + + def test_constructor_dict_tuple_indexer(self): + # GH 12948 + data = {(1, 1, None): -1.0} + result = Series(data) + expected = Series( + -1.0, index=MultiIndex(levels=[[1], [1], [np.nan]], codes=[[0], [0], [-1]]) + ) + tm.assert_series_equal(result, expected) + + def test_constructor_mapping(self, non_dict_mapping_subclass): + # GH 29788 + ndm = non_dict_mapping_subclass({3: "three"}) + result = Series(ndm) + expected = Series(["three"], index=[3]) + + tm.assert_series_equal(result, expected) + + def test_constructor_list_of_tuples(self): + data = [(1, 1), (2, 2), (2, 3)] + s = Series(data) + assert list(s) == data + + def test_constructor_tuple_of_tuples(self): + data = ((1, 1), (2, 2), (2, 3)) + s = Series(data) + assert tuple(s) == data + + @pytest.mark.parametrize( + "data, expected_values, expected_index", + [ + ({(1, 2): 3, (None, 5): 6}, [3, 6], [(1, 2), (None, 5)]), + ({(1,): 3, (4, 5): 6}, [3, 6], [(1, None), (4, 5)]), + ], + ) + def test_constructor_dict_of_tuples(self, data, expected_values, expected_index): + # GH 60695 + result = Series(data).sort_values() + expected = Series(expected_values, index=MultiIndex.from_tuples(expected_index)) + tm.assert_series_equal(result, expected) + + # https://github.com/pandas-dev/pandas/issues/22698 + @pytest.mark.filterwarnings("ignore:elementwise comparison:FutureWarning") + def test_fromDict(self, using_infer_string): + data = {"a": 0, "b": 1, "c": 2, "d": 3} + + series = Series(data) + tm.assert_is_sorted(series.index) + + data = {"a": 0, "b": "1", "c": "2", "d": datetime.now()} + series = Series(data) + assert series.dtype == np.object_ + + data = {"a": 0, "b": "1", "c": "2", "d": "3"} + series = Series(data) + assert series.dtype == np.object_ if not using_infer_string else "str" + + data = {"a": "0", "b": "1"} + series = Series(data, dtype=float) + assert series.dtype == np.float64 + + def test_fromValue(self, datetime_series, using_infer_string): + nans = Series(np.nan, index=datetime_series.index, dtype=np.float64) + assert nans.dtype == np.float64 + assert len(nans) == len(datetime_series) + + strings = Series("foo", index=datetime_series.index) + assert strings.dtype == np.object_ if not using_infer_string else "str" + assert len(strings) == len(datetime_series) + + d = datetime.now() + dates = Series(d, index=datetime_series.index) + assert dates.dtype == "M8[us]" + assert len(dates) == len(datetime_series) + + # GH12336 + # Test construction of categorical series from value + categorical = Series(0, index=datetime_series.index, dtype="category") + expected = Series(0, index=datetime_series.index).astype("category") + assert categorical.dtype == "category" + assert len(categorical) == len(datetime_series) + tm.assert_series_equal(categorical, expected) + + def test_constructor_dtype_timedelta64(self): + # basic + td = Series([timedelta(days=i) for i in range(3)]) + assert td.dtype == "timedelta64[us]" + + td = Series([timedelta(days=1)]) + assert td.dtype == "timedelta64[us]" + + td = Series([timedelta(days=1), timedelta(days=2), np.timedelta64(1, "s")]) + + assert td.dtype == "timedelta64[us]" + + # mixed with NaT + td = Series([timedelta(days=1), NaT], dtype="m8[ns]") + assert td.dtype == "timedelta64[ns]" + + td = Series([timedelta(days=1), np.nan], dtype="m8[ns]") + assert td.dtype == "timedelta64[ns]" + + td = Series([np.timedelta64(300000000), NaT], dtype="m8[ns]") + assert td.dtype == "timedelta64[ns]" + + # improved inference + # GH5689 + td = Series([np.timedelta64(300000000), NaT]) + assert td.dtype == "timedelta64[ns]" + + # because iNaT is int, not coerced to timedelta + td = Series([np.timedelta64(300000000), iNaT]) + assert td.dtype == "object" + + td = Series([np.timedelta64(300000000), np.nan]) + assert td.dtype == "timedelta64[ns]" + + td = Series([NaT, np.timedelta64(300000000)]) + assert td.dtype == "timedelta64[ns]" + + td = Series([np.timedelta64(1, "s")]) + assert td.dtype == "timedelta64[s]" + + # valid astype + td.astype("int64") + + # invalid casting + msg = r"Converting from timedelta64\[s\] to int32 is not supported" + with pytest.raises(TypeError, match=msg): + td.astype("int32") + + # this is an invalid casting + msg = "|".join( + [ + "Could not convert object to NumPy timedelta", + "Could not convert 'foo' to NumPy timedelta", + ] + ) + with pytest.raises(ValueError, match=msg): + Series([timedelta(days=1), "foo"], dtype="m8[ns]") + + # leave as object here + td = Series([timedelta(days=i) for i in range(3)] + ["foo"]) + assert td.dtype == "object" + + # as of 2.0, these no longer infer timedelta64 based on the strings, + # matching Index behavior + ser = Series([None, NaT, "1 Day"]) + assert ser.dtype == object + + ser = Series([np.nan, NaT, "1 Day"]) + assert ser.dtype == object + + ser = Series([NaT, None, "1 Day"]) + assert ser.dtype == object + + ser = Series([NaT, np.nan, "1 Day"]) + assert ser.dtype == object + + # GH 16406 + def test_constructor_mixed_tz(self): + s = Series([Timestamp("20130101"), Timestamp("20130101", tz="US/Eastern")]) + expected = Series( + [Timestamp("20130101"), Timestamp("20130101", tz="US/Eastern")], + dtype="object", + ) + tm.assert_series_equal(s, expected) + + def test_NaT_scalar(self): + series = Series([0, 1000, 2000, iNaT], dtype="M8[ns]") + + val = series[3] + assert isna(val) + + series[2] = val + assert isna(series[2]) + + def test_NaT_cast(self): + # GH10747 + result = Series([np.nan]).astype("M8[ns]") + expected = Series([NaT], dtype="M8[ns]") + tm.assert_series_equal(result, expected) + + def test_constructor_name_hashable(self): + for n in [777, 777.0, "name", datetime(2001, 11, 11), (1,), "\u05d0"]: + for data in [[1, 2, 3], np.ones(3), {"a": 0, "b": 1}]: + s = Series(data, name=n) + assert s.name == n + + def test_constructor_name_unhashable(self): + msg = r"Series\.name must be a hashable type" + for n in [["name_list"], np.ones(2), {1: 2}]: + for data in [["name_list"], np.ones(2), {1: 2}]: + with pytest.raises(TypeError, match=msg): + Series(data, name=n) + + def test_auto_conversion(self): + series = Series(list(date_range("1/1/2000", periods=10, unit="ns"))) + assert series.dtype == "M8[ns]" + + def test_convert_non_ns(self): + # convert from a numpy array of non-ns timedelta64 + arr = np.array([1, 2, 3], dtype="timedelta64[s]") + ser = Series(arr) + assert ser.dtype == arr.dtype + + tdi = timedelta_range("00:00:01", periods=3, freq="s").as_unit("s") + expected = Series(tdi) + assert expected.dtype == arr.dtype + tm.assert_series_equal(ser, expected) + + # convert from a numpy array of non-ns datetime64 + arr = np.array( + ["2013-01-01", "2013-01-02", "2013-01-03"], dtype="datetime64[D]" + ) + ser = Series(arr) + expected = Series(date_range("20130101", periods=3, freq="D"), dtype="M8[s]") + assert expected.dtype == "M8[s]" + tm.assert_series_equal(ser, expected) + + arr = np.array( + ["2013-01-01 00:00:01", "2013-01-01 00:00:02", "2013-01-01 00:00:03"], + dtype="datetime64[s]", + ) + ser = Series(arr) + expected = Series( + date_range("20130101 00:00:01", periods=3, freq="s"), dtype="M8[s]" + ) + assert expected.dtype == "M8[s]" + tm.assert_series_equal(ser, expected) + + @pytest.mark.parametrize( + "index", + [ + date_range("1/1/2000", periods=10), + timedelta_range("1 day", periods=10), + period_range("2000-Q1", periods=10, freq="Q"), + ], + ids=lambda x: type(x).__name__, + ) + def test_constructor_cant_cast_datetimelike(self, index): + # floats are not ok + # strip Index to convert PeriodIndex -> Period + # We don't care whether the error message says + # PeriodIndex or PeriodArray + msg = f"Cannot cast {type(index).__name__.rstrip('Index')}.*? to " + + with pytest.raises(TypeError, match=msg): + Series(index, dtype=float) + + # ints are ok + # we test with np.int64 to get similar results on + # windows / 32-bit platforms + result = Series(index, dtype=np.int64) + expected = Series(index.astype(np.int64)) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "index", + [ + date_range("1/1/2000", periods=10), + timedelta_range("1 day", periods=10), + period_range("2000-Q1", periods=10, freq="Q"), + ], + ids=lambda x: type(x).__name__, + ) + def test_constructor_cast_object(self, index): + s = Series(index, dtype=object) + exp = Series(index).astype(object) + tm.assert_series_equal(s, exp) + + s = Series(Index(index, dtype=object), dtype=object) + exp = Series(index).astype(object) + tm.assert_series_equal(s, exp) + + s = Series(index.astype(object), dtype=object) + exp = Series(index).astype(object) + tm.assert_series_equal(s, exp) + + @pytest.mark.parametrize("dtype", [np.datetime64, np.timedelta64]) + def test_constructor_generic_timestamp_no_frequency(self, dtype, request): + # see gh-15524, gh-15987 + msg = "dtype has no unit. Please pass in" + + if np.dtype(dtype).name not in ["timedelta64", "datetime64"]: + mark = pytest.mark.xfail(reason="GH#33890 Is assigned ns unit") + request.applymarker(mark) + + with pytest.raises(ValueError, match=msg): + Series([], dtype=dtype) + + @pytest.mark.parametrize("unit", ["ps", "as", "fs", "Y", "M", "W", "D", "h", "m"]) + @pytest.mark.parametrize("kind", ["m", "M"]) + def test_constructor_generic_timestamp_bad_frequency(self, kind, unit): + # see gh-15524, gh-15987 + # as of 2.0 we raise on any non-supported unit rather than silently + # cast to nanos; previously we only raised for frequencies higher + # than ns + dtype = f"{kind}8[{unit}]" + + msg = "dtype=.* is not supported. Supported resolutions are" + with pytest.raises(TypeError, match=msg): + Series([], dtype=dtype) + + with pytest.raises(TypeError, match=msg): + # pre-2.0 the DataFrame cast raised but the Series case did not + DataFrame([[0]], dtype=dtype) + + @pytest.mark.parametrize("dtype", [None, "uint8", "category"]) + def test_constructor_range_dtype(self, dtype): + # GH 16804 + expected = Series([0, 1, 2, 3, 4], dtype=dtype or "int64") + result = Series(range(5), dtype=dtype) + tm.assert_series_equal(result, expected) + + def test_constructor_range_overflows(self): + # GH#30173 range objects that overflow int64 + rng = range(2**63, 2**63 + 4) + ser = Series(rng) + expected = Series(list(rng)) + tm.assert_series_equal(ser, expected) + assert list(ser) == list(rng) + assert ser.dtype == np.uint64 + + rng2 = range(2**63 + 4, 2**63, -1) + ser2 = Series(rng2) + expected2 = Series(list(rng2)) + tm.assert_series_equal(ser2, expected2) + assert list(ser2) == list(rng2) + assert ser2.dtype == np.uint64 + + rng3 = range(-(2**63), -(2**63) - 4, -1) + ser3 = Series(rng3) + expected3 = Series(list(rng3)) + tm.assert_series_equal(ser3, expected3) + assert list(ser3) == list(rng3) + assert ser3.dtype == object + + rng4 = range(2**73, 2**73 + 4) + ser4 = Series(rng4) + expected4 = Series(list(rng4)) + tm.assert_series_equal(ser4, expected4) + assert list(ser4) == list(rng4) + assert ser4.dtype == object + + def test_constructor_tz_mixed_data(self): + # GH 13051 + dt_list = [ + Timestamp("2016-05-01 02:03:37"), + Timestamp("2016-04-30 19:03:37-0700", tz="US/Pacific"), + ] + result = Series(dt_list) + expected = Series(dt_list, dtype=object) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("pydt", [True, False]) + def test_constructor_data_aware_dtype_naive(self, tz_aware_fixture, pydt): + # GH#25843, GH#41555, GH#33401 + tz = tz_aware_fixture + ts = Timestamp("2019", tz=tz) + if pydt: + ts = ts.to_pydatetime() + + msg = ( + "Cannot convert timezone-aware data to timezone-naive dtype. " + r"Use pd.Series\(values\).dt.tz_localize\(None\) instead." + ) + with pytest.raises(ValueError, match=msg): + Series([ts], dtype="datetime64[ns]") + + with pytest.raises(ValueError, match=msg): + Series(np.array([ts], dtype=object), dtype="datetime64[ns]") + + with pytest.raises(ValueError, match=msg): + Series({0: ts}, dtype="datetime64[ns]") + + msg = "Cannot unbox tzaware Timestamp to tznaive dtype" + with pytest.raises(TypeError, match=msg): + Series(ts, index=[0], dtype="datetime64[ns]") + + def test_constructor_datetime64(self): + rng = date_range("1/1/2000 00:00:00", "1/1/2000 1:59:50", freq="10s") + dates = np.asarray(rng) + + series = Series(dates) + assert np.issubdtype(series.dtype, np.dtype("M8[ns]")) + + def test_constructor_datetimelike_scalar_to_string_dtype( + self, nullable_string_dtype + ): + # https://github.com/pandas-dev/pandas/pull/33846 + result = Series("M", index=[1, 2, 3], dtype=nullable_string_dtype) + expected = Series(["M", "M", "M"], index=[1, 2, 3], dtype=nullable_string_dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("box", [lambda x: x, np.datetime64]) + def test_constructor_sparse_datetime64(self, box): + # https://github.com/pandas-dev/pandas/issues/35762 + values = [box("2012-01-01"), box("2013-01-01")] + dtype = pd.SparseDtype("datetime64[ns]") + result = Series(values, dtype=dtype) + arr = pd.arrays.SparseArray(values, dtype=dtype) + expected = Series(arr) + tm.assert_series_equal(result, expected) + + def test_construction_from_ordered_collection(self): + # https://github.com/pandas-dev/pandas/issues/36044 + result = Series({"a": 1, "b": 2}.keys()) + expected = Series(["a", "b"]) + tm.assert_series_equal(result, expected) + + result = Series({"a": 1, "b": 2}.values()) + expected = Series([1, 2]) + tm.assert_series_equal(result, expected) + + def test_construction_from_large_int_scalar_no_overflow(self): + # https://github.com/pandas-dev/pandas/issues/36291 + n = 1_000_000_000_000_000_000_000 + result = Series(n, index=[0]) + expected = Series(n) + tm.assert_series_equal(result, expected) + + def test_constructor_list_of_periods_infers_period_dtype(self): + series = Series(list(period_range("2000-01-01", periods=10, freq="D"))) + assert series.dtype == "Period[D]" + + series = Series( + [Period("2011-01-01", freq="D"), Period("2011-02-01", freq="D")] + ) + assert series.dtype == "Period[D]" + + def test_constructor_subclass_dict(self, dict_subclass): + data = dict_subclass((x, 10.0 * x) for x in range(10)) + series = Series(data) + expected = Series(dict(data.items())) + tm.assert_series_equal(series, expected) + + def test_constructor_ordereddict(self): + # GH3283 + data = OrderedDict( + (f"col{i}", np.random.default_rng(2).random()) for i in range(12) + ) + + series = Series(data) + expected = Series(list(data.values()), list(data.keys())) + tm.assert_series_equal(series, expected) + + # Test with subclass + class A(OrderedDict): + pass + + series = Series(A(data)) + tm.assert_series_equal(series, expected) + + @pytest.mark.parametrize( + "data, expected_index_multi", + [ + ({("a", "a"): 0.0, ("b", "a"): 1.0, ("b", "c"): 2.0}, True), + ({("a",): 0.0, ("a", "b"): 1.0}, True), + ({"z": 111.0, ("a", "a"): 0.0, ("b", "a"): 1.0, ("b", "c"): 2.0}, False), + ], + ) + def test_constructor_dict_multiindex(self, data, expected_index_multi): + # GH#60695 + result = Series(data) + + if expected_index_multi: + expected = Series( + list(data.values()), + index=MultiIndex.from_tuples(list(data.keys())), + ) + tm.assert_series_equal(result, expected) + else: + expected = Series( + list(data.values()), + index=Index(list(data.keys())), + ) + tm.assert_series_equal(result, expected) + + def test_constructor_dict_multiindex_reindex_flat(self): + # construction involves reindexing with a MultiIndex corner case + data = {("i", "i"): 0, ("i", "j"): 1, ("j", "i"): 2, "j": np.nan} + expected = Series(data) + + result = Series(expected[:-1].to_dict(), index=expected.index) + tm.assert_series_equal(result, expected) + + def test_constructor_dict_timedelta_index(self): + # GH #12169 : Resample category data with timedelta index + # construct Series from dict as data and TimedeltaIndex as index + # will result NaN in result Series data + expected = Series( + data=["A", "B", "C"], index=pd.to_timedelta([0, 10, 20], unit="s") + ) + + result = Series( + data={ + pd.to_timedelta(0, unit="s"): "A", + pd.to_timedelta(10, unit="s"): "B", + pd.to_timedelta(20, unit="s"): "C", + }, + index=pd.to_timedelta([0, 10, 20], unit="s"), + ) + tm.assert_series_equal(result, expected) + + def test_constructor_infer_index_tz(self): + values = [188.5, 328.25] + tzinfo = tzoffset(None, 7200) + index = [ + datetime(2012, 5, 11, 11, tzinfo=tzinfo), + datetime(2012, 5, 11, 12, tzinfo=tzinfo), + ] + series = Series(data=values, index=index) + + assert series.index.tz == tzinfo + + # it works! GH#2443 + repr(series.index[0]) + + def test_constructor_with_pandas_dtype(self): + # going through 2D->1D path + vals = [(1,), (2,), (3,)] + ser = Series(vals) + dtype = ser.array.dtype # NumpyEADtype + ser2 = Series(vals, dtype=dtype) + tm.assert_series_equal(ser, ser2) + + def test_constructor_int_dtype_missing_values(self): + # GH#43017 + result = Series(index=[0], dtype="int64") + expected = Series(np.nan, index=[0], dtype="float64") + tm.assert_series_equal(result, expected) + + def test_constructor_bool_dtype_missing_values(self): + # GH#43018 + result = Series(index=[0], dtype="bool") + expected = Series(True, index=[0], dtype="bool") + tm.assert_series_equal(result, expected) + + def test_constructor_int64_dtype(self, any_int_dtype): + # GH#44923 + result = Series(["0", "1", "2"], dtype=any_int_dtype) + expected = Series([0, 1, 2], dtype=any_int_dtype) + tm.assert_series_equal(result, expected) + + def test_constructor_raise_on_lossy_conversion_of_strings(self): + # GH#44923 + if not np_version_gt2: + raises = pytest.raises( + ValueError, match="string values cannot be losslessly cast to int8" + ) + else: + raises = pytest.raises( + OverflowError, match="The elements provided in the data" + ) + with raises: + Series(["128"], dtype="int8") + + def test_constructor_dtype_timedelta_alternative_construct(self): + # GH#35465 + result = Series([1000000, 200000, 3000000], dtype="timedelta64[us]") + expected = Series(pd.to_timedelta([1000000, 200000, 3000000], unit="us")) + tm.assert_series_equal(result, expected) + + def test_constructor_dtype_timedelta_ns_s_astype_int64(self): + # GH#35465 + result = Series([1000000, 200000, 3000000], dtype="timedelta64[ns]").astype( + "int64" + ) + expected = Series([1000000, 200000, 3000000], dtype="timedelta64[s]").astype( + "int64" + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.filterwarnings( + "ignore:elementwise comparison failed:DeprecationWarning" + ) + @pytest.mark.parametrize("func", [Series, DataFrame, Index, pd.array]) + def test_constructor_mismatched_null_nullable_dtype( + self, func, any_numeric_ea_dtype + ): + # GH#44514 + msg = "|".join( + [ + "cannot safely cast non-equivalent object", + r"int\(\) argument must be a string, a bytes-like object " + "or a (real )?number", + r"Cannot cast array data from dtype\('O'\) to dtype\('float64'\) " + "according to the rule 'safe'", + "object cannot be converted to a FloatingDtype", + "'values' contains non-numeric NA", + ] + ) + + for null in [*tm.NP_NAT_OBJECTS, NaT]: + with pytest.raises(TypeError, match=msg): + func([null, 1.0, 3.0], dtype=any_numeric_ea_dtype) + + def test_series_constructor_ea_int_from_bool(self): + # GH#42137 + result = Series([True, False, True, pd.NA], dtype="Int64") + expected = Series([1, 0, 1, pd.NA], dtype="Int64") + tm.assert_series_equal(result, expected) + + result = Series([True, False, True], dtype="Int64") + expected = Series([1, 0, 1], dtype="Int64") + tm.assert_series_equal(result, expected) + + def test_series_constructor_ea_int_from_string_bool(self): + # GH#42137 + with pytest.raises(ValueError, match="invalid literal"): + Series(["True", "False", "True", pd.NA], dtype="Int64") + + @pytest.mark.parametrize("val", [1, 1.0]) + def test_series_constructor_overflow_uint_ea(self, val): + # GH#38798 + max_val = np.iinfo(np.uint64).max - 1 + result = Series([max_val, val], dtype="UInt64") + expected = Series(np.array([max_val, 1], dtype="uint64"), dtype="UInt64") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("val", [1, 1.0]) + def test_series_constructor_overflow_uint_ea_with_na(self, val): + # GH#38798 + max_val = np.iinfo(np.uint64).max - 1 + result = Series([max_val, val, pd.NA], dtype="UInt64") + expected = Series( + IntegerArray( + np.array([max_val, 1, 0], dtype="uint64"), + np.array([0, 0, 1], dtype=np.bool_), + ) + ) + tm.assert_series_equal(result, expected) + + def test_series_constructor_overflow_uint_with_nan(self): + # GH#38798 + max_val = np.iinfo(np.uint64).max - 1 + result = Series([max_val, pd.NA], dtype="UInt64") + expected = Series( + IntegerArray( + np.array([max_val, 1], dtype="uint64"), + np.array([0, 1], dtype=np.bool_), + ) + ) + tm.assert_series_equal(result, expected) + + def test_series_constructor_ea_all_na(self): + # GH#38798 + result = Series([pd.NA, pd.NA], dtype="UInt64") + expected = Series( + IntegerArray( + np.array([1, 1], dtype="uint64"), + np.array([1, 1], dtype=np.bool_), + ) + ) + tm.assert_series_equal(result, expected) + + def test_series_from_index_dtype_equal_does_not_copy(self): + # GH#52008 + idx = Index([1, 2, 3]) + expected = idx.copy(deep=True) + ser = Series(idx, dtype="int64") + ser.iloc[0] = 100 + tm.assert_index_equal(idx, expected) + + def test_series_string_inference(self): + # GH#54430 + with pd.option_context("future.infer_string", True): + ser = Series(["a", "b"]) + dtype = pd.StringDtype("pyarrow" if HAS_PYARROW else "python", na_value=np.nan) + expected = Series(["a", "b"], dtype=dtype) + tm.assert_series_equal(ser, expected) + + expected = Series(["a", 1], dtype="object") + with pd.option_context("future.infer_string", True): + ser = Series(["a", 1]) + tm.assert_series_equal(ser, expected) + + @pytest.mark.parametrize("na_value", [None, np.nan, pd.NA]) + def test_series_string_with_na_inference(self, na_value): + # GH#54430 + with pd.option_context("future.infer_string", True): + ser = Series(["a", na_value]) + dtype = pd.StringDtype("pyarrow" if HAS_PYARROW else "python", na_value=np.nan) + expected = Series(["a", None], dtype=dtype) + tm.assert_series_equal(ser, expected) + + def test_series_string_inference_scalar(self): + # GH#54430 + with pd.option_context("future.infer_string", True): + ser = Series("a", index=[1]) + dtype = pd.StringDtype("pyarrow" if HAS_PYARROW else "python", na_value=np.nan) + expected = Series("a", index=[1], dtype=dtype) + tm.assert_series_equal(ser, expected) + + def test_series_string_inference_array_string_dtype(self): + # GH#54496 + with pd.option_context("future.infer_string", True): + ser = Series(np.array(["a", "b"])) + dtype = pd.StringDtype("pyarrow" if HAS_PYARROW else "python", na_value=np.nan) + expected = Series(["a", "b"], dtype=dtype) + tm.assert_series_equal(ser, expected) + + def test_series_string_inference_storage_definition(self): + # https://github.com/pandas-dev/pandas/issues/54793 + # but after PDEP-14 (string dtype), it was decided to keep dtype="string" + # returning the NA string dtype, so expected is changed from + # "string[pyarrow_numpy]" to "string[python]" + expected = Series( + ["a", "b"], dtype="string[pyarrow]" if HAS_PYARROW else "string[python]" + ) + with pd.option_context("future.infer_string", True): + result = Series(["a", "b"], dtype="string") + tm.assert_series_equal(result, expected) + + expected = Series(["a", "b"], dtype=pd.StringDtype(na_value=np.nan)) + with pd.option_context("future.infer_string", True): + result = Series(["a", "b"], dtype="str") + tm.assert_series_equal(result, expected) + + def test_series_constructor_infer_string_scalar(self): + # GH#55537 + with pd.option_context("future.infer_string", True): + ser = Series("a", index=[1, 2], dtype="string[python]") + expected = Series(["a", "a"], index=[1, 2], dtype="string[python]") + tm.assert_series_equal(ser, expected) + assert ser.dtype.storage == "python" + + def test_series_string_inference_na_first(self): + # GH#55655 + with pd.option_context("future.infer_string", True): + result = Series([pd.NA, "b"]) + dtype = pd.StringDtype("pyarrow" if HAS_PYARROW else "python", na_value=np.nan) + expected = Series([None, "b"], dtype=dtype) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("klass", [Series, Index]) + def test_inference_on_pandas_objects(self, klass): + # GH#56012 + obj = klass([Timestamp("2019-12-31")], dtype=object) + # This doesn't do inference + result = Series(obj) + assert result.dtype == np.object_ + + +class TestSeriesConstructorIndexCoercion: + def test_series_constructor_datetimelike_index_coercion(self): + idx = date_range("2020-01-01", periods=5) + ser = Series( + np.random.default_rng(2).standard_normal(len(idx)), idx.astype(object) + ) + # as of 2.0, we no longer silently cast the object-dtype index + # to DatetimeIndex GH#39307, GH#23598 + assert not isinstance(ser.index, DatetimeIndex) + + @pytest.mark.parametrize("container", [None, np.array, Series, Index]) + @pytest.mark.parametrize("data", [1.0, range(4)]) + def test_series_constructor_infer_multiindex(self, container, data): + indexes = [["a", "a", "b", "b"], ["x", "y", "x", "y"]] + if container is not None: + indexes = [container(ind) for ind in indexes] + + multi = Series(data, index=indexes) + assert isinstance(multi.index, MultiIndex) + + # TODO: make this not cast to object in pandas 3.0 + @pytest.mark.skipif( + not np_version_gt2, reason="StringDType only available in numpy 2 and above" + ) + @pytest.mark.parametrize( + "data", + [ + ["a", "b", "c"], + ["a", "b", np.nan], + ], + ) + def test_np_string_array_object_cast(self, data): + from numpy.dtypes import StringDType + + arr = np.array(data, dtype=StringDType()) + res = Series(arr) + assert res.dtype == np.object_ + + if data[-1] is np.nan: + # as of GH#62522 the comparison op for `res==data` casts data + # using sanitize_array, which casts to 'str' dtype, which does not + # consider string 'nan' to be equal to np.nan, + # (which apparently numpy does? weird.) + assert (res.iloc[:-1] == data[:-1]).all() + assert res.iloc[-1] == "nan" + else: + assert (res == data).all() + + +class TestSeriesConstructorInternals: + def test_constructor_no_pandas_array(self): + ser = Series([1, 2, 3]) + result = Series(ser.array) + tm.assert_series_equal(ser, result) + assert isinstance(result._mgr.blocks[0], NumpyBlock) + assert result._mgr.blocks[0].is_numeric + + def test_from_array(self): + result = Series(pd.array(["1h", "2h"], dtype="timedelta64[ns]")) + assert result._mgr.blocks[0].is_extension is False + + result = Series(pd.array(["2015"], dtype="datetime64[ns]")) + assert result._mgr.blocks[0].is_extension is False + + def test_from_list_dtype(self): + result = Series(["1h", "2h"], dtype="timedelta64[ns]") + assert result._mgr.blocks[0].is_extension is False + + result = Series(["2015"], dtype="datetime64[ns]") + assert result._mgr.blocks[0].is_extension is False + + +def test_constructor(rand_series_with_duplicate_datetimeindex): + dups = rand_series_with_duplicate_datetimeindex + assert isinstance(dups, Series) + assert isinstance(dups.index, DatetimeIndex) + + +@pytest.mark.parametrize( + "input_dict,expected", + [ + ({0: 0}, np.array([[0]], dtype=np.int64)), + ({"a": "a"}, np.array([["a"]], dtype=object)), + ({1: 1}, np.array([[1]], dtype=np.int64)), + ], +) +def test_numpy_array(input_dict, expected): + result = np.array([Series(input_dict)]) + tm.assert_numpy_array_equal(result, expected) + + +def test_index_ordered_dict_keys(): + # GH 22077 + + param_index = OrderedDict( + [ + ((("a", "b"), ("c", "d")), 1), + ((("a", None), ("c", "d")), 2), + ] + ) + series = Series([1, 2], index=param_index.keys()) + expected = Series( + [1, 2], + index=MultiIndex.from_tuples( + [(("a", "b"), ("c", "d")), (("a", None), ("c", "d"))] + ), + ) + tm.assert_series_equal(series, expected) + + +@pytest.mark.parametrize( + "input_list", + [ + [1, complex("nan"), 2], + [1 + 1j, complex("nan"), 2 + 2j], + ], +) +def test_series_with_complex_nan(input_list): + # GH#53627 + ser = Series(input_list) + result = Series(ser.array) + assert ser.dtype == "complex128" + tm.assert_series_equal(ser, result) + + +def test_dict_keys_rangeindex(): + result = Series({0: 1, 1: 2}) + expected = Series([1, 2], index=RangeIndex(2)) + tm.assert_series_equal(result, expected, check_index_type=True) diff --git a/pandas/tests/series/test_cumulative.py b/pandas/tests/series/test_cumulative.py new file mode 100644 index 0000000000000000000000000000000000000000..db83cf1112e7452df3328dd52b2ebe8a6232161c --- /dev/null +++ b/pandas/tests/series/test_cumulative.py @@ -0,0 +1,284 @@ +""" +Tests for Series cumulative operations. + +See also +-------- +tests.frame.test_cumulative +""" + +import re + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + +methods = { + "cumsum": np.cumsum, + "cumprod": np.cumprod, + "cummin": np.minimum.accumulate, + "cummax": np.maximum.accumulate, +} + + +class TestSeriesCumulativeOps: + @pytest.mark.parametrize("func", [np.cumsum, np.cumprod]) + def test_datetime_series(self, datetime_series, func): + tm.assert_numpy_array_equal( + func(datetime_series).values, + func(np.array(datetime_series)), + check_dtype=True, + ) + + # with missing values + ts = datetime_series.copy() + ts[::2] = np.nan + + result = func(ts)[1::2] + expected = func(np.array(ts.dropna())) + + tm.assert_numpy_array_equal(result.values, expected, check_dtype=False) + + @pytest.mark.parametrize("method", ["cummin", "cummax"]) + def test_cummin_cummax(self, datetime_series, method): + ufunc = methods[method] + + result = getattr(datetime_series, method)().values + expected = ufunc(np.array(datetime_series)) + + tm.assert_numpy_array_equal(result, expected) + ts = datetime_series.copy() + ts[::2] = np.nan + result = getattr(ts, method)()[1::2] + expected = ufunc(ts.dropna()) + + result.index = result.index._with_freq(None) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "ts", + [ + pd.Timedelta(0), + pd.Timestamp("1999-12-31"), + pd.Timestamp("1999-12-31").tz_localize("US/Pacific"), + ], + ) + @pytest.mark.parametrize( + "method, skipna, exp_tdi", + [ + ["cummax", True, ["NaT", "2 days", "NaT", "2 days", "NaT", "3 days"]], + ["cummin", True, ["NaT", "2 days", "NaT", "1 days", "NaT", "1 days"]], + [ + "cummax", + False, + ["NaT", "NaT", "NaT", "NaT", "NaT", "NaT"], + ], + [ + "cummin", + False, + ["NaT", "NaT", "NaT", "NaT", "NaT", "NaT"], + ], + ], + ) + def test_cummin_cummax_datetimelike(self, ts, method, skipna, exp_tdi): + # with ts==pd.Timedelta(0), we are testing td64; with naive Timestamp + # we are testing datetime64[ns]; with Timestamp[US/Pacific] + # we are testing dt64tz + tdi = pd.to_timedelta(["NaT", "2 days", "NaT", "1 days", "NaT", "3 days"]) + ser = pd.Series(tdi + ts) + + exp_tdi = pd.to_timedelta(exp_tdi) + expected = pd.Series(exp_tdi + ts) + result = getattr(ser, method)(skipna=skipna) + tm.assert_series_equal(expected, result) + + def test_cumsum_datetimelike(self): + # GH#57956 + df = pd.DataFrame( + [ + [pd.Timedelta(0), pd.Timedelta(days=1)], + [pd.Timedelta(days=2), pd.NaT], + [pd.Timedelta(hours=-6), pd.Timedelta(hours=12)], + ] + ) + result = df.cumsum() + expected = pd.DataFrame( + [ + [pd.Timedelta(0), pd.Timedelta(days=1)], + [pd.Timedelta(days=2), pd.NaT], + [pd.Timedelta(days=1, hours=18), pd.Timedelta(days=1, hours=12)], + ] + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "func, exp", + [ + ("cummin", "2012-1-1"), + ("cummax", "2012-1-2"), + ], + ) + def test_cummin_cummax_period(self, func, exp): + # GH#28385 + ser = pd.Series( + [pd.Period("2012-1-1", freq="D"), pd.NaT, pd.Period("2012-1-2", freq="D")] + ) + result = getattr(ser, func)(skipna=False) + expected = pd.Series([pd.Period("2012-1-1", freq="D"), pd.NaT, pd.NaT]) + tm.assert_series_equal(result, expected) + + result = getattr(ser, func)(skipna=True) + exp = pd.Period(exp, freq="D") + expected = pd.Series([pd.Period("2012-1-1", freq="D"), pd.NaT, exp]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "arg", + [ + [False, False, False, True, True, False, False], + [False, False, False, False, False, False, False], + ], + ) + @pytest.mark.parametrize( + "func", [lambda x: x, lambda x: ~x], ids=["identity", "inverse"] + ) + @pytest.mark.parametrize("method", methods.keys()) + def test_cummethods_bool(self, arg, func, method): + # GH#6270 + # checking Series method vs the ufunc applied to the values + + ser = func(pd.Series(arg)) + ufunc = methods[method] + + exp_vals = ufunc(ser.values) + expected = pd.Series(exp_vals) + + result = getattr(ser, method)() + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "method, expected", + [ + ["cumsum", pd.Series([0, 1, np.nan, 1], dtype=object)], + ["cumprod", pd.Series([False, 0, np.nan, 0])], + ["cummin", pd.Series([False, False, np.nan, False])], + ["cummax", pd.Series([False, True, np.nan, True])], + ], + ) + def test_cummethods_bool_in_object_dtype(self, method, expected): + ser = pd.Series([False, True, np.nan, False]) + result = getattr(ser, method)() + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "method, order", + [ + ["cummax", "abc"], + ["cummin", "cba"], + ], + ) + def test_cummax_cummin_on_ordered_categorical(self, method, order): + # GH#52335 + cat = pd.CategoricalDtype(list(order), ordered=True) + ser = pd.Series( + list("ababcab"), + dtype=cat, + ) + result = getattr(ser, method)() + expected = pd.Series( + list("abbbccc"), + dtype=cat, + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "skip, exp", + [ + [True, ["a", np.nan, "b", "b", "c"]], + [False, ["a", np.nan, np.nan, np.nan, np.nan]], + ], + ) + @pytest.mark.parametrize( + "method, order", + [ + ["cummax", "abc"], + ["cummin", "cba"], + ], + ) + def test_cummax_cummin_ordered_categorical_nan(self, skip, exp, method, order): + # GH#52335 + cat = pd.CategoricalDtype(list(order), ordered=True) + ser = pd.Series( + ["a", np.nan, "b", "a", "c"], + dtype=cat, + ) + result = getattr(ser, method)(skipna=skip) + expected = pd.Series( + exp, + dtype=cat, + ) + tm.assert_series_equal( + result, + expected, + ) + + def test_cumprod_timedelta(self): + # GH#48111 + ser = pd.Series([pd.Timedelta(days=1), pd.Timedelta(days=3)]) + with pytest.raises(TypeError, match="cumprod not supported for Timedelta"): + ser.cumprod() + + @pytest.mark.parametrize( + "data, op, skipna, expected_data", + [ + ([], "cumsum", True, []), + ([], "cumsum", False, []), + (["x", "z", "y"], "cumsum", True, ["x", "xz", "xzy"]), + (["x", "z", "y"], "cumsum", False, ["x", "xz", "xzy"]), + (["x", pd.NA, "y"], "cumsum", True, ["x", pd.NA, "xy"]), + (["x", pd.NA, "y"], "cumsum", False, ["x", pd.NA, pd.NA]), + ([pd.NA, "x", "y"], "cumsum", True, [pd.NA, "x", "xy"]), + ([pd.NA, "x", "y"], "cumsum", False, [pd.NA, pd.NA, pd.NA]), + ([pd.NA, pd.NA, pd.NA], "cumsum", True, [pd.NA, pd.NA, pd.NA]), + ([pd.NA, pd.NA, pd.NA], "cumsum", False, [pd.NA, pd.NA, pd.NA]), + ([], "cummin", True, []), + ([], "cummin", False, []), + (["y", "z", "x"], "cummin", True, ["y", "y", "x"]), + (["y", "z", "x"], "cummin", False, ["y", "y", "x"]), + (["y", pd.NA, "x"], "cummin", True, ["y", pd.NA, "x"]), + (["y", pd.NA, "x"], "cummin", False, ["y", pd.NA, pd.NA]), + ([pd.NA, "y", "x"], "cummin", True, [pd.NA, "y", "x"]), + ([pd.NA, "y", "x"], "cummin", False, [pd.NA, pd.NA, pd.NA]), + ([pd.NA, pd.NA, pd.NA], "cummin", True, [pd.NA, pd.NA, pd.NA]), + ([pd.NA, pd.NA, pd.NA], "cummin", False, [pd.NA, pd.NA, pd.NA]), + ([], "cummax", True, []), + ([], "cummax", False, []), + (["x", "z", "y"], "cummax", True, ["x", "z", "z"]), + (["x", "z", "y"], "cummax", False, ["x", "z", "z"]), + (["x", pd.NA, "y"], "cummax", True, ["x", pd.NA, "y"]), + (["x", pd.NA, "y"], "cummax", False, ["x", pd.NA, pd.NA]), + ([pd.NA, "x", "y"], "cummax", True, [pd.NA, "x", "y"]), + ([pd.NA, "x", "y"], "cummax", False, [pd.NA, pd.NA, pd.NA]), + ([pd.NA, pd.NA, pd.NA], "cummax", True, [pd.NA, pd.NA, pd.NA]), + ([pd.NA, pd.NA, pd.NA], "cummax", False, [pd.NA, pd.NA, pd.NA]), + ], + ) + def test_cum_methods_ea_strings( + self, string_dtype_no_object, data, op, skipna, expected_data + ): + # https://github.com/pandas-dev/pandas/pull/60633 - pyarrow + # https://github.com/pandas-dev/pandas/pull/60938 - Python + ser = pd.Series(data, dtype=string_dtype_no_object) + method = getattr(ser, op) + expected = pd.Series(expected_data, dtype=string_dtype_no_object) + result = method(skipna=skipna) + tm.assert_series_equal(result, expected) + + def test_cumprod_pyarrow_strings(self, pyarrow_string_dtype, skipna): + # https://github.com/pandas-dev/pandas/pull/60633 + ser = pd.Series(list("xyz"), dtype=pyarrow_string_dtype) + msg = re.escape(f"operation 'cumprod' not supported for dtype '{ser.dtype}'") + with pytest.raises(TypeError, match=msg): + ser.cumprod(skipna=skipna) diff --git a/pandas/tests/series/test_formats.py b/pandas/tests/series/test_formats.py new file mode 100644 index 0000000000000000000000000000000000000000..76c8914e60b76c71a39604c9fde1f8e731fc7e8e --- /dev/null +++ b/pandas/tests/series/test_formats.py @@ -0,0 +1,592 @@ +from datetime import ( + datetime, + timedelta, +) + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + Index, + Series, + date_range, + option_context, + period_range, + timedelta_range, +) + + +class TestSeriesRepr: + def test_multilevel_name_print_0(self): + # GH#55415 None does not get printed, but 0 does + # (matching DataFrame and flat index behavior) + mi = pd.MultiIndex.from_product([range(2, 3), range(3, 4)], names=[0, None]) + ser = Series(1.5, index=mi) + + res = repr(ser) + expected = "0 \n2 3 1.5\ndtype: float64" + assert res == expected + + def test_multilevel_name_print(self, lexsorted_two_level_string_multiindex): + index = lexsorted_two_level_string_multiindex + ser = Series(range(len(index)), index=index, name="sth") + expected = [ + "first second", + "foo one 0", + " two 1", + " three 2", + "bar one 3", + " two 4", + "baz two 5", + " three 6", + "qux one 7", + " two 8", + " three 9", + "Name: sth, dtype: int64", + ] + expected = "\n".join(expected) + assert repr(ser) == expected + + def test_small_name_printing(self): + # Test small Series. + s = Series([0, 1, 2]) + + s.name = "test" + assert "Name: test" in repr(s) + + s.name = None + assert "Name:" not in repr(s) + + def test_big_name_printing(self): + # Test big Series (diff code path). + s = Series(range(1000)) + + s.name = "test" + assert "Name: test" in repr(s) + + s.name = None + assert "Name:" not in repr(s) + + def test_empty_name_printing(self): + s = Series(index=date_range("20010101", "20020101"), name="test", dtype=object) + assert "Name: test" in repr(s) + + @pytest.mark.parametrize("args", [(), (0, -1)]) + def test_float_range(self, args): + str( + Series( + np.random.default_rng(2).standard_normal(1000), + index=np.arange(1000, *args), + ) + ) + + def test_empty_object(self): + # empty + str(Series(dtype=object)) + + def test_string(self, string_series): + str(string_series) + str(string_series.astype(int)) + + # with NaNs + string_series[5:7] = np.nan + str(string_series) + + def test_object(self, object_series): + str(object_series) + + def test_datetime(self, datetime_series): + str(datetime_series) + # with Nones + ots = datetime_series.astype("O") + ots[::2] = None + repr(ots) + + @pytest.mark.parametrize( + "name", + [ + "", + 1, + 1.2, + "foo", + "\u03b1\u03b2\u03b3", + "loooooooooooooooooooooooooooooooooooooooooooooooooooong", + ("foo", "bar", "baz"), + (1, 2), + ("foo", 1, 2.3), + ("\u03b1", "\u03b2", "\u03b3"), + ("\u03b1", "bar"), + ], + ) + def test_various_names(self, name, string_series): + # various names + string_series.name = name + repr(string_series) + + def test_tuple_name(self): + biggie = Series( + np.random.default_rng(2).standard_normal(1000), + index=np.arange(1000), + name=("foo", "bar", "baz"), + ) + repr(biggie) + + @pytest.mark.parametrize("arg", [100, 1001]) + def test_tidy_repr_name_0(self, arg): + # tidy repr + ser = Series(np.random.default_rng(2).standard_normal(arg), name=0) + rep_str = repr(ser) + assert "Name: 0" in rep_str + + def test_newline(self, any_string_dtype): + ser = Series( + ["a\n\r\tb"], + name="a\n\r\td", + index=Index(["a\n\r\tf"], dtype=any_string_dtype), + dtype=any_string_dtype, + ) + assert "\t" not in repr(ser) + assert "\r" not in repr(ser) + assert "a\n" not in repr(ser) + + @pytest.mark.parametrize( + "name, expected", + [ + ["foo", "Series([], Name: foo, dtype: int64)"], + [None, "Series([], dtype: int64)"], + ], + ) + def test_empty_int64(self, name, expected): + # with empty series (#4651) + s = Series([], dtype=np.int64, name=name) + assert repr(s) == expected + + def test_repr_bool_fails(self, capsys): + s = Series( + [ + DataFrame(np.random.default_rng(2).standard_normal((2, 2))) + for i in range(5) + ] + ) + + # It works (with no Cython exception barf)! + repr(s) + + captured = capsys.readouterr() + assert captured.err == "" + + def test_repr_name_iterable_indexable(self): + s = Series([1, 2, 3], name=np.int64(3)) + + # it works! + repr(s) + + s.name = ("\u05d0",) * 2 + repr(s) + + def test_repr_max_rows(self): + # GH 6863 + with option_context("display.max_rows", None): + str(Series(range(1001))) # should not raise exception + + def test_unicode_string_with_unicode(self): + df = Series(["\u05d0"], name="\u05d1") + str(df) + + ser = Series(["\u03c3"] * 10) + repr(ser) + + ser2 = Series(["\u05d0"] * 1000) + ser2.name = "title1" + repr(ser2) + + def test_str_to_bytes_raises(self): + # GH 26447 + df = Series(["abc"], name="abc") + msg = "^'str' object cannot be interpreted as an integer$" + with pytest.raises(TypeError, match=msg): + bytes(df) + + def test_timeseries_repr_object_dtype(self): + index = Index( + [datetime(2000, 1, 1) + timedelta(i) for i in range(1000)], dtype=object + ) + ts = Series(np.random.default_rng(2).standard_normal(len(index)), index) + repr(ts) + + ts = Series( + np.arange(20, dtype=np.float64), index=date_range("2020-01-01", periods=20) + ) + assert repr(ts).splitlines()[-1].startswith("Freq:") + + ts2 = ts.iloc[np.random.default_rng(2).integers(0, len(ts) - 1, 400)] + repr(ts2).splitlines()[-1] + + def test_latex_repr(self): + pytest.importorskip("jinja2") # uses Styler implementation + result = r"""\begin{tabular}{ll} +\toprule + & 0 \\ +\midrule +0 & $\alpha$ \\ +1 & b \\ +2 & c \\ +\bottomrule +\end{tabular} +""" + with option_context( + "styler.format.escape", None, "styler.render.repr", "latex" + ): + s = Series([r"$\alpha$", "b", "c"]) + assert result == s._repr_latex_() + + assert s._repr_latex_() is None + + def test_index_repr_in_frame_with_nan(self): + # see gh-25061 + i = Index([1, np.nan]) + s = Series([1, 2], index=i) + exp = """1.0 1\nNaN 2\ndtype: int64""" + + assert repr(s) == exp + + def test_series_repr_nat(self): + series = Series([0, 1000, 2000, pd.NaT._value], dtype="M8[ns]") + + result = repr(series) + expected = ( + "0 1970-01-01 00:00:00.000000\n" + "1 1970-01-01 00:00:00.000001\n" + "2 1970-01-01 00:00:00.000002\n" + "3 NaT\n" + "dtype: datetime64[ns]" + ) + assert result == expected + + def test_float_repr(self): + # GH#35603 + # check float format when cast to object + ser = Series([1.0]).astype(object) + expected = "0 1.0\ndtype: object" + assert repr(ser) == expected + + def test_different_null_objects(self): + # GH#45263 + ser = Series([1, 2, 3, 4], [True, None, np.nan, pd.NaT]) + result = repr(ser) + expected = "True 1\nNone 2\nNaN 3\nNaT 4\ndtype: int64" + assert result == expected + + def test_2d_extension_type(self): + # GH#33770 + + # Define a stub extension type with just enough code to run Series.__repr__() + class DtypeStub(pd.api.extensions.ExtensionDtype): + @property + def type(self): + return np.ndarray + + @property + def name(self): + return "DtypeStub" + + class ExtTypeStub(pd.api.extensions.ExtensionArray): + def __len__(self) -> int: + return 2 + + def __getitem__(self, ix): + return [ix == 1, ix == 0] + + @property + def dtype(self): + return DtypeStub() + + series = Series(ExtTypeStub(), copy=False) + res = repr(series) # This line crashed before GH#33770 was fixed. + expected = "\n".join( + ["0 [False True]", "1 [True False]", "dtype: DtypeStub"] + ) + assert res == expected + + +class TestCategoricalRepr: + def test_categorical_repr_unicode(self): + # see gh-21002 + + class County: + name = "San Sebastián" + state = "PR" + + def __repr__(self) -> str: + return self.name + ", " + self.state + + cat = Categorical([County() for _ in range(61)]) + idx = Index(cat) + ser = idx.to_series() + + repr(ser) + str(ser) + + def test_categorical_repr(self, using_infer_string): + a = Series(Categorical([1, 2, 3, 4])) + exp = ( + "0 1\n1 2\n2 3\n3 4\n" + "dtype: category\nCategories (4, int64): [1, 2, 3, 4]" + ) + + assert exp == a.__str__() + + a = Series(Categorical(["a", "b"] * 25)) + exp = ( + "0 a\n1 b\n" + " ..\n" + "48 a\n49 b\n" + "Length: 50, dtype: category\nCategories (2, object): ['a', 'b']" + ) + if using_infer_string: + exp = exp.replace("object", "str") + with option_context("display.max_rows", 5): + assert exp == repr(a) + + levs = list("abcdefghijklmnopqrstuvwxyz") + a = Series(Categorical(["a", "b"], categories=levs, ordered=True)) + exp = ( + "0 a\n1 b\n" + "dtype: category\n" + "Categories (26, object): ['a' < 'b' < 'c' < 'd' ... " + "'w' < 'x' < 'y' < 'z']" + ) + if using_infer_string: + exp = exp.replace("object", "str") + assert exp == a.__str__() + + def test_categorical_series_repr(self): + s = Series(Categorical([1, 2, 3])) + exp = """0 1 +1 2 +2 3 +dtype: category +Categories (3, int64): [1, 2, 3]""" + + assert repr(s) == exp + + s = Series(Categorical(np.arange(10))) + exp = f"""0 0 +1 1 +2 2 +3 3 +4 4 +5 5 +6 6 +7 7 +8 8 +9 9 +dtype: category +Categories (10, {np.dtype(int)}): [0, 1, 2, 3, ..., 6, 7, 8, 9]""" + + assert repr(s) == exp + + def test_categorical_series_repr_ordered(self): + s = Series(Categorical([1, 2, 3], ordered=True)) + exp = """0 1 +1 2 +2 3 +dtype: category +Categories (3, int64): [1 < 2 < 3]""" + + assert repr(s) == exp + + s = Series(Categorical(np.arange(10), ordered=True)) + exp = f"""0 0 +1 1 +2 2 +3 3 +4 4 +5 5 +6 6 +7 7 +8 8 +9 9 +dtype: category +Categories (10, {np.dtype(int)}): [0 < 1 < 2 < 3 ... 6 < 7 < 8 < 9]""" + + assert repr(s) == exp + + def test_categorical_series_repr_datetime(self): + idx = date_range("2011-01-01 09:00", freq="h", periods=5, unit="ns") + s = Series(Categorical(idx)) + exp = """0 2011-01-01 09:00:00 +1 2011-01-01 10:00:00 +2 2011-01-01 11:00:00 +3 2011-01-01 12:00:00 +4 2011-01-01 13:00:00 +dtype: category +Categories (5, datetime64[ns]): [2011-01-01 09:00:00, 2011-01-01 10:00:00, 2011-01-01 11:00:00, + 2011-01-01 12:00:00, 2011-01-01 13:00:00]""" # noqa: E501 + + assert repr(s) == exp + + idx = date_range( + "2011-01-01 09:00", freq="h", periods=5, tz="US/Eastern", unit="ns" + ) + s = Series(Categorical(idx)) + exp = """0 2011-01-01 09:00:00-05:00 +1 2011-01-01 10:00:00-05:00 +2 2011-01-01 11:00:00-05:00 +3 2011-01-01 12:00:00-05:00 +4 2011-01-01 13:00:00-05:00 +dtype: category +Categories (5, datetime64[ns, US/Eastern]): [2011-01-01 09:00:00-05:00, 2011-01-01 10:00:00-05:00, + 2011-01-01 11:00:00-05:00, 2011-01-01 12:00:00-05:00, + 2011-01-01 13:00:00-05:00]""" # noqa: E501 + + assert repr(s) == exp + + def test_categorical_series_repr_datetime_ordered(self): + idx = date_range("2011-01-01 09:00", freq="h", periods=5, unit="ns") + s = Series(Categorical(idx, ordered=True)) + exp = """0 2011-01-01 09:00:00 +1 2011-01-01 10:00:00 +2 2011-01-01 11:00:00 +3 2011-01-01 12:00:00 +4 2011-01-01 13:00:00 +dtype: category +Categories (5, datetime64[ns]): [2011-01-01 09:00:00 < 2011-01-01 10:00:00 < 2011-01-01 11:00:00 < + 2011-01-01 12:00:00 < 2011-01-01 13:00:00]""" # noqa: E501 + + assert repr(s) == exp + + idx = date_range( + "2011-01-01 09:00", freq="h", periods=5, tz="US/Eastern", unit="ns" + ) + s = Series(Categorical(idx, ordered=True)) + exp = """0 2011-01-01 09:00:00-05:00 +1 2011-01-01 10:00:00-05:00 +2 2011-01-01 11:00:00-05:00 +3 2011-01-01 12:00:00-05:00 +4 2011-01-01 13:00:00-05:00 +dtype: category +Categories (5, datetime64[ns, US/Eastern]): [2011-01-01 09:00:00-05:00 < 2011-01-01 10:00:00-05:00 < + 2011-01-01 11:00:00-05:00 < 2011-01-01 12:00:00-05:00 < + 2011-01-01 13:00:00-05:00]""" # noqa: E501 + + assert repr(s) == exp + + def test_categorical_series_repr_period(self): + idx = period_range("2011-01-01 09:00", freq="h", periods=5) + s = Series(Categorical(idx)) + exp = """0 2011-01-01 09:00 +1 2011-01-01 10:00 +2 2011-01-01 11:00 +3 2011-01-01 12:00 +4 2011-01-01 13:00 +dtype: category +Categories (5, period[h]): [2011-01-01 09:00, 2011-01-01 10:00, 2011-01-01 11:00, 2011-01-01 12:00, + 2011-01-01 13:00]""" # noqa: E501 + + assert repr(s) == exp + + idx = period_range("2011-01", freq="M", periods=5) + s = Series(Categorical(idx)) + exp = """0 2011-01 +1 2011-02 +2 2011-03 +3 2011-04 +4 2011-05 +dtype: category +Categories (5, period[M]): [2011-01, 2011-02, 2011-03, 2011-04, 2011-05]""" + + assert repr(s) == exp + + def test_categorical_series_repr_period_ordered(self): + idx = period_range("2011-01-01 09:00", freq="h", periods=5) + s = Series(Categorical(idx, ordered=True)) + exp = """0 2011-01-01 09:00 +1 2011-01-01 10:00 +2 2011-01-01 11:00 +3 2011-01-01 12:00 +4 2011-01-01 13:00 +dtype: category +Categories (5, period[h]): [2011-01-01 09:00 < 2011-01-01 10:00 < 2011-01-01 11:00 < 2011-01-01 12:00 < + 2011-01-01 13:00]""" # noqa: E501 + + assert repr(s) == exp + + idx = period_range("2011-01", freq="M", periods=5) + s = Series(Categorical(idx, ordered=True)) + exp = """0 2011-01 +1 2011-02 +2 2011-03 +3 2011-04 +4 2011-05 +dtype: category +Categories (5, period[M]): [2011-01 < 2011-02 < 2011-03 < 2011-04 < 2011-05]""" + + assert repr(s) == exp + + def test_categorical_series_repr_timedelta(self): + idx = timedelta_range("1 days", periods=5) + s = Series(Categorical(idx)) + exp = """0 1 days +1 2 days +2 3 days +3 4 days +4 5 days +dtype: category +Categories (5, timedelta64[us]): [1 days, 2 days, 3 days, 4 days, 5 days]""" + + assert repr(s) == exp + + idx = timedelta_range("1 hours", periods=10) + s = Series(Categorical(idx)) + exp = """0 0 days 01:00:00 +1 1 days 01:00:00 +2 2 days 01:00:00 +3 3 days 01:00:00 +4 4 days 01:00:00 +5 5 days 01:00:00 +6 6 days 01:00:00 +7 7 days 01:00:00 +8 8 days 01:00:00 +9 9 days 01:00:00 +dtype: category +Categories (10, timedelta64[us]): [0 days 01:00:00, 1 days 01:00:00, 2 days 01:00:00, + 3 days 01:00:00, ..., 6 days 01:00:00, 7 days 01:00:00, + 8 days 01:00:00, 9 days 01:00:00]""" # noqa: E501 + + assert repr(s) == exp + + def test_categorical_series_repr_timedelta_ordered(self): + idx = timedelta_range("1 days", periods=5) + s = Series(Categorical(idx, ordered=True)) + exp = """0 1 days +1 2 days +2 3 days +3 4 days +4 5 days +dtype: category +Categories (5, timedelta64[us]): [1 days < 2 days < 3 days < 4 days < 5 days]""" + + assert repr(s) == exp + + idx = timedelta_range("1 hours", periods=10) + s = Series(Categorical(idx, ordered=True)) + exp = """0 0 days 01:00:00 +1 1 days 01:00:00 +2 2 days 01:00:00 +3 3 days 01:00:00 +4 4 days 01:00:00 +5 5 days 01:00:00 +6 6 days 01:00:00 +7 7 days 01:00:00 +8 8 days 01:00:00 +9 9 days 01:00:00 +dtype: category +Categories (10, timedelta64[us]): [0 days 01:00:00 < 1 days 01:00:00 < 2 days 01:00:00 < + 3 days 01:00:00 ... 6 days 01:00:00 < 7 days 01:00:00 < + 8 days 01:00:00 < 9 days 01:00:00]""" # noqa: E501 + + assert repr(s) == exp diff --git a/pandas/tests/series/test_iteration.py b/pandas/tests/series/test_iteration.py new file mode 100644 index 0000000000000000000000000000000000000000..db5d80b3798b91b6c3ba212c2dbe7c5c5c56a5a6 --- /dev/null +++ b/pandas/tests/series/test_iteration.py @@ -0,0 +1,33 @@ +class TestIteration: + def test_keys(self, datetime_series): + assert datetime_series.keys() is datetime_series.index + + def test_iter_datetimes(self, datetime_series): + for i, val in enumerate(datetime_series): + assert val == datetime_series.iloc[i] + + def test_iter_strings(self, string_series): + for i, val in enumerate(string_series): + assert val == string_series.iloc[i] + + def test_iteritems_datetimes(self, datetime_series): + for idx, val in datetime_series.items(): + assert val == datetime_series[idx] # noqa: PLR1733 + + def test_iteritems_strings(self, string_series): + for idx, val in string_series.items(): + assert val == string_series[idx] # noqa: PLR1733 + + # assert is lazy (generators don't define reverse, lists do) + assert not hasattr(string_series.items(), "reverse") + + def test_items_datetimes(self, datetime_series): + for idx, val in datetime_series.items(): + assert val == datetime_series[idx] # noqa: PLR1733 + + def test_items_strings(self, string_series): + for idx, val in string_series.items(): + assert val == string_series[idx] # noqa: PLR1733 + + # assert is lazy (generators don't define reverse, lists do) + assert not hasattr(string_series.items(), "reverse") diff --git a/pandas/tests/series/test_logical_ops.py b/pandas/tests/series/test_logical_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a32a88b77c3a1bd864e819e5041a1844025bc4b5 --- /dev/null +++ b/pandas/tests/series/test_logical_ops.py @@ -0,0 +1,510 @@ +from datetime import datetime +import operator + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + Series, + bdate_range, +) +import pandas._testing as tm +from pandas.core import ops + + +class TestSeriesLogicalOps: + @pytest.mark.parametrize("bool_op", [operator.and_, operator.or_, operator.xor]) + def test_bool_operators_with_nas(self, bool_op): + # boolean &, |, ^ should work with object arrays and propagate NAs + ser = Series(bdate_range("1/1/2000", periods=10), dtype=object) + ser[::2] = np.nan + + mask = ser.isna() + filled = ser.fillna(ser[0]) + + result = bool_op(ser < ser[9], ser > ser[3]) + + expected = bool_op(filled < filled[9], filled > filled[3]) + expected[mask] = False + tm.assert_series_equal(result, expected) + + def test_logical_operators_bool_dtype_with_empty(self): + # GH#9016: support bitwise op for integer types + index = list("bca") + + s_tft = Series([True, False, True], index=index) + s_fff = Series([False, False, False], index=index) + s_empty = Series([], dtype=object) + + res = s_tft & s_empty + expected = s_fff.sort_index() + tm.assert_series_equal(res, expected) + + res = s_tft | s_empty + expected = s_tft.sort_index() + tm.assert_series_equal(res, expected) + + def test_logical_operators_int_dtype_with_int_dtype(self): + # GH#9016: support bitwise op for integer types + + s_0123 = Series(range(4), dtype="int64") + s_3333 = Series([3] * 4) + s_4444 = Series([4] * 4) + + res = s_0123 & s_3333 + expected = Series(range(4), dtype="int64") + tm.assert_series_equal(res, expected) + + res = s_0123 | s_4444 + expected = Series(range(4, 8), dtype="int64") + tm.assert_series_equal(res, expected) + + s_1111 = Series([1] * 4, dtype="int8") + res = s_0123 & s_1111 + expected = Series([0, 1, 0, 1], dtype="int64") + tm.assert_series_equal(res, expected) + + res = s_0123.astype(np.int16) | s_1111.astype(np.int32) + expected = Series([1, 1, 3, 3], dtype="int32") + tm.assert_series_equal(res, expected) + + def test_logical_operators_int_dtype_with_int_scalar(self): + # GH#9016: support bitwise op for integer types + s_0123 = Series(range(4), dtype="int64") + + res = s_0123 & 0 + expected = Series([0] * 4) + tm.assert_series_equal(res, expected) + + res = s_0123 & 1 + expected = Series([0, 1, 0, 1]) + tm.assert_series_equal(res, expected) + + def test_logical_operators_int_dtype_with_float(self): + # GH#9016: support bitwise op for integer types + s_0123 = Series(range(4), dtype="int64") + + err_msg = ( + r"Logical ops \(and, or, xor\) between Pandas objects and " + "dtype-less sequences" + ) + + msg = "Cannot perform.+with a dtyped.+array and scalar of type" + with pytest.raises(TypeError, match=msg): + s_0123 & np.nan + with pytest.raises(TypeError, match=msg): + s_0123 & 3.14 + msg = "unsupported operand type.+for &:" + with pytest.raises(TypeError, match=err_msg): + s_0123 & [0.1, 4, 3.14, 2] + with pytest.raises(TypeError, match=msg): + s_0123 & np.array([0.1, 4, 3.14, 2]) + with pytest.raises(TypeError, match=msg): + s_0123 & Series([0.1, 4, -3.14, 2]) + + def test_logical_operators_int_dtype_with_str(self): + s_1111 = Series([1] * 4, dtype="int8") + + err_msg = ( + r"Logical ops \(and, or, xor\) between Pandas objects and " + "dtype-less sequences" + ) + + msg = "Cannot perform 'and_' with a dtyped.+array and scalar of type" + with pytest.raises(TypeError, match=msg): + s_1111 & "a" + with pytest.raises(TypeError, match=err_msg): + s_1111 & ["a", "b", "c", "d"] + + def test_logical_operators_int_dtype_with_bool(self): + # GH#9016: support bitwise op for integer types + s_0123 = Series(range(4), dtype="int64") + + expected = Series([False] * 4) + + result = s_0123 & False + tm.assert_series_equal(result, expected) + + msg = ( + r"Logical ops \(and, or, xor\) between Pandas objects and " + "dtype-less sequences" + ) + with pytest.raises(TypeError, match=msg): + s_0123 & [False] + + with pytest.raises(TypeError, match=msg): + s_0123 & (False,) + + result = s_0123 ^ False + expected = Series([False, True, True, True]) + tm.assert_series_equal(result, expected) + + def test_logical_operators_int_dtype_with_object(self): + # GH#9016: support bitwise op for integer types + s_0123 = Series(range(4), dtype="int64") + + result = s_0123 & Series([False, np.nan, False, False]) + expected = Series([False] * 4) + tm.assert_series_equal(result, expected) + + s_abNd = Series(["a", "b", np.nan, "d"]) + with pytest.raises( + TypeError, match="unsupported.* 'int' and 'str'|'rand_' not supported" + ): + s_0123 & s_abNd + + def test_logical_operators_bool_dtype_with_int(self): + index = list("bca") + + s_tft = Series([True, False, True], index=index) + s_fff = Series([False, False, False], index=index) + + res = s_tft & 0 + expected = s_fff + tm.assert_series_equal(res, expected) + + res = s_tft & 1 + expected = s_tft + tm.assert_series_equal(res, expected) + + def test_logical_ops_bool_dtype_with_ndarray(self): + # make sure we operate on ndarray the same as Series + left = Series([True, True, True, False, True]) + right = [True, False, None, True, np.nan] + + msg = ( + r"Logical ops \(and, or, xor\) between Pandas objects and " + "dtype-less sequences" + ) + + expected = Series([True, False, False, False, False]) + with pytest.raises(TypeError, match=msg): + left & right + result = left & np.array(right) + tm.assert_series_equal(result, expected) + result = left & Index(right) + tm.assert_series_equal(result, expected) + result = left & Series(right) + tm.assert_series_equal(result, expected) + + expected = Series([True, True, True, True, True]) + with pytest.raises(TypeError, match=msg): + left | right + result = left | np.array(right) + tm.assert_series_equal(result, expected) + result = left | Index(right) + tm.assert_series_equal(result, expected) + result = left | Series(right) + tm.assert_series_equal(result, expected) + + expected = Series([False, True, True, True, True]) + with pytest.raises(TypeError, match=msg): + left ^ right + result = left ^ np.array(right) + tm.assert_series_equal(result, expected) + result = left ^ Index(right) + tm.assert_series_equal(result, expected) + result = left ^ Series(right) + tm.assert_series_equal(result, expected) + + def test_logical_operators_int_dtype_with_bool_dtype_and_reindex(self): + # GH#9016: support bitwise op for integer types + + index = list("bca") + + s_tft = Series([True, False, True], index=index) + s_tft = Series([True, False, True], index=index) + s_tff = Series([True, False, False], index=index) + + s_0123 = Series(range(4), dtype="int64") + + # s_0123 will be all false now because of reindexing like s_tft + expected = Series([False] * 7, index=[0, 1, 2, 3, "a", "b", "c"]) + result = s_tft & s_0123 + tm.assert_series_equal(result, expected) + + # GH#52538: no longer to object type when reindex is needed; + # matches DataFrame behavior + msg = r"unsupported operand type\(s\) for &: 'float' and 'bool'" + with pytest.raises(TypeError, match=msg): + s_0123 & s_tft + + s_a0b1c0 = Series([1], list("b")) + + res = s_tft & s_a0b1c0 + expected = s_tff.reindex(list("abc")) + tm.assert_series_equal(res, expected) + + res = s_tft | s_a0b1c0 + expected = s_tft.reindex(list("abc")) + tm.assert_series_equal(res, expected) + + def test_scalar_na_logical_ops_corners(self): + s = Series([2, 3, 4, 5, 6, 7, 8, 9, 10]) + + msg = "Cannot perform.+with a dtyped.+array and scalar of type" + with pytest.raises(TypeError, match=msg): + s & datetime(2005, 1, 1) + + s = Series([2, 3, 4, 5, 6, 7, 8, 9, datetime(2005, 1, 1)]) + s[::2] = np.nan + + expected = Series(True, index=s.index) + expected[::2] = False + + msg = ( + r"Logical ops \(and, or, xor\) between Pandas objects and " + "dtype-less sequences" + ) + with pytest.raises(TypeError, match=msg): + s & list(s) + + def test_scalar_na_logical_ops_corners_aligns(self): + s = Series([2, 3, 4, 5, 6, 7, 8, 9, datetime(2005, 1, 1)]) + s[::2] = np.nan + d = DataFrame({"A": s}) + + expected = DataFrame(False, index=range(9), columns=["A", *list(range(9))]) + + result = s & d + tm.assert_frame_equal(result, expected) + + result = d & s + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("op", [operator.and_, operator.or_, operator.xor]) + def test_logical_ops_with_index(self, op): + # GH#22092, GH#19792 + ser = Series([True, True, False, False]) + idx1 = Index([True, False, True, False]) + idx2 = Index([1, 0, 1, 0]) + + expected = Series([op(ser[n], idx1[n]) for n in range(len(ser))]) + + result = op(ser, idx1) + tm.assert_series_equal(result, expected) + + expected = Series([op(ser[n], idx2[n]) for n in range(len(ser))], dtype=bool) + + result = op(ser, idx2) + tm.assert_series_equal(result, expected) + + def test_reversed_xor_with_index_returns_series(self): + # GH#22092, GH#19792 pre-2.0 these were aliased to setops + ser = Series([True, True, False, False]) + idx1 = Index([True, False, True, False], dtype=bool) + idx2 = Index([1, 0, 1, 0]) + + expected = Series([False, True, True, False]) + result = idx1 ^ ser + tm.assert_series_equal(result, expected) + + result = idx2 ^ ser + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "op", + [ + ops.rand_, + ops.ror_, + ], + ) + def test_reversed_logical_op_with_index_returns_series(self, op): + # GH#22092, GH#19792 + ser = Series([True, True, False, False]) + idx1 = Index([True, False, True, False]) + idx2 = Index([1, 0, 1, 0]) + + expected = Series(op(idx1.values, ser.values)) + result = op(ser, idx1) + tm.assert_series_equal(result, expected) + + expected = op(ser, Series(idx2)) + result = op(ser, idx2) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "op, expected", + [ + (ops.rand_, [False, False]), + (ops.ror_, [True, True]), + (ops.rxor, [True, True]), + ], + ) + def test_reverse_ops_with_index(self, op, expected): + # https://github.com/pandas-dev/pandas/pull/23628 + # multi-set Index ops are buggy, so let's avoid duplicates... + # GH#49503 + ser = Series([True, False]) + idx = Index([False, True]) + + result = op(ser, idx) + expected = Series(expected) + tm.assert_series_equal(result, expected) + + def test_logical_ops_label_based(self, using_infer_string): + # GH#4947 + # logical ops should be label based + + a = Series([True, False, True], list("bca")) + b = Series([False, True, False], list("abc")) + + expected = Series([False, True, False], list("abc")) + result = a & b + tm.assert_series_equal(result, expected) + + expected = Series([True, True, False], list("abc")) + result = a | b + tm.assert_series_equal(result, expected) + + expected = Series([True, False, False], list("abc")) + result = a ^ b + tm.assert_series_equal(result, expected) + + # rhs is bigger + a = Series([True, False, True], list("bca")) + b = Series([False, True, False, True], list("abcd")) + + expected = Series([False, True, False, False], list("abcd")) + result = a & b + tm.assert_series_equal(result, expected) + + expected = Series([True, True, False, False], list("abcd")) + result = a | b + tm.assert_series_equal(result, expected) + + # filling + + # vs empty + empty = Series([], dtype=object) + + result = a & empty + expected = Series([False, False, False], list("abc")) + tm.assert_series_equal(result, expected) + + result = a | empty + expected = Series([True, True, False], list("abc")) + tm.assert_series_equal(result, expected) + + # vs non-matching + result = a & Series([1], ["z"]) + expected = Series([False, False, False, False], list("abcz")) + tm.assert_series_equal(result, expected) + + result = a | Series([1], ["z"]) + expected = Series([True, True, False, False], list("abcz")) + tm.assert_series_equal(result, expected) + + # identity + # we would like s[s|e] == s to hold for any e, whether empty or not + for e in [ + empty.copy(), + Series([1], ["z"]), + Series(np.nan, b.index), + Series(np.nan, a.index), + ]: + result = a[a | e] + tm.assert_series_equal(result, a[a]) + + for e in [Series(["z"])]: + if using_infer_string: + # TODO(infer_string) should this behave differently? + # -> https://github.com/pandas-dev/pandas/issues/60234 + with pytest.raises( + TypeError, match="not supported for dtype|unsupported operand type" + ): + result = a[a | e] + else: + result = a[a | e] + tm.assert_series_equal(result, a[a]) + + # vs scalars + index = list("bca") + t = Series([True, False, True]) + + for v in [True, 1, 2]: + result = Series([True, False, True], index=index) | v + expected = Series([True, True, True], index=index) + tm.assert_series_equal(result, expected) + + msg = "Cannot perform.+with a dtyped.+array and scalar of type" + for v in [np.nan, "foo"]: + with pytest.raises(TypeError, match=msg): + t | v + + for v in [False, 0]: + result = Series([True, False, True], index=index) | v + expected = Series([True, False, True], index=index) + tm.assert_series_equal(result, expected) + + for v in [True, 1]: + result = Series([True, False, True], index=index) & v + expected = Series([True, False, True], index=index) + tm.assert_series_equal(result, expected) + + for v in [False, 0]: + result = Series([True, False, True], index=index) & v + expected = Series([False, False, False], index=index) + tm.assert_series_equal(result, expected) + msg = "Cannot perform.+with a dtyped.+array and scalar of type" + for v in [np.nan]: + with pytest.raises(TypeError, match=msg): + t & v + + def test_logical_ops_df_compat(self): + # GH#1134 + s1 = Series([True, False, True], index=list("ABC"), name="x") + s2 = Series([True, True, False], index=list("ABD"), name="x") + + exp = Series([True, False, False, False], index=list("ABCD"), name="x") + tm.assert_series_equal(s1 & s2, exp) + tm.assert_series_equal(s2 & s1, exp) + + # True | np.nan => True + exp_or1 = Series([True, True, True, False], index=list("ABCD"), name="x") + tm.assert_series_equal(s1 | s2, exp_or1) + # np.nan | True => np.nan, filled with False + exp_or = Series([True, True, False, False], index=list("ABCD"), name="x") + tm.assert_series_equal(s2 | s1, exp_or) + + # DataFrame doesn't fill nan with False + tm.assert_frame_equal(s1.to_frame() & s2.to_frame(), exp.to_frame()) + tm.assert_frame_equal(s2.to_frame() & s1.to_frame(), exp.to_frame()) + + exp = DataFrame({"x": [True, True, np.nan, np.nan]}, index=list("ABCD")) + tm.assert_frame_equal(s1.to_frame() | s2.to_frame(), exp_or1.to_frame()) + tm.assert_frame_equal(s2.to_frame() | s1.to_frame(), exp_or.to_frame()) + + # different length + s3 = Series([True, False, True], index=list("ABC"), name="x") + s4 = Series([True, True, True, True], index=list("ABCD"), name="x") + + exp = Series([True, False, True, False], index=list("ABCD"), name="x") + tm.assert_series_equal(s3 & s4, exp) + tm.assert_series_equal(s4 & s3, exp) + + # np.nan | True => np.nan, filled with False + exp_or1 = Series([True, True, True, False], index=list("ABCD"), name="x") + tm.assert_series_equal(s3 | s4, exp_or1) + # True | np.nan => True + exp_or = Series([True, True, True, True], index=list("ABCD"), name="x") + tm.assert_series_equal(s4 | s3, exp_or) + + tm.assert_frame_equal(s3.to_frame() & s4.to_frame(), exp.to_frame()) + tm.assert_frame_equal(s4.to_frame() & s3.to_frame(), exp.to_frame()) + + tm.assert_frame_equal(s3.to_frame() | s4.to_frame(), exp_or1.to_frame()) + tm.assert_frame_equal(s4.to_frame() | s3.to_frame(), exp_or.to_frame()) + + def test_int_dtype_different_index_not_bool(self): + # GH 52500 + ser1 = Series([1, 2, 3], index=[10, 11, 23], name="a") + ser2 = Series([10, 20, 30], index=[11, 10, 23], name="a") + result = np.bitwise_xor(ser1, ser2) + expected = Series([21, 8, 29], index=[10, 11, 23], name="a") + tm.assert_series_equal(result, expected) + + result = ser1 ^ ser2 + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/series/test_missing.py b/pandas/tests/series/test_missing.py new file mode 100644 index 0000000000000000000000000000000000000000..1c88329a83b0ef5f67468b5b6db8e595ea55d2a4 --- /dev/null +++ b/pandas/tests/series/test_missing.py @@ -0,0 +1,88 @@ +from datetime import timedelta + +import numpy as np +import pytest + +from pandas._libs import iNaT + +import pandas as pd +from pandas import ( + Categorical, + Index, + NaT, + Series, + isna, +) +import pandas._testing as tm + + +class TestSeriesMissingData: + def test_categorical_nan_handling(self): + # NaNs are represented as -1 in labels + s = Series(Categorical(["a", "b", np.nan, "a"])) + tm.assert_index_equal(s.cat.categories, Index(["a", "b"])) + tm.assert_numpy_array_equal( + s.values.codes, np.array([0, 1, -1, 0], dtype=np.int8) + ) + + def test_timedelta64_nan(self): + td = Series([timedelta(days=i) for i in range(10)]) + + # nan ops on timedeltas + td1 = td.copy() + td1[0] = np.nan + assert isna(td1[0]) + assert td1[0]._value == iNaT + td1[0] = td[0] + assert not isna(td1[0]) + + # GH#16674 iNaT is treated as an integer when given by the user + with pytest.raises(TypeError, match="Invalid value"): + td1[1] = iNaT + + td1[2] = NaT + assert isna(td1[2]) + assert td1[2]._value == iNaT + td1[2] = td[2] + assert not isna(td1[2]) + + # boolean setting + # GH#2899 boolean setting + td3 = np.timedelta64(timedelta(days=3)) + td7 = np.timedelta64(timedelta(days=7)) + td[(td > td3) & (td < td7)] = np.nan + assert isna(td).sum() == 3 + + @pytest.mark.xfail( + reason="Chained inequality raises when trying to define 'selector'" + ) + def test_logical_range_select(self, datetime_series): + # NumPy limitation =( + # https://github.com/pandas-dev/pandas/commit/9030dc021f07c76809848925cb34828f6c8484f3 + + selector = -0.5 <= datetime_series <= 0.5 + expected = (datetime_series >= -0.5) & (datetime_series <= 0.5) + tm.assert_series_equal(selector, expected) + + def test_valid(self, datetime_series): + ts = datetime_series.copy() + ts.index = ts.index._with_freq(None) + ts[::2] = np.nan + + result = ts.dropna() + assert len(result) == ts.count() + tm.assert_series_equal(result, ts[1::2]) + tm.assert_series_equal(result, ts[pd.notna(ts)]) + + +def test_hasnans_uncached_for_series(): + # GH#19700 + # set float64 dtype to avoid upcast when setting nan + idx = Index([0, 1], dtype="float64") + assert idx.hasnans is False + assert "hasnans" in idx._cache + ser = idx.to_series() + assert ser.hasnans is False + assert not hasattr(ser, "_cache") + ser.iloc[-1] = np.nan + assert ser.hasnans is True diff --git a/pandas/tests/series/test_npfuncs.py b/pandas/tests/series/test_npfuncs.py new file mode 100644 index 0000000000000000000000000000000000000000..f30c01b49639935a8ace4ed1d91baa84f899d3b6 --- /dev/null +++ b/pandas/tests/series/test_npfuncs.py @@ -0,0 +1,52 @@ +""" +Tests for np.foo applied to Series, not necessarily ufuncs. +""" + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas import Series +import pandas._testing as tm + + +class TestPtp: + def test_ptp(self): + # GH#21614 + N = 1000 + arr = np.random.default_rng(2).standard_normal(N) + ser = Series(arr) + assert np.ptp(ser) == np.ptp(arr) + + +def test_numpy_unique(datetime_series): + # it works! + np.unique(datetime_series) + + +@pytest.mark.parametrize("index", [["a", "b", "c", "d", "e"], None]) +def test_numpy_argwhere(index): + # GH#35331 + + s = Series(range(5), index=index, dtype=np.int64) + + result = np.argwhere(s > 2).astype(np.int64) + expected = np.array([[3], [4]], dtype=np.int64) + + tm.assert_numpy_array_equal(result, expected) + + +@td.skip_if_no("pyarrow") +def test_log_arrow_backed_missing_value(using_nan_is_na): + # GH#56285 + ser = Series([1, 2, None], dtype="float64[pyarrow]") + if using_nan_is_na: + result = np.log(ser) + expected = np.log(Series([1, 2, None], dtype="float64[pyarrow]")) + tm.assert_series_equal(result, expected) + else: + # we get cast to object which raises + msg = "loop of ufunc does not support argument" + with pytest.raises(TypeError, match=msg): + np.log(ser) diff --git a/pandas/tests/series/test_reductions.py b/pandas/tests/series/test_reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..d0be93324dde56e6d393873d3c62fc3c750dee8e --- /dev/null +++ b/pandas/tests/series/test_reductions.py @@ -0,0 +1,233 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import Series +import pandas._testing as tm + + +@pytest.mark.parametrize("operation, expected", [("min", "a"), ("max", "b")]) +def test_reductions_series_strings(operation, expected): + # GH#31746 + ser = Series(["a", "b"], dtype="string") + res_operation_serie = getattr(ser, operation)() + assert res_operation_serie == expected + + +@pytest.mark.parametrize("as_period", [True, False]) +def test_mode_extension_dtype(as_period): + # GH#41927 preserve dt64tz dtype + ser = Series([pd.Timestamp(1979, 4, n) for n in range(1, 5)]) + + if as_period: + ser = ser.dt.to_period("D") + else: + ser = ser.dt.tz_localize("US/Central") + + res = ser.mode() + assert res.dtype == ser.dtype + tm.assert_series_equal(res, ser) + + +def test_mode_nullable_dtype(any_numeric_ea_dtype): + # GH#55340 + ser = Series([1, 3, 2, pd.NA, 3, 2, pd.NA], dtype=any_numeric_ea_dtype) + result = ser.mode(dropna=False) + expected = Series([2, 3, pd.NA], dtype=any_numeric_ea_dtype) + tm.assert_series_equal(result, expected) + + result = ser.mode(dropna=True) + expected = Series([2, 3], dtype=any_numeric_ea_dtype) + tm.assert_series_equal(result, expected) + + ser[-1] = pd.NA + + result = ser.mode(dropna=True) + expected = Series([2, 3], dtype=any_numeric_ea_dtype) + tm.assert_series_equal(result, expected) + + result = ser.mode(dropna=False) + expected = Series([pd.NA], dtype=any_numeric_ea_dtype) + tm.assert_series_equal(result, expected) + + +def test_mode_nullable_dtype_edge_case(any_numeric_ea_dtype): + # GH##58926 + ser = Series([1, 2, 3, 1], dtype=any_numeric_ea_dtype) + result = ser.mode(dropna=False) + expected = Series([1], dtype=any_numeric_ea_dtype) + tm.assert_series_equal(result, expected) + + ser2 = Series([1, 1, 2, 3, pd.NA], dtype=any_numeric_ea_dtype) + result = ser2.mode(dropna=False) + expected = Series([1], dtype=any_numeric_ea_dtype) + tm.assert_series_equal(result, expected) + + ser3 = Series([1, pd.NA, pd.NA], dtype=any_numeric_ea_dtype) + result = ser3.mode(dropna=False) + expected = Series([pd.NA], dtype=any_numeric_ea_dtype) + tm.assert_series_equal(result, expected) + + ser4 = Series([1, 1, pd.NA, pd.NA], dtype=any_numeric_ea_dtype) + result = ser4.mode(dropna=False) + expected = Series([1, pd.NA], dtype=any_numeric_ea_dtype) + tm.assert_series_equal(result, expected) + + +def test_mode_infer_string(): + # GH#56183 + pytest.importorskip("pyarrow") + ser = Series(["a", "b"], dtype=object) + with pd.option_context("future.infer_string", True): + result = ser.mode() + expected = Series(["a", "b"], dtype=object) + tm.assert_series_equal(result, expected) + + +def test_reductions_td64_with_nat(): + # GH#8617 + ser = Series([0, pd.NaT], dtype="m8[ns]") + exp = ser[0] + assert ser.median() == exp + assert ser.min() == exp + assert ser.max() == exp + + +def test_td64_sum_empty(skipna): + # GH#37151 + ser = Series([], dtype="timedelta64[ns]") + + result = ser.sum(skipna=skipna) + assert isinstance(result, pd.Timedelta) + assert result == pd.Timedelta(0) + + +def test_td64_summation_overflow(): + # GH#9442 + ser = Series(pd.date_range("20130101", periods=100000, freq="h", unit="ns")) + ser[0] += pd.Timedelta("1s 1ms") + + # mean + result = (ser - ser.min()).mean() + expected = pd.Timedelta((pd.TimedeltaIndex(ser - ser.min()).asi8 / len(ser)).sum()) + + # the computation is converted to float so + # might be some loss of precision + assert np.allclose(result._value / 1000, expected._value / 1000) + + # sum + msg = "overflow in timedelta operation" + with pytest.raises(ValueError, match=msg): + (ser - ser.min()).sum() + + s1 = ser[0:10000] + with pytest.raises(ValueError, match=msg): + (s1 - s1.min()).sum() + s2 = ser[0:1000] + (s2 - s2.min()).sum() + + +def test_prod_numpy16_bug(): + ser = Series([1.0, 1.0, 1.0], index=range(3)) + result = ser.prod() + + assert not isinstance(result, Series) + + +@pytest.mark.parametrize("func", [np.any, np.all]) +@pytest.mark.parametrize("kwargs", [{"keepdims": True}, {"out": object()}]) +def test_validate_any_all_out_keepdims_raises(kwargs, func): + ser = Series([1, 2]) + param = next(iter(kwargs)) + name = func.__name__ + + msg = ( + f"the '{param}' parameter is not " + "supported in the pandas " + rf"implementation of {name}\(\)" + ) + with pytest.raises(ValueError, match=msg): + func(ser, **kwargs) + + +def test_validate_sum_initial(): + ser = Series([1, 2]) + msg = ( + r"the 'initial' parameter is not " + r"supported in the pandas " + r"implementation of sum\(\)" + ) + with pytest.raises(ValueError, match=msg): + np.sum(ser, initial=10) + + +def test_validate_median_initial(): + ser = Series([1, 2]) + msg = ( + r"the 'overwrite_input' parameter is not " + r"supported in the pandas " + r"implementation of median\(\)" + ) + with pytest.raises(ValueError, match=msg): + # It seems like np.median doesn't dispatch, so we use the + # method instead of the ufunc. + ser.median(overwrite_input=True) + + +def test_validate_stat_keepdims(): + ser = Series([1, 2]) + msg = ( + r"the 'keepdims' parameter is not " + r"supported in the pandas " + r"implementation of sum\(\)" + ) + with pytest.raises(ValueError, match=msg): + np.sum(ser, keepdims=True) + + +def test_mean_with_convertible_string_raises(): + # GH#44008 + ser = Series(["1", "2"]) + assert ser.sum() == "12" + + msg = "Could not convert string '12' to numeric|does not support|Cannot perform" + with pytest.raises(TypeError, match=msg): + ser.mean() + + df = ser.to_frame() + msg = r"Could not convert \['12'\] to numeric|does not support|Cannot perform" + with pytest.raises(TypeError, match=msg): + df.mean() + + +def test_mean_dont_convert_j_to_complex(): + # GH#36703 + df = pd.DataFrame([{"db": "J", "numeric": 123}]) + msg = r"Could not convert \['J'\] to numeric|does not support|Cannot perform" + with pytest.raises(TypeError, match=msg): + df.mean() + + with pytest.raises(TypeError, match=msg): + df.agg("mean") + + msg = "Could not convert string 'J' to numeric|does not support|Cannot perform" + with pytest.raises(TypeError, match=msg): + df["db"].mean() + msg = "Could not convert string 'J' to numeric|ufunc 'divide'|Cannot perform" + with pytest.raises(TypeError, match=msg): + np.mean(df["db"].astype("string").array) + + +def test_median_with_convertible_string_raises(): + # GH#34671 this _could_ return a string "2", but definitely not float 2.0 + msg = r"Cannot convert \['1' '2' '3'\] to numeric|does not support|Cannot perform" + ser = Series(["1", "2", "3"]) + with pytest.raises(TypeError, match=msg): + ser.median() + + msg = ( + r"Cannot convert \[\['1' '2' '3'\]\] to numeric|does not support|Cannot perform" + ) + df = ser.to_frame() + with pytest.raises(TypeError, match=msg): + df.median() diff --git a/pandas/tests/series/test_subclass.py b/pandas/tests/series/test_subclass.py new file mode 100644 index 0000000000000000000000000000000000000000..c2d5afcf884b12b3007905061b7c503359e71a5d --- /dev/null +++ b/pandas/tests/series/test_subclass.py @@ -0,0 +1,82 @@ +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm + +pytestmark = pytest.mark.filterwarnings( + "ignore:Passing a BlockManager|Passing a SingleBlockManager:DeprecationWarning" +) + + +class TestSeriesSubclassing: + @pytest.mark.parametrize( + "idx_method, indexer, exp_data, exp_idx", + [ + ["loc", ["a", "b"], [1, 2], "ab"], + ["iloc", [2, 3], [3, 4], "cd"], + ], + ) + def test_indexing_sliced(self, idx_method, indexer, exp_data, exp_idx): + s = tm.SubclassedSeries([1, 2, 3, 4], index=list("abcd")) + res = getattr(s, idx_method)[indexer] + exp = tm.SubclassedSeries(exp_data, index=list(exp_idx)) + tm.assert_series_equal(res, exp) + + def test_to_frame(self): + s = tm.SubclassedSeries([1, 2, 3, 4], index=list("abcd"), name="xxx") + res = s.to_frame() + exp = tm.SubclassedDataFrame({"xxx": [1, 2, 3, 4]}, index=list("abcd")) + tm.assert_frame_equal(res, exp) + + def test_subclass_unstack(self): + # GH 15564 + s = tm.SubclassedSeries([1, 2, 3, 4], index=[list("aabb"), list("xyxy")]) + + res = s.unstack() + exp = tm.SubclassedDataFrame({"x": [1, 3], "y": [2, 4]}, index=["a", "b"]) + + tm.assert_frame_equal(res, exp) + + def test_subclass_empty_repr(self): + sub_series = tm.SubclassedSeries() + assert "SubclassedSeries" in repr(sub_series) + + def test_asof(self): + N = 3 + rng = pd.date_range("1/1/1990", periods=N, freq="53s") + s = tm.SubclassedSeries({"A": [np.nan, np.nan, np.nan]}, index=rng) + + result = s.asof(rng[-2:]) + assert isinstance(result, tm.SubclassedSeries) + + def test_explode(self): + s = tm.SubclassedSeries([[1, 2, 3], "foo", [], [3, 4]]) + result = s.explode() + assert isinstance(result, tm.SubclassedSeries) + + def test_equals(self): + # https://github.com/pandas-dev/pandas/pull/34402 + # allow subclass in both directions + s1 = pd.Series([1, 2, 3]) + s2 = tm.SubclassedSeries([1, 2, 3]) + assert s1.equals(s2) + assert s2.equals(s1) + + +class SubclassedSeries(pd.Series): + @property + def _constructor(self): + def _new(*args, **kwargs): + # some constructor logic that accesses the Series' name + if self.name == "test": + return pd.Series(*args, **kwargs) + return SubclassedSeries(*args, **kwargs) + + return _new + + +def test_constructor_from_dict(): + # https://github.com/pandas-dev/pandas/issues/52445 + result = SubclassedSeries({"a": 1, "b": 2, "c": 3}) + assert isinstance(result, SubclassedSeries) diff --git a/pandas/tests/series/test_ufunc.py b/pandas/tests/series/test_ufunc.py new file mode 100644 index 0000000000000000000000000000000000000000..9eaf1632528d87be4d79b69631151842e9050cdc --- /dev/null +++ b/pandas/tests/series/test_ufunc.py @@ -0,0 +1,475 @@ +from collections import deque +import re +import string + +import numpy as np +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.arrays import SparseArray + + +@pytest.fixture(params=[np.add, np.logaddexp]) +def ufunc(request): + # dunder op + return request.param + + +@pytest.fixture( + params=[pytest.param(True, marks=pytest.mark.fails_arm_wheels), False], + ids=["sparse", "dense"], +) +def sparse(request): + return request.param + + +@pytest.fixture +def arrays_for_binary_ufunc(): + """ + A pair of random, length-100 integer-dtype arrays, that are mostly 0. + """ + a1 = np.random.default_rng(2).integers(0, 10, 100, dtype="int64") + a2 = np.random.default_rng(2).integers(0, 10, 100, dtype="int64") + a1[::3] = 0 + a2[::4] = 0 + return a1, a2 + + +@pytest.mark.parametrize("ufunc", [np.positive, np.floor, np.exp]) +def test_unary_ufunc(ufunc, sparse): + # Test that ufunc(pd.Series) == pd.Series(ufunc) + arr = np.random.default_rng(2).integers(0, 10, 10, dtype="int64") + arr[::2] = 0 + if sparse: + arr = SparseArray(arr, dtype=pd.SparseDtype("int64", 0)) + + index = list(string.ascii_letters[:10]) + name = "name" + series = pd.Series(arr, index=index, name=name) + + result = ufunc(series) + expected = pd.Series(ufunc(arr), index=index, name=name) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("flip", [True, False], ids=["flipped", "straight"]) +def test_binary_ufunc_with_array(flip, sparse, ufunc, arrays_for_binary_ufunc): + # Test that ufunc(pd.Series(a), array) == pd.Series(ufunc(a, b)) + a1, a2 = arrays_for_binary_ufunc + if sparse: + a1 = SparseArray(a1, dtype=pd.SparseDtype("int64", 0)) + a2 = SparseArray(a2, dtype=pd.SparseDtype("int64", 0)) + + name = "name" # op(pd.Series, array) preserves the name. + series = pd.Series(a1, name=name) + other = a2 + + array_args = (a1, a2) + series_args = (series, other) # ufunc(series, array) + + if flip: + array_args = reversed(array_args) + series_args = reversed(series_args) # ufunc(array, series) + + expected = pd.Series(ufunc(*array_args), name=name) + result = ufunc(*series_args) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("flip", [True, False], ids=["flipped", "straight"]) +def test_binary_ufunc_with_index(flip, sparse, ufunc, arrays_for_binary_ufunc): + # Test that + # * func(pd.Series(a), pd.Series(b)) == pd.Series(ufunc(a, b)) + # * ufunc(Index, pd.Series) dispatches to pd.Series (returns a pd.Series) + a1, a2 = arrays_for_binary_ufunc + if sparse: + a1 = SparseArray(a1, dtype=pd.SparseDtype("int64", 0)) + a2 = SparseArray(a2, dtype=pd.SparseDtype("int64", 0)) + + name = "name" # op(pd.Series, array) preserves the name. + series = pd.Series(a1, name=name) + + other = pd.Index(a2, name=name).astype("int64") + + array_args = (a1, a2) + series_args = (series, other) # ufunc(series, array) + + if flip: + array_args = reversed(array_args) + series_args = reversed(series_args) # ufunc(array, series) + + expected = pd.Series(ufunc(*array_args), name=name) + result = ufunc(*series_args) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("shuffle", [True, False], ids=["unaligned", "aligned"]) +@pytest.mark.parametrize("flip", [True, False], ids=["flipped", "straight"]) +def test_binary_ufunc_with_series( + flip, shuffle, sparse, ufunc, arrays_for_binary_ufunc +): + # Test that + # * func(pd.Series(a), pd.Series(b)) == pd.Series(ufunc(a, b)) + # with alignment between the indices + a1, a2 = arrays_for_binary_ufunc + if sparse: + a1 = SparseArray(a1, dtype=pd.SparseDtype("int64", 0)) + a2 = SparseArray(a2, dtype=pd.SparseDtype("int64", 0)) + + name = "name" # op(pd.Series, array) preserves the name. + series = pd.Series(a1, name=name) + other = pd.Series(a2, name=name) + + idx = np.random.default_rng(2).permutation(len(a1)) + + if shuffle: + other = other.take(idx) + if flip: + index = other.align(series)[0].index + else: + index = series.align(other)[0].index + else: + index = series.index + + array_args = (a1, a2) + series_args = (series, other) # ufunc(series, array) + + if flip: + array_args = tuple(reversed(array_args)) + series_args = tuple(reversed(series_args)) # ufunc(array, series) + + expected = pd.Series(ufunc(*array_args), index=index, name=name) + result = ufunc(*series_args) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("flip", [True, False]) +def test_binary_ufunc_scalar(ufunc, sparse, flip, arrays_for_binary_ufunc): + # Test that + # * ufunc(pd.Series, scalar) == pd.Series(ufunc(array, scalar)) + # * ufunc(pd.Series, scalar) == ufunc(scalar, pd.Series) + arr, _ = arrays_for_binary_ufunc + if sparse: + arr = SparseArray(arr) + other = 2 + series = pd.Series(arr, name="name") + + series_args = (series, other) + array_args = (arr, other) + + if flip: + series_args = tuple(reversed(series_args)) + array_args = tuple(reversed(array_args)) + + expected = pd.Series(ufunc(*array_args), name="name") + result = ufunc(*series_args) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("ufunc", [np.divmod]) # TODO: np.modf, np.frexp +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.filterwarnings("ignore:divide by zero:RuntimeWarning") +def test_multiple_output_binary_ufuncs(ufunc, sparse, shuffle, arrays_for_binary_ufunc): + # Test that + # the same conditions from binary_ufunc_scalar apply to + # ufuncs with multiple outputs. + + a1, a2 = arrays_for_binary_ufunc + # work around https://github.com/pandas-dev/pandas/issues/26987 + a1[a1 == 0] = 1 + a2[a2 == 0] = 1 + + if sparse: + a1 = SparseArray(a1, dtype=pd.SparseDtype("int64", 0)) + a2 = SparseArray(a2, dtype=pd.SparseDtype("int64", 0)) + + s1 = pd.Series(a1) + s2 = pd.Series(a2) + + if shuffle: + # ensure we align before applying the ufunc + s2 = s2.sample(frac=1) + + expected = ufunc(a1, a2) + assert isinstance(expected, tuple) + + result = ufunc(s1, s2) + assert isinstance(result, tuple) + tm.assert_series_equal(result[0], pd.Series(expected[0])) + tm.assert_series_equal(result[1], pd.Series(expected[1])) + + +def test_multiple_output_ufunc(sparse, arrays_for_binary_ufunc): + # Test that the same conditions from unary input apply to multi-output + # ufuncs + arr, _ = arrays_for_binary_ufunc + + if sparse: + arr = SparseArray(arr) + + series = pd.Series(arr, name="name") + result = np.modf(series) + expected = np.modf(arr) + + assert isinstance(result, tuple) + assert isinstance(expected, tuple) + + tm.assert_series_equal(result[0], pd.Series(expected[0], name="name")) + tm.assert_series_equal(result[1], pd.Series(expected[1], name="name")) + + +def test_binary_ufunc_drops_series_name(ufunc, sparse, arrays_for_binary_ufunc): + # Drop the names when they differ. + a1, a2 = arrays_for_binary_ufunc + s1 = pd.Series(a1, name="a") + s2 = pd.Series(a2, name="b") + + result = ufunc(s1, s2) + assert result.name is None + + +def test_object_series_ok(): + class Dummy: + def __init__(self, value) -> None: + self.value = value + + def __add__(self, other): + return self.value + other.value + + arr = np.array([Dummy(0), Dummy(1)]) + ser = pd.Series(arr) + tm.assert_series_equal(np.add(ser, ser), pd.Series(np.add(ser, arr))) + tm.assert_series_equal(np.add(ser, Dummy(1)), pd.Series(np.add(ser, Dummy(1)))) + + +@pytest.fixture( + params=[ + pd.array([1, 3, 2], dtype=np.int64), + pd.array([1, 3, 2], dtype="Int64"), + pd.array([1, 3, 2], dtype="Float32"), + pd.array([1, 10, 2], dtype="Sparse[int]"), + pd.to_datetime(["2000", "2010", "2001"]), + pd.to_datetime(["2000", "2010", "2001"]).tz_localize("CET"), + pd.to_datetime(["2000", "2010", "2001"]).to_period(freq="D"), + pd.to_timedelta(["1 Day", "3 Days", "2 Days"]), + pd.IntervalIndex([pd.Interval(0, 1), pd.Interval(2, 3), pd.Interval(1, 2)]), + ], + ids=lambda x: str(x.dtype), +) +def values_for_np_reduce(request): + # min/max tests assume that these are monotonic increasing + return request.param + + +class TestNumpyReductions: + # TODO: cases with NAs, axis kwarg for DataFrame + + def test_multiply(self, values_for_np_reduce, box_with_array, request): + box = box_with_array + values = values_for_np_reduce + + with tm.assert_produces_warning(None): + obj = box(values) + + if isinstance(values, pd.core.arrays.SparseArray): + mark = pytest.mark.xfail(reason="SparseArray has no 'prod'") + request.applymarker(mark) + + if values.dtype.kind in "iuf": + result = np.multiply.reduce(obj) + if box is pd.DataFrame: + expected = obj.prod(numeric_only=False) + tm.assert_series_equal(result, expected) + elif box is pd.Index: + # Index has no 'prod' + expected = obj._values.prod() + assert result == expected + else: + expected = obj.prod() + assert result == expected + else: + msg = "|".join( + [ + "does not support operation", + "unsupported operand type", + "ufunc 'multiply' cannot use operands", + ] + ) + with pytest.raises(TypeError, match=msg): + np.multiply.reduce(obj) + + def test_add(self, values_for_np_reduce, box_with_array): + box = box_with_array + values = values_for_np_reduce + + with tm.assert_produces_warning(None): + obj = box(values) + + if values.dtype.kind in "miuf": + result = np.add.reduce(obj) + if box is pd.DataFrame: + expected = obj.sum(numeric_only=False) + tm.assert_series_equal(result, expected) + elif box is pd.Index: + # Index has no 'sum' + expected = obj._values.sum() + assert result == expected + else: + expected = obj.sum() + assert result == expected + else: + msg = "|".join( + [ + "does not support operation", + "unsupported operand type", + "ufunc 'add' cannot use operands", + ] + ) + with pytest.raises(TypeError, match=msg): + np.add.reduce(obj) + + def test_max(self, values_for_np_reduce, box_with_array, using_python_scalars): + box = box_with_array + values = values_for_np_reduce + + same_type = True + if box is pd.Index and values.dtype.kind in "if": + # ATM Index casts to object, so we get python ints/floats + same_type = False + + with tm.assert_produces_warning(None): + obj = box(values) + + result = np.maximum.reduce(obj) + if box is pd.DataFrame: + # TODO: cases with axis kwarg + expected = obj.max(numeric_only=False) + tm.assert_series_equal(result, expected) + else: + expected = values[1] + if using_python_scalars and values.dtype.kind in "if": + expected = expected.item() + assert result == expected + if same_type: + # check we have e.g. Timestamp instead of dt64 + assert type(result) == type(expected) + + def test_min(self, values_for_np_reduce, box_with_array, using_python_scalars): + box = box_with_array + values = values_for_np_reduce + + same_type = True + if box is pd.Index and values.dtype.kind in "if": + # ATM Index casts to object, so we get python ints/floats + same_type = False + + with tm.assert_produces_warning(None): + obj = box(values) + + result = np.minimum.reduce(obj) + if box is pd.DataFrame: + expected = obj.min(numeric_only=False) + tm.assert_series_equal(result, expected) + else: + expected = values[0] + if using_python_scalars and values.dtype.kind in ["i", "f"]: + expected = expected.item() + assert result == expected + if same_type: + # check we have e.g. Timestamp instead of dt64 + assert type(result) == type(expected) + + +@pytest.mark.parametrize("type_", [list, deque, tuple]) +def test_binary_ufunc_other_types(type_): + a = pd.Series([1, 2, 3], name="name") + b = type_([3, 4, 5]) + + result = np.add(a, b) + expected = pd.Series(np.add(a.to_numpy(), b), name="name") + tm.assert_series_equal(result, expected) + + +def test_object_dtype_ok(): + class Thing: + def __init__(self, value) -> None: + self.value = value + + def __add__(self, other): + other = getattr(other, "value", other) + return type(self)(self.value + other) + + def __eq__(self, other) -> bool: + return type(other) is Thing and self.value == other.value + + def __repr__(self) -> str: + return f"Thing({self.value})" + + s = pd.Series([Thing(1), Thing(2)]) + result = np.add(s, Thing(1)) + expected = pd.Series([Thing(2), Thing(3)]) + tm.assert_series_equal(result, expected) + + +def test_outer(): + # https://github.com/pandas-dev/pandas/issues/27186 + ser = pd.Series([1, 2, 3]) + obj = np.array([1, 2, 3]) + + with pytest.raises(NotImplementedError, match="^$"): + np.subtract.outer(ser, obj) + + +def test_np_matmul(): + # GH26650 + df1 = pd.DataFrame(data=[[-1, 1, 10]]) + df2 = pd.DataFrame(data=[-1, 1, 10]) + expected = pd.DataFrame(data=[102]) + + result = np.matmul(df1, df2) + tm.assert_frame_equal(expected, result) + + +@pytest.mark.parametrize("box", [pd.Index, pd.Series]) +def test_np_matmul_1D(box, using_python_scalars): + result = np.matmul(box([1, 2]), box([2, 3])) + assert result == 8 + if using_python_scalars: + assert type(result) == int, type(result) + else: + assert type(result) == np.int64, type(result) + + +def test_array_ufuncs_for_many_arguments(): + # GH39853 + def add3(x, y, z): + return x + y + z + + ufunc = np.frompyfunc(add3, 3, 1) + ser = pd.Series([1, 2]) + + result = ufunc(ser, ser, 1) + expected = pd.Series([3, 5], dtype=object) + tm.assert_series_equal(result, expected) + + df = pd.DataFrame([[1, 2]]) + + msg = ( + "Cannot apply ufunc " + "to mixed DataFrame and Series inputs." + ) + with pytest.raises(NotImplementedError, match=re.escape(msg)): + ufunc(ser, ser, df) + + +def test_np_trunc(): + # This used to test np.fix, which is not a ufunc but is composed of + # several ufunc calls under the hood with `out` and `where` keywords. But numpy + # is deprecating that (or at least discussing deprecating) in favor of np.trunc, + # which _is_ a ufunc without the out keyword usage. + ser = pd.Series([-1.5, -0.5, 0.5, 1.5]) + result = np.trunc(ser) + expected = pd.Series([-1.0, -0.0, 0.0, 1.0]) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/series/test_unary.py b/pandas/tests/series/test_unary.py new file mode 100644 index 0000000000000000000000000000000000000000..8f153788e413c5e9198dc35867bc628823555dbf --- /dev/null +++ b/pandas/tests/series/test_unary.py @@ -0,0 +1,50 @@ +import pytest + +from pandas import Series +import pandas._testing as tm + + +class TestSeriesUnaryOps: + # __neg__, __pos__, __invert__ + + def test_neg(self): + ser = Series(range(5), dtype="float64", name="series") + tm.assert_series_equal(-ser, -1 * ser) + + def test_invert(self): + ser = Series(range(5), dtype="float64", name="series") + tm.assert_series_equal(-(ser < 0), ~(ser < 0)) + + @pytest.mark.parametrize( + "source, neg_target, abs_target", + [ + ([1, 2, 3], [-1, -2, -3], [1, 2, 3]), + ([1, 2, None], [-1, -2, None], [1, 2, None]), + ], + ) + def test_all_numeric_unary_operators( + self, any_numeric_ea_dtype, source, neg_target, abs_target + ): + # GH38794 + dtype = any_numeric_ea_dtype + ser = Series(source, dtype=dtype) + neg_result, pos_result, abs_result = -ser, +ser, abs(ser) + if dtype.startswith("U"): + neg_target = -Series(source, dtype=dtype) + else: + neg_target = Series(neg_target, dtype=dtype) + + abs_target = Series(abs_target, dtype=dtype) + + tm.assert_series_equal(neg_result, neg_target) + tm.assert_series_equal(pos_result, ser) + tm.assert_series_equal(abs_result, abs_target) + + @pytest.mark.parametrize("op", ["__neg__", "__abs__"]) + def test_unary_float_op_mask(self, float_ea_dtype, op): + dtype = float_ea_dtype + ser = Series([1.1, 2.2, 3.3], dtype=dtype) + result = getattr(ser, op)() + target = result.copy(deep=True) + ser[0] = None + tm.assert_series_equal(result, target) diff --git a/pandas/tests/series/test_validate.py b/pandas/tests/series/test_validate.py new file mode 100644 index 0000000000000000000000000000000000000000..3c867f7582b7d3250bf5e009ffbf7545da404712 --- /dev/null +++ b/pandas/tests/series/test_validate.py @@ -0,0 +1,26 @@ +import pytest + + +@pytest.mark.parametrize( + "func", + [ + "reset_index", + "_set_name", + "sort_values", + "sort_index", + "rename", + "dropna", + "drop_duplicates", + ], +) +@pytest.mark.parametrize("inplace", [1, "True", [1, 2, 3], 5.0]) +def test_validate_bool_args(string_series, func, inplace): + """Tests for error handling related to data types of method arguments.""" + msg = 'For argument "inplace" expected type bool' + kwargs = {"inplace": inplace} + + if func == "_set_name": + kwargs["name"] = "hello" + + with pytest.raises(ValueError, match=msg): + getattr(string_series, func)(**kwargs) diff --git a/pandas/tests/strings/__init__.py b/pandas/tests/strings/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c4bec6a2378932b9d49302a04b4f0f80a9e3e3b --- /dev/null +++ b/pandas/tests/strings/__init__.py @@ -0,0 +1,23 @@ +import numpy as np + +import pandas as pd + + +def is_object_or_nan_string_dtype(dtype): + """ + Check if string-like dtype is following NaN semantics, i.e. is object + dtype or a NaN-variant of the StringDtype. + """ + return (isinstance(dtype, np.dtype) and dtype == "object") or ( + dtype.na_value is np.nan + ) + + +def _convert_na_value(ser, expected): + if ser.dtype != object: + if ser.dtype.na_value is np.nan: + expected = expected.fillna(np.nan) + else: + # GH#18463 + expected = expected.fillna(pd.NA) + return expected diff --git a/pandas/tests/strings/conftest.py b/pandas/tests/strings/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..d84d0db2c019df9206c33b0edf166f11df8eef57 --- /dev/null +++ b/pandas/tests/strings/conftest.py @@ -0,0 +1,132 @@ +import pytest + +from pandas import Series +from pandas.core.strings.accessor import StringMethods + +_any_string_method = [ + ("cat", (), {"sep": ","}), + ("cat", (Series(list("zyx")),), {"sep": ",", "join": "left"}), + ("center", (10,), {}), + ("contains", ("a",), {}), + ("count", ("a",), {}), + ("decode", ("UTF-8",), {}), + ("encode", ("UTF-8",), {}), + ("endswith", ("a",), {}), + ("endswith", ((),), {}), + ("endswith", (("a",),), {}), + ("endswith", (("a", "b"),), {}), + ("endswith", (("a", "MISSING"),), {}), + ("endswith", ("a",), {"na": True}), + ("endswith", ("a",), {"na": False}), + ("extract", ("([a-z]*)",), {"expand": False}), + ("extract", ("([a-z]*)",), {"expand": True}), + ("extractall", ("([a-z]*)",), {}), + ("find", ("a",), {}), + ("findall", ("a",), {}), + ("get", (0,), {}), + # because "index" (and "rindex") fail intentionally + # if the string is not found, search only for empty string + ("index", ("",), {}), + ("join", (",",), {}), + ("ljust", (10,), {}), + ("match", ("a",), {}), + ("fullmatch", ("a",), {}), + ("normalize", ("NFC",), {}), + ("pad", (10,), {}), + ("partition", (" ",), {"expand": False}), + ("partition", (" ",), {"expand": True}), + ("repeat", (3,), {}), + ("replace", ("a", "z"), {}), + ("rfind", ("a",), {}), + ("rindex", ("",), {}), + ("rjust", (10,), {}), + ("rpartition", (" ",), {"expand": False}), + ("rpartition", (" ",), {"expand": True}), + ("slice", (0, 1), {}), + ("slice_replace", (0, 1, "z"), {}), + ("split", (" ",), {"expand": False}), + ("split", (" ",), {"expand": True}), + ("startswith", ("a",), {}), + ("startswith", (("a",),), {}), + ("startswith", (("a", "b"),), {}), + ("startswith", (("a", "MISSING"),), {}), + ("startswith", ((),), {}), + ("startswith", ("a",), {"na": True}), + ("startswith", ("a",), {"na": False}), + ("removeprefix", ("a",), {}), + ("removesuffix", ("a",), {}), + # translating unicode points of "a" to "d" + ("translate", ({97: 100},), {}), + ("wrap", (2,), {}), + ("zfill", (10,), {}), + # methods without positional arguments: zip with empty tuple and empty dict + *zip( + [ + "capitalize", + "cat", + "get_dummies", + "isalnum", + "isalpha", + "isascii", + "isdecimal", + "isdigit", + "islower", + "isnumeric", + "isspace", + "istitle", + "isupper", + "len", + "lower", + "lstrip", + "partition", + "rpartition", + "rsplit", + "rstrip", + "slice", + "slice_replace", + "split", + "strip", + "swapcase", + "title", + "upper", + "casefold", + ], + [()] * 100, + [{}] * 100, + ), +] +ids, _, _ = zip(*_any_string_method) # use method name as fixture-id +missing_methods = {f for f in dir(StringMethods) if not f.startswith("_")} - set(ids) + +# test that the above list captures all methods of StringMethods +assert not missing_methods + + +@pytest.fixture(params=_any_string_method, ids=ids) +def any_string_method(request): + """ + Fixture for all public methods of `StringMethods` + + This fixture returns a tuple of the method name and sample arguments + necessary to call the method. + + Returns + ------- + method_name : str + The name of the method in `StringMethods` + args : tuple + Sample values for the positional arguments + kwargs : dict + Sample values for the keyword arguments + + Examples + -------- + >>> def test_something(any_string_method): + ... s = Series(["a", "b", np.nan, "d"]) + ... + ... method_name, args, kwargs = any_string_method + ... method = getattr(s.str, method_name) + ... # will not raise + ... method(*args, **kwargs) + """ + return request.param diff --git a/pandas/tests/strings/test_api.py b/pandas/tests/strings/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..cbb663cdca4ad6ceadce4f13176d306f5b388ebe --- /dev/null +++ b/pandas/tests/strings/test_api.py @@ -0,0 +1,216 @@ +import weakref + +import numpy as np +import pytest + +from pandas import ( + CategoricalDtype, + DataFrame, + Index, + MultiIndex, + Series, + _testing as tm, +) +from pandas.core.strings.accessor import StringMethods + +# subset of the full set from pandas/conftest.py +_any_allowed_skipna_inferred_dtype = [ + ("string", ["a", np.nan, "c"]), + ("bytes", [b"a", np.nan, b"c"]), + ("empty", [np.nan, np.nan, np.nan]), + ("empty", []), + ("mixed-integer", ["a", np.nan, 2]), +] +ids, _ = zip( + *_any_allowed_skipna_inferred_dtype, strict=True +) # use inferred type as id + + +@pytest.fixture(params=_any_allowed_skipna_inferred_dtype, ids=ids) +def any_allowed_skipna_inferred_dtype(request): + """ + Fixture for all (inferred) dtypes allowed in StringMethods.__init__ + + The covered (inferred) types are: + * 'string' + * 'empty' + * 'bytes' + * 'mixed' + * 'mixed-integer' + + Returns + ------- + inferred_dtype : str + The string for the inferred dtype from _libs.lib.infer_dtype + values : np.ndarray + An array of object dtype that will be inferred to have + `inferred_dtype` + + Examples + -------- + >>> from pandas._libs import lib + >>> + >>> def test_something(any_allowed_skipna_inferred_dtype): + ... inferred_dtype, values = any_allowed_skipna_inferred_dtype + ... # will pass + ... assert lib.infer_dtype(values, skipna=True) == inferred_dtype + ... + ... # constructor for .str-accessor will also pass + ... Series(values).str + """ + inferred_dtype, values = request.param + values = np.array(values, dtype=object) # object dtype to avoid casting + + # correctness of inference tested in tests/dtypes/test_inference.py + return inferred_dtype, values + + +def test_api(any_string_dtype): + # GH 6106, GH 9322 + assert Series.str is StringMethods + assert isinstance(Series([""], dtype=any_string_dtype).str, StringMethods) + + +def test_no_circular_reference(any_string_dtype): + # GH 47667 + ser = Series([""], dtype=any_string_dtype) + ref = weakref.ref(ser) + ser.str # Used to cache and cause circular reference + del ser + assert ref() is None + + +def test_api_mi_raises(): + # GH 23679 + mi = MultiIndex.from_arrays([["a", "b", "c"]]) + msg = "Can only use .str accessor with Index, not MultiIndex" + with pytest.raises(AttributeError, match=msg): + mi.str + assert not hasattr(mi, "str") + + +@pytest.mark.parametrize("dtype", [object, "category"]) +def test_api_per_dtype(index_or_series, dtype, any_skipna_inferred_dtype): + # one instance of parametrized fixture + box = index_or_series + inferred_dtype, values = any_skipna_inferred_dtype + + t = box(values, dtype=dtype) # explicit dtype to avoid casting + + types_passing_constructor = [ + "string", + "unicode", + "empty", + "bytes", + "mixed", + "mixed-integer", + ] + if inferred_dtype in types_passing_constructor: + # GH 6106 + assert isinstance(t.str, StringMethods) + else: + # GH 9184, GH 23011, GH 23163 + msg = "Can only use .str accessor with string values.*" + with pytest.raises(AttributeError, match=msg): + t.str + assert not hasattr(t, "str") + + +@pytest.mark.parametrize("dtype", [object, "category"]) +def test_api_per_method( + index_or_series, + dtype, + any_allowed_skipna_inferred_dtype, + any_string_method, + request, + using_infer_string, +): + # this test does not check correctness of the different methods, + # just that the methods work on the specified (inferred) dtypes, + # and raise on all others + box = index_or_series + + # one instance of each parametrized fixture + inferred_dtype, values = any_allowed_skipna_inferred_dtype + method_name, args, kwargs = any_string_method + + reason = None + if box is Index and values.size == 0: + if method_name in ["partition", "rpartition"] and kwargs.get("expand", True): + raises = TypeError + reason = "Method cannot deal with empty Index" + elif method_name == "split" and kwargs.get("expand", None): + raises = TypeError + reason = "Split fails on empty Series when expand=True" + elif method_name == "get_dummies": + raises = ValueError + reason = "Need to fortify get_dummies corner cases" + + elif ( + box is Index + and inferred_dtype == "empty" + and dtype == object + and method_name == "get_dummies" + ): + raises = ValueError + reason = "Need to fortify get_dummies corner cases" + + if reason is not None: + mark = pytest.mark.xfail(raises=raises, reason=reason) + request.applymarker(mark) + + t = box(values, dtype=dtype) # explicit dtype to avoid casting + method = getattr(t.str, method_name) + + if using_infer_string and dtype == "category": + string_allowed = method_name not in ["decode"] + else: + string_allowed = True + bytes_allowed = method_name in ["decode", "get", "len", "slice"] + # as of v0.23.4, all methods except 'cat' are very lenient with the + # allowed data types, just returning NaN for entries that error. + # This could be changed with an 'errors'-kwarg to the `str`-accessor, + # see discussion in GH 13877 + mixed_allowed = method_name not in ["cat"] + + allowed_types = ( + ["empty"] + + ["string", "unicode"] * string_allowed + + ["bytes"] * bytes_allowed + + ["mixed", "mixed-integer"] * mixed_allowed + ) + + if inferred_dtype in allowed_types: + # xref GH 23555, GH 23556 + method(*args, **kwargs) # works! + else: + # GH 23011, GH 23163 + msg = ( + f"Cannot use .str.{method_name} with values of " + f"inferred dtype {inferred_dtype!r}." + "|a bytes-like object is required, not 'str'" + ) + with pytest.raises(TypeError, match=msg): + method(*args, **kwargs) + + +def test_api_for_categorical(any_string_method, any_string_dtype): + # https://github.com/pandas-dev/pandas/issues/10661 + s = Series(list("aabb"), dtype=any_string_dtype) + s = s + " " + s + c = s.astype("category") + c = c.astype(CategoricalDtype(c.dtype.categories.astype("object"))) + assert isinstance(c.str, StringMethods) + + method_name, args, kwargs = any_string_method + + result = getattr(c.str, method_name)(*args, **kwargs) + expected = getattr(s.astype("object").str, method_name)(*args, **kwargs) + + if isinstance(result, DataFrame): + tm.assert_frame_equal(result, expected) + elif isinstance(result, Series): + tm.assert_series_equal(result, expected) + else: + # str.cat(others=None) returns string, for example + assert result == expected diff --git a/pandas/tests/strings/test_case_justify.py b/pandas/tests/strings/test_case_justify.py new file mode 100644 index 0000000000000000000000000000000000000000..819556f961fa39fa2e93388fd12d37b0f9aefa4d --- /dev/null +++ b/pandas/tests/strings/test_case_justify.py @@ -0,0 +1,423 @@ +from datetime import datetime +import operator + +import numpy as np +import pytest + +from pandas import ( + Series, + _testing as tm, +) + + +def test_title(any_string_dtype): + s = Series(["FOO", "BAR", np.nan, "Blah", "blurg"], dtype=any_string_dtype) + result = s.str.title() + expected = Series(["Foo", "Bar", np.nan, "Blah", "Blurg"], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_title_mixed_object(): + s = Series(["FOO", np.nan, "bar", True, datetime.today(), "blah", None, 1, 2.0]) + result = s.str.title() + expected = Series( + ["Foo", np.nan, "Bar", np.nan, np.nan, "Blah", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_almost_equal(result, expected) + + +def test_lower_upper(any_string_dtype): + s = Series(["om", np.nan, "nom", "nom"], dtype=any_string_dtype) + + result = s.str.upper() + expected = Series(["OM", np.nan, "NOM", "NOM"], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + result = result.str.lower() + tm.assert_series_equal(result, s) + + +def test_lower_upper_mixed_object(): + s = Series(["a", np.nan, "b", True, datetime.today(), "foo", None, 1, 2.0]) + + result = s.str.upper() + expected = Series( + ["A", np.nan, "B", np.nan, np.nan, "FOO", None, np.nan, np.nan], dtype=object + ) + tm.assert_series_equal(result, expected) + + result = s.str.lower() + expected = Series( + ["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan], dtype=object + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "data, expected", + [ + ( + ["FOO", "BAR", np.nan, "Blah", "blurg"], + ["Foo", "Bar", np.nan, "Blah", "Blurg"], + ), + (["a", "b", "c"], ["A", "B", "C"]), + (["a b", "a bc. de"], ["A b", "A bc. de"]), + ], +) +def test_capitalize(data, expected, any_string_dtype): + s = Series(data, dtype=any_string_dtype) + result = s.str.capitalize() + expected = Series(expected, dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_capitalize_mixed_object(): + s = Series(["FOO", np.nan, "bar", True, datetime.today(), "blah", None, 1, 2.0]) + result = s.str.capitalize() + expected = Series( + ["Foo", np.nan, "Bar", np.nan, np.nan, "Blah", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + +def test_swapcase(any_string_dtype): + s = Series(["FOO", "BAR", np.nan, "Blah", "blurg"], dtype=any_string_dtype) + result = s.str.swapcase() + expected = Series(["foo", "bar", np.nan, "bLAH", "BLURG"], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_swapcase_mixed_object(): + s = Series(["FOO", np.nan, "bar", True, datetime.today(), "Blah", None, 1, 2.0]) + result = s.str.swapcase() + expected = Series( + ["foo", np.nan, "BAR", np.nan, np.nan, "bLAH", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + +def test_casefold(): + # GH25405 + expected = Series(["ss", np.nan, "case", "ssd"]) + s = Series(["ß", np.nan, "case", "ßd"]) + result = s.str.casefold() + + tm.assert_series_equal(result, expected) + + +def test_casemethods(any_string_dtype): + values = ["aaa", "bbb", "CCC", "Dddd", "eEEE"] + s = Series(values, dtype=any_string_dtype) + assert s.str.lower().tolist() == [v.lower() for v in values] + assert s.str.upper().tolist() == [v.upper() for v in values] + assert s.str.title().tolist() == [v.title() for v in values] + assert s.str.capitalize().tolist() == [v.capitalize() for v in values] + assert s.str.swapcase().tolist() == [v.swapcase() for v in values] + + +def test_pad(any_string_dtype): + s = Series(["a", "b", np.nan, "c", np.nan, "eeeeee"], dtype=any_string_dtype) + + result = s.str.pad(5, side="left") + expected = Series( + [" a", " b", np.nan, " c", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = s.str.pad(5, side="right") + expected = Series( + ["a ", "b ", np.nan, "c ", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = s.str.pad(5, side="both") + expected = Series( + [" a ", " b ", np.nan, " c ", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + +def test_pad_mixed_object(): + s = Series(["a", np.nan, "b", True, datetime.today(), "ee", None, 1, 2.0]) + + result = s.str.pad(5, side="left") + expected = Series( + [" a", np.nan, " b", np.nan, np.nan, " ee", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + result = s.str.pad(5, side="right") + expected = Series( + ["a ", np.nan, "b ", np.nan, np.nan, "ee ", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + result = s.str.pad(5, side="both") + expected = Series( + [" a ", np.nan, " b ", np.nan, np.nan, " ee ", None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + +def test_pad_fillchar(any_string_dtype): + s = Series(["a", "b", np.nan, "c", np.nan, "eeeeee"], dtype=any_string_dtype) + + result = s.str.pad(5, side="left", fillchar="X") + expected = Series( + ["XXXXa", "XXXXb", np.nan, "XXXXc", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = s.str.pad(5, side="right", fillchar="X") + expected = Series( + ["aXXXX", "bXXXX", np.nan, "cXXXX", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = s.str.pad(5, side="both", fillchar="X") + expected = Series( + ["XXaXX", "XXbXX", np.nan, "XXcXX", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + +def test_pad_fillchar_bad_arg_raises(any_string_dtype): + s = Series(["a", "b", np.nan, "c", np.nan, "eeeeee"], dtype=any_string_dtype) + + msg = "fillchar must be a character, not str" + with pytest.raises(TypeError, match=msg): + s.str.pad(5, fillchar="XY") + + msg = "fillchar must be a character, not int" + with pytest.raises(TypeError, match=msg): + s.str.pad(5, fillchar=5) + + +@pytest.mark.parametrize("method_name", ["center", "ljust", "rjust", "zfill", "pad"]) +def test_pad_width_bad_arg_raises(method_name, any_string_dtype): + # see gh-13598 + s = Series(["1", "22", "a", "bb"], dtype=any_string_dtype) + op = operator.methodcaller(method_name, "f") + + msg = "width must be of integer type, not str" + with pytest.raises(TypeError, match=msg): + op(s.str) + + +def test_center_ljust_rjust(any_string_dtype): + s = Series(["a", "b", np.nan, "c", np.nan, "eeeeee"], dtype=any_string_dtype) + + result = s.str.center(5) + expected = Series( + [" a ", " b ", np.nan, " c ", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = s.str.ljust(5) + expected = Series( + ["a ", "b ", np.nan, "c ", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + result = s.str.rjust(5) + expected = Series( + [" a", " b", np.nan, " c", np.nan, "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + +def test_center_ljust_rjust_mixed_object(): + s = Series(["a", np.nan, "b", True, datetime.today(), "c", "eee", None, 1, 2.0]) + + result = s.str.center(5) + expected = Series( + [ + " a ", + np.nan, + " b ", + np.nan, + np.nan, + " c ", + " eee ", + None, + np.nan, + np.nan, + ], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + result = s.str.ljust(5) + expected = Series( + [ + "a ", + np.nan, + "b ", + np.nan, + np.nan, + "c ", + "eee ", + None, + np.nan, + np.nan, + ], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + result = s.str.rjust(5) + expected = Series( + [ + " a", + np.nan, + " b", + np.nan, + np.nan, + " c", + " eee", + None, + np.nan, + np.nan, + ], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + +def test_center_ljust_rjust_fillchar(any_string_dtype): + # GH#54533, GH#54792 + s = Series(["a", "bb", "cccc", "ddddd", "eeeeee"], dtype=any_string_dtype) + + result = s.str.center(5, fillchar="X") + expected = Series( + ["XXaXX", "XXbbX", "Xcccc", "ddddd", "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + expected = np.array([v.center(5, "X") for v in np.array(s)], dtype=np.object_) + tm.assert_numpy_array_equal(np.array(result, dtype=np.object_), expected) + + result = s.str.ljust(5, fillchar="X") + expected = Series( + ["aXXXX", "bbXXX", "ccccX", "ddddd", "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + expected = np.array([v.ljust(5, "X") for v in np.array(s)], dtype=np.object_) + tm.assert_numpy_array_equal(np.array(result, dtype=np.object_), expected) + + result = s.str.rjust(5, fillchar="X") + expected = Series( + ["XXXXa", "XXXbb", "Xcccc", "ddddd", "eeeeee"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + expected = np.array([v.rjust(5, "X") for v in np.array(s)], dtype=np.object_) + tm.assert_numpy_array_equal(np.array(result, dtype=np.object_), expected) + + +def test_center_ljust_rjust_fillchar_bad_arg_raises(any_string_dtype): + s = Series(["a", "bb", "cccc", "ddddd", "eeeeee"], dtype=any_string_dtype) + + # If fillchar is not a character, normal str raises TypeError + # 'aaa'.ljust(5, 'XY') + # TypeError: must be char, not str + template = "fillchar must be a character, not {dtype}" + + with pytest.raises(TypeError, match=template.format(dtype="str")): + s.str.center(5, fillchar="XY") + + with pytest.raises(TypeError, match=template.format(dtype="str")): + s.str.ljust(5, fillchar="XY") + + with pytest.raises(TypeError, match=template.format(dtype="str")): + s.str.rjust(5, fillchar="XY") + + with pytest.raises(TypeError, match=template.format(dtype="int")): + s.str.center(5, fillchar=1) + + with pytest.raises(TypeError, match=template.format(dtype="int")): + s.str.ljust(5, fillchar=1) + + with pytest.raises(TypeError, match=template.format(dtype="int")): + s.str.rjust(5, fillchar=1) + + +def test_zfill(any_string_dtype): + s = Series(["1", "22", "aaa", "333", "45678"], dtype=any_string_dtype) + + result = s.str.zfill(5) + expected = Series( + ["00001", "00022", "00aaa", "00333", "45678"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + expected = np.array([v.zfill(5) for v in np.array(s)], dtype=np.object_) + tm.assert_numpy_array_equal(np.array(result, dtype=np.object_), expected) + + result = s.str.zfill(3) + expected = Series(["001", "022", "aaa", "333", "45678"], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + expected = np.array([v.zfill(3) for v in np.array(s)], dtype=np.object_) + tm.assert_numpy_array_equal(np.array(result, dtype=np.object_), expected) + + s = Series(["1", np.nan, "aaa", np.nan, "45678"], dtype=any_string_dtype) + result = s.str.zfill(5) + expected = Series( + ["00001", np.nan, "00aaa", np.nan, "45678"], dtype=any_string_dtype + ) + tm.assert_series_equal(result, expected) + + +def test_wrap(any_string_dtype): + # test values are: two words less than width, two words equal to width, + # two words greater than width, one word less than width, one word + # equal to width, one word greater than width, multiple tokens with + # trailing whitespace equal to width + s = Series( + [ + "hello world", + "hello world!", + "hello world!!", + "abcdefabcde", + "abcdefabcdef", + "abcdefabcdefa", + "ab ab ab ab ", + "ab ab ab ab a", + "\t", + ], + dtype=any_string_dtype, + ) + + # expected values + expected = Series( + [ + "hello world", + "hello world!", + "hello\nworld!!", + "abcdefabcde", + "abcdefabcdef", + "abcdefabcdef\na", + "ab ab ab ab", + "ab ab ab ab\na", + "", + ], + dtype=any_string_dtype, + ) + + result = s.str.wrap(12, break_long_words=True) + tm.assert_series_equal(result, expected) + + +def test_wrap_unicode(any_string_dtype): + # test with pre and post whitespace (non-unicode), NaN, and non-ascii Unicode + s = Series( + [" pre ", np.nan, "\xac\u20ac\U00008000 abadcafe"], dtype=any_string_dtype + ) + expected = Series( + [" pre", np.nan, "\xac\u20ac\U00008000 ab\nadcafe"], dtype=any_string_dtype + ) + result = s.str.wrap(6) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/strings/test_cat.py b/pandas/tests/strings/test_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..40883fd9c756f4cad31758495b22590cd8607f4c --- /dev/null +++ b/pandas/tests/strings/test_cat.py @@ -0,0 +1,444 @@ +from datetime import datetime +import re + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + _testing as tm, + concat, + option_context, +) + + +@pytest.fixture +def index_or_series2(index_or_series): + return index_or_series + + +@pytest.mark.parametrize("other", [None, Series, Index]) +def test_str_cat_name(index_or_series, other): + # GH 21053 + box = index_or_series + values = ["a", "b"] + if other: + other = other(values) + else: + other = values + result = box(values, name="name").str.cat(other, sep=",") + assert result.name == "name" + + +@pytest.mark.parametrize( + "infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))] +) +def test_str_cat(index_or_series, infer_string): + with option_context("future.infer_string", infer_string): + box = index_or_series + # test_cat above tests "str_cat" from ndarray; + # here testing "str.cat" from Series/Index to ndarray/list + s = box(["a", "a", "b", "b", "c", np.nan]) + + # single array + result = s.str.cat() + expected = "aabbc" + assert result == expected + + result = s.str.cat(na_rep="-") + expected = "aabbc-" + assert result == expected + + result = s.str.cat(sep="_", na_rep="NA") + expected = "a_a_b_b_c_NA" + assert result == expected + + t = np.array(["a", np.nan, "b", "d", "foo", np.nan], dtype=object) + expected = box(["aa", "a-", "bb", "bd", "cfoo", "--"]) + + # Series/Index with array + result = s.str.cat(t, na_rep="-") + tm.assert_equal(result, expected) + + # Series/Index with list + result = s.str.cat(list(t), na_rep="-") + tm.assert_equal(result, expected) + + # errors for incorrect lengths + rgx = r"If `others` contains arrays or lists \(or other list-likes.*" + z = Series(["1", "2", "3"]) + + with pytest.raises(ValueError, match=rgx): + s.str.cat(z.values) + + with pytest.raises(ValueError, match=rgx): + s.str.cat(list(z)) + + +def test_str_cat_raises_intuitive_error(index_or_series): + # GH 11334 + box = index_or_series + s = box(["a", "b", "c", "d"]) + message = "Did you mean to supply a `sep` keyword?" + with pytest.raises(ValueError, match=message): + s.str.cat("|") + with pytest.raises(ValueError, match=message): + s.str.cat(" ") + + +@pytest.mark.parametrize( + "infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))] +) +@pytest.mark.parametrize("sep", ["", None]) +@pytest.mark.parametrize("dtype_target", ["object", "category"]) +@pytest.mark.parametrize("dtype_caller", ["object", "category"]) +def test_str_cat_categorical( + index_or_series, dtype_caller, dtype_target, sep, infer_string +): + box = index_or_series + + with option_context("future.infer_string", infer_string): + s = Index(["a", "a", "b", "a"], dtype=dtype_caller) + s = s if box == Index else Series(s, index=s, dtype=s.dtype) + t = Index(["b", "a", "b", "c"], dtype=dtype_target) + + expected = Index( + ["ab", "aa", "bb", "ac"], dtype=object if dtype_caller == "object" else None + ) + expected = ( + expected + if box == Index + else Series( + expected, index=Index(s, dtype=dtype_caller), dtype=expected.dtype + ) + ) + + # Series/Index with unaligned Index -> t.values + result = s.str.cat(t.values, sep=sep) + tm.assert_equal(result, expected) + + # Series/Index with Series having matching Index + t = Series(t.values, index=Index(s, dtype=dtype_caller)) + result = s.str.cat(t, sep=sep) + tm.assert_equal(result, expected) + + # Series/Index with Series.values + result = s.str.cat(t.values, sep=sep) + tm.assert_equal(result, expected) + + # Series/Index with Series having different Index + t = Series(t.values, index=t.values) + expected = Index( + ["aa", "aa", "bb", "bb", "aa"], + dtype=object if dtype_caller == "object" else None, + ) + dtype = object if dtype_caller == "object" else s.dtype.categories.dtype + expected = ( + expected + if box == Index + else Series( + expected, + index=Index(expected.str[:1], dtype=dtype), + dtype=expected.dtype, + ) + ) + + result = s.str.cat(t, sep=sep) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "data", + [[1, 2, 3], [0.1, 0.2, 0.3], [1, 2, "b"]], + ids=["integers", "floats", "mixed"], +) +# without dtype=object, np.array would cast [1, 2, 'b'] to ['1', '2', 'b'] +@pytest.mark.parametrize( + "box", + [Series, Index, list, lambda x: np.array(x, dtype=object)], + ids=["Series", "Index", "list", "np.array"], +) +def test_str_cat_wrong_dtype_raises(box, data): + # GH 22722 + s = Series(["a", "b", "c"]) + t = box(data) + + msg = "Concatenation requires list-likes containing only strings.*" + with pytest.raises(TypeError, match=msg): + # need to use outer and na_rep, as otherwise Index would not raise + s.str.cat(t, join="outer", na_rep="-") + + +def test_str_cat_mixed_inputs(index_or_series): + box = index_or_series + s = Index(["a", "b", "c", "d"]) + s = s if box == Index else Series(s, index=s) + + t = Series(["A", "B", "C", "D"], index=s.values) + d = concat([t, Series(s, index=s)], axis=1) + + expected = Index(["aAa", "bBb", "cCc", "dDd"]) + expected = expected if box == Index else Series(expected.values, index=s.values) + + # Series/Index with DataFrame + result = s.str.cat(d) + tm.assert_equal(result, expected) + + # Series/Index with two-dimensional ndarray + result = s.str.cat(d.values) + tm.assert_equal(result, expected) + + # Series/Index with list of Series + result = s.str.cat([t, s]) + tm.assert_equal(result, expected) + + # Series/Index with mixed list of Series/array + result = s.str.cat([t, s.values]) + tm.assert_equal(result, expected) + + # Series/Index with list of Series; different indexes + t.index = ["b", "c", "d", "a"] + expected = box(["aDa", "bAb", "cBc", "dCd"]) + expected = expected if box == Index else Series(expected.values, index=s.values) + result = s.str.cat([t, s]) + tm.assert_equal(result, expected) + + # Series/Index with mixed list; different index + result = s.str.cat([t, s.values]) + tm.assert_equal(result, expected) + + # Series/Index with DataFrame; different indexes + d.index = ["b", "c", "d", "a"] + expected = box(["aDd", "bAa", "cBb", "dCc"]) + expected = expected if box == Index else Series(expected.values, index=s.values) + result = s.str.cat(d) + tm.assert_equal(result, expected) + + # errors for incorrect lengths + rgx = r"If `others` contains arrays or lists \(or other list-likes.*" + z = Series(["1", "2", "3"]) + e = concat([z, z], axis=1) + + # two-dimensional ndarray + with pytest.raises(ValueError, match=rgx): + s.str.cat(e.values) + + # list of list-likes + with pytest.raises(ValueError, match=rgx): + s.str.cat([z.values, s.values]) + + # mixed list of Series/list-like + with pytest.raises(ValueError, match=rgx): + s.str.cat([z.values, s]) + + # errors for incorrect arguments in list-like + rgx = "others must be Series, Index, DataFrame,.*" + # make sure None/NaN do not crash checks in _get_series_list + u = Series(["a", np.nan, "c", None]) + + # mix of string and Series + with pytest.raises(TypeError, match=rgx): + s.str.cat([u, "u"]) + + # DataFrame in list + with pytest.raises(TypeError, match=rgx): + s.str.cat([u, d]) + + # 2-dim ndarray in list + with pytest.raises(TypeError, match=rgx): + s.str.cat([u, d.values]) + + # nested lists + with pytest.raises(TypeError, match=rgx): + s.str.cat([u, [u, d]]) + + # forbidden input type: set + # GH 23009 + with pytest.raises(TypeError, match=rgx): + s.str.cat(set(u)) + + # forbidden input type: set in list + # GH 23009 + with pytest.raises(TypeError, match=rgx): + s.str.cat([u, set(u)]) + + # other forbidden input type, e.g. int + with pytest.raises(TypeError, match=rgx): + s.str.cat(1) + + # nested list-likes + with pytest.raises(TypeError, match=rgx): + s.str.cat(iter([t.values, list(s)])) + + +def test_str_cat_align_indexed(index_or_series, join_type): + # https://github.com/pandas-dev/pandas/issues/18657 + box = index_or_series + + s = Series(["a", "b", "c", "d"], index=["a", "b", "c", "d"]) + t = Series(["D", "A", "E", "B"], index=["d", "a", "e", "b"]) + sa, ta = s.align(t, join=join_type) + # result after manual alignment of inputs + expected = sa.str.cat(ta, na_rep="-") + + if box == Index: + s = Index(s) + sa = Index(sa) + expected = Index(expected) + + result = s.str.cat(t, join=join_type, na_rep="-") + tm.assert_equal(result, expected) + + +def test_str_cat_align_mixed_inputs(join_type): + s = Series(["a", "b", "c", "d"]) + t = Series(["d", "a", "e", "b"], index=[3, 0, 4, 1]) + d = concat([t, t], axis=1) + + expected_outer = Series(["aaa", "bbb", "c--", "ddd", "-ee"]) + expected = expected_outer.loc[s.index.join(t.index, how=join_type)] + + # list of Series + result = s.str.cat([t, t], join=join_type, na_rep="-") + tm.assert_series_equal(result, expected) + + # DataFrame + result = s.str.cat(d, join=join_type, na_rep="-") + tm.assert_series_equal(result, expected) + + # mixed list of indexed/unindexed + u = np.array(["A", "B", "C", "D"]) + expected_outer = Series(["aaA", "bbB", "c-C", "ddD", "-e-"]) + # joint index of rhs [t, u]; u will be forced have index of s + rhs_idx = ( + t.index.intersection(s.index) + if join_type == "inner" + else t.index.union(s.index) + if join_type == "outer" + else t.index.append(s.index.difference(t.index)) + ) + + expected = expected_outer.loc[s.index.join(rhs_idx, how=join_type)] + result = s.str.cat([t, u], join=join_type, na_rep="-") + tm.assert_series_equal(result, expected) + + with pytest.raises(TypeError, match="others must be Series,.*"): + # nested lists are forbidden + s.str.cat([t, list(u)], join=join_type) + + # errors for incorrect lengths + rgx = r"If `others` contains arrays or lists \(or other list-likes.*" + z = Series(["1", "2", "3"]).values + + # unindexed object of wrong length + with pytest.raises(ValueError, match=rgx): + s.str.cat(z, join=join_type) + + # unindexed object of wrong length in list + with pytest.raises(ValueError, match=rgx): + s.str.cat([t, z], join=join_type) + + +def test_str_cat_datetime_index_unsorted(join_type): + # https://github.com/pandas-dev/pandas/pull/62843 + values = [datetime(2024, 1, 1), datetime(2024, 1, 2)] + s = Series(["a", "b"], index=[values[1], values[0]]) + others = Series(["c", "d"], index=[values[0], values[1]]) + result = s.str.cat(others, join=join_type) + if join_type in {"outer", "right"}: + expected = Series(["bc", "ad"], index=[values[0], values[1]]) + else: + expected = Series(["ad", "bc"], index=[values[1], values[0]]) + tm.assert_series_equal(result, expected) + + +def test_str_cat_all_na(index_or_series, index_or_series2): + # GH 24044 + box = index_or_series + other = index_or_series2 + + # check that all NaNs in caller / target work + s = Index(["a", "b", "c", "d"]) + s = s if box == Index else Series(s, index=s) + t = other([np.nan] * 4, dtype=object) + # add index of s for alignment + t = t if other == Index else Series(t, index=s) + + # all-NA target + if box == Series: + expected = Series([np.nan] * 4, index=s.index, dtype=s.dtype) + else: # box == Index + # TODO: Strimg option, this should return string dtype + expected = Index([np.nan] * 4, dtype=object) + result = s.str.cat(t, join="left") + tm.assert_equal(result, expected) + + # all-NA caller (only for Series) + if other == Series: + expected = Series([np.nan] * 4, dtype=object, index=t.index) + result = t.str.cat(s, join="left") + tm.assert_series_equal(result, expected) + + +def test_str_cat_special_cases(): + s = Series(["a", "b", "c", "d"]) + t = Series(["d", "a", "e", "b"], index=[3, 0, 4, 1]) + + # iterator of elements with different types + expected = Series(["aaa", "bbb", "c-c", "ddd", "-e-"]) + result = s.str.cat(iter([t, s.values]), join="outer", na_rep="-") + tm.assert_series_equal(result, expected) + + # right-align with different indexes in others + expected = Series(["aa-", "d-d"], index=[0, 3]) + result = s.str.cat([t.loc[[0]], t.loc[[3]]], join="right", na_rep="-") + tm.assert_series_equal(result, expected) + + +def test_cat_on_filtered_index(): + df = DataFrame( + index=MultiIndex.from_product( + [[2011, 2012], [1, 2, 3]], names=["year", "month"] + ) + ) + + df = df.reset_index() + df = df[df.month > 1] + + str_year = df.year.astype("str") + str_month = df.month.astype("str") + str_both = str_year.str.cat(str_month, sep=" ") + + assert str_both.loc[1] == "2011 2" + + str_multiple = str_year.str.cat([str_month, str_month], sep=" ") + + assert str_multiple.loc[1] == "2011 2 2" + + +@pytest.mark.parametrize("klass", [tuple, list, np.array, Series, Index]) +def test_cat_different_classes(klass): + # https://github.com/pandas-dev/pandas/issues/33425 + s = Series(["a", "b", "c"]) + result = s.str.cat(klass(["x", "y", "z"])) + expected = Series(["ax", "by", "cz"]) + tm.assert_series_equal(result, expected) + + +def test_cat_on_series_dot_str(): + # GH 28277 + ps = Series(["AbC", "de", "FGHI", "j", "kLLLm"]) + + message = re.escape( + "others must be Series, Index, DataFrame, np.ndarray " + "or list-like (either containing only strings or " + "containing only objects of type Series/Index/" + "np.ndarray[1-dim])" + ) + with pytest.raises(TypeError, match=message): + ps.str.cat(others=ps.str) diff --git a/pandas/tests/strings/test_extract.py b/pandas/tests/strings/test_extract.py new file mode 100644 index 0000000000000000000000000000000000000000..5a77ce618a88d8325d2bf47bbcc55083141c1b0e --- /dev/null +++ b/pandas/tests/strings/test_extract.py @@ -0,0 +1,784 @@ +from datetime import datetime +import re + +import numpy as np +import pytest + +from pandas.core.dtypes.dtypes import ArrowDtype + +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + _testing as tm, +) + + +def test_extract_expand_kwarg_wrong_type_raises(any_string_dtype): + # TODO: should this raise TypeError + values = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype) + with pytest.raises(ValueError, match="expand must be True or False"): + values.str.extract(".*(BAD[_]+).*(BAD)", expand=None) + + +def test_extract_expand_kwarg(any_string_dtype): + s = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype) + expected = DataFrame(["BAD__", np.nan, np.nan], dtype=any_string_dtype) + + result = s.str.extract(".*(BAD[_]+).*") + tm.assert_frame_equal(result, expected) + + result = s.str.extract(".*(BAD[_]+).*", expand=True) + tm.assert_frame_equal(result, expected) + + expected = DataFrame( + [["BAD__", "BAD"], [np.nan, np.nan], [np.nan, np.nan]], dtype=any_string_dtype + ) + result = s.str.extract(".*(BAD[_]+).*(BAD)", expand=False) + tm.assert_frame_equal(result, expected) + + +def test_extract_expand_False_mixed_object(): + ser = Series( + ["aBAD_BAD", np.nan, "BAD_b_BAD", True, datetime.today(), "foo", None, 1, 2.0] + ) + + # two groups + result = ser.str.extract(".*(BAD[_]+).*(BAD)", expand=False) + er = [np.nan, np.nan] # empty row + expected = DataFrame( + [["BAD_", "BAD"], er, ["BAD_", "BAD"], er, er, er, er, er, er], dtype=object + ) + tm.assert_frame_equal(result, expected) + + # single group + result = ser.str.extract(".*(BAD[_]+).*BAD", expand=False) + expected = Series( + ["BAD_", np.nan, "BAD_", np.nan, np.nan, np.nan, None, np.nan, np.nan], + dtype=object, + ) + tm.assert_series_equal(result, expected) + + +def test_extract_expand_index_raises(): + # GH9980 + # Index only works with one regex group since + # multi-group would expand to a frame + idx = Index(["A1", "A2", "A3", "A4", "B5"]) + msg = "only one regex group is supported with Index" + with pytest.raises(ValueError, match=msg): + idx.str.extract("([AB])([123])", expand=False) + + +def test_extract_expand_no_capture_groups_raises(index_or_series, any_string_dtype): + s_or_idx = index_or_series(["A1", "B2", "C3"], dtype=any_string_dtype) + msg = "pattern contains no capture groups" + + # no groups + with pytest.raises(ValueError, match=msg): + s_or_idx.str.extract("[ABC][123]", expand=False) + + # only non-capturing groups + with pytest.raises(ValueError, match=msg): + s_or_idx.str.extract("(?:[AB]).*", expand=False) + + +def test_extract_expand_single_capture_group(index_or_series, any_string_dtype): + # single group renames series/index properly + s_or_idx = index_or_series(["A1", "A2"], dtype=any_string_dtype) + result = s_or_idx.str.extract(r"(?PA)\d", expand=False) + + expected = index_or_series(["A", "A"], name="uno", dtype=any_string_dtype) + if index_or_series == Series: + tm.assert_series_equal(result, expected) + else: + tm.assert_index_equal(result, expected) + + +def test_extract_expand_capture_groups(any_string_dtype): + s = Series(["A1", "B2", "C3"], dtype=any_string_dtype) + # one group, no matches + result = s.str.extract("(_)", expand=False) + expected = Series([np.nan, np.nan, np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + # two groups, no matches + result = s.str.extract("(_)(_)", expand=False) + expected = DataFrame( + [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one group, some matches + result = s.str.extract("([AB])[123]", expand=False) + expected = Series(["A", "B", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + # two groups, some matches + result = s.str.extract("([AB])([123])", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one named group + result = s.str.extract("(?P[AB])", expand=False) + expected = Series(["A", "B", np.nan], name="letter", dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + # two named groups + result = s.str.extract("(?P[AB])(?P[123])", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], + columns=["letter", "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # mix named and unnamed groups + result = s.str.extract("([AB])(?P[123])", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], + columns=[0, "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # one normal group, one non-capturing group + result = s.str.extract("([AB])(?:[123])", expand=False) + expected = Series(["A", "B", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + # two normal groups, one non-capturing group + s = Series(["A11", "B22", "C33"], dtype=any_string_dtype) + result = s.str.extract("([AB])([123])(?:[123])", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one optional group followed by one normal group + s = Series(["A1", "B2", "3"], dtype=any_string_dtype) + result = s.str.extract("(?P[AB])?(?P[123])", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, "3"]], + columns=["letter", "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # one normal group followed by one optional group + s = Series(["A1", "B2", "C"], dtype=any_string_dtype) + result = s.str.extract("(?P[ABC])(?P[123])?", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], ["C", np.nan]], + columns=["letter", "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_extract_expand_capture_groups_index(index, any_string_dtype): + # https://github.com/pandas-dev/pandas/issues/6348 + # not passing index to the extractor + data = ["A1", "B2", "C"] + + if len(index) == 0: + pytest.skip("Test requires len(index) > 0") + while len(index) < len(data): + index = index.repeat(2) + + index = index[: len(data)] + ser = Series(data, index=index, dtype=any_string_dtype) + + result = ser.str.extract(r"(\d)", expand=False) + expected = Series(["1", "2", np.nan], index=index, dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + result = ser.str.extract(r"(?P\D)(?P\d)?", expand=False) + expected = DataFrame( + [["A", "1"], ["B", "2"], ["C", np.nan]], + columns=["letter", "number"], + index=index, + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_extract_single_series_name_is_preserved(any_string_dtype): + s = Series(["a3", "b3", "c2"], name="bob", dtype=any_string_dtype) + result = s.str.extract(r"(?P[a-z])", expand=False) + expected = Series(["a", "b", "c"], name="sue", dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_extract_expand_True(any_string_dtype): + # Contains tests like those in test_match and some others. + s = Series(["fooBAD__barBAD", np.nan, "foo"], dtype=any_string_dtype) + + result = s.str.extract(".*(BAD[_]+).*(BAD)", expand=True) + expected = DataFrame( + [["BAD__", "BAD"], [np.nan, np.nan], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + +def test_extract_expand_True_mixed_object(): + er = [np.nan, np.nan] # empty row + mixed = Series( + [ + "aBAD_BAD", + np.nan, + "BAD_b_BAD", + True, + datetime.today(), + "foo", + None, + 1, + 2.0, + ] + ) + + result = mixed.str.extract(".*(BAD[_]+).*(BAD)", expand=True) + expected = DataFrame( + [["BAD_", "BAD"], er, ["BAD_", "BAD"], er, er, er, er, er, er], dtype=object + ) + tm.assert_frame_equal(result, expected) + + +def test_extract_expand_True_single_capture_group_raises( + index_or_series, any_string_dtype +): + # these should work for both Series and Index + # no groups + s_or_idx = index_or_series(["A1", "B2", "C3"], dtype=any_string_dtype) + msg = "pattern contains no capture groups" + with pytest.raises(ValueError, match=msg): + s_or_idx.str.extract("[ABC][123]", expand=True) + + # only non-capturing groups + with pytest.raises(ValueError, match=msg): + s_or_idx.str.extract("(?:[AB]).*", expand=True) + + +def test_extract_expand_True_single_capture_group(index_or_series, any_string_dtype): + # single group renames series/index properly + s_or_idx = index_or_series(["A1", "A2"], dtype=any_string_dtype) + result = s_or_idx.str.extract(r"(?PA)\d", expand=True) + expected = DataFrame({"uno": ["A", "A"]}, dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("name", [None, "series_name"]) +def test_extract_series(name, any_string_dtype): + # extract should give the same result whether or not the series has a name. + s = Series(["A1", "B2", "C3"], name=name, dtype=any_string_dtype) + + # one group, no matches + result = s.str.extract("(_)", expand=True) + expected = DataFrame([np.nan, np.nan, np.nan], dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + # two groups, no matches + result = s.str.extract("(_)(_)", expand=True) + expected = DataFrame( + [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one group, some matches + result = s.str.extract("([AB])[123]", expand=True) + expected = DataFrame(["A", "B", np.nan], dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + # two groups, some matches + result = s.str.extract("([AB])([123])", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one named group + result = s.str.extract("(?P[AB])", expand=True) + expected = DataFrame({"letter": ["A", "B", np.nan]}, dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + # two named groups + result = s.str.extract("(?P[AB])(?P[123])", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], + columns=["letter", "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # mix named and unnamed groups + result = s.str.extract("([AB])(?P[123])", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], + columns=[0, "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # one normal group, one non-capturing group + result = s.str.extract("([AB])(?:[123])", expand=True) + expected = DataFrame(["A", "B", np.nan], dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + +def test_extract_optional_groups(any_string_dtype): + # two normal groups, one non-capturing group + s = Series(["A11", "B22", "C33"], dtype=any_string_dtype) + result = s.str.extract("([AB])([123])(?:[123])", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, np.nan]], dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one optional group followed by one normal group + s = Series(["A1", "B2", "3"], dtype=any_string_dtype) + result = s.str.extract("(?P[AB])?(?P[123])", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], [np.nan, "3"]], + columns=["letter", "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # one normal group followed by one optional group + s = Series(["A1", "B2", "C"], dtype=any_string_dtype) + result = s.str.extract("(?P[ABC])(?P[123])?", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], ["C", np.nan]], + columns=["letter", "number"], + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_extract_dataframe_capture_groups_index(index, any_string_dtype): + # GH6348 + # not passing index to the extractor + + data = ["A1", "B2", "C"] + + if len(index) < len(data): + pytest.skip(f"Index needs more than {len(data)} values") + + index = index[: len(data)] + s = Series(data, index=index, dtype=any_string_dtype) + + result = s.str.extract(r"(\d)", expand=True) + expected = DataFrame(["1", "2", np.nan], index=index, dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + result = s.str.extract(r"(?P\D)(?P\d)?", expand=True) + expected = DataFrame( + [["A", "1"], ["B", "2"], ["C", np.nan]], + columns=["letter", "number"], + index=index, + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_extract_single_group_returns_frame(any_string_dtype): + # GH11386 extract should always return DataFrame, even when + # there is only one group. Prior to v0.18.0, extract returned + # Series when there was only one group in the regex. + s = Series(["a3", "b3", "c2"], name="series_name", dtype=any_string_dtype) + result = s.str.extract(r"(?P[a-z])", expand=True) + expected = DataFrame({"letter": ["a", "b", "c"]}, dtype=any_string_dtype) + tm.assert_frame_equal(result, expected) + + +def test_extractall(any_string_dtype): + data = [ + "dave@google.com", + "tdhock5@gmail.com", + "maudelaperriere@gmail.com", + "rob@gmail.com some text steve@gmail.com", + "a@b.com some text c@d.com and e@f.com", + np.nan, + "", + ] + expected_tuples = [ + ("dave", "google", "com"), + ("tdhock5", "gmail", "com"), + ("maudelaperriere", "gmail", "com"), + ("rob", "gmail", "com"), + ("steve", "gmail", "com"), + ("a", "b", "com"), + ("c", "d", "com"), + ("e", "f", "com"), + ] + pat = r""" + (?P[a-z0-9]+) + @ + (?P[a-z]+) + \. + (?P[a-z]{2,4}) + """ + expected_columns = ["user", "domain", "tld"] + s = Series(data, dtype=any_string_dtype) + # extractall should return a DataFrame with one row for each match, indexed by the + # subject from which the match came. + expected_index = MultiIndex.from_tuples( + [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (4, 0), (4, 1), (4, 2)], + names=(None, "match"), + ) + expected = DataFrame( + expected_tuples, expected_index, expected_columns, dtype=any_string_dtype + ) + result = s.str.extractall(pat, flags=re.VERBOSE) + tm.assert_frame_equal(result, expected) + + # The index of the input Series should be used to construct the index of the output + # DataFrame: + mi = MultiIndex.from_tuples( + [ + ("single", "Dave"), + ("single", "Toby"), + ("single", "Maude"), + ("multiple", "robAndSteve"), + ("multiple", "abcdef"), + ("none", "missing"), + ("none", "empty"), + ] + ) + s = Series(data, index=mi, dtype=any_string_dtype) + expected_index = MultiIndex.from_tuples( + [ + ("single", "Dave", 0), + ("single", "Toby", 0), + ("single", "Maude", 0), + ("multiple", "robAndSteve", 0), + ("multiple", "robAndSteve", 1), + ("multiple", "abcdef", 0), + ("multiple", "abcdef", 1), + ("multiple", "abcdef", 2), + ], + names=(None, None, "match"), + ) + expected = DataFrame( + expected_tuples, expected_index, expected_columns, dtype=any_string_dtype + ) + result = s.str.extractall(pat, flags=re.VERBOSE) + tm.assert_frame_equal(result, expected) + + # MultiIndexed subject with names. + s = Series(data, index=mi, dtype=any_string_dtype) + s.index.names = ("matches", "description") + expected_index.names = ("matches", "description", "match") + expected = DataFrame( + expected_tuples, expected_index, expected_columns, dtype=any_string_dtype + ) + result = s.str.extractall(pat, flags=re.VERBOSE) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "pat,expected_names", + [ + # optional groups. + ("(?P[AB])?(?P[123])", ["letter", "number"]), + # only one of two groups has a name. + ("([AB])?(?P[123])", [0, "number"]), + ], +) +def test_extractall_column_names(pat, expected_names, any_string_dtype): + s = Series(["", "A1", "32"], dtype=any_string_dtype) + + result = s.str.extractall(pat) + expected = DataFrame( + [("A", "1"), (np.nan, "3"), (np.nan, "2")], + index=MultiIndex.from_tuples([(1, 0), (2, 0), (2, 1)], names=(None, "match")), + columns=expected_names, + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_extractall_single_group(any_string_dtype): + s = Series(["a3", "b3", "d4c2"], name="series_name", dtype=any_string_dtype) + expected_index = MultiIndex.from_tuples( + [(0, 0), (1, 0), (2, 0), (2, 1)], names=(None, "match") + ) + + # extractall(one named group) returns DataFrame with one named column. + result = s.str.extractall(r"(?P[a-z])") + expected = DataFrame( + {"letter": ["a", "b", "d", "c"]}, index=expected_index, dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # extractall(one un-named group) returns DataFrame with one un-named column. + result = s.str.extractall(r"([a-z])") + expected = DataFrame( + ["a", "b", "d", "c"], index=expected_index, dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + +def test_extractall_single_group_with_quantifier(any_string_dtype): + # GH#13382 + # extractall(one un-named group with quantifier) returns DataFrame with one un-named + # column. + s = Series(["ab3", "abc3", "d4cd2"], name="series_name", dtype=any_string_dtype) + result = s.str.extractall(r"([a-z]+)") + expected = DataFrame( + ["ab", "abc", "d", "cd"], + index=MultiIndex.from_tuples( + [(0, 0), (1, 0), (2, 0), (2, 1)], names=(None, "match") + ), + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "data, names", + [ + ([], (None,)), + ([], ("i1",)), + ([], (None, "i2")), + ([], ("i1", "i2")), + (["a3", "b3", "d4c2"], (None,)), + (["a3", "b3", "d4c2"], ("i1", "i2")), + (["a3", "b3", "d4c2"], (None, "i2")), + ], +) +def test_extractall_no_matches(data, names, any_string_dtype): + # GH19075 extractall with no matches should return a valid MultiIndex + n = len(data) + if len(names) == 1: + index = Index(range(n), name=names[0]) + else: + tuples = (tuple([i] * (n - 1)) for i in range(n)) + index = MultiIndex.from_tuples(tuples, names=names) + s = Series(data, name="series_name", index=index, dtype=any_string_dtype) + expected_index = MultiIndex.from_tuples([], names=((*names, "match"))) + + # one un-named group. + result = s.str.extractall("(z)") + expected = DataFrame(columns=range(1), index=expected_index, dtype=any_string_dtype) + tm.assert_frame_equal(result, expected, check_column_type=True) + + # two un-named groups. + result = s.str.extractall("(z)(z)") + expected = DataFrame(columns=range(2), index=expected_index, dtype=any_string_dtype) + tm.assert_frame_equal(result, expected, check_column_type=True) + + # one named group. + result = s.str.extractall("(?Pz)") + expected = DataFrame( + columns=["first"], index=expected_index, dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # two named groups. + result = s.str.extractall("(?Pz)(?Pz)") + expected = DataFrame( + columns=["first", "second"], index=expected_index, dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + # one named, one un-named. + result = s.str.extractall("(z)(?Pz)") + expected = DataFrame( + columns=[0, "second"], index=expected_index, dtype=any_string_dtype + ) + tm.assert_frame_equal(result, expected) + + +def test_extractall_stringindex(any_string_dtype): + s = Series(["a1a2", "b1", "c1"], name="xxx", dtype=any_string_dtype) + result = s.str.extractall(r"[ab](?P\d)") + expected = DataFrame( + {"digit": ["1", "2", "1"]}, + index=MultiIndex.from_tuples([(0, 0), (0, 1), (1, 0)], names=[None, "match"]), + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + # index should return the same result as the default index without name thus + # index.name doesn't affect to the result + if any_string_dtype == "object": + for idx in [ + Index(["a1a2", "b1", "c1"], dtype=object), + Index(["a1a2", "b1", "c1"], name="xxx", dtype=object), + ]: + result = idx.str.extractall(r"[ab](?P\d)") + tm.assert_frame_equal(result, expected) + + s = Series( + ["a1a2", "b1", "c1"], + name="s_name", + index=Index(["XX", "yy", "zz"], name="idx_name"), + dtype=any_string_dtype, + ) + result = s.str.extractall(r"[ab](?P\d)") + expected = DataFrame( + {"digit": ["1", "2", "1"]}, + index=MultiIndex.from_tuples( + [("XX", 0), ("XX", 1), ("yy", 0)], names=["idx_name", "match"] + ), + dtype=any_string_dtype, + ) + tm.assert_frame_equal(result, expected) + + +def test_extractall_no_capture_groups_raises(any_string_dtype): + # Does not make sense to use extractall with a regex that has no capture groups. + # (it returns DataFrame with one column for each capture group) + s = Series(["a3", "b3", "d4c2"], name="series_name", dtype=any_string_dtype) + with pytest.raises(ValueError, match="no capture groups"): + s.str.extractall(r"[a-z]") + + +def test_extract_index_one_two_groups(): + s = Series(["a3", "b3", "d4c2"], index=["A3", "B3", "D4"], name="series_name") + r = s.index.str.extract(r"([A-Z])", expand=True) + e = DataFrame(["A", "B", "D"]) + tm.assert_frame_equal(r, e) + + # Prior to v0.18.0, index.str.extract(regex with one group) + # returned Index. With more than one group, extract raised an + # error (GH9980). Now extract always returns DataFrame. + r = s.index.str.extract(r"(?P[A-Z])(?P[0-9])", expand=True) + e_list = [("A", "3"), ("B", "3"), ("D", "4")] + e = DataFrame(e_list, columns=["letter", "digit"]) + tm.assert_frame_equal(r, e) + + +def test_extractall_same_as_extract(any_string_dtype): + s = Series(["a3", "b3", "c2"], name="series_name", dtype=any_string_dtype) + + pattern_two_noname = r"([a-z])([0-9])" + extract_two_noname = s.str.extract(pattern_two_noname, expand=True) + has_multi_index = s.str.extractall(pattern_two_noname) + no_multi_index = has_multi_index.xs(0, level="match") + tm.assert_frame_equal(extract_two_noname, no_multi_index) + + pattern_two_named = r"(?P[a-z])(?P[0-9])" + extract_two_named = s.str.extract(pattern_two_named, expand=True) + has_multi_index = s.str.extractall(pattern_two_named) + no_multi_index = has_multi_index.xs(0, level="match") + tm.assert_frame_equal(extract_two_named, no_multi_index) + + pattern_one_named = r"(?P[a-z])" + extract_one_named = s.str.extract(pattern_one_named, expand=True) + has_multi_index = s.str.extractall(pattern_one_named) + no_multi_index = has_multi_index.xs(0, level="match") + tm.assert_frame_equal(extract_one_named, no_multi_index) + + pattern_one_noname = r"([a-z])" + extract_one_noname = s.str.extract(pattern_one_noname, expand=True) + has_multi_index = s.str.extractall(pattern_one_noname) + no_multi_index = has_multi_index.xs(0, level="match") + tm.assert_frame_equal(extract_one_noname, no_multi_index) + + +def test_extractall_same_as_extract_subject_index(any_string_dtype): + # same as above tests, but s has a MultiIndex. + mi = MultiIndex.from_tuples( + [("A", "first"), ("B", "second"), ("C", "third")], + names=("capital", "ordinal"), + ) + s = Series(["a3", "b3", "c2"], index=mi, name="series_name", dtype=any_string_dtype) + + pattern_two_noname = r"([a-z])([0-9])" + extract_two_noname = s.str.extract(pattern_two_noname, expand=True) + has_match_index = s.str.extractall(pattern_two_noname) + no_match_index = has_match_index.xs(0, level="match") + tm.assert_frame_equal(extract_two_noname, no_match_index) + + pattern_two_named = r"(?P[a-z])(?P[0-9])" + extract_two_named = s.str.extract(pattern_two_named, expand=True) + has_match_index = s.str.extractall(pattern_two_named) + no_match_index = has_match_index.xs(0, level="match") + tm.assert_frame_equal(extract_two_named, no_match_index) + + pattern_one_named = r"(?P[a-z])" + extract_one_named = s.str.extract(pattern_one_named, expand=True) + has_match_index = s.str.extractall(pattern_one_named) + no_match_index = has_match_index.xs(0, level="match") + tm.assert_frame_equal(extract_one_named, no_match_index) + + pattern_one_noname = r"([a-z])" + extract_one_noname = s.str.extract(pattern_one_noname, expand=True) + has_match_index = s.str.extractall(pattern_one_noname) + no_match_index = has_match_index.xs(0, level="match") + tm.assert_frame_equal(extract_one_noname, no_match_index) + + +def test_extractall_preserves_dtype(): + # Ensure that when extractall is called on a series with specific dtypes set, that + # the dtype is preserved in the resulting DataFrame's column. + pa = pytest.importorskip("pyarrow") + + result = Series(["abc", "ab"], dtype=ArrowDtype(pa.string())).str.extractall("(ab)") + assert result.dtypes[0] == "string[pyarrow]" + + +@pytest.mark.parametrize( + "pat, expected_data", + [ + (r"(a(?=b))", [None, "a", None, None]), + (r"((?<=a)b)", [None, "b", None, None]), + (r"(a(?!b))", ["a", None, "a", None]), + (r"((? \g \g", + ["Three Two One", "Baz Bar Foo"], + ), + ( + r"\3 \2 \1", + ["Three Two One", "Baz Bar Foo"], + ), + ( + r"\g<3> \g<2> \g<1>", + ["Three Two One", "Baz Bar Foo"], + ), + ( + r"\g<2>0", + ["Two0", "Bar0"], + ), + ( + r"\g<2>0 \1", + ["Two0 One", "Bar0 Foo"], + ), + ], + ids=[ + "named_groups_full_swap", + "numbered_groups_no_g_full_swap", + "numbered_groups_full_swap", + "single_group_with_literal", + "mixed_group_reference_with_literal", + ], +) +@pytest.mark.parametrize("use_compile", [True, False]) +def test_replace_named_groups_regex_swap( + any_string_dtype, use_compile, repl, expected_list +): + # GH#57636 + ser = Series(["One Two Three", "Foo Bar Baz"], dtype=any_string_dtype) + pattern = r"(?P\w+) (?P\w+) (?P\w+)" + if use_compile: + pattern = re.compile(pattern) + result = ser.str.replace(pattern, repl, regex=True) + expected = Series(expected_list, dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "repl", + [ + r"\g<20>", + r"\20", + r"\40", + r"\4", + ], +) +@pytest.mark.parametrize("use_compile", [True, False]) +def test_replace_named_groups_regex_swap_expected_fail( + any_string_dtype, repl, use_compile, request +): + # GH#57636 + if ( + not use_compile + and r"\g" not in repl + and isinstance(any_string_dtype, StringDtype) + and any_string_dtype.storage == "pyarrow" + ): + # calls pyarrow method directly + if repl == r"\20": + mark = pytest.mark.xfail(reason="PyArrow interprets as group + literal") + request.applymarker(mark) + + pa = pytest.importorskip("pyarrow") + error_type = pa.ArrowInvalid + error_msg = r"only has \d parenthesized subexpressions" + else: + error_type = re.error + error_msg = "invalid group reference" + + pattern = r"(?P\w+) (?P\w+) (?P\w+)" + if use_compile: + pattern = re.compile(pattern) + ser = Series(["One Two Three", "Foo Bar Baz"], dtype=any_string_dtype) + + with pytest.raises(error_type, match=error_msg): + ser.str.replace(pattern, repl, regex=True) + + +@pytest.mark.parametrize( + "pattern, repl", + [ + (r"(\w+) (\w+) (\w+)", r"\20"), + (r"(?P\w+) (?P\w+) (?P\w+)", r"\20"), + ], +) +def test_pyarrow_ambiguous_group_references(pyarrow_string_dtype, pattern, repl): + # GH#62653 + ser = Series(["One Two Three", "Foo Bar Baz"], dtype=pyarrow_string_dtype) + + result = ser.str.replace(pattern, repl, regex=True) + expected = Series(["Two0", "Bar0"], dtype=pyarrow_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "pattern, repl, expected_list", + [ + ( + r"\[(?P\d+)\]", + r"(\1)", + ["var.one(0)", "var.two(1)", "var.three(2)"], + ), + ( + r"\[(\d+)\]", + r"(\1)", + ["var.one(0)", "var.two(1)", "var.three(2)"], + ), + ], +) +@td.skip_if_no("pyarrow") +def test_pyarrow_backend_group_replacement(pattern, repl, expected_list): + ser = Series(["var.one[0]", "var.two[1]", "var.three[2]"]).convert_dtypes( + dtype_backend="pyarrow" + ) + result = ser.str.replace(pattern, repl, regex=True) + expected = Series(expected_list).convert_dtypes(dtype_backend="pyarrow") + tm.assert_series_equal(result, expected) + + +def test_replace_callable_named_groups(any_string_dtype): + # test regex named groups + ser = Series(["Foo Bar Baz", np.nan], dtype=any_string_dtype) + pat = r"(?P\w+) (?P\w+) (?P\w+)" + repl = lambda m: m.group("middle").swapcase() + result = ser.str.replace(pat, repl, regex=True) + expected = Series(["bAR", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_replace_compiled_regex(any_string_dtype): + # GH 15446 + ser = Series(["fooBAD__barBAD", np.nan], dtype=any_string_dtype) + + # test with compiled regex + pat = re.compile(r"BAD_*") + result = ser.str.replace(pat, "", regex=True) + expected = Series(["foobar", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + result = ser.str.replace(pat, "", n=1, regex=True) + expected = Series(["foobarBAD", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_replace_compiled_regex_mixed_object(): + pat = re.compile(r"BAD_*") + ser = Series( + ["aBAD", np.nan, "bBAD", True, datetime.today(), "fooBAD", None, 1, 2.0] + ) + result = Series(ser).str.replace(pat, "", regex=True) + expected = Series( + ["a", np.nan, "b", np.nan, np.nan, "foo", None, np.nan, np.nan], dtype=object + ) + tm.assert_series_equal(result, expected) + + +def test_replace_compiled_regex_unicode(any_string_dtype): + ser = Series([b"abcd,\xc3\xa0".decode("utf-8")], dtype=any_string_dtype) + expected = Series([b"abcd, \xc3\xa0".decode("utf-8")], dtype=any_string_dtype) + pat = re.compile(r"(?<=\w),(?=\w)", flags=re.UNICODE) + result = ser.str.replace(pat, ", ", regex=True) + tm.assert_series_equal(result, expected) + + +def test_replace_compiled_regex_raises(any_string_dtype): + # case and flags provided to str.replace will have no effect + # and will produce warnings + ser = Series(["fooBAD__barBAD__bad", np.nan], dtype=any_string_dtype) + pat = re.compile(r"BAD_*") + + msg = "case and flags cannot be set when pat is a compiled regex" + + with pytest.raises(ValueError, match=msg): + ser.str.replace(pat, "", flags=re.IGNORECASE, regex=True) + + with pytest.raises(ValueError, match=msg): + ser.str.replace(pat, "", case=False, regex=True) + + with pytest.raises(ValueError, match=msg): + ser.str.replace(pat, "", case=True, regex=True) + + +def test_replace_compiled_regex_callable(any_string_dtype): + # test with callable + ser = Series(["fooBAD__barBAD", np.nan], dtype=any_string_dtype) + repl = lambda m: m.group(0).swapcase() + pat = re.compile("[a-z][A-Z]{2}") + result = ser.str.replace(pat, repl, n=2, regex=True) + expected = Series(["foObaD__baRbaD", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("regex,expected_val", [(True, "bao"), (False, "foo")]) +def test_replace_literal(regex, expected_val, any_string_dtype): + # GH16808 literal replace (regex=False vs regex=True) + ser = Series(["f.o", "foo", np.nan], dtype=any_string_dtype) + expected = Series(["bao", expected_val, np.nan], dtype=any_string_dtype) + result = ser.str.replace("f.", "ba", regex=regex) + tm.assert_series_equal(result, expected) + + +def test_replace_literal_callable_raises(any_string_dtype): + ser = Series([], dtype=any_string_dtype) + repl = lambda m: m.group(0).swapcase() + + msg = "Cannot use a callable replacement when regex=False" + with pytest.raises(ValueError, match=msg): + ser.str.replace("abc", repl, regex=False) + + +def test_replace_literal_compiled_raises(any_string_dtype): + ser = Series([], dtype=any_string_dtype) + pat = re.compile("[a-z][A-Z]{2}") + + msg = "Cannot use a compiled regex as replacement pattern with regex=False" + with pytest.raises(ValueError, match=msg): + ser.str.replace(pat, "", regex=False) + + +def test_replace_moar(any_string_dtype): + # PR #1179 + ser = Series( + ["A", "B", "C", "Aaba", "Baca", "", np.nan, "CABA", "dog", "cat"], + dtype=any_string_dtype, + ) + + result = ser.str.replace("A", "YYY") + expected = Series( + ["YYY", "B", "C", "YYYaba", "Baca", "", np.nan, "CYYYBYYY", "dog", "cat"], + dtype=any_string_dtype, + ) + tm.assert_series_equal(result, expected) + + result = ser.str.replace("A", "YYY", case=False) + expected = Series( + [ + "YYY", + "B", + "C", + "YYYYYYbYYY", + "BYYYcYYY", + "", + np.nan, + "CYYYBYYY", + "dog", + "cYYYt", + ], + dtype=any_string_dtype, + ) + tm.assert_series_equal(result, expected) + + result = ser.str.replace("^.a|dog", "XX-XX ", case=False, regex=True) + expected = Series( + [ + "A", + "B", + "C", + "XX-XX ba", + "XX-XX ca", + "", + np.nan, + "XX-XX BA", + "XX-XX ", + "XX-XX t", + ], + dtype=any_string_dtype, + ) + tm.assert_series_equal(result, expected) + + +def test_replace_not_case_sensitive_not_regex(any_string_dtype): + # https://github.com/pandas-dev/pandas/issues/41602 + ser = Series(["A.", "a.", "Ab", "ab", np.nan], dtype=any_string_dtype) + + result = ser.str.replace("a", "c", case=False, regex=False) + expected = Series(["c.", "c.", "cb", "cb", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + result = ser.str.replace("a.", "c.", case=False, regex=False) + expected = Series(["c.", "c.", "Ab", "ab", np.nan], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +def test_replace_regex(any_string_dtype): + # https://github.com/pandas-dev/pandas/pull/24809 + s = Series(["a", "b", "ac", np.nan, ""], dtype=any_string_dtype) + result = s.str.replace("^.$", "a", regex=True) + expected = Series(["a", "a", "ac", np.nan, ""], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("regex", [True, False]) +def test_replace_regex_single_character(regex, any_string_dtype): + # https://github.com/pandas-dev/pandas/pull/24809, enforced in 2.0 + # GH 24804 + s = Series(["a.b", ".", "b", np.nan, ""], dtype=any_string_dtype) + + result = s.str.replace(".", "a", regex=regex) + if regex: + expected = Series(["aaa", "a", "a", np.nan, ""], dtype=any_string_dtype) + else: + expected = Series(["aab", "a", "b", np.nan, ""], dtype=any_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "pat, expected_data", + [ + (r"a(?=b)", ["aa", "xb", "ba", "bb"]), + (r"(?<=a)b", ["aa", "ax", "ba", "bb"]), + (r"a(?!b)", ["xx", "ab", "bx", "bb"]), + (r"(?" + ) + with pytest.raises(UnicodeEncodeError, match=msg): + ser.str.encode("cp1252") + + result = ser.str.encode("cp1252", "ignore") + expected = ser.map(lambda x: x.encode("cp1252", "ignore")) + tm.assert_series_equal(result, expected) + + +def test_decode_errors_kwarg(): + ser = Series([b"a", b"b", b"a\x9d"]) + + msg = ( + "'charmap' codec can't decode byte 0x9d in position 1: " + "character maps to " + ) + with pytest.raises(UnicodeDecodeError, match=msg): + ser.str.decode("cp1252") + + result = ser.str.decode("cp1252", "ignore") + expected = ser.map(lambda x: x.decode("cp1252", "ignore")).astype("str") + tm.assert_series_equal(result, expected) + + +def test_decode_string_dtype(string_dtype): + # https://github.com/pandas-dev/pandas/pull/60940 + ser = Series([b"a", b"b"]) + result = ser.str.decode("utf-8", dtype=string_dtype) + expected = Series(["a", "b"], dtype=string_dtype) + tm.assert_series_equal(result, expected) + + +def test_decode_object_dtype(object_dtype): + # https://github.com/pandas-dev/pandas/pull/60940 + ser = Series([b"a", rb"\ud800"]) + result = ser.str.decode("utf-8", dtype=object_dtype) + expected = Series(["a", r"\ud800"], dtype=object_dtype) + tm.assert_series_equal(result, expected) + + +def test_decode_bad_dtype(): + # https://github.com/pandas-dev/pandas/pull/60940 + ser = Series([b"a", b"b"]) + msg = "dtype must be string or object, got dtype='int64'" + with pytest.raises(ValueError, match=msg): + ser.str.decode("utf-8", dtype="int64") + + +@pytest.mark.parametrize( + "form, expected", + [ + ("NFKC", ["ABC", "ABC", "123", np.nan, "アイエ"]), + ("NFC", ["ABC", "ABC", "123", np.nan, "アイエ"]), # noqa: RUF001 + ], +) +def test_normalize(form, expected, any_string_dtype): + ser = Series( + ["ABC", "ABC", "123", np.nan, "アイエ"], # noqa: RUF001 + index=["a", "b", "c", "d", "e"], + dtype=any_string_dtype, + ) + expected = Series(expected, index=["a", "b", "c", "d", "e"], dtype=any_string_dtype) + result = ser.str.normalize(form) + tm.assert_series_equal(result, expected) + + +def test_normalize_bad_arg_raises(any_string_dtype): + ser = Series( + ["ABC", "ABC", "123", np.nan, "アイエ"], # noqa: RUF001 + index=["a", "b", "c", "d", "e"], + dtype=any_string_dtype, + ) + with pytest.raises(ValueError, match="invalid normalization form"): + ser.str.normalize("xxx") + + +def test_normalize_index(): + idx = Index(["ABC", "123", "アイエ"]) # noqa: RUF001 + expected = Index(["ABC", "123", "アイエ"]) + result = idx.str.normalize("NFKC") + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize( + "values,inferred_type", + [ + (["a", "b"], "string"), + (["a", "b", 1], "mixed-integer"), + (["a", "b", 1.3], "mixed"), + (["a", "b", 1.3, 1], "mixed-integer"), + (["aa", datetime(2011, 1, 1)], "mixed"), + ], +) +def test_index_str_accessor_visibility(values, inferred_type, index_or_series): + obj = index_or_series(values) + if index_or_series is Index: + assert obj.inferred_type == inferred_type + + assert isinstance(obj.str, StringMethods) + + +@pytest.mark.parametrize( + "values,inferred_type", + [ + ([1, np.nan], "floating"), + ([datetime(2011, 1, 1)], "datetime64"), + ([timedelta(1)], "timedelta64"), + ], +) +def test_index_str_accessor_non_string_values_raises( + values, inferred_type, index_or_series +): + obj = index_or_series(values) + if index_or_series is Index: + assert obj.inferred_type == inferred_type + + msg = "Can only use .str accessor with string values" + with pytest.raises(AttributeError, match=msg): + obj.str + + +def test_index_str_accessor_multiindex_raises(): + # MultiIndex has mixed dtype, but not allow to use accessor + idx = MultiIndex.from_tuples([("a", "b"), ("a", "b")]) + assert idx.inferred_type == "mixed" + + msg = "Can only use .str accessor with Index, not MultiIndex" + with pytest.raises(AttributeError, match=msg): + idx.str + + +def test_str_accessor_no_new_attributes(any_string_dtype): + # https://github.com/pandas-dev/pandas/issues/10673 + ser = Series(list("aabbcde"), dtype=any_string_dtype) + with pytest.raises(AttributeError, match="You cannot add any new attribute"): + ser.str.xlabel = "a" + + +def test_cat_on_bytes_raises(): + lhs = Series(np.array(list("abc"), "S1").astype(object)) + rhs = Series(np.array(list("def"), "S1").astype(object)) + msg = "Cannot use .str.cat with values of inferred dtype 'bytes'" + with pytest.raises(TypeError, match=msg): + lhs.str.cat(rhs) + + +def test_str_accessor_in_apply_func(): + # https://github.com/pandas-dev/pandas/issues/38979 + df = DataFrame(zip("abc", "def", strict=True)) + expected = Series(["A/D", "B/E", "C/F"]) + result = df.apply(lambda f: "/".join(f.str.upper()), axis=1) + tm.assert_series_equal(result, expected) + + +def test_zfill(): + # https://github.com/pandas-dev/pandas/issues/20868 + value = Series(["-1", "1", "1000", 10, np.nan]) + expected = Series(["-01", "001", "1000", np.nan, np.nan], dtype=object) + tm.assert_series_equal(value.str.zfill(3), expected) + + value = Series(["-2", "+5"]) + expected = Series(["-0002", "+0005"]) + tm.assert_series_equal(value.str.zfill(5), expected) + + +def test_zfill_with_non_integer_argument(): + value = Series(["-2", "+5"]) + wid = "a" + msg = f"width must be of integer type, not {type(wid).__name__}" + with pytest.raises(TypeError, match=msg): + value.str.zfill(wid) + + +def test_zfill_with_leading_sign(): + value = Series(["-cat", "-1", "+dog"]) + expected = Series(["-0cat", "-0001", "+0dog"]) + tm.assert_series_equal(value.str.zfill(5), expected) + + +def test_get_with_dict_label(): + # GH47911 + s = Series( + [ + {"name": "Hello", "value": "World"}, + {"name": "Goodbye", "value": "Planet"}, + {"value": "Sea"}, + ] + ) + result = s.str.get("name") + expected = Series(["Hello", "Goodbye", None], dtype=object) + tm.assert_series_equal(result, expected) + result = s.str.get("value") + expected = Series(["World", "Planet", "Sea"], dtype=object) + tm.assert_series_equal(result, expected) + + +def test_series_str_decode(): + # GH 22613 + result = Series([b"x", b"y"]).str.decode(encoding="UTF-8", errors="strict") + expected = Series(["x", "y"], dtype="str") + tm.assert_series_equal(result, expected) + + +def test_decode_with_dtype_none(): + with option_context("future.infer_string", True): + ser = Series([b"a", b"b", b"c"]) + result = ser.str.decode("utf-8", dtype=None) + expected = Series(["a", "b", "c"], dtype="str") + tm.assert_series_equal(result, expected) + + +def test_setitem_with_different_string_storage(): + # GH#52987 + # Test setitem with values from different string storage type + pytest.importorskip("pyarrow") + + # Test Series[string[python]].__setitem__(Series[string[pyarrow]]) + ser_python = Series(range(5), dtype="string[python]") + ser_pyarrow = ser_python.astype("string[pyarrow]") + + ser_python[:2] = ser_pyarrow[:2] + expected = Series(["0", "1", "2", "3", "4"], dtype="string[python]") + tm.assert_series_equal(ser_python, expected) + + # Test Series[string[pyarrow]].__setitem__(Series[string[python]]) + ser_pyarrow = Series(range(5), dtype="string[pyarrow]") + ser_python = ser_pyarrow.astype("string[python]") + + ser_pyarrow[:2] = ser_python[:2] + expected = Series(["0", "1", "2", "3", "4"], dtype="string[pyarrow]") + tm.assert_series_equal(ser_pyarrow, expected) + + # Test with slice and missing values + ser_python = Series(["a", "b", None, "d", "e"], dtype="string[python]") + ser_pyarrow = Series(["X", "Y", None], dtype="string[pyarrow]") + + ser_python[1:4] = ser_pyarrow + expected = Series(["a", "X", "Y", NA, "e"], dtype="string[python]") + tm.assert_series_equal(ser_python, expected) + + +@pytest.mark.parametrize( + "pat, expected", + [ + # lookaround assertions + (r"(?=abc)", True), + (r"(?<=123)", True), + (r"(?!xyz)", True), + (r"(?\w+)\s+(?P=word)\b", True), + ], +) +def test_has_regex_unsupported_code(pat, expected): + # https://github.com/pandas-dev/pandas/issues/60833 + assert ArrowStringArrayMixin._has_unsupported_regex(pat) == expected diff --git a/pandas/tests/test_aggregation.py b/pandas/tests/test_aggregation.py new file mode 100644 index 0000000000000000000000000000000000000000..3a01805cc2365c3f5024064465341d1fe664eeed --- /dev/null +++ b/pandas/tests/test_aggregation.py @@ -0,0 +1,93 @@ +import numpy as np +import pytest + +from pandas.core.apply import ( + _make_unique_kwarg_list, + maybe_mangle_lambdas, +) + + +def test_maybe_mangle_lambdas_passthrough(): + assert maybe_mangle_lambdas("mean") == "mean" + assert maybe_mangle_lambdas(lambda x: x).__name__ == "" + # don't mangle single lambda. + assert maybe_mangle_lambdas([lambda x: x])[0].__name__ == "" + + +def test_maybe_mangle_lambdas_listlike(): + aggfuncs = [lambda x: 1, lambda x: 2] + result = maybe_mangle_lambdas(aggfuncs) + assert result[0].__name__ == "" + assert result[1].__name__ == "" + assert aggfuncs[0](None) == result[0](None) + assert aggfuncs[1](None) == result[1](None) + + +def test_maybe_mangle_lambdas(): + func = {"A": [lambda x: 0, lambda x: 1]} + result = maybe_mangle_lambdas(func) + assert result["A"][0].__name__ == "" + assert result["A"][1].__name__ == "" + + +def test_maybe_mangle_lambdas_args(): + func = {"A": [lambda x, a, b=1: (0, a, b), lambda x: 1]} + result = maybe_mangle_lambdas(func) + assert result["A"][0].__name__ == "" + assert result["A"][1].__name__ == "" + + assert func["A"][0](0, 1) == (0, 1, 1) + assert func["A"][0](0, 1, 2) == (0, 1, 2) + assert func["A"][0](0, 2, b=3) == (0, 2, 3) + + +def test_maybe_mangle_lambdas_named(): + func = {"C": np.mean, "D": {"foo": np.mean, "bar": np.mean}} + result = maybe_mangle_lambdas(func) + assert result == func + + +@pytest.mark.parametrize( + "order, expected_reorder", + [ + ( + [ + ("height", ""), + ("height", "max"), + ("weight", "max"), + ("height", ""), + ("weight", ""), + ], + [ + ("height", "_0"), + ("height", "max"), + ("weight", "max"), + ("height", "_1"), + ("weight", ""), + ], + ), + ( + [ + ("col2", "min"), + ("col1", ""), + ("col1", ""), + ("col1", ""), + ], + [ + ("col2", "min"), + ("col1", "_0"), + ("col1", "_1"), + ("col1", "_2"), + ], + ), + ( + [("col", ""), ("col", ""), ("col", "")], + [("col", "_0"), ("col", "_1"), ("col", "_2")], + ), + ], +) +def test_make_unique(order, expected_reorder): + # GH 27519, test if make_unique function reorders correctly + result = _make_unique_kwarg_list(order) + + assert result == expected_reorder diff --git a/pandas/tests/test_algos.py b/pandas/tests/test_algos.py new file mode 100644 index 0000000000000000000000000000000000000000..ee34dff8446955cb4363db28d6f8d3115f2cb768 --- /dev/null +++ b/pandas/tests/test_algos.py @@ -0,0 +1,2083 @@ +from datetime import datetime +import struct + +import numpy as np +import pytest + +from pandas._libs import ( + algos as libalgos, + hashtable as ht, +) + +from pandas.core.dtypes.common import ( + is_bool_dtype, + is_complex_dtype, + is_float_dtype, + is_integer_dtype, + is_object_dtype, +) +from pandas.core.dtypes.dtypes import ( + CategoricalDtype, + DatetimeTZDtype, +) + +import pandas as pd +from pandas import ( + Categorical, + CategoricalIndex, + DataFrame, + DatetimeIndex, + Index, + IntervalIndex, + MultiIndex, + NaT, + Period, + PeriodIndex, + Series, + Timedelta, + Timestamp, + cut, + date_range, + timedelta_range, + to_datetime, + to_timedelta, +) +import pandas._testing as tm +import pandas.core.algorithms as algos +from pandas.core.arrays import ( + DatetimeArray, + TimedeltaArray, +) +import pandas.core.common as com + + +class TestFactorize: + def test_factorize_complex(self): + # GH#17927 + array = np.array([1, 2, 2 + 1j], dtype=complex) + labels, uniques = algos.factorize(array) + + expected_labels = np.array([0, 1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(labels, expected_labels) + + expected_uniques = np.array([(1 + 0j), (2 + 0j), (2 + 1j)], dtype=complex) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_factorize(self, index_or_series_obj, sort): + obj = index_or_series_obj + result_codes, result_uniques = obj.factorize(sort=sort) + + constructor = Index + if isinstance(obj, MultiIndex): + constructor = MultiIndex.from_tuples + expected_arr = obj.unique() + if expected_arr.dtype == np.float16: + expected_arr = expected_arr.astype(np.float32) + expected_uniques = constructor(expected_arr) + if ( + isinstance(obj, Index) + and expected_uniques.dtype == bool + and obj.dtype == object + ): + expected_uniques = expected_uniques.astype(object) + + if sort: + expected_uniques = expected_uniques.sort_values() + + # construct an integer ndarray so that + # `expected_uniques.take(expected_codes)` is equal to `obj` + expected_uniques_list = list(expected_uniques) + expected_codes = [expected_uniques_list.index(val) for val in obj] + expected_codes = np.asarray(expected_codes, dtype=np.intp) + + tm.assert_numpy_array_equal(result_codes, expected_codes) + tm.assert_index_equal(result_uniques, expected_uniques, exact=True) + + def test_series_factorize_use_na_sentinel_false(self): + # GH#35667 + values = np.array([1, 2, 1, np.nan]) + ser = Series(values) + codes, uniques = ser.factorize(use_na_sentinel=False) + + expected_codes = np.array([0, 1, 0, 2], dtype=np.intp) + expected_uniques = Index([1.0, 2.0, np.nan]) + + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_index_equal(uniques, expected_uniques) + + def test_basic(self): + items = np.array(["a", "b", "b", "a", "a", "c", "c", "c"], dtype=object) + codes, uniques = algos.factorize(items) + tm.assert_numpy_array_equal(uniques, np.array(["a", "b", "c"], dtype=object)) + + codes, uniques = algos.factorize(items, sort=True) + exp = np.array([0, 1, 1, 0, 0, 2, 2, 2], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = np.array(["a", "b", "c"], dtype=object) + tm.assert_numpy_array_equal(uniques, exp) + + arr = np.arange(5, dtype=np.intp)[::-1] + + codes, uniques = algos.factorize(arr) + exp = np.array([0, 1, 2, 3, 4], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = np.array([4, 3, 2, 1, 0], dtype=arr.dtype) + tm.assert_numpy_array_equal(uniques, exp) + + codes, uniques = algos.factorize(arr, sort=True) + exp = np.array([4, 3, 2, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = np.array([0, 1, 2, 3, 4], dtype=arr.dtype) + tm.assert_numpy_array_equal(uniques, exp) + + arr = np.arange(5.0)[::-1] + + codes, uniques = algos.factorize(arr) + exp = np.array([0, 1, 2, 3, 4], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = np.array([4.0, 3.0, 2.0, 1.0, 0.0], dtype=arr.dtype) + tm.assert_numpy_array_equal(uniques, exp) + + codes, uniques = algos.factorize(arr, sort=True) + exp = np.array([4, 3, 2, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = np.array([0.0, 1.0, 2.0, 3.0, 4.0], dtype=arr.dtype) + tm.assert_numpy_array_equal(uniques, exp) + + def test_mixed(self): + # doc example reshaping.rst + x = Series(["A", "A", np.nan, "B", 3.14, np.inf]) + codes, uniques = algos.factorize(x) + + exp = np.array([0, 0, -1, 1, 2, 3], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = Index(["A", "B", 3.14, np.inf]) + tm.assert_index_equal(uniques, exp) + + codes, uniques = algos.factorize(x, sort=True) + exp = np.array([2, 2, -1, 3, 0, 1], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = Index([3.14, np.inf, "A", "B"]) + tm.assert_index_equal(uniques, exp) + + def test_factorize_datetime64(self): + # M8 + v1 = Timestamp("20130101 09:00:00.00004") + v2 = Timestamp("20130101") + x = Series([v1, v1, v1, v2, v2, v1]) + codes, uniques = algos.factorize(x) + + exp = np.array([0, 0, 0, 1, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = DatetimeIndex([v1, v2]) + tm.assert_index_equal(uniques, exp) + + codes, uniques = algos.factorize(x, sort=True) + exp = np.array([1, 1, 1, 0, 0, 1], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + exp = DatetimeIndex([v2, v1]) + tm.assert_index_equal(uniques, exp) + + def test_factorize_period(self): + # period + v1 = Period("201302", freq="M") + v2 = Period("201303", freq="M") + x = Series([v1, v1, v1, v2, v2, v1]) + + # periods are not 'sorted' as they are converted back into an index + codes, uniques = algos.factorize(x) + exp = np.array([0, 0, 0, 1, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + tm.assert_index_equal(uniques, PeriodIndex([v1, v2])) + + codes, uniques = algos.factorize(x, sort=True) + exp = np.array([0, 0, 0, 1, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + tm.assert_index_equal(uniques, PeriodIndex([v1, v2])) + + def test_factorize_timedelta(self): + # GH 5986 + v1 = to_timedelta("1 day 1 min") + v2 = to_timedelta("1 day") + x = Series([v1, v2, v1, v1, v2, v2, v1]) + codes, uniques = algos.factorize(x) + exp = np.array([0, 1, 0, 0, 1, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + tm.assert_index_equal(uniques, to_timedelta([v1, v2])) + + codes, uniques = algos.factorize(x, sort=True) + exp = np.array([1, 0, 1, 1, 0, 0, 1], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp) + tm.assert_index_equal(uniques, to_timedelta([v2, v1])) + + def test_factorize_nan(self): + # nan should map to na_sentinel, not reverse_indexer[na_sentinel] + # rizer.factorize should not raise an exception if na_sentinel indexes + # outside of reverse_indexer + key = np.array([1, 2, 1, np.nan], dtype="O") + rizer = ht.ObjectFactorizer(len(key)) + for na_sentinel in (-1, 20): + ids = rizer.factorize(key, na_sentinel=na_sentinel) + expected = np.array([0, 1, 0, na_sentinel], dtype=np.intp) + assert len(set(key)) == len(set(expected)) + tm.assert_numpy_array_equal(pd.isna(key), expected == na_sentinel) + tm.assert_numpy_array_equal(ids, expected) + + def test_factorizer_with_mask(self): + # GH#49549 + data = np.array([1, 2, 3, 1, 1, 0], dtype="int64") + mask = np.array([False, False, False, False, False, True]) + rizer = ht.Int64Factorizer(len(data)) + result = rizer.factorize(data, mask=mask) + expected = np.array([0, 1, 2, 0, 0, -1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + expected_uniques = np.array([1, 2, 3], dtype="int64") + tm.assert_numpy_array_equal(rizer.uniques.to_array(), expected_uniques) + + def test_factorizer_object_with_nan(self): + # GH#49549 + data = np.array([1, 2, 3, 1, np.nan]) + rizer = ht.ObjectFactorizer(len(data)) + result = rizer.factorize(data.astype(object)) + expected = np.array([0, 1, 2, 0, -1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + expected_uniques = np.array([1, 2, 3], dtype=object) + tm.assert_numpy_array_equal(rizer.uniques.to_array(), expected_uniques) + + @pytest.mark.parametrize( + "data, expected_codes, expected_uniques", + [ + ( + [(1, 1), (1, 2), (0, 0), (1, 2), "nonsense"], + [0, 1, 2, 1, 3], + [(1, 1), (1, 2), (0, 0), "nonsense"], + ), + ( + [(1, 1), (1, 2), (0, 0), (1, 2), (1, 2, 3)], + [0, 1, 2, 1, 3], + [(1, 1), (1, 2), (0, 0), (1, 2, 3)], + ), + ([(1, 1), (1, 2), (0, 0), (1, 2)], [0, 1, 2, 1], [(1, 1), (1, 2), (0, 0)]), + ], + ) + def test_factorize_tuple_list(self, data, expected_codes, expected_uniques): + # GH9454 + data = com.asarray_tuplesafe(data, dtype=object) + codes, uniques = pd.factorize(data) + + tm.assert_numpy_array_equal(codes, np.array(expected_codes, dtype=np.intp)) + + expected_uniques_array = com.asarray_tuplesafe(expected_uniques, dtype=object) + tm.assert_numpy_array_equal(uniques, expected_uniques_array) + + def test_complex_sorting(self): + # gh 12666 - check no segfault + x17 = np.array([complex(i) for i in range(17)], dtype=object) + + msg = "'[<>]' not supported between instances of .*" + with pytest.raises(TypeError, match=msg): + algos.factorize(x17[::-1], sort=True) + + def test_numeric_dtype_factorize(self, any_real_numpy_dtype): + # GH41132 + dtype = any_real_numpy_dtype + data = np.array([1, 2, 2, 1], dtype=dtype) + expected_codes = np.array([0, 1, 1, 0], dtype=np.intp) + expected_uniques = np.array([1, 2], dtype=dtype) + + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_float64_factorize(self, writable): + data = np.array([1.0, 1e8, 1.0, 1e-8, 1e8, 1.0], dtype=np.float64) + data.setflags(write=writable) + expected_codes = np.array([0, 1, 0, 2, 1, 0], dtype=np.intp) + expected_uniques = np.array([1.0, 1e8, 1e-8], dtype=np.float64) + + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_uint64_factorize(self, writable): + data = np.array([2**64 - 1, 1, 2**64 - 1], dtype=np.uint64) + data.setflags(write=writable) + expected_codes = np.array([0, 1, 0], dtype=np.intp) + expected_uniques = np.array([2**64 - 1, 1], dtype=np.uint64) + + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_int64_factorize(self, writable): + data = np.array([2**63 - 1, -(2**63), 2**63 - 1], dtype=np.int64) + data.setflags(write=writable) + expected_codes = np.array([0, 1, 0], dtype=np.intp) + expected_uniques = np.array([2**63 - 1, -(2**63)], dtype=np.int64) + + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_string_factorize(self, writable): + data = np.array(["a", "c", "a", "b", "c"], dtype=object) + data.setflags(write=writable) + expected_codes = np.array([0, 1, 0, 2, 1], dtype=np.intp) + expected_uniques = np.array(["a", "c", "b"], dtype=object) + + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_object_factorize(self, writable): + data = np.array(["a", "c", None, np.nan, "a", "b", NaT, "c"], dtype=object) + data.setflags(write=writable) + expected_codes = np.array([0, 1, -1, -1, 0, 2, -1, 1], dtype=np.intp) + expected_uniques = np.array(["a", "c", "b"], dtype=object) + + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_datetime64_factorize(self, writable): + # GH35650 Verify whether read-only datetime64 array can be factorized + data = np.array([np.datetime64("2020-01-01T00:00:00.000")], dtype="M8[ns]") + data.setflags(write=writable) + expected_codes = np.array([0], dtype=np.intp) + expected_uniques = np.array( + ["2020-01-01T00:00:00.000000000"], dtype="datetime64[ns]" + ) + + codes, uniques = pd.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + def test_factorize_rangeindex(self, sort): + # increasing -> sort doesn't matter + ri = pd.RangeIndex.from_range(range(10)) + expected = np.arange(10, dtype=np.intp), ri + + result = algos.factorize(ri, sort=sort) + tm.assert_numpy_array_equal(result[0], expected[0]) + tm.assert_index_equal(result[1], expected[1], exact=True) + + result = ri.factorize(sort=sort) + tm.assert_numpy_array_equal(result[0], expected[0]) + tm.assert_index_equal(result[1], expected[1], exact=True) + + def test_factorize_rangeindex_decreasing(self, sort): + # decreasing -> sort matters + ri = pd.RangeIndex.from_range(range(10)) + expected = np.arange(10, dtype=np.intp), ri + + ri2 = ri[::-1] + expected = expected[0], ri2 + if sort: + expected = expected[0][::-1], expected[1][::-1] + + result = algos.factorize(ri2, sort=sort) + tm.assert_numpy_array_equal(result[0], expected[0]) + tm.assert_index_equal(result[1], expected[1], exact=True) + + result = ri2.factorize(sort=sort) + tm.assert_numpy_array_equal(result[0], expected[0]) + tm.assert_index_equal(result[1], expected[1], exact=True) + + def test_deprecate_order(self): + # gh 19727 - check warning is raised for deprecated keyword, order. + # Test not valid once order keyword is removed. + data = np.array([2**63, 1, 2**63], dtype=np.uint64) + with pytest.raises(TypeError, match="got an unexpected keyword"): + algos.factorize(data, order=True) + with tm.assert_produces_warning(False): + algos.factorize(data) + + @pytest.mark.parametrize( + "data", + [ + np.array([0, 1, 0], dtype="u8"), + np.array([-(2**63), 1, -(2**63)], dtype="i8"), + np.array(["__nan__", "foo", "__nan__"], dtype="object"), + ], + ) + def test_parametrized_factorize_na_value_default(self, data): + # arrays that include the NA default for that type, but isn't used. + codes, uniques = algos.factorize(data) + expected_uniques = data[[0, 1]] + expected_codes = np.array([0, 1, 0], dtype=np.intp) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + @pytest.mark.parametrize( + "data, na_value", + [ + (np.array([0, 1, 0, 2], dtype="u8"), 0), + (np.array([1, 0, 1, 2], dtype="u8"), 1), + (np.array([-(2**63), 1, -(2**63), 0], dtype="i8"), -(2**63)), + (np.array([1, -(2**63), 1, 0], dtype="i8"), 1), + (np.array(["a", "", "a", "b"], dtype=object), "a"), + (np.array([(), ("a", 1), (), ("a", 2)], dtype=object), ()), + (np.array([("a", 1), (), ("a", 1), ("a", 2)], dtype=object), ("a", 1)), + ], + ) + def test_parametrized_factorize_na_value(self, data, na_value): + codes, uniques = algos.factorize_array(data, na_value=na_value) + expected_uniques = data[[1, 3]] + expected_codes = np.array([-1, 0, -1, 1], dtype=np.intp) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_numpy_array_equal(uniques, expected_uniques) + + @pytest.mark.parametrize( + "data, uniques", + [ + ( + np.array(["b", "a", None, "b"], dtype=object), + np.array(["b", "a"], dtype=object), + ), + ( + pd.array([2, 1, pd.NA, 2], dtype="Int64"), + pd.array([2, 1], dtype="Int64"), + ), + ], + ids=["numpy_array", "extension_array"], + ) + def test_factorize_use_na_sentinel(self, sort, data, uniques): + codes, uniques = algos.factorize(data, sort=sort, use_na_sentinel=True) + if sort: + expected_codes = np.array([1, 0, -1, 1], dtype=np.intp) + expected_uniques = algos.safe_sort(uniques) + else: + expected_codes = np.array([0, 1, -1, 0], dtype=np.intp) + expected_uniques = uniques + tm.assert_numpy_array_equal(codes, expected_codes) + if isinstance(data, np.ndarray): + tm.assert_numpy_array_equal(uniques, expected_uniques) + else: + tm.assert_extension_array_equal(uniques, expected_uniques) + + @pytest.mark.parametrize( + "data, expected_codes, expected_uniques", + [ + ( + ["a", None, "b", "a"], + np.array([0, 1, 2, 0], dtype=np.dtype("intp")), + np.array(["a", np.nan, "b"], dtype=object), + ), + ( + ["a", np.nan, "b", "a"], + np.array([0, 1, 2, 0], dtype=np.dtype("intp")), + np.array(["a", np.nan, "b"], dtype=object), + ), + ], + ) + def test_object_factorize_use_na_sentinel_false( + self, data, expected_codes, expected_uniques + ): + codes, uniques = algos.factorize( + np.array(data, dtype=object), use_na_sentinel=False + ) + + tm.assert_numpy_array_equal(uniques, expected_uniques, strict_nan=True) + tm.assert_numpy_array_equal(codes, expected_codes, strict_nan=True) + + @pytest.mark.parametrize( + "data, expected_codes, expected_uniques", + [ + ( + np.array([1, None, 1, 2], dtype=object), + np.array([0, 1, 0, 2], dtype=np.dtype("intp")), + np.array([1, np.nan, 2], dtype="O"), + ), + ( + np.array([1, np.nan, 1, 2], dtype=np.float64), + np.array([0, 1, 0, 2], dtype=np.dtype("intp")), + np.array([1, np.nan, 2], dtype=np.float64), + ), + ], + ) + def test_int_factorize_use_na_sentinel_false( + self, data, expected_codes, expected_uniques + ): + codes, uniques = algos.factorize(data, use_na_sentinel=False) + + tm.assert_numpy_array_equal(uniques, expected_uniques, strict_nan=True) + tm.assert_numpy_array_equal(codes, expected_codes, strict_nan=True) + + @pytest.mark.parametrize( + "data, expected_codes, expected_uniques", + [ + ( + Index(Categorical(["a", "a", "b"])), + np.array([0, 0, 1], dtype=np.intp), + CategoricalIndex(["a", "b"], categories=["a", "b"], dtype="category"), + ), + ( + Series(Categorical(["a", "a", "b"])), + np.array([0, 0, 1], dtype=np.intp), + CategoricalIndex(["a", "b"], categories=["a", "b"], dtype="category"), + ), + ( + Series(DatetimeIndex(["2017", "2017"], tz="US/Eastern")), + np.array([0, 0], dtype=np.intp), + DatetimeIndex(["2017"], tz="US/Eastern"), + ), + ], + ) + def test_factorize_mixed_values(self, data, expected_codes, expected_uniques): + # GH 19721 + codes, uniques = algos.factorize(data) + tm.assert_numpy_array_equal(codes, expected_codes) + tm.assert_index_equal(uniques, expected_uniques) + + def test_factorize_interval_non_nano(self, unit): + # GH#56099 + left = DatetimeIndex(["2016-01-01", np.nan, "2015-10-11"]).as_unit(unit) + right = DatetimeIndex(["2016-01-02", np.nan, "2015-10-15"]).as_unit(unit) + idx = IntervalIndex.from_arrays(left, right) + codes, cats = idx.factorize() + assert cats.dtype == f"interval[datetime64[{unit}], right]" + + ts = Timestamp(0).as_unit(unit) + idx2 = IntervalIndex.from_arrays(left - ts, right - ts) + codes2, cats2 = idx2.factorize() + assert cats2.dtype == f"interval[timedelta64[{unit}], right]" + + idx3 = IntervalIndex.from_arrays( + left.tz_localize("US/Pacific"), right.tz_localize("US/Pacific") + ) + codes3, cats3 = idx3.factorize() + assert cats3.dtype == f"interval[datetime64[{unit}, US/Pacific], right]" + + +class TestUnique: + def test_ints(self): + arr = np.random.default_rng(2).integers(0, 100, size=50) + + result = algos.unique(arr) + assert isinstance(result, np.ndarray) + + def test_objects(self): + arr = np.random.default_rng(2).integers(0, 100, size=50).astype("O") + + result = algos.unique(arr) + assert isinstance(result, np.ndarray) + + def test_object_refcount_bug(self): + lst = np.array(["A", "B", "C", "D", "E"], dtype=object) + for i in range(1000): + len(algos.unique(lst)) + + def test_index_returned(self, index): + # GH#57043 + index = index.repeat(2) + result = algos.unique(index) + + # dict.fromkeys preserves the order + unique_values = list(dict.fromkeys(index.values)) + if isinstance(index, MultiIndex): + expected = MultiIndex.from_tuples(unique_values, names=index.names) + else: + expected = Index(unique_values, dtype=index.dtype) + if isinstance(index.dtype, DatetimeTZDtype): + expected = expected.normalize() + tm.assert_index_equal(result, expected, exact=True) + + def test_factorize_multiindex_empty(self): + # GH#57517 + mi = MultiIndex.from_product( + [Index([], name="a", dtype=object), Index([], name="i", dtype="f4")] + ) + codes, uniques = mi.factorize() + exp_codes = np.array([], dtype=np.intp) + tm.assert_numpy_array_equal(codes, exp_codes) + tm.assert_index_equal(uniques, mi[:0]) + + def test_dtype_preservation(self, any_numpy_dtype): + # GH 15442 + if any_numpy_dtype in (tm.BYTES_DTYPES + tm.STRING_DTYPES): + data = [1, 2, 2] + uniques = [1, 2] + elif is_integer_dtype(any_numpy_dtype): + data = [1, 2, 2] + uniques = [1, 2] + elif is_float_dtype(any_numpy_dtype): + data = [1, 2, 2] + uniques = [1.0, 2.0] + elif is_complex_dtype(any_numpy_dtype): + data = [complex(1, 0), complex(2, 0), complex(2, 0)] + uniques = [complex(1, 0), complex(2, 0)] + elif is_bool_dtype(any_numpy_dtype): + data = [True, True, False] + uniques = [True, False] + elif is_object_dtype(any_numpy_dtype): + data = ["A", "B", "B"] + uniques = ["A", "B"] + else: + # datetime64[ns]/M8[ns]/timedelta64[ns]/m8[ns] tested elsewhere + data = [1, 2, 2] + uniques = [1, 2] + + result = Series(data, dtype=any_numpy_dtype).unique() + expected = np.array(uniques, dtype=any_numpy_dtype) + + if any_numpy_dtype in tm.STRING_DTYPES: + expected = expected.astype(object) + + if expected.dtype.kind in ["m", "M"]: + # We get TimedeltaArray/DatetimeArray + assert isinstance(result, (DatetimeArray, TimedeltaArray)) + result = np.array(result) + tm.assert_numpy_array_equal(result, expected) + + def test_datetime64_dtype_array_returned(self): + # GH 9431 + dt_arr = np.array( + [ + "2015-01-03T00:00:00.000000000", + "2015-01-01T00:00:00.000000000", + ], + dtype="M8[ns]", + ) + + dt_index = to_datetime( + [ + "2015-01-03T00:00:00.000000000", + "2015-01-01T00:00:00.000000000", + "2015-01-01T00:00:00.000000000", + ] + ) + result = algos.unique(dt_index) + expected = to_datetime(dt_arr) + tm.assert_index_equal(result, expected, exact=True) + + s = Series(dt_index) + result = algos.unique(s) + tm.assert_numpy_array_equal(result, dt_arr) + assert result.dtype == dt_arr.dtype + + arr = s.values + result = algos.unique(arr) + tm.assert_numpy_array_equal(result, dt_arr) + assert result.dtype == dt_arr.dtype + + def test_datetime_non_ns(self): + a = np.array(["2000", "2000", "2001"], dtype="datetime64[s]") + result = pd.unique(a) + expected = np.array(["2000", "2001"], dtype="datetime64[s]") + tm.assert_numpy_array_equal(result, expected) + + def test_timedelta_non_ns(self): + a = np.array(["2000", "2000", "2001"], dtype="timedelta64[s]") + result = pd.unique(a) + expected = np.array([2000, 2001], dtype="timedelta64[s]") + tm.assert_numpy_array_equal(result, expected) + + def test_timedelta64_dtype_array_returned(self): + # GH 9431 + td_arr = np.array([31200, 45678, 10000], dtype="m8[ns]") + + td_index = to_timedelta([31200, 45678, 31200, 10000, 45678]) + result = algos.unique(td_index) + expected = to_timedelta(td_arr) + tm.assert_index_equal(result, expected) + assert result.dtype == expected.dtype + + s = Series(td_index) + result = algos.unique(s) + tm.assert_numpy_array_equal(result, td_arr) + assert result.dtype == td_arr.dtype + + arr = s.values + result = algos.unique(arr) + tm.assert_numpy_array_equal(result, td_arr) + assert result.dtype == td_arr.dtype + + def test_uint64_overflow(self): + s = Series([1, 2, 2**63, 2**63], dtype=np.uint64) + exp = np.array([1, 2, 2**63], dtype=np.uint64) + tm.assert_numpy_array_equal(algos.unique(s), exp) + + def test_nan_in_object_array(self): + duplicated_items = ["a", np.nan, "c", "c"] + result = pd.unique(np.array(duplicated_items, dtype=object)) + expected = np.array(["a", np.nan, "c"], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + def test_categorical(self): + # we are expecting to return in the order + # of appearance + expected = Categorical(list("bac")) + + # we are expecting to return in the order + # of the categories + expected_o = Categorical(list("bac"), categories=list("abc"), ordered=True) + + # GH 15939 + c = Categorical(list("baabc")) + result = c.unique() + tm.assert_categorical_equal(result, expected) + + result = algos.unique(c) + tm.assert_categorical_equal(result, expected) + + c = Categorical(list("baabc"), ordered=True) + result = c.unique() + tm.assert_categorical_equal(result, expected_o) + + result = algos.unique(c) + tm.assert_categorical_equal(result, expected_o) + + # Series of categorical dtype + s = Series(Categorical(list("baabc")), name="foo") + result = s.unique() + tm.assert_categorical_equal(result, expected) + + result = pd.unique(s) + tm.assert_categorical_equal(result, expected) + + # CI -> return CI + ci = CategoricalIndex(Categorical(list("baabc"), categories=list("abc"))) + expected = CategoricalIndex(expected) + result = ci.unique() + tm.assert_index_equal(result, expected) + + result = pd.unique(ci) + tm.assert_index_equal(result, expected) + + def test_datetime64tz_aware(self, unit): + # GH 15939 + + dti = Index( + [ + Timestamp("20160101", tz="US/Eastern"), + Timestamp("20160101", tz="US/Eastern"), + ] + ).as_unit(unit) + ser = Series(dti) + + result = ser.unique() + expected = dti[:1]._data + tm.assert_extension_array_equal(result, expected) + + result = dti.unique() + expected = dti[:1] + tm.assert_index_equal(result, expected) + + result = pd.unique(ser) + expected = dti[:1]._data + tm.assert_extension_array_equal(result, expected) + + result = pd.unique(dti) + expected = dti[:1] + tm.assert_index_equal(result, expected) + + def test_order_of_appearance(self): + # 9346 + # light testing of guarantee of order of appearance + # these also are the doc-examples + result = pd.unique(Series([2, 1, 3, 3])) + tm.assert_numpy_array_equal(result, np.array([2, 1, 3], dtype="int64")) + + result = pd.unique(Series([2] + [1] * 5)) + tm.assert_numpy_array_equal(result, np.array([2, 1], dtype="int64")) + + data = np.array(["a", "a", "b", "c"], dtype=object) + result = pd.unique(data) + expected = np.array(["a", "b", "c"], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + result = pd.unique(Series(Categorical(list("aabc")))) + expected = Categorical(list("abc")) + tm.assert_categorical_equal(result, expected) + + def test_order_of_appearance_dt64(self, unit): + ser = Series([Timestamp("20160101"), Timestamp("20160101")]).dt.as_unit(unit) + result = pd.unique(ser) + expected = np.array(["2016-01-01T00:00:00.000000000"], dtype=f"M8[{unit}]") + tm.assert_numpy_array_equal(result, expected) + + def test_order_of_appearance_dt64tz(self, unit): + dti = DatetimeIndex( + [ + Timestamp("20160101", tz="US/Eastern"), + Timestamp("20160101", tz="US/Eastern"), + ] + ).as_unit(unit) + result = pd.unique(dti) + expected = DatetimeIndex( + ["2016-01-01 00:00:00"], dtype=f"datetime64[{unit}, US/Eastern]", freq=None + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "arg ,expected", + [ + (("1", "1", "2"), np.array(["1", "2"], dtype=object)), + (("foo",), np.array(["foo"], dtype=object)), + ], + ) + def test_tuple_with_strings(self, arg, expected): + # see GH 17108 + arg = com.asarray_tuplesafe(arg, dtype=object) + result = pd.unique(arg) + tm.assert_numpy_array_equal(result, expected) + + def test_obj_none_preservation(self): + # GH 20866 + arr = np.array(["foo", None], dtype=object) + result = pd.unique(arr) + expected = np.array(["foo", None], dtype=object) + + tm.assert_numpy_array_equal(result, expected, strict_nan=True) + + def test_signed_zero(self): + # GH 21866 + a = np.array([-0.0, 0.0]) + result = pd.unique(a) + expected = np.array([-0.0]) # 0.0 and -0.0 are equivalent + tm.assert_numpy_array_equal(result, expected) + + def test_different_nans(self): + # GH 21866 + # create different nans from bit-patterns: + NAN1 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000000))[0] + NAN2 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000001))[0] + assert NAN1 != NAN1 + assert NAN2 != NAN2 + a = np.array([NAN1, NAN2]) # NAN1 and NAN2 are equivalent + result = pd.unique(a) + expected = np.array([np.nan]) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("el_type", [np.float64, object]) + def test_first_nan_kept(self, el_type): + # GH 22295 + # create different nans from bit-patterns: + bits_for_nan1 = 0xFFF8000000000001 + bits_for_nan2 = 0x7FF8000000000001 + NAN1 = struct.unpack("d", struct.pack("=Q", bits_for_nan1))[0] + NAN2 = struct.unpack("d", struct.pack("=Q", bits_for_nan2))[0] + assert NAN1 != NAN1 + assert NAN2 != NAN2 + a = np.array([NAN1, NAN2], dtype=el_type) + result = pd.unique(a) + assert result.size == 1 + # use bit patterns to identify which nan was kept: + result_nan_bits = struct.unpack("=Q", struct.pack("d", result[0]))[0] + assert result_nan_bits == bits_for_nan1 + + def test_do_not_mangle_na_values(self, unique_nulls_fixture, unique_nulls_fixture2): + # GH 22295 + if unique_nulls_fixture is unique_nulls_fixture2: + return # skip it, values not unique + a = np.array([unique_nulls_fixture, unique_nulls_fixture2], dtype=object) + result = pd.unique(a) + assert result.size == 2 + assert a[0] is unique_nulls_fixture + assert a[1] is unique_nulls_fixture2 + + def test_unique_masked(self, any_numeric_ea_dtype): + # GH#48019 + ser = Series([1, pd.NA, 2] * 3, dtype=any_numeric_ea_dtype) + result = pd.unique(ser) + expected = pd.array([1, pd.NA, 2], dtype=any_numeric_ea_dtype) + tm.assert_extension_array_equal(result, expected) + + def test_unique_NumpyExtensionArray(self): + arr_complex = pd.array( + [1 + 1j, 2, 3] + ) # NumpyEADtype('complex128') => NumpyExtensionArray + result = pd.unique(arr_complex) + expected = pd.array([1 + 1j, 2 + 0j, 3 + 0j]) + tm.assert_extension_array_equal(result, expected) + + +def test_nunique_ints(index_or_series_or_array): + # GH#36327 + values = index_or_series_or_array(np.random.default_rng(2).integers(0, 20, 30)) + result = algos.nunique_ints(values) + expected = len(algos.unique(values)) + assert result == expected + + +class TestIsin: + def test_invalid(self): + msg = ( + r"only list-like objects are allowed to be passed to isin\(\), " + r"you passed a `int`" + ) + with pytest.raises(TypeError, match=msg): + algos.isin(1, 1) + with pytest.raises(TypeError, match=msg): + algos.isin(1, [1]) + with pytest.raises(TypeError, match=msg): + algos.isin([1], 1) + + def test_basic(self): + result = algos.isin(np.array([1, 2]), [1]) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(Series([1, 2]), [1]) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(Series([1, 2]), Series([1])) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(Series([1, 2]), {1}) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + arg = np.array(["a", "b"], dtype=object) + result = algos.isin(arg, ["a"]) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(Series(arg), Series(["a"])) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(Series(arg), {"a"}) + expected = np.array([True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(arg, [1]) + expected = np.array([False, False]) + tm.assert_numpy_array_equal(result, expected) + + def test_i8(self): + arr = date_range("20130101", periods=3).values + result = algos.isin(arr, [arr[0]]) + expected = np.array([True, False, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(arr, arr[0:2]) + expected = np.array([True, True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(arr, set(arr[0:2])) + expected = np.array([True, True, False]) + tm.assert_numpy_array_equal(result, expected) + + arr = timedelta_range("1 day", periods=3).values + result = algos.isin(arr, [arr[0]]) + expected = np.array([True, False, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(arr, arr[0:2]) + expected = np.array([True, True, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.isin(arr, set(arr[0:2])) + expected = np.array([True, True, False]) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("dtype1", ["m8[ns]", "M8[ns]", "M8[ns, UTC]", "period[D]"]) + @pytest.mark.parametrize("dtype", ["i8", "f8", "u8"]) + def test_isin_datetimelike_values_numeric_comps(self, dtype, dtype1): + # Anything but object and we get all-False shortcut + + dta = date_range("2013-01-01", periods=3)._values + arr = Series(dta.view("i8")).array.view(dtype1) + + comps = arr.view("i8").astype(dtype) + + result = algos.isin(comps, arr) + expected = np.zeros(comps.shape, dtype=bool) + tm.assert_numpy_array_equal(result, expected) + + def test_large(self): + s = date_range("20000101", periods=2000000, freq="s").values + result = algos.isin(s, s[0:2]) + expected = np.zeros(len(s), dtype=bool) + expected[0] = True + expected[1] = True + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["m8[ns]", "M8[ns]", "M8[ns, UTC]", "period[D]"]) + def test_isin_datetimelike_all_nat(self, dtype): + # GH#56427 + dta = date_range("2013-01-01", periods=3)._values + arr = Series(dta.view("i8")).array.view(dtype) + + arr[0] = NaT + result = algos.isin(arr, [NaT]) + expected = np.array([True, False, False], dtype=bool) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("dtype", ["m8[ns]", "M8[ns]", "M8[ns, UTC]"]) + def test_isin_datetimelike_strings_returns_false(self, dtype): + # GH#53111 + dta = date_range("2013-01-01", periods=3)._values + arr = Series(dta.view("i8")).array.view(dtype) + + vals = [str(x) for x in arr] + res = algos.isin(arr, vals) + assert not res.any() + + vals2 = np.array(vals, dtype=str) + res2 = algos.isin(arr, vals2) + assert not res2.any() + + def test_isin_dt64tz_with_nat(self): + # the all-NaT values used to get inferred to tznaive, which was evaluated + # as non-matching GH#56427 + dti = date_range("2016-01-01", periods=3, tz="UTC") + ser = Series(dti) + ser[0] = NaT + + res = algos.isin(ser._values, [NaT]) + exp = np.array([True, False, False], dtype=bool) + tm.assert_numpy_array_equal(res, exp) + + def test_categorical_from_codes(self): + # GH 16639 + vals = np.array([0, 1, 2, 0]) + cats = ["a", "b", "c"] + Sd = Series(Categorical([1]).from_codes(vals, cats)) + St = Series(Categorical([1]).from_codes(np.array([0, 1]), cats)) + expected = np.array([True, True, False, True]) + result = algos.isin(Sd, St) + tm.assert_numpy_array_equal(expected, result) + + def test_categorical_isin(self): + vals = np.array([0, 1, 2, 0]) + cats = ["a", "b", "c"] + cat = Categorical([1]).from_codes(vals, cats) + other = Categorical([1]).from_codes(np.array([0, 1]), cats) + + expected = np.array([True, True, False, True]) + result = algos.isin(cat, other) + tm.assert_numpy_array_equal(expected, result) + + def test_same_nan_is_in(self): + # GH 22160 + # nan is special, because from " a is b" doesn't follow "a == b" + # at least, isin() should follow python's "np.nan in [nan] == True" + # casting to -> np.float64 -> another float-object somewhere on + # the way could lead jeopardize this behavior + comps = np.array([np.nan], dtype=object) # could be casted to float64 + values = [np.nan] + expected = np.array([True]) + result = algos.isin(comps, values) + tm.assert_numpy_array_equal(expected, result) + + def test_same_nan_is_in_large(self): + # https://github.com/pandas-dev/pandas/issues/22205 + s = np.tile(1.0, 1_000_001) + s[0] = np.nan + result = algos.isin(s, np.array([np.nan, 1])) + expected = np.ones(len(s), dtype=bool) + tm.assert_numpy_array_equal(result, expected) + + def test_same_nan_is_in_large_series(self): + # https://github.com/pandas-dev/pandas/issues/22205 + s = np.tile(1.0, 1_000_001) + series = Series(s) + s[0] = np.nan + result = series.isin(np.array([np.nan, 1])) + expected = Series(np.ones(len(s), dtype=bool)) + tm.assert_series_equal(result, expected) + + def test_same_object_is_in(self): + # GH 22160 + # there could be special treatment for nans + # the user however could define a custom class + # with similar behavior, then we at least should + # fall back to usual python's behavior: "a in [a] == True" + class LikeNan: + def __eq__(self, other) -> bool: + return False + + def __hash__(self): + return 0 + + a, b = LikeNan(), LikeNan() + + arg = np.array([a], dtype=object) + + # same object -> True + tm.assert_numpy_array_equal(algos.isin(arg, [a]), np.array([True])) + # different objects -> False + tm.assert_numpy_array_equal(algos.isin(arg, [b]), np.array([False])) + + def test_different_nans(self): + # GH 22160 + # all nans are handled as equivalent + + comps = [float("nan")] + values = [float("nan")] + assert comps[0] is not values[0] # different nan-objects + + # as list of python-objects: + result = algos.isin(np.array(comps), values) + tm.assert_numpy_array_equal(np.array([True]), result) + + # as object-array: + result = algos.isin( + np.asarray(comps, dtype=object), np.asarray(values, dtype=object) + ) + tm.assert_numpy_array_equal(np.array([True]), result) + + # as float64-array: + result = algos.isin( + np.asarray(comps, dtype=np.float64), np.asarray(values, dtype=np.float64) + ) + tm.assert_numpy_array_equal(np.array([True]), result) + + def test_no_cast(self): + # GH 22160 + # ensure 42 is not casted to a string + comps = np.array(["ss", 42], dtype=object) + values = ["42"] + expected = np.array([False, False]) + + result = algos.isin(comps, values) + tm.assert_numpy_array_equal(expected, result) + + @pytest.mark.parametrize("empty", [[], Series(dtype=object), np.array([])]) + def test_empty(self, empty): + # see gh-16991 + vals = Index(["a", "b"]) + expected = np.array([False, False]) + + result = algos.isin(vals, empty) + tm.assert_numpy_array_equal(expected, result) + + def test_different_nan_objects(self): + # GH 22119 + comps = np.array(["nan", np.nan * 1j, float("nan")], dtype=object) + vals = np.array([float("nan")], dtype=object) + expected = np.array([False, False, True]) + result = algos.isin(comps, vals) + tm.assert_numpy_array_equal(expected, result) + + def test_different_nans_as_float64(self): + # GH 21866 + # create different nans from bit-patterns, + # these nans will land in different buckets in the hash-table + # if no special care is taken + NAN1 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000000))[0] + NAN2 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000001))[0] + assert NAN1 != NAN1 + assert NAN2 != NAN2 + + # check that NAN1 and NAN2 are equivalent: + arr = np.array([NAN1, NAN2], dtype=np.float64) + lookup1 = np.array([NAN1], dtype=np.float64) + result = algos.isin(arr, lookup1) + expected = np.array([True, True]) + tm.assert_numpy_array_equal(result, expected) + + lookup2 = np.array([NAN2], dtype=np.float64) + result = algos.isin(arr, lookup2) + expected = np.array([True, True]) + tm.assert_numpy_array_equal(result, expected) + + def test_isin_int_df_string_search(self): + """Comparing df with int`s (1,2) with a string at isin() ("1") + -> should not match values because int 1 is not equal str 1""" + df = DataFrame({"values": [1, 2]}) + result = df.isin(["1"]) + expected_false = DataFrame({"values": [False, False]}) + tm.assert_frame_equal(result, expected_false) + + def test_isin_nan_df_string_search(self): + """Comparing df with nan value (np.nan,2) with a string at isin() ("NaN") + -> should not match values because np.nan is not equal str NaN""" + df = DataFrame({"values": [np.nan, 2]}) + result = df.isin(np.array(["NaN"], dtype=object)) + expected_false = DataFrame({"values": [False, False]}) + tm.assert_frame_equal(result, expected_false) + + def test_isin_float_df_string_search(self): + """Comparing df with floats (1.4245,2.32441) with a string at isin() ("1.4245") + -> should not match values because float 1.4245 is not equal str 1.4245""" + df = DataFrame({"values": [1.4245, 2.32441]}) + result = df.isin(np.array(["1.4245"], dtype=object)) + expected_false = DataFrame({"values": [False, False]}) + tm.assert_frame_equal(result, expected_false) + + def test_isin_unsigned_dtype(self): + # GH#46485 + ser = Series([1378774140726870442], dtype=np.uint64) + result = ser.isin([1378774140726870528]) + expected = Series(False) + tm.assert_series_equal(result, expected) + + +class TestValueCounts: + def test_value_counts(self): + arr = np.random.default_rng(1234).standard_normal(4) + factor = cut(arr, 4) + + # assert isinstance(factor, n) + result = algos.value_counts_internal(factor) + breaks = [-1.606, -1.018, -0.431, 0.155, 0.741] + index = IntervalIndex.from_breaks(breaks).astype(CategoricalDtype(ordered=True)) + expected = Series([1, 0, 2, 1], index=index, name="count") + tm.assert_series_equal(result.sort_index(), expected.sort_index()) + + def test_value_counts_bins(self): + s = [1, 2, 3, 4] + result = algos.value_counts_internal(s, bins=1) + expected = Series( + [4], index=IntervalIndex.from_tuples([(0.996, 4.0)]), name="count" + ) + tm.assert_series_equal(result, expected) + + result = algos.value_counts_internal(s, bins=2, sort=False) + expected = Series( + [2, 2], + index=IntervalIndex.from_tuples([(0.996, 2.5), (2.5, 4.0)]), + name="count", + ) + tm.assert_series_equal(result, expected) + + def test_value_counts_dtypes(self): + result = algos.value_counts_internal(np.array([1, 1.0])) + assert len(result) == 1 + + result = algos.value_counts_internal(np.array([1, 1.0]), bins=1) + assert len(result) == 1 + + result = algos.value_counts_internal(Series([1, 1.0, "1"])) # object + assert len(result) == 2 + + msg = "bins argument only works with numeric data" + with pytest.raises(TypeError, match=msg): + algos.value_counts_internal(np.array(["1", 1], dtype=object), bins=1) + + def test_value_counts_nat(self): + td = Series([np.timedelta64(10000), NaT], dtype="timedelta64[ns]") + dt = to_datetime(["NaT", "2014-01-01"]) + + for ser in [td, dt]: + vc = algos.value_counts_internal(ser) + vc_with_na = algos.value_counts_internal(ser, dropna=False) + assert len(vc) == 1 + assert len(vc_with_na) == 2 + + exp_dt = Series({Timestamp("2014-01-01 00:00:00"): 1}, name="count") + result_dt = algos.value_counts_internal(dt) + tm.assert_series_equal(result_dt, exp_dt) + + exp_td = Series([1], index=[np.timedelta64(10000)], name="count") + result_td = algos.value_counts_internal(td) + tm.assert_series_equal(result_td, exp_td) + + @pytest.mark.parametrize("dtype", [object, "M8[us]"]) + def test_value_counts_datetime_outofbounds(self, dtype): + # GH 13663 + ser = Series( + [ + datetime(3000, 1, 1), + datetime(5000, 1, 1), + datetime(5000, 1, 1), + datetime(6000, 1, 1), + datetime(3000, 1, 1), + datetime(3000, 1, 1), + ], + dtype=dtype, + ) + + res = ser.value_counts() + + exp_index = Index( + [datetime(3000, 1, 1), datetime(5000, 1, 1), datetime(6000, 1, 1)], + dtype=dtype, + ) + exp = Series([3, 2, 1], index=exp_index, name="count") + tm.assert_series_equal(res, exp) + + def test_categorical(self): + s = Series(Categorical(list("aaabbc"))) + result = s.value_counts() + expected = Series( + [3, 2, 1], index=CategoricalIndex(["a", "b", "c"]), name="count" + ) + + tm.assert_series_equal(result, expected, check_index_type=True) + + # preserve order? + s = s.cat.as_ordered() + result = s.value_counts() + expected.index = expected.index.as_ordered() + tm.assert_series_equal(result, expected, check_index_type=True) + + def test_categorical_nans(self): + s = Series(Categorical(list("aaaaabbbcc"))) # 4,3,2,1 (nan) + s.iloc[1] = np.nan + result = s.value_counts() + expected = Series( + [4, 3, 2], + index=CategoricalIndex(["a", "b", "c"], categories=["a", "b", "c"]), + name="count", + ) + tm.assert_series_equal(result, expected, check_index_type=True) + result = s.value_counts(dropna=False) + expected = Series( + [4, 3, 2, 1], index=CategoricalIndex(["a", "b", "c", np.nan]), name="count" + ) + tm.assert_series_equal(result, expected, check_index_type=True) + + # out of order + s = Series( + Categorical(list("aaaaabbbcc"), ordered=True, categories=["b", "a", "c"]) + ) + s.iloc[1] = np.nan + result = s.value_counts() + expected = Series( + [4, 3, 2], + index=CategoricalIndex( + ["a", "b", "c"], + categories=["b", "a", "c"], + ordered=True, + ), + name="count", + ) + tm.assert_series_equal(result, expected, check_index_type=True) + + result = s.value_counts(dropna=False) + expected = Series( + [4, 3, 2, 1], + index=CategoricalIndex( + ["a", "b", "c", np.nan], categories=["b", "a", "c"], ordered=True + ), + name="count", + ) + tm.assert_series_equal(result, expected, check_index_type=True) + + def test_categorical_zeroes(self): + # keep the `d` category with 0 + s = Series(Categorical(list("bbbaac"), categories=list("abcd"), ordered=True)) + result = s.value_counts() + expected = Series( + [3, 2, 1, 0], + index=Categorical( + ["b", "a", "c", "d"], categories=list("abcd"), ordered=True + ), + name="count", + ) + tm.assert_series_equal(result, expected, check_index_type=True) + + def test_value_counts_dropna(self): + # https://github.com/pandas-dev/pandas/issues/9443#issuecomment-73719328 + + tm.assert_series_equal( + Series([True, True, False]).value_counts(dropna=True), + Series([2, 1], index=[True, False], name="count"), + ) + tm.assert_series_equal( + Series([True, True, False]).value_counts(dropna=False), + Series([2, 1], index=[True, False], name="count"), + ) + + tm.assert_series_equal( + Series([True] * 3 + [False] * 2 + [None] * 5).value_counts(dropna=True), + Series([3, 2], index=Index([True, False], dtype=object), name="count"), + ) + tm.assert_series_equal( + Series([True] * 5 + [False] * 3 + [None] * 2).value_counts(dropna=False), + Series([5, 3, 2], index=[True, False, None], name="count"), + ) + tm.assert_series_equal( + Series([10.3, 5.0, 5.0]).value_counts(dropna=True), + Series([2, 1], index=[5.0, 10.3], name="count"), + ) + tm.assert_series_equal( + Series([10.3, 5.0, 5.0]).value_counts(dropna=False), + Series([2, 1], index=[5.0, 10.3], name="count"), + ) + + tm.assert_series_equal( + Series([10.3, 5.0, 5.0, None]).value_counts(dropna=True), + Series([2, 1], index=[5.0, 10.3], name="count"), + ) + + result = Series([10.3, 10.3, 5.0, 5.0, 5.0, None]).value_counts(dropna=False) + expected = Series([3, 2, 1], index=[5.0, 10.3, None], name="count") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("dtype", (np.float64, object, "M8[ns]")) + def test_value_counts_normalized(self, dtype): + # GH12558 + s = Series([1] * 2 + [2] * 3 + [np.nan] * 5) + s_typed = s.astype(dtype) + result = s_typed.value_counts(normalize=True, dropna=False) + expected = Series( + [0.5, 0.3, 0.2], + index=Series([np.nan, 2.0, 1.0], dtype=dtype), + name="proportion", + ) + tm.assert_series_equal(result, expected) + + result = s_typed.value_counts(normalize=True, dropna=True) + expected = Series( + [0.6, 0.4], index=Series([2.0, 1.0], dtype=dtype), name="proportion" + ) + tm.assert_series_equal(result, expected) + + def test_value_counts_uint64(self): + arr = np.array([2**63], dtype=np.uint64) + expected = Series([1], index=[2**63], name="count") + result = algos.value_counts_internal(arr) + + tm.assert_series_equal(result, expected) + + arr = np.array([-1, 2**63], dtype=object) + expected = Series([1, 1], index=[-1, 2**63], name="count") + result = algos.value_counts_internal(arr) + + tm.assert_series_equal(result, expected) + + def test_value_counts_series(self): + # GH#54857 + values = np.array([3, 1, 2, 3, 4, np.nan]) + result = Series(values).value_counts(bins=3) + expected = Series( + [2, 2, 1], + index=IntervalIndex.from_tuples( + [(0.996, 2.0), (2.0, 3.0), (3.0, 4.0)], dtype="interval[float64, right]" + ), + name="count", + ) + tm.assert_series_equal(result, expected) + + def test_value_counts_stability(self): + # GH 63155 + arr = np.random.default_rng(2).integers(0, 32, 64) + result = algos.value_counts_internal(arr, sort=True) + + value_counts = Series(arr).value_counts(sort=False) + expected = value_counts.sort_values(ascending=False, kind="stable") + tm.assert_series_equal(result, expected) + + unstable_sorted = value_counts.sort_values(ascending=False, kind="quicksort") + with pytest.raises(AssertionError): + tm.assert_series_equal(result, unstable_sorted) + + +class TestDuplicated: + def test_duplicated_with_nas(self): + keys = np.array([0, 1, np.nan, 0, 2, np.nan], dtype=object) + + result = algos.duplicated(keys) + expected = np.array([False, False, False, True, False, True]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.duplicated(keys, keep="first") + expected = np.array([False, False, False, True, False, True]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.duplicated(keys, keep="last") + expected = np.array([True, False, True, False, False, False]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.duplicated(keys, keep=False) + expected = np.array([True, False, True, True, False, True]) + tm.assert_numpy_array_equal(result, expected) + + keys = np.empty(8, dtype=object) + for i, t in enumerate( + zip([0, 0, np.nan, np.nan] * 2, [0, np.nan, 0, np.nan] * 2, strict=True) + ): + keys[i] = t + + result = algos.duplicated(keys) + falses = [False] * 4 + trues = [True] * 4 + expected = np.array(falses + trues) + tm.assert_numpy_array_equal(result, expected) + + result = algos.duplicated(keys, keep="last") + expected = np.array(trues + falses) + tm.assert_numpy_array_equal(result, expected) + + result = algos.duplicated(keys, keep=False) + expected = np.array(trues + trues) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize( + "case", + [ + np.array([1, 2, 1, 5, 3, 2, 4, 1, 5, 6]), + np.array([1.1, 2.2, 1.1, np.nan, 3.3, 2.2, 4.4, 1.1, np.nan, 6.6]), + np.array( + [ + 1 + 1j, + 2 + 2j, + 1 + 1j, + 5 + 5j, + 3 + 3j, + 2 + 2j, + 4 + 4j, + 1 + 1j, + 5 + 5j, + 6 + 6j, + ] + ), + np.array(["a", "b", "a", "e", "c", "b", "d", "a", "e", "f"], dtype=object), + np.array([1, 2**63, 1, 3**5, 10, 2**63, 39, 1, 3**5, 7], dtype=np.uint64), + ], + ) + def test_numeric_object_likes(self, case): + exp_first = np.array( + [False, False, True, False, False, True, False, True, True, False] + ) + exp_last = np.array( + [True, True, True, True, False, False, False, False, False, False] + ) + exp_false = exp_first | exp_last + + res_first = algos.duplicated(case, keep="first") + tm.assert_numpy_array_equal(res_first, exp_first) + + res_last = algos.duplicated(case, keep="last") + tm.assert_numpy_array_equal(res_last, exp_last) + + res_false = algos.duplicated(case, keep=False) + tm.assert_numpy_array_equal(res_false, exp_false) + + # index + for idx in [Index(case), Index(case, dtype="category")]: + res_first = idx.duplicated(keep="first") + tm.assert_numpy_array_equal(res_first, exp_first) + + res_last = idx.duplicated(keep="last") + tm.assert_numpy_array_equal(res_last, exp_last) + + res_false = idx.duplicated(keep=False) + tm.assert_numpy_array_equal(res_false, exp_false) + + # series + for s in [Series(case), Series(case, dtype="category")]: + res_first = s.duplicated(keep="first") + tm.assert_series_equal(res_first, Series(exp_first)) + + res_last = s.duplicated(keep="last") + tm.assert_series_equal(res_last, Series(exp_last)) + + res_false = s.duplicated(keep=False) + tm.assert_series_equal(res_false, Series(exp_false)) + + def test_datetime_likes(self): + dt = [ + "2011-01-01", + "2011-01-02", + "2011-01-01", + "NaT", + "2011-01-03", + "2011-01-02", + "2011-01-04", + "2011-01-01", + "NaT", + "2011-01-06", + ] + td = [ + "1 days", + "2 days", + "1 days", + "NaT", + "3 days", + "2 days", + "4 days", + "1 days", + "NaT", + "6 days", + ] + + cases = [ + np.array([Timestamp(d) for d in dt]), + np.array([Timestamp(d, tz="US/Eastern") for d in dt]), + np.array([Period(d, freq="D") for d in dt]), + np.array([np.datetime64(d) for d in dt]), + np.array([Timedelta(d) for d in td]), + ] + + exp_first = np.array( + [False, False, True, False, False, True, False, True, True, False] + ) + exp_last = np.array( + [True, True, True, True, False, False, False, False, False, False] + ) + exp_false = exp_first | exp_last + + for case in cases: + res_first = algos.duplicated(case, keep="first") + tm.assert_numpy_array_equal(res_first, exp_first) + + res_last = algos.duplicated(case, keep="last") + tm.assert_numpy_array_equal(res_last, exp_last) + + res_false = algos.duplicated(case, keep=False) + tm.assert_numpy_array_equal(res_false, exp_false) + + # index + for idx in [ + Index(case), + Index(case, dtype="category"), + Index(case, dtype=object), + ]: + res_first = idx.duplicated(keep="first") + tm.assert_numpy_array_equal(res_first, exp_first) + + res_last = idx.duplicated(keep="last") + tm.assert_numpy_array_equal(res_last, exp_last) + + res_false = idx.duplicated(keep=False) + tm.assert_numpy_array_equal(res_false, exp_false) + + # series + for s in [ + Series(case), + Series(case, dtype="category"), + Series(case, dtype=object), + ]: + res_first = s.duplicated(keep="first") + tm.assert_series_equal(res_first, Series(exp_first)) + + res_last = s.duplicated(keep="last") + tm.assert_series_equal(res_last, Series(exp_last)) + + res_false = s.duplicated(keep=False) + tm.assert_series_equal(res_false, Series(exp_false)) + + @pytest.mark.parametrize("case", [Index([1, 2, 3]), pd.RangeIndex(0, 3)]) + def test_unique_index(self, case): + assert case.is_unique is True + tm.assert_numpy_array_equal(case.duplicated(), np.array([False, False, False])) + + @pytest.mark.parametrize( + "arr, uniques", + [ + ( + [(0, 0), (0, 1), (1, 0), (1, 1), (0, 0), (0, 1), (1, 0), (1, 1)], + [(0, 0), (0, 1), (1, 0), (1, 1)], + ), + ( + [("b", "c"), ("a", "b"), ("a", "b"), ("b", "c")], + [("b", "c"), ("a", "b")], + ), + ([("a", 1), ("b", 2), ("a", 3), ("a", 1)], [("a", 1), ("b", 2), ("a", 3)]), + ], + ) + def test_unique_tuples(self, arr, uniques): + # https://github.com/pandas-dev/pandas/issues/16519 + expected = np.empty(len(uniques), dtype=object) + expected[:] = uniques + + msg = ( + r"unique requires a Series, Index, ExtensionArray, np.ndarray " + r"or NumpyExtensionArray got list" + ) + with pytest.raises(TypeError, match=msg): + # GH#52986 + pd.unique(arr) + + res = pd.unique(com.asarray_tuplesafe(arr, dtype=object)) + tm.assert_numpy_array_equal(res, expected) + + @pytest.mark.parametrize( + "array,expected", + [ + ( + [1 + 1j, 0, 1, 1j, 1 + 2j, 1 + 2j], + np.array([(1 + 1j), 0j, (1 + 0j), 1j, (1 + 2j)], dtype=complex), + ) + ], + ) + def test_unique_complex_numbers(self, array, expected): + # GH 17927 + msg = ( + r"unique requires a Series, Index, ExtensionArray, np.ndarray " + r"or NumpyExtensionArray got list" + ) + + with pytest.raises(TypeError, match=msg): + # GH#52986 + pd.unique(array) + + res = pd.unique(np.array(array)) + tm.assert_numpy_array_equal(res, expected) + + +class TestHashTable: + @pytest.mark.parametrize( + "htable, data", + [ + ( + ht.PyObjectHashTable, + np.array([f"foo_{i}" for i in range(1000)], dtype=object), + ), + ( + ht.StringHashTable, + np.array([f"foo_{i}" for i in range(1000)], dtype=object), + ), + (ht.Float64HashTable, np.arange(1000, dtype=np.float64)), + (ht.Int64HashTable, np.arange(1000, dtype=np.int64)), + (ht.UInt64HashTable, np.arange(1000, dtype=np.uint64)), + ], + ) + def test_hashtable_unique(self, htable, data, writable): + # output of maker has guaranteed unique elements + s = Series(data, dtype=data.dtype) + if htable == ht.Float64HashTable: + # add NaN for float column + s.loc[500] = np.nan + elif htable == ht.PyObjectHashTable: + # use different NaN types for object column + s.loc[500:502] = [np.nan, None, NaT] + + # create duplicated selection + s_duplicated = s.sample(frac=3, replace=True).reset_index(drop=True) + s_duplicated.values.setflags(write=writable) + + # drop_duplicates has own cython code (hash_table_func_helper.pxi) + # and is tested separately; keeps first occurrence like ht.unique() + expected_unique = s_duplicated.drop_duplicates(keep="first").values + result_unique = htable().unique(s_duplicated.values) + tm.assert_numpy_array_equal(result_unique, expected_unique) + + # test return_inverse=True + # reconstruction can only succeed if the inverse is correct + result_unique, result_inverse = htable().unique( + s_duplicated.values, return_inverse=True + ) + tm.assert_numpy_array_equal(result_unique, expected_unique) + reconstr = result_unique[result_inverse] + tm.assert_numpy_array_equal(reconstr, s_duplicated.values) + + @pytest.mark.parametrize( + "htable, data", + [ + ( + ht.PyObjectHashTable, + np.array([f"foo_{i}" for i in range(1000)], dtype=object), + ), + ( + ht.StringHashTable, + np.array([f"foo_{i}" for i in range(1000)], dtype=object), + ), + (ht.Float64HashTable, np.arange(1000, dtype=np.float64)), + (ht.Int64HashTable, np.arange(1000, dtype=np.int64)), + (ht.UInt64HashTable, np.arange(1000, dtype=np.uint64)), + ], + ) + def test_hashtable_factorize(self, htable, writable, data): + # output of maker has guaranteed unique elements + s = Series(data, dtype=data.dtype) + if htable == ht.Float64HashTable: + # add NaN for float column + s.loc[500] = np.nan + elif htable == ht.PyObjectHashTable: + # use different NaN types for object column + s.loc[500:502] = [np.nan, None, NaT] + + # create duplicated selection + s_duplicated = s.sample(frac=3, replace=True).reset_index(drop=True) + s_duplicated.values.setflags(write=writable) + na_mask = s_duplicated.isna().values + + result_unique, result_inverse = htable().factorize(s_duplicated.values) + + # drop_duplicates has own cython code (hash_table_func_helper.pxi) + # and is tested separately; keeps first occurrence like ht.factorize() + # since factorize removes all NaNs, we do the same here + expected_unique = s_duplicated.dropna().drop_duplicates().values + tm.assert_numpy_array_equal(result_unique, expected_unique) + + # reconstruction can only succeed if the inverse is correct. Since + # factorize removes the NaNs, those have to be excluded here as well + result_reconstruct = result_unique[result_inverse[~na_mask]] + expected_reconstruct = s_duplicated.dropna().values + tm.assert_numpy_array_equal(result_reconstruct, expected_reconstruct) + + +class TestRank: + @pytest.mark.parametrize( + "arr", + [ + [np.nan, np.nan, 5.0, 5.0, 5.0, np.nan, 1, 2, 3, np.nan], + [4.0, np.nan, 5.0, 5.0, 5.0, np.nan, 1, 2, 4.0, np.nan], + ], + ) + def test_scipy_compat(self, arr): + sp_stats = pytest.importorskip("scipy.stats") + + arr = np.array(arr) + + mask = ~np.isfinite(arr) + result = libalgos.rank_1d(arr) + arr[mask] = np.inf + exp = sp_stats.rankdata(arr) + exp[mask] = np.nan + tm.assert_almost_equal(result, exp) + + def test_basic(self, writable, any_int_numpy_dtype): + exp = np.array([1, 2], dtype=np.float64) + + data = np.array([1, 100], dtype=any_int_numpy_dtype) + data.setflags(write=writable) + ser = Series(data) + result = algos.rank(ser) + tm.assert_numpy_array_equal(result, exp) + + @pytest.mark.parametrize("dtype", [np.float64, np.uint64]) + def test_uint64_overflow(self, dtype): + exp = np.array([1, 2], dtype=np.float64) + + s = Series([1, 2**63], dtype=dtype) + tm.assert_numpy_array_equal(algos.rank(s), exp) + + @pytest.mark.parametrize("method", ["average", "min", "max"]) + def test_rank_tiny_values(self, method): + # GH62036: regression test for ranking with tiny float values + exp = np.array([4.0, 1.0, 3.0, np.nan, 2.0], dtype=np.float64) + s = Series( + [5.4954145e29, -9.791984e-21, 9.3715776e-26, pd.NA, 1.8790257e-28], + dtype="Float64", + ) + s = s.astype(object) + result = algos.rank(s, method=method) + tm.assert_numpy_array_equal(result, exp) + + def test_too_many_ndims(self): + arr = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) + msg = "Array with ndim > 2 are not supported" + + with pytest.raises(TypeError, match=msg): + algos.rank(arr) + + @pytest.mark.single_cpu + def test_pct_max_many_rows(self): + # GH 18271 + values = np.arange(2**24 + 1) + result = algos.rank(values, pct=True).max() + assert result == 1 + + values = np.arange(2**25 + 2).reshape(2**24 + 1, 2) + result = algos.rank(values, pct=True).max() + assert result == 1 + + +class TestMode: + def test_no_mode(self): + exp = Series([], dtype=np.float64, index=Index([], dtype=int)) + result, _ = algos.mode(np.array([])) + tm.assert_numpy_array_equal(result, exp.values) + + def test_mode_single(self, any_real_numpy_dtype): + # GH 15714 + exp_single = [1] + data_single = [1] + + exp_multi = [1] + data_multi = [1, 1] + + ser = Series(data_single, dtype=any_real_numpy_dtype) + exp = Series(exp_single, dtype=any_real_numpy_dtype) + result, _ = algos.mode(ser.values) + tm.assert_numpy_array_equal(result, exp.values) + tm.assert_series_equal(ser.mode(), exp) + + ser = Series(data_multi, dtype=any_real_numpy_dtype) + exp = Series(exp_multi, dtype=any_real_numpy_dtype) + result, _ = algos.mode(ser.values) + tm.assert_numpy_array_equal(result, exp.values) + tm.assert_series_equal(ser.mode(), exp) + + def test_mode_obj_int(self): + exp = Series([1], dtype=int) + result, _ = algos.mode(exp.values) + tm.assert_numpy_array_equal(result, exp.values) + + exp = Series(["a", "b", "c"], dtype=object) + result, _ = algos.mode(exp.values) + tm.assert_numpy_array_equal(result, exp.values) + + def test_number_mode(self, any_real_numpy_dtype): + exp_single = [1] + data_single = [1] * 5 + [2] * 3 + + exp_multi = [1, 3] + data_multi = [1] * 5 + [2] * 3 + [3] * 5 + + ser = Series(data_single, dtype=any_real_numpy_dtype) + exp = Series(exp_single, dtype=any_real_numpy_dtype) + result, _ = algos.mode(ser.values) + tm.assert_numpy_array_equal(result, exp.values) + tm.assert_series_equal(ser.mode(), exp) + + ser = Series(data_multi, dtype=any_real_numpy_dtype) + exp = Series(exp_multi, dtype=any_real_numpy_dtype) + result, _ = algos.mode(ser.values) + tm.assert_numpy_array_equal(result, exp.values) + tm.assert_series_equal(ser.mode(), exp) + + def test_strobj_mode(self): + exp = ["b"] + data = ["a"] * 2 + ["b"] * 3 + + ser = Series(data, dtype="c") + exp = Series(exp, dtype="c") + result, _ = algos.mode(ser.values) + tm.assert_numpy_array_equal(result, exp.values) + tm.assert_series_equal(ser.mode(), exp) + + @pytest.mark.parametrize("dt", [str, object]) + def test_strobj_multi_char(self, dt, using_infer_string): + exp = ["bar"] + data = ["foo"] * 2 + ["bar"] * 3 + + ser = Series(data, dtype=dt) + exp = Series(exp, dtype=dt) + result, _ = algos.mode(ser.values) + if using_infer_string and dt is str: + tm.assert_extension_array_equal(result, exp.values) + else: + tm.assert_numpy_array_equal(result, exp.values) + tm.assert_series_equal(ser.mode(), exp) + + def test_datelike_mode(self): + exp = Series(["1900-05-03", "2011-01-03", "2013-01-02"], dtype="M8[ns]") + ser = Series(["2011-01-03", "2013-01-02", "1900-05-03"], dtype="M8[ns]") + tm.assert_extension_array_equal(algos.mode(ser.values), exp._values) + tm.assert_series_equal(ser.mode(), exp) + + exp = Series(["2011-01-03", "2013-01-02"], dtype="M8[ns]") + ser = Series( + ["2011-01-03", "2013-01-02", "1900-05-03", "2011-01-03", "2013-01-02"], + dtype="M8[ns]", + ) + tm.assert_extension_array_equal(algos.mode(ser.values), exp._values) + tm.assert_series_equal(ser.mode(), exp) + + def test_timedelta_mode(self): + exp = Series(["-1 days", "0 days", "1 days"], dtype="timedelta64[ns]") + ser = Series(["1 days", "-1 days", "0 days"], dtype="timedelta64[ns]") + tm.assert_extension_array_equal(algos.mode(ser.values), exp._values) + tm.assert_series_equal(ser.mode(), exp) + + exp = Series(["2 min", "1 day"], dtype="timedelta64[ns]") + ser = Series( + ["1 day", "1 day", "-1 day", "-1 day 2 min", "2 min", "2 min"], + dtype="timedelta64[ns]", + ) + tm.assert_extension_array_equal(algos.mode(ser.values), exp._values) + tm.assert_series_equal(ser.mode(), exp) + + def test_mixed_dtype(self): + exp = Series(["foo"], dtype=object) + ser = Series([1, "foo", "foo"]) + result, _ = algos.mode(ser.values) + tm.assert_numpy_array_equal(result, exp.values) + tm.assert_series_equal(ser.mode(), exp) + + def test_uint64_overflow(self): + exp = Series([2**63], dtype=np.uint64) + ser = Series([1, 2**63, 2**63], dtype=np.uint64) + result, _ = algos.mode(ser.values) + tm.assert_numpy_array_equal(result, exp.values) + tm.assert_series_equal(ser.mode(), exp) + + exp = Series([1, 2**63], dtype=np.uint64) + ser = Series([1, 2**63], dtype=np.uint64) + result, _ = algos.mode(ser.values) + tm.assert_numpy_array_equal(result, exp.values) + tm.assert_series_equal(ser.mode(), exp) + + def test_categorical(self): + c = Categorical([1, 2]) + exp = c + res = Series(c).mode()._values + tm.assert_categorical_equal(res, exp) + + c = Categorical([1, "a", "a"]) + exp = Categorical(["a"], categories=[1, "a"]) + res = Series(c).mode()._values + tm.assert_categorical_equal(res, exp) + + c = Categorical([1, 1, 2, 3, 3]) + exp = Categorical([1, 3], categories=[1, 2, 3]) + res = Series(c).mode()._values + tm.assert_categorical_equal(res, exp) + + def test_index(self): + idx = Index([1, 2, 3]) + exp = Series([1, 2, 3], dtype=np.int64) + result, _ = algos.mode(idx) + tm.assert_numpy_array_equal(result, exp.values) + + idx = Index([1, "a", "a"]) + exp = Series(["a"], dtype=object) + result, _ = algos.mode(idx) + tm.assert_numpy_array_equal(result, exp.values) + + idx = Index([1, 1, 2, 3, 3]) + exp = Series([1, 3], dtype=np.int64) + result, _ = algos.mode(idx) + tm.assert_numpy_array_equal(result, exp.values) + + idx = Index( + ["1 day", "1 day", "-1 day", "-1 day 2 min", "2 min", "2 min"], + dtype="timedelta64[ns]", + ) + with pytest.raises(AttributeError, match="TimedeltaIndex"): + # algos.mode expects Arraylike, does *not* unwrap TimedeltaIndex + algos.mode(idx) + + def test_ser_mode_with_name(self): + # GH 46737 + ser = Series([1, 1, 3], name="foo") + result = ser.mode() + expected = Series([1], name="foo") + tm.assert_series_equal(result, expected) + + +class TestDiff: + @pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"]) + def test_diff_datetimelike_nat(self, dtype): + # NaT - NaT is NaT, not 0 + arr = np.arange(12).astype(np.int64).view(dtype).reshape(3, 4) + arr[:, 2] = arr.dtype.type("NaT", "ns") + result = algos.diff(arr, 1, axis=0) + + expected = np.ones(arr.shape, dtype="timedelta64[ns]") * 4 + expected[:, 2] = np.timedelta64("NaT", "ns") + expected[0, :] = np.timedelta64("NaT", "ns") + + tm.assert_numpy_array_equal(result, expected) + + result = algos.diff(arr.T, 1, axis=1) + tm.assert_numpy_array_equal(result, expected.T) + + def test_diff_ea_axis(self): + dta = date_range("2016-01-01", periods=3, tz="US/Pacific")._data + + msg = "cannot diff DatetimeArray on axis=1" + with pytest.raises(ValueError, match=msg): + algos.diff(dta, 1, axis=1) + + @pytest.mark.parametrize("dtype", ["int8", "int16"]) + def test_diff_low_precision_int(self, dtype): + arr = np.array([0, 1, 1, 0, 0], dtype=dtype) + result = algos.diff(arr, 1) + expected = np.array([np.nan, 1, 0, -1, 0], dtype="float32") + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("op", [np.array, pd.array]) +def test_union_with_duplicates(op): + # GH#36289 + lvals = op([3, 1, 3, 4]) + rvals = op([2, 3, 1, 1]) + expected = op([3, 3, 1, 1, 4, 2]) + if isinstance(expected, np.ndarray): + result = algos.union_with_duplicates(lvals, rvals) + tm.assert_numpy_array_equal(result, expected) + else: + result = algos.union_with_duplicates(lvals, rvals) + tm.assert_extension_array_equal(result, expected) diff --git a/pandas/tests/test_col.py b/pandas/tests/test_col.py new file mode 100644 index 0000000000000000000000000000000000000000..74cac1b8d1c1e0b105dc33270fb48e55708622c8 --- /dev/null +++ b/pandas/tests/test_col.py @@ -0,0 +1,293 @@ +from datetime import datetime + +import numpy as np +import pytest + +from pandas._libs.properties import cache_readonly + +import pandas as pd +import pandas._testing as tm +from pandas.api.typing import Expression +from pandas.tests.test_register_accessor import ensure_removed + + +@pytest.mark.parametrize( + ("expr", "expected_values", "expected_str"), + [ + (pd.col("a"), [1, 2], "col('a')"), + (pd.col("a") * 2, [2, 4], "col('a') * 2"), + (pd.col("a").sum(), [3, 3], "col('a').sum()"), + (pd.col("a") + 1, [2, 3], "col('a') + 1"), + (1 + pd.col("a"), [2, 3], "1 + col('a')"), + (pd.col("a") - 1, [0, 1], "col('a') - 1"), + (1 - pd.col("a"), [0, -1], "1 - col('a')"), + (pd.col("a") * 1, [1, 2], "col('a') * 1"), + (1 * pd.col("a"), [1, 2], "1 * col('a')"), + (pd.col("a") / 1, [1.0, 2.0], "col('a') / 1"), + (1 / pd.col("a"), [1.0, 0.5], "1 / col('a')"), + (pd.col("a") // 1, [1, 2], "col('a') // 1"), + (1 // pd.col("a"), [1, 0], "1 // col('a')"), + (pd.col("a") % 1, [0, 0], "col('a') % 1"), + (1 % pd.col("a"), [0, 1], "1 % col('a')"), + (pd.col("a") > 1, [False, True], "col('a') > 1"), + (pd.col("a") >= 1, [True, True], "col('a') >= 1"), + (pd.col("a") < 1, [False, False], "col('a') < 1"), + (pd.col("a") <= 1, [True, False], "col('a') <= 1"), + (pd.col("a") == 1, [True, False], "col('a') == 1"), + (np.power(pd.col("a"), 2), [1, 4], "power(col('a'), 2)"), + (np.divide(pd.col("a"), pd.col("a")), [1.0, 1.0], "divide(col('a'), col('a'))"), + ( + (pd.col("a") + 1) * (pd.col("b") + 2), + [10, 18], + "(col('a') + 1) * (col('b') + 2)", + ), + ( + (pd.col("a") - 1).astype("bool"), + [False, True], + "(col('a') - 1).astype('bool')", + ), + # Unary operators + (-pd.col("a"), [-1, -2], "-col('a')"), + (+pd.col("a"), [1, 2], "+col('a')"), + (-(pd.col("a") + 1), [-2, -3], "-(col('a') + 1)"), + (-pd.col("a") * 2, [-2, -4], "(-col('a')) * 2"), + (abs(pd.col("a")), [1, 2], "abs(col('a'))"), + (abs(pd.col("a") - 2), [1, 0], "abs(col('a') - 2)"), + ], +) +def test_col_simple( + expr: Expression, expected_values: list[object], expected_str: str +) -> None: + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + result = df.assign(c=expr) + expected = pd.DataFrame({"a": [1, 2], "b": [3, 4], "c": expected_values}) + tm.assert_frame_equal(result, expected) + assert str(expr) == expected_str + + +def test_frame_getitem() -> None: + # https://github.com/pandas-dev/pandas/pull/63439 + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + expr = pd.col("a") == 2 + result = df[expr] + expected = df.iloc[[1]] + tm.assert_frame_equal(result, expected) + + +def test_frame_setitem() -> None: + # https://github.com/pandas-dev/pandas/pull/63439 + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + expr = pd.col("a") == 2 + + result = df.copy() + result[expr] = 100 + expected = pd.DataFrame({"a": [1, 100], "b": [3, 100]}) + tm.assert_frame_equal(result, expected) + + +def test_frame_loc() -> None: + # https://github.com/pandas-dev/pandas/pull/63439 + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + expr = pd.col("a") == 2 + result = df.copy() + result.loc[expr, "b"] = 100 + expected = pd.DataFrame({"a": [1, 2], "b": [3, 100]}) + tm.assert_frame_equal(result, expected) + + +def test_frame_iloc() -> None: + # https://github.com/pandas-dev/pandas/pull/63439 + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + expr = pd.col("a") == 2 + result = df.copy() + result.iloc[expr, 1] = 100 + expected = pd.DataFrame({"a": [1, 2], "b": [3, 100]}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("expr", "expected_values", "expected_str"), + [ + (pd.col("a").dt.year, [2020], "col('a').dt.year"), + (pd.col("a").dt.strftime("%B"), ["January"], "col('a').dt.strftime('%B')"), + (pd.col("b").str.upper(), ["FOO"], "col('b').str.upper()"), + ], +) +def test_namespaces( + expr: Expression, expected_values: list[object], expected_str: str +) -> None: + df = pd.DataFrame({"a": [datetime(2020, 1, 1)], "b": ["foo"]}) + result = df.assign(c=expr) + expected = pd.DataFrame( + {"a": [datetime(2020, 1, 1)], "b": ["foo"], "c": expected_values} + ) + tm.assert_frame_equal(result, expected, check_dtype=False) + assert str(expr) == expected_str + + +def test_invalid() -> None: + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + with pytest.raises(ValueError, match=r"did you mean one of \['a', 'b'\] instead"): + df.assign(c=pd.col("c").mean()) + df = pd.DataFrame({f"col_{i}": [0] for i in range(11)}) + msg = ( + "did you mean one of " + r"\['col_0', 'col_1', 'col_2', 'col_3', " + "'col_4', 'col_5', 'col_6', 'col_7', " + r"'col_8', 'col_9',\.\.\.\] instead" + ) + "" + with pytest.raises(ValueError, match=msg): + df.assign(c=pd.col("c").mean()) + + +def test_custom_accessor() -> None: + df = pd.DataFrame({"a": [1, 2, 3]}) + + class XYZAccessor: + def __init__(self, pandas_obj): + self._obj = pandas_obj + + def mean(self): + return self._obj.mean() + + with ensure_removed(pd.Series, "xyz"): + pd.api.extensions.register_series_accessor("xyz")(XYZAccessor) + result = df.assign(b=pd.col("a").xyz.mean()) + expected = pd.DataFrame({"a": [1, 2, 3], "b": [2.0, 2.0, 2.0]}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("expr", "expected_values", "expected_str"), + [ + ( + pd.col("a") & pd.col("b"), + [False, False, True, False], + "col('a') & col('b')", + ), + ( + pd.col("a") & True, + [True, False, True, False], + "col('a') & True", + ), + ( + pd.col("a") | pd.col("b"), + [True, True, True, True], + "col('a') | col('b')", + ), + ( + pd.col("a") | False, + [True, False, True, False], + "col('a') | False", + ), + ( + pd.col("a") ^ pd.col("b"), + [True, True, False, True], + "col('a') ^ col('b')", + ), + ( + pd.col("a") ^ True, + [False, True, False, True], + "col('a') ^ True", + ), + ( + ~pd.col("a"), + [False, True, False, True], + "~col('a')", + ), + ], +) +def test_col_logical_ops( + expr: Expression, expected_values: list[bool], expected_str: str +) -> None: + # https://github.com/pandas-dev/pandas/issues/63322 + df = pd.DataFrame({"a": [True, False, True, False], "b": [False, True, True, True]}) + result = df.assign(c=expr) + expected = pd.DataFrame( + { + "a": [True, False, True, False], + "b": [False, True, True, True], + "c": expected_values, + } + ) + tm.assert_frame_equal(result, expected) + assert str(expr) == expected_str + + # Test that the expression works with .loc + result = df.loc[expr] + expected = df[expected_values] + tm.assert_frame_equal(result, expected) + + +def test_expression_getitem() -> None: + # https://github.com/pandas-dev/pandas/pull/63439 + df = pd.DataFrame({"a": [1, 2, 3]}) + expr = pd.col("a")[1] + expected_str = "col('a')[1]" + + assert str(expr) == expected_str + + result = df.assign(b=expr) + expected = pd.DataFrame({"a": [1, 2, 3], "b": [2, 2, 2]}) + tm.assert_frame_equal(result, expected) + + +def test_property() -> None: + # https://github.com/pandas-dev/pandas/pull/63439 + df = pd.DataFrame({"a": [1, 2, 3]}) + expr = pd.col("a").index + expected_str = "col('a').index" + + assert str(expr) == expected_str + + result = df.assign(b=expr) + expected = pd.DataFrame({"a": [1, 2, 3], "b": [0, 1, 2]}) + tm.assert_frame_equal(result, expected) + + +def test_cached_property() -> None: + # https://github.com/pandas-dev/pandas/pull/63439 + # Ensure test is valid + assert isinstance(pd.Index.dtype, cache_readonly) + + df = pd.DataFrame({"a": [1, 2, 3]}) + expr = pd.col("a").index.dtype + expected_str = "col('a').index.dtype" + assert str(expr) == expected_str + + result = df.assign(b=expr) + expected = pd.DataFrame({"a": [1, 2, 3], "b": np.int64}) + tm.assert_frame_equal(result, expected) + + +def test_qcut() -> None: + # https://github.com/pandas-dev/pandas/pull/63439 + df = pd.DataFrame({"a": [1, 2, 3]}) + expr = pd.qcut(pd.col("a"), 3) + expected_str = "qcut(x=col('a'), q=3, labels=None, retbins=False, precision=3)" + assert str(expr) == expected_str, str(expr) + + result = df.assign(b=expr) + expected = pd.DataFrame({"a": [1, 2, 3], "b": pd.qcut(df["a"], 3)}) + tm.assert_frame_equal(result, expected) + + +def test_where() -> None: + # https://github.com/pandas-dev/pandas/pull/63439 + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + expr = pd.col("a").where(pd.col("b") == 5, 100) + expected_str = "col('a').where(col('b') == 5, 100)" + assert str(expr) == expected_str, str(expr) + + result = df.assign(c=expr) + expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [100, 2, 100]}) + tm.assert_frame_equal(result, expected) + + expr = pd.col("a").where(pd.col("b") == 5, pd.col("a") + 1) + expected_str = "col('a').where(col('b') == 5, col('a') + 1)" + assert str(expr) == expected_str, str(expr) + + result = df.assign(c=expr) + expected = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [2, 2, 4]}) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/test_common.py b/pandas/tests/test_common.py new file mode 100644 index 0000000000000000000000000000000000000000..eab2ce6a2ea943a96471cdfe919dfb8909ce1e8e --- /dev/null +++ b/pandas/tests/test_common.py @@ -0,0 +1,273 @@ +import collections +from functools import partial +import string +import subprocess +import sys + +import numpy as np +import pytest + +from pandas.compat import WASM + +import pandas as pd +from pandas import Series +import pandas._testing as tm +from pandas.core import ops +import pandas.core.common as com +from pandas.util.version import Version + + +class TestGetCallableName: + def fn(self, x): + return x + + partial1 = partial(fn) + partial2 = partial(partial1) + lambda_ = lambda x: x + + class SomeCall: + def __call__(self): + # This shouldn't actually get called below; SomeCall.__init__ + # should. + raise NotImplementedError + + @pytest.mark.parametrize( + "func, expected", + [ + (fn, "fn"), + (partial1, "fn"), + (partial2, "fn"), + (lambda_, ""), + (SomeCall(), "SomeCall"), + (1, None), + ], + ) + def test_get_callable_name(self, func, expected): + assert com.get_callable_name(func) == expected + + +class TestRandomState: + def test_seed(self): + seed = 5 + assert com.random_state(seed).uniform() == np.random.RandomState(seed).uniform() + + def test_object(self): + seed = 10 + state_obj = np.random.RandomState(seed) + assert ( + com.random_state(state_obj).uniform() + == np.random.RandomState(seed).uniform() + ) + + def test_default(self): + assert com.random_state() is np.random + + def test_array_like(self): + state = np.random.default_rng(None).integers(0, 2**31, size=624, dtype="uint32") + assert ( + com.random_state(state).uniform() == np.random.RandomState(state).uniform() + ) + + def test_bit_generators(self): + seed = 3 + assert ( + com.random_state(np.random.MT19937(seed)).uniform() + == np.random.RandomState(np.random.MT19937(seed)).uniform() + ) + + seed = 11 + assert ( + com.random_state(np.random.PCG64(seed)).uniform() + == np.random.RandomState(np.random.PCG64(seed)).uniform() + ) + + @pytest.mark.parametrize("state", ["test", 5.5]) + def test_error(self, state): + msg = ( + "random_state must be an integer, array-like, a BitGenerator, Generator, " + "a numpy RandomState, or None" + ) + with pytest.raises(ValueError, match=msg): + com.random_state(state) + + +@pytest.mark.parametrize("args, expected", [((1, 2, None), True), ((1, 2, 3), False)]) +def test_any_none(args, expected): + assert com.any_none(*args) is expected + + +@pytest.mark.parametrize( + "args, expected", + [((1, 2, 3), True), ((1, 2, None), False), ((None, None, None), False)], +) +def test_all_not_none(args, expected): + assert com.all_not_none(*args) is expected + + +@pytest.mark.parametrize( + "left, right, expected", + [ + (Series([1], name="x"), Series([2], name="x"), "x"), + (Series([1], name="x"), Series([2], name="y"), None), + (Series([1]), Series([2], name="x"), None), + (Series([1], name="x"), Series([2]), None), + (Series([1], name="x"), [2], "x"), + ([1], Series([2], name="y"), "y"), + # matching NAs + (Series([1], name=np.nan), pd.Index([], name=np.nan), np.nan), + (Series([1], name=np.nan), pd.Index([], name=pd.NaT), None), + (Series([1], name=pd.NA), pd.Index([], name=pd.NA), pd.NA), + # tuple name GH#39757 + ( + Series([1], name=np.int64(1)), + pd.Index([], name=(np.int64(1), np.int64(2))), + None, + ), + ( + Series([1], name=(np.int64(1), np.int64(2))), + pd.Index([], name=(np.int64(1), np.int64(2))), + (np.int64(1), np.int64(2)), + ), + pytest.param( + Series([1], name=(np.float64("nan"), np.int64(2))), + pd.Index([], name=(np.float64("nan"), np.int64(2))), + (np.float64("nan"), np.int64(2)), + marks=pytest.mark.xfail( + reason="Not checking for matching NAs inside tuples." + ), + ), + ], +) +def test_maybe_match_name(left, right, expected): + res = ops.common._maybe_match_name(left, right) + assert res is expected or res == expected + + +@pytest.mark.parametrize( + "into, msg", + [ + ( + # uninitialized defaultdict + collections.defaultdict, + r"to_dict\(\) only accepts initialized defaultdicts", + ), + ( + # non-mapping subtypes,, instance + [], + "unsupported type: ", + ), + ( + # non-mapping subtypes, class + list, + "unsupported type: ", + ), + ], +) +def test_standardize_mapping_type_error(into, msg): + with pytest.raises(TypeError, match=msg): + com.standardize_mapping(into) + + +def test_standardize_mapping(): + fill = {"bad": "data"} + assert com.standardize_mapping(fill) == dict + + # Convert instance to type + assert com.standardize_mapping({}) == dict + + dd = collections.defaultdict(list) + assert isinstance(com.standardize_mapping(dd), partial) + + +def test_git_version(): + # GH 21295 + git_version = pd.__git_version__ + assert len(git_version) == 40 + assert all(c in string.hexdigits for c in git_version) + + +def test_version_tag(): + version = Version(pd.__version__) + try: + version > Version("0.0.1") + except TypeError as err: + raise ValueError( + "No git tags exist, please sync tags between upstream and your repo" + ) from err + + +@pytest.mark.parametrize("obj", [obj for obj in pd.__dict__.values() if callable(obj)]) +def test_serializable(obj, temp_file): + # GH 35611 + unpickled = tm.round_trip_pickle(obj, temp_file) + assert type(obj) == type(unpickled) + + +class TestIsBoolIndexer: + def test_non_bool_array_with_na(self): + # in particular, this should not raise + arr = np.array(["A", "B", np.nan], dtype=object) + assert not com.is_bool_indexer(arr) + + def test_list_subclass(self): + # GH#42433 + + class MyList(list): + pass + + val = MyList(["a"]) + + assert not com.is_bool_indexer(val) + + val = MyList([True]) + assert com.is_bool_indexer(val) + + def test_frozenlist(self): + # GH#42461 + data = {"col1": [1, 2], "col2": [3, 4]} + df = pd.DataFrame(data=data) + + frozen = df.index.names[1:] + assert not com.is_bool_indexer(frozen) + + result = df[frozen] + expected = df[[]] + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("scalar", [1, True]) + def test_numpyextensionarray(self, scalar): + # GH 63391 + arr = pd.arrays.NumpyExtensionArray(np.array([scalar])) + assert com.is_bool_indexer(arr) is isinstance(scalar, bool) + + +@pytest.mark.parametrize("with_exception", [True, False]) +def test_temp_setattr(with_exception): + # GH#45954 + ser = Series(dtype=object) + ser.name = "first" + # Raise a ValueError in either case to satisfy pytest.raises + match = "Inside exception raised" if with_exception else "Outside exception raised" + with pytest.raises(ValueError, match=match): + with com.temp_setattr(ser, "name", "second"): + assert ser.name == "second" + if with_exception: + raise ValueError("Inside exception raised") + raise ValueError("Outside exception raised") + assert ser.name == "first" + + +@pytest.mark.skipif(WASM, reason="Can't start subprocesses in WASM") +@pytest.mark.single_cpu +def test_str_size(): + # GH#21758 + a = "a" + expected = sys.getsizeof(a) + pyexe = sys.executable.replace("\\", "/") + call = [ + pyexe, + "-c", + "a='a';import sys;sys.getsizeof(a);import pandas;print(sys.getsizeof(a));", + ] + result = subprocess.check_output(call).decode()[-4:-1].strip("\n") + assert int(result) == int(expected) diff --git a/pandas/tests/test_downstream.py b/pandas/tests/test_downstream.py new file mode 100644 index 0000000000000000000000000000000000000000..60b6537e11bac0ae201dd52577ae59542b4e81f2 --- /dev/null +++ b/pandas/tests/test_downstream.py @@ -0,0 +1,302 @@ +""" +Testing that we work in the downstream packages +""" + +import array +from functools import partial +import importlib +import subprocess +import sys + +import numpy as np +import pytest + +from pandas.errors import IntCastingNaNError + +import pandas as pd +from pandas import ( + DataFrame, + DatetimeIndex, + Series, + TimedeltaIndex, +) +import pandas._testing as tm +from pandas.util.version import Version + + +@pytest.fixture +def df(): + return DataFrame({"A": [1, 2, 3]}) + + +def test_dask(df): + # dask sets "compute.use_numexpr" to False, so catch the current value + # and ensure to reset it afterwards to avoid impacting other tests + olduse = pd.get_option("compute.use_numexpr") + + try: + pytest.importorskip("toolz") + dd = pytest.importorskip("dask.dataframe") + + ddf = dd.from_pandas(df, npartitions=3) + assert ddf.A is not None + assert ddf.compute() is not None + finally: + pd.set_option("compute.use_numexpr", olduse) + + +# TODO(CoW) see https://github.com/pandas-dev/pandas/pull/51082 +@pytest.mark.skip(reason="not implemented with CoW") +def test_dask_ufunc(): + # dask sets "compute.use_numexpr" to False, so catch the current value + # and ensure to reset it afterwards to avoid impacting other tests + olduse = pd.get_option("compute.use_numexpr") + + try: + da = pytest.importorskip("dask.array") + dd = pytest.importorskip("dask.dataframe") + + s = Series([1.5, 2.3, 3.7, 4.0]) + ds = dd.from_pandas(s, npartitions=2) + + result = da.log(ds).compute() + expected = np.log(s) + tm.assert_series_equal(result, expected) + finally: + pd.set_option("compute.use_numexpr", olduse) + + +def test_construct_dask_float_array_int_dtype_match_ndarray(): + # GH#40110 make sure we treat a float-dtype dask array with the same + # rules we would for an ndarray + dd = pytest.importorskip("dask.dataframe") + + arr = np.array([1, 2.5, 3]) + darr = dd.from_array(arr) + + res = Series(darr) + expected = Series(arr) + tm.assert_series_equal(res, expected) + + # GH#49599 in 2.0 we raise instead of silently ignoring the dtype + msg = "Trying to coerce float values to integers" + with pytest.raises(ValueError, match=msg): + Series(darr, dtype="i8") + + msg = r"Cannot convert non-finite values \(NA or inf\) to integer" + arr[2] = np.nan + with pytest.raises(IntCastingNaNError, match=msg): + Series(darr, dtype="i8") + # which is the same as we get with a numpy input + with pytest.raises(IntCastingNaNError, match=msg): + Series(arr, dtype="i8") + + +def test_xarray(df): + pytest.importorskip("xarray") + + assert df.to_xarray() is not None + + +def test_xarray_cftimeindex_nearest(): + # https://github.com/pydata/xarray/issues/3751 + cftime = pytest.importorskip("cftime") + xarray = pytest.importorskip("xarray") + + times = xarray.date_range("0001", periods=2, use_cftime=True) + key = cftime.DatetimeGregorian(2000, 1, 1) + result = times.get_indexer([key], method="nearest") + expected = 1 + assert result == expected + + +@pytest.mark.single_cpu +def test_oo_optimizable(): + # GH 21071 + subprocess.check_call([sys.executable, "-OO", "-c", "import pandas"]) + + +@pytest.mark.single_cpu +def test_oo_optimized_datetime_index_unpickle(): + # GH 42866 + subprocess.check_call( + [ + sys.executable, + "-OO", + "-c", + ( + "import pandas as pd, pickle; " + "pickle.loads(pickle.dumps(pd.date_range('2021-01-01', periods=1)))" + ), + ] + ) + + +def test_statsmodels(): + smf = pytest.importorskip("statsmodels.formula.api") + + df = DataFrame( + {"Lottery": range(5), "Literacy": range(5), "Pop1831": range(100, 105)} + ) + smf.ols("Lottery ~ Literacy + np.log(Pop1831)", data=df).fit() + + +def test_scikit_learn(): + pytest.importorskip("sklearn") + from sklearn import ( + datasets, + svm, + ) + + digits = datasets.load_digits() + clf = svm.SVC(gamma=0.001, C=100.0) + clf.fit(digits.data[:-1], digits.target[:-1]) + clf.predict(digits.data[-1:]) + + +def test_seaborn(mpl_cleanup): + seaborn = pytest.importorskip("seaborn") + tips = DataFrame( + {"day": pd.date_range("2023", freq="D", periods=5), "total_bill": range(5)} + ) + seaborn.stripplot(x="day", y="total_bill", data=tips) + + +@pytest.mark.xfail(reason="pandas_datareader uses old variant of deprecate_kwarg") +def test_pandas_datareader(): + # https://github.com/pandas-dev/pandas/pull/61468 + # https://github.com/pydata/pandas-datareader/issues/1005 + pytest.importorskip("pandas_datareader") + + +@pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning") +def test_pyarrow(df): + pyarrow = pytest.importorskip("pyarrow") + table = pyarrow.Table.from_pandas(df) + result = table.to_pandas() + tm.assert_frame_equal(result, df) + + +def test_yaml_dump(df): + # GH#42748 + yaml = pytest.importorskip("yaml") + + dumped = yaml.dump(df) + + loaded = yaml.load(dumped, Loader=yaml.Loader) + tm.assert_frame_equal(df, loaded) + + loaded2 = yaml.load(dumped, Loader=yaml.UnsafeLoader) + tm.assert_frame_equal(df, loaded2) + + +@pytest.mark.parametrize("dependency", ["numpy", "dateutil"]) +def test_missing_required_dependency(monkeypatch, dependency): + # GH#61030 + original_import = __import__ + mock_error = ImportError(f"Mock error for {dependency}") + + def mock_import(name, *args, **kwargs): + if name == dependency: + raise mock_error + return original_import(name, *args, **kwargs) + + monkeypatch.setattr("builtins.__import__", mock_import) + + with pytest.raises(ImportError, match=dependency): + importlib.reload(importlib.import_module("pandas")) + + +def test_frame_setitem_dask_array_into_new_col(request): + # GH#47128 + + # dask sets "compute.use_numexpr" to False, so catch the current value + # and ensure to reset it afterwards to avoid impacting other tests + olduse = pd.get_option("compute.use_numexpr") + + try: + dask = pytest.importorskip("dask") + da = pytest.importorskip("dask.array") + if Version(dask.__version__) <= Version("2025.1.0") and Version( + np.__version__ + ) >= Version("2.1"): + request.applymarker( + pytest.mark.xfail(reason="loc.__setitem__ incorrectly mutated column c") + ) + + dda = da.array([1, 2]) + df = DataFrame({"a": ["a", "b"]}) + df["b"] = dda + df["c"] = dda + df.loc[[False, True], "b"] = 100 + result = df.loc[[1], :] + expected = DataFrame({"a": ["b"], "b": [100], "c": [2]}, index=[1]) + tm.assert_frame_equal(result, expected) + finally: + pd.set_option("compute.use_numexpr", olduse) + + +def test_pandas_priority(): + # GH#48347 + + class MyClass: + __pandas_priority__ = 5000 + + def __radd__(self, other): + return self + + left = MyClass() + right = Series(range(3)) + + assert right.__add__(left) is NotImplemented + assert right + left is left + + +@pytest.mark.parametrize("dtype", ["M8[ns]", "m8[ns]"]) +@pytest.mark.parametrize( + "box", [memoryview, partial(array.array, "i"), "dask", "xarray"] +) +def test_from_obscure_array(dtype, box): + # GH#24539 recognize e.g xarray, dask, ... + # Note: we dont do this for PeriodArray bc _from_sequence won't accept + # an array of integers + # TODO: could check with arraylike of Period objects + # GH#24539 recognize e.g xarray, dask, ... + arr = np.array([1, 2, 3], dtype=np.int64) + if box == "dask": + da = pytest.importorskip("dask.array") + data = da.array(arr) + elif box == "xarray": + xr = pytest.importorskip("xarray") + data = xr.DataArray(arr) + else: + data = box(arr) + + func = {"M8[ns]": pd.to_datetime, "m8[ns]": pd.to_timedelta}[dtype] + result = func(arr).array + expected = func(data).array + tm.assert_equal(result, expected) + + # Let's check the Indexes while we're here + idx_cls = {"M8[ns]": DatetimeIndex, "m8[ns]": TimedeltaIndex}[dtype] + result = idx_cls(arr) + expected = idx_cls(data) + tm.assert_index_equal(result, expected) + + +def test_xarray_coerce_unit(): + # GH44053 + xr = pytest.importorskip("xarray") + + arr = xr.DataArray([1, 2, 3]) + result = pd.to_datetime(arr, unit="ns") + expected = DatetimeIndex( + [ + "1970-01-01 00:00:00.000000001", + "1970-01-01 00:00:00.000000002", + "1970-01-01 00:00:00.000000003", + ], + dtype="datetime64[ns]", + freq=None, + ) + tm.assert_index_equal(result, expected) diff --git a/pandas/tests/test_errors.py b/pandas/tests/test_errors.py new file mode 100644 index 0000000000000000000000000000000000000000..a9895e89cbf24c893f4a8801da715e0b491a6964 --- /dev/null +++ b/pandas/tests/test_errors.py @@ -0,0 +1,144 @@ +import warnings + +import pytest + +from pandas.errors import ( + AbstractMethodError, + Pandas4Warning, + Pandas5Warning, + PandasChangeWarning, + PandasDeprecationWarning, + PandasPendingDeprecationWarning, + UndefinedVariableError, +) + +import pandas as pd +import pandas._testing as tm + + +@pytest.mark.parametrize( + "exc", + [ + "AttributeConflictWarning", + "CSSWarning", + "CategoricalConversionWarning", + "ClosedFileError", + "DataError", + "DatabaseError", + "DtypeWarning", + "EmptyDataError", + "IncompatibilityWarning", + "IndexingError", + "InvalidColumnName", + "InvalidComparison", + "InvalidVersion", + "LossySetitemError", + "MergeError", + "NoBufferPresent", + "NumExprClobberingError", + "NumbaUtilError", + "OptionError", + "OutOfBoundsDatetime", + "ParserError", + "ParserWarning", + "PerformanceWarning", + "PossibleDataLossError", + "PossiblePrecisionLoss", + "PyperclipException", + "SpecificationError", + "UnsortedIndexError", + "UnsupportedFunctionCall", + "ValueLabelTypeMismatch", + ], +) +def test_exception_importable(exc): + from pandas import errors + + err = getattr(errors, exc) + assert err is not None + + # check that we can raise on them + + msg = "^$" + + with pytest.raises(err, match=msg): + raise err() + + +def test_catch_oob(): + from pandas import errors + + msg = "Cannot cast 1500-01-01 00:00:00 to unit='ns' without overflow" + with pytest.raises(errors.OutOfBoundsDatetime, match=msg): + pd.Timestamp("15000101").as_unit("ns") + + +@pytest.mark.parametrize("is_local", [True, False]) +def test_catch_undefined_variable_error(is_local): + variable_name = "x" + if is_local: + msg = f"local variable '{variable_name}' is not defined" + else: + msg = f"name '{variable_name}' is not defined" + + with pytest.raises(UndefinedVariableError, match=msg): + raise UndefinedVariableError(variable_name, is_local) + + +class Foo: + @classmethod + def classmethod(cls): + raise AbstractMethodError(cls, methodtype="classmethod") + + @property + def property(self): + raise AbstractMethodError(self, methodtype="property") + + def method(self): + raise AbstractMethodError(self) + + +def test_AbstractMethodError_classmethod(): + xpr = "This classmethod must be defined in the concrete class Foo" + with pytest.raises(AbstractMethodError, match=xpr): + Foo.classmethod() + + xpr = "This property must be defined in the concrete class Foo" + with pytest.raises(AbstractMethodError, match=xpr): + Foo().property + + xpr = "This method must be defined in the concrete class Foo" + with pytest.raises(AbstractMethodError, match=xpr): + Foo().method() + + +@pytest.mark.parametrize( + "warn_category, catch_category", + [ + (Pandas4Warning, PandasChangeWarning), + (Pandas4Warning, PandasDeprecationWarning), + (Pandas5Warning, PandasChangeWarning), + (Pandas5Warning, PandasPendingDeprecationWarning), + ], +) +def test_pandas_warnings(warn_category, catch_category): + # https://github.com/pandas-dev/pandas/pull/61468 + with tm.assert_produces_warning(catch_category): + warnings.warn("test", category=warn_category) + + +@pytest.mark.parametrize( + "warn_category, filter_category", + [ + (Pandas4Warning, PandasChangeWarning), + (Pandas4Warning, PandasDeprecationWarning), + (Pandas5Warning, PandasChangeWarning), + (Pandas5Warning, PandasPendingDeprecationWarning), + ], +) +def test_pandas_warnings_filter(warn_category, filter_category): + # https://github.com/pandas-dev/pandas/pull/61468 + # Ensure users can suppress warnings. + with tm.assert_produces_warning(None), warnings.catch_warnings(): + warnings.filterwarnings(category=filter_category, action="ignore") + warnings.warn("test", category=warn_category) diff --git a/pandas/tests/test_expressions.py b/pandas/tests/test_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..11a87f5e353c64bfca6b4642bcc60d42ae3b36d8 --- /dev/null +++ b/pandas/tests/test_expressions.py @@ -0,0 +1,475 @@ +import operator +import re + +import numpy as np +import pytest + +from pandas.compat._optional import import_optional_dependency + +from pandas import option_context +import pandas._testing as tm +from pandas.core.api import DataFrame +from pandas.core.computation import expressions as expr +from pandas.util.version import Version + + +@pytest.fixture +def _frame(): + return DataFrame( + np.random.default_rng(2).standard_normal((10001, 4)), + columns=list("ABCD"), + dtype="float64", + ) + + +@pytest.fixture +def _frame2(): + return DataFrame( + np.random.default_rng(2).standard_normal((100, 4)), + columns=list("ABCD"), + dtype="float64", + ) + + +@pytest.fixture +def _mixed(_frame): + return DataFrame( + { + "A": _frame["A"], + "B": _frame["B"].astype("float32"), + "C": _frame["C"].astype("int64"), + "D": _frame["D"].astype("int32"), + } + ) + + +@pytest.fixture +def _mixed2(_frame2): + return DataFrame( + { + "A": _frame2["A"], + "B": _frame2["B"].astype("float32"), + "C": _frame2["C"].astype("int64"), + "D": _frame2["D"].astype("int32"), + } + ) + + +@pytest.fixture +def _integer(): + return DataFrame( + np.random.default_rng(2).integers(1, 100, size=(10001, 4)), + columns=list("ABCD"), + dtype="int64", + ) + + +@pytest.fixture +def _integer_integers(_integer): + # integers to get a case with zeros + return _integer * np.random.default_rng(2).integers(0, 2, size=np.shape(_integer)) + + +@pytest.fixture +def _integer2(): + return DataFrame( + np.random.default_rng(2).integers(1, 100, size=(101, 4)), + columns=list("ABCD"), + dtype="int64", + ) + + +@pytest.fixture +def _array(_frame): + return _frame["A"].to_numpy() + + +@pytest.fixture +def _array2(_frame2): + return _frame2["A"].to_numpy() + + +@pytest.fixture +def _array_mixed(_mixed): + return _mixed["D"].to_numpy() + + +@pytest.fixture +def _array_mixed2(_mixed2): + return _mixed2["D"].to_numpy() + + +@pytest.mark.skipif(not expr.USE_NUMEXPR, reason="not using numexpr") +class TestExpressions: + @staticmethod + def call_op(df, other, flex: bool, opname: str): + if flex: + op = lambda x, y: getattr(x, opname)(y) + op.__name__ = opname + else: + op = getattr(operator, opname) + + with option_context("compute.use_numexpr", False): + expected = op(df, other) + + expr.get_test_result() + + result = op(df, other) + return result, expected + + @pytest.mark.parametrize( + "fixture", + [ + "_integer", + "_integer2", + "_integer_integers", + "_frame", + "_frame2", + "_mixed", + "_mixed2", + ], + ) + @pytest.mark.parametrize("flex", [True, False]) + @pytest.mark.parametrize( + "arith", ["add", "sub", "mul", "mod", "truediv", "floordiv"] + ) + def test_run_arithmetic(self, request, fixture, flex, arith, monkeypatch): + df = request.getfixturevalue(fixture) + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", 0) + result, expected = self.call_op(df, df, flex, arith) + + if arith == "truediv": + assert all(x.kind == "f" for x in expected.dtypes.values) + tm.assert_equal(expected, result) + + for i in range(len(df.columns)): + result, expected = self.call_op( + df.iloc[:, i], df.iloc[:, i], flex, arith + ) + if arith == "truediv": + assert expected.dtype.kind == "f" + tm.assert_equal(expected, result) + + @pytest.mark.parametrize( + "fixture", + [ + "_integer", + "_integer2", + "_integer_integers", + "_frame", + "_frame2", + "_mixed", + "_mixed2", + ], + ) + @pytest.mark.parametrize("flex", [True, False]) + def test_run_binary(self, request, fixture, flex, comparison_op, monkeypatch): + """ + tests solely that the result is the same whether or not numexpr is + enabled. Need to test whether the function does the correct thing + elsewhere. + """ + df = request.getfixturevalue(fixture) + arith = comparison_op.__name__ + with option_context("compute.use_numexpr", False): + other = df + 1 + + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", 0) + expr.set_test_mode(True) + + result, expected = self.call_op(df, other, flex, arith) + + used_numexpr = expr.get_test_result() + assert used_numexpr, "Did not use numexpr as expected." + tm.assert_equal(expected, result) + + for i in range(len(df.columns)): + binary_comp = other.iloc[:, i] + 1 + self.call_op(df.iloc[:, i], binary_comp, flex, "add") + + def test_invalid(self): + array = np.random.default_rng(2).standard_normal(1_000_001) + array2 = np.random.default_rng(2).standard_normal(100) + + # no op + result = expr._can_use_numexpr(operator.add, None, array, array, "evaluate") + assert not result + + # min elements + result = expr._can_use_numexpr(operator.add, "+", array2, array2, "evaluate") + assert not result + + # ok, we only check on first part of expression + result = expr._can_use_numexpr(operator.add, "+", array, array2, "evaluate") + assert result + + @pytest.mark.filterwarnings("ignore:invalid value encountered in:RuntimeWarning") + @pytest.mark.parametrize( + "opname,op_str", + [("add", "+"), ("sub", "-"), ("mul", "*"), ("truediv", "/"), ("pow", "**")], + ) + @pytest.mark.parametrize( + "left_fix,right_fix", [("_array", "_array2"), ("_array_mixed", "_array_mixed2")] + ) + def test_binary_ops(self, request, opname, op_str, left_fix, right_fix): + left = request.getfixturevalue(left_fix) + right = request.getfixturevalue(right_fix) + + def testit(left, right, opname, op_str): + if opname == "pow": + left = np.abs(left) + + op = getattr(operator, opname) + + # array has 0s + result = expr.evaluate(op, left, left, use_numexpr=True) + expected = expr.evaluate(op, left, left, use_numexpr=False) + tm.assert_numpy_array_equal(result, expected) + + result = expr._can_use_numexpr(op, op_str, right, right, "evaluate") + assert not result + + with option_context("compute.use_numexpr", False): + testit(left, right, opname, op_str) + + expr.set_numexpr_threads(1) + testit(left, right, opname, op_str) + expr.set_numexpr_threads() + testit(left, right, opname, op_str) + + @pytest.mark.parametrize( + "left_fix,right_fix", [("_array", "_array2"), ("_array_mixed", "_array_mixed2")] + ) + def test_comparison_ops(self, request, comparison_op, left_fix, right_fix): + left = request.getfixturevalue(left_fix) + right = request.getfixturevalue(right_fix) + + def testit(): + f12 = left + 1 + f22 = right + 1 + + op = comparison_op + + result = expr.evaluate(op, left, f12, use_numexpr=True) + expected = expr.evaluate(op, left, f12, use_numexpr=False) + tm.assert_numpy_array_equal(result, expected) + + result = expr._can_use_numexpr(op, op, right, f22, "evaluate") + assert not result + + with option_context("compute.use_numexpr", False): + testit() + + expr.set_numexpr_threads(1) + testit() + expr.set_numexpr_threads() + testit() + + @pytest.mark.parametrize("cond", [True, False]) + @pytest.mark.parametrize("fixture", ["_frame", "_frame2", "_mixed", "_mixed2"]) + def test_where(self, request, cond, fixture): + df = request.getfixturevalue(fixture) + + def testit(): + c = np.empty(df.shape, dtype=np.bool_) + c.fill(cond) + result = expr.where(c, df.values, df.values + 1) + expected = np.where(c, df.values, df.values + 1) + tm.assert_numpy_array_equal(result, expected) + + with option_context("compute.use_numexpr", False): + testit() + + expr.set_numexpr_threads(1) + testit() + expr.set_numexpr_threads() + testit() + + @pytest.mark.parametrize( + "op_str,opname", [("/", "truediv"), ("//", "floordiv"), ("**", "pow")] + ) + def test_bool_ops_raise_on_arithmetic(self, op_str, opname): + df = DataFrame( + { + "a": np.random.default_rng(2).random(10) > 0.5, + "b": np.random.default_rng(2).random(10) > 0.5, + } + ) + + msg = f"operator '{opname}' not implemented for bool dtypes" + f = getattr(operator, opname) + err_msg = re.escape(msg) + + with pytest.raises(NotImplementedError, match=err_msg): + f(df, df) + + with pytest.raises(NotImplementedError, match=err_msg): + f(df.a, df.b) + + with pytest.raises(NotImplementedError, match=err_msg): + f(df.a, True) + + with pytest.raises(NotImplementedError, match=err_msg): + f(False, df.a) + + with pytest.raises(NotImplementedError, match=err_msg): + f(False, df) + + with pytest.raises(NotImplementedError, match=err_msg): + f(df, True) + + @pytest.mark.parametrize( + "op_str,opname", [("+", "add"), ("*", "mul"), ("-", "sub")] + ) + def test_bool_ops_warn_on_arithmetic(self, op_str, opname, monkeypatch): + n = 10 + df = DataFrame( + { + "a": np.random.default_rng(2).random(n) > 0.5, + "b": np.random.default_rng(2).random(n) > 0.5, + } + ) + + subs = {"+": "|", "*": "&", "-": "^"} + sub_funcs = {"|": "or_", "&": "and_", "^": "xor"} + + f = getattr(operator, opname) + fe = getattr(operator, sub_funcs[subs[op_str]]) + + if op_str == "-": + # raises TypeError + return + + msg = "operator is not supported by numexpr" + ne = import_optional_dependency("numexpr", errors="ignore") + warning = ( + UserWarning + if ne + and op_str in {"+", "*"} + and Version(ne.__version__) < Version("2.13.1") + else None + ) + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", 5) + with option_context("compute.use_numexpr", True): + with tm.assert_produces_warning(warning, match=msg): + r = f(df, df) + e = fe(df, df) + tm.assert_frame_equal(r, e) + + with tm.assert_produces_warning(warning, match=msg): + r = f(df.a, df.b) + e = fe(df.a, df.b) + tm.assert_series_equal(r, e) + + with tm.assert_produces_warning(warning, match=msg): + r = f(df.a, True) + e = fe(df.a, True) + tm.assert_series_equal(r, e) + + with tm.assert_produces_warning(warning, match=msg): + r = f(False, df.a) + e = fe(False, df.a) + tm.assert_series_equal(r, e) + + with tm.assert_produces_warning(warning, match=msg): + r = f(False, df) + e = fe(False, df) + tm.assert_frame_equal(r, e) + + with tm.assert_produces_warning(warning, match=msg): + r = f(df, True) + e = fe(df, True) + tm.assert_frame_equal(r, e) + + @pytest.mark.parametrize( + "test_input,expected", + [ + ( + DataFrame( + [[0, 1, 2, "aa"], [0, 1, 2, "aa"]], columns=["a", "b", "c", "dtype"] + ), + DataFrame([[False, False], [False, False]], columns=["a", "dtype"]), + ), + ( + DataFrame( + [[0, 3, 2, "aa"], [0, 4, 2, "aa"], [0, 1, 1, "bb"]], + columns=["a", "b", "c", "dtype"], + ), + DataFrame( + [[False, False], [False, False], [False, False]], + columns=["a", "dtype"], + ), + ), + ], + ) + def test_bool_ops_column_name_dtype(self, test_input, expected): + # GH 22383 - .ne fails if columns containing column name 'dtype' + result = test_input.loc[:, ["a", "dtype"]].ne(test_input.loc[:, ["a", "dtype"]]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "arith", ("add", "sub", "mul", "mod", "truediv", "floordiv") + ) + @pytest.mark.parametrize("axis", (0, 1)) + def test_frame_series_axis(self, axis, arith, _frame, monkeypatch): + # GH#26736 Dataframe.floordiv(Series, axis=1) fails + + df = _frame + if axis == 1: + other = df.iloc[0, :] + else: + other = df.iloc[:, 0] + + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", 0) + + op_func = getattr(df, arith) + + with option_context("compute.use_numexpr", False): + expected = op_func(other, axis=axis) + + result = op_func(other, axis=axis) + tm.assert_frame_equal(expected, result) + + @pytest.mark.parametrize( + "op", + [ + "__mod__", + "__rmod__", + "__floordiv__", + "__rfloordiv__", + ], + ) + @pytest.mark.parametrize("scalar", [-5, 5]) + def test_python_semantics_with_numexpr_installed( + self, op, box_with_array, scalar, monkeypatch + ): + # https://github.com/pandas-dev/pandas/issues/36047 + with monkeypatch.context() as m: + m.setattr(expr, "_MIN_ELEMENTS", 0) + data = np.arange(-50, 50) + obj = box_with_array(data) + method = getattr(obj, op) + result = method(scalar) + + # compare result with numpy + with option_context("compute.use_numexpr", False): + expected = method(scalar) + + tm.assert_equal(result, expected) + + # compare result element-wise with Python + for i, elem in enumerate(data): + if box_with_array == DataFrame: + scalar_result = result.iloc[i, 0] + else: + scalar_result = result[i] + try: + expected = getattr(int(elem), op)(scalar) + except ZeroDivisionError: + pass + else: + assert scalar_result == expected diff --git a/pandas/tests/test_flags.py b/pandas/tests/test_flags.py new file mode 100644 index 0000000000000000000000000000000000000000..9294b3fc3319b78b59d5637acdf3fd75737cd836 --- /dev/null +++ b/pandas/tests/test_flags.py @@ -0,0 +1,48 @@ +import pytest + +import pandas as pd + + +class TestFlags: + def test_equality(self): + a = pd.DataFrame().set_flags(allows_duplicate_labels=True).flags + b = pd.DataFrame().set_flags(allows_duplicate_labels=False).flags + + assert a == a + assert b == b + assert a != b + assert a != 2 + + def test_set(self): + df = pd.DataFrame().set_flags(allows_duplicate_labels=True) + a = df.flags + a.allows_duplicate_labels = False + assert a.allows_duplicate_labels is False + a["allows_duplicate_labels"] = True + assert a.allows_duplicate_labels is True + + def test_repr(self): + a = repr(pd.DataFrame({"A"}).set_flags(allows_duplicate_labels=True).flags) + assert a == "" + a = repr(pd.DataFrame({"A"}).set_flags(allows_duplicate_labels=False).flags) + assert a == "" + + def test_obj_ref(self): + df = pd.DataFrame() + flags = df.flags + del df + with pytest.raises(ValueError, match="object has been deleted"): + flags.allows_duplicate_labels = True + + def test_getitem(self): + df = pd.DataFrame() + flags = df.flags + assert flags["allows_duplicate_labels"] is True + flags["allows_duplicate_labels"] = False + assert flags["allows_duplicate_labels"] is False + + with pytest.raises(KeyError, match="a"): + flags["a"] + + with pytest.raises(ValueError, match="a"): + flags["a"] = 10 diff --git a/pandas/tests/test_multilevel.py b/pandas/tests/test_multilevel.py new file mode 100644 index 0000000000000000000000000000000000000000..ff7ab22c197d8467bf19ed7a42cdbd9305cc88b5 --- /dev/null +++ b/pandas/tests/test_multilevel.py @@ -0,0 +1,376 @@ +import datetime + +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + ArrowDtype, + DataFrame, + MultiIndex, + Series, +) +import pandas._testing as tm + + +class TestMultiLevel: + def test_reindex_level(self, multiindex_year_month_day_dataframe_random_data): + # axis=0 + ymd = multiindex_year_month_day_dataframe_random_data + + month_sums = ymd.groupby("month").sum() + result = month_sums.reindex(ymd.index, level=1) + expected = ymd.groupby(level="month").transform("sum") + + tm.assert_frame_equal(result, expected) + + # Series + result = month_sums["A"].reindex(ymd.index, level=1) + expected = ymd["A"].groupby(level="month").transform("sum") + tm.assert_series_equal(result, expected, check_names=False) + + def test_reindex(self, multiindex_dataframe_random_data): + frame = multiindex_dataframe_random_data + + expected = frame.iloc[[0, 3]] + reindexed = frame.loc[[("foo", "one"), ("bar", "one")]] + tm.assert_frame_equal(reindexed, expected) + + def test_reindex_preserve_levels( + self, multiindex_year_month_day_dataframe_random_data + ): + ymd = multiindex_year_month_day_dataframe_random_data + + new_index = ymd.index[::10] + chunk = ymd.reindex(new_index) + assert chunk.index.is_(new_index) + + chunk = ymd.loc[new_index] + assert chunk.index.equals(new_index) + + ymdT = ymd.T + chunk = ymdT.reindex(columns=new_index) + assert chunk.columns.is_(new_index) + + chunk = ymdT.loc[:, new_index] + assert chunk.columns.equals(new_index) + + def test_groupby_transform(self, multiindex_dataframe_random_data): + frame = multiindex_dataframe_random_data + + s = frame["A"] + grouper = s.index.get_level_values(0) + + grouped = s.groupby(grouper, group_keys=False) + + applied = grouped.apply(lambda x: x * 2) + expected = grouped.transform(lambda x: x * 2) + result = applied.reindex(expected.index) + tm.assert_series_equal(result, expected, check_names=False) + + def test_groupby_corner(self): + midx = MultiIndex( + levels=[["foo"], ["bar"], ["baz"]], + codes=[[0], [0], [0]], + names=["one", "two", "three"], + ) + df = DataFrame( + [np.random.default_rng(2).random(4)], + columns=["a", "b", "c", "d"], + index=midx, + ) + # should work + df.groupby(level="three") + + def test_setitem_with_expansion_multiindex_columns( + self, multiindex_year_month_day_dataframe_random_data + ): + ymd = multiindex_year_month_day_dataframe_random_data + + df = ymd[:5].T + df[2000, 1, 10] = df[2000, 1, 7] + assert isinstance(df.columns, MultiIndex) + assert (df[2000, 1, 10] == df[2000, 1, 7]).all() + + def test_alignment(self): + x = Series( + data=[1, 2, 3], index=MultiIndex.from_tuples([("A", 1), ("A", 2), ("B", 3)]) + ) + + y = Series( + data=[4, 5, 6], index=MultiIndex.from_tuples([("Z", 1), ("Z", 2), ("B", 3)]) + ) + + res = x - y + exp_index = x.index.union(y.index) + exp = x.reindex(exp_index) - y.reindex(exp_index) + tm.assert_series_equal(res, exp) + + # hit non-monotonic code path + res = x[::-1] - y[::-1] + exp_index = x.index.union(y.index) + exp = x.reindex(exp_index) - y.reindex(exp_index) + tm.assert_series_equal(res, exp) + + def test_groupby_multilevel(self, multiindex_year_month_day_dataframe_random_data): + ymd = multiindex_year_month_day_dataframe_random_data + + result = ymd.groupby(level=[0, 1]).mean() + + k1 = ymd.index.get_level_values(0) + k2 = ymd.index.get_level_values(1) + + expected = ymd.groupby([k1, k2]).mean() + + tm.assert_frame_equal(result, expected) + assert result.index.names == ymd.index.names[:2] + + result2 = ymd.groupby(level=ymd.index.names[:2]).mean() + tm.assert_frame_equal(result, result2) + + def test_multilevel_consolidate(self): + index = MultiIndex.from_tuples( + [("foo", "one"), ("foo", "two"), ("bar", "one"), ("bar", "two")] + ) + df = DataFrame( + np.random.default_rng(2).standard_normal((4, 4)), index=index, columns=index + ) + df["Totals", ""] = df.sum(axis=1) + df = df._consolidate() + + def test_level_with_tuples(self): + index = MultiIndex( + levels=[[("foo", "bar", 0), ("foo", "baz", 0), ("foo", "qux", 0)], [0, 1]], + codes=[[0, 0, 1, 1, 2, 2], [0, 1, 0, 1, 0, 1]], + ) + + series = Series(np.random.default_rng(2).standard_normal(6), index=index) + frame = DataFrame(np.random.default_rng(2).standard_normal((6, 4)), index=index) + + result = series[("foo", "bar", 0)] + result2 = series.loc[("foo", "bar", 0)] + expected = series[:2] + expected.index = expected.index.droplevel(0) + tm.assert_series_equal(result, expected) + tm.assert_series_equal(result2, expected) + + with pytest.raises(KeyError, match=r"^\(\('foo', 'bar', 0\), 2\)$"): + series[("foo", "bar", 0), 2] + + result = frame.loc[("foo", "bar", 0)] + result2 = frame.xs(("foo", "bar", 0)) + expected = frame[:2] + expected.index = expected.index.droplevel(0) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(result2, expected) + + index = MultiIndex( + levels=[[("foo", "bar"), ("foo", "baz"), ("foo", "qux")], [0, 1]], + codes=[[0, 0, 1, 1, 2, 2], [0, 1, 0, 1, 0, 1]], + ) + + series = Series(np.random.default_rng(2).standard_normal(6), index=index) + frame = DataFrame(np.random.default_rng(2).standard_normal((6, 4)), index=index) + + result = series[("foo", "bar")] + result2 = series.loc[("foo", "bar")] + expected = series[:2] + expected.index = expected.index.droplevel(0) + tm.assert_series_equal(result, expected) + tm.assert_series_equal(result2, expected) + + result = frame.loc[("foo", "bar")] + result2 = frame.xs(("foo", "bar")) + expected = frame[:2] + expected.index = expected.index.droplevel(0) + tm.assert_frame_equal(result, expected) + tm.assert_frame_equal(result2, expected) + + def test_reindex_level_partial_selection(self, multiindex_dataframe_random_data): + frame = multiindex_dataframe_random_data + + result = frame.reindex(["foo", "qux"], level=0) + expected = frame.iloc[[0, 1, 2, 7, 8, 9]] + tm.assert_frame_equal(result, expected) + + result = frame.T.reindex(["foo", "qux"], axis=1, level=0) + tm.assert_frame_equal(result, expected.T) + + result = frame.loc[["foo", "qux"]] + tm.assert_frame_equal(result, expected) + + result = frame["A"].loc[["foo", "qux"]] + tm.assert_series_equal(result, expected["A"]) + + result = frame.T.loc[:, ["foo", "qux"]] + tm.assert_frame_equal(result, expected.T) + + @pytest.mark.parametrize("d", [4, "d"]) + def test_empty_frame_groupby_dtypes_consistency(self, d): + # GH 20888 + group_keys = ["a", "b", "c"] + df = DataFrame({"a": [1], "b": [2], "c": [3], "d": [d]}) + + g = df[df.a == 2].groupby(group_keys) + result = g.first().index + expected = MultiIndex( + levels=[[1], [2], [3]], codes=[[], [], []], names=["a", "b", "c"] + ) + + tm.assert_index_equal(result, expected) + + def test_duplicate_groupby_issues(self): + idx_tp = [ + ("600809", "20061231"), + ("600809", "20070331"), + ("600809", "20070630"), + ("600809", "20070331"), + ] + dt = ["demo", "demo", "demo", "demo"] + + idx = MultiIndex.from_tuples(idx_tp, names=["STK_ID", "RPT_Date"]) + s = Series(dt, index=idx) + + result = s.groupby(s.index).first() + assert len(result) == 3 + + def test_subsets_multiindex_dtype(self): + # GH 20757 + data = [["x", 1]] + columns = [("a", "b", np.nan), ("a", "c", 0.0)] + df = DataFrame(data, columns=MultiIndex.from_tuples(columns)) + expected = df.dtypes.a.b + result = df.a.b.dtypes + tm.assert_series_equal(result, expected) + + def test_datetime_object_multiindex(self): + data_dic = { + (0, datetime.date(2018, 3, 3)): {"A": 1, "B": 10}, + (0, datetime.date(2018, 3, 4)): {"A": 2, "B": 11}, + (1, datetime.date(2018, 3, 3)): {"A": 3, "B": 12}, + (1, datetime.date(2018, 3, 4)): {"A": 4, "B": 13}, + } + result = DataFrame.from_dict(data_dic, orient="index") + data = {"A": [1, 2, 3, 4], "B": [10, 11, 12, 13]} + index = [ + [0, 0, 1, 1], + [ + datetime.date(2018, 3, 3), + datetime.date(2018, 3, 4), + datetime.date(2018, 3, 3), + datetime.date(2018, 3, 4), + ], + ] + expected = DataFrame(data=data, index=index) + + tm.assert_frame_equal(result, expected) + + def test_multiindex_with_na(self): + df = DataFrame( + [ + ["A", np.nan, 1.23, 4.56], + ["A", "G", 1.23, 4.56], + ["A", "D", 9.87, 10.54], + ], + columns=["pivot_0", "pivot_1", "col_1", "col_2"], + ).set_index(["pivot_0", "pivot_1"]) + + df.at[("A", "F"), "col_2"] = 0.0 + + expected = DataFrame( + [ + ["A", np.nan, 1.23, 4.56], + ["A", "G", 1.23, 4.56], + ["A", "D", 9.87, 10.54], + ["A", "F", np.nan, 0.0], + ], + columns=["pivot_0", "pivot_1", "col_1", "col_2"], + ).set_index(["pivot_0", "pivot_1"]) + + tm.assert_frame_equal(df, expected) + + @pytest.mark.parametrize("na", [None, np.nan]) + def test_multiindex_insert_level_with_na(self, na): + # GH 59003 + df = DataFrame([0], columns=[["A"], ["B"]]) + df[na, "B"] = 1 + tm.assert_frame_equal(df[na], DataFrame([1], columns=["B"])) + + def test_multiindex_dt_with_nan(self): + # GH#60388 + df = DataFrame( + [ + [1, np.nan, 5, np.nan], + [2, np.nan, 6, np.nan], + [np.nan, 3, np.nan, 7], + [np.nan, 4, np.nan, 8], + ], + index=Series(["a", "b", "c", "d"], dtype=object, name="sub"), + columns=MultiIndex.from_product( + [ + ["value1", "value2"], + [datetime.datetime(2024, 11, 1), datetime.datetime(2024, 11, 2)], + ], + names=[None, "Date"], + ), + ) + df = df.reset_index() + result = df[df.columns[0]] + expected = Series(["a", "b", "c", "d"], name=("sub", np.nan)) + tm.assert_series_equal(result, expected) + + @pytest.mark.filterwarnings("ignore:Passing a BlockManager:DeprecationWarning") + def test_multiindex_with_pyarrow_categorical(self): + # GH#53051 + pa = pytest.importorskip("pyarrow") + + df = DataFrame( + {"string_column": ["A", "B", "C"], "number_column": [1, 2, 3]} + ).astype( + { + "string_column": ArrowDtype(pa.dictionary(pa.int32(), pa.string())), + "number_column": "float[pyarrow]", + } + ) + + df = df.set_index(["string_column", "number_column"]) + + df_expected = DataFrame( + index=MultiIndex.from_arrays( + [["A", "B", "C"], [1, 2, 3]], names=["string_column", "number_column"] + ) + ) + tm.assert_frame_equal( + df, + df_expected, + check_index_type=False, + check_column_type=False, + ) + + +class TestSorted: + """everything you wanted to test about sorting""" + + def test_sort_non_lexsorted(self): + # degenerate case where we sort but don't + # have a satisfying result :< + # GH 15797 + idx = MultiIndex( + [["A", "B", "C"], ["c", "b", "a"]], [[0, 1, 2, 0, 1, 2], [0, 2, 1, 1, 0, 2]] + ) + + df = DataFrame({"col": range(len(idx))}, index=idx, dtype="int64") + assert df.index.is_monotonic_increasing is False + + sorted = df.sort_index() + assert sorted.index.is_monotonic_increasing is True + + expected = DataFrame( + {"col": [1, 4, 5, 2]}, + index=MultiIndex.from_tuples( + [("B", "a"), ("B", "c"), ("C", "a"), ("C", "b")] + ), + dtype="int64", + ) + result = sorted.loc[pd.IndexSlice["B":"C", "a":"c"], :] + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/test_nanops.py b/pandas/tests/test_nanops.py new file mode 100644 index 0000000000000000000000000000000000000000..531019e7222c75df5874de0175519c9416f940eb --- /dev/null +++ b/pandas/tests/test_nanops.py @@ -0,0 +1,1319 @@ +from functools import partial + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import is_integer_dtype + +import pandas as pd +from pandas import ( + Series, + isna, +) +import pandas._testing as tm +from pandas.core import nanops + +use_bn = nanops._USE_BOTTLENECK + + +@pytest.fixture +def disable_bottleneck(monkeypatch): + with monkeypatch.context() as m: + m.setattr(nanops, "_USE_BOTTLENECK", False) + yield + + +@pytest.fixture +def arr_shape(): + return 11, 7 + + +@pytest.fixture +def arr_float(arr_shape): + return np.random.default_rng(2).standard_normal(arr_shape) + + +@pytest.fixture +def arr_complex(arr_float): + return arr_float + arr_float * 1j + + +@pytest.fixture +def arr_int(arr_shape): + return np.random.default_rng(2).integers(-10, 10, arr_shape) + + +@pytest.fixture +def arr_bool(arr_shape): + return np.random.default_rng(2).integers(0, 2, arr_shape) == 0 + + +@pytest.fixture +def arr_str(arr_float): + return np.abs(arr_float).astype("S") + + +@pytest.fixture +def arr_utf(arr_float): + return np.abs(arr_float).astype("U") + + +@pytest.fixture +def arr_date(arr_shape): + return np.random.default_rng(2).integers(0, 20000, arr_shape).astype("M8[ns]") + + +@pytest.fixture +def arr_tdelta(arr_shape): + return np.random.default_rng(2).integers(0, 20000, arr_shape).astype("m8[ns]") + + +@pytest.fixture +def arr_nan(arr_shape): + return np.tile(np.nan, arr_shape) + + +@pytest.fixture +def arr_float_nan(arr_float, arr_nan): + return np.vstack([arr_float, arr_nan]) + + +@pytest.fixture +def arr_nan_float1(arr_nan, arr_float): + return np.vstack([arr_nan, arr_float]) + + +@pytest.fixture +def arr_nan_nan(arr_nan): + return np.vstack([arr_nan, arr_nan]) + + +@pytest.fixture +def arr_inf(arr_float): + return arr_float * np.inf + + +@pytest.fixture +def arr_float_inf(arr_float, arr_inf): + return np.vstack([arr_float, arr_inf]) + + +@pytest.fixture +def arr_nan_inf(arr_nan, arr_inf): + return np.vstack([arr_nan, arr_inf]) + + +@pytest.fixture +def arr_float_nan_inf(arr_float, arr_nan, arr_inf): + return np.vstack([arr_float, arr_nan, arr_inf]) + + +@pytest.fixture +def arr_nan_nan_inf(arr_nan, arr_inf): + return np.vstack([arr_nan, arr_nan, arr_inf]) + + +@pytest.fixture +def arr_obj( + arr_float, arr_int, arr_bool, arr_complex, arr_str, arr_utf, arr_date, arr_tdelta +): + return np.vstack( + [ + arr_float.astype("O"), + arr_int.astype("O"), + arr_bool.astype("O"), + arr_complex.astype("O"), + arr_str.astype("O"), + arr_utf.astype("O"), + arr_date.astype("O"), + arr_tdelta.astype("O"), + ] + ) + + +@pytest.fixture +def arr_nan_nanj(arr_nan): + with np.errstate(invalid="ignore"): + return arr_nan + arr_nan * 1j + + +@pytest.fixture +def arr_complex_nan(arr_complex, arr_nan_nanj): + with np.errstate(invalid="ignore"): + return np.vstack([arr_complex, arr_nan_nanj]) + + +@pytest.fixture +def arr_nan_infj(arr_inf): + with np.errstate(invalid="ignore"): + return arr_inf * 1j + + +@pytest.fixture +def arr_complex_nan_infj(arr_complex, arr_nan_infj): + with np.errstate(invalid="ignore"): + return np.vstack([arr_complex, arr_nan_infj]) + + +@pytest.fixture +def arr_float_1d(arr_float): + return arr_float[:, 0] + + +@pytest.fixture +def arr_nan_1d(arr_nan): + return arr_nan[:, 0] + + +@pytest.fixture +def arr_float_nan_1d(arr_float_nan): + return arr_float_nan[:, 0] + + +@pytest.fixture +def arr_float1_nan_1d(arr_float1_nan): + return arr_float1_nan[:, 0] + + +@pytest.fixture +def arr_nan_float1_1d(arr_nan_float1): + return arr_nan_float1[:, 0] + + +class TestnanopsDataFrame: + def setup_method(self): + nanops._USE_BOTTLENECK = False + + arr_shape = (11, 7) + + self.arr_float = np.random.default_rng(2).standard_normal(arr_shape) + self.arr_float1 = np.random.default_rng(2).standard_normal(arr_shape) + self.arr_complex = self.arr_float + self.arr_float1 * 1j + self.arr_int = np.random.default_rng(2).integers(-10, 10, arr_shape) + self.arr_bool = np.random.default_rng(2).integers(0, 2, arr_shape) == 0 + self.arr_str = np.abs(self.arr_float).astype("S") + self.arr_utf = np.abs(self.arr_float).astype("U") + self.arr_date = ( + np.random.default_rng(2).integers(0, 20000, arr_shape).astype("M8[ns]") + ) + self.arr_tdelta = ( + np.random.default_rng(2).integers(0, 20000, arr_shape).astype("m8[ns]") + ) + + self.arr_nan = np.tile(np.nan, arr_shape) + self.arr_float_nan = np.vstack([self.arr_float, self.arr_nan]) + self.arr_float1_nan = np.vstack([self.arr_float1, self.arr_nan]) + self.arr_nan_float1 = np.vstack([self.arr_nan, self.arr_float1]) + self.arr_nan_nan = np.vstack([self.arr_nan, self.arr_nan]) + + self.arr_inf = self.arr_float * np.inf + self.arr_float_inf = np.vstack([self.arr_float, self.arr_inf]) + + self.arr_nan_inf = np.vstack([self.arr_nan, self.arr_inf]) + self.arr_float_nan_inf = np.vstack([self.arr_float, self.arr_nan, self.arr_inf]) + self.arr_nan_nan_inf = np.vstack([self.arr_nan, self.arr_nan, self.arr_inf]) + self.arr_obj = np.vstack( + [ + self.arr_float.astype("O"), + self.arr_int.astype("O"), + self.arr_bool.astype("O"), + self.arr_complex.astype("O"), + self.arr_str.astype("O"), + self.arr_utf.astype("O"), + self.arr_date.astype("O"), + self.arr_tdelta.astype("O"), + ] + ) + + with np.errstate(invalid="ignore"): + self.arr_nan_nanj = self.arr_nan + self.arr_nan * 1j + self.arr_complex_nan = np.vstack([self.arr_complex, self.arr_nan_nanj]) + + self.arr_nan_infj = self.arr_inf * 1j + self.arr_complex_nan_infj = np.vstack([self.arr_complex, self.arr_nan_infj]) + + self.arr_float_2d = self.arr_float + self.arr_float1_2d = self.arr_float1 + + self.arr_nan_2d = self.arr_nan + self.arr_float_nan_2d = self.arr_float_nan + self.arr_float1_nan_2d = self.arr_float1_nan + self.arr_nan_float1_2d = self.arr_nan_float1 + + self.arr_float_1d = self.arr_float[:, 0] + self.arr_float1_1d = self.arr_float1[:, 0] + + self.arr_nan_1d = self.arr_nan[:, 0] + self.arr_float_nan_1d = self.arr_float_nan[:, 0] + self.arr_float1_nan_1d = self.arr_float1_nan[:, 0] + self.arr_nan_float1_1d = self.arr_nan_float1[:, 0] + + def teardown_method(self): + nanops._USE_BOTTLENECK = use_bn + + def check_results(self, targ, res, axis, check_dtype=True): + res = getattr(res, "asm8", res) + + if ( + axis != 0 + and hasattr(targ, "shape") + and targ.ndim + and targ.shape != res.shape + ): + res = np.split(res, [targ.shape[0]], axis=0)[0] + + try: + tm.assert_almost_equal(targ, res, check_dtype=check_dtype) + except AssertionError: + # handle timedelta dtypes + if hasattr(targ, "dtype") and targ.dtype == "m8[ns]": + raise + + # There are sometimes rounding errors with + # complex and object dtypes. + # If it isn't one of those, re-raise the error. + if not hasattr(res, "dtype") or res.dtype.kind not in ["c", "O"]: + raise + # convert object dtypes to something that can be split into + # real and imaginary parts + if res.dtype.kind == "O": + if targ.dtype.kind != "O": + res = res.astype(targ.dtype) + else: + cast_dtype = "c16" if hasattr(np, "complex128") else "f8" + res = res.astype(cast_dtype) + targ = targ.astype(cast_dtype) + # there should never be a case where numpy returns an object + # but nanops doesn't, so make that an exception + elif targ.dtype.kind == "O": + raise + tm.assert_almost_equal(np.real(targ), np.real(res), check_dtype=check_dtype) + tm.assert_almost_equal(np.imag(targ), np.imag(res), check_dtype=check_dtype) + + def check_fun_data( + self, + testfunc, + targfunc, + testar, + testarval, + targarval, + skipna, + check_dtype=True, + empty_targfunc=None, + **kwargs, + ): + for axis in [*list(range(targarval.ndim)), None]: + targartempval = targarval if skipna else testarval + if skipna and empty_targfunc and isna(targartempval).all(): + targ = empty_targfunc(targartempval, axis=axis, **kwargs) + else: + targ = targfunc(targartempval, axis=axis, **kwargs) + + if targartempval.dtype == object and ( + targfunc is np.any or targfunc is np.all + ): + # GH#12863 the numpy functions will retain e.g. floatiness + if isinstance(targ, np.ndarray): + targ = targ.astype(bool) + else: + targ = bool(targ) + + if testfunc.__name__ in ["nanargmax", "nanargmin"] and ( + testar.startswith("arr_nan") + or (testar.endswith("nan") and (not skipna or axis == 1)) + ): + with pytest.raises(ValueError, match="Encountered .* NA value"): + testfunc(testarval, axis=axis, skipna=skipna, **kwargs) + return + res = testfunc(testarval, axis=axis, skipna=skipna, **kwargs) + + if ( + isinstance(targ, np.complex128) + and isinstance(res, float) + and np.isnan(targ) + and np.isnan(res) + ): + # GH#18463 + targ = res + + self.check_results(targ, res, axis, check_dtype=check_dtype) + if skipna: + res = testfunc(testarval, axis=axis, **kwargs) + self.check_results(targ, res, axis, check_dtype=check_dtype) + if axis is None: + res = testfunc(testarval, skipna=skipna, **kwargs) + self.check_results(targ, res, axis, check_dtype=check_dtype) + if skipna and axis is None: + res = testfunc(testarval, **kwargs) + self.check_results(targ, res, axis, check_dtype=check_dtype) + + if testarval.ndim <= 1: + return + + # Recurse on lower-dimension + testarval2 = np.take(testarval, 0, axis=-1) + targarval2 = np.take(targarval, 0, axis=-1) + self.check_fun_data( + testfunc, + targfunc, + testar, + testarval2, + targarval2, + skipna=skipna, + check_dtype=check_dtype, + empty_targfunc=empty_targfunc, + **kwargs, + ) + + def check_fun( + self, testfunc, targfunc, testar, skipna, empty_targfunc=None, **kwargs + ): + targar = testar + if testar.endswith("_nan") and hasattr(self, testar[:-4]): + targar = testar[:-4] + + testarval = getattr(self, testar) + targarval = getattr(self, targar) + self.check_fun_data( + testfunc, + targfunc, + testar, + testarval, + targarval, + skipna=skipna, + empty_targfunc=empty_targfunc, + **kwargs, + ) + + def check_funs( + self, + testfunc, + targfunc, + skipna, + allow_complex=True, + allow_all_nan=True, + allow_date=True, + allow_tdelta=True, + allow_obj=True, + **kwargs, + ): + self.check_fun(testfunc, targfunc, "arr_float", skipna, **kwargs) + self.check_fun(testfunc, targfunc, "arr_float_nan", skipna, **kwargs) + self.check_fun(testfunc, targfunc, "arr_int", skipna, **kwargs) + self.check_fun(testfunc, targfunc, "arr_bool", skipna, **kwargs) + objs = [ + self.arr_float.astype("O"), + self.arr_int.astype("O"), + self.arr_bool.astype("O"), + ] + + if allow_all_nan: + self.check_fun(testfunc, targfunc, "arr_nan", skipna, **kwargs) + + if allow_complex: + self.check_fun(testfunc, targfunc, "arr_complex", skipna, **kwargs) + self.check_fun(testfunc, targfunc, "arr_complex_nan", skipna, **kwargs) + if allow_all_nan: + self.check_fun(testfunc, targfunc, "arr_nan_nanj", skipna, **kwargs) + objs += [self.arr_complex.astype("O")] + + if allow_date: + targfunc(self.arr_date) + self.check_fun(testfunc, targfunc, "arr_date", skipna, **kwargs) + objs += [self.arr_date.astype("O")] + + if allow_tdelta: + try: + targfunc(self.arr_tdelta) + except TypeError: + pass + else: + self.check_fun(testfunc, targfunc, "arr_tdelta", skipna, **kwargs) + objs += [self.arr_tdelta.astype("O")] + + if allow_obj: + self.arr_obj = np.vstack(objs) + # some nanops handle object dtypes better than their numpy + # counterparts, so the numpy functions need to be given something + # else + if allow_obj == "convert": + targfunc = partial( + self._badobj_wrap, func=targfunc, allow_complex=allow_complex + ) + self.check_fun(testfunc, targfunc, "arr_obj", skipna, **kwargs) + + def _badobj_wrap(self, value, func, allow_complex=True, **kwargs): + if value.dtype.kind == "O": + if allow_complex: + value = value.astype("c16") + else: + value = value.astype("f8") + return func(value, **kwargs) + + @pytest.mark.parametrize( + "nan_op,np_op", [(nanops.nanany, np.any), (nanops.nanall, np.all)] + ) + def test_nan_funcs(self, nan_op, np_op, skipna): + self.check_funs(nan_op, np_op, skipna, allow_all_nan=False, allow_date=False) + + def test_nansum(self, skipna): + self.check_funs( + nanops.nansum, + np.sum, + skipna, + allow_date=False, + check_dtype=False, + empty_targfunc=np.nansum, + ) + + def test_nanmean(self, skipna): + self.check_funs( + nanops.nanmean, np.mean, skipna, allow_obj=False, allow_date=False + ) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_nanmedian(self, skipna): + self.check_funs( + nanops.nanmedian, + np.median, + skipna, + allow_complex=False, + allow_date=False, + allow_obj="convert", + ) + + @pytest.mark.parametrize("ddof", range(3)) + def test_nanvar(self, ddof, skipna): + self.check_funs( + nanops.nanvar, + np.var, + skipna, + allow_complex=False, + allow_date=False, + allow_obj="convert", + ddof=ddof, + ) + + @pytest.mark.parametrize("ddof", range(3)) + def test_nanstd(self, ddof, skipna): + self.check_funs( + nanops.nanstd, + np.std, + skipna, + allow_complex=False, + allow_date=False, + allow_obj="convert", + ddof=ddof, + ) + + @pytest.mark.parametrize("ddof", range(3)) + def test_nansem(self, ddof, skipna): + sp_stats = pytest.importorskip("scipy.stats") + + with np.errstate(invalid="ignore"): + self.check_funs( + nanops.nansem, + sp_stats.sem, + skipna, + allow_complex=False, + allow_date=False, + allow_tdelta=False, + allow_obj="convert", + ddof=ddof, + ) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + @pytest.mark.parametrize( + "nan_op,np_op", [(nanops.nanmin, np.min), (nanops.nanmax, np.max)] + ) + def test_nanops_with_warnings(self, nan_op, np_op, skipna): + self.check_funs(nan_op, np_op, skipna, allow_obj=False) + + def _argminmax_wrap(self, value, axis=None, func=None): + res = func(value, axis) + nans = np.min(value, axis) + nullnan = isna(nans) + if res.ndim: + res[nullnan] = -1 + elif (hasattr(nullnan, "all") and nullnan.all()) or ( + not hasattr(nullnan, "all") and nullnan + ): + res = -1 + return res + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_nanargmax(self, skipna): + func = partial(self._argminmax_wrap, func=np.argmax) + self.check_funs(nanops.nanargmax, func, skipna, allow_obj=False) + + @pytest.mark.filterwarnings("ignore::RuntimeWarning") + def test_nanargmin(self, skipna): + func = partial(self._argminmax_wrap, func=np.argmin) + self.check_funs(nanops.nanargmin, func, skipna, allow_obj=False) + + def _skew_kurt_wrap(self, values, axis=None, func=None): + if not isinstance(values.dtype.type, np.floating): + values = values.astype("f8") + result = func(values, axis=axis, bias=False) + # fix for handling cases where all elements in an axis are the same + if isinstance(result, np.ndarray): + result[np.max(values, axis=axis) == np.min(values, axis=axis)] = 0 + return result + elif np.max(values) == np.min(values): + return 0.0 + return result + + def test_nanskew(self, skipna): + sp_stats = pytest.importorskip("scipy.stats") + + func = partial(self._skew_kurt_wrap, func=sp_stats.skew) + with np.errstate(invalid="ignore"): + self.check_funs( + nanops.nanskew, + func, + skipna, + allow_complex=False, + allow_date=False, + allow_tdelta=False, + ) + + def test_nankurt(self, skipna): + sp_stats = pytest.importorskip("scipy.stats") + + func1 = partial(sp_stats.kurtosis, fisher=True) + func = partial(self._skew_kurt_wrap, func=func1) + with np.errstate(invalid="ignore"): + self.check_funs( + nanops.nankurt, + func, + skipna, + allow_complex=False, + allow_date=False, + allow_tdelta=False, + ) + + def test_nanprod(self, skipna): + self.check_funs( + nanops.nanprod, + np.prod, + skipna, + allow_date=False, + allow_tdelta=False, + empty_targfunc=np.nanprod, + ) + + def check_nancorr_nancov_2d(self, checkfun, targ0, targ1, **kwargs): + res00 = checkfun(self.arr_float_2d, self.arr_float1_2d, **kwargs) + res01 = checkfun( + self.arr_float_2d, + self.arr_float1_2d, + min_periods=len(self.arr_float_2d) - 1, + **kwargs, + ) + tm.assert_almost_equal(targ0, res00) + tm.assert_almost_equal(targ0, res01) + + res10 = checkfun(self.arr_float_nan_2d, self.arr_float1_nan_2d, **kwargs) + res11 = checkfun( + self.arr_float_nan_2d, + self.arr_float1_nan_2d, + min_periods=len(self.arr_float_2d) - 1, + **kwargs, + ) + tm.assert_almost_equal(targ1, res10) + tm.assert_almost_equal(targ1, res11) + + targ2 = np.nan + res20 = checkfun(self.arr_nan_2d, self.arr_float1_2d, **kwargs) + res21 = checkfun(self.arr_float_2d, self.arr_nan_2d, **kwargs) + res22 = checkfun(self.arr_nan_2d, self.arr_nan_2d, **kwargs) + res23 = checkfun(self.arr_float_nan_2d, self.arr_nan_float1_2d, **kwargs) + res24 = checkfun( + self.arr_float_nan_2d, + self.arr_nan_float1_2d, + min_periods=len(self.arr_float_2d) - 1, + **kwargs, + ) + res25 = checkfun( + self.arr_float_2d, + self.arr_float1_2d, + min_periods=len(self.arr_float_2d) + 1, + **kwargs, + ) + tm.assert_almost_equal(targ2, res20) + tm.assert_almost_equal(targ2, res21) + tm.assert_almost_equal(targ2, res22) + tm.assert_almost_equal(targ2, res23) + tm.assert_almost_equal(targ2, res24) + tm.assert_almost_equal(targ2, res25) + + def check_nancorr_nancov_1d(self, checkfun, targ0, targ1, **kwargs): + res00 = checkfun(self.arr_float_1d, self.arr_float1_1d, **kwargs) + res01 = checkfun( + self.arr_float_1d, + self.arr_float1_1d, + min_periods=len(self.arr_float_1d) - 1, + **kwargs, + ) + tm.assert_almost_equal(targ0, res00) + tm.assert_almost_equal(targ0, res01) + + res10 = checkfun(self.arr_float_nan_1d, self.arr_float1_nan_1d, **kwargs) + res11 = checkfun( + self.arr_float_nan_1d, + self.arr_float1_nan_1d, + min_periods=len(self.arr_float_1d) - 1, + **kwargs, + ) + tm.assert_almost_equal(targ1, res10) + tm.assert_almost_equal(targ1, res11) + + targ2 = np.nan + res20 = checkfun(self.arr_nan_1d, self.arr_float1_1d, **kwargs) + res21 = checkfun(self.arr_float_1d, self.arr_nan_1d, **kwargs) + res22 = checkfun(self.arr_nan_1d, self.arr_nan_1d, **kwargs) + res23 = checkfun(self.arr_float_nan_1d, self.arr_nan_float1_1d, **kwargs) + res24 = checkfun( + self.arr_float_nan_1d, + self.arr_nan_float1_1d, + min_periods=len(self.arr_float_1d) - 1, + **kwargs, + ) + res25 = checkfun( + self.arr_float_1d, + self.arr_float1_1d, + min_periods=len(self.arr_float_1d) + 1, + **kwargs, + ) + tm.assert_almost_equal(targ2, res20) + tm.assert_almost_equal(targ2, res21) + tm.assert_almost_equal(targ2, res22) + tm.assert_almost_equal(targ2, res23) + tm.assert_almost_equal(targ2, res24) + tm.assert_almost_equal(targ2, res25) + + def test_nancorr(self): + targ0 = np.corrcoef(self.arr_float_2d, self.arr_float1_2d)[0, 1] + targ1 = np.corrcoef(self.arr_float_2d.flat, self.arr_float1_2d.flat)[0, 1] + self.check_nancorr_nancov_2d(nanops.nancorr, targ0, targ1) + targ0 = np.corrcoef(self.arr_float_1d, self.arr_float1_1d)[0, 1] + targ1 = np.corrcoef(self.arr_float_1d.flat, self.arr_float1_1d.flat)[0, 1] + self.check_nancorr_nancov_1d(nanops.nancorr, targ0, targ1, method="pearson") + + def test_nancorr_pearson(self): + targ0 = np.corrcoef(self.arr_float_2d, self.arr_float1_2d)[0, 1] + targ1 = np.corrcoef(self.arr_float_2d.flat, self.arr_float1_2d.flat)[0, 1] + self.check_nancorr_nancov_2d(nanops.nancorr, targ0, targ1, method="pearson") + targ0 = np.corrcoef(self.arr_float_1d, self.arr_float1_1d)[0, 1] + targ1 = np.corrcoef(self.arr_float_1d.flat, self.arr_float1_1d.flat)[0, 1] + self.check_nancorr_nancov_1d(nanops.nancorr, targ0, targ1, method="pearson") + + def test_nancorr_kendall(self): + sp_stats = pytest.importorskip("scipy.stats") + + targ0 = sp_stats.kendalltau(self.arr_float_2d, self.arr_float1_2d)[0] + targ1 = sp_stats.kendalltau(self.arr_float_2d.flat, self.arr_float1_2d.flat)[0] + self.check_nancorr_nancov_2d(nanops.nancorr, targ0, targ1, method="kendall") + targ0 = sp_stats.kendalltau(self.arr_float_1d, self.arr_float1_1d)[0] + targ1 = sp_stats.kendalltau(self.arr_float_1d.flat, self.arr_float1_1d.flat)[0] + self.check_nancorr_nancov_1d(nanops.nancorr, targ0, targ1, method="kendall") + + def test_nancorr_spearman(self): + sp_stats = pytest.importorskip("scipy.stats") + + targ0 = sp_stats.spearmanr(self.arr_float_2d, self.arr_float1_2d)[0] + targ1 = sp_stats.spearmanr(self.arr_float_2d.flat, self.arr_float1_2d.flat)[0] + self.check_nancorr_nancov_2d(nanops.nancorr, targ0, targ1, method="spearman") + targ0 = sp_stats.spearmanr(self.arr_float_1d, self.arr_float1_1d)[0] + targ1 = sp_stats.spearmanr(self.arr_float_1d.flat, self.arr_float1_1d.flat)[0] + self.check_nancorr_nancov_1d(nanops.nancorr, targ0, targ1, method="spearman") + + def test_invalid_method(self): + pytest.importorskip("scipy") + targ0 = np.corrcoef(self.arr_float_2d, self.arr_float1_2d)[0, 1] + targ1 = np.corrcoef(self.arr_float_2d.flat, self.arr_float1_2d.flat)[0, 1] + msg = "Unknown method 'foo', expected one of 'kendall', 'spearman'" + with pytest.raises(ValueError, match=msg): + self.check_nancorr_nancov_1d(nanops.nancorr, targ0, targ1, method="foo") + + def test_nancov(self): + targ0 = np.cov(self.arr_float_2d, self.arr_float1_2d)[0, 1] + targ1 = np.cov(self.arr_float_2d.flat, self.arr_float1_2d.flat)[0, 1] + self.check_nancorr_nancov_2d(nanops.nancov, targ0, targ1) + targ0 = np.cov(self.arr_float_1d, self.arr_float1_1d)[0, 1] + targ1 = np.cov(self.arr_float_1d.flat, self.arr_float1_1d.flat)[0, 1] + self.check_nancorr_nancov_1d(nanops.nancov, targ0, targ1) + + +@pytest.mark.parametrize( + "arr, correct", + [ + ("arr_complex", False), + ("arr_int", False), + ("arr_bool", False), + ("arr_str", False), + ("arr_utf", False), + ("arr_complex_nan", False), + ("arr_nan_nanj", False), + ("arr_nan_infj", True), + ("arr_complex_nan_infj", True), + ], +) +def test_has_infs_non_float(request, arr, correct, disable_bottleneck): + val = request.getfixturevalue(arr) + while getattr(val, "ndim", True): + res0 = nanops._has_infs(val) + if correct: + assert res0 + else: + assert not res0 + + if not hasattr(val, "ndim"): + break + + # Reduce dimension for next step in the loop + val = np.take(val, 0, axis=-1) + + +@pytest.mark.parametrize( + "arr, correct", + [ + ("arr_float", False), + ("arr_nan", False), + ("arr_float_nan", False), + ("arr_nan_nan", False), + ("arr_float_inf", True), + ("arr_inf", True), + ("arr_nan_inf", True), + ("arr_float_nan_inf", True), + ("arr_nan_nan_inf", True), + ], +) +@pytest.mark.parametrize("astype", [None, "f4", "f2"]) +def test_has_infs_floats(request, arr, correct, astype, disable_bottleneck): + val = request.getfixturevalue(arr) + if astype is not None: + val = val.astype(astype) + while getattr(val, "ndim", True): + res0 = nanops._has_infs(val) + if correct: + assert res0 + else: + assert not res0 + + if not hasattr(val, "ndim"): + break + + # Reduce dimension for next step in the loop + val = np.take(val, 0, axis=-1) + + +@pytest.mark.parametrize( + "fixture", ["arr_float", "arr_complex", "arr_int", "arr_bool", "arr_str", "arr_utf"] +) +def test_bn_ok_dtype(fixture, request, disable_bottleneck): + obj = request.getfixturevalue(fixture) + assert nanops._bn_ok_dtype(obj.dtype, "test") + + +@pytest.mark.parametrize( + "fixture", + [ + "arr_date", + "arr_tdelta", + "arr_obj", + ], +) +def test_bn_not_ok_dtype(fixture, request, disable_bottleneck): + obj = request.getfixturevalue(fixture) + assert not nanops._bn_ok_dtype(obj.dtype, "test") + + +class TestEnsureNumeric: + def test_numeric_values(self): + # Test integer + assert nanops._ensure_numeric(1) == 1 + + # Test float + assert nanops._ensure_numeric(1.1) == 1.1 + + # Test complex + assert nanops._ensure_numeric(1 + 2j) == 1 + 2j + + def test_ndarray(self): + # Test numeric ndarray + values = np.array([1, 2, 3]) + assert np.allclose(nanops._ensure_numeric(values), values) + + # Test object ndarray + o_values = values.astype(object) + assert np.allclose(nanops._ensure_numeric(o_values), values) + + # Test convertible string ndarray + s_values = np.array(["1", "2", "3"], dtype=object) + msg = r"Could not convert \['1' '2' '3'\] to numeric" + with pytest.raises(TypeError, match=msg): + nanops._ensure_numeric(s_values) + + # Test non-convertible string ndarray + s_values = np.array(["foo", "bar", "baz"], dtype=object) + msg = r"Could not convert .* to numeric" + with pytest.raises(TypeError, match=msg): + nanops._ensure_numeric(s_values) + + def test_convertable_values(self): + with pytest.raises(TypeError, match="Could not convert string '1' to numeric"): + nanops._ensure_numeric("1") + with pytest.raises( + TypeError, match="Could not convert string '1.1' to numeric" + ): + nanops._ensure_numeric("1.1") + with pytest.raises( + TypeError, match=r"Could not convert string '1\+1j' to numeric" + ): + nanops._ensure_numeric("1+1j") + + def test_non_convertable_values(self): + msg = "Could not convert string 'foo' to numeric" + with pytest.raises(TypeError, match=msg): + nanops._ensure_numeric("foo") + + # with the wrong type, python raises TypeError for us + msg = "argument must be a string or a number" + with pytest.raises(TypeError, match=msg): + nanops._ensure_numeric({}) + with pytest.raises(TypeError, match=msg): + nanops._ensure_numeric([]) + + +class TestNanvarFixedValues: + # xref GH10242 + # Samples from a normal distribution. + @pytest.fixture + def variance(self): + return 3.0 + + @pytest.fixture + def samples(self, variance): + return self.prng.normal(scale=variance**0.5, size=100000) + + def test_nanvar_all_finite(self, samples, variance): + actual_variance = nanops.nanvar(samples) + tm.assert_almost_equal(actual_variance, variance, rtol=1e-2) + + def test_nanvar_nans(self, samples, variance): + samples_test = np.nan * np.ones(2 * samples.shape[0]) + samples_test[::2] = samples + + actual_variance = nanops.nanvar(samples_test, skipna=True) + tm.assert_almost_equal(actual_variance, variance, rtol=1e-2) + + actual_variance = nanops.nanvar(samples_test, skipna=False) + tm.assert_almost_equal(actual_variance, np.nan, rtol=1e-2) + + def test_nanstd_nans(self, samples, variance): + samples_test = np.nan * np.ones(2 * samples.shape[0]) + samples_test[::2] = samples + + actual_std = nanops.nanstd(samples_test, skipna=True) + tm.assert_almost_equal(actual_std, variance**0.5, rtol=1e-2) + + actual_std = nanops.nanvar(samples_test, skipna=False) + tm.assert_almost_equal(actual_std, np.nan, rtol=1e-2) + + def test_nanvar_axis(self, samples, variance): + # Generate some sample data. + samples_unif = self.prng.uniform(size=samples.shape[0]) + samples = np.vstack([samples, samples_unif]) + + actual_variance = nanops.nanvar(samples, axis=1) + tm.assert_almost_equal( + actual_variance, np.array([variance, 1.0 / 12]), rtol=1e-2 + ) + + def test_nanvar_ddof(self): + n = 5 + samples = self.prng.uniform(size=(10000, n + 1)) + samples[:, -1] = np.nan # Force use of our own algorithm. + + variance_0 = nanops.nanvar(samples, axis=1, skipna=True, ddof=0).mean() + variance_1 = nanops.nanvar(samples, axis=1, skipna=True, ddof=1).mean() + variance_2 = nanops.nanvar(samples, axis=1, skipna=True, ddof=2).mean() + + # The unbiased estimate. + var = 1.0 / 12 + tm.assert_almost_equal(variance_1, var, rtol=1e-2) + + # The underestimated variance. + tm.assert_almost_equal(variance_0, (n - 1.0) / n * var, rtol=1e-2) + + # The overestimated variance. + tm.assert_almost_equal(variance_2, (n - 1.0) / (n - 2.0) * var, rtol=1e-2) + + @pytest.mark.parametrize("axis", range(2)) + @pytest.mark.parametrize("ddof", range(3)) + def test_ground_truth(self, axis, ddof): + # Test against values that were precomputed with Numpy. + samples = np.empty((4, 4)) + samples[:3, :3] = np.array( + [ + [0.97303362, 0.21869576, 0.55560287], + [0.72980153, 0.03109364, 0.99155171], + [0.09317602, 0.60078248, 0.15871292], + ] + ) + samples[3] = samples[:, 3] = np.nan + + # Actual variances along axis=0, 1 for ddof=0, 1, 2 + variance = np.array( + [ + [ + [0.13762259, 0.05619224, 0.11568816], + [0.20643388, 0.08428837, 0.17353224], + [0.41286776, 0.16857673, 0.34706449], + ], + [ + [0.09519783, 0.16435395, 0.05082054], + [0.14279674, 0.24653093, 0.07623082], + [0.28559348, 0.49306186, 0.15246163], + ], + ] + ) + + # Test nanvar. + var = nanops.nanvar(samples, skipna=True, axis=axis, ddof=ddof) + tm.assert_almost_equal(var[:3], variance[axis, ddof]) + assert np.isnan(var[3]) + + # Test nanstd. + std = nanops.nanstd(samples, skipna=True, axis=axis, ddof=ddof) + tm.assert_almost_equal(std[:3], variance[axis, ddof] ** 0.5) + assert np.isnan(std[3]) + + @pytest.mark.parametrize("ddof", range(3)) + def test_nanstd_roundoff(self, ddof): + # Regression test for GH 10242 (test data taken from GH 10489). Ensure + # that variance is stable. + data = Series(766897346 * np.ones(10)) + result = data.std(ddof=ddof) + assert result == 0.0 + + @property + def prng(self): + return np.random.default_rng(2) + + +class TestNanskewFixedValues: + # xref GH 11974 + # Test data + skewness value (computed with scipy.stats.skew) + @pytest.fixture + def samples(self): + return np.sin(np.linspace(0, 1, 200)) + + @pytest.fixture + def actual_skew(self): + return -0.1875895205961754 + + @pytest.mark.parametrize("val", [3075.2, 3075.3, 3075.5]) + def test_constant_series(self, val): + # xref GH 11974 + data = val * np.ones(300) + skew = nanops.nanskew(data) + assert skew == 0.0 + + def test_all_finite(self): + alpha, beta = 0.3, 0.1 + left_tailed = self.prng.beta(alpha, beta, size=100) + assert nanops.nanskew(left_tailed) < 0 + + alpha, beta = 0.1, 0.3 + right_tailed = self.prng.beta(alpha, beta, size=100) + assert nanops.nanskew(right_tailed) > 0 + + def test_ground_truth(self, samples, actual_skew): + skew = nanops.nanskew(samples) + tm.assert_almost_equal(skew, actual_skew) + + def test_axis(self, samples, actual_skew): + samples = np.vstack([samples, np.nan * np.ones(len(samples))]) + skew = nanops.nanskew(samples, axis=1) + tm.assert_almost_equal(skew, np.array([actual_skew, np.nan])) + + def test_nans(self, samples): + samples = np.hstack([samples, np.nan]) + skew = nanops.nanskew(samples, skipna=False) + assert np.isnan(skew) + + def test_nans_skipna(self, samples, actual_skew): + samples = np.hstack([samples, np.nan]) + skew = nanops.nanskew(samples, skipna=True) + tm.assert_almost_equal(skew, actual_skew) + + @pytest.mark.parametrize( + "initial_data, nobs", + [ + ([-2.05191341e-05, -4.10391103e-05], 27), + ([-2.05191341e-10, -4.10391103e-10], 27), + ([-2.05191341e-05, -4.10391103e-05], 10_000), + ([-2.05191341e-10, -4.10391103e-10], 10_000), + ], + ) + def test_low_variance(self, initial_data, nobs): + st = pytest.importorskip("scipy.stats") + data = np.zeros((nobs,), dtype=np.float64) + data[: len(initial_data)] = initial_data + skew = nanops.nanskew(data) + expected = st.skew(data, bias=False) + tm.assert_almost_equal(skew, expected) + + @property + def prng(self): + return np.random.default_rng(2) + + +class TestNankurtFixedValues: + # xref GH 11974 + # Test data + kurtosis value (computed with scipy.stats.kurtosis) + @pytest.fixture + def samples(self): + return np.sin(np.linspace(0, 1, 200)) + + @pytest.fixture + def actual_kurt(self): + return -1.2058303433799713 + + @pytest.mark.parametrize("val", [3075.2, 3075.3, 3075.5]) + def test_constant_series(self, val): + # xref GH 11974 + data = val * np.ones(300) + kurt = nanops.nankurt(data) + tm.assert_equal(kurt, 0.0) + + def test_all_finite(self): + alpha, beta = 0.3, 0.1 + left_tailed = self.prng.beta(alpha, beta, size=100) + assert nanops.nankurt(left_tailed) < 2 + + alpha, beta = 0.1, 0.3 + right_tailed = self.prng.beta(alpha, beta, size=100) + assert nanops.nankurt(right_tailed) < 0 + + def test_ground_truth(self, samples, actual_kurt): + kurt = nanops.nankurt(samples) + tm.assert_almost_equal(kurt, actual_kurt) + + def test_axis(self, samples, actual_kurt): + samples = np.vstack([samples, np.nan * np.ones(len(samples))]) + kurt = nanops.nankurt(samples, axis=1) + tm.assert_almost_equal(kurt, np.array([actual_kurt, np.nan])) + + def test_nans(self, samples): + samples = np.hstack([samples, np.nan]) + kurt = nanops.nankurt(samples, skipna=False) + assert np.isnan(kurt) + + def test_nans_skipna(self, samples, actual_kurt): + samples = np.hstack([samples, np.nan]) + kurt = nanops.nankurt(samples, skipna=True) + tm.assert_almost_equal(kurt, actual_kurt) + + @pytest.mark.parametrize( + "initial_data, nobs", + [ + ([-2.05191341e-05, -4.10391103e-05], 27), + ([-2.05191341e-10, -4.10391103e-10], 27), + ([-2.05191341e-05, -4.10391103e-05], 10_000), + ([-2.05191341e-10, -4.10391103e-10], 10_000), + ], + ) + def test_low_variance(self, initial_data, nobs): + # GH#57972 + st = pytest.importorskip("scipy.stats") + data = np.zeros((nobs,), dtype=np.float64) + data[: len(initial_data)] = initial_data + kurt = nanops.nankurt(data) + expected = st.kurtosis(data, bias=False) + tm.assert_almost_equal(kurt, expected) + + @property + def prng(self): + return np.random.default_rng(2) + + +class TestDatetime64NaNOps: + # Enabling mean changes the behavior of DataFrame.mean + # See https://github.com/pandas-dev/pandas/issues/24752 + def test_nanmean(self, unit): + dti = pd.date_range("2016-01-01", periods=3).as_unit(unit) + expected = dti[1] + + for obj in [dti, dti._data]: + result = nanops.nanmean(obj) + assert result == expected + + dti2 = dti.insert(1, pd.NaT) + + for obj in [dti2, dti2._data]: + result = nanops.nanmean(obj) + assert result == expected + + @pytest.mark.parametrize("constructor", ["M8", "m8"]) + def test_nanmean_skipna_false(self, constructor, unit): + dtype = f"{constructor}[{unit}]" + arr = np.arange(12).astype(np.int64).view(dtype).reshape(4, 3) + + arr[-1, -1] = "NaT" + + result = nanops.nanmean(arr, skipna=False) + assert np.isnat(result) + assert result.dtype == dtype + + result = nanops.nanmean(arr, axis=0, skipna=False) + expected = np.array([4, 5, "NaT"], dtype=arr.dtype) + tm.assert_numpy_array_equal(result, expected) + + result = nanops.nanmean(arr, axis=1, skipna=False) + expected = np.array([arr[0, 1], arr[1, 1], arr[2, 1], arr[-1, -1]]) + tm.assert_numpy_array_equal(result, expected) + + +def test_use_bottleneck(): + if nanops._BOTTLENECK_INSTALLED: + with pd.option_context("use_bottleneck", True): + assert pd.get_option("use_bottleneck") + + with pd.option_context("use_bottleneck", False): + assert not pd.get_option("use_bottleneck") + + +@pytest.mark.parametrize( + "numpy_op, expected", + [ + (np.sum, 10), + (np.nansum, 10), + (np.mean, 2.5), + (np.nanmean, 2.5), + (np.median, 2.5), + (np.nanmedian, 2.5), + (np.min, 1), + (np.max, 4), + (np.nanmin, 1), + (np.nanmax, 4), + ], +) +def test_numpy_ops(numpy_op, expected): + # GH8383 + result = numpy_op(Series([1, 2, 3, 4])) + assert result == expected + + +@pytest.mark.parametrize( + "operation", + [ + nanops.nanany, + nanops.nanall, + nanops.nansum, + nanops.nanmean, + nanops.nanmedian, + nanops.nanstd, + nanops.nanvar, + nanops.nansem, + nanops.nanargmax, + nanops.nanargmin, + nanops.nanmax, + nanops.nanmin, + nanops.nanskew, + nanops.nankurt, + nanops.nanprod, + ], +) +def test_nanops_independent_of_mask_param(operation): + # GH22764 + ser = Series([1, 2, np.nan, 3, np.nan, 4]) + mask = ser.isna() + median_expected = operation(ser._values) + median_result = operation(ser._values, mask=mask) + assert median_expected == median_result + + +@pytest.mark.parametrize("min_count", [-1, 0]) +def test_check_below_min_count_negative_or_zero_min_count(min_count): + # GH35227 + result = nanops.check_below_min_count((21, 37), None, min_count) + expected_result = False + assert result == expected_result + + +@pytest.mark.parametrize( + "mask", [None, np.array([False, False, True]), np.array([True] + 9 * [False])] +) +@pytest.mark.parametrize("min_count, expected_result", [(1, False), (101, True)]) +def test_check_below_min_count_positive_min_count(mask, min_count, expected_result): + # GH35227 + shape = (10, 10) + result = nanops.check_below_min_count(shape, mask, min_count) + assert result == expected_result + + +@td.skip_if_windows +@td.skip_if_32bit +@pytest.mark.parametrize("min_count, expected_result", [(1, False), (2812191852, True)]) +def test_check_below_min_count_large_shape(min_count, expected_result): + # GH35227 large shape used to show that the issue is fixed + shape = (2244367, 1253) + result = nanops.check_below_min_count(shape, mask=None, min_count=min_count) + assert result == expected_result + + +@pytest.mark.parametrize("func", ["nanmean", "nansum"]) +def test_check_bottleneck_disallow(any_real_numpy_dtype, func): + # GH 42878 bottleneck sometimes produces unreliable results for mean and sum + assert not nanops._bn_ok_dtype(np.dtype(any_real_numpy_dtype).type, func) + + +@pytest.mark.parametrize("val", [2**55, -(2**55), 20150515061816532]) +def test_nanmean_overflow(disable_bottleneck, val, using_python_scalars): + # GH 10155 + # In the previous implementation mean can overflow for int dtypes, it + # is now consistent with numpy + + ser = Series(val, index=range(500), dtype=np.int64) + result = ser.mean() + assert result == val + if using_python_scalars: + assert type(result) == float + else: + np_result = ser.values.mean() + assert result == np_result + assert result.dtype == np.float64 + + +@pytest.mark.parametrize( + "dtype", + [ + np.int16, + np.int32, + np.int64, + np.float32, + np.float64, + getattr(np, "float128", None), + ], +) +@pytest.mark.parametrize("method", ["mean", "std", "var", "skew", "kurt", "min", "max"]) +def test_returned_dtype(disable_bottleneck, dtype, method, using_python_scalars): + if dtype is None: + pytest.skip("np.float128 not available") + + ser = Series(range(10), dtype=dtype) + result = getattr(ser, method)() + if using_python_scalars: + if is_integer_dtype(dtype) and method in ["min", "max"]: + assert isinstance(result, int) + else: + assert type(result) == float + elif is_integer_dtype(dtype) and method not in ["min", "max"]: + assert result.dtype == np.float64 + else: + assert result.dtype == dtype diff --git a/pandas/tests/test_optional_dependency.py b/pandas/tests/test_optional_dependency.py new file mode 100644 index 0000000000000000000000000000000000000000..cd276914bfb21217c602c40f2aef07d8eff2255d --- /dev/null +++ b/pandas/tests/test_optional_dependency.py @@ -0,0 +1,100 @@ +import sys +import types + +import pytest + +from pandas.compat._optional import ( + VERSIONS, + import_optional_dependency, +) + +import pandas._testing as tm + + +def test_import_optional(): + match = "Import .*notapackage.* pip .* conda .* notapackage" + with pytest.raises(ImportError, match=match) as exc_info: + import_optional_dependency("notapackage") + # The original exception should be there as context: + assert isinstance(exc_info.value.__context__, ImportError) + + result = import_optional_dependency("notapackage", errors="ignore") + assert result is None + + +def test_xlrd_version_fallback(): + pytest.importorskip("xlrd") + import_optional_dependency("xlrd") + + +def test_bad_version(monkeypatch): + name = "fakemodule" + module = types.ModuleType(name) + module.__version__ = "0.9.0" + sys.modules[name] = module + monkeypatch.setitem(VERSIONS, name, "1.0.0") + + match = "Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'" + with pytest.raises(ImportError, match=match): + import_optional_dependency("fakemodule") + + # Test min_version parameter + result = import_optional_dependency("fakemodule", min_version="0.8") + assert result is module + + with tm.assert_produces_warning(UserWarning, match=match): + result = import_optional_dependency("fakemodule", errors="warn") + assert result is None + + module.__version__ = "1.0.0" # exact match is OK + result = import_optional_dependency("fakemodule") + assert result is module + + with pytest.raises(ImportError, match="Pandas requires version '1.1.0'"): + import_optional_dependency("fakemodule", min_version="1.1.0") + + with tm.assert_produces_warning(UserWarning, match="Pandas requires version"): + result = import_optional_dependency( + "fakemodule", errors="warn", min_version="1.1.0" + ) + assert result is None + + result = import_optional_dependency( + "fakemodule", errors="ignore", min_version="1.1.0" + ) + assert result is None + + +def test_submodule(monkeypatch): + # Create a fake module with a submodule + name = "fakemodule" + module = types.ModuleType(name) + module.__version__ = "0.9.0" + sys.modules[name] = module + sub_name = "submodule" + submodule = types.ModuleType(sub_name) + setattr(module, sub_name, submodule) + sys.modules[f"{name}.{sub_name}"] = submodule + monkeypatch.setitem(VERSIONS, name, "1.0.0") + + match = "Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'" + with pytest.raises(ImportError, match=match): + import_optional_dependency("fakemodule.submodule") + + with tm.assert_produces_warning(UserWarning, match=match): + result = import_optional_dependency("fakemodule.submodule", errors="warn") + assert result is None + + module.__version__ = "1.0.0" # exact match is OK + result = import_optional_dependency("fakemodule.submodule") + assert result is submodule + + +def test_no_version_raises(monkeypatch): + name = "fakemodule" + module = types.ModuleType(name) + sys.modules[name] = module + monkeypatch.setitem(VERSIONS, name, "1.0.0") + + with pytest.raises(ImportError, match="Can't determine .* fakemodule"): + import_optional_dependency(name) diff --git a/pandas/tests/test_register_accessor.py b/pandas/tests/test_register_accessor.py new file mode 100644 index 0000000000000000000000000000000000000000..9deff5613939412f0e51a8784f3e0958e9989140 --- /dev/null +++ b/pandas/tests/test_register_accessor.py @@ -0,0 +1,123 @@ +from collections.abc import Generator +import contextlib +import weakref + +import pytest + +import pandas as pd +import pandas._testing as tm +from pandas.core import accessor + + +def test_dirname_mixin() -> None: + # GH37173 + + class X(accessor.DirNamesMixin): + x = 1 + y: int + + def __init__(self) -> None: + self.z = 3 + + result = [attr_name for attr_name in dir(X()) if not attr_name.startswith("_")] + + assert result == ["x", "z"] + + +@contextlib.contextmanager +def ensure_removed(obj, attr) -> Generator[None, None, None]: + """Ensure that an attribute added to 'obj' during the test is + removed when we're done + """ + try: + yield + finally: + try: + delattr(obj, attr) + except AttributeError: + pass + obj._accessors.discard(attr) + + +class MyAccessor: + def __init__(self, obj) -> None: + self.obj = obj + self.item = "item" + + @property + def prop(self): + return self.item + + def method(self): + return self.item + + +@pytest.mark.parametrize( + "obj, registrar", + [ + (pd.Series, pd.api.extensions.register_series_accessor), + (pd.DataFrame, pd.api.extensions.register_dataframe_accessor), + (pd.Index, pd.api.extensions.register_index_accessor), + ], +) +def test_register(obj, registrar): + with ensure_removed(obj, "mine"): + before = set(dir(obj)) + registrar("mine")(MyAccessor) + o = obj([]) if obj is not pd.Series else obj([], dtype=object) + assert o.mine.prop == "item" + after = set(dir(obj)) + assert (before ^ after) == {"mine"} + assert "mine" in obj._accessors + + +def test_accessor_works(): + with ensure_removed(pd.Series, "mine"): + pd.api.extensions.register_series_accessor("mine")(MyAccessor) + + s = pd.Series([1, 2]) + assert s.mine.obj is s + + assert s.mine.prop == "item" + assert s.mine.method() == "item" + + +def test_overwrite_warns(): + match = r".*MyAccessor.*fake.*Series.*" + with tm.assert_produces_warning(UserWarning, match=match): + with ensure_removed(pd.Series, "fake"): + setattr(pd.Series, "fake", 123) + pd.api.extensions.register_series_accessor("fake")(MyAccessor) + s = pd.Series([1, 2]) + assert s.fake.prop == "item" + + +def test_raises_attribute_error(): + with ensure_removed(pd.Series, "bad"): + + @pd.api.extensions.register_series_accessor("bad") + class Bad: + def __init__(self, data) -> None: + raise AttributeError("whoops") + + with pytest.raises(AttributeError, match="whoops"): + pd.Series([], dtype=object).bad + + +@pytest.mark.parametrize( + "klass, registrar", + [ + (pd.Series, pd.api.extensions.register_series_accessor), + (pd.DataFrame, pd.api.extensions.register_dataframe_accessor), + (pd.Index, pd.api.extensions.register_index_accessor), + ], +) +def test_no_circular_reference(klass, registrar): + # GH 41357 + with ensure_removed(klass, "access"): + registrar("access")(MyAccessor) + obj = klass([0]) + ref = weakref.ref(obj) + assert obj.access.obj is obj + del obj + assert ref() is None diff --git a/pandas/tests/test_sorting.py b/pandas/tests/test_sorting.py new file mode 100644 index 0000000000000000000000000000000000000000..4596238946c62fd788ffb67aba8394ba41745dd4 --- /dev/null +++ b/pandas/tests/test_sorting.py @@ -0,0 +1,475 @@ +from collections import defaultdict +from datetime import datetime +from itertools import product + +import numpy as np +import pytest + +from pandas import ( + NA, + DataFrame, + MultiIndex, + Series, + array, + concat, + merge, +) +import pandas._testing as tm +from pandas.core.algorithms import safe_sort +import pandas.core.common as com +from pandas.core.sorting import ( + _decons_group_index, + get_group_index, + is_int64_overflow_possible, + lexsort_indexer, + nargsort, +) + + +@pytest.fixture +def left_right(): + low, high, n = -1 << 10, 1 << 10, 1 << 20 + left = DataFrame( + np.random.default_rng(2).integers(low, high, (n, 7)), columns=list("ABCDEFG") + ) + left["left"] = left.sum(axis=1) + right = left.sample( + frac=1, random_state=np.random.default_rng(2), ignore_index=True + ) + right.columns = [*right.columns[:-1].tolist(), "right"] + right["right"] *= -1 + return left, right + + +class TestSorting: + @pytest.mark.slow + def test_int64_overflow(self): + B = np.concatenate((np.arange(1000), np.arange(1000), np.arange(500))) + A = np.arange(2500) + df = DataFrame( + { + "A": A, + "B": B, + "C": A, + "D": B, + "E": A, + "F": B, + "G": A, + "H": B, + "values": np.random.default_rng(2).standard_normal(2500), + } + ) + + lg = df.groupby(["A", "B", "C", "D", "E", "F", "G", "H"]) + rg = df.groupby(["H", "G", "F", "E", "D", "C", "B", "A"]) + + left = lg.sum()["values"] + right = rg.sum()["values"] + + exp_index, _ = left.index.sortlevel() + tm.assert_index_equal(left.index, exp_index) + + exp_index, _ = right.index.sortlevel(0) + tm.assert_index_equal(right.index, exp_index) + + tups = list(map(tuple, df[["A", "B", "C", "D", "E", "F", "G", "H"]].values)) + tups = com.asarray_tuplesafe(tups) + + expected = df.groupby(tups).sum()["values"] + + for k, v in expected.items(): + assert left[k] == right[k[::-1]] + assert left[k] == v + assert len(left) == len(right) + + def test_int64_overflow_groupby_large_range(self): + # GH9096 + values = range(55109) + data = DataFrame.from_dict({"a": values, "b": values, "c": values, "d": values}) + grouped = data.groupby(["a", "b", "c", "d"]) + assert len(grouped) == len(values) + + @pytest.mark.slow + @pytest.mark.parametrize("agg", ["mean", "median"]) + def test_int64_overflow_groupby_large_df_shuffled(self, agg): + rs = np.random.default_rng(2) + arr = rs.integers(-1 << 12, 1 << 12, (1 << 15, 5)) + i = rs.choice(len(arr), len(arr) * 4) + arr = np.vstack((arr, arr[i])) # add some duplicate rows + + i = rs.permutation(len(arr)) + arr = arr[i] # shuffle rows + + df = DataFrame(arr, columns=list("abcde")) + df["jim"], df["joe"] = np.zeros((2, len(df))) + gr = df.groupby(list("abcde")) + + # verify this is testing what it is supposed to test! + assert is_int64_overflow_possible( + tuple(ping.ngroups for ping in gr._grouper.groupings) + ) + + mi = MultiIndex.from_arrays( + [ar.ravel() for ar in np.array_split(np.unique(arr, axis=0), 5, axis=1)], + names=list("abcde"), + ) + + res = DataFrame( + np.zeros((len(mi), 2)), columns=["jim", "joe"], index=mi + ).sort_index() + + tm.assert_frame_equal(getattr(gr, agg)(), res) + + @pytest.mark.parametrize( + "order, na_position, exp", + [ + [ + True, + "last", + list(range(5, 105)) + list(range(5)) + list(range(105, 110)), + ], + [ + True, + "first", + list(range(5)) + list(range(105, 110)) + list(range(5, 105)), + ], + [ + False, + "last", + list(range(104, 4, -1)) + list(range(5)) + list(range(105, 110)), + ], + [ + False, + "first", + list(range(5)) + list(range(105, 110)) + list(range(104, 4, -1)), + ], + ], + ) + def test_lexsort_indexer(self, order, na_position, exp): + keys = [[np.nan] * 5 + list(range(100)) + [np.nan] * 5] + result = lexsort_indexer(keys, orders=order, na_position=na_position) + tm.assert_numpy_array_equal(result, np.array(exp, dtype=np.intp)) + + @pytest.mark.parametrize( + "ascending, na_position, exp", + [ + [ + True, + "last", + list(range(5, 105)) + list(range(5)) + list(range(105, 110)), + ], + [ + True, + "first", + list(range(5)) + list(range(105, 110)) + list(range(5, 105)), + ], + [ + False, + "last", + list(range(104, 4, -1)) + list(range(5)) + list(range(105, 110)), + ], + [ + False, + "first", + list(range(5)) + list(range(105, 110)) + list(range(104, 4, -1)), + ], + ], + ) + def test_nargsort(self, ascending, na_position, exp): + # list places NaNs last, np.array(..., dtype="O") may not place NaNs first + items = np.array([np.nan] * 5 + list(range(100)) + [np.nan] * 5, dtype="O") + + # mergesort is the most difficult to get right because we want it to be + # stable. + + # According to numpy/core/tests/test_multiarray, """The number of + # sorted items must be greater than ~50 to check the actual algorithm + # because quick and merge sort fall over to insertion sort for small + # arrays.""" + + result = nargsort( + items, kind="mergesort", ascending=ascending, na_position=na_position + ) + tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False) + + +class TestMerge: + def test_int64_overflow_outer_merge(self): + # #2690, combinatorial explosion + df1 = DataFrame( + np.random.default_rng(2).standard_normal((1000, 7)), + columns=[*list("ABCDEF"), "G1"], + ) + df2 = DataFrame( + np.random.default_rng(3).standard_normal((1000, 7)), + columns=[*list("ABCDEF"), "G2"], + ) + result = merge(df1, df2, how="outer") + assert len(result) == 2000 + + @pytest.mark.slow + def test_int64_overflow_check_sum_col(self, left_right): + left, right = left_right + + out = merge(left, right, how="outer") + assert len(out) == len(left) + tm.assert_series_equal(out["left"], -out["right"], check_names=False) + result = out.iloc[:, :-2].sum(axis=1) + tm.assert_series_equal(out["left"], result, check_names=False) + assert result.name is None + + @pytest.mark.slow + def test_int64_overflow_how_merge(self, left_right, join_type): + left, right = left_right + + out = merge(left, right, how="outer") + out.sort_values(out.columns.tolist(), inplace=True) + tm.assert_frame_equal(out, merge(left, right, how=join_type, sort=True)) + + @pytest.mark.slow + def test_int64_overflow_sort_false_order(self, left_right): + left, right = left_right + + # check that left merge w/ sort=False maintains left frame order + out = merge(left, right, how="left", sort=False) + tm.assert_frame_equal(left, out[left.columns.tolist()]) + + out = merge(right, left, how="left", sort=False) + tm.assert_frame_equal(right, out[right.columns.tolist()]) + + @pytest.mark.slow + def test_int64_overflow_one_to_many_none_match(self, join_type, sort): + # one-2-many/none match + how = join_type + low, high, n = -1 << 10, 1 << 10, 1 << 11 + left = DataFrame( + np.random.default_rng(2).integers(low, high, (n, 7)).astype("int64"), + columns=list("ABCDEFG"), + ) + + # confirm that this is checking what it is supposed to check + shape = left.apply(Series.nunique).values + assert is_int64_overflow_possible(shape) + + # add duplicates to left frame + left = concat([left, left], ignore_index=True) + + right = DataFrame( + np.random.default_rng(3).integers(low, high, (n // 2, 7)).astype("int64"), + columns=list("ABCDEFG"), + ) + + # add duplicates & overlap with left to the right frame + i = np.random.default_rng(4).choice(len(left), n) + right = concat([right, right, left.iloc[i]], ignore_index=True) + + left["left"] = np.random.default_rng(2).standard_normal(len(left)) + right["right"] = np.random.default_rng(2).standard_normal(len(right)) + + # shuffle left & right frames + left = left.sample( + frac=1, ignore_index=True, random_state=np.random.default_rng(5) + ) + right = right.sample( + frac=1, ignore_index=True, random_state=np.random.default_rng(6) + ) + + # manually compute outer merge + ldict, rdict = defaultdict(list), defaultdict(list) + + for idx, row in left.set_index(list("ABCDEFG")).iterrows(): + ldict[idx].append(row["left"]) + + for idx, row in right.set_index(list("ABCDEFG")).iterrows(): + rdict[idx].append(row["right"]) + + vals = [] + for k, lval in ldict.items(): + rval = rdict.get(k, [np.nan]) + for lv, rv in product(lval, rval): + vals.append((*k, lv, rv)) + + for k, rval in rdict.items(): + if k not in ldict: + vals.extend((*k, np.nan, rv) for rv in rval) + + out = DataFrame(vals, columns=[*list("ABCDEFG"), "left", "right"]) + out = out.sort_values(out.columns.to_list(), ignore_index=True) + + jmask = { + "left": out["left"].notna(), + "right": out["right"].notna(), + "inner": out["left"].notna() & out["right"].notna(), + "outer": np.ones(len(out), dtype="bool"), + } + + mask = jmask[how] + frame = out[mask].sort_values(out.columns.to_list(), ignore_index=True) + assert mask.all() ^ mask.any() or how == "outer" + + res = merge(left, right, how=how, sort=sort) + if sort: + kcols = list("ABCDEFG") + tm.assert_frame_equal( + res[kcols], res[kcols].sort_values(kcols, kind="mergesort") + ) + + # as in GH9092 dtypes break with outer/right join + # 2021-12-18: dtype does not break anymore + tm.assert_frame_equal( + frame, res.sort_values(res.columns.to_list(), ignore_index=True) + ) + + +@pytest.mark.parametrize( + "codes_list, shape", + [ + [ + [ + np.tile([0, 1, 2, 3, 0, 1, 2, 3], 100).astype(np.int64), + np.tile([0, 2, 4, 3, 0, 1, 2, 3], 100).astype(np.int64), + np.tile([5, 1, 0, 2, 3, 0, 5, 4], 100).astype(np.int64), + ], + (4, 5, 6), + ], + [ + [ + np.tile(np.arange(10000, dtype=np.int64), 5), + np.tile(np.arange(10000, dtype=np.int64), 5), + ], + (10000, 10000), + ], + ], +) +def test_decons(codes_list, shape): + group_index = get_group_index(codes_list, shape, sort=True, xnull=True) + codes_list2 = _decons_group_index(group_index, shape) + + for a, b in zip(codes_list, codes_list2, strict=True): + tm.assert_numpy_array_equal(a, b) + + +class TestSafeSort: + @pytest.mark.parametrize( + "arg, exp", + [ + [[3, 1, 2, 0, 4], [0, 1, 2, 3, 4]], + [ + np.array(list("baaacb"), dtype=object), + np.array(list("aaabbc"), dtype=object), + ], + [[], []], + ], + ) + def test_basic_sort(self, arg, exp): + result = safe_sort(np.array(arg)) + expected = np.array(exp) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("verify", [True, False]) + @pytest.mark.parametrize( + "codes, exp_codes", + [ + [[0, 1, 1, 2, 3, 0, -1, 4], [3, 1, 1, 2, 0, 3, -1, 4]], + [[], []], + ], + ) + def test_codes(self, verify, codes, exp_codes): + values = np.array([3, 1, 2, 0, 4]) + expected = np.array([0, 1, 2, 3, 4]) + + result, result_codes = safe_sort( + values, codes, use_na_sentinel=True, verify=verify + ) + expected_codes = np.array(exp_codes, dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + tm.assert_numpy_array_equal(result_codes, expected_codes) + + def test_codes_out_of_bound(self): + values = np.array([3, 1, 2, 0, 4]) + expected = np.array([0, 1, 2, 3, 4]) + + # out of bound indices + codes = [0, 101, 102, 2, 3, 0, 99, 4] + result, result_codes = safe_sort(values, codes, use_na_sentinel=True) + expected_codes = np.array([3, -1, -1, 2, 0, 3, -1, 4], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + tm.assert_numpy_array_equal(result_codes, expected_codes) + + @pytest.mark.parametrize("codes", [[-1, -1], [2, -1], [2, 2]]) + def test_codes_empty_array_out_of_bound(self, codes): + empty_values = np.array([]) + expected_codes = -np.ones_like(codes, dtype=np.intp) + _, result_codes = safe_sort(empty_values, codes) + tm.assert_numpy_array_equal(result_codes, expected_codes) + + def test_mixed_integer(self): + values = np.array(["b", 1, 0, "a", 0, "b"], dtype=object) + result = safe_sort(values) + expected = np.array([0, 0, 1, "a", "b", "b"], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + def test_mixed_integer_with_codes(self): + values = np.array(["b", 1, 0, "a"], dtype=object) + codes = [0, 1, 2, 3, 0, -1, 1] + result, result_codes = safe_sort(values, codes) + expected = np.array([0, 1, "a", "b"], dtype=object) + expected_codes = np.array([3, 1, 0, 2, 3, -1, 1], dtype=np.intp) + tm.assert_numpy_array_equal(result, expected) + tm.assert_numpy_array_equal(result_codes, expected_codes) + + def test_unsortable(self): + # GH 13714 + arr = np.array([1, 2, datetime.now(), 0, 3], dtype=object) + msg = "'[<>]' not supported between instances of .*" + with pytest.raises(TypeError, match=msg): + safe_sort(arr) + + @pytest.mark.parametrize( + "arg, codes, err, msg", + [ + [1, None, TypeError, "Only np.ndarray, ExtensionArray, and Index"], + [np.array([0, 1, 2]), 1, TypeError, "Only list-like objects or None"], + [np.array([0, 1, 2, 1]), [0, 1], ValueError, "values should be unique"], + ], + ) + def test_exceptions(self, arg, codes, err, msg): + with pytest.raises(err, match=msg): + safe_sort(values=arg, codes=codes) + + @pytest.mark.parametrize( + "arg, exp", [[[1, 3, 2], [1, 2, 3]], [[1, 3, NA, 2], [1, 2, 3, NA]]] + ) + def test_extension_array(self, arg, exp): + a = array(arg, dtype="Int64") + result = safe_sort(a) + expected = array(exp, dtype="Int64") + tm.assert_extension_array_equal(result, expected) + + @pytest.mark.parametrize("verify", [True, False]) + def test_extension_array_codes(self, verify): + a = array([1, 3, 2], dtype="Int64") + result, codes = safe_sort(a, [0, 1, -1, 2], use_na_sentinel=True, verify=verify) + expected_values = array([1, 2, 3], dtype="Int64") + expected_codes = np.array([0, 2, -1, 1], dtype=np.intp) + tm.assert_extension_array_equal(result, expected_values) + tm.assert_numpy_array_equal(codes, expected_codes) + + +def test_mixed_str_null(nulls_fixture): + values = np.array(["b", nulls_fixture, "a", "b"], dtype=object) + result = safe_sort(values) + expected = np.array(["a", "b", "b", nulls_fixture], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + +def test_safe_sort_multiindex(): + # GH#48412 + arr1 = Series([2, 1, NA, NA], dtype="Int64") + arr2 = [2, 1, 3, 3] + midx = MultiIndex.from_arrays([arr1, arr2]) + result = safe_sort(midx) + expected = MultiIndex.from_arrays( + [Series([1, 2, NA, NA], dtype="Int64"), [1, 2, 3, 3]] + ) + tm.assert_index_equal(result, expected) diff --git a/pandas/tests/test_take.py b/pandas/tests/test_take.py new file mode 100644 index 0000000000000000000000000000000000000000..451ef42fff3d170682f5f0a6440df5ba9cb85a08 --- /dev/null +++ b/pandas/tests/test_take.py @@ -0,0 +1,317 @@ +from datetime import datetime + +import numpy as np +import pytest + +from pandas._libs import iNaT + +from pandas import array +import pandas._testing as tm +import pandas.core.algorithms as algos + + +@pytest.fixture( + params=[ + (np.int8, np.int16(127), np.int8), + (np.int8, np.int16(128), np.int16), + (np.int32, 1, np.int32), + (np.int32, 2.0, np.float64), + (np.int32, 3.0 + 4.0j, np.complex128), + (np.int32, True, np.object_), + (np.int32, "", np.object_), + (np.float64, 1, np.float64), + (np.float64, 2.0, np.float64), + (np.float64, 3.0 + 4.0j, np.complex128), + (np.float64, True, np.object_), + (np.float64, "", np.object_), + (np.complex128, 1, np.complex128), + (np.complex128, 2.0, np.complex128), + (np.complex128, 3.0 + 4.0j, np.complex128), + (np.complex128, True, np.object_), + (np.complex128, "", np.object_), + (np.bool_, 1, np.object_), + (np.bool_, 2.0, np.object_), + (np.bool_, 3.0 + 4.0j, np.object_), + (np.bool_, True, np.bool_), + (np.bool_, "", np.object_), + ] +) +def dtype_fill_out_dtype(request): + return request.param + + +class TestTake: + def test_1d_fill_nonna(self, dtype_fill_out_dtype): + dtype, fill_value, out_dtype = dtype_fill_out_dtype + data = np.random.default_rng(2).integers(0, 2, 4).astype(dtype) + indexer = [2, 1, 0, -1] + + result = algos.take_nd(data, indexer, fill_value=fill_value) + assert (result[[0, 1, 2]] == data[[2, 1, 0]]).all() + assert result[3] == fill_value + assert result.dtype == out_dtype + + indexer = [2, 1, 0, 1] + + result = algos.take_nd(data, indexer, fill_value=fill_value) + assert (result[[0, 1, 2, 3]] == data[indexer]).all() + assert result.dtype == dtype + + def test_2d_fill_nonna(self, dtype_fill_out_dtype): + dtype, fill_value, out_dtype = dtype_fill_out_dtype + data = np.random.default_rng(2).integers(0, 2, (5, 3)).astype(dtype) + indexer = [2, 1, 0, -1] + + result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value) + assert (result[[0, 1, 2], :] == data[[2, 1, 0], :]).all() + assert (result[3, :] == fill_value).all() + assert result.dtype == out_dtype + + result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value) + assert (result[:, [0, 1, 2]] == data[:, [2, 1, 0]]).all() + assert (result[:, 3] == fill_value).all() + assert result.dtype == out_dtype + + indexer = [2, 1, 0, 1] + result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value) + assert (result[[0, 1, 2, 3], :] == data[indexer, :]).all() + assert result.dtype == dtype + + result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value) + assert (result[:, [0, 1, 2, 3]] == data[:, indexer]).all() + assert result.dtype == dtype + + def test_3d_fill_nonna(self, dtype_fill_out_dtype): + dtype, fill_value, out_dtype = dtype_fill_out_dtype + + data = np.random.default_rng(2).integers(0, 2, (5, 4, 3)).astype(dtype) + indexer = [2, 1, 0, -1] + + result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value) + assert (result[[0, 1, 2], :, :] == data[[2, 1, 0], :, :]).all() + assert (result[3, :, :] == fill_value).all() + assert result.dtype == out_dtype + + result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value) + assert (result[:, [0, 1, 2], :] == data[:, [2, 1, 0], :]).all() + assert (result[:, 3, :] == fill_value).all() + assert result.dtype == out_dtype + + result = algos.take_nd(data, indexer, axis=2, fill_value=fill_value) + assert (result[:, :, [0, 1, 2]] == data[:, :, [2, 1, 0]]).all() + assert (result[:, :, 3] == fill_value).all() + assert result.dtype == out_dtype + + indexer = [2, 1, 0, 1] + result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value) + assert (result[[0, 1, 2, 3], :, :] == data[indexer, :, :]).all() + assert result.dtype == dtype + + result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value) + assert (result[:, [0, 1, 2, 3], :] == data[:, indexer, :]).all() + assert result.dtype == dtype + + result = algos.take_nd(data, indexer, axis=2, fill_value=fill_value) + assert (result[:, :, [0, 1, 2, 3]] == data[:, :, indexer]).all() + assert result.dtype == dtype + + def test_1d_other_dtypes(self): + arr = np.random.default_rng(2).standard_normal(10).astype(np.float32) + + indexer = [1, 2, 3, -1] + result = algos.take_nd(arr, indexer) + expected = arr.take(indexer) + expected[-1] = np.nan + tm.assert_almost_equal(result, expected) + + def test_2d_other_dtypes(self): + arr = np.random.default_rng(2).standard_normal((10, 5)).astype(np.float32) + + indexer = [1, 2, 3, -1] + + # axis=0 + result = algos.take_nd(arr, indexer, axis=0) + expected = arr.take(indexer, axis=0) + expected[-1] = np.nan + tm.assert_almost_equal(result, expected) + + # axis=1 + result = algos.take_nd(arr, indexer, axis=1) + expected = arr.take(indexer, axis=1) + expected[:, -1] = np.nan + tm.assert_almost_equal(result, expected) + + def test_1d_bool(self): + arr = np.array([0, 1, 0], dtype=bool) + + result = algos.take_nd(arr, [0, 2, 2, 1]) + expected = arr.take([0, 2, 2, 1]) + tm.assert_numpy_array_equal(result, expected) + + result = algos.take_nd(arr, [0, 2, -1]) + assert result.dtype == np.object_ + + def test_2d_bool(self): + arr = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 1]], dtype=bool) + + result = algos.take_nd(arr, [0, 2, 2, 1]) + expected = arr.take([0, 2, 2, 1], axis=0) + tm.assert_numpy_array_equal(result, expected) + + result = algos.take_nd(arr, [0, 2, 2, 1], axis=1) + expected = arr.take([0, 2, 2, 1], axis=1) + tm.assert_numpy_array_equal(result, expected) + + result = algos.take_nd(arr, [0, 2, -1]) + assert result.dtype == np.object_ + + def test_2d_float32(self): + arr = np.random.default_rng(2).standard_normal((4, 3)).astype(np.float32) + indexer = [0, 2, -1, 1, -1] + + # axis=0 + result = algos.take_nd(arr, indexer, axis=0) + + expected = arr.take(indexer, axis=0) + expected[[2, 4], :] = np.nan + tm.assert_almost_equal(result, expected) + + # axis=1 + result = algos.take_nd(arr, indexer, axis=1) + expected = arr.take(indexer, axis=1) + expected[:, [2, 4]] = np.nan + tm.assert_almost_equal(result, expected) + + def test_2d_datetime64(self): + # 2005/01/01 - 2006/01/01 + arr = ( + np.random.default_rng(2).integers(11_045_376, 11_360_736, (5, 3)) + * 100_000_000_000 + ) + arr = arr.view(dtype="datetime64[ns]") + indexer = [0, 2, -1, 1, -1] + + # axis=0 + result = algos.take_nd(arr, indexer, axis=0) + expected = arr.take(indexer, axis=0) + expected.view(np.int64)[[2, 4], :] = iNaT + tm.assert_almost_equal(result, expected) + + result = algos.take_nd(arr, indexer, axis=0, fill_value=datetime(2007, 1, 1)) + expected = arr.take(indexer, axis=0) + expected[[2, 4], :] = datetime(2007, 1, 1) + tm.assert_almost_equal(result, expected) + + # axis=1 + result = algos.take_nd(arr, indexer, axis=1) + expected = arr.take(indexer, axis=1) + expected.view(np.int64)[:, [2, 4]] = iNaT + tm.assert_almost_equal(result, expected) + + result = algos.take_nd(arr, indexer, axis=1, fill_value=datetime(2007, 1, 1)) + expected = arr.take(indexer, axis=1) + expected[:, [2, 4]] = datetime(2007, 1, 1) + tm.assert_almost_equal(result, expected) + + def test_take_axis_0(self): + arr = np.arange(12).reshape(4, 3) + result = algos.take(arr, [0, -1]) + expected = np.array([[0, 1, 2], [9, 10, 11]]) + tm.assert_numpy_array_equal(result, expected) + + # allow_fill=True + result = algos.take(arr, [0, -1], allow_fill=True, fill_value=0) + expected = np.array([[0, 1, 2], [0, 0, 0]]) + tm.assert_numpy_array_equal(result, expected) + + def test_take_axis_1(self): + arr = np.arange(12).reshape(4, 3) + result = algos.take(arr, [0, -1], axis=1) + expected = np.array([[0, 2], [3, 5], [6, 8], [9, 11]]) + tm.assert_numpy_array_equal(result, expected) + + # allow_fill=True + result = algos.take(arr, [0, -1], axis=1, allow_fill=True, fill_value=0) + expected = np.array([[0, 0], [3, 0], [6, 0], [9, 0]]) + tm.assert_numpy_array_equal(result, expected) + + # GH#26976 make sure we validate along the correct axis + with pytest.raises(IndexError, match="indices are out-of-bounds"): + algos.take(arr, [0, 3], axis=1, allow_fill=True, fill_value=0) + + def test_take_non_hashable_fill_value(self): + arr = np.array([1, 2, 3]) + indexer = np.array([1, -1]) + with pytest.raises(ValueError, match="fill_value must be a scalar"): + algos.take(arr, indexer, allow_fill=True, fill_value=[1]) + + # with object dtype it is allowed + arr = np.array([1, 2, 3], dtype=object) + result = algos.take(arr, indexer, allow_fill=True, fill_value=[1]) + expected = np.array([2, [1]], dtype=object) + tm.assert_numpy_array_equal(result, expected) + + +class TestExtensionTake: + # The take method found in pd.api.extensions + + def test_bounds_check_large(self): + arr = np.array([1, 2]) + + msg = "indices are out-of-bounds" + with pytest.raises(IndexError, match=msg): + algos.take(arr, [2, 3], allow_fill=True) + + msg = "index 2 is out of bounds for( axis 0 with)? size 2" + with pytest.raises(IndexError, match=msg): + algos.take(arr, [2, 3], allow_fill=False) + + def test_bounds_check_small(self): + arr = np.array([1, 2, 3], dtype=np.int64) + indexer = [0, -1, -2] + + msg = r"'indices' contains values less than allowed \(-2 < -1\)" + with pytest.raises(ValueError, match=msg): + algos.take(arr, indexer, allow_fill=True) + + result = algos.take(arr, indexer) + expected = np.array([1, 3, 2], dtype=np.int64) + tm.assert_numpy_array_equal(result, expected) + + @pytest.mark.parametrize("allow_fill", [True, False]) + def test_take_empty(self, allow_fill): + arr = np.array([], dtype=np.int64) + # empty take is ok + result = algos.take(arr, [], allow_fill=allow_fill) + tm.assert_numpy_array_equal(arr, result) + + msg = "|".join( + [ + "cannot do a non-empty take from an empty axes.", + "indices are out-of-bounds", + ] + ) + with pytest.raises(IndexError, match=msg): + algos.take(arr, [0], allow_fill=allow_fill) + + def test_take_na_empty(self): + result = algos.take(np.array([]), [-1, -1], allow_fill=True, fill_value=0.0) + expected = np.array([0.0, 0.0]) + tm.assert_numpy_array_equal(result, expected) + + def test_take_coerces_list(self): + # GH#52981 coercing is deprecated, disabled in 3.0 + arr = [1, 2, 3] + msg = ( + "pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, " + "Index, Series, or NumpyExtensionArray got list" + ) + with pytest.raises(TypeError, match=msg): + algos.take(arr, [0, 0]) + + def test_take_NumpyExtensionArray(self): + # GH#59177 + arr = array([1 + 1j, 2, 3]) # NumpyEADtype('complex128') (NumpyExtensionArray) + assert algos.take(arr, [2]) == 2 + arr = array([1, 2, 3]) # Int64Dtype() (ExtensionArray) + assert algos.take(arr, [2]) == 2 diff --git a/pandas/tests/tools/__init__.py b/pandas/tests/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/tools/test_to_datetime.py b/pandas/tests/tools/test_to_datetime.py new file mode 100644 index 0000000000000000000000000000000000000000..ec97985b496fd6367517a2449e793d5140949138 --- /dev/null +++ b/pandas/tests/tools/test_to_datetime.py @@ -0,0 +1,3829 @@ +"""test to_datetime""" + +import calendar +from collections import deque +from datetime import ( + date, + datetime, + timedelta, + timezone, +) +from decimal import Decimal +import locale +import zoneinfo + +from dateutil.parser import parse +import numpy as np +import pytest + +from pandas._libs import tslib +from pandas._libs.tslibs import ( + iNaT, + parsing, +) +from pandas.compat import ( + PY314, + WASM, +) +from pandas.errors import ( + OutOfBoundsDatetime, + OutOfBoundsTimedelta, +) +import pandas.util._test_decorators as td + +from pandas.core.dtypes.common import is_datetime64_ns_dtype + +import pandas as pd +from pandas import ( + DataFrame, + DatetimeIndex, + Index, + NaT, + Series, + Timestamp, + date_range, + isna, + to_datetime, +) +import pandas._testing as tm +from pandas.core.arrays import DatetimeArray +from pandas.core.tools import datetimes as tools +from pandas.core.tools.datetimes import start_caching_at + +PARSING_ERR_MSG = ( + r"You might want to try:\n" + r" - passing `format` if your strings have a consistent format;\n" + r" - passing `format=\'ISO8601\'` if your strings are all ISO8601 " + r"but not necessarily in exactly the same format;\n" + r" - passing `format=\'mixed\'`, and the format will be inferred " + r"for each element individually. You might want to use `dayfirst` " + r"alongside this." +) + +if PY314: + NOT_99 = ", not 99" + DAY_IS_OUT_OF_RANGE = ( + r"day \d{1,2} must be in range 1\.\.\d{1,2} for " + r"month \d{1,2} in year \d{4}" + ) +else: + NOT_99 = "" + DAY_IS_OUT_OF_RANGE = "day is out of range for month" + + +class TestTimeConversionFormats: + def test_to_datetime_readonly(self, writable): + # GH#34857 + arr = np.array([], dtype=object) + arr.setflags(write=writable) + result = to_datetime(arr) + expected = to_datetime([]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "format, expected", + [ + [ + "%d/%m/%Y", + [Timestamp("20000101"), Timestamp("20000201"), Timestamp("20000301")], + ], + [ + "%m/%d/%Y", + [Timestamp("20000101"), Timestamp("20000102"), Timestamp("20000103")], + ], + ], + ) + def test_to_datetime_format(self, cache, index_or_series, format, expected): + values = index_or_series(["1/1/2000", "1/2/2000", "1/3/2000"]) + result = to_datetime(values, format=format, cache=cache) + expected = index_or_series(expected) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "arg, expected, format", + [ + ["1/1/2000", "20000101", "%d/%m/%Y"], + ["1/1/2000", "20000101", "%m/%d/%Y"], + ["1/2/2000", "20000201", "%d/%m/%Y"], + ["1/2/2000", "20000102", "%m/%d/%Y"], + ["1/3/2000", "20000301", "%d/%m/%Y"], + ["1/3/2000", "20000103", "%m/%d/%Y"], + ], + ) + def test_to_datetime_format_scalar(self, cache, arg, expected, format): + result = to_datetime(arg, format=format, cache=cache) + expected = Timestamp(expected) + assert result == expected + + def test_to_datetime_format_YYYYMMDD(self, cache): + ser = Series([19801222, 19801222] + [19810105] * 5) + expected = Series([Timestamp(x) for x in ser.apply(str)]) + + result = to_datetime(ser, format="%Y%m%d", cache=cache) + tm.assert_series_equal(result, expected) + + result = to_datetime(ser.apply(str), format="%Y%m%d", cache=cache) + tm.assert_series_equal(result, expected) + + def test_to_datetime_format_YYYYMMDD_with_nat(self, cache): + # Explicit cast to float to explicit cast when setting np.nan + ser = Series([19801222, 19801222] + [19810105] * 5, dtype="float") + # with NaT + expected = Series( + [Timestamp("19801222"), Timestamp("19801222")] + + [Timestamp("19810105")] * 5, + dtype="M8[us]", + ) + expected[2] = np.nan + ser[2] = np.nan + + result = to_datetime(ser, format="%Y%m%d", cache=cache) + tm.assert_series_equal(result, expected) + + # string with NaT + ser2 = ser.apply(str) + ser2[2] = "nat" + with pytest.raises( + ValueError, + match=( + 'unconverted data remains when parsing with format "%Y%m%d": ".0". ' + ), + ): + # https://github.com/pandas-dev/pandas/issues/50051 + to_datetime(ser2, format="%Y%m%d", cache=cache) + + def test_to_datetime_format_YYYYMM_with_nat(self, cache): + # https://github.com/pandas-dev/pandas/issues/50237 + # Explicit cast to float to explicit cast when setting np.nan + ser = Series([198012, 198012] + [198101] * 5, dtype="float") + expected = Series( + [Timestamp("19801201"), Timestamp("19801201")] + + [Timestamp("19810101")] * 5, + dtype="M8[us]", + ) + expected[2] = np.nan + ser[2] = np.nan + result = to_datetime(ser, format="%Y%m", cache=cache) + tm.assert_series_equal(result, expected) + + def test_to_datetime_format_YYYYMMDD_oob_for_ns(self, cache): + # coercion + # GH 7930, GH 14487 + ser = Series([20121231, 20141231, 99991231]) + result = to_datetime(ser, format="%Y%m%d", errors="raise", cache=cache) + expected = Series( + np.array(["2012-12-31", "2014-12-31", "9999-12-31"], dtype="M8[s]"), + dtype="M8[us]", + ) + tm.assert_series_equal(result, expected) + + def test_to_datetime_format_YYYYMMDD_coercion(self, cache): + # coercion + # GH 7930 + ser = Series([20121231, 20141231, 999999999999999999999999999991231]) + result = to_datetime(ser, format="%Y%m%d", errors="coerce", cache=cache) + expected = Series(["20121231", "20141231", "NaT"], dtype="M8[us]") + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "input_s", + [ + # Null values with Strings + ["19801222", "20010112", None], + ["19801222", "20010112", np.nan], + ["19801222", "20010112", NaT], + ["19801222", "20010112", "NaT"], + # Null values with Integers + [19801222, 20010112, None], + [19801222, 20010112, np.nan], + [19801222, 20010112, NaT], + [19801222, 20010112, "NaT"], + ], + ) + def test_to_datetime_format_YYYYMMDD_with_none(self, input_s): + # GH 30011 + # format='%Y%m%d' + # with None + expected = Series([Timestamp("19801222"), Timestamp("20010112"), NaT]) + result = Series(to_datetime(input_s, format="%Y%m%d")) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "input_s, expected", + [ + # NaN before strings with invalid date values + [ + ["19801222", np.nan, "20010012", "10019999"], + [Timestamp("19801222"), np.nan, np.nan, np.nan], + ], + # NaN after strings with invalid date values + [ + ["19801222", "20010012", "10019999", np.nan], + [Timestamp("19801222"), np.nan, np.nan, np.nan], + ], + # NaN before integers with invalid date values + [ + [20190813, np.nan, 20010012, 20019999], + [Timestamp("20190813"), np.nan, np.nan, np.nan], + ], + # NaN after integers with invalid date values + [ + [20190813, 20010012, np.nan, 20019999], + [Timestamp("20190813"), np.nan, np.nan, np.nan], + ], + ], + ) + def test_to_datetime_format_YYYYMMDD_overflow(self, input_s, expected): + # GH 25512 + # format='%Y%m%d', errors='coerce' + input_s = Series(input_s) + result = to_datetime(input_s, format="%Y%m%d", errors="coerce") + expected = Series(expected) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "data, format, expected", + [ + ([pd.NA], "%Y%m%d%H%M%S", ["NaT"]), + ([pd.NA], None, ["NaT"]), + ( + [pd.NA, "20210202202020"], + "%Y%m%d%H%M%S", + ["NaT", "2021-02-02 20:20:20"], + ), + (["201010", pd.NA], "%y%m%d", ["2020-10-10", "NaT"]), + (["201010", pd.NA], "%d%m%y", ["2010-10-20", "NaT"]), + ([None, np.nan, pd.NA], None, ["NaT", "NaT", "NaT"]), + ([None, np.nan, pd.NA], "%Y%m%d", ["NaT", "NaT", "NaT"]), + ], + ) + def test_to_datetime_with_NA(self, data, format, expected): + # GH#42957 + result = to_datetime(data, format=format) + expected = DatetimeIndex(expected) + tm.assert_index_equal(result, expected) + + def test_to_datetime_with_NA_with_warning(self): + # GH#42957 + result = to_datetime(["201010", pd.NA]) + expected = DatetimeIndex(["2010-10-20", "NaT"]) + tm.assert_index_equal(result, expected) + + def test_to_datetime_format_integer(self, cache): + # GH 10178 + ser = Series([2000, 2001, 2002]) + expected = Series([Timestamp(x) for x in ser.apply(str)]) + + result = to_datetime(ser, format="%Y", cache=cache) + tm.assert_series_equal(result, expected) + + ser = Series([200001, 200105, 200206]) + expected = Series([Timestamp(x[:4] + "-" + x[4:]) for x in ser.apply(str)]) + + result = to_datetime(ser, format="%Y%m", cache=cache) + tm.assert_series_equal(result, expected) + + def test_to_datetime_format_microsecond(self, cache): + month_abbr = calendar.month_abbr[4] + val = f"01-{month_abbr}-2011 00:00:01.978" + + format = "%d-%b-%Y %H:%M:%S.%f" + result = to_datetime(val, format=format, cache=cache) + exp = datetime.strptime(val, format) + assert result == exp + + @pytest.mark.parametrize( + "value, format, dt", + [ + ["01/10/2010 15:20", "%m/%d/%Y %H:%M", Timestamp("2010-01-10 15:20")], + ["01/10/2010 05:43", "%m/%d/%Y %I:%M", Timestamp("2010-01-10 05:43")], + [ + "01/10/2010 13:56:01", + "%m/%d/%Y %H:%M:%S", + Timestamp("2010-01-10 13:56:01"), + ], + # The 3 tests below are locale-dependent. + # They pass, except when the machine locale is zh_CN or it_IT . + pytest.param( + "01/10/2010 08:14 PM", + "%m/%d/%Y %I:%M %p", + Timestamp("2010-01-10 20:14"), + marks=pytest.mark.xfail( + locale.getlocale()[0] in ("zh_CN", "it_IT"), + reason="fail on a CI build with LC_ALL=zh_CN.utf8/it_IT.utf8", + strict=False, + ), + ), + pytest.param( + "01/10/2010 07:40 AM", + "%m/%d/%Y %I:%M %p", + Timestamp("2010-01-10 07:40"), + marks=pytest.mark.xfail( + locale.getlocale()[0] in ("zh_CN", "it_IT"), + reason="fail on a CI build with LC_ALL=zh_CN.utf8/it_IT.utf8", + strict=False, + ), + ), + pytest.param( + "01/10/2010 09:12:56 AM", + "%m/%d/%Y %I:%M:%S %p", + Timestamp("2010-01-10 09:12:56"), + marks=pytest.mark.xfail( + locale.getlocale()[0] in ("zh_CN", "it_IT"), + reason="fail on a CI build with LC_ALL=zh_CN.utf8/it_IT.utf8", + strict=False, + ), + ), + ], + ) + def test_to_datetime_format_time(self, cache, value, format, dt): + assert to_datetime(value, format=format, cache=cache) == dt + + @td.skip_if_not_us_locale + def test_to_datetime_with_non_exact(self, cache): + # GH 10834 + # 8904 + # exact kw + ser = Series( + ["19MAY11", "foobar19MAY11", "19MAY11:00:00:00", "19MAY11 00:00:00Z"] + ) + result = to_datetime(ser, format="%d%b%y", exact=False, cache=cache) + expected = to_datetime( + ser.str.extract(r"(\d+\w+\d+)", expand=False), format="%d%b%y", cache=cache + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "format, expected", + [ + ("%Y-%m-%d", Timestamp(2000, 1, 3)), + ("%Y-%d-%m", Timestamp(2000, 3, 1)), + ("%Y-%m-%d %H", Timestamp(2000, 1, 3, 12)), + ("%Y-%d-%m %H", Timestamp(2000, 3, 1, 12)), + ("%Y-%m-%d %H:%M", Timestamp(2000, 1, 3, 12, 34)), + ("%Y-%d-%m %H:%M", Timestamp(2000, 3, 1, 12, 34)), + ("%Y-%m-%d %H:%M:%S", Timestamp(2000, 1, 3, 12, 34, 56)), + ("%Y-%d-%m %H:%M:%S", Timestamp(2000, 3, 1, 12, 34, 56)), + ("%Y-%m-%d %H:%M:%S.%f", Timestamp(2000, 1, 3, 12, 34, 56, 123456)), + ("%Y-%d-%m %H:%M:%S.%f", Timestamp(2000, 3, 1, 12, 34, 56, 123456)), + ( + "%Y-%m-%d %H:%M:%S.%f%z", + Timestamp(2000, 1, 3, 12, 34, 56, 123456, tz="UTC+01:00"), + ), + ( + "%Y-%d-%m %H:%M:%S.%f%z", + Timestamp(2000, 3, 1, 12, 34, 56, 123456, tz="UTC+01:00"), + ), + ], + ) + def test_non_exact_doesnt_parse_whole_string(self, cache, format, expected): + # https://github.com/pandas-dev/pandas/issues/50412 + # the formats alternate between ISO8601 and non-ISO8601 to check both paths + result = to_datetime( + "2000-01-03 12:34:56.123456+01:00", format=format, exact=False + ) + assert result == expected + + @pytest.mark.parametrize( + "arg", + [ + "2012-01-01 09:00:00.000000001", + "2012-01-01 09:00:00.000001", + "2012-01-01 09:00:00.001", + "2012-01-01 09:00:00.001000", + "2012-01-01 09:00:00.001000000", + ], + ) + def test_parse_nanoseconds_with_formula(self, cache, arg): + # GH8989 + # truncating the nanoseconds when a format was provided + expected = to_datetime(arg, cache=cache) + result = to_datetime(arg, format="%Y-%m-%d %H:%M:%S.%f", cache=cache) + assert result == expected + + @pytest.mark.parametrize( + "value,fmt,expected", + [ + ["2009324", "%Y%W%w", "2009-08-13"], + ["2013020", "%Y%U%w", "2013-01-13"], + ], + ) + def test_to_datetime_format_weeks(self, value, fmt, expected, cache): + assert to_datetime(value, format=fmt, cache=cache) == Timestamp(expected) + + @pytest.mark.parametrize( + "fmt,dates,expected_dates", + [ + [ + "%Y-%m-%d %H:%M:%S %Z", + ["2010-01-01 12:00:00 UTC"] * 2, + [Timestamp("2010-01-01 12:00:00", tz="UTC")] * 2, + ], + [ + "%Y-%m-%d %H:%M:%S%z", + ["2010-01-01 12:00:00+0100"] * 2, + [ + Timestamp( + "2010-01-01 12:00:00", tzinfo=timezone(timedelta(minutes=60)) + ) + ] + * 2, + ], + [ + "%Y-%m-%d %H:%M:%S %z", + ["2010-01-01 12:00:00 +0100"] * 2, + [ + Timestamp( + "2010-01-01 12:00:00", tzinfo=timezone(timedelta(minutes=60)) + ) + ] + * 2, + ], + [ + "%Y-%m-%d %H:%M:%S %z", + ["2010-01-01 12:00:00 Z", "2010-01-01 12:00:00 Z"], + [ + Timestamp( + "2010-01-01 12:00:00", tzinfo=timezone(timedelta(minutes=0)) + ), + Timestamp( + "2010-01-01 12:00:00", tzinfo=timezone(timedelta(minutes=0)) + ), + ], + ], + ], + ) + def test_to_datetime_parse_tzname_or_tzoffset(self, fmt, dates, expected_dates): + # GH 13486 + result = to_datetime(dates, format=fmt) + expected = Index(expected_dates) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "fmt,dates,expected_dates", + [ + [ + "%Y-%m-%d %H:%M:%S %Z", + [ + "2010-01-01 12:00:00 UTC", + "2010-01-01 12:00:00 GMT", + "2010-01-01 12:00:00 US/Pacific", + ], + [ + Timestamp("2010-01-01 12:00:00", tz="UTC"), + Timestamp("2010-01-01 12:00:00", tz="GMT"), + Timestamp("2010-01-01 12:00:00", tz="US/Pacific"), + ], + ], + [ + "%Y-%m-%d %H:%M:%S %z", + ["2010-01-01 12:00:00 +0100", "2010-01-01 12:00:00 -0100"], + [ + Timestamp( + "2010-01-01 12:00:00", tzinfo=timezone(timedelta(minutes=60)) + ), + Timestamp( + "2010-01-01 12:00:00", tzinfo=timezone(timedelta(minutes=-60)) + ), + ], + ], + ], + ) + def test_to_datetime_parse_tzname_or_tzoffset_utc_false_removed( + self, fmt, dates, expected_dates + ): + # GH#13486, GH#50887, GH#57275 + msg = "Mixed timezones detected. Pass utc=True in to_datetime" + with pytest.raises(ValueError, match=msg): + to_datetime(dates, format=fmt) + + def test_to_datetime_parse_tzname_or_tzoffset_different_tz_to_utc(self): + # GH 32792 + dates = [ + "2010-01-01 12:00:00 +0100", + "2010-01-01 12:00:00 -0100", + "2010-01-01 12:00:00 +0300", + "2010-01-01 12:00:00 +0400", + ] + expected_dates = [ + "2010-01-01 11:00:00+00:00", + "2010-01-01 13:00:00+00:00", + "2010-01-01 09:00:00+00:00", + "2010-01-01 08:00:00+00:00", + ] + fmt = "%Y-%m-%d %H:%M:%S %z" + + result = to_datetime(dates, format=fmt, utc=True) + expected = DatetimeIndex(expected_dates) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "offset", ["+0", "-1foo", "UTCbar", ":10", "+01:000:01", ""] + ) + def test_to_datetime_parse_timezone_malformed(self, offset): + fmt = "%Y-%m-%d %H:%M:%S %z" + date = "2010-01-01 12:00:00 " + offset + + msg = "|".join( + [ + r'^time data ".*" doesn\'t match format ".*". ' f"{PARSING_ERR_MSG}$", + r'^unconverted data remains when parsing with format ".*": ".*". ' + f"{PARSING_ERR_MSG}$", + ] + ) + with pytest.raises(ValueError, match=msg): + to_datetime([date], format=fmt) + + def test_to_datetime_parse_timezone_keeps_name(self): + # GH 21697 + fmt = "%Y-%m-%d %H:%M:%S %z" + arg = Index(["2010-01-01 12:00:00 Z"], name="foo") + result = to_datetime(arg, format=fmt) + expected = DatetimeIndex(["2010-01-01 12:00:00"], tz="UTC", name="foo") + tm.assert_index_equal(result, expected) + + +class TestToDatetime: + def test_to_datetime_mixed_string_resos(self): + # GH#62801 + vals = [ + "2016-01-01 01:02:03", + "2016-01-01 01:02:03.001", + "2016-01-01 01:02:03.001002", + "2016-01-01 01:02:03.001002003", + ] + expected = DatetimeIndex([Timestamp(x).as_unit("ns") for x in vals]) + + result1 = DatetimeIndex(vals) + tm.assert_index_equal(result1, expected) + + result2 = to_datetime(vals, format="ISO8601") + tm.assert_index_equal(result2, expected) + + result3 = to_datetime(vals, format="mixed") + tm.assert_index_equal(result3, expected) + + def test_to_datetime_none(self): + # GH#23055 + assert to_datetime(None) is NaT + + @pytest.mark.filterwarnings("ignore:Could not infer format") + def test_to_datetime_overflow(self): + # we should get an OutOfBoundsDatetime, NOT OverflowError + # TODO: Timestamp raises ValueError("could not convert string to Timestamp") + # can we make these more consistent? + arg = "08335394550" + msg = 'Parsing "08335394550" to datetime overflows' + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(arg) + + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime([arg]) + + res = to_datetime(arg, errors="coerce") + assert res is NaT + res = to_datetime([arg], errors="coerce") + exp = Index([NaT], dtype="M8[s]") + tm.assert_index_equal(res, exp) + + def test_to_datetime_mixed_datetime_and_string(self): + # GH#47018 adapted old doctest with new behavior + d1 = datetime(2020, 1, 1, 17, tzinfo=timezone(-timedelta(hours=1))) + d2 = datetime(2020, 1, 1, 18, tzinfo=timezone(-timedelta(hours=1))) + res = to_datetime(["2020-01-01 17:00 -0100", d2]) + expected = to_datetime([d1, d2]).tz_convert(timezone(timedelta(minutes=-60))) + tm.assert_index_equal(res, expected) + + def test_to_datetime_mixed_string_and_numeric(self): + # GH#55780 np.array(vals) would incorrectly cast the number to str + vals = ["2016-01-01", 0] + expected = DatetimeIndex([Timestamp(x) for x in vals]) + result = to_datetime(vals, format="mixed") + result2 = to_datetime(vals[::-1], format="mixed")[::-1] + result3 = DatetimeIndex(vals) + result4 = DatetimeIndex(vals[::-1])[::-1] + + tm.assert_index_equal(result, expected) + tm.assert_index_equal(result2, expected) + tm.assert_index_equal(result3, expected) + tm.assert_index_equal(result4, expected) + + @pytest.mark.parametrize( + "format", ["%Y-%m-%d", "%Y-%d-%m"], ids=["ISO8601", "non-ISO8601"] + ) + def test_to_datetime_mixed_date_and_string(self, format): + # https://github.com/pandas-dev/pandas/issues/50108 + d1 = date(2020, 1, 2) + res = to_datetime(["2020-01-01", d1], format=format) + expected = DatetimeIndex(["2020-01-01", "2020-01-02"], dtype="M8[us]") + tm.assert_index_equal(res, expected) + + @pytest.mark.parametrize( + "fmt", + ["%Y-%d-%m %H:%M:%S%z", "%Y-%m-%d %H:%M:%S%z"], + ids=["non-ISO8601 format", "ISO8601 format"], + ) + @pytest.mark.parametrize( + "utc, args, expected", + [ + pytest.param( + True, + ["2000-01-01 01:00:00-08:00", "2000-01-01 02:00:00-08:00"], + DatetimeIndex( + ["2000-01-01 09:00:00+00:00", "2000-01-01 10:00:00+00:00"], + dtype="datetime64[us, UTC]", + ), + id="all tz-aware, with utc", + ), + pytest.param( + False, + ["2000-01-01 01:00:00+00:00", "2000-01-01 02:00:00+00:00"], + DatetimeIndex( + ["2000-01-01 01:00:00+00:00", "2000-01-01 02:00:00+00:00"], + ).as_unit("us"), + id="all tz-aware, without utc", + ), + pytest.param( + True, + ["2000-01-01 01:00:00-08:00", "2000-01-01 02:00:00+00:00"], + DatetimeIndex( + ["2000-01-01 09:00:00+00:00", "2000-01-01 02:00:00+00:00"], + dtype="datetime64[us, UTC]", + ), + id="all tz-aware, mixed offsets, with utc", + ), + pytest.param( + True, + ["2000-01-01 01:00:00", "2000-01-01 02:00:00+00:00"], + DatetimeIndex( + ["2000-01-01 01:00:00+00:00", "2000-01-01 02:00:00+00:00"], + dtype="datetime64[us, UTC]", + ), + id="tz-aware string, naive pydatetime, with utc", + ), + ], + ) + @pytest.mark.parametrize( + "constructor", + [Timestamp, lambda x: Timestamp(x).to_pydatetime()], + ) + def test_to_datetime_mixed_datetime_and_string_with_format( + self, fmt, utc, args, expected, constructor + ): + # https://github.com/pandas-dev/pandas/issues/49298 + # https://github.com/pandas-dev/pandas/issues/50254 + # note: ISO8601 formats go down a fastpath, so we need to check both + # an ISO8601 format and a non-ISO8601 one + ts1 = constructor(args[0]) + ts2 = args[1] + result = to_datetime([ts1, ts2], format=fmt, utc=utc) + if constructor is Timestamp: + expected = expected.as_unit("us") + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "fmt", + ["%Y-%d-%m %H:%M:%S%z", "%Y-%m-%d %H:%M:%S%z"], + ids=["non-ISO8601 format", "ISO8601 format"], + ) + @pytest.mark.parametrize( + "constructor", + [Timestamp, lambda x: Timestamp(x).to_pydatetime()], + ) + def test_to_datetime_mixed_dt_and_str_with_format_mixed_offsets_utc_false_removed( + self, fmt, constructor + ): + # https://github.com/pandas-dev/pandas/issues/49298 + # https://github.com/pandas-dev/pandas/issues/50254 + # GH#57275 + # note: ISO8601 formats go down a fastpath, so we need to check both + # an ISO8601 format and a non-ISO8601 one + args = ["2000-01-01 01:00:00", "2000-01-01 02:00:00+00:00"] + ts1 = constructor(args[0]) + ts2 = args[1] + msg = "Mixed timezones detected. Pass utc=True in to_datetime" + + with pytest.raises(ValueError, match=msg): + to_datetime([ts1, ts2], format=fmt, utc=False) + + @pytest.mark.parametrize( + "fmt, expected", + [ + pytest.param( + "%Y-%m-%d %H:%M:%S%z", + [ + Timestamp("2000-01-01 09:00:00+0100", tz="UTC+01:00"), + Timestamp("2000-01-02 02:00:00+0200", tz="UTC+02:00"), + NaT, + ], + id="ISO8601, non-UTC", + ), + pytest.param( + "%Y-%d-%m %H:%M:%S%z", + [ + Timestamp("2000-01-01 09:00:00+0100", tz="UTC+01:00"), + Timestamp("2000-02-01 02:00:00+0200", tz="UTC+02:00"), + NaT, + ], + id="non-ISO8601, non-UTC", + ), + ], + ) + def test_to_datetime_mixed_offsets_with_none_tz_utc_false_removed( + self, fmt, expected + ): + # https://github.com/pandas-dev/pandas/issues/50071 + # GH#57275 + msg = "Mixed timezones detected. Pass utc=True in to_datetime" + + with pytest.raises(ValueError, match=msg): + to_datetime( + ["2000-01-01 09:00:00+01:00", "2000-01-02 02:00:00+02:00", None], + format=fmt, + utc=False, + ) + + @pytest.mark.parametrize( + "fmt, expected", + [ + pytest.param( + "%Y-%m-%d %H:%M:%S%z", + DatetimeIndex( + ["2000-01-01 08:00:00+00:00", "2000-01-02 00:00:00+00:00", "NaT"], + dtype="datetime64[us, UTC]", + ), + id="ISO8601, UTC", + ), + pytest.param( + "%Y-%d-%m %H:%M:%S%z", + DatetimeIndex( + ["2000-01-01 08:00:00+00:00", "2000-02-01 00:00:00+00:00", "NaT"], + dtype="datetime64[us, UTC]", + ), + id="non-ISO8601, UTC", + ), + ], + ) + def test_to_datetime_mixed_offsets_with_none(self, fmt, expected): + # https://github.com/pandas-dev/pandas/issues/50071 + result = to_datetime( + ["2000-01-01 09:00:00+01:00", "2000-01-02 02:00:00+02:00", None], + format=fmt, + utc=True, + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "fmt", + ["%Y-%d-%m %H:%M:%S%z", "%Y-%m-%d %H:%M:%S%z"], + ids=["non-ISO8601 format", "ISO8601 format"], + ) + @pytest.mark.parametrize( + "args", + [ + pytest.param( + ["2000-01-01 01:00:00-08:00", "2000-01-01 02:00:00-07:00"], + id="all tz-aware, mixed timezones, without utc", + ), + ], + ) + @pytest.mark.parametrize( + "constructor", + [Timestamp, lambda x: Timestamp(x).to_pydatetime()], + ) + def test_to_datetime_mixed_datetime_and_string_with_format_raises( + self, fmt, args, constructor + ): + # https://github.com/pandas-dev/pandas/issues/49298 + # note: ISO8601 formats go down a fastpath, so we need to check both + # an ISO8601 format and a non-ISO8601 one + ts1 = constructor(args[0]) + ts2 = constructor(args[1]) + with pytest.raises( + ValueError, match="cannot be converted to datetime64 unless utc=True" + ): + to_datetime([ts1, ts2], format=fmt, utc=False) + + def test_to_datetime_np_str(self): + # GH#32264 + # GH#48969 + value = np.str_("2019-02-04 10:18:46.297000+0000") + + ser = Series([value]) + + exp = Timestamp("2019-02-04 10:18:46.297000", tz="UTC") + + assert to_datetime(value) == exp + assert to_datetime(ser.iloc[0]) == exp + + res = to_datetime([value]) + expected = Index([exp]) + tm.assert_index_equal(res, expected) + + res = to_datetime(ser) + expected = Series(expected) + tm.assert_series_equal(res, expected) + + @pytest.mark.parametrize( + "s, _format, dt", + [ + ["2015-1-1", "%G-%V-%u", datetime(2014, 12, 29, 0, 0)], + ["2015-1-4", "%G-%V-%u", datetime(2015, 1, 1, 0, 0)], + ["2015-1-7", "%G-%V-%u", datetime(2015, 1, 4, 0, 0)], + ["2024-52-1", "%G-%V-%u", datetime(2024, 12, 23, 0, 0)], + ["2024-52-7", "%G-%V-%u", datetime(2024, 12, 29, 0, 0)], + ["2025-1-1", "%G-%V-%u", datetime(2024, 12, 30, 0, 0)], + ["2020-53-1", "%G-%V-%u", datetime(2020, 12, 28, 0, 0)], + ], + ) + def test_to_datetime_iso_week_year_format(self, s, _format, dt): + # See GH#16607 + assert to_datetime(s, format=_format) == dt + + @pytest.mark.parametrize( + "msg, s, _format", + [ + [ + "Week 53 does not exist in ISO year 2024", + "2024 53 1", + "%G %V %u", + ], + [ + "Week 53 does not exist in ISO year 2023", + "2023 53 1", + "%G %V %u", + ], + ], + ) + def test_invalid_iso_week_53(self, msg, s, _format): + # See GH#60885 + with pytest.raises(ValueError, match=msg): + to_datetime(s, format=_format) + + @pytest.mark.parametrize( + "msg, s, _format", + [ + [ + "ISO week directive '%V' is incompatible with the year directive " + "'%Y'. Use the ISO year '%G' instead.", + "1999 50", + "%Y %V", + ], + [ + "ISO year directive '%G' must be used with the ISO week directive " + "'%V' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 51", + "%G %V", + ], + [ + "ISO year directive '%G' must be used with the ISO week directive " + "'%V' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 Monday", + "%G %A", + ], + [ + "ISO year directive '%G' must be used with the ISO week directive " + "'%V' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 Mon", + "%G %a", + ], + [ + "ISO year directive '%G' must be used with the ISO week directive " + "'%V' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 6", + "%G %w", + ], + [ + "ISO year directive '%G' must be used with the ISO week directive " + "'%V' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 6", + "%G %u", + ], + [ + "ISO year directive '%G' must be used with the ISO week directive " + "'%V' and a weekday directive '%A', '%a', '%w', or '%u'.", + "2051", + "%G", + ], + [ + "Day of the year directive '%j' is not compatible with ISO year " + "directive '%G'. Use '%Y' instead.", + "1999 51 6 256", + "%G %V %u %j", + ], + [ + "ISO week directive '%V' is incompatible with the year directive " + "'%Y'. Use the ISO year '%G' instead.", + "1999 51 Sunday", + "%Y %V %A", + ], + [ + "ISO week directive '%V' is incompatible with the year directive " + "'%Y'. Use the ISO year '%G' instead.", + "1999 51 Sun", + "%Y %V %a", + ], + [ + "ISO week directive '%V' is incompatible with the year directive " + "'%Y'. Use the ISO year '%G' instead.", + "1999 51 1", + "%Y %V %w", + ], + [ + "ISO week directive '%V' is incompatible with the year directive " + "'%Y'. Use the ISO year '%G' instead.", + "1999 51 1", + "%Y %V %u", + ], + [ + "ISO week directive '%V' must be used with the ISO year directive " + "'%G' and a weekday directive '%A', '%a', '%w', or '%u'.", + "20", + "%V", + ], + [ + "ISO week directive '%V' must be used with the ISO year directive " + "'%G' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 51 Sunday", + "%V %A", + ], + [ + "ISO week directive '%V' must be used with the ISO year directive " + "'%G' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 51 Sun", + "%V %a", + ], + [ + "ISO week directive '%V' must be used with the ISO year directive " + "'%G' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 51 1", + "%V %w", + ], + [ + "ISO week directive '%V' must be used with the ISO year directive " + "'%G' and a weekday directive '%A', '%a', '%w', or '%u'.", + "1999 51 1", + "%V %u", + ], + [ + "Day of the year directive '%j' is not compatible with ISO year " + "directive '%G'. Use '%Y' instead.", + "1999 50", + "%G %j", + ], + [ + "ISO week directive '%V' must be used with the ISO year directive " + "'%G' and a weekday directive '%A', '%a', '%w', or '%u'.", + "20 Monday", + "%V %A", + ], + ], + ) + @pytest.mark.parametrize("errors", ["raise", "coerce"]) + def test_error_iso_week_year(self, msg, s, _format, errors): + # See GH#16607, GH#50308 + # This test checks for errors thrown when giving the wrong format + # However, as discussed on PR#25541, overriding the locale + # causes a different error to be thrown due to the format being + # locale specific, but the test data is in english. + # Therefore, the tests only run when locale is not overwritten, + # as a sort of solution to this problem. + if locale.getlocale() != ("zh_CN", "UTF-8") and locale.getlocale() != ( + "it_IT", + "UTF-8", + ): + with pytest.raises(ValueError, match=msg): + to_datetime(s, format=_format, errors=errors) + + @pytest.mark.parametrize("tz", [None, "US/Central"]) + def test_to_datetime_dtarr(self, tz): + # DatetimeArray + dti = date_range("1965-04-03", periods=19, freq="2W", tz=tz) + arr = dti._data + + result = to_datetime(arr) + assert result is arr + + # Doesn't work on Windows since tzpath not set correctly + @td.skip_if_windows + @pytest.mark.parametrize("utc", [True, False]) + @pytest.mark.parametrize("tz", [None, "US/Central"]) + def test_to_datetime_arrow(self, tz, utc, index_or_series): + pa = pytest.importorskip("pyarrow") + + dti = date_range("1965-04-03", periods=19, freq="2W", tz=tz) + dti = index_or_series(dti) + + dti_arrow = dti.astype(pd.ArrowDtype(pa.timestamp(unit="ns", tz=tz))) + + result = to_datetime(dti_arrow, utc=utc) + expected = to_datetime(dti, utc=utc).astype( + pd.ArrowDtype(pa.timestamp(unit="ns", tz=tz if not utc else "UTC")) + ) + if not utc and index_or_series is not Series: + # Doesn't hold for utc=True, since that will astype + # to_datetime also returns a new object for series + assert result is dti_arrow + if index_or_series is Series: + tm.assert_series_equal(result, expected) + else: + tm.assert_index_equal(result, expected) + + def test_to_datetime_pydatetime(self): + actual = to_datetime(datetime(2008, 1, 15)) + assert actual == datetime(2008, 1, 15) + + def test_to_datetime_YYYYMMDD(self): + actual = to_datetime("20080115") + assert actual == datetime(2008, 1, 15) + + @td.skip_if_windows # `tm.set_timezone` does not work in windows + @pytest.mark.skipif(WASM, reason="tzset is not available on WASM") + def test_to_datetime_now(self): + # See GH#18666 + with tm.set_timezone("US/Eastern"): + # GH#18705 + now = Timestamp("now") + pdnow = to_datetime("now") + pdnow2 = to_datetime(["now"])[0] + + # These should all be equal with infinite perf; this gives + # a generous margin of 10 seconds + assert abs(pdnow._value - now._value) < 1e10 + assert abs(pdnow2._value - now._value) < 1e10 + + assert pdnow.tzinfo is None + assert pdnow2.tzinfo is None + + @td.skip_if_windows # `tm.set_timezone` does not work on Windows + @pytest.mark.skipif(WASM, reason="tzset is not available on WASM") + @pytest.mark.parametrize("tz", ["Pacific/Auckland", "US/Samoa"]) + def test_to_datetime_today(self, tz): + # See GH#18666 + # Test with one timezone far ahead of UTC and another far behind, so + # one of these will _almost_ always be in a different day from UTC. + # Unfortunately this test between 12 and 1 AM Samoa time + # this both of these timezones _and_ UTC will all be in the same day, + # so this test will not detect the regression introduced in #18666. + with tm.set_timezone(tz): + nptoday = np.datetime64("today").astype("datetime64[us]").astype(np.int64) + pdtoday = to_datetime("today") + pdtoday2 = to_datetime(["today"])[0] + + tstoday = Timestamp("today") + tstoday2 = Timestamp.today() + + # These should all be equal with infinite perf; this gives + # a generous margin of 10 seconds + assert abs(pdtoday.normalize()._value - nptoday) < 1e10 + assert abs(pdtoday2.normalize()._value - nptoday) < 1e10 + assert abs(pdtoday._value - tstoday._value) < 1e10 + assert abs(pdtoday._value - tstoday2._value) < 1e10 + + assert pdtoday.tzinfo is None + assert pdtoday2.tzinfo is None + + @pytest.mark.parametrize("arg", ["now", "today"]) + def test_to_datetime_today_now_unicode_bytes(self, arg): + to_datetime([arg]) + + @pytest.mark.filterwarnings( + "ignore:Timestamp.utcnow is deprecated:DeprecationWarning" + ) + @pytest.mark.skipif(WASM, reason="tzset is not available on WASM") + @pytest.mark.parametrize( + "format, expected_ds", + [ + ("%Y-%m-%d %H:%M:%S%z", "2020-01-03"), + ("%Y-%d-%m %H:%M:%S%z", "2020-03-01"), + (None, "2020-01-03"), + ], + ) + @pytest.mark.parametrize( + "string, attribute", + [ + ("now", "utcnow"), + ("today", "today"), + ], + ) + def test_to_datetime_now_with_format(self, format, expected_ds, string, attribute): + # https://github.com/pandas-dev/pandas/issues/50359 + result = to_datetime(["2020-01-03 00:00:00Z", string], format=format, utc=True) + expected = DatetimeIndex( + [expected_ds, getattr(Timestamp, attribute)()], dtype="datetime64[s, UTC]" + ) + assert (expected - result).max().total_seconds() < 1 + + @pytest.mark.parametrize( + "dt", [np.datetime64("2000-01-01"), np.datetime64("2000-01-02")] + ) + def test_to_datetime_dt64s(self, cache, dt): + assert to_datetime(dt, cache=cache) == Timestamp(dt) + + @pytest.mark.parametrize( + "arg, format", + [ + ("2001-01-01", "%Y-%m-%d"), + ("01-01-2001", "%d-%m-%Y"), + ], + ) + def test_to_datetime_dt64s_and_str(self, arg, format): + # https://github.com/pandas-dev/pandas/issues/50036 + result = to_datetime([arg, np.datetime64("2020-01-01")], format=format) + expected = DatetimeIndex(["2001-01-01", "2020-01-01"]) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "dt", [np.datetime64("1000-01-01"), np.datetime64("5000-01-02")] + ) + @pytest.mark.parametrize("errors", ["raise", "coerce"]) + def test_to_datetime_dt64s_out_of_ns_bounds(self, cache, dt, errors): + # GH#50369 We cast to the nearest supported reso, i.e. "s" + ts = to_datetime(dt, errors=errors, cache=cache) + assert isinstance(ts, Timestamp) + assert ts.unit == "s" + assert ts.asm8 == dt + + ts = Timestamp(dt) + assert ts.unit == "s" + assert ts.asm8 == dt + + def test_to_datetime_dt64d_out_of_bounds(self, cache): + dt64 = np.datetime64(np.iinfo(np.int64).max, "D") + + msg = "Out of bounds second timestamp: 25252734927768524-07-27" + with pytest.raises(OutOfBoundsDatetime, match=msg): + Timestamp(dt64) + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(dt64, errors="raise", cache=cache) + + assert to_datetime(dt64, errors="coerce", cache=cache) is NaT + + @pytest.mark.parametrize("unit", ["s", "D"]) + def test_to_datetime_array_of_dt64s(self, cache, unit): + # https://github.com/pandas-dev/pandas/issues/31491 + # Need at least 50 to ensure cache is used. + dts = [ + np.datetime64("2000-01-01", unit), + np.datetime64("2000-01-02", unit), + ] * 30 + # Assuming all datetimes are in bounds, to_datetime() returns + # an array that is equal to Timestamp() parsing + result = to_datetime(dts, cache=cache) + expected = DatetimeIndex([Timestamp(x).asm8 for x in dts], dtype="M8[s]") + + tm.assert_index_equal(result, expected) + + # A list of datetimes where the last one is out of bounds + dts_with_oob = [*dts, np.datetime64("9999-01-01")] + + # As of GH#51978 we do not raise in this case + to_datetime(dts_with_oob, errors="raise") + + result = to_datetime(dts_with_oob, errors="coerce", cache=cache) + expected = DatetimeIndex(np.array(dts_with_oob, dtype="M8[s]")) + tm.assert_index_equal(result, expected) + + def test_to_datetime_tz(self, cache): + # xref 8260 + # uniform returns a DatetimeIndex + arr = [ + Timestamp("2013-01-01 13:00:00-0800", tz="US/Pacific"), + Timestamp("2013-01-02 14:00:00-0800", tz="US/Pacific"), + ] + result = to_datetime(arr, cache=cache) + expected = DatetimeIndex( + ["2013-01-01 13:00:00", "2013-01-02 14:00:00"], tz="US/Pacific" + ) + tm.assert_index_equal(result, expected) + + def test_to_datetime_tz_mixed(self, cache): + # mixed tzs will raise if errors='raise' + # https://github.com/pandas-dev/pandas/issues/50585 + arr = [ + Timestamp("2013-01-01 13:00:00", tz="US/Pacific"), + Timestamp("2013-01-02 14:00:00", tz="US/Eastern"), + ] + msg = ( + "Tz-aware datetime.datetime cannot be " + "converted to datetime64 unless utc=True" + ) + with pytest.raises(ValueError, match=msg): + to_datetime(arr, cache=cache) + + result = to_datetime(arr, cache=cache, errors="coerce") + expected = DatetimeIndex( + ["2013-01-01 13:00:00-08:00", "NaT"], dtype="datetime64[us, US/Pacific]" + ) + tm.assert_index_equal(result, expected) + + def test_to_datetime_different_offsets_removed(self, cache): + # inspired by asv timeseries.ToDatetimeNONISO8601 benchmark + # see GH-26097 for more + # GH#57275 + ts_string_1 = "March 1, 2018 12:00:00+0400" + ts_string_2 = "March 1, 2018 12:00:00+0500" + arr = [ts_string_1] * 5 + [ts_string_2] * 5 + msg = "Mixed timezones detected. Pass utc=True in to_datetime" + with pytest.raises(ValueError, match=msg): + to_datetime(arr, cache=cache) + + def test_to_datetime_tz_pytz(self, cache): + # see gh-8260 + pytz = pytest.importorskip("pytz") + us_eastern = pytz.timezone("US/Eastern") + arr = np.array( + [ + us_eastern.localize( + datetime(year=2000, month=1, day=1, hour=3, minute=0) + ), + us_eastern.localize( + datetime(year=2000, month=6, day=1, hour=3, minute=0) + ), + ], + dtype=object, + ) + result = to_datetime(arr, utc=True, cache=cache) + expected = DatetimeIndex( + ["2000-01-01 08:00:00+00:00", "2000-06-01 07:00:00+00:00"], + dtype="datetime64[us, UTC]", + freq=None, + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "init_constructor, end_constructor", + [ + (Index, DatetimeIndex), + (list, DatetimeIndex), + (np.array, DatetimeIndex), + (Series, Series), + ], + ) + def test_to_datetime_utc_true(self, cache, init_constructor, end_constructor): + # See gh-11934 & gh-6415 + data = ["20100102 121314", "20100102 121315"] + expected_data = [ + Timestamp("2010-01-02 12:13:14", tz="utc"), + Timestamp("2010-01-02 12:13:15", tz="utc"), + ] + + result = to_datetime( + init_constructor(data), format="%Y%m%d %H%M%S", utc=True, cache=cache + ) + expected = end_constructor(expected_data) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "scalar, expected", + [ + ["20100102 121314", Timestamp("2010-01-02 12:13:14", tz="utc")], + ["20100102 121315", Timestamp("2010-01-02 12:13:15", tz="utc")], + ], + ) + def test_to_datetime_utc_true_scalar(self, cache, scalar, expected): + # Test scalar case as well + result = to_datetime(scalar, format="%Y%m%d %H%M%S", utc=True, cache=cache) + assert result == expected + + def test_to_datetime_utc_true_with_series_single_value(self, cache): + # GH 15760 UTC=True with Series + ts = 1.5e18 + result = to_datetime(Series([ts]), utc=True, cache=cache) + expected = Series([Timestamp(ts, tz="utc")]) + tm.assert_series_equal(result, expected) + + def test_to_datetime_utc_true_with_series_tzaware_string(self, cache): + ts = "2013-01-01 00:00:00-01:00" + expected_ts = "2013-01-01 01:00:00" + data = Series([ts] * 3) + result = to_datetime(data, utc=True, cache=cache) + expected = Series([Timestamp(expected_ts, tz="utc")] * 3) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "date, dtype", + [ + ("2013-01-01 01:00:00", "datetime64[ns]"), + ("2013-01-01 01:00:00", "datetime64[ns, UTC]"), + ], + ) + def test_to_datetime_utc_true_with_series_datetime_ns(self, cache, date, dtype): + expected = Series( + [Timestamp("2013-01-01 01:00:00", tz="UTC")], dtype="M8[ns, UTC]" + ) + result = to_datetime(Series([date], dtype=dtype), utc=True, cache=cache) + tm.assert_series_equal(result, expected) + + def test_to_datetime_tz_psycopg2(self, request, cache): + # xref 8260 + psycopg2_tz = pytest.importorskip("psycopg2.tz") + + # misc cases + tz1 = psycopg2_tz.FixedOffsetTimezone(offset=-300, name=None) + tz2 = psycopg2_tz.FixedOffsetTimezone(offset=-240, name=None) + arr = np.array( + [ + datetime(2000, 1, 1, 3, 0, tzinfo=tz1), + datetime(2000, 6, 1, 3, 0, tzinfo=tz2), + ], + dtype=object, + ) + + result = to_datetime(arr, errors="coerce", utc=True, cache=cache) + expected = DatetimeIndex( + ["2000-01-01 08:00:00+00:00", "2000-06-01 07:00:00+00:00"], + dtype="datetime64[us, UTC]", + freq=None, + ) + tm.assert_index_equal(result, expected) + + # dtype coercion + i = DatetimeIndex( + ["2000-01-01 08:00:00"], + tz=psycopg2_tz.FixedOffsetTimezone(offset=-300, name=None), + ).as_unit("us") + assert not is_datetime64_ns_dtype(i) + + # tz coercion + result = to_datetime(i, errors="coerce", cache=cache) + tm.assert_index_equal(result, i) + + result = to_datetime(i, errors="coerce", utc=True, cache=cache) + expected = DatetimeIndex(["2000-01-01 13:00:00"], dtype="datetime64[us, UTC]") + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("arg", [True, False]) + def test_datetime_bool(self, cache, arg): + # GH13176 + msg = r"dtype bool cannot be converted to datetime64\[ns\]" + with pytest.raises(TypeError, match=msg): + to_datetime(arg) + assert to_datetime(arg, errors="coerce", cache=cache) is NaT + + def test_datetime_bool_arrays_mixed(self, cache): + msg = f"{type(cache)} is not convertible to datetime" + with pytest.raises(TypeError, match=msg): + to_datetime([False, datetime.today()], cache=cache) + with pytest.raises( + ValueError, + match=( + r'^time data "True" doesn\'t match format "%Y%m%d". ' + f"{PARSING_ERR_MSG}$" + ), + ): + to_datetime(["20130101", True], cache=cache) + tm.assert_index_equal( + to_datetime([0, False, NaT, 0.0], errors="coerce", cache=cache), + DatetimeIndex( + [to_datetime(0, cache=cache), NaT, NaT, to_datetime(0, cache=cache)] + ), + ) + + @pytest.mark.parametrize("arg", [bool, to_datetime]) + def test_datetime_invalid_datatype(self, arg): + # GH13176 + msg = "is not convertible to datetime" + with pytest.raises(TypeError, match=msg): + to_datetime(arg) + + @pytest.mark.parametrize("errors", ["coerce", "raise"]) + def test_invalid_format_raises(self, errors): + # https://github.com/pandas-dev/pandas/issues/50255 + with pytest.raises( + ValueError, match="':' is a bad directive in format 'H%:M%:S%" + ): + to_datetime(["00:00:00"], format="H%:M%:S%", errors=errors) + + @pytest.mark.parametrize("value", ["a", "00:01:99"]) + @pytest.mark.parametrize("format", [None, "%H:%M:%S"]) + def test_datetime_invalid_scalar(self, value, format): + # GH24763 + res = to_datetime(value, errors="coerce", format=format) + assert res is NaT + + msg = "|".join( + [ + r'^time data "a" doesn\'t match format "%H:%M:%S". ' + f"{PARSING_ERR_MSG}$", + r'^Given date string "a" not likely a datetime$', + r'^unconverted data remains when parsing with format "%H:%M:%S": "9". ' + f"{PARSING_ERR_MSG}$", + rf"^second must be in 0..59{NOT_99}: 00:01:99$", + ] + ) + with pytest.raises(ValueError, match=msg): + to_datetime(value, errors="raise", format=format) + + @pytest.mark.parametrize("value", ["3000/12/11 00:00:00"]) + @pytest.mark.parametrize("format", [None, "%H:%M:%S"]) + def test_datetime_outofbounds_scalar(self, value, format): + # GH24763 + res = to_datetime(value, errors="coerce", format=format) + if format is None: + assert isinstance(res, Timestamp) + assert res == Timestamp(value) + else: + assert res is NaT + + if format is not None: + msg = r'^time data ".*" doesn\'t match format ".*"' + with pytest.raises(ValueError, match=msg): + to_datetime(value, errors="raise", format=format) + else: + res = to_datetime(value, errors="raise", format=format) + assert isinstance(res, Timestamp) + assert res == Timestamp(value) + + @pytest.mark.parametrize( + ("values"), [(["a"]), (["00:01:99"]), (["a", "b", "99:00:00"])] + ) + @pytest.mark.parametrize("format", [(None), ("%H:%M:%S")]) + def test_datetime_invalid_index(self, values, format): + # GH24763 + # Not great to have logic in tests, but this one's hard to + # parametrise over + if format is None and len(values) > 1: + warn = UserWarning + else: + warn = None + + with tm.assert_produces_warning( + warn, match="Could not infer format", raise_on_extra_warnings=False + ): + res = to_datetime(values, errors="coerce", format=format) + tm.assert_index_equal(res, DatetimeIndex([NaT] * len(values))) + + msg = "|".join( + [ + r'^Given date string "a" not likely a datetime$', + r'^time data "a" doesn\'t match format "%H:%M:%S". ' + f"{PARSING_ERR_MSG}$", + r'^unconverted data remains when parsing with format "%H:%M:%S": "9". ' + f"{PARSING_ERR_MSG}$", + rf"^second must be in 0..59{NOT_99}: 00:01:99$", + ] + ) + with pytest.raises(ValueError, match=msg): + with tm.assert_produces_warning( + warn, match="Could not infer format", raise_on_extra_warnings=False + ): + to_datetime(values, errors="raise", format=format) + + @pytest.mark.parametrize("utc", [True, None]) + @pytest.mark.parametrize("format", ["%Y%m%d %H:%M:%S", None]) + @pytest.mark.parametrize("constructor", [list, tuple, np.array, Index, deque]) + def test_to_datetime_cache(self, utc, format, constructor): + date = "20130101 00:00:00" + test_dates = [date] * 10**5 + data = constructor(test_dates) + + result = to_datetime(data, utc=utc, format=format, cache=True) + expected = to_datetime(data, utc=utc, format=format, cache=False) + + tm.assert_index_equal(result, expected) + + def test_to_datetime_from_deque(self): + # GH 29403 + result = to_datetime(deque([Timestamp("2010-06-02 09:30:00")] * 51)) + expected = to_datetime([Timestamp("2010-06-02 09:30:00")] * 51) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("utc", [True, None]) + @pytest.mark.parametrize("format", ["%Y%m%d %H:%M:%S", None]) + def test_to_datetime_cache_series(self, utc, format): + date = "20130101 00:00:00" + test_dates = [date] * 10**5 + data = Series(test_dates) + result = to_datetime(data, utc=utc, format=format, cache=True) + expected = to_datetime(data, utc=utc, format=format, cache=False) + tm.assert_series_equal(result, expected) + + def test_to_datetime_cache_scalar(self): + date = "20130101 00:00:00" + result = to_datetime(date, cache=True) + expected = Timestamp("20130101 00:00:00") + assert result == expected + + @pytest.mark.parametrize( + "datetimelikes,expected_values,exp_unit", + ( + ( + (None, np.nan) + (NaT,) * start_caching_at, + (NaT,) * (start_caching_at + 2), + "s", + ), + ( + (None, Timestamp("2012-07-26").as_unit("s")) + + (NaT,) * start_caching_at, + (NaT, Timestamp("2012-07-26").as_unit("s")) + (NaT,) * start_caching_at, + "s", + ), + ( + (None,) + + (NaT,) * start_caching_at + + ("2012 July 26", Timestamp("2012-07-26")), + (NaT,) * (start_caching_at + 1) + + (Timestamp("2012-07-26"), Timestamp("2012-07-26")), + "us", + ), + ), + ) + def test_convert_object_to_datetime_with_cache( + self, datetimelikes, expected_values, exp_unit + ): + # GH#39882 + ser = Series( + datetimelikes, + dtype="object", + ) + result_series = to_datetime(ser, errors="coerce") + expected_series = Series( + expected_values, + dtype=f"datetime64[{exp_unit}]", + ) + tm.assert_series_equal(result_series, expected_series) + + @pytest.mark.parametrize( + "input", + [ + Series([NaT] * 20 + [None] * 20, dtype="object"), + Series([NaT] * 60 + [None] * 60, dtype="object"), + Series([None] * 20), + Series([None] * 60), + Series([""] * 20), + Series([""] * 60), + Series([pd.NA] * 20), + Series([pd.NA] * 60), + Series([np.nan] * 20), + Series([np.nan] * 60), + ], + ) + def test_to_datetime_converts_null_like_to_nat(self, cache, input): + # GH35888 + expected = Series([NaT] * len(input), dtype="M8[s]") + result = to_datetime(input, cache=cache) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "date, format", + [ + ("2017-20", "%Y-%W"), + ("20 Sunday", "%W %A"), + ("20 Sun", "%W %a"), + ("2017-21", "%Y-%U"), + ("20 Sunday", "%U %A"), + ("20 Sun", "%U %a"), + ], + ) + def test_week_without_day_and_calendar_year(self, date, format): + # GH16774 + + msg = "Cannot use '%W' or '%U' without day and year" + with pytest.raises(ValueError, match=msg): + to_datetime(date, format=format) + + def test_to_datetime_coerce(self): + # GH#26122, GH#57275 + ts_strings = [ + "March 1, 2018 12:00:00+0400", + "March 1, 2018 12:00:00+0500", + "20100240", + ] + msg = "Mixed timezones detected. Pass utc=True in to_datetime" + with pytest.raises(ValueError, match=msg): + to_datetime(ts_strings, errors="coerce") + + @pytest.mark.parametrize( + "string_arg, format", + [("March 1, 2018", "%B %d, %Y"), ("2018-03-01", "%Y-%m-%d")], + ) + @pytest.mark.parametrize( + "outofbounds", + [ + datetime(9999, 1, 1), + date(9999, 1, 1), + np.datetime64("9999-01-01"), + "January 1, 9999", + "9999-01-01", + ], + ) + def test_to_datetime_coerce_oob(self, string_arg, format, outofbounds): + # https://github.com/pandas-dev/pandas/issues/50255 + ts_strings = [string_arg, outofbounds] + result = to_datetime(ts_strings, errors="coerce", format=format) + if isinstance(outofbounds, str) and ( + format.startswith("%B") ^ outofbounds.startswith("J") + ): + # the strings don't match the given format, so they raise and we coerce + expected = DatetimeIndex([datetime(2018, 3, 1), NaT], dtype="M8[us]") + elif isinstance(outofbounds, datetime): + expected = DatetimeIndex( + [datetime(2018, 3, 1), outofbounds], dtype="M8[us]" + ) + else: + expected = DatetimeIndex( + [datetime(2018, 3, 1), outofbounds], dtype="M8[us]" + ) + tm.assert_index_equal(result, expected) + + def test_to_datetime_malformed_no_raise(self): + # GH 28299 + # GH 48633 + ts_strings = ["200622-12-31", "111111-24-11"] + with tm.assert_produces_warning( + UserWarning, match="Could not infer format", raise_on_extra_warnings=False + ): + result = to_datetime(ts_strings, errors="coerce") + # TODO: should Index get "s" by default here? + exp = Index([NaT, NaT], dtype="M8[s]") + tm.assert_index_equal(result, exp) + + def test_to_datetime_malformed_raise(self): + # GH 48633 + ts_strings = ["200622-12-31", "111111-24-11"] + msg = ( + 'Parsed string "200622-12-31" gives an invalid tzoffset, which must ' + r"be between -timedelta\(hours=24\) and timedelta\(hours=24\)" + ) + with pytest.raises( + ValueError, + match=msg, + ): + with tm.assert_produces_warning( + UserWarning, match="Could not infer format" + ): + to_datetime( + ts_strings, + errors="raise", + ) + + def test_iso_8601_strings_with_same_offset(self): + # GH 17697, 11736 + ts_str = "2015-11-18 15:30:00+05:30" + result = to_datetime(ts_str) + expected = Timestamp(ts_str) + assert result == expected + + expected = DatetimeIndex([Timestamp(ts_str)] * 2) + result = to_datetime([ts_str] * 2) + tm.assert_index_equal(result, expected) + + result = DatetimeIndex([ts_str] * 2) + tm.assert_index_equal(result, expected) + + def test_iso_8601_strings_with_different_offsets_removed(self): + # GH#17697, GH#11736, GH#50887, GH#57275 + ts_strings = ["2015-11-18 15:30:00+05:30", "2015-11-18 16:30:00+06:30", NaT] + msg = "Mixed timezones detected. Pass utc=True in to_datetime" + with pytest.raises(ValueError, match=msg): + to_datetime(ts_strings) + + def test_iso_8601_strings_with_different_offsets_utc(self): + ts_strings = ["2015-11-18 15:30:00+05:30", "2015-11-18 16:30:00+06:30", NaT] + result = to_datetime(ts_strings, utc=True) + expected = DatetimeIndex( + [Timestamp(2015, 11, 18, 10), Timestamp(2015, 11, 18, 10), NaT], tz="UTC" + ) + tm.assert_index_equal(result, expected) + + def test_mixed_offsets_with_native_datetime_utc_false_raises(self): + # GH#25978, GH#57275 + + vals = [ + "nan", + Timestamp("1990-01-01"), + "2015-03-14T16:15:14.123-08:00", + "2019-03-04T21:56:32.620-07:00", + None, + "today", + "now", + ] + ser = Series(vals) + assert all(ser[i] is vals[i] for i in range(len(vals))) # GH#40111 + + msg = "Mixed timezones detected. Pass utc=True in to_datetime" + with pytest.raises(ValueError, match=msg): + to_datetime(ser) + + def test_non_iso_strings_with_tz_offset(self): + result = to_datetime(["March 1, 2018 12:00:00+0400"] * 2) + expected = DatetimeIndex( + [datetime(2018, 3, 1, 12, tzinfo=timezone(timedelta(minutes=240)))] * 2 + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "ts, expected", + [ + (Timestamp("2018-01-01"), Timestamp("2018-01-01", tz="UTC")), + ( + Timestamp("2018-01-01", tz="US/Pacific"), + Timestamp("2018-01-01 08:00", tz="UTC"), + ), + ], + ) + def test_timestamp_utc_true(self, ts, expected): + # GH 24415 + result = to_datetime(ts, utc=True) + assert result == expected + + @pytest.mark.parametrize("dt_str", ["00010101", "13000101", "30000101", "99990101"]) + def test_to_datetime_with_format_out_of_bounds(self, dt_str): + # GH 9107 + res = to_datetime(dt_str, format="%Y%m%d") + dtobj = datetime.strptime(dt_str, "%Y%m%d") + expected = Timestamp(dtobj) + assert res == expected + assert res.unit == expected.unit + + def test_to_datetime_utc(self): + arr = np.array([parse("2012-06-13T01:39:00Z")], dtype=object) + + result = to_datetime(arr, utc=True) + assert result.tz is timezone.utc + + def test_to_datetime_fixed_offset(self): + from pandas.tests.indexes.datetimes.test_timezones import FixedOffset + + fixed_off = FixedOffset(-420, "-07:00") + + dates = [ + datetime(2000, 1, 1, tzinfo=fixed_off), + datetime(2000, 1, 2, tzinfo=fixed_off), + datetime(2000, 1, 3, tzinfo=fixed_off), + ] + result = to_datetime(dates) + assert result.tz == fixed_off + + @pytest.mark.parametrize( + "date", + [ + ["2020-10-26 00:00:00+06:00", "2020-10-26 00:00:00+01:00"], + ["2020-10-26 00:00:00+06:00", Timestamp("2018-01-01", tz="US/Pacific")], + [ + "2020-10-26 00:00:00+06:00", + datetime(2020, 1, 1, 18).astimezone( + zoneinfo.ZoneInfo("Australia/Melbourne") + ), + ], + ], + ) + def test_to_datetime_mixed_offsets_with_utc_false_removed(self, date): + # GH#50887, GH#57275 + msg = "Mixed timezones detected. Pass utc=True in to_datetime" + with pytest.raises(ValueError, match=msg): + to_datetime(date, utc=False) + + +class TestToDatetimeUnit: + @pytest.mark.parametrize("unit", ["Y", "M"]) + @pytest.mark.parametrize("item", [150, float(150)]) + def test_to_datetime_month_or_year_unit_int(self, cache, unit, item, request): + # GH#50870 Note we have separate tests that pd.Timestamp gets these right + ts = Timestamp(item, unit=unit) + dtype = "M8[s]" + expected = DatetimeIndex([ts], dtype=dtype) + + result = to_datetime([item], unit=unit, cache=cache) + tm.assert_index_equal(result, expected) + + result = to_datetime(np.array([item], dtype=object), unit=unit, cache=cache) + tm.assert_index_equal(result, expected) + + result = to_datetime(np.array([item]), unit=unit, cache=cache) + tm.assert_index_equal(result, expected) + + # with a nan! + result = to_datetime(np.array([item, np.nan]), unit=unit, cache=cache) + assert result.isna()[1] + tm.assert_index_equal(result[:1], expected.astype("M8[s]")) + + @pytest.mark.parametrize("unit", ["Y", "M"]) + def test_to_datetime_month_or_year_unit_non_round_float(self, cache, unit): + # GH#50301 + # Match Timestamp behavior in disallowing non-round floats with + # Y or M unit + msg = f"Conversion of non-round float with unit={unit} is ambiguous" + with pytest.raises(ValueError, match=msg): + to_datetime([1.5], unit=unit, errors="raise") + with pytest.raises(ValueError, match=msg): + to_datetime(np.array([1.5]), unit=unit, errors="raise") + + msg = r"Given date string \"1.5\" not likely a datetime" + with pytest.raises(ValueError, match=msg): + to_datetime(["1.5"], unit=unit, errors="raise") + + res = to_datetime([1.5], unit=unit, errors="coerce") + expected = Index([NaT], dtype="M8[ns]") + tm.assert_index_equal(res, expected) + + # In 3.0, the string "1.5" is parsed as as it would be without unit, + # which fails. With errors="coerce" this becomes NaT. + res = to_datetime(["1.5"], unit=unit, errors="coerce") + expected = to_datetime([NaT]) + tm.assert_index_equal(res, expected) + + # round floats are OK; treated like integers to give + # closest-to-supported unit + res = to_datetime([1.0], unit=unit) + expected = to_datetime([1], unit=unit).as_unit("s") + tm.assert_index_equal(res, expected) + + def test_unit(self, cache): + # GH 11758 + # test proper behavior with errors + msg = "cannot specify both format and unit" + with pytest.raises(ValueError, match=msg): + to_datetime([1], unit="D", format="%Y%m%d", cache=cache) + + def test_unit_array_mixed_nans(self, cache): + values = [11111111111111111, 1, 1.0, iNaT, NaT, np.nan, "NaT", ""] + + result = to_datetime(values, unit="D", errors="coerce", cache=cache) + expected = DatetimeIndex( + ["NaT", "1970-01-02", "1970-01-02", "NaT", "NaT", "NaT", "NaT", "NaT"], + dtype="M8[s]", + ) + tm.assert_index_equal(result, expected) + + msg = "cannot convert input 11111111111111111 with the unit 'D'" + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(values, unit="D", errors="raise", cache=cache) + + def test_unit_array_mixed_nans_large_int(self, cache): + values = [1420043460000000000000000, iNaT, NaT, np.nan, "NaT"] + + result = to_datetime(values, errors="coerce", unit="s", cache=cache) + expected = DatetimeIndex(["NaT", "NaT", "NaT", "NaT", "NaT"], dtype="M8[s]") + tm.assert_index_equal(result, expected) + + msg = "cannot convert input 1420043460000000000000000 with the unit 's'" + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(values, errors="raise", unit="s", cache=cache) + + def test_to_datetime_invalid_str_not_out_of_bounds_valuerror(self, cache): + # if we have a string, then we raise a ValueError + # and NOT an OutOfBoundsDatetime + msg = "Unknown datetime string format, unable to parse: foo" + with pytest.raises(ValueError, match=msg): + to_datetime("foo", errors="raise", unit="s", cache=cache) + + @pytest.mark.parametrize("error", ["raise", "coerce"]) + def test_unit_consistency(self, cache, error): + # consistency of conversions + expected = Timestamp("1970-05-09 14:25:11") + result = to_datetime(11111111, unit="s", errors=error, cache=cache) + assert result == expected + assert isinstance(result, Timestamp) + + @pytest.mark.parametrize("errors", ["raise", "coerce"]) + @pytest.mark.parametrize("dtype", ["float64", "int64"]) + def test_unit_with_numeric(self, cache, errors, dtype): + # GH 13180 + # coercions from floats/ints are ok + expected = DatetimeIndex( + ["2015-06-19 05:33:20", "2015-05-27 22:33:20"], dtype="M8[ns]" + ) + arr = np.array([1.434692e18, 1.432766e18]).astype(dtype) + result = to_datetime(arr, errors=errors, cache=cache) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "exp, arr, warning", + [ + [ + ["NaT", "2015-06-19 05:33:20", "2015-05-27 22:33:20"], + ["foo", 1.434692e18, 1.432766e18], + UserWarning, + ], + [ + ["2015-06-19 05:33:20", "2015-05-27 22:33:20", "NaT", "NaT"], + [1.434692e18, 1.432766e18, "foo", "NaT"], + None, + ], + ], + ) + def test_unit_with_numeric_coerce(self, cache, exp, arr, warning): + # but we want to make sure that we are coercing + # if we have ints/strings + expected = DatetimeIndex(exp, dtype="M8[ns]") + with tm.assert_produces_warning(warning, match="Could not infer format"): + result = to_datetime(arr, errors="coerce", cache=cache) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "arr", + [ + [Timestamp("20130101"), 1.434692e18, 1.432766e18], + [1.434692e18, 1.432766e18, Timestamp("20130101")], + ], + ) + def test_unit_mixed(self, cache, arr): + # GH#50453 pre-2.0 with mixed numeric/datetimes and errors="coerce" + # the numeric entries would be coerced to NaT, was never clear exactly + # why. + # mixed integers/datetimes + expected = Index([Timestamp(x) for x in arr], dtype="M8[ns]") + result = to_datetime(arr, errors="coerce", cache=cache) + tm.assert_index_equal(result, expected) + + # GH#49037 pre-2.0 this raised, but it always worked with Series, + # was never clear why it was disallowed + result = to_datetime(arr, errors="raise", cache=cache) + tm.assert_index_equal(result, expected) + + result = DatetimeIndex(arr) + tm.assert_index_equal(result, expected) + + def test_unit_rounding(self, cache): + # GH 14156 & GH 20445: argument will incur floating point errors + # but no premature rounding + value = 1434743731.8770001 + result = to_datetime(value, unit="s", cache=cache) + expected = Timestamp("2015-06-19 19:55:31.877000093") + assert result == expected + + alt = Timestamp(value, unit="s") + assert alt == result + + @pytest.mark.parametrize("dtype", [int, float]) + def test_to_datetime_unit(self, dtype): + epoch = 1370745748 + ser = Series([epoch + t for t in range(20)]).astype(dtype) + result = to_datetime(ser, unit="s") + unit = "s" + expected = Series( + [ + Timestamp("2013-06-09 02:42:28") + timedelta(seconds=t) + for t in range(20) + ], + dtype=f"M8[{unit}]", + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("null", [iNaT, np.nan]) + def test_to_datetime_unit_with_nulls(self, null): + epoch = 1370745748 + ser = Series([epoch + t for t in range(20)] + [null]) + result = to_datetime(ser, unit="s") + # With np.nan, the list gets cast to a float64 array, which always + # gets ns unit. + unit = "s" + expected = Series( + [Timestamp("2013-06-09 02:42:28") + timedelta(seconds=t) for t in range(20)] + + [NaT], + dtype=f"M8[{unit}]", + ) + tm.assert_series_equal(result, expected) + + def test_to_datetime_unit_fractional_seconds(self): + # GH13834 + epoch = 1370745748 + ser = Series([epoch + t for t in np.arange(0, 2, 0.25)] + [iNaT]).astype(float) + result = to_datetime(ser, unit="s") + expected = Series( + [ + Timestamp("2013-06-09 02:42:28") + timedelta(seconds=t) + for t in np.arange(0, 2, 0.25) + ] + + [NaT], + dtype="M8[ns]", + ) + # GH20455 argument will incur floating point errors but no premature rounding + result = result.dt.round("ms") + tm.assert_series_equal(result, expected) + + def test_to_datetime_unit_na_values(self): + result = to_datetime([1, 2, "NaT", NaT, np.nan], unit="D") + expected = DatetimeIndex( + [Timestamp("1970-01-02"), Timestamp("1970-01-03")] + ["NaT"] * 3, + dtype="M8[s]", + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize("bad_val", ["foo", 111111111111111]) + def test_to_datetime_unit_invalid(self, bad_val): + if bad_val == "foo": + msg = f"Unknown datetime string format, unable to parse: {bad_val}" + else: + msg = "cannot convert input 111111111111111 with the unit 'D'" + with pytest.raises(ValueError, match=msg): + to_datetime([1, 2, bad_val], unit="D") + + @pytest.mark.parametrize("bad_val", ["foo", 111111111111111]) + def test_to_timestamp_unit_coerce(self, bad_val): + # coerce we can process + expected = DatetimeIndex( + [Timestamp("1970-01-02"), Timestamp("1970-01-03")] + ["NaT"] * 1, + dtype="M8[s]", + ) + result = to_datetime([1, 2, bad_val], unit="D", errors="coerce") + tm.assert_index_equal(result, expected) + + def test_float_to_datetime_raise_near_bounds(self): + # GH50183 + msg = "cannot convert input with unit 'D'" + oneday_in_ns = 1e9 * 60 * 60 * 24 + tsmax_in_days = 2**63 / oneday_in_ns # 2**63 ns, in days + # just in bounds + should_succeed = Series( + [0, tsmax_in_days - 0.005, -tsmax_in_days + 0.005], dtype=float + ) + expected = (should_succeed * oneday_in_ns).astype(np.int64) + for error_mode in ["raise", "coerce"]: + result1 = to_datetime(should_succeed, unit="D", errors=error_mode) + # Cast to `np.float64` so that `rtol` and inexact checking kick in + # (`check_exact` doesn't take place for integer dtypes) + tm.assert_almost_equal( + result1.astype(np.int64).astype(np.float64), + expected.astype(np.float64), + rtol=1e-10, + ) + # just out of bounds + should_fail1 = Series([0, tsmax_in_days + 0.005], dtype=float) + should_fail2 = Series([0, -tsmax_in_days - 0.005], dtype=float) + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(should_fail1, unit="D", errors="raise") + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(should_fail2, unit="D", errors="raise") + + +class TestToDatetimeDataFrame: + @pytest.fixture + def df(self): + return DataFrame( + { + "year": [2015, 2016], + "month": [2, 3], + "day": [4, 5], + "hour": [6, 7], + "minute": [58, 59], + "second": [10, 11], + "ms": [1, 1], + "us": [2, 2], + "ns": [3, 3], + } + ) + + def test_dataframe(self, df, cache): + result = to_datetime( + {"year": df["year"], "month": df["month"], "day": df["day"]}, cache=cache + ) + expected = Series( + [Timestamp("20150204 00:00:00"), Timestamp("20160305 00:0:00")] + ) + tm.assert_series_equal(result, expected) + + # dict-like + result = to_datetime(df[["year", "month", "day"]].to_dict(), cache=cache) + expected.index = Index([0, 1]) + tm.assert_series_equal(result, expected) + + def test_dataframe_dict_with_constructable(self, df, cache): + # dict but with constructable + df2 = df[["year", "month", "day"]].to_dict() + df2["month"] = 2 + result = to_datetime(df2, cache=cache) + expected2 = Series( + [Timestamp("20150204 00:00:00"), Timestamp("20160205 00:0:00")], + index=Index([0, 1]), + ) + tm.assert_series_equal(result, expected2) + + @pytest.mark.parametrize( + "unit", + [ + { + "year": "years", + "month": "months", + "day": "days", + "hour": "hours", + "minute": "minutes", + "second": "seconds", + }, + { + "year": "year", + "month": "month", + "day": "day", + "hour": "hour", + "minute": "minute", + "second": "second", + }, + ], + ) + def test_dataframe_field_aliases_column_subset(self, df, cache, unit): + # unit mappings + result = to_datetime(df[list(unit.keys())].rename(columns=unit), cache=cache) + expected = Series( + [Timestamp("20150204 06:58:10"), Timestamp("20160305 07:59:11")], + dtype="M8[us]", + ) + tm.assert_series_equal(result, expected) + + def test_dataframe_field_aliases(self, df, cache): + d = { + "year": "year", + "month": "month", + "day": "day", + "hour": "hour", + "minute": "minute", + "second": "second", + "ms": "ms", + "us": "us", + "ns": "ns", + } + + result = to_datetime(df.rename(columns=d), cache=cache) + expected = Series( + [ + Timestamp("20150204 06:58:10.001002003"), + Timestamp("20160305 07:59:11.001002003"), + ] + ) + tm.assert_series_equal(result, expected) + + def test_dataframe_str_dtype(self, df, cache): + # coerce back to int + result = to_datetime(df.astype(str), cache=cache) + expected = Series( + [ + Timestamp("20150204 06:58:10.001002003"), + Timestamp("20160305 07:59:11.001002003"), + ] + ) + tm.assert_series_equal(result, expected) + + def test_dataframe_float32_dtype(self, df, cache): + # GH#60506 + # coerce to float64 + result = to_datetime(df.astype(np.float32), cache=cache) + expected = Series( + [ + Timestamp("20150204 06:58:10.001002003"), + Timestamp("20160305 07:59:11.001002003"), + ] + ) + tm.assert_series_equal(result, expected) + + def test_dataframe_coerce(self, cache): + # passing coerce + df2 = DataFrame({"year": [2015, 2016], "month": [2, 20], "day": [4, 5]}) + + msg = ( + r'^cannot assemble the datetimes: time data ".+" doesn\'t ' + r'match format "%Y%m%d"\.' + ) + with pytest.raises(ValueError, match=msg): + to_datetime(df2, cache=cache) + + result = to_datetime(df2, errors="coerce", cache=cache) + expected = Series([Timestamp("20150204 00:00:00"), NaT]) + tm.assert_series_equal(result, expected) + + def test_dataframe_extra_keys_raises(self, df, cache): + # extra columns + msg = r"extra keys have been passed to the datetime assemblage: \[foo\]" + df2 = df.copy() + df2["foo"] = 1 + with pytest.raises(ValueError, match=msg): + to_datetime(df2, cache=cache) + + @pytest.mark.parametrize( + "cols", + [ + ["year"], + ["year", "month"], + ["year", "month", "second"], + ["month", "day"], + ["year", "day", "second"], + ], + ) + def test_dataframe_missing_keys_raises(self, df, cache, cols): + # not enough + msg = ( + r"to assemble mappings requires at least that \[year, month, " + r"day\] be specified: \[.+\] is missing" + ) + with pytest.raises(ValueError, match=msg): + to_datetime(df[cols], cache=cache) + + def test_dataframe_duplicate_columns_raises(self, cache): + # duplicates + msg = "cannot assemble with duplicate keys" + df2 = DataFrame({"year": [2015, 2016], "month": [2, 20], "day": [4, 5]}) + df2.columns = ["year", "year", "day"] + with pytest.raises(ValueError, match=msg): + to_datetime(df2, cache=cache) + + df2 = DataFrame( + {"year": [2015, 2016], "month": [2, 20], "day": [4, 5], "hour": [4, 5]} + ) + df2.columns = ["year", "month", "day", "day"] + with pytest.raises(ValueError, match=msg): + to_datetime(df2, cache=cache) + + def test_dataframe_int16(self, cache): + # GH#13451 + df = DataFrame({"year": [2015, 2016], "month": [2, 3], "day": [4, 5]}) + + # int16 + result = to_datetime(df.astype("int16"), cache=cache) + expected = Series( + [Timestamp("20150204 00:00:00"), Timestamp("20160305 00:00:00")] + ) + tm.assert_series_equal(result, expected) + + def test_dataframe_mixed(self, cache): + # mixed dtypes + df = DataFrame({"year": [2015, 2016], "month": [2, 3], "day": [4, 5]}) + df["month"] = df["month"].astype("int8") + df["day"] = df["day"].astype("int8") + result = to_datetime(df, cache=cache) + expected = Series( + [Timestamp("20150204 00:00:00"), Timestamp("20160305 00:00:00")] + ) + tm.assert_series_equal(result, expected) + + def test_dataframe_float(self, cache): + # float + df = DataFrame({"year": [2000, 2001], "month": [1.5, 1], "day": [1, 1]}) + msg = ( + r"^cannot assemble the datetimes: unconverted data remains when parsing " + r'with format ".*": "1".' + ) + with pytest.raises(ValueError, match=msg): + to_datetime(df, cache=cache) + + def test_dataframe_utc_true(self): + # GH#23760 + df = DataFrame({"year": [2015, 2016], "month": [2, 3], "day": [4, 5]}) + result = to_datetime(df, utc=True) + expected = Series( + np.array(["2015-02-04", "2016-03-05"], dtype="datetime64[us]") + ).dt.tz_localize("UTC") + tm.assert_series_equal(result, expected) + + +class TestToDatetimeMisc: + def test_to_datetime_barely_out_of_bounds(self): + # GH#19529 + # GH#19382 close enough to bounds that dropping nanos would result + # in an in-bounds datetime + arr = np.array(["2262-04-11 23:47:16.854775808"], dtype=object) + + msg = "^Out of bounds nanosecond timestamp: .*" + with pytest.raises(OutOfBoundsDatetime, match=msg): + to_datetime(arr) + + @pytest.mark.parametrize( + "arg, exp_str", + [ + ["2012-01-01 00:00:00", "2012-01-01 00:00:00"], + ["20121001", "2012-10-01"], # bad iso 8601 + ], + ) + def test_to_datetime_iso8601(self, cache, arg, exp_str): + result = to_datetime([arg], cache=cache) + exp = Timestamp(exp_str) + assert result[0] == exp + + @pytest.mark.parametrize( + "input, format", + [ + ("2012", "%Y-%m"), + ("2012-01", "%Y-%m-%d"), + ("2012-01-01", "%Y-%m-%d %H"), + ("2012-01-01 10", "%Y-%m-%d %H:%M"), + ("2012-01-01 10:00", "%Y-%m-%d %H:%M:%S"), + ("2012-01-01 10:00:00", "%Y-%m-%d %H:%M:%S.%f"), + ("2012-01-01 10:00:00.123", "%Y-%m-%d %H:%M:%S.%f%z"), + (0, "%Y-%m-%d"), + ], + ) + @pytest.mark.parametrize("exact", [True, False]) + def test_to_datetime_iso8601_fails(self, input, format, exact): + # https://github.com/pandas-dev/pandas/issues/12649 + # `format` is longer than the string, so this fails regardless of `exact` + with pytest.raises( + ValueError, + match=(rf"time data \"{input}\" doesn't match format " rf"\"{format}\""), + ): + to_datetime(input, format=format, exact=exact) + + @pytest.mark.parametrize( + "input, format", + [ + ("2012-01-01", "%Y-%m"), + ("2012-01-01 10", "%Y-%m-%d"), + ("2012-01-01 10:00", "%Y-%m-%d %H"), + ("2012-01-01 10:00:00", "%Y-%m-%d %H:%M"), + (0, "%Y-%m-%d"), + ], + ) + def test_to_datetime_iso8601_exact_fails(self, input, format): + # https://github.com/pandas-dev/pandas/issues/12649 + # `format` is shorter than the date string, so only fails with `exact=True` + msg = "|".join( + [ + '^unconverted data remains when parsing with format ".*": ".*". ' + f"{PARSING_ERR_MSG}$", + f'^time data ".*" doesn\'t match format ".*". {PARSING_ERR_MSG}$', + ] + ) + with pytest.raises( + ValueError, + match=(msg), + ): + to_datetime(input, format=format) + + @pytest.mark.parametrize( + "input, format", + [ + ("2012-01-01", "%Y-%m"), + ("2012-01-01 00", "%Y-%m-%d"), + ("2012-01-01 00:00", "%Y-%m-%d %H"), + ("2012-01-01 00:00:00", "%Y-%m-%d %H:%M"), + ], + ) + def test_to_datetime_iso8601_non_exact(self, input, format): + # https://github.com/pandas-dev/pandas/issues/12649 + expected = Timestamp(2012, 1, 1) + result = to_datetime(input, format=format, exact=False) + assert result == expected + + @pytest.mark.parametrize( + "input, format", + [ + ("2020-01", "%Y/%m"), + ("2020-01-01", "%Y/%m/%d"), + ("2020-01-01 00", "%Y/%m/%dT%H"), + ("2020-01-01T00", "%Y/%m/%d %H"), + ("2020-01-01 00:00", "%Y/%m/%dT%H:%M"), + ("2020-01-01T00:00", "%Y/%m/%d %H:%M"), + ("2020-01-01 00:00:00", "%Y/%m/%dT%H:%M:%S"), + ("2020-01-01T00:00:00", "%Y/%m/%d %H:%M:%S"), + ], + ) + def test_to_datetime_iso8601_separator(self, input, format): + # https://github.com/pandas-dev/pandas/issues/12649 + with pytest.raises( + ValueError, + match=(rf"time data \"{input}\" doesn\'t match format " rf"\"{format}\""), + ): + to_datetime(input, format=format) + + @pytest.mark.parametrize( + "input, format", + [ + ("2020-01", "%Y-%m"), + ("2020-01-01", "%Y-%m-%d"), + ("2020-01-01 00", "%Y-%m-%d %H"), + ("2020-01-01T00", "%Y-%m-%dT%H"), + ("2020-01-01 00:00", "%Y-%m-%d %H:%M"), + ("2020-01-01T00:00", "%Y-%m-%dT%H:%M"), + ("2020-01-01 00:00:00", "%Y-%m-%d %H:%M:%S"), + ("2020-01-01T00:00:00", "%Y-%m-%dT%H:%M:%S"), + ("2020-01-01T00:00:00.000", "%Y-%m-%dT%H:%M:%S.%f"), + ("2020-01-01T00:00:00.000000", "%Y-%m-%dT%H:%M:%S.%f"), + ("2020-01-01T00:00:00.000000000", "%Y-%m-%dT%H:%M:%S.%f"), + ], + ) + def test_to_datetime_iso8601_valid(self, input, format): + # https://github.com/pandas-dev/pandas/issues/12649 + expected = Timestamp(2020, 1, 1) + result = to_datetime(input, format=format) + assert result == expected + + @pytest.mark.parametrize( + "input, format", + [ + ("2020-1", "%Y-%m"), + ("2020-1-1", "%Y-%m-%d"), + ("2020-1-1 0", "%Y-%m-%d %H"), + ("2020-1-1T0", "%Y-%m-%dT%H"), + ("2020-1-1 0:0", "%Y-%m-%d %H:%M"), + ("2020-1-1T0:0", "%Y-%m-%dT%H:%M"), + ("2020-1-1 0:0:0", "%Y-%m-%d %H:%M:%S"), + ("2020-1-1T0:0:0", "%Y-%m-%dT%H:%M:%S"), + ("2020-1-1T0:0:0.000", "%Y-%m-%dT%H:%M:%S.%f"), + ("2020-1-1T0:0:0.000000", "%Y-%m-%dT%H:%M:%S.%f"), + ("2020-1-1T0:0:0.000000000", "%Y-%m-%dT%H:%M:%S.%f"), + ], + ) + def test_to_datetime_iso8601_non_padded(self, input, format): + # https://github.com/pandas-dev/pandas/issues/21422 + expected = Timestamp(2020, 1, 1) + result = to_datetime(input, format=format) + assert result == expected + + @pytest.mark.parametrize( + "input, format", + [ + ("2020-01-01T00:00:00.000000000+00:00", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2020-01-01T00:00:00+00:00", "%Y-%m-%dT%H:%M:%S%z"), + ("2020-01-01T00:00:00Z", "%Y-%m-%dT%H:%M:%S%z"), + ], + ) + def test_to_datetime_iso8601_with_timezone_valid(self, input, format): + # https://github.com/pandas-dev/pandas/issues/12649 + expected = Timestamp(2020, 1, 1, tzinfo=timezone.utc) + result = to_datetime(input, format=format) + assert result == expected + + def test_to_datetime_default(self, cache): + rs = to_datetime("2001", cache=cache) + xp = datetime(2001, 1, 1) + assert rs == xp + + @pytest.mark.xfail(reason="fails to enforce dayfirst=True, which would raise") + def test_to_datetime_respects_dayfirst(self, cache): + # dayfirst is essentially broken + + # The msg here is not important since it isn't actually raised yet. + msg = "Invalid date specified" + with pytest.raises(ValueError, match=msg): + # if dayfirst is respected, then this would parse as month=13, which + # would raise + with tm.assert_produces_warning(UserWarning, match="Provide format"): + to_datetime("01-13-2012", dayfirst=True, cache=cache) + + def test_to_datetime_on_datetime64_series(self, cache): + # #2699 + ser = Series(date_range("1/1/2000", periods=10)) + + result = to_datetime(ser, cache=cache) + assert result[0] == ser[0] + + def test_to_datetime_with_space_in_series(self, cache): + # GH 6428 + ser = Series(["10/18/2006", "10/18/2008", " "]) + msg = ( + r'^time data " " doesn\'t match format "%m/%d/%Y". ' rf"{PARSING_ERR_MSG}$" + ) + with pytest.raises(ValueError, match=msg): + to_datetime(ser, errors="raise", cache=cache) + result_coerce = to_datetime(ser, errors="coerce", cache=cache) + expected_coerce = Series([datetime(2006, 10, 18), datetime(2008, 10, 18), NaT]) + tm.assert_series_equal(result_coerce, expected_coerce) + + @td.skip_if_not_us_locale + def test_to_datetime_with_apply(self, cache): + # this is only locale tested with US/None locales + # GH 5195 + # with a format and coerce a single item to_datetime fails + td = Series(["May 04", "Jun 02", "Dec 11"], index=[1, 2, 3]) + expected = to_datetime(td, format="%b %y", cache=cache) + result = td.apply(to_datetime, format="%b %y", cache=cache) + tm.assert_series_equal(result, expected) + + def test_to_datetime_timezone_name(self): + # https://github.com/pandas-dev/pandas/issues/49748 + result = to_datetime("2020-01-01 00:00:00UTC", format="%Y-%m-%d %H:%M:%S%Z") + expected = Timestamp(2020, 1, 1).tz_localize("UTC") + assert result == expected + + @td.skip_if_not_us_locale + @pytest.mark.parametrize("errors", ["raise", "coerce"]) + def test_to_datetime_with_apply_with_empty_str(self, cache, errors): + # this is only locale tested with US/None locales + # GH 5195, GH50251 + # with a format and coerce a single item to_datetime fails + td = Series(["May 04", "Jun 02", ""], index=[1, 2, 3]) + expected = to_datetime(td, format="%b %y", errors=errors, cache=cache) + + result = td.apply( + lambda x: to_datetime(x, format="%b %y", errors="coerce", cache=cache) + ) + tm.assert_series_equal(result, expected) + + def test_to_datetime_empty_stt(self, cache): + # empty string + result = to_datetime("", cache=cache) + assert result is NaT + + def test_to_datetime_empty_str_list(self, cache): + result = to_datetime(["", ""], cache=cache) + assert isna(result).all() + + def test_to_datetime_zero(self, cache): + # ints + result = Timestamp(0) + expected = to_datetime(0, cache=cache) + assert result == expected + + def test_to_datetime_strings(self, cache): + # GH 3888 (strings) + expected = to_datetime(["2012"], cache=cache)[0] + result = to_datetime("2012", cache=cache) + assert result == expected + + def test_to_datetime_strings_variation(self, cache): + array = ["2012", "20120101", "20120101 12:01:01"] + expected = [to_datetime(dt_str, cache=cache) for dt_str in array] + result = [Timestamp(date_str) for date_str in array] + tm.assert_almost_equal(result, expected) + + @pytest.mark.parametrize("result", [Timestamp("2012"), to_datetime("2012")]) + def test_to_datetime_strings_vs_constructor(self, result): + expected = Timestamp(2012, 1, 1) + assert result == expected + + def test_to_datetime_unprocessable_input(self, cache): + # GH 4928 + # GH 21864 + msg = '^Given date string "1" not likely a datetime$' + with pytest.raises(ValueError, match=msg): + to_datetime([1, "1"], errors="raise", cache=cache) + + def test_to_datetime_other_datetime64_units(self): + # 5/25/2012 + scalar = np.int64(1337904000000000).view("M8[us]") + as_obj = scalar.astype("O") + + index = DatetimeIndex([scalar]) + assert index[0] == scalar.astype("O") + + value = Timestamp(scalar) + assert value == as_obj + + def test_to_datetime_list_of_integers(self): + rng = date_range("1/1/2000", periods=20, unit="ns") + rng = DatetimeIndex(rng.values) + + ints = list(rng.asi8) + + result = DatetimeIndex(ints) + + tm.assert_index_equal(rng, result) + + def test_to_datetime_overflow(self): + # gh-17637 + # we are overflowing Timedelta range here + msg = "Cannot cast 139999 days 00:00:00 to unit='ns' without overflow" + with pytest.raises(OutOfBoundsTimedelta, match=msg): + date_range(start="1/1/1700", freq="B", periods=100000, unit="ns") + + def test_to_datetime_float_with_nans_floating_point_error(self): + # GH#58419 + ser = Series([np.nan] * 1000 + [1712219033.0], dtype=np.float64) + result = to_datetime(ser, unit="s", errors="coerce") + expected = Series( + [NaT] * 1000 + [Timestamp("2024-04-04 08:23:53")], dtype="datetime64[s]" + ) + tm.assert_series_equal(result, expected) + + def test_string_invalid_operation(self, cache): + invalid = np.array(["87156549591102612381000001219H5"], dtype=object) + # GH #51084 + + with pytest.raises(ValueError, match="Unknown datetime string format"): + to_datetime(invalid, errors="raise", cache=cache) + + def test_string_na_nat_conversion(self, cache): + # GH #999, #858 + + strings = np.array(["1/1/2000", "1/2/2000", np.nan, "1/4/2000"], dtype=object) + + expected = np.empty(4, dtype="M8[us]") + for i, val in enumerate(strings): + if isna(val): + expected[i] = iNaT + else: + expected[i] = parse(val) + + result = tslib.array_to_datetime(strings)[0] + tm.assert_almost_equal(result, expected) + + result2 = to_datetime(strings, cache=cache) + assert isinstance(result2, DatetimeIndex) + tm.assert_numpy_array_equal(result, result2.values) + + def test_string_na_nat_conversion_malformed(self, cache): + malformed = np.array(["1/100/2000", np.nan], dtype=object) + + # GH 10636, default is now 'raise' + msg = r"Unknown datetime string format" + with pytest.raises(ValueError, match=msg): + to_datetime(malformed, errors="raise", cache=cache) + + with pytest.raises(ValueError, match=msg): + to_datetime(malformed, errors="raise", cache=cache) + + def test_string_na_nat_conversion_with_name(self, cache): + idx = ["a", "b", "c", "d", "e"] + series = Series( + ["1/1/2000", np.nan, "1/3/2000", np.nan, "1/5/2000"], index=idx, name="foo" + ) + dseries = Series( + [ + to_datetime("1/1/2000", cache=cache), + np.nan, + to_datetime("1/3/2000", cache=cache), + np.nan, + to_datetime("1/5/2000", cache=cache), + ], + index=idx, + name="foo", + ) + + result = to_datetime(series, cache=cache) + dresult = to_datetime(dseries, cache=cache) + + expected = Series(np.empty(5, dtype="M8[us]"), index=idx) + for i in range(5): + x = series.iloc[i] + if isna(x): + expected.iloc[i] = NaT + else: + expected.iloc[i] = to_datetime(x, cache=cache) + + tm.assert_series_equal(result, expected, check_names=False) + assert result.name == "foo" + + tm.assert_series_equal(dresult, expected, check_names=False) + assert dresult.name == "foo" + + @pytest.mark.parametrize( + "unit", + ["h", "m", "s", "ms", "us", "ns"], + ) + def test_dti_constructor_numpy_timeunits(self, cache, unit): + # GH 9114 + dtype = np.dtype(f"M8[{unit}]") + base = to_datetime(["2000-01-01T00:00", "2000-01-02T00:00", "NaT"], cache=cache) + + values = base.values.astype(dtype) + + if unit in ["h", "m"]: + # we cast to closest supported unit + unit = "s" + exp_dtype = np.dtype(f"M8[{unit}]") + expected = DatetimeIndex(base.astype(exp_dtype)) + assert expected.dtype == exp_dtype + + tm.assert_index_equal(DatetimeIndex(values), expected) + tm.assert_index_equal(to_datetime(values, cache=cache), expected) + + def test_dayfirst(self, cache): + # GH 5917 + arr = ["10/02/2014", "11/02/2014", "12/02/2014"] + expected = DatetimeIndex( + [datetime(2014, 2, 10), datetime(2014, 2, 11), datetime(2014, 2, 12)] + ) + idx1 = DatetimeIndex(arr, dayfirst=True) + idx2 = DatetimeIndex(np.array(arr), dayfirst=True) + idx3 = to_datetime(arr, dayfirst=True, cache=cache) + idx4 = to_datetime(np.array(arr), dayfirst=True, cache=cache) + idx5 = DatetimeIndex(Index(arr), dayfirst=True) + idx6 = DatetimeIndex(Series(arr), dayfirst=True) + tm.assert_index_equal(expected, idx1) + tm.assert_index_equal(expected, idx2) + tm.assert_index_equal(expected, idx3) + tm.assert_index_equal(expected, idx4) + tm.assert_index_equal(expected, idx5) + tm.assert_index_equal(expected, idx6) + + def test_dayfirst_warnings_valid_input(self): + # GH 12585 + warning_msg = ( + "Parsing dates in .* format when dayfirst=.* was specified. " + "Pass `dayfirst=.*` or specify a format to silence this warning." + ) + + # CASE 1: valid input + arr = ["31/12/2014", "10/03/2011"] + expected = DatetimeIndex( + ["2014-12-31", "2011-03-10"], dtype="datetime64[us]", freq=None + ) + + # A. dayfirst arg correct, no warning + res1 = to_datetime(arr, dayfirst=True) + tm.assert_index_equal(expected, res1) + + # B. dayfirst arg incorrect, warning + with tm.assert_produces_warning(UserWarning, match=warning_msg): + res2 = to_datetime(arr, dayfirst=False) + tm.assert_index_equal(expected, res2) + + def test_dayfirst_warnings_invalid_input(self): + # CASE 2: invalid input + # cannot consistently process with single format + # ValueError *always* raised + + # first in DD/MM/YYYY, second in MM/DD/YYYY + arr = ["31/12/2014", "03/30/2011"] + + with pytest.raises( + ValueError, + match=( + r'^time data "03/30/2011" doesn\'t match format ' + rf'"%d/%m/%Y". {PARSING_ERR_MSG}$' + ), + ): + to_datetime(arr, dayfirst=True) + + @pytest.mark.parametrize("klass", [DatetimeIndex, DatetimeArray._from_sequence]) + def test_to_datetime_dta_tz(self, klass): + # GH#27733 + dti = date_range("2015-04-05", periods=3).rename("foo") + expected = dti.tz_localize("UTC") + + obj = klass(dti) + expected = klass(expected) + + result = to_datetime(obj, utc=True) + tm.assert_equal(result, expected) + + +class TestGuessDatetimeFormat: + @pytest.mark.parametrize( + "test_list", + [ + [ + "2011-12-30 00:00:00.000000", + "2011-12-30 00:00:00.000000", + "2011-12-30 00:00:00.000000", + ], + [np.nan, np.nan, "2011-12-30 00:00:00.000000"], + ["", "2011-12-30 00:00:00.000000"], + ["NaT", "2011-12-30 00:00:00.000000"], + ["2011-12-30 00:00:00.000000", "random_string"], + ["now", "2011-12-30 00:00:00.000000"], + ["today", "2011-12-30 00:00:00.000000"], + ], + ) + def test_guess_datetime_format_for_array(self, test_list): + expected_format = "%Y-%m-%d %H:%M:%S.%f" + test_array = np.array(test_list, dtype=object) + assert tools._guess_datetime_format_for_array(test_array) == expected_format + + @td.skip_if_not_us_locale + def test_guess_datetime_format_for_array_all_nans(self): + format_for_string_of_nans = tools._guess_datetime_format_for_array( + np.array([np.nan, np.nan, np.nan], dtype="O") + ) + assert format_for_string_of_nans is None + + +class TestToDatetimeInferFormat: + @pytest.mark.parametrize( + "test_format", ["%m-%d-%Y", "%m/%d/%Y %H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S.%f"] + ) + def test_to_datetime_infer_datetime_format_consistent_format( + self, cache, test_format + ): + ser = Series(date_range("20000101", periods=50, freq="h")) + + s_as_dt_strings = ser.apply(lambda x: x.strftime(test_format)) + + with_format = to_datetime(s_as_dt_strings, format=test_format, cache=cache) + without_format = to_datetime(s_as_dt_strings, cache=cache) + + # Whether the format is explicitly passed, or + # it is inferred, the results should all be the same + tm.assert_series_equal(with_format, without_format) + + def test_to_datetime_inconsistent_format(self, cache): + data = ["01/01/2011 00:00:00", "01-02-2011 00:00:00", "2011-01-03T00:00:00"] + ser = Series(np.array(data)) + msg = ( + r'^time data "01-02-2011 00:00:00" doesn\'t match format ' + rf'"%m/%d/%Y %H:%M:%S". {PARSING_ERR_MSG}$' + ) + with pytest.raises(ValueError, match=msg): + to_datetime(ser, cache=cache) + + def test_to_datetime_consistent_format(self, cache): + data = ["Jan/01/2011", "Feb/01/2011", "Mar/01/2011"] + ser = Series(np.array(data)) + result = to_datetime(ser, cache=cache) + expected = Series( + ["2011-01-01", "2011-02-01", "2011-03-01"], dtype="datetime64[us]" + ) + tm.assert_series_equal(result, expected) + + def test_to_datetime_series_with_nans(self, cache): + ser = Series( + np.array( + ["01/01/2011 00:00:00", np.nan, "01/03/2011 00:00:00", np.nan], + dtype=object, + ) + ) + result = to_datetime(ser, cache=cache) + expected = Series( + ["2011-01-01", NaT, "2011-01-03", NaT], dtype="datetime64[us]" + ) + tm.assert_series_equal(result, expected) + + def test_to_datetime_series_start_with_nans(self, cache): + ser = Series( + np.array( + [ + np.nan, + np.nan, + "01/01/2011 00:00:00", + "01/02/2011 00:00:00", + "01/03/2011 00:00:00", + ], + dtype=object, + ) + ) + + result = to_datetime(ser, cache=cache) + expected = Series( + [NaT, NaT, "2011-01-01", "2011-01-02", "2011-01-03"], dtype="datetime64[us]" + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "tz_name, offset", + [("UTC", 0), ("UTC-3", 180), ("UTC+3", -180)], + ) + def test_infer_datetime_format_tz_name(self, tz_name, offset): + # GH 33133 + ser = Series([f"2019-02-02 08:07:13 {tz_name}"]) + result = to_datetime(ser) + tz = timezone(timedelta(minutes=offset)) + expected = Series([Timestamp("2019-02-02 08:07:13").tz_localize(tz)]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "ts,zero_tz", + [ + ("2019-02-02 08:07:13", "Z"), + ("2019-02-02 08:07:13", ""), + ("2019-02-02 08:07:13.012345", "Z"), + ("2019-02-02 08:07:13.012345", ""), + ], + ) + def test_infer_datetime_format_zero_tz(self, ts, zero_tz): + # GH 41047 + ser = Series([ts + zero_tz]) + result = to_datetime(ser) + tz = timezone.utc if zero_tz == "Z" else None + expected = Series([Timestamp(ts, tz=tz)]) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("format", [None, "%Y-%m-%d"]) + def test_to_datetime_iso8601_noleading_0s(self, cache, format): + # GH 11871 + ser = Series(["2014-1-1", "2014-2-2", "2015-3-3"]) + expected = Series( + [ + Timestamp("2014-01-01"), + Timestamp("2014-02-02"), + Timestamp("2015-03-03"), + ] + ) + result = to_datetime(ser, format=format, cache=cache) + tm.assert_series_equal(result, expected) + + +class TestDaysInMonth: + # tests for issue #10154 + + @pytest.mark.parametrize( + "arg, format", + [ + ["2015-02-29", None], + ["2015-02-29", "%Y-%m-%d"], + ["2015-02-32", "%Y-%m-%d"], + ["2015-04-31", "%Y-%m-%d"], + ], + ) + def test_day_not_in_month_coerce(self, cache, arg, format): + assert isna(to_datetime(arg, errors="coerce", format=format, cache=cache)) + + def test_day_not_in_month_raise(self, cache): + if PY314: + msg = "day 29 must be in range 1..28 for month 2 in year 2015: 2015-02-29" + else: + msg = "day is out of range for month: 2015-02-29" + with pytest.raises(ValueError, match=msg): + to_datetime("2015-02-29", errors="raise", cache=cache) + + @pytest.mark.parametrize( + "arg, format, msg", + [ + ( + "2015-02-29", + "%Y-%m-%d", + f"^{DAY_IS_OUT_OF_RANGE}. {PARSING_ERR_MSG}$", + ), + ( + "2015-29-02", + "%Y-%d-%m", + f"^{DAY_IS_OUT_OF_RANGE}. {PARSING_ERR_MSG}$", + ), + ( + "2015-02-32", + "%Y-%m-%d", + '^unconverted data remains when parsing with format "%Y-%m-%d": "2". ' + f"{PARSING_ERR_MSG}$", + ), + ( + "2015-32-02", + "%Y-%d-%m", + '^time data "2015-32-02" doesn\'t match format "%Y-%d-%m". ' + f"{PARSING_ERR_MSG}$", + ), + ( + "2015-04-31", + "%Y-%m-%d", + f"^{DAY_IS_OUT_OF_RANGE}. {PARSING_ERR_MSG}$", + ), + ( + "2015-31-04", + "%Y-%d-%m", + f"^{DAY_IS_OUT_OF_RANGE}. {PARSING_ERR_MSG}$", + ), + ], + ) + def test_day_not_in_month_raise_value(self, cache, arg, format, msg): + # https://github.com/pandas-dev/pandas/issues/50462 + with pytest.raises(ValueError, match=msg): + to_datetime(arg, errors="raise", format=format, cache=cache) + + +class TestDatetimeParsingWrappers: + @pytest.mark.parametrize( + "date_str, expected", + [ + ("2011-01-01", datetime(2011, 1, 1)), + ("2Q2005", datetime(2005, 4, 1)), + ("2Q05", datetime(2005, 4, 1)), + ("2005Q1", datetime(2005, 1, 1)), + ("05Q1", datetime(2005, 1, 1)), + ("2011Q3", datetime(2011, 7, 1)), + ("11Q3", datetime(2011, 7, 1)), + ("3Q2011", datetime(2011, 7, 1)), + ("3Q11", datetime(2011, 7, 1)), + # quarterly without space + ("2000Q4", datetime(2000, 10, 1)), + ("00Q4", datetime(2000, 10, 1)), + ("4Q2000", datetime(2000, 10, 1)), + ("4Q00", datetime(2000, 10, 1)), + ("2000q4", datetime(2000, 10, 1)), + ("2000-Q4", datetime(2000, 10, 1)), + ("00-Q4", datetime(2000, 10, 1)), + ("4Q-2000", datetime(2000, 10, 1)), + ("4Q-00", datetime(2000, 10, 1)), + ("00q4", datetime(2000, 10, 1)), + ("2005", datetime(2005, 1, 1)), + ("2005-11", datetime(2005, 11, 1)), + ("2005 11", datetime(2005, 11, 1)), + ("11-2005", datetime(2005, 11, 1)), + ("11 2005", datetime(2005, 11, 1)), + ("200511", datetime(2020, 5, 11)), + ("20051109", datetime(2005, 11, 9)), + ("20051109 10:15", datetime(2005, 11, 9, 10, 15)), + ("20051109 08H", datetime(2005, 11, 9, 8, 0)), + ("2005-11-09 10:15", datetime(2005, 11, 9, 10, 15)), + ("2005-11-09 08H", datetime(2005, 11, 9, 8, 0)), + ("2005/11/09 10:15", datetime(2005, 11, 9, 10, 15)), + ("2005/11/09 10:15:32", datetime(2005, 11, 9, 10, 15, 32)), + ("2005/11/09 10:15:32 AM", datetime(2005, 11, 9, 10, 15, 32)), + ("2005/11/09 10:15:32 PM", datetime(2005, 11, 9, 22, 15, 32)), + ("2005/11/09 08H", datetime(2005, 11, 9, 8, 0)), + ("Thu Sep 25 10:36:28 2003", datetime(2003, 9, 25, 10, 36, 28)), + ("Thu Sep 25 2003", datetime(2003, 9, 25)), + ("Sep 25 2003", datetime(2003, 9, 25)), + ("January 1 2014", datetime(2014, 1, 1)), + # GH#10537 + ("2014-06", datetime(2014, 6, 1)), + ("06-2014", datetime(2014, 6, 1)), + ("2014-6", datetime(2014, 6, 1)), + ("6-2014", datetime(2014, 6, 1)), + ("20010101 12", datetime(2001, 1, 1, 12)), + ("20010101 1234", datetime(2001, 1, 1, 12, 34)), + ("20010101 123456", datetime(2001, 1, 1, 12, 34, 56)), + ], + ) + def test_parsers(self, date_str, expected, cache): + # dateutil >= 2.5.0 defaults to yearfirst=True + # https://github.com/dateutil/dateutil/issues/217 + yearfirst = True + + result1, reso_attrname = parsing.parse_datetime_string_with_reso( + date_str, yearfirst=yearfirst + ) + + reso = { + "nanosecond": "ns", + }.get(reso_attrname, "us") + result2 = to_datetime(date_str, yearfirst=yearfirst) + result3 = to_datetime([date_str], yearfirst=yearfirst) + # result5 is used below + result4 = to_datetime( + np.array([date_str], dtype=object), yearfirst=yearfirst, cache=cache + ) + result6 = DatetimeIndex([date_str], yearfirst=yearfirst) + # result7 is used below + result8 = DatetimeIndex(Index([date_str]), yearfirst=yearfirst) + result9 = DatetimeIndex(Series([date_str]), yearfirst=yearfirst) + + for res in [result1, result2]: + assert res == expected + for res in [result3, result4, result6, result8, result9]: + exp = DatetimeIndex([Timestamp(expected)]).as_unit(reso) + tm.assert_index_equal(res, exp) + + # these really need to have yearfirst, but we don't support + if not yearfirst: + result5 = Timestamp(date_str) + assert result5 == expected + result7 = date_range(date_str, freq="S", periods=1, yearfirst=yearfirst) + assert result7 == expected + + def test_na_values_with_cache( + self, cache, unique_nulls_fixture, unique_nulls_fixture2 + ): + # GH22305 + expected = Index([NaT, NaT], dtype="datetime64[s]") + result = to_datetime([unique_nulls_fixture, unique_nulls_fixture2], cache=cache) + tm.assert_index_equal(result, expected) + + def test_parsers_nat(self): + # Test that each of several string-accepting methods return pd.NaT + result1, _ = parsing.parse_datetime_string_with_reso("NaT") + result2 = to_datetime("NaT") + result3 = Timestamp("NaT") + result4 = DatetimeIndex(["NaT"])[0] + assert result1 is NaT + assert result2 is NaT + assert result3 is NaT + assert result4 is NaT + + @pytest.mark.parametrize( + "date_str, dayfirst, yearfirst, expected", + [ + ("10-11-12", False, False, datetime(2012, 10, 11)), + ("10-11-12", True, False, datetime(2012, 11, 10)), + ("10-11-12", False, True, datetime(2010, 11, 12)), + ("10-11-12", True, True, datetime(2010, 12, 11)), + ("20/12/21", False, False, datetime(2021, 12, 20)), + ("20/12/21", True, False, datetime(2021, 12, 20)), + ("20/12/21", False, True, datetime(2020, 12, 21)), + ("20/12/21", True, True, datetime(2020, 12, 21)), + # GH 58859 + ("20201012", True, False, datetime(2020, 12, 10)), + ], + ) + def test_parsers_dayfirst_yearfirst( + self, cache, date_str, dayfirst, yearfirst, expected + ): + # OK + # 2.5.1 10-11-12 [dayfirst=0, yearfirst=0] -> 2012-10-11 00:00:00 + # 2.5.2 10-11-12 [dayfirst=0, yearfirst=1] -> 2012-10-11 00:00:00 + # 2.5.3 10-11-12 [dayfirst=0, yearfirst=0] -> 2012-10-11 00:00:00 + + # OK + # 2.5.1 10-11-12 [dayfirst=0, yearfirst=1] -> 2010-11-12 00:00:00 + # 2.5.2 10-11-12 [dayfirst=0, yearfirst=1] -> 2010-11-12 00:00:00 + # 2.5.3 10-11-12 [dayfirst=0, yearfirst=1] -> 2010-11-12 00:00:00 + + # bug fix in 2.5.2 + # 2.5.1 10-11-12 [dayfirst=1, yearfirst=1] -> 2010-11-12 00:00:00 + # 2.5.2 10-11-12 [dayfirst=1, yearfirst=1] -> 2010-12-11 00:00:00 + # 2.5.3 10-11-12 [dayfirst=1, yearfirst=1] -> 2010-12-11 00:00:00 + + # OK + # 2.5.1 10-11-12 [dayfirst=1, yearfirst=0] -> 2012-11-10 00:00:00 + # 2.5.2 10-11-12 [dayfirst=1, yearfirst=0] -> 2012-11-10 00:00:00 + # 2.5.3 10-11-12 [dayfirst=1, yearfirst=0] -> 2012-11-10 00:00:00 + + # OK + # 2.5.1 20/12/21 [dayfirst=0, yearfirst=0] -> 2021-12-20 00:00:00 + # 2.5.2 20/12/21 [dayfirst=0, yearfirst=0] -> 2021-12-20 00:00:00 + # 2.5.3 20/12/21 [dayfirst=0, yearfirst=0] -> 2021-12-20 00:00:00 + + # OK + # 2.5.1 20/12/21 [dayfirst=0, yearfirst=1] -> 2020-12-21 00:00:00 + # 2.5.2 20/12/21 [dayfirst=0, yearfirst=1] -> 2020-12-21 00:00:00 + # 2.5.3 20/12/21 [dayfirst=0, yearfirst=1] -> 2020-12-21 00:00:00 + + # revert of bug in 2.5.2 + # 2.5.1 20/12/21 [dayfirst=1, yearfirst=1] -> 2020-12-21 00:00:00 + # 2.5.2 20/12/21 [dayfirst=1, yearfirst=1] -> month must be in 1..12 + # 2.5.3 20/12/21 [dayfirst=1, yearfirst=1] -> 2020-12-21 00:00:00 + + # OK + # 2.5.1 20/12/21 [dayfirst=1, yearfirst=0] -> 2021-12-20 00:00:00 + # 2.5.2 20/12/21 [dayfirst=1, yearfirst=0] -> 2021-12-20 00:00:00 + # 2.5.3 20/12/21 [dayfirst=1, yearfirst=0] -> 2021-12-20 00:00:00 + + # str : dayfirst, yearfirst, expected + + # compare with dateutil result + dateutil_result = parse(date_str, dayfirst=dayfirst, yearfirst=yearfirst) + assert dateutil_result == expected + + result1, _ = parsing.parse_datetime_string_with_reso( + date_str, dayfirst=dayfirst, yearfirst=yearfirst + ) + + # we don't support dayfirst/yearfirst here: + if not dayfirst and not yearfirst: + result2 = Timestamp(date_str) + assert result2 == expected + + result3 = to_datetime( + date_str, dayfirst=dayfirst, yearfirst=yearfirst, cache=cache + ) + + result4 = DatetimeIndex([date_str], dayfirst=dayfirst, yearfirst=yearfirst)[0] + + assert result1 == expected + assert result3 == expected + assert result4 == expected + + @pytest.mark.parametrize( + "date_str, exp_def", + [["10:15", datetime(1, 1, 1, 10, 15)], ["9:05", datetime(1, 1, 1, 9, 5)]], + ) + def test_parsers_timestring(self, date_str, exp_def): + # must be the same as dateutil result + exp_now = parse(date_str) + + result1, _ = parsing.parse_datetime_string_with_reso(date_str) + result2 = to_datetime(date_str) + result3 = to_datetime([date_str]) + result4 = Timestamp(date_str) + result5 = DatetimeIndex([date_str])[0] + # parse time string return time string based on default date + # others are not, and can't be changed because it is used in + # time series plot + assert result1 == exp_def + assert result2 == exp_now + assert result3 == exp_now + assert result4 == exp_now + assert result5 == exp_now + + @pytest.mark.parametrize( + "dt_string, tz, dt_string_repr", + [ + ( + "2013-01-01 05:45+0545", + timezone(timedelta(minutes=345)), + "Timestamp('2013-01-01 05:45:00+0545', tz='UTC+05:45')", + ), + ( + "2013-01-01 05:30+0530", + timezone(timedelta(minutes=330)), + "Timestamp('2013-01-01 05:30:00+0530', tz='UTC+05:30')", + ), + ], + ) + def test_parsers_timezone_minute_offsets_roundtrip( + self, cache, dt_string, tz, dt_string_repr + ): + # GH11708 + base = to_datetime("2013-01-01 00:00:00", cache=cache) + base = base.tz_localize("UTC").tz_convert(tz) + dt_time = to_datetime(dt_string, cache=cache) + assert base == dt_time + assert dt_string_repr == repr(dt_time) + + +@pytest.fixture(params=["D", "s", "ms", "us", "ns"]) +def units(request): + """Day and some time units. + + * D + * s + * ms + * us + * ns + """ + return request.param + + +@pytest.fixture +def julian_dates(): + return date_range("2014-1-1", periods=10).to_julian_date().values + + +class TestOrigin: + def test_origin_and_unit(self): + # GH#42624 + ts = to_datetime(1, unit="s", origin=1) + expected = Timestamp("1970-01-01 00:00:02") + assert ts == expected + + ts = to_datetime(1, unit="s", origin=1_000_000_000) + expected = Timestamp("2001-09-09 01:46:41") + assert ts == expected + + def test_julian(self, julian_dates): + # gh-11276, gh-11745 + # for origin as julian + + result = Series(to_datetime(julian_dates, unit="D", origin="julian")) + expected = Series( + to_datetime(julian_dates - Timestamp(0).to_julian_date(), unit="D") + ) + tm.assert_series_equal(result, expected) + + def test_unix(self): + result = Series(to_datetime([0, 1, 2], unit="D", origin="unix")) + expected = Series( + [Timestamp("1970-01-01"), Timestamp("1970-01-02"), Timestamp("1970-01-03")], + dtype="M8[s]", + ) + tm.assert_series_equal(result, expected) + + def test_julian_round_trip(self): + result = to_datetime(2456658, origin="julian", unit="D") + assert result.to_julian_date() == 2456658 + + # out-of-bounds + msg = "1 is Out of Bounds for origin='julian'" + with pytest.raises(ValueError, match=msg): + to_datetime(1, origin="julian", unit="D") + + def test_invalid_unit(self, units, julian_dates): + # checking for invalid combination of origin='julian' and unit != D + if units != "D": + msg = "unit must be 'D' for origin='julian'" + with pytest.raises(ValueError, match=msg): + to_datetime(julian_dates, unit=units, origin="julian") + + @pytest.mark.parametrize("unit", ["ns", "D"]) + def test_invalid_origin(self, unit): + # need to have a numeric specified + msg = "it must be numeric with a unit specified" + with pytest.raises(ValueError, match=msg): + to_datetime("2005-01-01", origin="1960-01-01", unit=unit) + + @pytest.mark.parametrize( + "epochs", + [ + Timestamp(1960, 1, 1), + datetime(1960, 1, 1), + "1960-01-01", + np.datetime64("1960-01-01"), + ], + ) + def test_epoch(self, units, epochs): + epoch_1960 = Timestamp(1960, 1, 1) + units_from_epochs = np.arange(5, dtype=np.int64) + exp_unit = "s" if units == "D" else units + expected = Series( + [pd.Timedelta(x, unit=units) + epoch_1960 for x in units_from_epochs], + dtype=f"M8[{exp_unit}]", + ) + + result = Series(to_datetime(units_from_epochs, unit=units, origin=epochs)) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "origin, exc", + [ + ("random_string", ValueError), + ("epoch", ValueError), + ("13-24-1990", ValueError), + (datetime(1, 1, 1), OutOfBoundsDatetime), + ], + ) + def test_invalid_origins(self, origin, exc, units): + msg = "|".join( + [ + f"origin {origin} is Out of Bounds", + f"origin {origin} cannot be converted to a Timestamp", + "Cannot cast .* to unit='ns' without overflow", + ] + ) + with pytest.raises(exc, match=msg): + to_datetime(list(range(5)), unit=units, origin=origin) + + def test_invalid_origins_tzinfo(self): + # GH16842 + with pytest.raises(ValueError, match="must be tz-naive"): + to_datetime(1, unit="D", origin=datetime(2000, 1, 1, tzinfo=timezone.utc)) + + def test_incorrect_value_exception(self): + # GH47495 + msg = "Unknown datetime string format, unable to parse: yesterday" + with pytest.raises(ValueError, match=msg): + to_datetime(["today", "yesterday"]) + + @pytest.mark.parametrize( + "format, warning", + [ + (None, UserWarning), + ("%Y-%m-%d %H:%M:%S", None), + ("%Y-%d-%m %H:%M:%S", None), + ], + ) + def test_to_datetime_out_of_bounds_with_format_arg(self, format, warning): + # see gh-23830 + if format is None: + res = to_datetime("2417-10-10 00:00:00.00", format=format) + assert isinstance(res, Timestamp) + assert res.year == 2417 + assert res.month == 10 + assert res.day == 10 + else: + msg = "unconverted data remains when parsing with format.*" + with pytest.raises(ValueError, match=msg): + to_datetime("2417-10-10 00:00:00.00", format=format) + + @pytest.mark.parametrize( + "arg, origin, expected_str", + [ + [200 * 365, "unix", "2169-11-13 00:00:00"], + [200 * 365, "1870-01-01", "2069-11-13 00:00:00"], + [300 * 365, "1870-01-01", "2169-10-20 00:00:00"], + ], + ) + def test_processing_order(self, arg, origin, expected_str): + # make sure we handle out-of-bounds *before* + # constructing the dates + + result = to_datetime(arg, unit="D", origin=origin) + expected = Timestamp(expected_str) + assert result == expected + + result = to_datetime(200 * 365, unit="D", origin="1870-01-01") + expected = Timestamp("2069-11-13 00:00:00") + assert result == expected + + result = to_datetime(300 * 365, unit="D", origin="1870-01-01") + expected = Timestamp("2169-10-20 00:00:00") + assert result == expected + + @pytest.mark.parametrize( + "offset,utc,exp", + [ + ["Z", True, "2019-01-01T00:00:00.000Z"], + ["Z", None, "2019-01-01T00:00:00.000Z"], + ["-01:00", True, "2019-01-01T01:00:00.000Z"], + ["-01:00", None, "2019-01-01T00:00:00.000-01:00"], + ], + ) + def test_arg_tz_ns_unit(self, offset, utc, exp): + # GH 25546 + arg = "2019-01-01T00:00:00.000" + offset + result = to_datetime([arg], unit="ns", utc=utc) + expected = to_datetime([exp]).as_unit("us") + tm.assert_index_equal(result, expected) + + +class TestShouldCache: + @pytest.mark.parametrize( + "listlike,do_caching", + [ + ([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], False), + ([1, 1, 1, 1, 4, 5, 6, 7, 8, 9], True), + ], + ) + def test_should_cache(self, listlike, do_caching): + assert ( + tools.should_cache(listlike, check_count=len(listlike), unique_share=0.7) + == do_caching + ) + + @pytest.mark.parametrize( + "unique_share,check_count, err_message", + [ + (0.5, 11, r"check_count must be in next bounds: \[0; len\(arg\)\]"), + (10, 2, r"unique_share must be in next bounds: \(0; 1\)"), + ], + ) + def test_should_cache_errors(self, unique_share, check_count, err_message): + arg = [5] * 10 + + with pytest.raises(AssertionError, match=err_message): + tools.should_cache(arg, unique_share, check_count) + + @pytest.mark.parametrize( + "listlike", + [ + (deque([Timestamp("2010-06-02 09:30:00")] * 51)), + ([Timestamp("2010-06-02 09:30:00")] * 51), + (tuple([Timestamp("2010-06-02 09:30:00")] * 51)), + ], + ) + def test_no_slicing_errors_in_should_cache(self, listlike): + # GH#29403 + assert tools.should_cache(listlike) is True + + +def test_nullable_integer_to_datetime(): + # Test for #30050 + ser = Series([1, 2, None, 2**61, None], dtype="Int64") + ser_copy = ser.copy() + + res = to_datetime(ser, unit="ns") + + expected = Series( + [ + np.datetime64("1970-01-01 00:00:00.000000001"), + np.datetime64("1970-01-01 00:00:00.000000002"), + np.datetime64("NaT"), + np.datetime64("2043-01-25 23:56:49.213693952"), + np.datetime64("NaT"), + ] + ) + tm.assert_series_equal(res, expected) + # Check that ser isn't mutated + tm.assert_series_equal(ser, ser_copy) + + +@pytest.mark.parametrize("klass", [np.array, list]) +def test_na_to_datetime(nulls_fixture, klass): + if isinstance(nulls_fixture, Decimal): + with pytest.raises(TypeError, match="not convertible to datetime"): + to_datetime(klass([nulls_fixture])) + + else: + result = to_datetime(klass([nulls_fixture])) + + assert result[0] is NaT + + +@pytest.mark.parametrize("errors", ["raise", "coerce"]) +@pytest.mark.parametrize( + "args, format", + [ + (["03/24/2016", "03/25/2016", ""], "%m/%d/%Y"), + (["2016-03-24", "2016-03-25", ""], "%Y-%m-%d"), + ], + ids=["non-ISO8601", "ISO8601"], +) +def test_empty_string_datetime(errors, args, format): + # GH13044, GH50251 + td = Series(args) + + # coerce empty string to pd.NaT + result = to_datetime(td, format=format, errors=errors) + expected = Series(["2016-03-24", "2016-03-25", NaT], dtype="datetime64[us]") + tm.assert_series_equal(expected, result) + + +def test_empty_string_datetime_coerce__unit(): + # GH13044 + # coerce empty string to pd.NaT + result = to_datetime([1, ""], unit="s", errors="coerce") + expected = DatetimeIndex(["1970-01-01 00:00:01", "NaT"], dtype="datetime64[s]") + tm.assert_index_equal(expected, result) + + # verify that no exception is raised even when errors='raise' is set + result = to_datetime([1, ""], unit="s", errors="raise") + tm.assert_index_equal(expected, result) + + +def test_to_datetime_monotonic_increasing_index(cache): + # GH28238 + cstart = start_caching_at + times = date_range(Timestamp("1980"), periods=cstart, freq="YS") + times = times.to_frame(index=False, name="DT").sample(n=cstart, random_state=1) + times.index = times.index.to_series().astype(float) / 1000 + result = to_datetime(times.iloc[:, 0], cache=cache) + expected = times.iloc[:, 0] + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "series_length", + [40, start_caching_at, (start_caching_at + 1), (start_caching_at + 5)], +) +def test_to_datetime_cache_coerce_50_lines_outofbounds(series_length): + # GH#45319 + ser = Series( + [datetime.fromisoformat("1446-04-12 00:00:00+00:00")] + + ([datetime.fromisoformat("1991-10-20 00:00:00+00:00")] * series_length), + dtype=object, + ) + result1 = to_datetime(ser, errors="coerce", utc=True) + + expected1 = Series([Timestamp(x) for x in ser]) + assert expected1.dtype == "M8[us, UTC]" + tm.assert_series_equal(result1, expected1) + + result3 = to_datetime(ser, errors="raise", utc=True) + tm.assert_series_equal(result3, expected1) + + +def test_to_datetime_format_f_parse_nanos(): + # GH 48767 + timestamp = "15/02/2020 02:03:04.123456789" + timestamp_format = "%d/%m/%Y %H:%M:%S.%f" + result = to_datetime(timestamp, format=timestamp_format) + expected = Timestamp( + year=2020, + month=2, + day=15, + hour=2, + minute=3, + second=4, + microsecond=123456, + nanosecond=789, + ) + assert result == expected + + +def test_to_datetime_mixed_iso8601(): + # https://github.com/pandas-dev/pandas/issues/50411 + result = to_datetime(["2020-01-01", "2020-01-01 05:00:00"], format="ISO8601") + expected = DatetimeIndex(["2020-01-01 00:00:00", "2020-01-01 05:00:00"]) + tm.assert_index_equal(result, expected) + + +def test_to_datetime_mixed_other(): + # https://github.com/pandas-dev/pandas/issues/50411 + result = to_datetime(["01/11/2000", "12 January 2000"], format="mixed") + expected = DatetimeIndex(["2000-01-11", "2000-01-12"]) + tm.assert_index_equal(result, expected) + + +@pytest.mark.parametrize("exact", [True, False]) +@pytest.mark.parametrize("format", ["ISO8601", "mixed"]) +def test_to_datetime_mixed_or_iso_exact(exact, format): + msg = "Cannot use 'exact' when 'format' is 'mixed' or 'ISO8601'" + with pytest.raises(ValueError, match=msg): + to_datetime(["2020-01-01"], exact=exact, format=format) + + +def test_to_datetime_mixed_not_necessarily_iso8601_raise(): + # https://github.com/pandas-dev/pandas/issues/50411 + with pytest.raises(ValueError, match="Time data 01-01-2000 is not ISO8601 format"): + to_datetime(["2020-01-01", "01-01-2000"], format="ISO8601") + + +def test_to_datetime_mixed_not_necessarily_iso8601_coerce(): + # https://github.com/pandas-dev/pandas/issues/50411 + result = to_datetime( + ["2020-01-01", "01-01-2000"], format="ISO8601", errors="coerce" + ) + tm.assert_index_equal(result, DatetimeIndex(["2020-01-01 00:00:00", NaT])) + + +def test_to_datetime_iso8601_utc_single_naive(): + # GH#61389 + result = to_datetime("2023-10-15T14:30:00", utc=True, format="ISO8601") + expected = Timestamp("2023-10-15 14:30:00+00:00") + assert result == expected + + +def test_to_datetime_iso8601_utc_mixed_negative_offset(): + # GH#61389 + data = ["2023-10-15T10:30:00-12:00", "2023-10-15T14:30:00"] + result = to_datetime(data, utc=True, format="ISO8601") + + expected = DatetimeIndex( + [Timestamp("2023-10-15 22:30:00+00:00"), Timestamp("2023-10-15 14:30:00+00:00")] + ) + tm.assert_index_equal(result, expected) + + +def test_to_datetime_iso8601_utc_mixed_positive_offset(): + # GH#61389 + data = ["2023-10-15T10:30:00+08:00", "2023-10-15T14:30:00"] + result = to_datetime(data, utc=True, format="ISO8601") + + expected = DatetimeIndex( + [Timestamp("2023-10-15 02:30:00+00:00"), Timestamp("2023-10-15 14:30:00+00:00")] + ) + tm.assert_index_equal(result, expected) + + +def test_to_datetime_iso8601_utc_mixed_both_offsets(): + # GH#61389 + data = [ + "2023-10-15T10:30:00+08:00", + "2023-10-15T12:30:00-05:00", + "2023-10-15T14:30:00", + ] + result = to_datetime(data, utc=True, format="ISO8601") + + expected = DatetimeIndex( + [ + Timestamp("2023-10-15 02:30:00+00:00"), + Timestamp("2023-10-15 17:30:00+00:00"), + Timestamp("2023-10-15 14:30:00+00:00"), + ] + ) + tm.assert_index_equal(result, expected) + + +def test_unknown_tz_raises(): + # GH#18702, GH#51476 + dtstr = "2014 Jan 9 05:15 FAKE" + msg = '.*un-recognized timezone "FAKE".' + with pytest.raises(ValueError, match=msg): + Timestamp(dtstr) + + with pytest.raises(ValueError, match=msg): + to_datetime(dtstr) + with pytest.raises(ValueError, match=msg): + to_datetime([dtstr]) + + +def test_unformatted_input_raises(): + valid, invalid = "2024-01-01", "N" + ser = Series([valid] * start_caching_at + [invalid]) + msg = 'time data "N" doesn\'t match format "%Y-%m-%d"' + + with pytest.raises(ValueError, match=msg): + to_datetime(ser, format="%Y-%m-%d", exact=True, cache=True) + + +def test_from_numeric_arrow_dtype(any_numeric_ea_dtype): + # GH 52425 + pytest.importorskip("pyarrow") + ser = Series([1, 2], dtype=f"{any_numeric_ea_dtype.lower()}[pyarrow]") + result = to_datetime(ser) + expected = Series([1, 2], dtype="datetime64[ns]") + tm.assert_series_equal(result, expected) + + +def test_to_datetime_with_empty_str_utc_false_format_mixed(): + # GH 50887 + vals = ["2020-01-01 00:00+00:00", ""] + result = to_datetime(vals, format="mixed") + expected = Index([Timestamp("2020-01-01 00:00+00:00"), "NaT"], dtype="M8[us, UTC]") + tm.assert_index_equal(result, expected) + + # Check that a couple of other similar paths work the same way + alt = to_datetime(vals) + tm.assert_index_equal(alt, expected) + alt2 = DatetimeIndex(vals) + tm.assert_index_equal(alt2, expected) + + +def test_to_datetime_with_empty_str_utc_false_offsets_and_format_mixed(): + # GH#50887, GH#57275 + msg = "Mixed timezones detected. Pass utc=True in to_datetime" + + with pytest.raises(ValueError, match=msg): + to_datetime( + ["2020-01-01 00:00+00:00", "2020-01-01 00:00+02:00", ""], format="mixed" + ) + + +def test_to_datetime_mixed_tzs_mixed_types(): + # GH#55793, GH#55693 mismatched tzs but one is str and other is + # datetime object + ts = Timestamp("2016-01-02 03:04:05", tz="US/Pacific") + dtstr = "2023-10-30 15:06+01" + arr = [ts, dtstr] + + msg = ( + "Mixed timezones detected. Pass utc=True in to_datetime or tz='UTC' " + "in DatetimeIndex to convert to a common timezone" + ) + with pytest.raises(ValueError, match=msg): + to_datetime(arr) + with pytest.raises(ValueError, match=msg): + to_datetime(arr, format="mixed") + with pytest.raises(ValueError, match=msg): + DatetimeIndex(arr) + + +def test_to_datetime_mixed_types_matching_tzs(): + # GH#55793 + dtstr = "2023-11-01 09:22:03-07:00" + ts = Timestamp(dtstr) + arr = [ts, dtstr] + res1 = to_datetime(arr) + res2 = to_datetime(arr[::-1])[::-1] + res3 = to_datetime(arr, format="mixed") + res4 = DatetimeIndex(arr) + + expected = DatetimeIndex([ts, ts]) + tm.assert_index_equal(res1, expected) + tm.assert_index_equal(res2, expected) + tm.assert_index_equal(res3, expected) + tm.assert_index_equal(res4, expected) + + +dtstr = "2020-01-01 00:00+00:00" +ts = Timestamp(dtstr) + + +@pytest.mark.filterwarnings("ignore:Could not infer format:UserWarning") +@pytest.mark.parametrize( + "aware_val", + [dtstr, Timestamp(dtstr)], + ids=lambda x: type(x).__name__, +) +@pytest.mark.parametrize( + "naive_val", + [dtstr[:-6], ts.tz_localize(None), ts.date(), ts.asm8, ts.value, float(ts.value)], + ids=lambda x: type(x).__name__, +) +@pytest.mark.parametrize("naive_first", [True, False]) +def test_to_datetime_mixed_awareness_mixed_types(aware_val, naive_val, naive_first): + # GH#55793, GH#55693, GH#57275 + # Empty string parses to NaT + vals = [aware_val, naive_val, ""] + + vec = vals + if naive_first: + # alas, the behavior is order-dependent, so we test both ways + vec = [naive_val, aware_val, ""] + + # both_strs-> paths that were previously already deprecated with warning + # issued in _array_to_datetime_object + both_strs = isinstance(aware_val, str) and isinstance(naive_val, str) + has_numeric = isinstance(naive_val, (int, float)) + both_datetime = isinstance(naive_val, datetime) and isinstance(aware_val, datetime) + + mixed_msg = ( + "Mixed timezones detected. Pass utc=True in to_datetime or tz='UTC' " + "in DatetimeIndex to convert to a common timezone" + ) + + first_non_null = next(x for x in vec if x != "") + # if first_non_null is a not a string, _guess_datetime_format_for_array + # doesn't guess a format so we don't go through array_strptime + if not isinstance(first_non_null, str): + # that case goes through array_strptime which has different behavior + msg = mixed_msg + if naive_first and isinstance(aware_val, Timestamp): + if isinstance(naive_val, Timestamp): + msg = "Tz-aware datetime.datetime cannot be converted to datetime64" + with pytest.raises(ValueError, match=msg): + to_datetime(vec) + else: + if not naive_first and both_datetime: + msg = "Cannot mix tz-aware with tz-naive values" + with pytest.raises(ValueError, match=msg): + to_datetime(vec) + + # No warning/error with utc=True + to_datetime(vec, utc=True) + + elif has_numeric and vec.index(aware_val) < vec.index(naive_val): + msg = "time data .* doesn't match format" + with pytest.raises(ValueError, match=msg): + to_datetime(vec) + with pytest.raises(ValueError, match=msg): + to_datetime(vec, utc=True) + + elif both_strs and vec.index(aware_val) < vec.index(naive_val): + msg = r"time data \"2020-01-01 00:00\" doesn't match format" + with pytest.raises(ValueError, match=msg): + to_datetime(vec) + with pytest.raises(ValueError, match=msg): + to_datetime(vec, utc=True) + + elif both_strs and vec.index(naive_val) < vec.index(aware_val): + msg = "unconverted data remains when parsing with format" + with pytest.raises(ValueError, match=msg): + to_datetime(vec) + with pytest.raises(ValueError, match=msg): + to_datetime(vec, utc=True) + + else: + msg = mixed_msg + with pytest.raises(ValueError, match=msg): + to_datetime(vec) + + # No warning/error with utc=True + to_datetime(vec, utc=True) + + if both_strs: + msg = mixed_msg + with pytest.raises(ValueError, match=msg): + to_datetime(vec, format="mixed") + with pytest.raises(ValueError, match=msg): + DatetimeIndex(vec) + else: + msg = mixed_msg + if naive_first and isinstance(aware_val, Timestamp): + if isinstance(naive_val, Timestamp): + msg = "Tz-aware datetime.datetime cannot be converted to datetime64" + with pytest.raises(ValueError, match=msg): + to_datetime(vec, format="mixed") + with pytest.raises(ValueError, match=msg): + DatetimeIndex(vec) + else: + if not naive_first and both_datetime: + msg = "Cannot mix tz-aware with tz-naive values" + with pytest.raises(ValueError, match=msg): + to_datetime(vec, format="mixed") + with pytest.raises(ValueError, match=msg): + DatetimeIndex(vec) + + +def test_to_datetime_wrapped_datetime64_ps(): + # GH#60341 + result = to_datetime([np.datetime64(1901901901901, "ps")]) + expected = DatetimeIndex( + ["1970-01-01 00:00:01.901901901"], dtype="datetime64[ns]", freq=None + ) + tm.assert_index_equal(result, expected) + + +def test_to_datetime_lxml_elementunicoderesult_with_format(cache): + etree = pytest.importorskip("lxml.etree") + + s = "2025-02-05 16:59:57" + node = etree.XML(f"{s}") + val = node.xpath("/date/node()")[0] # _ElementUnicodeResult + + out = to_datetime(Series([val]), format="%Y-%m-%d %H:%M:%S", cache=cache) + assert out.iloc[0] == Timestamp(s) diff --git a/pandas/tests/tools/test_to_numeric.py b/pandas/tests/tools/test_to_numeric.py new file mode 100644 index 0000000000000000000000000000000000000000..fcbc91d4c632f76f064f6389279834c74a458019 --- /dev/null +++ b/pandas/tests/tools/test_to_numeric.py @@ -0,0 +1,904 @@ +import decimal + +import numpy as np +from numpy import iinfo +import pytest + +import pandas.util._test_decorators as td + +import pandas as pd +from pandas import ( + ArrowDtype, + DataFrame, + Index, + Series, + option_context, + to_numeric, +) +import pandas._testing as tm + + +@pytest.fixture(params=[None, "raise", "coerce"]) +def errors(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def signed(request): + return request.param + + +@pytest.fixture(params=[lambda x: x, str], ids=["identity", "str"]) +def transform(request): + return request.param + + +@pytest.fixture(params=[47393996303418497800, 100000000000000000000]) +def large_val(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def multiple_elts(request): + return request.param + + +@pytest.fixture( + params=[ + (lambda x: Index(x, name="idx"), tm.assert_index_equal), + (lambda x: Series(x, name="ser"), tm.assert_series_equal), + (lambda x: np.array(Index(x).values), tm.assert_numpy_array_equal), + ] +) +def transform_assert_equal(request): + return request.param + + +@pytest.mark.parametrize( + "input_kwargs,result_kwargs", + [ + ({}, {"dtype": np.int64}), + ({"errors": "coerce", "downcast": "integer"}, {"dtype": np.int8}), + ], +) +def test_empty(input_kwargs, result_kwargs): + # see gh-16302 + ser = Series([], dtype=object) + result = to_numeric(ser, **input_kwargs) + + expected = Series([], **result_kwargs) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))] +) +@pytest.mark.parametrize("last_val", ["7", 7]) +def test_series(last_val, infer_string): + with option_context("future.infer_string", infer_string): + ser = Series(["1", "-3.14", last_val]) + result = to_numeric(ser) + + expected = Series([1, -3.14, 7]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "data", + [ + [1, 3, 4, 5], + [1.0, 3.0, 4.0, 5.0], + # Bool is regarded as numeric. + [True, False, True, True], + ], +) +def test_series_numeric(data): + ser = Series(data, index=list("ABCD"), name="EFG") + + result = to_numeric(ser) + tm.assert_series_equal(result, ser) + + +@pytest.mark.parametrize( + "data,msg", + [ + ([1, -3.14, "apple"], 'Unable to parse string "apple" at position 2'), + ( + ["orange", 1, -3.14, "apple"], + 'Unable to parse string "orange" at position 0', + ), + ], +) +def test_error(data, msg): + ser = Series(data) + + with pytest.raises(ValueError, match=msg): + to_numeric(ser, errors="raise") + + +def test_ignore_error(): + ser = Series([1, -3.14, "apple"]) + result = to_numeric(ser, errors="coerce") + + expected = Series([1, -3.14, np.nan]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "errors,exp", + [ + ("raise", 'Unable to parse string "apple" at position 2'), + # Coerces to float. + ("coerce", [1.0, 0.0, np.nan]), + ], +) +def test_bool_handling(errors, exp): + ser = Series([True, False, "apple"]) + + if isinstance(exp, str): + with pytest.raises(ValueError, match=exp): + to_numeric(ser, errors=errors) + else: + result = to_numeric(ser, errors=errors) + expected = Series(exp) + + tm.assert_series_equal(result, expected) + + +def test_list(): + ser = ["1", "-3.14", "7"] + res = to_numeric(ser) + + expected = np.array([1, -3.14, 7]) + tm.assert_numpy_array_equal(res, expected) + + +@pytest.mark.parametrize( + "data,arr_kwargs", + [ + ([1, 3, 4, 5], {"dtype": np.int64}), + ([1.0, 3.0, 4.0, 5.0], {}), + # Boolean is regarded as numeric. + ([True, False, True, True], {}), + ], +) +def test_list_numeric(data, arr_kwargs): + result = to_numeric(data) + expected = np.array(data, **arr_kwargs) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("kwargs", [{"dtype": "O"}, {}]) +def test_numeric(kwargs): + data = [1, -3.14, 7] + + ser = Series(data, **kwargs) + result = to_numeric(ser) + + expected = Series(data) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "columns", + [ + # One column. + "a", + # Multiple columns. + ["a", "b"], + ], +) +def test_numeric_df_columns(columns): + # see gh-14827 + df = DataFrame( + { + "a": [1.2, decimal.Decimal("3.14"), decimal.Decimal("infinity"), "0.1"], + "b": [1.0, 2.0, 3.0, 4.0], + } + ) + + expected = DataFrame({"a": [1.2, 3.14, np.inf, 0.1], "b": [1.0, 2.0, 3.0, 4.0]}) + df[columns] = df[columns].apply(to_numeric) + + tm.assert_frame_equal(df, expected) + + +@pytest.mark.parametrize( + "data,exp_data", + [ + ( + [[decimal.Decimal("3.14"), 1.0], decimal.Decimal("1.6"), 0.1], + [[3.14, 1.0], 1.6, 0.1], + ), + ([np.array([decimal.Decimal("3.14"), 1.0]), 0.1], [[3.14, 1.0], 0.1]), + ], +) +def test_numeric_embedded_arr_likes(data, exp_data): + # Test to_numeric with embedded lists and arrays + df = DataFrame({"a": data}) + df["a"] = df["a"].apply(to_numeric) + + expected = DataFrame({"a": exp_data}) + tm.assert_frame_equal(df, expected) + + +def test_all_nan(): + ser = Series(["a", "b", "c"]) + result = to_numeric(ser, errors="coerce") + + expected = Series([np.nan, np.nan, np.nan]) + tm.assert_series_equal(result, expected) + + +def test_type_check(errors): + # see gh-11776 + df = DataFrame({"a": [1, -3.14, 7], "b": ["4", "5", "6"]}) + kwargs = {"errors": errors} if errors is not None else {} + with pytest.raises(TypeError, match="1-d array"): + to_numeric(df, **kwargs) + + +@pytest.mark.parametrize("val", [1, 1.1, 20001]) +def test_scalar(val, signed, transform): + val = -val if signed else val + assert to_numeric(transform(val)) == float(val) + + +def test_really_large_scalar(large_val, signed, transform, errors): + # see gh-24910 + kwargs = {"errors": errors} if errors is not None else {} + val = -large_val if signed else large_val + + val = transform(val) + + expected = float(val) if errors == "coerce" else int(val) + tm.assert_almost_equal(to_numeric(val, **kwargs), expected) + + +def test_really_large_in_arr(large_val, signed, transform, multiple_elts, errors): + # see gh-24910 + kwargs = {"errors": errors} if errors is not None else {} + val = -large_val if signed else large_val + val = transform(val) + + extra_elt = "string" + arr = [val] + multiple_elts * [extra_elt] + + coercing = errors == "coerce" + + if errors in (None, "raise") and multiple_elts: + msg = 'Unable to parse string "string" at position 1' + + with pytest.raises(ValueError, match=msg): + to_numeric(arr, **kwargs) + else: + result = to_numeric(arr, **kwargs) + + exp_val = float(val) if (coercing) else int(val) + expected = [exp_val] + + if multiple_elts: + if coercing: + expected.append(np.nan) + exp_dtype = float + else: + expected.append(extra_elt) + exp_dtype = object + else: + exp_dtype = float if isinstance(exp_val, float) else object + + tm.assert_almost_equal(result, np.array(expected, dtype=exp_dtype)) + + +def test_really_large_in_arr_consistent(large_val, signed, multiple_elts, errors): + # see gh-24910 + # + # Even if we discover that we have to hold float, does not mean + # we should be lenient on subsequent elements that fail to be integer. + kwargs = {"errors": errors} if errors is not None else {} + arr = [str(-large_val if signed else large_val)] + + if multiple_elts: + arr.insert(0, large_val) + + result = to_numeric(arr, **kwargs) + expected = [float(i) if errors == "coerce" else int(i) for i in arr] + exp_dtype = float if errors == "coerce" else object + + tm.assert_almost_equal(result, np.array(expected, dtype=exp_dtype)) + + +@pytest.mark.parametrize( + "errors,checker", + [ + ("raise", 'Unable to parse string "fail" at position 0'), + ("coerce", lambda x: np.isnan(x)), + ], +) +def test_scalar_fail(errors, checker): + scalar = "fail" + + if isinstance(checker, str): + with pytest.raises(ValueError, match=checker): + to_numeric(scalar, errors=errors) + else: + assert checker(to_numeric(scalar, errors=errors)) + + +@pytest.mark.parametrize("data", [[1, 2, 3], [1.0, np.nan, 3, np.nan]]) +def test_numeric_dtypes(data, transform_assert_equal): + transform, assert_equal = transform_assert_equal + data = transform(data) + + result = to_numeric(data) + assert_equal(result, data) + + +@pytest.mark.parametrize( + "data,exp", + [ + (["1", "2", "3"], np.array([1, 2, 3], dtype="int64")), + (["1.5", "2.7", "3.4"], np.array([1.5, 2.7, 3.4])), + ], +) +def test_str(data, exp, transform_assert_equal): + transform, assert_equal = transform_assert_equal + result = to_numeric(transform(data)) + + expected = transform(exp) + assert_equal(result, expected) + + +def test_datetime_like(tz_naive_fixture, transform_assert_equal): + transform, assert_equal = transform_assert_equal + idx = pd.date_range("20130101", periods=3, tz=tz_naive_fixture) + + result = to_numeric(transform(idx)) + expected = transform(idx.asi8) + assert_equal(result, expected) + + +def test_timedelta(transform_assert_equal): + transform, assert_equal = transform_assert_equal + idx = pd.timedelta_range("1 days", periods=3, freq="D") + + result = to_numeric(transform(idx)) + expected = transform(idx.asi8) + assert_equal(result, expected) + + +@pytest.mark.parametrize( + "scalar", + [ + pd.Timedelta(1, "D"), + pd.Timestamp("2017-01-01T12"), + pd.Timestamp("2017-01-01T12", tz="US/Pacific"), + ], +) +def test_timedelta_timestamp_scalar(scalar): + # GH#59944 + result = to_numeric(scalar) + expected = to_numeric(Series(scalar))[0] + assert result == expected + + +def test_period(request, transform_assert_equal): + transform, assert_equal = transform_assert_equal + + idx = pd.period_range("2011-01", periods=3, freq="M", name="") + inp = transform(idx) + + if not isinstance(inp, Index): + request.applymarker( + pytest.mark.xfail(reason="Missing PeriodDtype support in to_numeric") + ) + result = to_numeric(inp) + expected = transform(idx.asi8) + assert_equal(result, expected) + + +@pytest.mark.parametrize( + "errors,expected", + [ + ("raise", "Invalid object type at position 0"), + ("coerce", Series([np.nan, 1.0, np.nan])), + ], +) +def test_non_hashable(errors, expected): + # see gh-13324 + ser = Series([[10.0, 2], 1.0, "apple"]) + + if isinstance(expected, str): + with pytest.raises(TypeError, match=expected): + to_numeric(ser, errors=errors) + else: + result = to_numeric(ser, errors=errors) + tm.assert_series_equal(result, expected) + + +def test_downcast_invalid_cast(): + # see gh-13352 + data = ["1", 2, 3] + invalid_downcast = "unsigned-integer" + msg = "invalid downcasting method provided" + + with pytest.raises(ValueError, match=msg): + to_numeric(data, downcast=invalid_downcast) + + +def test_errors_invalid_value(): + # see gh-26466 + data = ["1", 2, 3] + invalid_error_value = "invalid" + msg = "invalid error value specified" + + with pytest.raises(ValueError, match=msg): + to_numeric(data, errors=invalid_error_value) + + +@pytest.mark.parametrize( + "data", + [ + ["1", 2, 3], + [1, 2, 3], + np.array(["1970-01-02", "1970-01-03", "1970-01-04"], dtype="datetime64[D]"), + ], +) +@pytest.mark.parametrize( + "kwargs,exp_dtype", + [ + # Basic function tests. + ({}, np.int64), + ({"downcast": None}, np.int64), + # Support below np.float32 is rare and far between. + ({"downcast": "float"}, np.dtype(np.float32).char), + # Basic dtype support. + ({"downcast": "unsigned"}, np.dtype(np.typecodes["UnsignedInteger"][0])), + ], +) +def test_downcast_basic(data, kwargs, exp_dtype): + # see gh-13352 + result = to_numeric(data, **kwargs) + expected = np.array([1, 2, 3], dtype=exp_dtype) + tm.assert_numpy_array_equal(result, expected) + + +@pytest.mark.parametrize("signed_downcast", ["integer", "signed"]) +@pytest.mark.parametrize( + "data", + [ + ["1", 2, 3], + [1, 2, 3], + np.array(["1970-01-02", "1970-01-03", "1970-01-04"], dtype="datetime64[D]"), + ], +) +def test_signed_downcast(data, signed_downcast): + # see gh-13352 + smallest_int_dtype = np.dtype(np.typecodes["Integer"][0]) + expected = np.array([1, 2, 3], dtype=smallest_int_dtype) + + res = to_numeric(data, downcast=signed_downcast) + tm.assert_numpy_array_equal(res, expected) + + +def test_ignore_downcast_neg_to_unsigned(): + # Cannot cast to an unsigned integer + # because we have a negative number. + data = ["-1", 2, 3] + expected = np.array([-1, 2, 3], dtype=np.int64) + + res = to_numeric(data, downcast="unsigned") + tm.assert_numpy_array_equal(res, expected) + + +# Warning in 32 bit platforms +@pytest.mark.parametrize("downcast", ["integer", "signed", "unsigned"]) +@pytest.mark.parametrize( + "data,expected", + [ + (["1.1", 2, 3], np.array([1.1, 2, 3], dtype=np.float64)), + ( + [10000.0, 20000, 3000, 40000.36, 50000, 50000.00], + np.array( + [10000.0, 20000, 3000, 40000.36, 50000, 50000.00], dtype=np.float64 + ), + ), + ], +) +def test_ignore_downcast_cannot_convert_float(data, expected, downcast): + # Cannot cast to an integer (signed or unsigned) + # because we have a float number. + res = to_numeric(data, downcast=downcast) + tm.assert_numpy_array_equal(res, expected) + + +@pytest.mark.parametrize( + "downcast,expected_dtype", + [("integer", np.int16), ("signed", np.int16), ("unsigned", np.uint16)], +) +def test_downcast_not8bit(downcast, expected_dtype): + # the smallest integer dtype need not be np.(u)int8 + data = ["256", 257, 258] + + expected = np.array([256, 257, 258], dtype=expected_dtype) + res = to_numeric(data, downcast=downcast) + tm.assert_numpy_array_equal(res, expected) + + +@pytest.mark.parametrize( + "dtype,downcast,min_max", + [ + ("int8", "integer", [iinfo(np.int8).min, iinfo(np.int8).max]), + ("int16", "integer", [iinfo(np.int16).min, iinfo(np.int16).max]), + ("int32", "integer", [iinfo(np.int32).min, iinfo(np.int32).max]), + ("int64", "integer", [iinfo(np.int64).min, iinfo(np.int64).max]), + ("uint8", "unsigned", [iinfo(np.uint8).min, iinfo(np.uint8).max]), + ("uint16", "unsigned", [iinfo(np.uint16).min, iinfo(np.uint16).max]), + ("uint32", "unsigned", [iinfo(np.uint32).min, iinfo(np.uint32).max]), + ("uint64", "unsigned", [iinfo(np.uint64).min, iinfo(np.uint64).max]), + ("int16", "integer", [iinfo(np.int8).min, iinfo(np.int8).max + 1]), + ("int32", "integer", [iinfo(np.int16).min, iinfo(np.int16).max + 1]), + ("int64", "integer", [iinfo(np.int32).min, iinfo(np.int32).max + 1]), + ("int16", "integer", [iinfo(np.int8).min - 1, iinfo(np.int16).max]), + ("int32", "integer", [iinfo(np.int16).min - 1, iinfo(np.int32).max]), + ("int64", "integer", [iinfo(np.int32).min - 1, iinfo(np.int64).max]), + ("uint16", "unsigned", [iinfo(np.uint8).min, iinfo(np.uint8).max + 1]), + ("uint32", "unsigned", [iinfo(np.uint16).min, iinfo(np.uint16).max + 1]), + ("uint64", "unsigned", [iinfo(np.uint32).min, iinfo(np.uint32).max + 1]), + ], +) +def test_downcast_limits(dtype, downcast, min_max): + # see gh-14404: test the limits of each downcast. + series = to_numeric(Series(min_max), downcast=downcast) + assert series.dtype == dtype + + +def test_downcast_float64_to_float32(): + # GH-43693: Check float64 preservation when >= 16,777,217 + series = Series([16777217.0, np.finfo(np.float64).max, np.nan], dtype=np.float64) + result = to_numeric(series, downcast="float") + + assert series.dtype == result.dtype + + +def test_downcast_uint64(): + # see gh-14422: + # BUG: to_numeric doesn't work uint64 numbers + ser = Series([0, 9223372036854775808]) + result = to_numeric(ser, downcast="unsigned") + expected = Series([0, 9223372036854775808], dtype=np.uint64) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "data,exp_data", + [ + ( + [200, 300, "", "NaN", 30000000000000000000], + [200, 300, np.nan, np.nan, 30000000000000000000], + ), + ( + ["12345678901234567890", "1234567890", "ITEM"], + [12345678901234567890, 1234567890, np.nan], + ), + ], +) +def test_coerce_uint64_conflict(data, exp_data): + # see gh-17007 and gh-17125 + # + # Still returns float despite the uint64-nan conflict, + # which would normally force the casting to object. + result = to_numeric(Series(data), errors="coerce") + expected = Series(exp_data, dtype=float) + tm.assert_series_equal(result, expected) + + +def test_non_coerce_uint64_conflict(): + # see gh-17007 and gh-17125 + # + # For completeness. + ser = Series(["12345678901234567890", "1234567890", "ITEM"]) + + with pytest.raises(ValueError, match="Unable to parse string"): + to_numeric(ser, errors="raise") + + +@pytest.mark.parametrize("dc1", ["integer", "float", "unsigned"]) +@pytest.mark.parametrize("dc2", ["integer", "float", "unsigned"]) +def test_downcast_empty(dc1, dc2): + # GH32493 + + tm.assert_numpy_array_equal( + to_numeric([], downcast=dc1), + to_numeric([], downcast=dc2), + check_dtype=False, + ) + + +def test_failure_to_convert_uint64_string_to_NaN(): + # GH 32394 + result = to_numeric("uint64", errors="coerce") + assert np.isnan(result) + + ser = Series([32, 64, np.nan]) + result = to_numeric(Series(["32", "64", "uint64"]), errors="coerce") + tm.assert_series_equal(result, ser) + + +@pytest.mark.parametrize( + "strrep", + [ + "243.164", + "245.968", + "249.585", + "259.745", + "265.742", + "272.567", + "279.196", + "280.366", + "275.034", + "271.351", + "272.889", + "270.627", + "280.828", + "290.383", + "308.153", + "319.945", + "336.0", + "344.09", + "351.385", + "356.178", + "359.82", + "361.03", + "367.701", + "380.812", + "387.98", + "391.749", + "391.171", + "385.97", + "385.345", + "386.121", + "390.996", + "399.734", + "413.073", + "421.532", + "430.221", + "437.092", + "439.746", + "446.01", + "451.191", + "460.463", + "469.779", + "472.025", + "479.49", + "474.864", + "467.54", + "471.978", + ], +) +def test_precision_float_conversion(strrep): + # GH 31364 + result = to_numeric(strrep) + + assert result == float(strrep) + + +@pytest.mark.parametrize( + "values, expected", + [ + (["1", "2", None], Series([1, 2, pd.NA], dtype="Int64")), + (["1", "2", "3"], Series([1, 2, 3], dtype="Int64")), + (["1", "2", 3], Series([1, 2, 3], dtype="Int64")), + (["1", "2", 3.5], Series([1, 2, 3.5], dtype="Float64")), + (["1", None, 3.5], Series([1, pd.NA, 3.5], dtype="Float64")), + (["1", "2", "3.5"], Series([1, 2, 3.5], dtype="Float64")), + ], +) +def test_to_numeric_from_nullable_string(values, nullable_string_dtype, expected): + # https://github.com/pandas-dev/pandas/issues/37262 + s = Series(values, dtype=nullable_string_dtype) + result = to_numeric(s) + tm.assert_series_equal(result, expected) + + +def test_to_numeric_from_nullable_string_coerce(nullable_string_dtype): + # GH#52146 + values = ["a", "1"] + ser = Series(values, dtype=nullable_string_dtype) + result = to_numeric(ser, errors="coerce") + expected = Series([pd.NA, 1], dtype="Int64") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "data, input_dtype, downcast, expected_dtype", + ( + ([1, 1], "Int64", "integer", "Int8"), + ([1.0, pd.NA], "Float64", "integer", "Int8"), + ([1.0, 1.1], "Float64", "integer", "Float64"), + ([1, pd.NA], "Int64", "integer", "Int8"), + ([450, 300], "Int64", "integer", "Int16"), + ([1, 1], "Float64", "integer", "Int8"), + ([np.iinfo(np.int64).max - 1, 1], "Int64", "integer", "Int64"), + ([1, 1], "Int64", "signed", "Int8"), + ([1.0, 1.0], "Float32", "signed", "Int8"), + ([1.0, 1.1], "Float64", "signed", "Float64"), + ([1, pd.NA], "Int64", "signed", "Int8"), + ([450, -300], "Int64", "signed", "Int16"), + ([np.iinfo(np.uint64).max - 1, 1], "UInt64", "signed", "UInt64"), + ([1, 1], "Int64", "unsigned", "UInt8"), + ([1.0, 1.0], "Float32", "unsigned", "UInt8"), + ([1.0, 1.1], "Float64", "unsigned", "Float64"), + ([1, pd.NA], "Int64", "unsigned", "UInt8"), + ([450, -300], "Int64", "unsigned", "Int64"), + ([-1, -1], "Int32", "unsigned", "Int32"), + ([1, 1], "Float64", "float", "Float32"), + ([1, 1.1], "Float64", "float", "Float32"), + ([1, 1], "Float32", "float", "Float32"), + ([1, 1.1], "Float32", "float", "Float32"), + ), +) +def test_downcast_nullable_numeric(data, input_dtype, downcast, expected_dtype): + arr = pd.array(data, dtype=input_dtype) + result = to_numeric(arr, downcast=downcast) + expected = pd.array(data, dtype=expected_dtype) + tm.assert_extension_array_equal(result, expected) + + +def test_downcast_nullable_mask_is_copied(): + # GH38974 + + arr = pd.array([1, 2, pd.NA], dtype="Int64") + + result = to_numeric(arr, downcast="integer") + expected = pd.array([1, 2, pd.NA], dtype="Int8") + tm.assert_extension_array_equal(result, expected) + + arr[1] = pd.NA # should not modify result + tm.assert_extension_array_equal(result, expected) + + +def test_to_numeric_scientific_notation(): + # GH 15898 + result = to_numeric("1.7e+308") + expected = np.float64(1.7e308) + assert result == expected + + +@pytest.mark.parametrize("val", [9876543210.0, 2.0**128]) +def test_to_numeric_large_float_not_downcast_to_float_32(val): + # GH 19729 + expected = Series([val]) + result = to_numeric(expected, downcast="float") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "val, dtype", [(1, "Int64"), (1.5, "Float64"), (True, "boolean")] +) +def test_to_numeric_dtype_backend(val, dtype): + # GH#50505 + ser = Series([val], dtype=object) + result = to_numeric(ser, dtype_backend="numpy_nullable") + expected = Series([val], dtype=dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "val, dtype", + [ + (1, "Int64"), + (1.5, "Float64"), + (True, "boolean"), + (1, "int64[pyarrow]"), + (1.5, "float64[pyarrow]"), + (True, "bool[pyarrow]"), + ], +) +def test_to_numeric_dtype_backend_na(val, dtype): + # GH#50505 + if "pyarrow" in dtype: + pytest.importorskip("pyarrow") + dtype_backend = "pyarrow" + else: + dtype_backend = "numpy_nullable" + ser = Series([val, None], dtype=object) + result = to_numeric(ser, dtype_backend=dtype_backend) + expected = Series([val, pd.NA], dtype=dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "val, dtype, downcast", + [ + (1, "Int8", "integer"), + (1.5, "Float32", "float"), + (1, "Int8", "signed"), + (1, "int8[pyarrow]", "integer"), + (1.5, "float[pyarrow]", "float"), + (1, "int8[pyarrow]", "signed"), + ], +) +def test_to_numeric_dtype_backend_downcasting(val, dtype, downcast): + # GH#50505 + if "pyarrow" in dtype: + pytest.importorskip("pyarrow") + dtype_backend = "pyarrow" + else: + dtype_backend = "numpy_nullable" + ser = Series([val, None], dtype=object) + result = to_numeric(ser, dtype_backend=dtype_backend, downcast=downcast) + expected = Series([val, pd.NA], dtype=dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "smaller, dtype_backend", + [["UInt8", "numpy_nullable"], ["uint8[pyarrow]", "pyarrow"]], +) +def test_to_numeric_dtype_backend_downcasting_uint(smaller, dtype_backend): + # GH#50505 + if dtype_backend == "pyarrow": + pytest.importorskip("pyarrow") + ser = Series([1, pd.NA], dtype="UInt64") + result = to_numeric(ser, dtype_backend=dtype_backend, downcast="unsigned") + expected = Series([1, pd.NA], dtype=smaller) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "dtype", + [ + "Int64", + "UInt64", + "Float64", + "boolean", + "int64[pyarrow]", + "uint64[pyarrow]", + "float64[pyarrow]", + "bool[pyarrow]", + ], +) +def test_to_numeric_dtype_backend_already_nullable(dtype): + # GH#50505 + if "pyarrow" in dtype: + pytest.importorskip("pyarrow") + ser = Series([1, pd.NA], dtype=dtype) + result = to_numeric(ser, dtype_backend="numpy_nullable") + expected = Series([1, pd.NA], dtype=dtype) + tm.assert_series_equal(result, expected) + + +def test_to_numeric_dtype_backend_error(dtype_backend): + # GH#50505 + ser = Series(["a", "b", ""]) + expected = ser.copy() + with pytest.raises(ValueError, match="Unable to parse string"): + to_numeric(ser, dtype_backend=dtype_backend) + + result = to_numeric(ser, dtype_backend=dtype_backend, errors="coerce") + if dtype_backend == "pyarrow": + dtype = "double[pyarrow]" + else: + dtype = "Float64" + expected = Series([pd.NA, pd.NA, pd.NA], dtype=dtype) + tm.assert_series_equal(result, expected) + + +def test_invalid_dtype_backend(): + ser = Series([1, 2, 3]) + msg = ( + "dtype_backend numpy is invalid, only 'numpy_nullable' and " + "'pyarrow' are allowed." + ) + with pytest.raises(ValueError, match=msg): + to_numeric(ser, dtype_backend="numpy") + + +def test_coerce_pyarrow_backend(): + # GH 52588 + pa = pytest.importorskip("pyarrow") + ser = Series(list("12x"), dtype=ArrowDtype(pa.string())) + result = to_numeric(ser, errors="coerce", dtype_backend="pyarrow") + expected = Series([1, 2, None], dtype=ArrowDtype(pa.int64())) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/tools/test_to_time.py b/pandas/tests/tools/test_to_time.py new file mode 100644 index 0000000000000000000000000000000000000000..d4f48c9e0721e51dbc4c003e61f145d28f5ccae6 --- /dev/null +++ b/pandas/tests/tools/test_to_time.py @@ -0,0 +1,64 @@ +from datetime import time +import locale + +import numpy as np +import pytest + +from pandas import Series +import pandas._testing as tm +from pandas.core.tools.times import to_time + +# The tests marked with this are locale-dependent. +# They pass, except when the machine locale is zh_CN or it_IT. +fails_on_non_english = pytest.mark.xfail( + locale.getlocale()[0] in ("zh_CN", "it_IT"), + reason="fail on a CI build with LC_ALL=zh_CN.utf8/it_IT.utf8", + strict=False, +) + + +class TestToTime: + @pytest.mark.parametrize( + "time_string", + [ + "14:15", + "1415", + pytest.param("2:15pm", marks=fails_on_non_english), + pytest.param("0215pm", marks=fails_on_non_english), + "14:15:00", + "141500", + pytest.param("2:15:00pm", marks=fails_on_non_english), + pytest.param("021500pm", marks=fails_on_non_english), + time(14, 15), + ], + ) + def test_parsers_time(self, time_string): + # GH#11818 + assert to_time(time_string) == time(14, 15) + + def test_odd_format(self): + new_string = "14.15" + assert to_time(new_string, format="%H.%M") == time(14, 15) + + def test_arraylike(self): + arg = ["14:15", "20:20"] + expected_arr = [time(14, 15), time(20, 20)] + assert to_time(arg) == expected_arr + assert to_time(arg, format="%H:%M") == expected_arr + assert to_time(arg, infer_time_format=True) == expected_arr + assert to_time(arg, format="%I:%M%p", errors="coerce") == [None, None] + + with pytest.raises(ValueError, match="errors must be"): + to_time(arg, format="%I:%M%p", errors="ignore") + + msg = "Cannot convert.+to a time with given format" + with pytest.raises(ValueError, match=msg): + to_time(arg, format="%I:%M%p", errors="raise") + + tm.assert_series_equal( + to_time(Series(arg, name="test")), Series(expected_arr, name="test") + ) + + res = to_time(np.array(arg)) + assert isinstance(res, list) + assert res == expected_arr diff --git a/pandas/tests/tools/test_to_timedelta.py b/pandas/tests/tools/test_to_timedelta.py new file mode 100644 index 0000000000000000000000000000000000000000..25c89401c6b3426ab6a866ee454d2b242e370f46 --- /dev/null +++ b/pandas/tests/tools/test_to_timedelta.py @@ -0,0 +1,377 @@ +from datetime import ( + time, + timedelta, +) + +import numpy as np +import pytest + +from pandas.compat import ( + IS64, + WASM, +) +from pandas.errors import ( + OutOfBoundsTimedelta, + Pandas4Warning, +) + +import pandas as pd +from pandas import ( + Series, + TimedeltaIndex, + isna, + to_timedelta, +) +import pandas._testing as tm +from pandas.core.arrays import TimedeltaArray + + +class TestTimedeltas: + def test_to_timedelta_mixed_unit_strings(self): + # https://github.com/pandas-dev/pandas/pull/63196#issuecomment-3595743721 + result = to_timedelta(["1 days 06:05:01.00003", "15.5us"]) + + expected = TimedeltaIndex([108_301_000_030_000, 15_500], dtype="m8[ns]") + tm.assert_index_equal(result, expected) + + def test_to_timedelta_all_nat_unit(self): + # With all-NaT entries, we get "s" unit + result = to_timedelta([None]) + assert result.unit == "s" + + result = TimedeltaIndex([None]) + assert result.unit == "s" + + def test_to_timedelta_month_raises(self): + obj = np.timedelta64(1, "M") + + msg = "Unit M is not supported." + with pytest.raises(ValueError, match=msg): + to_timedelta(obj) + with pytest.raises(ValueError, match=msg): + pd.Timedelta(obj) + with pytest.raises(ValueError, match=msg): + to_timedelta([obj]) + with pytest.raises(ValueError, match=msg): + TimedeltaIndex([obj]) + + def test_to_timedelta_none(self): + # GH#23055 + assert to_timedelta(None) is pd.NaT + + def test_to_timedelta_dt64_raises(self): + # Passing datetime64-dtype data to TimedeltaIndex is no longer + # supported GH#29794 + msg = r"dtype datetime64\[ns\] cannot be converted to timedelta64\[ns\]" + + ser = Series([pd.NaT], dtype="M8[ns]") + with pytest.raises(TypeError, match=msg): + to_timedelta(ser) + with pytest.raises(TypeError, match=msg): + ser.to_frame().apply(to_timedelta) + + def test_to_timedelta_readonly(self, writable): + # GH#34857 + arr = np.array([], dtype=object) + arr.setflags(write=writable) + result = to_timedelta(arr) + expected = to_timedelta([]) + tm.assert_index_equal(result, expected) + + def test_to_timedelta_null(self): + result = to_timedelta(["", ""]) + assert isna(result).all() + + def test_to_timedelta_same_np_timedelta64(self): + # pass thru + result = to_timedelta(np.array([np.timedelta64(1, "s")])) + expected = pd.Index(np.array([np.timedelta64(1, "s")])) + tm.assert_index_equal(result, expected) + + def test_to_timedelta_series(self): + # Series + expected = Series( + [timedelta(days=1), timedelta(days=1, seconds=1)], dtype="m8[us]" + ) + + msg = "'d' is deprecated and will be removed in a future version." + with tm.assert_produces_warning(Pandas4Warning, match=msg): + result = to_timedelta(Series(["1d", "1days 00:00:01"])) + tm.assert_series_equal(result, expected) + + def test_to_timedelta_units(self): + # with units + result = TimedeltaIndex( + [np.timedelta64(0, "ns"), np.timedelta64(10, "s").astype("m8[ns]")] + ) + expected = to_timedelta([0, 10], unit="s").as_unit("ns") + tm.assert_index_equal(result, expected) + + def test_to_timedelta_mixed_dtype(self): + # https://github.com/pandas-dev/pandas/issues/64044 + result = to_timedelta(np.array([0.5, 2]), unit="m") + expected = TimedeltaIndex( + ["0 days 00:00:30", "0 days 00:02:00"], dtype="timedelta64[ns]", freq=None + ) + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "dtype, unit", + [ + ["int64", "s"], + ["int64", "m"], + ["int64", "h"], + ["timedelta64[s]", "s"], + ["timedelta64[D]", "D"], + ], + ) + def test_to_timedelta_units_dtypes(self, dtype, unit): + # arrays of various dtypes + arr = np.array([1] * 5, dtype=dtype) + result = to_timedelta(arr, unit=unit) + exp_dtype = "m8[s]" + expected = TimedeltaIndex([np.timedelta64(1, unit)] * 5, dtype=exp_dtype) + tm.assert_index_equal(result, expected) + + def test_to_timedelta_oob_non_nano(self): + arr = np.array([pd.NaT._value + 1], dtype="timedelta64[m]") + + msg = ( + "Cannot convert -9223372036854775807 minutes to " + r"timedelta64\[s\] without overflow" + ) + with pytest.raises(OutOfBoundsTimedelta, match=msg): + to_timedelta(arr) + + with pytest.raises(OutOfBoundsTimedelta, match=msg): + TimedeltaIndex(arr) + + with pytest.raises(OutOfBoundsTimedelta, match=msg): + TimedeltaArray._from_sequence(arr, dtype="m8[s]") + + @pytest.mark.parametrize("box", [lambda x: x, pd.DataFrame]) + @pytest.mark.parametrize("errors", ["raise", "coerce"]) + def test_to_timedelta_dataframe(self, box, errors): + # GH 11776 + arg = box(np.arange(10).reshape(2, 5)) + with pytest.raises(TypeError, match="1-d array"): + to_timedelta(arg, errors=errors) + + def test_to_timedelta_invalid_errors(self): + # bad value for errors parameter + msg = "errors must be one of" + with pytest.raises(ValueError, match=msg): + to_timedelta(["foo"], errors="never") + + @pytest.mark.parametrize("arg", [[1, 2], 1]) + def test_to_timedelta_invalid_unit(self, arg): + # these will error + msg = "invalid unit abbreviation: foo" + with pytest.raises(ValueError, match=msg): + to_timedelta(arg, unit="foo") + + def test_to_timedelta_time(self): + # time not supported ATM + msg = ( + "Value must be Timedelta, string, integer, float, timedelta or convertible" + ) + with pytest.raises(ValueError, match=msg): + to_timedelta(time(second=1)) + assert to_timedelta(time(second=1), errors="coerce") is pd.NaT + + def test_to_timedelta_bad_value(self): + msg = "Could not convert 'foo' to NumPy timedelta" + with pytest.raises(ValueError, match=msg): + to_timedelta(["foo", "bar"]) + + def test_to_timedelta_bad_value_coerce(self): + tm.assert_index_equal( + TimedeltaIndex([pd.NaT, pd.NaT]), + to_timedelta(["foo", "bar"], errors="coerce"), + ) + + tm.assert_index_equal( + TimedeltaIndex(["1 day", pd.NaT, "1 min"]), + to_timedelta(["1 day", "bar", "1 min"], errors="coerce"), + ) + + @pytest.mark.parametrize( + "val, errors", + [ + ("1M", True), + ("1 M", True), + ("1Y", True), + ("1 Y", True), + ("1y", True), + ("1 y", True), + ("1m", False), + ("1 m", False), + ("1 day", False), + ("2day", False), + ], + ) + def test_unambiguous_timedelta_values(self, val, errors): + # GH36666 Deprecate use of strings denoting units with 'M', 'Y', 'm' or 'y' + # in pd.to_timedelta + msg = "Units 'M', 'Y' and 'y' do not represent unambiguous timedelta" + if errors: + with pytest.raises(ValueError, match=msg): + to_timedelta(val) + else: + # check it doesn't raise + to_timedelta(val) + + def test_to_timedelta_via_apply(self): + # GH 5458 + expected = Series([np.timedelta64(1, "s")], dtype="m8[us]") + result = Series(["00:00:01"]).apply(to_timedelta) + tm.assert_series_equal(result, expected) + + result = Series([to_timedelta("00:00:01")]) + tm.assert_series_equal(result, expected) + + def test_to_timedelta_inference_without_warning(self): + # GH#41731 inference produces a warning in the Series constructor, + # but _not_ in to_timedelta + vals = ["00:00:01", pd.NaT] + with tm.assert_produces_warning(None): + result = to_timedelta(vals) + + expected = TimedeltaIndex([pd.Timedelta(seconds=1), pd.NaT], dtype="m8[us]") + tm.assert_index_equal(result, expected) + + def test_to_timedelta_on_missing_values(self): + # GH5438 + timedelta_NaT = np.timedelta64("NaT") + + actual = to_timedelta(Series(["00:00:01", np.nan])) + expected = Series( + [np.timedelta64(1000000000, "ns"), timedelta_NaT], + dtype=f"{tm.ENDIAN}m8[us]", + ) + tm.assert_series_equal(actual, expected) + + ser = Series(["00:00:01", pd.NaT], dtype="m8[us]") + actual = to_timedelta(ser) + tm.assert_series_equal(actual, expected) + + @pytest.mark.parametrize("val", [np.nan, pd.NaT, pd.NA]) + def test_to_timedelta_on_missing_values_scalar(self, val): + actual = to_timedelta(val) + assert actual._value == np.timedelta64("NaT").astype("int64") + + @pytest.mark.parametrize("val", [np.nan, pd.NaT, pd.NA]) + def test_to_timedelta_on_missing_values_list(self, val): + actual = to_timedelta([val]) + assert actual[0]._value == np.timedelta64("NaT").astype("int64") + + @pytest.mark.skipif(WASM, reason="No fp exception support in WASM") + @pytest.mark.xfail(not IS64, reason="Floating point error") + def test_to_timedelta_float(self): + # https://github.com/pandas-dev/pandas/issues/25077 + arr = np.arange(0, 1, 1e-6)[-10:] + result = to_timedelta(arr, unit="s") + expected_asi8 = np.arange(999990000, 10**9, 1000, dtype="int64") + tm.assert_numpy_array_equal(result.asi8, expected_asi8) + + def test_to_timedelta_coerce_strings_unit(self): + arr = np.array([1, 2, "error"], dtype=object) + result = to_timedelta(arr, unit="ns", errors="coerce") + expected = to_timedelta([1, 2, pd.NaT], unit="ns") + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + "expected_val, result_val", [[timedelta(days=2), 2], [None, None]] + ) + def test_to_timedelta_nullable_int64_dtype(self, expected_val, result_val): + # GH 35574 + expected = Series([timedelta(days=1), expected_val], dtype="m8[s]") + result = to_timedelta(Series([1, result_val], dtype="Int64"), unit="days") + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + ("input", "expected"), + [ + ("8:53:08.71800000001", "8:53:08.718"), + ("8:53:08.718001", "8:53:08.718001"), + ("8:53:08.7180000001", "8:53:08.7180000001"), + ("-8:53:08.71800000001", "-8:53:08.718"), + ("8:53:08.7180000089", "8:53:08.718000008"), + ], + ) + @pytest.mark.parametrize("func", [pd.Timedelta, to_timedelta]) + def test_to_timedelta_precision_over_nanos(self, input, expected, func): + # GH: 36738 + expected = pd.Timedelta(expected) + result = func(input) + assert result == expected + + def test_to_timedelta_zerodim(self, fixed_now_ts): + # ndarray.item() incorrectly returns int for dt64[ns] and td64[ns] + dt64 = fixed_now_ts.to_datetime64() + arg = np.array(dt64) + + msg = ( + "Value must be Timedelta, string, integer, float, timedelta " + "or convertible, not datetime64" + ) + with pytest.raises(ValueError, match=msg): + to_timedelta(arg) + + arg2 = arg.view("m8[ns]") + result = to_timedelta(arg2) + assert isinstance(result, pd.Timedelta) + assert result._value == dt64.view("i8") + + def test_to_timedelta_numeric_ea(self, any_numeric_ea_dtype): + # GH#48796 + ser = Series([1, pd.NA], dtype=any_numeric_ea_dtype) + result = to_timedelta(ser) + expected = Series([pd.Timedelta(1, unit="ns"), pd.NaT]) + tm.assert_series_equal(result, expected) + + def test_to_timedelta_fraction(self): + result = to_timedelta(1.0 / 3, unit="h") + expected = pd.Timedelta("0 days 00:19:59.999999998") + assert result == expected + + def test_to_timedelta_unit_round_floats(self): + # When the float is round, we give the requested unit + # (or nearest-supported) like we do with integers + arr = np.array([45.0], dtype=object) + result = to_timedelta(arr, unit="s") + expected = to_timedelta([45], unit="s") + tm.assert_index_equal(result, expected) + + arr2 = arr.astype(np.float64) + result2 = to_timedelta(arr2, unit="s") + tm.assert_index_equal(result2, expected) + + def test_to_timedelta_unit_non_round_floats(self): + # With non-round floats, we have to give nanosecond + arr = np.array([45.5], dtype=object) + result = to_timedelta(arr, unit="s") + assert result.unit == "ns" + + arr2 = arr.astype(np.float64) + result2 = to_timedelta(arr2, unit="s") + assert result2.unit == "ns" + + +def test_from_numeric_arrow_dtype(any_numeric_ea_dtype): + # GH 52425 + pytest.importorskip("pyarrow") + ser = Series([1, 2], dtype=f"{any_numeric_ea_dtype.lower()}[pyarrow]") + result = to_timedelta(ser) + expected = Series([1, 2], dtype="timedelta64[ns]") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("unit", ["ns", "ms"]) +def test_from_timedelta_arrow_dtype(unit): + # GH 54298 + pytest.importorskip("pyarrow") + expected = Series([timedelta(1)], dtype=f"duration[{unit}][pyarrow]") + result = to_timedelta(expected) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/tseries/__init__.py b/pandas/tests/tseries/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/tslibs/__init__.py b/pandas/tests/tslibs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/tslibs/test_ccalendar.py b/pandas/tests/tslibs/test_ccalendar.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb9d3387c91a84baa83148021503b445b4793ce --- /dev/null +++ b/pandas/tests/tslibs/test_ccalendar.py @@ -0,0 +1,64 @@ +from datetime import ( + date, + datetime, +) + +from hypothesis import given +import numpy as np +import pytest + +from pandas._libs.tslibs import ccalendar + +from pandas._testing._hypothesis import DATETIME_IN_PD_TIMESTAMP_RANGE_NO_TZ + + +@pytest.mark.parametrize( + "date_tuple,expected", + [ + ((2001, 3, 1), 60), + ((2004, 3, 1), 61), + ((1907, 12, 31), 365), # End-of-year, non-leap year. + ((2004, 12, 31), 366), # End-of-year, leap year. + ], +) +def test_get_day_of_year_numeric(date_tuple, expected): + assert ccalendar.get_day_of_year(*date_tuple) == expected + + +def test_get_day_of_year_dt(): + dt = datetime.fromordinal(1 + np.random.default_rng(2).integers(365 * 4000)) + result = ccalendar.get_day_of_year(dt.year, dt.month, dt.day) + + expected = (dt - dt.replace(month=1, day=1)).days + 1 + assert result == expected + + +@pytest.mark.parametrize( + "input_date_tuple, expected_iso_tuple", + [ + [(2020, 1, 1), (2020, 1, 3)], + [(2019, 12, 31), (2020, 1, 2)], + [(2019, 12, 30), (2020, 1, 1)], + [(2009, 12, 31), (2009, 53, 4)], + [(2010, 1, 1), (2009, 53, 5)], + [(2010, 1, 3), (2009, 53, 7)], + [(2010, 1, 4), (2010, 1, 1)], + [(2006, 1, 1), (2005, 52, 7)], + [(2005, 12, 31), (2005, 52, 6)], + [(2008, 12, 28), (2008, 52, 7)], + [(2008, 12, 29), (2009, 1, 1)], + ], +) +def test_dt_correct_iso_8601_year_week_and_day(input_date_tuple, expected_iso_tuple): + result = ccalendar.get_iso_calendar(*input_date_tuple) + expected_from_date_isocalendar = date(*input_date_tuple).isocalendar() + assert result == expected_from_date_isocalendar + assert result == expected_iso_tuple + + +@pytest.mark.slow +@given(DATETIME_IN_PD_TIMESTAMP_RANGE_NO_TZ) +def test_isocalendar(dt): + expected = dt.isocalendar() + result = ccalendar.get_iso_calendar(dt.year, dt.month, dt.day) + assert result == expected diff --git a/pandas/tests/tslibs/test_np_datetime.py b/pandas/tests/tslibs/test_np_datetime.py new file mode 100644 index 0000000000000000000000000000000000000000..02edf1a09387766d71097ea0baedc2640cfb824b --- /dev/null +++ b/pandas/tests/tslibs/test_np_datetime.py @@ -0,0 +1,222 @@ +import numpy as np +import pytest + +from pandas._libs.tslibs.dtypes import NpyDatetimeUnit +from pandas._libs.tslibs.np_datetime import ( + OutOfBoundsDatetime, + OutOfBoundsTimedelta, + astype_overflowsafe, + is_unitless, + py_get_unit_from_dtype, + py_td64_to_tdstruct, +) + +import pandas._testing as tm + + +def test_is_unitless(): + dtype = np.dtype("M8[ns]") + assert not is_unitless(dtype) + + dtype = np.dtype("datetime64") + assert is_unitless(dtype) + + dtype = np.dtype("m8[ns]") + assert not is_unitless(dtype) + + dtype = np.dtype("timedelta64") + assert is_unitless(dtype) + + msg = "dtype must be datetime64 or timedelta64" + with pytest.raises(ValueError, match=msg): + is_unitless(np.dtype(np.int64)) + + msg = "Argument 'dtype' has incorrect type" + with pytest.raises(TypeError, match=msg): + is_unitless("foo") + + +def test_get_unit_from_dtype(): + # datetime64 + assert py_get_unit_from_dtype(np.dtype("M8[Y]")) == NpyDatetimeUnit.NPY_FR_Y.value + assert py_get_unit_from_dtype(np.dtype("M8[M]")) == NpyDatetimeUnit.NPY_FR_M.value + assert py_get_unit_from_dtype(np.dtype("M8[W]")) == NpyDatetimeUnit.NPY_FR_W.value + # B has been deprecated and removed -> no 3 + assert py_get_unit_from_dtype(np.dtype("M8[D]")) == NpyDatetimeUnit.NPY_FR_D.value + assert py_get_unit_from_dtype(np.dtype("M8[h]")) == NpyDatetimeUnit.NPY_FR_h.value + assert py_get_unit_from_dtype(np.dtype("M8[m]")) == NpyDatetimeUnit.NPY_FR_m.value + assert py_get_unit_from_dtype(np.dtype("M8[s]")) == NpyDatetimeUnit.NPY_FR_s.value + assert py_get_unit_from_dtype(np.dtype("M8[ms]")) == NpyDatetimeUnit.NPY_FR_ms.value + assert py_get_unit_from_dtype(np.dtype("M8[us]")) == NpyDatetimeUnit.NPY_FR_us.value + assert py_get_unit_from_dtype(np.dtype("M8[ns]")) == NpyDatetimeUnit.NPY_FR_ns.value + assert py_get_unit_from_dtype(np.dtype("M8[ps]")) == NpyDatetimeUnit.NPY_FR_ps.value + assert py_get_unit_from_dtype(np.dtype("M8[fs]")) == NpyDatetimeUnit.NPY_FR_fs.value + assert py_get_unit_from_dtype(np.dtype("M8[as]")) == NpyDatetimeUnit.NPY_FR_as.value + + # timedelta64 + assert py_get_unit_from_dtype(np.dtype("m8[Y]")) == NpyDatetimeUnit.NPY_FR_Y.value + assert py_get_unit_from_dtype(np.dtype("m8[M]")) == NpyDatetimeUnit.NPY_FR_M.value + assert py_get_unit_from_dtype(np.dtype("m8[W]")) == NpyDatetimeUnit.NPY_FR_W.value + # B has been deprecated and removed -> no 3 + assert py_get_unit_from_dtype(np.dtype("m8[D]")) == NpyDatetimeUnit.NPY_FR_D.value + assert py_get_unit_from_dtype(np.dtype("m8[h]")) == NpyDatetimeUnit.NPY_FR_h.value + assert py_get_unit_from_dtype(np.dtype("m8[m]")) == NpyDatetimeUnit.NPY_FR_m.value + assert py_get_unit_from_dtype(np.dtype("m8[s]")) == NpyDatetimeUnit.NPY_FR_s.value + assert py_get_unit_from_dtype(np.dtype("m8[ms]")) == NpyDatetimeUnit.NPY_FR_ms.value + assert py_get_unit_from_dtype(np.dtype("m8[us]")) == NpyDatetimeUnit.NPY_FR_us.value + assert py_get_unit_from_dtype(np.dtype("m8[ns]")) == NpyDatetimeUnit.NPY_FR_ns.value + assert py_get_unit_from_dtype(np.dtype("m8[ps]")) == NpyDatetimeUnit.NPY_FR_ps.value + assert py_get_unit_from_dtype(np.dtype("m8[fs]")) == NpyDatetimeUnit.NPY_FR_fs.value + assert py_get_unit_from_dtype(np.dtype("m8[as]")) == NpyDatetimeUnit.NPY_FR_as.value + + +def test_td64_to_tdstruct(): + val = 12454636234 # arbitrary value + + res1 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_ns.value) + exp1 = { + "days": 0, + "hrs": 0, + "min": 0, + "sec": 12, + "ms": 454, + "us": 636, + "ns": 234, + "seconds": 12, + "microseconds": 454636, + "nanoseconds": 234, + } + assert res1 == exp1 + + res2 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_us.value) + exp2 = { + "days": 0, + "hrs": 3, + "min": 27, + "sec": 34, + "ms": 636, + "us": 234, + "ns": 0, + "seconds": 12454, + "microseconds": 636234, + "nanoseconds": 0, + } + assert res2 == exp2 + + res3 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_ms.value) + exp3 = { + "days": 144, + "hrs": 3, + "min": 37, + "sec": 16, + "ms": 234, + "us": 0, + "ns": 0, + "seconds": 13036, + "microseconds": 234000, + "nanoseconds": 0, + } + assert res3 == exp3 + + # Note this out of bounds for nanosecond Timedelta + res4 = py_td64_to_tdstruct(val, NpyDatetimeUnit.NPY_FR_s.value) + exp4 = { + "days": 144150, + "hrs": 21, + "min": 10, + "sec": 34, + "ms": 0, + "us": 0, + "ns": 0, + "seconds": 76234, + "microseconds": 0, + "nanoseconds": 0, + } + assert res4 == exp4 + + +class TestAstypeOverflowSafe: + def test_pass_non_dt64_array(self): + # check that we raise, not segfault + arr = np.arange(5) + dtype = np.dtype("M8[ns]") + + msg = ( + "astype_overflowsafe values.dtype and dtype must be either " + "both-datetime64 or both-timedelta64" + ) + with pytest.raises(TypeError, match=msg): + astype_overflowsafe(arr, dtype, copy=True) + + with pytest.raises(TypeError, match=msg): + astype_overflowsafe(arr, dtype, copy=False) + + def test_pass_non_dt64_dtype(self): + # check that we raise, not segfault + arr = np.arange(5, dtype="i8").view("M8[D]") + dtype = np.dtype("m8[ns]") + + msg = ( + "astype_overflowsafe values.dtype and dtype must be either " + "both-datetime64 or both-timedelta64" + ) + with pytest.raises(TypeError, match=msg): + astype_overflowsafe(arr, dtype, copy=True) + + with pytest.raises(TypeError, match=msg): + astype_overflowsafe(arr, dtype, copy=False) + + def test_astype_overflowsafe_dt64(self): + dtype = np.dtype("M8[ns]") + + dt = np.datetime64("2262-04-05", "D") + arr = dt + np.arange(10, dtype="m8[D]") + + # arr.astype silently overflows, so this + wrong = arr.astype(dtype) + roundtrip = wrong.astype(arr.dtype) + assert not (wrong == roundtrip).all() + + msg = "Out of bounds nanosecond timestamp" + with pytest.raises(OutOfBoundsDatetime, match=msg): + astype_overflowsafe(arr, dtype) + + # But converting to microseconds is fine, and we match numpy's results. + dtype2 = np.dtype("M8[us]") + result = astype_overflowsafe(arr, dtype2) + expected = arr.astype(dtype2) + tm.assert_numpy_array_equal(result, expected) + + def test_astype_overflowsafe_td64(self): + dtype = np.dtype("m8[ns]") + + dt = np.datetime64("2262-04-05", "D") + arr = dt + np.arange(10, dtype="m8[D]") + arr = arr.view("m8[D]") + + # arr.astype silently overflows, so this + wrong = arr.astype(dtype) + roundtrip = wrong.astype(arr.dtype) + assert not (wrong == roundtrip).all() + + msg = r"Cannot convert 106752 days to timedelta64\[ns\] without overflow" + with pytest.raises(OutOfBoundsTimedelta, match=msg): + astype_overflowsafe(arr, dtype) + + # But converting to microseconds is fine, and we match numpy's results. + dtype2 = np.dtype("m8[us]") + result = astype_overflowsafe(arr, dtype2) + expected = arr.astype(dtype2) + tm.assert_numpy_array_equal(result, expected) + + def test_astype_overflowsafe_disallow_rounding(self): + arr = np.array([-1500, 1500], dtype="M8[ns]") + dtype = np.dtype("M8[us]") + + msg = "Cannot losslessly cast '-1500 ns' to us" + with pytest.raises(ValueError, match=msg): + astype_overflowsafe(arr, dtype, round_ok=False) + + result = astype_overflowsafe(arr, dtype, round_ok=True) + expected = arr.astype(dtype) + tm.assert_numpy_array_equal(result, expected) diff --git a/pandas/tests/tslibs/test_npy_units.py b/pandas/tests/tslibs/test_npy_units.py new file mode 100644 index 0000000000000000000000000000000000000000..6d05dc79fbb2cf52688547b672365802463ce6f2 --- /dev/null +++ b/pandas/tests/tslibs/test_npy_units.py @@ -0,0 +1,27 @@ +import numpy as np + +from pandas._libs.tslibs.dtypes import abbrev_to_npy_unit +from pandas._libs.tslibs.vectorized import is_date_array_normalized + +# a datetime64 ndarray which *is* normalized +day_arr = np.arange(10, dtype="i8").view("M8[D]") + + +class TestIsDateArrayNormalized: + def test_is_date_array_normalized_day(self): + arr = day_arr + abbrev = "D" + unit = abbrev_to_npy_unit(abbrev) + result = is_date_array_normalized(arr.view("i8"), None, unit) + assert result is True + + def test_is_date_array_normalized_seconds(self): + abbrev = "s" + arr = day_arr.astype(f"M8[{abbrev}]") + unit = abbrev_to_npy_unit(abbrev) + result = is_date_array_normalized(arr.view("i8"), None, unit) + assert result is True + + arr[0] += np.timedelta64(1, abbrev) + result2 = is_date_array_normalized(arr.view("i8"), None, unit) + assert result2 is False diff --git a/pandas/tests/tslibs/test_parsing.py b/pandas/tests/tslibs/test_parsing.py new file mode 100644 index 0000000000000000000000000000000000000000..5d4e2e8ddb2349459ffd738cb99ceb7c5cd341ad --- /dev/null +++ b/pandas/tests/tslibs/test_parsing.py @@ -0,0 +1,426 @@ +""" +Tests for Timestamp parsing, aimed at pandas/_libs/tslibs/parsing.pyx +""" + +from datetime import datetime +import re + +from dateutil.parser import parse as du_parse +import numpy as np +import pytest + +from pandas._libs.tslibs import ( + parsing, + strptime, +) +from pandas._libs.tslibs.parsing import parse_datetime_string_with_reso +from pandas.compat import ( + ISMUSL, + WASM, + is_platform_windows, +) +import pandas.util._test_decorators as td + +# Usually we wouldn't want this import in this test file (which is targeted at +# tslibs.parsing), but it is convenient to test the Timestamp constructor at +# the same time as the other parsing functions. +from pandas import ( + Timestamp, + option_context, +) +import pandas._testing as tm + + +@pytest.mark.skipif(WASM, reason="tzset is not available on WASM") +@pytest.mark.skipif( + is_platform_windows() or ISMUSL, + reason="TZ setting incorrect on Windows and MUSL Linux", +) +def test_parsing_tzlocal_deprecated(): + # GH#50791 + msg = "|".join( + [ + r"Parsing 'EST' as tzlocal \(dependent on system timezone\) " + r"is no longer supported\. " + "Pass the 'tz' keyword or call tz_localize after construction instead", + ".*included an un-recognized timezone", + ] + ) + dtstr = "Jan 15 2004 03:00 EST" + + with tm.set_timezone("US/Eastern"): + with pytest.raises(ValueError, match=msg): + parse_datetime_string_with_reso(dtstr) + + with pytest.raises(ValueError, match=msg): + parsing.py_parse_datetime_string(dtstr) + + with pytest.raises(ValueError, match=msg): + Timestamp(dtstr) + + +def test_parse_datetime_string_with_reso(): + (parsed, reso) = parse_datetime_string_with_reso("4Q1984") + (parsed_lower, reso_lower) = parse_datetime_string_with_reso("4q1984") + + assert reso == reso_lower + assert parsed == parsed_lower + + +def test_parse_datetime_string_with_reso_nanosecond_reso(): + # GH#46811 + parsed, reso = parse_datetime_string_with_reso("2022-04-20 09:19:19.123456789") + assert reso == "nanosecond" + + +def test_parse_datetime_string_with_reso_invalid_type(): + # Raise on invalid input, don't just return it + msg = "Argument 'date_string' has incorrect type (expected str, got tuple)" + with pytest.raises(TypeError, match=re.escape(msg)): + parse_datetime_string_with_reso((4, 5)) + + +@pytest.mark.parametrize( + "dashed,normal", [("1988-Q2", "1988Q2"), ("2Q-1988", "2Q1988")] +) +def test_parse_time_quarter_with_dash(dashed, normal): + # see gh-9688 + (parsed_dash, reso_dash) = parse_datetime_string_with_reso(dashed) + (parsed, reso) = parse_datetime_string_with_reso(normal) + + assert parsed_dash == parsed + assert reso_dash == reso + + +@pytest.mark.parametrize("dashed", ["-2Q1992", "2-Q1992", "4-4Q1992"]) +def test_parse_time_quarter_with_dash_error(dashed): + msg = f"Unknown datetime string format, unable to parse: {dashed}" + + with pytest.raises(parsing.DateParseError, match=msg): + parse_datetime_string_with_reso(dashed) + + +@pytest.mark.parametrize( + "date_string,expected", + [ + ("123.1234", False), + ("-50000", False), + ("999", False), + ("m", False), + ("T", False), + ("Mon Sep 16, 2013", True), + ("2012-01-01", True), + ("01/01/2012", True), + ("01012012", True), + ("0101", True), + ("1-1", True), + ], +) +def test_does_not_convert_mixed_integer(date_string, expected): + assert parsing._does_string_look_like_datetime(date_string) is expected + + +@pytest.mark.parametrize( + "date_str,kwargs,msg", + [ + ( + "2013Q5", + {}, + ( + "Incorrect quarterly string is given, " + "quarter must be between 1 and 4: 2013Q5" + ), + ), + # see gh-5418 + ( + "2013Q1", + {"freq": "INVLD-L-DEC-SAT"}, + ("Unable to retrieve month information from given freq: INVLD-L-DEC-SAT"), + ), + ], +) +def test_parsers_quarterly_with_freq_error(date_str, kwargs, msg): + with pytest.raises(parsing.DateParseError, match=msg): + parsing.parse_datetime_string_with_reso(date_str, **kwargs) + + +@pytest.mark.parametrize( + "date_str,freq,expected", + [ + ("2013Q2", None, datetime(2013, 4, 1)), + ("2013Q2", "Y-APR", datetime(2012, 8, 1)), + ("2013-Q2", "Y-DEC", datetime(2013, 4, 1)), + ], +) +def test_parsers_quarterly_with_freq(date_str, freq, expected): + result, _ = parsing.parse_datetime_string_with_reso(date_str, freq=freq) + assert result == expected + + +@pytest.mark.parametrize( + "date_str", ["2Q 2005", "2Q-200Y", "2Q-200", "22Q2005", "2Q200.", "6Q-20"] +) +def test_parsers_quarter_invalid(date_str): + if date_str == "6Q-20": + msg = ( + "Incorrect quarterly string is given, quarter " + f"must be between 1 and 4: {date_str}" + ) + else: + msg = f"Unknown datetime string format, unable to parse: {date_str}" + + with pytest.raises(ValueError, match=msg): + parsing.parse_datetime_string_with_reso(date_str) + + +@pytest.mark.parametrize( + "date_str,expected", + [("201101", datetime(2011, 1, 1, 0, 0)), ("200005", datetime(2000, 5, 1, 0, 0))], +) +def test_parsers_month_freq(date_str, expected): + result, _ = parsing.parse_datetime_string_with_reso(date_str, freq="ME") + assert result == expected + + +@td.skip_if_not_us_locale +@pytest.mark.parametrize( + "string,fmt", + [ + ("20111230", "%Y%m%d"), + ("201112300000", "%Y%m%d%H%M"), + ("20111230000000", "%Y%m%d%H%M%S"), + ("20111230T00", "%Y%m%dT%H"), + ("20111230T0000", "%Y%m%dT%H%M"), + ("20111230T000000", "%Y%m%dT%H%M%S"), + ("2011-12-30", "%Y-%m-%d"), + ("2011", "%Y"), + ("2011-01", "%Y-%m"), + ("30-12-2011", "%d-%m-%Y"), + ("2011-12-30 00:00:00", "%Y-%m-%d %H:%M:%S"), + ("2011-12-30T00:00:00", "%Y-%m-%dT%H:%M:%S"), + ("2011-12-30T00:00:00UTC", "%Y-%m-%dT%H:%M:%S%Z"), + ("2011-12-30T00:00:00Z", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00+9", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00+09", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00+090", None), + ("2011-12-30T00:00:00+0900", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00-0900", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00+09:00", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00+09:000", None), + ("2011-12-30T00:00:00+9:0", "%Y-%m-%dT%H:%M:%S%z"), + ("2011-12-30T00:00:00+09:", None), + ("2011-12-30T00:00:00.000000UTC", "%Y-%m-%dT%H:%M:%S.%f%Z"), + ("2011-12-30T00:00:00.000000Z", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000+9", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000+09", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000+090", None), + ("2011-12-30T00:00:00.000000+0900", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000-0900", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000+09:00", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000+09:000", None), + ("2011-12-30T00:00:00.000000+9:0", "%Y-%m-%dT%H:%M:%S.%f%z"), + ("2011-12-30T00:00:00.000000+09:", None), + ("2011-12-30 00:00:00.000000", "%Y-%m-%d %H:%M:%S.%f"), + ("Tue 24 Aug 2021 01:30:48", "%a %d %b %Y %H:%M:%S"), + ("Tuesday 24 Aug 2021 01:30:48", "%A %d %b %Y %H:%M:%S"), + ("Tue 24 Aug 2021 01:30:48 AM", "%a %d %b %Y %I:%M:%S %p"), + ("Tuesday 24 Aug 2021 01:30:48 AM", "%A %d %b %Y %I:%M:%S %p"), + ("27.03.2003 14:55:00.000", "%d.%m.%Y %H:%M:%S.%f"), # GH50317 + ("2023-11-09T20:23:46Z", "%Y-%m-%dT%H:%M:%S%z"), # GH57452 + ], +) +def test_guess_datetime_format_with_parseable_formats(string, fmt): + with tm.maybe_produces_warning( + UserWarning, fmt is not None and re.search(r"%d.*%m", fmt) + ): + result = parsing.guess_datetime_format(string) + assert result == fmt + + +@pytest.mark.parametrize("dayfirst,expected", [(True, "%d/%m/%Y"), (False, "%m/%d/%Y")]) +def test_guess_datetime_format_with_dayfirst(dayfirst, expected): + ambiguous_string = "01/01/2011" + result = parsing.guess_datetime_format(ambiguous_string, dayfirst=dayfirst) + assert result == expected + + +@td.skip_if_not_us_locale +@pytest.mark.parametrize( + "string,fmt", + [ + ("30/Dec/2011", "%d/%b/%Y"), + ("30/December/2011", "%d/%B/%Y"), + ("30/Dec/2011 00:00:00", "%d/%b/%Y %H:%M:%S"), + ], +) +def test_guess_datetime_format_with_locale_specific_formats(string, fmt): + result = parsing.guess_datetime_format(string) + assert result == fmt + + +@pytest.mark.parametrize( + "invalid_dt", + [ + "01/2013", + "12:00:00", + "1/1/1/1", + "this_is_not_a_datetime", + "51a", + "13/2019", + "202001", # YYYYMM isn't ISO8601 + "2020/01", # YYYY/MM isn't ISO8601 either + "87156549591102612381000001219H5", + ], +) +def test_guess_datetime_format_invalid_inputs(invalid_dt): + # A datetime string must include a year, month and a day for it to be + # guessable, in addition to being a string that looks like a datetime. + assert parsing.guess_datetime_format(invalid_dt) is None + + +@pytest.mark.parametrize("invalid_type_dt", [9, datetime(2011, 1, 1)]) +def test_guess_datetime_format_wrong_type_inputs(invalid_type_dt): + # A datetime string must include a year, month and a day for it to be + # guessable, in addition to being a string that looks like a datetime. + with pytest.raises( + TypeError, + match=r"^Argument 'dt_str' has incorrect type \(expected str, got .*\)$", + ): + parsing.guess_datetime_format(invalid_type_dt) + + +@pytest.mark.parametrize( + "string,fmt,dayfirst,warning", + [ + ("2011-1-1", "%Y-%m-%d", False, None), + ("2011-1-1", "%Y-%d-%m", True, None), + ("1/1/2011", "%m/%d/%Y", False, None), + ("1/1/2011", "%d/%m/%Y", True, None), + ("30-1-2011", "%d-%m-%Y", False, UserWarning), + ("30-1-2011", "%d-%m-%Y", True, None), + ("2011-1-1 0:0:0", "%Y-%m-%d %H:%M:%S", False, None), + ("2011-1-1 0:0:0", "%Y-%d-%m %H:%M:%S", True, None), + ("2011-1-3T00:00:0", "%Y-%m-%dT%H:%M:%S", False, None), + ("2011-1-3T00:00:0", "%Y-%d-%mT%H:%M:%S", True, None), + ("2011-1-1 00:00:00", "%Y-%m-%d %H:%M:%S", False, None), + ("2011-1-1 00:00:00", "%Y-%d-%m %H:%M:%S", True, None), + ], +) +def test_guess_datetime_format_no_padding(string, fmt, dayfirst, warning): + # see gh-11142 + msg = ( + rf"Parsing dates in {fmt} format when dayfirst=False \(the default\) " + "was specified. " + "Pass `dayfirst=True` or specify a format to silence this warning." + ) + with tm.assert_produces_warning(warning, match=msg): + result = parsing.guess_datetime_format(string, dayfirst=dayfirst) + assert result == fmt + + +def test_try_parse_dates(): + arr = np.array(["5/1/2000", "6/1/2000", "7/1/2000"], dtype=object) + result = parsing.try_parse_dates(arr, parser=lambda x: du_parse(x, dayfirst=True)) + + expected = np.array([du_parse(d, dayfirst=True) for d in arr]) + tm.assert_numpy_array_equal(result, expected) + + +def test_parse_datetime_string_with_reso_check_instance_type_raise_exception(): + # issue 20684 + msg = "Argument 'date_string' has incorrect type (expected str, got tuple)" + with pytest.raises(TypeError, match=re.escape(msg)): + parse_datetime_string_with_reso((1, 2, 3)) + + result = parse_datetime_string_with_reso("2019") + expected = (datetime(2019, 1, 1), "year") + assert result == expected + + +@pytest.mark.parametrize( + "fmt,expected", + [ + ("%Y %m %d %H:%M:%S", True), + ("%Y/%m/%d %H:%M:%S", True), + (r"%Y\%m\%d %H:%M:%S", True), + ("%Y-%m-%d %H:%M:%S", True), + ("%Y.%m.%d %H:%M:%S", True), + ("%Y%m%d %H:%M:%S", True), + ("%Y-%m-%dT%H:%M:%S", True), + ("%Y-%m-%dT%H:%M:%S%z", True), + ("%Y-%m-%dT%H:%M:%S%Z", False), + ("%Y-%m-%dT%H:%M:%S.%f", True), + ("%Y-%m-%dT%H:%M:%S.%f%z", True), + ("%Y-%m-%dT%H:%M:%S.%f%Z", False), + ("%Y%m%d", True), + ("%Y%m", False), + ("%Y", True), + ("%Y-%m-%d", True), + ("%Y-%m", True), + ], +) +def test_is_iso_format(fmt, expected): + # see gh-41047 + result = strptime._test_format_is_iso(fmt) + assert result == expected + + +@pytest.mark.parametrize( + "input", + [ + "2018-01-01T00:00:00.123456789", + "2018-01-01T00:00:00.123456", + "2018-01-01T00:00:00.123", + ], +) +def test_guess_datetime_format_f(input): + # https://github.com/pandas-dev/pandas/issues/49043 + result = parsing.guess_datetime_format(input) + expected = "%Y-%m-%dT%H:%M:%S.%f" + assert result == expected + + +def _helper_hypothesis_delimited_date(call, date_string, **kwargs): + msg, result = None, None + try: + result = call(date_string, **kwargs) + except ValueError as err: + msg = str(err) + return msg, result + + +@pytest.mark.parametrize("input", ["21-01-01", "01-01-21"]) +@pytest.mark.parametrize("dayfirst", [True, False]) +def test_parse_datetime_string_with_reso_dayfirst(dayfirst, input): + with option_context("display.date_dayfirst", dayfirst): + except_out_dateutil, result = _helper_hypothesis_delimited_date( + parsing.parse_datetime_string_with_reso, input + ) + + except_in_dateutil, expected = _helper_hypothesis_delimited_date( + du_parse, + input, + default=datetime(1, 1, 1), + dayfirst=dayfirst, + yearfirst=False, + ) + assert except_out_dateutil == except_in_dateutil + assert result[0] == expected + + +@pytest.mark.parametrize("input", ["21-01-01", "01-01-21"]) +@pytest.mark.parametrize("yearfirst", [True, False]) +def test_parse_datetime_string_with_reso_yearfirst(yearfirst, input): + with option_context("display.date_yearfirst", yearfirst): + except_out_dateutil, result = _helper_hypothesis_delimited_date( + parsing.parse_datetime_string_with_reso, input + ) + except_in_dateutil, expected = _helper_hypothesis_delimited_date( + du_parse, + input, + default=datetime(1, 1, 1), + dayfirst=False, + yearfirst=yearfirst, + ) + assert except_out_dateutil == except_in_dateutil + assert result[0] == expected diff --git a/pandas/tests/tslibs/test_period.py b/pandas/tests/tslibs/test_period.py new file mode 100644 index 0000000000000000000000000000000000000000..4c17caabae327adf3af24a9b13c7c5da7d576cd4 --- /dev/null +++ b/pandas/tests/tslibs/test_period.py @@ -0,0 +1,123 @@ +import numpy as np +import pytest + +from pandas._libs.tslibs import ( + iNaT, + to_offset, +) +from pandas._libs.tslibs.period import ( + extract_ordinals, + get_period_field_arr, + period_asfreq, + period_ordinal, +) + +import pandas._testing as tm + + +def get_freq_code(freqstr: str) -> int: + off = to_offset(freqstr, is_period=True) + # error: "BaseOffset" has no attribute "_period_dtype_code" + code = off._period_dtype_code # type: ignore[attr-defined] + return code + + +@pytest.mark.parametrize( + "freq1,freq2,expected", + [ + ("D", "h", 24), + ("D", "min", 1440), + ("D", "s", 86400), + ("D", "ms", 86400000), + ("D", "us", 86400000000), + ("D", "ns", 86400000000000), + ("h", "min", 60), + ("h", "s", 3600), + ("h", "ms", 3600000), + ("h", "us", 3600000000), + ("h", "ns", 3600000000000), + ("min", "s", 60), + ("min", "ms", 60000), + ("min", "us", 60000000), + ("min", "ns", 60000000000), + ("s", "ms", 1000), + ("s", "us", 1000000), + ("s", "ns", 1000000000), + ("ms", "us", 1000), + ("ms", "ns", 1000000), + ("us", "ns", 1000), + ], +) +def test_intra_day_conversion_factors(freq1, freq2, expected): + assert ( + period_asfreq(1, get_freq_code(freq1), get_freq_code(freq2), False) == expected + ) + + +@pytest.mark.parametrize( + "freq,expected", [("Y", 0), ("M", 0), ("W", 1), ("D", 0), ("B", 0)] +) +def test_period_ordinal_start_values(freq, expected): + # information for Jan. 1, 1970. + assert period_ordinal(1970, 1, 1, 0, 0, 0, 0, 0, get_freq_code(freq)) == expected + + +@pytest.mark.parametrize( + "dt,expected", + [ + ((1970, 1, 4, 0, 0, 0, 0, 0), 1), + ((1970, 1, 5, 0, 0, 0, 0, 0), 2), + ((2013, 10, 6, 0, 0, 0, 0, 0), 2284), + ((2013, 10, 7, 0, 0, 0, 0, 0), 2285), + ], +) +def test_period_ordinal_week(dt, expected): + args = (*dt, get_freq_code("W")) + assert period_ordinal(*args) == expected + + +@pytest.mark.parametrize( + "day,expected", + [ + # Thursday (Oct. 3, 2013). + (3, 11415), + # Friday (Oct. 4, 2013). + (4, 11416), + # Saturday (Oct. 5, 2013). + (5, 11417), + # Sunday (Oct. 6, 2013). + (6, 11417), + # Monday (Oct. 7, 2013). + (7, 11417), + # Tuesday (Oct. 8, 2013). + (8, 11418), + ], +) +def test_period_ordinal_business_day(day, expected): + # 5000 is PeriodDtypeCode for BusinessDay + args = (2013, 10, day, 0, 0, 0, 0, 0, 5000) + assert period_ordinal(*args) == expected + + +class TestExtractOrdinals: + def test_extract_ordinals_raises(self): + # with non-object, make sure we raise TypeError, not segfault + arr = np.arange(5) + freq = to_offset("D") + with pytest.raises(TypeError, match="values must be object-dtype"): + extract_ordinals(arr, freq) + + def test_extract_ordinals_2d(self): + freq = to_offset("D") + arr = np.empty(10, dtype=object) + arr[:] = iNaT + + res = extract_ordinals(arr, freq) + res2 = extract_ordinals(arr.reshape(5, 2), freq) + tm.assert_numpy_array_equal(res, res2.reshape(-1)) + + +def test_get_period_field_array_raises_on_out_of_range(): + msg = "Buffer dtype mismatch, expected 'const int64_t' but got 'double'" + with pytest.raises(ValueError, match=msg): + get_period_field_arr(-1, np.empty(1), 0) diff --git a/pandas/tests/tslibs/test_resolution.py b/pandas/tests/tslibs/test_resolution.py new file mode 100644 index 0000000000000000000000000000000000000000..59004d2cabdeee68b479c125a6cc050259c41b5d --- /dev/null +++ b/pandas/tests/tslibs/test_resolution.py @@ -0,0 +1,56 @@ +import datetime + +import numpy as np +import pytest + +from pandas._libs.tslibs import ( + Resolution, + get_resolution, +) +from pandas._libs.tslibs.dtypes import NpyDatetimeUnit + + +def test_get_resolution_nano(): + # don't return the fallback RESO_DAY + arr = np.array([1], dtype=np.int64) + res = get_resolution(arr) + assert res == Resolution.RESO_NS + + +def test_get_resolution_non_nano_data(): + arr = np.array([1], dtype=np.int64) + res = get_resolution(arr, None, NpyDatetimeUnit.NPY_FR_us.value) + assert res == Resolution.RESO_US + + res = get_resolution(arr, datetime.UTC, NpyDatetimeUnit.NPY_FR_us.value) + assert res == Resolution.RESO_US + + +@pytest.mark.parametrize( + "freqstr,expected", + [ + ("Y", "year"), + ("Q", "quarter"), + ("M", "month"), + ("D", "day"), + ("h", "hour"), + ("min", "minute"), + ("s", "second"), + ("ms", "millisecond"), + ("us", "microsecond"), + ("ns", "nanosecond"), + ], +) +def test_get_attrname_from_abbrev(freqstr, expected): + reso = Resolution.get_reso_from_freqstr(freqstr) + assert reso.attr_abbrev == freqstr + assert reso.attrname == expected + + +@pytest.mark.parametrize("freq", ["H", "S"]) +def test_unit_H_S_raises(freq): + # GH#59143 + msg = f"Invalid frequency: {freq}" + + with pytest.raises(ValueError, match=msg): + Resolution.get_reso_from_freqstr(freq) diff --git a/pandas/tests/tslibs/test_strptime.py b/pandas/tests/tslibs/test_strptime.py new file mode 100644 index 0000000000000000000000000000000000000000..c63d3dbd9f5c7410f8be929d58047e4a3c8bf9ed --- /dev/null +++ b/pandas/tests/tslibs/test_strptime.py @@ -0,0 +1,111 @@ +from datetime import ( + datetime, + timezone, +) + +import numpy as np +import pytest + +from pandas._libs.tslibs.dtypes import NpyDatetimeUnit +from pandas._libs.tslibs.strptime import array_strptime + +from pandas import ( + NaT, + Timestamp, +) +import pandas._testing as tm + +creso_infer = NpyDatetimeUnit.NPY_FR_GENERIC.value + + +class TestArrayStrptimeResolutionInference: + def test_array_strptime_resolution_all_nat(self): + arr = np.array([NaT, np.nan], dtype=object) + + fmt = "%Y-%m-%d %H:%M:%S" + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + assert res.dtype == "M8[s]" + + res, _ = array_strptime(arr, fmt=fmt, utc=True, creso=creso_infer) + assert res.dtype == "M8[s]" + + @pytest.mark.parametrize("tz", [None, timezone.utc]) + def test_array_strptime_resolution_inference_homogeneous_strings(self, tz): + dt = datetime(2016, 1, 2, 3, 4, 5, 678900, tzinfo=tz) + dt0 = dt.replace(microsecond=0) + + fmt = "%Y-%m-%d %H:%M:%S" + dtstr = dt.strftime(fmt) + arr = np.array([dtstr] * 3, dtype=object) + expected = np.array([dt0.replace(tzinfo=None)] * 3, dtype="M8[us]") + + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + fmt = "%Y-%m-%d %H:%M:%S.%f" + dtstr = dt.strftime(fmt) + arr = np.array([dtstr] * 3, dtype=object) + expected = np.array([dt.replace(tzinfo=None)] * 3, dtype="M8[us]") + + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + fmt = "ISO8601" + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + @pytest.mark.parametrize("tz", [None, timezone.utc]) + def test_array_strptime_resolution_mixed(self, tz): + dt = datetime(2016, 1, 2, 3, 4, 5, 678900, tzinfo=tz) + + ts = Timestamp(dt).as_unit("ns") + + arr = np.array([dt, ts], dtype=object) + expected = np.array( + [Timestamp(dt).as_unit("ns").asm8, ts.asm8], + dtype="M8[ns]", + ) + + fmt = "%Y-%m-%d %H:%M:%S" + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + fmt = "ISO8601" + res, _ = array_strptime(arr, fmt=fmt, utc=False, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + def test_array_strptime_resolution_todaynow(self): + # specifically case where today/now is the *first* item + vals = np.array(["today", np.datetime64("2017-01-01", "us")], dtype=object) + + now = Timestamp("now").asm8 + res, _ = array_strptime(vals, fmt="%Y-%m-%d", utc=False, creso=creso_infer) + res2, _ = array_strptime( + vals[::-1], fmt="%Y-%m-%d", utc=False, creso=creso_infer + ) + + # 1s is an arbitrary cutoff for call overhead; in local testing the + # actual difference is about 250us + tolerance = np.timedelta64(1, "s") + + assert res.dtype == "M8[us]" + assert abs(res[0] - now) < tolerance + assert res[1] == vals[1] + + assert res2.dtype == "M8[us]" + assert abs(res2[1] - now) < tolerance * 2 + assert res2[0] == vals[1] + + def test_array_strptime_str_outside_nano_range(self): + vals = np.array(["2401-09-15"], dtype=object) + expected = np.array(["2401-09-15"], dtype="M8[us]") + fmt = "ISO8601" + res, _ = array_strptime(vals, fmt=fmt, creso=creso_infer) + tm.assert_numpy_array_equal(res, expected) + + # non-iso -> different path + vals2 = np.array(["Sep 15, 2401"], dtype=object) + expected2 = np.array(["2401-09-15"], dtype="M8[us]") + fmt2 = "%b %d, %Y" + res2, _ = array_strptime(vals2, fmt=fmt2, creso=creso_infer) + tm.assert_numpy_array_equal(res2, expected2) diff --git a/pandas/tests/tslibs/test_timezones.py b/pandas/tests/tslibs/test_timezones.py new file mode 100644 index 0000000000000000000000000000000000000000..c48986c597356f9e8f1070771fdc671e42380c0e --- /dev/null +++ b/pandas/tests/tslibs/test_timezones.py @@ -0,0 +1,193 @@ +from datetime import ( + datetime, + timedelta, + timezone, +) +import subprocess +import sys +import textwrap + +import dateutil.tz +import pytest + +from pandas._libs.tslibs import ( + conversion, + timezones, +) +from pandas.compat import is_platform_windows + +from pandas import Timestamp + + +@pytest.mark.single_cpu +def test_no_timezone_data(): + # https://github.com/pandas-dev/pandas/pull/63335 + # Test error message when timezone data is not available. + msg = "'No time zone found with key Europe/Brussels'" + code = textwrap.dedent( + f"""\ + import sys, zoneinfo, pandas as pd + sys.modules['tzdata'] = None + zoneinfo.reset_tzpath(['/path/to/nowhere']) + try: + pd.to_datetime('2012-01-01').tz_localize('Europe/Brussels') + except zoneinfo.ZoneInfoNotFoundError as err: + assert str(err) == "{msg}" + """ + ) + subprocess.check_call([sys.executable, "-c", code]) + + +def test_is_utc(utc_fixture): + tz = timezones.maybe_get_tz(utc_fixture) + assert timezones.is_utc(tz) + + +def test_cache_keys_are_distinct_for_pytz_vs_dateutil(): + pytz = pytest.importorskip("pytz") + for tz_name in pytz.common_timezones: + tz_p = timezones.maybe_get_tz(tz_name) + tz_d = timezones.maybe_get_tz("dateutil/" + tz_name) + + if tz_d is None: + pytest.skip(tz_name + ": dateutil does not know about this one") + + if not (tz_name == "UTC" and is_platform_windows()): + # they both end up as tzwin("UTC") on windows + assert timezones._p_tz_cache_key(tz_p) != timezones._p_tz_cache_key(tz_d) + + +def test_tzlocal_repr(): + # see gh-13583 + ts = Timestamp("2011-01-01", tz=dateutil.tz.tzlocal()) + assert ts.tz == dateutil.tz.tzlocal() + assert "tz='tzlocal()')" in repr(ts) + + +def test_tzlocal_maybe_get_tz(): + # see gh-13583 + tz = timezones.maybe_get_tz("tzlocal()") + assert tz == dateutil.tz.tzlocal() + + +def test_tzlocal_offset(): + # see gh-13583 + # + # Get offset using normal datetime for test. + ts = Timestamp("2011-01-01", tz=dateutil.tz.tzlocal()).as_unit("s") + + offset = dateutil.tz.tzlocal().utcoffset(datetime(2011, 1, 1)) + offset = offset.total_seconds() + + assert ts._value + offset == Timestamp("2011-01-01").as_unit("s")._value + + +def test_tzlocal_is_not_utc(): + # even if the machine running the test is localized to UTC + tz = dateutil.tz.tzlocal() + assert not timezones.is_utc(tz) + + assert not timezones.tz_compare(tz, dateutil.tz.tzutc()) + + +def test_tz_compare_utc(utc_fixture, utc_fixture2): + tz = timezones.maybe_get_tz(utc_fixture) + tz2 = timezones.maybe_get_tz(utc_fixture2) + assert timezones.tz_compare(tz, tz2) + + +@pytest.fixture( + params=[ + ("pytz/US/Eastern", lambda tz, x: tz.localize(x)), + (dateutil.tz.gettz("US/Eastern"), lambda tz, x: x.replace(tzinfo=tz)), + ] +) +def infer_setup(request): + eastern, localize = request.param + if isinstance(eastern, str) and eastern.startswith("pytz/"): + pytz = pytest.importorskip("pytz") + eastern = pytz.timezone(eastern.removeprefix("pytz/")) + + start_naive = datetime(2001, 1, 1) + end_naive = datetime(2009, 1, 1) + + start = localize(eastern, start_naive) + end = localize(eastern, end_naive) + + return eastern, localize, start, end, start_naive, end_naive + + +def test_infer_tz_compat(infer_setup): + eastern, _, start, end, start_naive, end_naive = infer_setup + + assert ( + timezones.infer_tzinfo(start, end) + is conversion.localize_pydatetime(start_naive, eastern).tzinfo + ) + assert ( + timezones.infer_tzinfo(start, None) + is conversion.localize_pydatetime(start_naive, eastern).tzinfo + ) + assert ( + timezones.infer_tzinfo(None, end) + is conversion.localize_pydatetime(end_naive, eastern).tzinfo + ) + + +def test_infer_tz_utc_localize(infer_setup): + _, _, start, end, start_naive, end_naive = infer_setup + utc = timezone.utc + + start = start_naive.astimezone(utc) + end = end_naive.astimezone(utc) + + assert timezones.infer_tzinfo(start, end) is utc + + +@pytest.mark.parametrize("ordered", [True, False]) +def test_infer_tz_mismatch(infer_setup, ordered): + eastern, _, _, _, start_naive, end_naive = infer_setup + msg = "Inputs must both have the same timezone" + + utc = timezone.utc + start = start_naive.astimezone(utc) + end = conversion.localize_pydatetime(end_naive, eastern) + + args = (start, end) if ordered else (end, start) + + with pytest.raises(AssertionError, match=msg): + timezones.infer_tzinfo(*args) + + +def test_maybe_get_tz_invalid_types(): + with pytest.raises(TypeError, match=""): + timezones.maybe_get_tz(44.0) + + with pytest.raises(TypeError, match=""): + timezones.maybe_get_tz(pytest) + + msg = "" + with pytest.raises(TypeError, match=msg): + timezones.maybe_get_tz(Timestamp("2021-01-01", tz="UTC")) + + +def test_maybe_get_tz_offset_only(): + # see gh-36004 + + # timezone.utc + tz = timezones.maybe_get_tz(timezone.utc) + assert tz == timezone(timedelta(hours=0, minutes=0)) + + # without UTC+- prefix + tz = timezones.maybe_get_tz("+01:15") + assert tz == timezone(timedelta(hours=1, minutes=15)) + + tz = timezones.maybe_get_tz("-01:15") + assert tz == timezone(-timedelta(hours=1, minutes=15)) + + # with UTC+- prefix + tz = timezones.maybe_get_tz("UTC+02:45") + assert tz == timezone(timedelta(hours=2, minutes=45)) + + tz = timezones.maybe_get_tz("UTC-02:45") + assert tz == timezone(-timedelta(hours=2, minutes=45)) diff --git a/pandas/tests/util/__init__.py b/pandas/tests/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/util/conftest.py b/pandas/tests/util/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..2e931ff42fe1528945dba38ae91e6cae0853afda --- /dev/null +++ b/pandas/tests/util/conftest.py @@ -0,0 +1,46 @@ +import pytest + + +@pytest.fixture(params=[True, False]) +def check_dtype(request): + """ + Fixture returning `True` or `False`, determining whether to check + if the `dtype` is identical or not, when comparing two data structures, + e.g. `Series`, `SparseArray` or `DataFrame`. + """ + return request.param + + +@pytest.fixture(params=[True, False]) +def check_exact(request): + """ + Fixture returning `True` or `False`, determining whether to + compare floating point numbers exactly or not. + """ + return request.param + + +@pytest.fixture(params=[True, False]) +def check_index_type(request): + """ + Fixture returning `True` or `False`, determining whether to check + if the `Index` types are identical or not. + """ + return request.param + + +@pytest.fixture(params=[0.5e-3, 0.5e-5]) +def rtol(request): + """ + Fixture returning 0.5e-3 or 0.5e-5. Those values are used as relative tolerance. + """ + return request.param + + +@pytest.fixture(params=[True, False]) +def check_categorical(request): + """ + Fixture returning `True` or `False`, determining whether to + compare internal `Categorical` exactly or not. + """ + return request.param diff --git a/pandas/tests/util/test_assert_almost_equal.py b/pandas/tests/util/test_assert_almost_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..091670ed69f11b72700766ba20024af37dbf68f5 --- /dev/null +++ b/pandas/tests/util/test_assert_almost_equal.py @@ -0,0 +1,586 @@ +import numpy as np +import pytest + +from pandas import ( + NA, + DataFrame, + Index, + NaT, + Series, + Timestamp, +) +import pandas._testing as tm + + +def _assert_almost_equal_both(a, b, **kwargs): + """ + Check that two objects are approximately equal. + + This check is performed commutatively. + + Parameters + ---------- + a : object + The first object to compare. + b : object + The second object to compare. + **kwargs + The arguments passed to `tm.assert_almost_equal`. + """ + tm.assert_almost_equal(a, b, **kwargs) + tm.assert_almost_equal(b, a, **kwargs) + + +def _assert_not_almost_equal(a, b, **kwargs): + """ + Check that two objects are not approximately equal. + + Parameters + ---------- + a : object + The first object to compare. + b : object + The second object to compare. + **kwargs + The arguments passed to `tm.assert_almost_equal`. + """ + try: + tm.assert_almost_equal(a, b, **kwargs) + msg = f"{a} and {b} were approximately equal when they shouldn't have been" + pytest.fail(reason=msg) + except AssertionError: + pass + + +def _assert_not_almost_equal_both(a, b, **kwargs): + """ + Check that two objects are not approximately equal. + + This check is performed commutatively. + + Parameters + ---------- + a : object + The first object to compare. + b : object + The second object to compare. + **kwargs + The arguments passed to `tm.assert_almost_equal`. + """ + _assert_not_almost_equal(a, b, **kwargs) + _assert_not_almost_equal(b, a, **kwargs) + + +@pytest.mark.parametrize( + "a,b", + [ + (1.1, 1.1), + (1.1, 1.100001), + (np.int16(1), 1.000001), + (np.float64(1.1), 1.1), + (np.uint32(5), 5), + ], +) +def test_assert_almost_equal_numbers(a, b): + _assert_almost_equal_both(a, b) + + +@pytest.mark.parametrize( + "a,b", + [ + (1.1, 1), + (1.1, True), + (1, 2), + (1.0001, np.int16(1)), + # The following two examples are not "almost equal" due to tol. + (0.1, 0.1001), + (0.0011, 0.0012), + ], +) +def test_assert_not_almost_equal_numbers(a, b): + _assert_not_almost_equal_both(a, b) + + +@pytest.mark.parametrize( + "a,b", + [ + (1.1, 1.1), + (1.1, 1.100001), + (1.1, 1.1001), + (0.000001, 0.000005), + (1000.0, 1000.0005), + # Testing this example, as per #13357 + (0.000011, 0.000012), + ], +) +def test_assert_almost_equal_numbers_atol(a, b): + # Equivalent to the deprecated check_less_precise=True, enforced in 2.0 + _assert_almost_equal_both(a, b, rtol=0.5e-3, atol=0.5e-3) + + +@pytest.mark.parametrize("a,b", [(1.1, 1.11), (0.1, 0.101), (0.000011, 0.001012)]) +def test_assert_not_almost_equal_numbers_atol(a, b): + _assert_not_almost_equal_both(a, b, atol=1e-3) + + +@pytest.mark.parametrize( + "a,b", + [ + (1.1, 1.1), + (1.1, 1.100001), + (1.1, 1.1001), + (1000.0, 1000.0005), + (1.1, 1.11), + (0.1, 0.101), + ], +) +def test_assert_almost_equal_numbers_rtol(a, b): + _assert_almost_equal_both(a, b, rtol=0.05) + + +@pytest.mark.parametrize("a,b", [(0.000011, 0.000012), (0.000001, 0.000005)]) +def test_assert_not_almost_equal_numbers_rtol(a, b): + _assert_not_almost_equal_both(a, b, rtol=0.05) + + +@pytest.mark.parametrize( + "a,b,rtol", + [ + (1.00001, 1.00005, 0.001), + (-0.908356 + 0.2j, -0.908358 + 0.2j, 1e-3), + (0.1 + 1.009j, 0.1 + 1.006j, 0.1), + (0.1001 + 2.0j, 0.1 + 2.001j, 0.01), + ], +) +def test_assert_almost_equal_complex_numbers(a, b, rtol): + _assert_almost_equal_both(a, b, rtol=rtol) + _assert_almost_equal_both(np.complex64(a), np.complex64(b), rtol=rtol) + _assert_almost_equal_both(np.complex128(a), np.complex128(b), rtol=rtol) + + +@pytest.mark.parametrize( + "a,b,rtol", + [ + (0.58310768, 0.58330768, 1e-7), + (-0.908 + 0.2j, -0.978 + 0.2j, 0.001), + (0.1 + 1j, 0.1 + 2j, 0.01), + (-0.132 + 1.001j, -0.132 + 1.005j, 1e-5), + (0.58310768j, 0.58330768j, 1e-9), + ], +) +def test_assert_not_almost_equal_complex_numbers(a, b, rtol): + _assert_not_almost_equal_both(a, b, rtol=rtol) + _assert_not_almost_equal_both(np.complex64(a), np.complex64(b), rtol=rtol) + _assert_not_almost_equal_both(np.complex128(a), np.complex128(b), rtol=rtol) + + +@pytest.mark.parametrize("a,b", [(0, 0), (0, 0.0), (0, np.float64(0)), (0.00000001, 0)]) +def test_assert_almost_equal_numbers_with_zeros(a, b): + _assert_almost_equal_both(a, b) + + +@pytest.mark.parametrize("a,b", [(0.001, 0), (1, 0)]) +def test_assert_not_almost_equal_numbers_with_zeros(a, b): + _assert_not_almost_equal_both(a, b) + + +@pytest.mark.parametrize("a,b", [(1, "abc"), (1, [1]), (1, object())]) +def test_assert_not_almost_equal_numbers_with_mixed(a, b): + _assert_not_almost_equal_both(a, b) + + +@pytest.mark.parametrize( + "left_dtype", ["M8[ns]", "m8[ns]", "float64", "int64", "object"] +) +@pytest.mark.parametrize( + "right_dtype", ["M8[ns]", "m8[ns]", "float64", "int64", "object"] +) +def test_assert_almost_equal_edge_case_ndarrays(left_dtype, right_dtype): + # Empty compare. + _assert_almost_equal_both( + np.array([], dtype=left_dtype), + np.array([], dtype=right_dtype), + check_dtype=False, + ) + + +def test_assert_almost_equal_sets(): + # GH#51727 + _assert_almost_equal_both({1, 2, 3}, {1, 2, 3}) + + +def test_assert_almost_not_equal_sets(): + # GH#51727 + msg = r"{1, 2, 3} != {1, 2, 4}" + with pytest.raises(AssertionError, match=msg): + _assert_almost_equal_both({1, 2, 3}, {1, 2, 4}) + + +def test_assert_almost_equal_dicts(): + _assert_almost_equal_both({"a": 1, "b": 2}, {"a": 1, "b": 2}) + + +@pytest.mark.parametrize( + "a,b", + [ + ({"a": 1, "b": 2}, {"a": 1, "b": 3}), + ({"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}), + ({"a": 1}, 1), + ({"a": 1}, "abc"), + ({"a": 1}, [1]), + ], +) +def test_assert_not_almost_equal_dicts(a, b): + _assert_not_almost_equal_both(a, b) + + +@pytest.mark.parametrize("val", [1, 2]) +def test_assert_almost_equal_dict_like_object(val): + dict_val = 1 + real_dict = {"a": val} + + class DictLikeObj: + def keys(self): + return ("a",) + + def __getitem__(self, item): + if item == "a": + return dict_val + + func = ( + _assert_almost_equal_both if val == dict_val else _assert_not_almost_equal_both + ) + func(real_dict, DictLikeObj(), check_dtype=False) + + +def test_assert_almost_equal_strings(): + _assert_almost_equal_both("abc", "abc") + + +@pytest.mark.parametrize("b", ["abcd", "abd", 1, [1]]) +def test_assert_not_almost_equal_strings(b): + _assert_not_almost_equal_both("abc", b) + + +@pytest.mark.parametrize("box", [list, np.array]) +def test_assert_almost_equal_iterables(box): + _assert_almost_equal_both(box([1, 2, 3]), box([1, 2, 3])) + + +@pytest.mark.parametrize( + "a,b", + [ + # Class is different. + (np.array([1, 2, 3]), [1, 2, 3]), + # Dtype is different. + (np.array([1, 2, 3]), np.array([1.0, 2.0, 3.0])), + # Can't compare generators. + (iter([1, 2, 3]), [1, 2, 3]), + ([1, 2, 3], [1, 2, 4]), + ([1, 2, 3], [1, 2, 3, 4]), + ([1, 2, 3], 1), + ], +) +def test_assert_not_almost_equal_iterables(a, b): + _assert_not_almost_equal(a, b) + + +def test_assert_almost_equal_null(): + _assert_almost_equal_both(None, None) + + +@pytest.mark.parametrize("a,b", [(None, np.nan), (None, 0), (np.nan, 0)]) +def test_assert_not_almost_equal_null(a, b): + _assert_not_almost_equal(a, b) + + +@pytest.mark.parametrize( + "a,b", + [ + (np.inf, np.inf), + (np.inf, float("inf")), + (np.array([np.inf, np.nan, -np.inf]), np.array([np.inf, np.nan, -np.inf])), + ], +) +def test_assert_almost_equal_inf(a, b): + _assert_almost_equal_both(a, b) + + +objs = [NA, np.nan, NaT, None, np.datetime64("NaT"), np.timedelta64("NaT")] + + +@pytest.mark.parametrize("left", objs) +@pytest.mark.parametrize("right", objs) +def test_mismatched_na_assert_almost_equal(left, right): + left_arr = np.array([left], dtype=object) + right_arr = np.array([right], dtype=object) + + msg = "Mismatched null-like values" + + if left is right: + _assert_almost_equal_both(left, right, check_dtype=False) + tm.assert_numpy_array_equal(left_arr, right_arr) + tm.assert_index_equal( + Index(left_arr, dtype=object), Index(right_arr, dtype=object) + ) + tm.assert_series_equal( + Series(left_arr, dtype=object), Series(right_arr, dtype=object) + ) + tm.assert_frame_equal( + DataFrame(left_arr, dtype=object), DataFrame(right_arr, dtype=object) + ) + + else: + with pytest.raises(AssertionError, match=msg): + _assert_almost_equal_both(left, right, check_dtype=False) + + # TODO: to get the same deprecation in assert_numpy_array_equal we need + # to change/deprecate the default for strict_nan to become True + # TODO: to get the same deprecation in assert_index_equal we need to + # change/deprecate array_equivalent_object to be stricter, as + # assert_index_equal uses Index.equal which uses array_equivalent. + with pytest.raises(AssertionError, match="Series are different"): + tm.assert_series_equal( + Series(left_arr, dtype=object), Series(right_arr, dtype=object) + ) + with pytest.raises(AssertionError, match="DataFrame.iloc.* are different"): + tm.assert_frame_equal( + DataFrame(left_arr, dtype=object), DataFrame(right_arr, dtype=object) + ) + + +def test_assert_not_almost_equal_inf(): + _assert_not_almost_equal_both(np.inf, 0) + + +@pytest.mark.parametrize( + "a,b", + [ + (Index([1.0, 1.1]), Index([1.0, 1.100001])), + (Series([1.0, 1.1]), Series([1.0, 1.100001])), + (np.array([1.1, 2.000001]), np.array([1.1, 2.0])), + (DataFrame({"a": [1.0, 1.1]}), DataFrame({"a": [1.0, 1.100001]})), + ], +) +def test_assert_almost_equal_pandas(a, b): + _assert_almost_equal_both(a, b) + + +def test_assert_almost_equal_object(): + a = [Timestamp("2011-01-01"), Timestamp("2011-01-01")] + b = [Timestamp("2011-01-01"), Timestamp("2011-01-01")] + _assert_almost_equal_both(a, b) + + +def test_assert_almost_equal_value_mismatch(): + msg = "expected 2\\.00000 but got 1\\.00000, with rtol=1e-05, atol=1e-08" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(1, 2) + + +@pytest.mark.parametrize( + "a,b,klass1,klass2", + [(np.array([1]), 1, "ndarray", "int"), (1, np.array([1]), "int", "ndarray")], +) +def test_assert_almost_equal_class_mismatch(a, b, klass1, klass2): + msg = f"""numpy array are different + +numpy array classes are different +\\[left\\]: {klass1} +\\[right\\]: {klass2}""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(a, b) + + +def test_assert_almost_equal_value_mismatch1(): + msg = """numpy array are different + +numpy array values are different \\(66\\.66667 %\\) +\\[left\\]: \\[nan, 2\\.0, 3\\.0\\] +\\[right\\]: \\[1\\.0, nan, 3\\.0\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(np.array([np.nan, 2, 3]), np.array([1, np.nan, 3])) + + +def test_assert_almost_equal_value_mismatch2(): + msg = """numpy array are different + +numpy array values are different \\(50\\.0 %\\) +\\[left\\]: \\[1, 2\\] +\\[right\\]: \\[1, 3\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(np.array([1, 2]), np.array([1, 3])) + + +def test_assert_almost_equal_value_mismatch3(): + msg = """numpy array are different + +numpy array values are different \\(16\\.66667 %\\) +\\[left\\]: \\[\\[1, 2\\], \\[3, 4\\], \\[5, 6\\]\\] +\\[right\\]: \\[\\[1, 3\\], \\[3, 4\\], \\[5, 6\\]\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal( + np.array([[1, 2], [3, 4], [5, 6]]), np.array([[1, 3], [3, 4], [5, 6]]) + ) + + +def test_assert_almost_equal_value_mismatch4(): + msg = """numpy array are different + +numpy array values are different \\(25\\.0 %\\) +\\[left\\]: \\[\\[1, 2\\], \\[3, 4\\]\\] +\\[right\\]: \\[\\[1, 3\\], \\[3, 4\\]\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(np.array([[1, 2], [3, 4]]), np.array([[1, 3], [3, 4]])) + + +def test_assert_almost_equal_shape_mismatch_override(): + msg = """Index are different + +Index shapes are different +\\[left\\]: \\(2L*,\\) +\\[right\\]: \\(3L*,\\)""" + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(np.array([1, 2]), np.array([3, 4, 5]), obj="Index") + + +def test_assert_almost_equal_unicode(): + # see gh-20503 + msg = """numpy array are different + +numpy array values are different \\(33\\.33333 %\\) +\\[left\\]: \\[á, à, ä\\] +\\[right\\]: \\[á, à, å\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(np.array(["á", "à", "ä"]), np.array(["á", "à", "å"])) + + +def test_assert_almost_equal_timestamp(): + a = np.array([Timestamp("2011-01-01"), Timestamp("2011-01-01")]) + b = np.array([Timestamp("2011-01-01"), Timestamp("2011-01-02")]) + + msg = """numpy array are different + +numpy array values are different \\(50\\.0 %\\) +\\[left\\]: \\[2011-01-01 00:00:00, 2011-01-01 00:00:00\\] +\\[right\\]: \\[2011-01-01 00:00:00, 2011-01-02 00:00:00\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal(a, b) + + +def test_assert_almost_equal_iterable_length_mismatch(): + msg = """Iterable are different + +Iterable length are different +\\[left\\]: 2 +\\[right\\]: 3""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal([1, 2], [3, 4, 5]) + + +def test_assert_almost_equal_iterable_values_mismatch(): + msg = """Iterable are different + +Iterable values are different \\(50\\.0 %\\) +\\[left\\]: \\[1, 2\\] +\\[right\\]: \\[1, 3\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_almost_equal([1, 2], [1, 3]) + + +subarr = np.empty(2, dtype=object) +subarr[:] = [np.array([None, "b"], dtype=object), np.array(["c", "d"], dtype=object)] + +NESTED_CASES = [ + # nested array + ( + np.array([np.array([50, 70, 90]), np.array([20, 30])], dtype=object), + np.array([np.array([50, 70, 90]), np.array([20, 30])], dtype=object), + ), + # >1 level of nesting + ( + np.array( + [ + np.array([np.array([50, 70]), np.array([90])], dtype=object), + np.array([np.array([20, 30])], dtype=object), + ], + dtype=object, + ), + np.array( + [ + np.array([np.array([50, 70]), np.array([90])], dtype=object), + np.array([np.array([20, 30])], dtype=object), + ], + dtype=object, + ), + ), + # lists + ( + np.array([[50, 70, 90], [20, 30]], dtype=object), + np.array([[50, 70, 90], [20, 30]], dtype=object), + ), + # mixed array/list + ( + np.array([np.array([1, 2, 3]), np.array([4, 5])], dtype=object), + np.array([[1, 2, 3], [4, 5]], dtype=object), + ), + ( + np.array([np.array([], dtype=object), None], dtype=object), + np.array([[], None], dtype=object), + ), + ( + np.array( + [ + np.array([np.array([1, 2, 3]), np.array([4, 5])], dtype=object), + np.array( + [np.array([6]), np.array([7, 8]), np.array([9])], dtype=object + ), + ], + dtype=object, + ), + np.array([[[1, 2, 3], [4, 5]], [[6], [7, 8], [9]]], dtype=object), + ), + # same-length lists + ( + np.array([subarr, None], dtype=object), + np.array([[[None, "b"], ["c", "d"]], None], dtype=object), + ), + # dicts + ( + np.array([{"f1": 1, "f2": np.array(["a", "b"], dtype=object)}], dtype=object), + np.array([{"f1": 1, "f2": np.array(["a", "b"], dtype=object)}], dtype=object), + ), + ( + np.array([{"f1": 1, "f2": np.array(["a", "b"], dtype=object)}], dtype=object), + np.array([{"f1": 1, "f2": ["a", "b"]}], dtype=object), + ), + # array/list of dicts + ( + np.array( + [ + np.array( + [{"f1": 1, "f2": np.array(["a", "b"], dtype=object)}], dtype=object + ), + np.array([], dtype=object), + ], + dtype=object, + ), + np.array([[{"f1": 1, "f2": ["a", "b"]}], []], dtype=object), + ), +] + + +@pytest.mark.filterwarnings("ignore:elementwise comparison failed:DeprecationWarning") +@pytest.mark.parametrize("a,b", NESTED_CASES) +def test_assert_almost_equal_array_nested(a, b): + _assert_almost_equal_both(a, b) diff --git a/pandas/tests/util/test_assert_attr_equal.py b/pandas/tests/util/test_assert_attr_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..bbbb0bf2172b12f93c9f0f6a97751854d1566a99 --- /dev/null +++ b/pandas/tests/util/test_assert_attr_equal.py @@ -0,0 +1,33 @@ +from types import SimpleNamespace + +import pytest + +from pandas.core.dtypes.common import is_float + +import pandas._testing as tm + + +def test_assert_attr_equal(nulls_fixture): + obj = SimpleNamespace() + obj.na_value = nulls_fixture + tm.assert_attr_equal("na_value", obj, obj) + + +def test_assert_attr_equal_different_nulls(nulls_fixture, nulls_fixture2): + obj = SimpleNamespace() + obj.na_value = nulls_fixture + + obj2 = SimpleNamespace() + obj2.na_value = nulls_fixture2 + + if nulls_fixture is nulls_fixture2: + tm.assert_attr_equal("na_value", obj, obj2) + elif is_float(nulls_fixture) and is_float(nulls_fixture2): + # we consider float("nan") and np.float64("nan") to be equivalent + tm.assert_attr_equal("na_value", obj, obj2) + elif type(nulls_fixture) is type(nulls_fixture2): + # e.g. Decimal("NaN") + tm.assert_attr_equal("na_value", obj, obj2) + else: + with pytest.raises(AssertionError, match='"na_value" are different'): + tm.assert_attr_equal("na_value", obj, obj2) diff --git a/pandas/tests/util/test_assert_categorical_equal.py b/pandas/tests/util/test_assert_categorical_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..c17156457470839906ead0f965880e9d3638ab7f --- /dev/null +++ b/pandas/tests/util/test_assert_categorical_equal.py @@ -0,0 +1,88 @@ +import pytest + +from pandas import Categorical +import pandas._testing as tm + + +@pytest.mark.parametrize("c", [None, [1, 2, 3, 4, 5]]) +def test_categorical_equal(c): + c = Categorical([1, 2, 3, 4], categories=c) + tm.assert_categorical_equal(c, c) + + +@pytest.mark.parametrize("check_category_order", [True, False]) +def test_categorical_equal_order_mismatch(check_category_order): + c1 = Categorical([1, 2, 3, 4], categories=[1, 2, 3, 4]) + c2 = Categorical([1, 2, 3, 4], categories=[4, 3, 2, 1]) + kwargs = {"check_category_order": check_category_order} + + if check_category_order: + msg = """Categorical\\.categories are different + +Categorical\\.categories values are different \\(100\\.0 %\\) +\\[left\\]: Index\\(\\[1, 2, 3, 4\\], dtype='int64'\\) +\\[right\\]: Index\\(\\[4, 3, 2, 1\\], dtype='int64'\\)""" + with pytest.raises(AssertionError, match=msg): + tm.assert_categorical_equal(c1, c2, **kwargs) + else: + tm.assert_categorical_equal(c1, c2, **kwargs) + + +def test_categorical_equal_categories_mismatch(): + msg = """Categorical\\.categories are different + +Categorical\\.categories values are different \\(25\\.0 %\\) +\\[left\\]: Index\\(\\[1, 2, 3, 4\\], dtype='int64'\\) +\\[right\\]: Index\\(\\[1, 2, 3, 5\\], dtype='int64'\\)""" + + c1 = Categorical([1, 2, 3, 4]) + c2 = Categorical([1, 2, 3, 5]) + + with pytest.raises(AssertionError, match=msg): + tm.assert_categorical_equal(c1, c2) + + +def test_categorical_equal_codes_mismatch(): + categories = [1, 2, 3, 4] + msg = """Categorical\\.codes are different + +Categorical\\.codes values are different \\(50\\.0 %\\) +\\[left\\]: \\[0, 1, 3, 2\\] +\\[right\\]: \\[0, 1, 2, 3\\]""" + + c1 = Categorical([1, 2, 4, 3], categories=categories) + c2 = Categorical([1, 2, 3, 4], categories=categories) + + with pytest.raises(AssertionError, match=msg): + tm.assert_categorical_equal(c1, c2) + + +def test_categorical_equal_ordered_mismatch(): + data = [1, 2, 3, 4] + msg = """Categorical are different + +Attribute "ordered" are different +\\[left\\]: False +\\[right\\]: True""" + + c1 = Categorical(data, ordered=False) + c2 = Categorical(data, ordered=True) + + with pytest.raises(AssertionError, match=msg): + tm.assert_categorical_equal(c1, c2) + + +@pytest.mark.parametrize("obj", ["index", "foo", "pandas"]) +def test_categorical_equal_object_override(obj): + data = [1, 2, 3, 4] + msg = f"""{obj} are different + +Attribute "ordered" are different +\\[left\\]: False +\\[right\\]: True""" + + c1 = Categorical(data, ordered=False) + c2 = Categorical(data, ordered=True) + + with pytest.raises(AssertionError, match=msg): + tm.assert_categorical_equal(c1, c2, obj=obj) diff --git a/pandas/tests/util/test_assert_extension_array_equal.py b/pandas/tests/util/test_assert_extension_array_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..5d82ae9af0e9573c916fbbb836a5dd64920794c3 --- /dev/null +++ b/pandas/tests/util/test_assert_extension_array_equal.py @@ -0,0 +1,125 @@ +import numpy as np +import pytest + +from pandas import ( + Timestamp, + array, +) +import pandas._testing as tm +from pandas.core.arrays.sparse import SparseArray + + +@pytest.mark.parametrize( + "kwargs", + [ + {}, # Default is check_exact=False + {"check_exact": False}, + {"check_exact": True}, + ], +) +def test_assert_extension_array_equal_not_exact(kwargs): + # see gh-23709 + arr1 = SparseArray([-0.17387645482451206, 0.3414148016424936]) + arr2 = SparseArray([-0.17387645482451206, 0.3414148016424937]) + + if kwargs.get("check_exact", False): + msg = """\ +ExtensionArray are different + +ExtensionArray values are different \\(50\\.0 %\\) +\\[left\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\] +\\[right\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_extension_array_equal(arr1, arr2, **kwargs) + else: + tm.assert_extension_array_equal(arr1, arr2, **kwargs) + + +@pytest.mark.parametrize("decimals", range(10)) +def test_assert_extension_array_equal_less_precise(decimals): + rtol = 0.5 * 10**-decimals + arr1 = SparseArray([0.5, 0.123456]) + arr2 = SparseArray([0.5, 0.123457]) + + if decimals >= 5: + msg = """\ +ExtensionArray are different + +ExtensionArray values are different \\(50\\.0 %\\) +\\[left\\]: \\[0\\.5, 0\\.123456\\] +\\[right\\]: \\[0\\.5, 0\\.123457\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_extension_array_equal(arr1, arr2, rtol=rtol) + else: + tm.assert_extension_array_equal(arr1, arr2, rtol=rtol) + + +def test_assert_extension_array_equal_dtype_mismatch(check_dtype): + end = 5 + kwargs = {"check_dtype": check_dtype} + + arr1 = SparseArray(np.arange(end, dtype="int64")) + arr2 = SparseArray(np.arange(end, dtype="int32")) + + if check_dtype: + msg = """\ +ExtensionArray are different + +Attribute "dtype" are different +\\[left\\]: Sparse\\[int64, 0\\] +\\[right\\]: Sparse\\[int32, 0\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_extension_array_equal(arr1, arr2, **kwargs) + else: + tm.assert_extension_array_equal(arr1, arr2, **kwargs) + + +def test_assert_extension_array_equal_missing_values(): + arr1 = SparseArray([np.nan, 1, 2, np.nan]) + arr2 = SparseArray([np.nan, 1, 2, 3]) + + msg = """\ +ExtensionArray NA mask are different + +ExtensionArray NA mask values are different \\(25\\.0 %\\) +\\[left\\]: \\[True, False, False, True\\] +\\[right\\]: \\[True, False, False, False\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_extension_array_equal(arr1, arr2) + + +@pytest.mark.parametrize("side", ["left", "right"]) +def test_assert_extension_array_equal_non_extension_array(side): + numpy_array = np.arange(5) + extension_array = SparseArray(numpy_array) + + msg = f"{side} is not an ExtensionArray" + args = ( + (numpy_array, extension_array) + if side == "left" + else (extension_array, numpy_array) + ) + + with pytest.raises(AssertionError, match=msg): + tm.assert_extension_array_equal(*args) + + +def test_assert_extension_array_equal_ignore_dtype_mismatch(any_int_dtype): + # https://github.com/pandas-dev/pandas/issues/35715 + left = array([1, 2, 3], dtype="Int64") + right = array([1, 2, 3], dtype=any_int_dtype) + tm.assert_extension_array_equal(left, right, check_dtype=False) + + +def test_assert_extension_array_equal_time_units(): + # https://github.com/pandas-dev/pandas/issues/55730 + timestamp = Timestamp("2023-11-04T12") + naive = array([timestamp], dtype="datetime64[ns]") + utc = array([timestamp], dtype="datetime64[ns, UTC]") + + tm.assert_extension_array_equal(naive, utc, check_dtype=False) + tm.assert_extension_array_equal(utc, naive, check_dtype=False) diff --git a/pandas/tests/util/test_assert_frame_equal.py b/pandas/tests/util/test_assert_frame_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..19abfe727fb4b3d4046d1ec99a84c8d6e2afc6c8 --- /dev/null +++ b/pandas/tests/util/test_assert_frame_equal.py @@ -0,0 +1,425 @@ +import numpy as np +import pytest + +from pandas.errors import Pandas4Warning + +import pandas as pd +from pandas import DataFrame +import pandas._testing as tm + + +@pytest.fixture(params=[True, False]) +def by_blocks_fixture(request): + return request.param + + +def _assert_frame_equal_both(a, b, **kwargs): + """ + Check that two DataFrame equal. + + This check is performed commutatively. + + Parameters + ---------- + a : DataFrame + The first DataFrame to compare. + b : DataFrame + The second DataFrame to compare. + kwargs : dict + The arguments passed to `tm.assert_frame_equal`. + """ + tm.assert_frame_equal(a, b, **kwargs) + tm.assert_frame_equal(b, a, **kwargs) + + +@pytest.mark.parametrize("check_like", [True, False]) +def test_frame_equal_row_order_mismatch(check_like, frame_or_series): + df1 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=["a", "b", "c"]) + df2 = DataFrame({"A": [3, 2, 1], "B": [6, 5, 4]}, index=["c", "b", "a"]) + + if not check_like: # Do not ignore row-column orderings. + msg = f"{frame_or_series.__name__}.index are different" + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal( + df1, df2, check_like=check_like, obj=frame_or_series.__name__ + ) + else: + _assert_frame_equal_both( + df1, df2, check_like=check_like, obj=frame_or_series.__name__ + ) + + +@pytest.mark.parametrize( + "df1,df2", + [ + ({"A": [1, 2, 3]}, {"A": [1, 2, 3, 4]}), + ({"A": [1, 2, 3], "B": [4, 5, 6]}, {"A": [1, 2, 3]}), + ], +) +def test_frame_equal_shape_mismatch(df1, df2, frame_or_series): + df1 = DataFrame(df1) + df2 = DataFrame(df2) + msg = f"{frame_or_series.__name__} are different" + + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(df1, df2, obj=frame_or_series.__name__) + + +@pytest.mark.parametrize( + "df1,df2,msg", + [ + # Index + ( + DataFrame.from_records({"a": [1, 2], "c": ["l1", "l2"]}, index=["a"]), + DataFrame.from_records({"a": [1.0, 2.0], "c": ["l1", "l2"]}, index=["a"]), + "DataFrame\\.index are different", + ), + # MultiIndex + ( + DataFrame.from_records( + {"a": [1, 2], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"] + ), + DataFrame.from_records( + {"a": [1.0, 2.0], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"] + ), + "DataFrame\\.index level \\[0\\] are different", + ), + ], +) +def test_frame_equal_index_dtype_mismatch(df1, df2, msg, check_index_type): + kwargs = {"check_index_type": check_index_type} + + if check_index_type: + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(df1, df2, **kwargs) + else: + tm.assert_frame_equal(df1, df2, **kwargs) + + +def test_empty_dtypes(check_dtype): + columns = ["col1", "col2"] + df1 = DataFrame(columns=columns) + df2 = DataFrame(columns=columns) + + kwargs = {"check_dtype": check_dtype} + df1["col1"] = df1["col1"].astype("int64") + + if check_dtype: + msg = r"Attributes of DataFrame\..* are different" + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(df1, df2, **kwargs) + else: + tm.assert_frame_equal(df1, df2, **kwargs) + + +@pytest.mark.parametrize("check_like", [True, False]) +def test_frame_equal_index_mismatch(check_like, frame_or_series, using_infer_string): + if using_infer_string: + dtype = "str" + else: + dtype = "object" + msg = f"""{frame_or_series.__name__}\\.index are different + +{frame_or_series.__name__}\\.index values are different \\(33\\.33333 %\\) +\\[left\\]: Index\\(\\['a', 'b', 'c'\\], dtype='{dtype}'\\) +\\[right\\]: Index\\(\\['a', 'b', 'd'\\], dtype='{dtype}'\\) +At positional index 2, first diff: c != d""" + + df1 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=["a", "b", "c"]) + df2 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=["a", "b", "d"]) + + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal( + df1, df2, check_like=check_like, obj=frame_or_series.__name__ + ) + + +@pytest.mark.parametrize("check_like", [True, False]) +def test_frame_equal_columns_mismatch(check_like, frame_or_series, using_infer_string): + if using_infer_string: + dtype = "str" + else: + dtype = "object" + msg = f"""{frame_or_series.__name__}\\.columns are different + +{frame_or_series.__name__}\\.columns values are different \\(50\\.0 %\\) +\\[left\\]: Index\\(\\['A', 'B'\\], dtype='{dtype}'\\) +\\[right\\]: Index\\(\\['A', 'b'\\], dtype='{dtype}'\\)""" + + df1 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}, index=["a", "b", "c"]) + df2 = DataFrame({"A": [1, 2, 3], "b": [4, 5, 6]}, index=["a", "b", "c"]) + + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal( + df1, df2, check_like=check_like, obj=frame_or_series.__name__ + ) + + +def test_frame_equal_block_mismatch(by_blocks_fixture, frame_or_series): + obj = frame_or_series.__name__ + msg = f"""{obj}\\.iloc\\[:, 1\\] \\(column name="B"\\) are different + +{obj}\\.iloc\\[:, 1\\] \\(column name="B"\\) values are different \\(33\\.33333 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\[4, 5, 6\\] +\\[right\\]: \\[4, 5, 7\\]""" + + df1 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + df2 = DataFrame({"A": [1, 2, 3], "B": [4, 5, 7]}) + + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(df1, df2, by_blocks=by_blocks_fixture, obj=obj) + + +@pytest.mark.parametrize( + "df1,df2,msg", + [ + ( + {"A": ["á", "à", "ä"], "E": ["é", "è", "ë"]}, + {"A": ["á", "à", "ä"], "E": ["é", "è", "e̊"]}, + """{obj}\\.iloc\\[:, 1\\] \\(column name="E"\\) are different + +{obj}\\.iloc\\[:, 1\\] \\(column name="E"\\) values are different \\(33\\.33333 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\[é, è, ë\\] +\\[right\\]: \\[é, è, e̊\\]""", + ), + ( + {"A": ["á", "à", "ä"], "E": ["é", "è", "ë"]}, + {"A": ["a", "a", "a"], "E": ["e", "e", "e"]}, + """{obj}\\.iloc\\[:, 0\\] \\(column name="A"\\) are different + +{obj}\\.iloc\\[:, 0\\] \\(column name="A"\\) values are different \\(100\\.0 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\[á, à, ä\\] +\\[right\\]: \\[a, a, a\\]""", + ), + ], +) +def test_frame_equal_unicode(df1, df2, msg, by_blocks_fixture, frame_or_series): + # see gh-20503 + # + # Test ensures that `tm.assert_frame_equals` raises the right exception + # when comparing DataFrames containing differing unicode objects. + df1 = DataFrame(df1) + df2 = DataFrame(df2) + msg = msg.format(obj=frame_or_series.__name__) + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal( + df1, df2, by_blocks=by_blocks_fixture, obj=frame_or_series.__name__ + ) + + +def test_assert_frame_equal_extension_dtype_mismatch(): + # https://github.com/pandas-dev/pandas/issues/32747 + left = DataFrame({"a": [1, 2, 3]}, dtype="Int64") + right = left.astype(int) + + msg = ( + "Attributes of DataFrame\\.iloc\\[:, 0\\] " + '\\(column name="a"\\) are different\n\n' + 'Attribute "dtype" are different\n' + "\\[left\\]: Int64\n" + "\\[right\\]: int[32|64]" + ) + + tm.assert_frame_equal(left, right, check_dtype=False) + + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(left, right, check_dtype=True) + + +def test_assert_frame_equal_interval_dtype_mismatch(): + # https://github.com/pandas-dev/pandas/issues/32747 + left = DataFrame({"a": [pd.Interval(0, 1)]}, dtype="interval") + right = left.astype(object) + + msg = ( + "Attributes of DataFrame\\.iloc\\[:, 0\\] " + '\\(column name="a"\\) are different\n\n' + 'Attribute "dtype" are different\n' + "\\[left\\]: interval\\[int64, right\\]\n" + "\\[right\\]: object" + ) + + tm.assert_frame_equal(left, right, check_dtype=False) + + with pytest.raises(AssertionError, match=msg): + tm.assert_frame_equal(left, right, check_dtype=True) + + +def test_assert_frame_equal_ignore_extension_dtype_mismatch(): + # https://github.com/pandas-dev/pandas/issues/35715 + left = DataFrame({"a": [1, 2, 3]}, dtype="Int64") + right = DataFrame({"a": [1, 2, 3]}, dtype="Int32") + tm.assert_frame_equal(left, right, check_dtype=False) + + +def test_assert_frame_equal_ignore_extension_dtype_mismatch_cross_class(): + # https://github.com/pandas-dev/pandas/issues/35715 + left = DataFrame({"a": [1, 2, 3]}, dtype="Int64") + right = DataFrame({"a": [1, 2, 3]}, dtype="int64") + tm.assert_frame_equal(left, right, check_dtype=False) + + +@pytest.mark.parametrize( + "dtype", ["timedelta64[ns]", "datetime64[ns, UTC]", "Period[D]"] +) +def test_assert_frame_equal_datetime_like_dtype_mismatch(dtype): + df1 = DataFrame({"a": []}, dtype=dtype) + df2 = DataFrame({"a": []}) + tm.assert_frame_equal(df1, df2, check_dtype=False) + + +def test_allows_duplicate_labels(): + left = DataFrame() + right = DataFrame().set_flags(allows_duplicate_labels=False) + tm.assert_frame_equal(left, left) + tm.assert_frame_equal(right, right) + tm.assert_frame_equal(left, right, check_flags=False) + tm.assert_frame_equal(right, left, check_flags=False) + + with pytest.raises(AssertionError, match="\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_numpy_array_equal(a, b) + + +def test_numpy_array_equal_identical_na(nulls_fixture): + a = np.array([nulls_fixture], dtype=object) + + tm.assert_numpy_array_equal(a, a) + + # matching but not the identical object + if hasattr(nulls_fixture, "copy"): + other = nulls_fixture.copy() + else: + other = copy.copy(nulls_fixture) + b = np.array([other], dtype=object) + tm.assert_numpy_array_equal(a, b) + + +def test_numpy_array_equal_different_na(): + a = np.array([np.nan], dtype=object) + b = np.array([pd.NA], dtype=object) + + msg = """numpy array are different + +numpy array values are different \\(100.0 %\\) +\\[left\\]: \\[nan\\] +\\[right\\]: \\[\\]""" + + with pytest.raises(AssertionError, match=msg): + tm.assert_numpy_array_equal(a, b) diff --git a/pandas/tests/util/test_assert_produces_warning.py b/pandas/tests/util/test_assert_produces_warning.py new file mode 100644 index 0000000000000000000000000000000000000000..9316f1452477c8b665339929cc8faf9498539b8f --- /dev/null +++ b/pandas/tests/util/test_assert_produces_warning.py @@ -0,0 +1,277 @@ +""" " +Test module for testing ``pandas._testing.assert_produces_warning``. +""" + +import warnings + +import pytest + +from pandas.errors import ( + DtypeWarning, + PerformanceWarning, +) + +import pandas._testing as tm + + +@pytest.fixture( + params=[ + (RuntimeWarning, UserWarning), + (UserWarning, FutureWarning), + (FutureWarning, RuntimeWarning), + (DeprecationWarning, PerformanceWarning), + (PerformanceWarning, FutureWarning), + (DtypeWarning, DeprecationWarning), + (ResourceWarning, DeprecationWarning), + (FutureWarning, DeprecationWarning), + ], + ids=lambda x: type(x).__name__, +) +def pair_different_warnings(request): + """ + Return pair or different warnings. + + Useful for testing how several different warnings are handled + in tm.assert_produces_warning. + """ + return request.param + + +def f(): + warnings.warn("f1", FutureWarning) # pdlint: ignore[warning_class] + warnings.warn("f2", RuntimeWarning) + + +def test_assert_produces_warning_honors_filter(): + # Raise by default. + msg = r"Caused unexpected warning\(s\)" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(RuntimeWarning): + f() + + with tm.assert_produces_warning(RuntimeWarning, raise_on_extra_warnings=False): + f() + + +@pytest.mark.parametrize( + "category", + [ + RuntimeWarning, + ResourceWarning, + UserWarning, + FutureWarning, + DeprecationWarning, + PerformanceWarning, + DtypeWarning, + ], +) +@pytest.mark.parametrize( + "message, match", + [ + ("", None), + ("", ""), + ("Warning message", r".*"), + ("Warning message", "War"), + ("Warning message", r"[Ww]arning"), + ("Warning message", "age"), + ("Warning message", r"age$"), + ("Message 12-234 with numbers", r"\d{2}-\d{3}"), + ("Message 12-234 with numbers", r"^Mes.*\d{2}-\d{3}"), + ("Message 12-234 with numbers", r"\d{2}-\d{3}\s\S+"), + ("Message, which we do not match", None), + ], +) +def test_catch_warning_category_and_match(category, message, match): + with tm.assert_produces_warning(category, match=match): + warnings.warn(message, category) + + +def test_fail_to_match_runtime_warning(): + category = RuntimeWarning + match = "Did not see this warning" + unmatched = ( + r"Did not see warning 'RuntimeWarning' matching 'Did not see this warning'. " + r"The emitted warning messages are " + r"\[RuntimeWarning\('This is not a match.'\), " + r"RuntimeWarning\('Another unmatched warning.'\)\]" + ) + with pytest.raises(AssertionError, match=unmatched): + with tm.assert_produces_warning(category, match=match): + warnings.warn("This is not a match.", category) + warnings.warn("Another unmatched warning.", category) + + +def test_fail_to_match_future_warning(): + category = FutureWarning + match = "Warning" + unmatched = ( + r"Did not see warning 'FutureWarning' matching 'Warning'. " + r"The emitted warning messages are " + r"\[FutureWarning\('This is not a match.'\), " + r"FutureWarning\('Another unmatched warning.'\)\]" + ) + with pytest.raises(AssertionError, match=unmatched): + with tm.assert_produces_warning(category, match=match): + warnings.warn("This is not a match.", category) + warnings.warn("Another unmatched warning.", category) + + +def test_fail_to_match_resource_warning(): + category = ResourceWarning + match = r"\d+" + unmatched = ( + r"Did not see warning 'ResourceWarning' matching '\\d\+'. " + r"The emitted warning messages are " + r"\[ResourceWarning\('This is not a match.'\), " + r"ResourceWarning\('Another unmatched warning.'\)\]" + ) + with pytest.raises(AssertionError, match=unmatched): + with tm.assert_produces_warning(category, match=match): + warnings.warn("This is not a match.", category) + warnings.warn("Another unmatched warning.", category) + + +def test_fail_to_catch_actual_warning(pair_different_warnings): + expected_category, actual_category = pair_different_warnings + match = "Did not see expected warning of class" + with pytest.raises(AssertionError, match=match): + with tm.assert_produces_warning(expected_category): + warnings.warn("warning message", actual_category) + + +def test_ignore_extra_warning(pair_different_warnings): + expected_category, extra_category = pair_different_warnings + with tm.assert_produces_warning(expected_category, raise_on_extra_warnings=False): + warnings.warn("Expected warning", expected_category) + warnings.warn("Unexpected warning OK", extra_category) + + +def test_raise_on_extra_warning(pair_different_warnings): + expected_category, extra_category = pair_different_warnings + match = r"Caused unexpected warning\(s\)" + with pytest.raises(AssertionError, match=match): + with tm.assert_produces_warning(expected_category): + warnings.warn("Expected warning", expected_category) + warnings.warn("Unexpected warning NOT OK", extra_category) + + +def test_same_category_different_messages_first_match(): + category = UserWarning + with tm.assert_produces_warning(category, match=r"^Match this"): + warnings.warn("Match this", category) + warnings.warn("Do not match that", category) + warnings.warn("Do not match that either", category) + + +def test_same_category_different_messages_last_match(): + category = DeprecationWarning + with tm.assert_produces_warning(category, match=r"^Match this"): + warnings.warn("Do not match that", category) + warnings.warn("Do not match that either", category) + warnings.warn("Match this", category) + + +def test_match_multiple_warnings(): + # https://github.com/pandas-dev/pandas/issues/47829 + category = (FutureWarning, UserWarning) + with tm.assert_produces_warning(category, match=r"^Match this"): + warnings.warn("Match this", FutureWarning) # pdlint: ignore[warning_class] + warnings.warn("Match this too", UserWarning) + + +def test_must_match_multiple_warnings(): + # https://github.com/pandas-dev/pandas/issues/56555 + category = (FutureWarning, UserWarning) + msg = "Did not see expected warning of class 'UserWarning'" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(category, match=r"^Match this"): + warnings.warn("Match this", FutureWarning) # pdlint: ignore[warning_class] + + +def test_must_match_multiple_warnings_messages(): + # https://github.com/pandas-dev/pandas/issues/56555 + category = (FutureWarning, UserWarning) + msg = r"The emitted warning messages are \[UserWarning\('Not this'\)\]" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(category, match=r"^Match this"): + warnings.warn("Match this", FutureWarning) # pdlint: ignore[warning_class] + warnings.warn("Not this", UserWarning) + + +def test_allow_partial_match_for_multiple_warnings(): + # https://github.com/pandas-dev/pandas/issues/56555 + category = (FutureWarning, UserWarning) + with tm.assert_produces_warning( + category, match=r"^Match this", must_find_all_warnings=False + ): + warnings.warn("Match this", FutureWarning) # pdlint: ignore[warning_class] + + +def test_allow_partial_match_for_multiple_warnings_messages(): + # https://github.com/pandas-dev/pandas/issues/56555 + category = (FutureWarning, UserWarning) + with tm.assert_produces_warning( + category, match=r"^Match this", must_find_all_warnings=False + ): + warnings.warn("Match this", FutureWarning) # pdlint: ignore[warning_class] + warnings.warn("Not this", UserWarning) + + +def test_right_category_wrong_match_raises(pair_different_warnings): + target_category, other_category = pair_different_warnings + with pytest.raises(AssertionError, match="Did not see warning.*matching"): + with tm.assert_produces_warning(target_category, match=r"^Match this"): + warnings.warn("Do not match it", target_category) + warnings.warn("Match this", other_category) + + +@pytest.mark.parametrize("false_or_none", [False, None]) +class TestFalseOrNoneExpectedWarning: + def test_raise_on_warning(self, false_or_none): + msg = r"Caused unexpected warning\(s\)" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(false_or_none): + f() + + def test_no_raise_without_warning(self, false_or_none): + with tm.assert_produces_warning(false_or_none): + pass + + def test_no_raise_with_false_raise_on_extra(self, false_or_none): + with tm.assert_produces_warning(false_or_none, raise_on_extra_warnings=False): + f() + + +def test_raises_during_exception(): + msg = "Did not see expected warning of class 'UserWarning'" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(UserWarning): + raise ValueError + + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(UserWarning): + warnings.warn( + "FutureWarning", FutureWarning + ) # pdlint: ignore[warning_class] + raise IndexError + + msg = "Caused unexpected warning" + with pytest.raises(AssertionError, match=msg): + with tm.assert_produces_warning(None): + warnings.warn( + "FutureWarning", FutureWarning + ) # pdlint: ignore[warning_class] + raise SystemError + + +def test_passes_during_exception(): + with pytest.raises(SyntaxError, match="Error"): + with tm.assert_produces_warning(None): + raise SyntaxError("Error") + + with pytest.raises(ValueError, match="Error"): + with tm.assert_produces_warning(FutureWarning, match="FutureWarning"): + warnings.warn( + "FutureWarning", FutureWarning + ) # pdlint: ignore[warning_class] + raise ValueError("Error") diff --git a/pandas/tests/util/test_assert_series_equal.py b/pandas/tests/util/test_assert_series_equal.py new file mode 100644 index 0000000000000000000000000000000000000000..683ca1d875ac5efde9d13498b5b36f17983196db --- /dev/null +++ b/pandas/tests/util/test_assert_series_equal.py @@ -0,0 +1,509 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + Categorical, + DataFrame, + Series, +) +import pandas._testing as tm + + +def _assert_series_equal_both(a, b, **kwargs): + """ + Check that two Series equal. + + This check is performed commutatively. + + Parameters + ---------- + a : Series + The first Series to compare. + b : Series + The second Series to compare. + kwargs : dict + The arguments passed to `tm.assert_series_equal`. + """ + tm.assert_series_equal(a, b, **kwargs) + tm.assert_series_equal(b, a, **kwargs) + + +def _assert_not_series_equal(a, b, **kwargs): + """ + Check that two Series are not equal. + + Parameters + ---------- + a : Series + The first Series to compare. + b : Series + The second Series to compare. + kwargs : dict + The arguments passed to `tm.assert_series_equal`. + """ + try: + tm.assert_series_equal(a, b, **kwargs) + msg = "The two Series were equal when they shouldn't have been" + + pytest.fail(msg=msg) + except AssertionError: + pass + + +def _assert_not_series_equal_both(a, b, **kwargs): + """ + Check that two Series are not equal. + + This check is performed commutatively. + + Parameters + ---------- + a : Series + The first Series to compare. + b : Series + The second Series to compare. + kwargs : dict + The arguments passed to `tm.assert_series_equal`. + """ + _assert_not_series_equal(a, b, **kwargs) + _assert_not_series_equal(b, a, **kwargs) + + +@pytest.mark.parametrize("data", [range(3), list("abc"), list("áàä")]) +def test_series_equal(data): + _assert_series_equal_both(Series(data), Series(data)) + + +@pytest.mark.parametrize( + "data1,data2", + [ + (range(3), range(1, 4)), + (list("abc"), list("xyz")), + (list("áàä"), list("éèë")), + (list("áàä"), list(b"aaa")), + (range(3), range(4)), + ], +) +def test_series_not_equal_value_mismatch(data1, data2): + _assert_not_series_equal_both(Series(data1), Series(data2)) + + +@pytest.mark.parametrize( + "kwargs", + [ + {"dtype": "float64"}, # dtype mismatch + {"index": [1, 2, 4]}, # index mismatch + {"name": "foo"}, # name mismatch + ], +) +def test_series_not_equal_metadata_mismatch(kwargs): + data = range(3) + s1 = Series(data) + + s2 = Series(data, **kwargs) + _assert_not_series_equal_both(s1, s2) + + +@pytest.mark.parametrize("data1,data2", [(0.12345, 0.12346), (0.1235, 0.1236)]) +@pytest.mark.parametrize("decimals", [0, 1, 2, 3, 5, 10]) +def test_less_precise(data1, data2, any_float_dtype, decimals): + rtol = 10**-decimals + s1 = Series([data1], dtype=any_float_dtype) + s2 = Series([data2], dtype=any_float_dtype) + + if decimals in (5, 10) or (decimals >= 3 and abs(data1 - data2) >= 0.0005): + msg = "Series values are different" + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, rtol=rtol) + else: + _assert_series_equal_both(s1, s2, rtol=rtol) + + +@pytest.mark.parametrize( + "s1,s2,msg", + [ + # Index + ( + Series(["l1", "l2"], index=[1, 2]), + Series(["l1", "l2"], index=[1.0, 2.0]), + "Series\\.index are different", + ), + # MultiIndex + ( + DataFrame.from_records( + {"a": [1, 2], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"] + ).c, + DataFrame.from_records( + {"a": [1.0, 2.0], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"] + ).c, + "Series\\.index level \\[0\\] are different", + ), + ], +) +def test_series_equal_index_dtype(s1, s2, msg, check_index_type): + kwargs = {"check_index_type": check_index_type} + + if check_index_type: + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, **kwargs) + else: + tm.assert_series_equal(s1, s2, **kwargs) + + +@pytest.mark.parametrize("check_like", [True, False]) +def test_series_equal_order_mismatch(check_like): + s1 = Series([1, 2, 3], index=["a", "b", "c"]) + s2 = Series([3, 2, 1], index=["c", "b", "a"]) + + if not check_like: # Do not ignore index ordering. + with pytest.raises(AssertionError, match="Series.index are different"): + tm.assert_series_equal(s1, s2, check_like=check_like) + else: + _assert_series_equal_both(s1, s2, check_like=check_like) + + +@pytest.mark.parametrize("check_index", [True, False]) +def test_series_equal_index_mismatch(check_index): + s1 = Series([1, 2, 3], index=["a", "b", "c"]) + s2 = Series([1, 2, 3], index=["c", "b", "a"]) + + if check_index: # Do not ignore index. + with pytest.raises(AssertionError, match="Series.index are different"): + tm.assert_series_equal(s1, s2, check_index=check_index) + else: + _assert_series_equal_both(s1, s2, check_index=check_index) + + +def test_series_invalid_param_combination(): + left = Series(dtype=object) + right = Series(dtype=object) + with pytest.raises( + ValueError, match="check_like must be False if check_index is False" + ): + tm.assert_series_equal(left, right, check_index=False, check_like=True) + + +def test_series_equal_length_mismatch(rtol): + msg = """Series are different + +Series length are different +\\[left\\]: 3, RangeIndex\\(start=0, stop=3, step=1\\) +\\[right\\]: 4, RangeIndex\\(start=0, stop=4, step=1\\)""" + + s1 = Series([1, 2, 3]) + s2 = Series([1, 2, 3, 4]) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, rtol=rtol) + + +def test_series_equal_numeric_values_mismatch(rtol): + msg = """Series are different + +Series values are different \\(33\\.33333 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\[1, 2, 3\\] +\\[right\\]: \\[1, 2, 4\\]""" + + s1 = Series([1, 2, 3]) + s2 = Series([1, 2, 4]) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, rtol=rtol) + + +def test_series_equal_categorical_values_mismatch(rtol, using_infer_string): + dtype = "str" if using_infer_string else "object" + msg = f"""Series are different + +Series values are different \\(66\\.66667 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\['a', 'b', 'c'\\] +Categories \\(3, {dtype}\\): \\['a', 'b', 'c'\\] +\\[right\\]: \\['a', 'c', 'b'\\] +Categories \\(3, {dtype}\\): \\['a', 'b', 'c'\\]""" + + s1 = Series(Categorical(["a", "b", "c"])) + s2 = Series(Categorical(["a", "c", "b"])) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, rtol=rtol) + + +def test_series_equal_datetime_values_mismatch(rtol): + msg = """Series are different + +Series values are different \\(100.0 %\\) +\\[index\\]: \\[0, 1, 2\\] +\\[left\\]: \\[1514764800000000000, 1514851200000000000, 1514937600000000000\\] +\\[right\\]: \\[1549065600000000000, 1549152000000000000, 1549238400000000000\\]""" + + s1 = Series(pd.date_range("2018-01-01", periods=3, freq="D", unit="ns")) + s2 = Series(pd.date_range("2019-02-02", periods=3, freq="D", unit="ns")) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, rtol=rtol) + + +def test_series_equal_categorical_mismatch(check_categorical, using_infer_string): + if using_infer_string: + dtype = "str" + else: + dtype = "object" + msg = f"""Attributes of Series are different + +Attribute "dtype" are different +\\[left\\]: CategoricalDtype\\(categories=\\['a', 'b'\\], ordered=False, \ +categories_dtype={dtype}\\) +\\[right\\]: CategoricalDtype\\(categories=\\['a', 'b', 'c'\\], \ +ordered=False, categories_dtype={dtype}\\)""" + + s1 = Series(Categorical(["a", "b"])) + s2 = Series(Categorical(["a", "b"], categories=list("abc"))) + + if check_categorical: + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s2, check_categorical=check_categorical) + else: + _assert_series_equal_both(s1, s2, check_categorical=check_categorical) + + +def test_assert_series_equal_extension_dtype_mismatch(): + # https://github.com/pandas-dev/pandas/issues/32747 + left = Series(pd.array([1, 2, 3], dtype="Int64")) + right = left.astype(int) + + msg = """Attributes of Series are different + +Attribute "dtype" are different +\\[left\\]: Int64 +\\[right\\]: int[32|64]""" + + tm.assert_series_equal(left, right, check_dtype=False) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(left, right, check_dtype=True) + + +def test_assert_series_equal_interval_dtype_mismatch(): + # https://github.com/pandas-dev/pandas/issues/32747 + left = Series([pd.Interval(0, 1)], dtype="interval") + right = left.astype(object) + + msg = """Attributes of Series are different + +Attribute "dtype" are different +\\[left\\]: interval\\[int64, right\\] +\\[right\\]: object""" + + tm.assert_series_equal(left, right, check_dtype=False) + + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(left, right, check_dtype=True) + + +def test_series_equal_series_type(): + class MySeries(Series): + pass + + s1 = Series([1, 2]) + s2 = Series([1, 2]) + s3 = MySeries([1, 2]) + + tm.assert_series_equal(s1, s2, check_series_type=False) + tm.assert_series_equal(s1, s2, check_series_type=True) + + tm.assert_series_equal(s1, s3, check_series_type=False) + tm.assert_series_equal(s3, s1, check_series_type=False) + + with pytest.raises(AssertionError, match="Series classes are different"): + tm.assert_series_equal(s1, s3, check_series_type=True) + + with pytest.raises(AssertionError, match="Series classes are different"): + tm.assert_series_equal(s3, s1, check_series_type=True) + + +def test_series_equal_exact_for_nonnumeric(): + # https://github.com/pandas-dev/pandas/issues/35446 + s1 = Series(["a", "b"]) + s2 = Series(["a", "b"]) + s3 = Series(["b", "a"]) + + tm.assert_series_equal(s1, s2, check_exact=True) + tm.assert_series_equal(s2, s1, check_exact=True) + + msg = """Series are different + +Series values are different \\(100\\.0 %\\) +\\[index\\]: \\[0, 1\\] +\\[left\\]: \\[a, b\\] +\\[right\\]: \\[b, a\\]""" + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s1, s3, check_exact=True) + + msg = """Series are different + +Series values are different \\(100\\.0 %\\) +\\[index\\]: \\[0, 1\\] +\\[left\\]: \\[b, a\\] +\\[right\\]: \\[a, b\\]""" + with pytest.raises(AssertionError, match=msg): + tm.assert_series_equal(s3, s1, check_exact=True) + + +def test_assert_series_equal_ignore_extension_dtype_mismatch(): + # https://github.com/pandas-dev/pandas/issues/35715 + left = Series([1, 2, 3], dtype="Int64") + right = Series([1, 2, 3], dtype="Int32") + tm.assert_series_equal(left, right, check_dtype=False) + + +def test_assert_series_equal_ignore_extension_dtype_mismatch_cross_class(): + # https://github.com/pandas-dev/pandas/issues/35715 + left = Series([1, 2, 3], dtype="Int64") + right = Series([1, 2, 3], dtype="int64") + tm.assert_series_equal(left, right, check_dtype=False) + + +def test_allows_duplicate_labels(): + left = Series([1]) + right = Series([1]).set_flags(allows_duplicate_labels=False) + tm.assert_series_equal(left, left) + tm.assert_series_equal(right, right) + tm.assert_series_equal(left, right, check_flags=False) + tm.assert_series_equal(right, left, check_flags=False) + + with pytest.raises(AssertionError, match=">> cumavg([1, 2, 3]) + 2 + """ + ), + method="cumavg", + operation="average", +) +def cumavg(whatever): + pass + + +@doc(cumsum, method="cummax", operation="maximum") +def cummax(whatever): + pass + + +@doc(cummax, method="cummin", operation="minimum") +def cummin(whatever): + pass + + +def test_docstring_formatting(): + docstr = dedent( + """ + This is the cumsum method. + + It computes the cumulative sum. + """ + ) + assert cumsum.__doc__ == docstr + + +def test_docstring_appending(): + docstr = dedent( + """ + This is the cumavg method. + + It computes the cumulative average. + + Examples + -------- + + >>> cumavg([1, 2, 3]) + 2 + """ + ) + assert cumavg.__doc__ == docstr + + +def test_doc_template_from_func(): + docstr = dedent( + """ + This is the cummax method. + + It computes the cumulative maximum. + """ + ) + assert cummax.__doc__ == docstr + + +def test_inherit_doc_template(): + docstr = dedent( + """ + This is the cummin method. + + It computes the cumulative minimum. + """ + ) + assert cummin.__doc__ == docstr diff --git a/pandas/tests/util/test_hashing.py b/pandas/tests/util/test_hashing.py new file mode 100644 index 0000000000000000000000000000000000000000..d6bc7017c2483c7ee0996a81b56f9e8e67cc894d --- /dev/null +++ b/pandas/tests/util/test_hashing.py @@ -0,0 +1,418 @@ +import numpy as np +import pytest + +import pandas as pd +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + period_range, + timedelta_range, +) +import pandas._testing as tm +from pandas.core.util.hashing import hash_tuples +from pandas.util import ( + hash_array, + hash_pandas_object, +) + + +@pytest.fixture( + params=[ + Series([1, 2, 3] * 3, dtype="int32"), + Series([None, 2.5, 3.5] * 3, dtype="float32"), + Series(["a", "b", "c"] * 3, dtype="category"), + Series(["d", "e", "f"] * 3), + Series([True, False, True] * 3), + Series(pd.date_range("20130101", periods=9)), + Series(pd.date_range("20130101", periods=9, tz="US/Eastern")), + Series(timedelta_range("2000", periods=9)), + ] +) +def series(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def index(request): + return request.param + + +def test_consistency(): + # Check that our hash doesn't change because of a mistake + # in the actual code; this is the ground truth. + result = hash_pandas_object(Index(["foo", "bar", "baz"])) + expected = Series( + np.array( + [3600424527151052760, 1374399572096150070, 477881037637427054], + dtype="uint64", + ), + index=["foo", "bar", "baz"], + ) + tm.assert_series_equal(result, expected) + + +def test_hash_array(series): + arr = series.values + tm.assert_numpy_array_equal(hash_array(arr), hash_array(arr)) + + +@pytest.mark.parametrize("dtype", ["U", object]) +def test_hash_array_mixed(dtype): + result1 = hash_array(np.array(["3", "4", "All"])) + result2 = hash_array(np.array([3, 4, "All"], dtype=dtype)) + + tm.assert_numpy_array_equal(result1, result2) + + +@pytest.mark.parametrize("val", [5, "foo", pd.Timestamp("20130101")]) +def test_hash_array_errors(val): + msg = "must pass an ndarray-like" + with pytest.raises(TypeError, match=msg): + hash_array(val) + + +def test_hash_array_index_exception(): + # GH42003 TypeError instead of AttributeError + obj = pd.DatetimeIndex(["2018-10-28 01:20:00"], tz="Europe/Berlin") + + msg = "Use hash_pandas_object instead" + with pytest.raises(TypeError, match=msg): + hash_array(obj) + + +def test_hash_tuples(): + tuples = [(1, "one"), (1, "two"), (2, "one")] + result = hash_tuples(tuples) + + expected = hash_pandas_object(MultiIndex.from_tuples(tuples)).values + tm.assert_numpy_array_equal(result, expected) + + # We only need to support MultiIndex and list-of-tuples + msg = "|".join(["object is not iterable", "zip argument #1 must support iteration"]) + with pytest.raises(TypeError, match=msg): + hash_tuples(tuples[0]) + + +@pytest.mark.parametrize("val", [5, "foo", pd.Timestamp("20130101")]) +def test_hash_tuples_err(val): + msg = "must be convertible to a list-of-tuples" + with pytest.raises(TypeError, match=msg): + hash_tuples(val) + + +def test_multiindex_unique(): + mi = MultiIndex.from_tuples([(118, 472), (236, 118), (51, 204), (102, 51)]) + assert mi.is_unique is True + + result = hash_pandas_object(mi) + assert result.is_unique is True + + +def test_multiindex_objects(): + mi = MultiIndex( + levels=[["b", "d", "a"], [1, 2, 3]], + codes=[[0, 1, 0, 2], [2, 0, 0, 1]], + names=["col1", "col2"], + ) + recons = mi._sort_levels_monotonic() + + # These are equal. + assert mi.equals(recons) + assert Index(mi.values).equals(Index(recons.values)) + + +@pytest.mark.parametrize( + "obj", + [ + Series([1, 2, 3]), + Series([1.0, 1.5, 3.2]), + Series([1.0, 1.5, np.nan]), + Series([1.0, 1.5, 3.2], index=[1.5, 1.1, 3.3]), + Series(["a", "b", "c"]), + Series(["a", np.nan, "c"]), + Series(["a", None, "c"]), + Series([True, False, True]), + Series(dtype=object), + DataFrame({"x": ["a", "b", "c"], "y": [1, 2, 3]}), + DataFrame(), + DataFrame(np.full((10, 4), np.nan)), + DataFrame( + { + "A": [0.0, 1.0, 2.0, 3.0, 4.0], + "B": [0.0, 1.0, 0.0, 1.0, 0.0], + "C": Index(["foo1", "foo2", "foo3", "foo4", "foo5"], dtype=object), + "D": pd.date_range("20130101", periods=5), + } + ), + DataFrame(range(5), index=pd.date_range("2020-01-01", periods=5)), + Series(range(5), index=pd.date_range("2020-01-01", periods=5)), + Series(period_range("2020-01-01", periods=10, freq="D")), + Series(pd.date_range("20130101", periods=3, tz="US/Eastern")), + ], +) +def test_hash_pandas_object(obj, index): + a = hash_pandas_object(obj, index=index) + b = hash_pandas_object(obj, index=index) + tm.assert_series_equal(a, b) + + +@pytest.mark.parametrize( + "obj", + [ + Series([1, 2, 3]), + Series([1.0, 1.5, 3.2]), + Series([1.0, 1.5, np.nan]), + Series([1.0, 1.5, 3.2], index=[1.5, 1.1, 3.3]), + Series(["a", "b", "c"]), + Series(["a", np.nan, "c"]), + Series(["a", None, "c"]), + Series([True, False, True]), + DataFrame({"x": ["a", "b", "c"], "y": [1, 2, 3]}), + DataFrame(np.full((10, 4), np.nan)), + DataFrame( + { + "A": [0.0, 1.0, 2.0, 3.0, 4.0], + "B": [0.0, 1.0, 0.0, 1.0, 0.0], + "C": Index(["foo1", "foo2", "foo3", "foo4", "foo5"], dtype=object), + "D": pd.date_range("20130101", periods=5), + } + ), + DataFrame(range(5), index=pd.date_range("2020-01-01", periods=5)), + Series(range(5), index=pd.date_range("2020-01-01", periods=5)), + Series(period_range("2020-01-01", periods=10, freq="D")), + Series(pd.date_range("20130101", periods=3, tz="US/Eastern")), + ], +) +def test_hash_pandas_object_diff_index_non_empty(obj): + a = hash_pandas_object(obj, index=True) + b = hash_pandas_object(obj, index=False) + assert not (a == b).all() + + +@pytest.mark.parametrize( + "obj", + [ + Index([1, 2, 3]), + Index([True, False, True]), + timedelta_range("1 day", periods=2), + period_range("2020-01-01", freq="D", periods=2), + MultiIndex.from_product( + [range(5), ["foo", "bar", "baz"], pd.date_range("20130101", periods=2)] + ), + MultiIndex.from_product([pd.CategoricalIndex(list("aabc")), range(3)]), + ], +) +def test_hash_pandas_index(obj, index): + a = hash_pandas_object(obj, index=index) + b = hash_pandas_object(obj, index=index) + tm.assert_series_equal(a, b) + + +def test_hash_pandas_series(series, index): + a = hash_pandas_object(series, index=index) + b = hash_pandas_object(series, index=index) + tm.assert_series_equal(a, b) + + +def test_hash_pandas_series_diff_index(series): + a = hash_pandas_object(series, index=True) + b = hash_pandas_object(series, index=False) + assert not (a == b).all() + + +@pytest.mark.parametrize("klass", [Index, Series]) +@pytest.mark.parametrize("dtype", ["float64", "object"]) +def test_hash_pandas_empty_object(klass, dtype, index): + # These are by-definition the same with + # or without the index as the data is empty. + obj = klass([], dtype=dtype) + a = hash_pandas_object(obj, index=index) + b = hash_pandas_object(obj, index=index) + tm.assert_series_equal(a, b) + + +@pytest.mark.parametrize( + "s1", + [ + ["a", "b", "c", "d"], + [1000, 2000, 3000, 4000], + pd.date_range(0, periods=4), + ], +) +@pytest.mark.parametrize("categorize", [True, False]) +def test_categorical_consistency(s1, categorize): + # see gh-15143 + # + # Check that categoricals hash consistent with their values, + # not codes. This should work for categoricals of any dtype. + s1 = Series(s1) + s2 = s1.astype("category").cat.set_categories(s1) + s3 = s2.cat.set_categories(list(reversed(s1))) + + # These should all hash identically. + h1 = hash_pandas_object(s1, categorize=categorize) + h2 = hash_pandas_object(s2, categorize=categorize) + h3 = hash_pandas_object(s3, categorize=categorize) + + tm.assert_series_equal(h1, h2) + tm.assert_series_equal(h1, h3) + + +def test_categorical_with_nan_consistency(unit): + dti = pd.date_range("2012-01-01", periods=5, name="B", unit=unit) + cat = pd.Categorical.from_codes([-1, 0, 1, 2, 3, 4], categories=dti) + expected = hash_array(cat, categorize=False) + + ts = pd.Timestamp("2012-01-01").as_unit(unit) + cat2 = pd.Categorical.from_codes([-1, 0], categories=[ts]) + result = hash_array(cat2, categorize=False) + + assert result[0] in expected + assert result[1] in expected + + +def test_pandas_errors(): + msg = "Unexpected type for hashing" + with pytest.raises(TypeError, match=msg): + hash_pandas_object(pd.Timestamp("20130101")) + + +def test_hash_keys(): + # Using different hash keys, should have + # different hashes for the same data. + # + # This only matters for object dtypes. + obj = Series(list("abc")) + + a = hash_pandas_object(obj, hash_key="9876543210123456") + b = hash_pandas_object(obj, hash_key="9876543210123465") + + assert (a != b).all() + + +def test_df_hash_keys(): + # DataFrame version of the test_hash_keys. + # https://github.com/pandas-dev/pandas/issues/41404 + obj = DataFrame({"x": np.arange(3), "y": list("abc")}) + + a = hash_pandas_object(obj, hash_key="9876543210123456") + b = hash_pandas_object(obj, hash_key="9876543210123465") + + assert (a != b).all() + + +def test_df_encoding(): + # Check that DataFrame recognizes optional encoding. + # https://github.com/pandas-dev/pandas/issues/41404 + # https://github.com/pandas-dev/pandas/pull/42049 + obj = DataFrame({"x": np.arange(3), "y": list("a+c")}) + + a = hash_pandas_object(obj, encoding="utf8") + b = hash_pandas_object(obj, encoding="utf7") + + # Note that the "+" is encoded as "+-" in utf-7. + assert a[0] == b[0] + assert a[1] != b[1] + assert a[2] == b[2] + + +def test_invalid_key(): + # This only matters for object dtypes. + msg = "key should be a 16-byte string encoded" + + with pytest.raises(ValueError, match=msg): + hash_pandas_object(Series(list("abc")), hash_key="foo") + + +def test_already_encoded(index): + # If already encoded, then ok. + obj = Series(list("abc")).str.encode("utf8") + a = hash_pandas_object(obj, index=index) + b = hash_pandas_object(obj, index=index) + tm.assert_series_equal(a, b) + + +def test_alternate_encoding(index): + obj = Series(list("abc")) + a = hash_pandas_object(obj, index=index) + b = hash_pandas_object(obj, index=index) + tm.assert_series_equal(a, b) + + +@pytest.mark.parametrize("l_exp", range(8)) +@pytest.mark.parametrize("l_add", [0, 1]) +def test_same_len_hash_collisions(l_exp, l_add): + length = 2 ** (l_exp + 8) + l_add + idx = np.array([str(i) for i in range(length)], dtype=object) + + result = hash_array(idx, "utf8") + assert not result[0] == result[1] + + +def test_hash_collisions(): + # Hash collisions are bad. + # + # https://github.com/pandas-dev/pandas/issues/14711#issuecomment-264885726 + hashes = [ + "Ingrid-9Z9fKIZmkO7i7Cn51Li34pJm44fgX6DYGBNj3VPlOH50m7HnBlPxfIwFMrcNJNMP6PSgLmwWnInciMWrCSAlLEvt7JkJl4IxiMrVbXSa8ZQoVaq5xoQPjltuJEfwdNlO6jo8qRRHvD8sBEBMQASrRa6TsdaPTPCBo3nwIBpE7YzzmyH0vMBhjQZLx1aCT7faSEx7PgFxQhHdKFWROcysamgy9iVj8DO2Fmwg1NNl93rIAqC3mdqfrCxrzfvIY8aJdzin2cHVzy3QUJxZgHvtUtOLxoqnUHsYbNTeq0xcLXpTZEZCxD4PGubIuCNf32c33M7HFsnjWSEjE2yVdWKhmSVodyF8hFYVmhYnMCztQnJrt3O8ZvVRXd5IKwlLexiSp4h888w7SzAIcKgc3g5XQJf6MlSMftDXm9lIsE1mJNiJEv6uY6pgvC3fUPhatlR5JPpVAHNSbSEE73MBzJrhCAbOLXQumyOXigZuPoME7QgJcBalliQol7YZ9", + "Tim-b9MddTxOWW2AT1Py6vtVbZwGAmYCjbp89p8mxsiFoVX4FyDOF3wFiAkyQTUgwg9sVqVYOZo09Dh1AzhFHbgij52ylF0SEwgzjzHH8TGY8Lypart4p4onnDoDvVMBa0kdthVGKl6K0BDVGzyOXPXKpmnMF1H6rJzqHJ0HywfwS4XYpVwlAkoeNsiicHkJUFdUAhG229INzvIAiJuAHeJDUoyO4DCBqtoZ5TDend6TK7Y914yHlfH3g1WZu5LksKv68VQHJriWFYusW5e6ZZ6dKaMjTwEGuRgdT66iU5nqWTHRH8WSzpXoCFwGcTOwyuqPSe0fTe21DVtJn1FKj9F9nEnR9xOvJUO7E0piCIF4Ad9yAIDY4DBimpsTfKXCu1vdHpKYerzbndfuFe5AhfMduLYZJi5iAw8qKSwR5h86ttXV0Mc0QmXz8dsRvDgxjXSmupPxBggdlqUlC828hXiTPD7am0yETBV0F3bEtvPiNJfremszcV8NcqAoARMe", + ] + + # These should be different. + result1 = hash_array(np.asarray(hashes[0:1], dtype=object), "utf8") + expected1 = np.array([14963968704024874985], dtype=np.uint64) + tm.assert_numpy_array_equal(result1, expected1) + + result2 = hash_array(np.asarray(hashes[1:2], dtype=object), "utf8") + expected2 = np.array([16428432627716348016], dtype=np.uint64) + tm.assert_numpy_array_equal(result2, expected2) + + result = hash_array(np.asarray(hashes, dtype=object), "utf8") + tm.assert_numpy_array_equal(result, np.concatenate([expected1, expected2], axis=0)) + + +@pytest.mark.parametrize( + "data, result_data", + [ + [[tuple("1"), tuple("2")], [10345501319357378243, 8331063931016360761]], + [[(1,), (2,)], [9408946347443669104, 3278256261030523334]], + ], +) +def test_hash_with_tuple(data, result_data): + # GH#28969 array containing a tuple raises on call to arr.astype(str) + # apparently a numpy bug github.com/numpy/numpy/issues/9441 + + df = DataFrame({"data": data}) + result = hash_pandas_object(df) + expected = Series(result_data, dtype=np.uint64) + tm.assert_series_equal(result, expected) + + +def test_hashable_tuple_args(): + # require that the elements of such tuples are themselves hashable + + df3 = DataFrame( + { + "data": [ + ( + 1, + [], + ), + ( + 2, + {}, + ), + ] + } + ) + with pytest.raises(TypeError, match="unhashable type: 'list'"): + hash_pandas_object(df3) + + +def test_hash_object_none_key(): + # https://github.com/pandas-dev/pandas/issues/30887 + result = pd.util.hash_pandas_object(Series(["a", "b"]), hash_key=None) + expected = Series([4578374827886788867, 17338122309987883691], dtype="uint64") + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/util/test_numba.py b/pandas/tests/util/test_numba.py new file mode 100644 index 0000000000000000000000000000000000000000..5dc4fa96f15ae5e279082747eca22b11c8bd4cc0 --- /dev/null +++ b/pandas/tests/util/test_numba.py @@ -0,0 +1,12 @@ +import pytest + +import pandas.util._test_decorators as td + +from pandas import option_context + + +@td.skip_if_installed("numba") +def test_numba_not_installed_option_context(): + with pytest.raises(ImportError, match="`Import numba` failed"): + with option_context("compute.use_numba", True): + pass diff --git a/pandas/tests/util/test_rewrite_warning.py b/pandas/tests/util/test_rewrite_warning.py new file mode 100644 index 0000000000000000000000000000000000000000..3db5e44d4fceaaa07f000abf9abc79764c6effbc --- /dev/null +++ b/pandas/tests/util/test_rewrite_warning.py @@ -0,0 +1,42 @@ +import warnings + +import pytest + +from pandas.util._exceptions import rewrite_warning + +import pandas._testing as tm + + +@pytest.mark.parametrize( + "target_category, target_message, hit", + [ + (FutureWarning, "Target message", True), + (FutureWarning, "Target", True), + (FutureWarning, "get mess", True), + (FutureWarning, "Missed message", False), + (DeprecationWarning, "Target message", False), + ], +) +@pytest.mark.parametrize( + "new_category", + [ + None, + DeprecationWarning, + ], +) +def test_rewrite_warning(target_category, target_message, hit, new_category): + new_message = "Rewritten message" + if hit: + expected_category = new_category if new_category else target_category + expected_message = new_message + else: + expected_category = FutureWarning + expected_message = "Target message" + with tm.assert_produces_warning(expected_category, match=expected_message): + with rewrite_warning( + target_message, target_category, new_message, new_category + ): + warnings.warn( + message="Target message", + category=FutureWarning, # pdlint: ignore[warning_class] + ) diff --git a/pandas/tests/util/test_shares_memory.py b/pandas/tests/util/test_shares_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..94bc51dca3f60c5c3765a0c30e5738c52629dfa7 --- /dev/null +++ b/pandas/tests/util/test_shares_memory.py @@ -0,0 +1,46 @@ +import numpy as np + +import pandas.util._test_decorators as td + +import pandas as pd +import pandas._testing as tm + + +def test_shares_memory_interval(): + obj = pd.interval_range(1, 5) + + assert tm.shares_memory(obj, obj) + assert tm.shares_memory(obj, obj._data) + assert tm.shares_memory(obj, obj[::-1]) + assert tm.shares_memory(obj, obj[:2]) + + assert not tm.shares_memory(obj, obj._data.copy()) + + +@td.skip_if_no("pyarrow") +def test_shares_memory_string(): + # GH#55823 + import pyarrow as pa + + obj = pd.array(["a", "b"], dtype=pd.StringDtype("pyarrow", na_value=pd.NA)) + assert tm.shares_memory(obj, obj) + + obj = pd.array(["a", "b"], dtype=pd.StringDtype("pyarrow", na_value=np.nan)) + assert tm.shares_memory(obj, obj) + + obj = pd.array(["a", "b"], dtype=pd.ArrowDtype(pa.string())) + assert tm.shares_memory(obj, obj) + + +def test_shares_memory_numpy(): + arr = np.arange(10) + view = arr[:5] + assert tm.shares_memory(arr, view) + arr2 = np.arange(10) + assert not tm.shares_memory(arr, arr2) + + +def test_shares_memory_rangeindex(): + idx = pd.RangeIndex(10) + arr = np.arange(10) + assert not tm.shares_memory(idx, arr) diff --git a/pandas/tests/util/test_show_versions.py b/pandas/tests/util/test_show_versions.py new file mode 100644 index 0000000000000000000000000000000000000000..72c9db23b210880793f37227c99e99e804800f08 --- /dev/null +++ b/pandas/tests/util/test_show_versions.py @@ -0,0 +1,81 @@ +import json +import os +import re + +from pandas.util._print_versions import ( + _get_dependency_info, + _get_sys_info, +) + +import pandas as pd + + +def test_show_versions(tmpdir): + # GH39701 + as_json = os.path.join(tmpdir, "test_output.json") + + pd.show_versions(as_json=as_json) + + with open(as_json, encoding="utf-8") as fd: + # check if file output is valid JSON, will raise an exception if not + result = json.load(fd) + + # Basic check that each version element is found in output + expected = { + "system": _get_sys_info(), + "dependencies": _get_dependency_info(), + } + + assert result == expected + + +def test_show_versions_console_json(capsys): + # GH39701 + pd.show_versions(as_json=True) + stdout = capsys.readouterr().out + + # check valid json is printed to the console if as_json is True + result = json.loads(stdout) + + # Basic check that each version element is found in output + expected = { + "system": _get_sys_info(), + "dependencies": _get_dependency_info(), + } + + assert result == expected + + +def test_show_versions_console(capsys): + # gh-32041 + # gh-32041 + pd.show_versions(as_json=False) + result = capsys.readouterr().out + + # check header + assert "INSTALLED VERSIONS" in result + + # check full commit hash + assert re.search(r"commit\s*:\s[0-9a-f]{40}\n", result) + + # check required dependency + # 2020-12-09 npdev has "dirty" in the tag + # 2022-05-25 npdev released with RC wo/ "dirty". + # Just ensure we match [0-9]+\..* since npdev version is variable + assert re.search(r"numpy\s*:\s[0-9]+\..*\n", result) + + # check optional dependency + assert re.search(r"pyarrow\s*:\s([0-9]+.*|None)\n", result) + + +def test_json_output_match(capsys, tmpdir): + # GH39701 + pd.show_versions(as_json=True) + result_console = capsys.readouterr().out + + out_path = os.path.join(tmpdir, "test_json.json") + pd.show_versions(as_json=out_path) + with open(out_path, encoding="utf-8") as out_fd: + result_file = out_fd.read() + + assert result_console == result_file diff --git a/pandas/tests/util/test_util.py b/pandas/tests/util/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..dfb8587d3924e1441ac9da0aeeaa5585c6b4fe6c --- /dev/null +++ b/pandas/tests/util/test_util.py @@ -0,0 +1,58 @@ +import os + +import pytest + +from pandas import ( + array, + compat, +) +import pandas._testing as tm + + +def test_numpy_err_state_is_default(): + expected = {"over": "warn", "divide": "warn", "invalid": "warn", "under": "ignore"} + import numpy as np + + # The error state should be unchanged after that import. + assert np.geterr() == expected + + +def test_convert_rows_list_to_csv_str(): + rows_list = ["aaa", "bbb", "ccc"] + ret = tm.convert_rows_list_to_csv_str(rows_list) + + if compat.is_platform_windows(): + expected = "aaa\r\nbbb\r\nccc\r\n" + else: + expected = "aaa\nbbb\nccc\n" + + assert ret == expected + + +@pytest.mark.parametrize("strict_data_files", [True, False]) +def test_datapath_missing(datapath): + with pytest.raises(ValueError, match="Could not find file"): + datapath("not_a_file") + + +def test_datapath(datapath): + args = ("io", "data", "csv", "iris.csv") + + result = datapath(*args) + expected = os.path.join(os.path.dirname(os.path.dirname(__file__)), *args) + + assert result == expected + + +def test_external_error_raised(): + with tm.external_error_raised(TypeError): + raise TypeError("Should not check this error message, so it will pass") + + +def test_is_sorted(): + arr = array([1, 2, 3], dtype="Int64") + tm.assert_is_sorted(arr) + + arr = array([4, 2, 3], dtype="Int64") + with pytest.raises(AssertionError, match="ExtensionArray are different"): + tm.assert_is_sorted(arr) diff --git a/pandas/tests/util/test_validate_args.py b/pandas/tests/util/test_validate_args.py new file mode 100644 index 0000000000000000000000000000000000000000..eef0931ec28efd02e3db7a85b0b3260742c1ff2d --- /dev/null +++ b/pandas/tests/util/test_validate_args.py @@ -0,0 +1,70 @@ +import pytest + +from pandas.util._validators import validate_args + + +@pytest.fixture +def _fname(): + return "func" + + +def test_bad_min_fname_arg_count(_fname): + msg = "'max_fname_arg_count' must be non-negative" + + with pytest.raises(ValueError, match=msg): + validate_args(_fname, (None,), -1, "foo") + + +def test_bad_arg_length_max_value_single(_fname): + args = (None, None) + compat_args = ("foo",) + + min_fname_arg_count = 0 + max_length = len(compat_args) + min_fname_arg_count + actual_length = len(args) + min_fname_arg_count + msg = ( + rf"{_fname}\(\) takes at most {max_length} " + rf"argument \({actual_length} given\)" + ) + + with pytest.raises(TypeError, match=msg): + validate_args(_fname, args, min_fname_arg_count, compat_args) + + +def test_bad_arg_length_max_value_multiple(_fname): + args = (None, None) + compat_args = {"foo": None} + + min_fname_arg_count = 2 + max_length = len(compat_args) + min_fname_arg_count + actual_length = len(args) + min_fname_arg_count + msg = ( + rf"{_fname}\(\) takes at most {max_length} " + rf"arguments \({actual_length} given\)" + ) + + with pytest.raises(TypeError, match=msg): + validate_args(_fname, args, min_fname_arg_count, compat_args) + + +@pytest.mark.parametrize("i", range(1, 3)) +def test_not_all_defaults(i, _fname): + bad_arg = "foo" + msg = ( + f"the '{bad_arg}' parameter is not supported " + rf"in the pandas implementation of {_fname}\(\)" + ) + + compat_args = {"foo": 2, "bar": -1, "baz": 3} + arg_vals = (1, -1, 3) + + with pytest.raises(ValueError, match=msg): + validate_args(_fname, arg_vals[:i], 2, compat_args) + + +def test_validation(_fname): + # No exceptions should be raised. + validate_args(_fname, (None,), 2, {"out": None}) + + compat_args = {"axis": 1, "out": None} + validate_args(_fname, (1, None), 2, compat_args) diff --git a/pandas/tests/util/test_validate_args_and_kwargs.py b/pandas/tests/util/test_validate_args_and_kwargs.py new file mode 100644 index 0000000000000000000000000000000000000000..215026d648471c04cb8751506c03626fda73fc68 --- /dev/null +++ b/pandas/tests/util/test_validate_args_and_kwargs.py @@ -0,0 +1,84 @@ +import pytest + +from pandas.util._validators import validate_args_and_kwargs + + +@pytest.fixture +def _fname(): + return "func" + + +def test_invalid_total_length_max_length_one(_fname): + compat_args = ("foo",) + kwargs = {"foo": "FOO"} + args = ("FoO", "BaZ") + + min_fname_arg_count = 0 + max_length = len(compat_args) + min_fname_arg_count + actual_length = len(kwargs) + len(args) + min_fname_arg_count + + msg = ( + rf"{_fname}\(\) takes at most {max_length} " + rf"argument \({actual_length} given\)" + ) + + with pytest.raises(TypeError, match=msg): + validate_args_and_kwargs(_fname, args, kwargs, min_fname_arg_count, compat_args) + + +def test_invalid_total_length_max_length_multiple(_fname): + compat_args = ("foo", "bar", "baz") + kwargs = {"foo": "FOO", "bar": "BAR"} + args = ("FoO", "BaZ") + + min_fname_arg_count = 2 + max_length = len(compat_args) + min_fname_arg_count + actual_length = len(kwargs) + len(args) + min_fname_arg_count + + msg = ( + rf"{_fname}\(\) takes at most {max_length} " + rf"arguments \({actual_length} given\)" + ) + + with pytest.raises(TypeError, match=msg): + validate_args_and_kwargs(_fname, args, kwargs, min_fname_arg_count, compat_args) + + +@pytest.mark.parametrize("args,kwargs", [((), {"foo": -5, "bar": 2}), ((-5, 2), {})]) +def test_missing_args_or_kwargs(args, kwargs, _fname): + bad_arg = "bar" + min_fname_arg_count = 2 + + compat_args = {"foo": -5, bad_arg: 1} + + msg = ( + rf"the '{bad_arg}' parameter is not supported " + rf"in the pandas implementation of {_fname}\(\)" + ) + + with pytest.raises(ValueError, match=msg): + validate_args_and_kwargs(_fname, args, kwargs, min_fname_arg_count, compat_args) + + +def test_duplicate_argument(_fname): + min_fname_arg_count = 2 + + compat_args = {"foo": None, "bar": None, "baz": None} + kwargs = {"foo": None, "bar": None} + args = (None,) # duplicate value for "foo" + + msg = rf"{_fname}\(\) got multiple values for keyword argument 'foo'" + + with pytest.raises(TypeError, match=msg): + validate_args_and_kwargs(_fname, args, kwargs, min_fname_arg_count, compat_args) + + +def test_validation(_fname): + # No exceptions should be raised. + compat_args = {"foo": 1, "bar": None, "baz": -2} + kwargs = {"baz": -2} + + args = (1, None) + min_fname_arg_count = 2 + + validate_args_and_kwargs(_fname, args, kwargs, min_fname_arg_count, compat_args) diff --git a/pandas/tests/util/test_validate_inclusive.py b/pandas/tests/util/test_validate_inclusive.py new file mode 100644 index 0000000000000000000000000000000000000000..c1254c614ab305c447090b148ea6a036569f76e6 --- /dev/null +++ b/pandas/tests/util/test_validate_inclusive.py @@ -0,0 +1,40 @@ +import numpy as np +import pytest + +from pandas.util._validators import validate_inclusive + +import pandas as pd + + +@pytest.mark.parametrize( + "invalid_inclusive", + ( + "ccc", + 2, + object(), + None, + np.nan, + pd.NA, + pd.DataFrame(), + ), +) +def test_invalid_inclusive(invalid_inclusive): + with pytest.raises( + ValueError, + match="Inclusive has to be either 'both', 'neither', 'left' or 'right'", + ): + validate_inclusive(invalid_inclusive) + + +@pytest.mark.parametrize( + "valid_inclusive, expected_tuple", + ( + ("left", (True, False)), + ("right", (False, True)), + ("both", (True, True)), + ("neither", (False, False)), + ), +) +def test_valid_inclusive(valid_inclusive, expected_tuple): + resultant_tuple = validate_inclusive(valid_inclusive) + assert expected_tuple == resultant_tuple diff --git a/pandas/tests/util/test_validate_kwargs.py b/pandas/tests/util/test_validate_kwargs.py new file mode 100644 index 0000000000000000000000000000000000000000..85d93638f788fa16a32ac3d83392a71c17f3cd7c --- /dev/null +++ b/pandas/tests/util/test_validate_kwargs.py @@ -0,0 +1,69 @@ +import pytest + +from pandas.util._validators import ( + validate_bool_kwarg, + validate_kwargs, +) + + +@pytest.fixture +def _fname(): + return "func" + + +def test_bad_kwarg(_fname): + good_arg = "f" + bad_arg = good_arg + "o" + + compat_args = {good_arg: "foo", bad_arg + "o": "bar"} + kwargs = {good_arg: "foo", bad_arg: "bar"} + + msg = rf"{_fname}\(\) got an unexpected keyword argument '{bad_arg}'" + + with pytest.raises(TypeError, match=msg): + validate_kwargs(_fname, kwargs, compat_args) + + +@pytest.mark.parametrize("i", range(1, 3)) +def test_not_all_none(i, _fname): + bad_arg = "foo" + msg = ( + rf"the '{bad_arg}' parameter is not supported " + rf"in the pandas implementation of {_fname}\(\)" + ) + + compat_args = {"foo": 1, "bar": "s", "baz": None} + + kwarg_keys = ("foo", "bar", "baz") + kwarg_vals = (2, "s", None) + + kwargs = dict(zip(kwarg_keys[:i], kwarg_vals[:i], strict=True)) + + with pytest.raises(ValueError, match=msg): + validate_kwargs(_fname, kwargs, compat_args) + + +def test_validation(_fname): + # No exceptions should be raised. + compat_args = {"f": None, "b": 1, "ba": "s"} + + kwargs = {"f": None, "b": 1} + validate_kwargs(_fname, kwargs, compat_args) + + +@pytest.mark.parametrize("name", ["inplace", "copy"]) +@pytest.mark.parametrize("value", [1, "True", [1, 2, 3], 5.0]) +def test_validate_bool_kwarg_fail(name, value): + msg = ( + f'For argument "{name}" expected type bool, ' + f"received type {type(value).__name__}" + ) + + with pytest.raises(ValueError, match=msg): + validate_bool_kwarg(value, name) + + +@pytest.mark.parametrize("name", ["inplace", "copy"]) +@pytest.mark.parametrize("value", [True, False, None]) +def test_validate_bool_kwarg(name, value): + assert validate_bool_kwarg(value, name) == value diff --git a/pandas/tests/window/__init__.py b/pandas/tests/window/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pandas/tests/window/conftest.py b/pandas/tests/window/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..fe873b3b74254c5aeb6fbec48db19cd27e37dc1b --- /dev/null +++ b/pandas/tests/window/conftest.py @@ -0,0 +1,124 @@ +from datetime import ( + datetime, + timedelta, +) + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas import ( + DataFrame, + Series, + bdate_range, +) + + +@pytest.fixture(params=[True, False]) +def raw(request): + """raw keyword argument for rolling.apply""" + return request.param + + +@pytest.fixture( + params=[ + "sum", + "mean", + "median", + "max", + "min", + "var", + "std", + "kurt", + "skew", + "count", + "sem", + ] +) +def arithmetic_win_operators(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def center(request): + return request.param + + +@pytest.fixture(params=[None, 1]) +def min_periods(request): + return request.param + + +@pytest.fixture(params=[True, False]) +def adjust(request): + """adjust keyword argument for ewm""" + return request.param + + +@pytest.fixture(params=[True, False]) +def ignore_na(request): + """ignore_na keyword argument for ewm""" + return request.param + + +@pytest.fixture(params=[True, False]) +def numeric_only(request): + """numeric_only keyword argument""" + return request.param + + +@pytest.fixture( + params=[ + pytest.param("numba", marks=[td.skip_if_no("numba"), pytest.mark.single_cpu]), + "cython", + ] +) +def engine(request): + """engine keyword argument for rolling.apply""" + return request.param + + +@pytest.fixture( + params=[ + pytest.param( + ("numba", True), marks=[td.skip_if_no("numba"), pytest.mark.single_cpu] + ), + ("cython", True), + ("cython", False), + ] +) +def engine_and_raw(request): + """engine and raw keyword arguments for rolling.apply""" + return request.param + + +@pytest.fixture(params=["1 day", timedelta(days=1), np.timedelta64(1, "D")]) +def halflife_with_times(request): + """Halflife argument for EWM when times is specified.""" + return request.param + + +@pytest.fixture +def series(): + """Make mocked series as fixture.""" + arr = np.random.default_rng(2).standard_normal(100) + locs = np.arange(20, 40) + arr[locs] = np.nan + series = Series(arr, index=bdate_range(datetime(2009, 1, 1), periods=100)) + return series + + +@pytest.fixture +def frame(): + """Make mocked frame as fixture.""" + return DataFrame( + np.random.default_rng(2).standard_normal((100, 10)), + index=bdate_range(datetime(2009, 1, 1), periods=100), + ) + + +@pytest.fixture(params=[None, 1, 2, 5, 10]) +def step(request): + """step keyword argument for rolling window operations.""" + return request.param diff --git a/pandas/tests/window/test_api.py b/pandas/tests/window/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..877b50e37670c0f2a5b9d9f816083a35e265ccbe --- /dev/null +++ b/pandas/tests/window/test_api.py @@ -0,0 +1,385 @@ +import numpy as np +import pytest + +from pandas.errors import ( + DataError, + SpecificationError, +) + +from pandas import ( + DataFrame, + Index, + MultiIndex, + Period, + Series, + Timestamp, + concat, + date_range, + timedelta_range, +) +import pandas._testing as tm + + +def test_getitem(step): + frame = DataFrame(np.random.default_rng(2).standard_normal((5, 5))) + r = frame.rolling(window=5, step=step) + tm.assert_index_equal(r._selected_obj.columns, frame[::step].columns) + + r = frame.rolling(window=5, step=step)[1] + assert r._selected_obj.name == frame[::step].columns[1] + + # technically this is allowed + r = frame.rolling(window=5, step=step)[1, 3] + tm.assert_index_equal(r._selected_obj.columns, frame[::step].columns[[1, 3]]) + + r = frame.rolling(window=5, step=step)[[1, 3]] + tm.assert_index_equal(r._selected_obj.columns, frame[::step].columns[[1, 3]]) + + +def test_select_bad_cols(): + df = DataFrame([[1, 2]], columns=["A", "B"]) + g = df.rolling(window=5) + with pytest.raises(KeyError, match="Columns not found: 'C'"): + g[["C"]] + with pytest.raises(KeyError, match="^[^A]+$"): + # A should not be referenced as a bad column... + # will have to rethink regex if you change message! + g[["A", "C"]] + + +def test_attribute_access(): + df = DataFrame([[1, 2]], columns=["A", "B"]) + r = df.rolling(window=5) + tm.assert_series_equal(r.A.sum(), r["A"].sum()) + msg = "'Rolling' object has no attribute 'F'" + with pytest.raises(AttributeError, match=msg): + r.F + + +def tests_skip_nuisance(step): + df = DataFrame({"A": range(5), "B": range(5, 10), "C": "foo"}) + r = df.rolling(window=3, step=step) + result = r[["A", "B"]].sum() + expected = DataFrame( + {"A": [np.nan, np.nan, 3, 6, 9], "B": [np.nan, np.nan, 18, 21, 24]}, + columns=list("AB"), + )[::step] + tm.assert_frame_equal(result, expected) + + +def test_sum_object_str_raises(step): + df = DataFrame({"A": range(5), "B": range(5, 10), "C": "foo"}) + r = df.rolling(window=3, step=step) + with pytest.raises( + DataError, match="Cannot aggregate non-numeric type: object|str" + ): + # GH#42738, enforced in 2.0 + r.sum() + + +def test_agg(step): + df = DataFrame({"A": range(5), "B": range(0, 10, 2)}) + + r = df.rolling(window=3, step=step) + a_mean = r["A"].mean() + a_std = r["A"].std() + a_sum = r["A"].sum() + b_mean = r["B"].mean() + b_std = r["B"].std() + + result = r.aggregate([np.mean, lambda x: np.std(x, ddof=1)]) + expected = concat([a_mean, a_std, b_mean, b_std], axis=1) + expected.columns = MultiIndex.from_product([["A", "B"], ["mean", ""]]) + tm.assert_frame_equal(result, expected) + + result = r.aggregate({"A": np.mean, "B": lambda x: np.std(x, ddof=1)}) + + expected = concat([a_mean, b_std], axis=1) + tm.assert_frame_equal(result, expected, check_like=True) + + result = r.aggregate({"A": ["mean", "std"]}) + expected = concat([a_mean, a_std], axis=1) + expected.columns = MultiIndex.from_tuples([("A", "mean"), ("A", "std")]) + tm.assert_frame_equal(result, expected) + + result = r["A"].aggregate(["mean", "sum"]) + expected = concat([a_mean, a_sum], axis=1) + expected.columns = ["mean", "sum"] + tm.assert_frame_equal(result, expected) + + msg = "nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + # using a dict with renaming + r.aggregate({"A": {"mean": "mean", "sum": "sum"}}) + + with pytest.raises(SpecificationError, match=msg): + r.aggregate( + {"A": {"mean": "mean", "sum": "sum"}, "B": {"mean2": "mean", "sum2": "sum"}} + ) + + result = r.aggregate({"A": ["mean", "std"], "B": ["mean", "std"]}) + expected = concat([a_mean, a_std, b_mean, b_std], axis=1) + + exp_cols = [("A", "mean"), ("A", "std"), ("B", "mean"), ("B", "std")] + expected.columns = MultiIndex.from_tuples(exp_cols) + tm.assert_frame_equal(result, expected, check_like=True) + + +def test_agg_apply(raw): + # passed lambda + df = DataFrame({"A": range(5), "B": range(0, 10, 2)}) + + r = df.rolling(window=3) + a_sum = r["A"].sum() + + result = r.agg({"A": np.sum, "B": lambda x: np.std(x, ddof=1)}) + rcustom = r["B"].apply(lambda x: np.std(x, ddof=1), raw=raw) + expected = concat([a_sum, rcustom], axis=1) + tm.assert_frame_equal(result, expected, check_like=True) + + +def test_agg_consistency(step): + df = DataFrame({"A": range(5), "B": range(0, 10, 2)}) + r = df.rolling(window=3, step=step) + + result = r.agg([np.sum, np.mean]).columns + expected = MultiIndex.from_product([list("AB"), ["sum", "mean"]]) + tm.assert_index_equal(result, expected) + + result = r["A"].agg([np.sum, np.mean]).columns + expected = Index(["sum", "mean"]) + tm.assert_index_equal(result, expected) + + result = r.agg({"A": [np.sum, np.mean]}).columns + expected = MultiIndex.from_tuples([("A", "sum"), ("A", "mean")]) + tm.assert_index_equal(result, expected) + + +def test_agg_nested_dicts(): + # API change for disallowing these types of nested dicts + df = DataFrame({"A": range(5), "B": range(0, 10, 2)}) + r = df.rolling(window=3) + + msg = "nested renamer is not supported" + with pytest.raises(SpecificationError, match=msg): + r.aggregate({"r1": {"A": ["mean", "sum"]}, "r2": {"B": ["mean", "sum"]}}) + + expected = concat( + [r["A"].mean(), r["A"].std(), r["B"].mean(), r["B"].std()], axis=1 + ) + expected.columns = MultiIndex.from_tuples( + [("ra", "mean"), ("ra", "std"), ("rb", "mean"), ("rb", "std")] + ) + with pytest.raises(SpecificationError, match=msg): + r[["A", "B"]].agg({"A": {"ra": ["mean", "std"]}, "B": {"rb": ["mean", "std"]}}) + + with pytest.raises(SpecificationError, match=msg): + r.agg({"A": {"ra": ["mean", "std"]}, "B": {"rb": ["mean", "std"]}}) + + +@pytest.mark.parametrize( + "func,window_size", + [ + ( + "rolling", + 2, + ), + ( + "expanding", + None, + ), + ], +) +def test_pipe(func, window_size): + # Issue #57076 + df = DataFrame( + { + "B": np.random.default_rng(2).standard_normal(10), + "C": np.random.default_rng(2).standard_normal(10), + } + ) + r = getattr(df, func)(window_size) + + expected = r.max() - r.mean() + result = r.pipe(lambda x: x.max() - x.mean()) + tm.assert_frame_equal(result, expected) + + expected = r.max() - 2 * r.min() + result = r.pipe(lambda x, k: x.max() - k * x.min(), k=2) + tm.assert_frame_equal(result, expected) + + +def test_count_nonnumeric_types(step): + # GH12541 + cols = [ + "int", + "float", + "string", + "datetime", + "timedelta", + "periods", + "fl_inf", + "fl_nan", + "str_nan", + "dt_nat", + "periods_nat", + ] + dt_nat_col = [Timestamp("20170101"), Timestamp("20170203"), Timestamp(None)] + + df = DataFrame( + { + "int": [1, 2, 3], + "float": [4.0, 5.0, 6.0], + "string": list("abc"), + "datetime": date_range("20170101", periods=3), + "timedelta": timedelta_range("1 s", periods=3, freq="s"), + "periods": [ + Period("2012-01"), + Period("2012-02"), + Period("2012-03"), + ], + "fl_inf": [1.0, 2.0, np.inf], + "fl_nan": [1.0, 2.0, np.nan], + "str_nan": ["aa", "bb", np.nan], + "dt_nat": dt_nat_col, + "periods_nat": [ + Period("2012-01"), + Period("2012-02"), + Period(None), + ], + }, + columns=cols, + ) + + expected = DataFrame( + { + "int": [1.0, 2.0, 2.0], + "float": [1.0, 2.0, 2.0], + "string": [1.0, 2.0, 2.0], + "datetime": [1.0, 2.0, 2.0], + "timedelta": [1.0, 2.0, 2.0], + "periods": [1.0, 2.0, 2.0], + "fl_inf": [1.0, 2.0, 2.0], + "fl_nan": [1.0, 2.0, 1.0], + "str_nan": [1.0, 2.0, 1.0], + "dt_nat": [1.0, 2.0, 1.0], + "periods_nat": [1.0, 2.0, 1.0], + }, + columns=cols, + )[::step] + + result = df.rolling(window=2, min_periods=0, step=step).count() + tm.assert_frame_equal(result, expected) + + result = df.rolling(1, min_periods=0, step=step).count() + expected = df.notna().astype(float)[::step] + tm.assert_frame_equal(result, expected) + + +def test_preserve_metadata(): + # GH 10565 + s = Series(np.arange(100), name="foo") + + s2 = s.rolling(30).sum() + s3 = s.rolling(20).sum() + assert s2.name == "foo" + assert s3.name == "foo" + + +@pytest.mark.parametrize( + "func,window_size,expected_vals", + [ + ( + "rolling", + 2, + [ + [np.nan, np.nan, np.nan, np.nan], + [15.0, 20.0, 25.0, 20.0], + [25.0, 30.0, 35.0, 30.0], + [np.nan, np.nan, np.nan, np.nan], + [20.0, 30.0, 35.0, 30.0], + [35.0, 40.0, 60.0, 40.0], + [60.0, 80.0, 85.0, 80], + ], + ), + ( + "expanding", + None, + [ + [10.0, 10.0, 20.0, 20.0], + [15.0, 20.0, 25.0, 20.0], + [20.0, 30.0, 30.0, 20.0], + [10.0, 10.0, 30.0, 30.0], + [20.0, 30.0, 35.0, 30.0], + [26.666667, 40.0, 50.0, 30.0], + [40.0, 80.0, 60.0, 30.0], + ], + ), + ], +) +def test_multiple_agg_funcs(func, window_size, expected_vals): + # GH 15072 + df = DataFrame( + [ + ["A", 10, 20], + ["A", 20, 30], + ["A", 30, 40], + ["B", 10, 30], + ["B", 30, 40], + ["B", 40, 80], + ["B", 80, 90], + ], + columns=["stock", "low", "high"], + ) + + f = getattr(df.groupby("stock"), func) + if window_size: + window = f(window_size) + else: + window = f() + + index = MultiIndex.from_tuples( + [("A", 0), ("A", 1), ("A", 2), ("B", 3), ("B", 4), ("B", 5), ("B", 6)], + names=["stock", None], + ) + columns = MultiIndex.from_tuples( + [("low", "mean"), ("low", "max"), ("high", "mean"), ("high", "min")] + ) + expected = DataFrame(expected_vals, index=index, columns=columns) + + result = window.agg({"low": ["mean", "max"], "high": ["mean", "min"]}) + + tm.assert_frame_equal(result, expected) + + +def test_dont_modify_attributes_after_methods( + arithmetic_win_operators, closed, center, min_periods, step +): + # GH 39554 + roll_obj = Series(range(1)).rolling( + 1, center=center, closed=closed, min_periods=min_periods, step=step + ) + expected = {attr: getattr(roll_obj, attr) for attr in roll_obj._attributes} + getattr(roll_obj, arithmetic_win_operators)() + result = {attr: getattr(roll_obj, attr) for attr in roll_obj._attributes} + assert result == expected + + +def test_rolling_min_min_periods(step): + a = Series([1, 2, 3, 4, 5]) + result = a.rolling(window=100, min_periods=1, step=step).min() + expected = Series(np.ones(len(a)))[::step] + tm.assert_series_equal(result, expected) + msg = "min_periods 5 must be <= window 3" + with pytest.raises(ValueError, match=msg): + Series([1, 2, 3]).rolling(window=3, min_periods=5, step=step).min() + + +def test_rolling_max_min_periods(step): + a = Series([1, 2, 3, 4, 5], dtype=np.float64) + result = a.rolling(window=100, min_periods=1, step=step).max() + expected = a[::step] + tm.assert_almost_equal(result, expected) + msg = "min_periods 5 must be <= window 3" + with pytest.raises(ValueError, match=msg): + Series([1, 2, 3]).rolling(window=3, min_periods=5, step=step).max() diff --git a/pandas/tests/window/test_apply.py b/pandas/tests/window/test_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..2398713585cfbe673c511ee41cffc8172a3595b8 --- /dev/null +++ b/pandas/tests/window/test_apply.py @@ -0,0 +1,318 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + Timestamp, + concat, + date_range, + isna, + notna, +) +import pandas._testing as tm + +from pandas.tseries import offsets + +# suppress warnings about empty slices, as we are deliberately testing +# with a 0-length Series +pytestmark = pytest.mark.filterwarnings( + "ignore:.*(empty slice|0 for slice).*:RuntimeWarning" +) + + +def f(x): + return x[np.isfinite(x)].mean() + + +@pytest.mark.parametrize("bad_raw", [None, 1, 0]) +def test_rolling_apply_invalid_raw(bad_raw): + with pytest.raises(ValueError, match="raw parameter must be `True` or `False`"): + Series(range(3)).rolling(1).apply(len, raw=bad_raw) + + +def test_rolling_apply_out_of_bounds(engine_and_raw): + # gh-1850 + engine, raw = engine_and_raw + + vals = Series([1, 2, 3, 4]) + + result = vals.rolling(10).apply(np.sum, engine=engine, raw=raw) + assert result.isna().all() + + result = vals.rolling(10, min_periods=1).apply(np.sum, engine=engine, raw=raw) + expected = Series([1, 3, 6, 10], dtype=float) + tm.assert_almost_equal(result, expected) + + +@pytest.mark.parametrize("window", [2, "2s"]) +def test_rolling_apply_with_pandas_objects(window): + # 5071 + df = DataFrame( + { + "A": np.random.default_rng(2).standard_normal(5), + "B": np.random.default_rng(2).integers(0, 10, size=5), + }, + index=date_range("20130101", periods=5, freq="s"), + ) + + # we have an equal spaced timeseries index + # so simulate removing the first period + def f(x): + if x.index[0] == df.index[0]: + return np.nan + return x.iloc[-1] + + result = df.rolling(window).apply(f, raw=False) + expected = df.iloc[2:].reindex_like(df) + tm.assert_frame_equal(result, expected) + + with tm.external_error_raised(AttributeError): + df.rolling(window).apply(f, raw=True) + + +def test_rolling_apply(engine_and_raw, step): + engine, raw = engine_and_raw + + expected = Series([], dtype="float64") + result = expected.rolling(10, step=step).apply( + lambda x: x.mean(), engine=engine, raw=raw + ) + tm.assert_series_equal(result, expected) + + # gh-8080 + s = Series([None, None, None]) + result = s.rolling(2, min_periods=0, step=step).apply( + lambda x: len(x), engine=engine, raw=raw + ) + expected = Series([1.0, 2.0, 2.0])[::step] + tm.assert_series_equal(result, expected) + + result = s.rolling(2, min_periods=0, step=step).apply(len, engine=engine, raw=raw) + tm.assert_series_equal(result, expected) + + +def test_all_apply(engine_and_raw): + engine, raw = engine_and_raw + + df = ( + DataFrame( + {"A": date_range("20130101", periods=5, freq="s"), "B": range(5)} + ).set_index("A") + * 2 + ) + er = df.rolling(window=1) + r = df.rolling(window="1s") + + result = r.apply(lambda x: 1, engine=engine, raw=raw) + expected = er.apply(lambda x: 1, engine=engine, raw=raw) + tm.assert_frame_equal(result, expected) + + +def test_ragged_apply(engine_and_raw): + engine, raw = engine_and_raw + + df = DataFrame({"B": range(5)}) + df.index = [ + Timestamp("20130101 09:00:00"), + Timestamp("20130101 09:00:02"), + Timestamp("20130101 09:00:03"), + Timestamp("20130101 09:00:05"), + Timestamp("20130101 09:00:06"), + ] + + f = lambda x: 1 + result = df.rolling(window="1s", min_periods=1).apply(f, engine=engine, raw=raw) + expected = df.copy() + expected["B"] = 1.0 + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="2s", min_periods=1).apply(f, engine=engine, raw=raw) + expected = df.copy() + expected["B"] = 1.0 + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="5s", min_periods=1).apply(f, engine=engine, raw=raw) + expected = df.copy() + expected["B"] = 1.0 + tm.assert_frame_equal(result, expected) + + +def test_invalid_engine(): + with pytest.raises(ValueError, match="engine must be either 'numba' or 'cython'"): + Series(range(1)).rolling(1).apply(lambda x: x, engine="foo") + + +def test_invalid_engine_kwargs_cython(): + with pytest.raises(ValueError, match="cython engine does not accept engine_kwargs"): + Series(range(1)).rolling(1).apply( + lambda x: x, engine="cython", engine_kwargs={"nopython": False} + ) + + +def test_invalid_raw_numba(): + with pytest.raises( + ValueError, match="raw must be `True` when using the numba engine" + ): + Series(range(1)).rolling(1).apply(lambda x: x, raw=False, engine="numba") + + +@pytest.mark.parametrize("args_kwargs", [[None, {"par": 10}], [(10,), None]]) +def test_rolling_apply_args_kwargs(args_kwargs): + # GH 33433 + def numpysum(x, par): + return np.sum(x + par) + + df = DataFrame({"gr": [1, 1], "a": [1, 2]}) + + idx = Index(["gr", "a"]) + expected = DataFrame([[11.0, 11.0], [11.0, 12.0]], columns=idx) + + result = df.rolling(1).apply(numpysum, args=args_kwargs[0], kwargs=args_kwargs[1]) + tm.assert_frame_equal(result, expected) + + midx = MultiIndex.from_tuples([(1, 0), (1, 1)], names=["gr", None]) + expected = Series([11.0, 12.0], index=midx, name="a") + + gb_rolling = df.groupby("gr")["a"].rolling(1) + + result = gb_rolling.apply(numpysum, args=args_kwargs[0], kwargs=args_kwargs[1]) + tm.assert_series_equal(result, expected) + + +def test_nans(raw): + obj = Series(np.random.default_rng(2).standard_normal(50)) + obj[:10] = np.nan + obj[-10:] = np.nan + + result = obj.rolling(50, min_periods=30).apply(f, raw=raw) + tm.assert_almost_equal(result.iloc[-1], np.mean(obj[10:-10])) + + # min_periods is working correctly + result = obj.rolling(20, min_periods=15).apply(f, raw=raw) + assert isna(result.iloc[23]) + assert not isna(result.iloc[24]) + + assert not isna(result.iloc[-6]) + assert isna(result.iloc[-5]) + + obj2 = Series(np.random.default_rng(2).standard_normal(20)) + result = obj2.rolling(10, min_periods=5).apply(f, raw=raw) + assert isna(result.iloc[3]) + assert notna(result.iloc[4]) + + result0 = obj.rolling(20, min_periods=0).apply(f, raw=raw) + result1 = obj.rolling(20, min_periods=1).apply(f, raw=raw) + tm.assert_almost_equal(result0, result1) + + +def test_center(raw): + obj = Series(np.random.default_rng(2).standard_normal(50)) + obj[:10] = np.nan + obj[-10:] = np.nan + + result = obj.rolling(20, min_periods=15, center=True).apply(f, raw=raw) + expected = ( + concat([obj, Series([np.nan] * 9)]) + .rolling(20, min_periods=15) + .apply(f, raw=raw) + .iloc[9:] + .reset_index(drop=True) + ) + tm.assert_series_equal(result, expected) + + +def test_series(raw, series): + result = series.rolling(50).apply(f, raw=raw) + assert isinstance(result, Series) + tm.assert_almost_equal(result.iloc[-1], np.mean(series[-50:])) + + +def test_frame(raw, frame): + result = frame.rolling(50).apply(f, raw=raw) + assert isinstance(result, DataFrame) + tm.assert_series_equal( + result.iloc[-1, :], + frame.iloc[-50:, :].apply(np.mean, axis=0, raw=raw), + check_names=False, + ) + + +def test_time_rule_series(raw, series): + win = 25 + minp = 10 + ser = series[::2].resample("B").mean() + series_result = ser.rolling(window=win, min_periods=minp).apply(f, raw=raw) + last_date = series_result.index[-1] + prev_date = last_date - 24 * offsets.BDay() + + trunc_series = series[::2].truncate(prev_date, last_date) + tm.assert_almost_equal(series_result.iloc[-1], np.mean(trunc_series)) + + +def test_time_rule_frame(raw, frame): + win = 25 + minp = 10 + frm = frame[::2].resample("B").mean() + frame_result = frm.rolling(window=win, min_periods=minp).apply(f, raw=raw) + last_date = frame_result.index[-1] + prev_date = last_date - 24 * offsets.BDay() + + trunc_frame = frame[::2].truncate(prev_date, last_date) + tm.assert_series_equal( + frame_result.xs(last_date), + trunc_frame.apply(np.mean, raw=raw), + check_names=False, + ) + + +@pytest.mark.parametrize("minp", [0, 99, 100]) +def test_min_periods(raw, series, minp, step): + result = series.rolling(len(series) + 1, min_periods=minp, step=step).apply( + f, raw=raw + ) + expected = series.rolling(len(series), min_periods=minp, step=step).apply( + f, raw=raw + ) + nan_mask = isna(result) + tm.assert_series_equal(nan_mask, isna(expected)) + + nan_mask = ~nan_mask + tm.assert_almost_equal(result[nan_mask], expected[nan_mask]) + + +def test_center_reindex_series(raw, series): + # shifter index + s = [f"x{x:d}" for x in range(12)] + minp = 10 + + series_xp = ( + series.reindex(list(series.index) + s) + .rolling(window=25, min_periods=minp) + .apply(f, raw=raw) + .shift(-12) + .reindex(series.index) + ) + series_rs = series.rolling(window=25, min_periods=minp, center=True).apply( + f, raw=raw + ) + tm.assert_series_equal(series_xp, series_rs) + + +def test_center_reindex_frame(raw): + # shifter index + frame = DataFrame(range(100), index=date_range("2020-01-01", freq="D", periods=100)) + s = [f"x{x:d}" for x in range(12)] + minp = 10 + + frame_xp = ( + frame.reindex(list(frame.index) + s) + .rolling(window=25, min_periods=minp) + .apply(f, raw=raw) + .shift(-12) + .reindex(frame.index) + ) + frame_rs = frame.rolling(window=25, min_periods=minp, center=True).apply(f, raw=raw) + tm.assert_frame_equal(frame_xp, frame_rs) diff --git a/pandas/tests/window/test_base_indexer.py b/pandas/tests/window/test_base_indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..0c62ecc836c5043468acab308a5c26727d7652a5 --- /dev/null +++ b/pandas/tests/window/test_base_indexer.py @@ -0,0 +1,519 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + MultiIndex, + Series, + concat, + date_range, +) +import pandas._testing as tm +from pandas.api.indexers import ( + BaseIndexer, + FixedForwardWindowIndexer, +) +from pandas.core.indexers.objects import ( + ExpandingIndexer, + FixedWindowIndexer, + VariableOffsetWindowIndexer, +) + +from pandas.tseries.offsets import BusinessDay + + +def test_bad_get_window_bounds_signature(): + class BadIndexer(BaseIndexer): + def get_window_bounds(self): + return None + + indexer = BadIndexer() + with pytest.raises(ValueError, match="BadIndexer does not implement"): + Series(range(5)).rolling(indexer) + + +def test_expanding_indexer(): + s = Series(range(10)) + indexer = ExpandingIndexer() + result = s.rolling(indexer).mean() + expected = s.expanding().mean() + tm.assert_series_equal(result, expected) + + +def test_indexer_constructor_arg(): + # Example found in computation.rst + use_expanding = [True, False, True, False, True] + df = DataFrame({"values": range(5)}) + + class CustomIndexer(BaseIndexer): + def get_window_bounds(self, num_values, min_periods, center, closed, step): + start = np.empty(num_values, dtype=np.int64) + end = np.empty(num_values, dtype=np.int64) + for i in range(num_values): + if self.use_expanding[i]: + start[i] = 0 + end[i] = i + 1 + else: + start[i] = i + end[i] = i + self.window_size + return start, end + + indexer = CustomIndexer(window_size=1, use_expanding=use_expanding) + result = df.rolling(indexer).sum() + expected = DataFrame({"values": [0.0, 1.0, 3.0, 3.0, 10.0]}) + tm.assert_frame_equal(result, expected) + + +def test_indexer_accepts_rolling_args(): + df = DataFrame({"values": range(5)}) + + class CustomIndexer(BaseIndexer): + def get_window_bounds(self, num_values, min_periods, center, closed, step): + start = np.empty(num_values, dtype=np.int64) + end = np.empty(num_values, dtype=np.int64) + for i in range(num_values): + if ( + center + and min_periods == 1 + and closed == "both" + and step == 1 + and i == 2 + ): + start[i] = 0 + end[i] = num_values + else: + start[i] = i + end[i] = i + self.window_size + return start, end + + indexer = CustomIndexer(window_size=1) + result = df.rolling( + indexer, center=True, min_periods=1, closed="both", step=1 + ).sum() + expected = DataFrame({"values": [0.0, 1.0, 10.0, 3.0, 4.0]}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "func,np_func,expected,np_kwargs", + [ + ("count", len, [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, np.nan], {}), + ("min", np.min, [0.0, 1.0, 2.0, 3.0, 4.0, 6.0, 6.0, 7.0, 8.0, np.nan], {}), + ( + "max", + np.max, + [2.0, 3.0, 4.0, 100.0, 100.0, 100.0, 8.0, 9.0, 9.0, np.nan], + {}, + ), + ( + "std", + np.std, + [ + 1.0, + 1.0, + 1.0, + 55.71654452, + 54.85739087, + 53.9845657, + 1.0, + 1.0, + 0.70710678, + np.nan, + ], + {"ddof": 1}, + ), + ( + "var", + np.var, + [ + 1.0, + 1.0, + 1.0, + 3104.333333, + 3009.333333, + 2914.333333, + 1.0, + 1.0, + 0.500000, + np.nan, + ], + {"ddof": 1}, + ), + ( + "median", + np.median, + [1.0, 2.0, 3.0, 4.0, 6.0, 7.0, 7.0, 8.0, 8.5, np.nan], + {}, + ), + ], +) +def test_rolling_forward_window( + frame_or_series, func, np_func, expected, np_kwargs, step +): + # GH 32865 + values = np.arange(10.0) + values[5] = 100.0 + + indexer = FixedForwardWindowIndexer(window_size=3) + + match = "Forward-looking windows can't have center=True" + rolling = frame_or_series(values).rolling(window=indexer, center=True) + with pytest.raises(ValueError, match=match): + getattr(rolling, func)() + + match = "Forward-looking windows don't support setting the closed argument" + rolling = frame_or_series(values).rolling(window=indexer, closed="right") + with pytest.raises(ValueError, match=match): + getattr(rolling, func)() + + rolling = frame_or_series(values).rolling(window=indexer, min_periods=2, step=step) + result = getattr(rolling, func)() + + # Check that the function output matches the explicitly provided array + expected = frame_or_series(expected)[::step] + tm.assert_equal(result, expected) + + # Check that the rolling function output matches applying an alternative + # function to the rolling window object + expected2 = frame_or_series(rolling.apply(lambda x: np_func(x, **np_kwargs))) + tm.assert_equal(result, expected2) + + # Check that the function output matches applying an alternative function + # if min_periods isn't specified + # GH 39604: After count-min_periods deprecation, apply(lambda x: len(x)) + # is equivalent to count after setting min_periods=0 + min_periods = 0 if func == "count" else None + rolling3 = frame_or_series(values).rolling(window=indexer, min_periods=min_periods) + result3 = getattr(rolling3, func)() + expected3 = frame_or_series(rolling3.apply(lambda x: np_func(x, **np_kwargs))) + tm.assert_equal(result3, expected3) + + +def test_rolling_forward_skewness(frame_or_series, step): + values = np.arange(10.0) + values[5] = 100.0 + + indexer = FixedForwardWindowIndexer(window_size=5) + rolling = frame_or_series(values).rolling(window=indexer, min_periods=3, step=step) + result = rolling.skew() + + expected = frame_or_series( + [ + 0.0, + 2.232396, + 2.229508, + 2.228340, + 2.229091, + 2.231989, + 0.0, + 0.0, + np.nan, + np.nan, + ] + )[::step] + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "func,expected", + [ + ("cov", [2.0, 2.0, 2.0, 97.0, 2.0, -93.0, 2.0, 2.0, np.nan, np.nan]), + ( + "corr", + [ + 1.0, + 1.0, + 1.0, + 0.8704775290207161, + 0.018229084250926637, + -0.861357304646493, + 1.0, + 1.0, + np.nan, + np.nan, + ], + ), + ], +) +def test_rolling_forward_cov_corr(func, expected): + values1 = np.arange(10).reshape(-1, 1) + values2 = values1 * 2 + values1[5, 0] = 100 + values = np.concatenate([values1, values2], axis=1) + + indexer = FixedForwardWindowIndexer(window_size=3) + rolling = DataFrame(values).rolling(window=indexer, min_periods=3) + # We are interested in checking only pairwise covariance / correlation + result = getattr(rolling, func)().loc[(slice(None), 1), 0] + result = result.reset_index(drop=True) + expected = Series(expected).reset_index(drop=True) + expected.name = result.name + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "closed,expected_data", + [ + ["right", [0.0, 1.0, 2.0, 3.0, 7.0, 12.0, 6.0, 7.0, 8.0, 9.0]], + ["left", [0.0, 0.0, 1.0, 2.0, 5.0, 9.0, 5.0, 6.0, 7.0, 8.0]], + ], +) +def test_non_fixed_variable_window_indexer(closed, expected_data): + index = date_range("2020", periods=10) + df = DataFrame(range(10), index=index) + offset = BusinessDay(1) + indexer = VariableOffsetWindowIndexer(index=index, offset=offset) + result = df.rolling(indexer, closed=closed).sum() + expected = DataFrame(expected_data, index=index) + tm.assert_frame_equal(result, expected) + + +def test_variableoffsetwindowindexer_not_dti(): + # GH 54379 + with pytest.raises(ValueError, match="index must be a DatetimeIndex."): + VariableOffsetWindowIndexer(index="foo", offset=BusinessDay(1)) + + +def test_variableoffsetwindowindexer_not_offset(): + # GH 54379 + idx = date_range("2020", periods=10) + with pytest.raises(ValueError, match="offset must be a DateOffset-like object."): + VariableOffsetWindowIndexer(index=idx, offset="foo") + + +def test_fixed_forward_indexer_count(step): + # GH: 35579 + df = DataFrame({"b": [None, None, None, 7]}) + indexer = FixedForwardWindowIndexer(window_size=2) + result = df.rolling(window=indexer, min_periods=0, step=step).count() + expected = DataFrame({"b": [0.0, 0.0, 1.0, 1.0]})[::step] + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("end_value", "values"), [(1, [0.0, 1, 1, 3, 2]), (-1, [0.0, 1, 0, 3, 1])] +) +@pytest.mark.parametrize(("func", "args"), [("median", []), ("quantile", [0.5])]) +def test_indexer_quantile_sum(end_value, values, func, args): + # GH 37153 + class CustomIndexer(BaseIndexer): + def get_window_bounds(self, num_values, min_periods, center, closed, step): + start = np.empty(num_values, dtype=np.int64) + end = np.empty(num_values, dtype=np.int64) + for i in range(num_values): + if self.use_expanding[i]: + start[i] = 0 + end[i] = max(i + end_value, 1) + else: + start[i] = i + end[i] = i + self.window_size + return start, end + + use_expanding = [True, False, True, False, True] + df = DataFrame({"values": range(5)}) + + indexer = CustomIndexer(window_size=1, use_expanding=use_expanding) + result = getattr(df.rolling(indexer), func)(*args) + expected = DataFrame({"values": values}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "indexer_class", [FixedWindowIndexer, FixedForwardWindowIndexer, ExpandingIndexer] +) +@pytest.mark.parametrize("window_size", [1, 2, 12]) +@pytest.mark.parametrize( + "df_data", + [ + {"a": [1, 1], "b": [0, 1]}, + {"a": [1, 2], "b": [0, 1]}, + {"a": [1] * 16, "b": [np.nan, 1, 2, np.nan, *list(range(4, 16))]}, + ], +) +def test_indexers_are_reusable_after_groupby_rolling( + indexer_class, window_size, df_data +): + # GH 43267 + df = DataFrame(df_data) + num_trials = 3 + indexer = indexer_class(window_size=window_size) + original_window_size = indexer.window_size + for i in range(num_trials): + df.groupby("a")["b"].rolling(window=indexer, min_periods=1).mean() + assert indexer.window_size == original_window_size + + +@pytest.mark.parametrize( + "window_size, num_values, expected_start, expected_end", + [ + (1, 1, [0], [1]), + (1, 2, [0, 1], [1, 2]), + (2, 1, [0], [1]), + (2, 2, [0, 1], [2, 2]), + (5, 12, range(12), list(range(5, 12)) + [12] * 5), + (12, 5, range(5), [5] * 5), + (0, 0, np.array([]), np.array([])), + (1, 0, np.array([]), np.array([])), + (0, 1, [0], [0]), + ], +) +def test_fixed_forward_indexer_bounds( + window_size, num_values, expected_start, expected_end, step +): + # GH 43267 + indexer = FixedForwardWindowIndexer(window_size=window_size) + start, end = indexer.get_window_bounds(num_values=num_values, step=step) + + tm.assert_numpy_array_equal( + start, np.array(expected_start[::step]), check_dtype=False + ) + tm.assert_numpy_array_equal(end, np.array(expected_end[::step]), check_dtype=False) + assert len(start) == len(end) + + +@pytest.mark.parametrize( + "df, window_size, expected", + [ + ( + DataFrame({"b": [0, 1, 2], "a": [1, 2, 2]}), + 2, + Series( + [0, 1.5, 2.0], + index=MultiIndex.from_arrays([[1, 2, 2], range(3)], names=["a", None]), + name="b", + dtype=np.float64, + ), + ), + ( + DataFrame( + { + "b": [np.nan, 1, 2, np.nan, *list(range(4, 18))], + "a": [1] * 7 + [2] * 11, + "c": range(18), + } + ), + 12, + Series( + [ + 3.6, + 3.6, + 4.25, + 5.0, + 5.0, + 5.5, + 6.0, + 12.0, + 12.5, + 13.0, + 13.5, + 14.0, + 14.5, + 15.0, + 15.5, + 16.0, + 16.5, + 17.0, + ], + index=MultiIndex.from_arrays( + [[1] * 7 + [2] * 11, range(18)], names=["a", None] + ), + name="b", + dtype=np.float64, + ), + ), + ], +) +def test_rolling_groupby_with_fixed_forward_specific(df, window_size, expected): + # GH 43267 + indexer = FixedForwardWindowIndexer(window_size=window_size) + result = df.groupby("a")["b"].rolling(window=indexer, min_periods=1).mean() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "group_keys", + [ + (1,), + (1, 2), + (2, 1), + (1, 1, 2), + (1, 2, 1), + (1, 1, 2, 2), + (1, 2, 3, 2, 3), + (1, 1, 2) * 4, + (1, 2, 3) * 5, + ], +) +@pytest.mark.parametrize("window_size", [1, 2, 3, 4, 5, 8, 20]) +def test_rolling_groupby_with_fixed_forward_many(group_keys, window_size): + # GH 43267 + df = DataFrame( + { + "a": np.array(list(group_keys)), + "b": np.arange(len(group_keys), dtype=np.float64) + 17, + "c": np.arange(len(group_keys), dtype=np.int64), + } + ) + + indexer = FixedForwardWindowIndexer(window_size=window_size) + result = df.groupby("a")["b"].rolling(window=indexer, min_periods=1).sum() + result.index.names = ["a", "c"] + + groups = df.groupby("a")[["a", "b", "c"]] + manual = concat( + [ + g.assign( + b=[ + g["b"].iloc[i : i + window_size].sum(min_count=1) + for i in range(len(g)) + ] + ) + for _, g in groups + ] + ) + manual = manual.set_index(["a", "c"])["b"] + + tm.assert_series_equal(result, manual) + + +def test_unequal_start_end_bounds(): + class CustomIndexer(BaseIndexer): + def get_window_bounds(self, num_values, min_periods, center, closed, step): + return np.array([1]), np.array([1, 2]) + + indexer = CustomIndexer() + roll = Series(1).rolling(indexer) + match = "start" + with pytest.raises(ValueError, match=match): + roll.mean() + + with pytest.raises(ValueError, match=match): + next(iter(roll)) + + with pytest.raises(ValueError, match=match): + roll.corr(pairwise=True) + + with pytest.raises(ValueError, match=match): + roll.cov(pairwise=True) + + +def test_unequal_bounds_to_object(): + # GH 44470 + class CustomIndexer(BaseIndexer): + def get_window_bounds(self, num_values, min_periods, center, closed, step): + return np.array([1]), np.array([2]) + + indexer = CustomIndexer() + roll = Series([1, 1]).rolling(indexer) + match = "start and end" + with pytest.raises(ValueError, match=match): + roll.mean() + + with pytest.raises(ValueError, match=match): + next(iter(roll)) + + with pytest.raises(ValueError, match=match): + roll.corr(pairwise=True) + + with pytest.raises(ValueError, match=match): + roll.cov(pairwise=True) diff --git a/pandas/tests/window/test_cython_aggregations.py b/pandas/tests/window/test_cython_aggregations.py new file mode 100644 index 0000000000000000000000000000000000000000..2e23618a3a201ce35a4217153510389eef791590 --- /dev/null +++ b/pandas/tests/window/test_cython_aggregations.py @@ -0,0 +1,114 @@ +from functools import partial +import sys + +import numpy as np +import pytest + +import pandas._libs.window.aggregations as window_aggregations + +from pandas import Series +import pandas._testing as tm + + +def _get_rolling_aggregations(): + # list pairs of name and function + # each function has this signature: + # (const float64_t[:] values, ndarray[int64_t] start, + # ndarray[int64_t] end, int64_t minp) -> np.ndarray + named_roll_aggs = ( + [ + ("roll_sum", window_aggregations.roll_sum), + ("roll_mean", window_aggregations.roll_mean), + ] + + [ + (f"roll_var({ddof})", partial(window_aggregations.roll_var, ddof=ddof)) + for ddof in [0, 1] + ] + + [ + ("roll_skew", window_aggregations.roll_skew), + ("roll_kurt", window_aggregations.roll_kurt), + ("roll_median_c", window_aggregations.roll_median_c), + ("roll_max", window_aggregations.roll_max), + ("roll_min", window_aggregations.roll_min), + ("roll_first", window_aggregations.roll_first), + ("roll_last", window_aggregations.roll_last), + ("roll_nunique", window_aggregations.roll_nunique), + ] + + [ + ( + f"roll_quantile({quantile},{interpolation})", + partial( + window_aggregations.roll_quantile, + quantile=quantile, + interpolation=interpolation, + ), + ) + for quantile in [0.0001, 0.5, 0.9999] + for interpolation in window_aggregations.interpolation_types + ] + + [ + ( + f"roll_rank({percentile},{method},{ascending})", + partial( + window_aggregations.roll_rank, + percentile=percentile, + method=method, + ascending=ascending, + ), + ) + for percentile in [True, False] + for method in window_aggregations.rolling_rank_tiebreakers.keys() + for ascending in [True, False] + ] + ) + # unzip to a list of 2 tuples, names and functions + unzipped = list(zip(*named_roll_aggs, strict=True)) + return {"ids": unzipped[0], "params": unzipped[1]} + + +_rolling_aggregations = _get_rolling_aggregations() + + +@pytest.fixture( + params=_rolling_aggregations["params"], ids=_rolling_aggregations["ids"] +) +def rolling_aggregation(request): + """Make a rolling aggregation function as fixture.""" + return request.param + + +def test_rolling_aggregation_boundary_consistency(rolling_aggregation): + # GH-45647 + minp, step, width, size, selection = 0, 1, 3, 11, [2, 7] + values = np.arange(1, 1 + size, dtype=np.float64) + end = np.arange(width, size, step, dtype=np.int64) + start = end - width + selarr = np.array(selection, dtype=np.int32) + result = Series(rolling_aggregation(values, start[selarr], end[selarr], minp)) + expected = Series(rolling_aggregation(values, start, end, minp)[selarr]) + tm.assert_equal(expected, result) + + +def test_rolling_aggregation_with_unused_elements(rolling_aggregation): + # GH-45647 + minp, width = 0, 5 # width at least 4 for kurt + size = 2 * width + 5 + values = np.arange(1, size + 1, dtype=np.float64) + values[width : width + 2] = sys.float_info.min + values[width + 2] = np.nan + values[width + 3 : width + 5] = sys.float_info.max + start = np.array([0, size - width], dtype=np.int64) + end = np.array([width, size], dtype=np.int64) + loc = np.array( + [j for i in range(len(start)) for j in range(start[i], end[i])], + dtype=np.int32, + ) + result = Series(rolling_aggregation(values, start, end, minp)) + compact_values = np.array(values[loc], dtype=np.float64) + compact_start = np.arange(0, len(start) * width, width, dtype=np.int64) + compact_end = compact_start + width + expected = Series( + rolling_aggregation(compact_values, compact_start, compact_end, minp) + ) + assert np.isfinite(expected.values).all(), "Not all expected values are finite" + tm.assert_equal(expected, result) diff --git a/pandas/tests/window/test_dtypes.py b/pandas/tests/window/test_dtypes.py new file mode 100644 index 0000000000000000000000000000000000000000..4007320b5de332ee4aef40b1ad1be9092eeb3347 --- /dev/null +++ b/pandas/tests/window/test_dtypes.py @@ -0,0 +1,173 @@ +import numpy as np +import pytest + +from pandas.errors import DataError + +from pandas.core.dtypes.common import pandas_dtype + +from pandas import ( + NA, + DataFrame, + Series, +) +import pandas._testing as tm + +# gh-12373 : rolling functions error on float32 data +# make sure rolling functions works for different dtypes +# +# further note that we are only checking rolling for fully dtype +# compliance (though both expanding and ewm inherit) + + +def get_dtype(dtype, coerce_int=None): + if coerce_int is False and "int" in dtype: + return None + return pandas_dtype(dtype) + + +@pytest.fixture( + params=[ + "object", + "category", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "float32", + "float64", + "m8[ns]", + "M8[ns]", + "datetime64[ns, UTC]", + ] +) +def dtypes(request): + """Dtypes for window tests""" + return request.param + + +@pytest.mark.parametrize( + "method, data, expected_data, coerce_int, min_periods", + [ + ("count", np.arange(5), [1, 2, 2, 2, 2], True, 0), + ("count", np.arange(10, 0, -2), [1, 2, 2, 2, 2], True, 0), + ("count", [0, 1, 2, np.nan, 4], [1, 2, 2, 1, 1], False, 0), + ("max", np.arange(5), [np.nan, 1, 2, 3, 4], True, None), + ("max", np.arange(10, 0, -2), [np.nan, 10, 8, 6, 4], True, None), + ("max", [0, 1, 2, np.nan, 4], [np.nan, 1, 2, np.nan, np.nan], False, None), + ("min", np.arange(5), [np.nan, 0, 1, 2, 3], True, None), + ("min", np.arange(10, 0, -2), [np.nan, 8, 6, 4, 2], True, None), + ("min", [0, 1, 2, np.nan, 4], [np.nan, 0, 1, np.nan, np.nan], False, None), + ("sum", np.arange(5), [np.nan, 1, 3, 5, 7], True, None), + ("sum", np.arange(10, 0, -2), [np.nan, 18, 14, 10, 6], True, None), + ("sum", [0, 1, 2, np.nan, 4], [np.nan, 1, 3, np.nan, np.nan], False, None), + ("mean", np.arange(5), [np.nan, 0.5, 1.5, 2.5, 3.5], True, None), + ("mean", np.arange(10, 0, -2), [np.nan, 9, 7, 5, 3], True, None), + ("mean", [0, 1, 2, np.nan, 4], [np.nan, 0.5, 1.5, np.nan, np.nan], False, None), + ("std", np.arange(5), [np.nan] + [np.sqrt(0.5)] * 4, True, None), + ("std", np.arange(10, 0, -2), [np.nan] + [np.sqrt(2)] * 4, True, None), + ( + "std", + [0, 1, 2, np.nan, 4], + [np.nan] + [np.sqrt(0.5)] * 2 + [np.nan] * 2, + False, + None, + ), + ("var", np.arange(5), [np.nan, 0.5, 0.5, 0.5, 0.5], True, None), + ("var", np.arange(10, 0, -2), [np.nan, 2, 2, 2, 2], True, None), + ("var", [0, 1, 2, np.nan, 4], [np.nan, 0.5, 0.5, np.nan, np.nan], False, None), + ("median", np.arange(5), [np.nan, 0.5, 1.5, 2.5, 3.5], True, None), + ("median", np.arange(10, 0, -2), [np.nan, 9, 7, 5, 3], True, None), + ( + "median", + [0, 1, 2, np.nan, 4], + [np.nan, 0.5, 1.5, np.nan, np.nan], + False, + None, + ), + ], +) +def test_series_dtypes( + method, data, expected_data, coerce_int, dtypes, min_periods, step +): + ser = Series(data, dtype=get_dtype(dtypes, coerce_int=coerce_int)) + rolled = ser.rolling(2, min_periods=min_periods, step=step) + + if dtypes in ("m8[ns]", "M8[ns]", "datetime64[ns, UTC]") and method != "count": + msg = "No numeric types to aggregate" + with pytest.raises(DataError, match=msg): + getattr(rolled, method)() + else: + result = getattr(rolled, method)() + expected = Series(expected_data, dtype="float64")[::step] + tm.assert_almost_equal(result, expected) + + +def test_series_nullable_int(any_signed_int_ea_dtype, step): + # GH 43016 + ser = Series([0, 1, NA], dtype=any_signed_int_ea_dtype) + result = ser.rolling(2, step=step).mean() + expected = Series([np.nan, 0.5, np.nan])[::step] + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "method, expected_data, min_periods", + [ + ("count", {0: Series([1, 2, 2, 2, 2]), 1: Series([1, 2, 2, 2, 2])}, 0), + ( + "max", + {0: Series([np.nan, 2, 4, 6, 8]), 1: Series([np.nan, 3, 5, 7, 9])}, + None, + ), + ( + "min", + {0: Series([np.nan, 0, 2, 4, 6]), 1: Series([np.nan, 1, 3, 5, 7])}, + None, + ), + ( + "sum", + {0: Series([np.nan, 2, 6, 10, 14]), 1: Series([np.nan, 4, 8, 12, 16])}, + None, + ), + ( + "mean", + {0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])}, + None, + ), + ( + "std", + { + 0: Series([np.nan] + [np.sqrt(2)] * 4), + 1: Series([np.nan] + [np.sqrt(2)] * 4), + }, + None, + ), + ( + "var", + {0: Series([np.nan, 2, 2, 2, 2]), 1: Series([np.nan, 2, 2, 2, 2])}, + None, + ), + ( + "median", + {0: Series([np.nan, 1, 3, 5, 7]), 1: Series([np.nan, 2, 4, 6, 8])}, + None, + ), + ], +) +def test_dataframe_dtypes(method, expected_data, dtypes, min_periods, step): + df = DataFrame(np.arange(10).reshape((5, 2)), dtype=get_dtype(dtypes)) + rolled = df.rolling(2, min_periods=min_periods, step=step) + + if dtypes in ("m8[ns]", "M8[ns]", "datetime64[ns, UTC]") and method != "count": + msg = "Cannot aggregate non-numeric type" + with pytest.raises(DataError, match=msg): + getattr(rolled, method)() + else: + result = getattr(rolled, method)() + expected = DataFrame(expected_data, dtype="float64")[::step] + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/window/test_ewm.py b/pandas/tests/window/test_ewm.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea6c805a2ee4936501fbe9c1973572167efc914 --- /dev/null +++ b/pandas/tests/window/test_ewm.py @@ -0,0 +1,737 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + DatetimeIndex, + Series, + date_range, +) +import pandas._testing as tm + + +def test_doc_string(): + df = DataFrame({"B": [0, 1, 2, np.nan, 4]}) + df + df.ewm(com=0.5).mean() + + +def test_constructor(frame_or_series): + c = frame_or_series(range(5)).ewm + + # valid + c(com=0.5) + c(span=1.5) + c(alpha=0.5) + c(halflife=0.75) + c(com=0.5, span=None) + c(alpha=0.5, com=None) + c(halflife=0.75, alpha=None) + + # not valid: mutually exclusive + msg = "comass, span, halflife, and alpha are mutually exclusive" + with pytest.raises(ValueError, match=msg): + c(com=0.5, alpha=0.5) + with pytest.raises(ValueError, match=msg): + c(span=1.5, halflife=0.75) + with pytest.raises(ValueError, match=msg): + c(alpha=0.5, span=1.5) + + # not valid: com < 0 + msg = "comass must satisfy: comass >= 0" + with pytest.raises(ValueError, match=msg): + c(com=-0.5) + + # not valid: span < 1 + msg = "span must satisfy: span >= 1" + with pytest.raises(ValueError, match=msg): + c(span=0.5) + + # not valid: halflife <= 0 + msg = "halflife must satisfy: halflife > 0" + with pytest.raises(ValueError, match=msg): + c(halflife=0) + + # not valid: alpha <= 0 or alpha > 1 + msg = "alpha must satisfy: 0 < alpha <= 1" + for alpha in (-0.5, 1.5): + with pytest.raises(ValueError, match=msg): + c(alpha=alpha) + + +def test_ewma_times_not_datetime_type(): + msg = r"times must be datetime64 dtype." + with pytest.raises(ValueError, match=msg): + Series(range(5)).ewm(times=np.arange(5)) + + +def test_ewma_times_not_same_length(): + msg = "times must be the same length as the object." + with pytest.raises(ValueError, match=msg): + Series(range(5)).ewm(times=np.arange(4).astype("datetime64[ns]")) + + +def test_ewma_halflife_not_correct_type(): + msg = "halflife must be a timedelta convertible object" + with pytest.raises(ValueError, match=msg): + Series(range(5)).ewm(halflife=1, times=np.arange(5).astype("datetime64[ns]")) + + +def test_ewma_halflife_without_times(halflife_with_times): + msg = "halflife can only be a timedelta convertible argument if times is not None." + with pytest.raises(ValueError, match=msg): + Series(range(5)).ewm(halflife=halflife_with_times) + + +@pytest.mark.parametrize( + "times", + [ + np.arange(10).astype("datetime64[D]").astype("datetime64[ns]"), + date_range("2000", freq="D", periods=10), + date_range("2000", freq="D", periods=10).tz_localize("UTC"), + ], +) +@pytest.mark.parametrize("min_periods", [0, 2]) +def test_ewma_with_times_equal_spacing(halflife_with_times, times, min_periods): + halflife = halflife_with_times + data = np.arange(10.0) + data[::2] = np.nan + df = DataFrame({"A": data}) + result = df.ewm(halflife=halflife, min_periods=min_periods, times=times).mean() + expected = df.ewm(halflife=1.0, min_periods=min_periods).mean() + tm.assert_frame_equal(result, expected) + + +def test_ewma_with_times_variable_spacing(tz_aware_fixture, unit, adjust): + # GH 54328 + tz = tz_aware_fixture + halflife = "23 days" + times = ( + DatetimeIndex(["2020-01-01", "2020-01-10T00:04:05", "2020-02-23T05:00:23"]) + .tz_localize(tz) + .as_unit(unit) + ) + data = np.arange(3) + df = DataFrame(data) + result = df.ewm(halflife=halflife, times=times, adjust=adjust).mean() + if adjust: + expected = DataFrame([0.0, 0.5674161888241773, 1.545239952073459]) + else: + expected = DataFrame([0.0, 0.23762518642226227, 1.534926369128742]) + tm.assert_frame_equal(result, expected) + + +def test_ewm_with_nat_raises(halflife_with_times): + # GH#38535 + ser = Series(range(1)) + times = DatetimeIndex(["NaT"]) + with pytest.raises(ValueError, match="Cannot convert NaT values to integer"): + ser.ewm(com=0.1, halflife=halflife_with_times, times=times) + + +def test_ewm_with_times_getitem(halflife_with_times): + # GH 40164 + halflife = halflife_with_times + data = np.arange(10.0) + data[::2] = np.nan + times = date_range("2000", freq="D", periods=10) + df = DataFrame({"A": data, "B": data}) + result = df.ewm(halflife=halflife, times=times)["A"].mean() + expected = df.ewm(halflife=1.0)["A"].mean() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("arg", ["com", "halflife", "span", "alpha"]) +def test_ewm_getitem_attributes_retained(arg, adjust, ignore_na): + # GH 40164 + kwargs = {arg: 1, "adjust": adjust, "ignore_na": ignore_na} + ewm = DataFrame({"A": range(1), "B": range(1)}).ewm(**kwargs) + expected = {attr: getattr(ewm, attr) for attr in ewm._attributes} + ewm_slice = ewm["A"] + result = {attr: getattr(ewm, attr) for attr in ewm_slice._attributes} + assert result == expected + + +def test_ewma_times_adjust_false_with_disallowed_com(): + # GH 54328 + with pytest.raises( + NotImplementedError, + match=( + "None of com, span, or alpha can be specified " + "if times is provided and adjust=False" + ), + ): + Series(range(1)).ewm( + 0.1, + adjust=False, + times=date_range("2000", freq="D", periods=1), + halflife="1D", + ) + + +def test_ewma_times_adjust_false_with_disallowed_alpha(): + # GH 54328 + with pytest.raises( + NotImplementedError, + match=( + "None of com, span, or alpha can be specified " + "if times is provided and adjust=False" + ), + ): + Series(range(1)).ewm( + 0.1, + adjust=False, + times=date_range("2000", freq="D", periods=1), + alpha=0.5, + halflife="1D", + ) + + +def test_ewma_times_adjust_false_with_disallowed_span(): + # GH 54328 + with pytest.raises( + NotImplementedError, + match=( + "None of com, span, or alpha can be specified " + "if times is provided and adjust=False" + ), + ): + Series(range(1)).ewm( + 0.1, + adjust=False, + times=date_range("2000", freq="D", periods=1), + span=10, + halflife="1D", + ) + + +def test_times_string_col_raises(): + # GH 43265 + df = DataFrame( + {"A": np.arange(10.0), "time_col": date_range("2000", freq="D", periods=10)} + ) + with pytest.raises(ValueError, match="times must be datetime64"): + df.ewm(halflife="1 day", min_periods=0, times="time_col") + + +def test_ewm_sum_adjust_false_notimplemented(): + data = Series(range(1)).ewm(com=1, adjust=False) + with pytest.raises(NotImplementedError, match="sum is not"): + data.sum() + + +@pytest.mark.parametrize("method", ["sum", "std", "var", "cov", "corr"]) +def test_times_only_mean_implemented(frame_or_series, method): + # GH 51695 + halflife = "1 day" + times = date_range("2000", freq="D", periods=10) + ewm = frame_or_series(range(10)).ewm(halflife=halflife, times=times) + with pytest.raises( + NotImplementedError, match=f"{method} is not implemented with times" + ): + getattr(ewm, method)() + + +@pytest.mark.parametrize( + "expected_data, ignore", + [[[10.0, 5.0, 2.5, 11.25], False], [[10.0, 5.0, 5.0, 12.5], True]], +) +def test_ewm_sum(expected_data, ignore): + # xref from Numbagg tests + # https://github.com/numbagg/numbagg/blob/v0.2.1/numbagg/test/test_moving.py#L50 + data = Series([10, 0, np.nan, 10]) + result = data.ewm(alpha=0.5, ignore_na=ignore).sum() + expected = Series(expected_data) + tm.assert_series_equal(result, expected) + + +def test_ewma_adjust(): + vals = Series(np.zeros(1000)) + vals[5] = 1 + result = vals.ewm(span=100, adjust=False).mean().sum() + assert np.abs(result - 1) < 1e-2 + + +def test_ewma_cases(adjust, ignore_na): + # try adjust/ignore_na args matrix + + s = Series([1.0, 2.0, 4.0, 8.0]) + + if adjust: + expected = Series([1.0, 1.6, 2.736842, 4.923077]) + else: + expected = Series([1.0, 1.333333, 2.222222, 4.148148]) + + result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean() + tm.assert_series_equal(result, expected) + + +def test_ewma_nan_handling(): + s = Series([1.0] + [np.nan] * 5 + [1.0]) + result = s.ewm(com=5).mean() + tm.assert_series_equal(result, Series([1.0] * len(s))) + + s = Series([np.nan] * 2 + [1.0] + [np.nan] * 2 + [1.0]) + result = s.ewm(com=5).mean() + tm.assert_series_equal(result, Series([np.nan] * 2 + [1.0] * 4)) + + +@pytest.mark.parametrize( + "s, adjust, ignore_na, w", + [ + ( + [np.nan, 1.0, 101.0], + True, + False, + [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0], + ), + ( + [np.nan, 1.0, 101.0], + True, + True, + [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0], + ), + ( + [np.nan, 1.0, 101.0], + False, + False, + [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))], + ), + ( + [np.nan, 1.0, 101.0], + False, + True, + [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))], + ), + ( + [1.0, np.nan, 101.0], + True, + False, + [(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, 1.0], + ), + ( + [1.0, np.nan, 101.0], + True, + True, + [(1.0 - (1.0 / (1.0 + 2.0))), np.nan, 1.0], + ), + ( + [1.0, np.nan, 101.0], + False, + False, + [(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, (1.0 / (1.0 + 2.0))], + ), + ( + [1.0, np.nan, 101.0], + False, + True, + [(1.0 - (1.0 / (1.0 + 2.0))), np.nan, (1.0 / (1.0 + 2.0))], + ), + ( + [np.nan, 1.0, np.nan, np.nan, 101.0, np.nan], + True, + False, + [np.nan, (1.0 - (1.0 / (1.0 + 2.0))) ** 3, np.nan, np.nan, 1.0, np.nan], + ), + ( + [np.nan, 1.0, np.nan, np.nan, 101.0, np.nan], + True, + True, + [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), np.nan, np.nan, 1.0, np.nan], + ), + ( + [np.nan, 1.0, np.nan, np.nan, 101.0, np.nan], + False, + False, + [ + np.nan, + (1.0 - (1.0 / (1.0 + 2.0))) ** 3, + np.nan, + np.nan, + (1.0 / (1.0 + 2.0)), + np.nan, + ], + ), + ( + [np.nan, 1.0, np.nan, np.nan, 101.0, np.nan], + False, + True, + [ + np.nan, + (1.0 - (1.0 / (1.0 + 2.0))), + np.nan, + np.nan, + (1.0 / (1.0 + 2.0)), + np.nan, + ], + ), + ( + [1.0, np.nan, 101.0, 50.0], + True, + False, + [ + (1.0 - (1.0 / (1.0 + 2.0))) ** 3, + np.nan, + (1.0 - (1.0 / (1.0 + 2.0))), + 1.0, + ], + ), + ( + [1.0, np.nan, 101.0, 50.0], + True, + True, + [ + (1.0 - (1.0 / (1.0 + 2.0))) ** 2, + np.nan, + (1.0 - (1.0 / (1.0 + 2.0))), + 1.0, + ], + ), + ( + [1.0, np.nan, 101.0, 50.0], + False, + False, + [ + (1.0 - (1.0 / (1.0 + 2.0))) ** 3, + np.nan, + (1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)), + (1.0 / (1.0 + 2.0)) + * ((1.0 - (1.0 / (1.0 + 2.0))) ** 2 + (1.0 / (1.0 + 2.0))), + ], + ), + ( + [1.0, np.nan, 101.0, 50.0], + False, + True, + [ + (1.0 - (1.0 / (1.0 + 2.0))) ** 2, + np.nan, + (1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)), + (1.0 / (1.0 + 2.0)), + ], + ), + ], +) +def test_ewma_nan_handling_cases(s, adjust, ignore_na, w): + # GH 7603 + s = Series(s) + expected = (s.multiply(w).cumsum() / Series(w).cumsum()).ffill() + result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean() + + tm.assert_series_equal(result, expected) + if ignore_na is False: + # check that ignore_na defaults to False + result = s.ewm(com=2.0, adjust=adjust).mean() + tm.assert_series_equal(result, expected) + + +def test_ewm_alpha(): + # GH 10789 + arr = np.random.default_rng(2).standard_normal(100) + locs = np.arange(20, 40) + arr[locs] = np.nan + + s = Series(arr) + a = s.ewm(alpha=0.61722699889169674).mean() + b = s.ewm(com=0.62014947789973052).mean() + c = s.ewm(span=2.240298955799461).mean() + d = s.ewm(halflife=0.721792864318).mean() + tm.assert_series_equal(a, b) + tm.assert_series_equal(a, c) + tm.assert_series_equal(a, d) + + +def test_ewm_domain_checks(): + # GH 12492 + arr = np.random.default_rng(2).standard_normal(100) + locs = np.arange(20, 40) + arr[locs] = np.nan + + s = Series(arr) + msg = "comass must satisfy: comass >= 0" + with pytest.raises(ValueError, match=msg): + s.ewm(com=-0.1) + s.ewm(com=0.0) + s.ewm(com=0.1) + + msg = "span must satisfy: span >= 1" + with pytest.raises(ValueError, match=msg): + s.ewm(span=-0.1) + with pytest.raises(ValueError, match=msg): + s.ewm(span=0.0) + with pytest.raises(ValueError, match=msg): + s.ewm(span=0.9) + s.ewm(span=1.0) + s.ewm(span=1.1) + + msg = "halflife must satisfy: halflife > 0" + with pytest.raises(ValueError, match=msg): + s.ewm(halflife=-0.1) + with pytest.raises(ValueError, match=msg): + s.ewm(halflife=0.0) + s.ewm(halflife=0.1) + + msg = "alpha must satisfy: 0 < alpha <= 1" + with pytest.raises(ValueError, match=msg): + s.ewm(alpha=-0.1) + with pytest.raises(ValueError, match=msg): + s.ewm(alpha=0.0) + s.ewm(alpha=0.1) + s.ewm(alpha=1.0) + with pytest.raises(ValueError, match=msg): + s.ewm(alpha=1.1) + + +@pytest.mark.parametrize("method", ["mean", "std", "var"]) +def test_ew_empty_series(method): + vals = Series([], dtype=np.float64) + + ewm = vals.ewm(3) + result = getattr(ewm, method)() + tm.assert_almost_equal(result, vals) + + +@pytest.mark.parametrize("min_periods", [0, 1]) +@pytest.mark.parametrize("name", ["mean", "var", "std"]) +def test_ew_min_periods(min_periods, name): + # excluding NaNs correctly + arr = np.random.default_rng(2).standard_normal(50) + arr[:10] = np.nan + arr[-10:] = np.nan + s = Series(arr) + + # check min_periods + # GH 7898 + result = getattr(s.ewm(com=50, min_periods=2), name)() + assert result[:11].isna().all() + assert not result[11:].isna().any() + + result = getattr(s.ewm(com=50, min_periods=min_periods), name)() + if name == "mean": + assert result[:10].isna().all() + assert not result[10:].isna().any() + else: + # ewm.std, ewm.var (with bias=False) require at least + # two values + assert result[:11].isna().all() + assert not result[11:].isna().any() + + # check series of length 0 + result = getattr(Series(dtype=object).ewm(com=50, min_periods=min_periods), name)() + tm.assert_series_equal(result, Series(dtype="float64")) + + # check series of length 1 + result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)() + if name == "mean": + tm.assert_series_equal(result, Series([1.0])) + else: + # ewm.std, ewm.var with bias=False require at least + # two values + tm.assert_series_equal(result, Series([np.nan])) + + # pass in ints + result2 = getattr(Series(np.arange(50)).ewm(span=10), name)() + assert result2.dtype == np.float64 + + +@pytest.mark.parametrize("name", ["cov", "corr"]) +def test_ewm_corr_cov(name): + A = Series(np.random.default_rng(2).standard_normal(50), index=range(50)) + B = A[2:] + np.random.default_rng(2).standard_normal(48) + + A[:10] = np.nan + B.iloc[-10:] = np.nan + + result = getattr(A.ewm(com=20, min_periods=5), name)(B) + assert np.isnan(result.values[:14]).all() + assert not np.isnan(result.values[14:]).any() + + +@pytest.mark.parametrize("min_periods", [0, 1, 2]) +@pytest.mark.parametrize("name", ["cov", "corr"]) +def test_ewm_corr_cov_min_periods(name, min_periods): + # GH 7898 + A = Series(np.random.default_rng(2).standard_normal(50), index=range(50)) + B = A[2:] + np.random.default_rng(2).standard_normal(48) + + A[:10] = np.nan + B.iloc[-10:] = np.nan + + result = getattr(A.ewm(com=20, min_periods=min_periods), name)(B) + # binary functions (ewmcov, ewmcorr) with bias=False require at + # least two values + assert np.isnan(result.values[:11]).all() + assert not np.isnan(result.values[11:]).any() + + # check series of length 0 + empty = Series([], dtype=np.float64) + result = getattr(empty.ewm(com=50, min_periods=min_periods), name)(empty) + tm.assert_series_equal(result, empty) + + # check series of length 1 + result = getattr(Series([1.0]).ewm(com=50, min_periods=min_periods), name)( + Series([1.0]) + ) + tm.assert_series_equal(result, Series([np.nan])) + + +@pytest.mark.parametrize("name", ["cov", "corr"]) +def test_different_input_array_raise_exception(name): + A = Series(np.random.default_rng(2).standard_normal(50), index=range(50)) + A[:10] = np.nan + + msg = "other must be a DataFrame or Series" + # exception raised is Exception + with pytest.raises(ValueError, match=msg): + getattr(A.ewm(com=20, min_periods=5), name)( + np.random.default_rng(2).standard_normal(50) + ) + + +@pytest.mark.parametrize("name", ["var", "std", "mean"]) +def test_ewma_series(series, name): + series_result = getattr(series.ewm(com=10), name)() + assert isinstance(series_result, Series) + + +@pytest.mark.parametrize("name", ["var", "std", "mean"]) +def test_ewma_frame(frame, name): + frame_result = getattr(frame.ewm(com=10), name)() + assert isinstance(frame_result, DataFrame) + + +def test_ewma_span_com_args(series): + A = series.ewm(com=9.5).mean() + B = series.ewm(span=20).mean() + tm.assert_almost_equal(A, B) + msg = "comass, span, halflife, and alpha are mutually exclusive" + with pytest.raises(ValueError, match=msg): + series.ewm(com=9.5, span=20) + + msg = "Must pass one of comass, span, halflife, or alpha" + with pytest.raises(ValueError, match=msg): + series.ewm().mean() + + +def test_ewma_halflife_arg(series): + A = series.ewm(com=13.932726172912965).mean() + B = series.ewm(halflife=10.0).mean() + tm.assert_almost_equal(A, B) + msg = "comass, span, halflife, and alpha are mutually exclusive" + with pytest.raises(ValueError, match=msg): + series.ewm(span=20, halflife=50) + with pytest.raises(ValueError, match=msg): + series.ewm(com=9.5, halflife=50) + with pytest.raises(ValueError, match=msg): + series.ewm(com=9.5, span=20, halflife=50) + msg = "Must pass one of comass, span, halflife, or alpha" + with pytest.raises(ValueError, match=msg): + series.ewm() + + +def test_ewm_alpha_arg(series): + # GH 10789 + s = series + msg = "Must pass one of comass, span, halflife, or alpha" + with pytest.raises(ValueError, match=msg): + s.ewm() + + msg = "comass, span, halflife, and alpha are mutually exclusive" + with pytest.raises(ValueError, match=msg): + s.ewm(com=10.0, alpha=0.5) + with pytest.raises(ValueError, match=msg): + s.ewm(span=10.0, alpha=0.5) + with pytest.raises(ValueError, match=msg): + s.ewm(halflife=10.0, alpha=0.5) + + +@pytest.mark.parametrize("func", ["cov", "corr"]) +def test_ewm_pairwise_cov_corr(func, frame): + result = getattr(frame.ewm(span=10, min_periods=5), func)() + result = result.loc[(slice(None), 1), 5] + result.index = result.index.droplevel(1) + expected = getattr(frame[1].ewm(span=10, min_periods=5), func)(frame[5]) + tm.assert_series_equal(result, expected, check_names=False) + + +def test_numeric_only_frame(arithmetic_win_operators, numeric_only): + # GH#46560 + kernel = arithmetic_win_operators + df = DataFrame({"a": [1], "b": 2, "c": 3}) + df["c"] = df["c"].astype(object) + ewm = df.ewm(span=2, min_periods=1) + op = getattr(ewm, kernel, None) + if op is not None: + result = op(numeric_only=numeric_only) + + columns = ["a", "b"] if numeric_only else ["a", "b", "c"] + expected = df[columns].agg([kernel]).reset_index(drop=True).astype(float) + assert list(expected.columns) == columns + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("kernel", ["corr", "cov"]) +@pytest.mark.parametrize("use_arg", [True, False]) +def test_numeric_only_corr_cov_frame(kernel, numeric_only, use_arg): + # GH#46560 + df = DataFrame({"a": [1, 2, 3], "b": 2, "c": 3}) + df["c"] = df["c"].astype(object) + arg = (df,) if use_arg else () + ewm = df.ewm(span=2, min_periods=1) + op = getattr(ewm, kernel) + result = op(*arg, numeric_only=numeric_only) + + # Compare result to op using float dtypes, dropping c when numeric_only is True + columns = ["a", "b"] if numeric_only else ["a", "b", "c"] + df2 = df[columns].astype(float) + arg2 = (df2,) if use_arg else () + ewm2 = df2.ewm(span=2, min_periods=1) + op2 = getattr(ewm2, kernel) + expected = op2(*arg2, numeric_only=numeric_only) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", [int, object]) +def test_numeric_only_series(arithmetic_win_operators, numeric_only, dtype): + # GH#46560 + kernel = arithmetic_win_operators + ser = Series([1], dtype=dtype) + ewm = ser.ewm(span=2, min_periods=1) + op = getattr(ewm, kernel, None) + if op is None: + # Nothing to test + pytest.skip("No op to test") + if numeric_only and dtype is object: + msg = f"ExponentialMovingWindow.{kernel} does not implement numeric_only" + with pytest.raises(NotImplementedError, match=msg): + op(numeric_only=numeric_only) + else: + result = op(numeric_only=numeric_only) + expected = ser.agg([kernel]).reset_index(drop=True).astype(float) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("kernel", ["corr", "cov"]) +@pytest.mark.parametrize("use_arg", [True, False]) +@pytest.mark.parametrize("dtype", [int, object]) +def test_numeric_only_corr_cov_series(kernel, use_arg, numeric_only, dtype): + # GH#46560 + ser = Series([1, 2, 3], dtype=dtype) + arg = (ser,) if use_arg else () + ewm = ser.ewm(span=2, min_periods=1) + op = getattr(ewm, kernel) + if numeric_only and dtype is object: + msg = f"ExponentialMovingWindow.{kernel} does not implement numeric_only" + with pytest.raises(NotImplementedError, match=msg): + op(*arg, numeric_only=numeric_only) + else: + result = op(*arg, numeric_only=numeric_only) + + ser2 = ser.astype(float) + arg2 = (ser2,) if use_arg else () + ewm2 = ser2.ewm(span=2, min_periods=1) + op2 = getattr(ewm2, kernel) + expected = op2(*arg2, numeric_only=numeric_only) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/window/test_expanding.py b/pandas/tests/window/test_expanding.py new file mode 100644 index 0000000000000000000000000000000000000000..d0bd68214bcba7f8365bfbfd9d77d9a2e6235400 --- /dev/null +++ b/pandas/tests/window/test_expanding.py @@ -0,0 +1,830 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + DatetimeIndex, + Index, + MultiIndex, + Series, + isna, + notna, +) +import pandas._testing as tm + + +def test_doc_string(): + df = DataFrame({"B": [0, 1, 2, np.nan, 4]}) + df + df.expanding(2).sum() + + +def test_constructor(frame_or_series): + # GH 12669 + + c = frame_or_series(range(5)).expanding + + # valid + c(min_periods=1) + + +@pytest.mark.parametrize("w", [2.0, "foo", np.array([2])]) +def test_constructor_invalid(frame_or_series, w): + # not valid + + c = frame_or_series(range(5)).expanding + msg = "min_periods must be an integer" + with pytest.raises(ValueError, match=msg): + c(min_periods=w) + + +@pytest.mark.parametrize( + "expander", + [ + 1, + pytest.param( + "ls", + marks=pytest.mark.xfail( + reason="GH#16425 expanding with offset not supported" + ), + ), + ], +) +def test_empty_df_expanding(expander): + # GH 15819 Verifies that datetime and integer expanding windows can be + # applied to empty DataFrames + + expected = DataFrame() + result = DataFrame().expanding(expander).sum() + tm.assert_frame_equal(result, expected) + + # Verifies that datetime and integer expanding windows can be applied + # to empty DataFrames with datetime index + expected = DataFrame(index=DatetimeIndex([])) + result = DataFrame(index=DatetimeIndex([])).expanding(expander).sum() + tm.assert_frame_equal(result, expected) + + +def test_missing_minp_zero(): + # https://github.com/pandas-dev/pandas/pull/18921 + # minp=0 + x = Series([np.nan]) + result = x.expanding(min_periods=0).sum() + expected = Series([0.0]) + tm.assert_series_equal(result, expected) + + # minp=1 + result = x.expanding(min_periods=1).sum() + expected = Series([np.nan]) + tm.assert_series_equal(result, expected) + + +def test_expanding(): + # see gh-23372. + df = DataFrame(np.ones((10, 20))) + + expected = DataFrame( + {i: [np.nan] * 2 + [float(j) for j in range(3, 11)] for i in range(20)} + ) + result = df.expanding(3).sum() + tm.assert_frame_equal(result, expected) + + +def test_expanding_count_with_min_periods(frame_or_series): + # GH 26996 + result = frame_or_series(range(5)).expanding(min_periods=3).count() + expected = frame_or_series([np.nan, np.nan, 3.0, 4.0, 5.0]) + tm.assert_equal(result, expected) + + +def test_expanding_count_default_min_periods_with_null_values(frame_or_series): + # GH 26996 + values = [1, 2, 3, np.nan, 4, 5, 6] + expected_counts = [1.0, 2.0, 3.0, 3.0, 4.0, 5.0, 6.0] + + result = frame_or_series(values).expanding().count() + expected = frame_or_series(expected_counts) + tm.assert_equal(result, expected) + + +def test_expanding_count_with_min_periods_exceeding_series_length(frame_or_series): + # GH 25857 + result = frame_or_series(range(5)).expanding(min_periods=6).count() + expected = frame_or_series([np.nan, np.nan, np.nan, np.nan, np.nan]) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "df,expected,min_periods", + [ + ( + {"A": [1, 2, 3], "B": [4, 5, 6]}, + [ + ({"A": [1], "B": [4]}, [0]), + ({"A": [1, 2], "B": [4, 5]}, [0, 1]), + ({"A": [1, 2, 3], "B": [4, 5, 6]}, [0, 1, 2]), + ], + 3, + ), + ( + {"A": [1, 2, 3], "B": [4, 5, 6]}, + [ + ({"A": [1], "B": [4]}, [0]), + ({"A": [1, 2], "B": [4, 5]}, [0, 1]), + ({"A": [1, 2, 3], "B": [4, 5, 6]}, [0, 1, 2]), + ], + 2, + ), + ( + {"A": [1, 2, 3], "B": [4, 5, 6]}, + [ + ({"A": [1], "B": [4]}, [0]), + ({"A": [1, 2], "B": [4, 5]}, [0, 1]), + ({"A": [1, 2, 3], "B": [4, 5, 6]}, [0, 1, 2]), + ], + 1, + ), + ({"A": [1], "B": [4]}, [], 2), + (None, [({}, [])], 1), + ( + {"A": [1, np.nan, 3], "B": [np.nan, 5, 6]}, + [ + ({"A": [1.0], "B": [np.nan]}, [0]), + ({"A": [1, np.nan], "B": [np.nan, 5]}, [0, 1]), + ({"A": [1, np.nan, 3], "B": [np.nan, 5, 6]}, [0, 1, 2]), + ], + 3, + ), + ( + {"A": [1, np.nan, 3], "B": [np.nan, 5, 6]}, + [ + ({"A": [1.0], "B": [np.nan]}, [0]), + ({"A": [1, np.nan], "B": [np.nan, 5]}, [0, 1]), + ({"A": [1, np.nan, 3], "B": [np.nan, 5, 6]}, [0, 1, 2]), + ], + 2, + ), + ( + {"A": [1, np.nan, 3], "B": [np.nan, 5, 6]}, + [ + ({"A": [1.0], "B": [np.nan]}, [0]), + ({"A": [1, np.nan], "B": [np.nan, 5]}, [0, 1]), + ({"A": [1, np.nan, 3], "B": [np.nan, 5, 6]}, [0, 1, 2]), + ], + 1, + ), + ], +) +def test_iter_expanding_dataframe(df, expected, min_periods): + # GH 11704 + df = DataFrame(df) + expecteds = [DataFrame(values, index=index) for (values, index) in expected] + + for expected, actual in zip(expecteds, df.expanding(min_periods), strict=False): + tm.assert_frame_equal(actual, expected) + + +@pytest.mark.parametrize( + "ser,expected,min_periods", + [ + (Series([1, 2, 3]), [([1], [0]), ([1, 2], [0, 1]), ([1, 2, 3], [0, 1, 2])], 3), + (Series([1, 2, 3]), [([1], [0]), ([1, 2], [0, 1]), ([1, 2, 3], [0, 1, 2])], 2), + (Series([1, 2, 3]), [([1], [0]), ([1, 2], [0, 1]), ([1, 2, 3], [0, 1, 2])], 1), + (Series([1, 2]), [([1], [0]), ([1, 2], [0, 1])], 2), + (Series([np.nan, 2]), [([np.nan], [0]), ([np.nan, 2], [0, 1])], 2), + (Series([], dtype="int64"), [], 2), + ], +) +def test_iter_expanding_series(ser, expected, min_periods): + # GH 11704 + expecteds = [Series(values, index=index) for (values, index) in expected] + + for expected, actual in zip(expecteds, ser.expanding(min_periods), strict=True): + tm.assert_series_equal(actual, expected) + + +def test_center_invalid(): + # GH 20647 + df = DataFrame() + with pytest.raises(TypeError, match=".* got an unexpected keyword"): + df.expanding(center=True) + + +def test_expanding_sem(frame_or_series): + # GH: 26476 + obj = frame_or_series([0, 1, 2]) + result = obj.expanding().sem() + if isinstance(result, DataFrame): + result = Series(result[0].values) + expected = Series([np.nan, 0.5, (1 / 3) ** 0.5]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["skew", "kurt"]) +def test_expanding_skew_kurt_numerical_stability(method): + # GH: 6929 + s = Series(np.random.default_rng(2).random(10)) + expected = getattr(s.expanding(3), method)() + s = s + 5000 + result = getattr(s.expanding(3), method)() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("window", [1, 3, 10, 20]) +@pytest.mark.parametrize("method", ["min", "max", "average"]) +@pytest.mark.parametrize("pct", [True, False]) +@pytest.mark.parametrize("test_data", ["default", "duplicates", "nans"]) +def test_rank(window, method, pct, ascending, test_data): + length = 20 + if test_data == "default": + ser = Series(data=np.random.default_rng(2).random(length)) + elif test_data == "duplicates": + ser = Series(data=np.random.default_rng(2).choice(3, length)) + elif test_data == "nans": + ser = Series( + data=np.random.default_rng(2).choice( + [1.0, 0.25, 0.75, np.nan, np.inf, -np.inf], length + ) + ) + + expected = ser.expanding(window).apply( + lambda x: x.rank(method=method, pct=pct, ascending=ascending).iloc[-1] + ) + result = ser.expanding(window).rank(method=method, pct=pct, ascending=ascending) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("window", [1, 3, 10, 20]) +@pytest.mark.parametrize("test_data", ["default", "duplicates", "nans", "precision"]) +def test_nunique(window, test_data): + length = 20 + if test_data == "default": + ser = Series(data=np.random.default_rng(2).random(length)) + elif test_data == "duplicates": + ser = Series(data=np.random.default_rng(2).choice(3, length)) + elif test_data == "nans": + ser = Series( + data=np.random.default_rng(2).choice( + [1.0, 0.25, 0.75, np.nan, np.inf, -np.inf], length + ) + ) + elif test_data == "precision": + ser = Series( + data=[ + 0.3, + 0.1 * 3, # Not necessarily exactly 0.3 + 0.6, + 0.2 * 3, # Not necessarily exactly 0.6 + 0.9, + 0.3 * 3, # Not necessarily exactly 0.9 + 0.5, + 0.1 * 5, # Not necessarily exactly 0.5 + 0.8, + 0.2 * 4, # Not necessarily exactly 0.8 + ], + dtype=np.float64, + ) + + expected = ser.expanding(window).apply(lambda x: x.nunique()) + result = ser.expanding(window).nunique() + + tm.assert_series_equal(result, expected) + + +def test_expanding_corr(series): + A = series.dropna() + B = (A + np.random.default_rng(2).standard_normal(len(A)))[:-5] + + result = A.expanding().corr(B) + + rolling_result = A.rolling(window=len(A), min_periods=1).corr(B) + + tm.assert_almost_equal(rolling_result, result) + + +def test_expanding_count(series): + result = series.expanding(min_periods=0).count() + tm.assert_almost_equal( + result, series.rolling(window=len(series), min_periods=0).count() + ) + + +def test_expanding_quantile(series): + result = series.expanding().quantile(0.5) + + rolling_result = series.rolling(window=len(series), min_periods=1).quantile(0.5) + + tm.assert_almost_equal(result, rolling_result) + + +def test_expanding_cov(series): + A = series + B = (A + np.random.default_rng(2).standard_normal(len(A)))[:-5] + + result = A.expanding().cov(B) + + rolling_result = A.rolling(window=len(A), min_periods=1).cov(B) + + tm.assert_almost_equal(rolling_result, result) + + +def test_expanding_cov_pairwise(frame): + result = frame.expanding().cov() + + rolling_result = frame.rolling(window=len(frame), min_periods=1).cov() + + tm.assert_frame_equal(result, rolling_result) + + +def test_expanding_corr_pairwise(frame): + result = frame.expanding().corr() + + rolling_result = frame.rolling(window=len(frame), min_periods=1).corr() + tm.assert_frame_equal(result, rolling_result) + + +@pytest.mark.parametrize( + "func,static_comp", + [ + ("sum", lambda x: np.sum(x, axis=0)), + ("mean", lambda x: np.mean(x, axis=0)), + ("max", lambda x: np.max(x, axis=0)), + ("min", lambda x: np.min(x, axis=0)), + ], + ids=["sum", "mean", "max", "min"], +) +def test_expanding_func(func, static_comp, frame_or_series): + data = frame_or_series(np.array(list(range(10)) + [np.nan] * 10)) + + obj = data.expanding(min_periods=1) + result = getattr(obj, func)() + assert isinstance(result, frame_or_series) + + expected = static_comp(data[:11]) + if frame_or_series is Series: + tm.assert_almost_equal(result[10], expected) + else: + tm.assert_series_equal(result.iloc[10], expected, check_names=False) + + +@pytest.mark.parametrize( + "func,static_comp", + [("sum", np.sum), ("mean", np.mean), ("max", np.max), ("min", np.min)], + ids=["sum", "mean", "max", "min"], +) +def test_expanding_min_periods(func, static_comp): + ser = Series(np.random.default_rng(2).standard_normal(50)) + + result = getattr(ser.expanding(min_periods=30), func)() + assert result[:29].isna().all() + tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50])) + + # min_periods is working correctly + result = getattr(ser.expanding(min_periods=15), func)() + assert isna(result.iloc[13]) + assert notna(result.iloc[14]) + + ser2 = Series(np.random.default_rng(2).standard_normal(20)) + result = getattr(ser2.expanding(min_periods=5), func)() + assert isna(result[3]) + assert notna(result[4]) + + # min_periods=0 + result0 = getattr(ser.expanding(min_periods=0), func)() + result1 = getattr(ser.expanding(min_periods=1), func)() + tm.assert_almost_equal(result0, result1) + + result = getattr(ser.expanding(min_periods=1), func)() + tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50])) + + +def test_expanding_apply(engine_and_raw, frame_or_series): + engine, raw = engine_and_raw + data = frame_or_series(np.array(list(range(10)) + [np.nan] * 10)) + result = data.expanding(min_periods=1).apply( + lambda x: x.mean(), raw=raw, engine=engine + ) + assert isinstance(result, frame_or_series) + + if frame_or_series is Series: + tm.assert_almost_equal(result[9], np.mean(data[:11], axis=0)) + else: + tm.assert_series_equal( + result.iloc[9], np.mean(data[:11], axis=0), check_names=False + ) + + +def test_expanding_min_periods_apply(engine_and_raw): + engine, raw = engine_and_raw + ser = Series(np.random.default_rng(2).standard_normal(50)) + + result = ser.expanding(min_periods=30).apply( + lambda x: x.mean(), raw=raw, engine=engine + ) + assert result[:29].isna().all() + tm.assert_almost_equal(result.iloc[-1], np.mean(ser[:50])) + + # min_periods is working correctly + result = ser.expanding(min_periods=15).apply( + lambda x: x.mean(), raw=raw, engine=engine + ) + assert isna(result.iloc[13]) + assert notna(result.iloc[14]) + + ser2 = Series(np.random.default_rng(2).standard_normal(20)) + result = ser2.expanding(min_periods=5).apply( + lambda x: x.mean(), raw=raw, engine=engine + ) + assert isna(result[3]) + assert notna(result[4]) + + # min_periods=0 + result0 = ser.expanding(min_periods=0).apply( + lambda x: x.mean(), raw=raw, engine=engine + ) + result1 = ser.expanding(min_periods=1).apply( + lambda x: x.mean(), raw=raw, engine=engine + ) + tm.assert_almost_equal(result0, result1) + + result = ser.expanding(min_periods=1).apply( + lambda x: x.mean(), raw=raw, engine=engine + ) + tm.assert_almost_equal(result.iloc[-1], np.mean(ser[:50])) + + +@pytest.mark.parametrize( + "f", + [ + lambda x: (x.expanding(min_periods=5).cov(x, pairwise=True)), + lambda x: (x.expanding(min_periods=5).corr(x, pairwise=True)), + ], +) +def test_moment_functions_zero_length_pairwise(f): + df1 = DataFrame() + df2 = DataFrame(columns=Index(["a"], name="foo"), index=Index([], name="bar")) + df2["a"] = df2["a"].astype("float64") + + df1_expected = DataFrame(index=MultiIndex.from_product([df1.index, df1.columns])) + df2_expected = DataFrame( + index=MultiIndex.from_product([df2.index, df2.columns], names=["bar", "foo"]), + columns=Index(["a"], name="foo"), + dtype="float64", + ) + + df1_result = f(df1) + tm.assert_frame_equal(df1_result, df1_expected) + + df2_result = f(df2) + tm.assert_frame_equal(df2_result, df2_expected) + + +@pytest.mark.parametrize( + "f", + [ + lambda x: x.expanding().count(), + lambda x: x.expanding(min_periods=5).cov(x, pairwise=False), + lambda x: x.expanding(min_periods=5).corr(x, pairwise=False), + lambda x: x.expanding(min_periods=5).max(), + lambda x: x.expanding(min_periods=5).min(), + lambda x: x.expanding(min_periods=5).first(), + lambda x: x.expanding(min_periods=5).last(), + lambda x: x.expanding(min_periods=5).sum(), + lambda x: x.expanding(min_periods=5).mean(), + lambda x: x.expanding(min_periods=5).std(), + lambda x: x.expanding(min_periods=5).var(), + lambda x: x.expanding(min_periods=5).skew(), + lambda x: x.expanding(min_periods=5).kurt(), + lambda x: x.expanding(min_periods=5).quantile(0.5), + lambda x: x.expanding(min_periods=5).median(), + lambda x: x.expanding(min_periods=5).apply(sum, raw=False), + lambda x: x.expanding(min_periods=5).apply(sum, raw=True), + ], +) +def test_moment_functions_zero_length(f): + # GH 8056 + s = Series(dtype=np.float64) + s_expected = s + df1 = DataFrame() + df1_expected = df1 + df2 = DataFrame(columns=["a"]) + df2["a"] = df2["a"].astype("float64") + df2_expected = df2 + + s_result = f(s) + tm.assert_series_equal(s_result, s_expected) + + df1_result = f(df1) + tm.assert_frame_equal(df1_result, df1_expected) + + df2_result = f(df2) + tm.assert_frame_equal(df2_result, df2_expected) + + +def test_expanding_apply_empty_series(engine_and_raw): + engine, raw = engine_and_raw + ser = Series([], dtype=np.float64) + tm.assert_series_equal( + ser, ser.expanding().apply(lambda x: x.mean(), raw=raw, engine=engine) + ) + + +def test_expanding_apply_min_periods_0(engine_and_raw): + # GH 8080 + engine, raw = engine_and_raw + s = Series([None, None, None]) + result = s.expanding(min_periods=0).apply(lambda x: len(x), raw=raw, engine=engine) + expected = Series([1.0, 2.0, 3.0]) + tm.assert_series_equal(result, expected) + + +def test_expanding_cov_diff_index(): + # GH 7512 + s1 = Series([1, 2, 3], index=range(3)) + s2 = Series([1, 3], index=range(0, 4, 2)) + result = s1.expanding().cov(s2) + expected = Series([None, None, 2.0]) + tm.assert_series_equal(result, expected) + + s2a = Series([1, None, 3], index=[0, 1, 2]) + result = s1.expanding().cov(s2a) + tm.assert_series_equal(result, expected) + + s1 = Series([7, 8, 10], index=[0, 1, 3]) + s2 = Series([7, 9, 10], index=[0, 2, 3]) + result = s1.expanding().cov(s2) + expected = Series([None, None, None, 4.5], index=list(range(4))) + tm.assert_series_equal(result, expected) + + +def test_expanding_corr_diff_index(): + # GH 7512 + s1 = Series([1, 2, 3], index=range(3)) + s2 = Series([1, 3], index=range(0, 4, 2)) + result = s1.expanding().corr(s2) + expected = Series([None, None, 1.0]) + tm.assert_series_equal(result, expected) + + s2a = Series([1, None, 3], index=[0, 1, 2]) + result = s1.expanding().corr(s2a) + tm.assert_series_equal(result, expected) + + s1 = Series([7, 8, 10], index=[0, 1, 3]) + s2 = Series([7, 9, 10], index=[0, 2, 3]) + result = s1.expanding().corr(s2) + expected = Series([None, None, None, 1.0], index=list(range(4))) + tm.assert_series_equal(result, expected) + + +def test_expanding_cov_pairwise_diff_length(): + # GH 7512 + df1 = DataFrame([[1, 5], [3, 2], [3, 9]], columns=Index(["A", "B"], name="foo")) + df1a = DataFrame( + [[1, 5], [3, 9]], index=[0, 2], columns=Index(["A", "B"], name="foo") + ) + df2 = DataFrame( + [[5, 6], [None, None], [2, 1]], columns=Index(["X", "Y"], name="foo") + ) + df2a = DataFrame( + [[5, 6], [2, 1]], index=[0, 2], columns=Index(["X", "Y"], name="foo") + ) + # xref gh-15826 + # .loc is not preserving the names + result1 = df1.expanding().cov(df2, pairwise=True).loc[2] + result2 = df1.expanding().cov(df2a, pairwise=True).loc[2] + result3 = df1a.expanding().cov(df2, pairwise=True).loc[2] + result4 = df1a.expanding().cov(df2a, pairwise=True).loc[2] + expected = DataFrame( + [[-3.0, -6.0], [-5.0, -10.0]], + columns=Index(["A", "B"], name="foo"), + index=Index(["X", "Y"], name="foo"), + ) + tm.assert_frame_equal(result1, expected) + tm.assert_frame_equal(result2, expected) + tm.assert_frame_equal(result3, expected) + tm.assert_frame_equal(result4, expected) + + +def test_expanding_corr_pairwise_diff_length(): + # GH 7512 + df1 = DataFrame( + [[1, 2], [3, 2], [3, 4]], columns=["A", "B"], index=Index(range(3), name="bar") + ) + df1a = DataFrame( + [[1, 2], [3, 4]], index=Index([0, 2], name="bar"), columns=["A", "B"] + ) + df2 = DataFrame( + [[5, 6], [None, None], [2, 1]], + columns=["X", "Y"], + index=Index(range(3), name="bar"), + ) + df2a = DataFrame( + [[5, 6], [2, 1]], index=Index([0, 2], name="bar"), columns=["X", "Y"] + ) + result1 = df1.expanding().corr(df2, pairwise=True).loc[2] + result2 = df1.expanding().corr(df2a, pairwise=True).loc[2] + result3 = df1a.expanding().corr(df2, pairwise=True).loc[2] + result4 = df1a.expanding().corr(df2a, pairwise=True).loc[2] + expected = DataFrame( + [[-1.0, -1.0], [-1.0, -1.0]], columns=["A", "B"], index=Index(["X", "Y"]) + ) + tm.assert_frame_equal(result1, expected) + tm.assert_frame_equal(result2, expected) + tm.assert_frame_equal(result3, expected) + tm.assert_frame_equal(result4, expected) + + +@pytest.mark.parametrize( + "values,method,expected", + [ + ( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + "first", + [float("nan"), float("nan"), 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + ), + ( + [1.0, np.nan, 3.0, np.nan, 5.0, np.nan, 7.0, np.nan, 9.0, np.nan], + "first", + [ + float("nan"), + float("nan"), + float("nan"), + float("nan"), + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], + ), + ( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + "last", + [float("nan"), float("nan"), 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + ), + ( + [1.0, np.nan, 3.0, np.nan, 5.0, np.nan, 7.0, np.nan, 9.0, np.nan], + "last", + [ + float("nan"), + float("nan"), + float("nan"), + float("nan"), + 5.0, + 5.0, + 7.0, + 7.0, + 9.0, + 9.0, + ], + ), + ], +) +def test_expanding_first_last(values, method, expected): + # GH#33155 + x = Series(values) + result = getattr(x.expanding(3), method)() + expected = Series(expected) + tm.assert_almost_equal(result, expected) + + x = DataFrame({"A": values}) + result = getattr(x.expanding(3), method)() + expected = DataFrame({"A": expected}) + tm.assert_almost_equal(result, expected) + + +@pytest.mark.parametrize( + "values,method,expected", + [ + ( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + "first", + [1.0] * 10, + ), + ( + [1.0, np.nan, 3.0, np.nan, 5.0, np.nan, 7.0, np.nan, 9.0, np.nan], + "first", + [1.0] * 10, + ), + ( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + "last", + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + ), + ( + [1.0, np.nan, 3.0, np.nan, 5.0, np.nan, 7.0, np.nan, 9.0, np.nan], + "last", + [1.0, 1.0, 3.0, 3.0, 5.0, 5.0, 7.0, 7.0, 9.0, 9.0], + ), + ], +) +def test_expanding_first_last_no_minp(values, method, expected): + # GH#33155 + x = Series(values) + result = getattr(x.expanding(min_periods=0), method)() + expected = Series(expected) + tm.assert_almost_equal(result, expected) + + x = DataFrame({"A": values}) + result = getattr(x.expanding(min_periods=0), method)() + expected = DataFrame({"A": expected}) + tm.assert_almost_equal(result, expected) + + +def test_expanding_apply_args_kwargs(engine_and_raw): + def mean_w_arg(x, const): + return np.mean(x) + const + + engine, raw = engine_and_raw + + df = DataFrame(np.random.default_rng(2).random((20, 3))) + + expected = df.expanding().apply(np.mean, engine=engine, raw=raw) + 20.0 + + result = df.expanding().apply(mean_w_arg, engine=engine, raw=raw, args=(20,)) + tm.assert_frame_equal(result, expected) + + result = df.expanding().apply(mean_w_arg, raw=raw, kwargs={"const": 20}) + tm.assert_frame_equal(result, expected) + + +def test_numeric_only_frame(arithmetic_win_operators, numeric_only): + # GH#46560 + kernel = arithmetic_win_operators + df = DataFrame({"a": [1], "b": 2, "c": 3}) + df["c"] = df["c"].astype(object) + expanding = df.expanding() + op = getattr(expanding, kernel, None) + if op is not None: + result = op(numeric_only=numeric_only) + + columns = ["a", "b"] if numeric_only else ["a", "b", "c"] + expected = df[columns].agg([kernel]).reset_index(drop=True).astype(float) + assert list(expected.columns) == columns + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("kernel", ["corr", "cov"]) +@pytest.mark.parametrize("use_arg", [True, False]) +def test_numeric_only_corr_cov_frame(kernel, numeric_only, use_arg): + # GH#46560 + df = DataFrame({"a": [1, 2, 3], "b": 2, "c": 3}) + df["c"] = df["c"].astype(object) + arg = (df,) if use_arg else () + expanding = df.expanding() + op = getattr(expanding, kernel) + result = op(*arg, numeric_only=numeric_only) + + # Compare result to op using float dtypes, dropping c when numeric_only is True + columns = ["a", "b"] if numeric_only else ["a", "b", "c"] + df2 = df[columns].astype(float) + arg2 = (df2,) if use_arg else () + expanding2 = df2.expanding() + op2 = getattr(expanding2, kernel) + expected = op2(*arg2, numeric_only=numeric_only) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", [int, object]) +def test_numeric_only_series(arithmetic_win_operators, numeric_only, dtype): + # GH#46560 + kernel = arithmetic_win_operators + ser = Series([1], dtype=dtype) + expanding = ser.expanding() + op = getattr(expanding, kernel) + if numeric_only and dtype is object: + msg = f"Expanding.{kernel} does not implement numeric_only" + with pytest.raises(NotImplementedError, match=msg): + op(numeric_only=numeric_only) + else: + result = op(numeric_only=numeric_only) + expected = ser.agg([kernel]).reset_index(drop=True).astype(float) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("kernel", ["corr", "cov"]) +@pytest.mark.parametrize("use_arg", [True, False]) +@pytest.mark.parametrize("dtype", [int, object]) +def test_numeric_only_corr_cov_series(kernel, use_arg, numeric_only, dtype): + # GH#46560 + ser = Series([1, 2, 3], dtype=dtype) + arg = (ser,) if use_arg else () + expanding = ser.expanding() + op = getattr(expanding, kernel) + if numeric_only and dtype is object: + msg = f"Expanding.{kernel} does not implement numeric_only" + with pytest.raises(NotImplementedError, match=msg): + op(*arg, numeric_only=numeric_only) + else: + result = op(*arg, numeric_only=numeric_only) + + ser2 = ser.astype(float) + arg2 = (ser2,) if use_arg else () + expanding2 = ser2.expanding() + op2 = getattr(expanding2, kernel) + expected = op2(*arg2, numeric_only=numeric_only) + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/window/test_groupby.py b/pandas/tests/window/test_groupby.py new file mode 100644 index 0000000000000000000000000000000000000000..543ae095b1cb432fff6e85d422f984a741ba06d2 --- /dev/null +++ b/pandas/tests/window/test_groupby.py @@ -0,0 +1,1389 @@ +import numpy as np +import pytest + +from pandas.errors import Pandas4Warning + +from pandas import ( + DataFrame, + DatetimeIndex, + Index, + MultiIndex, + NamedAgg, + Series, + Timestamp, + date_range, + to_datetime, +) +import pandas._testing as tm +from pandas.api.indexers import BaseIndexer +from pandas.core.groupby.groupby import get_groupby + + +@pytest.fixture +def times_frame(): + """Frame for testing times argument in EWM groupby.""" + return DataFrame( + { + "A": ["a", "b", "c", "a", "b", "c", "a", "b", "c", "a"], + "B": [0, 0, 0, 1, 1, 1, 2, 2, 2, 3], + "C": to_datetime( + [ + "2020-01-01", + "2020-01-01", + "2020-01-01", + "2020-01-02", + "2020-01-10", + "2020-01-22", + "2020-01-03", + "2020-01-23", + "2020-01-23", + "2020-01-04", + ] + ), + } + ) + + +@pytest.fixture +def roll_frame(): + return DataFrame({"A": [1] * 20 + [2] * 12 + [3] * 8, "B": np.arange(40)}) + + +class TestRolling: + def test_groupby_unsupported_argument(self, roll_frame): + msg = r"groupby\(\) got an unexpected keyword argument 'foo'" + with pytest.raises(TypeError, match=msg): + roll_frame.groupby("A", foo=1) + + def test_getitem(self, roll_frame): + g = roll_frame.groupby("A") + g_mutated = get_groupby(roll_frame, by="A") + + expected = g_mutated.B.apply(lambda x: x.rolling(2).mean()) + + result = g.rolling(2).mean().B + tm.assert_series_equal(result, expected) + + result = g.rolling(2).B.mean() + tm.assert_series_equal(result, expected) + + result = g.B.rolling(2).mean() + tm.assert_series_equal(result, expected) + + result = roll_frame.B.groupby(roll_frame.A).rolling(2).mean() + tm.assert_series_equal(result, expected) + + def test_getitem_multiple(self, roll_frame): + # GH 13174 + g = roll_frame.groupby("A") + r = g.rolling(2, min_periods=0) + g_mutated = get_groupby(roll_frame, by="A") + expected = g_mutated.B.apply(lambda x: x.rolling(2, min_periods=0).count()) + + result = r.B.count() + tm.assert_series_equal(result, expected) + + result = r.B.count() + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "f", + [ + "sum", + "mean", + "min", + "max", + "first", + "last", + "count", + "kurt", + "skew", + "nunique", + ], + ) + def test_rolling(self, f, roll_frame): + g = roll_frame.groupby("A", group_keys=False) + r = g.rolling(window=4) + + result = getattr(r, f)() + expected = g.apply(lambda x: getattr(x.rolling(4), f)()) + # GH 39732 + expected_index = MultiIndex.from_arrays([roll_frame["A"], range(40)]) + expected.index = expected_index + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("f", ["std", "var"]) + def test_rolling_ddof(self, f, roll_frame): + g = roll_frame.groupby("A", group_keys=False) + r = g.rolling(window=4) + + result = getattr(r, f)(ddof=1) + expected = g.apply(lambda x: getattr(x.rolling(4), f)(ddof=1)) + # GH 39732 + expected_index = MultiIndex.from_arrays([roll_frame["A"], range(40)]) + expected.index = expected_index + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "midpoint", "nearest"] + ) + def test_rolling_quantile(self, interpolation, roll_frame): + g = roll_frame.groupby("A", group_keys=False) + r = g.rolling(window=4) + + result = r.quantile(0.4, interpolation=interpolation) + expected = g.apply( + lambda x: x.rolling(4).quantile(0.4, interpolation=interpolation) + ) + # GH 39732 + expected_index = MultiIndex.from_arrays([roll_frame["A"], range(40)]) + expected.index = expected_index + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("f, expected_val", [["corr", 1], ["cov", 0.5]]) + def test_rolling_corr_cov_other_same_size_as_groups(self, f, expected_val): + # GH 42915 + df = DataFrame( + {"value": range(10), "idx1": [1] * 5 + [2] * 5, "idx2": [1, 2, 3, 4, 5] * 2} + ).set_index(["idx1", "idx2"]) + other = DataFrame({"value": range(5), "idx2": [1, 2, 3, 4, 5]}).set_index( + "idx2" + ) + result = getattr(df.groupby(level=0).rolling(2), f)(other) + expected_data = ([np.nan] + [expected_val] * 4) * 2 + expected = DataFrame( + expected_data, + columns=["value"], + index=MultiIndex.from_arrays( + [ + [1] * 5 + [2] * 5, + [1] * 5 + [2] * 5, + list(range(1, 6)) * 2, + ], + names=["idx1", "idx1", "idx2"], + ), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("f", ["corr", "cov"]) + def test_rolling_corr_cov_other_diff_size_as_groups(self, f, roll_frame): + g = roll_frame.groupby("A") + r = g.rolling(window=4) + + result = getattr(r, f)(roll_frame) + + def func(x): + return getattr(x.rolling(4), f)(roll_frame) + + expected = g.apply(func) + # GH 39591: The grouped column should be all np.nan + # (groupby.apply inserts 0s for cov) + expected["A"] = np.nan + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("f", ["corr", "cov"]) + def test_rolling_corr_cov_pairwise(self, f, roll_frame): + g = roll_frame.groupby("A") + r = g.rolling(window=4) + + result = getattr(r.B, f)(pairwise=True) + + def func(x): + return getattr(x.B.rolling(4), f)(pairwise=True) + + expected = g.apply(func) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "func, expected_values", + [("cov", [[1.0, 1.0], [1.0, 4.0]]), ("corr", [[1.0, 0.5], [0.5, 1.0]])], + ) + def test_rolling_corr_cov_unordered(self, func, expected_values): + # GH 43386 + df = DataFrame( + { + "a": ["g1", "g2", "g1", "g1"], + "b": [0, 0, 1, 2], + "c": [2, 0, 6, 4], + } + ) + rol = df.groupby("a").rolling(3) + result = getattr(rol, func)() + expected = DataFrame( + { + "b": 4 * [np.nan] + expected_values[0] + 2 * [np.nan], + "c": 4 * [np.nan] + expected_values[1] + 2 * [np.nan], + }, + index=MultiIndex.from_tuples( + [ + ("g1", 0, "b"), + ("g1", 0, "c"), + ("g1", 2, "b"), + ("g1", 2, "c"), + ("g1", 3, "b"), + ("g1", 3, "c"), + ("g2", 1, "b"), + ("g2", 1, "c"), + ], + names=["a", None, None], + ), + ) + tm.assert_frame_equal(result, expected) + + def test_rolling_apply(self, raw, roll_frame): + g = roll_frame.groupby("A", group_keys=False) + r = g.rolling(window=4) + + # reduction + result = r.apply(lambda x: x.sum(), raw=raw) + expected = g.apply(lambda x: x.rolling(4).apply(lambda y: y.sum(), raw=raw)) + # GH 39732 + expected_index = MultiIndex.from_arrays([roll_frame["A"], range(40)]) + expected.index = expected_index + tm.assert_frame_equal(result, expected) + + def test_rolling_apply_mutability(self): + # GH 14013 + df = DataFrame({"A": ["foo"] * 3 + ["bar"] * 3, "B": [1] * 6}) + g = df.groupby("A") + + mi = MultiIndex.from_tuples( + [("bar", 3), ("bar", 4), ("bar", 5), ("foo", 0), ("foo", 1), ("foo", 2)] + ) + + mi.names = ["A", None] + # Grouped column should not be a part of the output + expected = DataFrame([np.nan, 2.0, 2.0] * 2, columns=["B"], index=mi) + + result = g.rolling(window=2).sum() + tm.assert_frame_equal(result, expected) + + # Call an arbitrary function on the groupby + g.sum() + + # Make sure nothing has been mutated + result = g.rolling(window=2).sum() + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("expected_value,raw_value", [[1.0, True], [0.0, False]]) + def test_groupby_rolling(self, expected_value, raw_value): + # GH 31754 + + def isnumpyarray(x): + return int(isinstance(x, np.ndarray)) + + df = DataFrame({"id": [1, 1, 1], "value": [1, 2, 3]}) + result = df.groupby("id").value.rolling(1).apply(isnumpyarray, raw=raw_value) + expected = Series( + [expected_value] * 3, + index=MultiIndex.from_tuples(((1, 0), (1, 1), (1, 2)), names=["id", None]), + name="value", + ) + tm.assert_series_equal(result, expected) + + def test_groupby_rolling_center_center(self): + # GH 35552 + series = Series(range(1, 6)) + result = series.groupby(series).rolling(center=True, window=3).mean() + expected = Series( + [np.nan] * 5, + index=MultiIndex.from_tuples(((1, 0), (2, 1), (3, 2), (4, 3), (5, 4))), + ) + tm.assert_series_equal(result, expected) + + series = Series(range(1, 5)) + result = series.groupby(series).rolling(center=True, window=3).mean() + expected = Series( + [np.nan] * 4, + index=MultiIndex.from_tuples(((1, 0), (2, 1), (3, 2), (4, 3))), + ) + tm.assert_series_equal(result, expected) + + df = DataFrame({"a": ["a"] * 5 + ["b"] * 6, "b": range(11)}) + result = df.groupby("a").rolling(center=True, window=3).mean() + expected = DataFrame( + [np.nan, 1, 2, 3, np.nan, np.nan, 6, 7, 8, 9, np.nan], + index=MultiIndex.from_tuples( + ( + ("a", 0), + ("a", 1), + ("a", 2), + ("a", 3), + ("a", 4), + ("b", 5), + ("b", 6), + ("b", 7), + ("b", 8), + ("b", 9), + ("b", 10), + ), + names=["a", None], + ), + columns=["b"], + ) + tm.assert_frame_equal(result, expected) + + df = DataFrame({"a": ["a"] * 5 + ["b"] * 5, "b": range(10)}) + result = df.groupby("a").rolling(center=True, window=3).mean() + expected = DataFrame( + [np.nan, 1, 2, 3, np.nan, np.nan, 6, 7, 8, np.nan], + index=MultiIndex.from_tuples( + ( + ("a", 0), + ("a", 1), + ("a", 2), + ("a", 3), + ("a", 4), + ("b", 5), + ("b", 6), + ("b", 7), + ("b", 8), + ("b", 9), + ), + names=["a", None], + ), + columns=["b"], + ) + tm.assert_frame_equal(result, expected) + + def test_groupby_rolling_center_on(self): + # GH 37141 + df = DataFrame( + data={ + "Date": date_range("2020-01-01", "2020-01-10"), + "gb": ["group_1"] * 6 + ["group_2"] * 4, + "value": range(10), + } + ) + result = ( + df.groupby("gb") + .rolling(6, on="Date", center=True, min_periods=1) + .value.mean() + ) + mi = MultiIndex.from_arrays([df["gb"], df["Date"]], names=["gb", "Date"]) + expected = Series( + [1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 7.0, 7.5, 7.5, 7.5], + name="value", + index=mi, + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("min_periods", [5, 4, 3]) + def test_groupby_rolling_center_min_periods(self, min_periods): + # GH 36040 + df = DataFrame({"group": ["A"] * 10 + ["B"] * 10, "data": range(20)}) + + window_size = 5 + result = ( + df.groupby("group") + .rolling(window_size, center=True, min_periods=min_periods) + .mean() + ) + result = result.reset_index()[["group", "data"]] + + grp_A_mean = [1.0, 1.5, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 7.5, 8.0] + grp_B_mean = [x + 10.0 for x in grp_A_mean] + + num_nans = max(0, min_periods - 3) # For window_size of 5 + nans = [np.nan] * num_nans + grp_A_expected = nans + grp_A_mean[num_nans : 10 - num_nans] + nans + grp_B_expected = nans + grp_B_mean[num_nans : 10 - num_nans] + nans + + expected = DataFrame( + {"group": ["A"] * 10 + ["B"] * 10, "data": grp_A_expected + grp_B_expected} + ) + + tm.assert_frame_equal(result, expected) + + def test_groupby_subselect_rolling(self): + # GH 35486 + df = DataFrame( + {"a": [1, 2, 3, 2], "b": [4.0, 2.0, 3.0, 1.0], "c": [10, 20, 30, 20]} + ) + result = df.groupby("a")[["b"]].rolling(2).max() + expected = DataFrame( + [np.nan, np.nan, 2.0, np.nan], + columns=["b"], + index=MultiIndex.from_tuples( + ((1, 0), (2, 1), (2, 3), (3, 2)), names=["a", None] + ), + ) + tm.assert_frame_equal(result, expected) + + result = df.groupby("a")["b"].rolling(2).max() + expected = Series( + [np.nan, np.nan, 2.0, np.nan], + index=MultiIndex.from_tuples( + ((1, 0), (2, 1), (2, 3), (3, 2)), names=["a", None] + ), + name="b", + ) + tm.assert_series_equal(result, expected) + + def test_groupby_rolling_custom_indexer(self): + # GH 35557 + class SimpleIndexer(BaseIndexer): + def get_window_bounds( + self, + num_values=0, + min_periods=None, + center=None, + closed=None, + step=None, + ): + min_periods = self.window_size if min_periods is None else 0 + end = np.arange(num_values, dtype=np.int64) + 1 + start = end - self.window_size + start[start < 0] = min_periods + return start, end + + df = DataFrame( + {"a": [1.0, 2.0, 3.0, 4.0, 5.0] * 3}, index=[0] * 5 + [1] * 5 + [2] * 5 + ) + result = ( + df.groupby(df.index) + .rolling(SimpleIndexer(window_size=3), min_periods=1) + .sum() + ) + expected = df.groupby(df.index).rolling(window=3, min_periods=1).sum() + tm.assert_frame_equal(result, expected) + + def test_groupby_rolling_subset_with_closed(self): + # GH 35549 + df = DataFrame( + { + "column1": range(8), + "column2": range(8), + "group": ["A"] * 4 + ["B"] * 4, + "date": [ + Timestamp(date) + for date in ["2019-01-01", "2019-01-01", "2019-01-02", "2019-01-02"] + ] + * 2, + } + ) + result = ( + df.groupby("group").rolling("1D", on="date", closed="left")["column1"].sum() + ) + expected = Series( + [np.nan, np.nan, 1.0, 1.0, np.nan, np.nan, 9.0, 9.0], + index=MultiIndex.from_frame( + df[["group", "date"]], + names=["group", "date"], + ), + name="column1", + ) + tm.assert_series_equal(result, expected) + + def test_groupby_rolling_agg_namedagg(self): + # GH#28333 + df = DataFrame( + { + "kind": ["cat", "dog", "cat", "dog", "cat", "dog"], + "height": [9.1, 6.0, 9.5, 34.0, 12.0, 8.0], + "weight": [7.9, 7.5, 9.9, 198.0, 10.0, 42.0], + } + ) + result = ( + df.groupby("kind") + .rolling(2) + .agg( + total_weight=NamedAgg(column="weight", aggfunc=sum), + min_height=NamedAgg(column="height", aggfunc=min), + ) + ) + expected = DataFrame( + { + "total_weight": [np.nan, 17.8, 19.9, np.nan, 205.5, 240.0], + "min_height": [np.nan, 9.1, 9.5, np.nan, 6.0, 8.0], + }, + index=MultiIndex( + [["cat", "dog"], [0, 1, 2, 3, 4, 5]], + [[0, 0, 0, 1, 1, 1], [0, 2, 4, 1, 3, 5]], + names=["kind", None], + ), + ) + tm.assert_frame_equal(result, expected) + + def test_groupby_subset_rolling_subset_with_closed(self): + # GH 35549 + df = DataFrame( + { + "column1": range(8), + "column2": range(8), + "group": ["A"] * 4 + ["B"] * 4, + "date": [ + Timestamp(date) + for date in ["2019-01-01", "2019-01-01", "2019-01-02", "2019-01-02"] + ] + * 2, + } + ) + + result = ( + df.groupby("group")[["column1", "date"]] + .rolling("1D", on="date", closed="left")["column1"] + .sum() + ) + expected = Series( + [np.nan, np.nan, 1.0, 1.0, np.nan, np.nan, 9.0, 9.0], + index=MultiIndex.from_frame( + df[["group", "date"]], + names=["group", "date"], + ), + name="column1", + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize("func", ["max", "min"]) + def test_groupby_rolling_index_changed(self, func): + # GH: #36018 nlevels of MultiIndex changed + ds = Series( + [1, 2, 2], + index=MultiIndex.from_tuples( + [("a", "x"), ("a", "y"), ("c", "z")], names=["1", "2"] + ), + name="a", + ) + + result = getattr(ds.groupby(ds).rolling(2), func)() + expected = Series( + [np.nan, np.nan, 2.0], + index=MultiIndex.from_tuples( + [(1, "a", "x"), (2, "a", "y"), (2, "c", "z")], names=["a", "1", "2"] + ), + name="a", + ) + tm.assert_series_equal(result, expected) + + def test_groupby_rolling_empty_frame(self): + # GH 36197 + expected = DataFrame({"s1": []}) + result = expected.groupby("s1").rolling(window=1).sum() + # GH 32262 + expected = expected.drop(columns="s1") + # GH-38057 from_tuples gives empty object dtype, we now get float/int levels + # expected.index = MultiIndex.from_tuples([], names=["s1", None]) + expected.index = MultiIndex.from_product( + [Index([], dtype="float64"), Index([], dtype="int64")], names=["s1", None] + ) + tm.assert_frame_equal(result, expected) + + expected = DataFrame({"s1": [], "s2": []}) + result = expected.groupby(["s1", "s2"]).rolling(window=1).sum() + # GH 32262 + expected = expected.drop(columns=["s1", "s2"]) + expected.index = MultiIndex.from_product( + [ + Index([], dtype="float64"), + Index([], dtype="float64"), + Index([], dtype="int64"), + ], + names=["s1", "s2", None], + ) + tm.assert_frame_equal(result, expected) + + def test_groupby_rolling_string_index(self): + # GH: 36727 + df = DataFrame( + [ + ["A", "group_1", Timestamp(2019, 1, 1, 9)], + ["B", "group_1", Timestamp(2019, 1, 2, 9)], + ["Z", "group_2", Timestamp(2019, 1, 3, 9)], + ["H", "group_1", Timestamp(2019, 1, 6, 9)], + ["E", "group_2", Timestamp(2019, 1, 20, 9)], + ], + columns=["index", "group", "eventTime"], + ).set_index("index") + + groups = df.groupby("group") + df["count_to_date"] = groups.cumcount() + rolling_groups = groups.rolling("10D", on="eventTime") + result = rolling_groups.apply(lambda df: df.shape[0]) + expected = DataFrame( + [ + ["A", "group_1", Timestamp(2019, 1, 1, 9), 1.0], + ["B", "group_1", Timestamp(2019, 1, 2, 9), 2.0], + ["H", "group_1", Timestamp(2019, 1, 6, 9), 3.0], + ["Z", "group_2", Timestamp(2019, 1, 3, 9), 1.0], + ["E", "group_2", Timestamp(2019, 1, 20, 9), 1.0], + ], + columns=["index", "group", "eventTime", "count_to_date"], + ).set_index(["group", "index"]) + tm.assert_frame_equal(result, expected) + + def test_groupby_rolling_no_sort(self): + # GH 36889 + result = ( + DataFrame({"foo": [2, 1], "bar": [2, 1]}) + .groupby("foo", sort=False) + .rolling(1) + .min() + ) + expected = DataFrame( + np.array([[2.0, 2.0], [1.0, 1.0]]), + columns=["foo", "bar"], + index=MultiIndex.from_tuples([(2, 0), (1, 1)], names=["foo", None]), + ) + # GH 32262 + expected = expected.drop(columns="foo") + tm.assert_frame_equal(result, expected) + + def test_groupby_rolling_count_closed_on(self, unit): + # GH 35869 + df = DataFrame( + { + "column1": range(6), + "column2": range(6), + "group": 3 * ["A", "B"], + "date": date_range(end="20190101", periods=6, unit=unit), + } + ) + msg = "'d' is deprecated and will be removed in a future version." + + with tm.assert_produces_warning(Pandas4Warning, match=msg): + result = ( + df.groupby("group") + .rolling("3d", on="date", closed="left")["column1"] + .count() + ) + dti = DatetimeIndex( + [ + "2018-12-27", + "2018-12-29", + "2018-12-31", + "2018-12-28", + "2018-12-30", + "2019-01-01", + ], + dtype=f"M8[{unit}]", + ) + mi = MultiIndex.from_arrays( + [ + ["A", "A", "A", "B", "B", "B"], + dti, + ], + names=["group", "date"], + ) + expected = Series( + [np.nan, 1.0, 1.0, np.nan, 1.0, 1.0], + name="column1", + index=mi, + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + ("func", "kwargs", "expected_values"), + [ + ( + "rolling", + {"window": 2, "min_periods": 1}, + [np.nan, 0.5, np.nan, 0.5, 0.5], + ), + ("expanding", {}, [np.nan, 0.5, np.nan, 0.5, (1 / 3) ** 0.5]), + ], + ) + def test_groupby_rolling_sem(self, func, kwargs, expected_values): + # GH: 26476 + df = DataFrame( + [["a", 1], ["a", 2], ["b", 1], ["b", 2], ["b", 3]], columns=["a", "b"] + ) + result = getattr(df.groupby("a"), func)(**kwargs).sem() + expected = DataFrame( + {"a": [np.nan] * 5, "b": expected_values}, + index=MultiIndex.from_tuples( + [("a", 0), ("a", 1), ("b", 2), ("b", 3), ("b", 4)], names=["a", None] + ), + ) + # GH 32262 + expected = expected.drop(columns="a") + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + ("rollings", "key"), [({"on": "a"}, "a"), ({"on": None}, "index")] + ) + def test_groupby_rolling_nans_in_index(self, rollings, key): + # GH: 34617 + df = DataFrame( + { + "a": to_datetime(["2020-06-01 12:00", "2020-06-01 14:00", np.nan]), + "b": [1, 2, 3], + "c": [1, 1, 1], + } + ) + if key == "index": + df = df.set_index("a") + with pytest.raises(ValueError, match=f"{key} values must not have NaT"): + df.groupby("c").rolling("60min", **rollings) + + @pytest.mark.parametrize("group_keys", [True, False]) + def test_groupby_rolling_group_keys(self, group_keys): + # GH 37641 + # GH 38523: GH 37641 actually was not a bug. + # group_keys only applies to groupby.apply directly + arrays = [["val1", "val1", "val2"], ["val1", "val1", "val2"]] + index = MultiIndex.from_arrays(arrays, names=("idx1", "idx2")) + + s = Series([1, 2, 3], index=index) + result = s.groupby(["idx1", "idx2"], group_keys=group_keys).rolling(1).mean() + expected = Series( + [1.0, 2.0, 3.0], + index=MultiIndex.from_tuples( + [ + ("val1", "val1", "val1", "val1"), + ("val1", "val1", "val1", "val1"), + ("val2", "val2", "val2", "val2"), + ], + names=["idx1", "idx2", "idx1", "idx2"], + ), + ) + tm.assert_series_equal(result, expected) + + def test_groupby_rolling_index_level_and_column_label(self): + # The groupby keys should not appear as a resulting column + arrays = [["val1", "val1", "val2"], ["val1", "val1", "val2"]] + index = MultiIndex.from_arrays(arrays, names=("idx1", "idx2")) + + df = DataFrame({"A": [1, 1, 2], "B": range(3)}, index=index) + result = df.groupby(["idx1", "A"]).rolling(1).mean() + expected = DataFrame( + {"B": [0.0, 1.0, 2.0]}, + index=MultiIndex.from_tuples( + [ + ("val1", 1, "val1", "val1"), + ("val1", 1, "val1", "val1"), + ("val2", 2, "val2", "val2"), + ], + names=["idx1", "A", "idx1", "idx2"], + ), + ) + tm.assert_frame_equal(result, expected) + + def test_groupby_rolling_resulting_multiindex(self): + # a few different cases checking the created MultiIndex of the result + # https://github.com/pandas-dev/pandas/pull/38057 + + # grouping by 1 columns -> 2-level MI as result + df = DataFrame({"a": np.arange(8.0), "b": [1, 2] * 4}) + result = df.groupby("b").rolling(3).mean() + expected_index = MultiIndex.from_tuples( + [(1, 0), (1, 2), (1, 4), (1, 6), (2, 1), (2, 3), (2, 5), (2, 7)], + names=["b", None], + ) + tm.assert_index_equal(result.index, expected_index) + + def test_groupby_rolling_resulting_multiindex2(self): + # grouping by 2 columns -> 3-level MI as result + df = DataFrame({"a": np.arange(12.0), "b": [1, 2] * 6, "c": [1, 2, 3, 4] * 3}) + result = df.groupby(["b", "c"]).rolling(2).sum() + expected_index = MultiIndex.from_tuples( + [ + (1, 1, 0), + (1, 1, 4), + (1, 1, 8), + (1, 3, 2), + (1, 3, 6), + (1, 3, 10), + (2, 2, 1), + (2, 2, 5), + (2, 2, 9), + (2, 4, 3), + (2, 4, 7), + (2, 4, 11), + ], + names=["b", "c", None], + ) + tm.assert_index_equal(result.index, expected_index) + + def test_groupby_rolling_resulting_multiindex3(self): + # grouping with 1 level on dataframe with 2-level MI -> 3-level MI as result + df = DataFrame({"a": np.arange(8.0), "b": [1, 2] * 4, "c": [1, 2, 3, 4] * 2}) + df = df.set_index("c", append=True) + result = df.groupby("b").rolling(3).mean() + expected_index = MultiIndex.from_tuples( + [ + (1, 0, 1), + (1, 2, 3), + (1, 4, 1), + (1, 6, 3), + (2, 1, 2), + (2, 3, 4), + (2, 5, 2), + (2, 7, 4), + ], + names=["b", None, "c"], + ) + tm.assert_index_equal(result.index, expected_index, exact="equiv") + + def test_groupby_rolling_object_doesnt_affect_groupby_apply(self, roll_frame): + # GH 39732 + g = roll_frame.groupby("A", group_keys=False) + expected = g.apply(lambda x: x.rolling(4).sum()).index + _ = g.rolling(window=4) + result = g.apply(lambda x: x.rolling(4).sum()).index + tm.assert_index_equal(result, expected) + + @pytest.mark.parametrize( + ("window", "min_periods", "closed", "expected"), + [ + (2, 0, "left", [None, 0.0, 1.0, 1.0, None, 0.0, 1.0, 1.0]), + (2, 2, "left", [None, None, 1.0, 1.0, None, None, 1.0, 1.0]), + (4, 4, "left", [None, None, None, None, None, None, None, None]), + (4, 4, "right", [None, None, None, 5.0, None, None, None, 5.0]), + ], + ) + def test_groupby_rolling_var(self, window, min_periods, closed, expected): + df = DataFrame([1, 2, 3, 4, 5, 6, 7, 8]) + result = ( + df.groupby([1, 2, 1, 2, 1, 2, 1, 2]) + .rolling(window=window, min_periods=min_periods, closed=closed) + .var(0) + ) + expected_result = DataFrame( + np.array(expected, dtype="float64"), + index=MultiIndex( + levels=[np.array([1, 2]), [0, 1, 2, 3, 4, 5, 6, 7]], + codes=[[0, 0, 0, 0, 1, 1, 1, 1], [0, 2, 4, 6, 1, 3, 5, 7]], + ), + ) + tm.assert_frame_equal(result, expected_result) + + @pytest.mark.parametrize( + "columns", [MultiIndex.from_tuples([("A", ""), ("B", "C")]), ["A", "B"]] + ) + def test_by_column_not_in_values(self, columns): + # GH 32262 + df = DataFrame([[1, 0]] * 20 + [[2, 0]] * 12 + [[3, 0]] * 8, columns=columns) + g = df.groupby("A") + original_obj = g.obj.copy(deep=True) + r = g.rolling(4) + result = r.sum() + assert "A" not in result.columns + tm.assert_frame_equal(g.obj, original_obj) + + def test_groupby_level(self): + # GH 38523, 38787 + arrays = [ + ["Falcon", "Falcon", "Parrot", "Parrot"], + ["Captive", "Wild", "Captive", "Wild"], + ] + index = MultiIndex.from_arrays(arrays, names=("Animal", "Type")) + df = DataFrame({"Max Speed": [390.0, 350.0, 30.0, 20.0]}, index=index) + result = df.groupby(level=0)["Max Speed"].rolling(2).sum() + expected = Series( + [np.nan, 740.0, np.nan, 50.0], + index=MultiIndex.from_tuples( + [ + ("Falcon", "Falcon", "Captive"), + ("Falcon", "Falcon", "Wild"), + ("Parrot", "Parrot", "Captive"), + ("Parrot", "Parrot", "Wild"), + ], + names=["Animal", "Animal", "Type"], + ), + name="Max Speed", + ) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "by, expected_data", + [ + [["id"], {"num": [100.0, 150.0, 150.0, 200.0]}], + [ + ["id", "index"], + { + "date": [ + Timestamp("2018-01-01"), + Timestamp("2018-01-02"), + Timestamp("2018-01-01"), + Timestamp("2018-01-02"), + ], + "num": [100.0, 200.0, 150.0, 250.0], + }, + ], + ], + ) + def test_as_index_false(self, by, expected_data, unit): + # GH 39433 + data = [ + ["A", "2018-01-01", 100.0], + ["A", "2018-01-02", 200.0], + ["B", "2018-01-01", 150.0], + ["B", "2018-01-02", 250.0], + ] + df = DataFrame(data, columns=["id", "date", "num"]) + df["date"] = df["date"].astype(f"M8[{unit}]") + df = df.set_index(["date"]) + + gp_by = [getattr(df, attr) for attr in by] + result = ( + df.groupby(gp_by, as_index=False).rolling(window=2, min_periods=1).mean() + ) + + expected = {"id": ["A", "A", "B", "B"]} + expected.update(expected_data) + expected = DataFrame( + expected, + index=df.index, + ) + if "date" in expected_data: + expected["date"] = expected["date"].astype(f"M8[{unit}]") + tm.assert_frame_equal(result, expected) + + def test_nan_and_zero_endpoints(self, any_int_numpy_dtype): + # https://github.com/twosigma/pandas/issues/53 + typ = np.dtype(any_int_numpy_dtype).type + size = 1000 + idx = np.repeat(typ(0), size) + idx[-1] = 1 + + val = 5e25 + arr = np.repeat(val, size) + arr[0] = np.nan + arr[-1] = 0 + + df = DataFrame( + { + "index": idx, + "adl2": arr, + } + ).set_index("index") + result = df.groupby("index")["adl2"].rolling(window=10, min_periods=1).mean() + expected = Series( + arr, + name="adl2", + index=MultiIndex.from_arrays( + [ + Index([0] * 999 + [1], dtype=typ, name="index"), + Index([0] * 999 + [1], dtype=typ, name="index"), + ], + ), + ) + tm.assert_series_equal(result, expected) + + def test_groupby_rolling_non_monotonic(self): + # GH 43909 + + shuffled = [3, 0, 1, 2] + sec = 1_000 + df = DataFrame( + [{"t": Timestamp(2 * x * sec), "x": x + 1, "c": 42} for x in shuffled] + ) + with pytest.raises(ValueError, match=r".* must be monotonic"): + df.groupby("c").rolling(on="t", window="3s") + + def test_groupby_monotonic(self): + # GH 15130 + # we don't need to validate monotonicity when grouping + + # GH 43909 we should raise an error here to match + # behaviour of non-groupby rolling. + + data = [ + ["David", "1/1/2015", 100], + ["David", "1/5/2015", 500], + ["David", "5/30/2015", 50], + ["David", "7/25/2015", 50], + ["Ryan", "1/4/2014", 100], + ["Ryan", "1/19/2015", 500], + ["Ryan", "3/31/2016", 50], + ["Joe", "7/1/2015", 100], + ["Joe", "9/9/2015", 500], + ["Joe", "10/15/2015", 50], + ] + + df = DataFrame(data=data, columns=["name", "date", "amount"]) + df["date"] = to_datetime(df["date"]) + df = df.sort_values("date") + + expected = ( + df.set_index("date") + .groupby("name") + .apply(lambda x: x.rolling("180D")["amount"].sum()) + ) + result = df.groupby("name").rolling("180D", on="date")["amount"].sum() + tm.assert_series_equal(result, expected) + + def test_datelike_on_monotonic_within_each_group(self): + # GH 13966 (similar to #15130, closed by #15175) + + # superseded by 43909 + # GH 46061: OK if the on is monotonic relative to each each group + + dates = date_range(start="2016-01-01 09:30:00", periods=20, freq="s") + df = DataFrame( + { + "A": [1] * 20 + [2] * 12 + [3] * 8, + "B": np.concatenate((dates, dates)), + "C": np.arange(40), + } + ) + + expected = ( + df.set_index("B").groupby("A").apply(lambda x: x.rolling("4s")["C"].mean()) + ) + result = df.groupby("A").rolling("4s", on="B").C.mean() + tm.assert_series_equal(result, expected) + + def test_datelike_on_not_monotonic_within_each_group(self): + # GH 46061 + df = DataFrame( + { + "A": [1] * 3 + [2] * 3, + "B": [Timestamp(year, 1, 1) for year in [2020, 2021, 2019]] * 2, + "C": range(6), + } + ) + with pytest.raises(ValueError, match="Each group within B must be monotonic."): + df.groupby("A").rolling("365D", on="B") + + +class TestExpanding: + @pytest.fixture + def frame(self): + return DataFrame({"A": [1] * 20 + [2] * 12 + [3] * 8, "B": np.arange(40)}) + + @pytest.mark.parametrize( + "f", + [ + "sum", + "mean", + "min", + "max", + "first", + "last", + "count", + "kurt", + "skew", + "nunique", + ], + ) + def test_expanding(self, f, frame): + g = frame.groupby("A", group_keys=False) + r = g.expanding() + + result = getattr(r, f)() + expected = g.apply(lambda x: getattr(x.expanding(), f)()) + # GH 39732 + expected_index = MultiIndex.from_arrays([frame["A"], range(40)]) + expected.index = expected_index + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("f", ["std", "var"]) + def test_expanding_ddof(self, f, frame): + g = frame.groupby("A", group_keys=False) + r = g.expanding() + + result = getattr(r, f)(ddof=0) + expected = g.apply(lambda x: getattr(x.expanding(), f)(ddof=0)) + # GH 39732 + expected_index = MultiIndex.from_arrays([frame["A"], range(40)]) + expected.index = expected_index + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "midpoint", "nearest"] + ) + def test_expanding_quantile(self, interpolation, frame): + g = frame.groupby("A", group_keys=False) + r = g.expanding() + + result = r.quantile(0.4, interpolation=interpolation) + expected = g.apply( + lambda x: x.expanding().quantile(0.4, interpolation=interpolation) + ) + # GH 39732 + expected_index = MultiIndex.from_arrays([frame["A"], range(40)]) + expected.index = expected_index + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("f", ["corr", "cov"]) + def test_expanding_corr_cov(self, f, frame): + g = frame.groupby("A") + r = g.expanding() + + result = getattr(r, f)(frame) + + def func_0(x): + return getattr(x.expanding(), f)(frame) + + expected = g.apply(func_0) + # GH 39591: groupby.apply returns 1 instead of nan for windows + # with all nan values + null_idx = list(range(20, 61)) + list(range(72, 113)) + expected.iloc[null_idx, 1] = np.nan + # GH 39591: The grouped column should be all np.nan + # (groupby.apply inserts 0s for cov) + expected["A"] = np.nan + tm.assert_frame_equal(result, expected) + + result = getattr(r.B, f)(pairwise=True) + + def func_1(x): + return getattr(x.B.expanding(), f)(pairwise=True) + + expected = g.apply(func_1) + tm.assert_series_equal(result, expected) + + def test_expanding_apply(self, raw, frame): + g = frame.groupby("A", group_keys=False) + r = g.expanding() + + # reduction + result = r.apply(lambda x: x.sum(), raw=raw) + expected = g.apply(lambda x: x.expanding().apply(lambda y: y.sum(), raw=raw)) + # GH 39732 + expected_index = MultiIndex.from_arrays([frame["A"], range(40)]) + expected.index = expected_index + tm.assert_frame_equal(result, expected) + + def test_groupby_expanding_agg_namedagg(self): + # GH#28333 + df = DataFrame( + { + "kind": ["cat", "dog", "cat", "dog", "cat", "dog"], + "height": [9.1, 6.0, 9.5, 34.0, 12.0, 8.0], + "weight": [7.9, 7.5, 9.9, 198.0, 10.0, 42.0], + } + ) + result = ( + df.groupby("kind") + .expanding(1) + .agg( + total_weight=NamedAgg(column="weight", aggfunc=sum), + min_height=NamedAgg(column="height", aggfunc=min), + ) + ) + expected = DataFrame( + { + "total_weight": [7.9, 17.8, 27.8, 7.5, 205.5, 247.5], + "min_height": [9.1, 9.1, 9.1, 6.0, 6.0, 6.0], + }, + index=MultiIndex( + [["cat", "dog"], [0, 1, 2, 3, 4, 5]], + [[0, 0, 0, 1, 1, 1], [0, 2, 4, 1, 3, 5]], + names=["kind", None], + ), + ) + tm.assert_frame_equal(result, expected) + + +class TestEWM: + @pytest.mark.parametrize( + "method, expected_data", + [ + ["mean", [0.0, 0.6666666666666666, 1.4285714285714286, 2.2666666666666666]], + ["std", [np.nan, 0.707107, 0.963624, 1.177164]], + ["var", [np.nan, 0.5, 0.9285714285714286, 1.3857142857142857]], + ], + ) + def test_methods(self, method, expected_data): + # GH 16037 + df = DataFrame({"A": ["a"] * 4, "B": range(4)}) + result = getattr(df.groupby("A").ewm(com=1.0), method)() + expected = DataFrame( + {"B": expected_data}, + index=MultiIndex.from_tuples( + [ + ("a", 0), + ("a", 1), + ("a", 2), + ("a", 3), + ], + names=["A", None], + ), + ) + tm.assert_frame_equal(result, expected) + + def test_groupby_ewm_agg_namedagg(self): + # GH#28333 + df = DataFrame({"A": ["a"] * 4, "B": range(4)}) + result = ( + df.groupby("A") + .ewm(com=1.0) + .agg( + B_mean=NamedAgg(column="B", aggfunc="mean"), + B_std=NamedAgg(column="B", aggfunc="std"), + B_var=NamedAgg(column="B", aggfunc="var"), + ) + ) + expected = DataFrame( + { + "B_mean": [ + 0.0, + 0.6666666666666666, + 1.4285714285714286, + 2.2666666666666666, + ], + "B_std": [np.nan, 0.707107, 0.963624, 1.177164], + "B_var": [np.nan, 0.5, 0.9285714285714286, 1.3857142857142857], + }, + index=MultiIndex.from_tuples( + [ + ("a", 0), + ("a", 1), + ("a", 2), + ("a", 3), + ], + names=["A", None], + ), + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "method, expected_data", + [["corr", [np.nan, 1.0, 1.0, 1]], ["cov", [np.nan, 0.5, 0.928571, 1.385714]]], + ) + def test_pairwise_methods(self, method, expected_data): + # GH 16037 + df = DataFrame({"A": ["a"] * 4, "B": range(4)}) + result = getattr(df.groupby("A").ewm(com=1.0), method)() + expected = DataFrame( + {"B": expected_data}, + index=MultiIndex.from_tuples( + [ + ("a", 0, "B"), + ("a", 1, "B"), + ("a", 2, "B"), + ("a", 3, "B"), + ], + names=["A", None, None], + ), + ) + tm.assert_frame_equal(result, expected) + + expected = df.groupby("A")[["B"]].apply( + lambda x: getattr(x.ewm(com=1.0), method)() + ) + tm.assert_frame_equal(result, expected) + + def test_times(self, times_frame): + # GH 40951 + halflife = "23 days" + # GH#42738 + times = times_frame.pop("C") + result = times_frame.groupby("A").ewm(halflife=halflife, times=times).mean() + expected = DataFrame( + { + "B": [ + 0.0, + 0.507534, + 1.020088, + 1.537661, + 0.0, + 0.567395, + 1.221209, + 0.0, + 0.653141, + 1.195003, + ] + }, + index=MultiIndex.from_tuples( + [ + ("a", 0), + ("a", 3), + ("a", 6), + ("a", 9), + ("b", 1), + ("b", 4), + ("b", 7), + ("c", 2), + ("c", 5), + ("c", 8), + ], + names=["A", None], + ), + ) + tm.assert_frame_equal(result, expected) + + def test_times_array(self, times_frame): + # GH 40951 + halflife = "23 days" + times = times_frame.pop("C") + gb = times_frame.groupby("A") + result = gb.ewm(halflife=halflife, times=times).mean() + expected = gb.ewm(halflife=halflife, times=times.values).mean() + tm.assert_frame_equal(result, expected) + + def test_dont_mutate_obj_after_slicing(self): + # GH 43355 + df = DataFrame( + { + "id": ["a", "a", "b", "b", "b"], + "timestamp": date_range("2021-9-1", periods=5, freq="h"), + "y": range(5), + } + ) + grp = df.groupby("id").rolling("1h", on="timestamp") + result = grp.count() + expected_df = DataFrame( + { + "timestamp": date_range("2021-9-1", periods=5, freq="h"), + "y": [1.0] * 5, + }, + index=MultiIndex.from_arrays( + [["a", "a", "b", "b", "b"], list(range(5))], names=["id", None] + ), + ) + tm.assert_frame_equal(result, expected_df) + + result = grp["y"].count() + expected_series = Series( + [1.0] * 5, + index=MultiIndex.from_arrays( + [ + ["a", "a", "b", "b", "b"], + date_range("2021-9-1", periods=5, freq="h"), + ], + names=["id", "timestamp"], + ), + name="y", + ) + tm.assert_series_equal(result, expected_series) + # This is the key test + result = grp.count() + tm.assert_frame_equal(result, expected_df) + + +def test_rolling_corr_with_single_integer_in_index(): + # GH 44078 + df = DataFrame({"a": [(1,), (1,), (1,)], "b": [4, 5, 6]}) + gb = df.groupby(["a"]) + result = gb.rolling(2).corr(other=df) + index = MultiIndex.from_tuples([((1,), 0), ((1,), 1), ((1,), 2)], names=["a", None]) + expected = DataFrame( + {"a": [np.nan, np.nan, np.nan], "b": [np.nan, 1.0, 1.0]}, index=index + ) + tm.assert_frame_equal(result, expected) + + +def test_rolling_corr_with_tuples_in_index(): + # GH 44078 + df = DataFrame( + { + "a": [ + ( + 1, + 2, + ), + ( + 1, + 2, + ), + ( + 1, + 2, + ), + ], + "b": [4, 5, 6], + } + ) + gb = df.groupby(["a"]) + result = gb.rolling(2).corr(other=df) + index = MultiIndex.from_tuples( + [((1, 2), 0), ((1, 2), 1), ((1, 2), 2)], names=["a", None] + ) + expected = DataFrame( + {"a": [np.nan, np.nan, np.nan], "b": [np.nan, 1.0, 1.0]}, index=index + ) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/window/test_numba.py b/pandas/tests/window/test_numba.py new file mode 100644 index 0000000000000000000000000000000000000000..ff6a616bc526427d0c2f3abfec9b66273c65fa00 --- /dev/null +++ b/pandas/tests/window/test_numba.py @@ -0,0 +1,648 @@ +import numpy as np +import pytest + +from pandas.compat import is_platform_arm +from pandas.errors import NumbaUtilError +import pandas.util._test_decorators as td + +from pandas import ( + DataFrame, + Series, + option_context, + to_datetime, +) +import pandas._testing as tm +from pandas.api.indexers import BaseIndexer +from pandas.util.version import Version + +pytestmark = [pytest.mark.single_cpu] + +numba = pytest.importorskip("numba") +pytestmark.append( + pytest.mark.skipif( + Version(numba.__version__) == Version("0.61") and is_platform_arm(), + reason=f"Segfaults on ARM platforms with numba {numba.__version__}", + ) +) + + +@pytest.fixture(params=["single", "table"]) +def method(request): + """method keyword in rolling/expanding/ewm constructor""" + return request.param + + +@pytest.fixture( + params=[ + ["sum", {}], + ["mean", {}], + ["median", {}], + ["max", {}], + ["min", {}], + ["var", {}], + ["var", {"ddof": 0}], + ["std", {}], + ["std", {"ddof": 0}], + ] +) +def arithmetic_numba_supported_operators(request): + return request.param + + +@pytest.fixture +def roll_frame(): + return DataFrame({"A": [1] * 20 + [2] * 12 + [3] * 8, "B": np.arange(40)}) + + +@td.skip_if_no("numba") +@pytest.mark.filterwarnings("ignore") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +class TestEngine: + @pytest.mark.parametrize("jit", [True, False]) + def test_numba_vs_cython_apply(self, jit, nogil, parallel, nopython, center, step): + def f(x, *args): + arg_sum = 0 + for arg in args: + arg_sum += arg + return np.mean(x) + arg_sum + + if jit: + import numba + + f = numba.jit(f) + + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + args = (2,) + + s = Series(range(10)) + result = s.rolling(2, center=center, step=step).apply( + f, args=args, engine="numba", engine_kwargs=engine_kwargs, raw=True + ) + expected = s.rolling(2, center=center, step=step).apply( + f, engine="cython", args=args, raw=True + ) + tm.assert_series_equal(result, expected) + + def test_apply_numba_with_kwargs(self, roll_frame): + # GH 58995 + # rolling apply + def func(sr, a=0): + return sr.sum() + a + + data = DataFrame(range(10)) + + result = data.rolling(5).apply(func, engine="numba", raw=True, kwargs={"a": 1}) + expected = data.rolling(5).sum() + 1 + tm.assert_frame_equal(result, expected) + + result = data.rolling(5).apply(func, engine="numba", raw=True, args=(1,)) + tm.assert_frame_equal(result, expected) + + # expanding apply + + result = data.expanding().apply(func, engine="numba", raw=True, kwargs={"a": 1}) + expected = data.expanding().sum() + 1 + tm.assert_frame_equal(result, expected) + + result = data.expanding().apply(func, engine="numba", raw=True, args=(1,)) + tm.assert_frame_equal(result, expected) + + # groupby rolling + result = ( + roll_frame.groupby("A") + .rolling(5) + .apply(func, engine="numba", raw=True, kwargs={"a": 1}) + ) + expected = roll_frame.groupby("A").rolling(5).sum() + 1 + tm.assert_frame_equal(result, expected) + + result = ( + roll_frame.groupby("A") + .rolling(5) + .apply(func, engine="numba", raw=True, args=(1,)) + ) + tm.assert_frame_equal(result, expected) + # groupby expanding + + result = ( + roll_frame.groupby("A") + .expanding() + .apply(func, engine="numba", raw=True, kwargs={"a": 1}) + ) + expected = roll_frame.groupby("A").expanding().sum() + 1 + tm.assert_frame_equal(result, expected) + + result = ( + roll_frame.groupby("A") + .expanding() + .apply(func, engine="numba", raw=True, args=(1,)) + ) + tm.assert_frame_equal(result, expected) + + def test_numba_min_periods(self): + # GH 58868 + def last_row(x): + assert len(x) == 3 + return x[-1] + + df = DataFrame([[1, 2], [3, 4], [5, 6], [7, 8]]) + + result = df.rolling(3, method="table", min_periods=3).apply( + last_row, raw=True, engine="numba" + ) + + expected = DataFrame([[np.nan, np.nan], [np.nan, np.nan], [5, 6], [7, 8]]) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "data", + [ + DataFrame(np.eye(5)), + DataFrame( + [ + [5, 7, 7, 7, np.nan, np.inf, 4, 3, 3, 3], + [5, 7, 7, 7, np.nan, np.inf, 7, 3, 3, 3], + [np.nan, np.nan, 5, 6, 7, 5, 5, 5, 5, 5], + ] + ).T, + Series(range(5), name="foo"), + Series([20, 10, 10, np.inf, 1, 1, 2, 3]), + Series([20, 10, 10, np.nan, 10, 1, 2, 3]), + ], + ) + def test_numba_vs_cython_rolling_methods( + self, + data, + nogil, + parallel, + nopython, + arithmetic_numba_supported_operators, + step, + ): + method, kwargs = arithmetic_numba_supported_operators + + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + + roll = data.rolling(3, step=step) + result = getattr(roll, method)( + engine="numba", engine_kwargs=engine_kwargs, **kwargs + ) + expected = getattr(roll, method)(engine="cython", **kwargs) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize( + "data", [DataFrame(np.eye(5)), Series(range(5), name="foo")] + ) + def test_numba_vs_cython_expanding_methods( + self, data, nogil, parallel, nopython, arithmetic_numba_supported_operators + ): + method, kwargs = arithmetic_numba_supported_operators + + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + + data = DataFrame(np.eye(5)) + expand = data.expanding() + result = getattr(expand, method)( + engine="numba", engine_kwargs=engine_kwargs, **kwargs + ) + expected = getattr(expand, method)(engine="cython", **kwargs) + tm.assert_equal(result, expected) + + @pytest.mark.parametrize("jit", [True, False]) + def test_cache_apply(self, jit, nogil, parallel, nopython, step): + # Test that the functions are cached correctly if we switch functions + def func_1(x): + return np.mean(x) + 4 + + def func_2(x): + return np.std(x) * 5 + + if jit: + import numba + + func_1 = numba.jit(func_1) + func_2 = numba.jit(func_2) + + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + + roll = Series(range(10)).rolling(2, step=step) + result = roll.apply( + func_1, engine="numba", engine_kwargs=engine_kwargs, raw=True + ) + expected = roll.apply(func_1, engine="cython", raw=True) + tm.assert_series_equal(result, expected) + + result = roll.apply( + func_2, engine="numba", engine_kwargs=engine_kwargs, raw=True + ) + expected = roll.apply(func_2, engine="cython", raw=True) + tm.assert_series_equal(result, expected) + # This run should use the cached func_1 + result = roll.apply( + func_1, engine="numba", engine_kwargs=engine_kwargs, raw=True + ) + expected = roll.apply(func_1, engine="cython", raw=True) + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "window,window_kwargs", + [ + ["rolling", {"window": 3, "min_periods": 0}], + ["expanding", {}], + ], + ) + def test_dont_cache_args( + self, window, window_kwargs, nogil, parallel, nopython, method + ): + # GH 42287 + + def add(values, x): + return np.sum(values) + x + + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + df = DataFrame({"value": [0, 0, 0]}) + result = getattr(df, window)(method=method, **window_kwargs).apply( + add, raw=True, engine="numba", engine_kwargs=engine_kwargs, args=(1,) + ) + expected = DataFrame({"value": [1.0, 1.0, 1.0]}) + tm.assert_frame_equal(result, expected) + + result = getattr(df, window)(method=method, **window_kwargs).apply( + add, raw=True, engine="numba", engine_kwargs=engine_kwargs, args=(2,) + ) + expected = DataFrame({"value": [2.0, 2.0, 2.0]}) + tm.assert_frame_equal(result, expected) + + def test_dont_cache_engine_kwargs(self): + # If the user passes a different set of engine_kwargs don't return the same + # jitted function + nogil = False + parallel = True + nopython = True + + def func(x): + return nogil + parallel + nopython + + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + df = DataFrame({"value": [0, 0, 0]}) + result = df.rolling(1).apply( + func, raw=True, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame({"value": [2.0, 2.0, 2.0]}) + tm.assert_frame_equal(result, expected) + + parallel = False + engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel} + result = df.rolling(1).apply( + func, raw=True, engine="numba", engine_kwargs=engine_kwargs + ) + expected = DataFrame({"value": [1.0, 1.0, 1.0]}) + tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numba") +class TestEWM: + @pytest.mark.parametrize( + "grouper", [lambda x: x, lambda x: x.groupby("A")], ids=["None", "groupby"] + ) + @pytest.mark.parametrize("method", ["mean", "sum"]) + def test_invalid_engine(self, grouper, method): + df = DataFrame({"A": ["a", "b", "a", "b"], "B": range(4)}) + with pytest.raises(ValueError, match="engine must be either"): + getattr(grouper(df).ewm(com=1.0), method)(engine="foo") + + @pytest.mark.parametrize( + "grouper", [lambda x: x, lambda x: x.groupby("A")], ids=["None", "groupby"] + ) + @pytest.mark.parametrize("method", ["mean", "sum"]) + def test_invalid_engine_kwargs(self, grouper, method): + df = DataFrame({"A": ["a", "b", "a", "b"], "B": range(4)}) + with pytest.raises(ValueError, match="cython engine does not"): + getattr(grouper(df).ewm(com=1.0), method)( + engine="cython", engine_kwargs={"nopython": True} + ) + + @pytest.mark.parametrize("grouper", ["None", "groupby"]) + @pytest.mark.parametrize("method", ["mean", "sum"]) + def test_cython_vs_numba( + self, grouper, method, nogil, parallel, nopython, ignore_na, adjust + ): + df = DataFrame({"B": range(4)}) + if grouper == "None": + grouper = lambda x: x + else: + df["A"] = ["a", "b", "a", "b"] + grouper = lambda x: x.groupby("A") + if method == "sum": + adjust = True + ewm = grouper(df).ewm(com=1.0, adjust=adjust, ignore_na=ignore_na) + + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + result = getattr(ewm, method)(engine="numba", engine_kwargs=engine_kwargs) + expected = getattr(ewm, method)(engine="cython") + + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("grouper", ["None", "groupby"]) + def test_cython_vs_numba_times(self, grouper, nogil, parallel, nopython, ignore_na): + # GH 40951 + + df = DataFrame({"B": [0, 0, 1, 1, 2, 2]}) + if grouper == "None": + grouper = lambda x: x + else: + grouper = lambda x: x.groupby("A") + df["A"] = ["a", "b", "a", "b", "b", "a"] + + halflife = "23 days" + times = to_datetime( + [ + "2020-01-01", + "2020-01-01", + "2020-01-02", + "2020-01-10", + "2020-02-23", + "2020-01-03", + ] + ) + ewm = grouper(df).ewm( + halflife=halflife, adjust=True, ignore_na=ignore_na, times=times + ) + + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + + result = ewm.mean(engine="numba", engine_kwargs=engine_kwargs) + expected = ewm.mean(engine="cython") + + tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numba") +def test_use_global_config(): + def f(x): + return np.mean(x) + 2 + + s = Series(range(10)) + with option_context("compute.use_numba", True): + result = s.rolling(2).apply(f, engine=None, raw=True) + expected = s.rolling(2).apply(f, engine="numba", raw=True) + tm.assert_series_equal(expected, result) + + +@td.skip_if_no("numba") +def test_invalid_kwargs_nopython(): + with pytest.raises(TypeError, match="got an unexpected keyword argument 'a'"): + Series(range(1)).rolling(1).apply( + lambda x: x, kwargs={"a": 1}, engine="numba", raw=True + ) + with pytest.raises( + NumbaUtilError, match="numba does not support keyword-only arguments" + ): + Series(range(1)).rolling(1).apply( + lambda x, *, a: x, kwargs={"a": 1}, engine="numba", raw=True + ) + + tm.assert_series_equal( + Series(range(1), dtype=float) + 1, + Series(range(1)) + .rolling(1) + .apply(lambda x, a: (x + a).sum(), kwargs={"a": 1}, engine="numba", raw=True), + ) + + +@td.skip_if_no("numba") +@pytest.mark.slow +@pytest.mark.filterwarnings("ignore") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +class TestTableMethod: + def test_table_series_valueerror(self): + def f(x): + return np.sum(x, axis=0) + 1 + + with pytest.raises( + ValueError, match="method='table' not applicable for Series objects." + ): + Series(range(1)).rolling(1, method="table").apply( + f, engine="numba", raw=True + ) + + def test_table_method_rolling_methods( + self, + nogil, + parallel, + nopython, + arithmetic_numba_supported_operators, + step, + ): + method, kwargs = arithmetic_numba_supported_operators + + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + + df = DataFrame(np.eye(3)) + roll_table = df.rolling(2, method="table", min_periods=0, step=step) + if method in ("var", "std"): + with pytest.raises(NotImplementedError, match=f"{method} not supported"): + getattr(roll_table, method)( + engine_kwargs=engine_kwargs, engine="numba", **kwargs + ) + else: + roll_single = df.rolling(2, method="single", min_periods=0, step=step) + result = getattr(roll_table, method)( + engine_kwargs=engine_kwargs, engine="numba", **kwargs + ) + expected = getattr(roll_single, method)( + engine_kwargs=engine_kwargs, engine="numba", **kwargs + ) + tm.assert_frame_equal(result, expected) + + def test_table_method_rolling_apply(self, nogil, parallel, nopython, step): + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + + def f(x): + return np.sum(x, axis=0) + 1 + + df = DataFrame(np.eye(3)) + result = df.rolling(2, method="table", min_periods=0, step=step).apply( + f, raw=True, engine_kwargs=engine_kwargs, engine="numba" + ) + expected = df.rolling(2, method="single", min_periods=0, step=step).apply( + f, raw=True, engine_kwargs=engine_kwargs, engine="numba" + ) + tm.assert_frame_equal(result, expected) + + def test_table_method_rolling_apply_col_order(self): + # GH#59666 + def f(x): + return np.nanmean(x[:, 0] - x[:, 1]) + + df = DataFrame( + { + "a": [1, 2, 3, 4, 5, 6], + "b": [6, 7, 8, 5, 6, 7], + } + ) + result = df.rolling(3, method="table", min_periods=0)[["a", "b"]].apply( + f, raw=True, engine="numba" + ) + expected = DataFrame( + { + "a": [-5, -5, -5, -3.66667, -2.33333, -1], + "b": [-5, -5, -5, -3.66667, -2.33333, -1], + } + ) + tm.assert_almost_equal(result, expected) + result = df.rolling(3, method="table", min_periods=0)[["b", "a"]].apply( + f, raw=True, engine="numba" + ) + expected = DataFrame( + { + "b": [5, 5, 5, 3.66667, 2.33333, 1], + "a": [5, 5, 5, 3.66667, 2.33333, 1], + } + ) + tm.assert_almost_equal(result, expected) + + def test_table_method_rolling_weighted_mean(self, step): + def weighted_mean(x): + arr = np.ones((1, x.shape[1])) + arr[:, :2] = (x[:, :2] * x[:, 2]).sum(axis=0) / x[:, 2].sum() + return arr + + df = DataFrame([[1, 2, 0.6], [2, 3, 0.4], [3, 4, 0.2], [4, 5, 0.7]]) + result = df.rolling(2, method="table", min_periods=0, step=step).apply( + weighted_mean, raw=True, engine="numba" + ) + expected = DataFrame( + [ + [1.0, 2.0, 1.0], + [1.8, 2.0, 1.0], + [3.333333, 2.333333, 1.0], + [1.555556, 7, 1.0], + ] + )[::step] + tm.assert_frame_equal(result, expected) + + def test_table_method_expanding_apply(self, nogil, parallel, nopython): + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + + def f(x): + return np.sum(x, axis=0) + 1 + + df = DataFrame(np.eye(3)) + result = df.expanding(method="table").apply( + f, raw=True, engine_kwargs=engine_kwargs, engine="numba" + ) + expected = df.expanding(method="single").apply( + f, raw=True, engine_kwargs=engine_kwargs, engine="numba" + ) + tm.assert_frame_equal(result, expected) + + def test_table_method_expanding_methods( + self, nogil, parallel, nopython, arithmetic_numba_supported_operators + ): + method, kwargs = arithmetic_numba_supported_operators + + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + + df = DataFrame(np.eye(3)) + expand_table = df.expanding(method="table") + if method in ("var", "std"): + with pytest.raises(NotImplementedError, match=f"{method} not supported"): + getattr(expand_table, method)( + engine_kwargs=engine_kwargs, engine="numba", **kwargs + ) + else: + expand_single = df.expanding(method="single") + result = getattr(expand_table, method)( + engine_kwargs=engine_kwargs, engine="numba", **kwargs + ) + expected = getattr(expand_single, method)( + engine_kwargs=engine_kwargs, engine="numba", **kwargs + ) + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize("data", [np.eye(3), np.ones((2, 3)), np.ones((3, 2))]) + @pytest.mark.parametrize("method", ["mean", "sum"]) + def test_table_method_ewm(self, data, method, nogil, parallel, nopython): + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + + df = DataFrame(data) + + result = getattr(df.ewm(com=1, method="table"), method)( + engine_kwargs=engine_kwargs, engine="numba" + ) + expected = getattr(df.ewm(com=1, method="single"), method)( + engine_kwargs=engine_kwargs, engine="numba" + ) + tm.assert_frame_equal(result, expected) + + +@td.skip_if_no("numba") +def test_npfunc_no_warnings(): + df = DataFrame({"col1": [1, 2, 3, 4, 5]}) + with tm.assert_produces_warning(False): + df.col1.rolling(2).apply(np.prod, raw=True, engine="numba") + + +class PrescribedWindowIndexer(BaseIndexer): + def __init__(self, start, end): + self._start = start + self._end = end + super().__init__() + + def get_window_bounds( + self, num_values=None, min_periods=None, center=None, closed=None, step=None + ): + if num_values is None: + num_values = len(self._start) + start = np.clip(self._start, 0, num_values) + end = np.clip(self._end, 0, num_values) + return start, end + + +@td.skip_if_no("numba") +class TestMinMaxNumba: + @pytest.mark.parametrize( + "is_max, has_nan, exp_list", + [ + (True, False, [3.0, 5.0, 2.0, 5.0, 1.0, 5.0, 6.0, 7.0, 8.0, 9.0]), + (True, True, [3.0, 4.0, 2.0, 4.0, 1.0, 4.0, 6.0, 7.0, 7.0, 9.0]), + (False, False, [3.0, 2.0, 2.0, 1.0, 1.0, 0.0, 0.0, 0.0, 7.0, 0.0]), + (False, True, [3.0, 2.0, 2.0, 1.0, 1.0, 1.0, 6.0, 6.0, 7.0, 1.0]), + ], + ) + def test_minmax(self, is_max, has_nan, exp_list): + nan_idx = [0, 5, 8] + df = DataFrame( + { + "data": [5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 6.0, 7.0, 8.0, 9.0], + "start": [2, 0, 3, 0, 4, 0, 5, 5, 7, 3], + "end": [3, 4, 4, 5, 5, 6, 7, 8, 9, 10], + } + ) + if has_nan: + df.loc[nan_idx, "data"] = np.nan + expected = Series(exp_list, name="data") + r = df.data.rolling( + PrescribedWindowIndexer(df.start.to_numpy(), df.end.to_numpy()) + ) + if is_max: + result = r.max(engine="numba") + else: + result = r.min(engine="numba") + + tm.assert_series_equal(result, expected) + + def test_wrong_order(self): + start = np.array(range(5), dtype=np.int64) + end = start + 1 + end[3] = end[2] + start[3] = start[2] - 1 + + df = DataFrame({"data": start * 1.0, "start": start, "end": end}) + + r = df.data.rolling(PrescribedWindowIndexer(start, end)) + with pytest.raises( + ValueError, match="Start/End ordering requirement is violated at index 3" + ): + r.max(engine="numba") diff --git a/pandas/tests/window/test_online.py b/pandas/tests/window/test_online.py new file mode 100644 index 0000000000000000000000000000000000000000..43d55a7992b3ce52255a6813e8b9e93b82a45324 --- /dev/null +++ b/pandas/tests/window/test_online.py @@ -0,0 +1,112 @@ +import numpy as np +import pytest + +from pandas.compat import is_platform_arm + +from pandas import ( + DataFrame, + Series, +) +import pandas._testing as tm +from pandas.util.version import Version + +pytestmark = [pytest.mark.single_cpu] + +numba = pytest.importorskip("numba") +pytestmark.append( + pytest.mark.skipif( + Version(numba.__version__) == Version("0.61") and is_platform_arm(), + reason=f"Segfaults on ARM platforms with numba {numba.__version__}", + ) +) + + +@pytest.mark.filterwarnings("ignore") +# Filter warnings when parallel=True and the function can't be parallelized by Numba +class TestEWM: + def test_invalid_update(self): + df = DataFrame({"a": range(5), "b": range(5)}) + online_ewm = df.head(2).ewm(0.5).online() + with pytest.raises( + ValueError, + match="Must call mean with update=None first before passing update", + ): + online_ewm.mean(update=df.head(1)) + + @pytest.mark.slow + @pytest.mark.parametrize( + "obj", [DataFrame({"a": range(5), "b": range(5)}), Series(range(5), name="foo")] + ) + def test_online_vs_non_online_mean( + self, obj, nogil, parallel, nopython, adjust, ignore_na + ): + expected = obj.ewm(0.5, adjust=adjust, ignore_na=ignore_na).mean() + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + + online_ewm = ( + obj.head(2) + .ewm(0.5, adjust=adjust, ignore_na=ignore_na) + .online(engine_kwargs=engine_kwargs) + ) + # Test resetting once + for _ in range(2): + result = online_ewm.mean() + tm.assert_equal(result, expected.head(2)) + + result = online_ewm.mean(update=obj.tail(3)) + tm.assert_equal(result, expected.tail(3)) + + online_ewm.reset() + + @pytest.mark.xfail(raises=NotImplementedError) + @pytest.mark.parametrize( + "obj", [DataFrame({"a": range(5), "b": range(5)}), Series(range(5), name="foo")] + ) + def test_update_times_mean( + self, obj, nogil, parallel, nopython, adjust, ignore_na, halflife_with_times + ): + times = Series( + np.array( + ["2020-01-01", "2020-01-05", "2020-01-07", "2020-01-17", "2020-01-21"], + dtype="datetime64[ns]", + ) + ) + expected = obj.ewm( + 0.5, + adjust=adjust, + ignore_na=ignore_na, + times=times, + halflife=halflife_with_times, + ).mean() + + engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython} + online_ewm = ( + obj.head(2) + .ewm( + 0.5, + adjust=adjust, + ignore_na=ignore_na, + times=times.head(2), + halflife=halflife_with_times, + ) + .online(engine_kwargs=engine_kwargs) + ) + # Test resetting once + for _ in range(2): + result = online_ewm.mean() + tm.assert_equal(result, expected.head(2)) + + result = online_ewm.mean(update=obj.tail(3), update_times=times.tail(3)) + tm.assert_equal(result, expected.tail(3)) + + online_ewm.reset() + + @pytest.mark.parametrize("method", ["aggregate", "std", "corr", "cov", "var"]) + def test_ewm_notimplementederror_raises(self, method): + ser = Series(range(10)) + kwargs = {} + if method == "aggregate": + kwargs["func"] = lambda x: x + + with pytest.raises(NotImplementedError, match=".* is not implemented."): + getattr(ser.ewm(1).online(), method)(**kwargs) diff --git a/pandas/tests/window/test_pairwise.py b/pandas/tests/window/test_pairwise.py new file mode 100644 index 0000000000000000000000000000000000000000..eb22502fd648b3b90e3aec80b66d1f19bcd5c0f8 --- /dev/null +++ b/pandas/tests/window/test_pairwise.py @@ -0,0 +1,457 @@ +import numpy as np +import pytest + +from pandas.compat import IS64 + +from pandas import ( + DataFrame, + Index, + MultiIndex, + Series, + date_range, +) +import pandas._testing as tm +from pandas.core.algorithms import safe_sort + + +@pytest.fixture( + params=[ + DataFrame([[2, 4], [1, 2], [5, 2], [8, 1]], columns=[1, 0]), + DataFrame([[2, 4], [1, 2], [5, 2], [8, 1]], columns=[1, 1]), + DataFrame([[2, 4], [1, 2], [5, 2], [8, 1]], columns=["C", "C"]), + DataFrame([[2, 4], [1, 2], [5, 2], [8, 1]], columns=[1.0, 0]), + DataFrame([[2, 4], [1, 2], [5, 2], [8, 1]], columns=[0.0, 1]), + DataFrame([[2, 4], [1, 2], [5, 2], [8, 1]], columns=["C", 1]), + DataFrame([[2.0, 4.0], [1.0, 2.0], [5.0, 2.0], [8.0, 1.0]], columns=[1, 0.0]), + DataFrame([[2, 4.0], [1, 2.0], [5, 2.0], [8, 1.0]], columns=[0, 1.0]), + DataFrame([[2, 4], [1, 2], [5, 2], [8, 1.0]], columns=[1.0, "X"]), + ] +) +def pairwise_frames(request): + """Pairwise frames test_pairwise""" + return request.param + + +@pytest.fixture +def pairwise_target_frame(): + """Pairwise target frame for test_pairwise""" + return DataFrame([[2, 4], [1, 2], [5, 2], [8, 1]], columns=[0, 1]) + + +@pytest.fixture +def pairwise_other_frame(): + """Pairwise other frame for test_pairwise""" + return DataFrame( + [[None, 1, 1], [None, 1, 2], [None, 3, 2], [None, 8, 1]], + columns=["Y", "Z", "X"], + ) + + +def test_rolling_cov(series): + A = series + B = A + np.random.default_rng(2).standard_normal(len(A)) + + result = A.rolling(window=50, min_periods=25).cov(B) + tm.assert_almost_equal(result.iloc[-1], np.cov(A[-50:], B[-50:])[0, 1]) + + +def test_rolling_corr(series): + A = series + B = A + np.random.default_rng(2).standard_normal(len(A)) + + result = A.rolling(window=50, min_periods=25).corr(B) + tm.assert_almost_equal(result.iloc[-1], np.corrcoef(A[-50:], B[-50:])[0, 1]) + + +def test_rolling_corr_bias_correction(): + # test for correct bias correction + a = Series( + np.arange(20, dtype=np.float64), index=date_range("2020-01-01", periods=20) + ) + b = a.copy() + a[:5] = np.nan + b[:10] = np.nan + + result = a.rolling(window=len(a), min_periods=1).corr(b) + tm.assert_almost_equal(result.iloc[-1], a.corr(b)) + + +@pytest.mark.parametrize("func", ["cov", "corr"]) +def test_rolling_pairwise_cov_corr(func, frame): + result = getattr(frame.rolling(window=10, min_periods=5), func)() + result = result.loc[(slice(None), 1), 5] + result.index = result.index.droplevel(1) + expected = getattr(frame[1].rolling(window=10, min_periods=5), func)(frame[5]) + tm.assert_series_equal(result, expected, check_names=False) + + +@pytest.mark.parametrize("method", ["corr", "cov"]) +def test_flex_binary_frame(method, frame): + series = frame[1] + + res = getattr(series.rolling(window=10), method)(frame) + res2 = getattr(frame.rolling(window=10), method)(series) + exp = frame.apply(lambda x: getattr(series.rolling(window=10), method)(x)) + + tm.assert_frame_equal(res, exp) + tm.assert_frame_equal(res2, exp) + + frame2 = DataFrame( + np.random.default_rng(2).standard_normal(frame.shape), + index=frame.index, + columns=frame.columns, + ) + + res3 = getattr(frame.rolling(window=10), method)(frame2) + res3.columns = Index(list(res3.columns)) + exp = DataFrame( + {k: getattr(frame[k].rolling(window=10), method)(frame2[k]) for k in frame} + ) + tm.assert_frame_equal(res3, exp) + + +@pytest.mark.parametrize("window", range(7)) +def test_rolling_corr_with_zero_variance(window): + # GH 18430 + s = Series(np.zeros(20)) + other = Series(np.arange(20)) + + assert s.rolling(window=window).corr(other=other).isna().all() + + +def test_corr_sanity(): + # GH 3155 + df = DataFrame( + np.array( + [ + [0.87024726, 0.18505595], + [0.64355431, 0.3091617], + [0.92372966, 0.50552513], + [0.00203756, 0.04520709], + [0.84780328, 0.33394331], + [0.78369152, 0.63919667], + ] + ) + ) + + res = df[0].rolling(5, center=True).corr(df[1]) + assert all(np.abs(np.nan_to_num(x)) <= 1 for x in res) + + df = DataFrame(np.random.default_rng(2).random((30, 2))) + res = df[0].rolling(5, center=True).corr(df[1]) + assert all(np.abs(np.nan_to_num(x)) <= 1 for x in res) + + +def test_rolling_cov_diff_length(): + # GH 7512 + s1 = Series([1, 2, 3], index=range(3)) + s2 = Series([1, 3], index=range(0, 4, 2)) + result = s1.rolling(window=3, min_periods=2).cov(s2) + expected = Series([None, None, 2.0]) + tm.assert_series_equal(result, expected) + + s2a = Series([1, None, 3], index=range(3)) + result = s1.rolling(window=3, min_periods=2).cov(s2a) + tm.assert_series_equal(result, expected) + + +def test_rolling_corr_diff_length(): + # GH 7512 + s1 = Series([1, 2, 3], index=range(3)) + s2 = Series([1, 3], index=range(0, 4, 2)) + result = s1.rolling(window=3, min_periods=2).corr(s2) + expected = Series([None, None, 1.0]) + tm.assert_series_equal(result, expected) + + s2a = Series([1, None, 3], index=range(3)) + result = s1.rolling(window=3, min_periods=2).corr(s2a) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["cov", "corr"]) +def test_time_based_rolling_other_longer_raises(func): + # GH#62937 + idx_short = date_range("2020-01-01", periods=3, freq="D") + idx_long = date_range("2020-01-01", periods=5, freq="D") + s = Series([1, 2, 3], index=idx_short) + other = Series([1, 2, 3, 4, 5], index=idx_long) + msg = "Variable rolling window requires .* Got 3 < 5" + with pytest.raises(ValueError, match=msg): + getattr(s.rolling("2D"), func)(other) + + +@pytest.mark.parametrize( + "f", + [ + lambda x: (x.rolling(window=10, min_periods=5).cov(x, pairwise=True)), + lambda x: (x.rolling(window=10, min_periods=5).corr(x, pairwise=True)), + ], +) +def test_rolling_functions_window_non_shrinkage_binary(f): + # corr/cov return a MI DataFrame + df = DataFrame( + [[1, 5], [3, 2], [3, 9], [-1, 0]], + columns=Index(["A", "B"], name="foo"), + index=Index(range(4), name="bar"), + ) + df_expected = DataFrame( + columns=Index(["A", "B"], name="foo"), + index=MultiIndex.from_product([df.index, df.columns], names=["bar", "foo"]), + dtype="float64", + ) + df_result = f(df) + tm.assert_frame_equal(df_result, df_expected) + + +@pytest.mark.parametrize( + "f", + [ + lambda x: (x.rolling(window=10, min_periods=5).cov(x, pairwise=True)), + lambda x: (x.rolling(window=10, min_periods=5).corr(x, pairwise=True)), + ], +) +def test_moment_functions_zero_length_pairwise(f): + df1 = DataFrame() + df2 = DataFrame(columns=Index(["a"], name="foo"), index=Index([], name="bar")) + df2["a"] = df2["a"].astype("float64") + + df1_expected = DataFrame(index=MultiIndex.from_product([df1.index, df1.columns])) + df2_expected = DataFrame( + index=MultiIndex.from_product([df2.index, df2.columns], names=["bar", "foo"]), + columns=Index(["a"], name="foo"), + dtype="float64", + ) + + df1_result = f(df1) + tm.assert_frame_equal(df1_result, df1_expected) + + df2_result = f(df2) + tm.assert_frame_equal(df2_result, df2_expected) + + +class TestPairwise: + # GH 7738 + @pytest.mark.parametrize("f", [lambda x: x.cov(), lambda x: x.corr()]) + def test_no_flex(self, pairwise_frames, pairwise_target_frame, f): + # DataFrame methods (which do not call flex_binary_moment()) + + result = f(pairwise_frames) + tm.assert_index_equal(result.index, pairwise_frames.columns) + tm.assert_index_equal(result.columns, pairwise_frames.columns) + expected = f(pairwise_target_frame) + # since we have sorted the results + # we can only compare non-nans + result = result.dropna().values + expected = expected.dropna().values + + tm.assert_numpy_array_equal(result, expected, check_dtype=False) + + @pytest.mark.parametrize( + "f", + [ + lambda x: x.expanding().cov(pairwise=True), + lambda x: x.expanding().corr(pairwise=True), + lambda x: x.rolling(window=3).cov(pairwise=True), + lambda x: x.rolling(window=3).corr(pairwise=True), + lambda x: x.ewm(com=3).cov(pairwise=True), + lambda x: x.ewm(com=3).corr(pairwise=True), + ], + ) + def test_pairwise_with_self(self, pairwise_frames, pairwise_target_frame, f): + # DataFrame with itself, pairwise=True + # note that we may construct the 1st level of the MI + # in a non-monotonic way, so compare accordingly + result = f(pairwise_frames) + tm.assert_index_equal( + result.index.levels[0], pairwise_frames.index, check_names=False + ) + tm.assert_index_equal( + safe_sort(result.index.levels[1]), + safe_sort(pairwise_frames.columns.unique()), + ) + tm.assert_index_equal(result.columns, pairwise_frames.columns) + expected = f(pairwise_target_frame) + # since we have sorted the results + # we can only compare non-nans + result = result.dropna().values + expected = expected.dropna().values + + tm.assert_numpy_array_equal(result, expected, check_dtype=False) + + @pytest.mark.parametrize( + "f", + [ + lambda x: x.expanding().cov(pairwise=False), + lambda x: x.expanding().corr(pairwise=False), + lambda x: x.rolling(window=3).cov(pairwise=False), + lambda x: x.rolling(window=3).corr(pairwise=False), + lambda x: x.ewm(com=3).cov(pairwise=False), + lambda x: x.ewm(com=3).corr(pairwise=False), + ], + ) + def test_no_pairwise_with_self(self, pairwise_frames, pairwise_target_frame, f): + # DataFrame with itself, pairwise=False + result = f(pairwise_frames) + tm.assert_index_equal(result.index, pairwise_frames.index) + tm.assert_index_equal(result.columns, pairwise_frames.columns) + expected = f(pairwise_target_frame) + # since we have sorted the results + # we can only compare non-nans + result = result.dropna().values + expected = expected.dropna().values + + tm.assert_numpy_array_equal(result, expected, check_dtype=False) + + @pytest.mark.parametrize( + "f", + [ + lambda x, y: x.expanding().cov(y, pairwise=True), + lambda x, y: x.expanding().corr(y, pairwise=True), + lambda x, y: x.rolling(window=3).cov(y, pairwise=True), + # TODO: We're missing a flag somewhere in meson + pytest.param( + lambda x, y: x.rolling(window=3).corr(y, pairwise=True), + marks=pytest.mark.xfail( + not IS64, reason="Precision issues on 32 bit", strict=False + ), + ), + lambda x, y: x.ewm(com=3).cov(y, pairwise=True), + lambda x, y: x.ewm(com=3).corr(y, pairwise=True), + ], + ) + def test_pairwise_with_other( + self, pairwise_frames, pairwise_target_frame, pairwise_other_frame, f + ): + # DataFrame with another DataFrame, pairwise=True + result = f(pairwise_frames, pairwise_other_frame) + tm.assert_index_equal( + result.index.levels[0], pairwise_frames.index, check_names=False + ) + tm.assert_index_equal( + safe_sort(result.index.levels[1]), + safe_sort(pairwise_other_frame.columns.unique()), + ) + expected = f(pairwise_target_frame, pairwise_other_frame) + # since we have sorted the results + # we can only compare non-nans + result = result.dropna().values + expected = expected.dropna().values + + tm.assert_numpy_array_equal(result, expected, check_dtype=False) + + @pytest.mark.filterwarnings("ignore:RuntimeWarning") + @pytest.mark.parametrize( + "f", + [ + lambda x, y: x.expanding().cov(y, pairwise=False), + lambda x, y: x.expanding().corr(y, pairwise=False), + lambda x, y: x.rolling(window=3).cov(y, pairwise=False), + lambda x, y: x.rolling(window=3).corr(y, pairwise=False), + lambda x, y: x.ewm(com=3).cov(y, pairwise=False), + lambda x, y: x.ewm(com=3).corr(y, pairwise=False), + ], + ) + def test_no_pairwise_with_other(self, pairwise_frames, pairwise_other_frame, f): + # DataFrame with another DataFrame, pairwise=False + result = ( + f(pairwise_frames, pairwise_other_frame) + if pairwise_frames.columns.is_unique + else None + ) + if result is not None: + # we can have int and str columns + expected_index = pairwise_frames.index.union(pairwise_other_frame.index) + expected_columns = pairwise_frames.columns.union( + pairwise_other_frame.columns + ) + tm.assert_index_equal(result.index, expected_index) + tm.assert_index_equal(result.columns, expected_columns) + else: + with pytest.raises(ValueError, match="'arg1' columns are not unique"): + f(pairwise_frames, pairwise_other_frame) + with pytest.raises(ValueError, match="'arg2' columns are not unique"): + f(pairwise_other_frame, pairwise_frames) + + @pytest.mark.parametrize( + "f", + [ + lambda x, y: x.expanding().cov(y), + lambda x, y: x.expanding().corr(y), + lambda x, y: x.rolling(window=3).cov(y), + lambda x, y: x.rolling(window=3).corr(y), + lambda x, y: x.ewm(com=3).cov(y), + lambda x, y: x.ewm(com=3).corr(y), + ], + ) + def test_pairwise_with_series(self, pairwise_frames, pairwise_target_frame, f): + # DataFrame with a Series + result = f(pairwise_frames, Series([1, 1, 3, 8])) + tm.assert_index_equal(result.index, pairwise_frames.index) + tm.assert_index_equal(result.columns, pairwise_frames.columns) + expected = f(pairwise_target_frame, Series([1, 1, 3, 8])) + # since we have sorted the results + # we can only compare non-nans + result = result.dropna().values + expected = expected.dropna().values + tm.assert_numpy_array_equal(result, expected, check_dtype=False) + + result = f(Series([1, 1, 3, 8]), pairwise_frames) + tm.assert_index_equal(result.index, pairwise_frames.index) + tm.assert_index_equal(result.columns, pairwise_frames.columns) + expected = f(Series([1, 1, 3, 8]), pairwise_target_frame) + # since we have sorted the results + # we can only compare non-nans + result = result.dropna().values + expected = expected.dropna().values + tm.assert_numpy_array_equal(result, expected, check_dtype=False) + + def test_corr_freq_memory_error(self): + # GH 31789 + s = Series(range(5), index=date_range("2020", periods=5)) + result = s.rolling("12h").corr(s) + expected = Series([np.nan] * 5, index=date_range("2020", periods=5)) + tm.assert_series_equal(result, expected) + + def test_cov_mulittindex(self): + # GH 34440 + + columns = MultiIndex.from_product([list("ab"), list("xy"), list("AB")]) + index = range(3) + df = DataFrame(np.arange(24).reshape(3, 8), index=index, columns=columns) + + result = df.ewm(alpha=0.1).cov() + + index = MultiIndex.from_product([range(3), list("ab"), list("xy"), list("AB")]) + columns = MultiIndex.from_product([list("ab"), list("xy"), list("AB")]) + expected = DataFrame( + np.vstack( + ( + np.full((8, 8), np.nan), + np.full((8, 8), 32.000000), + np.full((8, 8), 63.881919), + ) + ), + index=index, + columns=columns, + ) + + tm.assert_frame_equal(result, expected) + + def test_multindex_columns_pairwise_func(self): + # GH 21157 + columns = MultiIndex.from_arrays([["M", "N"], ["P", "Q"]], names=["a", "b"]) + df = DataFrame(np.ones((5, 2)), columns=columns) + result = df.rolling(3).corr() + expected = DataFrame( + np.nan, + index=MultiIndex.from_arrays( + [ + np.repeat(np.arange(5, dtype=np.int64), 2), + ["M", "N"] * 5, + ["P", "Q"] * 5, + ], + names=[None, "a", "b"], + ), + columns=columns, + ) + tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/window/test_rolling.py b/pandas/tests/window/test_rolling.py new file mode 100644 index 0000000000000000000000000000000000000000..8a232751a82de73aeb892148e03f2617df13f528 --- /dev/null +++ b/pandas/tests/window/test_rolling.py @@ -0,0 +1,2115 @@ +from datetime import ( + datetime, + timedelta, +) + +import numpy as np +import pytest + +from pandas.compat import ( + IS64, +) +from pandas.errors import Pandas4Warning + +from pandas import ( + DataFrame, + DatetimeIndex, + MultiIndex, + Series, + Timedelta, + Timestamp, + date_range, + period_range, +) +import pandas._testing as tm +from pandas.api.indexers import BaseIndexer +from pandas.core.indexers.objects import VariableOffsetWindowIndexer + +from pandas.tseries.offsets import BusinessDay + + +def test_doc_string(): + df = DataFrame({"B": [0, 1, 2, np.nan, 4]}) + df + df.rolling(2).sum() + df.rolling(2, min_periods=1).sum() + + +def test_constructor(frame_or_series): + # GH 12669 + + c = frame_or_series(range(5)).rolling + + # valid + c(0) + c(window=2) + c(window=2, min_periods=1) + c(window=2, min_periods=1, center=True) + c(window=2, min_periods=1, center=False) + + # GH 13383 + + msg = "window must be an integer 0 or greater" + + with pytest.raises(ValueError, match=msg): + c(-1) + + +@pytest.mark.parametrize("w", [2.0, "foo", np.array([2])]) +def test_invalid_constructor(frame_or_series, w): + # not valid + + c = frame_or_series(range(5)).rolling + + msg = "|".join( + [ + "window must be an integer", + "passed window foo is not compatible with a datetimelike index", + ] + ) + with pytest.raises(ValueError, match=msg): + c(window=w) + + msg = "min_periods must be an integer" + with pytest.raises(ValueError, match=msg): + c(window=2, min_periods=w) + + msg = "center must be a boolean" + with pytest.raises(ValueError, match=msg): + c(window=2, min_periods=1, center=w) + + +@pytest.mark.parametrize( + "window", + [ + timedelta(days=3), + Timedelta(days=3), + "3D", + VariableOffsetWindowIndexer( + index=date_range("2015-12-25", periods=5), offset=BusinessDay(1) + ), + ], +) +def test_freq_window_not_implemented(window): + # GH 15354 + df = DataFrame( + np.arange(10), + index=date_range("2015-12-24", periods=10, freq="D"), + ) + with pytest.raises( + NotImplementedError, match="^step (not implemented|is not supported)" + ): + df.rolling(window, step=3).sum() + + +@pytest.mark.parametrize("agg", ["cov", "corr"]) +def test_step_not_implemented_for_cov_corr(agg): + # GH 15354 + roll = DataFrame(range(2)).rolling(1, step=2) + with pytest.raises(NotImplementedError, match="step not implemented"): + getattr(roll, agg)() + + +@pytest.mark.parametrize("window", [timedelta(days=3), Timedelta(days=3)]) +def test_constructor_with_timedelta_window(window): + # GH 15440 + n = 10 + df = DataFrame( + {"value": np.arange(n)}, + index=date_range("2015-12-24", periods=n, freq="D"), + ) + expected_data = np.append([0.0, 1.0], np.arange(3.0, 27.0, 3)) + + result = df.rolling(window=window).sum() + expected = DataFrame( + {"value": expected_data}, + index=date_range("2015-12-24", periods=n, freq="D"), + ) + tm.assert_frame_equal(result, expected) + expected = df.rolling("3D").sum() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("window", [timedelta(days=3), Timedelta(days=3), "3D"]) +def test_constructor_timedelta_window_and_minperiods(window, raw): + # GH 15305 + n = 10 + df = DataFrame( + {"value": np.arange(n)}, + index=date_range("2017-08-08", periods=n, freq="D"), + ) + expected = DataFrame( + {"value": np.append([np.nan, 1.0], np.arange(3.0, 27.0, 3))}, + index=date_range("2017-08-08", periods=n, freq="D"), + ) + result_roll_sum = df.rolling(window=window, min_periods=2).sum() + result_roll_generic = df.rolling(window=window, min_periods=2).apply(sum, raw=raw) + tm.assert_frame_equal(result_roll_sum, expected) + tm.assert_frame_equal(result_roll_generic, expected) + + +def test_closed_fixed(closed, arithmetic_win_operators): + # GH 34315 + func_name = arithmetic_win_operators + df_fixed = DataFrame({"A": [0, 1, 2, 3, 4]}) + df_time = DataFrame({"A": [0, 1, 2, 3, 4]}, index=date_range("2020", periods=5)) + + result = getattr( + df_fixed.rolling(2, closed=closed, min_periods=1), + func_name, + )() + expected = getattr( + df_time.rolling("2D", closed=closed, min_periods=1), + func_name, + )().reset_index(drop=True) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "closed, window_selections", + [ + ( + "both", + [ + [True, True, False, False, False], + [True, True, True, False, False], + [False, True, True, True, False], + [False, False, True, True, True], + [False, False, False, True, True], + ], + ), + ( + "left", + [ + [True, False, False, False, False], + [True, True, False, False, False], + [False, True, True, False, False], + [False, False, True, True, False], + [False, False, False, True, True], + ], + ), + ( + "right", + [ + [True, True, False, False, False], + [False, True, True, False, False], + [False, False, True, True, False], + [False, False, False, True, True], + [False, False, False, False, True], + ], + ), + ( + "neither", + [ + [True, False, False, False, False], + [False, True, False, False, False], + [False, False, True, False, False], + [False, False, False, True, False], + [False, False, False, False, True], + ], + ), + ], +) +def test_datetimelike_centered_selections( + closed, window_selections, arithmetic_win_operators +): + # GH 34315 + func_name = arithmetic_win_operators + df_time = DataFrame( + {"A": [0.0, 1.0, 2.0, 3.0, 4.0]}, index=date_range("2020", periods=5) + ) + + expected = DataFrame( + {"A": [getattr(df_time["A"].iloc[s], func_name)() for s in window_selections]}, + index=date_range("2020", periods=5), + ) + + result = getattr( + df_time.rolling("2D", closed=closed, min_periods=1, center=True), + func_name, + )() + + tm.assert_frame_equal(result, expected, check_dtype=False) + + +@pytest.mark.parametrize( + "window,closed,expected", + [ + ("3s", "right", [3.0, 3.0, 3.0]), + ("3s", "both", [3.0, 3.0, 3.0]), + ("3s", "left", [3.0, 3.0, 3.0]), + ("3s", "neither", [3.0, 3.0, 3.0]), + ("2s", "right", [3.0, 2.0, 2.0]), + ("2s", "both", [3.0, 3.0, 3.0]), + ("2s", "left", [1.0, 3.0, 3.0]), + ("2s", "neither", [1.0, 2.0, 2.0]), + ], +) +def test_datetimelike_centered_offset_covers_all( + window, closed, expected, frame_or_series +): + # GH 42753 + + index = [ + Timestamp("20130101 09:00:01"), + Timestamp("20130101 09:00:02"), + Timestamp("20130101 09:00:02"), + ] + df = frame_or_series([1, 1, 1], index=index) + + result = df.rolling(window, closed=closed, center=True).sum() + expected = frame_or_series(expected, index=index) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "window,closed,expected", + [ + ("2D", "right", [4, 4, 4, 4, 4, 4, 2, 2]), + ("2D", "left", [2, 2, 4, 4, 4, 4, 4, 4]), + ("2D", "both", [4, 4, 6, 6, 6, 6, 4, 4]), + ("2D", "neither", [2, 2, 2, 2, 2, 2, 2, 2]), + ], +) +def test_datetimelike_nonunique_index_centering( + window, closed, expected, frame_or_series +): + index = DatetimeIndex( + [ + "2020-01-01", + "2020-01-01", + "2020-01-02", + "2020-01-02", + "2020-01-03", + "2020-01-03", + "2020-01-04", + "2020-01-04", + ] + ) + + df = frame_or_series([1] * 8, index=index, dtype=float) + expected = frame_or_series(expected, index=index, dtype=float) + + result = df.rolling(window, center=True, closed=closed).sum() + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "closed,expected", + [ + ("left", [np.nan, np.nan, 1, 1, 1, 10, 14, 14, 18, 21]), + ("neither", [np.nan, np.nan, 1, 1, 1, 9, 5, 5, 13, 8]), + ("right", [0, 1, 3, 6, 10, 14, 11, 18, 21, 17]), + ("both", [0, 1, 3, 6, 10, 15, 20, 27, 26, 30]), + ], +) +def test_variable_window_nonunique(closed, expected, frame_or_series): + # GH 20712 + index = DatetimeIndex( + [ + "2011-01-01", + "2011-01-01", + "2011-01-02", + "2011-01-02", + "2011-01-02", + "2011-01-03", + "2011-01-04", + "2011-01-04", + "2011-01-05", + "2011-01-06", + ] + ) + + df = frame_or_series(range(10), index=index, dtype=float) + expected = frame_or_series(expected, index=index, dtype=float) + + result = df.rolling("2D", closed=closed).sum() + + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "closed,expected", + [ + ("left", [np.nan, np.nan, 1, 1, 1, 10, 15, 15, 18, 21]), + ("neither", [np.nan, np.nan, 1, 1, 1, 10, 15, 15, 13, 8]), + ("right", [0, 1, 3, 6, 10, 15, 21, 28, 21, 17]), + ("both", [0, 1, 3, 6, 10, 15, 21, 28, 26, 30]), + ], +) +def test_variable_offset_window_nonunique(closed, expected, frame_or_series): + # GH 20712 + index = DatetimeIndex( + [ + "2011-01-01", + "2011-01-01", + "2011-01-02", + "2011-01-02", + "2011-01-02", + "2011-01-03", + "2011-01-04", + "2011-01-04", + "2011-01-05", + "2011-01-06", + ] + ) + + df = frame_or_series(range(10), index=index, dtype=float) + expected = frame_or_series(expected, index=index, dtype=float) + + offset = BusinessDay(2) + indexer = VariableOffsetWindowIndexer(index=index, offset=offset) + result = df.rolling(indexer, closed=closed, min_periods=1).sum() + + tm.assert_equal(result, expected) + + +def test_even_number_window_alignment(): + # see discussion in GH 38780 + s = Series(range(3), index=date_range(start="2020-01-01", freq="D", periods=3)) + + # behavior of index- and datetime-based windows differs here! + # s.rolling(window=2, min_periods=1, center=True).mean() + + result = s.rolling(window="2D", min_periods=1, center=True).mean() + + expected = Series([0.5, 1.5, 2], index=s.index) + + tm.assert_series_equal(result, expected) + + +def test_closed_fixed_binary_col(center, step): + # GH 34315 + data = [0, 1, 1, 0, 0, 1, 0, 1] + df = DataFrame( + {"binary_col": data}, + index=date_range(start="2020-01-01", freq="min", periods=len(data)), + ) + + if center: + expected_data = [2 / 3, 0.5, 0.4, 0.5, 0.428571, 0.5, 0.571429, 0.5] + else: + expected_data = [np.nan, 0, 0.5, 2 / 3, 0.5, 0.4, 0.5, 0.428571] + + expected = DataFrame( + expected_data, + columns=["binary_col"], + index=date_range(start="2020-01-01", freq="min", periods=len(expected_data)), + )[::step] + + rolling = df.rolling( + window=len(df), closed="left", min_periods=1, center=center, step=step + ) + result = rolling.mean() + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("closed", ["neither", "left"]) +def test_closed_empty(closed, arithmetic_win_operators): + # GH 26005 + func_name = arithmetic_win_operators + ser = Series(data=np.arange(5), index=date_range("2000", periods=5, freq="2D")) + roll = ser.rolling("1D", closed=closed) + + result = getattr(roll, func_name)() + expected = Series([np.nan] * 5, index=ser.index) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("func", ["min", "max"]) +def test_closed_one_entry(func): + # GH24718 + ser = Series(data=[2], index=date_range("2000", periods=1)) + result = getattr(ser.rolling("10D", closed="left"), func)() + tm.assert_series_equal(result, Series([np.nan], index=ser.index)) + + +@pytest.mark.parametrize("func", ["min", "max"]) +def test_closed_one_entry_groupby(func): + # GH24718 + ser = DataFrame( + data={"A": [1, 1, 2], "B": [3, 2, 1]}, + index=date_range("2000", periods=3), + ) + result = getattr( + ser.groupby("A", sort=False)["B"].rolling("10D", closed="left"), func + )() + exp_idx = MultiIndex.from_arrays(arrays=[[1, 1, 2], ser.index], names=("A", None)) + expected = Series(data=[np.nan, 3, np.nan], index=exp_idx, name="B") + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("input_dtype", ["int", "float"]) +@pytest.mark.parametrize( + "func,closed,expected", + [ + ("min", "right", [0.0, 0, 0, 1, 2, 3, 4, 5, 6, 7]), + ("min", "both", [0.0, 0, 0, 0, 1, 2, 3, 4, 5, 6]), + ("min", "neither", [np.nan, 0, 0, 1, 2, 3, 4, 5, 6, 7]), + ("min", "left", [np.nan, 0, 0, 0, 1, 2, 3, 4, 5, 6]), + ("max", "right", [0.0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), + ("max", "both", [0.0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), + ("max", "neither", [np.nan, 0, 1, 2, 3, 4, 5, 6, 7, 8]), + ("max", "left", [np.nan, 0, 1, 2, 3, 4, 5, 6, 7, 8]), + ], +) +def test_closed_min_max_datetime(input_dtype, func, closed, expected): + # see gh-21704 + ser = Series( + data=np.arange(10).astype(input_dtype), + index=date_range("2000", periods=10), + ) + + result = getattr(ser.rolling("3D", closed=closed), func)() + expected = Series(expected, index=ser.index) + tm.assert_series_equal(result, expected) + + +def test_closed_uneven(): + # see gh-21704 + ser = Series(data=np.arange(10), index=date_range("2000", periods=10)) + + # uneven + ser = ser.drop(index=ser.index[[1, 5]]) + result = ser.rolling("3D", closed="left").min() + expected = Series([np.nan, 0, 0, 2, 3, 4, 6, 6], index=ser.index) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "func,closed,expected", + [ + ("min", "right", [np.nan, 0, 0, 1, 2, 3, 4, 5, np.nan, np.nan]), + ("min", "both", [np.nan, 0, 0, 0, 1, 2, 3, 4, 5, np.nan]), + ("min", "neither", [np.nan, np.nan, 0, 1, 2, 3, 4, 5, np.nan, np.nan]), + ("min", "left", [np.nan, np.nan, 0, 0, 1, 2, 3, 4, 5, np.nan]), + ("max", "right", [np.nan, 1, 2, 3, 4, 5, 6, 6, np.nan, np.nan]), + ("max", "both", [np.nan, 1, 2, 3, 4, 5, 6, 6, 6, np.nan]), + ("max", "neither", [np.nan, np.nan, 1, 2, 3, 4, 5, 6, np.nan, np.nan]), + ("max", "left", [np.nan, np.nan, 1, 2, 3, 4, 5, 6, 6, np.nan]), + ], +) +def test_closed_min_max_minp(func, closed, expected): + # see gh-21704 + ser = Series(data=np.arange(10), index=date_range("2000", periods=10)) + # Explicit cast to float to avoid implicit cast when setting nan + ser = ser.astype("float") + ser[ser.index[-3:]] = np.nan + result = getattr(ser.rolling("3D", min_periods=2, closed=closed), func)() + expected = Series(expected, index=ser.index) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "closed,expected", + [ + ("right", [0, 0.5, 1, 2, 3, 4, 5, 6, 7, 8]), + ("both", [0, 0.5, 1, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5]), + ("neither", [np.nan, 0, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5]), + ("left", [np.nan, 0, 0.5, 1, 2, 3, 4, 5, 6, 7]), + ], +) +def test_closed_median_quantile(closed, expected): + # GH 26005 + ser = Series(data=np.arange(10), index=date_range("2000", periods=10)) + roll = ser.rolling("3D", closed=closed) + expected = Series(expected, index=ser.index) + + result = roll.median() + tm.assert_series_equal(result, expected) + + result = roll.quantile(0.5) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("roller", ["1s", 1]) +def tests_empty_df_rolling(roller): + # GH 15819 Verifies that datetime and integer rolling windows can be + # applied to empty DataFrames + expected = DataFrame() + result = DataFrame().rolling(roller).sum() + tm.assert_frame_equal(result, expected) + + # Verifies that datetime and integer rolling windows can be applied to + # empty DataFrames with datetime index + expected = DataFrame(index=DatetimeIndex([])) + result = DataFrame(index=DatetimeIndex([])).rolling(roller).sum() + tm.assert_frame_equal(result, expected) + + +def test_empty_window_median_quantile(): + # GH 26005 + expected = Series([np.nan, np.nan, np.nan]) + roll = Series(np.arange(3)).rolling(0) + + result = roll.median() + tm.assert_series_equal(result, expected) + + result = roll.quantile(0.1) + tm.assert_series_equal(result, expected) + + +def test_missing_minp_zero(): + # https://github.com/pandas-dev/pandas/pull/18921 + # minp=0 + x = Series([np.nan]) + result = x.rolling(1, min_periods=0).sum() + expected = Series([0.0]) + tm.assert_series_equal(result, expected) + + # minp=1 + result = x.rolling(1, min_periods=1).sum() + expected = Series([np.nan]) + tm.assert_series_equal(result, expected) + + +def test_missing_minp_zero_variable(): + # https://github.com/pandas-dev/pandas/pull/18921 + x = Series( + [np.nan] * 4, + index=DatetimeIndex(["2017-01-01", "2017-01-04", "2017-01-06", "2017-01-07"]), + ) + result = x.rolling(Timedelta("2D"), min_periods=0).sum() + expected = Series(0.0, index=x.index) + tm.assert_series_equal(result, expected) + + +def test_multi_index_names(): + # GH 16789, 16825 + cols = MultiIndex.from_product([["A", "B"], ["C", "D", "E"]], names=["1", "2"]) + df = DataFrame(np.ones((10, 6)), columns=cols) + result = df.rolling(3).cov() + + tm.assert_index_equal(result.columns, df.columns) + assert result.index.names == [None, "1", "2"] + + +def test_rolling_axis_sum(): + # see gh-23372. + df = DataFrame(np.ones((10, 20))) + expected = DataFrame({i: [np.nan] * 2 + [3.0] * 8 for i in range(20)}) + result = df.rolling(3).sum() + tm.assert_frame_equal(result, expected) + + +def test_rolling_axis_count(): + # see gh-26055 + df = DataFrame({"x": range(3), "y": range(3)}) + + expected = DataFrame({"x": [1.0, 2.0, 2.0], "y": [1.0, 2.0, 2.0]}) + result = df.rolling(2, min_periods=0).count() + tm.assert_frame_equal(result, expected) + + +def test_readonly_array(): + # GH-27766 + arr = np.array([1, 3, np.nan, 3, 5]) + arr.setflags(write=False) + result = Series(arr).rolling(2).mean() + expected = Series([np.nan, 2, np.nan, np.nan, 4]) + tm.assert_series_equal(result, expected) + + +def test_rolling_datetime(tz_naive_fixture): + # GH-28192 + tz = tz_naive_fixture + df = DataFrame( + {i: [1] * 2 for i in date_range("2019-8-01", "2019-08-03", freq="D", tz=tz)} + ) + + result = df.T.rolling("2D").sum().T + expected = DataFrame( + { + **{ + i: [1.0] * 2 + for i in date_range("2019-8-01", periods=1, freq="D", tz=tz) + }, + **{ + i: [2.0] * 2 + for i in date_range("2019-8-02", "2019-8-03", freq="D", tz=tz) + }, + } + ) + tm.assert_frame_equal(result, expected) + + +def test_rolling_window_as_string(center): + # see gh-22590 + date_today = datetime.now() + days = date_range(date_today, date_today + timedelta(365), freq="D") + + data = np.ones(len(days)) + df = DataFrame({"DateCol": days, "metric": data}) + + df.set_index("DateCol", inplace=True) + result = df.rolling(window="21D", min_periods=2, closed="left", center=center)[ + "metric" + ].agg("max") + + index = days.rename("DateCol") + index = index._with_freq(None) + expected_data = np.ones(len(days), dtype=np.float64) + if not center: + expected_data[:2] = np.nan + expected = Series(expected_data, index=index, name="metric") + tm.assert_series_equal(result, expected) + + +def test_min_periods1(): + # GH#6795 + df = DataFrame([0, 1, 2, 1, 0], columns=["a"]) + result = df["a"].rolling(3, center=True, min_periods=1).max() + expected = Series([1.0, 2.0, 2.0, 2.0, 1.0], name="a") + tm.assert_series_equal(result, expected) + + +def test_rolling_count_with_min_periods(frame_or_series): + # GH 26996 + result = frame_or_series(range(5)).rolling(3, min_periods=3).count() + expected = frame_or_series([np.nan, np.nan, 3.0, 3.0, 3.0]) + tm.assert_equal(result, expected) + + +def test_rolling_count_default_min_periods_with_null_values(frame_or_series): + # GH 26996 + values = [1, 2, 3, np.nan, 4, 5, 6] + expected_counts = [1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 3.0] + + # GH 31302 + result = frame_or_series(values).rolling(3, min_periods=0).count() + expected = frame_or_series(expected_counts) + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "df,expected,window,min_periods", + [ + ( + {"A": [1, 2, 3], "B": [4, 5, 6]}, + [ + ({"A": [1], "B": [4]}, [0]), + ({"A": [1, 2], "B": [4, 5]}, [0, 1]), + ({"A": [1, 2, 3], "B": [4, 5, 6]}, [0, 1, 2]), + ], + 3, + None, + ), + ( + {"A": [1, 2, 3], "B": [4, 5, 6]}, + [ + ({"A": [1], "B": [4]}, [0]), + ({"A": [1, 2], "B": [4, 5]}, [0, 1]), + ({"A": [2, 3], "B": [5, 6]}, [1, 2]), + ], + 2, + 1, + ), + ( + {"A": [1, 2, 3], "B": [4, 5, 6]}, + [ + ({"A": [1], "B": [4]}, [0]), + ({"A": [1, 2], "B": [4, 5]}, [0, 1]), + ({"A": [2, 3], "B": [5, 6]}, [1, 2]), + ], + 2, + 2, + ), + ( + {"A": [1, 2, 3], "B": [4, 5, 6]}, + [ + ({"A": [1], "B": [4]}, [0]), + ({"A": [2], "B": [5]}, [1]), + ({"A": [3], "B": [6]}, [2]), + ], + 1, + 1, + ), + ( + {"A": [1, 2, 3], "B": [4, 5, 6]}, + [ + ({"A": [1], "B": [4]}, [0]), + ({"A": [2], "B": [5]}, [1]), + ({"A": [3], "B": [6]}, [2]), + ], + 1, + 0, + ), + ({"A": [1], "B": [4]}, [], 2, None), + ({"A": [1], "B": [4]}, [], 2, 1), + (None, [({}, [])], 2, None), + ( + {"A": [1, np.nan, 3], "B": [np.nan, 5, 6]}, + [ + ({"A": [1.0], "B": [np.nan]}, [0]), + ({"A": [1, np.nan], "B": [np.nan, 5]}, [0, 1]), + ({"A": [1, np.nan, 3], "B": [np.nan, 5, 6]}, [0, 1, 2]), + ], + 3, + 2, + ), + ], +) +def test_iter_rolling_dataframe(df, expected, window, min_periods): + # GH 11704 + df = DataFrame(df) + expecteds = [DataFrame(values, index=index) for (values, index) in expected] + + for expected, actual in zip( + expecteds, df.rolling(window, min_periods=min_periods), strict=False + ): + tm.assert_frame_equal(actual, expected) + + +@pytest.mark.parametrize( + "expected,window", + [ + ( + [ + ({"A": [1], "B": [4]}, [0]), + ({"A": [1, 2], "B": [4, 5]}, [0, 1]), + ({"A": [2, 3], "B": [5, 6]}, [1, 2]), + ], + "2D", + ), + ( + [ + ({"A": [1], "B": [4]}, [0]), + ({"A": [1, 2], "B": [4, 5]}, [0, 1]), + ({"A": [1, 2, 3], "B": [4, 5, 6]}, [0, 1, 2]), + ], + "3D", + ), + ( + [ + ({"A": [1], "B": [4]}, [0]), + ({"A": [2], "B": [5]}, [1]), + ({"A": [3], "B": [6]}, [2]), + ], + "1D", + ), + ], +) +def test_iter_rolling_on_dataframe(expected, window): + # GH 11704, 40373 + df = DataFrame( + { + "A": [1, 2, 3, 4, 5], + "B": [4, 5, 6, 7, 8], + "C": date_range(start="2016-01-01", periods=5, freq="D"), + } + ) + + expecteds = [ + DataFrame(values, index=df.loc[index, "C"]) for (values, index) in expected + ] + for expected, actual in zip(expecteds, df.rolling(window, on="C"), strict=False): + tm.assert_frame_equal(actual, expected) + + +def test_iter_rolling_on_dataframe_unordered(): + # GH 43386 + df = DataFrame({"a": ["x", "y", "x"], "b": [0, 1, 2]}) + results = list(df.groupby("a").rolling(2)) + expecteds = [df.iloc[idx, [1]] for idx in [[0], [0, 2], [1]]] + for result, expected in zip(results, expecteds, strict=True): + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "ser,expected,window, min_periods", + [ + ( + Series([1, 2, 3]), + [([1], [0]), ([1, 2], [0, 1]), ([1, 2, 3], [0, 1, 2])], + 3, + None, + ), + ( + Series([1, 2, 3]), + [([1], [0]), ([1, 2], [0, 1]), ([1, 2, 3], [0, 1, 2])], + 3, + 1, + ), + ( + Series([1, 2, 3]), + [([1], [0]), ([1, 2], [0, 1]), ([2, 3], [1, 2])], + 2, + 1, + ), + ( + Series([1, 2, 3]), + [([1], [0]), ([1, 2], [0, 1]), ([2, 3], [1, 2])], + 2, + 2, + ), + (Series([1, 2, 3]), [([1], [0]), ([2], [1]), ([3], [2])], 1, 0), + (Series([1, 2, 3]), [([1], [0]), ([2], [1]), ([3], [2])], 1, 1), + (Series([1, 2]), [([1], [0]), ([1, 2], [0, 1])], 2, 0), + (Series([], dtype="int64"), [], 2, 1), + ], +) +def test_iter_rolling_series(ser, expected, window, min_periods): + # GH 11704 + expecteds = [Series(values, index=index) for (values, index) in expected] + + for expected, actual in zip( + expecteds, ser.rolling(window, min_periods=min_periods), strict=True + ): + tm.assert_series_equal(actual, expected) + + +@pytest.mark.parametrize( + "expected,expected_index,window", + [ + ( + [[0], [1], [2], [3], [4]], + [ + date_range("2020-01-01", periods=1, freq="D"), + date_range("2020-01-02", periods=1, freq="D"), + date_range("2020-01-03", periods=1, freq="D"), + date_range("2020-01-04", periods=1, freq="D"), + date_range("2020-01-05", periods=1, freq="D"), + ], + "1D", + ), + ( + [[0], [0, 1], [1, 2], [2, 3], [3, 4]], + [ + date_range("2020-01-01", periods=1, freq="D"), + date_range("2020-01-01", periods=2, freq="D"), + date_range("2020-01-02", periods=2, freq="D"), + date_range("2020-01-03", periods=2, freq="D"), + date_range("2020-01-04", periods=2, freq="D"), + ], + "2D", + ), + ( + [[0], [0, 1], [0, 1, 2], [1, 2, 3], [2, 3, 4]], + [ + date_range("2020-01-01", periods=1, freq="D"), + date_range("2020-01-01", periods=2, freq="D"), + date_range("2020-01-01", periods=3, freq="D"), + date_range("2020-01-02", periods=3, freq="D"), + date_range("2020-01-03", periods=3, freq="D"), + ], + "3D", + ), + ], +) +def test_iter_rolling_datetime(expected, expected_index, window): + # GH 11704 + ser = Series(range(5), index=date_range(start="2020-01-01", periods=5, freq="D")) + + expecteds = [ + Series(values, index=idx) + for (values, idx) in zip(expected, expected_index, strict=True) + ] + + for expected, actual in zip(expecteds, ser.rolling(window), strict=True): + tm.assert_series_equal(actual, expected) + + +@pytest.mark.parametrize( + "grouping,_index", + [ + ( + {"level": 0}, + MultiIndex.from_tuples( + [(0, 0), (0, 0), (1, 1), (1, 1), (1, 1)], names=[None, None] + ), + ), + ( + {"by": "X"}, + MultiIndex.from_tuples( + [(0, 0), (1, 0), (2, 1), (3, 1), (4, 1)], names=["X", None] + ), + ), + ], +) +def test_rolling_positional_argument(grouping, _index, raw): + # GH 34605 + + def scaled_sum(*args): + if len(args) < 2: + raise ValueError("The function needs two arguments") + array, scale = args + return array.sum() / scale + + df = DataFrame(data={"X": range(5)}, index=[0, 0, 1, 1, 1]) + + expected = DataFrame(data={"X": [0.0, 0.5, 1.0, 1.5, 2.0]}, index=_index) + # GH 40341 + if "by" in grouping: + expected = expected.drop(columns="X", errors="ignore") + result = df.groupby(**grouping).rolling(1).apply(scaled_sum, raw=raw, args=(2,)) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("add", [0.0, 2.0]) +def test_rolling_numerical_accuracy_kahan_mean(add, unit): + # GH: 36031 implementing kahan summation + dti = DatetimeIndex( + [ + Timestamp("19700101 09:00:00"), + Timestamp("19700101 09:00:03"), + Timestamp("19700101 09:00:06"), + ] + ).as_unit(unit) + df = DataFrame( + {"A": [3002399751580331.0 + add, -0.0, -0.0]}, + index=dti, + ) + result = ( + df.resample("1s").ffill().rolling("3s", closed="left", min_periods=3).mean() + ) + dates = date_range("19700101 09:00:00", periods=7, freq="s", unit=unit) + expected = DataFrame( + { + "A": [ + np.nan, + np.nan, + np.nan, + 3002399751580330.5, + 2001599834386887.25, + 1000799917193443.625, + 0.0, + ] + }, + index=dates, + ) + tm.assert_frame_equal(result, expected) + + +def test_rolling_numerical_accuracy_kahan_sum(): + # GH: 13254 + df = DataFrame([2.186, -1.647, 0.0, 0.0, 0.0, 0.0], columns=["x"]) + result = df["x"].rolling(3).sum() + expected = Series([np.nan, np.nan, 0.539, -1.647, 0.0, 0.0], name="x") + tm.assert_series_equal(result, expected) + + +def test_rolling_numerical_accuracy_jump(): + # GH: 32761 + index = date_range(start="2020-01-01", end="2020-01-02", freq="60s").append( + DatetimeIndex(["2020-01-03"]) + ) + data = np.random.default_rng(2).random(len(index)) + + df = DataFrame({"data": data}, index=index) + result = df.rolling("60s").mean() + tm.assert_frame_equal(result, df[["data"]]) + + +def test_rolling_numerical_accuracy_small_values(): + # GH: 10319 + s = Series( + data=[0.00012456, 0.0003, -0.0, -0.0], + index=date_range("1999-02-03", "1999-02-06"), + ) + result = s.rolling(1).mean() + tm.assert_series_equal(result, s) + + +def test_rolling_numerical_too_large_numbers(): + # GH: 11645 + dates = date_range("2015-01-01", periods=10, freq="D") + ds = Series(data=range(10), index=dates, dtype=np.float64) + ds.iloc[2] = -9e33 + result = ds.rolling(5).mean() + expected = Series( + [ + np.nan, + np.nan, + np.nan, + np.nan, + -1.8e33, + -1.8e33, + -1.8e33, + 5.0, + 6.0, + 7.0, + ], + index=dates, + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("index", "window"), + [ + ( + period_range(start="2020-01-01 08:00", end="2020-01-01 08:08", freq="min"), + "2min", + ), + ( + period_range( + start="2020-01-01 08:00", end="2020-01-01 12:00", freq="30min" + ), + "1h", + ), + ], +) +@pytest.mark.parametrize( + ("func", "values"), + [ + ("min", [np.nan, 0, 0, 1, 2, 3, 4, 5, 6]), + ("max", [np.nan, 0, 1, 2, 3, 4, 5, 6, 7]), + ("sum", [np.nan, 0, 1, 3, 5, 7, 9, 11, 13]), + ], +) +def test_rolling_period_index(index, window, func, values): + # GH: 34225 + ds = Series([0, 1, 2, 3, 4, 5, 6, 7, 8], index=index) + result = getattr(ds.rolling(window, closed="left"), func)() + expected = Series(values, index=index) + tm.assert_series_equal(result, expected) + + +def test_rolling_sem(frame_or_series): + # GH: 26476 + obj = frame_or_series([0, 1, 2]) + result = obj.rolling(2, min_periods=1).sem() + if isinstance(result, DataFrame): + result = Series(result[0].values) + expected = Series([np.nan] + [0.5] * 2) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("func", "values", "window", "ddof", "expected_values"), + [ + ("var", [99999999999999999, 1, 1, 2, 3, 1, 1], 2, 1, [5e33, 0, 0.5, 0.5, 2, 0]), + ( + "std", + [99999999999999999, 1, 1, 2, 3, 1, 1], + 2, + 1, + [7.071068e16, 0, 0.7071068, 0.7071068, 1.414214, 0], + ), + ("var", [99999999999999999, 1, 2, 2, 3, 1, 1], 2, 1, [5e33, 0.5, 0, 0.5, 2, 0]), + ( + "std", + [99999999999999999, 1, 2, 2, 3, 1, 1], + 2, + 1, + [7.071068e16, 0.7071068, 0, 0.7071068, 1.414214, 0], + ), + ( + "std", + [1.2e03, 1.3e17, 1.5e17, 1.995e03, 1.990e03], + 2, + 1, + [9.192388e16, 1.414214e16, 1.060660e17, 3.535534e00], + ), + ( + "var", + [ + 0.00000000e00, + 0.00000000e00, + 3.16188252e-18, + 2.95781651e-16, + 2.23153542e-51, + 0.00000000e00, + 0.00000000e00, + 5.39943432e-48, + 1.38206260e-73, + 0.00000000e00, + ], + 3, + 1, + [ + 3.33250036e-036, + 2.88538519e-032, + 2.88538519e-032, + 2.91622617e-032, + 1.65991678e-102, + 9.71796366e-096, + 9.71796366e-096, + 9.71796366e-096, + ], + ), + ( + "std", + [1, -1, 0, 1, 3, 2, -2, 10000000000, 1, 2, 0, -2, 1, 3, 0, 1], + 6, + 1, + [ + 1.41421356e00, + 1.87082869e00, + 4.08248290e09, + 4.08248290e09, + 4.08248290e09, + 4.08248290e09, + 4.08248290e09, + 4.08248290e09, + 1.72240142e00, + 1.75119007e00, + 1.64316767e00, + ], + ), + ], +) +def test_rolling_var_correctness(func, values, window, ddof, expected_values): + # GH: 37051, 42064, 54518, 52407, 47721 + ts = Series(values) + result = getattr(ts.rolling(window=window), func)(ddof=ddof) + if result.last_valid_index(): + result = result[ + result.first_valid_index() : result.last_valid_index() + 1 + ].reset_index(drop=True) + expected = Series(expected_values) + tm.assert_series_equal(result, expected, atol=1e-55) + # GH 42064 + tm.assert_series_equal(result == 0, expected == 0) + + +def test_timeoffset_as_window_parameter_for_corr(unit): + # GH: 28266 + dti = DatetimeIndex( + [ + Timestamp("20130101 09:00:00"), + Timestamp("20130102 09:00:02"), + Timestamp("20130103 09:00:03"), + Timestamp("20130105 09:00:05"), + Timestamp("20130106 09:00:06"), + ] + ).as_unit(unit) + mi = MultiIndex.from_product([dti, ["B", "A"]]) + + exp = DataFrame( + { + "B": [ + np.nan, + np.nan, + 0.9999999999999998, + -1.0, + 1.0, + -0.3273268353539892, + 0.9999999999999998, + 1.0, + 0.9999999999999998, + 1.0, + ], + "A": [ + np.nan, + np.nan, + -1.0, + 1.0000000000000002, + -0.3273268353539892, + 0.9999999999999966, + 1.0, + 1.0000000000000002, + 1.0, + 1.0000000000000002, + ], + }, + index=mi, + ) + + df = DataFrame( + {"B": [0, 1, 2, 4, 3], "A": [7, 4, 6, 9, 3]}, + index=dti, + ) + + res = df.rolling(window="3D").corr() + + tm.assert_frame_equal(exp, res) + + +@pytest.mark.parametrize("method", ["var", "sum", "mean", "skew", "kurt", "min", "max"]) +def test_rolling_decreasing_indices(method): + """ + Make sure that decreasing indices give the same results as increasing indices. + + GH 36933 + """ + df = DataFrame({"values": np.arange(-15, 10) ** 2}) + df_reverse = DataFrame({"values": df["values"][::-1]}, index=df.index[::-1]) + + increasing = getattr(df.rolling(window=5), method)() + decreasing = getattr(df_reverse.rolling(window=5), method)() + + tm.assert_almost_equal( + decreasing.values[::-1][:-4], increasing.values[4:], atol=1e-12 + ) + + +@pytest.mark.parametrize( + "window,closed,expected", + [ + ("2s", "right", [1.0, 3.0, 5.0, 3.0]), + ("2s", "left", [0.0, 1.0, 3.0, 5.0]), + ("2s", "both", [1.0, 3.0, 6.0, 5.0]), + ("2s", "neither", [0.0, 1.0, 2.0, 3.0]), + ("3s", "right", [1.0, 3.0, 6.0, 5.0]), + ("3s", "left", [1.0, 3.0, 6.0, 5.0]), + ("3s", "both", [1.0, 3.0, 6.0, 5.0]), + ("3s", "neither", [1.0, 3.0, 6.0, 5.0]), + ], +) +def test_rolling_decreasing_indices_centered(window, closed, expected, frame_or_series): + """ + Ensure that a symmetrical inverted index return same result as non-inverted. + """ + # GH 43927 + + index = date_range("2020", periods=4, freq="1s") + df_inc = frame_or_series(range(4), index=index) + df_dec = frame_or_series(range(4), index=index[::-1]) + + expected_inc = frame_or_series(expected, index=index) + expected_dec = frame_or_series(expected, index=index[::-1]) + + result_inc = df_inc.rolling(window, closed=closed, center=True).sum() + result_dec = df_dec.rolling(window, closed=closed, center=True).sum() + + tm.assert_equal(result_inc, expected_inc) + tm.assert_equal(result_dec, expected_dec) + + +@pytest.mark.parametrize( + "window,expected", + [ + ("1ns", [1.0, 1.0, 1.0, 1.0]), + ("3ns", [2.0, 3.0, 3.0, 2.0]), + ], +) +def test_rolling_center_nanosecond_resolution( + window, closed, expected, frame_or_series +): + index = date_range("2020", periods=4, freq="1ns") + df = frame_or_series([1, 1, 1, 1], index=index, dtype=float) + expected = frame_or_series(expected, index=index, dtype=float) + result = df.rolling(window, closed=closed, center=True).sum() + tm.assert_equal(result, expected) + + +@pytest.mark.parametrize( + "method,expected", + [ + ( + "var", + [ + float("nan"), + 43.0, + float("nan"), + 136.333333, + 43.5, + 94.966667, + 182.0, + 318.0, + ], + ), + ( + "mean", + [float("nan"), 7.5, float("nan"), 21.5, 6.0, 9.166667, 13.0, 17.5], + ), + ( + "sum", + [float("nan"), 30.0, float("nan"), 86.0, 30.0, 55.0, 91.0, 140.0], + ), + ( + "skew", + [ + float("nan"), + 0.709296, + float("nan"), + 0.407073, + 0.984656, + 0.919184, + 0.874674, + 0.842418, + ], + ), + ( + "kurt", + [ + float("nan"), + -0.5916711736073559, + float("nan"), + -1.0028993131317954, + -0.06103844629409494, + -0.254143227116194, + -0.37362637362637585, + -0.45439658241367054, + ], + ), + ], +) +def test_rolling_non_monotonic(method, expected): + """ + Make sure the (rare) branch of non-monotonic indices is covered by a test. + + output from 1.1.3 is assumed to be the expected output. Output of sum/mean has + manually been verified. + + GH 36933. + """ + # Based on an example found in computation.rst + use_expanding = [True, False, True, False, True, True, True, True] + df = DataFrame({"values": np.arange(len(use_expanding)) ** 2}) + + class CustomIndexer(BaseIndexer): + def get_window_bounds(self, num_values, min_periods, center, closed, step): + start = np.empty(num_values, dtype=np.int64) + end = np.empty(num_values, dtype=np.int64) + for i in range(num_values): + if self.use_expanding[i]: + start[i] = 0 + end[i] = i + 1 + else: + start[i] = i + end[i] = i + self.window_size + return start, end + + indexer = CustomIndexer(window_size=4, use_expanding=use_expanding) + + result = getattr(df.rolling(indexer), method)() + expected = DataFrame({"values": expected}) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("index", "window"), + [ + ([0, 1, 2, 3, 4], 2), + (date_range("2001-01-01", freq="D", periods=5), "2D"), + ], +) +def test_rolling_corr_timedelta_index(index, window): + # GH: 31286 + x = Series([1, 2, 3, 4, 5], index=index) + y = x.copy() + x.iloc[0:2] = 0.0 + result = x.rolling(window).corr(y) + expected = Series([np.nan, np.nan, 1, 1, 1], index=index) + tm.assert_almost_equal(result, expected) + + +@pytest.mark.parametrize( + "values,method,expected", + [ + ( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + "first", + [float("nan"), float("nan"), 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + ), + ( + [1.0, np.nan, 3.0, np.nan, 5.0, np.nan, 7.0, np.nan, 9.0, np.nan], + "first", + [float("nan")] * 10, + ), + ( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + "last", + [float("nan"), float("nan"), 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + ), + ( + [1.0, np.nan, 3.0, np.nan, 5.0, np.nan, 7.0, np.nan, 9.0, np.nan], + "last", + [float("nan")] * 10, + ), + ], +) +def test_rolling_first_last(values, method, expected): + # GH#33155 + x = Series(values) + result = getattr(x.rolling(3), method)() + expected = Series(expected) + tm.assert_almost_equal(result, expected) + + x = DataFrame({"A": values}) + result = getattr(x.rolling(3), method)() + expected = DataFrame({"A": expected}) + tm.assert_almost_equal(result, expected) + + +@pytest.mark.parametrize( + "values,method,expected", + [ + ( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + "first", + [1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + ), + ( + [1.0, np.nan, 3.0, np.nan, 5.0, np.nan, 7.0, np.nan, 9.0, np.nan], + "first", + [1.0, 1.0, 1.0, 3.0, 3.0, 5.0, 5.0, 7.0, 7.0, 9.0], + ), + ( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + "last", + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + ), + ( + [1.0, np.nan, 3.0, np.nan, 5.0, np.nan, 7.0, np.nan, 9.0, np.nan], + "last", + [1.0, 1.0, 3.0, 3.0, 5.0, 5.0, 7.0, 7.0, 9.0, 9.0], + ), + ], +) +def test_rolling_first_last_no_minp(values, method, expected): + # GH#33155 + x = Series(values) + result = getattr(x.rolling(3, min_periods=0), method)() + expected = Series(expected) + tm.assert_almost_equal(result, expected) + + x = DataFrame({"A": values}) + result = getattr(x.rolling(3, min_periods=0), method)() + expected = DataFrame({"A": expected}) + tm.assert_almost_equal(result, expected) + + +def test_groupby_rolling_nan_included(): + # GH 35542 + data = {"group": ["g1", np.nan, "g1", "g2", np.nan], "B": [0, 1, 2, 3, 4]} + df = DataFrame(data) + result = df.groupby("group", dropna=False).rolling(1, min_periods=1).mean() + expected = DataFrame( + {"B": [0.0, 2.0, 3.0, 1.0, 4.0]}, + # GH-38057 from_tuples puts the NaNs in the codes, result expects them + # to be in the levels, at the moment + # index=MultiIndex.from_tuples( + # [("g1", 0), ("g1", 2), ("g2", 3), (np.nan, 1), (np.nan, 4)], + # names=["group", None], + # ), + index=MultiIndex( + [["g1", "g2", np.nan], [0, 1, 2, 3, 4]], + [[0, 0, 1, 2, 2], [0, 2, 3, 1, 4]], + names=["group", None], + ), + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("method", ["skew", "kurt"]) +def test_rolling_skew_kurt_numerical_stability(method): + # GH#6929 + ser = Series(np.random.default_rng(2).random(10)) + ser_copy = ser.copy() + expected = getattr(ser.rolling(3), method)() + tm.assert_series_equal(ser, ser_copy) + ser = ser + 50000 + result = getattr(ser.rolling(3), method)() + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("method", "data", "values"), + [ + ( + "skew", + [3000000, 1, 1, 2, 3, 4, 999], + [np.nan] * 3 + [2.0, 0.854563, 0.0, 1.999984], + ), + ( + "skew", + [1e6, -1e6, 1, 2, 3, 4, 5, 6], + [np.nan] * 3 + [-5.51135192e-06, -2.0, 0.0, 0.0, 0.0], + ), + ( + "kurt", + [3000000, 1, 1, 2, 3, 4, 999], + [np.nan] * 3 + [4.0, -1.289256, -1.2, 3.999946], + ), + ( + "kurt", + [1e6, -1e6, 1, 2, 3, 4, 5, 6], + [np.nan] * 3 + [1.5, 4.0, -1.2, -1.2, -1.2], + ), + ], +) +def test_rolling_skew_kurt_large_value_range(method, data, values): + # GH: 37557, 47461, 61416 + s = Series(data) + result = getattr(s.rolling(4), method)() + expected = Series(values) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("method", ["skew", "kurt"]) +def test_same_result_with_different_lengths(method): + # GH-54380 + len_smaller = 10 + len_bigger = 12 + window_size = 8 + + rng = np.random.default_rng(2) + data = rng.normal(loc=0.0, scale=1e3, size=len_bigger) + window_smaller = Series(data[:len_smaller]).rolling(window_size) + window_bigger = Series(data).rolling(window_size) + + result_smaller = getattr(window_smaller, method)() + result_bigger = getattr(window_bigger, method)() + + result_bigger_trimmed = result_bigger[:len_smaller] + + tm.assert_series_equal(result_smaller, result_bigger_trimmed, check_exact=True) + + +def test_invalid_method(): + with pytest.raises(ValueError, match="method must be 'table' or 'single"): + Series(range(1)).rolling(1, method="foo") + + +def test_rolling_descending_date_order_with_offset(frame_or_series): + # GH#40002 + msg = "'d' is deprecated and will be removed in a future version." + + with tm.assert_produces_warning(Pandas4Warning, match=msg): + idx = date_range(start="2020-01-01", end="2020-01-03", freq="1d") + obj = frame_or_series(range(1, 4), index=idx) + result = obj.rolling("1d", closed="left").sum() + + expected = frame_or_series([np.nan, 1, 2], index=idx) + tm.assert_equal(result, expected) + + result = obj.iloc[::-1].rolling("1D", closed="left").sum() + idx = date_range(start="2020-01-03", end="2020-01-01", freq="-1D") + expected = frame_or_series([np.nan, 3, 2], index=idx) + tm.assert_equal(result, expected) + + +def test_rolling_var_floating_artifact_precision(): + # GH 37051 + s = Series([7, 5, 5, 5]) + result = s.rolling(3).var() + expected = Series([np.nan, np.nan, 4 / 3, 0]) + tm.assert_series_equal(result, expected, atol=1.0e-15, rtol=1.0e-15) + # GH 42064 + # new `roll_var` will output 0.0 correctly + tm.assert_series_equal(result == 0, expected == 0) + + +def test_rolling_std_small_values(): + # GH 37051 + s = Series( + [ + 0.00000054, + 0.00000053, + 0.00000054, + ] + ) + result = s.rolling(2).std() + expected = Series([np.nan, 7.071068e-9, 7.071068e-9]) + tm.assert_series_equal(result, expected, atol=1.0e-15, rtol=1.0e-15) + + +@pytest.mark.parametrize( + "start, exp_values", + [ + (1, [0.03, 0.0155, 0.0155, 0.011, 0.01025]), + (2, [0.001, 0.001, 0.0015, 0.00366666]), + ], +) +def test_rolling_mean_all_nan_window_floating_artifacts(start, exp_values): + # GH#41053 + df = DataFrame( + [ + 0.03, + 0.03, + 0.001, + np.nan, + 0.002, + 0.008, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + np.nan, + 0.005, + 0.2, + ] + ) + + values = [ + *exp_values, + 0.00366666, + 0.005, + 0.005, + 0.008, + np.nan, + np.nan, + 0.005, + 0.102500, + ] + expected = DataFrame( + values, + index=list(range(start, len(values) + start)), + ) + result = df.iloc[start:].rolling(5, min_periods=0).mean() + tm.assert_frame_equal(result, expected) + + +def test_rolling_sum_all_nan_window_floating_artifacts(): + # GH#41053 + df = DataFrame([0.002, 0.008, 0.005, np.nan, np.nan, np.nan]) + result = df.rolling(3, min_periods=0).sum() + expected = DataFrame([0.002, 0.010, 0.015, 0.013, 0.005, 0.0]) + tm.assert_frame_equal(result, expected) + + +def test_rolling_zero_window(): + # GH 22719 + s = Series(range(1)) + result = s.rolling(0).min() + expected = Series([np.nan]) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("window", [1, 3, 10, 20]) +@pytest.mark.parametrize("method", ["min", "max", "average"]) +@pytest.mark.parametrize("pct", [True, False]) +@pytest.mark.parametrize("test_data", ["default", "duplicates", "nans"]) +def test_rank(window, method, pct, ascending, test_data): + length = 20 + if test_data == "default": + ser = Series(data=np.random.default_rng(2).random(length)) + elif test_data == "duplicates": + ser = Series(data=np.random.default_rng(2).choice(3, length)) + elif test_data == "nans": + ser = Series( + data=np.random.default_rng(2).choice( + [1.0, 0.25, 0.75, np.nan, np.inf, -np.inf], length + ) + ) + + expected = ser.rolling(window).apply( + lambda x: x.rank(method=method, pct=pct, ascending=ascending).iloc[-1] + ) + result = ser.rolling(window).rank(method=method, pct=pct, ascending=ascending) + + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("window", [1, 3, 10, 20]) +@pytest.mark.parametrize("test_data", ["default", "duplicates", "nans", "precision"]) +def test_nunique(window, test_data): + length = 20 + if test_data == "default": + ser = Series(data=np.random.default_rng(2).random(length)) + elif test_data == "duplicates": + ser = Series(data=np.random.default_rng(2).choice(3, length)) + elif test_data == "nans": + ser = Series( + data=np.random.default_rng(2).choice( + [1.0, 0.25, 0.75, np.nan, np.inf, -np.inf], length + ) + ) + elif test_data == "precision": + ser = Series( + data=[ + 0.3, + 0.1 * 3, # Not necessarily exactly 0.3 + 0.6, + 0.2 * 3, # Not necessarily exactly 0.6 + 0.9, + 0.3 * 3, # Not necessarily exactly 0.9 + 0.5, + 0.1 * 5, # Not necessarily exactly 0.5 + 0.8, + 0.2 * 4, # Not necessarily exactly 0.8 + ], + dtype=np.float64, + ) + + expected = ser.rolling(window).apply(lambda x: x.nunique()) + result = ser.rolling(window).nunique() + + tm.assert_series_equal(result, expected) + + +def test_rolling_quantile_np_percentile(): + # #9413: Tests that rolling window's quantile default behavior + # is analogous to Numpy's percentile + row = 10 + col = 5 + idx = date_range("20100101", periods=row, freq="B") + df = DataFrame( + np.random.default_rng(2).random(row * col).reshape((row, -1)), index=idx + ) + + df_quantile = df.quantile([0.25, 0.5, 0.75], axis=0) + np_percentile = np.percentile(df, [25, 50, 75], axis=0) + + tm.assert_almost_equal(df_quantile.values, np.array(np_percentile)) + + +@pytest.mark.parametrize("quantile", [0.0, 0.1, 0.45, 0.5, 1]) +@pytest.mark.parametrize( + "interpolation", ["linear", "lower", "higher", "nearest", "midpoint"] +) +@pytest.mark.parametrize( + "data", + [ + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + [8.0, 1.0, 3.0, 4.0, 5.0, 2.0, 6.0, 7.0], + [0.0, np.nan, 0.2, np.nan, 0.4], + [np.nan, np.nan, np.nan, np.nan], + [np.nan, 0.1, np.nan, 0.3, 0.4, 0.5], + [0.5], + [np.nan, 0.7, 0.6], + ], +) +def test_rolling_quantile_interpolation_options(quantile, interpolation, data): + # Tests that rolling window's quantile behavior is analogous to + # Series' quantile for each interpolation option + s = Series(data) + + q1 = s.quantile(quantile, interpolation) + q2 = s.expanding(min_periods=1).quantile(quantile, interpolation).iloc[-1] + + if np.isnan(q1): + assert np.isnan(q2) + elif not IS64: + # Less precision on 32-bit + assert np.allclose([q1], [q2], rtol=1e-07, atol=0) + else: + assert q1 == q2 + + +def test_invalid_quantile_value(): + data = np.arange(5) + s = Series(data) + + msg = "Interpolation 'invalid' is not supported" + with pytest.raises(ValueError, match=msg): + s.rolling(len(data), min_periods=1).quantile(0.5, interpolation="invalid") + + +def test_rolling_quantile_param(): + ser = Series([0.0, 0.1, 0.5, 0.9, 1.0]) + msg = "quantile value -0.1 not in \\[0, 1\\]" + with pytest.raises(ValueError, match=msg): + ser.rolling(3).quantile(-0.1) + + msg = "quantile value 10.0 not in \\[0, 1\\]" + with pytest.raises(ValueError, match=msg): + ser.rolling(3).quantile(10.0) + + msg = "must be real number, not str" + with pytest.raises(TypeError, match=msg): + ser.rolling(3).quantile("foo") + + +def test_rolling_std_1obs(): + vals = Series([1.0, 2.0, 3.0, 4.0, 5.0]) + + result = vals.rolling(1, min_periods=1).std() + expected = Series([np.nan] * 5) + tm.assert_series_equal(result, expected) + + result = vals.rolling(1, min_periods=1).std(ddof=0) + expected = Series([0.0] * 5) + tm.assert_series_equal(result, expected) + + result = Series([np.nan, np.nan, 3, 4, 5]).rolling(3, min_periods=2).std() + assert np.isnan(result[2]) + + +def test_rolling_std_neg_sqrt(): + # unit test from Bottleneck + + # Test move_nanstd for neg sqrt. + + a = Series( + [ + 0.0011448196318903589, + 0.00028718669878572767, + 0.00028718669878572767, + 0.00028718669878572767, + 0.00028718669878572767, + ] + ) + b = a.rolling(window=3).std() + assert np.isfinite(b[2:]).all() + + b = a.ewm(span=3).std() + assert np.isfinite(b[2:]).all() + + +def test_step_not_integer_raises(): + with pytest.raises(ValueError, match="step must be an integer"): + DataFrame(range(2)).rolling(1, step="foo") + + +def test_step_not_positive_raises(): + with pytest.raises(ValueError, match="step must be >= 0"): + DataFrame(range(2)).rolling(1, step=-1) + + +@pytest.mark.parametrize( + ["values", "window", "min_periods", "expected"], + [ + [ + [20, 10, 10, np.inf, 1, 1, 2, 3], + 3, + 1, + [np.nan, 50, 100 / 3, 0, 40.5, 0, 1 / 3, 1], + ], + [ + [20, 10, 10, np.nan, 10, 1, 2, 3], + 3, + 1, + [np.nan, 50, 100 / 3, 0, 0, 40.5, 73 / 3, 1], + ], + [ + [np.nan, 5, 6, 7, 5, 5, 5], + 3, + 3, + [np.nan] * 3 + [1, 1, 4 / 3, 0], + ], + [ + [5, 7, 7, 7, np.nan, np.inf, 4, 3, 3, 3], + 3, + 3, + [np.nan] * 2 + [4 / 3, 0] + [np.nan] * 4 + [1 / 3, 0], + ], + [ + [5, 7, 7, 7, np.nan, np.inf, 7, 3, 3, 3], + 3, + 3, + [np.nan] * 2 + [4 / 3, 0] + [np.nan] * 4 + [16 / 3, 0], + ], + [ + [5, 7] * 4, + 3, + 3, + [np.nan] * 2 + [4 / 3] * 6, + ], + [ + [5, 7, 5, np.nan, 7, 5, 7], + 3, + 2, + [np.nan, 2, 4 / 3] + [2] * 3 + [4 / 3], + ], + ], +) +def test_rolling_var_same_value_count_logic(values, window, min_periods, expected): + # GH 42064. + + expected = Series(expected) + sr = Series(values) + + # With new algo implemented, result will be set to .0 in rolling var + # if sufficient amount of consecutively same values are found. + result_var = sr.rolling(window, min_periods=min_periods).var() + + # use `assert_series_equal` twice to check for equality, + # because `check_exact=True` will fail in 32-bit tests due to + # precision loss. + + # 1. result should be close to correct value + # non-zero values can still differ slightly from "truth" + # as the result of online algorithm + tm.assert_series_equal(result_var, expected) + # 2. zeros should be exactly the same since the new algo takes effect here + tm.assert_series_equal(expected == 0, result_var == 0) + + # std should also pass as it's just a sqrt of var + result_std = sr.rolling(window, min_periods=min_periods).std() + tm.assert_series_equal(result_std, np.sqrt(expected)) + tm.assert_series_equal(expected == 0, result_std == 0) + + +def test_rolling_mean_sum_floating_artifacts(): + # GH 42064. + + sr = Series([1 / 3, 4, 0, 0, 0, 0, 0]) + r = sr.rolling(3) + result = r.mean() + assert (result[-3:] == 0).all() + result = r.sum() + assert (result[-3:] == 0).all() + + +def test_rolling_skew_kurt_floating_artifacts(): + # GH 42064 46431 + + sr = Series([1 / 3, 4, 0, 0, 0, 0, 0]) + r = sr.rolling(4) + result = r.skew() + expected = Series([np.nan, np.nan, np.nan, 1.9619045191072484, 2.0, 0.0, 0.0]) + tm.assert_series_equal(result, expected) + result = r.kurt() + expected = Series([np.nan, np.nan, np.nan, 3.8636048803878786, 4.0, -3.0, -3.0]) + tm.assert_series_equal(result, expected) + + +def test_numeric_only_frame(arithmetic_win_operators, numeric_only): + # GH#46560 + kernel = arithmetic_win_operators + df = DataFrame({"a": [1], "b": 2, "c": 3}) + df["c"] = df["c"].astype(object) + rolling = df.rolling(2, min_periods=1) + op = getattr(rolling, kernel) + result = op(numeric_only=numeric_only) + + columns = ["a", "b"] if numeric_only else ["a", "b", "c"] + expected = df[columns].agg([kernel]).reset_index(drop=True).astype(float) + assert list(expected.columns) == columns + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("kernel", ["corr", "cov"]) +@pytest.mark.parametrize("use_arg", [True, False]) +def test_numeric_only_corr_cov_frame(kernel, numeric_only, use_arg): + # GH#46560 + df = DataFrame({"a": [1, 2, 3], "b": 2, "c": 3}) + df["c"] = df["c"].astype(object) + arg = (df,) if use_arg else () + rolling = df.rolling(2, min_periods=1) + op = getattr(rolling, kernel) + result = op(*arg, numeric_only=numeric_only) + + # Compare result to op using float dtypes, dropping c when numeric_only is True + columns = ["a", "b"] if numeric_only else ["a", "b", "c"] + df2 = df[columns].astype(float) + arg2 = (df2,) if use_arg else () + rolling2 = df2.rolling(2, min_periods=1) + op2 = getattr(rolling2, kernel) + expected = op2(*arg2, numeric_only=numeric_only) + + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize("dtype", [int, object]) +def test_numeric_only_series(arithmetic_win_operators, numeric_only, dtype): + # GH#46560 + kernel = arithmetic_win_operators + ser = Series([1], dtype=dtype) + rolling = ser.rolling(2, min_periods=1) + op = getattr(rolling, kernel) + if numeric_only and dtype is object: + msg = f"Rolling.{kernel} does not implement numeric_only" + with pytest.raises(NotImplementedError, match=msg): + op(numeric_only=numeric_only) + else: + result = op(numeric_only=numeric_only) + expected = ser.agg([kernel]).reset_index(drop=True).astype(float) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("kernel", ["corr", "cov"]) +@pytest.mark.parametrize("use_arg", [True, False]) +@pytest.mark.parametrize("dtype", [int, object]) +def test_numeric_only_corr_cov_series(kernel, use_arg, numeric_only, dtype): + # GH#46560 + ser = Series([1, 2, 3], dtype=dtype) + arg = (ser,) if use_arg else () + rolling = ser.rolling(2, min_periods=1) + op = getattr(rolling, kernel) + if numeric_only and dtype is object: + msg = f"Rolling.{kernel} does not implement numeric_only" + with pytest.raises(NotImplementedError, match=msg): + op(*arg, numeric_only=numeric_only) + else: + result = op(*arg, numeric_only=numeric_only) + + ser2 = ser.astype(float) + arg2 = (ser2,) if use_arg else () + rolling2 = ser2.rolling(2, min_periods=1) + op2 = getattr(rolling2, kernel) + expected = op2(*arg2, numeric_only=numeric_only) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("tz", [None, "UTC", "Europe/Prague"]) +def test_rolling_timedelta_window_non_nanoseconds(unit, tz): + # Test Sum, GH#55106 + df_time = DataFrame( + {"A": range(5)}, + index=date_range("2013-01-01", freq="1s", periods=5, tz=tz, unit="ns"), + ) + sum_in_nanosecs = df_time.rolling("1s").sum() + # microseconds / milliseconds should not break the correct rolling + df_time.index = df_time.index.as_unit(unit) + sum_in_microsecs = df_time.rolling("1s").sum() + sum_in_microsecs.index = sum_in_microsecs.index.as_unit("ns") + tm.assert_frame_equal(sum_in_nanosecs, sum_in_microsecs) + + # Test max, GH#55026 + ref_dates = date_range("2023-01-01", "2023-01-10", unit="ns", tz=tz) + ref_series = Series(0, index=ref_dates) + ref_series.iloc[0] = 1 + ref_max_series = ref_series.rolling(Timedelta(days=4)).max() + + dates = date_range("2023-01-01", "2023-01-10", unit=unit, tz=tz) + series = Series(0, index=dates) + series.iloc[0] = 1 + max_series = series.rolling(Timedelta(days=4)).max() + + ref_df = DataFrame(ref_max_series) + df = DataFrame(max_series) + df.index = df.index.as_unit("ns") + + tm.assert_frame_equal(ref_df, df) + + +class PrescribedWindowIndexer(BaseIndexer): + def __init__(self, start, end): + self._start = start + self._end = end + super().__init__() + + def get_window_bounds( + self, num_values=None, min_periods=None, center=None, closed=None, step=None + ): + if num_values is None: + num_values = len(self._start) + start = np.clip(self._start, 0, num_values) + end = np.clip(self._end, 0, num_values) + return start, end + + +class TestMinMax: + @pytest.mark.parametrize( + "is_max, has_nan, exp_list", + [ + (True, False, [3.0, 5.0, 2.0, 5.0, 1.0, 5.0, 6.0, 7.0, 8.0, 9.0]), + (True, True, [3.0, 4.0, 2.0, 4.0, 1.0, 4.0, 6.0, 7.0, 7.0, 9.0]), + (False, False, [3.0, 2.0, 2.0, 1.0, 1.0, 0.0, 0.0, 0.0, 7.0, 0.0]), + (False, True, [3.0, 2.0, 2.0, 1.0, 1.0, 1.0, 6.0, 6.0, 7.0, 1.0]), + ], + ) + def test_minmax(self, is_max, has_nan, exp_list): + nan_idx = [0, 5, 8] + df = DataFrame( + { + "data": [5.0, 4.0, 3.0, 2.0, 1.0, 0.0, 6.0, 7.0, 8.0, 9.0], + "start": [2, 0, 3, 0, 4, 0, 5, 5, 7, 3], + "end": [3, 4, 4, 5, 5, 6, 7, 8, 9, 10], + } + ) + if has_nan: + df.loc[nan_idx, "data"] = np.nan + expected = Series(exp_list, name="data") + r = df.data.rolling( + PrescribedWindowIndexer(df.start.to_numpy(), df.end.to_numpy()) + ) + if is_max: + result = r.max() + else: + result = r.min() + + tm.assert_series_equal(result, expected) + + def test_wrong_order(self): + start = np.array(range(5), dtype=np.int64) + end = start + 1 + end[3] = end[2] + start[3] = start[2] - 1 + + df = DataFrame({"data": start * 1.0, "start": start, "end": end}) + + r = df.data.rolling(PrescribedWindowIndexer(start, end)) + with pytest.raises( + ValueError, match="Start/End ordering requirement is violated at index 3" + ): + r.max() diff --git a/pandas/tests/window/test_rolling_functions.py b/pandas/tests/window/test_rolling_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..36ae7f3b7dfe47038901a9d154495f157b0ea23a --- /dev/null +++ b/pandas/tests/window/test_rolling_functions.py @@ -0,0 +1,535 @@ +from datetime import datetime + +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas import ( + DataFrame, + DatetimeIndex, + Series, + concat, + isna, + notna, +) +import pandas._testing as tm + +from pandas.tseries import offsets + + +@pytest.mark.parametrize( + "compare_func, roll_func, kwargs", + [ + [np.mean, "mean", {}], + [np.nansum, "sum", {}], + [ + lambda x: np.isfinite(x).astype(float).sum(), + "count", + {}, + ], + [np.median, "median", {}], + [np.min, "min", {}], + [np.max, "max", {}], + [lambda x: np.std(x, ddof=1), "std", {}], + [lambda x: np.std(x, ddof=0), "std", {"ddof": 0}], + [lambda x: np.var(x, ddof=1), "var", {}], + [lambda x: np.var(x, ddof=0), "var", {"ddof": 0}], + ], +) +def test_series(series, compare_func, roll_func, kwargs, step): + result = getattr(series.rolling(50, step=step), roll_func)(**kwargs) + assert isinstance(result, Series) + end = range(0, len(series), step or 1)[-1] + 1 + tm.assert_almost_equal(result.iloc[-1], compare_func(series[end - 50 : end])) + + +@pytest.mark.parametrize( + "compare_func, roll_func, kwargs", + [ + [np.mean, "mean", {}], + [np.nansum, "sum", {}], + [ + lambda x: np.isfinite(x).astype(float).sum(), + "count", + {}, + ], + [np.median, "median", {}], + [np.min, "min", {}], + [np.max, "max", {}], + [lambda x: np.std(x, ddof=1), "std", {}], + [lambda x: np.std(x, ddof=0), "std", {"ddof": 0}], + [lambda x: np.var(x, ddof=1), "var", {}], + [lambda x: np.var(x, ddof=0), "var", {"ddof": 0}], + ], +) +def test_frame(raw, frame, compare_func, roll_func, kwargs, step): + result = getattr(frame.rolling(50, step=step), roll_func)(**kwargs) + assert isinstance(result, DataFrame) + end = range(0, len(frame), step or 1)[-1] + 1 + tm.assert_series_equal( + result.iloc[-1, :], + frame.iloc[end - 50 : end, :].apply(compare_func, axis=0, raw=raw), + check_names=False, + ) + + +@pytest.mark.parametrize( + "compare_func, roll_func, kwargs, minp", + [ + [np.mean, "mean", {}, 10], + [np.nansum, "sum", {}, 10], + [lambda x: np.isfinite(x).astype(float).sum(), "count", {}, 0], + [np.median, "median", {}, 10], + [np.min, "min", {}, 10], + [np.max, "max", {}, 10], + [lambda x: np.std(x, ddof=1), "std", {}, 10], + [lambda x: np.std(x, ddof=0), "std", {"ddof": 0}, 10], + [lambda x: np.var(x, ddof=1), "var", {}, 10], + [lambda x: np.var(x, ddof=0), "var", {"ddof": 0}, 10], + ], +) +def test_time_rule_series(series, compare_func, roll_func, kwargs, minp): + win = 25 + ser = series[::2].resample("B").mean() + series_result = getattr(ser.rolling(window=win, min_periods=minp), roll_func)( + **kwargs + ) + last_date = series_result.index[-1] + prev_date = last_date - 24 * offsets.BDay() + + trunc_series = series[::2].truncate(prev_date, last_date) + tm.assert_almost_equal(series_result.iloc[-1], compare_func(trunc_series)) + + +@pytest.mark.parametrize( + "compare_func, roll_func, kwargs, minp", + [ + [np.mean, "mean", {}, 10], + [np.nansum, "sum", {}, 10], + [lambda x: np.isfinite(x).astype(float).sum(), "count", {}, 0], + [np.median, "median", {}, 10], + [np.min, "min", {}, 10], + [np.max, "max", {}, 10], + [lambda x: np.std(x, ddof=1), "std", {}, 10], + [lambda x: np.std(x, ddof=0), "std", {"ddof": 0}, 10], + [lambda x: np.var(x, ddof=1), "var", {}, 10], + [lambda x: np.var(x, ddof=0), "var", {"ddof": 0}, 10], + ], +) +def test_time_rule_frame(raw, frame, compare_func, roll_func, kwargs, minp): + win = 25 + frm = frame[::2].resample("B").mean() + frame_result = getattr(frm.rolling(window=win, min_periods=minp), roll_func)( + **kwargs + ) + last_date = frame_result.index[-1] + prev_date = last_date - 24 * offsets.BDay() + + trunc_frame = frame[::2].truncate(prev_date, last_date) + tm.assert_series_equal( + frame_result.xs(last_date), + trunc_frame.apply(compare_func, raw=raw), + check_names=False, + ) + + +@pytest.mark.parametrize( + "compare_func, roll_func, kwargs", + [ + [np.mean, "mean", {}], + [np.nansum, "sum", {}], + [np.median, "median", {}], + [np.min, "min", {}], + [np.max, "max", {}], + [lambda x: np.std(x, ddof=1), "std", {}], + [lambda x: np.std(x, ddof=0), "std", {"ddof": 0}], + [lambda x: np.var(x, ddof=1), "var", {}], + [lambda x: np.var(x, ddof=0), "var", {"ddof": 0}], + ], +) +def test_nans(compare_func, roll_func, kwargs): + obj = Series(np.random.default_rng(2).standard_normal(50)) + obj[:10] = np.nan + obj[-10:] = np.nan + + result = getattr(obj.rolling(50, min_periods=30), roll_func)(**kwargs) + tm.assert_almost_equal(result.iloc[-1], compare_func(obj[10:-10])) + + # min_periods is working correctly + result = getattr(obj.rolling(20, min_periods=15), roll_func)(**kwargs) + assert isna(result.iloc[23]) + assert not isna(result.iloc[24]) + + assert not isna(result.iloc[-6]) + assert isna(result.iloc[-5]) + + obj2 = Series(np.random.default_rng(2).standard_normal(20)) + result = getattr(obj2.rolling(10, min_periods=5), roll_func)(**kwargs) + assert isna(result.iloc[3]) + assert notna(result.iloc[4]) + + if roll_func != "sum": + result0 = getattr(obj.rolling(20, min_periods=0), roll_func)(**kwargs) + result1 = getattr(obj.rolling(20, min_periods=1), roll_func)(**kwargs) + tm.assert_almost_equal(result0, result1) + + +def test_nans_count(): + obj = Series(np.random.default_rng(2).standard_normal(50)) + obj[:10] = np.nan + obj[-10:] = np.nan + result = obj.rolling(50, min_periods=30).count() + tm.assert_almost_equal( + result.iloc[-1], np.isfinite(obj[10:-10]).astype(float).sum() + ) + + +@pytest.mark.parametrize( + "roll_func, kwargs", + [ + ["mean", {}], + ["sum", {}], + ["median", {}], + ["min", {}], + ["max", {}], + ["std", {}], + ["std", {"ddof": 0}], + ["var", {}], + ["var", {"ddof": 0}], + ], +) +@pytest.mark.parametrize("minp", [0, 99, 100]) +def test_min_periods(series, minp, roll_func, kwargs, step): + result = getattr( + series.rolling(len(series) + 1, min_periods=minp, step=step), roll_func + )(**kwargs) + expected = getattr( + series.rolling(len(series), min_periods=minp, step=step), roll_func + )(**kwargs) + nan_mask = isna(result) + tm.assert_series_equal(nan_mask, isna(expected)) + + nan_mask = ~nan_mask + tm.assert_almost_equal(result[nan_mask], expected[nan_mask]) + + +def test_min_periods_count(series, step): + result = series.rolling(len(series) + 1, min_periods=0, step=step).count() + expected = series.rolling(len(series), min_periods=0, step=step).count() + nan_mask = isna(result) + tm.assert_series_equal(nan_mask, isna(expected)) + + nan_mask = ~nan_mask + tm.assert_almost_equal(result[nan_mask], expected[nan_mask]) + + +@pytest.mark.parametrize( + "roll_func, kwargs, minp", + [ + ["mean", {}, 15], + ["sum", {}, 15], + ["count", {}, 0], + ["median", {}, 15], + ["min", {}, 15], + ["max", {}, 15], + ["std", {}, 15], + ["std", {"ddof": 0}, 15], + ["var", {}, 15], + ["var", {"ddof": 0}, 15], + ], +) +def test_center(roll_func, kwargs, minp): + obj = Series(np.random.default_rng(2).standard_normal(50)) + obj[:10] = np.nan + obj[-10:] = np.nan + + result = getattr(obj.rolling(20, min_periods=minp, center=True), roll_func)( + **kwargs + ) + expected = ( + getattr( + concat([obj, Series([np.nan] * 9)]).rolling(20, min_periods=minp), roll_func + )(**kwargs) + .iloc[9:] + .reset_index(drop=True) + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "roll_func, kwargs, minp, fill_value", + [ + ["mean", {}, 10, None], + ["sum", {}, 10, None], + ["count", {}, 0, 0], + ["median", {}, 10, None], + ["min", {}, 10, None], + ["max", {}, 10, None], + ["std", {}, 10, None], + ["std", {"ddof": 0}, 10, None], + ["var", {}, 10, None], + ["var", {"ddof": 0}, 10, None], + ], +) +def test_center_reindex_series(series, roll_func, kwargs, minp, fill_value): + # shifter index + s = [f"x{x:d}" for x in range(12)] + + series_xp = ( + getattr( + series.reindex(list(series.index) + s).rolling(window=25, min_periods=minp), + roll_func, + )(**kwargs) + .shift(-12) + .reindex(series.index) + ) + series_rs = getattr( + series.rolling(window=25, min_periods=minp, center=True), roll_func + )(**kwargs) + if fill_value is not None: + series_xp = series_xp.fillna(fill_value) + tm.assert_series_equal(series_xp, series_rs) + + +@pytest.mark.parametrize( + "roll_func, kwargs, minp, fill_value", + [ + ["mean", {}, 10, None], + ["sum", {}, 10, None], + ["count", {}, 0, 0], + ["median", {}, 10, None], + ["min", {}, 10, None], + ["max", {}, 10, None], + ["std", {}, 10, None], + ["std", {"ddof": 0}, 10, None], + ["var", {}, 10, None], + ["var", {"ddof": 0}, 10, None], + ], +) +def test_center_reindex_frame(frame, roll_func, kwargs, minp, fill_value): + # shifter index + s = [f"x{x:d}" for x in range(12)] + + frame_xp = ( + getattr( + frame.reindex(list(frame.index) + s).rolling(window=25, min_periods=minp), + roll_func, + )(**kwargs) + .shift(-12) + .reindex(frame.index) + ) + frame_rs = getattr( + frame.rolling(window=25, min_periods=minp, center=True), roll_func + )(**kwargs) + if fill_value is not None: + frame_xp = frame_xp.fillna(fill_value) + tm.assert_frame_equal(frame_xp, frame_rs) + + +@pytest.mark.parametrize( + "f", + [ + lambda x: x.rolling(window=10, min_periods=5).cov(x, pairwise=False), + lambda x: x.rolling(window=10, min_periods=5).corr(x, pairwise=False), + lambda x: x.rolling(window=10, min_periods=5).max(), + lambda x: x.rolling(window=10, min_periods=5).min(), + lambda x: x.rolling(window=10, min_periods=5).sum(), + lambda x: x.rolling(window=10, min_periods=5).mean(), + lambda x: x.rolling(window=10, min_periods=5).std(), + lambda x: x.rolling(window=10, min_periods=5).var(), + lambda x: x.rolling(window=10, min_periods=5).skew(), + lambda x: x.rolling(window=10, min_periods=5).kurt(), + lambda x: x.rolling(window=10, min_periods=5).first(), + lambda x: x.rolling(window=10, min_periods=5).last(), + lambda x: x.rolling(window=10, min_periods=5).quantile(q=0.5), + lambda x: x.rolling(window=10, min_periods=5).median(), + lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=False), + lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=True), + pytest.param( + lambda x: x.rolling(win_type="boxcar", window=10, min_periods=5).mean(), + marks=td.skip_if_no("scipy"), + ), + ], +) +def test_rolling_functions_window_non_shrinkage(f): + # GH 7764 + s = Series(range(4)) + s_expected = Series(np.nan, index=s.index) + df = DataFrame([[1, 5], [3, 2], [3, 9], [-1, 0]], columns=["A", "B"]) + df_expected = DataFrame(np.nan, index=df.index, columns=df.columns) + + s_result = f(s) + tm.assert_series_equal(s_result, s_expected) + + df_result = f(df) + tm.assert_frame_equal(df_result, df_expected) + + +def test_rolling_max_gh6297(step): + """Replicate result expected in GH #6297""" + indices = [datetime(1975, 1, i) for i in range(1, 6)] + # So that we can have 2 datapoints on one of the days + indices.append(datetime(1975, 1, 3, 6, 0)) + series = Series(range(1, 7), index=indices) + # Use floats instead of ints as values + series = series.map(lambda x: float(x)) + # Sort chronologically + series = series.sort_index() + + expected = Series( + [1.0, 2.0, 6.0, 4.0, 5.0], + index=DatetimeIndex([datetime(1975, 1, i, 0) for i in range(1, 6)], freq="D"), + )[::step] + x = series.resample("D").max().rolling(window=1, step=step).max() + tm.assert_series_equal(expected, x) + + +def test_rolling_max_resample(step): + indices = [datetime(1975, 1, i) for i in range(1, 6)] + # So that we can have 3 datapoints on last day (4, 10, and 20) + indices.append(datetime(1975, 1, 5, 1)) + indices.append(datetime(1975, 1, 5, 2)) + series = Series([*list(range(5)), 10, 20], index=indices) + # Use floats instead of ints as values + series = series.map(lambda x: float(x)) + # Sort chronologically + series = series.sort_index() + + # Default how should be max + expected = Series( + [0.0, 1.0, 2.0, 3.0, 20.0], + index=DatetimeIndex([datetime(1975, 1, i, 0) for i in range(1, 6)], freq="D"), + )[::step] + x = series.resample("D").max().rolling(window=1, step=step).max() + tm.assert_series_equal(expected, x) + + # Now specify median (10.0) + expected = Series( + [0.0, 1.0, 2.0, 3.0, 10.0], + index=DatetimeIndex([datetime(1975, 1, i, 0) for i in range(1, 6)], freq="D"), + )[::step] + x = series.resample("D").median().rolling(window=1, step=step).max() + tm.assert_series_equal(expected, x) + + # Now specify mean (4+10+20)/3 + v = (4.0 + 10.0 + 20.0) / 3.0 + expected = Series( + [0.0, 1.0, 2.0, 3.0, v], + index=DatetimeIndex([datetime(1975, 1, i, 0) for i in range(1, 6)], freq="D"), + )[::step] + x = series.resample("D").mean().rolling(window=1, step=step).max() + tm.assert_series_equal(expected, x) + + +def test_rolling_min_resample(step): + indices = [datetime(1975, 1, i) for i in range(1, 6)] + # So that we can have 3 datapoints on last day (4, 10, and 20) + indices.append(datetime(1975, 1, 5, 1)) + indices.append(datetime(1975, 1, 5, 2)) + series = Series([*list(range(5)), 10, 20], index=indices) + # Use floats instead of ints as values + series = series.map(lambda x: float(x)) + # Sort chronologically + series = series.sort_index() + + # Default how should be min + expected = Series( + [0.0, 1.0, 2.0, 3.0, 4.0], + index=DatetimeIndex([datetime(1975, 1, i, 0) for i in range(1, 6)], freq="D"), + )[::step] + r = series.resample("D").min().rolling(window=1, step=step) + tm.assert_series_equal(expected, r.min()) + + +def test_rolling_median_resample(): + indices = [datetime(1975, 1, i) for i in range(1, 6)] + # So that we can have 3 datapoints on last day (4, 10, and 20) + indices.append(datetime(1975, 1, 5, 1)) + indices.append(datetime(1975, 1, 5, 2)) + series = Series([*list(range(5)), 10, 20], index=indices) + # Use floats instead of ints as values + series = series.map(lambda x: float(x)) + # Sort chronologically + series = series.sort_index() + + # Default how should be median + expected = Series( + [0.0, 1.0, 2.0, 3.0, 10], + index=DatetimeIndex([datetime(1975, 1, i, 0) for i in range(1, 6)], freq="D"), + ) + x = series.resample("D").median().rolling(window=1).median() + tm.assert_series_equal(expected, x) + + +def test_rolling_median_memory_error(): + # GH11722 + n = 20000 + Series(np.random.default_rng(2).standard_normal(n)).rolling( + window=2, center=False + ).median() + Series(np.random.default_rng(2).standard_normal(n)).rolling( + window=2, center=False + ).median() + + +def test_rolling_min_max_numeric_types(any_real_numpy_dtype): + # GH12373 + + # Just testing that these don't throw exceptions and that + # the return type is float64. Other tests will cover quantitative + # correctness + result = ( + DataFrame(np.arange(20, dtype=any_real_numpy_dtype)).rolling(window=5).max() + ) + assert result.dtypes[0] == np.dtype("f8") + result = ( + DataFrame(np.arange(20, dtype=any_real_numpy_dtype)).rolling(window=5).min() + ) + assert result.dtypes[0] == np.dtype("f8") + + +@pytest.mark.parametrize( + "f", + [ + lambda x: x.rolling(window=10, min_periods=0).count(), + lambda x: x.rolling(window=10, min_periods=5).cov(x, pairwise=False), + lambda x: x.rolling(window=10, min_periods=5).corr(x, pairwise=False), + lambda x: x.rolling(window=10, min_periods=5).max(), + lambda x: x.rolling(window=10, min_periods=5).min(), + lambda x: x.rolling(window=10, min_periods=5).sum(), + lambda x: x.rolling(window=10, min_periods=5).mean(), + lambda x: x.rolling(window=10, min_periods=5).std(), + lambda x: x.rolling(window=10, min_periods=5).var(), + lambda x: x.rolling(window=10, min_periods=5).skew(), + lambda x: x.rolling(window=10, min_periods=5).kurt(), + lambda x: x.rolling(window=10, min_periods=5).first(), + lambda x: x.rolling(window=10, min_periods=5).last(), + lambda x: x.rolling(window=10, min_periods=5).quantile(0.5), + lambda x: x.rolling(window=10, min_periods=5).median(), + lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=False), + lambda x: x.rolling(window=10, min_periods=5).apply(sum, raw=True), + pytest.param( + lambda x: x.rolling(win_type="boxcar", window=10, min_periods=5).mean(), + marks=td.skip_if_no("scipy"), + ), + ], +) +def test_moment_functions_zero_length(f): + # GH 8056 + s = Series(dtype=np.float64) + s_expected = s + df1 = DataFrame() + df1_expected = df1 + df2 = DataFrame(columns=["a"]) + df2["a"] = df2["a"].astype("float64") + df2_expected = df2 + + s_result = f(s) + tm.assert_series_equal(s_result, s_expected) + + df1_result = f(df1) + tm.assert_frame_equal(df1_result, df1_expected) + + df2_result = f(df2) + tm.assert_frame_equal(df2_result, df2_expected) diff --git a/pandas/tests/window/test_rolling_quantile.py b/pandas/tests/window/test_rolling_quantile.py new file mode 100644 index 0000000000000000000000000000000000000000..66713f1cfaa8dbd8dfbb10158da9a56b42287c6c --- /dev/null +++ b/pandas/tests/window/test_rolling_quantile.py @@ -0,0 +1,175 @@ +from functools import partial + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, + concat, + isna, + notna, +) +import pandas._testing as tm + +from pandas.tseries import offsets + + +def scoreatpercentile(a, per): + values = np.sort(a, axis=0) + + idx = int(per / 1.0 * (values.shape[0] - 1)) + + if idx == values.shape[0] - 1: + retval = values[-1] + + else: + qlow = idx / (values.shape[0] - 1) + qhig = (idx + 1) / (values.shape[0] - 1) + vlow = values[idx] + vhig = values[idx + 1] + retval = vlow + (vhig - vlow) * (per - qlow) / (qhig - qlow) + + return retval + + +@pytest.mark.parametrize("q", [0.0, 0.1, 0.5, 0.9, 1.0]) +def test_series(series, q, step): + compare_func = partial(scoreatpercentile, per=q) + result = series.rolling(50, step=step).quantile(q) + assert isinstance(result, Series) + end = range(0, len(series), step or 1)[-1] + 1 + tm.assert_almost_equal(result.iloc[-1], compare_func(series[end - 50 : end])) + + +@pytest.mark.parametrize("q", [0.0, 0.1, 0.5, 0.9, 1.0]) +def test_frame(raw, frame, q, step): + compare_func = partial(scoreatpercentile, per=q) + result = frame.rolling(50, step=step).quantile(q) + assert isinstance(result, DataFrame) + end = range(0, len(frame), step or 1)[-1] + 1 + tm.assert_series_equal( + result.iloc[-1, :], + frame.iloc[end - 50 : end, :].apply(compare_func, axis=0, raw=raw), + check_names=False, + ) + + +@pytest.mark.parametrize("q", [0.0, 0.1, 0.5, 0.9, 1.0]) +def test_time_rule_series(series, q): + compare_func = partial(scoreatpercentile, per=q) + win = 25 + ser = series[::2].resample("B").mean() + series_result = ser.rolling(window=win, min_periods=10).quantile(q) + last_date = series_result.index[-1] + prev_date = last_date - 24 * offsets.BDay() + + trunc_series = series[::2].truncate(prev_date, last_date) + tm.assert_almost_equal(series_result.iloc[-1], compare_func(trunc_series)) + + +@pytest.mark.parametrize("q", [0.0, 0.1, 0.5, 0.9, 1.0]) +def test_time_rule_frame(raw, frame, q): + compare_func = partial(scoreatpercentile, per=q) + win = 25 + frm = frame[::2].resample("B").mean() + frame_result = frm.rolling(window=win, min_periods=10).quantile(q) + last_date = frame_result.index[-1] + prev_date = last_date - 24 * offsets.BDay() + + trunc_frame = frame[::2].truncate(prev_date, last_date) + tm.assert_series_equal( + frame_result.xs(last_date), + trunc_frame.apply(compare_func, raw=raw), + check_names=False, + ) + + +@pytest.mark.parametrize("q", [0.0, 0.1, 0.5, 0.9, 1.0]) +def test_nans(q): + compare_func = partial(scoreatpercentile, per=q) + obj = Series(np.random.default_rng(2).standard_normal(50)) + obj[:10] = np.nan + obj[-10:] = np.nan + + result = obj.rolling(50, min_periods=30).quantile(q) + tm.assert_almost_equal(result.iloc[-1], compare_func(obj[10:-10])) + + # min_periods is working correctly + result = obj.rolling(20, min_periods=15).quantile(q) + assert isna(result.iloc[23]) + assert not isna(result.iloc[24]) + + assert not isna(result.iloc[-6]) + assert isna(result.iloc[-5]) + + obj2 = Series(np.random.default_rng(2).standard_normal(20)) + result = obj2.rolling(10, min_periods=5).quantile(q) + assert isna(result.iloc[3]) + assert notna(result.iloc[4]) + + result0 = obj.rolling(20, min_periods=0).quantile(q) + result1 = obj.rolling(20, min_periods=1).quantile(q) + tm.assert_almost_equal(result0, result1) + + +@pytest.mark.parametrize("minp", [0, 99, 100]) +@pytest.mark.parametrize("q", [0.0, 0.1, 0.5, 0.9, 1.0]) +def test_min_periods(series, minp, q, step): + result = series.rolling(len(series) + 1, min_periods=minp, step=step).quantile(q) + expected = series.rolling(len(series), min_periods=minp, step=step).quantile(q) + nan_mask = isna(result) + tm.assert_series_equal(nan_mask, isna(expected)) + + nan_mask = ~nan_mask + tm.assert_almost_equal(result[nan_mask], expected[nan_mask]) + + +@pytest.mark.parametrize("q", [0.0, 0.1, 0.5, 0.9, 1.0]) +def test_center(q): + obj = Series(np.random.default_rng(2).standard_normal(50)) + obj[:10] = np.nan + obj[-10:] = np.nan + + result = obj.rolling(20, center=True).quantile(q) + expected = ( + concat([obj, Series([np.nan] * 9)]) + .rolling(20) + .quantile(q) + .iloc[9:] + .reset_index(drop=True) + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("q", [0.0, 0.1, 0.5, 0.9, 1.0]) +def test_center_reindex_series(series, q): + # shifter index + s = [f"x{x:d}" for x in range(12)] + + series_xp = ( + series.reindex(list(series.index) + s) + .rolling(window=25) + .quantile(q) + .shift(-12) + .reindex(series.index) + ) + + series_rs = series.rolling(window=25, center=True).quantile(q) + tm.assert_series_equal(series_xp, series_rs) + + +@pytest.mark.parametrize("q", [0.0, 0.1, 0.5, 0.9, 1.0]) +def test_center_reindex_frame(frame, q): + # shifter index + s = [f"x{x:d}" for x in range(12)] + + frame_xp = ( + frame.reindex(list(frame.index) + s) + .rolling(window=25) + .quantile(q) + .shift(-12) + .reindex(frame.index) + ) + frame_rs = frame.rolling(window=25, center=True).quantile(q) + tm.assert_frame_equal(frame_xp, frame_rs) diff --git a/pandas/tests/window/test_rolling_skew_kurt.py b/pandas/tests/window/test_rolling_skew_kurt.py new file mode 100644 index 0000000000000000000000000000000000000000..79c14f243e7cc93b395ea84e05ec6bc79942b79b --- /dev/null +++ b/pandas/tests/window/test_rolling_skew_kurt.py @@ -0,0 +1,227 @@ +from functools import partial + +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, + concat, + isna, + notna, +) +import pandas._testing as tm + +from pandas.tseries import offsets + + +@pytest.mark.parametrize("sp_func, roll_func", [["kurtosis", "kurt"], ["skew", "skew"]]) +def test_series(series, sp_func, roll_func): + sp_stats = pytest.importorskip("scipy.stats") + + compare_func = partial(getattr(sp_stats, sp_func), bias=False) + result = getattr(series.rolling(50), roll_func)() + assert isinstance(result, Series) + tm.assert_almost_equal(result.iloc[-1], compare_func(series[-50:])) + + +@pytest.mark.parametrize("sp_func, roll_func", [["kurtosis", "kurt"], ["skew", "skew"]]) +def test_frame(raw, frame, sp_func, roll_func): + sp_stats = pytest.importorskip("scipy.stats") + + compare_func = partial(getattr(sp_stats, sp_func), bias=False) + result = getattr(frame.rolling(50), roll_func)() + assert isinstance(result, DataFrame) + tm.assert_series_equal( + result.iloc[-1, :], + frame.iloc[-50:, :].apply(compare_func, axis=0, raw=raw), + check_names=False, + ) + + +@pytest.mark.parametrize("sp_func, roll_func", [["kurtosis", "kurt"], ["skew", "skew"]]) +def test_time_rule_series(series, sp_func, roll_func): + sp_stats = pytest.importorskip("scipy.stats") + + compare_func = partial(getattr(sp_stats, sp_func), bias=False) + win = 25 + ser = series[::2].resample("B").mean() + series_result = getattr(ser.rolling(window=win, min_periods=10), roll_func)() + last_date = series_result.index[-1] + prev_date = last_date - 24 * offsets.BDay() + + trunc_series = series[::2].truncate(prev_date, last_date) + tm.assert_almost_equal(series_result.iloc[-1], compare_func(trunc_series)) + + +@pytest.mark.parametrize("sp_func, roll_func", [["kurtosis", "kurt"], ["skew", "skew"]]) +def test_time_rule_frame(raw, frame, sp_func, roll_func): + sp_stats = pytest.importorskip("scipy.stats") + + compare_func = partial(getattr(sp_stats, sp_func), bias=False) + win = 25 + frm = frame[::2].resample("B").mean() + frame_result = getattr(frm.rolling(window=win, min_periods=10), roll_func)() + last_date = frame_result.index[-1] + prev_date = last_date - 24 * offsets.BDay() + + trunc_frame = frame[::2].truncate(prev_date, last_date) + tm.assert_series_equal( + frame_result.xs(last_date), + trunc_frame.apply(compare_func, raw=raw), + check_names=False, + ) + + +@pytest.mark.parametrize("sp_func, roll_func", [["kurtosis", "kurt"], ["skew", "skew"]]) +def test_nans(sp_func, roll_func): + sp_stats = pytest.importorskip("scipy.stats") + + compare_func = partial(getattr(sp_stats, sp_func), bias=False) + obj = Series(np.random.default_rng(2).standard_normal(50)) + obj[:10] = np.nan + obj[-10:] = np.nan + + result = getattr(obj.rolling(50, min_periods=30), roll_func)() + tm.assert_almost_equal(result.iloc[-1], compare_func(obj[10:-10])) + + # min_periods is working correctly + result = getattr(obj.rolling(20, min_periods=15), roll_func)() + assert isna(result.iloc[23]) + assert not isna(result.iloc[24]) + + assert not isna(result.iloc[-6]) + assert isna(result.iloc[-5]) + + obj2 = Series(np.random.default_rng(2).standard_normal(20)) + result = getattr(obj2.rolling(10, min_periods=5), roll_func)() + assert isna(result.iloc[3]) + assert notna(result.iloc[4]) + + result0 = getattr(obj.rolling(20, min_periods=0), roll_func)() + result1 = getattr(obj.rolling(20, min_periods=1), roll_func)() + tm.assert_almost_equal(result0, result1) + + +@pytest.mark.parametrize("minp", [0, 99, 100]) +@pytest.mark.parametrize("roll_func", ["kurt", "skew"]) +def test_min_periods(series, minp, roll_func, step): + result = getattr( + series.rolling(len(series) + 1, min_periods=minp, step=step), roll_func + )() + expected = getattr( + series.rolling(len(series), min_periods=minp, step=step), roll_func + )() + nan_mask = isna(result) + tm.assert_series_equal(nan_mask, isna(expected)) + + nan_mask = ~nan_mask + tm.assert_almost_equal(result[nan_mask], expected[nan_mask]) + + +@pytest.mark.parametrize("roll_func", ["kurt", "skew"]) +def test_center(roll_func): + obj = Series(np.random.default_rng(2).standard_normal(50)) + obj[:10] = np.nan + obj[-10:] = np.nan + + result = getattr(obj.rolling(20, center=True), roll_func)() + expected = ( + getattr(concat([obj, Series([np.nan] * 9)]).rolling(20), roll_func)() + .iloc[9:] + .reset_index(drop=True) + ) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize("roll_func", ["kurt", "skew"]) +def test_center_reindex_series(series, roll_func): + # shifter index + s = [f"x{x:d}" for x in range(12)] + + series_xp = ( + getattr( + series.reindex(list(series.index) + s).rolling(window=25), + roll_func, + )() + .shift(-12) + .reindex(series.index) + ) + series_rs = getattr(series.rolling(window=25, center=True), roll_func)() + tm.assert_series_equal(series_xp, series_rs) + + +@pytest.mark.slow +@pytest.mark.parametrize("roll_func", ["kurt", "skew"]) +def test_center_reindex_frame(frame, roll_func): + # shifter index + s = [f"x{x:d}" for x in range(12)] + + frame_xp = ( + getattr( + frame.reindex(list(frame.index) + s).rolling(window=25), + roll_func, + )() + .shift(-12) + .reindex(frame.index) + ) + frame_rs = getattr(frame.rolling(window=25, center=True), roll_func)() + tm.assert_frame_equal(frame_xp, frame_rs) + + +def test_rolling_skew_edge_cases(step): + expected = Series([np.nan] * 4 + [0.0])[::step] + # yields all NaN (0 variance) + d = Series([1] * 5) + x = d.rolling(window=5, step=step).skew() + # index 4 should be 0 as it contains 5 same obs + tm.assert_series_equal(expected, x) + + expected = Series([np.nan] * 5)[::step] + # yields all NaN (window too small) + d = Series(np.random.default_rng(2).standard_normal(5)) + x = d.rolling(window=2, step=step).skew() + tm.assert_series_equal(expected, x) + + # yields [NaN, NaN, NaN, 0.177994, 1.548824] + d = Series([-1.50837035, -0.1297039, 0.19501095, 1.73508164, 0.41941401]) + expected = Series([np.nan, np.nan, np.nan, 0.177994, 1.548824])[::step] + x = d.rolling(window=4, step=step).skew() + tm.assert_series_equal(expected, x) + + +def test_rolling_kurt_edge_cases(step): + expected = Series([np.nan] * 4 + [-3.0])[::step] + + # yields all NaN (0 variance) + d = Series([1] * 5) + x = d.rolling(window=5, step=step).kurt() + tm.assert_series_equal(expected, x) + + # yields all NaN (window too small) + expected = Series([np.nan] * 5)[::step] + d = Series(np.random.default_rng(2).standard_normal(5)) + x = d.rolling(window=3, step=step).kurt() + tm.assert_series_equal(expected, x) + + # yields [NaN, NaN, NaN, 1.224307, 2.671499] + d = Series([-1.50837035, -0.1297039, 0.19501095, 1.73508164, 0.41941401]) + expected = Series([np.nan, np.nan, np.nan, 1.224307, 2.671499])[::step] + x = d.rolling(window=4, step=step).kurt() + tm.assert_series_equal(expected, x) + + +def test_rolling_skew_eq_value_fperr(step): + # #18804 all rolling skew for all equal values should return Nan + # #46717 update: all equal values should return 0 instead of NaN + a = Series([1.1] * 15).rolling(window=10, step=step).skew() + assert (a[a.index >= 9] == 0).all() + assert a[a.index < 9].isna().all() + + +def test_rolling_kurt_eq_value_fperr(step): + # #18804 all rolling kurt for all equal values should return Nan + # #46717 update: all equal values should return -3 instead of NaN + a = Series([1.1] * 15).rolling(window=10, step=step).kurt() + assert (a[a.index >= 9] == -3).all() + assert a[a.index < 9].isna().all() diff --git a/pandas/tests/window/test_timeseries_window.py b/pandas/tests/window/test_timeseries_window.py new file mode 100644 index 0000000000000000000000000000000000000000..043f369566a5df4b87b6d486fa3b9f23b8d208cf --- /dev/null +++ b/pandas/tests/window/test_timeseries_window.py @@ -0,0 +1,747 @@ +import numpy as np +import pytest + +import pandas.util._test_decorators as td + +from pandas import ( + DataFrame, + DatetimeIndex, + Index, + MultiIndex, + NaT, + Series, + Timestamp, + date_range, +) +import pandas._testing as tm + +from pandas.tseries import offsets + + +@pytest.fixture +def regular(): + return DataFrame( + {"A": date_range("20130101", periods=5, freq="s"), "B": range(5)} + ).set_index("A") + + +@pytest.fixture +def ragged(): + df = DataFrame({"B": range(5)}) + df.index = [ + Timestamp("20130101 09:00:00"), + Timestamp("20130101 09:00:02"), + Timestamp("20130101 09:00:03"), + Timestamp("20130101 09:00:05"), + Timestamp("20130101 09:00:06"), + ] + return df + + +class TestRollingTS: + # rolling time-series friendly + # xref GH13327 + + def test_doc_string(self): + df = DataFrame( + {"B": [0, 1, 2, np.nan, 4]}, + index=[ + Timestamp("20130101 09:00:00"), + Timestamp("20130101 09:00:02"), + Timestamp("20130101 09:00:03"), + Timestamp("20130101 09:00:05"), + Timestamp("20130101 09:00:06"), + ], + ) + df + df.rolling("2s").sum() + + def test_invalid_window_non_int(self, regular): + # not a valid freq + msg = "passed window foobar is not compatible with a datetimelike index" + with pytest.raises(ValueError, match=msg): + regular.rolling(window="foobar") + # not a datetimelike index + msg = "window must be an integer" + with pytest.raises(ValueError, match=msg): + regular.reset_index().rolling(window="foobar") + + @pytest.mark.parametrize("freq", ["2MS", offsets.MonthBegin(2)]) + def test_invalid_window_nonfixed(self, freq, regular): + # non-fixed freqs + msg = "\\<2 \\* MonthBegins\\> is a non-fixed frequency" + with pytest.raises(ValueError, match=msg): + regular.rolling(window=freq) + + @pytest.mark.parametrize("freq", ["1D", offsets.Day(2), "2ms"]) + def test_valid_window(self, freq, regular): + regular.rolling(window=freq) + + @pytest.mark.parametrize("minp", [1.0, "foo", np.array([1, 2, 3])]) + def test_invalid_minp(self, minp, regular): + # non-integer min_periods + msg = ( + r"local variable 'minp' referenced before assignment|" + "min_periods must be an integer" + ) + with pytest.raises(ValueError, match=msg): + regular.rolling(window="1D", min_periods=minp) + + def test_on(self, regular): + df = regular + + # not a valid column + msg = ( + r"invalid on specified as foobar, must be a column " + "\\(of DataFrame\\), an Index or None" + ) + with pytest.raises(ValueError, match=msg): + df.rolling(window="2s", on="foobar") + + # column is valid + df = df.copy() + df["C"] = date_range("20130101", periods=len(df)) + df.rolling(window="2D", on="C").sum() + + # invalid columns + msg = "window must be an integer" + with pytest.raises(ValueError, match=msg): + df.rolling(window="2d", on="B") + + # ok even though on non-selected + df.rolling(window="2D", on="C").B.sum() + + def test_monotonic_on(self): + # on/index must be monotonic + df = DataFrame( + {"A": date_range("20130101", periods=5, freq="s"), "B": range(5)} + ) + + assert df.A.is_monotonic_increasing + df.rolling("2s", on="A").sum() + + df = df.set_index("A") + assert df.index.is_monotonic_increasing + df.rolling("2s").sum() + + def test_non_monotonic_on(self): + # GH 19248 + df = DataFrame( + {"A": date_range("20130101", periods=5, freq="s"), "B": range(5)} + ) + df = df.set_index("A") + non_monotonic_index = df.index.to_list() + non_monotonic_index[0] = non_monotonic_index[3] + df.index = non_monotonic_index + + assert not df.index.is_monotonic_increasing + + msg = "index values must be monotonic" + with pytest.raises(ValueError, match=msg): + df.rolling("2s").sum() + + df = df.reset_index() + + msg = ( + r"invalid on specified as A, must be a column " + "\\(of DataFrame\\), an Index or None" + ) + with pytest.raises(ValueError, match=msg): + df.rolling("2s", on="A").sum() + + def test_frame_on(self): + df = DataFrame( + {"B": range(5), "C": date_range("20130101 09:00:00", periods=5, freq="3s")} + ) + + df["A"] = [ + Timestamp("20130101 09:00:00"), + Timestamp("20130101 09:00:02"), + Timestamp("20130101 09:00:03"), + Timestamp("20130101 09:00:05"), + Timestamp("20130101 09:00:06"), + ] + + # we are doing simulating using 'on' + expected = df.set_index("A").rolling("2s").B.sum().reset_index(drop=True) + + result = df.rolling("2s", on="A").B.sum() + tm.assert_series_equal(result, expected) + + # test as a frame + # we should be ignoring the 'on' as an aggregation column + # note that the expected is setting, computing, and resetting + # so the columns need to be switched compared + # to the actual result where they are ordered as in the + # original + expected = ( + df.set_index("A").rolling("2s")[["B"]].sum().reset_index()[["B", "A"]] + ) + + result = df.rolling("2s", on="A")[["B"]].sum() + tm.assert_frame_equal(result, expected) + + def test_frame_on2(self, unit): + # using multiple aggregation columns + dti = DatetimeIndex( + [ + Timestamp("20130101 09:00:00"), + Timestamp("20130101 09:00:02"), + Timestamp("20130101 09:00:03"), + Timestamp("20130101 09:00:05"), + Timestamp("20130101 09:00:06"), + ] + ).as_unit(unit) + df = DataFrame( + { + "A": [0, 1, 2, 3, 4], + "B": [0, 1, 2, np.nan, 4], + "C": dti, + }, + columns=["A", "C", "B"], + ) + + expected1 = DataFrame( + {"A": [0.0, 1, 3, 3, 7], "B": [0, 1, 3, np.nan, 4], "C": df["C"]}, + columns=["A", "C", "B"], + ) + + result = df.rolling("2s", on="C").sum() + expected = expected1 + tm.assert_frame_equal(result, expected) + + expected = Series([0, 1, 3, np.nan, 4], name="B") + result = df.rolling("2s", on="C").B.sum() + tm.assert_series_equal(result, expected) + + expected = expected1[["A", "B", "C"]] + result = df.rolling("2s", on="C")[["A", "B", "C"]].sum() + tm.assert_frame_equal(result, expected) + + def test_basic_regular(self, regular): + df = regular.copy() + + df.index = date_range("20130101", periods=5, freq="D") + expected = df.rolling(window=1, min_periods=1).sum() + result = df.rolling(window="1D").sum() + tm.assert_frame_equal(result, expected) + + df.index = date_range("20130101", periods=5, freq="2D") + expected = df.rolling(window=1, min_periods=1).sum() + result = df.rolling(window="2D", min_periods=1).sum() + tm.assert_frame_equal(result, expected) + + expected = df.rolling(window=1, min_periods=1).sum() + result = df.rolling(window="2D", min_periods=1).sum() + tm.assert_frame_equal(result, expected) + + expected = df.rolling(window=1).sum() + result = df.rolling(window="2D").sum() + tm.assert_frame_equal(result, expected) + + def test_min_periods(self, regular): + # compare for min_periods + df = regular + + # these slightly different + expected = df.rolling(2, min_periods=1).sum() + result = df.rolling("2s").sum() + tm.assert_frame_equal(result, expected) + + expected = df.rolling(2, min_periods=1).sum() + result = df.rolling("2s", min_periods=1).sum() + tm.assert_frame_equal(result, expected) + + def test_closed(self, regular, unit): + # xref GH13965 + + dti = DatetimeIndex( + [ + Timestamp("20130101 09:00:01"), + Timestamp("20130101 09:00:02"), + Timestamp("20130101 09:00:03"), + Timestamp("20130101 09:00:04"), + Timestamp("20130101 09:00:06"), + ] + ).as_unit(unit) + + df = DataFrame( + {"A": [1] * 5}, + index=dti, + ) + + # closed must be 'right', 'left', 'both', 'neither' + msg = "closed must be 'right', 'left', 'both' or 'neither'" + with pytest.raises(ValueError, match=msg): + regular.rolling(window="2s", closed="blabla") + + expected = df.copy() + expected["A"] = [1.0, 2, 2, 2, 1] + result = df.rolling("2s", closed="right").sum() + tm.assert_frame_equal(result, expected) + + # default should be 'right' + result = df.rolling("2s").sum() + tm.assert_frame_equal(result, expected) + + expected = df.copy() + expected["A"] = [1.0, 2, 3, 3, 2] + result = df.rolling("2s", closed="both").sum() + tm.assert_frame_equal(result, expected) + + expected = df.copy() + expected["A"] = [np.nan, 1.0, 2, 2, 1] + result = df.rolling("2s", closed="left").sum() + tm.assert_frame_equal(result, expected) + + expected = df.copy() + expected["A"] = [np.nan, 1.0, 1, 1, np.nan] + result = df.rolling("2s", closed="neither").sum() + tm.assert_frame_equal(result, expected) + + def test_ragged_sum(self, ragged): + df = ragged + result = df.rolling(window="1s", min_periods=1).sum() + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="2s", min_periods=1).sum() + expected = df.copy() + expected["B"] = [0.0, 1, 3, 3, 7] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="2s", min_periods=2).sum() + expected = df.copy() + expected["B"] = [np.nan, np.nan, 3, np.nan, 7] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="3s", min_periods=1).sum() + expected = df.copy() + expected["B"] = [0.0, 1, 3, 5, 7] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="3s").sum() + expected = df.copy() + expected["B"] = [0.0, 1, 3, 5, 7] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="4s", min_periods=1).sum() + expected = df.copy() + expected["B"] = [0.0, 1, 3, 6, 9] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="4s", min_periods=3).sum() + expected = df.copy() + expected["B"] = [np.nan, np.nan, 3, 6, 9] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="5s", min_periods=1).sum() + expected = df.copy() + expected["B"] = [0.0, 1, 3, 6, 10] + tm.assert_frame_equal(result, expected) + + def test_ragged_mean(self, ragged): + df = ragged + result = df.rolling(window="1s", min_periods=1).mean() + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="2s", min_periods=1).mean() + expected = df.copy() + expected["B"] = [0.0, 1, 1.5, 3.0, 3.5] + tm.assert_frame_equal(result, expected) + + def test_ragged_median(self, ragged): + df = ragged + result = df.rolling(window="1s", min_periods=1).median() + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="2s", min_periods=1).median() + expected = df.copy() + expected["B"] = [0.0, 1, 1.5, 3.0, 3.5] + tm.assert_frame_equal(result, expected) + + def test_ragged_quantile(self, ragged): + df = ragged + result = df.rolling(window="1s", min_periods=1).quantile(0.5) + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="2s", min_periods=1).quantile(0.5) + expected = df.copy() + expected["B"] = [0.0, 1, 1.5, 3.0, 3.5] + tm.assert_frame_equal(result, expected) + + def test_ragged_std(self, ragged): + df = ragged + result = df.rolling(window="1s", min_periods=1).std(ddof=0) + expected = df.copy() + expected["B"] = [0.0] * 5 + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="1s", min_periods=1).std(ddof=1) + expected = df.copy() + expected["B"] = [np.nan] * 5 + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="3s", min_periods=1).std(ddof=0) + expected = df.copy() + expected["B"] = [0.0] + [0.5] * 4 + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="5s", min_periods=1).std(ddof=1) + expected = df.copy() + expected["B"] = [np.nan, 0.707107, 1.0, 1.0, 1.290994] + tm.assert_frame_equal(result, expected) + + def test_ragged_var(self, ragged): + df = ragged + result = df.rolling(window="1s", min_periods=1).var(ddof=0) + expected = df.copy() + expected["B"] = [0.0] * 5 + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="1s", min_periods=1).var(ddof=1) + expected = df.copy() + expected["B"] = [np.nan] * 5 + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="3s", min_periods=1).var(ddof=0) + expected = df.copy() + expected["B"] = [0.0] + [0.25] * 4 + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="5s", min_periods=1).var(ddof=1) + expected = df.copy() + expected["B"] = [np.nan, 0.5, 1.0, 1.0, 1 + 2 / 3.0] + tm.assert_frame_equal(result, expected) + + def test_ragged_skew(self, ragged): + df = ragged + result = df.rolling(window="3s", min_periods=1).skew() + expected = df.copy() + expected["B"] = [np.nan] * 5 + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="5s", min_periods=1).skew() + expected = df.copy() + expected["B"] = [np.nan] * 2 + [0.0, 0.0, 0.0] + tm.assert_frame_equal(result, expected) + + def test_ragged_kurt(self, ragged): + df = ragged + result = df.rolling(window="3s", min_periods=1).kurt() + expected = df.copy() + expected["B"] = [np.nan] * 5 + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="5s", min_periods=1).kurt() + expected = df.copy() + expected["B"] = [np.nan] * 4 + [-1.2] + tm.assert_frame_equal(result, expected) + + def test_ragged_count(self, ragged): + df = ragged + result = df.rolling(window="1s", min_periods=1).count() + expected = df.copy() + expected["B"] = [1.0, 1, 1, 1, 1] + tm.assert_frame_equal(result, expected) + + df = ragged + result = df.rolling(window="1s").count() + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="2s", min_periods=1).count() + expected = df.copy() + expected["B"] = [1.0, 1, 2, 1, 2] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="2s", min_periods=2).count() + expected = df.copy() + expected["B"] = [np.nan, np.nan, 2, np.nan, 2] + tm.assert_frame_equal(result, expected) + + def test_regular_min(self): + df = DataFrame( + {"A": date_range("20130101", periods=5, freq="s"), "B": [0.0, 1, 2, 3, 4]} + ).set_index("A") + result = df.rolling("1s").min() + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + df = DataFrame( + {"A": date_range("20130101", periods=5, freq="s"), "B": [5, 4, 3, 4, 5]} + ).set_index("A") + + tm.assert_frame_equal(result, expected) + result = df.rolling("2s").min() + expected = df.copy() + expected["B"] = [5.0, 4, 3, 3, 4] + tm.assert_frame_equal(result, expected) + + result = df.rolling("5s").min() + expected = df.copy() + expected["B"] = [5.0, 4, 3, 3, 3] + tm.assert_frame_equal(result, expected) + + def test_ragged_min(self, ragged): + df = ragged + + result = df.rolling(window="1s", min_periods=1).min() + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="2s", min_periods=1).min() + expected = df.copy() + expected["B"] = [0.0, 1, 1, 3, 3] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="5s", min_periods=1).min() + expected = df.copy() + expected["B"] = [0.0, 0, 0, 1, 1] + tm.assert_frame_equal(result, expected) + + def test_perf_min(self): + N = 10000 + + dfp = DataFrame( + {"B": np.random.default_rng(2).standard_normal(N)}, + index=date_range("20130101", periods=N, freq="s"), + ) + expected = dfp.rolling(2, min_periods=1).min() + result = dfp.rolling("2s").min() + assert ((result - expected) < 0.01).all().all() + + expected = dfp.rolling(200, min_periods=1).min() + result = dfp.rolling("200s").min() + assert ((result - expected) < 0.01).all().all() + + def test_ragged_max(self, ragged): + df = ragged + + result = df.rolling(window="1s", min_periods=1).max() + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="2s", min_periods=1).max() + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="5s", min_periods=1).max() + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + def test_ragged_first(self, ragged): + df = ragged + + result = df.rolling(window="1s", min_periods=1).first() + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="2s", min_periods=1).first() + expected = df.copy() + expected["B"] = [0.0, 1, 1, 3, 3] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="5s", min_periods=1).first() + expected = df.copy() + expected["B"] = [0.0, 0, 0, 1, 1] + tm.assert_frame_equal(result, expected) + + def test_ragged_last(self, ragged): + df = ragged + + result = df.rolling(window="1s", min_periods=1).last() + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="2s", min_periods=1).last() + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + result = df.rolling(window="5s", min_periods=1).last() + expected = df.copy() + expected["B"] = [0.0, 1, 2, 3, 4] + tm.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( + "freq, op, result_data", + [ + ("ms", "min", [0.0] * 10), + ("ms", "mean", [0.0] * 9 + [2.0 / 9]), + ("ms", "max", [0.0] * 9 + [2.0]), + ("s", "min", [0.0] * 10), + ("s", "mean", [0.0] * 9 + [2.0 / 9]), + ("s", "max", [0.0] * 9 + [2.0]), + ("min", "min", [0.0] * 10), + ("min", "mean", [0.0] * 9 + [2.0 / 9]), + ("min", "max", [0.0] * 9 + [2.0]), + ("h", "min", [0.0] * 10), + ("h", "mean", [0.0] * 9 + [2.0 / 9]), + ("h", "max", [0.0] * 9 + [2.0]), + ("D", "min", [0.0] * 10), + ("D", "mean", [0.0] * 9 + [2.0 / 9]), + ("D", "max", [0.0] * 9 + [2.0]), + ], + ) + def test_freqs_ops(self, freq, op, result_data): + # GH 21096 + index = date_range(start="2018-1-1 01:00:00", freq=f"1{freq}", periods=10) + # Explicit cast to float to avoid implicit cast when setting nan + s = Series(data=0, index=index, dtype="float") + s.iloc[1] = np.nan + s.iloc[-1] = 2 + result = getattr(s.rolling(window=f"10{freq}"), op)() + expected = Series(data=result_data, index=index) + + tm.assert_series_equal(result, expected) + + @pytest.mark.parametrize( + "f", + [ + "sum", + "mean", + "count", + "median", + "std", + "var", + "kurt", + "skew", + "min", + "max", + "first", + "last", + ], + ) + def test_all(self, f, regular): + # simple comparison of integer vs time-based windowing + df = regular * 2 + er = df.rolling(window=1) + r = df.rolling(window="1s") + + result = getattr(r, f)() + expected = getattr(er, f)() + tm.assert_frame_equal(result, expected) + + result = r.quantile(0.5) + expected = er.quantile(0.5) + tm.assert_frame_equal(result, expected) + + def test_all2(self, arithmetic_win_operators): + f = arithmetic_win_operators + # more sophisticated comparison of integer vs. + # time-based windowing + df = DataFrame( + {"B": np.arange(50)}, index=date_range("20130101", periods=50, freq="h") + ) + # in-range data + dft = df.between_time("09:00", "16:00") + + r = dft.rolling(window="5h") + + result = getattr(r, f)() + + # we need to roll the days separately + # to compare with a time-based roll + # finally groupby-apply will return a multi-index + # so we need to drop the day + def agg_by_day(x): + x = x.between_time("09:00", "16:00") + return getattr(x.rolling(5, min_periods=1), f)() + + expected = ( + df.groupby(df.index.day).apply(agg_by_day).reset_index(level=0, drop=True) + ) + + tm.assert_frame_equal(result, expected) + + def test_rolling_cov_offset(self): + # GH16058 + + idx = date_range("2017-01-01", periods=24, freq="1h") + ss = Series(np.arange(len(idx)), index=idx) + + result = ss.rolling("2h").cov() + expected = Series([np.nan] + [0.5] * (len(idx) - 1), index=idx) + tm.assert_series_equal(result, expected) + + expected2 = ss.rolling(2, min_periods=1).cov() + tm.assert_series_equal(result, expected2) + + result = ss.rolling("3h").cov() + expected = Series([np.nan, 0.5] + [1.0] * (len(idx) - 2), index=idx) + tm.assert_series_equal(result, expected) + + expected2 = ss.rolling(3, min_periods=1).cov() + tm.assert_series_equal(result, expected2) + + def test_rolling_on_decreasing_index(self, unit): + # GH-19248, GH-32385 + index = DatetimeIndex( + [ + Timestamp("20190101 09:00:30"), + Timestamp("20190101 09:00:27"), + Timestamp("20190101 09:00:20"), + Timestamp("20190101 09:00:18"), + Timestamp("20190101 09:00:10"), + ] + ).as_unit(unit) + + df = DataFrame({"column": [3, 4, 4, 5, 6]}, index=index) + result = df.rolling("5s").min() + expected = DataFrame({"column": [3.0, 3.0, 4.0, 4.0, 6.0]}, index=index) + tm.assert_frame_equal(result, expected) + + def test_rolling_on_empty(self): + # GH-32385 + df = DataFrame({"column": []}, index=[]) + result = df.rolling("5s").min() + expected = DataFrame({"column": []}, index=[]) + tm.assert_frame_equal(result, expected) + + def test_rolling_on_multi_index_level(self): + # GH-15584 + df = DataFrame( + {"column": range(6)}, + index=MultiIndex.from_product( + [date_range("20190101", periods=3), range(2)], names=["date", "seq"] + ), + ) + result = df.rolling("10D", on=df.index.get_level_values("date")).sum() + expected = DataFrame( + {"column": [0.0, 1.0, 3.0, 6.0, 10.0, 15.0]}, index=df.index + ) + tm.assert_frame_equal(result, expected) + + +def test_nat_axis_error(): + idx = [Timestamp("2020"), NaT] + df = DataFrame(np.eye(2), index=idx) + with pytest.raises(ValueError, match="index values must not have NaT"): + df.rolling("D").mean() + + +@td.skip_if_no("pyarrow") +def test_arrow_datetime_axis(): + # GH 55849 + expected = Series( + np.arange(5, dtype=np.float64), + index=Index( + date_range("2020-01-01", periods=5), dtype="timestamp[ns][pyarrow]" + ), + ) + result = expected.rolling("1D").sum() + tm.assert_series_equal(result, expected) diff --git a/pandas/tests/window/test_win_type.py b/pandas/tests/window/test_win_type.py new file mode 100644 index 0000000000000000000000000000000000000000..574dfc34b6d267169bd66b73b42a73829383f78d --- /dev/null +++ b/pandas/tests/window/test_win_type.py @@ -0,0 +1,670 @@ +import numpy as np +import pytest + +from pandas import ( + DataFrame, + Series, + Timedelta, + concat, + date_range, +) +import pandas._testing as tm +from pandas.api.indexers import BaseIndexer + + +@pytest.fixture( + params=[ + "triang", + "blackman", + "hamming", + "bartlett", + "bohman", + "blackmanharris", + "nuttall", + "barthann", + ] +) +def win_types(request): + return request.param + + +@pytest.fixture(params=["kaiser", "gaussian", "general_gaussian", "exponential"]) +def win_types_special(request): + return request.param + + +def test_constructor(frame_or_series): + # GH 12669 + pytest.importorskip("scipy") + c = frame_or_series(range(5)).rolling + + # valid + c(win_type="boxcar", window=2, min_periods=1) + c(win_type="boxcar", window=2, min_periods=1, center=True) + c(win_type="boxcar", window=2, min_periods=1, center=False) + + +@pytest.mark.parametrize("w", [2.0, "foo", np.array([2])]) +def test_invalid_constructor(frame_or_series, w): + # not valid + pytest.importorskip("scipy") + c = frame_or_series(range(5)).rolling + with pytest.raises(ValueError, match="min_periods must be an integer"): + c(win_type="boxcar", window=2, min_periods=w) + with pytest.raises(ValueError, match="center must be a boolean"): + c(win_type="boxcar", window=2, min_periods=1, center=w) + + +@pytest.mark.parametrize("wt", ["foobar", 1]) +def test_invalid_constructor_wintype(frame_or_series, wt): + pytest.importorskip("scipy") + c = frame_or_series(range(5)).rolling + with pytest.raises(ValueError, match="Invalid win_type"): + c(win_type=wt, window=2) + + +def test_constructor_with_win_type(frame_or_series, win_types): + # GH 12669 + pytest.importorskip("scipy") + c = frame_or_series(range(5)).rolling + c(win_type=win_types, window=2) + + +@pytest.mark.parametrize("arg", ["median", "kurt", "skew"]) +def test_agg_function_support(arg): + pytest.importorskip("scipy") + df = DataFrame({"A": np.arange(5)}) + roll = df.rolling(2, win_type="triang") + + msg = f"'{arg}' is not a valid function for 'Window' object" + with pytest.raises(AttributeError, match=msg): + roll.agg(arg) + + with pytest.raises(AttributeError, match=msg): + roll.agg([arg]) + + with pytest.raises(AttributeError, match=msg): + roll.agg({"A": arg}) + + +def test_invalid_scipy_arg(): + # This error is raised by scipy + pytest.importorskip("scipy") + msg = r"boxcar\(\) got an unexpected" + with pytest.raises(TypeError, match=msg): + Series(range(3)).rolling(1, win_type="boxcar").mean(foo="bar") + + +def test_constructor_with_win_type_invalid(frame_or_series): + # GH 13383 + pytest.importorskip("scipy") + c = frame_or_series(range(5)).rolling + + msg = "window must be an integer 0 or greater" + + with pytest.raises(ValueError, match=msg): + c(-1, win_type="boxcar") + + +def test_window_with_args(step): + # make sure that we are aggregating window functions correctly with arg + pytest.importorskip("scipy") + r = Series(np.random.default_rng(2).standard_normal(100)).rolling( + window=10, min_periods=1, win_type="gaussian", step=step + ) + expected = concat([r.mean(std=10), r.mean(std=0.01)], axis=1) + expected.columns = ["", ""] + result = r.aggregate([lambda x: x.mean(std=10), lambda x: x.mean(std=0.01)]) + tm.assert_frame_equal(result, expected) + + def a(x): + return x.mean(std=10) + + def b(x): + return x.mean(std=0.01) + + expected = concat([r.mean(std=10), r.mean(std=0.01)], axis=1) + expected.columns = ["a", "b"] + result = r.aggregate([a, b]) + tm.assert_frame_equal(result, expected) + + +def test_win_type_with_method_invalid(): + pytest.importorskip("scipy") + with pytest.raises( + NotImplementedError, match="'single' is the only supported method type." + ): + Series(range(1)).rolling(1, win_type="triang", method="table") + + +@pytest.mark.parametrize("arg", [2000000000, "2s", Timedelta("2s")]) +def test_consistent_win_type_freq(arg): + # GH 15969 + pytest.importorskip("scipy") + s = Series(range(1)) + with pytest.raises(ValueError, match="Invalid win_type freq"): + s.rolling(arg, win_type="freq") + + +def test_win_type_freq_return_none(): + # GH 48838 + freq_roll = Series(range(2), index=date_range("2020", periods=2)).rolling("2s") + assert freq_roll.win_type is None + + +def test_win_type_not_implemented(): + pytest.importorskip("scipy") + + class CustomIndexer(BaseIndexer): + def get_window_bounds(self, num_values, min_periods, center, closed, step): + return np.array([0, 1]), np.array([1, 2]) + + df = DataFrame({"values": range(2)}) + indexer = CustomIndexer() + with pytest.raises(NotImplementedError, match="BaseIndexer subclasses not"): + df.rolling(indexer, win_type="boxcar") + + +def test_cmov_mean(step): + # GH 8238 + pytest.importorskip("scipy") + vals = np.array([6.95, 15.21, 4.72, 9.12, 13.81, 13.49, 16.68, 9.48, 10.63, 14.48]) + result = Series(vals).rolling(5, center=True, step=step).mean() + expected_values = [ + np.nan, + np.nan, + 9.962, + 11.27, + 11.564, + 12.516, + 12.818, + 12.952, + np.nan, + np.nan, + ] + expected = Series(expected_values)[::step] + tm.assert_series_equal(expected, result) + + +def test_cmov_window(step): + # GH 8238 + pytest.importorskip("scipy") + vals = np.array([6.95, 15.21, 4.72, 9.12, 13.81, 13.49, 16.68, 9.48, 10.63, 14.48]) + result = Series(vals).rolling(5, win_type="boxcar", center=True, step=step).mean() + expected_values = [ + np.nan, + np.nan, + 9.962, + 11.27, + 11.564, + 12.516, + 12.818, + 12.952, + np.nan, + np.nan, + ] + expected = Series(expected_values)[::step] + tm.assert_series_equal(expected, result) + + +def test_cmov_window_corner(step): + # GH 8238 + # all nan + pytest.importorskip("scipy") + vals = Series([np.nan] * 10) + result = vals.rolling(5, center=True, win_type="boxcar", step=step).mean() + assert np.isnan(result).all() + + # empty + vals = Series([], dtype=object) + result = vals.rolling(5, center=True, win_type="boxcar", step=step).mean() + assert len(result) == 0 + + # shorter than window + vals = Series(np.random.default_rng(2).standard_normal(5)) + result = vals.rolling(10, win_type="boxcar", step=step).mean() + assert np.isnan(result).all() + assert len(result) == len(range(0, 5, step or 1)) + + +@pytest.mark.parametrize( + "f,xp", + [ + ( + "mean", + [ + [np.nan, np.nan], + [np.nan, np.nan], + [9.252, 9.392], + [8.644, 9.906], + [8.87, 10.208], + [6.81, 8.588], + [7.792, 8.644], + [9.05, 7.824], + [np.nan, np.nan], + [np.nan, np.nan], + ], + ), + ( + "std", + [ + [np.nan, np.nan], + [np.nan, np.nan], + [3.789706, 4.068313], + [3.429232, 3.237411], + [3.589269, 3.220810], + [3.405195, 2.380655], + [3.281839, 2.369869], + [3.676846, 1.801799], + [np.nan, np.nan], + [np.nan, np.nan], + ], + ), + ( + "var", + [ + [np.nan, np.nan], + [np.nan, np.nan], + [14.36187, 16.55117], + [11.75963, 10.48083], + [12.88285, 10.37362], + [11.59535, 5.66752], + [10.77047, 5.61628], + [13.51920, 3.24648], + [np.nan, np.nan], + [np.nan, np.nan], + ], + ), + ( + "sum", + [ + [np.nan, np.nan], + [np.nan, np.nan], + [46.26, 46.96], + [43.22, 49.53], + [44.35, 51.04], + [34.05, 42.94], + [38.96, 43.22], + [45.25, 39.12], + [np.nan, np.nan], + [np.nan, np.nan], + ], + ), + ], +) +def test_cmov_window_frame(f, xp, step): + # Gh 8238 + pytest.importorskip("scipy") + df = DataFrame( + np.array( + [ + [12.18, 3.64], + [10.18, 9.16], + [13.24, 14.61], + [4.51, 8.11], + [6.15, 11.44], + [9.14, 6.21], + [11.31, 10.67], + [2.94, 6.51], + [9.42, 8.39], + [12.44, 7.34], + ] + ) + ) + xp = DataFrame(np.array(xp))[::step] + + roll = df.rolling(5, win_type="boxcar", center=True, step=step) + rs = getattr(roll, f)() + + tm.assert_frame_equal(xp, rs) + + +@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4, 5]) +def test_cmov_window_na_min_periods(step, min_periods): + pytest.importorskip("scipy") + vals = Series(np.random.default_rng(2).standard_normal(10)) + vals[4] = np.nan + vals[8] = np.nan + + xp = vals.rolling(5, min_periods=min_periods, center=True, step=step).mean() + rs = vals.rolling( + 5, win_type="boxcar", min_periods=min_periods, center=True, step=step + ).mean() + tm.assert_series_equal(xp, rs) + + +def test_cmov_window_regular(win_types, step): + # GH 8238 + pytest.importorskip("scipy") + vals = np.array([6.95, 15.21, 4.72, 9.12, 13.81, 13.49, 16.68, 9.48, 10.63, 14.48]) + xps = { + "hamming": [ + np.nan, + np.nan, + 8.71384, + 9.56348, + 12.38009, + 14.03687, + 13.8567, + 11.81473, + np.nan, + np.nan, + ], + "triang": [ + np.nan, + np.nan, + 9.28667, + 10.34667, + 12.00556, + 13.33889, + 13.38, + 12.33667, + np.nan, + np.nan, + ], + "barthann": [ + np.nan, + np.nan, + 8.4425, + 9.1925, + 12.5575, + 14.3675, + 14.0825, + 11.5675, + np.nan, + np.nan, + ], + "bohman": [ + np.nan, + np.nan, + 7.61599, + 9.1764, + 12.83559, + 14.17267, + 14.65923, + 11.10401, + np.nan, + np.nan, + ], + "blackmanharris": [ + np.nan, + np.nan, + 6.97691, + 9.16438, + 13.05052, + 14.02156, + 15.10512, + 10.74574, + np.nan, + np.nan, + ], + "nuttall": [ + np.nan, + np.nan, + 7.04618, + 9.16786, + 13.02671, + 14.03559, + 15.05657, + 10.78514, + np.nan, + np.nan, + ], + "blackman": [ + np.nan, + np.nan, + 7.73345, + 9.17869, + 12.79607, + 14.20036, + 14.57726, + 11.16988, + np.nan, + np.nan, + ], + "bartlett": [ + np.nan, + np.nan, + 8.4425, + 9.1925, + 12.5575, + 14.3675, + 14.0825, + 11.5675, + np.nan, + np.nan, + ], + } + + xp = Series(xps[win_types])[::step] + rs = Series(vals).rolling(5, win_type=win_types, center=True, step=step).mean() + tm.assert_series_equal(xp, rs) + + +def test_cmov_window_regular_linear_range(win_types, step): + # GH 8238 + pytest.importorskip("scipy") + vals = np.array(range(10), dtype=float) + rs = Series(vals).rolling(5, win_type=win_types, center=True, step=step).mean() + xp = vals + xp[:2] = np.nan + xp[-2:] = np.nan + xp = Series(xp)[::step] + + tm.assert_series_equal(xp, rs) + + +def test_cmov_window_regular_missing_data(win_types, step): + # GH 8238 + pytest.importorskip("scipy") + vals = np.array( + [6.95, 15.21, 4.72, 9.12, 13.81, 13.49, 16.68, np.nan, 10.63, 14.48] + ) + xps = { + "bartlett": [ + np.nan, + np.nan, + 9.70333, + 10.5225, + 8.4425, + 9.1925, + 12.5575, + 14.3675, + 15.61667, + 13.655, + ], + "blackman": [ + np.nan, + np.nan, + 9.04582, + 11.41536, + 7.73345, + 9.17869, + 12.79607, + 14.20036, + 15.8706, + 13.655, + ], + "barthann": [ + np.nan, + np.nan, + 9.70333, + 10.5225, + 8.4425, + 9.1925, + 12.5575, + 14.3675, + 15.61667, + 13.655, + ], + "bohman": [ + np.nan, + np.nan, + 8.9444, + 11.56327, + 7.61599, + 9.1764, + 12.83559, + 14.17267, + 15.90976, + 13.655, + ], + "hamming": [ + np.nan, + np.nan, + 9.59321, + 10.29694, + 8.71384, + 9.56348, + 12.38009, + 14.20565, + 15.24694, + 13.69758, + ], + "nuttall": [ + np.nan, + np.nan, + 8.47693, + 12.2821, + 7.04618, + 9.16786, + 13.02671, + 14.03673, + 16.08759, + 13.65553, + ], + "triang": [ + np.nan, + np.nan, + 9.33167, + 9.76125, + 9.28667, + 10.34667, + 12.00556, + 13.82125, + 14.49429, + 13.765, + ], + "blackmanharris": [ + np.nan, + np.nan, + 8.42526, + 12.36824, + 6.97691, + 9.16438, + 13.05052, + 14.02175, + 16.1098, + 13.65509, + ], + } + + xp = Series(xps[win_types])[::step] + rs = Series(vals).rolling(5, win_type=win_types, min_periods=3, step=step).mean() + tm.assert_series_equal(xp, rs) + + +def test_cmov_window_special(win_types_special, step): + # GH 8238 + pytest.importorskip("scipy") + kwds = { + "kaiser": {"beta": 1.0}, + "gaussian": {"std": 1.0}, + "general_gaussian": {"p": 2.0, "sig": 2.0}, + "exponential": {"tau": 10}, + } + + vals = np.array([6.95, 15.21, 4.72, 9.12, 13.81, 13.49, 16.68, 9.48, 10.63, 14.48]) + + xps = { + "gaussian": [ + np.nan, + np.nan, + 8.97297, + 9.76077, + 12.24763, + 13.89053, + 13.65671, + 12.01002, + np.nan, + np.nan, + ], + "general_gaussian": [ + np.nan, + np.nan, + 9.85011, + 10.71589, + 11.73161, + 13.08516, + 12.95111, + 12.74577, + np.nan, + np.nan, + ], + "kaiser": [ + np.nan, + np.nan, + 9.86851, + 11.02969, + 11.65161, + 12.75129, + 12.90702, + 12.83757, + np.nan, + np.nan, + ], + "exponential": [ + np.nan, + np.nan, + 9.83364, + 11.10472, + 11.64551, + 12.66138, + 12.92379, + 12.83770, + np.nan, + np.nan, + ], + } + + xp = Series(xps[win_types_special])[::step] + rs = ( + Series(vals) + .rolling(5, win_type=win_types_special, center=True, step=step) + .mean(**kwds[win_types_special]) + ) + tm.assert_series_equal(xp, rs) + + +def test_cmov_window_special_linear_range(win_types_special, step): + # GH 8238 + pytest.importorskip("scipy") + kwds = { + "kaiser": {"beta": 1.0}, + "gaussian": {"std": 1.0}, + "general_gaussian": {"p": 2.0, "sig": 2.0}, + "slepian": {"width": 0.5}, + "exponential": {"tau": 10}, + } + + vals = np.array(range(10), dtype=float) + rs = ( + Series(vals) + .rolling(5, win_type=win_types_special, center=True, step=step) + .mean(**kwds[win_types_special]) + ) + xp = vals + xp[:2] = np.nan + xp[-2:] = np.nan + xp = Series(xp)[::step] + tm.assert_series_equal(xp, rs) + + +def test_weighted_var_big_window_no_segfault(win_types, center): + # GitHub Issue #46772 + pytest.importorskip("scipy") + x = Series(0) + result = x.rolling(window=16, center=center, win_type=win_types).var() + expected = Series(np.nan) + + tm.assert_series_equal(result, expected) diff --git a/pandas/tseries/__init__.py b/pandas/tseries/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c00843ecac418a41b01470db93388f6c5568ea6b --- /dev/null +++ b/pandas/tseries/__init__.py @@ -0,0 +1,12 @@ +# ruff: noqa: TC004 +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + # import modules that have public classes/functions: + from pandas.tseries import ( + frequencies, + offsets, + ) + + # and mark only those modules as public + __all__ = ["frequencies", "offsets"] diff --git a/pandas/tseries/api.py b/pandas/tseries/api.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea899f1610a7ef223f1c069552bedf4c502a8b4 --- /dev/null +++ b/pandas/tseries/api.py @@ -0,0 +1,10 @@ +""" +Timeseries API +""" + +from pandas._libs.tslibs.parsing import guess_datetime_format + +from pandas.tseries import offsets +from pandas.tseries.frequencies import infer_freq + +__all__ = ["guess_datetime_format", "infer_freq", "offsets"] diff --git a/pandas/tseries/frequencies.py b/pandas/tseries/frequencies.py new file mode 100644 index 0000000000000000000000000000000000000000..196b3aadccaefe4bbc4cb862d36b035feeb92e93 --- /dev/null +++ b/pandas/tseries/frequencies.py @@ -0,0 +1,623 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from pandas._libs import lib +from pandas._libs.algos import unique_deltas +from pandas._libs.tslibs import ( + Timestamp, + get_unit_from_dtype, + periods_per_day, + tz_convert_from_utc, +) +from pandas._libs.tslibs.ccalendar import ( + DAYS, + MONTH_ALIASES, + MONTH_NUMBERS, + MONTHS, + int_to_weekday, +) +from pandas._libs.tslibs.dtypes import OFFSET_TO_PERIOD_FREQSTR +from pandas._libs.tslibs.fields import ( + build_field_sarray, + month_position_check, +) +from pandas._libs.tslibs.offsets import ( + DateOffset, + Day, + to_offset, +) +from pandas._libs.tslibs.parsing import get_rule_month +from pandas.util._decorators import ( + cache_readonly, + set_module, +) + +from pandas.core.dtypes.common import is_numeric_dtype +from pandas.core.dtypes.dtypes import ( + ArrowDtype, + DatetimeTZDtype, + PeriodDtype, +) +from pandas.core.dtypes.generic import ( + ABCIndex, + ABCSeries, +) + +from pandas.core.algorithms import unique + +if TYPE_CHECKING: + from pandas._typing import npt + + from pandas import ( + DatetimeIndex, + Series, + TimedeltaIndex, + ) + from pandas.core.arrays.datetimelike import DatetimeLikeArrayMixin +# -------------------------------------------------------------------- +# Offset related functions + +_need_suffix = ["QS", "BQE", "BQS", "YS", "BYE", "BYS"] + +for _prefix in _need_suffix: + for _m in MONTHS: + key = f"{_prefix}-{_m}" + OFFSET_TO_PERIOD_FREQSTR[key] = OFFSET_TO_PERIOD_FREQSTR[_prefix] + +for _prefix in ["Y", "Q"]: + for _m in MONTHS: + _alias = f"{_prefix}-{_m}" + OFFSET_TO_PERIOD_FREQSTR[_alias] = _alias + +for _d in DAYS: + OFFSET_TO_PERIOD_FREQSTR[f"W-{_d}"] = f"W-{_d}" + + +def get_period_alias(offset_str: str) -> str | None: + """ + Alias to closest period strings BQ->Q etc. + """ + return OFFSET_TO_PERIOD_FREQSTR.get(offset_str, None) + + +# --------------------------------------------------------------------- +# Period codes + + +@set_module("pandas") +def infer_freq( + index: DatetimeIndex | TimedeltaIndex | Series | DatetimeLikeArrayMixin, +) -> str | None: + """ + Infer the most likely frequency given the input index. + + This method attempts to deduce the most probable frequency (e.g., 'D' for daily, + 'H' for hourly) from a sequence of datetime-like objects. It is particularly useful + when the frequency of a time series is not explicitly set or known but can be + inferred from its values. + + Parameters + ---------- + index : DatetimeIndex, TimedeltaIndex, Series or array-like + If passed a Series will use the values of the series (NOT THE INDEX). + + Returns + ------- + str or None + None if no discernible frequency. + + Raises + ------ + TypeError + If the index is not datetime-like. + ValueError + If there are fewer than three values. + + See Also + -------- + date_range : Return a fixed frequency DatetimeIndex. + timedelta_range : Return a fixed frequency TimedeltaIndex with day as the default. + period_range : Return a fixed frequency PeriodIndex. + DatetimeIndex.freq : Return the frequency object if it is set, otherwise None. + + Examples + -------- + >>> idx = pd.date_range(start="2020/12/01", end="2020/12/30", periods=30) + >>> pd.infer_freq(idx) + 'D' + """ + from pandas.core.api import DatetimeIndex + + if isinstance(index, ABCSeries): + values = index._values + + if isinstance(index.dtype, ArrowDtype): + import pyarrow as pa + + if pa.types.is_timestamp(values.dtype.pyarrow_dtype): + # GH#58403 + values = values._to_datetimearray() + + if not ( + lib.is_np_dtype(values.dtype, "mM") + or isinstance(values.dtype, DatetimeTZDtype) + or values.dtype == object + ): + raise TypeError( + "cannot infer freq from a non-convertible dtype " + f"on a Series of {index.dtype}" + ) + index = values + + inferer: _FrequencyInferer + + if not hasattr(index, "dtype"): + pass + elif isinstance(index.dtype, PeriodDtype): + raise TypeError( + "PeriodIndex given. Check the `freq` attribute instead of using infer_freq." + ) + elif lib.is_np_dtype(index.dtype, "m"): + # Allow TimedeltaIndex and TimedeltaArray + inferer = _TimedeltaFrequencyInferer(index) + return inferer.get_freq() + + elif is_numeric_dtype(index.dtype): + raise TypeError( + f"cannot infer freq from a non-convertible index of dtype {index.dtype}" + ) + + if not isinstance(index, DatetimeIndex): + index = DatetimeIndex(index, copy=False) + + inferer = _FrequencyInferer(index) + return inferer.get_freq() + + +class _FrequencyInferer: + """ + Not sure if I can avoid the state machine here + """ + + def __init__(self, index) -> None: + self.index = index + self.i8values = index.asi8 + + # For get_unit_from_dtype we need the dtype to the underlying ndarray, + # which for tz-aware is not the same as index.dtype + if isinstance(index, ABCIndex): + # error: Item "ndarray[Any, Any]" of "Union[ExtensionArray, + # ndarray[Any, Any]]" has no attribute "_ndarray" + self._creso = get_unit_from_dtype( + index._data._ndarray.dtype # type: ignore[union-attr] + ) + else: + # otherwise we have DTA/TDA + self._creso = get_unit_from_dtype(index._ndarray.dtype) + + # This moves the values, which are implicitly in UTC, to the + # the timezone so they are in local time + if hasattr(index, "tz"): + if index.tz is not None: + self.i8values = tz_convert_from_utc( + self.i8values, index.tz, reso=self._creso + ) + + if len(index) < 3: + raise ValueError("Need at least 3 dates to infer frequency") + + self.is_monotonic = ( + self.index._is_monotonic_increasing or self.index._is_monotonic_decreasing + ) + + @cache_readonly + def deltas(self) -> npt.NDArray[np.int64]: + return unique_deltas(self.i8values) + + @cache_readonly + def deltas_asi8(self) -> npt.NDArray[np.int64]: + # NB: we cannot use self.i8values here because we may have converted + # the tz in __init__ + return unique_deltas(self.index.asi8) + + @cache_readonly + def is_unique(self) -> bool: + return len(self.deltas) == 1 + + @cache_readonly + def is_unique_asi8(self) -> bool: + return len(self.deltas_asi8) == 1 + + def get_freq(self) -> str | None: + """ + Find the appropriate frequency string to describe the inferred + frequency of self.i8values + + Returns + ------- + str or None + """ + if not self.is_monotonic or not self.index._is_unique: + return None + + delta = self.deltas[0] + ppd = periods_per_day(self._creso) + if delta and _is_multiple(delta, ppd): + return self._infer_daily_rule() + + # Business hourly, maybe. 17: one day / 65: one weekend + if self.hour_deltas in ([1, 17], [1, 65], [1, 17, 65]): + return "bh" + + # Possibly intraday frequency. Here we use the + # original .asi8 values as the modified values + # will not work around DST transitions. See #8772 + if not self.is_unique_asi8: + return None + + delta = self.deltas_asi8[0] + pph = ppd // 24 + ppm = pph // 60 + pps = ppm // 60 + if _is_multiple(delta, pph): + # Hours + return _maybe_add_count("h", delta / pph) + elif _is_multiple(delta, ppm): + # Minutes + return _maybe_add_count("min", delta / ppm) + elif _is_multiple(delta, pps): + # Seconds + return _maybe_add_count("s", delta / pps) + elif _is_multiple(delta, (pps // 1000)): + # Milliseconds + return _maybe_add_count("ms", delta / (pps // 1000)) + elif _is_multiple(delta, (pps // 1_000_000)): + # Microseconds + return _maybe_add_count("us", delta / (pps // 1_000_000)) + else: + # Nanoseconds + return _maybe_add_count("ns", delta) + + @cache_readonly + def day_deltas(self) -> list[int]: + ppd = periods_per_day(self._creso) + return [x / ppd for x in self.deltas] + + @cache_readonly + def hour_deltas(self) -> list[int]: + pph = periods_per_day(self._creso) // 24 + return [x / pph for x in self.deltas] + + @cache_readonly + def fields(self) -> np.ndarray: # structured array of fields + return build_field_sarray(self.i8values, reso=self._creso) + + @cache_readonly + def rep_stamp(self) -> Timestamp: + return Timestamp(self.i8values[0], unit=self.index.unit) + + def month_position_check(self) -> str | None: + return month_position_check(self.fields, self.index.dayofweek) + + @cache_readonly + def mdiffs(self) -> npt.NDArray[np.int64]: + nmonths = self.fields["Y"] * 12 + self.fields["M"] + return unique_deltas(nmonths.astype("i8")) + + @cache_readonly + def ydiffs(self) -> npt.NDArray[np.int64]: + return unique_deltas(self.fields["Y"].astype("i8")) + + def _infer_daily_rule(self) -> str | None: + annual_rule = self._get_annual_rule() + if annual_rule: + nyears = self.ydiffs[0] + month = MONTH_ALIASES[self.rep_stamp.month] + alias = f"{annual_rule}-{month}" + return _maybe_add_count(alias, nyears) + + quarterly_rule = self._get_quarterly_rule() + if quarterly_rule: + nquarters = self.mdiffs[0] / 3 + mod_dict = {0: 12, 2: 11, 1: 10} + month = MONTH_ALIASES[mod_dict[self.rep_stamp.month % 3]] + alias = f"{quarterly_rule}-{month}" + return _maybe_add_count(alias, nquarters) + + monthly_rule = self._get_monthly_rule() + if monthly_rule: + return _maybe_add_count(monthly_rule, self.mdiffs[0]) + + if self.is_unique: + return self._get_daily_rule() + + if self._is_business_daily(): + return "B" + + wom_rule = self._get_wom_rule() + if wom_rule: + return wom_rule + + return None + + def _get_daily_rule(self) -> str | None: + ppd = periods_per_day(self._creso) + days = self.deltas[0] / ppd + if days % 7 == 0: + # Weekly + wd = int_to_weekday[self.rep_stamp.weekday()] + alias = f"W-{wd}" + return _maybe_add_count(alias, days / 7) + else: + return _maybe_add_count("D", days) + + def _get_annual_rule(self) -> str | None: + if len(self.ydiffs) > 1: + return None + + if len(unique(self.fields["M"])) > 1: + return None + + pos_check = self.month_position_check() + + if pos_check is None: + return None + else: + return {"cs": "YS", "bs": "BYS", "ce": "YE", "be": "BYE"}.get(pos_check) + + def _get_quarterly_rule(self) -> str | None: + if len(self.mdiffs) > 1: + return None + + if not self.mdiffs[0] % 3 == 0: + return None + + pos_check = self.month_position_check() + + if pos_check is None: + return None + else: + return {"cs": "QS", "bs": "BQS", "ce": "QE", "be": "BQE"}.get(pos_check) + + def _get_monthly_rule(self) -> str | None: + if len(self.mdiffs) > 1: + return None + pos_check = self.month_position_check() + + if pos_check is None: + return None + else: + return {"cs": "MS", "bs": "BMS", "ce": "ME", "be": "BME"}.get(pos_check) + + def _is_business_daily(self) -> bool: + # quick check: cannot be business daily + if self.day_deltas != [1, 3]: + return False + + # probably business daily, but need to confirm + first_weekday = self.index[0].weekday() + shifts = np.diff(self.i8values) + ppd = periods_per_day(self._creso) + shifts = np.floor_divide(shifts, ppd) + weekdays = np.mod(first_weekday + np.cumsum(shifts), 7) + + return bool( + np.all( + ((weekdays == 0) & (shifts == 3)) + | ((weekdays > 0) & (weekdays <= 4) & (shifts == 1)) + ) + ) + + def _get_wom_rule(self) -> str | None: + weekdays = unique(self.index.weekday) + if len(weekdays) > 1: + return None + + week_of_months = unique((self.index.day - 1) // 7) + # Only attempt to infer up to WOM-4. See #9425 + week_of_months = week_of_months[week_of_months < 4] + if len(week_of_months) == 0 or len(week_of_months) > 1: + return None + + # get which week + week = week_of_months[0] + 1 + wd = int_to_weekday[weekdays[0]] + + return f"WOM-{week}{wd}" + + +class _TimedeltaFrequencyInferer(_FrequencyInferer): + def _infer_daily_rule(self): + if self.is_unique: + return self._get_daily_rule() + + +def _is_multiple(us, mult: int) -> bool: + return us % mult == 0 + + +def _maybe_add_count(base: str, count: float) -> str: + if count != 1: + assert count == int(count) + count = int(count) + return f"{count}{base}" + else: + return base + + +# ---------------------------------------------------------------------- +# Frequency comparison + + +def is_subperiod(source, target) -> bool: + """ + Returns True if downsampling is possible between source and target + frequencies + + Parameters + ---------- + source : str or DateOffset + Frequency converting from + target : str or DateOffset + Frequency converting to + + Returns + ------- + bool + """ + if target is None or source is None: + return False + source = _maybe_coerce_freq(source) + target = _maybe_coerce_freq(target) + + if _is_annual(target): + if _is_quarterly(source): + return _quarter_months_conform( + get_rule_month(source), get_rule_month(target) + ) + return source in {"D", "C", "B", "M", "h", "min", "s", "ms", "us", "ns"} + elif _is_quarterly(target): + return source in {"D", "C", "B", "M", "h", "min", "s", "ms", "us", "ns"} + elif _is_monthly(target): + return source in {"D", "C", "B", "h", "min", "s", "ms", "us", "ns"} + elif _is_weekly(target): + return source in {target, "D", "C", "B", "h", "min", "s", "ms", "us", "ns"} + elif target == "B": + return source in {"B", "h", "min", "s", "ms", "us", "ns"} + elif target == "C": + return source in {"C", "h", "min", "s", "ms", "us", "ns"} + elif target == "D": + return source in {"D", "h", "min", "s", "ms", "us", "ns"} + elif target == "h": + return source in {"h", "min", "s", "ms", "us", "ns"} + elif target == "min": + return source in {"min", "s", "ms", "us", "ns"} + elif target == "s": + return source in {"s", "ms", "us", "ns"} + elif target == "ms": + return source in {"ms", "us", "ns"} + elif target == "us": + return source in {"us", "ns"} + elif target == "ns": + return source in {"ns"} + else: + return False + + +def is_superperiod(source, target) -> bool: + """ + Returns True if upsampling is possible between source and target + frequencies + + Parameters + ---------- + source : str or DateOffset + Frequency converting from + target : str or DateOffset + Frequency converting to + + Returns + ------- + bool + """ + if target is None or source is None: + return False + source = _maybe_coerce_freq(source) + target = _maybe_coerce_freq(target) + + if _is_annual(source): + if _is_annual(target): + return get_rule_month(source) == get_rule_month(target) + + if _is_quarterly(target): + smonth = get_rule_month(source) + tmonth = get_rule_month(target) + return _quarter_months_conform(smonth, tmonth) + return target in {"D", "C", "B", "M", "h", "min", "s", "ms", "us", "ns"} + elif _is_quarterly(source): + return target in {"D", "C", "B", "M", "h", "min", "s", "ms", "us", "ns"} + elif _is_monthly(source): + return target in {"D", "C", "B", "h", "min", "s", "ms", "us", "ns"} + elif _is_weekly(source): + return target in {source, "D", "C", "B", "h", "min", "s", "ms", "us", "ns"} + elif source == "B": + return target in {"D", "C", "B", "h", "min", "s", "ms", "us", "ns"} + elif source == "C": + return target in {"D", "C", "B", "h", "min", "s", "ms", "us", "ns"} + elif source == "D": + return target in {"D", "C", "B", "h", "min", "s", "ms", "us", "ns"} + elif source == "h": + return target in {"h", "min", "s", "ms", "us", "ns"} + elif source == "min": + return target in {"min", "s", "ms", "us", "ns"} + elif source == "s": + return target in {"s", "ms", "us", "ns"} + elif source == "ms": + return target in {"ms", "us", "ns"} + elif source == "us": + return target in {"us", "ns"} + elif source == "ns": + return target in {"ns"} + else: + return False + + +def _maybe_coerce_freq(code) -> str: + """we might need to coerce a code to a rule_code + and uppercase it + + Parameters + ---------- + source : str or DateOffset + Frequency converting from + + Returns + ------- + str + """ + assert code is not None + if isinstance(code, DateOffset): + code = PeriodDtype(to_offset(code.name))._freqstr + if code in {"h", "min", "s", "ms", "us", "ns"}: + return code + else: + return code.upper() + + +def _quarter_months_conform(source: str, target: str) -> bool: + snum = MONTH_NUMBERS[source] + tnum = MONTH_NUMBERS[target] + return snum % 3 == tnum % 3 + + +def _is_annual(rule: str) -> bool: + rule = rule.upper() + return rule == "Y" or rule.startswith("Y-") + + +def _is_quarterly(rule: str) -> bool: + rule = rule.upper() + return rule == "Q" or rule.startswith(("Q-", "BQ")) + + +def _is_monthly(rule: str) -> bool: + rule = rule.upper() + return rule in ("M", "BM") + + +def _is_weekly(rule: str) -> bool: + rule = rule.upper() + return rule == "W" or rule.startswith("W-") + + +__all__ = [ + "Day", + "get_period_alias", + "infer_freq", + "is_subperiod", + "is_superperiod", + "to_offset", +] diff --git a/pandas/tseries/holiday.py b/pandas/tseries/holiday.py new file mode 100644 index 0000000000000000000000000000000000000000..b5ab8cb2eb8bea8e239baca2ff0794931e264b7d --- /dev/null +++ b/pandas/tseries/holiday.py @@ -0,0 +1,682 @@ +from __future__ import annotations + +from datetime import ( + datetime, + timedelta, +) +from typing import ( + TYPE_CHECKING, + Literal, + overload, +) +import warnings + +from dateutil.relativedelta import ( + FR, + MO, + SA, + SU, + TH, + TU, + WE, +) +import numpy as np + +from pandas._libs.tslibs.offsets import BaseOffset +from pandas.errors import PerformanceWarning + +from pandas import ( + DateOffset, + DatetimeIndex, + Series, + Timestamp, + concat, + date_range, +) + +from pandas.tseries.offsets import ( + Day, + Easter, +) + +if TYPE_CHECKING: + from collections.abc import Callable + + +def next_monday(dt: datetime) -> datetime: + """ + If holiday falls on Saturday, use following Monday instead; + if holiday falls on Sunday, use Monday instead + """ + if dt.weekday() == 5: + return dt + timedelta(2) + elif dt.weekday() == 6: + return dt + timedelta(1) + return dt + + +def next_monday_or_tuesday(dt: datetime) -> datetime: + """ + For second holiday of two adjacent ones! + If holiday falls on Saturday, use following Monday instead; + if holiday falls on Sunday or Monday, use following Tuesday instead + (because Monday is already taken by adjacent holiday on the day before) + """ + dow = dt.weekday() + if dow in (5, 6): + return dt + timedelta(2) + if dow == 0: + return dt + timedelta(1) + return dt + + +def previous_friday(dt: datetime) -> datetime: + """ + If holiday falls on Saturday or Sunday, use previous Friday instead. + """ + if dt.weekday() == 5: + return dt - timedelta(1) + elif dt.weekday() == 6: + return dt - timedelta(2) + return dt + + +def sunday_to_monday(dt: datetime) -> datetime: + """ + If holiday falls on Sunday, use day thereafter (Monday) instead. + """ + if dt.weekday() == 6: + return dt + timedelta(1) + return dt + + +def weekend_to_monday(dt: datetime) -> datetime: + """ + If holiday falls on Sunday or Saturday, + use day thereafter (Monday) instead. + Needed for holidays such as Christmas observation in Europe + """ + if dt.weekday() == 6: + return dt + timedelta(1) + elif dt.weekday() == 5: + return dt + timedelta(2) + return dt + + +def nearest_workday(dt: datetime) -> datetime: + """ + If holiday falls on Saturday, use day before (Friday) instead; + if holiday falls on Sunday, use day thereafter (Monday) instead. + """ + if dt.weekday() == 5: + return dt - timedelta(1) + elif dt.weekday() == 6: + return dt + timedelta(1) + return dt + + +def next_workday(dt: datetime) -> datetime: + """ + returns next workday used for observances + """ + dt += timedelta(days=1) + while dt.weekday() > 4: + # Mon-Fri are 0-4 + dt += timedelta(days=1) + return dt + + +def previous_workday(dt: datetime) -> datetime: + """ + returns previous workday used for observances + """ + dt -= timedelta(days=1) + while dt.weekday() > 4: + # Mon-Fri are 0-4 + dt -= timedelta(days=1) + return dt + + +def before_nearest_workday(dt: datetime) -> datetime: + """ + returns previous workday before nearest workday + """ + return previous_workday(nearest_workday(dt)) + + +def after_nearest_workday(dt: datetime) -> datetime: + """ + returns next workday after nearest workday + needed for Boxing day or multiple holidays in a series + """ + return next_workday(nearest_workday(dt)) + + +class Holiday: + """ + Class that defines a holiday with start/end dates and rules + for observance. + """ + + start_date: Timestamp | None + end_date: Timestamp | None + days_of_week: tuple[int, ...] | None + + def __init__( + self, + name: str, + year=None, + month=None, + day=None, + offset: BaseOffset | list[BaseOffset] | None = None, + observance: Callable | None = None, + start_date=None, + end_date=None, + days_of_week: tuple | None = None, + exclude_dates: DatetimeIndex | None = None, + ) -> None: + """ + Parameters + ---------- + name : str + Name of the holiday , defaults to class name + year : int, default None + Year of the holiday + month : int, default None + Month of the holiday + day : int, default None + Day of the holiday + offset : list of pandas.tseries.offsets or + class from pandas.tseries.offsets, default None + Computes offset from date + observance : function, default None + Computes when holiday is given a pandas Timestamp + start_date : datetime-like, default None + First date the holiday is observed + end_date : datetime-like, default None + Last date the holiday is observed + days_of_week : tuple of int or dateutil.relativedelta weekday strs, default None + Provide a tuple of days e.g (0,1,2,3,) for Monday through Thursday + Monday=0,..,Sunday=6 + Only instances of the holiday included in days_of_week will be computed + exclude_dates : DatetimeIndex or default None + Specific dates to exclude e.g. skipping a specific year's holiday + + Examples + -------- + >>> from dateutil.relativedelta import MO + + >>> USMemorialDay = pd.tseries.holiday.Holiday( + ... "Memorial Day", month=5, day=31, offset=pd.DateOffset(weekday=MO(-1)) + ... ) + >>> USMemorialDay + Holiday: Memorial Day (month=5, day=31, offset=) + + >>> USLaborDay = pd.tseries.holiday.Holiday( + ... "Labor Day", month=9, day=1, offset=pd.DateOffset(weekday=MO(1)) + ... ) + >>> USLaborDay + Holiday: Labor Day (month=9, day=1, offset=) + + >>> July3rd = pd.tseries.holiday.Holiday("July 3rd", month=7, day=3) + >>> July3rd + Holiday: July 3rd (month=7, day=3, ) + + >>> NewYears = pd.tseries.holiday.Holiday( + ... "New Years Day", + ... month=1, + ... day=1, + ... observance=pd.tseries.holiday.nearest_workday, + ... ) + >>> NewYears # doctest: +SKIP + Holiday: New Years Day ( + month=1, day=1, observance= + ) + + >>> July3rd = pd.tseries.holiday.Holiday( + ... "July 3rd", month=7, day=3, days_of_week=(0, 1, 2, 3) + ... ) + >>> July3rd + Holiday: July 3rd (month=7, day=3, ) + """ + if offset is not None: + if observance is not None: + raise NotImplementedError("Cannot use both offset and observance.") + if not ( + isinstance(offset, BaseOffset) + or ( + isinstance(offset, list) + and all(isinstance(off, BaseOffset) for off in offset) + ) + ): + raise ValueError( + "Only BaseOffsets and flat lists of them are supported for offset." + ) + + self.name = name + self.year = year + self.month = month + self.day = day + self.offset = offset + self.start_date = ( + Timestamp(start_date) if start_date is not None else start_date + ) + self.end_date = Timestamp(end_date) if end_date is not None else end_date + self.observance = observance + if not (days_of_week is None or isinstance(days_of_week, tuple)): + raise ValueError("days_of_week must be None or tuple.") + self.days_of_week = days_of_week + if not (exclude_dates is None or isinstance(exclude_dates, DatetimeIndex)): + raise ValueError("exclude_dates must be None or of type DatetimeIndex.") + self.exclude_dates = exclude_dates + + def __repr__(self) -> str: + info = "" + if self.year is not None: + info += f"year={self.year}, " + info += f"month={self.month}, day={self.day}, " + + if self.offset is not None: + info += f"offset={self.offset}" + + if self.observance is not None: + info += f"observance={self.observance}" + + repr = f"Holiday: {self.name} ({info})" + return repr + + @overload + def dates(self, start_date, end_date, return_name: Literal[True]) -> Series: ... + + @overload + def dates( + self, start_date, end_date, return_name: Literal[False] + ) -> DatetimeIndex: ... + + @overload + def dates(self, start_date, end_date) -> DatetimeIndex: ... + + def dates( + self, start_date, end_date, return_name: bool = False + ) -> Series | DatetimeIndex: + """ + Calculate holidays observed between start date and end date + + Parameters + ---------- + start_date : starting date, datetime-like, optional + end_date : ending date, datetime-like, optional + return_name : bool, optional, default=False + If True, return a series that has dates and holiday names. + False will only return dates. + + Returns + ------- + Series or DatetimeIndex + Series if return_name is True + """ + start_date = Timestamp(start_date) + end_date = Timestamp(end_date) + + filter_start_date = start_date + filter_end_date = end_date + + if self.year is not None: + dt = Timestamp(datetime(self.year, self.month, self.day)) + dti = DatetimeIndex([dt]) + if return_name: + return Series(self.name, index=dti) + else: + return dti + + dates = self._reference_dates(start_date, end_date) + holiday_dates = self._apply_rule(dates) + if self.days_of_week is not None: + holiday_dates = holiday_dates[ + np.isin( + # error: "DatetimeIndex" has no attribute "dayofweek" + holiday_dates.dayofweek, # type: ignore[attr-defined] + self.days_of_week, + ).ravel() + ] + + if self.start_date is not None: + filter_start_date = max( + self.start_date.tz_localize(filter_start_date.tz), filter_start_date + ) + if self.end_date is not None: + filter_end_date = min( + self.end_date.tz_localize(filter_end_date.tz), filter_end_date + ) + holiday_dates = holiday_dates[ + (holiday_dates >= filter_start_date) & (holiday_dates <= filter_end_date) + ] + + if self.exclude_dates is not None: + holiday_dates = holiday_dates.difference(self.exclude_dates) + if return_name: + return Series(self.name, index=holiday_dates) + return holiday_dates + + def _reference_dates( + self, start_date: Timestamp, end_date: Timestamp + ) -> DatetimeIndex: + """ + Get reference dates for the holiday. + + Return reference dates for the holiday also returning the year + prior to the start_date and year following the end_date. This ensures + that any offsets to be applied will yield the holidays within + the passed in dates. + """ + if self.start_date is not None: + start_date = self.start_date.tz_localize(start_date.tz) + + if self.end_date is not None: + end_date = self.end_date.tz_localize(start_date.tz) + + year_offset = DateOffset(years=1) + reference_start_date = Timestamp( + datetime(start_date.year - 1, self.month, self.day) + ) + + reference_end_date = Timestamp( + datetime(end_date.year + 1, self.month, self.day) + ) + # Don't process unnecessary holidays + dates = date_range( + start=reference_start_date, + end=reference_end_date, + freq=year_offset, + tz=start_date.tz, + ) + + return dates + + def _apply_rule(self, dates: DatetimeIndex) -> DatetimeIndex: + """ + Apply the given offset/observance to a DatetimeIndex of dates. + + Parameters + ---------- + dates : DatetimeIndex + Dates to apply the given offset/observance rule + + Returns + ------- + Dates with rules applied + """ + if dates.empty: + return dates.copy() + + if self.observance is not None: + return dates.map(lambda d: self.observance(d)) + + if self.offset is not None: + if not isinstance(self.offset, list): + offsets = [self.offset] + else: + offsets = self.offset + for offset in offsets: + # if we are adding a non-vectorized value + # ignore the PerformanceWarnings: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", PerformanceWarning) + dates += offset + return dates + + +holiday_calendars: dict[str, type[AbstractHolidayCalendar]] = {} + + +def register(cls) -> None: + try: + name = cls.name + except AttributeError: + name = cls.__name__ + holiday_calendars[name] = cls + + +def get_calendar(name: str) -> AbstractHolidayCalendar: + """ + Return an instance of a calendar based on its name. + + Parameters + ---------- + name : str + Calendar name to return an instance of + """ + return holiday_calendars[name]() + + +class HolidayCalendarMetaClass(type): + def __new__(cls, clsname: str, bases, attrs): + calendar_class = super().__new__(cls, clsname, bases, attrs) + register(calendar_class) + return calendar_class + + +class AbstractHolidayCalendar(metaclass=HolidayCalendarMetaClass): + """ + Abstract interface to create holidays following certain rules. + """ + + rules: list[Holiday] = [] + start_date = Timestamp(datetime(1970, 1, 1)) + end_date = Timestamp(datetime(2200, 12, 31)) + _cache: tuple[Timestamp, Timestamp, Series] | None = None + + def __init__(self, name: str = "", rules=None) -> None: + """ + Initializes holiday object with a given set a rules. Normally + classes just have the rules defined within them. + + Parameters + ---------- + name : str + Name of the holiday calendar, defaults to class name + rules : array of Holiday objects + A set of rules used to create the holidays. + """ + super().__init__() + if not name: + name = type(self).__name__ + self.name = name + + if rules is not None: + self.rules = rules + + def rule_from_name(self, name: str) -> Holiday | None: + for rule in self.rules: + if rule.name == name: + return rule + + return None + + def holidays( + self, start=None, end=None, return_name: bool = False + ) -> DatetimeIndex | Series: + """ + Returns a curve with holidays between start_date and end_date + + Parameters + ---------- + start : starting date, datetime-like, optional + end : ending date, datetime-like, optional + return_name : bool, optional + If True, return a series that has dates and holiday names. + False will only return a DatetimeIndex of dates. + + Returns + ------- + DatetimeIndex of holidays + """ + if self.rules is None: + raise Exception( + f"Holiday Calendar {self.name} does not have any rules specified" + ) + + if start is None: + start = AbstractHolidayCalendar.start_date + + if end is None: + end = AbstractHolidayCalendar.end_date + + start = Timestamp(start) + end = Timestamp(end) + + # If we don't have a cache or the dates are outside the prior cache, we + # get them again + if self._cache is None or start < self._cache[0] or end > self._cache[1]: + pre_holidays = [ + rule.dates(start, end, return_name=True) for rule in self.rules + ] + if pre_holidays: + holidays = concat(pre_holidays) + else: + holidays = Series(index=DatetimeIndex([]), dtype=object) + + self._cache = (start, end, holidays.sort_index()) + + holidays = self._cache[2] + holidays = holidays[start:end] + + if return_name: + return holidays + else: + return holidays.index + + @staticmethod + def merge_class(base, other): + """ + Merge holiday calendars together. The base calendar + will take precedence to other. The merge will be done + based on each holiday's name. + + Parameters + ---------- + base : AbstractHolidayCalendar + instance/subclass or array of Holiday objects + other : AbstractHolidayCalendar + instance/subclass or array of Holiday objects + """ + try: + other = other.rules + except AttributeError: + pass + + if not isinstance(other, list): + other = [other] + other_holidays = {holiday.name: holiday for holiday in other} + + try: + base = base.rules + except AttributeError: + pass + + if not isinstance(base, list): + base = [base] + base_holidays = {holiday.name: holiday for holiday in base} + + other_holidays.update(base_holidays) + return list(other_holidays.values()) + + def merge(self, other, inplace: bool = False): + """ + Merge holiday calendars together. The caller's class + rules take precedence. The merge will be done + based on each holiday's name. + + Parameters + ---------- + other : holiday calendar + inplace : bool (default=False) + If True set rule_table to holidays, else return array of Holidays + """ + holidays = self.merge_class(self, other) + if inplace: + self.rules = holidays + else: + return holidays + + +USMemorialDay = Holiday( + "Memorial Day", month=5, day=31, offset=DateOffset(weekday=MO(-1)) +) +USLaborDay = Holiday("Labor Day", month=9, day=1, offset=DateOffset(weekday=MO(1))) +USColumbusDay = Holiday( + "Columbus Day", month=10, day=1, offset=DateOffset(weekday=MO(2)) +) +USThanksgivingDay = Holiday( + "Thanksgiving Day", month=11, day=1, offset=DateOffset(weekday=TH(4)) +) +USMartinLutherKingJr = Holiday( + "Birthday of Martin Luther King, Jr.", + start_date=datetime(1986, 1, 1), + month=1, + day=1, + offset=DateOffset(weekday=MO(3)), +) +USPresidentsDay = Holiday( + "Washington's Birthday", month=2, day=1, offset=DateOffset(weekday=MO(3)) +) +GoodFriday = Holiday("Good Friday", month=1, day=1, offset=[Easter(), Day(-2)]) + +EasterMonday = Holiday("Easter Monday", month=1, day=1, offset=[Easter(), Day(1)]) + + +class USFederalHolidayCalendar(AbstractHolidayCalendar): + """ + US Federal Government Holiday Calendar based on rules specified by: + https://www.opm.gov/policy-data-oversight/pay-leave/federal-holidays/ + """ + + rules = [ + Holiday("New Year's Day", month=1, day=1, observance=nearest_workday), + USMartinLutherKingJr, + USPresidentsDay, + USMemorialDay, + Holiday( + "Juneteenth National Independence Day", + month=6, + day=19, + start_date="2021-06-18", + observance=nearest_workday, + ), + Holiday("Independence Day", month=7, day=4, observance=nearest_workday), + USLaborDay, + USColumbusDay, + Holiday("Veterans Day", month=11, day=11, observance=nearest_workday), + USThanksgivingDay, + Holiday("Christmas Day", month=12, day=25, observance=nearest_workday), + ] + + +def HolidayCalendarFactory(name: str, base, other, base_class=AbstractHolidayCalendar): + rules = AbstractHolidayCalendar.merge_class(base, other) + calendar_class = type(name, (base_class,), {"rules": rules, "name": name}) + return calendar_class + + +__all__ = [ + "FR", + "MO", + "SA", + "SU", + "TH", + "TU", + "WE", + "HolidayCalendarFactory", + "after_nearest_workday", + "before_nearest_workday", + "get_calendar", + "nearest_workday", + "next_monday", + "next_monday_or_tuesday", + "next_workday", + "previous_friday", + "previous_workday", + "register", + "sunday_to_monday", + "weekend_to_monday", +] diff --git a/pandas/tseries/offsets.py b/pandas/tseries/offsets.py new file mode 100644 index 0000000000000000000000000000000000000000..1f0c4281ffc773ae49f7d9190609cbf2e57f8564 --- /dev/null +++ b/pandas/tseries/offsets.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from pandas._libs.tslibs.offsets import ( + FY5253, + BaseOffset, + BDay, + BHalfYearBegin, + BHalfYearEnd, + BMonthBegin, + BMonthEnd, + BQuarterBegin, + BQuarterEnd, + BusinessDay, + BusinessHour, + BusinessMonthBegin, + BusinessMonthEnd, + BYearBegin, + BYearEnd, + CBMonthBegin, + CBMonthEnd, + CDay, + CustomBusinessDay, + CustomBusinessHour, + CustomBusinessMonthBegin, + CustomBusinessMonthEnd, + DateOffset, + Day, + Easter, + FY5253Quarter, + HalfYearBegin, + HalfYearEnd, + Hour, + LastWeekOfMonth, + Micro, + Milli, + Minute, + MonthBegin, + MonthEnd, + Nano, + QuarterBegin, + QuarterEnd, + Second, + SemiMonthBegin, + SemiMonthEnd, + Tick, + Week, + WeekOfMonth, + YearBegin, + YearEnd, +) + +__all__ = [ + "FY5253", + "BDay", + "BHalfYearBegin", + "BHalfYearEnd", + "BMonthBegin", + "BMonthEnd", + "BQuarterBegin", + "BQuarterEnd", + "BYearBegin", + "BYearEnd", + "BaseOffset", + "BusinessDay", + "BusinessHour", + "BusinessMonthBegin", + "BusinessMonthEnd", + "CBMonthBegin", + "CBMonthEnd", + "CDay", + "CustomBusinessDay", + "CustomBusinessHour", + "CustomBusinessMonthBegin", + "CustomBusinessMonthEnd", + "DateOffset", + "Day", + "Easter", + "FY5253Quarter", + "HalfYearBegin", + "HalfYearEnd", + "Hour", + "LastWeekOfMonth", + "Micro", + "Milli", + "Minute", + "MonthBegin", + "MonthEnd", + "Nano", + "QuarterBegin", + "QuarterEnd", + "Second", + "SemiMonthBegin", + "SemiMonthEnd", + "Tick", + "Week", + "WeekOfMonth", + "YearBegin", + "YearEnd", +] diff --git a/pandas/util/__init__.py b/pandas/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a462080f328f4c7111a45147beb09f866597b063 --- /dev/null +++ b/pandas/util/__init__.py @@ -0,0 +1,29 @@ +def __getattr__(key: str): + # These imports need to be lazy to avoid circular import errors + if key == "hash_array": + from pandas.core.util.hashing import hash_array + + return hash_array + if key == "hash_pandas_object": + from pandas.core.util.hashing import hash_pandas_object + + return hash_pandas_object + if key == "Appender": + from pandas.util._decorators import Appender + + return Appender + if key == "Substitution": + from pandas.util._decorators import Substitution + + return Substitution + + if key == "cache_readonly": + from pandas.util._decorators import cache_readonly + + return cache_readonly + + raise AttributeError(f"module 'pandas.util' has no attribute '{key}'") + + +def __dir__() -> list[str]: + return [*list(globals().keys()), "hash_array", "hash_pandas_object"] diff --git a/pandas/util/_decorators.py b/pandas/util/_decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..dd2ed6c00e48c8531838b9987c39c1131374961a --- /dev/null +++ b/pandas/util/_decorators.py @@ -0,0 +1,532 @@ +from __future__ import annotations + +from functools import wraps +import inspect +from textwrap import dedent +from typing import ( + TYPE_CHECKING, + Any, + cast, +) +import warnings + +from pandas._libs.properties import cache_readonly +from pandas._typing import ( + F, + T, +) +from pandas.util._exceptions import find_stack_level + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Mapping, + ) + + from pandas.errors import PandasChangeWarning + + +def deprecate( + klass: type[Warning], + name: str, + alternative: Callable[..., Any], + version: str, + alt_name: str | None = None, + stacklevel: int = 2, + msg: str | None = None, +) -> Callable[[F], F]: + """ + Return a new function that emits a deprecation warning on use. + + To use this method for a deprecated function, another function + `alternative` with the same signature must exist. The deprecated + function will emit a deprecation warning, and in the docstring + it will contain the deprecation directive with the provided version + so it can be detected for future removal. + + Parameters + ---------- + klass : Warning + The warning class to use. + name : str + Name of function to deprecate. + alternative : func + Function to use instead. + version : str + Version of pandas in which the method has been deprecated. + alt_name : str, optional + Name to use in preference of alternative.__name__. + stacklevel : int, default 2 + msg : str + The message to display in the warning. + Default is '{name} is deprecated. Use {alt_name} instead.' + """ + alt_name = alt_name or alternative.__name__ + warning_msg = msg or f"{name} is deprecated, use {alt_name} instead." + + @wraps(alternative) + def wrapper(*args, **kwargs) -> Callable[..., Any]: + warnings.warn(warning_msg, klass, stacklevel=stacklevel) + return alternative(*args, **kwargs) + + # adding deprecated directive to the docstring + msg = msg or f"Use `{alt_name}` instead." + doc_error_msg = ( + "deprecate needs a correctly formatted docstring in " + "the target function (should have a one liner short " + "summary, and opening quotes should be in their own " + f"line). Found:\n{alternative.__doc__}" + ) + + # when python is running in optimized mode (i.e. `-OO`), docstrings are + # removed, so we check that a docstring with correct formatting is used + # but we allow empty docstrings + if alternative.__doc__: + if alternative.__doc__.count("\n") < 3: + raise AssertionError(doc_error_msg) + empty1, summary, empty2, doc_string = alternative.__doc__.split("\n", 3) + if empty1 or (empty2 and not summary): + raise AssertionError(doc_error_msg) + wrapper.__doc__ = dedent( + f""" + {summary.strip()} + + .. deprecated:: {version} + {msg} + + {dedent(doc_string)}""" + ) + # error: Incompatible return value type (got "Callable[[VarArg(Any), KwArg(Any)], + # Callable[...,Any]]", expected "Callable[[F], F]") + return wrapper # type: ignore[return-value] + + +def deprecate_kwarg( + klass: type[Warning], + old_arg_name: str, + new_arg_name: str | None, + mapping: Mapping[Any, Any] | Callable[[Any], Any] | None = None, + stacklevel: int = 2, +) -> Callable[[F], F]: + """ + Decorator to deprecate a keyword argument of a function. + + Parameters + ---------- + klass : Warning + The warning class to use. + old_arg_name : str + Name of argument in function to deprecate. + new_arg_name : str or None + Name of preferred argument in function. Use None to raise warning that + ``old_arg_name`` keyword is deprecated. + mapping : dict or callable + If mapping is present, use it to translate old arguments to + new arguments. A callable must do its own value checking; + values not found in a dict will be forwarded unchanged. + stacklevel : int, default 2 + + Examples + -------- + The following deprecates 'cols', using 'columns' instead + + >>> @deprecate_kwarg(FutureWarning, old_arg_name="cols", new_arg_name="columns") + ... def f(columns=""): + ... print(columns) + >>> f(columns="should work ok") + should work ok + + >>> f(cols="should raise warning") # doctest: +SKIP + FutureWarning: cols is deprecated, use columns instead + warnings.warn(msg, FutureWarning) + should raise warning + + >>> f(cols="should error", columns="can't pass do both") # doctest: +SKIP + TypeError: Can only specify 'cols' or 'columns', not both + + >>> @deprecate_kwarg(FutureWarning, "old", "new", {"yes": True, "no": False}) + ... def f(new=False): + ... print("yes!" if new else "no!") + >>> f(old="yes") # doctest: +SKIP + FutureWarning: old='yes' is deprecated, use new=True instead + warnings.warn(msg, FutureWarning) + yes! + + To raise a warning that a keyword will be removed entirely in the future + + >>> @deprecate_kwarg(FutureWarning, old_arg_name="cols", new_arg_name=None) + ... def f(cols="", another_param=""): + ... print(cols) + >>> f(cols="should raise warning") # doctest: +SKIP + FutureWarning: the 'cols' keyword is deprecated and will be removed in a + future version. Please take steps to stop the use of 'cols' + should raise warning + >>> f(another_param="should not raise warning") # doctest: +SKIP + should not raise warning + + >>> f(cols="should raise warning", another_param="") # doctest: +SKIP + FutureWarning: the 'cols' keyword is deprecated and will be removed in a + future version. Please take steps to stop the use of 'cols' + should raise warning + """ + if mapping is not None and not hasattr(mapping, "get") and not callable(mapping): + raise TypeError( + "mapping from old to new argument values must be dict or callable!" + ) + + def _deprecate_kwarg(func: F) -> F: + @wraps(func) + def wrapper(*args, **kwargs) -> Callable[..., Any]: + __tracebackhide__ = True + + old_arg_value = kwargs.pop(old_arg_name, None) + + if old_arg_value is not None: + if new_arg_name is None: + msg = ( + f"the {old_arg_name!r} keyword is deprecated and " + "will be removed in a future version. Please take " + f"steps to stop the use of {old_arg_name!r}" + ) + warnings.warn(msg, klass, stacklevel=stacklevel) + kwargs[old_arg_name] = old_arg_value + return func(*args, **kwargs) + + elif mapping is not None: + if callable(mapping): + new_arg_value = mapping(old_arg_value) + else: + new_arg_value = mapping.get(old_arg_value, old_arg_value) + msg = ( + f"the {old_arg_name}={old_arg_value!r} keyword is " + "deprecated, use " + f"{new_arg_name}={new_arg_value!r} instead." + ) + else: + new_arg_value = old_arg_value + msg = ( + f"the {old_arg_name!r} keyword is deprecated, " + f"use {new_arg_name!r} instead." + ) + + warnings.warn(msg, klass, stacklevel=stacklevel) + if kwargs.get(new_arg_name) is not None: + msg = ( + f"Can only specify {old_arg_name!r} " + f"or {new_arg_name!r}, not both." + ) + raise TypeError(msg) + kwargs[new_arg_name] = new_arg_value + return func(*args, **kwargs) + + return cast(F, wrapper) + + return _deprecate_kwarg + + +def _format_argument_list(allow_args: list[str]) -> str: + """ + Convert the allow_args argument (either string or integer) of + `deprecate_nonkeyword_arguments` function to a string describing + it to be inserted into warning message. + + Parameters + ---------- + allowed_args : list, tuple or int + The `allowed_args` argument for `deprecate_nonkeyword_arguments`, + but None value is not allowed. + + Returns + ------- + str + The substring describing the argument list in best way to be + inserted to the warning message. + + Examples + -------- + `format_argument_list([])` -> '' + `format_argument_list(['a'])` -> "except for the arguments 'a'" + `format_argument_list(['a', 'b'])` -> "except for the arguments 'a' and 'b'" + `format_argument_list(['a', 'b', 'c'])` -> + "except for the arguments 'a', 'b' and 'c'" + """ + if "self" in allow_args: + allow_args.remove("self") + if not allow_args: + return "" + elif len(allow_args) == 1: + return f" except for the argument '{allow_args[0]}'" + else: + last = allow_args[-1] + args = ", ".join(["'" + x + "'" for x in allow_args[:-1]]) + return f" except for the arguments {args} and '{last}'" + + +def future_version_msg(version: str | None) -> str: + """Specify which version of pandas the deprecation will take place in.""" + if version is None: + return "In a future version of pandas" + else: + return f"Starting with pandas version {version}" + + +def deprecate_nonkeyword_arguments( + klass: type[PandasChangeWarning], + allowed_args: list[str] | None = None, + name: str | None = None, +) -> Callable[[F], F]: + """ + Decorator to deprecate a use of non-keyword arguments of a function. + + Parameters + ---------- + klass : Warning + The warning class to use. + allowed_args : list, optional + In case of list, it must be the list of names of some + first arguments of the decorated functions that are + OK to be given as positional arguments. In case of None value, + defaults to list of all arguments not having the + default value. + name : str, optional + The specific name of the function to show in the warning + message. If None, then the Qualified name of the function + is used. + """ + + def decorate(func): + old_sig = inspect.signature(func) + + if allowed_args is not None: + allow_args = allowed_args + else: + allow_args = [ + p.name + for p in old_sig.parameters.values() + if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) + and p.default is p.empty + ] + + new_params = [ + p.replace(kind=p.KEYWORD_ONLY) + if ( + p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) + and p.name not in allow_args + ) + else p + for p in old_sig.parameters.values() + ] + new_params.sort(key=lambda p: p.kind) + new_sig = old_sig.replace(parameters=new_params) + + num_allow_args = len(allow_args) + msg = ( + f"{future_version_msg(klass.version())} all arguments of " + f"{name or func.__qualname__}{{arguments}} will be keyword-only." + ) + + @wraps(func) + def wrapper(*args, **kwargs): + if len(args) > num_allow_args: + warnings.warn( + msg.format(arguments=_format_argument_list(allow_args)), + klass, + stacklevel=find_stack_level(), + ) + return func(*args, **kwargs) + + # error: "Callable[[VarArg(Any), KwArg(Any)], Any]" has no + # attribute "__signature__" + wrapper.__signature__ = new_sig # type: ignore[attr-defined] + return wrapper + + return decorate + + +def doc(*docstrings: None | str | Callable, **params: object) -> Callable[[F], F]: + """ + A decorator to take docstring templates, concatenate them and perform string + substitution on them. + + This decorator will add a variable "_docstring_components" to the wrapped + callable to keep track the original docstring template for potential usage. + If it should be consider as a template, it will be saved as a string. + Otherwise, it will be saved as callable, and later user __doc__ and dedent + to get docstring. + + Parameters + ---------- + *docstrings : None, str, or callable + The string / docstring / docstring template to be appended in order + after default docstring under callable. + **params + The string which would be used to format docstring template. + """ + + def decorator(decorated: F) -> F: + # collecting docstring and docstring templates + docstring_components: list[str | Callable] = [] + if decorated.__doc__: + docstring_components.append(dedent(decorated.__doc__)) + + for docstring in docstrings: + if docstring is None: + continue + if hasattr(docstring, "_docstring_components"): + docstring_components.extend( + docstring._docstring_components # pyright: ignore[reportAttributeAccessIssue] + ) + elif isinstance(docstring, str) or docstring.__doc__: + docstring_components.append(docstring) + + params_applied = [ + component.format(**params) + if isinstance(component, str) and len(params) > 0 + else component + for component in docstring_components + ] + + decorated.__doc__ = "".join( + [ + component + if isinstance(component, str) + else dedent(component.__doc__ or "") + for component in params_applied + ] + ) + + # error: "F" has no attribute "_docstring_components" + decorated._docstring_components = ( # type: ignore[attr-defined] + docstring_components + ) + return decorated + + return decorator + + +# Substitution and Appender are derived from matplotlib.docstring (1.1.0) +# module https://matplotlib.org/users/license.html + + +class Substitution: + """ + A decorator to take a function's docstring and perform string + substitution on it. + + This decorator should be robust even if func.__doc__ is None + (for example, if -OO was passed to the interpreter) + + Usage: construct a docstring.Substitution with a sequence or + dictionary suitable for performing substitution; then + decorate a suitable function with the constructed object. e.g. + + sub_author_name = Substitution(author='Jason') + + @sub_author_name + def some_function(x): + "%(author)s wrote this function" + + # note that some_function.__doc__ is now "Jason wrote this function" + + One can also use positional arguments. + + sub_first_last_names = Substitution('Edgar Allen', 'Poe') + + @sub_first_last_names + def some_function(x): + "%s %s wrote the Raven" + """ + + def __init__(self, *args, **kwargs) -> None: + if args and kwargs: + raise AssertionError("Only positional or keyword args are allowed") + + self.params = args or kwargs + + def __call__(self, func: F) -> F: + func.__doc__ = func.__doc__ and func.__doc__ % self.params + return func + + def update(self, *args, **kwargs) -> None: + """ + Update self.params with supplied args. + """ + if isinstance(self.params, dict): + self.params.update(*args, **kwargs) + + +class Appender: + """ + A function decorator that will append an addendum to the docstring + of the target function. + + This decorator should be robust even if func.__doc__ is None + (for example, if -OO was passed to the interpreter). + + Usage: construct a docstring.Appender with a string to be joined to + the original docstring. An optional 'join' parameter may be supplied + which will be used to join the docstring and addendum. e.g. + + add_copyright = Appender("Copyright (c) 2009", join='\n') + + @add_copyright + def my_dog(has='fleas'): + "This docstring will have a copyright below" + pass + """ + + addendum: str | None + + def __init__(self, addendum: str | None, join: str = "", indents: int = 0) -> None: + if indents > 0: + self.addendum = indent(addendum, indents=indents) + else: + self.addendum = addendum + self.join = join + + def __call__(self, func: T) -> T: + func.__doc__ = func.__doc__ if func.__doc__ else "" + self.addendum = self.addendum if self.addendum else "" + docitems = [func.__doc__, self.addendum] + func.__doc__ = dedent(self.join.join(docitems)) + return func + + +def indent(text: str | None, indents: int = 1) -> str: + if not text or not isinstance(text, str): + return "" + jointext = "".join(["\n"] + [" "] * indents) + return jointext.join(text.split("\n")) + + +__all__ = [ + "Appender", + "Substitution", + "cache_readonly", + "deprecate", + "deprecate_kwarg", + "deprecate_nonkeyword_arguments", + "doc", + "future_version_msg", +] + + +def set_module(module) -> Callable[[F], F]: + """Private decorator for overriding __module__ on a function or class. + + Example usage:: + + @set_module("pandas") + def example(): + pass + + + assert example.__module__ == "pandas" + """ + + def decorator(func: F) -> F: + if module is not None: + func.__module__ = module + return func + + return decorator diff --git a/pandas/util/_doctools.py b/pandas/util/_doctools.py new file mode 100644 index 0000000000000000000000000000000000000000..61bb456aec59fa7005c4e3ccadad3b9d9f0b2e37 --- /dev/null +++ b/pandas/util/_doctools.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +import pandas as pd + +if TYPE_CHECKING: + from collections.abc import Iterable + + from matplotlib.figure import Figure + + +class TablePlotter: + """ + Layout some DataFrames in vertical/horizontal layout for explanation. + Used in merging.rst + """ + + def __init__( + self, + cell_width: float = 0.37, + cell_height: float = 0.25, + font_size: float = 7.5, + ) -> None: + self.cell_width = cell_width + self.cell_height = cell_height + self.font_size = font_size + + def _shape(self, df: pd.DataFrame) -> tuple[int, int]: + """ + Calculate table shape considering index levels. + """ + row, col = df.shape + return row + df.columns.nlevels, col + df.index.nlevels + + def _get_cells(self, left, right, vertical) -> tuple[int, int]: + """ + Calculate appropriate figure size based on left and right data. + """ + if vertical: + # calculate required number of cells + vcells = max(sum(self._shape(df)[0] for df in left), self._shape(right)[0]) + hcells = max(self._shape(df)[1] for df in left) + self._shape(right)[1] + else: + vcells = max([self._shape(df)[0] for df in left] + [self._shape(right)[0]]) + hcells = sum([self._shape(df)[1] for df in left] + [self._shape(right)[1]]) + return hcells, vcells + + def plot( + self, left, right, labels: Iterable[str] = (), vertical: bool = True + ) -> Figure: + """ + Plot left / right DataFrames in specified layout. + + Parameters + ---------- + left : list of DataFrames before operation is applied + right : DataFrame of operation result + labels : list of str to be drawn as titles of left DataFrames + vertical : bool, default True + If True, use vertical layout. If False, use horizontal layout. + """ + from matplotlib import gridspec + import matplotlib.pyplot as plt + + if not isinstance(left, list): + left = [left] + left = [self._conv(df) for df in left] + right = self._conv(right) + + hcells, vcells = self._get_cells(left, right, vertical) + + if vertical: + figsize = self.cell_width * hcells, self.cell_height * vcells + else: + # include margin for titles + figsize = self.cell_width * hcells, self.cell_height * vcells + fig = plt.figure(figsize=figsize) + + if vertical: + gs = gridspec.GridSpec(len(left), hcells) + # left + max_left_cols = max(self._shape(df)[1] for df in left) + max_left_rows = max(self._shape(df)[0] for df in left) + for i, (_left, _label) in enumerate(zip(left, labels, strict=True)): + ax = fig.add_subplot(gs[i, 0:max_left_cols]) + self._make_table(ax, _left, title=_label, height=1.0 / max_left_rows) + # right + ax = plt.subplot(gs[:, max_left_cols:]) + self._make_table(ax, right, title="Result", height=1.05 / vcells) + fig.subplots_adjust(top=0.9, bottom=0.05, left=0.05, right=0.95) + else: + max_rows = max(self._shape(df)[0] for df in [*left, right]) + height = 1.0 / np.max(max_rows) + gs = gridspec.GridSpec(1, hcells) + # left + i = 0 + for df, _label in zip(left, labels, strict=True): + sp = self._shape(df) + ax = fig.add_subplot(gs[0, i : i + sp[1]]) + self._make_table(ax, df, title=_label, height=height) + i += sp[1] + # right + ax = plt.subplot(gs[0, i:]) + self._make_table(ax, right, title="Result", height=height) + fig.subplots_adjust(top=0.85, bottom=0.05, left=0.05, right=0.95) + + return fig + + def _conv(self, data): + """ + Convert each input to appropriate for table outplot. + """ + if isinstance(data, pd.Series): + if data.name is None: + data = data.to_frame(name="") + else: + data = data.to_frame() + data = data.fillna("NaN") + return data + + def _insert_index(self, data): + # insert is destructive + data = data.copy() + idx_nlevels = data.index.nlevels + if idx_nlevels == 1: + data.insert(0, "Index", data.index) + else: + for i in range(idx_nlevels): + data.insert(i, f"Index{i}", data.index._get_level_values(i)) + + col_nlevels = data.columns.nlevels + if col_nlevels > 1: + col = data.columns._get_level_values(0) + values = [ + data.columns._get_level_values(i)._values for i in range(1, col_nlevels) + ] + col_df = pd.DataFrame(values) + data.columns = col_df.columns + data = pd.concat([col_df, data]) + data.columns = col + return data + + def _make_table(self, ax, df, title: str, height: float | None = None) -> None: + if df is None: + ax.set_visible(False) + return + + from pandas import plotting + + idx_nlevels = df.index.nlevels + col_nlevels = df.columns.nlevels + # must be convert here to get index levels for colorization + df = self._insert_index(df) + tb = plotting.table(ax, df, loc=9) + tb.set_fontsize(self.font_size) + + if height is None: + height = 1.0 / (len(df) + 1) + + props = tb.properties() + for (r, c), cell in props["celld"].items(): + if c == -1: + cell.set_visible(False) + elif r < col_nlevels and c < idx_nlevels: + cell.set_visible(False) + elif r < col_nlevels or c < idx_nlevels: + cell.set_facecolor("#AAAAAA") + cell.set_height(height) + + ax.set_title(title, size=self.font_size) + ax.axis("off") + + +def main() -> None: + import matplotlib.pyplot as plt + + p = TablePlotter() + + df1 = pd.DataFrame({"A": [10, 11, 12], "B": [20, 21, 22], "C": [30, 31, 32]}) + df2 = pd.DataFrame({"A": [10, 12], "C": [30, 32]}) + + p.plot([df1, df2], pd.concat([df1, df2]), labels=["df1", "df2"], vertical=True) + plt.show() + + df3 = pd.DataFrame({"X": [10, 12], "Z": [30, 32]}) + + p.plot( + [df1, df3], pd.concat([df1, df3], axis=1), labels=["df1", "df2"], vertical=False + ) + plt.show() + + idx = pd.MultiIndex.from_tuples( + [(1, "A"), (1, "B"), (1, "C"), (2, "A"), (2, "B"), (2, "C")] + ) + column = pd.MultiIndex.from_tuples([(1, "A"), (1, "B")]) + df3 = pd.DataFrame({"v1": [1, 2, 3, 4, 5, 6], "v2": [5, 6, 7, 8, 9, 10]}, index=idx) + df3.columns = column + p.plot(df3, df3, labels=["df3"]) + plt.show() + + +if __name__ == "__main__": + main() diff --git a/pandas/util/_exceptions.py b/pandas/util/_exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..b3c8e54d3ca7f778f9ba27c6f3c7ebb59b8a980a --- /dev/null +++ b/pandas/util/_exceptions.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import contextlib +import inspect +import os +import re +from typing import ( + TYPE_CHECKING, + Any, +) +import warnings + +if TYPE_CHECKING: + from collections.abc import Generator + from types import FrameType + + +@contextlib.contextmanager +def rewrite_exception(old_name: str, new_name: str) -> Generator[None]: + """ + Rewrite the message of an exception. + """ + try: + yield + except Exception as err: + if not err.args: + raise + msg = str(err.args[0]) + msg = msg.replace(old_name, new_name) + args: tuple[Any, ...] = (msg,) + if len(err.args) > 1: + args = args + err.args[1:] + err.args = args + raise + + +def find_stack_level() -> int: + """ + Find the first place in the stack that is not inside pandas + (tests notwithstanding). + """ + + import pandas as pd + + pkg_dir = os.path.dirname(pd.__file__) + test_dir = os.path.join(pkg_dir, "tests") + + # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow + frame: FrameType | None = inspect.currentframe() + try: + n = 0 + while frame: + filename = inspect.getfile(frame) + if filename.startswith(pkg_dir) and not filename.startswith(test_dir): + frame = frame.f_back + n += 1 + else: + break + finally: + # See note in + # https://docs.python.org/3/library/inspect.html#inspect.Traceback + del frame + return n + + +@contextlib.contextmanager +def rewrite_warning( + target_message: str, + target_category: type[Warning], + new_message: str, + new_category: type[Warning] | None = None, +) -> Generator[None]: + """ + Rewrite the message of a warning. + + Parameters + ---------- + target_message : str + Warning message to match. + target_category : Warning + Warning type to match. + new_message : str + New warning message to emit. + new_category : Warning or None, default None + New warning type to emit. When None, will be the same as target_category. + """ + if new_category is None: + new_category = target_category + with warnings.catch_warnings(record=True) as record: + yield + if len(record) > 0: + match = re.compile(target_message) + for warning in record: + if warning.category is target_category and re.search( + match, str(warning.message) + ): + category = new_category + message: Warning | str = new_message + else: + category, message = warning.category, warning.message + warnings.warn_explicit( + message=message, + category=category, + filename=warning.filename, + lineno=warning.lineno, + ) diff --git a/pandas/util/_print_versions.py b/pandas/util/_print_versions.py new file mode 100644 index 0000000000000000000000000000000000000000..8a97c700802289af49c3f9301fe12f21fc63ec4c --- /dev/null +++ b/pandas/util/_print_versions.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import json +import locale +import os +import platform +import struct +import sys +from typing import TYPE_CHECKING + +from pandas.util._decorators import set_module + +if TYPE_CHECKING: + from pandas._typing import JSONSerializable + +from pandas.compat._optional import ( + VERSIONS, + get_version, + import_optional_dependency, +) + + +def _get_commit_hash() -> str | None: + """ + Use vendored versioneer code to get git hash, which handles + git worktree correctly. + """ + try: + from pandas._version_meson import ( # pyright: ignore [reportMissingImports] + __git_version__, + ) + + return __git_version__ + except ImportError: + from pandas._version import get_versions + + versions = get_versions() + return versions["full-revisionid"] + + +def _get_sys_info() -> dict[str, JSONSerializable]: + """ + Returns system information as a JSON serializable dictionary. + """ + uname_result = platform.uname() + language_code, encoding = locale.getlocale() + return { + "commit": _get_commit_hash(), + "python": platform.python_version(), + "python-bits": struct.calcsize("P") * 8, + "OS": uname_result.system, + "OS-release": uname_result.release, + "Version": uname_result.version, + "machine": uname_result.machine, + "processor": uname_result.processor, + "byteorder": sys.byteorder, + "LC_ALL": os.environ.get("LC_ALL"), + "LANG": os.environ.get("LANG"), + "LOCALE": {"language-code": language_code, "encoding": encoding}, + } + + +def _get_dependency_info() -> dict[str, JSONSerializable]: + """ + Returns dependency information as a JSON serializable dictionary. + """ + deps = [ + "pandas", + # required + "numpy", + "dateutil", + # install / build, + "pip", + "Cython", + # docs + "sphinx", + # Other, not imported. + "IPython", + ] + # Optional dependencies + deps.extend(list(VERSIONS)) + + result: dict[str, JSONSerializable] = {} + for modname in deps: + try: + mod = import_optional_dependency(modname, errors="ignore") + except Exception: + # Dependency conflicts may cause a non ImportError + result[modname] = "N/A" + else: + result[modname] = get_version(mod) if mod else None + return result + + +@set_module("pandas") +def show_versions(as_json: str | bool = False) -> None: + """ + Provide useful information, important for bug reports. + + It comprises info about hosting operation system, pandas version, + and versions of other installed relative packages. + + Parameters + ---------- + as_json : str or bool, default False + * If False, outputs info in a human readable form to the console. + * If str, it will be considered as a path to a file. + Info will be written to that file in JSON format. + * If True, outputs info in JSON format to the console. + + See Also + -------- + get_option : Retrieve the value of the specified option. + set_option : Set the value of the specified option or options. + + Examples + -------- + >>> pd.show_versions() # doctest: +SKIP + Your output may look something like this: + INSTALLED VERSIONS + ------------------ + commit : 37ea63d540fd27274cad6585082c91b1283f963d + python : 3.10.6.final.0 + python-bits : 64 + OS : Linux + OS-release : 5.10.102.1-microsoft-standard-WSL2 + Version : #1 SMP Wed Mar 2 00:30:59 UTC 2022 + machine : x86_64 + processor : x86_64 + byteorder : little + LC_ALL : None + LANG : en_GB.UTF-8 + LOCALE : en_GB.UTF-8 + pandas : 2.0.1 + numpy : 1.24.3 + ... + """ + sys_info = _get_sys_info() + deps = _get_dependency_info() + + if as_json: + j = {"system": sys_info, "dependencies": deps} + + if as_json is True: + sys.stdout.writelines(json.dumps(j, indent=2)) + else: + assert isinstance(as_json, str) # needed for mypy + with open(as_json, "w", encoding="utf-8") as f: + json.dump(j, f, indent=2) + + else: + assert isinstance(sys_info["LOCALE"], dict) # needed for mypy + language_code = sys_info["LOCALE"]["language-code"] + encoding = sys_info["LOCALE"]["encoding"] + sys_info["LOCALE"] = f"{language_code}.{encoding}" + + maxlen = max(len(x) for x in deps) + print("\nINSTALLED VERSIONS") + print("------------------") + for k, v in sys_info.items(): + print(f"{k:<{maxlen}}: {v}") + print("") + for k, v in deps.items(): + print(f"{k:<{maxlen}}: {v}") diff --git a/pandas/util/_test_decorators.py b/pandas/util/_test_decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..3c5c349a5a93df4175973c0f224d274ebf15f99d --- /dev/null +++ b/pandas/util/_test_decorators.py @@ -0,0 +1,152 @@ +""" +This module provides decorator functions which can be applied to test objects +in order to skip those objects when certain conditions occur. A sample use case +is to detect if the platform is missing ``matplotlib``. If so, any test objects +which require ``matplotlib`` and decorated with ``@td.skip_if_no("matplotlib")`` +will be skipped by ``pytest`` during the execution of the test suite. + +To illustrate, after importing this module: + +import pandas.util._test_decorators as td + +The decorators can be applied to classes: + +@td.skip_if_no("package") +class Foo: + ... + +Or individual functions: + +@td.skip_if_no("package") +def test_foo(): + ... + +For more information, refer to the ``pytest`` documentation on ``skipif``. +""" + +from __future__ import annotations + +import locale +import sys +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from collections.abc import Callable + from pandas._typing import F + +from pandas.compat import ( + IS64, + WASM, + is_platform_windows, +) +from pandas.compat._optional import import_optional_dependency + + +def skip_if_installed(package: str) -> pytest.MarkDecorator: + """ + Skip a test if a package is installed. + + Parameters + ---------- + package : str + The name of the package. + + Returns + ------- + pytest.MarkDecorator + a pytest.mark.skipif to use as either a test decorator or a + parametrization mark. + """ + return pytest.mark.skipif( + bool(import_optional_dependency(package, errors="ignore")), + reason=f"Skipping because {package} is installed.", + ) + + +def skip_if_no(package: str, min_version: str | None = None) -> pytest.MarkDecorator: + """ + Generic function to help skip tests when required packages are not + present on the testing system. + + This function returns a pytest mark with a skip condition that will be + evaluated during test collection. An attempt will be made to import the + specified ``package`` and optionally ensure it meets the ``min_version`` + + The mark can be used as either a decorator for a test class or to be + applied to parameters in pytest.mark.parametrize calls or parametrized + fixtures. Use pytest.importorskip if an imported moduled is later needed + or for test functions. + + If the import and version check are unsuccessful, then the test function + (or test case when used in conjunction with parametrization) will be + skipped. + + Parameters + ---------- + package: str + The name of the required package. + min_version: str or None, default None + Optional minimum version of the package. + + Returns + ------- + pytest.MarkDecorator + a pytest.mark.skipif to use as either a test decorator or a + parametrization mark. + """ + msg = f"Could not import '{package}'" + if min_version: + msg += f" satisfying a min_version of {min_version}" + return pytest.mark.skipif( + not bool( + import_optional_dependency( + package, errors="ignore", min_version=min_version + ) + ), + reason=msg, + ) + + +skip_if_32bit = pytest.mark.skipif(not IS64, reason="skipping for 32 bit") +skip_if_windows = pytest.mark.skipif(is_platform_windows(), reason="Running on Windows") +skip_if_not_us_locale = pytest.mark.skipif( + locale.getlocale()[0] != "en_US", + reason=f"Set local {locale.getlocale()[0]} is not en_US", +) +skip_if_wasm = pytest.mark.skipif( + WASM, + reason="does not support wasm", +) +skip_if_thread_unsafe_warnings = pytest.mark.skipif( + not getattr(sys.flags, "context_aware_warnings", 0), + reason="Python warnings must be thread-safe for consistent results", +) + + +def parametrize_fixture_doc(*args) -> Callable[[F], F]: + """ + Intended for use as a decorator for parametrized fixture, + this function will wrap the decorated function with a pytest + ``parametrize_fixture_doc`` mark. That mark will format + initial fixture docstring by replacing placeholders {0}, {1} etc + with parameters passed as arguments. + + Parameters + ---------- + args: iterable + Positional arguments for docstring. + + Returns + ------- + function + The decorated function wrapped within a pytest + ``parametrize_fixture_doc`` mark + """ + + def documented_fixture(fixture): + fixture.__doc__ = fixture.__doc__.format(*args) + return fixture + + return documented_fixture diff --git a/pandas/util/_tester.py b/pandas/util/_tester.py new file mode 100644 index 0000000000000000000000000000000000000000..e69ec8e123b0857af4d81118c57d2e337d7e7e50 --- /dev/null +++ b/pandas/util/_tester.py @@ -0,0 +1,60 @@ +""" +Entrypoint for testing from the top-level namespace. +""" + +from __future__ import annotations + +import os +import sys + +from pandas.compat._optional import import_optional_dependency +from pandas.util._decorators import set_module + +PKG = os.path.dirname(os.path.dirname(__file__)) + + +@set_module("pandas") +def test(extra_args: list[str] | None = None, run_doctests: bool = False) -> None: # noqa: PT028 + """ + Run the pandas test suite using pytest. + + By default, runs with the marks -m "not slow and not network and not db" + + Parameters + ---------- + extra_args : list[str], default None + Extra marks to run the tests. + run_doctests : bool, default False + Whether to only run the Python and Cython doctests. If you would like to run + both doctests/regular tests, just append "--doctest-modules"/"--doctest-cython" + to extra_args. + + See Also + -------- + pytest.main : The main entry point for pytest testing framework. + + Examples + -------- + >>> pd.test() # doctest: +SKIP + running: pytest... + """ + pytest = import_optional_dependency("pytest") + import_optional_dependency("hypothesis") + cmd = ["-m not slow and not network and not db"] + if extra_args: + if not isinstance(extra_args, list): + extra_args = [extra_args] + cmd = extra_args + if run_doctests: + cmd = [ + "--doctest-modules", + "--doctest-cython", + f"--ignore={os.path.join(PKG, 'tests')}", + ] + cmd += [PKG] + joined = " ".join(cmd) + print(f"running: pytest {joined}") + sys.exit(pytest.main(cmd)) + + +__all__ = ["test"] diff --git a/pandas/util/_validators.py b/pandas/util/_validators.py new file mode 100644 index 0000000000000000000000000000000000000000..9097875782d227db0a7b342fb0fd096800c5accb --- /dev/null +++ b/pandas/util/_validators.py @@ -0,0 +1,482 @@ +""" +Module that contains many useful utilities +for validating data or function arguments +""" + +from __future__ import annotations + +from collections.abc import ( + Iterable, + Sequence, +) +from typing import ( + TypeVar, + overload, +) + +import numpy as np + +from pandas._libs import lib +from pandas._libs.missing import NA + +from pandas.core.dtypes.common import ( + is_bool, + is_integer, +) + +BoolishT = TypeVar("BoolishT", bool, int) +BoolishNoneT = TypeVar("BoolishNoneT", bool, int, None) + + +def _check_arg_length(fname, args, max_fname_arg_count, compat_args) -> None: + """ + Checks whether 'args' has length of at most 'compat_args'. Raises + a TypeError if that is not the case, similar to in Python when a + function is called with too many arguments. + """ + if max_fname_arg_count < 0: + raise ValueError("'max_fname_arg_count' must be non-negative") + + if len(args) > len(compat_args): + max_arg_count = len(compat_args) + max_fname_arg_count + actual_arg_count = len(args) + max_fname_arg_count + argument = "argument" if max_arg_count == 1 else "arguments" + + raise TypeError( + f"{fname}() takes at most {max_arg_count} {argument} " + f"({actual_arg_count} given)" + ) + + +def _check_for_default_values(fname, arg_val_dict, compat_args) -> None: + """ + Check that the keys in `arg_val_dict` are mapped to their + default values as specified in `compat_args`. + + Note that this function is to be called only when it has been + checked that arg_val_dict.keys() is a subset of compat_args + """ + for key in arg_val_dict: + # try checking equality directly with '=' operator, + # as comparison may have been overridden for the left + # hand object + try: + v1 = arg_val_dict[key] + v2 = compat_args[key] + + # check for None-ness otherwise we could end up + # comparing a numpy array vs None + if (v1 is not None and v2 is None) or (v1 is None and v2 is not None): + match = False + else: + match = v1 == v2 + + if not is_bool(match): + raise ValueError("'match' is not a boolean") + + # could not compare them directly, so try comparison + # using the 'is' operator + except ValueError: + match = arg_val_dict[key] is compat_args[key] + + if not match: + raise ValueError( + f"the '{key}' parameter is not supported in " + f"the pandas implementation of {fname}()" + ) + + +def validate_args(fname, args, max_fname_arg_count, compat_args) -> None: + """ + Checks whether the length of the `*args` argument passed into a function + has at most `len(compat_args)` arguments and whether or not all of these + elements in `args` are set to their default values. + + Parameters + ---------- + fname : str + The name of the function being passed the `*args` parameter + args : tuple + The `*args` parameter passed into a function + max_fname_arg_count : int + The maximum number of arguments that the function `fname` + can accept, excluding those in `args`. Used for displaying + appropriate error messages. Must be non-negative. + compat_args : dict + A dictionary of keys and their associated default values. + In order to accommodate buggy behaviour in some versions of `numpy`, + where a signature displayed keyword arguments but then passed those + arguments **positionally** internally when calling downstream + implementations, a dict ensures that the original + order of the keyword arguments is enforced. + + Raises + ------ + TypeError + If `args` contains more values than there are `compat_args` + ValueError + If `args` contains values that do not correspond to those + of the default values specified in `compat_args` + """ + _check_arg_length(fname, args, max_fname_arg_count, compat_args) + + # We do this so that we can provide a more informative + # error message about the parameters that we are not + # supporting in the pandas implementation of 'fname' + kwargs = dict(zip(compat_args, args, strict=False)) + _check_for_default_values(fname, kwargs, compat_args) + + +def _check_for_invalid_keys(fname, kwargs, compat_args) -> None: + """ + Checks whether 'kwargs' contains any keys that are not + in 'compat_args' and raises a TypeError if there is one. + """ + # set(dict) --> set of the dictionary's keys + diff = set(kwargs) - set(compat_args) + + if diff: + bad_arg = next(iter(diff)) + raise TypeError(f"{fname}() got an unexpected keyword argument '{bad_arg}'") + + +def validate_kwargs(fname, kwargs, compat_args) -> None: + """ + Checks whether parameters passed to the **kwargs argument in a + function `fname` are valid parameters as specified in `*compat_args` + and whether or not they are set to their default values. + + Parameters + ---------- + fname : str + The name of the function being passed the `**kwargs` parameter + kwargs : dict + The `**kwargs` parameter passed into `fname` + compat_args: dict + A dictionary of keys that `kwargs` is allowed to have and their + associated default values + + Raises + ------ + TypeError if `kwargs` contains keys not in `compat_args` + ValueError if `kwargs` contains keys in `compat_args` that do not + map to the default values specified in `compat_args` + """ + kwds = kwargs.copy() + _check_for_invalid_keys(fname, kwargs, compat_args) + _check_for_default_values(fname, kwds, compat_args) + + +def validate_args_and_kwargs( + fname, args, kwargs, max_fname_arg_count, compat_args +) -> None: + """ + Checks whether parameters passed to the *args and **kwargs argument in a + function `fname` are valid parameters as specified in `*compat_args` + and whether or not they are set to their default values. + + Parameters + ---------- + fname: str + The name of the function being passed the `**kwargs` parameter + args: tuple + The `*args` parameter passed into a function + kwargs: dict + The `**kwargs` parameter passed into `fname` + max_fname_arg_count: int + The minimum number of arguments that the function `fname` + requires, excluding those in `args`. Used for displaying + appropriate error messages. Must be non-negative. + compat_args: dict + A dictionary of keys that `kwargs` is allowed to + have and their associated default values. + + Raises + ------ + TypeError if `args` contains more values than there are + `compat_args` OR `kwargs` contains keys not in `compat_args` + ValueError if `args` contains values not at the default value (`None`) + `kwargs` contains keys in `compat_args` that do not map to the default + value as specified in `compat_args` + + See Also + -------- + validate_args : Purely args validation. + validate_kwargs : Purely kwargs validation. + + """ + # Check that the total number of arguments passed in (i.e. + # args and kwargs) does not exceed the length of compat_args + _check_arg_length( + fname, args + tuple(kwargs.values()), max_fname_arg_count, compat_args + ) + + # Check there is no overlap with the positional and keyword + # arguments, similar to what is done in actual Python functions + args_dict = dict(zip(compat_args, args, strict=False)) + + for key in args_dict: + if key in kwargs: + raise TypeError( + f"{fname}() got multiple values for keyword argument '{key}'" + ) + + kwargs.update(args_dict) + validate_kwargs(fname, kwargs, compat_args) + + +def validate_bool_kwarg( + value: BoolishNoneT, + arg_name: str, + none_allowed: bool = True, + int_allowed: bool = False, +) -> BoolishNoneT: + """ + Ensure that argument passed in arg_name can be interpreted as boolean. + + Parameters + ---------- + value : bool + Value to be validated. + arg_name : str + Name of the argument. To be reflected in the error message. + none_allowed : bool, default True + Whether to consider None to be a valid boolean. + int_allowed : bool, default False + Whether to consider integer value to be a valid boolean. + + Returns + ------- + value + The same value as input. + + Raises + ------ + ValueError + If the value is not a valid boolean. + """ + good_value = is_bool(value) + if none_allowed: + good_value = good_value or (value is None) + + if int_allowed: + good_value = good_value or isinstance(value, int) + + if not good_value: + raise ValueError( + f'For argument "{arg_name}" expected type bool, received ' + f"type {type(value).__name__}." + ) + return value + + +def validate_na_arg(value, name: str): + """ + Validate na arguments. + + Parameters + ---------- + value : object + Value to validate. + name : str + Name of the argument, used to raise an informative error message. + + Raises + ______ + ValueError + When ``value`` is determined to be invalid. + """ + if ( + value is lib.no_default + or isinstance(value, bool) + or value is None + or value is NA + or (lib.is_float(value) and np.isnan(value)) + ): + return + raise ValueError(f"{name} must be None, pd.NA, np.nan, True, or False; got {value}") + + +def validate_fillna_kwargs(value, method, validate_scalar_dict_value: bool = True): + """ + Validate the keyword arguments to 'fillna'. + + This checks that exactly one of 'value' and 'method' is specified. + If 'method' is specified, this validates that it's a valid method. + + Parameters + ---------- + value, method : object + The 'value' and 'method' keyword arguments for 'fillna'. + validate_scalar_dict_value : bool, default True + Whether to validate that 'value' is a scalar or dict. Specifically, + validate that it is not a list or tuple. + + Returns + ------- + value, method : object + """ + from pandas.core.missing import clean_fill_method + + if value is None and method is None: + raise ValueError("Must specify a fill 'value' or 'method'.") + if value is None and method is not None: + method = clean_fill_method(method) + + elif value is not None and method is None: + if validate_scalar_dict_value and isinstance(value, (list, tuple)): + raise TypeError( + '"value" parameter must be a scalar or dict, but ' + f'you passed a "{type(value).__name__}"' + ) + + elif value is not None and method is not None: + raise ValueError("Cannot specify both 'value' and 'method'.") + + return value, method + + +def validate_percentile(q: float | Iterable[float]) -> np.ndarray: + """ + Validate percentiles (used by describe and quantile). + + This function checks if the given float or iterable of floats is a valid percentile + otherwise raises a ValueError. + + Parameters + ---------- + q: float or iterable of floats + A single percentile or an iterable of percentiles. + + Returns + ------- + ndarray + An ndarray of the percentiles if valid. + + Raises + ------ + ValueError if percentiles are not in given interval([0, 1]). + """ + q_arr = np.asarray(q) + # Don't change this to an f-string. The string formatting + # is too expensive for cases where we don't need it. + msg = "percentiles should all be in the interval [0, 1]" + if q_arr.ndim == 0: + if not 0 <= q_arr <= 1: + raise ValueError(msg) + elif not all(0 <= qs <= 1 for qs in q_arr): + raise ValueError(msg) + return q_arr + + +@overload +def validate_ascending(ascending: BoolishT) -> BoolishT: ... + + +@overload +def validate_ascending(ascending: Sequence[BoolishT]) -> list[BoolishT]: ... + + +def validate_ascending( + ascending: bool | int | Sequence[BoolishT], +) -> bool | int | list[BoolishT]: + """Validate ``ascending`` kwargs for ``sort_index`` method.""" + kwargs = {"none_allowed": False, "int_allowed": True} + if not isinstance(ascending, Sequence): + return validate_bool_kwarg(ascending, "ascending", **kwargs) + + return [validate_bool_kwarg(item, "ascending", **kwargs) for item in ascending] + + +def validate_endpoints(closed: str | None) -> tuple[bool, bool]: + """ + Check that the `closed` argument is among [None, "left", "right"] + + Parameters + ---------- + closed : {None, "left", "right"} + + Returns + ------- + left_closed : bool + right_closed : bool + + Raises + ------ + ValueError : if argument is not among valid values + """ + left_closed = False + right_closed = False + + if closed is None: + left_closed = True + right_closed = True + elif closed == "left": + left_closed = True + elif closed == "right": + right_closed = True + else: + raise ValueError("Closed has to be either 'left', 'right' or None") + + return left_closed, right_closed + + +def validate_inclusive(inclusive: str | None) -> tuple[bool, bool]: + """ + Check that the `inclusive` argument is among {"both", "neither", "left", "right"}. + + Parameters + ---------- + inclusive : {"both", "neither", "left", "right"} + + Returns + ------- + left_right_inclusive : tuple[bool, bool] + + Raises + ------ + ValueError : if argument is not among valid values + """ + left_right_inclusive: tuple[bool, bool] | None = None + + if isinstance(inclusive, str): + left_right_inclusive = { + "both": (True, True), + "left": (True, False), + "right": (False, True), + "neither": (False, False), + }.get(inclusive) + + if left_right_inclusive is None: + raise ValueError( + "Inclusive has to be either 'both', 'neither', 'left' or 'right'" + ) + + return left_right_inclusive + + +def validate_insert_loc(loc: int, length: int) -> int: + """ + Check that we have an integer between -length and length, inclusive. + + Standardize negative loc to within [0, length]. + + The exceptions we raise on failure match np.insert. + """ + if not is_integer(loc): + raise TypeError(f"loc must be an integer between -{length} and {length}") + + if loc < 0: + loc += length + if not 0 <= loc <= length: + raise IndexError(f"loc must be an integer between -{length} and {length}") + return loc # pyright: ignore[reportReturnType] + + +def check_dtype_backend(dtype_backend) -> None: + if dtype_backend is not lib.no_default: + if dtype_backend not in ["numpy_nullable", "pyarrow"]: + raise ValueError( + f"dtype_backend {dtype_backend} is invalid, only 'numpy_nullable' and " + f"'pyarrow' are allowed.", + ) diff --git a/pyarrow/include/arrow/acero/accumulation_queue.h b/pyarrow/include/arrow/acero/accumulation_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..b0e0b85a4f3d0504ad0e09237e498c001c55f96a --- /dev/null +++ b/pyarrow/include/arrow/acero/accumulation_queue.h @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/acero/visibility.h" +#include "arrow/compute/exec.h" +#include "arrow/result.h" + +namespace arrow { +namespace acero { +namespace util { + +using arrow::compute::ExecBatch; + +/// \brief A container that accumulates batches until they are ready to +/// be processed. +class ARROW_ACERO_EXPORT AccumulationQueue { + public: + AccumulationQueue() : row_count_(0) {} + ~AccumulationQueue() = default; + + // We should never be copying ExecBatch around + AccumulationQueue(const AccumulationQueue&) = delete; + AccumulationQueue& operator=(const AccumulationQueue&) = delete; + + AccumulationQueue(AccumulationQueue&& that); + AccumulationQueue& operator=(AccumulationQueue&& that); + + void Concatenate(AccumulationQueue&& that); + void InsertBatch(ExecBatch batch); + int64_t row_count() { return row_count_; } + size_t batch_count() { return batches_.size(); } + bool empty() const { return batches_.empty(); } + void Clear(); + ExecBatch& operator[](size_t i); + + private: + int64_t row_count_; + std::vector batches_; +}; + +/// A queue that sequences incoming batches +/// +/// This can be used when a node needs to do some kind of ordered processing on +/// the stream. +/// +/// Batches can be inserted in any order. The process_callback will be called on +/// the batches, in order, without reentrant calls. For this reason the callback +/// should be quick. +/// +/// For example, in a top-n node, the process callback should determine how many +/// rows need to be delivered for the given batch, and then return a task to actually +/// deliver those rows. +class ARROW_ACERO_EXPORT SequencingQueue { + public: + using Task = std::function; + + /// Strategy that describes how to handle items + class Processor { + public: + /// Process the batch, potentially generating a task + /// + /// This method will be called on each batch in order. Calls to this method + /// will be serialized and it will not be called reentrantly. This makes it + /// safe to do things that rely on order but minimal time should be spent here + /// to avoid becoming a bottleneck. + /// + /// \return a follow-up task that will be scheduled. The follow-up task(s) are + /// is not guaranteed to run in any particular order. If nullopt is + /// returned then nothing will be scheduled. + virtual Result> Process(ExecBatch batch) = 0; + /// Schedule a task + virtual void Schedule(Task task) = 0; + }; + + virtual ~SequencingQueue() = default; + + /// Insert a batch into the queue + /// + /// This will insert the batch into the queue. If this batch was the next batch + /// to deliver then this will trigger 1+ calls to the process callback to generate + /// 1+ tasks. + /// + /// The task generated by this call will be executed immediately. The remaining + /// tasks will be scheduled using the schedule callback. + /// + /// From a data pipeline perspective the sequencing queue is a "sometimes" breaker. If + /// a task arrives in order then this call will usually execute the downstream pipeline. + /// If this task arrives early then this call will only queue the data. + virtual Status InsertBatch(ExecBatch batch) = 0; + + /// Create a queue + /// \param processor describes how to process the batches, must outlive the queue + static std::unique_ptr Make(Processor* processor); +}; + +/// A queue that sequences incoming batches +/// +/// Unlike SequencingQueue the Process method is not expected to schedule new tasks. +/// +/// If a batch arrives and another thread is currently processing then the batch +/// will be queued and control will return. In other words, delivery of batches will +/// not block on the Process method. +/// +/// It can be helpful to think of this as if a dedicated thread is running Process as +/// batches arrive +class ARROW_ACERO_EXPORT SerialSequencingQueue { + public: + /// Strategy that describes how to handle items + class Processor { + public: + virtual ~Processor() = default; + /// Process the batch + /// + /// This method will be called on each batch in order. Calls to this method + /// will be serialized and it will not be called reentrantly. This makes it + /// safe to do things that rely on order. + /// + /// If this falls behind then data may accumulate + /// + /// TODO: Could add backpressure if needed but right now all uses of this should + /// be pretty fast and so are unlikely to block. + virtual Status Process(ExecBatch batch) = 0; + }; + + virtual ~SerialSequencingQueue() = default; + + /// Insert a batch into the queue + /// + /// This will insert the batch into the queue. If this batch was the next batch + /// to deliver then this may trigger calls to the processor which will be run + /// as part of this call. + virtual Status InsertBatch(ExecBatch batch) = 0; + + /// Create a queue + /// \param processor describes how to process the batches, must outlive the queue + static std::unique_ptr Make(Processor* processor); +}; + +} // namespace util +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/aggregate_node.h b/pyarrow/include/arrow/acero/aggregate_node.h new file mode 100644 index 0000000000000000000000000000000000000000..0c6fea16a8acc75046309708221189d368f605c0 --- /dev/null +++ b/pyarrow/include/arrow/acero/aggregate_node.h @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include + +#include "arrow/acero/visibility.h" +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/test_util_internal.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/result.h" +#include "arrow/type_fwd.h" + +namespace arrow { +namespace acero { +namespace aggregate { + +using compute::Aggregate; +using compute::default_exec_context; +using compute::ExecContext; + +/// \brief Make the output schema of an aggregate node +/// +/// The output schema is determined by the aggregation kernels, which may depend on the +/// ExecContext argument. To guarantee correct results, the same ExecContext argument +/// should be used in execution. +/// +/// \param[in] input_schema the schema of the input to the node +/// \param[in] keys the grouping keys for the aggregation +/// \param[in] segment_keys the segmenting keys for the aggregation +/// \param[in] aggregates the aggregates for the aggregation +/// \param[in] exec_ctx the execution context for the aggregation +ARROW_ACERO_EXPORT Result> MakeOutputSchema( + const std::shared_ptr& input_schema, const std::vector& keys, + const std::vector& segment_keys, const std::vector& aggregates, + ExecContext* exec_ctx = default_exec_context()); + +} // namespace aggregate +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/api.h b/pyarrow/include/arrow/acero/api.h new file mode 100644 index 0000000000000000000000000000000000000000..c9724fd512d0b56dfa3a24647b3885677c92b534 --- /dev/null +++ b/pyarrow/include/arrow/acero/api.h @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +/// \defgroup acero-api Utilities for creating and executing execution plans +/// @{ +/// @} + +/// \defgroup acero-nodes Options classes for the various exec nodes +/// @{ +/// @} + +#include "arrow/acero/exec_plan.h" +#include "arrow/acero/options.h" diff --git a/pyarrow/include/arrow/acero/asof_join_node.h b/pyarrow/include/arrow/acero/asof_join_node.h new file mode 100644 index 0000000000000000000000000000000000000000..6a0ce8fd386b01ac868bac3d4d026a309e351cb3 --- /dev/null +++ b/pyarrow/include/arrow/acero/asof_join_node.h @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/acero/options.h" +#include "arrow/acero/visibility.h" +#include "arrow/compute/exec.h" +#include "arrow/type.h" + +namespace arrow { +namespace acero { +namespace asofjoin { + +using AsofJoinKeys = AsofJoinNodeOptions::Keys; + +/// \brief Make the output schema of an as-of-join node +/// +/// \param[in] input_schema the schema of each input to the node +/// \param[in] input_keys the key of each input to the node +ARROW_ACERO_EXPORT Result> MakeOutputSchema( + const std::vector>& input_schema, + const std::vector& input_keys); + +} // namespace asofjoin +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/backpressure_handler.h b/pyarrow/include/arrow/acero/backpressure_handler.h new file mode 100644 index 0000000000000000000000000000000000000000..c6a47e60197a51f85c2279f00ff8851c78a264f5 --- /dev/null +++ b/pyarrow/include/arrow/acero/backpressure_handler.h @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once +#include "arrow/acero/exec_plan.h" +#include "arrow/acero/options.h" + +#include + +namespace arrow::acero { + +class BackpressureHandler { + private: + BackpressureHandler(size_t low_threshold, size_t high_threshold, + std::unique_ptr backpressure_control) + : low_threshold_(low_threshold), + high_threshold_(high_threshold), + backpressure_control_(std::move(backpressure_control)) {} + + public: + static Result Make( + size_t low_threshold, size_t high_threshold, + std::unique_ptr backpressure_control) { + if (low_threshold >= high_threshold) { + return Status::Invalid("low threshold (", low_threshold, + ") must be less than high threshold (", high_threshold, ")"); + } + if (backpressure_control == NULLPTR) { + return Status::Invalid("null backpressure control parameter"); + } + BackpressureHandler backpressure_handler(low_threshold, high_threshold, + std::move(backpressure_control)); + return backpressure_handler; + } + + void Handle(size_t start_level, size_t end_level) { + if (start_level < high_threshold_ && end_level >= high_threshold_) { + backpressure_control_->Pause(); + } else if (start_level > low_threshold_ && end_level <= low_threshold_) { + backpressure_control_->Resume(); + } + } + + private: + size_t low_threshold_; + size_t high_threshold_; + std::unique_ptr backpressure_control_; +}; + +} // namespace arrow::acero diff --git a/pyarrow/include/arrow/acero/benchmark_util.h b/pyarrow/include/arrow/acero/benchmark_util.h new file mode 100644 index 0000000000000000000000000000000000000000..0ba8553887c03f876b6e08f031f5641170c2e09f --- /dev/null +++ b/pyarrow/include/arrow/acero/benchmark_util.h @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "benchmark/benchmark.h" + +#include "arrow/acero/exec_plan.h" +#include "arrow/acero/test_util_internal.h" +#include "arrow/compute/exec.h" + +namespace arrow { + +namespace acero { + +Status BenchmarkNodeOverhead(benchmark::State& state, int32_t num_batches, + int32_t batch_size, arrow::acero::BatchesWithSchema data, + std::vector& node_declarations, + arrow::MemoryPool* pool = default_memory_pool()); + +Status BenchmarkIsolatedNodeOverhead(benchmark::State& state, + arrow::compute::Expression expr, int32_t num_batches, + int32_t batch_size, + arrow::acero::BatchesWithSchema data, + std::string factory_name, + arrow::acero::ExecNodeOptions& options, + arrow::MemoryPool* pool = default_memory_pool()); + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/bloom_filter.h b/pyarrow/include/arrow/acero/bloom_filter.h new file mode 100644 index 0000000000000000000000000000000000000000..8f9fe171baeb39f5347d112921666ba057cb56b6 --- /dev/null +++ b/pyarrow/include/arrow/acero/bloom_filter.h @@ -0,0 +1,323 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/acero/partition_util.h" +#include "arrow/acero/util.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/simd.h" + +namespace arrow { +namespace acero { + +// A set of pre-generated bit masks from a 64-bit word. +// +// It is used to map selected bits of hash to a bit mask that will be used in +// a Bloom filter. +// +// These bit masks need to look random and need to have a similar fractions of +// bits set in order for a Bloom filter to have a low false positives rate. +// +struct ARROW_ACERO_EXPORT BloomFilterMasks { + // Generate all masks as a single bit vector. Each bit offset in this bit + // vector corresponds to a single mask. + // In each consecutive kBitsPerMask bits, there must be between + // kMinBitsSet and kMaxBitsSet bits set. + // + BloomFilterMasks(); + + inline uint64_t mask(int bit_offset) { +#if ARROW_LITTLE_ENDIAN + return (arrow::util::SafeLoadAs(masks_ + bit_offset / 8) >> + (bit_offset % 8)) & + kFullMask; +#else + return (BYTESWAP(arrow::util::SafeLoadAs(masks_ + bit_offset / 8)) >> + (bit_offset % 8)) & + kFullMask; +#endif + } + + // Masks are 57 bits long because then they can be accessed at an + // arbitrary bit offset using a single unaligned 64-bit load instruction. + // + static constexpr int kBitsPerMask = 57; + static constexpr uint64_t kFullMask = (1ULL << kBitsPerMask) - 1; + + // Minimum and maximum number of bits set in each mask. + // This constraint is enforced when generating the bit masks. + // Values should be close to each other and chosen as to minimize a Bloom + // filter false positives rate. + // + static constexpr int kMinBitsSet = 4; + static constexpr int kMaxBitsSet = 5; + + // Number of generated masks. + // Having more masks to choose will improve false positives rate of Bloom + // filter but will also use more memory, which may lead to more CPU cache + // misses. + // The chosen value results in using only a few cache-lines for mask lookups, + // while providing a good variety of available bit masks. + // + static constexpr int kLogNumMasks = 10; + static constexpr int kNumMasks = 1 << kLogNumMasks; + + // Data of masks. Masks are stored in a single bit vector. Nth mask is + // kBitsPerMask bits starting at bit offset N. + // + static constexpr int kTotalBytes = (kNumMasks + 64) / 8; + uint8_t masks_[kTotalBytes]; +}; + +// A variant of a blocked Bloom filter implementation. +// A Bloom filter is a data structure that provides approximate membership test +// functionality based only on the hash of the key. Membership test may return +// false positives but not false negatives. Approximation of the result allows +// in general case (for arbitrary data types of keys) to save on both memory and +// lookup cost compared to the accurate membership test. +// The accurate test may sometimes still be cheaper for a specific data types +// and inputs, e.g. integers from a small range. +// +// This blocked Bloom filter is optimized for use in hash joins, to achieve a +// good balance between the size of the filter, the cost of its building and +// querying and the rate of false positives. +// +class ARROW_ACERO_EXPORT BlockedBloomFilter { + friend class BloomFilterBuilder_SingleThreaded; + friend class BloomFilterBuilder_Parallel; + + public: + BlockedBloomFilter() : log_num_blocks_(0), num_blocks_(0), blocks_(NULLPTR) {} + + inline bool Find(uint64_t hash) const { + uint64_t m = mask(hash); + uint64_t b = blocks_[block_id(hash)]; + return (b & m) == m; + } + + // Uses SIMD if available for smaller Bloom filters. + // Uses memory prefetching for larger Bloom filters. + // + void Find(int64_t hardware_flags, int64_t num_rows, const uint32_t* hashes, + uint8_t* result_bit_vector, bool enable_prefetch = true) const; + void Find(int64_t hardware_flags, int64_t num_rows, const uint64_t* hashes, + uint8_t* result_bit_vector, bool enable_prefetch = true) const; + + int log_num_blocks() const { return log_num_blocks_; } + + int NumHashBitsUsed() const; + + bool IsSameAs(const BlockedBloomFilter* other) const; + + int64_t NumBitsSet() const; + + // Folding of a block Bloom filter after the initial version + // has been built. + // + // One of the parameters for creation of Bloom filter is the number + // of bits allocated for it. The more bits allocated, the lower the + // probability of false positives. A good heuristic is to aim for + // half of the bits set in the constructed Bloom filter. This should + // result in a good trade off between size (and following cost of + // memory accesses) and false positives rate. + // + // There might have been many duplicate keys in the input provided + // to Bloom filter builder. In that case the resulting bit vector + // would be more sparse then originally intended. It is possible to + // easily correct that and cut in half the size of Bloom filter + // after it has already been constructed. The process to do that is + // approximately equal to OR-ing bits from upper and lower half (the + // way we address these bits when inserting or querying a hash makes + // such folding in half possible). + // + // We will keep folding as long as the fraction of bits set is less + // than 1/4. The resulting bit vector density should be in the [1/4, + // 1/2) range. + // + void Fold(); + + private: + Status CreateEmpty(int64_t num_rows_to_insert, MemoryPool* pool); + + inline void Insert(uint64_t hash) { + uint64_t m = mask(hash); + uint64_t& b = blocks_[block_id(hash)]; + b |= m; + } + + void Insert(int64_t hardware_flags, int64_t num_rows, const uint32_t* hashes); + void Insert(int64_t hardware_flags, int64_t num_rows, const uint64_t* hashes); + + inline uint64_t mask(uint64_t hash) const { + // The lowest bits of hash are used to pick mask index. + // + int mask_id = static_cast(hash & (BloomFilterMasks::kNumMasks - 1)); + uint64_t result = masks_.mask(mask_id); + + // The next set of hash bits is used to pick the amount of bit + // rotation of the mask. + // + int rotation = (hash >> BloomFilterMasks::kLogNumMasks) & 63; + result = ROTL64(result, rotation); + + return result; + } + + inline int64_t block_id(uint64_t hash) const { + // The next set of hash bits following the bits used to select a + // mask is used to pick block id (index of 64-bit word in a bit + // vector). + // + return (hash >> (BloomFilterMasks::kLogNumMasks + 6)) & (num_blocks_ - 1); + } + + template + inline void InsertImp(int64_t num_rows, const T* hashes); + + template + inline void FindImp(int64_t num_rows, const T* hashes, uint8_t* result_bit_vector, + bool enable_prefetch) const; + + void SingleFold(int num_folds); + +#if defined(ARROW_HAVE_RUNTIME_AVX2) + inline __m256i mask_avx2(__m256i hash) const; + inline __m256i block_id_avx2(__m256i hash) const; + int64_t Insert_avx2(int64_t num_rows, const uint32_t* hashes); + int64_t Insert_avx2(int64_t num_rows, const uint64_t* hashes); + template + int64_t InsertImp_avx2(int64_t num_rows, const T* hashes); + int64_t Find_avx2(int64_t num_rows, const uint32_t* hashes, + uint8_t* result_bit_vector) const; + int64_t Find_avx2(int64_t num_rows, const uint64_t* hashes, + uint8_t* result_bit_vector) const; + template + int64_t FindImp_avx2(int64_t num_rows, const T* hashes, + uint8_t* result_bit_vector) const; +#endif + + bool UsePrefetch() const { + return num_blocks_ * sizeof(uint64_t) > kPrefetchLimitBytes; + } + + static constexpr int64_t kPrefetchLimitBytes = 256 * 1024; + + static BloomFilterMasks masks_; + + // Total number of bits used by block Bloom filter must be a power + // of 2. + // + int log_num_blocks_; + int64_t num_blocks_; + + // Buffer allocated to store an array of power of 2 64-bit blocks. + // + std::shared_ptr buf_; + // Pointer to mutable data owned by Buffer + // + uint64_t* blocks_; +}; + +// We have two separate implementations of building a Bloom filter, multi-threaded and +// single-threaded. +// +// Single threaded version is useful in two ways: +// a) It allows to verify parallel implementation in tests (the single threaded one is +// simpler and can be used as the source of truth). +// b) It is preferred for small and medium size Bloom filters, because it skips extra +// synchronization related steps from parallel variant (partitioning and taking locks). +// +enum class BloomFilterBuildStrategy { + SINGLE_THREADED = 0, + PARALLEL = 1, +}; + +class ARROW_ACERO_EXPORT BloomFilterBuilder { + public: + virtual ~BloomFilterBuilder() = default; + virtual Status Begin(size_t num_threads, int64_t hardware_flags, MemoryPool* pool, + int64_t num_rows, int64_t num_batches, + BlockedBloomFilter* build_target) = 0; + virtual int64_t num_tasks() const { return 0; } + virtual Status PushNextBatch(size_t thread_index, int64_t num_rows, + const uint32_t* hashes) = 0; + virtual Status PushNextBatch(size_t thread_index, int64_t num_rows, + const uint64_t* hashes) = 0; + virtual void CleanUp() {} + static std::unique_ptr Make(BloomFilterBuildStrategy strategy); +}; + +class ARROW_ACERO_EXPORT BloomFilterBuilder_SingleThreaded : public BloomFilterBuilder { + public: + Status Begin(size_t num_threads, int64_t hardware_flags, MemoryPool* pool, + int64_t num_rows, int64_t num_batches, + BlockedBloomFilter* build_target) override; + + Status PushNextBatch(size_t /*thread_index*/, int64_t num_rows, + const uint32_t* hashes) override; + + Status PushNextBatch(size_t /*thread_index*/, int64_t num_rows, + const uint64_t* hashes) override; + + private: + template + void PushNextBatchImp(int64_t num_rows, const T* hashes); + + int64_t hardware_flags_; + BlockedBloomFilter* build_target_; +}; + +class ARROW_ACERO_EXPORT BloomFilterBuilder_Parallel : public BloomFilterBuilder { + public: + Status Begin(size_t num_threads, int64_t hardware_flags, MemoryPool* pool, + int64_t num_rows, int64_t num_batches, + BlockedBloomFilter* build_target) override; + + Status PushNextBatch(size_t thread_id, int64_t num_rows, + const uint32_t* hashes) override; + + Status PushNextBatch(size_t thread_id, int64_t num_rows, + const uint64_t* hashes) override; + + void CleanUp() override; + + private: + template + void PushNextBatchImp(size_t thread_id, int64_t num_rows, const T* hashes); + + int64_t hardware_flags_; + BlockedBloomFilter* build_target_; + int log_num_prtns_; + struct ThreadLocalState { + std::vector partitioned_hashes_32; + std::vector partitioned_hashes_64; + std::vector partition_ranges; + std::vector unprocessed_partition_ids; + }; + std::vector thread_local_states_; + PartitionLocks prtn_locks_; +}; + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/exec_plan.h b/pyarrow/include/arrow/acero/exec_plan.h new file mode 100644 index 0000000000000000000000000000000000000000..dba6c64ddc8379f7a8e6aa666f55555ced6c78aa --- /dev/null +++ b/pyarrow/include/arrow/acero/exec_plan.h @@ -0,0 +1,819 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/acero/type_fwd.h" +#include "arrow/acero/visibility.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/ordering.h" +#include "arrow/type_fwd.h" +#include "arrow/util/future.h" +#include "arrow/util/macros.h" +#include "arrow/util/tracing.h" +#include "arrow/util/type_fwd.h" + +namespace arrow { + +using compute::ExecBatch; +using compute::ExecContext; +using compute::FunctionRegistry; +using compute::GetFunctionRegistry; +using compute::Ordering; +using compute::threaded_exec_context; + +namespace acero { + +/// \addtogroup acero-internals +/// @{ + +class ARROW_ACERO_EXPORT ExecPlan : public std::enable_shared_from_this { + public: + // This allows operators to rely on signed 16-bit indices + static const uint32_t kMaxBatchSize = 1 << 15; + using NodeVector = std::vector; + + virtual ~ExecPlan() = default; + + QueryContext* query_context(); + + /// \brief retrieve the nodes in the plan + const NodeVector& nodes() const; + + /// Make an empty exec plan + static Result> Make( + QueryOptions options, ExecContext exec_context = *threaded_exec_context(), + std::shared_ptr metadata = NULLPTR); + + static Result> Make( + ExecContext exec_context = *threaded_exec_context(), + std::shared_ptr metadata = NULLPTR); + + static Result> Make( + QueryOptions options, ExecContext* exec_context, + std::shared_ptr metadata = NULLPTR); + + static Result> Make( + ExecContext* exec_context, + std::shared_ptr metadata = NULLPTR); + + ExecNode* AddNode(std::unique_ptr node); + + template + Node* EmplaceNode(Args&&... args) { + std::unique_ptr node{new Node{std::forward(args)...}}; + auto out = node.get(); + AddNode(std::move(node)); + return out; + } + + Status Validate(); + + /// \brief Start producing on all nodes + /// + /// Nodes are started in reverse topological order, such that any node + /// is started before all of its inputs. + void StartProducing(); + + /// \brief Stop producing on all nodes + /// + /// Triggers all sources to stop producing new data. In order to cleanly stop the plan + /// will continue to run any tasks that are already in progress. The caller should + /// still wait for `finished` to complete before destroying the plan. + void StopProducing(); + + /// \brief A future which will be marked finished when all tasks have finished. + Future<> finished(); + + /// \brief Return whether the plan has non-empty metadata + bool HasMetadata() const; + + /// \brief Return the plan's attached metadata + std::shared_ptr metadata() const; + + std::string ToString() const; +}; + +// Acero can be extended by providing custom implementations of ExecNode. The methods +// below are documented in detail and provide careful instruction on how to fulfill the +// ExecNode contract. It's suggested you familiarize yourself with the Acero +// documentation in the C++ user guide. +class ARROW_ACERO_EXPORT ExecNode { + public: + using NodeVector = std::vector; + + virtual ~ExecNode() = default; + + virtual const char* kind_name() const = 0; + + // The number of inputs expected by this node + int num_inputs() const { return static_cast(inputs_.size()); } + + /// This node's predecessors in the exec plan + const NodeVector& inputs() const { return inputs_; } + + /// True if the plan has no output schema (is a sink) + bool is_sink() const { return !output_schema_; } + + /// \brief Labels identifying the function of each input. + const std::vector& input_labels() const { return input_labels_; } + + /// This node's successor in the exec plan + const ExecNode* output() const { return output_; } + + /// The datatypes for batches produced by this node + const std::shared_ptr& output_schema() const { return output_schema_; } + + /// This node's exec plan + ExecPlan* plan() { return plan_; } + + /// \brief An optional label, for display and debugging + /// + /// There is no guarantee that this value is non-empty or unique. + const std::string& label() const { return label_; } + void SetLabel(std::string label) { label_ = std::move(label); } + + virtual Status Validate() const; + + /// \brief the ordering of the output batches + /// + /// This does not guarantee the batches will be emitted by this node + /// in order. Instead it guarantees that the batches will have their + /// ExecBatch::index property set in a way that respects this ordering. + /// + /// In other words, given the ordering {{"x", SortOrder::Ascending}} we + /// know that all values of x in a batch with index N will be less than + /// or equal to all values of x in a batch with index N+k (assuming k > 0). + /// Furthermore, we also know that values will be sorted within a batch. + /// Any row N will have a value of x that is less than the value for + /// any row N+k. + /// + /// Note that an ordering can be both Ordering::Unordered and Ordering::Implicit. + /// A node's output should be marked Ordering::Unordered if the order is + /// non-deterministic. For example, a hash-join has no predictable output order. + /// + /// If the ordering is Ordering::Implicit then there is a meaningful order but that + /// ordering is not represented by any column in the data. The most common case for + /// this is when reading data from an in-memory table. The data has an implicit "row + /// order" which is not necessarily represented in the data set. + /// + /// A filter or project node will not modify the ordering. Nothing needs to be done + /// other than ensure the index assigned to output batches is the same as the + /// input batch that was mapped. + /// + /// Other nodes may introduce order. For example, an order-by node will emit + /// a brand new ordering independent of the input ordering. + /// + /// Finally, as described above, such as a hash-join or aggregation may may + /// destroy ordering (although these nodes could also choose to establish a + /// new ordering based on the hash keys). + /// + /// Some nodes will require an ordering. For example, a fetch node or an + /// asof join node will only function if the input data is ordered (for fetch + /// it is enough to be implicitly ordered. For an asof join the ordering must + /// be explicit and compatible with the on key.) + /// + /// Nodes that maintain ordering should be careful to avoid introducing gaps + /// in the batch index. This may require emitting empty batches in order to + /// maintain continuity. + virtual const Ordering& ordering() const; + + /// Upstream API: + /// These functions are called by input nodes that want to inform this node + /// about an updated condition (a new input batch or an impending + /// end of stream). + /// + /// Implementation rules: + /// - these may be called anytime after StartProducing() has succeeded + /// (and even during or after StopProducing()) + /// - these may be called concurrently + /// - these are allowed to call back into PauseProducing(), ResumeProducing() + /// and StopProducing() + + /// Transfer input batch to ExecNode + /// + /// A node will typically perform some kind of operation on the batch + /// and then call InputReceived on its outputs with the result. + /// + /// Other nodes may need to accumulate some number of inputs before any + /// output can be produced. These nodes will add the batch to some kind + /// of in-memory accumulation queue and return. + virtual Status InputReceived(ExecNode* input, ExecBatch batch) = 0; + + /// Mark the inputs finished after the given number of batches. + /// + /// This may be called before all inputs are received. This simply fixes + /// the total number of incoming batches for an input, so that the ExecNode + /// knows when it has received all input, regardless of order. + virtual Status InputFinished(ExecNode* input, int total_batches) = 0; + + /// \brief Perform any needed initialization + /// + /// This hook performs any actions in between creation of ExecPlan and the call to + /// StartProducing. An example could be Bloom filter pushdown. The order of ExecNodes + /// that executes this method is undefined, but the calls are made synchronously. + /// + /// At this point a node can rely on all inputs & outputs (and the input schemas) + /// being well defined. + virtual Status Init(); + + /// Lifecycle API: + /// - start / stop to initiate and terminate production + /// - pause / resume to apply backpressure + /// + /// Implementation rules: + /// - StartProducing() should not recurse into the inputs, as it is + /// handled by ExecPlan::StartProducing() + /// - PauseProducing(), ResumeProducing(), StopProducing() may be called + /// concurrently, potentially even before the call to StartProducing + /// has finished. + /// - PauseProducing(), ResumeProducing(), StopProducing() may be called + /// by the downstream nodes' InputReceived(), InputFinished() methods + /// + /// StopProducing may be called due to an error, by the user (e.g. cancel), or + /// because a node has all the data it needs (e.g. limit, top-k on sorted data). + /// This means the method may be called multiple times and we have the following + /// additional rules + /// - StopProducing() must be idempotent + /// - StopProducing() must be forwarded to inputs (this is needed for the limit/top-k + /// case because we may not be stopping the entire plan) + + // Right now, since synchronous calls happen in both directions (input to + // output and then output to input), a node must be careful to be reentrant + // against synchronous calls from its output, *and* also concurrent calls from + // other threads. The most reliable solution is to update the internal state + // first, and notify outputs only at the end. + // + // Concurrent calls to PauseProducing and ResumeProducing can be hard to sequence + // as they may travel at different speeds through the plan. + // + // For example, consider a resume that comes quickly after a pause. If the source + // receives the resume before the pause the source may think the destination is full + // and halt production which would lead to deadlock. + // + // To resolve this a counter is sent for all calls to pause/resume. Only the call with + // the highest counter value is valid. So if a call to PauseProducing(5) comes after + // a call to ResumeProducing(6) then the source should continue producing. + + /// \brief Start producing + /// + /// This must only be called once. + /// + /// This is typically called automatically by ExecPlan::StartProducing(). + virtual Status StartProducing() = 0; + + /// \brief Pause producing temporarily + /// + /// \param output Pointer to the output that is full + /// \param counter Counter used to sequence calls to pause/resume + /// + /// This call is a hint that an output node is currently not willing + /// to receive data. + /// + /// This may be called any number of times. + /// However, the node is still free to produce data (which may be difficult + /// to prevent anyway if data is produced using multiple threads). + virtual void PauseProducing(ExecNode* output, int32_t counter) = 0; + + /// \brief Resume producing after a temporary pause + /// + /// \param output Pointer to the output that is now free + /// \param counter Counter used to sequence calls to pause/resume + /// + /// This call is a hint that an output node is willing to receive data again. + /// + /// This may be called any number of times. + virtual void ResumeProducing(ExecNode* output, int32_t counter) = 0; + + /// \brief Stop producing new data + /// + /// If this node is a source then the source should stop generating data + /// as quickly as possible. If this node is not a source then there is typically + /// nothing that needs to be done although a node may choose to start ignoring incoming + /// data. + /// + /// This method will be called when an error occurs in the plan + /// This method may also be called by the user if they wish to end a plan early + /// Finally, this method may be called if a node determines it no longer needs any more + /// input (for example, a limit node). + /// + /// This method may be called multiple times. + /// + /// This is not a pause. There will be no way to start the source again after this has + /// been called. + virtual Status StopProducing(); + + std::string ToString(int indent = 0) const; + + protected: + ExecNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, + std::shared_ptr output_schema); + + virtual Status StopProducingImpl() = 0; + + /// Provide extra info to include in the string representation. + virtual std::string ToStringExtra(int indent = 0) const; + + std::atomic stopped_; + ExecPlan* plan_; + std::string label_; + + NodeVector inputs_; + std::vector input_labels_; + + std::shared_ptr output_schema_; + ExecNode* output_ = NULLPTR; +}; + +/// \brief An extensible registry for factories of ExecNodes +class ARROW_ACERO_EXPORT ExecFactoryRegistry { + public: + using Factory = std::function(ExecPlan*, std::vector, + const ExecNodeOptions&)>; + + virtual ~ExecFactoryRegistry() = default; + + /// \brief Get the named factory from this registry + /// + /// will raise if factory_name is not found + virtual Result GetFactory(const std::string& factory_name) = 0; + + /// \brief Add a factory to this registry with the provided name + /// + /// will raise if factory_name is already in the registry + virtual Status AddFactory(std::string factory_name, Factory factory) = 0; +}; + +/// The default registry, which includes built-in factories. +ARROW_ACERO_EXPORT +ExecFactoryRegistry* default_exec_factory_registry(); + +/// \brief Construct an ExecNode using the named factory +inline Result MakeExecNode( + const std::string& factory_name, ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options, + ExecFactoryRegistry* registry = default_exec_factory_registry()) { + ARROW_ASSIGN_OR_RAISE(auto factory, registry->GetFactory(factory_name)); + return factory(plan, std::move(inputs), options); +} + +/// @} + +/// \addtogroup acero-api +/// @{ + +/// \brief Helper class for declaring execution nodes +/// +/// A Declaration represents an unconstructed ExecNode (and potentially an entire graph +/// since its inputs may also be Declarations) +/// +/// A Declaration can be converted to a plan and executed using one of the +/// DeclarationToXyz methods. +/// +/// For more direct control, a Declaration can be added to an existing execution +/// plan with Declaration::AddToPlan, which will recursively construct any inputs as +/// necessary. +struct ARROW_ACERO_EXPORT Declaration { + using Input = std::variant; + + Declaration() {} + + /// \brief construct a declaration + /// \param factory_name the name of the exec node to construct. The node must have + /// been added to the exec node registry with this name. + /// \param inputs the inputs to the node, these should be other declarations + /// \param options options that control the behavior of the node. You must use + /// the appropriate subclass. For example, if `factory_name` is + /// "project" then `options` should be ProjectNodeOptions. + /// \param label a label to give the node. Can be used to distinguish it from other + /// nodes of the same type in the plan. + Declaration(std::string factory_name, std::vector inputs, + std::shared_ptr options, std::string label) + : factory_name{std::move(factory_name)}, + inputs{std::move(inputs)}, + options{std::move(options)}, + label{std::move(label)} {} + + template + Declaration(std::string factory_name, std::vector inputs, Options options, + std::string label) + : Declaration{std::move(factory_name), std::move(inputs), + std::shared_ptr( + std::make_shared(std::move(options))), + std::move(label)} {} + + template + Declaration(std::string factory_name, std::vector inputs, Options options) + : Declaration{std::move(factory_name), std::move(inputs), std::move(options), + /*label=*/""} {} + + template + Declaration(std::string factory_name, Options options) + : Declaration{std::move(factory_name), {}, std::move(options), /*label=*/""} {} + + template + Declaration(std::string factory_name, Options options, std::string label) + : Declaration{std::move(factory_name), {}, std::move(options), std::move(label)} {} + + /// \brief Convenience factory for the common case of a simple sequence of nodes. + /// + /// Each of decls will be appended to the inputs of the subsequent declaration, + /// and the final modified declaration will be returned. + /// + /// Without this convenience factory, constructing a sequence would require explicit, + /// difficult-to-read nesting: + /// + /// Declaration{"n3", + /// { + /// Declaration{"n2", + /// { + /// Declaration{"n1", + /// { + /// Declaration{"n0", N0Opts{}}, + /// }, + /// N1Opts{}}, + /// }, + /// N2Opts{}}, + /// }, + /// N3Opts{}}; + /// + /// An equivalent Declaration can be constructed more tersely using Sequence: + /// + /// Declaration::Sequence({ + /// {"n0", N0Opts{}}, + /// {"n1", N1Opts{}}, + /// {"n2", N2Opts{}}, + /// {"n3", N3Opts{}}, + /// }); + static Declaration Sequence(std::vector decls); + + /// \brief add the declaration to an already created execution plan + /// \param plan the plan to add the node to + /// \param registry the registry to use to lookup the node factory + /// + /// This method will recursively call AddToPlan on all of the declaration's inputs. + /// This method is only for advanced use when the DeclarationToXyz methods are not + /// sufficient. + /// + /// \return the instantiated execution node + Result AddToPlan(ExecPlan* plan, ExecFactoryRegistry* registry = + default_exec_factory_registry()) const; + + // Validate a declaration + bool IsValid(ExecFactoryRegistry* registry = default_exec_factory_registry()) const; + + /// \brief the name of the factory to use when creating a node + std::string factory_name; + /// \brief the declarations's inputs + std::vector inputs; + /// \brief options to control the behavior of the node + std::shared_ptr options; + /// \brief a label to give the node in the plan + std::string label; +}; + +/// \brief How to handle unaligned buffers +enum class UnalignedBufferHandling { kWarn, kIgnore, kReallocate, kError }; + +/// \brief get the default behavior of unaligned buffer handling +/// +/// This is configurable via the ACERO_ALIGNMENT_HANDLING environment variable which +/// can be set to "warn", "ignore", "reallocate", or "error". If the environment +/// variable is not set, or is set to an invalid value, this will return kWarn +UnalignedBufferHandling GetDefaultUnalignedBufferHandling(); + +/// \brief plan-wide options that can be specified when executing an execution plan +struct ARROW_ACERO_EXPORT QueryOptions { + /// \brief Should the plan use a legacy batching strategy + /// + /// This is currently in place only to support the Scanner::ToTable + /// method. This method relies on batch indices from the scanner + /// remaining consistent. This is impractical in the ExecPlan which + /// might slice batches as needed (e.g. for a join) + /// + /// However, it still works for simple plans and this is the only way + /// we have at the moment for maintaining implicit order. + bool use_legacy_batching = false; + + /// If the output has a meaningful order then sequence the output of the plan + /// + /// The default behavior (std::nullopt) will sequence output batches if there + /// is a meaningful ordering in the final node and will emit batches immediately + /// otherwise. + /// + /// If explicitly set to true then plan execution will fail if there is no + /// meaningful ordering. This can be useful to validate a query that should + /// be emitting ordered results. + /// + /// If explicitly set to false then batches will be emit immediately even if there + /// is a meaningful ordering. This could cause batches to be emit out of order but + /// may offer a small decrease to latency. + std::optional sequence_output = std::nullopt; + + /// \brief should the plan use multiple background threads for CPU-intensive work + /// + /// If this is false then all CPU work will be done on the calling thread. I/O tasks + /// will still happen on the I/O executor and may be multi-threaded (but should not use + /// significant CPU resources). + /// + /// Will be ignored if custom_cpu_executor is set + bool use_threads = true; + + /// \brief custom executor to use for CPU-intensive work + /// + /// Must be null or remain valid for the duration of the plan. If this is null then + /// a default thread pool will be chosen whose behavior will be controlled by + /// the `use_threads` option. + ::arrow::internal::Executor* custom_cpu_executor = NULLPTR; + + /// \brief custom executor to use for IO work + /// + /// Must be null or remain valid for the duration of the plan. If this is null then + /// the global io thread pool will be chosen whose behavior will be controlled by + /// the "ARROW_IO_THREADS" environment. + ::arrow::internal::Executor* custom_io_executor = NULLPTR; + + /// \brief a memory pool to use for allocations + /// + /// Must remain valid for the duration of the plan. + MemoryPool* memory_pool = default_memory_pool(); + + /// \brief a function registry to use for the plan + /// + /// Must remain valid for the duration of the plan. + FunctionRegistry* function_registry = GetFunctionRegistry(); + /// \brief the names of the output columns + /// + /// If this is empty then names will be generated based on the input columns + /// + /// If set then the number of names must equal the number of output columns + std::vector field_names; + + /// \brief Policy for unaligned buffers in source data + /// + /// Various compute functions and acero internals will type pun array + /// buffers from uint8_t* to some kind of value type (e.g. we might + /// cast to int32_t* to add two int32 arrays) + /// + /// If the buffer is poorly aligned (e.g. an int32 array is not aligned + /// on a 4-byte boundary) then this is technically undefined behavior in C++. + /// However, most modern compilers and CPUs are fairly tolerant of this + /// behavior and nothing bad (beyond a small hit to performance) is likely + /// to happen. + /// + /// Note that this only applies to source buffers. All buffers allocated internally + /// by Acero will be suitably aligned. + /// + /// If this field is set to kWarn then Acero will check if any buffers are unaligned + /// and, if they are, will emit a warning. + /// + /// If this field is set to kReallocate then Acero will allocate a new, suitably aligned + /// buffer and copy the contents from the old buffer into this new buffer. + /// + /// If this field is set to kError then Acero will gracefully abort the plan instead. + /// + /// If this field is set to kIgnore then Acero will not even check if the buffers are + /// unaligned. + /// + /// If this field is not set then it will be treated as kWarn unless overridden + /// by the ACERO_ALIGNMENT_HANDLING environment variable + std::optional unaligned_buffer_handling; +}; + +/// \brief Calculate the output schema of a declaration +/// +/// This does not actually execute the plan. This operation may fail if the +/// declaration represents an invalid plan (e.g. a project node with multiple inputs) +/// +/// \param declaration A declaration describing an execution plan +/// \param function_registry The function registry to use for function execution. If null +/// then the default function registry will be used. +/// +/// \return the schema that batches would have after going through the execution plan +ARROW_ACERO_EXPORT Result> DeclarationToSchema( + const Declaration& declaration, FunctionRegistry* function_registry = NULLPTR); + +/// \brief Create a string representation of a plan +/// +/// This representation is for debug purposes only. +/// +/// Conversion to a string may fail if the declaration represents an +/// invalid plan. +/// +/// Use Substrait for complete serialization of plans +/// +/// \param declaration A declaration describing an execution plan +/// \param function_registry The function registry to use for function execution. If null +/// then the default function registry will be used. +/// +/// \return a string representation of the plan suitable for debugging output +ARROW_ACERO_EXPORT Result DeclarationToString( + const Declaration& declaration, FunctionRegistry* function_registry = NULLPTR); + +/// \brief Utility method to run a declaration and collect the results into a table +/// +/// \param declaration A declaration describing the plan to run +/// \param use_threads If `use_threads` is false then all CPU work will be done on the +/// calling thread. I/O tasks will still happen on the I/O executor +/// and may be multi-threaded (but should not use significant CPU +/// resources). +/// \param memory_pool The memory pool to use for allocations made while running the plan. +/// \param function_registry The function registry to use for function execution. If null +/// then the default function registry will be used. +/// +/// This method will add a sink node to the declaration to collect results into a +/// table. It will then create an ExecPlan from the declaration, start the exec plan, +/// block until the plan has finished, and return the created table. +ARROW_ACERO_EXPORT Result> DeclarationToTable( + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); + +ARROW_ACERO_EXPORT Result> DeclarationToTable( + Declaration declaration, QueryOptions query_options); + +/// \brief Asynchronous version of \see DeclarationToTable +/// +/// \param declaration A declaration describing the plan to run +/// \param use_threads The behavior of use_threads is slightly different than the +/// synchronous version since we cannot run synchronously on the +/// calling thread. Instead, if use_threads=false then a new thread +/// pool will be created with a single thread and this will be used for +/// all compute work. +/// \param memory_pool The memory pool to use for allocations made while running the plan. +/// \param function_registry The function registry to use for function execution. If null +/// then the default function registry will be used. +ARROW_ACERO_EXPORT Future> DeclarationToTableAsync( + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); + +/// \brief Overload of \see DeclarationToTableAsync accepting a custom exec context +/// +/// The executor must be specified (cannot be null) and must be kept alive until the +/// returned future finishes. +ARROW_ACERO_EXPORT Future> DeclarationToTableAsync( + Declaration declaration, ExecContext custom_exec_context); + +/// \brief a collection of exec batches with a common schema +struct BatchesWithCommonSchema { + std::vector batches; + std::shared_ptr schema; +}; + +/// \brief Utility method to run a declaration and collect the results into ExecBatch +/// vector +/// +/// \see DeclarationToTable for details on threading & execution +ARROW_ACERO_EXPORT Result DeclarationToExecBatches( + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); + +ARROW_ACERO_EXPORT Result DeclarationToExecBatches( + Declaration declaration, QueryOptions query_options); + +/// \brief Asynchronous version of \see DeclarationToExecBatches +/// +/// \see DeclarationToTableAsync for details on threading & execution +ARROW_ACERO_EXPORT Future DeclarationToExecBatchesAsync( + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); + +/// \brief Overload of \see DeclarationToExecBatchesAsync accepting a custom exec context +/// +/// \see DeclarationToTableAsync for details on threading & execution +ARROW_ACERO_EXPORT Future DeclarationToExecBatchesAsync( + Declaration declaration, ExecContext custom_exec_context); + +/// \brief Utility method to run a declaration and collect the results into a vector +/// +/// \see DeclarationToTable for details on threading & execution +ARROW_ACERO_EXPORT Result>> DeclarationToBatches( + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); + +ARROW_ACERO_EXPORT Result>> DeclarationToBatches( + Declaration declaration, QueryOptions query_options); + +/// \brief Asynchronous version of \see DeclarationToBatches +/// +/// \see DeclarationToTableAsync for details on threading & execution +ARROW_ACERO_EXPORT Future>> +DeclarationToBatchesAsync(Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); + +/// \brief Overload of \see DeclarationToBatchesAsync accepting a custom exec context +/// +/// \see DeclarationToTableAsync for details on threading & execution +ARROW_ACERO_EXPORT Future>> +DeclarationToBatchesAsync(Declaration declaration, ExecContext exec_context); + +/// \brief Utility method to run a declaration and return results as a RecordBatchReader +/// +/// If an exec context is not provided then a default exec context will be used based +/// on the value of `use_threads`. If `use_threads` is false then the CPU executor will +/// be a serial executor and all CPU work will be done on the calling thread. I/O tasks +/// will still happen on the I/O executor and may be multi-threaded. +/// +/// If `use_threads` is false then all CPU work will happen during the calls to +/// RecordBatchReader::Next and no CPU work will happen in the background. If +/// `use_threads` is true then CPU work will happen on the CPU thread pool and tasks may +/// run in between calls to RecordBatchReader::Next. If the returned reader is not +/// consumed quickly enough then the plan will eventually pause as the backpressure queue +/// fills up. +/// +/// If a custom exec context is provided then the value of `use_threads` will be ignored. +/// +/// The returned RecordBatchReader can be closed early to cancel the computation of record +/// batches. In this case, only errors encountered by the computation may be reported. In +/// particular, no cancellation error may be reported. +ARROW_ACERO_EXPORT Result> DeclarationToReader( + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); + +ARROW_ACERO_EXPORT Result> DeclarationToReader( + Declaration declaration, QueryOptions query_options); + +/// \brief Utility method to run a declaration and ignore results +/// +/// This can be useful when the data are consumed as part of the plan itself, for +/// example, when the plan ends with a write node. +/// +/// \see DeclarationToTable for details on threading & execution +ARROW_ACERO_EXPORT Status +DeclarationToStatus(Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); + +ARROW_ACERO_EXPORT Status DeclarationToStatus(Declaration declaration, + QueryOptions query_options); + +/// \brief Asynchronous version of \see DeclarationToStatus +/// +/// This can be useful when the data are consumed as part of the plan itself, for +/// example, when the plan ends with a write node. +/// +/// \see DeclarationToTableAsync for details on threading & execution +ARROW_ACERO_EXPORT Future<> DeclarationToStatusAsync( + Declaration declaration, bool use_threads = true, + MemoryPool* memory_pool = default_memory_pool(), + FunctionRegistry* function_registry = NULLPTR); + +/// \brief Overload of \see DeclarationToStatusAsync accepting a custom exec context +/// +/// \see DeclarationToTableAsync for details on threading & execution +ARROW_ACERO_EXPORT Future<> DeclarationToStatusAsync(Declaration declaration, + ExecContext exec_context); + +/// @} + +/// \brief Wrap an ExecBatch generator in a RecordBatchReader. +/// +/// The RecordBatchReader does not impose any ordering on emitted batches. +ARROW_ACERO_EXPORT +std::shared_ptr MakeGeneratorReader( + std::shared_ptr, std::function>()>, + MemoryPool*); + +constexpr int kDefaultBackgroundMaxQ = 32; +constexpr int kDefaultBackgroundQRestart = 16; + +/// \brief Make a generator of RecordBatchReaders +/// +/// Useful as a source node for an Exec plan +ARROW_ACERO_EXPORT +Result>()>> MakeReaderGenerator( + std::shared_ptr reader, arrow::internal::Executor* io_executor, + int max_q = kDefaultBackgroundMaxQ, int q_restart = kDefaultBackgroundQRestart); + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/hash_join.h b/pyarrow/include/arrow/acero/hash_join.h new file mode 100644 index 0000000000000000000000000000000000000000..c0faacf04baf02e865a61a0301a0cfa92b3fab1b --- /dev/null +++ b/pyarrow/include/arrow/acero/hash_join.h @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/acero/accumulation_queue.h" +#include "arrow/acero/bloom_filter.h" +#include "arrow/acero/options.h" +#include "arrow/acero/query_context.h" +#include "arrow/acero/schema_util.h" +#include "arrow/acero/task_util.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/tracing.h" + +namespace arrow { +namespace acero { + +using util::AccumulationQueue; + +class ARROW_ACERO_EXPORT HashJoinImpl { + public: + using OutputBatchCallback = std::function; + using BuildFinishedCallback = std::function; + using FinishedCallback = std::function; + using RegisterTaskGroupCallback = std::function, std::function)>; + using StartTaskGroupCallback = std::function; + using AbortContinuationImpl = std::function; + + virtual ~HashJoinImpl() = default; + virtual Status Init(QueryContext* ctx, JoinType join_type, size_t num_threads, + const HashJoinProjectionMaps* proj_map_left, + const HashJoinProjectionMaps* proj_map_right, + std::vector key_cmp, Expression filter, + RegisterTaskGroupCallback register_task_group_callback, + StartTaskGroupCallback start_task_group_callback, + OutputBatchCallback output_batch_callback, + FinishedCallback finished_callback) = 0; + + virtual Status BuildHashTable(size_t thread_index, AccumulationQueue batches, + BuildFinishedCallback on_finished) = 0; + virtual Status ProbeSingleBatch(size_t thread_index, ExecBatch batch) = 0; + virtual Status ProbingFinished(size_t thread_index) = 0; + virtual void Abort(TaskScheduler::AbortContinuationImpl pos_abort_callback) = 0; + virtual std::string ToString() const = 0; + + static Result> MakeBasic(); + static Result> MakeSwiss(); + + protected: + arrow::util::tracing::Span span_; +}; + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/hash_join_dict.h b/pyarrow/include/arrow/acero/hash_join_dict.h new file mode 100644 index 0000000000000000000000000000000000000000..02454a7146278176e27379e6033f79547574a367 --- /dev/null +++ b/pyarrow/include/arrow/acero/hash_join_dict.h @@ -0,0 +1,318 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/acero/schema_util.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/row/row_encoder_internal.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" + +// This file contains hash join logic related to handling of dictionary encoded key +// columns. +// +// A key column from probe side of the join can be matched against a key column from build +// side of the join, as long as the underlying value types are equal. That means that: +// - both scalars and arrays can be used and even mixed in the same column +// - dictionary column can be matched against non-dictionary column if underlying value +// types are equal +// - dictionary column can be matched against dictionary column with a different index +// type, and potentially using a different dictionary, if underlying value types are equal +// +// We currently require in hash join that for all dictionary encoded columns, the same +// dictionary is used in all input exec batches. +// +// In order to allow matching columns with different dictionaries, different dictionary +// index types, and dictionary key against non-dictionary key, internally comparisons will +// be evaluated after remapping values on both sides of the join to a common +// representation (which will be called "unified representation"). This common +// representation is a column of int32() type (not a dictionary column). It represents an +// index in the unified dictionary computed for the (only) dictionary present on build +// side (an empty dictionary is still created for an empty build side). Null value is +// always represented in this common representation as null int32 value, unified +// dictionary will never contain a null value (so there is no ambiguity of representing +// nulls as either index to a null entry in the dictionary or null index). +// +// Unified dictionary represents values present on build side. There may be values on +// probe side that are not present in it. All such values, that are not null, are mapped +// in the common representation to a special constant kMissingValueId. +// + +namespace arrow { + +using compute::ExecBatch; +using compute::ExecContext; +using compute::internal::RowEncoder; + +namespace acero { + +/// Helper class with operations that are stateless and common to processing of dictionary +/// keys on both build and probe side. +class HashJoinDictUtil { + public: + // Null values in unified representation are always represented as null that has + // corresponding integer set to this constant + static constexpr int32_t kNullId = 0; + // Constant representing a value, that is not null, missing on the build side, in + // unified representation. + static constexpr int32_t kMissingValueId = -1; + + // Check if data types of corresponding pair of key column on build and probe side are + // compatible + static bool KeyDataTypesValid(const std::shared_ptr& probe_data_type, + const std::shared_ptr& build_data_type); + + // Input must be dictionary array or dictionary scalar. + // A precomputed and provided here lookup table in the form of int32() array will be + // used to remap input indices to unified representation. + // + static Result> IndexRemapUsingLUT( + ExecContext* ctx, const Datum& indices, int64_t batch_length, + const std::shared_ptr& map_array, + const std::shared_ptr& data_type); + + // Return int32() array that contains indices of input dictionary array or scalar after + // type casting. + static Result> ConvertToInt32( + const std::shared_ptr& from_type, const Datum& input, + int64_t batch_length, ExecContext* ctx); + + // Return an array that contains elements of input int32() array after casting to a + // given integer type. This is used for mapping unified representation stored in the + // hash table on build side back to original input data type of hash join, when + // outputting hash join results to parent exec node. + // + static Result> ConvertFromInt32( + const std::shared_ptr& to_type, const Datum& input, int64_t batch_length, + ExecContext* ctx); + + // Return dictionary referenced in either dictionary array or dictionary scalar + static std::shared_ptr ExtractDictionary(const Datum& data); +}; + +/// Implements processing of dictionary arrays/scalars in key columns on the build side of +/// a hash join. +/// Each instance of this class corresponds to a single column and stores and +/// processes only the information related to that column. +/// Const methods are thread-safe, non-const methods are not (the caller must make sure +/// that only one thread at any time will access them). +/// +class HashJoinDictBuild { + public: + // Returns true if the key column (described in input by its data type) requires any + // pre- or post-processing related to handling dictionaries. + // + static bool KeyNeedsProcessing(const std::shared_ptr& build_data_type) { + return (build_data_type->id() == Type::DICTIONARY); + } + + // Data type of unified representation + static std::shared_ptr DataTypeAfterRemapping() { return int32(); } + + // Should be called only once in hash join, before processing any build or probe + // batches. + // + // Takes a pointer to the dictionary for a corresponding key column on the build side as + // an input. If the build side is empty, it still needs to be called, but with + // dictionary pointer set to null. + // + // Currently it is required that all input batches on build side share the same + // dictionary. For each input batch during its pre-processing, dictionary will be + // checked and error will be returned if it is different then the one provided in the + // call to this method. + // + // Unifies the dictionary. The order of the values is still preserved. + // Null and duplicate entries are removed. If the dictionary is already unified, its + // copy will be produced and stored within this class. + // + // Prepares the mapping from ids within original dictionary to the ids in the resulting + // dictionary. This is used later on to pre-process (map to unified representation) key + // column on build side. + // + // Prepares the reverse mapping (in the form of hash table) from values to the ids in + // the resulting dictionary. This will be used later on to pre-process (map to unified + // representation) key column on probe side. Values on probe side that are not present + // in the original dictionary will be mapped to a special constant kMissingValueId. The + // exception is made for nulls, which get always mapped to nulls (both when null is + // represented as a dictionary id pointing to a null and a null dictionary id). + // + Status Init(ExecContext* ctx, std::shared_ptr dictionary, + std::shared_ptr index_type, std::shared_ptr value_type); + + // Remap array or scalar values into unified representation (array of int32()). + // Outputs kMissingValueId if input value is not found in the unified dictionary. + // Outputs null for null input value (with corresponding data set to kNullId). + // + Result> RemapInputValues(ExecContext* ctx, + const Datum& values, + int64_t batch_length) const; + + // Remap dictionary array or dictionary scalar on build side to unified representation. + // Dictionary referenced in the input must match the dictionary that was + // given during initialization. + // The output is a dictionary array that references unified dictionary. + // + Result> RemapInput( + ExecContext* ctx, const Datum& indices, int64_t batch_length, + const std::shared_ptr& data_type) const; + + // Outputs dictionary array referencing unified dictionary, given an array with 32-bit + // ids. + // Used to post-process values looked up in a hash table on build side of the hash join + // before outputting to the parent exec node. + // + Result> RemapOutput(const ArrayData& indices32Bit, + ExecContext* ctx) const; + + // Release shared pointers and memory + void CleanUp(); + + private: + // Data type of dictionary ids for the input dictionary on build side + std::shared_ptr index_type_; + // Data type of values for the input dictionary on build side + std::shared_ptr value_type_; + // Mapping from (encoded as string) values to the ids in unified dictionary + std::unordered_map hash_table_; + // Mapping from input dictionary ids to unified dictionary ids + std::shared_ptr remapped_ids_; + // Input dictionary + std::shared_ptr dictionary_; + // Unified dictionary + std::shared_ptr unified_dictionary_; +}; + +/// Implements processing of dictionary arrays/scalars in key columns on the probe side of +/// a hash join. +/// Each instance of this class corresponds to a single column and stores and +/// processes only the information related to that column. +/// It is not thread-safe - every participating thread should use its own instance of +/// this class. +/// +class HashJoinDictProbe { + public: + static bool KeyNeedsProcessing(const std::shared_ptr& probe_data_type, + const std::shared_ptr& build_data_type); + + // Data type of the result of remapping input key column. + // + // The result of remapping is what is used in hash join for matching keys on build and + // probe side. The exact data types may be different, as described below, and therefore + // a common representation is needed for simplifying comparisons of pairs of keys on + // both sides. + // + // We support matching key that is of non-dictionary type with key that is of dictionary + // type, as long as the underlying value types are equal. We support matching when both + // keys are of dictionary type, regardless whether underlying dictionary index types are + // the same or not. + // + static std::shared_ptr DataTypeAfterRemapping( + const std::shared_ptr& build_data_type); + + // Should only be called if KeyNeedsProcessing method returns true for a pair of + // corresponding key columns from build and probe side. + // Converts values in order to match the common representation for + // both build and probe side used in hash table comparison. + // Supports arrays and scalars as input. + // Argument opt_build_side should be null if dictionary key on probe side is matched + // with non-dictionary key on build side. + // + Result> RemapInput( + const HashJoinDictBuild* opt_build_side, const Datum& data, int64_t batch_length, + const std::shared_ptr& probe_data_type, + const std::shared_ptr& build_data_type, ExecContext* ctx); + + void CleanUp(); + + private: + // May be null if probe side key is non-dictionary. Otherwise it is used to verify that + // only a single dictionary is referenced in exec batch on probe side of hash join. + std::shared_ptr dictionary_; + // Mapping from dictionary on probe side of hash join (if it is used) to unified + // representation. + std::shared_ptr remapped_ids_; + // Encoder of key columns that uses unified representation instead of original data type + // for key columns that need to use it (have dictionaries on either side of the join). + RowEncoder encoder_; +}; + +// Encapsulates dictionary handling logic for build side of hash join. +// +class HashJoinDictBuildMulti { + public: + Status Init(const SchemaProjectionMaps& proj_map, + const ExecBatch* opt_non_empty_batch, ExecContext* ctx); + static void InitEncoder(const SchemaProjectionMaps& proj_map, + RowEncoder* encoder, ExecContext* ctx); + Status EncodeBatch(size_t thread_index, + const SchemaProjectionMaps& proj_map, + const ExecBatch& batch, RowEncoder* encoder, ExecContext* ctx) const; + Status PostDecode(const SchemaProjectionMaps& proj_map, + ExecBatch* decoded_key_batch, ExecContext* ctx); + const HashJoinDictBuild& get_dict_build(int icol) const { return remap_imp_[icol]; } + + private: + std::vector needs_remap_; + std::vector remap_imp_; +}; + +// Encapsulates dictionary handling logic for probe side of hash join +// +class HashJoinDictProbeMulti { + public: + void Init(size_t num_threads); + bool BatchRemapNeeded(size_t thread_index, + const SchemaProjectionMaps& proj_map_probe, + const SchemaProjectionMaps& proj_map_build, + ExecContext* ctx); + Status EncodeBatch(size_t thread_index, + const SchemaProjectionMaps& proj_map_probe, + const SchemaProjectionMaps& proj_map_build, + const HashJoinDictBuildMulti& dict_build, const ExecBatch& batch, + RowEncoder** out_encoder, ExecBatch* opt_out_key_batch, + ExecContext* ctx); + + private: + void InitLocalStateIfNeeded( + size_t thread_index, const SchemaProjectionMaps& proj_map_probe, + const SchemaProjectionMaps& proj_map_build, ExecContext* ctx); + static void InitEncoder(const SchemaProjectionMaps& proj_map_probe, + const SchemaProjectionMaps& proj_map_build, + RowEncoder* encoder, ExecContext* ctx); + struct ThreadLocalState { + bool is_initialized; + // Whether any key column needs remapping (because of dictionaries used) before doing + // join hash table lookups + bool any_needs_remap; + // Whether each key column needs remapping before doing join hash table lookups + std::vector needs_remap; + std::vector remap_imp; + // Encoder of key columns that uses unified representation instead of original data + // type for key columns that need to use it (have dictionaries on either side of the + // join). + RowEncoder post_remap_encoder; + }; + std::vector local_states_; +}; + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/hash_join_node.h b/pyarrow/include/arrow/acero/hash_join_node.h new file mode 100644 index 0000000000000000000000000000000000000000..19745b8675cf0c63ed92c6e5448c9e6a68467f59 --- /dev/null +++ b/pyarrow/include/arrow/acero/hash_join_node.h @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/acero/options.h" +#include "arrow/acero/schema_util.h" +#include "arrow/result.h" +#include "arrow/status.h" + +namespace arrow { + +using compute::ExecContext; + +namespace acero { + +class ARROW_ACERO_EXPORT HashJoinSchema { + public: + Status Init(JoinType join_type, const Schema& left_schema, + const std::vector& left_keys, const Schema& right_schema, + const std::vector& right_keys, const Expression& filter, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix); + + Status Init(JoinType join_type, const Schema& left_schema, + const std::vector& left_keys, + const std::vector& left_output, const Schema& right_schema, + const std::vector& right_keys, + const std::vector& right_output, const Expression& filter, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix); + + static Status ValidateSchemas(JoinType join_type, const Schema& left_schema, + const std::vector& left_keys, + const std::vector& left_output, + const Schema& right_schema, + const std::vector& right_keys, + const std::vector& right_output, + const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix); + + bool HasDictionaries() const; + + bool HasLargeBinary() const; + + Result BindFilter(Expression filter, const Schema& left_schema, + const Schema& right_schema, ExecContext* exec_context); + std::shared_ptr MakeOutputSchema(const std::string& left_field_name_suffix, + const std::string& right_field_name_suffix); + + bool LeftPayloadIsEmpty() const { return PayloadIsEmpty(0); } + + bool RightPayloadIsEmpty() const { return PayloadIsEmpty(1); } + + static int kMissingField() { + return SchemaProjectionMaps::kMissingField; + } + + SchemaProjectionMaps proj_maps[2]; + + private: + static bool IsTypeSupported(const DataType& type); + + Status CollectFilterColumns(std::vector& left_filter, + std::vector& right_filter, + const Expression& filter, const Schema& left_schema, + const Schema& right_schema); + + Expression RewriteFilterToUseFilterSchema(int right_filter_offset, + const SchemaProjectionMap& left_to_filter, + const SchemaProjectionMap& right_to_filter, + const Expression& filter); + + bool PayloadIsEmpty(int side) const { + assert(side == 0 || side == 1); + return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0; + } + + static Result> ComputePayload(const Schema& schema, + const std::vector& output, + const std::vector& filter, + const std::vector& key); +}; + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/map_node.h b/pyarrow/include/arrow/acero/map_node.h new file mode 100644 index 0000000000000000000000000000000000000000..8bdd0ab2ca3854c6561aa3735ae143e7c58b4f77 --- /dev/null +++ b/pyarrow/include/arrow/acero/map_node.h @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/acero/exec_plan.h" +#include "arrow/acero/util.h" +#include "arrow/acero/visibility.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/cancel.h" +#include "arrow/util/type_fwd.h" + +namespace arrow { +namespace acero { + +/// A utility base class for simple exec nodes with one input +/// +/// Pause/Resume Producing are forwarded appropriately +/// There is nothing to do in StopProducingImpl +/// +/// An AtomicCounter is used to keep track of when all data has arrived. When it +/// has the Finish() method will be invoked +class ARROW_ACERO_EXPORT MapNode : public ExecNode, public TracedNode { + public: + MapNode(ExecPlan* plan, std::vector inputs, + std::shared_ptr output_schema); + + Status InputFinished(ExecNode* input, int total_batches) override; + + Status StartProducing() override; + + void PauseProducing(ExecNode* output, int32_t counter) override; + + void ResumeProducing(ExecNode* output, int32_t counter) override; + + Status InputReceived(ExecNode* input, ExecBatch batch) override; + + const Ordering& ordering() const override; + + protected: + Status StopProducingImpl() override; + + /// Transform a batch + /// + /// The output batch will have the same guarantee as the input batch + /// If this was the last batch this call may trigger Finish() + virtual Result ProcessBatch(ExecBatch batch) = 0; + + /// Function called after all data has been received + /// + /// By default this does nothing. Override this to provide a custom implementation. + virtual void Finish(); + + protected: + // Counter for the number of batches received + AtomicCounter input_counter_; +}; + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/options.h b/pyarrow/include/arrow/acero/options.h new file mode 100644 index 0000000000000000000000000000000000000000..827e9ea775d7b8e892d05f9b81a79ec25991cc3c --- /dev/null +++ b/pyarrow/include/arrow/acero/options.h @@ -0,0 +1,874 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/acero/type_fwd.h" +#include "arrow/acero/visibility.h" +#include "arrow/compute/api_aggregate.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/expression.h" +#include "arrow/result.h" +#include "arrow/util/future.h" + +namespace arrow { + +using compute::Aggregate; +using compute::ExecBatch; +using compute::Expression; +using compute::literal; +using compute::Ordering; +using compute::SelectKOptions; +using compute::SortOptions; + +namespace internal { + +class Executor; + +} // namespace internal + +namespace acero { + +/// \brief This must not be used in release-mode +struct DebugOptions; + +using AsyncExecBatchGenerator = std::function>()>; + +/// \addtogroup acero-nodes +/// @{ + +/// \brief A base class for all options objects +/// +/// The only time this is used directly is when a node has no configuration +class ARROW_ACERO_EXPORT ExecNodeOptions { + public: + virtual ~ExecNodeOptions() = default; + + /// \brief This must not be used in release-mode + std::shared_ptr debug_opts; +}; + +/// \brief A node representing a generic source of data for Acero +/// +/// The source node will start calling `generator` during StartProducing. An initial +/// task will be created that will call `generator`. It will not call `generator` +/// reentrantly. If the source can be read in parallel then those details should be +/// encapsulated within `generator`. +/// +/// For each batch received a new task will be created to push that batch downstream. +/// This task will slice smaller units of size `ExecPlan::kMaxBatchSize` from the +/// parent batch and call InputReceived. Thus, if the `generator` yields a large +/// batch it may result in several calls to InputReceived. +/// +/// The SourceNode will, by default, assign an implicit ordering to outgoing batches. +/// This is valid as long as the generator generates batches in a deterministic fashion. +/// Currently, the only way to override this is to subclass the SourceNode. +/// +/// This node is not generally used directly but can serve as the basis for various +/// specialized nodes. +class ARROW_ACERO_EXPORT SourceNodeOptions : public ExecNodeOptions { + public: + /// Create an instance from values + SourceNodeOptions(std::shared_ptr output_schema, + std::function>()> generator, + Ordering ordering = Ordering::Unordered()) + : output_schema(std::move(output_schema)), + generator(std::move(generator)), + ordering(std::move(ordering)) {} + + /// \brief the schema for batches that will be generated by this source + std::shared_ptr output_schema; + /// \brief an asynchronous stream of batches ending with std::nullopt + std::function>()> generator; + /// \brief the order of the data, defaults to Ordering::Unordered + Ordering ordering; +}; + +/// \brief a node that generates data from a table already loaded in memory +/// +/// The table source node will slice off chunks, defined by `max_batch_size` +/// for parallel processing. The table source node extends source node and so these +/// chunks will be iteratively processed in small batches. \see SourceNodeOptions +/// for details. +class ARROW_ACERO_EXPORT TableSourceNodeOptions : public ExecNodeOptions { + public: + static constexpr int64_t kDefaultMaxBatchSize = 1 << 20; + + /// Create an instance from values + TableSourceNodeOptions(std::shared_ptr table, + int64_t max_batch_size = kDefaultMaxBatchSize) + : table(std::move(table)), max_batch_size(max_batch_size) {} + + /// \brief a table which acts as the data source + std::shared_ptr
table; + /// \brief size of batches to emit from this node + /// If the table is larger the node will emit multiple batches from the + /// the table to be processed in parallel. + int64_t max_batch_size; +}; + +/// \brief define a lazily resolved Arrow table. +/// +/// The table uniquely identified by the names can typically be resolved at the time when +/// the plan is to be consumed. +/// +/// This node is for serialization purposes only and can never be executed. +class ARROW_ACERO_EXPORT NamedTableNodeOptions : public ExecNodeOptions { + public: + /// Create an instance from values + NamedTableNodeOptions(std::vector names, std::shared_ptr schema) + : names(std::move(names)), schema(std::move(schema)) {} + + /// \brief the names to put in the serialized plan + std::vector names; + /// \brief the output schema of the table + std::shared_ptr schema; +}; + +/// \brief a source node which feeds data from a synchronous iterator of batches +/// +/// ItMaker is a maker of an iterator of tabular data. +/// +/// The node can be configured to use an I/O executor. If set then each time the +/// iterator is polled a new I/O thread task will be created to do the polling. This +/// allows a blocking iterator to stay off the CPU thread pool. +template +class ARROW_ACERO_EXPORT SchemaSourceNodeOptions : public ExecNodeOptions { + public: + /// Create an instance that will create a new task on io_executor for each iteration + SchemaSourceNodeOptions(std::shared_ptr schema, ItMaker it_maker, + arrow::internal::Executor* io_executor) + : schema(std::move(schema)), + it_maker(std::move(it_maker)), + io_executor(io_executor), + requires_io(true) {} + + /// Create an instance that will either iterate synchronously or use the default I/O + /// executor + SchemaSourceNodeOptions(std::shared_ptr schema, ItMaker it_maker, + bool requires_io = false) + : schema(std::move(schema)), + it_maker(std::move(it_maker)), + io_executor(NULLPTR), + requires_io(requires_io) {} + + /// \brief The schema of the record batches from the iterator + std::shared_ptr schema; + + /// \brief A maker of an iterator which acts as the data source + ItMaker it_maker; + + /// \brief The executor to use for scanning the iterator + /// + /// Defaults to the default I/O executor. Only used if requires_io is true. + /// If requires_io is false then this MUST be nullptr. + arrow::internal::Executor* io_executor; + + /// \brief If true then items will be fetched from the iterator on a dedicated I/O + /// thread to keep I/O off the CPU thread + bool requires_io; +}; + +/// a source node that reads from a RecordBatchReader +/// +/// Each iteration of the RecordBatchReader will be run on a new thread task created +/// on the I/O thread pool. +class ARROW_ACERO_EXPORT RecordBatchReaderSourceNodeOptions : public ExecNodeOptions { + public: + /// Create an instance from values + RecordBatchReaderSourceNodeOptions(std::shared_ptr reader, + arrow::internal::Executor* io_executor = NULLPTR) + : reader(std::move(reader)), io_executor(io_executor) {} + + /// \brief The RecordBatchReader which acts as the data source + std::shared_ptr reader; + + /// \brief The executor to use for the reader + /// + /// Defaults to the default I/O executor. + arrow::internal::Executor* io_executor; +}; + +/// a source node that reads from an iterator of array vectors +using ArrayVectorIteratorMaker = std::function>()>; +/// \brief An extended Source node which accepts a schema and array-vectors +class ARROW_ACERO_EXPORT ArrayVectorSourceNodeOptions + : public SchemaSourceNodeOptions { + using SchemaSourceNodeOptions::SchemaSourceNodeOptions; +}; + +/// a source node that reads from an iterator of ExecBatch +using ExecBatchIteratorMaker = std::function>()>; +/// \brief An extended Source node which accepts a schema and exec-batches +class ARROW_ACERO_EXPORT ExecBatchSourceNodeOptions + : public SchemaSourceNodeOptions { + public: + using SchemaSourceNodeOptions::SchemaSourceNodeOptions; + ExecBatchSourceNodeOptions(std::shared_ptr schema, + std::vector batches, + ::arrow::internal::Executor* io_executor); + ExecBatchSourceNodeOptions(std::shared_ptr schema, + std::vector batches, bool requires_io = false); +}; + +using RecordBatchIteratorMaker = std::function>()>; +/// a source node that reads from an iterator of RecordBatch +class ARROW_ACERO_EXPORT RecordBatchSourceNodeOptions + : public SchemaSourceNodeOptions { + using SchemaSourceNodeOptions::SchemaSourceNodeOptions; +}; + +/// \brief a node which excludes some rows from batches passed through it +/// +/// filter_expression will be evaluated against each batch which is pushed to +/// this node. Any rows for which filter_expression does not evaluate to `true` will be +/// excluded in the batch emitted by this node. +/// +/// This node will emit empty batches if all rows are excluded. This is done +/// to avoid gaps in the ordering. +class ARROW_ACERO_EXPORT FilterNodeOptions : public ExecNodeOptions { + public: + /// \brief create an instance from values + explicit FilterNodeOptions(Expression filter_expression) + : filter_expression(std::move(filter_expression)) {} + + /// \brief the expression to filter batches + /// + /// The return type of this expression must be boolean + Expression filter_expression; +}; + +/// \brief a node which selects a specified subset from the input +class ARROW_ACERO_EXPORT FetchNodeOptions : public ExecNodeOptions { + public: + static constexpr std::string_view kName = "fetch"; + /// \brief create an instance from values + FetchNodeOptions(int64_t offset, int64_t count) : offset(offset), count(count) {} + /// \brief the number of rows to skip + int64_t offset; + /// \brief the number of rows to keep (not counting skipped rows) + int64_t count; +}; + +/// \brief a node which executes expressions on input batches, producing batches +/// of the same length with new columns. +/// +/// Each expression will be evaluated against each batch which is pushed to +/// this node to produce a corresponding output column. +/// +/// If names are not provided, the string representations of exprs will be used. +class ARROW_ACERO_EXPORT ProjectNodeOptions : public ExecNodeOptions { + public: + /// \brief create an instance from values + explicit ProjectNodeOptions(std::vector expressions, + std::vector names = {}) + : expressions(std::move(expressions)), names(std::move(names)) {} + + /// \brief the expressions to run on the batches + /// + /// The output will have one column for each expression. If you wish to keep any of + /// the columns from the input then you should create a simple field_ref expression + /// for that column. + std::vector expressions; + /// \brief the names of the output columns + /// + /// If this is not specified then the result of calling ToString on the expression will + /// be used instead + /// + /// This list should either be empty or have the same length as `expressions` + std::vector names; +}; + +/// \brief a node which aggregates input batches and calculates summary statistics +/// +/// The node can summarize the entire input or it can group the input with grouping keys +/// and segment keys. +/// +/// By default, the aggregate node is a pipeline breaker. It must accumulate all input +/// before any output is produced. Segment keys are a performance optimization. If +/// you know your input is already partitioned by one or more columns then you can +/// specify these as segment keys. At each change in the segment keys the node will +/// emit values for all data seen so far. +/// +/// Segment keys are currently limited to single-threaded mode. +/// +/// Both keys and segment-keys determine the group. However segment-keys are also used +/// for determining grouping segments, which should be large, and allow streaming a +/// partial aggregation result after processing each segment. One common use-case for +/// segment-keys is ordered aggregation, in which the segment-key attribute specifies a +/// column with non-decreasing values or a lexicographically-ordered set of such columns. +/// +/// If the keys attribute is a non-empty vector, then each aggregate in `aggregates` is +/// expected to be a HashAggregate function. If the keys attribute is an empty vector, +/// then each aggregate is assumed to be a ScalarAggregate function. +/// +/// If the segment_keys attribute is a non-empty vector, then segmented aggregation, as +/// described above, applies. +/// +/// The keys and segment_keys vectors must be disjoint. +/// +/// If no measures are provided then you will simply get the list of unique keys. +/// +/// This node outputs segment keys first, followed by regular keys, followed by one +/// column for each aggregate. +class ARROW_ACERO_EXPORT AggregateNodeOptions : public ExecNodeOptions { + public: + /// \brief create an instance from values + explicit AggregateNodeOptions(std::vector aggregates, + std::vector keys = {}, + std::vector segment_keys = {}) + : aggregates(std::move(aggregates)), + keys(std::move(keys)), + segment_keys(std::move(segment_keys)) {} + + // aggregations which will be applied to the targeted fields + std::vector aggregates; + // keys by which aggregations will be grouped (optional) + std::vector keys; + // keys by which aggregations will be segmented (optional) + std::vector segment_keys; +}; + +/// \brief a default value at which backpressure will be applied +constexpr int32_t kDefaultBackpressureHighBytes = 1 << 30; // 1GiB +/// \brief a default value at which backpressure will be removed +constexpr int32_t kDefaultBackpressureLowBytes = 1 << 28; // 256MiB + +/// \brief an interface that can be queried for backpressure statistics +class ARROW_ACERO_EXPORT BackpressureMonitor { + public: + virtual ~BackpressureMonitor() = default; + /// \brief fetches the number of bytes currently queued up + virtual uint64_t bytes_in_use() = 0; + /// \brief checks to see if backpressure is currently applied + virtual bool is_paused() = 0; +}; + +/// \brief Options to control backpressure behavior +struct ARROW_ACERO_EXPORT BackpressureOptions { + /// \brief Create default options that perform no backpressure + BackpressureOptions() : resume_if_below(0), pause_if_above(0) {} + /// \brief Create options that will perform backpressure + /// + /// \param resume_if_below The producer should resume producing if the backpressure + /// queue has fewer than resume_if_below items. + /// \param pause_if_above The producer should pause producing if the backpressure + /// queue has more than pause_if_above items + BackpressureOptions(uint64_t resume_if_below, uint64_t pause_if_above) + : resume_if_below(resume_if_below), pause_if_above(pause_if_above) {} + + /// \brief create an instance using default values for backpressure limits + static BackpressureOptions DefaultBackpressure() { + return BackpressureOptions(kDefaultBackpressureLowBytes, + kDefaultBackpressureHighBytes); + } + + /// \brief helper method to determine if backpressure is disabled + /// \return true if pause_if_above is greater than zero, false otherwise + bool should_apply_backpressure() const { return pause_if_above > 0; } + + /// \brief the number of bytes at which the producer should resume producing + uint64_t resume_if_below; + /// \brief the number of bytes at which the producer should pause producing + /// + /// If this is <= 0 then backpressure will be disabled + uint64_t pause_if_above; +}; + +/// \brief a sink node which collects results in a queue +/// +/// Emitted batches will only be ordered if there is a meaningful ordering +/// and sequence_output is not set to false. +class ARROW_ACERO_EXPORT SinkNodeOptions : public ExecNodeOptions { + public: + explicit SinkNodeOptions(std::function>()>* generator, + std::shared_ptr* schema, + BackpressureOptions backpressure = {}, + BackpressureMonitor** backpressure_monitor = NULLPTR, + std::optional sequence_output = std::nullopt) + : generator(generator), + schema(schema), + backpressure(backpressure), + backpressure_monitor(backpressure_monitor), + sequence_output(sequence_output) {} + + explicit SinkNodeOptions(std::function>()>* generator, + BackpressureOptions backpressure = {}, + BackpressureMonitor** backpressure_monitor = NULLPTR, + std::optional sequence_output = std::nullopt) + : generator(generator), + schema(NULLPTR), + backpressure(std::move(backpressure)), + backpressure_monitor(backpressure_monitor), + sequence_output(sequence_output) {} + + /// \brief A pointer to a generator of batches. + /// + /// This will be set when the node is added to the plan and should be used to consume + /// data from the plan. If this function is not called frequently enough then the sink + /// node will start to accumulate data and may apply backpressure. + std::function>()>* generator; + /// \brief A pointer which will be set to the schema of the generated batches + /// + /// This is optional, if nullptr is passed in then it will be ignored. + /// This will be set when the node is added to the plan, before StartProducing is called + std::shared_ptr* schema; + /// \brief Options to control when to apply backpressure + /// + /// This is optional, the default is to never apply backpressure. If the plan is not + /// consumed quickly enough the system may eventually run out of memory. + BackpressureOptions backpressure; + /// \brief A pointer to a backpressure monitor + /// + /// This will be set when the node is added to the plan. This can be used to inspect + /// the amount of data currently queued in the sink node. This is an optional utility + /// and backpressure can be applied even if this is not used. + BackpressureMonitor** backpressure_monitor; + /// \brief Controls whether batches should be emitted immediately or sequenced in order + /// + /// \see QueryOptions for more details + std::optional sequence_output; +}; + +/// \brief Control used by a SinkNodeConsumer to pause & resume +/// +/// Callers should ensure that they do not call Pause and Resume simultaneously and they +/// should sequence things so that a call to Pause() is always followed by an eventual +/// call to Resume() +class ARROW_ACERO_EXPORT BackpressureControl { + public: + virtual ~BackpressureControl() = default; + /// \brief Ask the input to pause + /// + /// This is best effort, batches may continue to arrive + /// Must eventually be followed by a call to Resume() or deadlock will occur + virtual void Pause() = 0; + /// \brief Ask the input to resume + virtual void Resume() = 0; +}; + +/// \brief a sink node that consumes the data as part of the plan using callbacks +class ARROW_ACERO_EXPORT SinkNodeConsumer { + public: + virtual ~SinkNodeConsumer() = default; + /// \brief Prepare any consumer state + /// + /// This will be run once the schema is finalized as the plan is starting and + /// before any calls to Consume. A common use is to save off the schema so that + /// batches can be interpreted. + virtual Status Init(const std::shared_ptr& schema, + BackpressureControl* backpressure_control, ExecPlan* plan) = 0; + /// \brief Consume a batch of data + virtual Status Consume(ExecBatch batch) = 0; + /// \brief Signal to the consumer that the last batch has been delivered + /// + /// The returned future should only finish when all outstanding tasks have completed + /// + /// If the plan is ended early or aborts due to an error then this will not be + /// called. + virtual Future<> Finish() = 0; +}; + +/// \brief Add a sink node which consumes data within the exec plan run +class ARROW_ACERO_EXPORT ConsumingSinkNodeOptions : public ExecNodeOptions { + public: + explicit ConsumingSinkNodeOptions(std::shared_ptr consumer, + std::vector names = {}, + std::optional sequence_output = std::nullopt) + : consumer(std::move(consumer)), + names(std::move(names)), + sequence_output(sequence_output) {} + + std::shared_ptr consumer; + /// \brief Names to rename the sink's schema fields to + /// + /// If specified then names must be provided for all fields. Currently, only a flat + /// schema is supported (see GH-31875). + /// + /// If not specified then names will be generated based on the source data. + std::vector names; + /// \brief Controls whether batches should be emitted immediately or sequenced in order + /// + /// \see QueryOptions for more details + std::optional sequence_output; +}; + +/// \brief Make a node which sorts rows passed through it +/// +/// All batches pushed to this node will be accumulated, then sorted, by the given +/// fields. Then sorted batches will be forwarded to the generator in sorted order. +class ARROW_ACERO_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions { + public: + /// \brief create an instance from values + explicit OrderBySinkNodeOptions( + SortOptions sort_options, + std::function>()>* generator) + : SinkNodeOptions(generator), sort_options(std::move(sort_options)) {} + + /// \brief options describing which columns and direction to sort + SortOptions sort_options; +}; + +/// \brief Apply a new ordering to data +/// +/// Currently this node works by accumulating all data, sorting, and then emitting +/// the new data with an updated batch index. +/// +/// Larger-than-memory sort is not currently supported. +class ARROW_ACERO_EXPORT OrderByNodeOptions : public ExecNodeOptions { + public: + static constexpr std::string_view kName = "order_by"; + explicit OrderByNodeOptions(Ordering ordering) : ordering(std::move(ordering)) {} + + /// \brief The new ordering to apply to outgoing data + Ordering ordering; +}; + +enum class JoinType { + LEFT_SEMI, + RIGHT_SEMI, + LEFT_ANTI, + RIGHT_ANTI, + INNER, + LEFT_OUTER, + RIGHT_OUTER, + FULL_OUTER +}; + +std::string ToString(JoinType t); + +enum class JoinKeyCmp { EQ, IS }; + +/// \brief a node which implements a join operation using a hash table +class ARROW_ACERO_EXPORT HashJoinNodeOptions : public ExecNodeOptions { + public: + static constexpr const char* default_output_suffix_for_left = ""; + static constexpr const char* default_output_suffix_for_right = ""; + /// \brief create an instance from values that outputs all columns + HashJoinNodeOptions( + JoinType in_join_type, std::vector in_left_keys, + std::vector in_right_keys, Expression filter = literal(true), + std::string output_suffix_for_left = default_output_suffix_for_left, + std::string output_suffix_for_right = default_output_suffix_for_right, + bool disable_bloom_filter = false) + : join_type(in_join_type), + left_keys(std::move(in_left_keys)), + right_keys(std::move(in_right_keys)), + output_all(true), + output_suffix_for_left(std::move(output_suffix_for_left)), + output_suffix_for_right(std::move(output_suffix_for_right)), + filter(std::move(filter)), + disable_bloom_filter(disable_bloom_filter) { + this->key_cmp.resize(this->left_keys.size()); + for (size_t i = 0; i < this->left_keys.size(); ++i) { + this->key_cmp[i] = JoinKeyCmp::EQ; + } + } + /// \brief create an instance from keys + /// + /// This will create an inner join that outputs all columns and has no post join filter + /// + /// `in_left_keys` should have the same length and types as `in_right_keys` + /// @param in_left_keys the keys in the left input + /// @param in_right_keys the keys in the right input + HashJoinNodeOptions(std::vector in_left_keys, + std::vector in_right_keys) + : left_keys(std::move(in_left_keys)), right_keys(std::move(in_right_keys)) { + this->join_type = JoinType::INNER; + this->output_all = true; + this->output_suffix_for_left = default_output_suffix_for_left; + this->output_suffix_for_right = default_output_suffix_for_right; + this->key_cmp.resize(this->left_keys.size()); + for (size_t i = 0; i < this->left_keys.size(); ++i) { + this->key_cmp[i] = JoinKeyCmp::EQ; + } + this->filter = literal(true); + } + /// \brief create an instance from values using JoinKeyCmp::EQ for all comparisons + HashJoinNodeOptions( + JoinType join_type, std::vector left_keys, + std::vector right_keys, std::vector left_output, + std::vector right_output, Expression filter = literal(true), + std::string output_suffix_for_left = default_output_suffix_for_left, + std::string output_suffix_for_right = default_output_suffix_for_right, + bool disable_bloom_filter = false) + : join_type(join_type), + left_keys(std::move(left_keys)), + right_keys(std::move(right_keys)), + output_all(false), + left_output(std::move(left_output)), + right_output(std::move(right_output)), + output_suffix_for_left(std::move(output_suffix_for_left)), + output_suffix_for_right(std::move(output_suffix_for_right)), + filter(std::move(filter)), + disable_bloom_filter(disable_bloom_filter) { + this->key_cmp.resize(this->left_keys.size()); + for (size_t i = 0; i < this->left_keys.size(); ++i) { + this->key_cmp[i] = JoinKeyCmp::EQ; + } + } + /// \brief create an instance from values + HashJoinNodeOptions( + JoinType join_type, std::vector left_keys, + std::vector right_keys, std::vector left_output, + std::vector right_output, std::vector key_cmp, + Expression filter = literal(true), + std::string output_suffix_for_left = default_output_suffix_for_left, + std::string output_suffix_for_right = default_output_suffix_for_right, + bool disable_bloom_filter = false) + : join_type(join_type), + left_keys(std::move(left_keys)), + right_keys(std::move(right_keys)), + output_all(false), + left_output(std::move(left_output)), + right_output(std::move(right_output)), + key_cmp(std::move(key_cmp)), + output_suffix_for_left(std::move(output_suffix_for_left)), + output_suffix_for_right(std::move(output_suffix_for_right)), + filter(std::move(filter)), + disable_bloom_filter(disable_bloom_filter) {} + + HashJoinNodeOptions() = default; + + // type of join (inner, left, semi...) + JoinType join_type = JoinType::INNER; + // key fields from left input + std::vector left_keys; + // key fields from right input + std::vector right_keys; + // if set all valid fields from both left and right input will be output + // (and field ref vectors for output fields will be ignored) + bool output_all = false; + // output fields passed from left input + std::vector left_output; + // output fields passed from right input + std::vector right_output; + // key comparison function (determines whether a null key is equal another null + // key or not) + std::vector key_cmp; + // suffix added to names of output fields coming from left input (used to distinguish, + // if necessary, between fields of the same name in left and right input and can be left + // empty if there are no name collisions) + std::string output_suffix_for_left; + // suffix added to names of output fields coming from right input + std::string output_suffix_for_right; + // residual filter which is applied to matching rows. Rows that do not match + // the filter are not included. The filter is applied against the + // concatenated input schema (left fields then right fields) and can reference + // fields that are not included in the output. + Expression filter = literal(true); + // whether or not to disable Bloom filters in this join + bool disable_bloom_filter = false; +}; + +/// \brief a node which implements the asof join operation +/// +/// Note, this API is experimental and will change in the future +/// +/// This node takes one left table and any number of right tables, and asof joins them +/// together. Batches produced by each input must be ordered by the "on" key. +/// This node will output one row for each row in the left table. +class ARROW_ACERO_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { + public: + /// \brief Keys for one input table of the AsofJoin operation + /// + /// The keys must be consistent across the input tables: + /// Each "on" key must refer to a field of the same type and units across the tables. + /// Each "by" key must refer to a list of fields of the same types across the tables. + struct Keys { + /// \brief "on" key for the join. + /// + /// The input table must be sorted by the "on" key. Must be a single field of a common + /// type. An inexact match is used on the "on" key, i.e. a row is considered a + /// match if and only if `right.on - left.on` is in the range + /// `[min(0, tolerance), max(0, tolerance)]`. + /// Currently, the "on" key must be of an integer, date, or timestamp type. + FieldRef on_key; + /// \brief "by" key for the join. + /// + /// Each input table must have each field of the "by" key. Exact equality is used for + /// each field of the "by" key. + /// Currently, each field of the "by" key must be of an integer, date, timestamp, or + /// base-binary type. + std::vector by_key; + }; + + AsofJoinNodeOptions(std::vector input_keys, int64_t tolerance) + : input_keys(std::move(input_keys)), tolerance(tolerance) {} + + /// \brief AsofJoin keys per input table. At least two keys must be given. The first key + /// corresponds to a left table and all other keys correspond to right tables for the + /// as-of-join. + /// + /// \see `Keys` for details. + std::vector input_keys; + /// \brief Tolerance for inexact "on" key matching. A right row is considered a match + /// with a left row if `right.on - left.on` is in the range + /// `[min(0, tolerance), max(0, tolerance)]`. `tolerance` may be: + /// - negative, in which case a past-as-of-join occurs (match iff + /// `tolerance <= right.on - left.on <= 0`); + /// - or positive, in which case a future-as-of-join occurs (match iff + /// `0 <= right.on - left.on <= tolerance`); + /// - or zero, in which case an exact-as-of-join occurs (match iff + /// `right.on == left.on`). + /// + /// The tolerance is interpreted in the same units as the "on" key. + int64_t tolerance; +}; + +/// \brief a node which select top_k/bottom_k rows passed through it +/// +/// All batches pushed to this node will be accumulated, then selected, by the given +/// fields. Then sorted batches will be forwarded to the generator in sorted order. +class ARROW_ACERO_EXPORT SelectKSinkNodeOptions : public SinkNodeOptions { + public: + explicit SelectKSinkNodeOptions( + SelectKOptions select_k_options, + std::function>()>* generator) + : SinkNodeOptions(generator), select_k_options(std::move(select_k_options)) {} + + /// SelectK options + SelectKOptions select_k_options; +}; + +/// \brief a sink node which accumulates all output into a table +class ARROW_ACERO_EXPORT TableSinkNodeOptions : public ExecNodeOptions { + public: + /// \brief create an instance from values + explicit TableSinkNodeOptions(std::shared_ptr
* output_table, + std::optional sequence_output = std::nullopt) + : output_table(output_table), sequence_output(sequence_output) {} + + /// \brief an "out parameter" specifying the table that will be created + /// + /// Must not be null and remain valid for the entirety of the plan execution. After the + /// plan has completed this will be set to point to the result table + std::shared_ptr
* output_table; + /// \brief Controls whether batches should be emitted immediately or sequenced in order + /// + /// \see QueryOptions for more details + std::optional sequence_output; + /// \brief Custom names to use for the columns. + /// + /// If specified then names must be provided for all fields. Currently, only a flat + /// schema is supported (see GH-31875). + /// + /// If not specified then names will be generated based on the source data. + std::vector names; +}; + +/// \brief a row template that describes one row that will be generated for each input row +struct ARROW_ACERO_EXPORT PivotLongerRowTemplate { + PivotLongerRowTemplate(std::vector feature_values, + std::vector> measurement_values) + : feature_values(std::move(feature_values)), + measurement_values(std::move(measurement_values)) {} + /// A (typically unique) set of feature values for the template, usually derived from a + /// column name + /// + /// These will be used to populate the feature columns + std::vector feature_values; + /// The fields containing the measurements to use for this row + /// + /// These will be used to populate the measurement columns. If nullopt then nulls + /// will be inserted for the given value. + std::vector> measurement_values; +}; + +/// \brief Reshape a table by turning some columns into additional rows +/// +/// This operation is sometimes also referred to as UNPIVOT +/// +/// This is typically done when there are multiple observations in each row in order to +/// transform to a table containing a single observation per row. +/// +/// For example: +/// +/// | time | left_temp | right_temp | +/// | ---- | --------- | ---------- | +/// | 1 | 10 | 20 | +/// | 2 | 15 | 18 | +/// +/// The above table contains two observations per row. There is an implicit feature +/// "location" (left vs right) and a measurement "temp". What we really want is: +/// +/// | time | location | temp | +/// | --- | --- | --- | +/// | 1 | left | 10 | +/// | 1 | right | 20 | +/// | 2 | left | 15 | +/// | 2 | right | 18 | +/// +/// For a more complex example consider: +/// +/// | time | ax1 | ay1 | bx1 | ay2 | +/// | ---- | --- | --- | --- | --- | +/// | 0 | 1 | 2 | 3 | 4 | +/// +/// We can pretend a vs b and x vs y are features while 1 and 2 are two different +/// kinds of measurements. We thus want to pivot to +/// +/// | time | a/b | x/y | f1 | f2 | +/// | ---- | --- | --- | ---- | ---- | +/// | 0 | a | x | 1 | null | +/// | 0 | a | y | 2 | 4 | +/// | 0 | b | x | 3 | null | +/// +/// To do this we create a row template for each combination of features. One should +/// be able to do this purely by looking at the column names. For example, given the +/// above columns "ax1", "ay1", "bx1", and "ay2" we know we have three feature +/// combinations (a, x), (a, y), and (b, x). Similarly, we know we have two possible +/// measurements, "1" and "2". +/// +/// For each combination of features we create a row template. In each row template we +/// describe the combination and then list which columns to use for the measurements. +/// If a measurement doesn't exist for a given combination then we use nullopt. +/// +/// So, for our above example, we have: +/// +/// (a, x): names={"a", "x"}, values={"ax1", nullopt} +/// (a, y): names={"a", "y"}, values={"ay1", "ay2"} +/// (b, x): names={"b", "x"}, values={"bx1", nullopt} +/// +/// Finishing it off we name our new columns: +/// feature_field_names={"a/b","x/y"} +/// measurement_field_names={"f1", "f2"} +class ARROW_ACERO_EXPORT PivotLongerNodeOptions : public ExecNodeOptions { + public: + static constexpr std::string_view kName = "pivot_longer"; + /// One or more row templates to create new output rows + /// + /// Normally there are at least two row templates. The output # of rows + /// will be the input # of rows * the number of row templates + std::vector row_templates; + /// The names of the columns which describe the new features + std::vector feature_field_names; + /// The names of the columns which represent the measurements + std::vector measurement_field_names; +}; + +/// @} + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/order_by_impl.h b/pyarrow/include/arrow/acero/order_by_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..9b5a0f69a69ffc8f23fb5416e82777d2d06f0a00 --- /dev/null +++ b/pyarrow/include/arrow/acero/order_by_impl.h @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/acero/options.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" + +namespace arrow { + +using compute::ExecContext; + +namespace acero { + +class OrderByImpl { + public: + virtual ~OrderByImpl() = default; + + virtual void InputReceived(const std::shared_ptr& batch) = 0; + + virtual Result DoFinish() = 0; + + virtual std::string ToString() const = 0; + + static Result> MakeSort( + ExecContext* ctx, const std::shared_ptr& output_schema, + const SortOptions& options); + + static Result> MakeSelectK( + ExecContext* ctx, const std::shared_ptr& output_schema, + const SelectKOptions& options); +}; + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/partition_util.h b/pyarrow/include/arrow/acero/partition_util.h new file mode 100644 index 0000000000000000000000000000000000000000..52cc47bb8a99f5fcc32defa09698c715025f322b --- /dev/null +++ b/pyarrow/include/arrow/acero/partition_util.h @@ -0,0 +1,186 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/acero/util.h" +#include "arrow/buffer.h" +#include "arrow/util/pcg_random.h" + +namespace arrow { +namespace acero { + +class PartitionSort { + public: + /// \brief Bucket sort rows on partition ids in O(num_rows) time. + /// + /// Include in the output exclusive cumulative sum of bucket sizes. + /// This corresponds to ranges in the sorted array containing all row ids for + /// each of the partitions. + /// + /// prtn_ranges must be initialized and have at least num_prtns + 1 elements + /// when this method returns prtn_ranges[i] will contains the total number of + /// elements in partitions 0 through i. prtn_ranges[0] will be 0. + /// + /// prtn_id_impl must be a function that takes in a row id (int) and returns + /// a partition id (int). The returned partition id must be between 0 and + /// num_prtns (exclusive). + /// + /// output_pos_impl is a function that takes in a row id (int) and a position (int) + /// in the bucket sorted output. The function should insert the row in the + /// output. + /// + /// For example: + /// + /// in_arr: [5, 7, 2, 3, 5, 4] + /// num_prtns: 3 + /// prtn_id_impl: [&in_arr] (int row_id) { return in_arr[row_id] / 3; } + /// output_pos_impl: [&sorted_row_ids] (int row_id, int pos) { sorted_row_ids[pos] = + /// row_id; } + /// + /// After Execution + /// sorted_row_ids: [2, 0, 3, 4, 5, 1] + /// prtn_ranges: [0, 1, 5, 6] + template + static void Eval(int64_t num_rows, int num_prtns, uint16_t* prtn_ranges, + INPUT_PRTN_ID_FN prtn_id_impl, OUTPUT_POS_FN output_pos_impl) { + ARROW_DCHECK(num_rows > 0 && num_rows <= (1 << 15)); + ARROW_DCHECK(num_prtns >= 1 && num_prtns <= (1 << 15)); + + memset(prtn_ranges, 0, (num_prtns + 1) * sizeof(uint16_t)); + + for (int64_t i = 0; i < num_rows; ++i) { + int prtn_id = static_cast(prtn_id_impl(i)); + ++prtn_ranges[prtn_id + 1]; + } + + uint16_t sum = 0; + for (int i = 0; i < num_prtns; ++i) { + uint16_t sum_next = sum + prtn_ranges[i + 1]; + prtn_ranges[i + 1] = sum; + sum = sum_next; + } + + for (int64_t i = 0; i < num_rows; ++i) { + int prtn_id = static_cast(prtn_id_impl(i)); + int pos = prtn_ranges[prtn_id + 1]++; + output_pos_impl(i, pos); + } + } +}; + +/// \brief A control for synchronizing threads on a partitionable workload +class PartitionLocks { + public: + PartitionLocks(); + ~PartitionLocks(); + /// \brief Initializes the control, must be called before use + /// + /// \param num_threads Maximum number of threads that will access the partitions + /// \param num_prtns Number of partitions to synchronize + void Init(size_t num_threads, int num_prtns); + /// \brief Cleans up the control, it should not be used after this call + void CleanUp(); + /// \brief Acquire a partition to work on one + /// + /// \param thread_id The index of the thread trying to acquire the partition lock + /// \param num_prtns Length of prtns_to_try, must be <= num_prtns used in Init + /// \param prtns_to_try An array of partitions that still have remaining work + /// \param limit_retries If false, this method will spinwait forever until success + /// \param max_retries Max times to attempt checking out work before returning false + /// \param[out] locked_prtn_id The id of the partition locked + /// \param[out] locked_prtn_id_pos The index of the partition locked in prtns_to_try + /// \return True if a partition was locked, false if max_retries was attempted + /// without successfully acquiring a lock + /// + /// This method is thread safe + bool AcquirePartitionLock(size_t thread_id, int num_prtns, const int* prtns_to_try, + bool limit_retries, int max_retries, int* locked_prtn_id, + int* locked_prtn_id_pos); + /// \brief Release a partition so that other threads can work on it + void ReleasePartitionLock(int prtn_id); + + // Executes (synchronously and using current thread) the same operation on a set of + // multiple partitions. Tries to minimize partition locking overhead by randomizing and + // adjusting order in which partitions are processed. + // + // PROCESS_PRTN_FN is a callback which will be executed for each partition after + // acquiring the lock for that partition. It gets partition id as an argument. + // IS_PRTN_EMPTY_FN is a callback which filters out (when returning true) partitions + // with specific ids from processing. + // + template + Status ForEachPartition(size_t thread_id, + /*scratch space buffer with space for one element per partition; + dirty in and dirty out*/ + int* temp_unprocessed_prtns, IS_PRTN_EMPTY_FN is_prtn_empty_fn, + PROCESS_PRTN_FN process_prtn_fn) { + int num_unprocessed_partitions = 0; + for (int i = 0; i < num_prtns_; ++i) { + bool is_prtn_empty = is_prtn_empty_fn(i); + if (!is_prtn_empty) { + temp_unprocessed_prtns[num_unprocessed_partitions++] = i; + } + } + while (num_unprocessed_partitions > 0) { + int locked_prtn_id; + int locked_prtn_id_pos; + AcquirePartitionLock(thread_id, num_unprocessed_partitions, temp_unprocessed_prtns, + /*limit_retries=*/false, /*max_retries=*/-1, &locked_prtn_id, + &locked_prtn_id_pos); + { + class AutoReleaseLock { + public: + AutoReleaseLock(PartitionLocks* locks, int prtn_id) + : locks(locks), prtn_id(prtn_id) {} + ~AutoReleaseLock() { locks->ReleasePartitionLock(prtn_id); } + PartitionLocks* locks; + int prtn_id; + } auto_release_lock(this, locked_prtn_id); + ARROW_RETURN_NOT_OK(process_prtn_fn(locked_prtn_id)); + } + if (locked_prtn_id_pos < num_unprocessed_partitions - 1) { + temp_unprocessed_prtns[locked_prtn_id_pos] = + temp_unprocessed_prtns[num_unprocessed_partitions - 1]; + } + --num_unprocessed_partitions; + } + return Status::OK(); + } + + private: + std::atomic* lock_ptr(int prtn_id); + int random_int(size_t thread_id, int num_values); + + struct PartitionLock { + static constexpr int kCacheLineBytes = 64; + std::atomic lock; + uint8_t padding[kCacheLineBytes]; + }; + int num_prtns_; + std::unique_ptr locks_; + std::unique_ptr rngs_; +}; + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/query_context.h b/pyarrow/include/arrow/acero/query_context.h new file mode 100644 index 0000000000000000000000000000000000000000..3eff299439828e602558e5ebc278660bb7ce37eb --- /dev/null +++ b/pyarrow/include/arrow/acero/query_context.h @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +#pragma once + +#include + +#include "arrow/acero/exec_plan.h" +#include "arrow/acero/task_util.h" +#include "arrow/acero/util.h" +#include "arrow/compute/exec.h" +#include "arrow/io/interfaces.h" +#include "arrow/util/async_util.h" +#include "arrow/util/type_fwd.h" + +namespace arrow { + +using compute::default_exec_context; +using io::IOContext; + +namespace acero { + +class ARROW_ACERO_EXPORT QueryContext { + public: + QueryContext(QueryOptions opts = {}, + ExecContext exec_context = *default_exec_context()); + + Status Init(arrow::util::AsyncTaskScheduler* scheduler); + + const ::arrow::internal::CpuInfo* cpu_info() const; + int64_t hardware_flags() const; + const QueryOptions& options() const { return options_; } + MemoryPool* memory_pool() const { return exec_context_.memory_pool(); } + ::arrow::internal::Executor* executor() const { return exec_context_.executor(); } + ExecContext* exec_context() { return &exec_context_; } + IOContext* io_context() { return &io_context_; } + TaskScheduler* scheduler() { return task_scheduler_.get(); } + arrow::util::AsyncTaskScheduler* async_scheduler() { return async_scheduler_; } + + size_t GetThreadIndex(); + size_t max_concurrency() const; + + /// \brief Start an external task + /// + /// This should be avoided if possible. It is kept in for now for legacy + /// purposes. This should be called before the external task is started. If + /// a valid future is returned then it should be marked complete when the + /// external task has finished. + /// + /// \param name A name to give the task for traceability and debugging + /// + /// \return an invalid future if the plan has already ended, otherwise this + /// returns a future that must be completed when the external task + /// finishes. + Result> BeginExternalTask(std::string_view name); + + /// \brief Add a single function as a task to the query's task group + /// on the compute threadpool. + /// + /// \param fn The task to run. Takes no arguments and returns a Status. + /// \param name A name to give the task for traceability and debugging + void ScheduleTask(std::function fn, std::string_view name); + /// \brief Add a single function as a task to the query's task group + /// on the compute threadpool. + /// + /// \param fn The task to run. Takes the thread index and returns a Status. + /// \param name A name to give the task for traceability and debugging + void ScheduleTask(std::function fn, std::string_view name); + /// \brief Add a single function as a task to the query's task group on + /// the IO thread pool + /// + /// \param fn The task to run. Returns a status. + /// \param name A name to give the task for traceability and debugging + void ScheduleIOTask(std::function fn, std::string_view name); + + // Register/Start TaskGroup is a way of performing a "Parallel For" pattern: + // - The task function takes the thread index and the index of the task + // - The on_finished function takes the thread index + // Returns an integer ID that will be used to reference the task group in + // StartTaskGroup. At runtime, call StartTaskGroup with the ID and the number of times + // you'd like the task to be executed. The need to register a task group before use will + // be removed after we rewrite the scheduler. + /// \brief Register a "parallel for" task group with the scheduler + /// + /// \param task The function implementing the task. Takes the thread_index and + /// the task index. + /// \param on_finished The function that gets run once all tasks have been completed. + /// Takes the thread_index. + /// + /// Must be called inside of ExecNode::Init. + int RegisterTaskGroup(std::function task, + std::function on_finished); + + /// \brief Start the task group with the specified ID. This can only + /// be called once per task_group_id. + /// + /// \param task_group_id The ID of the task group to run + /// \param num_tasks The number of times to run the task + Status StartTaskGroup(int task_group_id, int64_t num_tasks); + + // This is an RAII class for keeping track of in-flight file IO. Useful for getting + // an estimate of memory use, and how much memory we expect to be freed soon. + // Returned by ReportTempFileIO. + struct [[nodiscard]] TempFileIOMark { + QueryContext* ctx_; + size_t bytes_; + + TempFileIOMark(QueryContext* ctx, size_t bytes) : ctx_(ctx), bytes_(bytes) { + ctx_->in_flight_bytes_to_disk_.fetch_add(bytes_, std::memory_order_acquire); + } + + ARROW_DISALLOW_COPY_AND_ASSIGN(TempFileIOMark); + + ~TempFileIOMark() { + ctx_->in_flight_bytes_to_disk_.fetch_sub(bytes_, std::memory_order_release); + } + }; + + TempFileIOMark ReportTempFileIO(size_t bytes) { return {this, bytes}; } + + size_t GetCurrentTempFileIO() { return in_flight_bytes_to_disk_.load(); } + + private: + QueryOptions options_; + // To be replaced with Acero-specific context once scheduler is done and + // we don't need ExecContext for kernels + ExecContext exec_context_; + IOContext io_context_; + + arrow::util::AsyncTaskScheduler* async_scheduler_ = NULLPTR; + std::unique_ptr task_scheduler_ = TaskScheduler::Make(); + + ThreadIndexer thread_indexer_; + + std::atomic in_flight_bytes_to_disk_{0}; +}; +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/schema_util.h b/pyarrow/include/arrow/acero/schema_util.h new file mode 100644 index 0000000000000000000000000000000000000000..db3076a58841a6cb85fcc3d5033ef3b74ed18898 --- /dev/null +++ b/pyarrow/include/arrow/acero/schema_util.h @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/type.h" // for DataType, FieldRef, Field and Schema + +namespace arrow { + +using internal::checked_cast; + +namespace acero { + +// Identifiers for all different row schemas that are used in a join +// +enum class HashJoinProjection : int { + INPUT = 0, + KEY = 1, + PAYLOAD = 2, + FILTER = 3, + OUTPUT = 4 +}; + +struct SchemaProjectionMap { + static constexpr int kMissingField = -1; + int num_cols; + const int* source_to_base; + const int* base_to_target; + inline int get(int i) const { + assert(i >= 0 && i < num_cols); + assert(source_to_base[i] != kMissingField); + return base_to_target[source_to_base[i]]; + } +}; + +/// Helper class for managing different projections of the same row schema. +/// Used to efficiently map any field in one projection to a corresponding field in +/// another projection. +/// Materialized mappings are generated lazily at the time of the first access. +/// Thread-safe apart from initialization. +template +class SchemaProjectionMaps { + public: + static constexpr int kMissingField = -1; + + Status Init(ProjectionIdEnum full_schema_handle, const Schema& schema, + const std::vector& projection_handles, + const std::vector*>& projections) { + assert(projection_handles.size() == projections.size()); + ARROW_RETURN_NOT_OK(RegisterSchema(full_schema_handle, schema)); + for (size_t i = 0; i < projections.size(); ++i) { + ARROW_RETURN_NOT_OK( + RegisterProjectedSchema(projection_handles[i], *(projections[i]), schema)); + } + RegisterEnd(); + return Status::OK(); + } + + int num_cols(ProjectionIdEnum schema_handle) const { + int id = schema_id(schema_handle); + return static_cast(schemas_[id].second.data_types.size()); + } + + bool is_empty(ProjectionIdEnum schema_handle) const { + return num_cols(schema_handle) == 0; + } + + const std::string& field_name(ProjectionIdEnum schema_handle, int field_id) const { + int id = schema_id(schema_handle); + return schemas_[id].second.field_names[field_id]; + } + + const std::shared_ptr& data_type(ProjectionIdEnum schema_handle, + int field_id) const { + int id = schema_id(schema_handle); + return schemas_[id].second.data_types[field_id]; + } + + const std::vector>& data_types( + ProjectionIdEnum schema_handle) const { + int id = schema_id(schema_handle); + return schemas_[id].second.data_types; + } + + SchemaProjectionMap map(ProjectionIdEnum from, ProjectionIdEnum to) const { + int id_from = schema_id(from); + int id_to = schema_id(to); + SchemaProjectionMap result; + result.num_cols = num_cols(from); + result.source_to_base = mappings_[id_from].data(); + result.base_to_target = inverse_mappings_[id_to].data(); + return result; + } + + protected: + struct FieldInfos { + std::vector field_paths; + std::vector field_names; + std::vector> data_types; + }; + + Status RegisterSchema(ProjectionIdEnum handle, const Schema& schema) { + FieldInfos out_fields; + const FieldVector& in_fields = schema.fields(); + out_fields.field_paths.resize(in_fields.size()); + out_fields.field_names.resize(in_fields.size()); + out_fields.data_types.resize(in_fields.size()); + for (size_t i = 0; i < in_fields.size(); ++i) { + const std::string& name = in_fields[i]->name(); + const std::shared_ptr& type = in_fields[i]->type(); + out_fields.field_paths[i] = static_cast(i); + out_fields.field_names[i] = name; + out_fields.data_types[i] = type; + } + schemas_.push_back(std::make_pair(handle, out_fields)); + return Status::OK(); + } + + Status RegisterProjectedSchema(ProjectionIdEnum handle, + const std::vector& selected_fields, + const Schema& full_schema) { + FieldInfos out_fields; + const FieldVector& in_fields = full_schema.fields(); + out_fields.field_paths.resize(selected_fields.size()); + out_fields.field_names.resize(selected_fields.size()); + out_fields.data_types.resize(selected_fields.size()); + for (size_t i = 0; i < selected_fields.size(); ++i) { + // All fields must be found in schema without ambiguity + ARROW_ASSIGN_OR_RAISE(auto match, selected_fields[i].FindOne(full_schema)); + const std::string& name = in_fields[match[0]]->name(); + const std::shared_ptr& type = in_fields[match[0]]->type(); + out_fields.field_paths[i] = match[0]; + out_fields.field_names[i] = name; + out_fields.data_types[i] = type; + } + schemas_.push_back(std::make_pair(handle, out_fields)); + return Status::OK(); + } + + void RegisterEnd() { + size_t size = schemas_.size(); + mappings_.resize(size); + inverse_mappings_.resize(size); + int id_base = 0; + for (size_t i = 0; i < size; ++i) { + GenerateMapForProjection(static_cast(i), id_base); + } + } + + int schema_id(ProjectionIdEnum schema_handle) const { + for (size_t i = 0; i < schemas_.size(); ++i) { + if (schemas_[i].first == schema_handle) { + return static_cast(i); + } + } + // We should never get here + assert(false); + return -1; + } + + void GenerateMapForProjection(int id_proj, int id_base) { + int num_cols_proj = static_cast(schemas_[id_proj].second.data_types.size()); + int num_cols_base = static_cast(schemas_[id_base].second.data_types.size()); + + std::vector& mapping = mappings_[id_proj]; + std::vector& inverse_mapping = inverse_mappings_[id_proj]; + mapping.resize(num_cols_proj); + inverse_mapping.resize(num_cols_base); + + if (id_proj == id_base) { + for (int i = 0; i < num_cols_base; ++i) { + mapping[i] = inverse_mapping[i] = i; + } + } else { + const FieldInfos& fields_proj = schemas_[id_proj].second; + const FieldInfos& fields_base = schemas_[id_base].second; + for (int i = 0; i < num_cols_base; ++i) { + inverse_mapping[i] = SchemaProjectionMap::kMissingField; + } + for (int i = 0; i < num_cols_proj; ++i) { + int field_id = SchemaProjectionMap::kMissingField; + for (int j = 0; j < num_cols_base; ++j) { + if (fields_proj.field_paths[i] == fields_base.field_paths[j]) { + field_id = j; + // If there are multiple matches for the same input field, + // it will be mapped to the first match. + break; + } + } + assert(field_id != SchemaProjectionMap::kMissingField); + mapping[i] = field_id; + inverse_mapping[field_id] = i; + } + } + } + + // vector used as a mapping from ProjectionIdEnum to fields + std::vector> schemas_; + std::vector> mappings_; + std::vector> inverse_mappings_; +}; + +using HashJoinProjectionMaps = SchemaProjectionMaps; + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/task_util.h b/pyarrow/include/arrow/acero/task_util.h new file mode 100644 index 0000000000000000000000000000000000000000..fbd4af699d12795bd92bd385f23a036d63adde38 --- /dev/null +++ b/pyarrow/include/arrow/acero/task_util.h @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/acero/visibility.h" +#include "arrow/status.h" +#include "arrow/util/config.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace acero { + +// Atomic value surrounded by padding bytes to avoid cache line invalidation +// whenever it is modified by a concurrent thread on a different CPU core. +// +template +class AtomicWithPadding { + private: + static constexpr int kCacheLineSize = 64; + uint8_t padding_before[kCacheLineSize]; + + public: + std::atomic value; + + private: + uint8_t padding_after[kCacheLineSize]; +}; + +// Used for asynchronous execution of operations that can be broken into +// a fixed number of symmetric tasks that can be executed concurrently. +// +// Implements priorities between multiple such operations, called task groups. +// +// Allows to specify the maximum number of in-flight tasks at any moment. +// +// Also allows for executing next pending tasks immediately using a caller thread. +// +class ARROW_ACERO_EXPORT TaskScheduler { + public: + using TaskImpl = std::function; + using TaskGroupContinuationImpl = std::function; + using ScheduleImpl = std::function; + using AbortContinuationImpl = std::function; + + virtual ~TaskScheduler() = default; + + // Order in which task groups are registered represents priorities of their tasks + // (the first group has the highest priority). + // + // Returns task group identifier that is used to request operations on the task group. + virtual int RegisterTaskGroup(TaskImpl task_impl, + TaskGroupContinuationImpl cont_impl) = 0; + + virtual void RegisterEnd() = 0; + + // total_num_tasks may be zero, in which case task group continuation will be executed + // immediately + virtual Status StartTaskGroup(size_t thread_id, int group_id, + int64_t total_num_tasks) = 0; + + // Execute given number of tasks immediately using caller thread + virtual Status ExecuteMore(size_t thread_id, int num_tasks_to_execute, + bool execute_all) = 0; + + // Begin scheduling tasks using provided callback and + // the limit on the number of in-flight tasks at any moment. + // + // Scheduling will continue as long as there are waiting tasks. + // + // It will automatically resume whenever new task group gets started. + virtual Status StartScheduling(size_t thread_id, ScheduleImpl schedule_impl, + int num_concurrent_tasks, bool use_sync_execution) = 0; + + // Abort scheduling and execution. + // Used in case of being notified about unrecoverable error for the entire query. + virtual void Abort(AbortContinuationImpl impl) = 0; + + static std::unique_ptr Make(); +}; + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/test_nodes.h b/pyarrow/include/arrow/acero/test_nodes.h new file mode 100644 index 0000000000000000000000000000000000000000..7e31aa31b34d7b423ab85ff2e77c1cec0087fa5b --- /dev/null +++ b/pyarrow/include/arrow/acero/test_nodes.h @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/acero/options.h" +#include "arrow/acero/test_util_internal.h" +#include "arrow/testing/random.h" + +namespace arrow { +namespace acero { + +// \brief Make a delaying source that is optionally noisy (prints when it emits) +AsyncGenerator> MakeDelayedGen( + Iterator> src, std::string label, double delay_sec, + bool noisy = false); + +// \brief Make a delaying source that is optionally noisy (prints when it emits) +AsyncGenerator> MakeDelayedGen( + AsyncGenerator> src, std::string label, double delay_sec, + bool noisy = false); + +// \brief Make a delaying source that is optionally noisy (prints when it emits) +AsyncGenerator> MakeDelayedGen(BatchesWithSchema src, + std::string label, + double delay_sec, + bool noisy = false); + +/// A node that slightly resequences the input at random +struct JitterNodeOptions : public ExecNodeOptions { + random::SeedType seed; + /// The max amount to add to a node's "cost". + int max_jitter_modifier; + + explicit JitterNodeOptions(random::SeedType seed, int max_jitter_modifier = 5) + : seed(seed), max_jitter_modifier(max_jitter_modifier) {} + static constexpr std::string_view kName = "jitter"; +}; + +class GateImpl; + +class Gate { + public: + static std::shared_ptr Make(); + + Gate(); + virtual ~Gate(); + + void ReleaseAllBatches(); + void ReleaseOneBatch(); + Future<> WaitForNextReleasedBatch(); + + private: + ARROW_DISALLOW_COPY_AND_ASSIGN(Gate); + + GateImpl* impl_; +}; + +// A node that holds all input batches until a given gate is released +struct GatedNodeOptions : public ExecNodeOptions { + explicit GatedNodeOptions(Gate* gate) : gate(gate) {} + Gate* gate; + + static constexpr std::string_view kName = "gated"; +}; + +void RegisterTestNodes(); + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/time_series_util.h b/pyarrow/include/arrow/acero/time_series_util.h new file mode 100644 index 0000000000000000000000000000000000000000..97707f43bf20b95387f463a9c07e37f54c33998c --- /dev/null +++ b/pyarrow/include/arrow/acero/time_series_util.h @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/record_batch.h" +#include "arrow/type_traits.h" + +namespace arrow::acero { + +// normalize the value to unsigned 64-bits while preserving ordering of values +template ::value, bool> = true> +uint64_t NormalizeTime(T t); + +uint64_t GetTime(const RecordBatch* batch, Type::type time_type, int col, uint64_t row); + +} // namespace arrow::acero diff --git a/pyarrow/include/arrow/acero/tpch_node.h b/pyarrow/include/arrow/acero/tpch_node.h new file mode 100644 index 0000000000000000000000000000000000000000..e6476b57ad6b4108af56777c029d932f4af94726 --- /dev/null +++ b/pyarrow/include/arrow/acero/tpch_node.h @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/acero/type_fwd.h" +#include "arrow/acero/visibility.h" +#include "arrow/result.h" +#include "arrow/status.h" + +namespace arrow { +namespace acero { +namespace internal { + +class ARROW_ACERO_EXPORT TpchGen { + public: + virtual ~TpchGen() = default; + + /* + * \brief Create a factory for nodes that generate TPC-H data + * + * Note: Individual tables will reference each other. It is important that you only + * create a single TpchGen instance for each plan and then you can create nodes for each + * table from that single TpchGen instance. Note: Every batch will be scheduled as a new + * task using the ExecPlan's scheduler. + */ + static Result> Make( + ExecPlan* plan, double scale_factor = 1.0, int64_t batch_size = 4096, + std::optional seed = std::nullopt); + + // The below methods will create and add an ExecNode to the plan that generates + // data for the desired table. If columns is empty, all columns will be generated. + // The methods return the added ExecNode, which should be used for inputs. + virtual Result Supplier(std::vector columns = {}) = 0; + virtual Result Part(std::vector columns = {}) = 0; + virtual Result PartSupp(std::vector columns = {}) = 0; + virtual Result Customer(std::vector columns = {}) = 0; + virtual Result Orders(std::vector columns = {}) = 0; + virtual Result Lineitem(std::vector columns = {}) = 0; + virtual Result Nation(std::vector columns = {}) = 0; + virtual Result Region(std::vector columns = {}) = 0; +}; + +} // namespace internal +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/type_fwd.h b/pyarrow/include/arrow/acero/type_fwd.h new file mode 100644 index 0000000000000000000000000000000000000000..f0410de9f7830a7d0e55a04eb514ae9d82e6958c --- /dev/null +++ b/pyarrow/include/arrow/acero/type_fwd.h @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/compute/type_fwd.h" + +namespace arrow { + +namespace acero { + +class ExecNode; +class ExecPlan; +class ExecNodeOptions; +class ExecFactoryRegistry; +class QueryContext; +struct QueryOptions; +struct Declaration; +class SinkNodeConsumer; + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/util.h b/pyarrow/include/arrow/acero/util.h new file mode 100644 index 0000000000000000000000000000000000000000..ee46e8527422abae4f97804058639593dd6b159c --- /dev/null +++ b/pyarrow/include/arrow/acero/util.h @@ -0,0 +1,184 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/acero/options.h" +#include "arrow/acero/type_fwd.h" +#include "arrow/buffer.h" +#include "arrow/compute/expression.h" +#include "arrow/compute/util.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/cpu_info.h" +#include "arrow/util/logging.h" +#include "arrow/util/mutex.h" +#include "arrow/util/thread_pool.h" +#include "arrow/util/type_fwd.h" + +namespace arrow { + +namespace acero { + +ARROW_ACERO_EXPORT +Status ValidateExecNodeInputs(ExecPlan* plan, const std::vector& inputs, + int expected_num_inputs, const char* kind_name); + +ARROW_ACERO_EXPORT +Result> TableFromExecBatches( + const std::shared_ptr& schema, const std::vector& exec_batches); + +class ARROW_ACERO_EXPORT AtomicCounter { + public: + AtomicCounter() = default; + + int count() const { return count_.load(); } + + std::optional total() const { + int total = total_.load(); + if (total == -1) return {}; + return total; + } + + // return true if the counter is complete + bool Increment() { + ARROW_DCHECK_NE(count_.load(), total_.load()); + int count = count_.fetch_add(1) + 1; + if (count != total_.load()) return false; + return DoneOnce(); + } + + // return true if the counter is complete + bool SetTotal(int total) { + total_.store(total); + if (count_.load() != total) return false; + return DoneOnce(); + } + + // return true if the counter has not already been completed + bool Cancel() { return DoneOnce(); } + + // return true if the counter has finished or been cancelled + bool Completed() { return complete_.load(); } + + private: + // ensure there is only one true return from Increment(), SetTotal(), or Cancel() + bool DoneOnce() { + bool expected = false; + return complete_.compare_exchange_strong(expected, true); + } + + std::atomic count_{0}, total_{-1}; + std::atomic complete_{false}; +}; + +class ARROW_ACERO_EXPORT ThreadIndexer { + public: + size_t operator()(); + + static size_t Capacity(); + + private: + static size_t Check(size_t thread_index); + + arrow::util::Mutex mutex_; + std::unordered_map id_to_index_; +}; + +/// \brief A consumer that collects results into an in-memory table +struct ARROW_ACERO_EXPORT TableSinkNodeConsumer : public SinkNodeConsumer { + public: + TableSinkNodeConsumer(std::shared_ptr
* out, MemoryPool* pool) + : out_(out), pool_(pool) {} + Status Init(const std::shared_ptr& schema, + BackpressureControl* backpressure_control, ExecPlan* plan) override; + Status Consume(ExecBatch batch) override; + Future<> Finish() override; + + private: + std::shared_ptr
* out_; + MemoryPool* pool_; + std::shared_ptr schema_; + std::vector> batches_; + arrow::util::Mutex consume_mutex_; +}; + +class ARROW_ACERO_EXPORT NullSinkNodeConsumer : public SinkNodeConsumer { + public: + Status Init(const std::shared_ptr&, BackpressureControl*, + ExecPlan* plan) override { + return Status::OK(); + } + Status Consume(ExecBatch exec_batch) override { return Status::OK(); } + Future<> Finish() override { return Status::OK(); } + + public: + static std::shared_ptr Make() { + return std::make_shared(); + } +}; + +/// CRTP helper for tracing helper functions + +class ARROW_ACERO_EXPORT TracedNode { + public: + // All nodes should call TraceStartProducing or NoteStartProducing exactly once + // Most nodes will be fine with a call to NoteStartProducing since the StartProducing + // call is usually fairly cheap and simply schedules tasks to fetch the actual data. + + explicit TracedNode(ExecNode* node) : node_(node) {} + + // Create a span to record the StartProducing work + [[nodiscard]] ::arrow::internal::tracing::Scope TraceStartProducing( + std::string extra_details) const; + + // Record a call to StartProducing without creating with a span + void NoteStartProducing(std::string extra_details) const; + + // All nodes should call TraceInputReceived for each batch they receive. This call + // should track the time spent processing the batch. NoteInputReceived is available + // but usually won't be used unless a node is simply adding batches to a trivial queue. + + // Create a span to record the InputReceived work + [[nodiscard]] ::arrow::internal::tracing::Scope TraceInputReceived( + const ExecBatch& batch) const; + + // Record a call to InputReceived without creating with a span + void NoteInputReceived(const ExecBatch& batch) const; + + // Create a span to record any "finish" work. This should NOT be called as part of + // InputFinished and many nodes may not need to call this at all. This should be used + // when a node has some extra work that has to be done once it has received all of its + // data. For example, an aggregation node calculating aggregations. This will + // typically be called as a result of InputFinished OR InputReceived. + [[nodiscard]] ::arrow::internal::tracing::Scope TraceFinish() const; + + private: + ExecNode* node_; +}; + +} // namespace acero +} // namespace arrow diff --git a/pyarrow/include/arrow/acero/visibility.h b/pyarrow/include/arrow/acero/visibility.h new file mode 100644 index 0000000000000000000000000000000000000000..21a697a56eca962602b34b2766d74442d185c3d7 --- /dev/null +++ b/pyarrow/include/arrow/acero/visibility.h @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#if defined(_WIN32) || defined(__CYGWIN__) +# if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4251) +# else +# pragma GCC diagnostic ignored "-Wattributes" +# endif + +# ifdef ARROW_ACERO_STATIC +# define ARROW_ACERO_EXPORT +# elif defined(ARROW_ACERO_EXPORTING) +# define ARROW_ACERO_EXPORT __declspec(dllexport) +# else +# define ARROW_ACERO_EXPORT __declspec(dllimport) +# endif + +# define ARROW_ACERO_NO_EXPORT +#else // Not Windows +# ifndef ARROW_ACERO_EXPORT +# define ARROW_ACERO_EXPORT __attribute__((visibility("default"))) +# endif +# ifndef ARROW_ACERO_NO_EXPORT +# define ARROW_ACERO_NO_EXPORT __attribute__((visibility("hidden"))) +# endif +#endif // Not-Windows + +#if defined(_MSC_VER) +# pragma warning(pop) +#endif diff --git a/pyarrow/include/arrow/adapters/orc/adapter.h b/pyarrow/include/arrow/adapters/orc/adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..4ffff81f355f1ddcdc19516746c61b8021477de4 --- /dev/null +++ b/pyarrow/include/arrow/adapters/orc/adapter.h @@ -0,0 +1,323 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/adapters/orc/options.h" +#include "arrow/io/interfaces.h" +#include "arrow/memory_pool.h" +#include "arrow/record_batch.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace adapters { +namespace orc { + +/// \brief Information about an ORC stripe +struct StripeInformation { + /// \brief Offset of the stripe from the start of the file, in bytes + int64_t offset; + /// \brief Length of the stripe, in bytes + int64_t length; + /// \brief Number of rows in the stripe + int64_t num_rows; + /// \brief Index of the first row of the stripe + int64_t first_row_id; +}; + +/// \class ORCFileReader +/// \brief Read an Arrow Table or RecordBatch from an ORC file. +class ARROW_EXPORT ORCFileReader { + public: + ~ORCFileReader(); + + /// \brief Creates a new ORC reader + /// + /// \param[in] file the data source + /// \param[in] pool a MemoryPool to use for buffer allocations + /// \return the returned reader object + static Result> Open( + const std::shared_ptr& file, MemoryPool* pool); + + /// \brief Return the schema read from the ORC file + /// + /// \return the returned Schema object + Result> ReadSchema(); + + /// \brief Read the file as a Table + /// + /// The table will be composed of one record batch per stripe. + /// + /// \return the returned Table + Result> Read(); + + /// \brief Read the file as a Table + /// + /// The table will be composed of one record batch per stripe. + /// + /// \param[in] schema the Table schema + /// \return the returned Table + Result> Read(const std::shared_ptr& schema); + + /// \brief Read the file as a Table + /// + /// The table will be composed of one record batch per stripe. + /// + /// \param[in] include_indices the selected field indices to read + /// \return the returned Table + Result> Read(const std::vector& include_indices); + + /// \brief Read the file as a Table + /// + /// The table will be composed of one record batch per stripe. + /// + /// \param[in] include_names the selected field names to read + /// \return the returned Table + Result> Read(const std::vector& include_names); + + /// \brief Read the file as a Table + /// + /// The table will be composed of one record batch per stripe. + /// + /// \param[in] schema the Table schema + /// \param[in] include_indices the selected field indices to read + /// \return the returned Table + Result> Read(const std::shared_ptr& schema, + const std::vector& include_indices); + + /// \brief Read a single stripe as a RecordBatch + /// + /// \param[in] stripe the stripe index + /// \return the returned RecordBatch + Result> ReadStripe(int64_t stripe); + + /// \brief Read a single stripe as a RecordBatch + /// + /// \param[in] stripe the stripe index + /// \param[in] include_indices the selected field indices to read + /// \return the returned RecordBatch + Result> ReadStripe( + int64_t stripe, const std::vector& include_indices); + + /// \brief Read a single stripe as a RecordBatch + /// + /// \param[in] stripe the stripe index + /// \param[in] include_names the selected field names to read + /// \return the returned RecordBatch + Result> ReadStripe( + int64_t stripe, const std::vector& include_names); + + /// \brief Seek to designated row. Invoke NextStripeReader() after seek + /// will return stripe reader starting from designated row. + /// + /// \param[in] row_number the rows number to seek + Status Seek(int64_t row_number); + + /// \brief Get a stripe level record batch iterator. + /// + /// Each record batch will have up to `batch_size` rows. + /// NextStripeReader serves as a fine-grained alternative to ReadStripe + /// which may cause OOM issues by loading the whole stripe into memory. + /// + /// Note this will only read rows for the current stripe, not the entire + /// file. + /// + /// \param[in] batch_size the maximum number of rows in each record batch + /// \return the returned stripe reader + Result> NextStripeReader(int64_t batch_size); + + /// \brief Get a stripe level record batch iterator. + /// + /// Each record batch will have up to `batch_size` rows. + /// NextStripeReader serves as a fine-grained alternative to ReadStripe + /// which may cause OOM issues by loading the whole stripe into memory. + /// + /// Note this will only read rows for the current stripe, not the entire + /// file. + /// + /// \param[in] batch_size the maximum number of rows in each record batch + /// \param[in] include_indices the selected field indices to read + /// \return the stripe reader + Result> NextStripeReader( + int64_t batch_size, const std::vector& include_indices); + + /// \brief Get a record batch iterator for the entire file. + /// + /// Each record batch will have up to `batch_size` rows. + /// + /// \param[in] batch_size the maximum number of rows in each record batch + /// \param[in] include_names the selected field names to read, if not empty + /// (otherwise all fields are read) + /// \return the record batch iterator + Result> GetRecordBatchReader( + int64_t batch_size, const std::vector& include_names); + + /// \brief The number of stripes in the file + int64_t NumberOfStripes(); + + /// \brief The number of rows in the file + int64_t NumberOfRows(); + + /// \brief StripeInformation for each stripe. + StripeInformation GetStripeInformation(int64_t stripe); + + /// \brief Get the format version of the file. + /// Currently known values are 0.11 and 0.12. + /// + /// \return The FileVersion of the ORC file. + FileVersion GetFileVersion(); + + /// \brief Get the software instance and version that wrote this file. + /// + /// \return a user-facing string that specifies the software version + std::string GetSoftwareVersion(); + + /// \brief Get the compression kind of the file. + /// + /// \return The kind of compression in the ORC file. + Result GetCompression(); + + /// \brief Get the buffer size for the compression. + /// + /// \return Number of bytes to buffer for the compression codec. + int64_t GetCompressionSize(); + + /// \brief Get the number of rows per an entry in the row index. + /// \return the number of rows per an entry in the row index or 0 if there + /// is no row index. + int64_t GetRowIndexStride(); + + /// \brief Get ID of writer that generated the file. + /// + /// \return UNKNOWN_WRITER if the writer ID is undefined + WriterId GetWriterId(); + + /// \brief Get the writer id value when getWriterId() returns an unknown writer. + /// + /// \return the integer value of the writer ID. + int32_t GetWriterIdValue(); + + /// \brief Get the version of the writer. + /// + /// \return the version of the writer. + + WriterVersion GetWriterVersion(); + + /// \brief Get the number of stripe statistics in the file. + /// + /// \return the number of stripe statistics + int64_t GetNumberOfStripeStatistics(); + + /// \brief Get the length of the data stripes in the file. + /// + /// \return return the number of bytes in stripes + int64_t GetContentLength(); + + /// \brief Get the length of the file stripe statistics. + /// + /// \return the number of compressed bytes in the file stripe statistics + int64_t GetStripeStatisticsLength(); + + /// \brief Get the length of the file footer. + /// + /// \return the number of compressed bytes in the file footer + int64_t GetFileFooterLength(); + + /// \brief Get the length of the file postscript. + /// + /// \return the number of bytes in the file postscript + int64_t GetFilePostscriptLength(); + + /// \brief Get the total length of the file. + /// + /// \return the number of bytes in the file + int64_t GetFileLength(); + + /// \brief Get the serialized file tail. + /// Useful if another reader of the same file wants to avoid re-reading + /// the file tail. See ReadOptions.SetSerializedFileTail(). + /// + /// \return a string of bytes with the file tail + std::string GetSerializedFileTail(); + + /// \brief Return the metadata read from the ORC file + /// + /// \return A KeyValueMetadata object containing the ORC metadata + Result> ReadMetadata(); + + private: + class Impl; + std::unique_ptr impl_; + ORCFileReader(); +}; + +/// \class ORCFileWriter +/// \brief Write an Arrow Table or RecordBatch to an ORC file. +class ARROW_EXPORT ORCFileWriter { + public: + ~ORCFileWriter(); + /// \brief Creates a new ORC writer. + /// + /// \param[in] output_stream a pointer to the io::OutputStream to write into + /// \param[in] write_options the ORC writer options for Arrow + /// \return the returned writer object + static Result> Open( + io::OutputStream* output_stream, + const WriteOptions& write_options = WriteOptions()); + + /// \brief Write a table. This can be called multiple times. + /// + /// Tables passed in subsequent calls must match the schema of the table that was + /// written first. + /// + /// \param[in] table the Arrow table from which data is extracted. + /// \return Status + Status Write(const Table& table); + + /// \brief Write a RecordBatch. This can be called multiple times. + /// + /// RecordBatches passed in subsequent calls must match the schema of the + /// RecordBatch that was written first. + /// + /// \param[in] record_batch the Arrow RecordBatch from which data is extracted. + /// \return Status + Status Write(const RecordBatch& record_batch); + + /// \brief Close an ORC writer (orc::Writer) + /// + /// \return Status + Status Close(); + + private: + class Impl; + std::unique_ptr impl_; + + private: + ORCFileWriter(); +}; + +} // namespace orc +} // namespace adapters +} // namespace arrow diff --git a/pyarrow/include/arrow/adapters/orc/options.h b/pyarrow/include/arrow/adapters/orc/options.h new file mode 100644 index 0000000000000000000000000000000000000000..3a300da678db98c24949203be7ab471a57502640 --- /dev/null +++ b/pyarrow/include/arrow/adapters/orc/options.h @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/io/interfaces.h" +#include "arrow/status.h" +#include "arrow/util/type_fwd.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +namespace adapters { + +namespace orc { + +enum class WriterId : int32_t { + kOrcJava = 0, + kOrcCpp = 1, + kPresto = 2, + kScritchleyGo = 3, + kTrino = 4, + kUnknown = INT32_MAX +}; + +enum class WriterVersion : int32_t { + kOriginal = 0, + kHive8732 = 1, + kHive4243 = 2, + kHive12055 = 3, + kHive13083 = 4, + kOrc101 = 5, + kOrc135 = 6, + kOrc517 = 7, + kOrc203 = 8, + kOrc14 = 9, + kMax = INT32_MAX +}; + +enum class CompressionStrategy : int32_t { kSpeed = 0, kCompression }; + +class ARROW_EXPORT FileVersion { + private: + int32_t major_version_; + int32_t minor_version_; + + public: + static const FileVersion& v_0_11(); + static const FileVersion& v_0_12(); + + FileVersion(int32_t major, int32_t minor) + : major_version_(major), minor_version_(minor) {} + + /** + * Get major version + */ + int32_t major_version() const { return this->major_version_; } + + /** + * Get minor version + */ + int32_t minor_version() const { return this->minor_version_; } + + bool operator==(const FileVersion& right) const { + return this->major_version() == right.major_version() && + this->minor_version() == right.minor_version(); + } + + bool operator!=(const FileVersion& right) const { return !(*this == right); } + + std::string ToString() const; +}; + +/// Options for the ORC Writer +struct ARROW_EXPORT WriteOptions { + /// Number of rows the ORC writer writes at a time, default 1024 + int64_t batch_size = 1024; + /// Which ORC file version to use, default FileVersion(0, 12) + FileVersion file_version = FileVersion(0, 12); + /// Size of each ORC stripe in bytes, default 64 MiB + int64_t stripe_size = 64 * 1024 * 1024; + /// The compression codec of the ORC file, there is no compression by default + Compression::type compression = Compression::UNCOMPRESSED; + /// The size of each compression block in bytes, default 64 KiB + int64_t compression_block_size = 64 * 1024; + /// The compression strategy i.e. speed vs size reduction, default + /// CompressionStrategy::kSpeed + CompressionStrategy compression_strategy = CompressionStrategy::kSpeed; + /// The number of rows per an entry in the row index, default 10000 + int64_t row_index_stride = 10000; + /// The padding tolerance, default 0.0 + double padding_tolerance = 0.0; + /// The dictionary key size threshold. 0 to disable dictionary encoding. + /// 1 to always enable dictionary encoding, default 0.0 + double dictionary_key_size_threshold = 0.0; + /// The array of columns that use the bloom filter, default empty + std::vector bloom_filter_columns; + /// The upper limit of the false-positive rate of the bloom filter, default 0.05 + double bloom_filter_fpp = 0.05; +}; + +} // namespace orc +} // namespace adapters +} // namespace arrow diff --git a/pyarrow/include/arrow/adapters/tensorflow/convert.h b/pyarrow/include/arrow/adapters/tensorflow/convert.h new file mode 100644 index 0000000000000000000000000000000000000000..9d093eddf6b598150ddb55da0e84699a5b7ef4b8 --- /dev/null +++ b/pyarrow/include/arrow/adapters/tensorflow/convert.h @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "tensorflow/core/framework/op.h" + +#include "arrow/type.h" + +// These utilities are supposed to be included in TensorFlow operators +// that need to be compiled separately from Arrow because of ABI issues. +// They therefore need to be header-only. + +namespace arrow { + +namespace adapters { + +namespace tensorflow { + +Status GetArrowType(::tensorflow::DataType dtype, std::shared_ptr* out) { + switch (dtype) { + case ::tensorflow::DT_BOOL: + *out = arrow::boolean(); + break; + case ::tensorflow::DT_FLOAT: + *out = arrow::float32(); + break; + case ::tensorflow::DT_DOUBLE: + *out = arrow::float64(); + break; + case ::tensorflow::DT_HALF: + *out = arrow::float16(); + break; + case ::tensorflow::DT_INT8: + *out = arrow::int8(); + break; + case ::tensorflow::DT_INT16: + *out = arrow::int16(); + break; + case ::tensorflow::DT_INT32: + *out = arrow::int32(); + break; + case ::tensorflow::DT_INT64: + *out = arrow::int64(); + break; + case ::tensorflow::DT_UINT8: + *out = arrow::uint8(); + break; + case ::tensorflow::DT_UINT16: + *out = arrow::uint16(); + break; + case ::tensorflow::DT_UINT32: + *out = arrow::uint32(); + break; + case ::tensorflow::DT_UINT64: + *out = arrow::uint64(); + break; + default: + return Status::TypeError("TensorFlow data type is not supported"); + } + return Status::OK(); +} + +Status GetTensorFlowType(std::shared_ptr dtype, ::tensorflow::DataType* out) { + switch (dtype->id()) { + case Type::BOOL: + *out = ::tensorflow::DT_BOOL; + break; + case Type::UINT8: + *out = ::tensorflow::DT_UINT8; + break; + case Type::INT8: + *out = ::tensorflow::DT_INT8; + break; + case Type::UINT16: + *out = ::tensorflow::DT_UINT16; + break; + case Type::INT16: + *out = ::tensorflow::DT_INT16; + break; + case Type::UINT32: + *out = ::tensorflow::DT_UINT32; + break; + case Type::INT32: + *out = ::tensorflow::DT_INT32; + break; + case Type::UINT64: + *out = ::tensorflow::DT_UINT64; + break; + case Type::INT64: + *out = ::tensorflow::DT_INT64; + break; + case Type::HALF_FLOAT: + *out = ::tensorflow::DT_HALF; + break; + case Type::FLOAT: + *out = ::tensorflow::DT_FLOAT; + break; + case Type::DOUBLE: + *out = ::tensorflow::DT_DOUBLE; + break; + default: + return Status::TypeError("Arrow data type is not supported"); + } + return arrow::Status::OK(); +} + +} // namespace tensorflow + +} // namespace adapters + +} // namespace arrow diff --git a/pyarrow/include/arrow/api.h b/pyarrow/include/arrow/api.h new file mode 100644 index 0000000000000000000000000000000000000000..ac568a00eedc32984758f4675b58ac626c9c947a --- /dev/null +++ b/pyarrow/include/arrow/api.h @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Coarse public API while the library is in development + +#pragma once + +#include "arrow/array.h" // IWYU pragma: export +#include "arrow/array/array_run_end.h" // IWYU pragma: export +#include "arrow/array/concatenate.h" // IWYU pragma: export +#include "arrow/buffer.h" // IWYU pragma: export +#include "arrow/builder.h" // IWYU pragma: export +#include "arrow/chunked_array.h" // IWYU pragma: export +#include "arrow/compare.h" // IWYU pragma: export +#include "arrow/config.h" // IWYU pragma: export +#include "arrow/datum.h" // IWYU pragma: export +#include "arrow/extension_type.h" // IWYU pragma: export +#include "arrow/memory_pool.h" // IWYU pragma: export +#include "arrow/pretty_print.h" // IWYU pragma: export +#include "arrow/record_batch.h" // IWYU pragma: export +#include "arrow/result.h" // IWYU pragma: export +#include "arrow/status.h" // IWYU pragma: export +#include "arrow/table.h" // IWYU pragma: export +#include "arrow/table_builder.h" // IWYU pragma: export +#include "arrow/tensor.h" // IWYU pragma: export +#include "arrow/type.h" // IWYU pragma: export +#include "arrow/util/key_value_metadata.h" // IWYU pragma: export +#include "arrow/visit_array_inline.h" // IWYU pragma: export +#include "arrow/visit_scalar_inline.h" // IWYU pragma: export +#include "arrow/visitor.h" // IWYU pragma: export + +/// \brief Top-level namespace for Apache Arrow C++ API +namespace arrow {} diff --git a/pyarrow/include/arrow/array.h b/pyarrow/include/arrow/array.h new file mode 100644 index 0000000000000000000000000000000000000000..4d72ea9506a414fd6e50d5c7d0af437084045e05 --- /dev/null +++ b/pyarrow/include/arrow/array.h @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Kitchen-sink public API for arrow::Array data structures. C++ library code +// (especially header files) in Apache Arrow should use more specific headers +// unless it's a file that uses most or all Array types in which case using +// arrow/array.h is fine. + +#pragma once + +/// \defgroup numeric-arrays Concrete classes for numeric arrays +/// @{ +/// @} + +/// \defgroup binary-arrays Concrete classes for binary/string arrays +/// @{ +/// @} + +/// \defgroup nested-arrays Concrete classes for nested arrays +/// @{ +/// @} + +/// \defgroup run-end-encoded-arrays Concrete classes for run-end encoded arrays +/// @{ +/// @} + +#include "arrow/array/array_base.h" // IWYU pragma: keep +#include "arrow/array/array_binary.h" // IWYU pragma: keep +#include "arrow/array/array_decimal.h" // IWYU pragma: keep +#include "arrow/array/array_dict.h" // IWYU pragma: keep +#include "arrow/array/array_nested.h" // IWYU pragma: keep +#include "arrow/array/array_primitive.h" // IWYU pragma: keep +#include "arrow/array/array_run_end.h" // IWYU pragma: keep +#include "arrow/array/data.h" // IWYU pragma: keep +#include "arrow/array/util.h" // IWYU pragma: keep diff --git a/pyarrow/include/arrow/array/array_base.h b/pyarrow/include/arrow/array/array_base.h new file mode 100644 index 0000000000000000000000000000000000000000..60df45357e5d2fd8bc31cfa714aa4d9d89288508 --- /dev/null +++ b/pyarrow/include/arrow/array/array_base.h @@ -0,0 +1,323 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/array/data.h" +#include "arrow/buffer.h" +#include "arrow/compare.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" +#include "arrow/visitor.h" + +namespace arrow { + +// ---------------------------------------------------------------------- +// User array accessor types + +/// \brief Array base type +/// Immutable data array with some logical type and some length. +/// +/// Any memory is owned by the respective Buffer instance (or its parents). +/// +/// The base class is only required to have a null bitmap buffer if the null +/// count is greater than 0 +/// +/// If known, the null count can be provided in the base Array constructor. If +/// the null count is not known, pass -1 to indicate that the null count is to +/// be computed on the first call to null_count() +class ARROW_EXPORT Array { + public: + virtual ~Array() = default; + + /// \brief Return true if value at index is null. Does not boundscheck + bool IsNull(int64_t i) const { return !IsValid(i); } + + /// \brief Return true if value at index is valid (not null). Does not + /// boundscheck + bool IsValid(int64_t i) const { + if (null_bitmap_data_ != NULLPTR) { + return bit_util::GetBit(null_bitmap_data_, i + data_->offset); + } + // Dispatching with a few conditionals like this makes IsNull more + // efficient for how it is used in practice. Making IsNull virtual + // would add a vtable lookup to every call and prevent inlining + + // a potential inner-branch removal. + if (type_id() == Type::SPARSE_UNION) { + return !internal::IsNullSparseUnion(*data_, i); + } + if (type_id() == Type::DENSE_UNION) { + return !internal::IsNullDenseUnion(*data_, i); + } + if (type_id() == Type::RUN_END_ENCODED) { + return !internal::IsNullRunEndEncoded(*data_, i); + } + return data_->null_count != data_->length; + } + + /// \brief Return a Scalar containing the value of this array at i + Result> GetScalar(int64_t i) const; + + /// Size in the number of elements this array contains. + int64_t length() const { return data_->length; } + + /// A relative position into another array's data, to enable zero-copy + /// slicing. This value defaults to zero + int64_t offset() const { return data_->offset; } + + /// The number of null entries in the array. If the null count was not known + /// at time of construction (and set to a negative value), then the null + /// count will be computed and cached on the first invocation of this + /// function + int64_t null_count() const; + + /// \brief Computes the logical null count for arrays of all types including + /// those that do not have a validity bitmap like union and run-end encoded + /// arrays + /// + /// If the array has a validity bitmap, this function behaves the same as + /// null_count(). For types that have no validity bitmap, this function will + /// recompute the null count every time it is called. + /// + /// \see GetNullCount + int64_t ComputeLogicalNullCount() const; + + const std::shared_ptr& type() const { return data_->type; } + Type::type type_id() const { return data_->type->id(); } + + /// Buffer for the validity (null) bitmap, if any. Note that Union types + /// never have a null bitmap. + /// + /// Note that for `null_count == 0` or for null type, this will be null. + /// This buffer does not account for any slice offset + const std::shared_ptr& null_bitmap() const { return data_->buffers[0]; } + + /// Raw pointer to the null bitmap. + /// + /// Note that for `null_count == 0` or for null type, this will be null. + /// This buffer does not account for any slice offset + const uint8_t* null_bitmap_data() const { return null_bitmap_data_; } + + /// Equality comparison with another array + /// + /// Note that arrow::ArrayStatistics is not included in the comparison. + bool Equals(const Array& arr, const EqualOptions& = EqualOptions::Defaults()) const; + bool Equals(const std::shared_ptr& arr, + const EqualOptions& = EqualOptions::Defaults()) const; + + /// \brief Return the formatted unified diff of arrow::Diff between this + /// Array and another Array + std::string Diff(const Array& other) const; + + /// Approximate equality comparison with another array + /// + /// epsilon is only used if this is FloatArray or DoubleArray + /// + /// Note that arrow::ArrayStatistics is not included in the comparison. + bool ApproxEquals(const std::shared_ptr& arr, + const EqualOptions& = EqualOptions::Defaults()) const; + bool ApproxEquals(const Array& arr, + const EqualOptions& = EqualOptions::Defaults()) const; + + /// Compare if the range of slots specified are equal for the given array and + /// this array. end_idx exclusive. This methods does not bounds check. + /// + /// Note that arrow::ArrayStatistics is not included in the comparison. + bool RangeEquals(int64_t start_idx, int64_t end_idx, int64_t other_start_idx, + const Array& other, + const EqualOptions& = EqualOptions::Defaults()) const; + bool RangeEquals(int64_t start_idx, int64_t end_idx, int64_t other_start_idx, + const std::shared_ptr& other, + const EqualOptions& = EqualOptions::Defaults()) const; + bool RangeEquals(const Array& other, int64_t start_idx, int64_t end_idx, + int64_t other_start_idx, + const EqualOptions& = EqualOptions::Defaults()) const; + bool RangeEquals(const std::shared_ptr& other, int64_t start_idx, + int64_t end_idx, int64_t other_start_idx, + const EqualOptions& = EqualOptions::Defaults()) const; + + /// \brief Apply the ArrayVisitor::Visit() method specialized to the array type + Status Accept(ArrayVisitor* visitor) const; + + /// Construct a zero-copy view of this array with the given type. + /// + /// This method checks if the types are layout-compatible. + /// Nested types are traversed in depth-first order. Data buffers must have + /// the same item sizes, even though the logical types may be different. + /// An error is returned if the types are not layout-compatible. + Result> View(const std::shared_ptr& type) const; + + /// \brief Construct a copy of the array with all buffers on destination + /// Memory Manager + /// + /// This method recursively copies the array's buffers and those of its children + /// onto the destination MemoryManager device and returns the new Array. + Result> CopyTo(const std::shared_ptr& to) const; + + /// \brief Construct a new array attempting to zero-copy view if possible. + /// + /// Like CopyTo this method recursively goes through all of the array's buffers + /// and those of it's children and first attempts to create zero-copy + /// views on the destination MemoryManager device. If it can't, it falls back + /// to performing a copy. See Buffer::ViewOrCopy. + Result> ViewOrCopyTo( + const std::shared_ptr& to) const; + + /// Construct a zero-copy slice of the array with the indicated offset and + /// length + /// + /// \param[in] offset the position of the first element in the constructed + /// slice + /// \param[in] length the length of the slice. If there are not enough + /// elements in the array, the length will be adjusted accordingly + /// + /// \return a new object wrapped in std::shared_ptr + std::shared_ptr Slice(int64_t offset, int64_t length) const; + + /// Slice from offset until end of the array + std::shared_ptr Slice(int64_t offset) const; + + /// Input-checking variant of Array::Slice + Result> SliceSafe(int64_t offset, int64_t length) const; + /// Input-checking variant of Array::Slice + Result> SliceSafe(int64_t offset) const; + + const std::shared_ptr& data() const { return data_; } + + int num_fields() const { return static_cast(data_->child_data.size()); } + + /// \return PrettyPrint representation of array suitable for debugging + std::string ToString() const; + + /// \brief Perform cheap validation checks to determine obvious inconsistencies + /// within the array's internal data. + /// + /// This is O(k) where k is the number of descendents. + /// + /// \return Status + Status Validate() const; + + /// \brief Perform extensive validation checks to determine inconsistencies + /// within the array's internal data. + /// + /// This is potentially O(k*n) where k is the number of descendents and n + /// is the array length. + /// + /// \return Status + Status ValidateFull() const; + + /// \brief Return the device_type that this array's data is allocated on + /// + /// This just delegates to calling device_type on the underlying ArrayData + /// object which backs this Array. + /// + /// \return DeviceAllocationType + DeviceAllocationType device_type() const { return data_->device_type(); } + + /// \brief Return the statistics of this Array + /// + /// This just delegates to calling statistics on the underlying ArrayData + /// object which backs this Array. + /// + /// \return const std::shared_ptr& + const std::shared_ptr& statistics() const { return data_->statistics; } + + protected: + Array() = default; + ARROW_DEFAULT_MOVE_AND_ASSIGN(Array); + + std::shared_ptr data_; + const uint8_t* null_bitmap_data_ = NULLPTR; + + /// Protected method for constructors + void SetData(const std::shared_ptr& data) { + if (data->buffers.size() > 0) { + null_bitmap_data_ = data->GetValuesSafe(0, /*offset=*/0); + } else { + null_bitmap_data_ = NULLPTR; + } + data_ = data; + } + + private: + ARROW_DISALLOW_COPY_AND_ASSIGN(Array); +}; + +ARROW_EXPORT void PrintTo(const Array& x, std::ostream* os); + +static inline std::ostream& operator<<(std::ostream& os, const Array& x) { + os << x.ToString(); + return os; +} + +/// Base class for non-nested arrays +class ARROW_EXPORT FlatArray : public Array { + protected: + using Array::Array; +}; + +/// Base class for arrays of fixed-size logical types +class ARROW_EXPORT PrimitiveArray : public FlatArray { + public: + /// Does not account for any slice offset + const std::shared_ptr& values() const { return data_->buffers[1]; } + + protected: + PrimitiveArray(const std::shared_ptr& type, int64_t length, + const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + PrimitiveArray() : raw_values_(NULLPTR) {} + + void SetData(const std::shared_ptr& data) { + this->Array::SetData(data); + raw_values_ = data->GetValuesSafe(1, /*offset=*/0); + } + + explicit PrimitiveArray(const std::shared_ptr& data) { SetData(data); } + + const uint8_t* raw_values_; +}; + +/// Degenerate null type Array +class ARROW_EXPORT NullArray : public FlatArray { + public: + using TypeClass = NullType; + + explicit NullArray(const std::shared_ptr& data) { SetData(data); } + explicit NullArray(int64_t length); + + private: + void SetData(const std::shared_ptr& data) { + null_bitmap_data_ = NULLPTR; + data->null_count = data->length; + data_ = data; + } +}; + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/array_binary.h b/pyarrow/include/arrow/array/array_binary.h new file mode 100644 index 0000000000000000000000000000000000000000..63903eac46d413c24ccaeb048273e8f5e6c8d3c6 --- /dev/null +++ b/pyarrow/include/arrow/array/array_binary.h @@ -0,0 +1,321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Array accessor classes for Binary, LargeBinary, String, LargeString, +// FixedSizeBinary + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/data.h" +#include "arrow/buffer.h" +#include "arrow/stl_iterator.h" +#include "arrow/type.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \addtogroup binary-arrays +/// +/// @{ + +// ---------------------------------------------------------------------- +// Binary and String + +/// Base class for variable-sized binary arrays, regardless of offset size +/// and logical interpretation. +template +class BaseBinaryArray : public FlatArray { + public: + using TypeClass = TYPE; + using offset_type = typename TypeClass::offset_type; + using IteratorType = stl::ArrayIterator>; + + /// Return the pointer to the given elements bytes + // XXX should GetValue(int64_t i) return a string_view? + const uint8_t* GetValue(int64_t i, offset_type* out_length) const { + const offset_type pos = raw_value_offsets_[i]; + *out_length = raw_value_offsets_[i + 1] - pos; + return raw_data_ + pos; + } + + /// \brief Get binary value as a string_view + /// + /// \param i the value index + /// \return the view over the selected value + std::string_view GetView(int64_t i) const { + const offset_type pos = raw_value_offsets_[i]; + return std::string_view(reinterpret_cast(raw_data_ + pos), + raw_value_offsets_[i + 1] - pos); + } + + std::optional operator[](int64_t i) const { + return *IteratorType(*this, i); + } + + /// \brief Get binary value as a string_view + /// Provided for consistency with other arrays. + /// + /// \param i the value index + /// \return the view over the selected value + std::string_view Value(int64_t i) const { return GetView(i); } + + /// \brief Get binary value as a std::string + /// + /// \param i the value index + /// \return the value copied into a std::string + std::string GetString(int64_t i) const { return std::string(GetView(i)); } + + /// Note that this buffer does not account for any slice offset + std::shared_ptr value_offsets() const { return data_->buffers[1]; } + + /// Note that this buffer does not account for any slice offset + std::shared_ptr value_data() const { return data_->buffers[2]; } + + const offset_type* raw_value_offsets() const { return raw_value_offsets_; } + + const uint8_t* raw_data() const { return raw_data_; } + + /// \brief Return the data buffer absolute offset of the data for the value + /// at the passed index. + /// + /// Does not perform boundschecking + offset_type value_offset(int64_t i) const { return raw_value_offsets_[i]; } + + /// \brief Return the length of the data for the value at the passed index. + /// + /// Does not perform boundschecking + offset_type value_length(int64_t i) const { + return raw_value_offsets_[i + 1] - raw_value_offsets_[i]; + } + + /// \brief Return the total length of the memory in the data buffer + /// referenced by this array. If the array has been sliced then this may be + /// less than the size of the data buffer (data_->buffers[2]). + offset_type total_values_length() const { + if (data_->length > 0) { + return raw_value_offsets_[data_->length] - raw_value_offsets_[0]; + } else { + return 0; + } + } + + IteratorType begin() const { return IteratorType(*this); } + + IteratorType end() const { return IteratorType(*this, length()); } + + protected: + // For subclasses + BaseBinaryArray() = default; + + // Protected method for constructors + void SetData(const std::shared_ptr& data) { + this->Array::SetData(data); + raw_value_offsets_ = data->GetValuesSafe(1); + raw_data_ = data->GetValuesSafe(2, /*offset=*/0); + } + + const offset_type* raw_value_offsets_ = NULLPTR; + const uint8_t* raw_data_ = NULLPTR; +}; + +/// Concrete Array class for variable-size binary data +class ARROW_EXPORT BinaryArray : public BaseBinaryArray { + public: + explicit BinaryArray(const std::shared_ptr& data); + + BinaryArray(int64_t length, const std::shared_ptr& value_offsets, + const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + protected: + // For subclasses such as StringArray + BinaryArray() : BaseBinaryArray() {} +}; + +/// Concrete Array class for variable-size string (utf-8) data +class ARROW_EXPORT StringArray : public BinaryArray { + public: + using TypeClass = StringType; + + explicit StringArray(const std::shared_ptr& data); + + StringArray(int64_t length, const std::shared_ptr& value_offsets, + const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + /// \brief Validate that this array contains only valid UTF8 entries + /// + /// This check is also implied by ValidateFull() + Status ValidateUTF8() const; +}; + +/// Concrete Array class for large variable-size binary data +class ARROW_EXPORT LargeBinaryArray : public BaseBinaryArray { + public: + explicit LargeBinaryArray(const std::shared_ptr& data); + + LargeBinaryArray(int64_t length, const std::shared_ptr& value_offsets, + const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + protected: + // For subclasses such as LargeStringArray + LargeBinaryArray() : BaseBinaryArray() {} +}; + +/// Concrete Array class for large variable-size string (utf-8) data +class ARROW_EXPORT LargeStringArray : public LargeBinaryArray { + public: + using TypeClass = LargeStringType; + + explicit LargeStringArray(const std::shared_ptr& data); + + LargeStringArray(int64_t length, const std::shared_ptr& value_offsets, + const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + /// \brief Validate that this array contains only valid UTF8 entries + /// + /// This check is also implied by ValidateFull() + Status ValidateUTF8() const; +}; + +// ---------------------------------------------------------------------- +// BinaryView and StringView + +/// Concrete Array class for variable-size binary view data using the +/// BinaryViewType::c_type struct to reference in-line or out-of-line string values +class ARROW_EXPORT BinaryViewArray : public FlatArray { + public: + using TypeClass = BinaryViewType; + using IteratorType = stl::ArrayIterator; + using c_type = BinaryViewType::c_type; + + explicit BinaryViewArray(std::shared_ptr data); + + BinaryViewArray(std::shared_ptr type, int64_t length, + std::shared_ptr views, BufferVector data_buffers, + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + // For API compatibility with BinaryArray etc. + std::string_view GetView(int64_t i) const; + std::string GetString(int64_t i) const { return std::string{GetView(i)}; } + + const auto& values() const { return data_->buffers[1]; } + const c_type* raw_values() const { return raw_values_; } + + std::optional operator[](int64_t i) const { + return *IteratorType(*this, i); + } + + IteratorType begin() const { return IteratorType(*this); } + IteratorType end() const { return IteratorType(*this, length()); } + + protected: + using FlatArray::FlatArray; + + void SetData(std::shared_ptr data) { + FlatArray::SetData(std::move(data)); + raw_values_ = data_->GetValuesSafe(1); + } + + const c_type* raw_values_; +}; + +/// Concrete Array class for variable-size string view (utf-8) data using +/// BinaryViewType::c_type to reference in-line or out-of-line string values +class ARROW_EXPORT StringViewArray : public BinaryViewArray { + public: + using TypeClass = StringViewType; + + explicit StringViewArray(std::shared_ptr data); + + using BinaryViewArray::BinaryViewArray; + + /// \brief Validate that this array contains only valid UTF8 entries + /// + /// This check is also implied by ValidateFull() + Status ValidateUTF8() const; +}; + +// ---------------------------------------------------------------------- +// Fixed width binary + +/// Concrete Array class for fixed-size binary data +class ARROW_EXPORT FixedSizeBinaryArray : public PrimitiveArray { + public: + using TypeClass = FixedSizeBinaryType; + using IteratorType = stl::ArrayIterator; + + explicit FixedSizeBinaryArray(const std::shared_ptr& data); + + FixedSizeBinaryArray(const std::shared_ptr& type, int64_t length, + const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + const uint8_t* GetValue(int64_t i) const { return values_ + i * byte_width_; } + const uint8_t* Value(int64_t i) const { return GetValue(i); } + + std::string_view GetView(int64_t i) const { + return std::string_view(reinterpret_cast(GetValue(i)), byte_width_); + } + + std::optional operator[](int64_t i) const { + return *IteratorType(*this, i); + } + + std::string GetString(int64_t i) const { return std::string(GetView(i)); } + + int32_t byte_width() const { return byte_width_; } + + const uint8_t* raw_values() const { return values_; } + + IteratorType begin() const { return IteratorType(*this); } + + IteratorType end() const { return IteratorType(*this, length()); } + + protected: + void SetData(const std::shared_ptr& data) { + this->PrimitiveArray::SetData(data); + byte_width_ = + internal::checked_cast(*type()).byte_width(); + values_ = raw_values_ + data_->offset * byte_width_; + } + + const uint8_t* values_; + int32_t byte_width_; +}; + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/array_decimal.h b/pyarrow/include/arrow/array/array_decimal.h new file mode 100644 index 0000000000000000000000000000000000000000..2f10bb842999640a8cada703ff12ea29c0e5f718 --- /dev/null +++ b/pyarrow/include/arrow/array/array_decimal.h @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/array/array_binary.h" +#include "arrow/array/data.h" +#include "arrow/type.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \addtogroup numeric-arrays +/// +/// @{ + +// ---------------------------------------------------------------------- +// Decimal32Array + +/// Concrete Array class for 32-bit decimal data +class ARROW_EXPORT Decimal32Array : public FixedSizeBinaryArray { + public: + using TypeClass = Decimal32Type; + + using FixedSizeBinaryArray::FixedSizeBinaryArray; + + /// \brief Construct Decimal32Array from ArrayData instance + explicit Decimal32Array(const std::shared_ptr& data); + + std::string FormatValue(int64_t i) const; +}; + +// ---------------------------------------------------------------------- +// Decimal64Array + +/// Concrete Array class for 64-bit decimal data +class ARROW_EXPORT Decimal64Array : public FixedSizeBinaryArray { + public: + using TypeClass = Decimal64Type; + + using FixedSizeBinaryArray::FixedSizeBinaryArray; + + /// \brief Construct Decimal64Array from ArrayData instance + explicit Decimal64Array(const std::shared_ptr& data); + + std::string FormatValue(int64_t i) const; +}; + +// ---------------------------------------------------------------------- +// Decimal128Array + +/// Concrete Array class for 128-bit decimal data +class ARROW_EXPORT Decimal128Array : public FixedSizeBinaryArray { + public: + using TypeClass = Decimal128Type; + + using FixedSizeBinaryArray::FixedSizeBinaryArray; + + /// \brief Construct Decimal128Array from ArrayData instance + explicit Decimal128Array(const std::shared_ptr& data); + + std::string FormatValue(int64_t i) const; +}; + +// Backward compatibility +using DecimalArray = Decimal128Array; + +// ---------------------------------------------------------------------- +// Decimal256Array + +/// Concrete Array class for 256-bit decimal data +class ARROW_EXPORT Decimal256Array : public FixedSizeBinaryArray { + public: + using TypeClass = Decimal256Type; + + using FixedSizeBinaryArray::FixedSizeBinaryArray; + + /// \brief Construct Decimal256Array from ArrayData instance + explicit Decimal256Array(const std::shared_ptr& data); + + std::string FormatValue(int64_t i) const; +}; + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/array_dict.h b/pyarrow/include/arrow/array/array_dict.h new file mode 100644 index 0000000000000000000000000000000000000000..bf376b51f8c9470d2b4e4c7ed950c9a513fddc9b --- /dev/null +++ b/pyarrow/include/arrow/array/array_dict.h @@ -0,0 +1,182 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/data.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +// ---------------------------------------------------------------------- +// DictionaryArray + +/// \brief Array type for dictionary-encoded data with a +/// data-dependent dictionary +/// +/// A dictionary array contains an array of non-negative integers (the +/// "dictionary indices") along with a data type containing a "dictionary" +/// corresponding to the distinct values represented in the data. +/// +/// For example, the array +/// +/// ["foo", "bar", "foo", "bar", "foo", "bar"] +/// +/// with dictionary ["bar", "foo"], would have dictionary array representation +/// +/// indices: [1, 0, 1, 0, 1, 0] +/// dictionary: ["bar", "foo"] +/// +/// The indices in principle may be any integer type. +class ARROW_EXPORT DictionaryArray : public Array { + public: + using TypeClass = DictionaryType; + + explicit DictionaryArray(const std::shared_ptr& data); + + DictionaryArray(const std::shared_ptr& type, + const std::shared_ptr& indices, + const std::shared_ptr& dictionary); + + /// \brief Construct DictionaryArray from dictionary and indices + /// array and validate + /// + /// This function does the validation of the indices and input type. It checks if + /// all indices are non-negative and smaller than the size of the dictionary. + /// + /// \param[in] type a dictionary type + /// \param[in] dictionary the dictionary with same value type as the + /// type object + /// \param[in] indices an array of non-negative integers smaller than the + /// size of the dictionary + static Result> FromArrays( + const std::shared_ptr& type, const std::shared_ptr& indices, + const std::shared_ptr& dictionary); + + static Result> FromArrays( + const std::shared_ptr& indices, const std::shared_ptr& dictionary) { + return FromArrays(::arrow::dictionary(indices->type(), dictionary->type()), indices, + dictionary); + } + + /// \brief Transpose this DictionaryArray + /// + /// This method constructs a new dictionary array with the given dictionary + /// type, transposing indices using the transpose map. The type and the + /// transpose map are typically computed using DictionaryUnifier. + /// + /// \param[in] type the new type object + /// \param[in] dictionary the new dictionary + /// \param[in] transpose_map transposition array of this array's indices + /// into the target array's indices + /// \param[in] pool a pool to allocate the array data from + Result> Transpose( + const std::shared_ptr& type, const std::shared_ptr& dictionary, + const int32_t* transpose_map, MemoryPool* pool = default_memory_pool()) const; + + Result> Compact(MemoryPool* pool = default_memory_pool()) const; + + /// \brief Determine whether dictionary arrays may be compared without unification + bool CanCompareIndices(const DictionaryArray& other) const; + + /// \brief Return the dictionary for this array, which is stored as + /// a member of the ArrayData internal structure + const std::shared_ptr& dictionary() const; + const std::shared_ptr& indices() const; + + /// \brief Return the ith value of indices, cast to int64_t. Not recommended + /// for use in performance-sensitive code. Does not validate whether the + /// value is null or out-of-bounds. + int64_t GetValueIndex(int64_t i) const; + + const DictionaryType* dict_type() const { return dict_type_; } + + private: + void SetData(const std::shared_ptr& data); + const DictionaryType* dict_type_; + std::shared_ptr indices_; + + // Lazily initialized when invoking dictionary() + mutable std::shared_ptr dictionary_; +}; + +/// \brief Helper class for incremental dictionary unification +class ARROW_EXPORT DictionaryUnifier { + public: + virtual ~DictionaryUnifier() = default; + + /// \brief Construct a DictionaryUnifier + /// \param[in] value_type the data type of the dictionaries + /// \param[in] pool MemoryPool to use for memory allocations + static Result> Make( + std::shared_ptr value_type, MemoryPool* pool = default_memory_pool()); + + /// \brief Unify dictionaries across array chunks + /// + /// The dictionaries in the array chunks will be unified, their indices + /// accordingly transposed. + /// + /// Only dictionaries with a primitive value type are currently supported. + /// However, dictionaries nested inside a more complex type are correctly unified. + static Result> UnifyChunkedArray( + const std::shared_ptr& array, + MemoryPool* pool = default_memory_pool()); + + /// \brief Unify dictionaries across the chunks of each table column + /// + /// The dictionaries in each table column will be unified, their indices + /// accordingly transposed. + /// + /// Only dictionaries with a primitive value type are currently supported. + /// However, dictionaries nested inside a more complex type are correctly unified. + static Result> UnifyTable( + const Table& table, MemoryPool* pool = default_memory_pool()); + + /// \brief Append dictionary to the internal memo + virtual Status Unify(const Array& dictionary) = 0; + + /// \brief Append dictionary and compute transpose indices + /// \param[in] dictionary the dictionary values to unify + /// \param[out] out_transpose a Buffer containing computed transpose indices + /// as int32_t values equal in length to the passed dictionary. The value in + /// each slot corresponds to the new index value for each original index + /// for a DictionaryArray with the old dictionary + virtual Status Unify(const Array& dictionary, + std::shared_ptr* out_transpose) = 0; + + /// \brief Return a result DictionaryType with the smallest possible index + /// type to accommodate the unified dictionary. The unifier cannot be used + /// after this is called + virtual Status GetResult(std::shared_ptr* out_type, + std::shared_ptr* out_dict) = 0; + + /// \brief Return a unified dictionary with the given index type. If + /// the index type is not large enough then an invalid status will be returned. + /// The unifier cannot be used after this is called + virtual Status GetResultWithIndexType(const std::shared_ptr& index_type, + std::shared_ptr* out_dict) = 0; +}; + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/array_nested.h b/pyarrow/include/arrow/array/array_nested.h new file mode 100644 index 0000000000000000000000000000000000000000..bf84f802b1ab502fc50794997645b52756bb6df2 --- /dev/null +++ b/pyarrow/include/arrow/array/array_nested.h @@ -0,0 +1,899 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Array accessor classes for List, LargeList, ListView, LargeListView, FixedSizeList, +// Map, Struct, and Union + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/data.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \addtogroup nested-arrays +/// +/// @{ + +// ---------------------------------------------------------------------- +// VarLengthListLikeArray + +template +class VarLengthListLikeArray; + +namespace internal { + +// Private helper for [Large]List[View]Array::SetData. +// Unfortunately, trying to define VarLengthListLikeArray::SetData outside of this header +// doesn't play well with MSVC. +template +void SetListData(VarLengthListLikeArray* self, + const std::shared_ptr& data, + Type::type expected_type_id = TYPE::type_id); + +/// \brief A version of Flatten that keeps recursively flattening until an array of +/// non-list values is reached. +/// +/// Array types considered to be lists by this function: +/// - list +/// - large_list +/// - list_view +/// - large_list_view +/// - fixed_size_list +/// +/// \see ListArray::Flatten +ARROW_EXPORT Result> FlattenLogicalListRecursively( + const Array& in_array, MemoryPool* memory_pool); + +} // namespace internal + +/// Base class for variable-sized list and list-view arrays, regardless of offset size. +template +class VarLengthListLikeArray : public Array { + public: + using TypeClass = TYPE; + using offset_type = typename TypeClass::offset_type; + + const TypeClass* var_length_list_like_type() const { return this->list_type_; } + + /// \brief Return array object containing the list's values + /// + /// Note that this buffer does not account for any slice offset or length. + const std::shared_ptr& values() const { return values_; } + + /// Note that this buffer does not account for any slice offset or length. + const std::shared_ptr& value_offsets() const { return data_->buffers[1]; } + + const std::shared_ptr& value_type() const { return list_type_->value_type(); } + + /// Return pointer to raw value offsets accounting for any slice offset + const offset_type* raw_value_offsets() const { return raw_value_offsets_; } + + // The following functions will not perform boundschecking + + offset_type value_offset(int64_t i) const { return raw_value_offsets_[i]; } + + /// \brief Return the size of the value at a particular index + /// + /// Since non-empty null lists and list-views are possible, avoid calling this + /// function when the list at slot i is null. + /// + /// \pre IsValid(i) + virtual offset_type value_length(int64_t i) const = 0; + + /// \pre IsValid(i) + std::shared_ptr value_slice(int64_t i) const { + return values_->Slice(value_offset(i), value_length(i)); + } + + /// \brief Flatten all level recursively until reach a non-list type, and return + /// a non-list type Array. + /// + /// \see internal::FlattenLogicalListRecursively + Result> FlattenRecursively( + MemoryPool* memory_pool = default_memory_pool()) const { + return internal::FlattenLogicalListRecursively(*this, memory_pool); + } + + protected: + friend void internal::SetListData(VarLengthListLikeArray* self, + const std::shared_ptr& data, + Type::type expected_type_id); + + const TypeClass* list_type_ = NULLPTR; + std::shared_ptr values_; + const offset_type* raw_value_offsets_ = NULLPTR; +}; + +// ---------------------------------------------------------------------- +// ListArray / LargeListArray + +template +class BaseListArray : public VarLengthListLikeArray { + public: + using TypeClass = TYPE; + using offset_type = typename TYPE::offset_type; + + const TypeClass* list_type() const { return this->var_length_list_like_type(); } + + /// \brief Return the size of the value at a particular index + /// + /// Since non-empty null lists are possible, avoid calling this + /// function when the list at slot i is null. + /// + /// \pre IsValid(i) + offset_type value_length(int64_t i) const final { + return this->raw_value_offsets_[i + 1] - this->raw_value_offsets_[i]; + } +}; + +/// Concrete Array class for list data +class ARROW_EXPORT ListArray : public BaseListArray { + public: + explicit ListArray(std::shared_ptr data); + + ListArray(std::shared_ptr type, int64_t length, + std::shared_ptr value_offsets, std::shared_ptr values, + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + /// \brief Construct ListArray from array of offsets and child value array + /// + /// This function does the bare minimum of validation of the offsets and + /// input types, and will allocate a new offsets array if necessary (i.e. if + /// the offsets contain any nulls). If the offsets do not have nulls, they + /// are assumed to be well-formed. + /// + /// If a null_bitmap is not provided, the nulls will be inferred from the offsets' + /// null bitmap. But if a null_bitmap is provided, the offsets array can't have nulls. + /// + /// And when a null_bitmap is provided, the offsets array cannot be a slice (i.e. an + /// array with offset() > 0). + /// + /// \param[in] offsets Array containing n + 1 offsets encoding length and + /// size. Must be of int32 type + /// \param[in] values Array containing list values + /// \param[in] pool MemoryPool in case new offsets array needs to be + /// allocated because of null values + /// \param[in] null_bitmap Optional validity bitmap + /// \param[in] null_count Optional null count in null_bitmap + static Result> FromArrays( + const Array& offsets, const Array& values, MemoryPool* pool = default_memory_pool(), + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount); + + static Result> FromArrays( + std::shared_ptr type, const Array& offsets, const Array& values, + MemoryPool* pool = default_memory_pool(), + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount); + + /// \brief Build a ListArray from a ListViewArray + static Result> FromListView(const ListViewArray& source, + MemoryPool* pool); + + /// \brief Return an Array that is a concatenation of the lists in this array. + /// + /// Note that it's different from `values()` in that it takes into + /// consideration of this array's offsets as well as null elements backed + /// by non-empty lists (they are skipped, thus copying may be needed). + Result> Flatten( + MemoryPool* memory_pool = default_memory_pool()) const; + + /// \brief Return list offsets as an Int32Array + /// + /// The returned array will not have a validity bitmap, so you cannot expect + /// to pass it to ListArray::FromArrays() and get back the same list array + /// if the original one has nulls. + std::shared_ptr offsets() const; + + protected: + // This constructor defers SetData to a derived array class + ListArray() = default; + + void SetData(const std::shared_ptr& data); +}; + +/// Concrete Array class for large list data (with 64-bit offsets) +class ARROW_EXPORT LargeListArray : public BaseListArray { + public: + explicit LargeListArray(const std::shared_ptr& data); + + LargeListArray(const std::shared_ptr& type, int64_t length, + const std::shared_ptr& value_offsets, + const std::shared_ptr& values, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + /// \brief Construct LargeListArray from array of offsets and child value array + /// + /// This function does the bare minimum of validation of the offsets and + /// input types, and will allocate a new offsets array if necessary (i.e. if + /// the offsets contain any nulls). If the offsets do not have nulls, they + /// are assumed to be well-formed. + /// + /// If a null_bitmap is not provided, the nulls will be inferred from the offsets' + /// null bitmap. But if a null_bitmap is provided, the offsets array can't have nulls. + /// + /// And when a null_bitmap is provided, the offsets array cannot be a slice (i.e. an + /// array with offset() > 0). + /// + /// \param[in] offsets Array containing n + 1 offsets encoding length and + /// size. Must be of int64 type + /// \param[in] values Array containing list values + /// \param[in] pool MemoryPool in case new offsets array needs to be + /// allocated because of null values + /// \param[in] null_bitmap Optional validity bitmap + /// \param[in] null_count Optional null count in null_bitmap + static Result> FromArrays( + const Array& offsets, const Array& values, MemoryPool* pool = default_memory_pool(), + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount); + + static Result> FromArrays( + std::shared_ptr type, const Array& offsets, const Array& values, + MemoryPool* pool = default_memory_pool(), + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount); + + /// \brief Build a LargeListArray from a LargeListViewArray + static Result> FromListView( + const LargeListViewArray& source, MemoryPool* pool); + + /// \brief Return an Array that is a concatenation of the lists in this array. + /// + /// Note that it's different from `values()` in that it takes into + /// consideration of this array's offsets as well as null elements backed + /// by non-empty lists (they are skipped, thus copying may be needed). + Result> Flatten( + MemoryPool* memory_pool = default_memory_pool()) const; + + /// \brief Return list offsets as an Int64Array + std::shared_ptr offsets() const; + + protected: + void SetData(const std::shared_ptr& data); +}; + +// ---------------------------------------------------------------------- +// ListViewArray / LargeListViewArray + +template +class BaseListViewArray : public VarLengthListLikeArray { + public: + using TypeClass = TYPE; + using offset_type = typename TYPE::offset_type; + + const TypeClass* list_view_type() const { return this->var_length_list_like_type(); } + + /// \brief Note that this buffer does not account for any slice offset or length. + const std::shared_ptr& value_sizes() const { return this->data_->buffers[2]; } + + /// \brief Return pointer to raw value offsets accounting for any slice offset + const offset_type* raw_value_sizes() const { return raw_value_sizes_; } + + /// \brief Return the size of the value at a particular index + /// + /// This should not be called if the list-view at slot i is null. + /// The returned size in those cases could be any value from 0 to the + /// length of the child values array. + /// + /// \pre IsValid(i) + offset_type value_length(int64_t i) const final { return this->raw_value_sizes_[i]; } + + protected: + const offset_type* raw_value_sizes_ = NULLPTR; +}; + +/// \brief Concrete Array class for list-view data +class ARROW_EXPORT ListViewArray : public BaseListViewArray { + public: + explicit ListViewArray(std::shared_ptr data); + + ListViewArray(std::shared_ptr type, int64_t length, + std::shared_ptr value_offsets, + std::shared_ptr value_sizes, std::shared_ptr values, + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + /// \brief Construct ListViewArray from array of offsets, sizes, and child + /// value array + /// + /// Construct a ListViewArray using buffers from offsets and sizes arrays + /// that project views into the child values array. + /// + /// This function does the bare minimum of validation of the offsets/sizes and + /// input types. The offset and length of the offsets and sizes arrays must + /// match and that will be checked, but their contents will be assumed to be + /// well-formed. + /// + /// If a null_bitmap is not provided, the nulls will be inferred from the + /// offsets's null bitmap. But if a null_bitmap is provided, the offsets array + /// can't have nulls. + /// + /// And when a null_bitmap is provided, neither the offsets or sizes array can be a + /// slice (i.e. an array with offset() > 0). + /// + /// \param[in] offsets An array of int32 offsets into the values array. NULL values are + /// supported if the corresponding values in sizes is NULL or 0. + /// \param[in] sizes An array containing the int32 sizes of every view. NULL values are + /// taken to represent a NULL list-view in the array being created. + /// \param[in] values Array containing list values + /// \param[in] pool MemoryPool + /// \param[in] null_bitmap Optional validity bitmap + /// \param[in] null_count Optional null count in null_bitmap + static Result> FromArrays( + const Array& offsets, const Array& sizes, const Array& values, + MemoryPool* pool = default_memory_pool(), + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount); + + static Result> FromArrays( + std::shared_ptr type, const Array& offsets, const Array& sizes, + const Array& values, MemoryPool* pool = default_memory_pool(), + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount); + + /// \brief Build a ListViewArray from a ListArray + static Result> FromList(const ListArray& list_array, + MemoryPool* pool); + + /// \brief Return an Array that is a concatenation of the list-views in this array. + /// + /// Note that it's different from `values()` in that it takes into + /// consideration this array's offsets (which can be in any order) + /// and sizes. Nulls are skipped. + /// + /// This function invokes Concatenate() if list-views are non-contiguous. It + /// will try to minimize the number of array slices passed to Concatenate() by + /// maximizing the size of each slice (containing as many contiguous + /// list-views as possible). + Result> Flatten( + MemoryPool* memory_pool = default_memory_pool()) const; + + /// \brief Return list-view offsets as an Int32Array + /// + /// The returned array will not have a validity bitmap, so you cannot expect + /// to pass it to ListArray::FromArrays() and get back the same list array + /// if the original one has nulls. + std::shared_ptr offsets() const; + + /// \brief Return list-view sizes as an Int32Array + /// + /// The returned array will not have a validity bitmap, so you cannot expect + /// to pass it to ListViewArray::FromArrays() and get back the same list + /// array if the original one has nulls. + std::shared_ptr sizes() const; + + protected: + // This constructor defers SetData to a derived array class + ListViewArray() = default; + + void SetData(const std::shared_ptr& data); +}; + +/// \brief Concrete Array class for large list-view data (with 64-bit offsets +/// and sizes) +class ARROW_EXPORT LargeListViewArray : public BaseListViewArray { + public: + explicit LargeListViewArray(std::shared_ptr data); + + LargeListViewArray(std::shared_ptr type, int64_t length, + std::shared_ptr value_offsets, + std::shared_ptr value_sizes, std::shared_ptr values, + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + /// \brief Construct LargeListViewArray from array of offsets, sizes, and child + /// value array + /// + /// Construct an LargeListViewArray using buffers from offsets and sizes arrays + /// that project views into the values array. + /// + /// This function does the bare minimum of validation of the offsets/sizes and + /// input types. The offset and length of the offsets and sizes arrays must + /// match and that will be checked, but their contents will be assumed to be + /// well-formed. + /// + /// If a null_bitmap is not provided, the nulls will be inferred from the offsets' or + /// sizes' null bitmap. Only one of these two is allowed to have a null bitmap. But if a + /// null_bitmap is provided, the offsets array and the sizes array can't have nulls. + /// + /// And when a null_bitmap is provided, neither the offsets or sizes array can be a + /// slice (i.e. an array with offset() > 0). + /// + /// \param[in] offsets An array of int64 offsets into the values array. NULL values are + /// supported if the corresponding values in sizes is NULL or 0. + /// \param[in] sizes An array containing the int64 sizes of every view. NULL values are + /// taken to represent a NULL list-view in the array being created. + /// \param[in] values Array containing list values + /// \param[in] pool MemoryPool + /// \param[in] null_bitmap Optional validity bitmap + /// \param[in] null_count Optional null count in null_bitmap + static Result> FromArrays( + const Array& offsets, const Array& sizes, const Array& values, + MemoryPool* pool = default_memory_pool(), + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount); + + static Result> FromArrays( + std::shared_ptr type, const Array& offsets, const Array& sizes, + const Array& values, MemoryPool* pool = default_memory_pool(), + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount); + + /// \brief Build a LargeListViewArray from a LargeListArray + static Result> FromList( + const LargeListArray& list_array, MemoryPool* pool); + + /// \brief Return an Array that is a concatenation of the large list-views in this + /// array. + /// + /// Note that it's different from `values()` in that it takes into + /// consideration this array's offsets (which can be in any order) + /// and sizes. Nulls are skipped. + Result> Flatten( + MemoryPool* memory_pool = default_memory_pool()) const; + + /// \brief Return list-view offsets as an Int64Array + /// + /// The returned array will not have a validity bitmap, so you cannot expect + /// to pass it to LargeListArray::FromArrays() and get back the same list array + /// if the original one has nulls. + std::shared_ptr offsets() const; + + /// \brief Return list-view sizes as an Int64Array + /// + /// The returned array will not have a validity bitmap, so you cannot expect + /// to pass it to LargeListViewArray::FromArrays() and get back the same list + /// array if the original one has nulls. + std::shared_ptr sizes() const; + + protected: + // This constructor defers SetData to a derived array class + LargeListViewArray() = default; + + void SetData(const std::shared_ptr& data); +}; + +// ---------------------------------------------------------------------- +// MapArray + +/// Concrete Array class for map data +/// +/// NB: "value" in this context refers to a pair of a key and the corresponding item +class ARROW_EXPORT MapArray : public ListArray { + public: + using TypeClass = MapType; + + explicit MapArray(const std::shared_ptr& data); + + MapArray(const std::shared_ptr& type, int64_t length, + const std::shared_ptr& value_offsets, + const std::shared_ptr& keys, const std::shared_ptr& items, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + MapArray(const std::shared_ptr& type, int64_t length, BufferVector buffers, + const std::shared_ptr& keys, const std::shared_ptr& items, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + MapArray(const std::shared_ptr& type, int64_t length, + const std::shared_ptr& value_offsets, + const std::shared_ptr& values, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + /// \brief Construct MapArray from array of offsets and child key, item arrays + /// + /// This function does the bare minimum of validation of the offsets and + /// input types, and will allocate a new offsets array if necessary (i.e. if + /// the offsets contain any nulls). If the offsets do not have nulls, they + /// are assumed to be well-formed + /// + /// \param[in] offsets Array containing n + 1 offsets encoding length and + /// size. Must be of int32 type + /// \param[in] keys Array containing key values + /// \param[in] items Array containing item values + /// \param[in] pool MemoryPool in case new offsets array needs to be + /// \param[in] null_bitmap Optional validity bitmap + /// allocated because of null values + static Result> FromArrays( + const std::shared_ptr& offsets, const std::shared_ptr& keys, + const std::shared_ptr& items, MemoryPool* pool = default_memory_pool(), + std::shared_ptr null_bitmap = NULLPTR); + + static Result> FromArrays( + std::shared_ptr type, const std::shared_ptr& offsets, + const std::shared_ptr& keys, const std::shared_ptr& items, + MemoryPool* pool = default_memory_pool(), + std::shared_ptr null_bitmap = NULLPTR); + + const MapType* map_type() const { return map_type_; } + + /// \brief Return array object containing all map keys + const std::shared_ptr& keys() const { return keys_; } + + /// \brief Return array object containing all mapped items + const std::shared_ptr& items() const { return items_; } + + /// Validate child data before constructing the actual MapArray. + static Status ValidateChildData( + const std::vector>& child_data); + + protected: + void SetData(const std::shared_ptr& data); + + static Result> FromArraysInternal( + std::shared_ptr type, const std::shared_ptr& offsets, + const std::shared_ptr& keys, const std::shared_ptr& items, + MemoryPool* pool, std::shared_ptr null_bitmap = NULLPTR); + + private: + const MapType* map_type_; + std::shared_ptr keys_, items_; +}; + +// ---------------------------------------------------------------------- +// FixedSizeListArray + +/// Concrete Array class for fixed size list data +class ARROW_EXPORT FixedSizeListArray : public Array { + public: + using TypeClass = FixedSizeListType; + using offset_type = TypeClass::offset_type; + + explicit FixedSizeListArray(const std::shared_ptr& data); + + FixedSizeListArray(const std::shared_ptr& type, int64_t length, + const std::shared_ptr& values, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + const FixedSizeListType* list_type() const; + + /// \brief Return array object containing the list's values + const std::shared_ptr& values() const; + + const std::shared_ptr& value_type() const; + + // The following functions will not perform boundschecking + int64_t value_offset(int64_t i) const { + i += data_->offset; + return list_size_ * i; + } + /// \brief Return the fixed-size of the values + /// + /// No matter the value of the index parameter, the result is the same. + /// So even when the value at slot i is null, this function will return a + /// non-zero size. + /// + /// \pre IsValid(i) + int32_t value_length(int64_t i = 0) const { + ARROW_UNUSED(i); + return list_size_; + } + /// \pre IsValid(i) + std::shared_ptr value_slice(int64_t i) const { + return values_->Slice(value_offset(i), value_length(i)); + } + + /// \brief Return an Array that is a concatenation of the lists in this array. + /// + /// Note that it's different from `values()` in that it takes into + /// consideration null elements (they are skipped, thus copying may be needed). + Result> Flatten( + MemoryPool* memory_pool = default_memory_pool()) const; + + /// \brief Flatten all level recursively until reach a non-list type, and return + /// a non-list type Array. + /// + /// \see internal::FlattenLogicalListRecursively + Result> FlattenRecursively( + MemoryPool* memory_pool = default_memory_pool()) const { + return internal::FlattenLogicalListRecursively(*this, memory_pool); + } + + /// \brief Construct FixedSizeListArray from child value array and value_length + /// + /// \param[in] values Array containing list values + /// \param[in] list_size The fixed length of each list + /// \param[in] null_bitmap Optional validity bitmap + /// \param[in] null_count Optional null count in null_bitmap + /// \return Will have length equal to values.length() / list_size + static Result> FromArrays( + const std::shared_ptr& values, int32_t list_size, + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount); + + /// \brief Construct FixedSizeListArray from child value array and type + /// + /// \param[in] values Array containing list values + /// \param[in] type The fixed sized list type + /// \param[in] null_bitmap Optional validity bitmap + /// \param[in] null_count Optional null count in null_bitmap + /// \return Will have length equal to values.length() / type.list_size() + static Result> FromArrays( + const std::shared_ptr& values, std::shared_ptr type, + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount); + + protected: + void SetData(const std::shared_ptr& data); + int32_t list_size_; + + private: + std::shared_ptr values_; +}; + +// ---------------------------------------------------------------------- +// Struct + +/// Concrete Array class for struct data +class ARROW_EXPORT StructArray : public Array { + public: + using TypeClass = StructType; + + ~StructArray() override; + + explicit StructArray(const std::shared_ptr& data); + + StructArray(const std::shared_ptr& type, int64_t length, + const std::vector>& children, + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + /// \brief Return a StructArray from child arrays and field names. + /// + /// The length and data type are automatically inferred from the arguments. + /// There should be at least one child array. + static Result> Make( + const ArrayVector& children, const std::vector& field_names, + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + /// \brief Return a StructArray from child arrays and fields. + /// + /// The length is automatically inferred from the arguments. + /// There should be at least one child array. This method does not + /// check that field types and child array types are consistent. + static Result> Make( + const ArrayVector& children, const FieldVector& fields, + std::shared_ptr null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + const StructType* struct_type() const; + + // Return a shared pointer in case the requestor desires to share ownership + // with this array. The returned array has its offset, length and null + // count adjusted. + std::shared_ptr field(int pos) const; + + const ArrayVector& fields() const; + + /// Returns null if name not found + std::shared_ptr GetFieldByName(const std::string& name) const; + + /// Indicate if field named `name` can be found unambiguously in the struct. + Status CanReferenceFieldByName(const std::string& name) const; + + /// Indicate if fields named `names` can be found unambiguously in the struct. + Status CanReferenceFieldsByNames(const std::vector& names) const; + + /// \brief Flatten this array as a vector of arrays, one for each field + /// + /// \param[in] pool The pool to allocate null bitmaps from, if necessary + Result Flatten(MemoryPool* pool = default_memory_pool()) const; + + /// \brief Get one of the child arrays, combining its null bitmap + /// with the parent struct array's bitmap. + /// + /// \param[in] index Which child array to get + /// \param[in] pool The pool to allocate null bitmaps from, if necessary + Result> GetFlattenedField( + int index, MemoryPool* pool = default_memory_pool()) const; + + private: + // For caching boxed child data + struct ARROW_NO_EXPORT Impl; + std::unique_ptr impl_; +}; + +// ---------------------------------------------------------------------- +// Union + +/// Base class for SparseUnionArray and DenseUnionArray +class ARROW_EXPORT UnionArray : public Array { + public: + using type_code_t = int8_t; + + ~UnionArray() override; + + /// Note that this buffer does not account for any slice offset + const std::shared_ptr& type_codes() const { return data_->buffers[1]; } + + const type_code_t* raw_type_codes() const { return raw_type_codes_; } + + /// The logical type code of the value at index. + type_code_t type_code(int64_t i) const { return raw_type_codes_[i]; } + + /// The physical child id containing value at index. + int child_id(int64_t i) const { return union_type_->child_ids()[raw_type_codes_[i]]; } + + const UnionType* union_type() const { return union_type_; } + + UnionMode::type mode() const { return union_type_->mode(); } + + /// \brief Return the given field as an individual array. + /// + /// For sparse unions, the returned array has its offset, length and null + /// count adjusted. + std::shared_ptr field(int pos) const; + + protected: + UnionArray(); + + void SetData(std::shared_ptr data); + + const type_code_t* raw_type_codes_; + const UnionType* union_type_; + + private: + // For caching boxed child data + struct ARROW_NO_EXPORT Impl; + std::unique_ptr impl_; +}; + +/// Concrete Array class for sparse union data +class ARROW_EXPORT SparseUnionArray : public UnionArray { + public: + using TypeClass = SparseUnionType; + + ~SparseUnionArray() override; + + explicit SparseUnionArray(std::shared_ptr data); + + SparseUnionArray(std::shared_ptr type, int64_t length, ArrayVector children, + std::shared_ptr type_ids, int64_t offset = 0); + + /// \brief Construct SparseUnionArray from type_ids and children + /// + /// This function does the bare minimum of validation of the input types. + /// + /// \param[in] type_ids An array of logical type ids for the union type + /// \param[in] children Vector of children Arrays containing the data for each type. + /// \param[in] type_codes Vector of type codes. + static Result> Make(const Array& type_ids, ArrayVector children, + std::vector type_codes) { + return Make(std::move(type_ids), std::move(children), std::vector{}, + std::move(type_codes)); + } + + /// \brief Construct SparseUnionArray with custom field names from type_ids and children + /// + /// This function does the bare minimum of validation of the input types. + /// + /// \param[in] type_ids An array of logical type ids for the union type + /// \param[in] children Vector of children Arrays containing the data for each type. + /// \param[in] field_names Vector of strings containing the name of each field. + /// \param[in] type_codes Vector of type codes. + static Result> Make(const Array& type_ids, ArrayVector children, + std::vector field_names = {}, + std::vector type_codes = {}); + + const SparseUnionType* union_type() const { + return internal::checked_cast(union_type_); + } + + /// \brief Get one of the child arrays, adjusting its null bitmap + /// where the union array type code does not match. + /// + /// \param[in] index Which child array to get (i.e. the physical index, not the type + /// code) \param[in] pool The pool to allocate null bitmaps from, if necessary + Result> GetFlattenedField( + int index, MemoryPool* pool = default_memory_pool()) const; + + protected: + void SetData(std::shared_ptr data); +}; + +/// \brief Concrete Array class for dense union data +/// +/// Note that union types do not have a validity bitmap +class ARROW_EXPORT DenseUnionArray : public UnionArray { + public: + using TypeClass = DenseUnionType; + + ~DenseUnionArray() override; + + explicit DenseUnionArray(const std::shared_ptr& data); + + DenseUnionArray(std::shared_ptr type, int64_t length, ArrayVector children, + std::shared_ptr type_ids, + std::shared_ptr value_offsets = NULLPTR, int64_t offset = 0); + + /// \brief Construct DenseUnionArray from type_ids, value_offsets, and children + /// + /// This function does the bare minimum of validation of the offsets and + /// input types. + /// + /// \param[in] type_ids An array of logical type ids for the union type + /// \param[in] value_offsets An array of signed int32 values indicating the + /// relative offset into the respective child array for the type in a given slot. + /// The respective offsets for each child value array must be in order / increasing. + /// \param[in] children Vector of children Arrays containing the data for each type. + /// \param[in] type_codes Vector of type codes. + static Result> Make(const Array& type_ids, + const Array& value_offsets, + ArrayVector children, + std::vector type_codes) { + return Make(type_ids, value_offsets, std::move(children), std::vector{}, + std::move(type_codes)); + } + + /// \brief Construct DenseUnionArray with custom field names from type_ids, + /// value_offsets, and children + /// + /// This function does the bare minimum of validation of the offsets and + /// input types. + /// + /// \param[in] type_ids An array of logical type ids for the union type + /// \param[in] value_offsets An array of signed int32 values indicating the + /// relative offset into the respective child array for the type in a given slot. + /// The respective offsets for each child value array must be in order / increasing. + /// \param[in] children Vector of children Arrays containing the data for each type. + /// \param[in] field_names Vector of strings containing the name of each field. + /// \param[in] type_codes Vector of type codes. + static Result> Make(const Array& type_ids, + const Array& value_offsets, + ArrayVector children, + std::vector field_names = {}, + std::vector type_codes = {}); + + const DenseUnionType* union_type() const { + return internal::checked_cast(union_type_); + } + + /// Note that this buffer does not account for any slice offset + const std::shared_ptr& value_offsets() const { return data_->buffers[2]; } + + int32_t value_offset(int64_t i) const { return raw_value_offsets_[i]; } + + const int32_t* raw_value_offsets() const { return raw_value_offsets_; } + + protected: + const int32_t* raw_value_offsets_; + + void SetData(const std::shared_ptr& data); +}; + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/array_primitive.h b/pyarrow/include/arrow/array/array_primitive.h new file mode 100644 index 0000000000000000000000000000000000000000..cebf47ad93d8aa719328007f3c4fa6d960855027 --- /dev/null +++ b/pyarrow/include/arrow/array/array_primitive.h @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Array accessor types for primitive/C-type-based arrays, such as numbers, +// boolean, and temporal types. + +#pragma once + +#include +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/data.h" +#include "arrow/stl_iterator.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" // IWYU pragma: export +#include "arrow/type_traits.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// Concrete Array class for boolean data +class ARROW_EXPORT BooleanArray : public PrimitiveArray { + public: + using TypeClass = BooleanType; + using IteratorType = stl::ArrayIterator; + + explicit BooleanArray(const std::shared_ptr& data); + + BooleanArray(int64_t length, const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + bool Value(int64_t i) const { + return bit_util::GetBit(reinterpret_cast(raw_values_), + i + data_->offset); + } + + bool GetView(int64_t i) const { return Value(i); } + + std::optional operator[](int64_t i) const { return *IteratorType(*this, i); } + + /// \brief Return the number of false (0) values among the valid + /// values. Result is not cached. + int64_t false_count() const; + + /// \brief Return the number of true (1) values among the valid + /// values. Result is not cached. + int64_t true_count() const; + + IteratorType begin() const { return IteratorType(*this); } + + IteratorType end() const { return IteratorType(*this, length()); } + + protected: + using PrimitiveArray::PrimitiveArray; +}; + +/// \addtogroup numeric-arrays +/// +/// @{ + +/// \brief Concrete Array class for numeric data with a corresponding C type +/// +/// This class is templated on the corresponding DataType subclass for the +/// given data, for example NumericArray or NumericArray. +/// +/// Note that convenience aliases are available for all accepted types +/// (for example Int8Array for NumericArray). +template +class NumericArray : public PrimitiveArray { + public: + using TypeClass = TYPE; + using value_type = typename TypeClass::c_type; + using IteratorType = stl::ArrayIterator>; + + explicit NumericArray(const std::shared_ptr& data) { + NumericArray::SetData(data); + } + + // Only enable this constructor without a type argument for types without additional + // metadata + template + NumericArray(enable_if_parameter_free length, + const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0) { + NumericArray::SetData(ArrayData::Make(TypeTraits::type_singleton(), length, + {null_bitmap, data}, null_count, offset)); + } + + NumericArray(std::shared_ptr type, int64_t length, + const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0) { + NumericArray::SetData(ArrayData::Make(std::move(type), length, {null_bitmap, data}, + null_count, offset)); + } + + const value_type* raw_values() const { return values_; } + + value_type Value(int64_t i) const { return values_[i]; } + + // For API compatibility with BinaryArray etc. + value_type GetView(int64_t i) const { return values_[i]; } + + std::optional operator[](int64_t i) const { + return *IteratorType(*this, i); + } + + IteratorType begin() const { return IteratorType(*this); } + + IteratorType end() const { return IteratorType(*this, length()); } + + protected: + NumericArray() : values_(NULLPTR) {} + + void SetData(const std::shared_ptr& data) { + this->PrimitiveArray::SetData(data); + values_ = raw_values_ + ? (reinterpret_cast(raw_values_) + data_->offset) + : NULLPTR; + } + + const value_type* values_; +}; + +/// DayTimeArray +/// --------------------- +/// \brief Array of Day and Millisecond values. +class ARROW_EXPORT DayTimeIntervalArray : public PrimitiveArray { + public: + using TypeClass = DayTimeIntervalType; + using IteratorType = stl::ArrayIterator; + + explicit DayTimeIntervalArray(const std::shared_ptr& data); + + DayTimeIntervalArray(const std::shared_ptr& type, int64_t length, + const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + DayTimeIntervalArray(int64_t length, const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + TypeClass::DayMilliseconds GetValue(int64_t i) const; + TypeClass::DayMilliseconds Value(int64_t i) const { return GetValue(i); } + + // For compatibility with Take kernel. + TypeClass::DayMilliseconds GetView(int64_t i) const { return GetValue(i); } + + IteratorType begin() const { return IteratorType(*this); } + + IteratorType end() const { return IteratorType(*this, length()); } + + std::optional operator[](int64_t i) const { + return *IteratorType(*this, i); + } + + int32_t byte_width() const { return sizeof(TypeClass::DayMilliseconds); } + + const uint8_t* raw_values() const { return raw_values_ + data_->offset * byte_width(); } +}; + +/// \brief Array of Month, Day and nanosecond values. +class ARROW_EXPORT MonthDayNanoIntervalArray : public PrimitiveArray { + public: + using TypeClass = MonthDayNanoIntervalType; + using IteratorType = stl::ArrayIterator; + + explicit MonthDayNanoIntervalArray(const std::shared_ptr& data); + + MonthDayNanoIntervalArray(const std::shared_ptr& type, int64_t length, + const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + MonthDayNanoIntervalArray(int64_t length, const std::shared_ptr& data, + const std::shared_ptr& null_bitmap = NULLPTR, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + TypeClass::MonthDayNanos GetValue(int64_t i) const; + TypeClass::MonthDayNanos Value(int64_t i) const { return GetValue(i); } + + // For compatibility with Take kernel. + TypeClass::MonthDayNanos GetView(int64_t i) const { return GetValue(i); } + + IteratorType begin() const { return IteratorType(*this); } + + IteratorType end() const { return IteratorType(*this, length()); } + + std::optional operator[](int64_t i) const { + return *IteratorType(*this, i); + } + + int32_t byte_width() const { return sizeof(TypeClass::MonthDayNanos); } + + const uint8_t* raw_values() const { return raw_values_ + data_->offset * byte_width(); } +}; + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/array_run_end.h b/pyarrow/include/arrow/array/array_run_end.h new file mode 100644 index 0000000000000000000000000000000000000000..b46b0855ab36776eec4e22cef1a35112e2d18fa8 --- /dev/null +++ b/pyarrow/include/arrow/array/array_run_end.h @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Array accessor classes run-end encoded arrays + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/data.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \addtogroup run-end-encoded-arrays +/// +/// @{ + +// ---------------------------------------------------------------------- +// RunEndEncoded + +/// \brief Array type for run-end encoded data +class ARROW_EXPORT RunEndEncodedArray : public Array { + private: + std::shared_ptr run_ends_array_; + std::shared_ptr values_array_; + + public: + using TypeClass = RunEndEncodedType; + + explicit RunEndEncodedArray(const std::shared_ptr& data); + + /// \brief Construct a RunEndEncodedArray from all parameters + /// + /// The length and offset parameters refer to the dimensions of the logical + /// array which is the array we would get after expanding all the runs into + /// repeated values. As such, length can be much greater than the length of + /// the child run_ends and values arrays. + RunEndEncodedArray(const std::shared_ptr& type, int64_t length, + const std::shared_ptr& run_ends, + const std::shared_ptr& values, int64_t offset = 0); + + /// \brief Construct a RunEndEncodedArray from all parameters + /// + /// The length and offset parameters refer to the dimensions of the logical + /// array which is the array we would get after expanding all the runs into + /// repeated values. As such, length can be much greater than the length of + /// the child run_ends and values arrays. + static Result> Make( + const std::shared_ptr& type, int64_t logical_length, + const std::shared_ptr& run_ends, const std::shared_ptr& values, + int64_t logical_offset = 0); + + /// \brief Construct a RunEndEncodedArray from values and run ends arrays + /// + /// The data type is automatically inferred from the arguments. + /// The run_ends and values arrays must have the same length. + static Result> Make( + int64_t logical_length, const std::shared_ptr& run_ends, + const std::shared_ptr& values, int64_t logical_offset = 0); + + protected: + void SetData(const std::shared_ptr& data); + + public: + /// \brief Returns an array holding the logical indexes of each run-end + /// + /// The physical offset to the array is applied. + const std::shared_ptr& run_ends() const { return run_ends_array_; } + + /// \brief Returns an array holding the values of each run + /// + /// The physical offset to the array is applied. + const std::shared_ptr& values() const { return values_array_; } + + /// \brief Returns an array holding the logical indexes of each run end + /// + /// If a non-zero logical offset is set, this function allocates a new + /// array and rewrites all the run end values to be relative to the logical + /// offset and cuts the end of the array to the logical length. + Result> LogicalRunEnds(MemoryPool* pool) const; + + /// \brief Returns an array holding the values of each run + /// + /// If a non-zero logical offset is set, this function allocates a new + /// array containing only the values within the logical range. + std::shared_ptr LogicalValues() const; + + /// \brief Find the physical offset of this REE array + /// + /// This function uses binary-search, so it has a O(log N) cost. + int64_t FindPhysicalOffset() const; + + /// \brief Find the physical length of this REE array + /// + /// The physical length of an REE is the number of physical values (and + /// run-ends) necessary to represent the logical range of values from offset + /// to length. + /// + /// Avoid calling this function if the physical length can be established in + /// some other way (e.g. when iterating over the runs sequentially until the + /// end). This function uses binary-search, so it has a O(log N) cost. + int64_t FindPhysicalLength() const; +}; + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/builder_adaptive.h b/pyarrow/include/arrow/array/builder_adaptive.h new file mode 100644 index 0000000000000000000000000000000000000000..0cea571be3e3244741f3df15f87c8958eedddf76 --- /dev/null +++ b/pyarrow/include/arrow/array/builder_adaptive.h @@ -0,0 +1,215 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/array/builder_base.h" +#include "arrow/buffer.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \addtogroup numeric-builders +/// +/// @{ + +namespace internal { + +class ARROW_EXPORT AdaptiveIntBuilderBase : public ArrayBuilder { + public: + AdaptiveIntBuilderBase(uint8_t start_int_size, MemoryPool* pool, + int64_t alignment = kDefaultBufferAlignment); + + explicit AdaptiveIntBuilderBase(MemoryPool* pool, + int64_t alignment = kDefaultBufferAlignment) + : AdaptiveIntBuilderBase(sizeof(uint8_t), pool, alignment) {} + + /// \brief Append multiple nulls + /// \param[in] length the number of nulls to append + Status AppendNulls(int64_t length) final { + ARROW_RETURN_NOT_OK(CommitPendingData()); + if (ARROW_PREDICT_TRUE(length > 0)) { + ARROW_RETURN_NOT_OK(Reserve(length)); + memset(data_->mutable_data() + length_ * int_size_, 0, int_size_ * length); + UnsafeSetNull(length); + } + return Status::OK(); + } + + Status AppendNull() final { + pending_data_[pending_pos_] = 0; + pending_valid_[pending_pos_] = 0; + pending_has_nulls_ = true; + ++pending_pos_; + ++length_; + ++null_count_; + + if (ARROW_PREDICT_FALSE(pending_pos_ >= pending_size_)) { + return CommitPendingData(); + } + return Status::OK(); + } + + Status AppendEmptyValues(int64_t length) final { + ARROW_RETURN_NOT_OK(CommitPendingData()); + if (ARROW_PREDICT_TRUE(length > 0)) { + ARROW_RETURN_NOT_OK(Reserve(length)); + memset(data_->mutable_data() + length_ * int_size_, 0, int_size_ * length); + UnsafeSetNotNull(length); + } + return Status::OK(); + } + + Status AppendEmptyValue() final { + pending_data_[pending_pos_] = 0; + pending_valid_[pending_pos_] = 1; + ++pending_pos_; + ++length_; + + if (ARROW_PREDICT_FALSE(pending_pos_ >= pending_size_)) { + return CommitPendingData(); + } + return Status::OK(); + } + + void Reset() override; + Status Resize(int64_t capacity) override; + + protected: + Status AppendInternal(const uint64_t val) { + pending_data_[pending_pos_] = val; + pending_valid_[pending_pos_] = 1; + ++pending_pos_; + ++length_; + + if (ARROW_PREDICT_FALSE(pending_pos_ >= pending_size_)) { + return CommitPendingData(); + } + return Status::OK(); + } + + virtual Status CommitPendingData() = 0; + + template + typename std::enable_if= sizeof(new_type), Status>::type + ExpandIntSizeInternal(); + template + typename std::enable_if<(sizeof(old_type) < sizeof(new_type)), Status>::type + ExpandIntSizeInternal(); + + std::shared_ptr data_; + uint8_t* raw_data_ = NULLPTR; + + const uint8_t start_int_size_; + uint8_t int_size_; + + static constexpr int32_t pending_size_ = 1024; + uint8_t pending_valid_[pending_size_]; + uint64_t pending_data_[pending_size_]; + int32_t pending_pos_ = 0; + bool pending_has_nulls_ = false; +}; + +} // namespace internal + +class ARROW_EXPORT AdaptiveUIntBuilder : public internal::AdaptiveIntBuilderBase { + public: + explicit AdaptiveUIntBuilder(uint8_t start_int_size, + MemoryPool* pool = default_memory_pool()); + + explicit AdaptiveUIntBuilder(MemoryPool* pool = default_memory_pool()) + : AdaptiveUIntBuilder(sizeof(uint8_t), pool) {} + + using internal::AdaptiveIntBuilderBase::Reset; + + /// Scalar append + Status Append(const uint64_t val) { return AppendInternal(val); } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a contiguous C array of values + /// \param[in] length the number of values to append + /// \param[in] valid_bytes an optional sequence of bytes where non-zero + /// indicates a valid (non-null) value + /// \return Status + Status AppendValues(const uint64_t* values, int64_t length, + const uint8_t* valid_bytes = NULLPTR); + + Status FinishInternal(std::shared_ptr* out) override; + + std::shared_ptr type() const override; + + protected: + Status CommitPendingData() override; + Status ExpandIntSize(uint8_t new_int_size); + + Status AppendValuesInternal(const uint64_t* values, int64_t length, + const uint8_t* valid_bytes); + + template + Status ExpandIntSizeN(); +}; + +class ARROW_EXPORT AdaptiveIntBuilder : public internal::AdaptiveIntBuilderBase { + public: + explicit AdaptiveIntBuilder(uint8_t start_int_size, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment); + + explicit AdaptiveIntBuilder(MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : AdaptiveIntBuilder(sizeof(uint8_t), pool, alignment) {} + + using internal::AdaptiveIntBuilderBase::Reset; + + /// Scalar append + Status Append(const int64_t val) { return AppendInternal(static_cast(val)); } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a contiguous C array of values + /// \param[in] length the number of values to append + /// \param[in] valid_bytes an optional sequence of bytes where non-zero + /// indicates a valid (non-null) value + /// \return Status + Status AppendValues(const int64_t* values, int64_t length, + const uint8_t* valid_bytes = NULLPTR); + + Status FinishInternal(std::shared_ptr* out) override; + + std::shared_ptr type() const override; + + protected: + Status CommitPendingData() override; + Status ExpandIntSize(uint8_t new_int_size); + + Status AppendValuesInternal(const int64_t* values, int64_t length, + const uint8_t* valid_bytes); + + template + Status ExpandIntSizeN(); +}; + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/builder_base.h b/pyarrow/include/arrow/array/builder_base.h new file mode 100644 index 0000000000000000000000000000000000000000..ecd2136f5d20ba126bd359977ea17f76c4fe23ed --- /dev/null +++ b/pyarrow/include/arrow/array/builder_base.h @@ -0,0 +1,371 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include // IWYU pragma: keep +#include +#include +#include +#include +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/array_primitive.h" +#include "arrow/buffer.h" +#include "arrow/buffer_builder.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +namespace internal { + +template +class ArrayBuilderExtraOps { + public: + /// \brief Append a value from an optional or null if it has no value. + Status AppendOrNull(const std::optional& value) { + auto* self = static_cast(this); + return value.has_value() ? self->Append(*value) : self->AppendNull(); + } + + /// \brief Append a value from an optional or null if it has no value. + /// + /// Unsafe methods don't check existing size. + void UnsafeAppendOrNull(const std::optional& value) { + auto* self = static_cast(this); + return value.has_value() ? self->UnsafeAppend(*value) : self->UnsafeAppendNull(); + } +}; + +} // namespace internal + +/// \defgroup numeric-builders Concrete builder subclasses for numeric types +/// @{ +/// @} + +/// \defgroup temporal-builders Concrete builder subclasses for temporal types +/// @{ +/// @} + +/// \defgroup binary-builders Concrete builder subclasses for binary types +/// @{ +/// @} + +/// \defgroup nested-builders Concrete builder subclasses for nested types +/// @{ +/// @} + +/// \defgroup dictionary-builders Concrete builder subclasses for dictionary types +/// @{ +/// @} + +/// \defgroup run-end-encoded-builders Concrete builder subclasses for run-end encoded +/// arrays +/// @{ +/// @} + +constexpr int64_t kMinBuilderCapacity = 1 << 5; +constexpr int64_t kListMaximumElements = std::numeric_limits::max() - 1; + +/// Base class for all data array builders. +/// +/// This class provides a facilities for incrementally building the null bitmap +/// (see Append methods) and as a side effect the current number of slots and +/// the null count. +/// +/// \note Users are expected to use builders as one of the concrete types below. +/// For example, ArrayBuilder* pointing to BinaryBuilder should be downcast before use. +class ARROW_EXPORT ArrayBuilder { + public: + explicit ArrayBuilder(MemoryPool* pool, int64_t alignment = kDefaultBufferAlignment) + : pool_(pool), alignment_(alignment), null_bitmap_builder_(pool, alignment) {} + + ARROW_DEFAULT_MOVE_AND_ASSIGN(ArrayBuilder); + + virtual ~ArrayBuilder() = default; + + /// For nested types. Since the objects are owned by this class instance, we + /// skip shared pointers and just return a raw pointer + ArrayBuilder* child(int i) { return children_[i].get(); } + + const std::shared_ptr& child_builder(int i) const { return children_[i]; } + + int num_children() const { return static_cast(children_.size()); } + + virtual int64_t length() const { return length_; } + int64_t null_count() const { return null_count_; } + int64_t capacity() const { return capacity_; } + + /// \brief Ensure that enough memory has been allocated to fit the indicated + /// number of total elements in the builder, including any that have already + /// been appended. Does not account for reallocations that may be due to + /// variable size data, like binary values. To make space for incremental + /// appends, use Reserve instead. + /// + /// \param[in] capacity the minimum number of total array values to + /// accommodate. Must be greater than the current capacity. + /// \return Status + virtual Status Resize(int64_t capacity); + + /// \brief Ensure that there is enough space allocated to append the indicated + /// number of elements without any further reallocation. Overallocation is + /// used in order to minimize the impact of incremental Reserve() calls. + /// Note that additional_capacity is relative to the current number of elements + /// rather than to the current capacity, so calls to Reserve() which are not + /// interspersed with addition of new elements may not increase the capacity. + /// + /// \param[in] additional_capacity the number of additional array values + /// \return Status + Status Reserve(int64_t additional_capacity) { + auto current_capacity = capacity(); + auto min_capacity = length() + additional_capacity; + if (min_capacity <= current_capacity) return Status::OK(); + + // leave growth factor up to BufferBuilder + auto new_capacity = BufferBuilder::GrowByFactor(current_capacity, min_capacity); + return Resize(new_capacity); + } + + /// Reset the builder. + virtual void Reset(); + + /// \brief Append a null value to builder + virtual Status AppendNull() = 0; + /// \brief Append a number of null values to builder + virtual Status AppendNulls(int64_t length) = 0; + + /// \brief Append a non-null value to builder + /// + /// The appended value is an implementation detail, but the corresponding + /// memory slot is guaranteed to be initialized. + /// This method is useful when appending a null value to a parent nested type. + virtual Status AppendEmptyValue() = 0; + + /// \brief Append a number of non-null values to builder + /// + /// The appended values are an implementation detail, but the corresponding + /// memory slot is guaranteed to be initialized. + /// This method is useful when appending null values to a parent nested type. + virtual Status AppendEmptyValues(int64_t length) = 0; + + /// \brief Append a value from a scalar + Status AppendScalar(const Scalar& scalar) { return AppendScalar(scalar, 1); } + virtual Status AppendScalar(const Scalar& scalar, int64_t n_repeats); + virtual Status AppendScalars(const ScalarVector& scalars); + + /// \brief Append a range of values from an array. + /// + /// The given array must be the same type as the builder. + virtual Status AppendArraySlice(const ArraySpan& ARROW_ARG_UNUSED(array), + int64_t ARROW_ARG_UNUSED(offset), + int64_t ARROW_ARG_UNUSED(length)) { + return Status::NotImplemented("AppendArraySlice for builder for ", *type()); + } + + /// \brief Return result of builder as an internal generic ArrayData + /// object. Resets builder except for dictionary builder + /// + /// \param[out] out the finalized ArrayData object + /// \return Status + virtual Status FinishInternal(std::shared_ptr* out) = 0; + + /// \brief Return result of builder as an Array object. + /// + /// The builder is reset except for DictionaryBuilder. + /// + /// \param[out] out the finalized Array object + /// \return Status + Status Finish(std::shared_ptr* out); + + /// \brief Return result of builder as an Array object. + /// + /// The builder is reset except for DictionaryBuilder. + /// + /// \return The finalized Array object + Result> Finish(); + + /// \brief Return the type of the built Array + virtual std::shared_ptr type() const = 0; + + protected: + /// Append to null bitmap + Status AppendToBitmap(bool is_valid); + + /// Vector append. Treat each zero byte as a null. If valid_bytes is null + /// assume all of length bits are valid. + Status AppendToBitmap(const uint8_t* valid_bytes, int64_t length); + + /// Uniform append. Append N times the same validity bit. + Status AppendToBitmap(int64_t num_bits, bool value); + + /// Set the next length bits to not null (i.e. valid). + Status SetNotNull(int64_t length); + + // Unsafe operations (don't check capacity/don't resize) + + void UnsafeAppendNull() { UnsafeAppendToBitmap(false); } + + // Append to null bitmap, update the length + void UnsafeAppendToBitmap(bool is_valid) { + null_bitmap_builder_.UnsafeAppend(is_valid); + ++length_; + if (!is_valid) ++null_count_; + } + + // Vector append. Treat each zero byte as a nullzero. If valid_bytes is null + // assume all of length bits are valid. + void UnsafeAppendToBitmap(const uint8_t* valid_bytes, int64_t length) { + if (valid_bytes == NULLPTR) { + return UnsafeSetNotNull(length); + } + null_bitmap_builder_.UnsafeAppend(valid_bytes, length); + length_ += length; + null_count_ = null_bitmap_builder_.false_count(); + } + + // Vector append. Copy from a given bitmap. If bitmap is null assume + // all of length bits are valid. + void UnsafeAppendToBitmap(const uint8_t* bitmap, int64_t offset, int64_t length) { + if (bitmap == NULLPTR) { + return UnsafeSetNotNull(length); + } + null_bitmap_builder_.UnsafeAppend(bitmap, offset, length); + length_ += length; + null_count_ = null_bitmap_builder_.false_count(); + } + + // Append the same validity value a given number of times. + void UnsafeAppendToBitmap(const int64_t num_bits, bool value) { + if (value) { + UnsafeSetNotNull(num_bits); + } else { + UnsafeSetNull(num_bits); + } + } + + void UnsafeAppendToBitmap(const std::vector& is_valid); + + // Set the next validity bits to not null (i.e. valid). + void UnsafeSetNotNull(int64_t length); + + // Set the next validity bits to null (i.e. invalid). + void UnsafeSetNull(int64_t length); + + static Status TrimBuffer(const int64_t bytes_filled, ResizableBuffer* buffer); + + /// \brief Finish to an array of the specified ArrayType + template + Status FinishTyped(std::shared_ptr* out) { + std::shared_ptr out_untyped; + ARROW_RETURN_NOT_OK(Finish(&out_untyped)); + *out = std::static_pointer_cast(std::move(out_untyped)); + return Status::OK(); + } + + // Check the requested capacity for validity + Status CheckCapacity(int64_t new_capacity) { + if (ARROW_PREDICT_FALSE(new_capacity < 0)) { + return Status::Invalid( + "Resize capacity must be positive (requested: ", new_capacity, ")"); + } + + if (ARROW_PREDICT_FALSE(new_capacity < length_)) { + return Status::Invalid("Resize cannot downsize (requested: ", new_capacity, + ", current length: ", length_, ")"); + } + + return Status::OK(); + } + + // Check for array type + Status CheckArrayType(const std::shared_ptr& expected_type, + const Array& array, const char* message); + Status CheckArrayType(Type::type expected_type, const Array& array, + const char* message); + + MemoryPool* pool_; + int64_t alignment_; + + TypedBufferBuilder null_bitmap_builder_; + int64_t null_count_ = 0; + + // Array length, so far. Also, the index of the next element to be added + int64_t length_ = 0; + int64_t capacity_ = 0; + + // Child value array builders. These are owned by this class + std::vector> children_; + + private: + ARROW_DISALLOW_COPY_AND_ASSIGN(ArrayBuilder); +}; + +/// \brief Construct an empty ArrayBuilder corresponding to the data +/// type +/// \param[in] pool the MemoryPool to use for allocations +/// \param[in] type the data type to create the builder for +/// \param[out] out the created ArrayBuilder +ARROW_EXPORT +Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out); + +inline Result> MakeBuilder( + const std::shared_ptr& type, MemoryPool* pool = default_memory_pool()) { + std::unique_ptr out; + ARROW_RETURN_NOT_OK(MakeBuilder(pool, type, &out)); + return out; +} + +/// \brief Construct an empty ArrayBuilder corresponding to the data +/// type, where any top-level or nested dictionary builders return the +/// exact index type specified by the type. +ARROW_EXPORT +Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out); + +inline Result> MakeBuilderExactIndex( + const std::shared_ptr& type, MemoryPool* pool = default_memory_pool()) { + std::unique_ptr out; + ARROW_RETURN_NOT_OK(MakeBuilderExactIndex(pool, type, &out)); + return out; +} + +/// \brief Construct an empty DictionaryBuilder initialized optionally +/// with a preexisting dictionary +/// \param[in] pool the MemoryPool to use for allocations +/// \param[in] type the dictionary type to create the builder for +/// \param[in] dictionary the initial dictionary, if any. May be nullptr +/// \param[out] out the created ArrayBuilder +ARROW_EXPORT +Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr& type, + const std::shared_ptr& dictionary, + std::unique_ptr* out); + +inline Result> MakeDictionaryBuilder( + const std::shared_ptr& type, const std::shared_ptr& dictionary, + MemoryPool* pool = default_memory_pool()) { + std::unique_ptr out; + ARROW_RETURN_NOT_OK(MakeDictionaryBuilder(pool, type, dictionary, &out)); + return out; +} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/builder_binary.h b/pyarrow/include/arrow/array/builder_binary.h new file mode 100644 index 0000000000000000000000000000000000000000..d0e761ae9684132240f21ee335a996bdda081a63 --- /dev/null +++ b/pyarrow/include/arrow/array/builder_binary.h @@ -0,0 +1,993 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/array_binary.h" +#include "arrow/array/builder_base.h" +#include "arrow/array/data.h" +#include "arrow/buffer.h" +#include "arrow/buffer_builder.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/binary_view_util.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \addtogroup binary-builders +/// +/// @{ + +// ---------------------------------------------------------------------- +// Binary and String + +template +class BaseBinaryBuilder + : public ArrayBuilder, + public internal::ArrayBuilderExtraOps, std::string_view> { + public: + using TypeClass = TYPE; + using offset_type = typename TypeClass::offset_type; + + explicit BaseBinaryBuilder(MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : ArrayBuilder(pool, alignment), + offsets_builder_(pool, alignment), + value_data_builder_(pool, alignment) {} + + BaseBinaryBuilder(const std::shared_ptr& type, MemoryPool* pool) + : BaseBinaryBuilder(pool) {} + + Status Append(const uint8_t* value, offset_type length) { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppendNextOffset(); + // Safety check for UBSAN. + if (ARROW_PREDICT_TRUE(length > 0)) { + ARROW_RETURN_NOT_OK(ValidateOverflow(length)); + ARROW_RETURN_NOT_OK(value_data_builder_.Append(value, length)); + } + + UnsafeAppendToBitmap(true); + return Status::OK(); + } + + Status Append(const char* value, offset_type length) { + return Append(reinterpret_cast(value), length); + } + + Status Append(std::string_view value) { + return Append(value.data(), static_cast(value.size())); + } + + /// Extend the last appended value by appending more data at the end + /// + /// Unlike Append, this does not create a new offset. + Status ExtendCurrent(const uint8_t* value, offset_type length) { + // Safety check for UBSAN. + if (ARROW_PREDICT_TRUE(length > 0)) { + ARROW_RETURN_NOT_OK(ValidateOverflow(length)); + ARROW_RETURN_NOT_OK(value_data_builder_.Append(value, length)); + } + return Status::OK(); + } + + Status ExtendCurrent(std::string_view value) { + return ExtendCurrent(reinterpret_cast(value.data()), + static_cast(value.size())); + } + + Status AppendNulls(int64_t length) final { + const int64_t num_bytes = value_data_builder_.length(); + ARROW_RETURN_NOT_OK(Reserve(length)); + for (int64_t i = 0; i < length; ++i) { + offsets_builder_.UnsafeAppend(static_cast(num_bytes)); + } + UnsafeAppendToBitmap(length, false); + return Status::OK(); + } + + Status AppendNull() final { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppendNextOffset(); + UnsafeAppendToBitmap(false); + return Status::OK(); + } + + Status AppendEmptyValue() final { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppendNextOffset(); + UnsafeAppendToBitmap(true); + return Status::OK(); + } + + Status AppendEmptyValues(int64_t length) final { + const int64_t num_bytes = value_data_builder_.length(); + ARROW_RETURN_NOT_OK(Reserve(length)); + for (int64_t i = 0; i < length; ++i) { + offsets_builder_.UnsafeAppend(static_cast(num_bytes)); + } + UnsafeAppendToBitmap(length, true); + return Status::OK(); + } + + /// \brief Append without checking capacity + /// + /// Offsets and data should have been presized using Reserve() and + /// ReserveData(), respectively. + void UnsafeAppend(const uint8_t* value, offset_type length) { + UnsafeAppendNextOffset(); + value_data_builder_.UnsafeAppend(value, length); + UnsafeAppendToBitmap(true); + } + + void UnsafeAppend(const char* value, offset_type length) { + UnsafeAppend(reinterpret_cast(value), length); + } + + void UnsafeAppend(const std::string& value) { + UnsafeAppend(value.c_str(), static_cast(value.size())); + } + + void UnsafeAppend(std::string_view value) { + UnsafeAppend(value.data(), static_cast(value.size())); + } + + /// Like ExtendCurrent, but do not check capacity + void UnsafeExtendCurrent(const uint8_t* value, offset_type length) { + value_data_builder_.UnsafeAppend(value, length); + } + + void UnsafeExtendCurrent(std::string_view value) { + UnsafeExtendCurrent(reinterpret_cast(value.data()), + static_cast(value.size())); + } + + void UnsafeAppendNull() { + const int64_t num_bytes = value_data_builder_.length(); + offsets_builder_.UnsafeAppend(static_cast(num_bytes)); + UnsafeAppendToBitmap(false); + } + + void UnsafeAppendEmptyValue() { + const int64_t num_bytes = value_data_builder_.length(); + offsets_builder_.UnsafeAppend(static_cast(num_bytes)); + UnsafeAppendToBitmap(true); + } + + /// \brief Append a sequence of strings in one shot. + /// + /// \param[in] values a vector of strings + /// \param[in] valid_bytes an optional sequence of bytes where non-zero + /// indicates a valid (non-null) value + /// \return Status + Status AppendValues(const std::vector& values, + const uint8_t* valid_bytes = NULLPTR) { + std::size_t total_length = std::accumulate( + values.begin(), values.end(), 0ULL, + [](uint64_t sum, const std::string& str) { return sum + str.size(); }); + ARROW_RETURN_NOT_OK(Reserve(values.size())); + ARROW_RETURN_NOT_OK(ReserveData(total_length)); + + if (valid_bytes != NULLPTR) { + for (std::size_t i = 0; i < values.size(); ++i) { + UnsafeAppendNextOffset(); + if (valid_bytes[i]) { + value_data_builder_.UnsafeAppend( + reinterpret_cast(values[i].data()), values[i].size()); + } + } + } else { + for (const auto& value : values) { + UnsafeAppendNextOffset(); + value_data_builder_.UnsafeAppend(reinterpret_cast(value.data()), + value.size()); + } + } + + UnsafeAppendToBitmap(valid_bytes, values.size()); + return Status::OK(); + } + + /// \brief Append a sequence of nul-terminated strings in one shot. + /// If one of the values is NULL, it is processed as a null + /// value even if the corresponding valid_bytes entry is 1. + /// + /// \param[in] values a contiguous C array of nul-terminated char * + /// \param[in] length the number of values to append + /// \param[in] valid_bytes an optional sequence of bytes where non-zero + /// indicates a valid (non-null) value + /// \return Status + Status AppendValues(const char** values, int64_t length, + const uint8_t* valid_bytes = NULLPTR) { + std::size_t total_length = 0; + std::vector value_lengths(length); + bool have_null_value = false; + for (int64_t i = 0; i < length; ++i) { + if (values[i] != NULLPTR) { + auto value_length = strlen(values[i]); + value_lengths[i] = value_length; + total_length += value_length; + } else { + have_null_value = true; + } + } + ARROW_RETURN_NOT_OK(Reserve(length)); + ARROW_RETURN_NOT_OK(ReserveData(total_length)); + + if (valid_bytes) { + int64_t valid_bytes_offset = 0; + for (int64_t i = 0; i < length; ++i) { + UnsafeAppendNextOffset(); + if (valid_bytes[i]) { + if (values[i]) { + value_data_builder_.UnsafeAppend(reinterpret_cast(values[i]), + value_lengths[i]); + } else { + UnsafeAppendToBitmap(valid_bytes + valid_bytes_offset, + i - valid_bytes_offset); + UnsafeAppendToBitmap(false); + valid_bytes_offset = i + 1; + } + } + } + UnsafeAppendToBitmap(valid_bytes + valid_bytes_offset, length - valid_bytes_offset); + } else { + if (have_null_value) { + std::vector valid_vector(length, 0); + for (int64_t i = 0; i < length; ++i) { + UnsafeAppendNextOffset(); + if (values[i]) { + value_data_builder_.UnsafeAppend(reinterpret_cast(values[i]), + value_lengths[i]); + valid_vector[i] = 1; + } + } + UnsafeAppendToBitmap(valid_vector.data(), length); + } else { + for (int64_t i = 0; i < length; ++i) { + UnsafeAppendNextOffset(); + value_data_builder_.UnsafeAppend(reinterpret_cast(values[i]), + value_lengths[i]); + } + UnsafeAppendToBitmap(NULLPTR, length); + } + } + return Status::OK(); + } + + Status AppendArraySlice(const ArraySpan& array, int64_t offset, + int64_t length) override { + auto bitmap = array.GetValues(0, 0); + auto offsets = array.GetValues(1); + auto data = array.GetValues(2, 0); + auto total_length = offsets[offset + length] - offsets[offset]; + ARROW_RETURN_NOT_OK(Reserve(length)); + ARROW_RETURN_NOT_OK(ReserveData(total_length)); + for (int64_t i = 0; i < length; i++) { + if (!bitmap || bit_util::GetBit(bitmap, array.offset + offset + i)) { + const offset_type start = offsets[offset + i]; + const offset_type end = offsets[offset + i + 1]; + UnsafeAppend(data + start, end - start); + } else { + UnsafeAppendNull(); + } + } + return Status::OK(); + } + + void Reset() override { + ArrayBuilder::Reset(); + offsets_builder_.Reset(); + value_data_builder_.Reset(); + } + + Status ValidateOverflow(int64_t new_bytes) { + auto new_size = value_data_builder_.length() + new_bytes; + if (ARROW_PREDICT_FALSE(new_size > memory_limit())) { + return Status::CapacityError("array cannot contain more than ", memory_limit(), + " bytes, have ", new_size); + } else { + return Status::OK(); + } + } + + Status Resize(int64_t capacity) override { + ARROW_RETURN_NOT_OK(CheckCapacity(capacity)); + // One more than requested for offsets + ARROW_RETURN_NOT_OK(offsets_builder_.Resize(capacity + 1)); + return ArrayBuilder::Resize(capacity); + } + + /// \brief Ensures there is enough allocated capacity to append the indicated + /// number of bytes to the value data buffer without additional allocations + Status ReserveData(int64_t elements) { + ARROW_RETURN_NOT_OK(ValidateOverflow(elements)); + return value_data_builder_.Reserve(elements); + } + + Status FinishInternal(std::shared_ptr* out) override { + // Write final offset (values length) + ARROW_RETURN_NOT_OK(AppendNextOffset()); + + // These buffers' padding zeroed by BufferBuilder + std::shared_ptr offsets, value_data, null_bitmap; + ARROW_RETURN_NOT_OK(offsets_builder_.Finish(&offsets)); + ARROW_RETURN_NOT_OK(value_data_builder_.Finish(&value_data)); + ARROW_RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap)); + + *out = ArrayData::Make(type(), length_, {null_bitmap, offsets, value_data}, + null_count_, 0); + Reset(); + return Status::OK(); + } + + /// \return data pointer of the value date builder + const uint8_t* value_data() const { return value_data_builder_.data(); } + /// \return size of values buffer so far + int64_t value_data_length() const { return value_data_builder_.length(); } + /// \return capacity of values buffer + int64_t value_data_capacity() const { return value_data_builder_.capacity(); } + + /// \return data pointer of the value date builder + const offset_type* offsets_data() const { return offsets_builder_.data(); } + + /// Temporary access to a value. + /// + /// This pointer becomes invalid on the next modifying operation. + const uint8_t* GetValue(int64_t i, offset_type* out_length) const { + const offset_type* offsets = offsets_builder_.data(); + const auto offset = offsets[i]; + if (i == (length_ - 1)) { + *out_length = static_cast(value_data_builder_.length()) - offset; + } else { + *out_length = offsets[i + 1] - offset; + } + return value_data_builder_.data() + offset; + } + + offset_type offset(int64_t i) const { return offsets_data()[i]; } + + /// Temporary access to a value. + /// + /// This view becomes invalid on the next modifying operation. + std::string_view GetView(int64_t i) const { + offset_type value_length; + const uint8_t* value_data = GetValue(i, &value_length); + return std::string_view(reinterpret_cast(value_data), value_length); + } + + // Cannot make this a static attribute because of linking issues + static constexpr int64_t memory_limit() { + return std::numeric_limits::max() - 1; + } + + protected: + TypedBufferBuilder offsets_builder_; + TypedBufferBuilder value_data_builder_; + + Status AppendNextOffset() { + const int64_t num_bytes = value_data_builder_.length(); + return offsets_builder_.Append(static_cast(num_bytes)); + } + + void UnsafeAppendNextOffset() { + const int64_t num_bytes = value_data_builder_.length(); + offsets_builder_.UnsafeAppend(static_cast(num_bytes)); + } +}; + +/// \class BinaryBuilder +/// \brief Builder class for variable-length binary data +class ARROW_EXPORT BinaryBuilder : public BaseBinaryBuilder { + public: + using BaseBinaryBuilder::BaseBinaryBuilder; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return binary(); } +}; + +/// \class StringBuilder +/// \brief Builder class for UTF8 strings +class ARROW_EXPORT StringBuilder : public BinaryBuilder { + public: + using BinaryBuilder::BinaryBuilder; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return utf8(); } +}; + +/// \class LargeBinaryBuilder +/// \brief Builder class for large variable-length binary data +class ARROW_EXPORT LargeBinaryBuilder : public BaseBinaryBuilder { + public: + using BaseBinaryBuilder::BaseBinaryBuilder; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return large_binary(); } +}; + +/// \class LargeStringBuilder +/// \brief Builder class for large UTF8 strings +class ARROW_EXPORT LargeStringBuilder : public LargeBinaryBuilder { + public: + using LargeBinaryBuilder::LargeBinaryBuilder; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return large_utf8(); } +}; + +// ---------------------------------------------------------------------- +// BinaryViewBuilder, StringViewBuilder +// +// These builders do not support building raw pointer view arrays. + +namespace internal { + +// We allocate medium-sized memory chunks and accumulate data in those, which +// may result in some waste if there are many large-ish strings. If a string +// comes along that does not fit into a block, we allocate a new block and +// write into that. +// +// Later we can implement optimizations to continuing filling underfull blocks +// after encountering a large string that required allocating a new block. +class ARROW_EXPORT StringHeapBuilder { + public: + static constexpr int64_t kDefaultBlocksize = 32 << 10; // 32KB + + StringHeapBuilder(MemoryPool* pool, int64_t alignment) + : pool_(pool), alignment_(alignment) {} + + void SetBlockSize(int64_t blocksize) { blocksize_ = blocksize; } + + using c_type = BinaryViewType::c_type; + + template + std::conditional_t, c_type> Append(const uint8_t* value, + int64_t length) { + if (length <= BinaryViewType::kInlineSize) { + return util::ToInlineBinaryView(value, static_cast(length)); + } + + if constexpr (Safe) { + ARROW_RETURN_NOT_OK(Reserve(length)); + } + + auto v = util::ToNonInlineBinaryView(value, static_cast(length), + static_cast(blocks_.size() - 1), + current_offset_); + + memcpy(current_out_buffer_, value, static_cast(length)); + current_out_buffer_ += length; + current_remaining_bytes_ -= length; + current_offset_ += static_cast(length); + return v; + } + + static constexpr int64_t ValueSizeLimit() { + return std::numeric_limits::max(); + } + + /// \brief Ensure that the indicated number of bytes can be appended via + /// UnsafeAppend operations without the need to allocate more memory + Status Reserve(int64_t num_bytes) { + if (ARROW_PREDICT_FALSE(num_bytes > ValueSizeLimit())) { + return Status::CapacityError( + "BinaryView or StringView elements cannot reference " + "strings larger than 2GB"); + } + if (num_bytes > current_remaining_bytes_) { + ARROW_RETURN_NOT_OK(FinishLastBlock()); + current_remaining_bytes_ = num_bytes > blocksize_ ? num_bytes : blocksize_; + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr new_block, + AllocateResizableBuffer(current_remaining_bytes_, alignment_, pool_)); + current_offset_ = 0; + current_out_buffer_ = new_block->mutable_data(); + blocks_.emplace_back(std::move(new_block)); + } + return Status::OK(); + } + + void Reset() { + current_offset_ = 0; + current_out_buffer_ = NULLPTR; + current_remaining_bytes_ = 0; + blocks_.clear(); + } + + int64_t current_remaining_bytes() const { return current_remaining_bytes_; } + + Result>> Finish() { + if (!blocks_.empty()) { + ARROW_RETURN_NOT_OK(FinishLastBlock()); + } + current_offset_ = 0; + current_out_buffer_ = NULLPTR; + current_remaining_bytes_ = 0; + return std::move(blocks_); + } + + private: + Status FinishLastBlock() { + if (current_remaining_bytes_ > 0) { + // Avoid leaking uninitialized bytes from the allocator + ARROW_RETURN_NOT_OK( + blocks_.back()->Resize(blocks_.back()->size() - current_remaining_bytes_, + /*shrink_to_fit=*/true)); + blocks_.back()->ZeroPadding(); + } + return Status::OK(); + } + + MemoryPool* pool_; + int64_t alignment_; + int64_t blocksize_ = kDefaultBlocksize; + std::vector> blocks_; + + int32_t current_offset_ = 0; + uint8_t* current_out_buffer_ = NULLPTR; + int64_t current_remaining_bytes_ = 0; +}; + +} // namespace internal + +class ARROW_EXPORT BinaryViewBuilder : public ArrayBuilder { + public: + using TypeClass = BinaryViewType; + + // this constructor provided for MakeBuilder compatibility + BinaryViewBuilder(const std::shared_ptr&, MemoryPool* pool); + + explicit BinaryViewBuilder(MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : ArrayBuilder(pool, alignment), + data_builder_(pool, alignment), + data_heap_builder_(pool, alignment) {} + + /// Set the size for future preallocated data buffers. + /// + /// The default size is 32KB, so after each 32KB of string data appended to the builder + /// a new data buffer will be allocated. Adjust this to a larger value to decrease the + /// frequency of allocation, or to a smaller value to lower the overhead of each + /// allocation. + void SetBlockSize(int64_t blocksize) { data_heap_builder_.SetBlockSize(blocksize); } + + /// The number of bytes which can be appended to this builder without allocating another + /// data buffer. + int64_t current_block_bytes_remaining() const { + return data_heap_builder_.current_remaining_bytes(); + } + + Status Append(const uint8_t* value, int64_t length) { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppendToBitmap(true); + ARROW_ASSIGN_OR_RAISE(auto v, + data_heap_builder_.Append(value, length)); + data_builder_.UnsafeAppend(v); + return Status::OK(); + } + + Status Append(const char* value, int64_t length) { + return Append(reinterpret_cast(value), length); + } + + Status Append(std::string_view value) { + return Append(value.data(), static_cast(value.size())); + } + + /// \brief Append without checking capacity + /// + /// Builder should have been presized using Reserve() and ReserveData(), + /// respectively, and the value must not be larger than 2GB + void UnsafeAppend(const uint8_t* value, int64_t length) { + UnsafeAppendToBitmap(true); + auto v = data_heap_builder_.Append(value, length); + data_builder_.UnsafeAppend(v); + } + + void UnsafeAppend(const char* value, int64_t length) { + UnsafeAppend(reinterpret_cast(value), length); + } + + void UnsafeAppend(const std::string& value) { + UnsafeAppend(value.c_str(), static_cast(value.size())); + } + + void UnsafeAppend(std::string_view value) { + UnsafeAppend(value.data(), static_cast(value.size())); + } + + /// \brief Ensures there is enough allocated available capacity in the + /// out-of-line data heap to append the indicated number of bytes without + /// additional allocations + Status ReserveData(int64_t length); + + Status AppendNulls(int64_t length) final { + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(length, BinaryViewType::c_type{}); + UnsafeSetNull(length); + return Status::OK(); + } + + /// \brief Append a single null element + Status AppendNull() final { + ARROW_RETURN_NOT_OK(Reserve(1)); + data_builder_.UnsafeAppend(BinaryViewType::c_type{}); + UnsafeAppendToBitmap(false); + return Status::OK(); + } + + /// \brief Append a empty element (length-0 inline string) + Status AppendEmptyValue() final { + ARROW_RETURN_NOT_OK(Reserve(1)); + data_builder_.UnsafeAppend(BinaryViewType::c_type{}); + UnsafeAppendToBitmap(true); + return Status::OK(); + } + + /// \brief Append several empty elements + Status AppendEmptyValues(int64_t length) final { + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(length, BinaryViewType::c_type{}); + UnsafeSetNotNull(length); + return Status::OK(); + } + + void UnsafeAppendNull() { + data_builder_.UnsafeAppend(BinaryViewType::c_type{}); + UnsafeAppendToBitmap(false); + } + + void UnsafeAppendEmptyValue() { + data_builder_.UnsafeAppend(BinaryViewType::c_type{}); + UnsafeAppendToBitmap(true); + } + + /// \brief Append a slice of a BinaryViewArray passed as an ArraySpan. Copies + /// the underlying out-of-line string memory to avoid memory lifetime issues + Status AppendArraySlice(const ArraySpan& array, int64_t offset, + int64_t length) override; + + void Reset() override; + + Status Resize(int64_t capacity) override { + ARROW_RETURN_NOT_OK(CheckCapacity(capacity)); + capacity = std::max(capacity, kMinBuilderCapacity); + ARROW_RETURN_NOT_OK(data_builder_.Resize(capacity)); + return ArrayBuilder::Resize(capacity); + } + + Status FinishInternal(std::shared_ptr* out) override; + + std::shared_ptr type() const override { return binary_view(); } + + protected: + TypedBufferBuilder data_builder_; + + // Accumulates out-of-line data in fixed-size chunks which are then attached + // to the resulting ArrayData + internal::StringHeapBuilder data_heap_builder_; +}; + +class ARROW_EXPORT StringViewBuilder : public BinaryViewBuilder { + public: + using BinaryViewBuilder::BinaryViewBuilder; + std::shared_ptr type() const override { return utf8_view(); } +}; + +// ---------------------------------------------------------------------- +// FixedSizeBinaryBuilder + +class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { + public: + using TypeClass = FixedSizeBinaryType; + + explicit FixedSizeBinaryBuilder(const std::shared_ptr& type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment); + + Status Append(const uint8_t* value) { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppend(value); + return Status::OK(); + } + + Status Append(const char* value) { + return Append(reinterpret_cast(value)); + } + + Status Append(std::string_view view) { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppend(view); + return Status::OK(); + } + + Status Append(const std::string& s) { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppend(s); + return Status::OK(); + } + + Status Append(const Buffer& s) { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppend(s); + return Status::OK(); + } + + Status Append(const std::shared_ptr& s) { return Append(*s); } + + template + Status Append(const std::array& value) { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppend( + std::string_view(reinterpret_cast(value.data()), value.size())); + return Status::OK(); + } + + Status AppendValues(const uint8_t* data, int64_t length, + const uint8_t* valid_bytes = NULLPTR); + + Status AppendValues(const uint8_t* data, int64_t length, const uint8_t* validity, + int64_t bitmap_offset); + + Status AppendNull() final; + Status AppendNulls(int64_t length) final; + + Status AppendEmptyValue() final; + Status AppendEmptyValues(int64_t length) final; + + Status AppendArraySlice(const ArraySpan& array, int64_t offset, + int64_t length) override { + return AppendValues( + array.GetValues(1, 0) + ((array.offset + offset) * byte_width_), length, + array.GetValues(0, 0), array.offset + offset); + } + + void UnsafeAppend(const uint8_t* value) { + UnsafeAppendToBitmap(true); + if (ARROW_PREDICT_TRUE(byte_width_ > 0)) { + byte_builder_.UnsafeAppend(value, byte_width_); + } + } + + void UnsafeAppend(const char* value) { + UnsafeAppend(reinterpret_cast(value)); + } + + void UnsafeAppend(std::string_view value) { +#ifndef NDEBUG + CheckValueSize(static_cast(value.size())); +#endif + UnsafeAppend(reinterpret_cast(value.data())); + } + + void UnsafeAppend(const Buffer& s) { UnsafeAppend(std::string_view{s}); } + + void UnsafeAppend(const std::shared_ptr& s) { UnsafeAppend(*s); } + + void UnsafeAppendNull() { + UnsafeAppendToBitmap(false); + byte_builder_.UnsafeAppend(/*num_copies=*/byte_width_, 0); + } + + Status ValidateOverflow(int64_t new_bytes) const { + auto new_size = byte_builder_.length() + new_bytes; + if (ARROW_PREDICT_FALSE(new_size > memory_limit())) { + return Status::CapacityError("array cannot contain more than ", memory_limit(), + " bytes, have ", new_size); + } else { + return Status::OK(); + } + } + + /// \brief Ensures there is enough allocated capacity to append the indicated + /// number of bytes to the value data buffer without additional allocations + Status ReserveData(int64_t elements) { + ARROW_RETURN_NOT_OK(ValidateOverflow(elements)); + return byte_builder_.Reserve(elements); + } + + void Reset() override; + Status Resize(int64_t capacity) override; + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + /// \return size of values buffer so far + int64_t value_data_length() const { return byte_builder_.length(); } + + int32_t byte_width() const { return byte_width_; } + + /// Temporary access to a value. + /// + /// This pointer becomes invalid on the next modifying operation. + const uint8_t* GetValue(int64_t i) const; + + /// Temporary mutable access to a value. + /// + /// This pointer becomes invalid on the next modifying operation. + uint8_t* GetMutableValue(int64_t i) { + uint8_t* data_ptr = byte_builder_.mutable_data(); + return data_ptr + i * byte_width_; + } + + /// Temporary mutable access to a value. + /// + /// This view becomes invalid on the next modifying operation. + std::string_view GetView(int64_t i) const; + + /// Advance builder without allocating nor writing any values + /// + /// The internal pointer is advanced by `length` values and the same number + /// of non-null entries are appended to the validity bitmap. + /// This method assumes that the `length` values were populated directly, + /// for example using `GetMutableValue`. + void UnsafeAdvance(int64_t length) { + byte_builder_.UnsafeAdvance(length * byte_width_); + UnsafeAppendToBitmap(length, true); + } + + /// Advance builder without allocating nor writing any values + /// + /// The internal pointer is advanced by `length` values and the same number + /// of validity bits are appended to the validity bitmap. + /// This method assumes that the `length` values were populated directly, + /// for example using `GetMutableValue`. + void UnsafeAdvance(int64_t length, const uint8_t* validity, int64_t valid_bits_offset) { + byte_builder_.UnsafeAdvance(length * byte_width_); + UnsafeAppendToBitmap(validity, valid_bits_offset, length); + } + + static constexpr int64_t memory_limit() { + return std::numeric_limits::max() - 1; + } + + std::shared_ptr type() const override { + return fixed_size_binary(byte_width_); + } + + protected: + int32_t byte_width_; + BufferBuilder byte_builder_; + + void CheckValueSize(int64_t size); +}; + +/// @} + +// ---------------------------------------------------------------------- +// Chunked builders: build a sequence of BinaryArray or StringArray that are +// limited to a particular size (to the upper limit of 2GB) + +namespace internal { + +class ARROW_EXPORT ChunkedBinaryBuilder { + public: + explicit ChunkedBinaryBuilder(int32_t max_chunk_value_length, + MemoryPool* pool = default_memory_pool()); + + ChunkedBinaryBuilder(int32_t max_chunk_value_length, int32_t max_chunk_length, + MemoryPool* pool = default_memory_pool()); + + virtual ~ChunkedBinaryBuilder() = default; + + Status Append(const uint8_t* value, int32_t length) { + if (ARROW_PREDICT_FALSE(length + builder_->value_data_length() > + max_chunk_value_length_)) { + if (builder_->value_data_length() == 0) { + // The current item is larger than max_chunk_size_; + // this chunk will be oversize and hold *only* this item + ARROW_RETURN_NOT_OK(builder_->Append(value, length)); + return NextChunk(); + } + // The current item would cause builder_->value_data_length() to exceed + // max_chunk_size_, so finish this chunk and append the current item to the next + // chunk + ARROW_RETURN_NOT_OK(NextChunk()); + return Append(value, length); + } + + if (ARROW_PREDICT_FALSE(builder_->length() == max_chunk_length_)) { + // The current item would cause builder_->length() to exceed max_chunk_length_, so + // finish this chunk and append the current item to the next chunk + ARROW_RETURN_NOT_OK(NextChunk()); + } + + return builder_->Append(value, length); + } + + Status Append(std::string_view value) { + return Append(reinterpret_cast(value.data()), + static_cast(value.size())); + } + + Status AppendNull() { + if (ARROW_PREDICT_FALSE(builder_->length() == max_chunk_length_)) { + ARROW_RETURN_NOT_OK(NextChunk()); + } + return builder_->AppendNull(); + } + + Status Reserve(int64_t values); + + virtual Status Finish(ArrayVector* out); + + protected: + Status NextChunk(); + + // maximum total character data size per chunk + int64_t max_chunk_value_length_; + + // maximum elements allowed per chunk + int64_t max_chunk_length_ = kListMaximumElements; + + // when Reserve() would cause builder_ to exceed its max_chunk_length_, + // add to extra_capacity_ instead and wait to reserve until the next chunk + int64_t extra_capacity_ = 0; + + std::unique_ptr builder_; + std::vector> chunks_; +}; + +class ARROW_EXPORT ChunkedStringBuilder : public ChunkedBinaryBuilder { + public: + using ChunkedBinaryBuilder::ChunkedBinaryBuilder; + + Status Finish(ArrayVector* out) override; +}; + +} // namespace internal + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/builder_decimal.h b/pyarrow/include/arrow/array/builder_decimal.h new file mode 100644 index 0000000000000000000000000000000000000000..a0bf0a04220842cceada0d0754ad6be4e41a3093 --- /dev/null +++ b/pyarrow/include/arrow/array/builder_decimal.h @@ -0,0 +1,164 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/array/array_decimal.h" +#include "arrow/array/builder_base.h" +#include "arrow/array/builder_binary.h" +#include "arrow/array/data.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \addtogroup numeric-builders +/// +/// @{ + +class ARROW_EXPORT Decimal32Builder : public FixedSizeBinaryBuilder { + public: + using TypeClass = Decimal32Type; + using ValueType = Decimal32; + + explicit Decimal32Builder(const std::shared_ptr& type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment); + + using FixedSizeBinaryBuilder::Append; + using FixedSizeBinaryBuilder::AppendValues; + using FixedSizeBinaryBuilder::Reset; + + Status Append(Decimal32 val); + void UnsafeAppend(Decimal32 val); + void UnsafeAppend(std::string_view val); + + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return decimal_type_; } + + protected: + std::shared_ptr decimal_type_; +}; + +class ARROW_EXPORT Decimal64Builder : public FixedSizeBinaryBuilder { + public: + using TypeClass = Decimal64Type; + using ValueType = Decimal64; + + explicit Decimal64Builder(const std::shared_ptr& type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment); + + using FixedSizeBinaryBuilder::Append; + using FixedSizeBinaryBuilder::AppendValues; + using FixedSizeBinaryBuilder::Reset; + + Status Append(Decimal64 val); + void UnsafeAppend(Decimal64 val); + void UnsafeAppend(std::string_view val); + + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return decimal_type_; } + + protected: + std::shared_ptr decimal_type_; +}; + +class ARROW_EXPORT Decimal128Builder : public FixedSizeBinaryBuilder { + public: + using TypeClass = Decimal128Type; + using ValueType = Decimal128; + + explicit Decimal128Builder(const std::shared_ptr& type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment); + + using FixedSizeBinaryBuilder::Append; + using FixedSizeBinaryBuilder::AppendValues; + using FixedSizeBinaryBuilder::Reset; + + Status Append(Decimal128 val); + void UnsafeAppend(Decimal128 val); + void UnsafeAppend(std::string_view val); + + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return decimal_type_; } + + protected: + std::shared_ptr decimal_type_; +}; + +class ARROW_EXPORT Decimal256Builder : public FixedSizeBinaryBuilder { + public: + using TypeClass = Decimal256Type; + using ValueType = Decimal256; + + explicit Decimal256Builder(const std::shared_ptr& type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment); + + using FixedSizeBinaryBuilder::Append; + using FixedSizeBinaryBuilder::AppendValues; + using FixedSizeBinaryBuilder::Reset; + + Status Append(const Decimal256& val); + void UnsafeAppend(const Decimal256& val); + void UnsafeAppend(std::string_view val); + + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { return decimal_type_; } + + protected: + std::shared_ptr decimal_type_; +}; + +using DecimalBuilder = Decimal128Builder; + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/builder_dict.h b/pyarrow/include/arrow/array/builder_dict.h new file mode 100644 index 0000000000000000000000000000000000000000..116c82049eea9ea49a716452090297f57be4eb6b --- /dev/null +++ b/pyarrow/include/arrow/array/builder_dict.h @@ -0,0 +1,728 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/array_binary.h" +#include "arrow/array/builder_adaptive.h" // IWYU pragma: export +#include "arrow/array/builder_base.h" // IWYU pragma: export +#include "arrow/array/builder_primitive.h" // IWYU pragma: export +#include "arrow/array/data.h" +#include "arrow/array/util.h" +#include "arrow/scalar.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/decimal.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +// ---------------------------------------------------------------------- +// Dictionary builder + +namespace internal { + +template +struct DictionaryValue { + using type = typename T::c_type; + using PhysicalType = T; +}; + +template +struct DictionaryValue> { + using type = std::string_view; + using PhysicalType = + typename std::conditional::value, + BinaryType, LargeBinaryType>::type; +}; + +template +struct DictionaryValue> { + using type = std::string_view; + using PhysicalType = BinaryViewType; +}; + +template +struct DictionaryValue> { + using type = std::string_view; + using PhysicalType = BinaryType; +}; + +class ARROW_EXPORT DictionaryMemoTable { + public: + DictionaryMemoTable(MemoryPool* pool, const std::shared_ptr& type); + DictionaryMemoTable(MemoryPool* pool, const std::shared_ptr& dictionary); + ~DictionaryMemoTable(); + + Status GetArrayData(int64_t start_offset, std::shared_ptr* out); + + /// \brief Insert new memo values + Status InsertValues(const Array& values); + + int32_t size() const; + + template + Status GetOrInsert(typename DictionaryValue::type value, int32_t* out) { + // We want to keep the DictionaryMemoTable implementation private, also we can't + // use extern template classes because of compiler issues (MinGW?). Instead, + // we expose explicit function overrides for each supported physical type. + const typename DictionaryValue::PhysicalType* physical_type = NULLPTR; + return GetOrInsert(physical_type, value, out); + } + + private: + Status GetOrInsert(const BooleanType*, bool value, int32_t* out); + Status GetOrInsert(const Int8Type*, int8_t value, int32_t* out); + Status GetOrInsert(const Int16Type*, int16_t value, int32_t* out); + Status GetOrInsert(const Int32Type*, int32_t value, int32_t* out); + Status GetOrInsert(const Int64Type*, int64_t value, int32_t* out); + Status GetOrInsert(const UInt8Type*, uint8_t value, int32_t* out); + Status GetOrInsert(const UInt16Type*, uint16_t value, int32_t* out); + Status GetOrInsert(const UInt32Type*, uint32_t value, int32_t* out); + Status GetOrInsert(const UInt64Type*, uint64_t value, int32_t* out); + Status GetOrInsert(const DurationType*, int64_t value, int32_t* out); + Status GetOrInsert(const TimestampType*, int64_t value, int32_t* out); + Status GetOrInsert(const Date32Type*, int32_t value, int32_t* out); + Status GetOrInsert(const Date64Type*, int64_t value, int32_t* out); + Status GetOrInsert(const Time32Type*, int32_t value, int32_t* out); + Status GetOrInsert(const Time64Type*, int64_t value, int32_t* out); + Status GetOrInsert(const MonthDayNanoIntervalType*, + MonthDayNanoIntervalType::MonthDayNanos value, int32_t* out); + Status GetOrInsert(const DayTimeIntervalType*, + DayTimeIntervalType::DayMilliseconds value, int32_t* out); + Status GetOrInsert(const MonthIntervalType*, int32_t value, int32_t* out); + Status GetOrInsert(const FloatType*, float value, int32_t* out); + Status GetOrInsert(const DoubleType*, double value, int32_t* out); + + Status GetOrInsert(const BinaryType*, std::string_view value, int32_t* out); + Status GetOrInsert(const LargeBinaryType*, std::string_view value, int32_t* out); + Status GetOrInsert(const BinaryViewType*, std::string_view value, int32_t* out); + + class DictionaryMemoTableImpl; + std::unique_ptr impl_; +}; + +} // namespace internal + +/// \addtogroup dictionary-builders +/// +/// @{ + +namespace internal { + +/// \brief Array builder for created encoded DictionaryArray from +/// dense array +/// +/// Unlike other builders, dictionary builder does not completely +/// reset the state on Finish calls. +template +class DictionaryBuilderBase : public ArrayBuilder { + public: + using TypeClass = DictionaryType; + using Value = typename DictionaryValue::type; + + // WARNING: the type given below is the value type, not the DictionaryType. + // The DictionaryType is instantiated on the Finish() call. + template + DictionaryBuilderBase(uint8_t start_int_size, + enable_if_t::value && + !is_fixed_size_binary_type::value, + const std::shared_ptr&> + value_type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : ArrayBuilder(pool, alignment), + memo_table_(new internal::DictionaryMemoTable(pool, value_type)), + delta_offset_(0), + byte_width_(-1), + indices_builder_(start_int_size, pool, alignment), + value_type_(value_type) {} + + template + explicit DictionaryBuilderBase( + enable_if_t::value, const std::shared_ptr&> + value_type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : ArrayBuilder(pool, alignment), + memo_table_(new internal::DictionaryMemoTable(pool, value_type)), + delta_offset_(0), + byte_width_(-1), + indices_builder_(pool, alignment), + value_type_(value_type) {} + + template + explicit DictionaryBuilderBase( + const std::shared_ptr& index_type, + enable_if_t::value, const std::shared_ptr&> + value_type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : ArrayBuilder(pool, alignment), + memo_table_(new internal::DictionaryMemoTable(pool, value_type)), + delta_offset_(0), + byte_width_(-1), + indices_builder_(index_type, pool, alignment), + value_type_(value_type) {} + + template + DictionaryBuilderBase(uint8_t start_int_size, + enable_if_t::value && + is_fixed_size_binary_type::value, + const std::shared_ptr&> + value_type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : ArrayBuilder(pool, alignment), + memo_table_(new internal::DictionaryMemoTable(pool, value_type)), + delta_offset_(0), + byte_width_(static_cast(*value_type).byte_width()), + indices_builder_(start_int_size, pool, alignment), + value_type_(value_type) {} + + template + explicit DictionaryBuilderBase( + enable_if_fixed_size_binary&> value_type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : ArrayBuilder(pool, alignment), + memo_table_(new internal::DictionaryMemoTable(pool, value_type)), + delta_offset_(0), + byte_width_(static_cast(*value_type).byte_width()), + indices_builder_(pool, alignment), + value_type_(value_type) {} + + template + explicit DictionaryBuilderBase( + const std::shared_ptr& index_type, + enable_if_fixed_size_binary&> value_type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : ArrayBuilder(pool, alignment), + memo_table_(new internal::DictionaryMemoTable(pool, value_type)), + delta_offset_(0), + byte_width_(static_cast(*value_type).byte_width()), + indices_builder_(index_type, pool, alignment), + value_type_(value_type) {} + + template + explicit DictionaryBuilderBase( + enable_if_parameter_free pool = default_memory_pool()) + : DictionaryBuilderBase(TypeTraits::type_singleton(), pool) {} + + // This constructor doesn't check for errors. Use InsertMemoValues instead. + explicit DictionaryBuilderBase(const std::shared_ptr& dictionary, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : ArrayBuilder(pool, alignment), + memo_table_(new internal::DictionaryMemoTable(pool, dictionary)), + delta_offset_(0), + byte_width_(-1), + indices_builder_(pool, alignment), + value_type_(dictionary->type()) {} + + ~DictionaryBuilderBase() override = default; + + /// \brief The current number of entries in the dictionary + int64_t dictionary_length() const { return memo_table_->size(); } + + /// \brief The value byte width (for FixedSizeBinaryType) + template + enable_if_fixed_size_binary byte_width() const { + return byte_width_; + } + + /// \brief Append a scalar value + Status Append(Value value) { + ARROW_RETURN_NOT_OK(Reserve(1)); + + int32_t memo_index; + ARROW_RETURN_NOT_OK(memo_table_->GetOrInsert(value, &memo_index)); + ARROW_RETURN_NOT_OK(indices_builder_.Append(memo_index)); + length_ += 1; + + return Status::OK(); + } + + /// \brief Append a fixed-width string (only for FixedSizeBinaryType) + template + enable_if_fixed_size_binary Append(const uint8_t* value) { + return Append(std::string_view(reinterpret_cast(value), byte_width_)); + } + + /// \brief Append a fixed-width string (only for FixedSizeBinaryType) + template + enable_if_fixed_size_binary Append(const char* value) { + return Append(std::string_view(value, byte_width_)); + } + + /// \brief Append a string (only for binary types) + template + enable_if_binary_like Append(const uint8_t* value, int32_t length) { + return Append(reinterpret_cast(value), length); + } + + /// \brief Append a string (only for binary types) + template + enable_if_binary_like Append(const char* value, int32_t length) { + return Append(std::string_view(value, length)); + } + + /// \brief Append a string (only for string types) + template + enable_if_string_like Append(const char* value, int32_t length) { + return Append(std::string_view(value, length)); + } + + /// \brief Append a decimal (only for Decimal32/64/128/256 Type) + template ::CType> + enable_if_decimal Append(const CType& value) { + auto bytes = value.ToBytes(); + return Append(bytes.data(), static_cast(bytes.size())); + } + + /// \brief Append a scalar null value + Status AppendNull() final { + length_ += 1; + null_count_ += 1; + + return indices_builder_.AppendNull(); + } + + Status AppendNulls(int64_t length) final { + length_ += length; + null_count_ += length; + + return indices_builder_.AppendNulls(length); + } + + Status AppendEmptyValue() final { + length_ += 1; + + return indices_builder_.AppendEmptyValue(); + } + + Status AppendEmptyValues(int64_t length) final { + length_ += length; + + return indices_builder_.AppendEmptyValues(length); + } + + Status AppendScalar(const Scalar& scalar, int64_t n_repeats) override { + if (!scalar.is_valid) return AppendNulls(n_repeats); + + const auto& dict_ty = internal::checked_cast(*scalar.type); + const DictionaryScalar& dict_scalar = + internal::checked_cast(scalar); + const auto& dict = internal::checked_cast::ArrayType&>( + *dict_scalar.value.dictionary); + ARROW_RETURN_NOT_OK(Reserve(n_repeats)); + switch (dict_ty.index_type()->id()) { + case Type::UINT8: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT8: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::UINT16: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT16: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::UINT32: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT32: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::UINT64: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT64: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + default: + return Status::TypeError("Invalid index type: ", dict_ty); + } + return Status::OK(); + } + + Status AppendScalars(const ScalarVector& scalars) override { + for (const auto& scalar : scalars) { + ARROW_RETURN_NOT_OK(AppendScalar(*scalar, /*n_repeats=*/1)); + } + return Status::OK(); + } + + Status AppendArraySlice(const ArraySpan& array, int64_t offset, int64_t length) final { + // Visit the indices and insert the unpacked values. + const auto& dict_ty = internal::checked_cast(*array.type); + // See if possible to avoid using ToArrayData here + const typename TypeTraits::ArrayType dict(array.dictionary().ToArrayData()); + ARROW_RETURN_NOT_OK(Reserve(length)); + switch (dict_ty.index_type()->id()) { + case Type::UINT8: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT8: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::UINT16: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT16: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::UINT32: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT32: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::UINT64: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT64: + return AppendArraySliceImpl(dict, array, offset, length); + default: + return Status::TypeError("Invalid index type: ", dict_ty); + } + return Status::OK(); + } + + /// \brief Insert values into the dictionary's memo, but do not append any + /// indices. Can be used to initialize a new builder with known dictionary + /// values + /// \param[in] values dictionary values to add to memo. Type must match + /// builder type + Status InsertMemoValues(const Array& values) { + return memo_table_->InsertValues(values); + } + + /// \brief Append a whole dense array to the builder + template + enable_if_t::value, Status> AppendArray( + const Array& array) { + using ArrayType = typename TypeTraits::ArrayType; + +#ifndef NDEBUG + ARROW_RETURN_NOT_OK(ArrayBuilder::CheckArrayType( + value_type_, array, "Wrong value type of array to be appended")); +#endif + + const auto& concrete_array = static_cast(array); + for (int64_t i = 0; i < array.length(); i++) { + if (array.IsNull(i)) { + ARROW_RETURN_NOT_OK(AppendNull()); + } else { + ARROW_RETURN_NOT_OK(Append(concrete_array.GetView(i))); + } + } + return Status::OK(); + } + + template + enable_if_fixed_size_binary AppendArray(const Array& array) { +#ifndef NDEBUG + ARROW_RETURN_NOT_OK(ArrayBuilder::CheckArrayType( + value_type_, array, "Wrong value type of array to be appended")); +#endif + + const auto& concrete_array = static_cast(array); + for (int64_t i = 0; i < array.length(); i++) { + if (array.IsNull(i)) { + ARROW_RETURN_NOT_OK(AppendNull()); + } else { + ARROW_RETURN_NOT_OK(Append(concrete_array.GetValue(i))); + } + } + return Status::OK(); + } + + void Reset() override { + // Perform a partial reset. Call ResetFull to also reset the accumulated + // dictionary values + ArrayBuilder::Reset(); + indices_builder_.Reset(); + } + + /// \brief Reset and also clear accumulated dictionary values in memo table + void ResetFull() { + Reset(); + memo_table_.reset(new internal::DictionaryMemoTable(pool_, value_type_)); + } + + Status Resize(int64_t capacity) override { + ARROW_RETURN_NOT_OK(CheckCapacity(capacity)); + capacity = std::max(capacity, kMinBuilderCapacity); + ARROW_RETURN_NOT_OK(indices_builder_.Resize(capacity)); + capacity_ = indices_builder_.capacity(); + return Status::OK(); + } + + /// \brief Return dictionary indices and a delta dictionary since the last + /// time that Finish or FinishDelta were called, and reset state of builder + /// (except the memo table) + Status FinishDelta(std::shared_ptr* out_indices, + std::shared_ptr* out_delta) { + std::shared_ptr indices_data; + std::shared_ptr delta_data; + ARROW_RETURN_NOT_OK(FinishWithDictOffset(delta_offset_, &indices_data, &delta_data)); + *out_indices = MakeArray(indices_data); + *out_delta = MakeArray(delta_data); + return Status::OK(); + } + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { + return ::arrow::dictionary(indices_builder_.type(), value_type_); + } + + protected: + template + Status AppendArraySliceImpl(const typename TypeTraits::ArrayType& dict, + const ArraySpan& array, int64_t offset, int64_t length) { + const c_type* values = array.GetValues(1) + offset; + return VisitBitBlocks( + array.buffers[0].data, array.offset + offset, length, + [&](const int64_t position) { + const int64_t index = static_cast(values[position]); + if (dict.IsValid(index)) { + return Append(dict.GetView(index)); + } + return AppendNull(); + }, + [&]() { return AppendNull(); }); + } + + template + Status AppendScalarImpl(const typename TypeTraits::ArrayType& dict, + const Scalar& index_scalar, int64_t n_repeats) { + using ScalarType = typename TypeTraits::ScalarType; + const auto index = internal::checked_cast(index_scalar).value; + if (index_scalar.is_valid && dict.IsValid(index)) { + const auto& value = dict.GetView(index); + for (int64_t i = 0; i < n_repeats; i++) { + ARROW_RETURN_NOT_OK(Append(value)); + } + return Status::OK(); + } + return AppendNulls(n_repeats); + } + + Status FinishInternal(std::shared_ptr* out) override { + std::shared_ptr dictionary; + ARROW_RETURN_NOT_OK(FinishWithDictOffset(/*offset=*/0, out, &dictionary)); + + // Set type of array data to the right dictionary type + (*out)->type = type(); + (*out)->dictionary = dictionary; + return Status::OK(); + } + + Status FinishWithDictOffset(int64_t dict_offset, + std::shared_ptr* out_indices, + std::shared_ptr* out_dictionary) { + // Finalize indices array + ARROW_RETURN_NOT_OK(indices_builder_.FinishInternal(out_indices)); + + // Generate dictionary array from hash table contents + ARROW_RETURN_NOT_OK(memo_table_->GetArrayData(dict_offset, out_dictionary)); + delta_offset_ = memo_table_->size(); + + // Update internals for further uses of this DictionaryBuilder + ArrayBuilder::Reset(); + return Status::OK(); + } + + std::unique_ptr memo_table_; + + // The size of the dictionary memo at last invocation of Finish, to use in + // FinishDelta for computing dictionary deltas + int32_t delta_offset_; + + // Only used for FixedSizeBinaryType + int32_t byte_width_; + + BuilderType indices_builder_; + std::shared_ptr value_type_; +}; + +template +class DictionaryBuilderBase : public ArrayBuilder { + public: + template + DictionaryBuilderBase( + enable_if_t::value, uint8_t> + start_int_size, + const std::shared_ptr& value_type, + MemoryPool* pool = default_memory_pool()) + : ArrayBuilder(pool), indices_builder_(start_int_size, pool) {} + + explicit DictionaryBuilderBase(const std::shared_ptr& value_type, + MemoryPool* pool = default_memory_pool()) + : ArrayBuilder(pool), indices_builder_(pool) {} + + explicit DictionaryBuilderBase(const std::shared_ptr& index_type, + const std::shared_ptr& value_type, + MemoryPool* pool = default_memory_pool()) + : ArrayBuilder(pool), indices_builder_(index_type, pool) {} + + template + explicit DictionaryBuilderBase( + enable_if_t::value, uint8_t> + start_int_size, + MemoryPool* pool = default_memory_pool()) + : ArrayBuilder(pool), indices_builder_(start_int_size, pool) {} + + explicit DictionaryBuilderBase(MemoryPool* pool = default_memory_pool()) + : ArrayBuilder(pool), indices_builder_(pool) {} + + explicit DictionaryBuilderBase(const std::shared_ptr& dictionary, + MemoryPool* pool = default_memory_pool()) + : ArrayBuilder(pool), indices_builder_(pool) {} + + /// \brief Append a scalar null value + Status AppendNull() final { + length_ += 1; + null_count_ += 1; + + return indices_builder_.AppendNull(); + } + + Status AppendNulls(int64_t length) final { + length_ += length; + null_count_ += length; + + return indices_builder_.AppendNulls(length); + } + + Status AppendEmptyValue() final { + length_ += 1; + + return indices_builder_.AppendEmptyValue(); + } + + Status AppendEmptyValues(int64_t length) final { + length_ += length; + + return indices_builder_.AppendEmptyValues(length); + } + + /// \brief Append a whole dense array to the builder + Status AppendArray(const Array& array) { +#ifndef NDEBUG + ARROW_RETURN_NOT_OK(ArrayBuilder::CheckArrayType( + Type::NA, array, "Wrong value type of array to be appended")); +#endif + for (int64_t i = 0; i < array.length(); i++) { + ARROW_RETURN_NOT_OK(AppendNull()); + } + return Status::OK(); + } + + Status Resize(int64_t capacity) override { + ARROW_RETURN_NOT_OK(CheckCapacity(capacity)); + capacity = std::max(capacity, kMinBuilderCapacity); + + ARROW_RETURN_NOT_OK(indices_builder_.Resize(capacity)); + capacity_ = indices_builder_.capacity(); + return Status::OK(); + } + + Status FinishInternal(std::shared_ptr* out) override { + ARROW_RETURN_NOT_OK(indices_builder_.FinishInternal(out)); + (*out)->type = dictionary((*out)->type, null()); + (*out)->dictionary = NullArray(0).data(); + return Status::OK(); + } + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + std::shared_ptr type() const override { + return ::arrow::dictionary(indices_builder_.type(), null()); + } + + protected: + BuilderType indices_builder_; +}; + +} // namespace internal + +/// \brief A DictionaryArray builder that uses AdaptiveIntBuilder to return the +/// smallest index size that can accommodate the dictionary indices +template +class DictionaryBuilder : public internal::DictionaryBuilderBase { + public: + using BASE = internal::DictionaryBuilderBase; + using BASE::BASE; + + /// \brief Append dictionary indices directly without modifying memo + /// + /// NOTE: Experimental API + Status AppendIndices(const int64_t* values, int64_t length, + const uint8_t* valid_bytes = NULLPTR) { + int64_t null_count_before = this->indices_builder_.null_count(); + ARROW_RETURN_NOT_OK(this->indices_builder_.AppendValues(values, length, valid_bytes)); + this->capacity_ = this->indices_builder_.capacity(); + this->length_ += length; + this->null_count_ += this->indices_builder_.null_count() - null_count_before; + return Status::OK(); + } +}; + +/// \brief A DictionaryArray builder that always returns int32 dictionary +/// indices so that data cast to dictionary form will have a consistent index +/// type, e.g. for creating a ChunkedArray +template +class Dictionary32Builder : public internal::DictionaryBuilderBase { + public: + using BASE = internal::DictionaryBuilderBase; + using BASE::BASE; + + /// \brief Append dictionary indices directly without modifying memo + /// + /// NOTE: Experimental API + Status AppendIndices(const int32_t* values, int64_t length, + const uint8_t* valid_bytes = NULLPTR) { + int64_t null_count_before = this->indices_builder_.null_count(); + ARROW_RETURN_NOT_OK(this->indices_builder_.AppendValues(values, length, valid_bytes)); + this->capacity_ = this->indices_builder_.capacity(); + this->length_ += length; + this->null_count_ += this->indices_builder_.null_count() - null_count_before; + return Status::OK(); + } +}; + +// ---------------------------------------------------------------------- +// Binary / Unicode builders +// (compatibility aliases; those used to be derived classes with additional +// Append() overloads, but they have been folded into DictionaryBuilderBase) + +using BinaryDictionaryBuilder = DictionaryBuilder; +using StringDictionaryBuilder = DictionaryBuilder; +using BinaryDictionary32Builder = Dictionary32Builder; +using StringDictionary32Builder = Dictionary32Builder; + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/builder_nested.h b/pyarrow/include/arrow/array/builder_nested.h new file mode 100644 index 0000000000000000000000000000000000000000..fdbeb0cd7d17b40b929d2ba73dba6f425d01c968 --- /dev/null +++ b/pyarrow/include/arrow/array/builder_nested.h @@ -0,0 +1,836 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/array/array_nested.h" +#include "arrow/array/builder_base.h" +#include "arrow/array/data.h" +#include "arrow/buffer.h" +#include "arrow/buffer_builder.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \addtogroup nested-builders +/// +/// @{ + +// ---------------------------------------------------------------------- +// VarLengthListLikeBuilder + +template +class VarLengthListLikeBuilder : public ArrayBuilder { + public: + using TypeClass = TYPE; + using offset_type = typename TypeClass::offset_type; + + /// Use this constructor to incrementally build the value array along with offsets and + /// null bitmap. + VarLengthListLikeBuilder(MemoryPool* pool, + const std::shared_ptr& value_builder, + const std::shared_ptr& type, + int64_t alignment = kDefaultBufferAlignment) + : ArrayBuilder(pool, alignment), + offsets_builder_(pool, alignment), + value_builder_(value_builder), + value_field_(type->field(0)->WithType(NULLPTR)) {} + + VarLengthListLikeBuilder(MemoryPool* pool, + const std::shared_ptr& value_builder, + int64_t alignment = kDefaultBufferAlignment) + : VarLengthListLikeBuilder(pool, value_builder, + std::make_shared(value_builder->type()), + alignment) {} + + ~VarLengthListLikeBuilder() override = default; + + Status Resize(int64_t capacity) override { + if (ARROW_PREDICT_FALSE(capacity > maximum_elements())) { + return Status::CapacityError(type_name(), + " array cannot reserve space for more than ", + maximum_elements(), " got ", capacity); + } + ARROW_RETURN_NOT_OK(CheckCapacity(capacity)); + + // One more than requested for list offsets + const int64_t offsets_capacity = + is_list_view(TYPE::type_id) ? capacity : capacity + 1; + ARROW_RETURN_NOT_OK(offsets_builder_.Resize(offsets_capacity)); + return ArrayBuilder::Resize(capacity); + } + + void Reset() override { + ArrayBuilder::Reset(); + offsets_builder_.Reset(); + value_builder_->Reset(); + } + + /// \brief Start a new variable-length list slot + /// + /// This function should be called before appending elements to the + /// value builder. Elements appended to the value builder before this function + /// is called for the first time, will not be members of any list value. + /// + /// After this function is called, list_length elements SHOULD be appended to + /// the values builder. If this contract is violated, the behavior is defined by + /// the concrete builder implementation and SHOULD NOT be relied upon unless + /// the caller is specifically building a [Large]List or [Large]ListView array. + /// + /// For [Large]List arrays, the list slot length will be the number of elements + /// appended to the values builder before the next call to Append* or Finish. For + /// [Large]ListView arrays, the list slot length will be exactly list_length, but if + /// Append* is called before at least list_length elements are appended to the values + /// builder, the current list slot will share elements with the next list + /// slots or an invalid [Large]ListView array will be generated because there + /// aren't enough elements in the values builder to fill the list slots. + /// + /// If you're building a [Large]List and don't need to be compatible + /// with [Large]ListView, then `BaseListBuilder::Append(bool is_valid)` + /// is a simpler API. + /// + /// \pre if is_valid is false, list_length MUST be 0 + /// \param is_valid Whether the new list slot is valid + /// \param list_length The number of elements in the list + Status Append(bool is_valid, int64_t list_length) { + ARROW_RETURN_NOT_OK(Reserve(1)); + assert(is_valid || list_length == 0); + UnsafeAppendToBitmap(is_valid); + UnsafeAppendDimensions(/*offset=*/value_builder_->length(), /*size=*/list_length); + return Status::OK(); + } + + Status AppendNull() final { + // Append() a null list slot with list_length=0. + // + // When building [Large]List arrays, elements being appended to the values builder + // before the next call to Append* or Finish will extend the list slot length, but + // that is totally fine because list arrays admit non-empty null list slots. + // + // In the case of [Large]ListViews that's not a problem either because the + // list slot length remains zero. + return Append(false, 0); + } + + Status AppendNulls(int64_t length) final { + ARROW_RETURN_NOT_OK(Reserve(length)); + UnsafeAppendToBitmap(length, false); + UnsafeAppendEmptyDimensions(/*num_values=*/length); + return Status::OK(); + } + + /// \brief Append an empty list slot + /// + /// \post Another call to Append* or Finish should be made before appending to + /// the values builder to ensure list slot remains empty + Status AppendEmptyValue() final { return Append(true, 0); } + + /// \brief Append an empty list slot + /// + /// \post Another call to Append* or Finish should be made before appending to + /// the values builder to ensure the last list slot remains empty + Status AppendEmptyValues(int64_t length) final { + ARROW_RETURN_NOT_OK(Reserve(length)); + UnsafeAppendToBitmap(length, true); + UnsafeAppendEmptyDimensions(/*num_values=*/length); + return Status::OK(); + } + + /// \brief Vector append + /// + /// For list-array builders, the sizes are inferred from the offsets. + /// BaseListBuilder provides an implementation that doesn't take sizes, but + /// this virtual function allows dispatching calls to both list-array and + /// list-view-array builders (which need the sizes) + /// + /// \param offsets The offsets of the variable-length lists + /// \param sizes The sizes of the variable-length lists + /// \param length The number of offsets, sizes, and validity bits to append + /// \param valid_bytes If passed, valid_bytes is of equal length to values, + /// and any zero byte will be considered as a null for that slot + virtual Status AppendValues(const offset_type* offsets, const offset_type* sizes, + int64_t length, const uint8_t* valid_bytes) = 0; + + Status AppendArraySlice(const ArraySpan& array, int64_t offset, + int64_t length) override { + const offset_type* offsets = array.GetValues(1); + [[maybe_unused]] const offset_type* sizes = NULLPTR; + if constexpr (is_list_view(TYPE::type_id)) { + sizes = array.GetValues(2); + } + static_assert(internal::may_have_validity_bitmap(TYPE::type_id)); + const uint8_t* validity = array.MayHaveNulls() ? array.buffers[0].data : NULLPTR; + ARROW_RETURN_NOT_OK(Reserve(length)); + for (int64_t row = offset; row < offset + length; row++) { + const bool is_valid = !validity || bit_util::GetBit(validity, array.offset + row); + int64_t size = 0; + if (is_valid) { + if constexpr (is_list_view(TYPE::type_id)) { + size = sizes[row]; + } else { + size = offsets[row + 1] - offsets[row]; + } + } + UnsafeAppendToBitmap(is_valid); + UnsafeAppendDimensions(/*offset=*/value_builder_->length(), size); + if (is_valid) { + ARROW_RETURN_NOT_OK( + value_builder_->AppendArraySlice(array.child_data[0], offsets[row], size)); + } + } + return Status::OK(); + } + + Status ValidateOverflow(int64_t new_elements) const { + auto new_length = value_builder_->length() + new_elements; + if (ARROW_PREDICT_FALSE(new_length > maximum_elements())) { + return Status::CapacityError(type_name(), " array cannot contain more than ", + maximum_elements(), " elements, have ", new_elements); + } else { + return Status::OK(); + } + } + + ArrayBuilder* value_builder() const { return value_builder_.get(); } + + // Cannot make this a static attribute because of linking issues + static constexpr int64_t maximum_elements() { + return std::numeric_limits::max() - 1; + } + + std::shared_ptr type() const override { + return std::make_shared(value_field_->WithType(value_builder_->type())); + } + + private: + static constexpr const char* type_name() { + if constexpr (is_list_view(TYPE::type_id)) { + return "ListView"; + } else { + return "List"; + } + } + + protected: + /// \brief Append dimensions for num_values empty list slots. + /// + /// ListViewBuilder overrides this to also append the sizes. + virtual void UnsafeAppendEmptyDimensions(int64_t num_values) { + const int64_t offset = value_builder_->length(); + for (int64_t i = 0; i < num_values; ++i) { + offsets_builder_.UnsafeAppend(static_cast(offset)); + } + } + + /// \brief Append dimensions for a single list slot. + /// + /// ListViewBuilder overrides this to also append the size. + virtual void UnsafeAppendDimensions(int64_t offset, int64_t ARROW_ARG_UNUSED(size)) { + offsets_builder_.UnsafeAppend(static_cast(offset)); + } + + TypedBufferBuilder offsets_builder_; + std::shared_ptr value_builder_; + std::shared_ptr value_field_; +}; + +// ---------------------------------------------------------------------- +// ListBuilder / LargeListBuilder + +template +class BaseListBuilder : public VarLengthListLikeBuilder { + private: + using BASE = VarLengthListLikeBuilder; + + public: + using TypeClass = TYPE; + using offset_type = typename BASE::offset_type; + + using BASE::BASE; + + using BASE::Append; + + ~BaseListBuilder() override = default; + + /// \brief Start a new variable-length list slot + /// + /// This function should be called before beginning to append elements to the + /// value builder + Status Append(bool is_valid = true) { + // The value_length parameter to BASE::Append(bool, int64_t) is ignored when + // building a list array, so we can pass 0 here. + return BASE::Append(is_valid, 0); + } + + /// \brief Vector append + /// + /// If passed, valid_bytes is of equal length to values, and any zero byte + /// will be considered as a null for that slot + Status AppendValues(const offset_type* offsets, int64_t length, + const uint8_t* valid_bytes = NULLPTR) { + ARROW_RETURN_NOT_OK(this->Reserve(length)); + this->UnsafeAppendToBitmap(valid_bytes, length); + this->offsets_builder_.UnsafeAppend(offsets, length); + return Status::OK(); + } + + Status AppendValues(const offset_type* offsets, const offset_type* sizes, + int64_t length, const uint8_t* valid_bytes) final { + // Offsets are assumed to be valid, but the first length-1 sizes have to be + // consistent with the offsets to partially rule out the possibility that the + // caller is passing sizes that could work if building a list-view, but don't + // work on building a list that requires offsets to be non-decreasing. + // + // CAUTION: the last size element (`sizes[length - 1]`) is not + // validated and could be inconsistent with the offsets given in a + // subsequent call to AppendValues. +#ifndef NDEBUG + if (sizes) { + for (int64_t i = 0; i < length - 1; ++i) { + if (ARROW_PREDICT_FALSE(offsets[i] != offsets[i + 1] - sizes[i])) { + if (!valid_bytes || valid_bytes[i]) { + return Status::Invalid( + "BaseListBuilder: sizes are inconsistent with offsets provided"); + } + } + } + } +#endif + return AppendValues(offsets, length, valid_bytes); + } + + Status AppendValues(const offset_type* offsets, const offset_type* sizes, + int64_t length) { + return AppendValues(offsets, sizes, length, /*valid_bytes=*/NULLPTR); + } + + Status AppendNextOffset() { + ARROW_RETURN_NOT_OK(this->ValidateOverflow(0)); + const int64_t num_values = this->value_builder_->length(); + return this->offsets_builder_.Append(static_cast(num_values)); + } + + Status FinishInternal(std::shared_ptr* out) override { + ARROW_RETURN_NOT_OK(AppendNextOffset()); + + // Offset padding zeroed by BufferBuilder + std::shared_ptr offsets; + std::shared_ptr null_bitmap; + ARROW_RETURN_NOT_OK(this->offsets_builder_.Finish(&offsets)); + ARROW_RETURN_NOT_OK(this->null_bitmap_builder_.Finish(&null_bitmap)); + + if (this->value_builder_->length() == 0) { + // Try to make sure we get a non-null values buffer (ARROW-2744) + ARROW_RETURN_NOT_OK(this->value_builder_->Resize(0)); + } + + std::shared_ptr items; + ARROW_RETURN_NOT_OK(this->value_builder_->FinishInternal(&items)); + + *out = ArrayData::Make(this->type(), this->length_, + {std::move(null_bitmap), std::move(offsets)}, + {std::move(items)}, this->null_count_); + this->Reset(); + return Status::OK(); + } +}; + +/// \class ListBuilder +/// \brief Builder class for variable-length list array value types +/// +/// To use this class, you must append values to the child array builder and use +/// the Append function to delimit each distinct list value (once the values +/// have been appended to the child array) or use the bulk API to append +/// a sequence of offsets and null values. +/// +/// A note on types. Per arrow/type.h all types in the c++ implementation are +/// logical so even though this class always builds list array, this can +/// represent multiple different logical types. If no logical type is provided +/// at construction time, the class defaults to List where t is taken from the +/// value_builder/values that the object is constructed with. +class ARROW_EXPORT ListBuilder : public BaseListBuilder { + public: + using BaseListBuilder::BaseListBuilder; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } +}; + +/// \class LargeListBuilder +/// \brief Builder class for large variable-length list array value types +/// +/// Like ListBuilder, but to create large list arrays (with 64-bit offsets). +class ARROW_EXPORT LargeListBuilder : public BaseListBuilder { + public: + using BaseListBuilder::BaseListBuilder; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } +}; + +// ---------------------------------------------------------------------- +// ListViewBuilder / LargeListViewBuilder + +template +class BaseListViewBuilder : public VarLengthListLikeBuilder { + private: + using BASE = VarLengthListLikeBuilder; + + public: + using TypeClass = TYPE; + using offset_type = typename BASE::offset_type; + + using BASE::BASE; + + ~BaseListViewBuilder() override = default; + + Status Resize(int64_t capacity) override { + ARROW_RETURN_NOT_OK(BASE::Resize(capacity)); + return sizes_builder_.Resize(capacity); + } + + void Reset() override { + BASE::Reset(); + sizes_builder_.Reset(); + } + + /// \brief Vector append + /// + /// If passed, valid_bytes is of equal length to values, and any zero byte + /// will be considered as a null for that slot + Status AppendValues(const offset_type* offsets, const offset_type* sizes, + int64_t length, const uint8_t* valid_bytes) final { + ARROW_RETURN_NOT_OK(this->Reserve(length)); + this->UnsafeAppendToBitmap(valid_bytes, length); + this->offsets_builder_.UnsafeAppend(offsets, length); + this->sizes_builder_.UnsafeAppend(sizes, length); + return Status::OK(); + } + + Status AppendValues(const offset_type* offsets, const offset_type* sizes, + int64_t length) { + return AppendValues(offsets, sizes, length, /*valid_bytes=*/NULLPTR); + } + + Status FinishInternal(std::shared_ptr* out) override { + // Offset and sizes padding zeroed by BufferBuilder + std::shared_ptr null_bitmap; + std::shared_ptr offsets; + std::shared_ptr sizes; + ARROW_RETURN_NOT_OK(this->null_bitmap_builder_.Finish(&null_bitmap)); + ARROW_RETURN_NOT_OK(this->offsets_builder_.Finish(&offsets)); + ARROW_RETURN_NOT_OK(this->sizes_builder_.Finish(&sizes)); + + if (this->value_builder_->length() == 0) { + // Try to make sure we get a non-null values buffer (ARROW-2744) + ARROW_RETURN_NOT_OK(this->value_builder_->Resize(0)); + } + + std::shared_ptr items; + ARROW_RETURN_NOT_OK(this->value_builder_->FinishInternal(&items)); + + *out = ArrayData::Make(this->type(), this->length_, + {std::move(null_bitmap), std::move(offsets), std::move(sizes)}, + {std::move(items)}, this->null_count_); + this->Reset(); + return Status::OK(); + } + + protected: + void UnsafeAppendEmptyDimensions(int64_t num_values) override { + for (int64_t i = 0; i < num_values; ++i) { + this->offsets_builder_.UnsafeAppend(0); + } + for (int64_t i = 0; i < num_values; ++i) { + this->sizes_builder_.UnsafeAppend(0); + } + } + + void UnsafeAppendDimensions(int64_t offset, int64_t size) override { + this->offsets_builder_.UnsafeAppend(static_cast(offset)); + this->sizes_builder_.UnsafeAppend(static_cast(size)); + } + + private: + TypedBufferBuilder sizes_builder_; +}; + +class ARROW_EXPORT ListViewBuilder final : public BaseListViewBuilder { + public: + using BaseListViewBuilder::BaseListViewBuilder; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } +}; + +class ARROW_EXPORT LargeListViewBuilder final + : public BaseListViewBuilder { + public: + using BaseListViewBuilder::BaseListViewBuilder; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } +}; + +// ---------------------------------------------------------------------- +// Map builder + +/// \class MapBuilder +/// \brief Builder class for arrays of variable-size maps +/// +/// To use this class, you must use the Append function to delimit each distinct +/// map before appending values to the key and item array builders, or use the +/// bulk API to append a sequence of offsets and null maps. +/// +/// Key uniqueness and ordering are not validated. +class ARROW_EXPORT MapBuilder : public ArrayBuilder { + public: + /// Use this constructor to define the built array's type explicitly. If key_builder + /// or item_builder has indeterminate type, this builder will also. + MapBuilder(MemoryPool* pool, const std::shared_ptr& key_builder, + const std::shared_ptr& item_builder, + const std::shared_ptr& type); + + /// Use this constructor to infer the built array's type. If key_builder or + /// item_builder has indeterminate type, this builder will also. + MapBuilder(MemoryPool* pool, const std::shared_ptr& key_builder, + const std::shared_ptr& item_builder, bool keys_sorted = false); + + MapBuilder(MemoryPool* pool, const std::shared_ptr& item_builder, + const std::shared_ptr& type); + + Status Resize(int64_t capacity) override; + void Reset() override; + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + /// \brief Vector append + /// + /// If passed, valid_bytes is of equal length to values, and any zero byte + /// will be considered as a null for that slot + Status AppendValues(const int32_t* offsets, int64_t length, + const uint8_t* valid_bytes = NULLPTR); + + /// \brief Start a new variable-length map slot + /// + /// This function should be called before beginning to append elements to the + /// key and item builders + Status Append(); + + Status AppendNull() final; + + Status AppendNulls(int64_t length) final; + + Status AppendEmptyValue() final; + + Status AppendEmptyValues(int64_t length) final; + + Status AppendArraySlice(const ArraySpan& array, int64_t offset, + int64_t length) override { + const auto* offsets = array.GetValues(1); + static_assert(internal::may_have_validity_bitmap(MapType::type_id)); + const uint8_t* validity = array.MayHaveNulls() ? array.buffers[0].data : NULLPTR; + for (int64_t row = offset; row < offset + length; row++) { + const bool is_valid = !validity || bit_util::GetBit(validity, array.offset + row); + if (is_valid) { + ARROW_RETURN_NOT_OK(Append()); + const int64_t slot_length = offsets[row + 1] - offsets[row]; + // Add together the inner StructArray offset to the Map/List offset + int64_t key_value_offset = array.child_data[0].offset + offsets[row]; + ARROW_RETURN_NOT_OK(key_builder_->AppendArraySlice( + array.child_data[0].child_data[0], key_value_offset, slot_length)); + ARROW_RETURN_NOT_OK(item_builder_->AppendArraySlice( + array.child_data[0].child_data[1], key_value_offset, slot_length)); + } else { + ARROW_RETURN_NOT_OK(AppendNull()); + } + } + return Status::OK(); + } + + /// \brief Get builder to append keys. + /// + /// Append a key with this builder should be followed by appending + /// an item or null value with item_builder(). + ArrayBuilder* key_builder() const { return key_builder_.get(); } + + /// \brief Get builder to append items + /// + /// Appending an item with this builder should have been preceded + /// by appending a key with key_builder(). + ArrayBuilder* item_builder() const { return item_builder_.get(); } + + /// \brief Get builder to add Map entries as struct values. + /// + /// This is used instead of key_builder()/item_builder() and allows + /// the Map to be built as a list of struct values. + ArrayBuilder* value_builder() const { return list_builder_->value_builder(); } + + std::shared_ptr type() const override { + // Key and Item builder may update types, but they don't contain the field names, + // so we need to reconstruct the type. (See ARROW-13735.) + return std::make_shared( + field(entries_name_, + struct_({field(key_name_, key_builder_->type(), false), + field(item_name_, item_builder_->type(), item_nullable_)}), + false), + keys_sorted_); + } + + Status ValidateOverflow(int64_t new_elements) { + return list_builder_->ValidateOverflow(new_elements); + } + + protected: + inline Status AdjustStructBuilderLength(); + + protected: + bool keys_sorted_ = false; + bool item_nullable_ = false; + std::string entries_name_; + std::string key_name_; + std::string item_name_; + std::shared_ptr list_builder_; + std::shared_ptr key_builder_; + std::shared_ptr item_builder_; +}; + +// ---------------------------------------------------------------------- +// FixedSizeList builder + +/// \class FixedSizeListBuilder +/// \brief Builder class for fixed-length list array value types +class ARROW_EXPORT FixedSizeListBuilder : public ArrayBuilder { + public: + using TypeClass = FixedSizeListType; + + /// Use this constructor to define the built array's type explicitly. If value_builder + /// has indeterminate type, this builder will also. + FixedSizeListBuilder(MemoryPool* pool, + const std::shared_ptr& value_builder, + int32_t list_size); + + /// Use this constructor to infer the built array's type. If value_builder has + /// indeterminate type, this builder will also. + FixedSizeListBuilder(MemoryPool* pool, + const std::shared_ptr& value_builder, + const std::shared_ptr& type); + + Status Resize(int64_t capacity) override; + void Reset() override; + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + /// \brief Append a valid fixed length list. + /// + /// This function affects only the validity bitmap; the child values must be appended + /// using the child array builder. + Status Append(); + + /// \brief Vector append + /// + /// If passed, valid_bytes will be read and any zero byte + /// will cause the corresponding slot to be null + /// + /// This function affects only the validity bitmap; the child values must be appended + /// using the child array builder. This includes appending nulls for null lists. + /// XXX this restriction is confusing, should this method be omitted? + Status AppendValues(int64_t length, const uint8_t* valid_bytes = NULLPTR); + + /// \brief Append a null fixed length list. + /// + /// The child array builder will have the appropriate number of nulls appended + /// automatically. + Status AppendNull() final; + + /// \brief Append length null fixed length lists. + /// + /// The child array builder will have the appropriate number of nulls appended + /// automatically. + Status AppendNulls(int64_t length) final; + + Status ValidateOverflow(int64_t new_elements); + + Status AppendEmptyValue() final; + + Status AppendEmptyValues(int64_t length) final; + + Status AppendArraySlice(const ArraySpan& array, int64_t offset, int64_t length) final { + const uint8_t* validity = array.MayHaveNulls() ? array.buffers[0].data : NULLPTR; + for (int64_t row = offset; row < offset + length; row++) { + if (!validity || bit_util::GetBit(validity, array.offset + row)) { + ARROW_RETURN_NOT_OK(value_builder_->AppendArraySlice( + array.child_data[0], list_size_ * (array.offset + row), list_size_)); + ARROW_RETURN_NOT_OK(Append()); + } else { + ARROW_RETURN_NOT_OK(AppendNull()); + } + } + return Status::OK(); + } + + ArrayBuilder* value_builder() const { return value_builder_.get(); } + + std::shared_ptr type() const override { + return fixed_size_list(value_field_->WithType(value_builder_->type()), list_size_); + } + + // Cannot make this a static attribute because of linking issues + static constexpr int64_t maximum_elements() { + return std::numeric_limits::max() - 1; + } + + protected: + std::shared_ptr value_field_; + const int32_t list_size_; + std::shared_ptr value_builder_; +}; + +// ---------------------------------------------------------------------- +// Struct + +// --------------------------------------------------------------------------------- +// StructArray builder +/// Append, Resize and Reserve methods are acting on StructBuilder. +/// Please make sure all these methods of all child-builders' are consistently +/// called to maintain data-structure consistency. +class ARROW_EXPORT StructBuilder : public ArrayBuilder { + public: + /// If any of field_builders has indeterminate type, this builder will also + StructBuilder(const std::shared_ptr& type, MemoryPool* pool, + std::vector> field_builders); + + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + /// Null bitmap is of equal length to every child field, and any zero byte + /// will be considered as a null for that field, but users must using app- + /// end methods or advance methods of the child builders' independently to + /// insert data. + Status AppendValues(int64_t length, const uint8_t* valid_bytes) { + ARROW_RETURN_NOT_OK(Reserve(length)); + UnsafeAppendToBitmap(valid_bytes, length); + return Status::OK(); + } + + /// Append an element to the Struct. All child-builders' Append method must + /// be called independently to maintain data-structure consistency. + Status Append(bool is_valid = true) { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppendToBitmap(is_valid); + return Status::OK(); + } + + /// \brief Append a null value. Automatically appends an empty value to each child + /// builder. + Status AppendNull() final { + for (const auto& field : children_) { + ARROW_RETURN_NOT_OK(field->AppendEmptyValue()); + } + return Append(false); + } + + /// \brief Append multiple null values. Automatically appends empty values to each + /// child builder. + Status AppendNulls(int64_t length) final { + for (const auto& field : children_) { + ARROW_RETURN_NOT_OK(field->AppendEmptyValues(length)); + } + ARROW_RETURN_NOT_OK(Reserve(length)); + UnsafeAppendToBitmap(length, false); + return Status::OK(); + } + + Status AppendEmptyValue() final { + for (const auto& field : children_) { + ARROW_RETURN_NOT_OK(field->AppendEmptyValue()); + } + return Append(true); + } + + Status AppendEmptyValues(int64_t length) final { + for (const auto& field : children_) { + ARROW_RETURN_NOT_OK(field->AppendEmptyValues(length)); + } + ARROW_RETURN_NOT_OK(Reserve(length)); + UnsafeAppendToBitmap(length, true); + return Status::OK(); + } + + Status AppendArraySlice(const ArraySpan& array, int64_t offset, + int64_t length) override { + for (int i = 0; static_cast(i) < children_.size(); i++) { + ARROW_RETURN_NOT_OK(children_[i]->AppendArraySlice(array.child_data[i], + array.offset + offset, length)); + } + const uint8_t* validity = array.MayHaveNulls() ? array.buffers[0].data : NULLPTR; + ARROW_RETURN_NOT_OK(Reserve(length)); + UnsafeAppendToBitmap(validity, array.offset + offset, length); + return Status::OK(); + } + + void Reset() override; + + ArrayBuilder* field_builder(int i) const { return children_[i].get(); } + + int num_fields() const { return static_cast(children_.size()); } + + std::shared_ptr type() const override; + + private: + std::shared_ptr type_; +}; + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/builder_primitive.h b/pyarrow/include/arrow/array/builder_primitive.h new file mode 100644 index 0000000000000000000000000000000000000000..6d79d6e9649994e99b85b233cc81ba8c1a8a1ba1 --- /dev/null +++ b/pyarrow/include/arrow/array/builder_primitive.h @@ -0,0 +1,689 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/array/builder_base.h" +#include "arrow/array/data.h" +#include "arrow/result.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/float16.h" + +namespace arrow { + +class ARROW_EXPORT NullBuilder : public ArrayBuilder { + public: + explicit NullBuilder(MemoryPool* pool = default_memory_pool(), + int64_t ARROW_ARG_UNUSED(alignment) = kDefaultBufferAlignment) + : ArrayBuilder(pool) {} + + explicit NullBuilder(const std::shared_ptr& ARROW_ARG_UNUSED(type), + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : NullBuilder(pool, alignment) {} + + /// \brief Append the specified number of null elements + Status AppendNulls(int64_t length) final { + if (length < 0) return Status::Invalid("length must be positive"); + null_count_ += length; + length_ += length; + return Status::OK(); + } + + /// \brief Append a single null element + Status AppendNull() final { return AppendNulls(1); } + + Status AppendEmptyValues(int64_t length) final { return AppendNulls(length); } + + Status AppendEmptyValue() final { return AppendEmptyValues(1); } + + Status Append(std::nullptr_t) { return AppendNull(); } + + Status AppendArraySlice(const ArraySpan&, int64_t, int64_t length) override { + return AppendNulls(length); + } + + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + std::shared_ptr type() const override { return null(); } + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } +}; + +/// \addtogroup numeric-builders +/// +/// @{ + +/// Base class for all Builders that emit an Array of a scalar numerical type. +template +class NumericBuilder + : public ArrayBuilder, + public internal::ArrayBuilderExtraOps, typename T::c_type> { + public: + using TypeClass = T; + using value_type = typename T::c_type; + using ArrayType = typename TypeTraits::ArrayType; + + template + explicit NumericBuilder( + enable_if_parameter_free pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : ArrayBuilder(pool, alignment), + type_(TypeTraits::type_singleton()), + data_builder_(pool, alignment) {} + + NumericBuilder(const std::shared_ptr& type, MemoryPool* pool, + int64_t alignment = kDefaultBufferAlignment) + : ArrayBuilder(pool, alignment), type_(type), data_builder_(pool, alignment) {} + + /// Append a single scalar and increase the size if necessary. + Status Append(const value_type val) { + ARROW_RETURN_NOT_OK(ArrayBuilder::Reserve(1)); + UnsafeAppend(val); + return Status::OK(); + } + + /// Write nulls as uint8_t* (0 value indicates null) into pre-allocated memory + /// The memory at the corresponding data slot is set to 0 to prevent + /// uninitialized memory access + Status AppendNulls(int64_t length) final { + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(length, value_type{}); // zero + UnsafeSetNull(length); + return Status::OK(); + } + + /// \brief Append a single null element + Status AppendNull() final { + ARROW_RETURN_NOT_OK(Reserve(1)); + data_builder_.UnsafeAppend(value_type{}); // zero + UnsafeAppendToBitmap(false); + return Status::OK(); + } + + /// \brief Append a empty element + Status AppendEmptyValue() final { + ARROW_RETURN_NOT_OK(Reserve(1)); + data_builder_.UnsafeAppend(value_type{}); // zero + UnsafeAppendToBitmap(true); + return Status::OK(); + } + + /// \brief Append several empty elements + Status AppendEmptyValues(int64_t length) final { + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(length, value_type{}); // zero + UnsafeSetNotNull(length); + return Status::OK(); + } + + value_type GetValue(int64_t index) const { return data_builder_.data()[index]; } + + value_type* GetMutableValue(int64_t index) { + return &data_builder_.mutable_data()[index]; + } + + void Reset() override { + data_builder_.Reset(); + ArrayBuilder::Reset(); + } + + Status Resize(int64_t capacity) override { + ARROW_RETURN_NOT_OK(CheckCapacity(capacity)); + capacity = std::max(capacity, kMinBuilderCapacity); + ARROW_RETURN_NOT_OK(data_builder_.Resize(capacity)); + return ArrayBuilder::Resize(capacity); + } + + value_type operator[](int64_t index) const { return GetValue(index); } + + value_type& operator[](int64_t index) { + return reinterpret_cast(data_builder_.mutable_data())[index]; + } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a contiguous C array of values + /// \param[in] length the number of values to append + /// \param[in] valid_bytes an optional sequence of bytes where non-zero + /// indicates a valid (non-null) value + /// \return Status + Status AppendValues(const value_type* values, int64_t length, + const uint8_t* valid_bytes = NULLPTR) { + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(values, length); + // length_ is update by these + ArrayBuilder::UnsafeAppendToBitmap(valid_bytes, length); + return Status::OK(); + } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a contiguous C array of values + /// \param[in] length the number of values to append + /// \param[in] bitmap a validity bitmap to copy (may be null) + /// \param[in] bitmap_offset an offset into the validity bitmap + /// \return Status + Status AppendValues(const value_type* values, int64_t length, const uint8_t* bitmap, + int64_t bitmap_offset) { + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(values, length); + // length_ is update by these + ArrayBuilder::UnsafeAppendToBitmap(bitmap, bitmap_offset, length); + return Status::OK(); + } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a contiguous C array of values + /// \param[in] length the number of values to append + /// \param[in] is_valid an std::vector indicating valid (1) or null + /// (0). Equal in length to values + /// \return Status + Status AppendValues(const value_type* values, int64_t length, + const std::vector& is_valid) { + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(values, length); + // length_ is update by these + ArrayBuilder::UnsafeAppendToBitmap(is_valid); + return Status::OK(); + } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a std::vector of values + /// \param[in] is_valid an std::vector indicating valid (1) or null + /// (0). Equal in length to values + /// \return Status + Status AppendValues(const std::vector& values, + const std::vector& is_valid) { + if (values.empty()) { + return Status::OK(); + } + return AppendValues(values.data(), static_cast(values.size()), is_valid); + } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a std::vector of values + /// \return Status + Status AppendValues(const std::vector& values) { + if (values.empty()) { + return Status::OK(); + } + return AppendValues(values.data(), static_cast(values.size())); + } + + Status FinishInternal(std::shared_ptr* out) override { + ARROW_ASSIGN_OR_RAISE(auto null_bitmap, + null_bitmap_builder_.FinishWithLength(length_)); + ARROW_ASSIGN_OR_RAISE(auto data, data_builder_.FinishWithLength(length_)); + *out = ArrayData::Make(type(), length_, {null_bitmap, data}, null_count_); + capacity_ = length_ = null_count_ = 0; + return Status::OK(); + } + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values_begin InputIterator to the beginning of the values + /// \param[in] values_end InputIterator pointing to the end of the values + /// \return Status + template + Status AppendValues(ValuesIter values_begin, ValuesIter values_end) { + int64_t length = static_cast(std::distance(values_begin, values_end)); + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(values_begin, values_end); + // this updates the length_ + UnsafeSetNotNull(length); + return Status::OK(); + } + + /// \brief Append a sequence of elements in one shot, with a specified nullmap + /// \param[in] values_begin InputIterator to the beginning of the values + /// \param[in] values_end InputIterator pointing to the end of the values + /// \param[in] valid_begin InputIterator with elements indication valid(1) + /// or null(0) values. + /// \return Status + template + enable_if_t::value, Status> AppendValues( + ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) { + static_assert(!internal::is_null_pointer::value, + "Don't pass a NULLPTR directly as valid_begin, use the 2-argument " + "version instead"); + int64_t length = static_cast(std::distance(values_begin, values_end)); + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(values_begin, values_end); + null_bitmap_builder_.UnsafeAppend( + length, [&valid_begin]() -> bool { return *valid_begin++; }); + length_ = null_bitmap_builder_.length(); + null_count_ = null_bitmap_builder_.false_count(); + return Status::OK(); + } + + // Same as above, with a pointer type ValidIter + template + enable_if_t::value, Status> AppendValues( + ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) { + int64_t length = static_cast(std::distance(values_begin, values_end)); + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(values_begin, values_end); + // this updates the length_ + if (valid_begin == NULLPTR) { + UnsafeSetNotNull(length); + } else { + null_bitmap_builder_.UnsafeAppend( + length, [&valid_begin]() -> bool { return *valid_begin++; }); + length_ = null_bitmap_builder_.length(); + null_count_ = null_bitmap_builder_.false_count(); + } + + return Status::OK(); + } + + Status AppendArraySlice(const ArraySpan& array, int64_t offset, + int64_t length) override { + return AppendValues(array.GetValues(1) + offset, length, + array.GetValues(0, 0), array.offset + offset); + } + + /// Append a single scalar under the assumption that the underlying Buffer is + /// large enough. + /// + /// This method does not capacity-check; make sure to call Reserve + /// beforehand. + void UnsafeAppend(const value_type val) { + ArrayBuilder::UnsafeAppendToBitmap(true); + data_builder_.UnsafeAppend(val); + } + + void UnsafeAppendNull() { + ArrayBuilder::UnsafeAppendToBitmap(false); + data_builder_.UnsafeAppend(value_type{}); // zero + } + + /// Advance builder without allocating nor writing any values + /// + /// The internal pointer is advanced by `length` values and the same number + /// of non-null entries are appended to the validity bitmap. + /// This method assumes that the `length` values were populated directly, + /// for example using `GetMutableValue`. + void UnsafeAdvance(int64_t length) { + data_builder_.UnsafeAdvance(length); + UnsafeAppendToBitmap(length, true); + } + + /// Advance builder without allocating nor writing any values + /// + /// The internal pointer is advanced by `length` values and the same number + /// of validity bits are appended to the validity bitmap. + /// This method assumes that the `length` values were populated directly, + /// for example using `GetMutableValue`. + void UnsafeAdvance(int64_t length, const uint8_t* validity, int64_t valid_bits_offset) { + data_builder_.UnsafeAdvance(length); + UnsafeAppendToBitmap(validity, valid_bits_offset, length); + } + + std::shared_ptr type() const override { return type_; } + + protected: + std::shared_ptr type_; + TypedBufferBuilder data_builder_; +}; + +// Builders + +using UInt8Builder = NumericBuilder; +using UInt16Builder = NumericBuilder; +using UInt32Builder = NumericBuilder; +using UInt64Builder = NumericBuilder; + +using Int8Builder = NumericBuilder; +using Int16Builder = NumericBuilder; +using Int32Builder = NumericBuilder; +using Int64Builder = NumericBuilder; + +using FloatBuilder = NumericBuilder; +using DoubleBuilder = NumericBuilder; + +/// @} + +/// \addtogroup temporal-builders +/// +/// @{ + +using Date32Builder = NumericBuilder; +using Date64Builder = NumericBuilder; +using Time32Builder = NumericBuilder; +using Time64Builder = NumericBuilder; +using TimestampBuilder = NumericBuilder; +using MonthIntervalBuilder = NumericBuilder; +using DurationBuilder = NumericBuilder; + +/// @} + +/// \addtogroup numeric-builders +/// +/// @{ + +class ARROW_EXPORT HalfFloatBuilder : public NumericBuilder { + public: + using BaseClass = NumericBuilder; + using Float16 = arrow::util::Float16; + + using BaseClass::Append; + using BaseClass::AppendValues; + using BaseClass::BaseClass; + using BaseClass::GetValue; + using BaseClass::UnsafeAppend; + + /// Scalar append a arrow::util::Float16 + Status Append(const Float16 val) { return Append(val.bits()); } + + /// Scalar append a arrow::util::Float16, without checking for capacity + void UnsafeAppend(const Float16 val) { UnsafeAppend(val.bits()); } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a contiguous array of arrow::util::Float16 + /// \param[in] length the number of values to append + /// \param[in] valid_bytes an optional sequence of bytes where non-zero + /// indicates a valid (non-null) value + /// \return Status + Status AppendValues(const Float16* values, int64_t length, + const uint8_t* valid_bytes = NULLPTR) { + return BaseClass::AppendValues(reinterpret_cast(values), length, + valid_bytes); + } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a contiguous array of arrow::util::Float16 + /// \param[in] length the number of values to append + /// \param[in] bitmap a validity bitmap to copy (may be null) + /// \param[in] bitmap_offset an offset into the validity bitmap + /// \return Status + Status AppendValues(const Float16* values, int64_t length, const uint8_t* bitmap, + int64_t bitmap_offset) { + return BaseClass::AppendValues(reinterpret_cast(values), length, + bitmap, bitmap_offset); + } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a contiguous array of arrow::util::Float16 + /// \param[in] length the number of values to append + /// \param[in] is_valid a std::vector indicating valid (1) or null + /// (0). Equal in length to values + /// \return Status + Status AppendValues(const Float16* values, int64_t length, + const std::vector& is_valid) { + return BaseClass::AppendValues(reinterpret_cast(values), length, + is_valid); + } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a std::vector + /// \param[in] is_valid a std::vector indicating valid (1) or null + /// (0). Equal in length to values + /// \return Status + Status AppendValues(const std::vector& values, + const std::vector& is_valid) { + return AppendValues(values.data(), static_cast(values.size()), is_valid); + } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a std::vector + /// \return Status + Status AppendValues(const std::vector& values) { + return AppendValues(values.data(), static_cast(values.size())); + } + + /// \brief Append one value many times in one shot + /// \param[in] length the number of values to append + /// \param[in] value a arrow::util::Float16 + Status AppendValues(int64_t length, Float16 value) { + RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(length, value.bits()); + ArrayBuilder::UnsafeSetNotNull(length); + return Status::OK(); + } + + /// \brief Get the value at a certain index + /// \param[in] index the zero-based index + /// @tparam T arrow::util::Float16 or value_type (uint16_t) + template + T GetValue(int64_t index) const { + static_assert(std::is_same_v || + std::is_same_v); + if constexpr (std::is_same_v) { + return BaseClass::GetValue(index); + } else { + return Float16::FromBits(BaseClass::GetValue(index)); + } + } +}; + +/// @} + +class ARROW_EXPORT BooleanBuilder + : public ArrayBuilder, + public internal::ArrayBuilderExtraOps { + public: + using TypeClass = BooleanType; + using value_type = bool; + + explicit BooleanBuilder(MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment); + + BooleanBuilder(const std::shared_ptr& type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment); + + /// Write nulls as uint8_t* (0 value indicates null) into pre-allocated memory + Status AppendNulls(int64_t length) final { + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(length, false); + UnsafeSetNull(length); + return Status::OK(); + } + + Status AppendNull() final { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppendNull(); + return Status::OK(); + } + + Status AppendEmptyValue() final { + ARROW_RETURN_NOT_OK(Reserve(1)); + data_builder_.UnsafeAppend(false); + UnsafeSetNotNull(1); + return Status::OK(); + } + + Status AppendEmptyValues(int64_t length) final { + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(length, false); + UnsafeSetNotNull(length); + return Status::OK(); + } + + /// Scalar append + Status Append(const bool val) { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppend(val); + return Status::OK(); + } + + Status Append(const uint8_t val) { return Append(val != 0); } + + /// Scalar append, without checking for capacity + void UnsafeAppend(const bool val) { + data_builder_.UnsafeAppend(val); + UnsafeAppendToBitmap(true); + } + + void UnsafeAppendNull() { + data_builder_.UnsafeAppend(false); + UnsafeAppendToBitmap(false); + } + + void UnsafeAppend(const uint8_t val) { UnsafeAppend(val != 0); } + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a contiguous array of bytes (non-zero is 1) + /// \param[in] length the number of values to append + /// \param[in] valid_bytes an optional sequence of bytes where non-zero + /// indicates a valid (non-null) value + /// \return Status + Status AppendValues(const uint8_t* values, int64_t length, + const uint8_t* valid_bytes = NULLPTR); + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a bitmap of values + /// \param[in] length the number of values to append + /// \param[in] validity a validity bitmap to copy (may be null) + /// \param[in] offset an offset into the values and validity bitmaps + /// \return Status + Status AppendValues(const uint8_t* values, int64_t length, const uint8_t* validity, + int64_t offset); + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a contiguous C array of values + /// \param[in] length the number of values to append + /// \param[in] is_valid an std::vector indicating valid (1) or null + /// (0). Equal in length to values + /// \return Status + Status AppendValues(const uint8_t* values, int64_t length, + const std::vector& is_valid); + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a std::vector of bytes + /// \param[in] is_valid an std::vector indicating valid (1) or null + /// (0). Equal in length to values + /// \return Status + Status AppendValues(const std::vector& values, + const std::vector& is_valid); + + /// \brief Append a sequence of elements in one shot + /// \param[in] values a std::vector of bytes + /// \return Status + Status AppendValues(const std::vector& values); + + /// \brief Append a sequence of elements in one shot + /// \param[in] values an std::vector indicating true (1) or false + /// \param[in] is_valid an std::vector indicating valid (1) or null + /// (0). Equal in length to values + /// \return Status + Status AppendValues(const std::vector& values, const std::vector& is_valid); + + /// \brief Append a sequence of elements in one shot + /// \param[in] values an std::vector indicating true (1) or false + /// \return Status + Status AppendValues(const std::vector& values); + + /// \brief Append a sequence of elements in one shot + /// \param[in] values_begin InputIterator to the beginning of the values + /// \param[in] values_end InputIterator pointing to the end of the values + /// or null(0) values + /// \return Status + template + Status AppendValues(ValuesIter values_begin, ValuesIter values_end) { + int64_t length = static_cast(std::distance(values_begin, values_end)); + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend( + length, [&values_begin]() -> bool { return *values_begin++; }); + // this updates length_ + UnsafeSetNotNull(length); + return Status::OK(); + } + + /// \brief Append a sequence of elements in one shot, with a specified nullmap + /// \param[in] values_begin InputIterator to the beginning of the values + /// \param[in] values_end InputIterator pointing to the end of the values + /// \param[in] valid_begin InputIterator with elements indication valid(1) + /// or null(0) values + /// \return Status + template + enable_if_t::value, Status> AppendValues( + ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) { + static_assert(!internal::is_null_pointer::value, + "Don't pass a NULLPTR directly as valid_begin, use the 2-argument " + "version instead"); + int64_t length = static_cast(std::distance(values_begin, values_end)); + ARROW_RETURN_NOT_OK(Reserve(length)); + + data_builder_.UnsafeAppend( + length, [&values_begin]() -> bool { return *values_begin++; }); + null_bitmap_builder_.UnsafeAppend( + length, [&valid_begin]() -> bool { return *valid_begin++; }); + length_ = null_bitmap_builder_.length(); + null_count_ = null_bitmap_builder_.false_count(); + return Status::OK(); + } + + // Same as above, for a pointer type ValidIter + template + enable_if_t::value, Status> AppendValues( + ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) { + int64_t length = static_cast(std::distance(values_begin, values_end)); + ARROW_RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend( + length, [&values_begin]() -> bool { return *values_begin++; }); + + if (valid_begin == NULLPTR) { + UnsafeSetNotNull(length); + } else { + null_bitmap_builder_.UnsafeAppend( + length, [&valid_begin]() -> bool { return *valid_begin++; }); + } + length_ = null_bitmap_builder_.length(); + null_count_ = null_bitmap_builder_.false_count(); + return Status::OK(); + } + + Status AppendValues(int64_t length, bool value); + + Status AppendArraySlice(const ArraySpan& array, int64_t offset, + int64_t length) override { + return AppendValues(array.GetValues(1, 0), length, + array.GetValues(0, 0), array.offset + offset); + } + + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + void Reset() override; + Status Resize(int64_t capacity) override; + + std::shared_ptr type() const override { return boolean(); } + + protected: + TypedBufferBuilder data_builder_; +}; + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/builder_run_end.h b/pyarrow/include/arrow/array/builder_run_end.h new file mode 100644 index 0000000000000000000000000000000000000000..ac92efbd0dbe6b470b8275219e75b41aa3f7ab3a --- /dev/null +++ b/pyarrow/include/arrow/array/builder_run_end.h @@ -0,0 +1,303 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/array/builder_base.h" + +namespace arrow { + +/// \addtogroup run-end-encoded-builders +/// +/// @{ + +namespace internal { + +/// \brief An ArrayBuilder that deduplicates repeated values as they are +/// appended to the inner-ArrayBuilder and reports the length of the current run +/// of identical values. +/// +/// The following sequence of calls +/// +/// Append(2) +/// Append(2) +/// Append(2) +/// Append(7) +/// Append(7) +/// Append(2) +/// FinishInternal() +/// +/// will cause the inner-builder to receive only 3 Append calls +/// +/// Append(2) +/// Append(7) +/// Append(2) +/// FinishInternal() +/// +/// Note that values returned by length(), null_count() and capacity() are +/// related to the compressed array built by the inner-ArrayBuilder. +class RunCompressorBuilder : public ArrayBuilder { + public: + RunCompressorBuilder(MemoryPool* pool, std::shared_ptr inner_builder, + std::shared_ptr type); + + ~RunCompressorBuilder() override; + + ARROW_DISALLOW_COPY_AND_ASSIGN(RunCompressorBuilder); + + /// \brief Called right before a run is being closed + /// + /// Subclasses can override this function to perform an additional action when + /// a run is closed (i.e. run-length is known and value is appended to the + /// inner builder). + /// + /// \param value can be NULLPTR if closing a run of NULLs + /// \param length the greater than 0 length of the value run being closed + virtual Status WillCloseRun(const std::shared_ptr& value, + int64_t length) { + return Status::OK(); + } + + /// \brief Called right before a run of empty values is being closed + /// + /// Subclasses can override this function to perform an additional action when + /// a run of empty values is appended (i.e. run-length is known and a single + /// empty value is appended to the inner builder). + /// + /// \param length the greater than 0 length of the value run being closed + virtual Status WillCloseRunOfEmptyValues(int64_t length) { return Status::OK(); } + + /// \brief Allocate enough memory for a given number of array elements. + /// + /// NOTE: Conservatively resizing a run-length compressed array for a given + /// number of logical elements is not possible, since the physical length will + /// vary depending on the values to be appended in the future. But we can + /// pessimistically assume that each run will contain a single value and + /// allocate that number of runs. + Status Resize(int64_t capacity) override { return ResizePhysical(capacity); } + + /// \brief Allocate enough memory for a given number of runs. + /// + /// Like Resize on non-encoded builders, it does not account for variable size + /// data. + Status ResizePhysical(int64_t capacity); + + Status ReservePhysical(int64_t additional_capacity) { + return Reserve(additional_capacity); + } + + void Reset() override; + + Status AppendNull() final { return AppendNulls(1); } + Status AppendNulls(int64_t length) override; + + Status AppendEmptyValue() final { return AppendEmptyValues(1); } + Status AppendEmptyValues(int64_t length) override; + + Status AppendScalar(const Scalar& scalar, int64_t n_repeats) override; + Status AppendScalars(const ScalarVector& scalars) override; + + // AppendArraySlice() is not implemented. + + /// \brief Append a slice of an array containing values from already + /// compressed runs. + /// + /// NOTE: WillCloseRun() is not called as the length of each run cannot be + /// determined at this point. Caller should ensure that !has_open_run() by + /// calling FinishCurrentRun() before calling this. + /// + /// Pre-condition: !has_open_run() + Status AppendRunCompressedArraySlice(const ArraySpan& array, int64_t offset, + int64_t length); + + /// \brief Forces the closing of the current run if one is currently open. + /// + /// This can be called when one wants to ensure the current run will not be + /// extended. This may cause identical values to appear close to each other in + /// the underlying array (i.e. two runs that could be a single run) if more + /// values are appended after this is called. + /// + /// Finish() and FinishInternal() call this automatically. + virtual Status FinishCurrentRun(); + + Status FinishInternal(std::shared_ptr* out) override; + + ArrayBuilder& inner_builder() const { return *inner_builder_; } + + std::shared_ptr type() const override { return inner_builder_->type(); } + + bool has_open_run() const { return current_run_length_ > 0; } + int64_t open_run_length() const { return current_run_length_; } + + private: + inline void UpdateDimensions() { + capacity_ = inner_builder_->capacity(); + length_ = inner_builder_->length(); + null_count_ = inner_builder_->null_count(); + } + + private: + std::shared_ptr inner_builder_; + std::shared_ptr current_value_ = NULLPTR; + int64_t current_run_length_ = 0; +}; + +} // namespace internal + +// ---------------------------------------------------------------------- +// RunEndEncoded builder + +/// \brief Run-end encoded array builder. +/// +/// NOTE: the value returned by and capacity() is related to the +/// compressed array (physical) and not the decoded array (logical) that is +/// run-end encoded. null_count() always returns 0. length(), on the other hand, +/// returns the logical length of the run-end encoded array. +class ARROW_EXPORT RunEndEncodedBuilder : public ArrayBuilder { + private: + // An internal::RunCompressorBuilder that produces a run-end in the + // RunEndEncodedBuilder every time a value-run is closed. + class ValueRunBuilder : public internal::RunCompressorBuilder { + public: + ValueRunBuilder(MemoryPool* pool, const std::shared_ptr& value_builder, + const std::shared_ptr& value_type, + RunEndEncodedBuilder& ree_builder); + + ~ValueRunBuilder() override = default; + + Status WillCloseRun(const std::shared_ptr&, int64_t length) override { + return ree_builder_.CloseRun(length); + } + + Status WillCloseRunOfEmptyValues(int64_t length) override { + return ree_builder_.CloseRun(length); + } + + private: + RunEndEncodedBuilder& ree_builder_; + }; + + public: + RunEndEncodedBuilder(MemoryPool* pool, + const std::shared_ptr& run_end_builder, + const std::shared_ptr& value_builder, + std::shared_ptr type); + + /// \brief Allocate enough memory for a given number of array elements. + /// + /// NOTE: Conservatively resizing an REE for a given number of logical + /// elements is not possible, since the physical length will vary depending on + /// the values to be appended in the future. But we can pessimistically assume + /// that each run will contain a single value and allocate that number of + /// runs. + Status Resize(int64_t capacity) override { return ResizePhysical(capacity); } + + /// \brief Allocate enough memory for a given number of runs. + Status ResizePhysical(int64_t capacity); + + /// \brief Ensure that there is enough space allocated to append the indicated + /// number of run without any further reallocation. Overallocation is + /// used in order to minimize the impact of incremental ReservePhysical() calls. + /// Note that additional_capacity is relative to the current number of elements + /// rather than to the current capacity, so calls to Reserve() which are not + /// interspersed with addition of new elements may not increase the capacity. + /// + /// \param[in] additional_capacity the number of additional runs + /// \return Status + Status ReservePhysical(int64_t additional_capacity) { + return Reserve(additional_capacity); + } + + void Reset() override; + + Status AppendNull() final { return AppendNulls(1); } + Status AppendNulls(int64_t length) override; + + Status AppendEmptyValue() final { return AppendEmptyValues(1); } + Status AppendEmptyValues(int64_t length) override; + Status AppendScalar(const Scalar& scalar, int64_t n_repeats) override; + Status AppendScalars(const ScalarVector& scalars) override; + Status AppendArraySlice(const ArraySpan& array, int64_t offset, + int64_t length) override; + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + /// \brief Forces the closing of the current run if one is currently open. + /// + /// This can be called when one wants to ensure the current run will not be + /// extended. This may cause identical values to appear close to each other in + /// the values array (i.e. two runs that could be a single run) if more + /// values are appended after this is called. + Status FinishCurrentRun(); + + std::shared_ptr type() const override; + + private: + /// \brief Update physical capacity and logical length + /// + /// \param committed_logical_length number of logical values that have been + /// committed to the values array + /// \param open_run_length number of logical values in the currently open run if any + inline void UpdateDimensions(int64_t committed_logical_length, + int64_t open_run_length) { + capacity_ = run_end_builder().capacity(); + length_ = committed_logical_length + open_run_length; + committed_logical_length_ = committed_logical_length; + } + + // Pre-condition: !value_run_builder_.has_open_run() + template + Status DoAppendArraySlice(const ArraySpan& array, int64_t offset, int64_t length); + + template + Status DoAppendRunEnd(int64_t run_end); + + /// \brief Cast run_end to the appropriate type and appends it to the run_ends + /// array. + Status AppendRunEnd(int64_t run_end); + + /// \brief Close a run by appending a value to the run_ends array and updating + /// length_ to reflect the new run. + /// + /// Pre-condition: run_length > 0. + [[nodiscard]] Status CloseRun(int64_t run_length); + + ArrayBuilder& run_end_builder(); + ArrayBuilder& value_builder(); + + private: + std::shared_ptr type_; + ValueRunBuilder* value_run_builder_; + // The length not counting the current open run in the value_run_builder_ + int64_t committed_logical_length_ = 0; +}; + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/builder_time.h b/pyarrow/include/arrow/array/builder_time.h new file mode 100644 index 0000000000000000000000000000000000000000..b471e9621cd4b125fd44e8f2f4239c7f720ac95d --- /dev/null +++ b/pyarrow/include/arrow/array/builder_time.h @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Contains declarations of time related Arrow builder types. + +#pragma once + +#include + +#include "arrow/array/builder_base.h" +#include "arrow/array/builder_primitive.h" + +namespace arrow { + +/// \addtogroup temporal-builders +/// +/// @{ + +class ARROW_EXPORT DayTimeIntervalBuilder : public NumericBuilder { + public: + using DayMilliseconds = DayTimeIntervalType::DayMilliseconds; + + explicit DayTimeIntervalBuilder(MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : DayTimeIntervalBuilder(day_time_interval(), pool, alignment) {} + + explicit DayTimeIntervalBuilder(std::shared_ptr type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : NumericBuilder(type, pool, alignment) {} +}; + +class ARROW_EXPORT MonthDayNanoIntervalBuilder + : public NumericBuilder { + public: + using MonthDayNanos = MonthDayNanoIntervalType::MonthDayNanos; + + explicit MonthDayNanoIntervalBuilder(MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : MonthDayNanoIntervalBuilder(month_day_nano_interval(), pool, alignment) {} + + explicit MonthDayNanoIntervalBuilder(std::shared_ptr type, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : NumericBuilder(type, pool, alignment) {} +}; + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/builder_union.h b/pyarrow/include/arrow/array/builder_union.h new file mode 100644 index 0000000000000000000000000000000000000000..718ef4c32cebef1d30e4f7c036a7ab8f4b333e4a --- /dev/null +++ b/pyarrow/include/arrow/array/builder_union.h @@ -0,0 +1,254 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/array/array_nested.h" +#include "arrow/array/builder_base.h" +#include "arrow/array/data.h" +#include "arrow/buffer_builder.h" +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \addtogroup nested-builders +/// +/// @{ + +/// \brief Base class for union array builds. +/// +/// Note that while we subclass ArrayBuilder, as union types do not have a +/// validity bitmap, the bitmap builder member of ArrayBuilder is not used. +class ARROW_EXPORT BasicUnionBuilder : public ArrayBuilder { + public: + Status FinishInternal(std::shared_ptr* out) override; + + /// \cond FALSE + using ArrayBuilder::Finish; + /// \endcond + + Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + + /// \brief Make a new child builder available to the UnionArray + /// + /// \param[in] new_child the child builder + /// \param[in] field_name the name of the field in the union array type + /// if type inference is used + /// \return child index, which is the "type" argument that needs + /// to be passed to the "Append" method to add a new element to + /// the union array. + int8_t AppendChild(const std::shared_ptr& new_child, + const std::string& field_name = ""); + + std::shared_ptr type() const override; + + int64_t length() const override { return types_builder_.length(); } + + protected: + BasicUnionBuilder(MemoryPool* pool, int64_t alignment, + const std::vector>& children, + const std::shared_ptr& type); + + int8_t NextTypeId(); + + std::vector> child_fields_; + std::vector type_codes_; + UnionMode::type mode_; + + std::vector type_id_to_children_; + std::vector type_id_to_child_id_; + // for all type_id < dense_type_id_, type_id_to_children_[type_id] != nullptr + int8_t dense_type_id_ = 0; + TypedBufferBuilder types_builder_; +}; + +/// \class DenseUnionBuilder +/// +/// This API is EXPERIMENTAL. +class ARROW_EXPORT DenseUnionBuilder : public BasicUnionBuilder { + public: + /// Use this constructor to initialize the UnionBuilder with no child builders, + /// allowing type to be inferred. You will need to call AppendChild for each of the + /// children builders you want to use. + explicit DenseUnionBuilder(MemoryPool* pool, + int64_t alignment = kDefaultBufferAlignment) + : BasicUnionBuilder(pool, alignment, {}, dense_union(FieldVector{})), + offsets_builder_(pool, alignment) {} + + /// Use this constructor to specify the type explicitly. + /// You can still add child builders to the union after using this constructor + DenseUnionBuilder(MemoryPool* pool, + const std::vector>& children, + const std::shared_ptr& type, + int64_t alignment = kDefaultBufferAlignment) + : BasicUnionBuilder(pool, alignment, children, type), + offsets_builder_(pool, alignment) {} + + Status AppendNull() final { + const int8_t first_child_code = type_codes_[0]; + ArrayBuilder* child_builder = type_id_to_children_[first_child_code]; + ARROW_RETURN_NOT_OK(types_builder_.Append(first_child_code)); + ARROW_RETURN_NOT_OK( + offsets_builder_.Append(static_cast(child_builder->length()))); + // Append a null arbitrarily to the first child + return child_builder->AppendNull(); + } + + Status AppendNulls(int64_t length) final { + const int8_t first_child_code = type_codes_[0]; + ArrayBuilder* child_builder = type_id_to_children_[first_child_code]; + ARROW_RETURN_NOT_OK(types_builder_.Append(length, first_child_code)); + ARROW_RETURN_NOT_OK( + offsets_builder_.Append(length, static_cast(child_builder->length()))); + // Append just a single null to the first child + return child_builder->AppendNull(); + } + + Status AppendEmptyValue() final { + const int8_t first_child_code = type_codes_[0]; + ArrayBuilder* child_builder = type_id_to_children_[first_child_code]; + ARROW_RETURN_NOT_OK(types_builder_.Append(first_child_code)); + ARROW_RETURN_NOT_OK( + offsets_builder_.Append(static_cast(child_builder->length()))); + // Append an empty value arbitrarily to the first child + return child_builder->AppendEmptyValue(); + } + + Status AppendEmptyValues(int64_t length) final { + const int8_t first_child_code = type_codes_[0]; + ArrayBuilder* child_builder = type_id_to_children_[first_child_code]; + ARROW_RETURN_NOT_OK(types_builder_.Append(length, first_child_code)); + ARROW_RETURN_NOT_OK( + offsets_builder_.Append(length, static_cast(child_builder->length()))); + // Append just a single empty value to the first child + return child_builder->AppendEmptyValue(); + } + + /// \brief Append an element to the UnionArray. This must be followed + /// by an append to the appropriate child builder. + /// + /// \param[in] next_type type_id of the child to which the next value will be appended. + /// + /// The corresponding child builder must be appended to independently after this method + /// is called. + Status Append(int8_t next_type) { + ARROW_RETURN_NOT_OK(types_builder_.Append(next_type)); + if (type_id_to_children_[next_type]->length() == kListMaximumElements) { + return Status::CapacityError( + "a dense UnionArray cannot contain more than 2^31 - 1 elements from a single " + "child"); + } + auto offset = static_cast(type_id_to_children_[next_type]->length()); + return offsets_builder_.Append(offset); + } + + Status AppendArraySlice(const ArraySpan& array, int64_t offset, + int64_t length) override; + + Status FinishInternal(std::shared_ptr* out) override; + + private: + TypedBufferBuilder offsets_builder_; +}; + +/// \class SparseUnionBuilder +/// +/// This API is EXPERIMENTAL. +class ARROW_EXPORT SparseUnionBuilder : public BasicUnionBuilder { + public: + /// Use this constructor to initialize the UnionBuilder with no child builders, + /// allowing type to be inferred. You will need to call AppendChild for each of the + /// children builders you want to use. + explicit SparseUnionBuilder(MemoryPool* pool, + int64_t alignment = kDefaultBufferAlignment) + : BasicUnionBuilder(pool, alignment, {}, sparse_union(FieldVector{})) {} + + /// Use this constructor to specify the type explicitly. + /// You can still add child builders to the union after using this constructor + SparseUnionBuilder(MemoryPool* pool, + const std::vector>& children, + const std::shared_ptr& type, + int64_t alignment = kDefaultBufferAlignment) + : BasicUnionBuilder(pool, alignment, children, type) {} + + /// \brief Append a null value. + /// + /// A null is appended to the first child, empty values to the other children. + Status AppendNull() final { + const auto first_child_code = type_codes_[0]; + ARROW_RETURN_NOT_OK(types_builder_.Append(first_child_code)); + ARROW_RETURN_NOT_OK(type_id_to_children_[first_child_code]->AppendNull()); + for (int i = 1; i < static_cast(type_codes_.size()); ++i) { + ARROW_RETURN_NOT_OK(type_id_to_children_[type_codes_[i]]->AppendEmptyValue()); + } + return Status::OK(); + } + + /// \brief Append multiple null values. + /// + /// Nulls are appended to the first child, empty values to the other children. + Status AppendNulls(int64_t length) final { + const auto first_child_code = type_codes_[0]; + ARROW_RETURN_NOT_OK(types_builder_.Append(length, first_child_code)); + ARROW_RETURN_NOT_OK(type_id_to_children_[first_child_code]->AppendNulls(length)); + for (int i = 1; i < static_cast(type_codes_.size()); ++i) { + ARROW_RETURN_NOT_OK( + type_id_to_children_[type_codes_[i]]->AppendEmptyValues(length)); + } + return Status::OK(); + } + + Status AppendEmptyValue() final { + ARROW_RETURN_NOT_OK(types_builder_.Append(type_codes_[0])); + for (int8_t code : type_codes_) { + ARROW_RETURN_NOT_OK(type_id_to_children_[code]->AppendEmptyValue()); + } + return Status::OK(); + } + + Status AppendEmptyValues(int64_t length) final { + ARROW_RETURN_NOT_OK(types_builder_.Append(length, type_codes_[0])); + for (int8_t code : type_codes_) { + ARROW_RETURN_NOT_OK(type_id_to_children_[code]->AppendEmptyValues(length)); + } + return Status::OK(); + } + + /// \brief Append an element to the UnionArray. This must be followed + /// by an append to the appropriate child builder. + /// + /// \param[in] next_type type_id of the child to which the next value will be appended. + /// + /// The corresponding child builder must be appended to independently after this method + /// is called, and all other child builders must have null or empty value appended. + Status Append(int8_t next_type) { return types_builder_.Append(next_type); } + + Status AppendArraySlice(const ArraySpan& array, int64_t offset, + int64_t length) override; +}; + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/concatenate.h b/pyarrow/include/arrow/array/concatenate.h new file mode 100644 index 0000000000000000000000000000000000000000..aada5624d63a3052edddf0182799c474bee0c528 --- /dev/null +++ b/pyarrow/include/arrow/array/concatenate.h @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace internal { + +/// \brief Concatenate arrays +/// +/// \param[in] arrays a vector of arrays to be concatenated +/// \param[in] pool memory to store the result will be allocated from this memory pool +/// \param[out] out_suggested_cast if a non-OK Result is returned, the function might set +/// out_suggested_cast to a cast suggestion that would allow concatenating the arrays +/// without overflow of offsets (e.g. string to large_string) +/// +/// \return the concatenated array +ARROW_EXPORT +Result> Concatenate(const ArrayVector& arrays, MemoryPool* pool, + std::shared_ptr* out_suggested_cast); + +} // namespace internal + +/// \brief Concatenate arrays +/// +/// \param[in] arrays a vector of arrays to be concatenated +/// \param[in] pool memory to store the result will be allocated from this memory pool +/// \return the concatenated array +ARROW_EXPORT +Result> Concatenate(const ArrayVector& arrays, + MemoryPool* pool = default_memory_pool()); + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/data.h b/pyarrow/include/arrow/array/data.h new file mode 100644 index 0000000000000000000000000000000000000000..c6636df9bb3025de78b00b1f5b4265783c05e148 --- /dev/null +++ b/pyarrow/include/arrow/array/data.h @@ -0,0 +1,750 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include // IWYU pragma: export +#include +#include +#include +#include +#include + +#include "arrow/array/statistics.h" +#include "arrow/buffer.h" +#include "arrow/result.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/macros.h" +#include "arrow/util/span.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +namespace internal { +// ---------------------------------------------------------------------- +// Null handling for types without a validity bitmap and the dictionary type + +ARROW_EXPORT bool IsNullSparseUnion(const ArrayData& data, int64_t i); +ARROW_EXPORT bool IsNullDenseUnion(const ArrayData& data, int64_t i); +ARROW_EXPORT bool IsNullRunEndEncoded(const ArrayData& data, int64_t i); + +ARROW_EXPORT bool UnionMayHaveLogicalNulls(const ArrayData& data); +ARROW_EXPORT bool RunEndEncodedMayHaveLogicalNulls(const ArrayData& data); +ARROW_EXPORT bool DictionaryMayHaveLogicalNulls(const ArrayData& data); + +} // namespace internal + +// When slicing, we do not know the null count of the sliced range without +// doing some computation. To avoid doing this eagerly, we set the null count +// to -1 (any negative number will do). When Array::null_count is called the +// first time, the null count will be computed. See ARROW-33 +constexpr int64_t kUnknownNullCount = -1; + +// ---------------------------------------------------------------------- +// Generic array data container + +/// \class ArrayData +/// \brief Mutable container for generic Arrow array data +/// +/// This data structure is a self-contained representation of the memory and +/// metadata inside an Arrow array data structure (called vectors in Java). The +/// Array class and its concrete subclasses provide strongly-typed accessors +/// with support for the visitor pattern and other affordances. +/// +/// This class is designed for easy internal data manipulation, analytical data +/// processing, and data transport to and from IPC messages. +/// +/// This class is also useful in an analytics setting where memory may be +/// efficiently reused. For example, computing the Abs of a numeric array +/// should return null iff the input is null: therefore, an Abs function can +/// reuse the validity bitmap (a Buffer) of its input as the validity bitmap +/// of its output. +/// +/// This class is meant mostly for immutable data access. Any mutable access +/// (either to ArrayData members or to the contents of its Buffers) should take +/// into account the fact that ArrayData instances are typically wrapped in a +/// shared_ptr and can therefore have multiple owners at any given time. +/// Therefore, mutable access is discouraged except when initially populating +/// the ArrayData. +struct ARROW_EXPORT ArrayData { + ArrayData() = default; + + ArrayData(std::shared_ptr type, int64_t length, + int64_t null_count = kUnknownNullCount, int64_t offset = 0) + : type(std::move(type)), length(length), null_count(null_count), offset(offset) {} + + ArrayData(std::shared_ptr type, int64_t length, + std::vector> buffers, + int64_t null_count = kUnknownNullCount, int64_t offset = 0) + : ArrayData(std::move(type), length, null_count, offset) { + this->buffers = std::move(buffers); +#ifndef NDEBUG + // in debug mode, call the `device_type` function to trigger + // the DCHECKs that validate all the buffers are on the same device + ARROW_UNUSED(this->device_type()); +#endif + } + + ArrayData(std::shared_ptr type, int64_t length, + std::vector> buffers, + std::vector> child_data, + int64_t null_count = kUnknownNullCount, int64_t offset = 0) + : ArrayData(std::move(type), length, null_count, offset) { + this->buffers = std::move(buffers); + this->child_data = std::move(child_data); +#ifndef NDEBUG + // in debug mode, call the `device_type` function to trigger + // the DCHECKs that validate all the buffers (including children) + // are on the same device + ARROW_UNUSED(this->device_type()); +#endif + } + + static std::shared_ptr Make(std::shared_ptr type, int64_t length, + std::vector> buffers, + int64_t null_count = kUnknownNullCount, + int64_t offset = 0); + + static std::shared_ptr Make( + std::shared_ptr type, int64_t length, + std::vector> buffers, + std::vector> child_data, + int64_t null_count = kUnknownNullCount, int64_t offset = 0); + + static std::shared_ptr Make( + std::shared_ptr type, int64_t length, + std::vector> buffers, + std::vector> child_data, + std::shared_ptr dictionary, int64_t null_count = kUnknownNullCount, + int64_t offset = 0); + + static std::shared_ptr Make(std::shared_ptr type, int64_t length, + int64_t null_count = kUnknownNullCount, + int64_t offset = 0); + + // Move constructor + ArrayData(ArrayData&& other) noexcept + : type(std::move(other.type)), + length(other.length), + null_count(other.null_count.load()), + offset(other.offset), + buffers(std::move(other.buffers)), + child_data(std::move(other.child_data)), + dictionary(std::move(other.dictionary)), + statistics(std::move(other.statistics)) {} + + // Copy constructor + ArrayData(const ArrayData& other) noexcept + : type(other.type), + length(other.length), + null_count(other.null_count.load()), + offset(other.offset), + buffers(other.buffers), + child_data(other.child_data), + dictionary(other.dictionary), + statistics(other.statistics) {} + + // Move assignment + ArrayData& operator=(ArrayData&& other) { + type = std::move(other.type); + length = other.length; + SetNullCount(other.null_count); + offset = other.offset; + buffers = std::move(other.buffers); + child_data = std::move(other.child_data); + dictionary = std::move(other.dictionary); + statistics = std::move(other.statistics); + return *this; + } + + // Copy assignment + ArrayData& operator=(const ArrayData& other) { + type = other.type; + length = other.length; + SetNullCount(other.null_count); + offset = other.offset; + buffers = other.buffers; + child_data = other.child_data; + dictionary = other.dictionary; + statistics = other.statistics; + return *this; + } + + /// \brief Return a shallow copy of this ArrayData + std::shared_ptr Copy() const { return std::make_shared(*this); } + + /// \brief Deep copy this ArrayData to destination memory manager + /// + /// Returns a new ArrayData object with buffers and all child buffers + /// copied to the destination memory manager. This includes dictionaries + /// if applicable. + Result> CopyTo( + const std::shared_ptr& to) const; + + /// \brief View or copy this ArrayData to destination memory manager + /// + /// Tries to view the buffer contents on the given memory manager's device + /// if possible (to avoid a copy) but falls back to copying if a no-copy view + /// isn't supported. + Result> ViewOrCopyTo( + const std::shared_ptr& to) const; + + /// \brief Return the null-ness of a given array element + /// + /// Calling `IsNull(i)` is the same as `!IsValid(i)`. + bool IsNull(int64_t i) const { return !IsValid(i); } + + /// \brief Return the validity of a given array element + /// + /// For most data types, this will simply query the validity bitmap. + /// For union and run-end-encoded arrays, the underlying child data is + /// queried instead. + /// For dictionary arrays, this reflects the validity of the dictionary + /// index, but the corresponding dictionary value might still be null. + /// For null arrays, this always returns false. + bool IsValid(int64_t i) const { + if (buffers[0] != NULLPTR) { + return bit_util::GetBit(buffers[0]->data(), i + offset); + } + const auto type = this->type->id(); + if (type == Type::SPARSE_UNION) { + return !internal::IsNullSparseUnion(*this, i); + } + if (type == Type::DENSE_UNION) { + return !internal::IsNullDenseUnion(*this, i); + } + if (type == Type::RUN_END_ENCODED) { + return !internal::IsNullRunEndEncoded(*this, i); + } + return null_count.load() != length; + } + + /// \brief Access a buffer's data as a typed C pointer + /// + /// \param i the buffer index + /// \param absolute_offset the offset into the buffer + /// + /// If `absolute_offset` is non-zero, the type `T` must match the + /// layout of buffer number `i` for the array's data type; otherwise + /// offset computation would be incorrect. + /// + /// If the given buffer is bit-packed (such as a validity bitmap, or + /// the data buffer of a boolean array), then `absolute_offset` must be + /// zero for correct results, and any bit offset must be applied manually + /// by the caller. + template + inline const T* GetValues(int i, int64_t absolute_offset) const { + if (buffers[i]) { + return reinterpret_cast(buffers[i]->data()) + absolute_offset; + } else { + return NULLPTR; + } + } + + /// \brief Access a buffer's data as a typed C pointer + /// + /// \param i the buffer index + /// + /// This method uses the array's offset to index into buffer number `i`. + /// + /// Calling this method on a bit-packed buffer (such as a validity bitmap, or + /// the data buffer of a boolean array) will lead to incorrect results. + /// You should instead call `GetValues(i, 0)` and apply the bit offset manually. + template + inline const T* GetValues(int i) const { + return GetValues(i, offset); + } + + /// \brief Access a buffer's data as a typed C pointer + /// + /// \param i the buffer index + /// \param absolute_offset the offset into the buffer + /// + /// Like `GetValues(i, absolute_offset)`, but returns nullptr if the given buffer + /// is not a CPU buffer. + template + inline const T* GetValuesSafe(int i, int64_t absolute_offset) const { + if (buffers[i] && buffers[i]->is_cpu()) { + return reinterpret_cast(buffers[i]->data()) + absolute_offset; + } else { + return NULLPTR; + } + } + + /// \brief Access a buffer's data as a typed C pointer + /// + /// \param i the buffer index + /// + /// Like `GetValues(i)`, but returns nullptr if the given buffer is not a CPU buffer. + template + inline const T* GetValuesSafe(int i) const { + return GetValuesSafe(i, offset); + } + + /// \brief Access a buffer's data as a mutable typed C pointer + /// + /// \param i the buffer index + /// \param absolute_offset the offset into the buffer + /// + /// Like `GetValues(i, absolute_offset)`, but allows mutating buffer contents. + /// This should only be used when initially populating the ArrayData, before + /// it is attached to a Array instance. + template + inline T* GetMutableValues(int i, int64_t absolute_offset) { + if (buffers[i]) { + return reinterpret_cast(buffers[i]->mutable_data()) + absolute_offset; + } else { + return NULLPTR; + } + } + + /// \brief Access a buffer's data as a mutable typed C pointer + /// + /// \param i the buffer index + /// + /// Like `GetValues(i)`, but allows mutating buffer contents. + /// This should only be used when initially populating the ArrayData, before + /// it is attached to a Array instance. + template + inline T* GetMutableValues(int i) { + return GetMutableValues(i, offset); + } + + /// \brief Construct a zero-copy slice of the data with the given offset and length + /// + /// This method applies the given slice to this ArrayData, taking into account + /// its existing offset and length. + /// If the given `length` is too large, the slice length is clamped so as not + /// to go past the offset end. + /// If the given `often` is too large, or if either `offset` or `length` is negative, + /// behavior is undefined. + /// + /// The associated ArrayStatistics is always discarded in a sliced + /// ArrayData, even if the slice is trivially equal to the original ArrayData. + /// If you want to reuse the statistics from the original ArrayData, you must + /// explicitly reattach them. + std::shared_ptr Slice(int64_t offset, int64_t length) const; + + /// \brief Construct a zero-copy slice of the data with the given offset and length + /// + /// Like `Slice(offset, length)`, but returns an error if the requested slice + /// falls out of bounds. + /// Unlike Slice, `length` isn't clamped to the available buffer size. + Result> SliceSafe(int64_t offset, int64_t length) const; + + /// \brief Set the cached physical null count + /// + /// \param v the number of nulls in the ArrayData + /// + /// This should only be used when initially populating the ArrayData, if + /// it possible to compute the null count without visiting the entire validity + /// bitmap. In most cases, relying on `GetNullCount` is sufficient. + void SetNullCount(int64_t v) { null_count.store(v); } + + /// \brief Return the physical null count + /// + /// This method returns the number of array elements for which `IsValid` would + /// return false. + /// + /// A cached value is returned if already available, otherwise it is first + /// computed and stored. + /// How it is is computed depends on the data type, see `IsValid` for details. + /// + /// Note that this method is typically much faster than calling `IsValid` + /// for all elements. Therefore, it helps avoid per-element validity bitmap + /// lookups in the common cases where the array contains zero or only nulls. + int64_t GetNullCount() const; + + /// \brief Return true if the array may have nulls in its validity bitmap + /// + /// This method returns true if the data has a validity bitmap, and the physical + /// null count is either known to be non-zero or not yet known. + /// + /// Unlike `MayHaveLogicalNulls`, this does not check for the presence of nulls + /// in child data for data types such as unions and run-end encoded types. + /// + /// \see HasValidityBitmap + /// \see MayHaveLogicalNulls + bool MayHaveNulls() const { + // If an ArrayData is slightly malformed it may have kUnknownNullCount set + // but no buffer + return null_count.load() != 0 && buffers[0] != NULLPTR; + } + + /// \brief Return true if the array has a validity bitmap + bool HasValidityBitmap() const { return buffers[0] != NULLPTR; } + + /// \brief Return true if the array may have logical nulls + /// + /// Unlike `MayHaveNulls`, this method checks for null child values + /// for types without a validity bitmap, such as unions and run-end encoded + /// types, and for null dictionary values for dictionary types. + /// + /// This implies that `MayHaveLogicalNulls` may return true for arrays that + /// don't have a top-level validity bitmap. It is therefore necessary + /// to call `HasValidityBitmap` before accessing a top-level validity bitmap. + /// + /// Code that previously used MayHaveNulls and then dealt with the validity + /// bitmap directly can be fixed to handle all types correctly without + /// performance degradation when handling most types by adopting + /// HasValidityBitmap and MayHaveLogicalNulls. + /// + /// Before: + /// + /// uint8_t* validity = array.MayHaveNulls() ? array.buffers[0].data : NULLPTR; + /// for (int64_t i = 0; i < array.length; ++i) { + /// if (validity && !bit_util::GetBit(validity, i)) { + /// continue; // skip a NULL + /// } + /// ... + /// } + /// + /// After: + /// + /// bool all_valid = !array.MayHaveLogicalNulls(); + /// uint8_t* validity = array.HasValidityBitmap() ? array.buffers[0].data : NULLPTR; + /// for (int64_t i = 0; i < array.length; ++i) { + /// bool is_valid = all_valid || + /// (validity && bit_util::GetBit(validity, i)) || + /// array.IsValid(i); + /// if (!is_valid) { + /// continue; // skip a NULL + /// } + /// ... + /// } + bool MayHaveLogicalNulls() const { + if (buffers[0] != NULLPTR) { + return null_count.load() != 0; + } + const auto t = type->id(); + if (t == Type::SPARSE_UNION || t == Type::DENSE_UNION) { + return internal::UnionMayHaveLogicalNulls(*this); + } + if (t == Type::RUN_END_ENCODED) { + return internal::RunEndEncodedMayHaveLogicalNulls(*this); + } + if (t == Type::DICTIONARY) { + return internal::DictionaryMayHaveLogicalNulls(*this); + } + return null_count.load() != 0; + } + + /// \brief Compute the logical null count for arrays of all types + /// + /// If the array has a validity bitmap, this function behaves the same as + /// GetNullCount. For arrays that have no validity bitmap but whose values + /// may be logically null (such as union arrays and run-end encoded arrays), + /// this function recomputes the null count every time it is called. + /// + /// \see GetNullCount + int64_t ComputeLogicalNullCount() const; + + /// \brief Return the device_type of the underlying buffers and children + /// + /// If there are no buffers in this ArrayData object, it just returns + /// DeviceAllocationType::kCPU as a default. We also assume that all buffers + /// should be allocated on the same device type and perform DCHECKs to confirm + /// this in debug mode. + /// + /// \return DeviceAllocationType + DeviceAllocationType device_type() const; + + std::shared_ptr type; + int64_t length = 0; + mutable std::atomic null_count{0}; + // The logical start point into the physical buffers (in values, not bytes). + // Note that, for child data, this must be *added* to the child data's own offset. + int64_t offset = 0; + std::vector> buffers; + std::vector> child_data; + + // The dictionary for this Array, if any. Only used for dictionary type + std::shared_ptr dictionary; + + // The statistics for this Array. + std::shared_ptr statistics; +}; + +/// \brief A non-owning Buffer reference +struct ARROW_EXPORT BufferSpan { + // It is the user of this class's responsibility to ensure that + // buffers that were const originally are not written to + // accidentally. + uint8_t* data = NULLPTR; + int64_t size = 0; + // Pointer back to buffer that owns this memory + const std::shared_ptr* owner = NULLPTR; + + template + const T* data_as() const { + return reinterpret_cast(data); + } + template + T* mutable_data_as() { + return reinterpret_cast(data); + } +}; + +/// \brief EXPERIMENTAL: A non-owning array data container +/// +/// Unlike ArrayData, this class doesn't own its referenced data type nor data buffers. +/// It is cheaply copyable and can therefore be suitable for use cases where +/// shared_ptr overhead is not acceptable. However, care should be taken to +/// keep alive the referenced objects and memory while the ArraySpan object is in use. +/// For this reason, this should not be exposed in most public APIs (apart from +/// compute kernel interfaces). +struct ARROW_EXPORT ArraySpan { + const DataType* type = NULLPTR; + int64_t length = 0; + mutable int64_t null_count = kUnknownNullCount; + int64_t offset = 0; + BufferSpan buffers[3]; + + ArraySpan() = default; + + explicit ArraySpan(const DataType* type, int64_t length) : type(type), length(length) {} + + ArraySpan(const ArrayData& data) { // NOLINT implicit conversion + SetMembers(data); + } + explicit ArraySpan(const Scalar& data) { FillFromScalar(data); } + + /// If dictionary-encoded, put dictionary in the first entry + std::vector child_data; + + /// \brief Populate ArraySpan to look like an array of length 1 pointing at + /// the data members of a Scalar value + void FillFromScalar(const Scalar& value); + + void SetMembers(const ArrayData& data); + + void SetBuffer(int index, const std::shared_ptr& buffer) { + this->buffers[index].data = const_cast(buffer->data()); + this->buffers[index].size = buffer->size(); + this->buffers[index].owner = &buffer; + } + + const ArraySpan& dictionary() const { return child_data[0]; } + + /// \brief Return the number of buffers (out of 3) that are used to + /// constitute this array + int num_buffers() const; + + // Access a buffer's data as a typed C pointer + template + inline T* GetValues(int i, int64_t absolute_offset) { + return reinterpret_cast(buffers[i].data) + absolute_offset; + } + + template + inline T* GetValues(int i) { + return GetValues(i, this->offset); + } + + // Access a buffer's data as a typed C pointer + template + inline const T* GetValues(int i, int64_t absolute_offset) const { + return reinterpret_cast(buffers[i].data) + absolute_offset; + } + + template + inline const T* GetValues(int i) const { + return GetValues(i, this->offset); + } + + /// \brief Access a buffer's data as a span + /// + /// \param i The buffer index + /// \param length The required length (in number of typed values) of the requested span + /// \pre i > 0 + /// \pre length <= the length of the buffer (in number of values) that's expected for + /// this array type + /// \return A span of the requested length + template + util::span GetSpan(int i, int64_t length) const { + const int64_t buffer_length = buffers[i].size / static_cast(sizeof(T)); + assert(i > 0 && length + offset <= buffer_length); + ARROW_UNUSED(buffer_length); + return util::span(buffers[i].data_as() + this->offset, length); + } + + /// \brief Access a buffer's data as a span + /// + /// \param i The buffer index + /// \param length The required length (in number of typed values) of the requested span + /// \pre i > 0 + /// \pre length <= the length of the buffer (in number of values) that's expected for + /// this array type + /// \return A span of the requested length + template + util::span GetSpan(int i, int64_t length) { + const int64_t buffer_length = buffers[i].size / static_cast(sizeof(T)); + assert(i > 0 && length + offset <= buffer_length); + ARROW_UNUSED(buffer_length); + return util::span(buffers[i].mutable_data_as() + this->offset, length); + } + + inline bool IsNull(int64_t i) const { return !IsValid(i); } + + inline bool IsValid(int64_t i) const { + if (this->buffers[0].data != NULLPTR) { + return bit_util::GetBit(this->buffers[0].data, i + this->offset); + } else { + const auto type = this->type->id(); + if (type == Type::SPARSE_UNION) { + return !IsNullSparseUnion(i); + } + if (type == Type::DENSE_UNION) { + return !IsNullDenseUnion(i); + } + if (type == Type::RUN_END_ENCODED) { + return !IsNullRunEndEncoded(i); + } + return this->null_count != this->length; + } + } + + std::shared_ptr ToArrayData() const; + + std::shared_ptr ToArray() const; + + std::shared_ptr GetBuffer(int index) const { + const BufferSpan& buf = this->buffers[index]; + if (buf.owner) { + return *buf.owner; + } else if (buf.data != NULLPTR) { + // Buffer points to some memory without an owning buffer + return std::make_shared(buf.data, buf.size); + } else { + return NULLPTR; + } + } + + void SetSlice(int64_t offset, int64_t length) { + this->offset = offset; + this->length = length; + if (this->type->id() == Type::NA) { + this->null_count = this->length; + } else if (buffers[0].data != NULLPTR) { + this->null_count = kUnknownNullCount; + } else { + this->null_count = 0; + } + } + + /// \brief Return physical null count, or compute and set it if it's not known + int64_t GetNullCount() const; + + /// \brief Return true if the array has a validity bitmap and the physical null + /// count is known to be non-zero or not yet known + /// + /// Note that this is not the same as MayHaveLogicalNulls, which also checks + /// for the presence of nulls in child data for types like unions and run-end + /// encoded types. + /// + /// \see HasValidityBitmap + /// \see MayHaveLogicalNulls + bool MayHaveNulls() const { + // If an ArrayData is slightly malformed it may have kUnknownNullCount set + // but no buffer + return null_count != 0 && buffers[0].data != NULLPTR; + } + + /// \brief Return true if the array has a validity bitmap + bool HasValidityBitmap() const { return buffers[0].data != NULLPTR; } + + /// \brief Return true if the validity bitmap may have 0's in it, or if the + /// child arrays (in the case of types without a validity bitmap) may have + /// nulls, or if the dictionary of dictionay array may have nulls. + /// + /// \see ArrayData::MayHaveLogicalNulls + bool MayHaveLogicalNulls() const { + if (buffers[0].data != NULLPTR) { + return null_count != 0; + } + const auto t = type->id(); + if (t == Type::SPARSE_UNION || t == Type::DENSE_UNION) { + return UnionMayHaveLogicalNulls(); + } + if (t == Type::RUN_END_ENCODED) { + return RunEndEncodedMayHaveLogicalNulls(); + } + if (t == Type::DICTIONARY) { + return DictionaryMayHaveLogicalNulls(); + } + return null_count != 0; + } + + /// \brief Compute the logical null count for arrays of all types including + /// those that do not have a validity bitmap like union and run-end encoded + /// arrays + /// + /// If the array has a validity bitmap, this function behaves the same as + /// GetNullCount. For types that have no validity bitmap, this function will + /// recompute the logical null count every time it is called. + /// + /// \see GetNullCount + int64_t ComputeLogicalNullCount() const; + + /// Some DataTypes (StringView, BinaryView) may have an arbitrary number of variadic + /// buffers. Since ArraySpan only has 3 buffers, we pack the variadic buffers into + /// buffers[2]; IE buffers[2].data points to the first shared_ptr of the + /// variadic set and buffers[2].size is the number of variadic buffers times + /// sizeof(shared_ptr). + /// + /// \see HasVariadicBuffers + util::span> GetVariadicBuffers() const; + bool HasVariadicBuffers() const; + + private: + ARROW_FRIEND_EXPORT friend bool internal::IsNullRunEndEncoded(const ArrayData& data, + int64_t i); + + bool IsNullSparseUnion(int64_t i) const; + bool IsNullDenseUnion(int64_t i) const; + + /// \brief Return true if the value at logical index i is null + /// + /// This function uses binary-search, so it has a O(log N) cost. + /// Iterating over the whole array and calling IsNull is O(N log N), so + /// for better performance it is recommended to use a + /// ree_util::RunEndEncodedArraySpan to iterate run by run instead. + bool IsNullRunEndEncoded(int64_t i) const; + + bool UnionMayHaveLogicalNulls() const; + bool RunEndEncodedMayHaveLogicalNulls() const; + bool DictionaryMayHaveLogicalNulls() const; +}; + +namespace internal { + +void FillZeroLengthArray(const DataType* type, ArraySpan* span); + +/// Construct a zero-copy view of this ArrayData with the given type. +/// +/// This method checks if the types are layout-compatible. +/// Nested types are traversed in depth-first order. Data buffers must have +/// the same item sizes, even though the logical types may be different. +/// An error is returned if the types are not layout-compatible. +ARROW_EXPORT +Result> GetArrayView(const std::shared_ptr& data, + const std::shared_ptr& type); + +} // namespace internal +} // namespace arrow diff --git a/pyarrow/include/arrow/array/diff.h b/pyarrow/include/arrow/array/diff.h new file mode 100644 index 0000000000000000000000000000000000000000..a405164b333f3b21a17e8414ef59a8a628c28579 --- /dev/null +++ b/pyarrow/include/arrow/array/diff.h @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/array_nested.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \brief Compare two arrays, returning an edit script which expresses the difference +/// between them +/// +/// An edit script is an array of struct(insert: bool, run_length: int64_t). +/// Each element of "insert" determines whether an element was inserted into (true) +/// or deleted from (false) base. Each insertion or deletion is followed by a run of +/// elements which are unchanged from base to target; the length of this run is stored +/// in "run_length". (Note that the edit script begins and ends with a run of shared +/// elements but both fields of the struct must have the same length. To accommodate this +/// the first element of "insert" should be ignored.) +/// +/// For example for base "hlloo" and target "hello", the edit script would be +/// [ +/// {"insert": false, "run_length": 1}, // leading run of length 1 ("h") +/// {"insert": true, "run_length": 3}, // insert("e") then a run of length 3 ("llo") +/// {"insert": false, "run_length": 0} // delete("o") then an empty run +/// ] +/// +/// Diffing arrays containing nulls is not currently supported. +/// +/// \param[in] base baseline for comparison +/// \param[in] target an array of identical type to base whose elements differ from base's +/// \param[in] pool memory to store the result will be allocated from this memory pool +/// \return an edit script array which can be applied to base to produce target +ARROW_EXPORT +Result> Diff(const Array& base, const Array& target, + MemoryPool* pool = default_memory_pool()); + +/// \brief visitor interface for easy traversal of an edit script +/// +/// visitor will be called for each hunk of insertions and deletions. +ARROW_EXPORT Status VisitEditScript( + const Array& edits, + const std::function& visitor); + +/// \brief return a function which will format an edit script in unified +/// diff format to os, given base and target arrays of type +ARROW_EXPORT Result< + std::function> +MakeUnifiedDiffFormatter(const DataType& type, std::ostream* os); + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/statistics.h b/pyarrow/include/arrow/array/statistics.h new file mode 100644 index 0000000000000000000000000000000000000000..ae78dca0b0c6b19b4ad4bfc8deb9962260afb466 --- /dev/null +++ b/pyarrow/include/arrow/array/statistics.h @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/compare.h" +#include "arrow/type.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \class ArrayStatistics +/// \brief Statistics for an Array +/// +/// Apache Arrow format doesn't have statistics but data source such +/// as Apache Parquet may have statistics. Statistics associated with +/// data source can be read unified API via this class. +struct ARROW_EXPORT ArrayStatistics { + /// \brief The type for maximum and minimum values. If the target + /// value exists, one of them is used. `std::nullopt` is used + /// otherwise. + using ValueType = std::variant; + using NumericType = std::variant; + using CountType = NumericType; + using SizeType = NumericType; + + static const std::shared_ptr& ValueToArrowType( + const std::optional& value, + const std::shared_ptr& array_type) { + if (!value.has_value()) { + return null(); + } + + struct Visitor { + const std::shared_ptr& array_type; + + const std::shared_ptr& operator()(const bool&) { return boolean(); } + const std::shared_ptr& operator()(const int64_t&) { return int64(); } + const std::shared_ptr& operator()(const uint64_t&) { return uint64(); } + const std::shared_ptr& operator()(const double&) { return float64(); } + const std::shared_ptr& operator()(const std::string&) { + switch (array_type->id()) { + case Type::STRING: + case Type::BINARY: + case Type::FIXED_SIZE_BINARY: + case Type::LARGE_STRING: + case Type::LARGE_BINARY: + case Type::BINARY_VIEW: + case Type::STRING_VIEW: + return array_type; + default: + return utf8(); + } + } + } visitor{array_type}; + return std::visit(visitor, value.value()); + } + + /// \brief The number of rows, may not be set + /// Note: when set to `int64_t`, it represents `exact_row_count`, + /// and when set to `double`, it represents `approximate_row_count`. + /// Note: this value is not used by \ref arrow::RecordBatch::MakeStatisticsArray. + std::optional row_count = std::nullopt; + + /// \brief The number of null values, may not be set + /// Note: when set to `int64_t`, it represents `exact_null_count`, + /// and when set to `double`, it represents `approximate_null_count`. + std::optional null_count = std::nullopt; + + /// \brief The number of distinct values, may not be set + /// Note: when set to `int64_t`, it represents `exact_distinct_count`, + /// and when set to `double`, it represents `approximate_distinct_count`. + std::optional distinct_count = std::nullopt; + + /// \brief The maximum length in bytes of the rows in an array; may not be set + /// Note: when the type is `int64_t`, it represents `max_byte_width_exact`, + /// and when the type is `double`, it represents `max_byte_width_approximate`. + std::optional max_byte_width = std::nullopt; + + /// \brief The average size in bytes of a row in an array, may not be set. + std::optional average_byte_width = std::nullopt; + + /// \brief Whether the average size in bytes is exact or not. + bool is_average_byte_width_exact = false; + + /// \brief The minimum value, may not be set + std::optional min = std::nullopt; + + /// \brief Compute Arrow type of the minimum value. + /// + /// If \ref ValueType is `std::string`, `array_type` may be + /// used. If `array_type` is a binary-like type such as \ref + /// arrow::binary and \ref arrow::large_utf8, `array_type` is + /// returned. \ref arrow::utf8 is returned otherwise. + /// + /// If \ref ValueType isn't `std::string`, `array_type` isn't used. + /// + /// \param array_type The Arrow type of the associated array. + /// + /// \return \ref arrow::null if the minimum value is `std::nullopt`, + /// Arrow type based on \ref ValueType of the \ref min + /// otherwise. + const std::shared_ptr& MinArrowType( + const std::shared_ptr& array_type) { + return ValueToArrowType(min, array_type); + } + + /// \brief Whether the minimum value is exact or not + bool is_min_exact = false; + + /// \brief The maximum value, may not be set + std::optional max = std::nullopt; + + /// \brief Compute Arrow type of the maximum value. + /// + /// If \ref ValueType is `std::string`, `array_type` may be + /// used. If `array_type` is a binary-like type such as \ref + /// arrow::binary and \ref arrow::large_utf8, `array_type` is + /// returned. \ref arrow::utf8 is returned otherwise. + /// + /// If \ref ValueType isn't `std::string`, `array_type` isn't used. + /// + /// \param array_type The Arrow type of the associated array. + /// + /// \return \ref arrow::null if the maximum value is `std::nullopt`, + /// Arrow type based on \ref ValueType of the \ref max + /// otherwise. + const std::shared_ptr& MaxArrowType( + const std::shared_ptr& array_type) { + return ValueToArrowType(max, array_type); + } + + /// \brief Whether the maximum value is exact or not + bool is_max_exact = false; + + /// \brief Check two \ref arrow::ArrayStatistics for equality + /// + /// \param other The \ref arrow::ArrayStatistics instance to compare against. + /// + /// \param equal_options Options used to compare double values for equality. + /// + /// \return True if the two \ref arrow::ArrayStatistics instances are equal; otherwise, + /// false. + bool Equals(const ArrayStatistics& other, + const EqualOptions& equal_options = EqualOptions::Defaults()) const { + return ArrayStatisticsEquals(*this, other, equal_options); + } + + /// \brief Check two statistics for equality + bool operator==(const ArrayStatistics& other) const { return Equals(other); } + + /// \brief Check two statistics for not equality + bool operator!=(const ArrayStatistics& other) const { return !Equals(other); } +}; + +} // namespace arrow diff --git a/pyarrow/include/arrow/array/util.h b/pyarrow/include/arrow/array/util.h new file mode 100644 index 0000000000000000000000000000000000000000..fd8e75ddb86405c523a8083f559dab0e72364e24 --- /dev/null +++ b/pyarrow/include/arrow/array/util.h @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/array/data.h" +#include "arrow/compare.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \defgroup array-factories Array factory functions +/// +/// @{ + +/// \brief Create a strongly-typed Array instance from generic ArrayData +/// \param[in] data the array contents +/// \return the resulting Array instance +ARROW_EXPORT +std::shared_ptr MakeArray(const std::shared_ptr& data); + +/// \brief Create a strongly-typed Array instance with all elements null +/// \param[in] type the array type +/// \param[in] length the array length +/// \param[in] pool the memory pool to allocate memory from +ARROW_EXPORT +Result> MakeArrayOfNull(const std::shared_ptr& type, + int64_t length, + MemoryPool* pool = default_memory_pool()); + +/// \brief Create an Array instance whose slots are the given scalar +/// \param[in] scalar the value with which to fill the array +/// \param[in] length the array length +/// \param[in] pool the memory pool to allocate memory from +ARROW_EXPORT +Result> MakeArrayFromScalar( + const Scalar& scalar, int64_t length, MemoryPool* pool = default_memory_pool()); + +/// \brief Create an empty Array of a given type +/// +/// The output Array will be of the given type. +/// +/// \param[in] type the data type of the empty Array +/// \param[in] pool the memory pool to allocate memory from +/// \return the resulting Array +ARROW_EXPORT +Result> MakeEmptyArray(std::shared_ptr type, + MemoryPool* pool = default_memory_pool()); + +/// @} + +namespace internal { + +/// \brief Swap endian of each element in a generic ArrayData +/// +/// As dictionaries are often shared between different arrays, dictionaries +/// are not swapped by this function and should be handled separately. +/// +/// \param[in] data the array contents +/// \param[in] pool the memory pool to allocate memory from +/// \return the resulting ArrayData whose elements were swapped +ARROW_EXPORT +Result> SwapEndianArrayData( + const std::shared_ptr& data, MemoryPool* pool = default_memory_pool()); + +/// Given a number of ArrayVectors, treat each ArrayVector as the +/// chunks of a chunked array. Then rechunk each ArrayVector such that +/// all ArrayVectors are chunked identically. It is mandatory that +/// all ArrayVectors contain the same total number of elements. +ARROW_EXPORT +std::vector RechunkArraysConsistently(const std::vector&); + +} // namespace internal +} // namespace arrow diff --git a/pyarrow/include/arrow/array/validate.h b/pyarrow/include/arrow/array/validate.h new file mode 100644 index 0000000000000000000000000000000000000000..3ebfa0a51edce21ca585862b1dbb074b6cf8d9c8 --- /dev/null +++ b/pyarrow/include/arrow/array/validate.h @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace internal { + +// Internal functions implementing Array::Validate() and friends. + +// O(1) array metadata validation + +ARROW_EXPORT +Status ValidateArray(const Array& array); + +ARROW_EXPORT +Status ValidateArray(const ArrayData& data); + +// O(N) array data validation. +// Note that, starting from 7.0.0, "full" routines also validate metadata. +// Before, ValidateArray() needed to be called before ValidateArrayFull() +// to ensure metadata correctness, otherwise invalid memory accesses +// may occur. + +ARROW_EXPORT +Status ValidateArrayFull(const Array& array); + +ARROW_EXPORT +Status ValidateArrayFull(const ArrayData& data); + +ARROW_EXPORT +Status ValidateUTF8(const Array& array); + +ARROW_EXPORT +Status ValidateUTF8(const ArrayData& data); + +} // namespace internal +} // namespace arrow diff --git a/pyarrow/include/arrow/buffer.h b/pyarrow/include/arrow/buffer.h new file mode 100644 index 0000000000000000000000000000000000000000..ce909a3ea182f4d1d8fb294512ccc74e55bf0030 --- /dev/null +++ b/pyarrow/include/arrow/buffer.h @@ -0,0 +1,587 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/device.h" +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/span.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +// ---------------------------------------------------------------------- +// Buffer classes + +/// \class Buffer +/// \brief Object containing a pointer to a piece of contiguous memory with a +/// particular size. +/// +/// Buffers have two related notions of length: size and capacity. Size is +/// the number of bytes that might have valid data. Capacity is the number +/// of bytes that were allocated for the buffer in total. +/// +/// The Buffer base class does not own its memory, but subclasses often do. +/// +/// The following invariant is always true: Size <= Capacity +class ARROW_EXPORT Buffer { + public: + ARROW_DISALLOW_COPY_AND_ASSIGN(Buffer); + + /// \brief Construct from buffer and size without copying memory + /// + /// \param[in] data a memory buffer + /// \param[in] size buffer size + /// + /// \note The passed memory must be kept alive through some other means + Buffer(const uint8_t* data, int64_t size) + : is_mutable_(false), + is_cpu_(true), + data_(data), + size_(size), + capacity_(size), + device_type_(DeviceAllocationType::kCPU) { + SetMemoryManager(default_cpu_memory_manager()); + } + + Buffer(const uint8_t* data, int64_t size, std::shared_ptr mm, + std::shared_ptr parent = NULLPTR, + std::optional device_type_override = std::nullopt) + : is_mutable_(false), + data_(data), + size_(size), + capacity_(size), + parent_(std::move(parent)) { + // SetMemoryManager will also set device_type_ + SetMemoryManager(std::move(mm)); + // If a device type is specified, use that instead. Example of when this can be + // useful: the CudaMemoryManager can set device_type_ to kCUDA, but you can specify + // device_type_override=kCUDA_HOST as the device type to override it. + if (device_type_override != std::nullopt) { + device_type_ = *device_type_override; + } + } + + Buffer(uintptr_t address, int64_t size, std::shared_ptr mm, + std::shared_ptr parent = NULLPTR) + : Buffer(reinterpret_cast(address), size, std::move(mm), + std::move(parent)) {} + + /// \brief Construct from string_view without copying memory + /// + /// \param[in] data a string_view object + /// + /// \note The memory viewed by data must not be deallocated in the lifetime of the + /// Buffer; temporary rvalue strings must be stored in an lvalue somewhere + explicit Buffer(std::string_view data) + : Buffer(reinterpret_cast(data.data()), + static_cast(data.size())) {} + + virtual ~Buffer() = default; + + /// An offset into data that is owned by another buffer, but we want to be + /// able to retain a valid pointer to it even after other shared_ptr's to the + /// parent buffer have been destroyed + /// + /// This method makes no assertions about alignment or padding of the buffer but + /// in general we expected buffers to be aligned and padded to 64 bytes. In the future + /// we might add utility methods to help determine if a buffer satisfies this contract. + Buffer(std::shared_ptr parent, const int64_t offset, const int64_t size) + : Buffer(parent->data_ + offset, size) { + parent_ = std::move(parent); + SetMemoryManager(parent_->memory_manager_); + } + + uint8_t operator[](std::size_t i) const { return data_[i]; } + + /// \brief Construct a new std::string with a hexadecimal representation of the buffer. + /// \return std::string + std::string ToHexString(); + + /// Return true if both buffers are the same size and contain the same bytes + /// up to the number of compared bytes + bool Equals(const Buffer& other, int64_t nbytes) const; + + /// Return true if both buffers are the same size and contain the same bytes + bool Equals(const Buffer& other) const; + + /// Copy a section of the buffer into a new Buffer. + Result> CopySlice( + const int64_t start, const int64_t nbytes, + MemoryPool* pool = default_memory_pool()) const; + + /// Zero bytes in padding, i.e. bytes between size_ and capacity_. + void ZeroPadding() { +#ifndef NDEBUG + CheckMutable(); +#endif + // A zero-capacity buffer can have a null data pointer + if (capacity_ != 0) { + memset(mutable_data() + size_, 0, static_cast(capacity_ - size_)); + } + } + + /// \brief Construct an immutable buffer that takes ownership of the contents + /// of an std::string (without copying it). + /// + /// \param[in] data a string to own + /// \return a new Buffer instance + static std::shared_ptr FromString(std::string data); + + /// \brief Construct an immutable buffer that takes ownership of the contents + /// of an std::vector (without copying it). Only vectors of TrivialType objects + /// (integers, floating point numbers, ...) can be wrapped by this function. + /// + /// \param[in] vec a vector to own + /// \return a new Buffer instance + template + static std::shared_ptr FromVector(std::vector vec) { + static_assert(std::is_trivial_v, + "Buffer::FromVector can only wrap vectors of trivial objects"); + + if (vec.empty()) { + return std::shared_ptr{new Buffer()}; + } + + auto* data = reinterpret_cast(vec.data()); + auto size_in_bytes = static_cast(vec.size() * sizeof(T)); + return std::shared_ptr{ + new Buffer{data, size_in_bytes}, + // Keep the vector's buffer alive inside the shared_ptr's destructor until after + // we have deleted the Buffer. Note we can't use this trick in FromString since + // std::string's data is inline for short strings so moving invalidates pointers + // into the string's buffer. + [vec = std::move(vec)](Buffer* buffer) { delete buffer; }}; + } + + /// \brief Create buffer referencing typed memory with some length without + /// copying + /// \param[in] data the typed memory as C array + /// \param[in] length the number of values in the array + /// \return a new shared_ptr + template + static std::shared_ptr Wrap(const T* data, SizeType length) { + return std::make_shared(reinterpret_cast(data), + static_cast(sizeof(T) * length)); + } + + /// \brief Create buffer referencing std::vector with some length without + /// copying + /// \param[in] data the vector to be referenced. If this vector is changed, + /// the buffer may become invalid + /// \return a new shared_ptr + template + static std::shared_ptr Wrap(const std::vector& data) { + return std::make_shared(reinterpret_cast(data.data()), + static_cast(sizeof(T) * data.size())); + } + + /// \brief Copy buffer contents into a new std::string + /// \return std::string + /// \note Can throw std::bad_alloc if buffer is large + std::string ToString() const; + + /// \brief View buffer contents as a std::string_view + /// \return std::string_view + explicit operator std::string_view() const { + return {reinterpret_cast(data_), static_cast(size_)}; + } + + /// \brief Return a pointer to the buffer's data + /// + /// The buffer has to be a CPU buffer (`is_cpu()` is true). + /// Otherwise, an assertion may be thrown or a null pointer may be returned. + /// + /// To get the buffer's data address regardless of its device, call `address()`. + const uint8_t* data() const { +#ifndef NDEBUG + CheckCPU(); +#endif + return ARROW_PREDICT_TRUE(is_cpu_) ? data_ : NULLPTR; + } + + /// \brief Return a pointer to the buffer's data cast to a specific type + /// + /// The buffer has to be a CPU buffer (`is_cpu()` is true). + /// Otherwise, an assertion may be thrown or a null pointer may be returned. + template + const T* data_as() const { + return reinterpret_cast(data()); + } + + /// \brief Return the buffer's data as a span + template + util::span span_as() const { + return util::span(data_as(), static_cast(size() / sizeof(T))); + } + + /// \brief Return a writable pointer to the buffer's data + /// + /// The buffer has to be a mutable CPU buffer (`is_cpu()` and `is_mutable()` + /// are true). Otherwise, an assertion may be thrown or a null pointer may + /// be returned. + /// + /// To get the buffer's mutable data address regardless of its device, call + /// `mutable_address()`. + uint8_t* mutable_data() { +#ifndef NDEBUG + CheckCPU(); + CheckMutable(); +#endif + return ARROW_PREDICT_TRUE(is_cpu_ && is_mutable_) ? const_cast(data_) + : NULLPTR; + } + + /// \brief Return a writable pointer to the buffer's data cast to a specific type + /// + /// The buffer has to be a mutable CPU buffer (`is_cpu()` and `is_mutable()` + /// are true). Otherwise, an assertion may be thrown or a null pointer may + /// be returned. + template + T* mutable_data_as() { + return reinterpret_cast(mutable_data()); + } + + /// \brief Return the buffer's mutable data as a span + template + util::span mutable_span_as() { + return util::span(mutable_data_as(), static_cast(size() / sizeof(T))); + } + + /// \brief Return the device address of the buffer's data + uintptr_t address() const { return reinterpret_cast(data_); } + + /// \brief Return a writable device address to the buffer's data + /// + /// The buffer has to be a mutable buffer (`is_mutable()` is true). + /// Otherwise, an assertion may be thrown or 0 may be returned. + uintptr_t mutable_address() const { +#ifndef NDEBUG + CheckMutable(); +#endif + return ARROW_PREDICT_TRUE(is_mutable_) ? reinterpret_cast(data_) : 0; + } + + /// \brief Return the buffer's size in bytes + int64_t size() const { return size_; } + + /// \brief Return the buffer's capacity (number of allocated bytes) + int64_t capacity() const { return capacity_; } + + /// \brief Whether the buffer is directly CPU-accessible + /// + /// If this function returns true, you can read directly from the buffer's + /// `data()` pointer. Otherwise, you'll have to `View()` or `Copy()` it. + bool is_cpu() const { return is_cpu_; } + + /// \brief Whether the buffer is mutable + /// + /// If this function returns true, you are allowed to modify buffer contents + /// using the pointer returned by `mutable_data()` or `mutable_address()`. + bool is_mutable() const { return is_mutable_; } + + const std::shared_ptr& device() const { return memory_manager_->device(); } + + const std::shared_ptr& memory_manager() const { return memory_manager_; } + + DeviceAllocationType device_type() const { return device_type_; } + + std::shared_ptr parent() const { return parent_; } + + /// \brief Get a RandomAccessFile for reading a buffer + /// + /// The returned file object reads from this buffer's underlying memory. + static Result> GetReader(std::shared_ptr); + + /// \brief Get a OutputStream for writing to a buffer + /// + /// The buffer must be mutable. The returned stream object writes into the buffer's + /// underlying memory (but it won't resize it). + static Result> GetWriter(std::shared_ptr); + + /// \brief Copy buffer + /// + /// The buffer contents will be copied into a new buffer allocated by the + /// given MemoryManager. This function supports cross-device copies. + static Result> Copy(std::shared_ptr source, + const std::shared_ptr& to); + + /// \brief Copy a non-owned buffer + /// + /// This is useful for cases where the source memory area is externally managed + /// (its lifetime not tied to the source Buffer), otherwise please use Copy(). + static Result> CopyNonOwned( + const Buffer& source, const std::shared_ptr& to); + + /// \brief View buffer + /// + /// Return a Buffer that reflects this buffer, seen potentially from another + /// device, without making an explicit copy of the contents. The underlying + /// mechanism is typically implemented by the kernel or device driver, and may + /// involve lazy caching of parts of the buffer contents on the destination + /// device's memory. + /// + /// If a non-copy view is unsupported for the buffer on the given device, + /// nullptr is returned. An error can be returned if some low-level + /// operation fails (such as an out-of-memory condition). + static Result> View(std::shared_ptr source, + const std::shared_ptr& to); + + /// \brief View or copy buffer + /// + /// Try to view buffer contents on the given MemoryManager's device, but + /// fall back to copying if a no-copy view isn't supported. + static Result> ViewOrCopy( + std::shared_ptr source, const std::shared_ptr& to); + + virtual std::shared_ptr device_sync_event() const { return NULLPTR; } + + protected: + bool is_mutable_; + bool is_cpu_; + const uint8_t* data_; + int64_t size_; + int64_t capacity_; + DeviceAllocationType device_type_; + + // null by default, but may be set + std::shared_ptr parent_; + + private: + // private so that subclasses are forced to call SetMemoryManager() + std::shared_ptr memory_manager_; + + protected: + Buffer(); + + void CheckMutable() const; + void CheckCPU() const; + + void SetMemoryManager(std::shared_ptr mm) { + memory_manager_ = std::move(mm); + is_cpu_ = memory_manager_->is_cpu(); + device_type_ = memory_manager_->device()->device_type(); + } +}; + +/// \defgroup buffer-slicing-functions Functions for slicing buffers +/// +/// @{ + +/// \brief Construct a view on a buffer at the given offset and length. +/// +/// This function cannot fail and does not check for errors (except in debug builds) +static inline std::shared_ptr SliceBuffer(std::shared_ptr buffer, + const int64_t offset, + const int64_t length) { + return std::make_shared(std::move(buffer), offset, length); +} + +/// \brief Construct a view on a buffer at the given offset, up to the buffer's end. +/// +/// This function cannot fail and does not check for errors (except in debug builds) +static inline std::shared_ptr SliceBuffer(std::shared_ptr buffer, + const int64_t offset) { + int64_t length = buffer->size() - offset; + return SliceBuffer(std::move(buffer), offset, length); +} + +/// \brief Input-checking version of SliceBuffer +/// +/// An Invalid Status is returned if the requested slice falls out of bounds. +ARROW_EXPORT +Result> SliceBufferSafe(std::shared_ptr buffer, + int64_t offset); +/// \brief Input-checking version of SliceBuffer +/// +/// An Invalid Status is returned if the requested slice falls out of bounds. +/// Note that unlike SliceBuffer, `length` isn't clamped to the available buffer size. +ARROW_EXPORT +Result> SliceBufferSafe(std::shared_ptr buffer, + int64_t offset, int64_t length); + +/// \brief Like SliceBuffer, but construct a mutable buffer slice. +/// +/// If the parent buffer is not mutable, behavior is undefined (it may abort +/// in debug builds). +ARROW_EXPORT +std::shared_ptr SliceMutableBuffer(std::shared_ptr buffer, + const int64_t offset, const int64_t length); + +/// \brief Like SliceBuffer, but construct a mutable buffer slice. +/// +/// If the parent buffer is not mutable, behavior is undefined (it may abort +/// in debug builds). +static inline std::shared_ptr SliceMutableBuffer(std::shared_ptr buffer, + const int64_t offset) { + int64_t length = buffer->size() - offset; + return SliceMutableBuffer(std::move(buffer), offset, length); +} + +/// \brief Input-checking version of SliceMutableBuffer +/// +/// An Invalid Status is returned if the requested slice falls out of bounds. +ARROW_EXPORT +Result> SliceMutableBufferSafe(std::shared_ptr buffer, + int64_t offset); +/// \brief Input-checking version of SliceMutableBuffer +/// +/// An Invalid Status is returned if the requested slice falls out of bounds. +/// Note that unlike SliceBuffer, `length` isn't clamped to the available buffer size. +ARROW_EXPORT +Result> SliceMutableBufferSafe(std::shared_ptr buffer, + int64_t offset, int64_t length); + +/// @} + +/// \class MutableBuffer +/// \brief A Buffer whose contents can be mutated. May or may not own its data. +class ARROW_EXPORT MutableBuffer : public Buffer { + public: + MutableBuffer(uint8_t* data, const int64_t size) : Buffer(data, size) { + is_mutable_ = true; + } + + MutableBuffer(uint8_t* data, const int64_t size, std::shared_ptr mm) + : Buffer(data, size, std::move(mm)) { + is_mutable_ = true; + } + + MutableBuffer(const std::shared_ptr& parent, const int64_t offset, + const int64_t size); + + /// \brief Create buffer referencing typed memory with some length + /// \param[in] data the typed memory as C array + /// \param[in] length the number of values in the array + /// \return a new shared_ptr + template + static std::shared_ptr Wrap(T* data, SizeType length) { + return std::make_shared(reinterpret_cast(data), + static_cast(sizeof(T) * length)); + } + + protected: + MutableBuffer() : Buffer(NULLPTR, 0) {} +}; + +/// \class ResizableBuffer +/// \brief A mutable buffer that can be resized +class ARROW_EXPORT ResizableBuffer : public MutableBuffer { + public: + /// Change buffer reported size to indicated size, allocating memory if + /// necessary. This will ensure that the capacity of the buffer is a multiple + /// of 64 bytes as defined in Layout.md. + /// Consider using ZeroPadding afterwards, to conform to the Arrow layout + /// specification. + /// + /// @param new_size The new size for the buffer. + /// @param shrink_to_fit Whether to shrink the capacity if new size < current size + virtual Status Resize(const int64_t new_size, bool shrink_to_fit) = 0; + Status Resize(const int64_t new_size) { + return Resize(new_size, /*shrink_to_fit=*/true); + } + + /// Ensure that buffer has enough memory allocated to fit the indicated + /// capacity (and meets the 64 byte padding requirement in Layout.md). + /// It does not change buffer's reported size and doesn't zero the padding. + virtual Status Reserve(const int64_t new_capacity) = 0; + + template + Status TypedResize(const int64_t new_nb_elements, bool shrink_to_fit = true) { + return Resize(sizeof(T) * new_nb_elements, shrink_to_fit); + } + + template + Status TypedReserve(const int64_t new_nb_elements) { + return Reserve(sizeof(T) * new_nb_elements); + } + + protected: + ResizableBuffer(uint8_t* data, int64_t size) : MutableBuffer(data, size) {} + ResizableBuffer(uint8_t* data, int64_t size, std::shared_ptr mm) + : MutableBuffer(data, size, std::move(mm)) {} +}; + +/// \defgroup buffer-allocation-functions Functions for allocating buffers +/// +/// @{ + +/// \brief Allocate a fixed size mutable buffer from a memory pool, zero its padding. +/// +/// \param[in] size size of buffer to allocate +/// \param[in] pool a memory pool +ARROW_EXPORT +Result> AllocateBuffer(const int64_t size, + MemoryPool* pool = NULLPTR); +ARROW_EXPORT +Result> AllocateBuffer(const int64_t size, int64_t alignment, + MemoryPool* pool = NULLPTR); + +/// \brief Allocate a resizeable buffer from a memory pool, zero its padding. +/// +/// \param[in] size size of buffer to allocate +/// \param[in] pool a memory pool +ARROW_EXPORT +Result> AllocateResizableBuffer( + const int64_t size, MemoryPool* pool = NULLPTR); +ARROW_EXPORT +Result> AllocateResizableBuffer( + const int64_t size, const int64_t alignment, MemoryPool* pool = NULLPTR); + +/// \brief Allocate a bitmap buffer from a memory pool +/// no guarantee on values is provided. +/// +/// \param[in] length size in bits of bitmap to allocate +/// \param[in] pool memory pool to allocate memory from +ARROW_EXPORT +Result> AllocateBitmap(int64_t length, + MemoryPool* pool = NULLPTR); + +/// \brief Allocate a zero-initialized bitmap buffer from a memory pool +/// +/// \param[in] length size in bits of bitmap to allocate +/// \param[in] pool memory pool to allocate memory from +ARROW_EXPORT +Result> AllocateEmptyBitmap(int64_t length, + MemoryPool* pool = NULLPTR); + +ARROW_EXPORT +Result> AllocateEmptyBitmap(int64_t length, int64_t alignment, + MemoryPool* pool = NULLPTR); + +/// \brief Concatenate multiple buffers into a single buffer +/// +/// \param[in] buffers to be concatenated +/// \param[in] pool memory pool to allocate the new buffer from +ARROW_EXPORT +Result> ConcatenateBuffers(const BufferVector& buffers, + MemoryPool* pool = NULLPTR); + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/buffer_builder.h b/pyarrow/include/arrow/buffer_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..e9177c656c021939405328fedd7a1e2704212650 --- /dev/null +++ b/pyarrow/include/arrow/buffer_builder.h @@ -0,0 +1,488 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/status.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_generate.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/macros.h" +#include "arrow/util/ubsan.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +// ---------------------------------------------------------------------- +// Buffer builder classes + +/// \class BufferBuilder +/// \brief A class for incrementally building a contiguous chunk of in-memory +/// data +class ARROW_EXPORT BufferBuilder { + public: + explicit BufferBuilder(MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : pool_(pool), + data_(/*ensure never null to make ubsan happy and avoid check penalties below*/ + util::MakeNonNull()), + capacity_(0), + size_(0), + alignment_(alignment) {} + + /// \brief Constructs new Builder that will start using + /// the provided buffer until Finish/Reset are called. + /// The buffer is not resized. + explicit BufferBuilder(std::shared_ptr buffer, + MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : buffer_(std::move(buffer)), + pool_(pool), + data_(buffer_->mutable_data()), + capacity_(buffer_->capacity()), + size_(buffer_->size()), + alignment_(alignment) {} + + /// \brief Resize the buffer to the nearest multiple of 64 bytes + /// + /// \param new_capacity the new capacity of the builder. Will be + /// rounded up to a multiple of 64 bytes for padding + /// \param shrink_to_fit if new capacity is smaller than the existing, + /// reallocate internal buffer. Set to false to avoid reallocations when + /// shrinking the builder. + /// \return Status + Status Resize(const int64_t new_capacity, bool shrink_to_fit = true) { + if (buffer_ == NULLPTR) { + ARROW_ASSIGN_OR_RAISE(buffer_, + AllocateResizableBuffer(new_capacity, alignment_, pool_)); + } else { + ARROW_RETURN_NOT_OK(buffer_->Resize(new_capacity, shrink_to_fit)); + } + capacity_ = buffer_->capacity(); + data_ = buffer_->mutable_data(); + return Status::OK(); + } + + /// \brief Ensure that builder can accommodate the additional number of bytes + /// without the need to perform allocations + /// + /// \param[in] additional_bytes number of additional bytes to make space for + /// \return Status + Status Reserve(const int64_t additional_bytes) { + auto min_capacity = size_ + additional_bytes; + if (min_capacity <= capacity_) { + return Status::OK(); + } + return Resize(GrowByFactor(capacity_, min_capacity), false); + } + + /// \brief Return a capacity expanded by the desired growth factor + static int64_t GrowByFactor(int64_t current_capacity, int64_t new_capacity) { + // Doubling capacity except for large Reserve requests. 2x growth strategy + // (versus 1.5x) seems to have slightly better performance when using + // jemalloc, but significantly better performance when using the system + // allocator. See ARROW-6450 for further discussion + return std::max(new_capacity, current_capacity * 2); + } + + /// \brief Append the given data to the buffer + /// + /// The buffer is automatically expanded if necessary. + Status Append(const void* data, const int64_t length) { + if (ARROW_PREDICT_FALSE(size_ + length > capacity_)) { + ARROW_RETURN_NOT_OK(Resize(GrowByFactor(capacity_, size_ + length), false)); + } + UnsafeAppend(data, length); + return Status::OK(); + } + + /// \brief Append the given data to the buffer + /// + /// The buffer is automatically expanded if necessary. + Status Append(std::string_view v) { return Append(v.data(), v.size()); } + + /// \brief Append copies of a value to the buffer + /// + /// The buffer is automatically expanded if necessary. + Status Append(const int64_t num_copies, uint8_t value) { + ARROW_RETURN_NOT_OK(Reserve(num_copies)); + UnsafeAppend(num_copies, value); + return Status::OK(); + } + + // Advance pointer and zero out memory + Status Advance(const int64_t length) { return Append(length, 0); } + + // Advance pointer, but don't allocate or zero memory + void UnsafeAdvance(const int64_t length) { size_ += length; } + + // Unsafe methods don't check existing size + void UnsafeAppend(const void* data, const int64_t length) { + memcpy(data_ + size_, data, static_cast(length)); + size_ += length; + } + + void UnsafeAppend(std::string_view v) { + UnsafeAppend(v.data(), static_cast(v.size())); + } + + void UnsafeAppend(const int64_t num_copies, uint8_t value) { + memset(data_ + size_, value, static_cast(num_copies)); + size_ += num_copies; + } + + /// \brief Return result of builder as a Buffer object. + /// + /// The builder is reset and can be reused afterwards. + /// + /// \param[out] out the finalized Buffer object + /// \param shrink_to_fit if the buffer size is smaller than its capacity, + /// reallocate to fit more tightly in memory. Set to false to avoid + /// a reallocation, at the expense of potentially more memory consumption. + /// \return Status + Status Finish(std::shared_ptr* out, bool shrink_to_fit = true) { + ARROW_RETURN_NOT_OK(Resize(size_, shrink_to_fit)); + if (size_ != 0) buffer_->ZeroPadding(); + *out = buffer_; + if (*out == NULLPTR) { + ARROW_ASSIGN_OR_RAISE(*out, AllocateBuffer(0, alignment_, pool_)); + } + Reset(); + return Status::OK(); + } + + Result> Finish(bool shrink_to_fit = true) { + std::shared_ptr out; + ARROW_RETURN_NOT_OK(Finish(&out, shrink_to_fit)); + return out; + } + + /// \brief Like Finish, but override the final buffer size + /// + /// This is useful after writing data directly into the builder memory + /// without calling the Append methods (basically, when using BufferBuilder + /// mostly for memory allocation). + Result> FinishWithLength(int64_t final_length, + bool shrink_to_fit = true) { + size_ = final_length; + return Finish(shrink_to_fit); + } + + void Reset() { + buffer_ = NULLPTR; + capacity_ = size_ = 0; + } + + /// \brief Set size to a smaller value without modifying builder + /// contents. For reusable BufferBuilder classes + /// \param[in] position must be non-negative and less than or equal + /// to the current length() + void Rewind(int64_t position) { size_ = position; } + + int64_t capacity() const { return capacity_; } + int64_t length() const { return size_; } + const uint8_t* data() const { return data_; } + uint8_t* mutable_data() { return data_; } + template + const T* data_as() const { + return reinterpret_cast(data_); + } + template + T* mutable_data_as() { + return reinterpret_cast(data_); + } + + private: + std::shared_ptr buffer_; + MemoryPool* pool_; + uint8_t* data_; + int64_t capacity_; + int64_t size_; + int64_t alignment_; +}; + +template +class TypedBufferBuilder; + +/// \brief A BufferBuilder for building a buffer of arithmetic elements +template +class TypedBufferBuilder< + T, typename std::enable_if::value || + std::is_standard_layout::value>::type> { + public: + explicit TypedBufferBuilder(MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : bytes_builder_(pool, alignment) {} + + explicit TypedBufferBuilder(std::shared_ptr buffer, + MemoryPool* pool = default_memory_pool()) + : bytes_builder_(std::move(buffer), pool) {} + + explicit TypedBufferBuilder(BufferBuilder builder) + : bytes_builder_(std::move(builder)) {} + + BufferBuilder* bytes_builder() { return &bytes_builder_; } + + Status Append(T value) { + return bytes_builder_.Append(reinterpret_cast(&value), sizeof(T)); + } + + Status Append(const T* values, int64_t num_elements) { + return bytes_builder_.Append(reinterpret_cast(values), + num_elements * sizeof(T)); + } + + Status Append(const int64_t num_copies, T value) { + ARROW_RETURN_NOT_OK(Reserve(num_copies + length())); + UnsafeAppend(num_copies, value); + return Status::OK(); + } + + void UnsafeAppend(T value) { + bytes_builder_.UnsafeAppend(reinterpret_cast(&value), sizeof(T)); + } + + void UnsafeAppend(const T* values, int64_t num_elements) { + bytes_builder_.UnsafeAppend(reinterpret_cast(values), + num_elements * sizeof(T)); + } + + template + void UnsafeAppend(Iter values_begin, Iter values_end) { + auto num_elements = static_cast(std::distance(values_begin, values_end)); + auto data = mutable_data() + length(); + bytes_builder_.UnsafeAdvance(num_elements * sizeof(T)); + std::copy(values_begin, values_end, data); + } + + void UnsafeAppend(const int64_t num_copies, T value) { + auto data = mutable_data() + length(); + bytes_builder_.UnsafeAdvance(num_copies * sizeof(T)); + std::fill(data, data + num_copies, value); + } + + Status Resize(const int64_t new_capacity, bool shrink_to_fit = true) { + return bytes_builder_.Resize(new_capacity * sizeof(T), shrink_to_fit); + } + + Status Reserve(const int64_t additional_elements) { + return bytes_builder_.Reserve(additional_elements * sizeof(T)); + } + + Status Advance(const int64_t length) { + return bytes_builder_.Advance(length * sizeof(T)); + } + + void UnsafeAdvance(const int64_t length) { + bytes_builder_.UnsafeAdvance(length * sizeof(T)); + } + + Status Finish(std::shared_ptr* out, bool shrink_to_fit = true) { + return bytes_builder_.Finish(out, shrink_to_fit); + } + + Result> Finish(bool shrink_to_fit = true) { + std::shared_ptr out; + ARROW_RETURN_NOT_OK(Finish(&out, shrink_to_fit)); + return out; + } + + /// \brief Like Finish, but override the final buffer size + /// + /// This is useful after writing data directly into the builder memory + /// without calling the Append methods (basically, when using TypedBufferBuilder + /// only for memory allocation). + Result> FinishWithLength(int64_t final_length, + bool shrink_to_fit = true) { + return bytes_builder_.FinishWithLength(final_length * sizeof(T), shrink_to_fit); + } + + void Reset() { bytes_builder_.Reset(); } + + int64_t length() const { return bytes_builder_.length() / sizeof(T); } + int64_t capacity() const { return bytes_builder_.capacity() / sizeof(T); } + const T* data() const { return reinterpret_cast(bytes_builder_.data()); } + T* mutable_data() { return reinterpret_cast(bytes_builder_.mutable_data()); } + + private: + BufferBuilder bytes_builder_; +}; + +/// \brief A BufferBuilder for building a buffer containing a bitmap +template <> +class TypedBufferBuilder { + public: + explicit TypedBufferBuilder(MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : bytes_builder_(pool, alignment) {} + + explicit TypedBufferBuilder(BufferBuilder builder) + : bytes_builder_(std::move(builder)) {} + + BufferBuilder* bytes_builder() { return &bytes_builder_; } + + Status Append(bool value) { + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppend(value); + return Status::OK(); + } + + Status Append(const uint8_t* valid_bytes, int64_t num_elements) { + ARROW_RETURN_NOT_OK(Reserve(num_elements)); + UnsafeAppend(valid_bytes, num_elements); + return Status::OK(); + } + + Status Append(const int64_t num_copies, bool value) { + ARROW_RETURN_NOT_OK(Reserve(num_copies)); + UnsafeAppend(num_copies, value); + return Status::OK(); + } + + void UnsafeAppend(bool value) { + bit_util::SetBitTo(mutable_data(), bit_length_, value); + if (!value) { + ++false_count_; + } + ++bit_length_; + } + + /// \brief Append bits from an array of bytes (one value per byte) + void UnsafeAppend(const uint8_t* bytes, int64_t num_elements) { + if (num_elements == 0) return; + int64_t i = 0; + internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements, [&] { + bool value = bytes[i++]; + false_count_ += !value; + return value; + }); + bit_length_ += num_elements; + } + + /// \brief Append bits from a packed bitmap + void UnsafeAppend(const uint8_t* bitmap, int64_t offset, int64_t num_elements) { + if (num_elements == 0) return; + internal::CopyBitmap(bitmap, offset, num_elements, mutable_data(), bit_length_); + false_count_ += num_elements - internal::CountSetBits(bitmap, offset, num_elements); + bit_length_ += num_elements; + } + + void UnsafeAppend(const int64_t num_copies, bool value) { + bit_util::SetBitsTo(mutable_data(), bit_length_, num_copies, value); + false_count_ += num_copies * !value; + bit_length_ += num_copies; + } + + template + void UnsafeAppend(const int64_t num_elements, Generator&& gen) { + if (num_elements == 0) return; + + if (count_falses) { + internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements, [&] { + bool value = gen(); + false_count_ += !value; + return value; + }); + } else { + internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements, + std::forward(gen)); + } + bit_length_ += num_elements; + } + + Status Resize(const int64_t new_capacity, bool shrink_to_fit = true) { + const int64_t old_byte_capacity = bytes_builder_.capacity(); + ARROW_RETURN_NOT_OK( + bytes_builder_.Resize(bit_util::BytesForBits(new_capacity), shrink_to_fit)); + // Resize() may have chosen a larger capacity (e.g. for padding), + // so ask it again before calling memset(). + const int64_t new_byte_capacity = bytes_builder_.capacity(); + if (new_byte_capacity > old_byte_capacity) { + // The additional buffer space is 0-initialized for convenience, + // so that other methods can simply bump the length. + memset(mutable_data() + old_byte_capacity, 0, + static_cast(new_byte_capacity - old_byte_capacity)); + } + return Status::OK(); + } + + Status Reserve(const int64_t additional_elements) { + return Resize( + BufferBuilder::GrowByFactor(bit_length_, bit_length_ + additional_elements), + false); + } + + Status Advance(const int64_t length) { + ARROW_RETURN_NOT_OK(Reserve(length)); + bit_length_ += length; + false_count_ += length; + return Status::OK(); + } + + Status Finish(std::shared_ptr* out, bool shrink_to_fit = true) { + // set bytes_builder_.size_ == byte size of data + bytes_builder_.UnsafeAdvance(bit_util::BytesForBits(bit_length_) - + bytes_builder_.length()); + bit_length_ = false_count_ = 0; + return bytes_builder_.Finish(out, shrink_to_fit); + } + + Result> Finish(bool shrink_to_fit = true) { + std::shared_ptr out; + ARROW_RETURN_NOT_OK(Finish(&out, shrink_to_fit)); + return out; + } + + /// \brief Like Finish, but override the final buffer size + /// + /// This is useful after writing data directly into the builder memory + /// without calling the Append methods (basically, when using TypedBufferBuilder + /// only for memory allocation). + Result> FinishWithLength(int64_t final_length, + bool shrink_to_fit = true) { + const auto final_byte_length = bit_util::BytesForBits(final_length); + bytes_builder_.UnsafeAdvance(final_byte_length - bytes_builder_.length()); + bit_length_ = false_count_ = 0; + return bytes_builder_.FinishWithLength(final_byte_length, shrink_to_fit); + } + + void Reset() { + bytes_builder_.Reset(); + bit_length_ = false_count_ = 0; + } + + int64_t length() const { return bit_length_; } + int64_t capacity() const { return bytes_builder_.capacity() * 8; } + const uint8_t* data() const { return bytes_builder_.data(); } + uint8_t* mutable_data() { return bytes_builder_.mutable_data(); } + int64_t false_count() const { return false_count_; } + + private: + BufferBuilder bytes_builder_; + int64_t bit_length_ = 0; + int64_t false_count_ = 0; +}; + +} // namespace arrow diff --git a/pyarrow/include/arrow/builder.h b/pyarrow/include/arrow/builder.h new file mode 100644 index 0000000000000000000000000000000000000000..f0aa14c1e0612d1872a5959998651a12668f449f --- /dev/null +++ b/pyarrow/include/arrow/builder.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/array/builder_adaptive.h" // IWYU pragma: keep +#include "arrow/array/builder_base.h" // IWYU pragma: keep +#include "arrow/array/builder_binary.h" // IWYU pragma: keep +#include "arrow/array/builder_decimal.h" // IWYU pragma: keep +#include "arrow/array/builder_dict.h" // IWYU pragma: keep +#include "arrow/array/builder_nested.h" // IWYU pragma: keep +#include "arrow/array/builder_primitive.h" // IWYU pragma: keep +#include "arrow/array/builder_run_end.h" // IWYU pragma: keep +#include "arrow/array/builder_time.h" // IWYU pragma: keep +#include "arrow/array/builder_union.h" // IWYU pragma: keep +#include "arrow/status.h" +#include "arrow/util/visibility.h" diff --git a/pyarrow/include/arrow/c/abi.h b/pyarrow/include/arrow/c/abi.h new file mode 100644 index 0000000000000000000000000000000000000000..ae632f2dbd2601135cb02bc203dd085afd0acaf7 --- /dev/null +++ b/pyarrow/include/arrow/c/abi.h @@ -0,0 +1,460 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// \file abi.h Arrow C Data Interface +/// +/// The Arrow C Data interface defines a very small, stable set +/// of C definitions which can be easily copied into any project's +/// source code and vendored to be used for columnar data interchange +/// in the Arrow format. For non-C/C++ languages and runtimes, +/// it should be almost as easy to translate the C definitions into +/// the corresponding C FFI declarations. +/// +/// Applications and libraries can therefore work with Arrow memory +/// without necessarily using the Arrow libraries or reinventing +/// the wheel. Developers can choose between tight integration +/// with the Arrow software project or minimal integration with +/// the Arrow format only. + +#pragma once + +#include + +// Spec and documentation: https://arrow.apache.org/docs/format/CDataInterface.html + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef ARROW_C_DATA_INTERFACE +# define ARROW_C_DATA_INTERFACE + +# define ARROW_FLAG_DICTIONARY_ORDERED 1 +# define ARROW_FLAG_NULLABLE 2 +# define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +# define ARROW_STATISTICS_KEY_AVERAGE_BYTE_WIDTH_EXACT "ARROW:average_byte_width:exact" +# define ARROW_STATISTICS_KEY_AVERAGE_BYTE_WIDTH_APPROXIMATE \ + "ARROW:average_byte_width:approximate" +# define ARROW_STATISTICS_KEY_DISTINCT_COUNT_EXACT "ARROW:distinct_count:exact" +# define ARROW_STATISTICS_KEY_DISTINCT_COUNT_APPROXIMATE \ + "ARROW:distinct_count:approximate" +# define ARROW_STATISTICS_KEY_MAX_BYTE_WIDTH_EXACT "ARROW:max_byte_width:exact" +# define ARROW_STATISTICS_KEY_MAX_BYTE_WIDTH_APPROXIMATE \ + "ARROW:max_byte_width:approximate" +# define ARROW_STATISTICS_KEY_MAX_VALUE_EXACT "ARROW:max_value:exact" +# define ARROW_STATISTICS_KEY_MAX_VALUE_APPROXIMATE "ARROW:max_value:approximate" +# define ARROW_STATISTICS_KEY_MIN_VALUE_EXACT "ARROW:min_value:exact" +# define ARROW_STATISTICS_KEY_MIN_VALUE_APPROXIMATE "ARROW:min_value:approximate" +# define ARROW_STATISTICS_KEY_NULL_COUNT_EXACT "ARROW:null_count:exact" +# define ARROW_STATISTICS_KEY_NULL_COUNT_APPROXIMATE "ARROW:null_count:approximate" +# define ARROW_STATISTICS_KEY_ROW_COUNT_EXACT "ARROW:row_count:exact" +# define ARROW_STATISTICS_KEY_ROW_COUNT_APPROXIMATE "ARROW:row_count:approximate" + +#endif // ARROW_C_DATA_INTERFACE + +#ifndef ARROW_C_DEVICE_DATA_INTERFACE +# define ARROW_C_DEVICE_DATA_INTERFACE + +// Spec and Documentation: https://arrow.apache.org/docs/format/CDeviceDataInterface.html + +// DeviceType for the allocated memory +typedef int32_t ArrowDeviceType; + +// CPU device, same as using ArrowArray directly +# define ARROW_DEVICE_CPU 1 +// CUDA GPU Device +# define ARROW_DEVICE_CUDA 2 +// Pinned CUDA CPU memory by cudaMallocHost +# define ARROW_DEVICE_CUDA_HOST 3 +// OpenCL Device +# define ARROW_DEVICE_OPENCL 4 +// Vulkan buffer for next-gen graphics +# define ARROW_DEVICE_VULKAN 7 +// Metal for Apple GPU +# define ARROW_DEVICE_METAL 8 +// Verilog simulator buffer +# define ARROW_DEVICE_VPI 9 +// ROCm GPUs for AMD GPUs +# define ARROW_DEVICE_ROCM 10 +// Pinned ROCm CPU memory allocated by hipMallocHost +# define ARROW_DEVICE_ROCM_HOST 11 +// Reserved for extension +# define ARROW_DEVICE_EXT_DEV 12 +// CUDA managed/unified memory allocated by cudaMallocManaged +# define ARROW_DEVICE_CUDA_MANAGED 13 +// unified shared memory allocated on a oneAPI non-partitioned device. +# define ARROW_DEVICE_ONEAPI 14 +// GPU support for next-gen WebGPU standard +# define ARROW_DEVICE_WEBGPU 15 +// Qualcomm Hexagon DSP +# define ARROW_DEVICE_HEXAGON 16 + +struct ArrowDeviceArray { + // the Allocated Array + // + // the buffers in the array (along with the buffers of any + // children) are what is allocated on the device. + struct ArrowArray array; + // The device id to identify a specific device + int64_t device_id; + // The type of device which can access this memory. + ArrowDeviceType device_type; + // An event-like object to synchronize on if needed. + void* sync_event; + // Reserved bytes for future expansion. + int64_t reserved[3]; +}; + +#endif // ARROW_C_DEVICE_DATA_INTERFACE + +#ifndef ARROW_C_STREAM_INTERFACE +# define ARROW_C_STREAM_INTERFACE + +struct ArrowArrayStream { + // Callback to get the stream type + // (will be the same for all arrays in the stream). + // + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowSchema must be released independently from the stream. + int (*get_schema)(struct ArrowArrayStream*, struct ArrowSchema* out); + + // Callback to get the next array + // (if no error and the array is released, the stream has ended) + // + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowArray must be released independently from the stream. + int (*get_next)(struct ArrowArrayStream*, struct ArrowArray* out); + + // Callback to get optional detailed error information. + // This must only be called if the last stream operation failed + // with a non-0 return code. + // + // Return value: pointer to a null-terminated character array describing + // the last error, or NULL if no description is available. + // + // The returned pointer is only valid until the next operation on this stream + // (including release). + const char* (*get_last_error)(struct ArrowArrayStream*); + + // Release callback: release the stream's own resources. + // Note that arrays returned by `get_next` must be individually released. + void (*release)(struct ArrowArrayStream*); + + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_STREAM_INTERFACE + +#ifndef ARROW_C_DEVICE_STREAM_INTERFACE +# define ARROW_C_DEVICE_STREAM_INTERFACE + +// Equivalent to ArrowArrayStream, but for ArrowDeviceArrays. +// +// This stream is intended to provide a stream of data on a single +// device, if a producer wants data to be produced on multiple devices +// then multiple streams should be provided. One per device. +struct ArrowDeviceArrayStream { + // The device that this stream produces data on. + ArrowDeviceType device_type; + + // Callback to get the stream schema + // (will be the same for all arrays in the stream). + // + // Return value 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowSchema must be released independently from the stream. + // The schema should be accessible via CPU memory. + int (*get_schema)(struct ArrowDeviceArrayStream* self, struct ArrowSchema* out); + + // Callback to get the next array + // (if no error and the array is released, the stream has ended) + // + // Return value: 0 if successful, an `errno`-compatible error code otherwise. + // + // If successful, the ArrowDeviceArray must be released independently from the stream. + int (*get_next)(struct ArrowDeviceArrayStream* self, struct ArrowDeviceArray* out); + + // Callback to get optional detailed error information. + // This must only be called if the last stream operation failed + // with a non-0 return code. + // + // Return value: pointer to a null-terminated character array describing + // the last error, or NULL if no description is available. + // + // The returned pointer is only valid until the next operation on this stream + // (including release). + const char* (*get_last_error)(struct ArrowDeviceArrayStream* self); + + // Release callback: release the stream's own resources. + // Note that arrays returned by `get_next` must be individually released. + void (*release)(struct ArrowDeviceArrayStream* self); + + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_DEVICE_STREAM_INTERFACE + +#ifndef ARROW_C_ASYNC_STREAM_INTERFACE +# define ARROW_C_ASYNC_STREAM_INTERFACE + +// EXPERIMENTAL: ArrowAsyncTask represents available data from a producer that was passed +// to an invocation of `on_next_task` on the ArrowAsyncDeviceStreamHandler. +// +// The reason for this Task approach instead of the Async interface returning +// the Array directly is to allow for more complex thread handling and reducing +// context switching and data transfers between CPU cores (e.g. from one L1/L2 +// cache to another) if desired. +// +// For example, the `on_next_task` callback can be called when data is ready, while +// the producer puts potential "decoding" logic in the `ArrowAsyncTask` object. This +// allows for the producer to manage the I/O on one thread which calls `on_next_task` +// and the consumer can determine when the decoding (producer logic in the `extract_data` +// callback of the task) occurs and on which thread, to avoid a CPU core transfer +// (data staying in the L2 cache). +struct ArrowAsyncTask { + // This callback should populate the ArrowDeviceArray associated with this task. + // The order of ArrowAsyncTasks provided by the producer enables a consumer to + // ensure the order of data to process. + // + // This function is expected to be synchronous, but should not perform any blocking + // I/O. Ideally it should be as cheap as possible so as to not tie up the consumer + // thread unnecessarily. + // + // Returns: 0 if successful, errno-compatible error otherwise. + // + // If a non-0 value is returned then it should be followed by a call to `on_error` + // on the appropriate ArrowAsyncDeviceStreamHandler. This is because it's highly + // likely that whatever is calling this function may be entirely disconnected from + // the current control flow. Indicating an error here with a non-zero return allows + // the current flow to be aware of the error occurring, while still allowing any + // logging or error handling to still be centralized in the `on_error` callback of + // the original Async handler. + // + // Rather than a release callback, any required cleanup should be performed as part + // of the invocation of `extract_data`. Ownership of the Array is passed to the consumer + // calling this, and so it must be released separately. + // + // It is only valid to call this method exactly once. + int (*extract_data)(struct ArrowAsyncTask* self, struct ArrowDeviceArray* out); + + // opaque task-specific data + void* private_data; +}; + +// EXPERIMENTAL: ArrowAsyncProducer represents a 1-to-1 relationship between an async +// producer and consumer. This object allows the consumer to perform backpressure and flow +// control on the asynchronous stream processing. This object must be owned by the +// producer who creates it, and thus is responsible for cleaning it up. +struct ArrowAsyncProducer { + // The device type that this stream produces data on. + ArrowDeviceType device_type; + + // A consumer must call this function to start receiving on_next_task calls. + // + // It *must* be valid to call this synchronously from within `on_next_task` or + // `on_schema`, but this function *must not* immediately call `on_next_task` so as + // to avoid recursion and reentrant callbacks. + // + // After cancel has been called, additional calls to this function must be NOPs, + // but allowed. While not cancelled, calling this function must register the + // given number of additional arrays/batches to be produced with the producer. + // The producer should only call `on_next_task` at most the registered number + // of arrays before propagating backpressure. + // + // Any error encountered by calling request must be propagated by calling the `on_error` + // callback of the ArrowAsyncDeviceStreamHandler. + // + // While not cancelled, any subsequent calls to `on_next_task`, `on_error` or + // `release` should be scheduled by the producer to be called later. + // + // It is invalid for a consumer to call this with a value of n <= 0, producers should + // error if given such a value. + void (*request)(struct ArrowAsyncProducer* self, int64_t n); + + // This cancel callback signals a producer that it must eventually stop making calls + // to on_next_task. It must be idempotent and thread-safe. After calling cancel once, + // subsequent calls must be NOPs. This must not call any consumer-side handlers other + // than `on_error`. + // + // It is not required that calling cancel affect the producer immediately, only that it + // must eventually stop calling on_next_task and subsequently call release on the + // async handler. As such, a consumer must be prepared to receive one or more calls to + // `on_next_task` even after calling cancel if there are still requested arrays pending. + // + // Successful cancellation should *not* result in the producer calling `on_error`, it + // should finish out any remaining tasks and eventually call `release`. + // + // Any error encountered during handling a call to cancel must be reported via the + // on_error callback on the async stream handler. + void (*cancel)(struct ArrowAsyncProducer* self); + + // Any additional metadata tied to a specific stream of data. This must either be NULL + // or a valid pointer to metadata which is encoded in the same way schema metadata + // would be. Non-null metadata must be valid for the lifetime of this object. As an + // example a producer could use this to provide the total number of rows and/or batches + // in the stream if known. + const char* additional_metadata; + + // producer-specific opaque data. + void* private_data; +}; + +// EXPERIMENTAL: Similar to ArrowDeviceArrayStream, except designed for an asynchronous +// style of interaction. While ArrowDeviceArrayStream provides producer +// defined callbacks, this is intended to be created by the consumer instead. +// The consumer passes this handler to the producer, which in turn uses the +// callbacks to inform the consumer of events in the stream. +struct ArrowAsyncDeviceStreamHandler { + // Handler for receiving a schema. The passed in stream_schema must be + // released or moved by the handler (producer is giving ownership of the schema to + // the handler, but not ownership of the top level object itself). + // + // With the exception of an error occurring (on_error), this must be the first + // callback function which is called by a producer and must only be called exactly + // once. As such, the producer should provide a valid ArrowAsyncProducer instance + // so the consumer can control the flow. See the documentation on ArrowAsyncProducer + // for how it works. The ArrowAsyncProducer is owned by the producer who calls this + // function and thus the producer is responsible for cleaning it up when calling + // the release callback of this handler. + // + // If there is any additional metadata tied to this stream, it will be provided as + // a non-null value for the `additional_metadata` field of the ArrowAsyncProducer + // which will be valid at least until the release callback is called. + // + // Return value: 0 if successful, `errno`-compatible error otherwise + // + // A producer that receives a non-zero return here should stop producing and eventually + // call release instead. + int (*on_schema)(struct ArrowAsyncDeviceStreamHandler* self, + struct ArrowSchema* stream_schema); + + // Handler for receiving data. This is called when data is available providing an + // ArrowAsyncTask struct to signify it. The producer indicates the end of the stream + // by passing NULL as the value for the task rather than a valid pointer to a task. + // The task object is only valid for the lifetime of this function call, if a consumer + // wants to utilize it after this function returns, it must copy or move the contents + // of it to a new ArrowAsyncTask object. + // + // The `request` callback of a provided ArrowAsyncProducer must be called in order + // to start receiving calls to this handler. + // + // The metadata argument can be null or can be used by a producer + // to pass arbitrary extra information to the consumer (such as total number + // of rows, context info, or otherwise). The data should be passed using the same + // encoding as the metadata within the ArrowSchema struct itself (defined in + // the spec at + // https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.metadata) + // + // If metadata is non-null then it only needs to exist for the lifetime of this call, + // a consumer who wants it to live after that must copy it to ensure lifetime. + // + // A producer *must not* call this concurrently from multiple different threads. + // + // A consumer must be prepared to receive one or more calls to this callback even + // after calling cancel on the corresponding ArrowAsyncProducer, as cancel does not + // guarantee it happens immediately. + // + // Return value: 0 if successful, `errno`-compatible error otherwise. + // + // If the consumer returns a non-zero return from this method, that indicates to the + // producer that it should stop propagating data as an error occurred. After receiving + // such a return, the only interaction with this object is for the producer to call + // the `release` callback. + int (*on_next_task)(struct ArrowAsyncDeviceStreamHandler* self, + struct ArrowAsyncTask* task, const char* metadata); + + // Handler for encountering an error. The producer should call release after + // this returns to clean up any resources. The `code` passed in can be any error + // code that a producer wants, but should be errno-compatible for consistency. + // + // If the message or metadata are non-null, they will only last as long as this + // function call. The consumer would need to perform a copy of the data if it is + // necessary for them to live past the lifetime of this call. + // + // Error metadata should be encoded as with metadata in ArrowSchema, defined in + // the spec at + // https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema.metadata + // + // It is valid for this to be called by a producer with or without a preceding call + // to ArrowAsyncProducer.request. + // + // This callback must not call any methods of an ArrowAsyncProducer object. + void (*on_error)(struct ArrowAsyncDeviceStreamHandler* self, int code, + const char* message, const char* metadata); + + // Release callback to release any resources for the handler. Should always be + // called by a producer when it is done utilizing a handler. No callbacks should + // be called after this is called. + // + // It is valid for the release callback to be called by a producer with or without + // a preceding call to ArrowAsyncProducer.request. + // + // The release callback must not call any methods of an ArrowAsyncProducer object. + void (*release)(struct ArrowAsyncDeviceStreamHandler* self); + + // MUST be populated by the producer BEFORE calling any callbacks other than release. + // This provides the connection between a handler and its producer, and must exist until + // the release callback is called. + struct ArrowAsyncProducer* producer; + + // Opaque handler-specific data + void* private_data; +}; + +#endif // ARROW_C_ASYNC_STREAM_INTERFACE + +#ifdef __cplusplus +} +#endif diff --git a/pyarrow/include/arrow/c/bridge.h b/pyarrow/include/arrow/c/bridge.h new file mode 100644 index 0000000000000000000000000000000000000000..78860e0650e741a95e7f8bc0c5ab35bc1c01cf79 --- /dev/null +++ b/pyarrow/include/arrow/c/bridge.h @@ -0,0 +1,489 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/c/abi.h" +#include "arrow/device.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/async_generator_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +/// \defgroup c-data-interface Functions for working with the C data interface. +/// +/// @{ + +/// \brief Export C++ DataType using the C data interface format. +/// +/// The root type is considered to have empty name and metadata. +/// If you want the root type to have a name and/or metadata, pass +/// a Field instead. +/// +/// \param[in] type DataType object to export +/// \param[out] out C struct where to export the datatype +ARROW_EXPORT +Status ExportType(const DataType& type, struct ArrowSchema* out); + +/// \brief Export C++ Field using the C data interface format. +/// +/// \param[in] field Field object to export +/// \param[out] out C struct where to export the field +ARROW_EXPORT +Status ExportField(const Field& field, struct ArrowSchema* out); + +/// \brief Export C++ Schema using the C data interface format. +/// +/// \param[in] schema Schema object to export +/// \param[out] out C struct where to export the field +ARROW_EXPORT +Status ExportSchema(const Schema& schema, struct ArrowSchema* out); + +/// \brief Export C++ Array using the C data interface format. +/// +/// The resulting ArrowArray struct keeps the array data and buffers alive +/// until its release callback is called by the consumer. +/// +/// \param[in] array Array object to export +/// \param[out] out C struct where to export the array +/// \param[out] out_schema optional C struct where to export the array type +ARROW_EXPORT +Status ExportArray(const Array& array, struct ArrowArray* out, + struct ArrowSchema* out_schema = NULLPTR); + +/// \brief Export C++ RecordBatch using the C data interface format. +/// +/// The record batch is exported as if it were a struct array. +/// The resulting ArrowArray struct keeps the record batch data and buffers alive +/// until its release callback is called by the consumer. +/// +/// \param[in] batch Record batch to export +/// \param[out] out C struct where to export the record batch +/// \param[out] out_schema optional C struct where to export the record batch schema +ARROW_EXPORT +Status ExportRecordBatch(const RecordBatch& batch, struct ArrowArray* out, + struct ArrowSchema* out_schema = NULLPTR); + +/// \brief Import C++ DataType from the C data interface. +/// +/// The given ArrowSchema struct is released (as per the C data interface +/// specification), even if this function fails. +/// +/// \param[in,out] schema C data interface struct representing the data type +/// \return Imported type object +ARROW_EXPORT +Result> ImportType(struct ArrowSchema* schema); + +/// \brief Import C++ Field from the C data interface. +/// +/// The given ArrowSchema struct is released (as per the C data interface +/// specification), even if this function fails. +/// +/// \param[in,out] schema C data interface struct representing the field +/// \return Imported field object +ARROW_EXPORT +Result> ImportField(struct ArrowSchema* schema); + +/// \brief Import C++ Schema from the C data interface. +/// +/// The given ArrowSchema struct is released (as per the C data interface +/// specification), even if this function fails. +/// +/// \param[in,out] schema C data interface struct representing the field +/// \return Imported field object +ARROW_EXPORT +Result> ImportSchema(struct ArrowSchema* schema); + +/// \brief Import C++ array from the C data interface. +/// +/// The ArrowArray struct has its contents moved (as per the C data interface +/// specification) to a private object held alive by the resulting array. +/// +/// \param[in,out] array C data interface struct holding the array data +/// \param[in] type type of the imported array +/// \return Imported array object +ARROW_EXPORT +Result> ImportArray(struct ArrowArray* array, + std::shared_ptr type); + +/// \brief Import C++ array and its type from the C data interface. +/// +/// The ArrowArray struct has its contents moved (as per the C data interface +/// specification) to a private object held alive by the resulting array. +/// The ArrowSchema struct is released, even if this function fails. +/// +/// \param[in,out] array C data interface struct holding the array data +/// \param[in,out] type C data interface struct holding the array type +/// \return Imported array object +ARROW_EXPORT +Result> ImportArray(struct ArrowArray* array, + struct ArrowSchema* type); + +/// \brief Import C++ record batch from the C data interface. +/// +/// The ArrowArray struct has its contents moved (as per the C data interface +/// specification) to a private object held alive by the resulting record batch. +/// +/// \param[in,out] array C data interface struct holding the record batch data +/// \param[in] schema schema of the imported record batch +/// \return Imported record batch object +ARROW_EXPORT +Result> ImportRecordBatch(struct ArrowArray* array, + std::shared_ptr schema); + +/// \brief Import C++ record batch and its schema from the C data interface. +/// +/// The type represented by the ArrowSchema struct must be a struct type array. +/// The ArrowArray struct has its contents moved (as per the C data interface +/// specification) to a private object held alive by the resulting record batch. +/// The ArrowSchema struct is released, even if this function fails. +/// +/// \param[in,out] array C data interface struct holding the record batch data +/// \param[in,out] schema C data interface struct holding the record batch schema +/// \return Imported record batch object +ARROW_EXPORT +Result> ImportRecordBatch(struct ArrowArray* array, + struct ArrowSchema* schema); + +/// @} + +/// \defgroup c-data-device-interface Functions for working with the C data device +/// interface. +/// +/// @{ + +/// \brief EXPERIMENTAL: Export C++ Array as an ArrowDeviceArray. +/// +/// The resulting ArrowDeviceArray struct keeps the array data and buffers alive +/// until its release callback is called by the consumer. All buffers in +/// the provided array MUST have the same device_type, otherwise an error +/// will be returned. +/// +/// If sync is non-null, get_event will be called on it in order to +/// potentially provide an event for consumers to synchronize on. +/// +/// \param[in] array Array object to export +/// \param[in] sync shared_ptr to object derived from Device::SyncEvent or null +/// \param[out] out C struct to export the array to +/// \param[out] out_schema optional C struct to export the array type to +ARROW_EXPORT +Status ExportDeviceArray(const Array& array, std::shared_ptr sync, + struct ArrowDeviceArray* out, + struct ArrowSchema* out_schema = NULLPTR); + +/// \brief EXPERIMENTAL: Export C++ RecordBatch as an ArrowDeviceArray. +/// +/// The record batch is exported as if it were a struct array. +/// The resulting ArrowDeviceArray struct keeps the record batch data and buffers alive +/// until its release callback is called by the consumer. +/// +/// All buffers of all columns in the record batch must have the same device_type +/// otherwise an error will be returned. If columns are on different devices, +/// they should be exported using different ArrowDeviceArray instances. +/// +/// If sync is non-null, get_event will be called on it in order to +/// potentially provide an event for consumers to synchronize on. +/// +/// \param[in] batch Record batch to export +/// \param[in] sync shared_ptr to object derived from Device::SyncEvent or null +/// \param[out] out C struct where to export the record batch +/// \param[out] out_schema optional C struct where to export the record batch schema +ARROW_EXPORT +Status ExportDeviceRecordBatch(const RecordBatch& batch, + std::shared_ptr sync, + struct ArrowDeviceArray* out, + struct ArrowSchema* out_schema = NULLPTR); + +using DeviceMemoryMapper = + std::function>(ArrowDeviceType, int64_t)>; + +ARROW_EXPORT +Result> DefaultDeviceMemoryMapper( + ArrowDeviceType device_type, int64_t device_id); + +/// \brief EXPERIMENTAL: Import C++ device array from the C data interface. +/// +/// The ArrowArray struct has its contents moved (as per the C data interface +/// specification) to a private object held alive by the resulting array. The +/// buffers of the Array are located on the device indicated by the device_type. +/// +/// \param[in,out] array C data interface struct holding the array data +/// \param[in] type type of the imported array +/// \param[in] mapper A function to map device + id to memory manager. If not +/// specified, defaults to map "cpu" to the built-in default memory manager. +/// \return Imported array object +ARROW_EXPORT +Result> ImportDeviceArray( + struct ArrowDeviceArray* array, std::shared_ptr type, + const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper); + +/// \brief EXPERIMENTAL: Import C++ device array and its type from the C data interface. +/// +/// The ArrowArray struct has its contents moved (as per the C data interface +/// specification) to a private object held alive by the resulting array. +/// The ArrowSchema struct is released, even if this function fails. The +/// buffers of the Array are located on the device indicated by the device_type. +/// +/// \param[in,out] array C data interface struct holding the array data +/// \param[in,out] type C data interface struct holding the array type +/// \param[in] mapper A function to map device + id to memory manager. If not +/// specified, defaults to map "cpu" to the built-in default memory manager. +/// \return Imported array object +ARROW_EXPORT +Result> ImportDeviceArray( + struct ArrowDeviceArray* array, struct ArrowSchema* type, + const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper); + +/// \brief EXPERIMENTAL: Import C++ record batch with buffers on a device from the C data +/// interface. +/// +/// The ArrowArray struct has its contents moved (as per the C data interface +/// specification) to a private object held alive by the resulting record batch. +/// The buffers of all columns of the record batch are located on the device +/// indicated by the device type. +/// +/// \param[in,out] array C data interface struct holding the record batch data +/// \param[in] schema schema of the imported record batch +/// \param[in] mapper A function to map device + id to memory manager. If not +/// specified, defaults to map "cpu" to the built-in default memory manager. +/// \return Imported record batch object +ARROW_EXPORT +Result> ImportDeviceRecordBatch( + struct ArrowDeviceArray* array, std::shared_ptr schema, + const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper); + +/// \brief EXPERIMENTAL: Import C++ record batch with buffers on a device and its schema +/// from the C data interface. +/// +/// The type represented by the ArrowSchema struct must be a struct type array. +/// The ArrowArray struct has its contents moved (as per the C data interface +/// specification) to a private object held alive by the resulting record batch. +/// The ArrowSchema struct is released, even if this function fails. The buffers +/// of all columns of the record batch are located on the device indicated by the +/// device type. +/// +/// \param[in,out] array C data interface struct holding the record batch data +/// \param[in,out] schema C data interface struct holding the record batch schema +/// \param[in] mapper A function to map device + id to memory manager. If not +/// specified, defaults to map "cpu" to the built-in default memory manager. +/// \return Imported record batch object +ARROW_EXPORT +Result> ImportDeviceRecordBatch( + struct ArrowDeviceArray* array, struct ArrowSchema* schema, + const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper); + +/// @} + +/// \defgroup c-stream-interface Functions for working with the C data interface. +/// +/// @{ + +/// \brief Export C++ RecordBatchReader using the C stream interface. +/// +/// The resulting ArrowArrayStream struct keeps the record batch reader alive +/// until its release callback is called by the consumer. +/// +/// \param[in] reader RecordBatchReader object to export +/// \param[out] out C struct where to export the stream +ARROW_EXPORT +Status ExportRecordBatchReader(std::shared_ptr reader, + struct ArrowArrayStream* out); + +/// \brief Export C++ ChunkedArray using the C data interface format. +/// +/// The resulting ArrowArrayStream struct keeps the chunked array data and buffers alive +/// until its release callback is called by the consumer. +/// +/// \param[in] chunked_array ChunkedArray object to export +/// \param[out] out C struct where to export the stream +ARROW_EXPORT +Status ExportChunkedArray(std::shared_ptr chunked_array, + struct ArrowArrayStream* out); + +/// \brief Export C++ RecordBatchReader using the C device stream interface +/// +/// The resulting ArrowDeviceArrayStream struct keeps the record batch reader +/// alive until its release callback is called by the consumer. The device +/// type is determined by calling device_type() on the RecordBatchReader. +/// +/// \param[in] reader RecordBatchReader object to export +/// \param[out] out C struct to export the stream to +ARROW_EXPORT +Status ExportDeviceRecordBatchReader(std::shared_ptr reader, + struct ArrowDeviceArrayStream* out); + +/// \brief Export C++ ChunkedArray using the C device data interface format. +/// +/// The resulting ArrowDeviceArrayStream keeps the chunked array data and buffers +/// alive until its release callback is called by the consumer. +/// +/// \param[in] chunked_array ChunkedArray object to export +/// \param[in] device_type the device type the data is located on +/// \param[out] out C struct to export the stream to +ARROW_EXPORT +Status ExportDeviceChunkedArray(std::shared_ptr chunked_array, + DeviceAllocationType device_type, + struct ArrowDeviceArrayStream* out); + +/// \brief Import C++ RecordBatchReader from the C stream interface. +/// +/// The ArrowArrayStream struct has its contents moved to a private object +/// held alive by the resulting record batch reader. +/// +/// \param[in,out] stream C stream interface struct +/// \return Imported RecordBatchReader object +ARROW_EXPORT +Result> ImportRecordBatchReader( + struct ArrowArrayStream* stream); + +/// \brief Import C++ ChunkedArray from the C stream interface +/// +/// The ArrowArrayStream struct has its contents moved to a private object, +/// is consumed in its entirity, and released before returning all chunks +/// as a ChunkedArray. +/// +/// \param[in,out] stream C stream interface struct +/// \return Imported ChunkedArray object +ARROW_EXPORT +Result> ImportChunkedArray(struct ArrowArrayStream* stream); + +/// \brief Import C++ RecordBatchReader from the C device stream interface +/// +/// The ArrowDeviceArrayStream struct has its contents moved to a private object +/// held alive by the resulting record batch reader. +/// +/// \note If there was a required sync event, sync events are accessible by individual +/// buffers of columns. We are not yet bubbling the sync events from the buffers up to +/// the `GetSyncEvent` method of an imported RecordBatch. This will be added in a future +/// update. +/// +/// \param[in,out] stream C device stream interface struct +/// \param[in] mapper mapping from device type and ID to memory manager +/// \return Imported RecordBatchReader object +ARROW_EXPORT +Result> ImportDeviceRecordBatchReader( + struct ArrowDeviceArrayStream* stream, + const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper); + +/// \brief Import C++ ChunkedArray from the C device stream interface +/// +/// The ArrowDeviceArrayStream struct has its contents moved to a private object, +/// is consumed in its entirety, and released before returning all chunks as a +/// ChunkedArray. +/// +/// \note Any chunks that require synchronization for their device memory will have +/// the SyncEvent objects available by checking the individual buffers of each chunk. +/// These SyncEvents should be checked before accessing the data in those buffers. +/// +/// \param[in,out] stream C device stream interface struct +/// \param[in] mapper mapping from device type and ID to memory manager +/// \return Imported ChunkedArray object +ARROW_EXPORT +Result> ImportDeviceChunkedArray( + struct ArrowDeviceArrayStream* stream, + const DeviceMemoryMapper& mapper = DefaultDeviceMemoryMapper); + +/// @} + +/// \defgroup c-async-stream-interface Functions for working with the async C data +/// interface. +/// +/// @{ + +/// \brief EXPERIMENTAL: AsyncErrorDetail is a StatusDetail that contains an error code +/// and message from an asynchronous operation. +class AsyncErrorDetail : public StatusDetail { + public: + AsyncErrorDetail(int code, std::string message, std::string metadata) + : code_(code), message_(std::move(message)), metadata_(std::move(metadata)) {} + const char* type_id() const override { return "AsyncErrorDetail"; } + // ToString just returns the error message that was returned with the error + std::string ToString() const override { return message_; } + // code is an errno-compatible error code + int code() const { return code_; } + // returns any metadata that was returned with the error, likely in a + // key-value format similar to ArrowSchema metadata + const std::string& ErrorMetadataString() const { return metadata_; } + std::shared_ptr ErrorMetadata() const; + + private: + int code_{0}; + std::string message_; + std::string metadata_; +}; + +struct AsyncRecordBatchGenerator { + std::shared_ptr schema; + DeviceAllocationType device_type; + AsyncGenerator generator; +}; + +namespace internal { +class Executor; +} + +/// \brief EXPERIMENTAL: Create an AsyncRecordBatchReader and populate a corresponding +/// handler to pass to a producer +/// +/// The ArrowAsyncDeviceStreamHandler struct is intended to have its callbacks populated +/// and then be passed to a producer to call the appropriate callbacks when data is ready. +/// This inverts the traditional flow of control, and so we construct a corresponding +/// AsyncRecordBatchGenerator to provide an interface for the consumer to retrieve data as +/// it is pushed to the handler. +/// +/// \param[in,out] handler C struct to be populated +/// \param[in] executor the executor to use for waiting and populating record batches +/// \param[in] queue_size initial number of record batches to request for queueing +/// \param[in] mapper mapping from device type and ID to memory manager +/// \return Future that resolves to either an error or AsyncRecordBatchGenerator once a +/// schema is available or an error is received. +ARROW_EXPORT +Future CreateAsyncDeviceStreamHandler( + struct ArrowAsyncDeviceStreamHandler* handler, internal::Executor* executor, + uint64_t queue_size = 5, DeviceMemoryMapper mapper = DefaultDeviceMemoryMapper); + +/// \brief EXPERIMENTAL: Export an AsyncGenerator of record batches using a provided +/// handler +/// +/// This function calls the callbacks on the consumer-provided async handler as record +/// batches become available from the AsyncGenerator which is provided. It will first call +/// on_schema using the provided schema, and then serially visit each record batch from +/// the generator, calling the on_next_task callback. If an error occurs, on_error will be +/// called appropriately. +/// +/// \param[in] schema the schema of the stream being exported +/// \param[in] generator a generator that asynchronously produces record batches +/// \param[in] device_type the device type that the record batches will be located on +/// \param[in] handler the handler whose callbacks to utilize as data is available +/// \return Future that will resolve once the generator is exhausted or an error occurs +ARROW_EXPORT +Future<> ExportAsyncRecordBatchReader( + std::shared_ptr schema, + AsyncGenerator> generator, + DeviceAllocationType device_type, struct ArrowAsyncDeviceStreamHandler* handler); + +/// @} + +} // namespace arrow diff --git a/pyarrow/include/arrow/c/dlpack.h b/pyarrow/include/arrow/c/dlpack.h new file mode 100644 index 0000000000000000000000000000000000000000..65da38423c2ad62fce26fc115024ef843fb802b5 --- /dev/null +++ b/pyarrow/include/arrow/c/dlpack.h @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/array/array_base.h" +#include "arrow/c/dlpack_abi.h" + +namespace arrow::dlpack { + +/// \brief Export Arrow array as DLPack tensor. +/// +/// DLMangedTensor is produced as defined by the DLPack protocol, +/// see https://dmlc.github.io/dlpack/latest/. +/// +/// Data types for which the protocol is supported are +/// integer and floating-point data types. +/// +/// DLPack protocol only supports arrays with one contiguous +/// memory region which means Arrow Arrays with validity buffers +/// are not supported. +/// +/// \param[in] arr Arrow array +/// \return DLManagedTensor struct +ARROW_EXPORT +Result ExportArray(const std::shared_ptr& arr); + +ARROW_EXPORT +Result ExportTensor(const std::shared_ptr& t); + +/// \brief Get DLDevice with enumerator specifying the +/// type of the device data is stored on and index of the +/// device which is 0 by default for CPU. +/// +/// \param[in] arr Arrow array +/// \return DLDevice struct +ARROW_EXPORT +Result ExportDevice(const std::shared_ptr& arr); + +ARROW_EXPORT +Result ExportDevice(const std::shared_ptr& t); + +} // namespace arrow::dlpack diff --git a/pyarrow/include/arrow/c/dlpack_abi.h b/pyarrow/include/arrow/c/dlpack_abi.h new file mode 100644 index 0000000000000000000000000000000000000000..fbe2a56a344b373f3d3e950e434ba5392036a080 --- /dev/null +++ b/pyarrow/include/arrow/c/dlpack_abi.h @@ -0,0 +1,321 @@ +// Taken from: +// https://github.com/dmlc/dlpack/blob/ca4d00ad3e2e0f410eeab3264d21b8a39397f362/include/dlpack/dlpack.h +/*! + * Copyright (c) 2017 by Contributors + * \file dlpack.h + * \brief The common header of DLPack. + */ +#ifndef DLPACK_DLPACK_H_ +#define DLPACK_DLPACK_H_ + +/** + * \brief Compatibility with C++ + */ +#ifdef __cplusplus +# define DLPACK_EXTERN_C extern "C" +#else +# define DLPACK_EXTERN_C +#endif + +/*! \brief The current major version of dlpack */ +#define DLPACK_MAJOR_VERSION 1 + +/*! \brief The current minor version of dlpack */ +#define DLPACK_MINOR_VERSION 0 + +/*! \brief DLPACK_DLL prefix for windows */ +#ifdef _WIN32 +# ifdef DLPACK_EXPORTS +# define DLPACK_DLL __declspec(dllexport) +# else +# define DLPACK_DLL __declspec(dllimport) +# endif +#else +# define DLPACK_DLL +#endif + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! + * \brief The DLPack version. + * + * A change in major version indicates that we have changed the + * data layout of the ABI - DLManagedTensorVersioned. + * + * A change in minor version indicates that we have added new + * code, such as a new device type, but the ABI is kept the same. + * + * If an obtained DLPack tensor has a major version that disagrees + * with the version number specified in this header file + * (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter + * (and it is safe to do so). It is not safe to access any other fields + * as the memory layout will have changed. + * + * In the case of a minor version mismatch, the tensor can be safely used as + * long as the consumer knows how to interpret all fields. Minor version + * updates indicate the addition of enumeration values. + */ +typedef struct { + /*! \brief DLPack major version. */ + uint32_t major; + /*! \brief DLPack minor version. */ + uint32_t minor; +} DLPackVersion; + +/*! + * \brief The device type in DLDevice. + */ +#ifdef __cplusplus +typedef enum : int32_t { +#else +typedef enum { +#endif + /*! \brief CPU device */ + kDLCPU = 1, + /*! \brief CUDA GPU device */ + kDLCUDA = 2, + /*! + * \brief Pinned CUDA CPU memory by cudaMallocHost + */ + kDLCUDAHost = 3, + /*! \brief OpenCL devices. */ + kDLOpenCL = 4, + /*! \brief Vulkan buffer for next generation graphics. */ + kDLVulkan = 7, + /*! \brief Metal for Apple GPU. */ + kDLMetal = 8, + /*! \brief Verilog simulator buffer */ + kDLVPI = 9, + /*! \brief ROCm GPUs for AMD GPUs */ + kDLROCM = 10, + /*! + * \brief Pinned ROCm CPU memory allocated by hipMallocHost + */ + kDLROCMHost = 11, + /*! + * \brief Reserved extension device type, + * used for quickly test extension device + * The semantics can differ depending on the implementation. + */ + kDLExtDev = 12, + /*! + * \brief CUDA managed/unified memory allocated by cudaMallocManaged + */ + kDLCUDAManaged = 13, + /*! + * \brief Unified shared memory allocated on a oneAPI non-partititioned + * device. Call to oneAPI runtime is required to determine the device + * type, the USM allocation type and the sycl context it is bound to. + * + */ + kDLOneAPI = 14, + /*! \brief GPU support for next generation WebGPU standard. */ + kDLWebGPU = 15, + /*! \brief Qualcomm Hexagon DSP */ + kDLHexagon = 16, +} DLDeviceType; + +/*! + * \brief A Device for Tensor and operator. + */ +typedef struct { + /*! \brief The device type used in the device. */ + DLDeviceType device_type; + /*! + * \brief The device index. + * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0. + */ + int32_t device_id; +} DLDevice; + +/*! + * \brief The type code options DLDataType. + */ +typedef enum { + /*! \brief signed integer */ + kDLInt = 0U, + /*! \brief unsigned integer */ + kDLUInt = 1U, + /*! \brief IEEE floating point */ + kDLFloat = 2U, + /*! + * \brief Opaque handle type, reserved for testing purposes. + * Frameworks need to agree on the handle data type for the exchange to be well-defined. + */ + kDLOpaqueHandle = 3U, + /*! \brief bfloat16 */ + kDLBfloat = 4U, + /*! + * \brief complex number + * (C/C++/Python layout: compact struct per complex number) + */ + kDLComplex = 5U, + /*! \brief boolean */ + kDLBool = 6U, +} DLDataTypeCode; + +/*! + * \brief The data type the tensor can hold. The data type is assumed to follow the + * native endian-ness. An explicit error message should be raised when attempting to + * export an array with non-native endianness + * + * Examples + * - float: type_code = 2, bits = 32, lanes = 1 + * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4 + * - int8: type_code = 0, bits = 8, lanes = 1 + * - std::complex: type_code = 5, bits = 64, lanes = 1 + * - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, + * the underlying storage size of bool is 8 bits) + */ +typedef struct { + /*! + * \brief Type code of base types. + * We keep it uint8_t instead of DLDataTypeCode for minimal memory + * footprint, but the value should be one of DLDataTypeCode enum values. + * */ + uint8_t code; + /*! + * \brief Number of bits, common choices are 8, 16, 32. + */ + uint8_t bits; + /*! \brief Number of lanes in the type, used for vector types. */ + uint16_t lanes; +} DLDataType; + +/*! + * \brief Plain C Tensor object, does not manage memory. + */ +typedef struct { + /*! + * \brief The data pointer points to the allocated data. This will be CUDA + * device pointer or cl_mem handle in OpenCL. It may be opaque on some device + * types. This pointer is always aligned to 256 bytes as in CUDA. The + * `byte_offset` field should be used to point to the beginning of the data. + * + * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, + * TVM, perhaps others) do not adhere to this 256 byte aligment requirement + * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed + * (after which this note will be updated); at the moment it is recommended + * to not rely on the data pointer being correctly aligned. + * + * For given DLTensor, the size of memory required to store the contents of + * data is calculated as follows: + * + * \code{.c} + * static inline size_t GetDataSize(const DLTensor* t) { + * size_t size = 1; + * for (tvm_index_t i = 0; i < t->ndim; ++i) { + * size *= t->shape[i]; + * } + * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; + * return size; + * } + * \endcode + */ + void* data; + /*! \brief The device of the tensor */ + DLDevice device; + /*! \brief Number of dimensions */ + int32_t ndim; + /*! \brief The data type of the pointer*/ + DLDataType dtype; + /*! \brief The shape of the tensor */ + int64_t* shape; + /*! + * \brief strides of the tensor (in number of elements, not bytes) + * can be NULL, indicating tensor is compact and row-majored. + */ + int64_t* strides; + /*! \brief The offset in bytes to the beginning pointer to data */ + uint64_t byte_offset; +} DLTensor; + +/*! + * \brief C Tensor object, manage memory of DLTensor. This data structure is + * intended to facilitate the borrowing of DLTensor by another framework. It is + * not meant to transfer the tensor. When the borrowing framework doesn't need + * the tensor, it should call the deleter to notify the host that the resource + * is no longer needed. + * + * \note This data structure is used as Legacy DLManagedTensor + * in DLPack exchange and is deprecated after DLPack v0.8 + * Use DLManagedTensorVersioned instead. + * This data structure may get renamed or deleted in future versions. + * + * \sa DLManagedTensorVersioned + */ +typedef struct DLManagedTensor { + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; + /*! \brief the context of the original host framework of DLManagedTensor in + * which DLManagedTensor is used in the framework. It can also be NULL. + */ + void* manager_ctx; + /*! + * \brief Destructor - this should be called + * to destruct the manager_ctx which backs the DLManagedTensor. It can be + * NULL if there is no way for the caller to provide a reasonable destructor. + * The destructors deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensor* self); +} DLManagedTensor; + +// bit masks used in in the DLManagedTensorVersioned + +/*! \brief bit mask to indicate that the tensor is read only. */ +#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL) + +/*! + * \brief A versioned and managed C Tensor object, manage memory of DLTensor. + * + * This data structure is intended to facilitate the borrowing of DLTensor by + * another framework. It is not meant to transfer the tensor. When the borrowing + * framework doesn't need the tensor, it should call the deleter to notify the + * host that the resource is no longer needed. + * + * \note This is the current standard DLPack exchange data structure. + */ +struct DLManagedTensorVersioned { + /*! + * \brief The API and ABI version of the current managed Tensor + */ + DLPackVersion version; + /*! + * \brief the context of the original host framework. + * + * Stores DLManagedTensorVersioned is used in the + * framework. It can also be NULL. + */ + void* manager_ctx; + /*! + * \brief Destructor. + * + * This should be called to destruct manager_ctx which holds the + * DLManagedTensorVersioned. It can be NULL if there is no way for the caller to provide + * a reasonable destructor. The destructors deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensorVersioned* self); + /*! + * \brief Additional bitmask flags information about the tensor. + * + * By default the flags should be set to 0. + * + * \note Future ABI changes should keep everything until this field + * stable, to ensure that deleter can be correctly called. + * + * \sa DLPACK_FLAG_BITMASK_READ_ONLY + */ + uint64_t flags; + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; +}; + +#ifdef __cplusplus +} // DLPACK_EXTERN_C +#endif +#endif // DLPACK_DLPACK_H_ diff --git a/pyarrow/include/arrow/c/helpers.h b/pyarrow/include/arrow/c/helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..6e4df17f43ebfe238484056fedbd4e6d575460f0 --- /dev/null +++ b/pyarrow/include/arrow/c/helpers.h @@ -0,0 +1,178 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/c/abi.h" + +#define ARROW_C_ASSERT(condition, msg) \ + do { \ + if (!(condition)) { \ + fprintf(stderr, "%s:%d:: %s", __FILE__, __LINE__, (msg)); \ + abort(); \ + } \ + } while (0) + +#ifdef __cplusplus +extern "C" { +#endif + +/// Query whether the C schema is released +inline int ArrowSchemaIsReleased(const struct ArrowSchema* schema) { + return schema->release == NULL; +} + +/// Mark the C schema released (for use in release callbacks) +inline void ArrowSchemaMarkReleased(struct ArrowSchema* schema) { + schema->release = NULL; +} + +/// Move the C schema from `src` to `dest` +/// +/// Note `dest` must *not* point to a valid schema already, otherwise there +/// will be a memory leak. +inline void ArrowSchemaMove(struct ArrowSchema* src, struct ArrowSchema* dest) { + assert(dest != src); + assert(!ArrowSchemaIsReleased(src)); + memcpy(dest, src, sizeof(struct ArrowSchema)); + ArrowSchemaMarkReleased(src); +} + +/// Release the C schema, if necessary, by calling its release callback +inline void ArrowSchemaRelease(struct ArrowSchema* schema) { + if (!ArrowSchemaIsReleased(schema)) { + schema->release(schema); + ARROW_C_ASSERT(ArrowSchemaIsReleased(schema), + "ArrowSchemaRelease did not cleanup release callback"); + } +} + +/// Query whether the C array is released +inline int ArrowArrayIsReleased(const struct ArrowArray* array) { + return array->release == NULL; +} + +inline int ArrowDeviceArrayIsReleased(const struct ArrowDeviceArray* array) { + return ArrowArrayIsReleased(&array->array); +} + +/// Mark the C array released (for use in release callbacks) +inline void ArrowArrayMarkReleased(struct ArrowArray* array) { array->release = NULL; } + +inline void ArrowDeviceArrayMarkReleased(struct ArrowDeviceArray* array) { + ArrowArrayMarkReleased(&array->array); +} + +/// Move the C array from `src` to `dest` +/// +/// Note `dest` must *not* point to a valid array already, otherwise there +/// will be a memory leak. +inline void ArrowArrayMove(struct ArrowArray* src, struct ArrowArray* dest) { + assert(dest != src); + assert(!ArrowArrayIsReleased(src)); + memcpy(dest, src, sizeof(struct ArrowArray)); + ArrowArrayMarkReleased(src); +} + +inline void ArrowDeviceArrayMove(struct ArrowDeviceArray* src, + struct ArrowDeviceArray* dest) { + assert(dest != src); + assert(!ArrowDeviceArrayIsReleased(src)); + memcpy(dest, src, sizeof(struct ArrowDeviceArray)); + ArrowDeviceArrayMarkReleased(src); +} + +/// Release the C array, if necessary, by calling its release callback +inline void ArrowArrayRelease(struct ArrowArray* array) { + if (!ArrowArrayIsReleased(array)) { + array->release(array); + ARROW_C_ASSERT(ArrowArrayIsReleased(array), + "ArrowArrayRelease did not cleanup release callback"); + } +} + +inline void ArrowDeviceArrayRelease(struct ArrowDeviceArray* array) { + if (!ArrowDeviceArrayIsReleased(array)) { + array->array.release(&array->array); + ARROW_C_ASSERT(ArrowDeviceArrayIsReleased(array), + "ArrowDeviceArrayRelease did not cleanup release callback"); + } +} + +/// Query whether the C array stream is released +inline int ArrowArrayStreamIsReleased(const struct ArrowArrayStream* stream) { + return stream->release == NULL; +} + +inline int ArrowDeviceArrayStreamIsReleased(const struct ArrowDeviceArrayStream* stream) { + return stream->release == NULL; +} + +/// Mark the C array stream released (for use in release callbacks) +inline void ArrowArrayStreamMarkReleased(struct ArrowArrayStream* stream) { + stream->release = NULL; +} + +inline void ArrowDeviceArrayStreamMarkReleased(struct ArrowDeviceArrayStream* stream) { + stream->release = NULL; +} + +/// Move the C array stream from `src` to `dest` +/// +/// Note `dest` must *not* point to a valid stream already, otherwise there +/// will be a memory leak. +inline void ArrowArrayStreamMove(struct ArrowArrayStream* src, + struct ArrowArrayStream* dest) { + assert(dest != src); + assert(!ArrowArrayStreamIsReleased(src)); + memcpy(dest, src, sizeof(struct ArrowArrayStream)); + ArrowArrayStreamMarkReleased(src); +} + +inline void ArrowDeviceArrayStreamMove(struct ArrowDeviceArrayStream* src, + struct ArrowDeviceArrayStream* dest) { + assert(dest != src); + assert(!ArrowDeviceArrayStreamIsReleased(src)); + memcpy(dest, src, sizeof(struct ArrowDeviceArrayStream)); + ArrowDeviceArrayStreamMarkReleased(src); +} + +/// Release the C array stream, if necessary, by calling its release callback +inline void ArrowArrayStreamRelease(struct ArrowArrayStream* stream) { + if (!ArrowArrayStreamIsReleased(stream)) { + stream->release(stream); + ARROW_C_ASSERT(ArrowArrayStreamIsReleased(stream), + "ArrowArrayStreamRelease did not cleanup release callback"); + } +} + +inline void ArrowDeviceArrayStreamRelease(struct ArrowDeviceArrayStream* stream) { + if (!ArrowDeviceArrayStreamIsReleased(stream)) { + stream->release(stream); + ARROW_C_ASSERT(ArrowDeviceArrayStreamIsReleased(stream), + "ArrowDeviceArrayStreamRelease did not cleanup release callback"); + } +} + +#ifdef __cplusplus +} +#endif diff --git a/pyarrow/include/arrow/chunk_resolver.h b/pyarrow/include/arrow/chunk_resolver.h new file mode 100644 index 0000000000000000000000000000000000000000..3d6458167fac979c2d6c6c112fa00194b9818092 --- /dev/null +++ b/pyarrow/include/arrow/chunk_resolver.h @@ -0,0 +1,294 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/span.h" + +namespace arrow { + +class ChunkResolver; + +template +struct ARROW_EXPORT TypedChunkLocation { + /// \brief Index of the chunk in the array of chunks + /// + /// The value is always in the range `[0, chunks.size()]`. `chunks.size()` is used + /// to represent out-of-bounds locations. + IndexType chunk_index = 0; + + /// \brief Index of the value in the chunk + /// + /// The value is UNDEFINED if `chunk_index >= chunks.size()` + IndexType index_in_chunk = 0; + + TypedChunkLocation() = default; + + TypedChunkLocation(IndexType chunk_index, IndexType index_in_chunk) + : chunk_index(chunk_index), index_in_chunk(index_in_chunk) { + static_assert(sizeof(TypedChunkLocation) == 2 * sizeof(IndexType)); + static_assert(alignof(TypedChunkLocation) == alignof(IndexType)); + } + + bool operator==(TypedChunkLocation other) const { + return chunk_index == other.chunk_index && index_in_chunk == other.index_in_chunk; + } +}; + +using ChunkLocation = TypedChunkLocation; + +/// \brief An utility that incrementally resolves logical indices into +/// physical indices in a chunked array. +class ARROW_EXPORT ChunkResolver { + private: + /// \brief Array containing `chunks.size() + 1` offsets. + /// + /// `offsets_[i]` is the starting logical index of chunk `i`. `offsets_[0]` is always 0 + /// and `offsets_[chunks.size()]` is the logical length of the chunked array. + std::vector offsets_; + + /// \brief Cache of the index of the last resolved chunk. + /// + /// \invariant `cached_chunk_ in [0, chunks.size()]` + mutable std::atomic cached_chunk_; + + public: + explicit ChunkResolver(const ArrayVector& chunks) noexcept; + explicit ChunkResolver(util::span chunks) noexcept; + explicit ChunkResolver(const RecordBatchVector& batches) noexcept; + + /// \brief Construct a ChunkResolver from a vector of chunks.size() + 1 offsets. + /// + /// The first offset must be 0 and the last offset must be the logical length of the + /// chunked array. Each offset before the last represents the starting logical index of + /// the corresponding chunk. + explicit ChunkResolver(std::vector offsets) noexcept + : offsets_(std::move(offsets)), cached_chunk_(0) { +#ifndef NDEBUG + assert(offsets_.size() >= 1); + assert(offsets_[0] == 0); + for (size_t i = 1; i < offsets_.size(); i++) { + assert(offsets_[i] >= offsets_[i - 1]); + } + assert(offsets_.size() - 1 <= + static_cast(std::numeric_limits::max())); +#endif + } + + ChunkResolver(ChunkResolver&& other) noexcept; + ChunkResolver& operator=(ChunkResolver&& other) noexcept; + + ChunkResolver(const ChunkResolver& other) noexcept; + ChunkResolver& operator=(const ChunkResolver& other) noexcept; + + int64_t logical_array_length() const { return offsets_.back(); } + int32_t num_chunks() const { return static_cast(offsets_.size() - 1); } + + int64_t chunk_length(int64_t chunk_index) const { + return offsets_[chunk_index + 1] - offsets_[chunk_index]; + } + + /// \brief Resolve a logical index to a ChunkLocation. + /// + /// The returned ChunkLocation contains the chunk index and the within-chunk index + /// equivalent to the logical index. + /// + /// \pre `index >= 0` + /// \post `location.chunk_index` in `[0, chunks.size()]` + /// \param index The logical index to resolve + /// \return ChunkLocation with a valid chunk_index if index is within + /// bounds, or with `chunk_index == chunks.size()` if logical index is + /// `>= chunked_array.length()`. + inline ChunkLocation Resolve(int64_t index) const { + const auto cached_chunk = cached_chunk_.load(std::memory_order_relaxed); + const auto chunk_index = + ResolveChunkIndex(index, cached_chunk); + return ChunkLocation{chunk_index, index - offsets_[chunk_index]}; + } + + /// \brief Resolve a logical index to a ChunkLocation. + /// + /// The returned ChunkLocation contains the chunk index and the within-chunk index + /// equivalent to the logical index. + /// + /// \pre `index >= 0` + /// \post `location.chunk_index` in `[0, chunks.size()]` + /// \param index The logical index to resolve + /// \param hint ChunkLocation{} or the last ChunkLocation returned by + /// this ChunkResolver. + /// \return ChunkLocation with a valid chunk_index if index is within + /// bounds, or with `chunk_index == chunks.size()` if logical index is + /// `>= chunked_array.length()`. + inline ChunkLocation ResolveWithHint(int64_t index, ChunkLocation hint) const { + assert(hint.chunk_index < static_cast(offsets_.size())); + const auto chunk_index = ResolveChunkIndex( + index, static_cast(hint.chunk_index)); + return ChunkLocation{chunk_index, index - offsets_[chunk_index]}; + } + + /// \brief Resolve `n_indices` logical indices to chunk indices. + /// + /// \pre 0 <= logical_index_vec[i] < logical_array_length() + /// (for well-defined and valid chunk index results) + /// \pre out_chunk_location_vec has space for `n_indices` locations + /// \pre chunk_hint in [0, chunks.size()] + /// \post out_chunk_location_vec[i].chunk_index in [0, chunks.size()] for i in [0, n) + /// \post if logical_index_vec[i] >= chunked_array.length(), then + /// out_chunk_location_vec[i].chunk_index == chunks.size() + /// and out_chunk_location_vec[i].index_in_chunk is UNDEFINED (can be + /// out-of-bounds) + /// \post if logical_index_vec[i] < 0, then both values in out_chunk_index_vec[i] + /// are UNDEFINED + /// + /// \param n_indices The number of logical indices to resolve + /// \param logical_index_vec The logical indices to resolve + /// \param out_chunk_location_vec The output array where the locations will be written + /// \param chunk_hint 0 or the last chunk_index produced by ResolveMany + /// \return false iff chunks.size() > std::numeric_limits::max() + template + [[nodiscard]] bool ResolveMany(int64_t n_indices, const IndexType* logical_index_vec, + TypedChunkLocation* out_chunk_location_vec, + IndexType chunk_hint = 0) const { + if constexpr (sizeof(IndexType) < sizeof(uint32_t)) { + // The max value returned by Bisect is `offsets.size() - 1` (= chunks.size()). + constexpr int64_t kMaxIndexTypeValue = std::numeric_limits::max(); + // A ChunkedArray with enough empty chunks can make the index of a chunk + // exceed the logical index and thus the maximum value of IndexType. + const bool chunk_index_fits_on_type = num_chunks() <= kMaxIndexTypeValue; + if (ARROW_PREDICT_FALSE(!chunk_index_fits_on_type)) { + return false; + } + // Since an index-in-chunk cannot possibly exceed the logical index being + // queried, we don't have to worry about these values not fitting on IndexType. + } + if constexpr (std::is_signed_v) { + // We interpret signed integers as unsigned and avoid having to generate double + // the amount of binary code to handle each integer width. + // + // Negative logical indices can become large values when cast to unsigned, and + // they are gracefully handled by ResolveManyImpl, but both the chunk index + // and the index in chunk values will be undefined in these cases. This + // happend because int8_t(-1) == uint8_t(255) and 255 could be a valid + // logical index in the chunked array. + using U = std::make_unsigned_t; + ResolveManyImpl(n_indices, reinterpret_cast(logical_index_vec), + reinterpret_cast*>(out_chunk_location_vec), + static_cast(chunk_hint)); + } else { + static_assert(std::is_unsigned_v); + ResolveManyImpl(n_indices, logical_index_vec, out_chunk_location_vec, + static_cast(chunk_hint)); + } + return true; + } + + private: + template + inline int64_t ResolveChunkIndex(int64_t index, int32_t cached_chunk) const { + // It is common for algorithms sequentially processing arrays to make consecutive + // accesses at a relatively small distance from each other, hence often falling in the + // same chunk. + // + // This is guaranteed when merging (assuming each side of the merge uses its + // own resolver), and is the most common case in recursive invocations of + // partitioning. + const auto num_offsets = static_cast(offsets_.size()); + const int64_t* offsets = offsets_.data(); + if (ARROW_PREDICT_TRUE(index >= offsets[cached_chunk]) && + (static_cast(cached_chunk + 1) == num_offsets || + index < offsets[cached_chunk + 1])) { + return cached_chunk; + } + // lo < hi is guaranteed by `num_offsets = chunks.size() + 1` + const auto chunk_index = Bisect(index, offsets, /*lo=*/0, /*hi=*/num_offsets); + if constexpr (StoreCachedChunk) { + assert(static_cast(chunk_index) < static_cast(offsets_.size())); + cached_chunk_.store(chunk_index, std::memory_order_relaxed); + } + return chunk_index; + } + + /// \pre all the pre-conditions of ChunkResolver::ResolveMany() + /// \pre num_offsets - 1 <= std::numeric_limits::max() + void ResolveManyImpl(int64_t, const uint8_t*, TypedChunkLocation*, + int32_t) const; + void ResolveManyImpl(int64_t, const uint16_t*, TypedChunkLocation*, + int32_t) const; + void ResolveManyImpl(int64_t, const uint32_t*, TypedChunkLocation*, + int32_t) const; + void ResolveManyImpl(int64_t, const uint64_t*, TypedChunkLocation*, + int32_t) const; + + public: + /// \brief Find the index of the chunk that contains the logical index. + /// + /// Any non-negative index is accepted. When `hi=num_offsets`, the largest + /// possible return value is `num_offsets-1` which is equal to + /// `chunks.size()`. Which is returned when the logical index is greater or + /// equal the logical length of the chunked array. + /// + /// \pre index >= 0 (otherwise, when index is negative, hi-1 is returned) + /// \pre lo < hi + /// \pre lo >= 0 && hi <= offsets_.size() + static inline int32_t Bisect(int64_t index, const int64_t* offsets, int32_t lo, + int32_t hi) { + return Bisect(static_cast(index), + reinterpret_cast(offsets), static_cast(lo), + static_cast(hi)); + } + + static inline int32_t Bisect(uint64_t index, const uint64_t* offsets, uint32_t lo, + uint32_t hi) { + // Similar to std::upper_bound(), but slightly different as our offsets + // array always starts with 0. + auto n = hi - lo; + // First iteration does not need to check for n > 1 + // (lo < hi is guaranteed by the precondition). + assert(n > 1 && "lo < hi is a precondition of Bisect"); + do { + const uint32_t m = n >> 1; + const uint32_t mid = lo + m; + if (index >= offsets[mid]) { + lo = mid; + n -= m; + } else { + n = m; + } + } while (n > 1); + return lo; + } +}; + +// Explicitly instantiate template base struct, for DLL linking on Windows +template struct TypedChunkLocation; +template struct TypedChunkLocation; +template struct TypedChunkLocation; +template struct TypedChunkLocation; +template struct TypedChunkLocation; +template struct TypedChunkLocation; +template struct TypedChunkLocation; +template struct TypedChunkLocation; +} // namespace arrow diff --git a/pyarrow/include/arrow/chunked_array.h b/pyarrow/include/arrow/chunked_array.h new file mode 100644 index 0000000000000000000000000000000000000000..02bcd0f9026bc7ba8ac9ef2daf2a3bd7ab31d56f --- /dev/null +++ b/pyarrow/include/arrow/chunked_array.h @@ -0,0 +1,283 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/chunk_resolver.h" +#include "arrow/compare.h" +#include "arrow/device_allocation_type_set.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; +class DataType; +class MemoryPool; +namespace stl { +template +class ChunkedArrayIterator; +} // namespace stl + +/// \class ChunkedArray +/// \brief A data structure managing a list of primitive Arrow arrays logically +/// as one large array +/// +/// Data chunking is treated throughout this project largely as an +/// implementation detail for performance and memory use optimization. +/// ChunkedArray allows Array objects to be collected and interpreted +/// as a single logical array without requiring an expensive concatenation +/// step. +/// +/// In some cases, data produced by a function may exceed the capacity of an +/// Array (like BinaryArray or StringArray) and so returning multiple Arrays is +/// the only possibility. In these cases, we recommend returning a ChunkedArray +/// instead of vector of Arrays or some alternative. +/// +/// When data is processed in parallel, it may not be practical or possible to +/// create large contiguous memory allocations and write output into them. With +/// some data types, like binary and string types, it is not possible at all to +/// produce non-chunked array outputs without requiring a concatenation step at +/// the end of processing. +/// +/// Application developers may tune chunk sizes based on analysis of +/// performance profiles but many developer-users will not need to be +/// especially concerned with the chunking details. +/// +/// Preserving the chunk layout/sizes in processing steps is generally not +/// considered to be a contract in APIs. A function may decide to alter the +/// chunking of its result. Similarly, APIs accepting multiple ChunkedArray +/// inputs should not expect the chunk layout to be the same in each input. +class ARROW_EXPORT ChunkedArray { + public: + ChunkedArray(ChunkedArray&&) = default; + ChunkedArray& operator=(ChunkedArray&&) = default; + + /// \brief Construct a chunked array from a single Array + explicit ChunkedArray(std::shared_ptr chunk) + : ChunkedArray(ArrayVector{std::move(chunk)}) {} + + /// \brief Construct a chunked array from a vector of arrays and an optional data type + /// + /// The vector elements must have the same data type. + /// If the data type is passed explicitly, the vector may be empty. + /// If the data type is omitted, the vector must be non-empty. + explicit ChunkedArray(ArrayVector chunks, std::shared_ptr type = NULLPTR); + + // \brief Constructor with basic input validation. + static Result> Make( + ArrayVector chunks, std::shared_ptr type = NULLPTR); + + /// \brief Create an empty ChunkedArray of a given type + /// + /// The output ChunkedArray will have one chunk with an empty + /// array of the given type. + /// + /// \param[in] type the data type of the empty ChunkedArray + /// \param[in] pool the memory pool to allocate memory from + /// \return the resulting ChunkedArray + static Result> MakeEmpty( + std::shared_ptr type, MemoryPool* pool = default_memory_pool()); + + /// \return the total length of the chunked array; computed on construction + int64_t length() const { return length_; } + + /// \return the total number of nulls among all chunks + int64_t null_count() const { return null_count_; } + + /// \return the total number of chunks in the chunked array + int num_chunks() const { return static_cast(chunks_.size()); } + + /// \return chunk a particular chunk from the chunked array + const std::shared_ptr& chunk(int i) const { return chunks_[i]; } + + /// \return an ArrayVector of chunks + const ArrayVector& chunks() const { return chunks_; } + + /// \return The set of device allocation types used by the chunks in this + /// chunked array. + DeviceAllocationTypeSet device_types() const; + + /// \return true if all chunks are allocated on CPU-accessible memory. + bool is_cpu() const { return device_types().is_cpu_only(); } + + /// \brief Construct a zero-copy slice of the chunked array with the + /// indicated offset and length + /// + /// \param[in] offset the position of the first element in the constructed + /// slice + /// \param[in] length the length of the slice. If there are not enough + /// elements in the chunked array, the length will be adjusted accordingly + /// + /// \return a new object wrapped in std::shared_ptr + std::shared_ptr Slice(int64_t offset, int64_t length) const; + + /// \brief Slice from offset until end of the chunked array + std::shared_ptr Slice(int64_t offset) const; + + /// \brief Flatten this chunked array as a vector of chunked arrays, one + /// for each struct field + /// + /// \param[in] pool The pool for buffer allocations, if any + Result>> Flatten( + MemoryPool* pool = default_memory_pool()) const; + + /// Construct a zero-copy view of this chunked array with the given + /// type. Calls Array::View on each constituent chunk. Always succeeds if + /// there are zero chunks + Result> View(const std::shared_ptr& type) const; + + /// \brief Return the type of the chunked array + const std::shared_ptr& type() const { return type_; } + + /// \brief Return a Scalar containing the value of this array at index + Result> GetScalar(int64_t index) const; + + /// \brief Determine if two chunked arrays are equal. + /// + /// Two chunked arrays can be equal only if they have equal datatypes. + /// However, they may be equal even if they have different chunkings. + bool Equals(const ChunkedArray& other, + const EqualOptions& opts = EqualOptions::Defaults()) const; + /// \brief Determine if two chunked arrays are equal. + bool Equals(const std::shared_ptr& other, + const EqualOptions& opts = EqualOptions::Defaults()) const; + /// \brief Determine if two chunked arrays approximately equal + bool ApproxEquals(const ChunkedArray& other, + const EqualOptions& = EqualOptions::Defaults()) const; + + /// \return PrettyPrint representation suitable for debugging + std::string ToString() const; + + /// \brief Perform cheap validation checks to determine obvious inconsistencies + /// within the chunk array's internal data. + /// + /// This is O(k*m) where k is the number of array descendents, + /// and m is the number of chunks. + /// + /// \return Status + Status Validate() const; + + /// \brief Perform extensive validation checks to determine inconsistencies + /// within the chunk array's internal data. + /// + /// This is O(k*n) where k is the number of array descendents, + /// and n is the length in elements. + /// + /// \return Status + Status ValidateFull() const; + + protected: + ArrayVector chunks_; + std::shared_ptr type_; + int64_t length_; + int64_t null_count_; + + private: + template + friend class ::arrow::stl::ChunkedArrayIterator; + ChunkResolver chunk_resolver_; + ARROW_DISALLOW_COPY_AND_ASSIGN(ChunkedArray); +}; + +namespace internal { + +/// \brief EXPERIMENTAL: Utility for incremental iteration over contiguous +/// pieces of potentially differently-chunked ChunkedArray objects +class ARROW_EXPORT MultipleChunkIterator { + public: + MultipleChunkIterator(const ChunkedArray& left, const ChunkedArray& right) + : left_(left), + right_(right), + pos_(0), + length_(left.length()), + chunk_idx_left_(0), + chunk_idx_right_(0), + chunk_pos_left_(0), + chunk_pos_right_(0) {} + + bool Next(std::shared_ptr* next_left, std::shared_ptr* next_right); + + int64_t position() const { return pos_; } + + private: + const ChunkedArray& left_; + const ChunkedArray& right_; + + // The amount of the entire ChunkedArray consumed + int64_t pos_; + + // Length of the chunked array(s) + int64_t length_; + + // Current left chunk + int chunk_idx_left_; + + // Current right chunk + int chunk_idx_right_; + + // Offset into the current left chunk + int64_t chunk_pos_left_; + + // Offset into the current right chunk + int64_t chunk_pos_right_; +}; + +/// \brief Evaluate binary function on two ChunkedArray objects having possibly +/// different chunk layouts. The passed binary function / functor should have +/// the following signature. +/// +/// Status(const Array&, const Array&, int64_t) +/// +/// The third argument is the absolute position relative to the start of each +/// ChunkedArray. The function is executed against each contiguous pair of +/// array segments, slicing if necessary. +/// +/// For example, if two arrays have chunk sizes +/// +/// left: [10, 10, 20] +/// right: [15, 10, 15] +/// +/// Then the following invocations take place (pseudocode) +/// +/// func(left.chunk[0][0:10], right.chunk[0][0:10], 0) +/// func(left.chunk[1][0:5], right.chunk[0][10:15], 10) +/// func(left.chunk[1][5:10], right.chunk[1][0:5], 15) +/// func(left.chunk[2][0:5], right.chunk[1][5:10], 20) +/// func(left.chunk[2][5:20], right.chunk[2][:], 25) +template +Status ApplyBinaryChunked(const ChunkedArray& left, const ChunkedArray& right, + Action&& action) { + MultipleChunkIterator iterator(left, right); + std::shared_ptr left_piece, right_piece; + while (iterator.Next(&left_piece, &right_piece)) { + ARROW_RETURN_NOT_OK(action(*left_piece, *right_piece, iterator.position())); + } + return Status::OK(); +} + +} // namespace internal +} // namespace arrow diff --git a/pyarrow/include/arrow/compare.h b/pyarrow/include/arrow/compare.h new file mode 100644 index 0000000000000000000000000000000000000000..2198495d7d20371d86aef50b8beb00541d323e74 --- /dev/null +++ b/pyarrow/include/arrow/compare.h @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Functions for comparing Arrow data structures + +#pragma once + +#include +#include + +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +struct ArrayStatistics; +class Array; +class DataType; +class Tensor; +class SparseTensor; +struct Scalar; + +static constexpr double kDefaultAbsoluteTolerance = 1E-5; + +/// A container of options for equality comparisons +class EqualOptions { + public: + /// Whether or not NaNs are considered equal. + bool nans_equal() const { return nans_equal_; } + + /// Return a new EqualOptions object with the "nans_equal" property changed. + EqualOptions nans_equal(bool v) const { + auto res = EqualOptions(*this); + res.nans_equal_ = v; + return res; + } + + /// Whether or not zeros with differing signs are considered equal. + bool signed_zeros_equal() const { return signed_zeros_equal_; } + + /// Return a new EqualOptions object with the "signed_zeros_equal" property changed. + EqualOptions signed_zeros_equal(bool v) const { + auto res = EqualOptions(*this); + res.signed_zeros_equal_ = v; + return res; + } + + /// Whether the "atol" property is used in the comparison. + /// + /// This option only affects the Equals methods + /// and has no effect on ApproxEquals methods. + bool use_atol() const { return use_atol_; } + + /// Return a new EqualOptions object with the "use_atol" property changed. + EqualOptions use_atol(bool v) const { + auto res = EqualOptions(*this); + res.use_atol_ = v; + return res; + } + + /// The absolute tolerance for approximate comparisons of floating-point values. + /// Note that this option is ignored if "use_atol" is set to false. + double atol() const { return atol_; } + + /// Return a new EqualOptions object with the "atol" property changed. + EqualOptions atol(double v) const { + auto res = EqualOptions(*this); + res.atol_ = v; + return res; + } + + /// Whether the \ref arrow::Schema property is used in the comparison. + /// + /// This option only affects the Equals methods + /// and has no effect on ApproxEquals methods. + bool use_schema() const { return use_schema_; } + + /// Return a new EqualOptions object with the "use_schema_" property changed. + /// + /// Setting this option is false making the value of \ref EqualOptions::use_metadata + /// is ignored. + EqualOptions use_schema(bool v) const { + auto res = EqualOptions(*this); + res.use_schema_ = v; + return res; + } + + /// Whether the "metadata" in \ref arrow::Schema is used in the comparison. + /// + /// This option only affects the Equals methods + /// and has no effect on the ApproxEquals methods. + /// + /// Note: This option is only considered when \ref arrow::EqualOptions::use_schema is + /// set to true. + bool use_metadata() const { return use_metadata_; } + + /// Return a new EqualOptions object with the "use_metadata" property changed. + EqualOptions use_metadata(bool v) const { + auto res = EqualOptions(*this); + res.use_metadata_ = v; + return res; + } + + /// The ostream to which a diff will be formatted if arrays disagree. + /// If this is null (the default) no diff will be formatted. + std::ostream* diff_sink() const { return diff_sink_; } + + /// Return a new EqualOptions object with the "diff_sink" property changed. + /// This option will be ignored if diff formatting of the types of compared arrays is + /// not supported. + EqualOptions diff_sink(std::ostream* diff_sink) const { + auto res = EqualOptions(*this); + res.diff_sink_ = diff_sink; + return res; + } + + static EqualOptions Defaults() { return {}; } + + protected: + double atol_ = kDefaultAbsoluteTolerance; + bool nans_equal_ = false; + bool signed_zeros_equal_ = true; + bool use_atol_ = false; + bool use_schema_ = true; + bool use_metadata_ = false; + + std::ostream* diff_sink_ = NULLPTR; +}; + +/// Returns true if the arrays are exactly equal +/// +/// Note that arrow::ArrayStatistics is not included in the comparison. +ARROW_EXPORT bool ArrayEquals(const Array& left, const Array& right, + const EqualOptions& = EqualOptions::Defaults()); + +/// Returns true if the arrays are approximately equal. For non-floating point +/// types, this is equivalent to ArrayEquals(left, right) +/// +/// Note that arrow::ArrayStatistics is not included in the comparison. +ARROW_EXPORT bool ArrayApproxEquals(const Array& left, const Array& right, + const EqualOptions& = EqualOptions::Defaults()); + +/// Returns true if indicated equal-length segment of arrays are exactly equal +/// +/// Note that arrow::ArrayStatistics is not included in the comparison. +ARROW_EXPORT bool ArrayRangeEquals(const Array& left, const Array& right, + int64_t start_idx, int64_t end_idx, + int64_t other_start_idx, + const EqualOptions& = EqualOptions::Defaults()); + +/// Returns true if indicated equal-length segment of arrays are approximately equal +/// +/// Note that arrow::ArrayStatistics is not included in the comparison. +ARROW_EXPORT bool ArrayRangeApproxEquals(const Array& left, const Array& right, + int64_t start_idx, int64_t end_idx, + int64_t other_start_idx, + const EqualOptions& = EqualOptions::Defaults()); + +ARROW_EXPORT bool TensorEquals(const Tensor& left, const Tensor& right, + const EqualOptions& = EqualOptions::Defaults()); + +/// EXPERIMENTAL: Returns true if the given sparse tensors are exactly equal +ARROW_EXPORT bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right, + const EqualOptions& = EqualOptions::Defaults()); + +/// Returns true if the type metadata are exactly equal +/// \param[in] left a DataType +/// \param[in] right a DataType +/// \param[in] check_metadata whether to compare KeyValueMetadata for child +/// fields +ARROW_EXPORT bool TypeEquals(const DataType& left, const DataType& right, + bool check_metadata = true); + +/// \brief Check two \ref arrow::ArrayStatistics for equality +/// \param[in] left an \ref arrow::ArrayStatistics +/// \param[in] right an \ref arrow::ArrayStatistics +/// \param[in] options Options used to compare double values for equality. +/// \return True if the two \ref arrow::ArrayStatistics instances are equal; otherwise, +/// false. +ARROW_EXPORT bool ArrayStatisticsEquals( + const ArrayStatistics& left, const ArrayStatistics& right, + const EqualOptions& options = EqualOptions::Defaults()); + +/// Returns true if scalars are equal +/// \param[in] left a Scalar +/// \param[in] right a Scalar +/// \param[in] options comparison options +ARROW_EXPORT bool ScalarEquals(const Scalar& left, const Scalar& right, + const EqualOptions& options = EqualOptions::Defaults()); + +/// Returns true if scalars are approximately equal +/// \param[in] left a Scalar +/// \param[in] right a Scalar +/// \param[in] options comparison options +ARROW_EXPORT bool ScalarApproxEquals( + const Scalar& left, const Scalar& right, + const EqualOptions& options = EqualOptions::Defaults()); + +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/api.h b/pyarrow/include/arrow/compute/api.h new file mode 100644 index 0000000000000000000000000000000000000000..343e30643cfd31916caafc4a84a3fd393c9a84ef --- /dev/null +++ b/pyarrow/include/arrow/compute/api.h @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +/// \defgroup compute-functions Abstract compute function API +/// @{ +/// @} + +/// \defgroup compute-concrete-options Concrete option classes for compute functions +/// @{ +/// @} + +#include "arrow/compute/api_aggregate.h" // IWYU pragma: export +#include "arrow/compute/api_scalar.h" // IWYU pragma: export +#include "arrow/compute/api_vector.h" // IWYU pragma: export +#include "arrow/compute/cast.h" // IWYU pragma: export +#include "arrow/compute/function.h" // IWYU pragma: export +#include "arrow/compute/function_options.h" // IWYU pragma: export +#include "arrow/compute/initialize.h" // IWYU pragma: export +#include "arrow/compute/kernel.h" // IWYU pragma: export +#include "arrow/compute/registry.h" // IWYU pragma: export +#include "arrow/datum.h" // IWYU pragma: export + +#include "arrow/compute/expression.h" // IWYU pragma: export + +/// \defgroup execnode-row Utilities for working with data in a row-major format +/// @{ +/// @} + +#include "arrow/compute/row/grouper.h" // IWYU pragma: export + +/// \defgroup acero-internals Acero internals, useful for those extending Acero +/// @{ +/// @} + +#include "arrow/compute/exec.h" // IWYU pragma: export diff --git a/pyarrow/include/arrow/compute/api_aggregate.h b/pyarrow/include/arrow/compute/api_aggregate.h new file mode 100644 index 0000000000000000000000000000000000000000..d31e0a73156dc8007e0fefaeabc5e9b3e60618fa --- /dev/null +++ b/pyarrow/include/arrow/compute/api_aggregate.h @@ -0,0 +1,596 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Eager evaluation convenience APIs for invoking common functions, including +// necessary memory allocations + +#pragma once + +#include + +#include "arrow/compute/function_options.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; + +namespace compute { + +class ExecContext; + +// ---------------------------------------------------------------------- +// Aggregate functions + +/// \addtogroup compute-concrete-options +/// @{ + +/// \brief Control general scalar aggregate kernel behavior +/// +/// By default, null values are ignored (skip_nulls = true). +class ARROW_EXPORT ScalarAggregateOptions : public FunctionOptions { + public: + explicit ScalarAggregateOptions(bool skip_nulls = true, uint32_t min_count = 1); + static constexpr const char kTypeName[] = "ScalarAggregateOptions"; + static ScalarAggregateOptions Defaults() { return ScalarAggregateOptions{}; } + + /// If true (the default), null values are ignored. Otherwise, if any value is null, + /// emit null. + bool skip_nulls; + /// If less than this many non-null values are observed, emit null. + uint32_t min_count; +}; + +/// \brief Control count aggregate kernel behavior. +/// +/// By default, only non-null values are counted. +class ARROW_EXPORT CountOptions : public FunctionOptions { + public: + enum CountMode { + /// Count only non-null values. + ONLY_VALID = 0, + /// Count only null values. + ONLY_NULL, + /// Count both non-null and null values. + ALL, + }; + explicit CountOptions(CountMode mode = CountMode::ONLY_VALID); + static constexpr const char kTypeName[] = "CountOptions"; + static CountOptions Defaults() { return CountOptions{}; } + + CountMode mode; +}; + +/// \brief Control Mode kernel behavior +/// +/// Returns top-n common values and counts. +/// By default, returns the most common value and count. +class ARROW_EXPORT ModeOptions : public FunctionOptions { + public: + explicit ModeOptions(int64_t n = 1, bool skip_nulls = true, uint32_t min_count = 0); + static constexpr const char kTypeName[] = "ModeOptions"; + static ModeOptions Defaults() { return ModeOptions{}; } + + int64_t n = 1; + /// If true (the default), null values are ignored. Otherwise, if any value is null, + /// emit null. + bool skip_nulls; + /// If less than this many non-null values are observed, emit null. + uint32_t min_count; +}; + +/// \brief Control Delta Degrees of Freedom (ddof) of Variance and Stddev kernel +/// +/// The divisor used in calculations is N - ddof, where N is the number of elements. +/// By default, ddof is zero, and population variance or stddev is returned. +class ARROW_EXPORT VarianceOptions : public FunctionOptions { + public: + explicit VarianceOptions(int ddof = 0, bool skip_nulls = true, uint32_t min_count = 0); + static constexpr const char kTypeName[] = "VarianceOptions"; + static VarianceOptions Defaults() { return VarianceOptions{}; } + + int ddof = 0; + /// If true (the default), null values are ignored. Otherwise, if any value is null, + /// emit null. + bool skip_nulls; + /// If less than this many non-null values are observed, emit null. + uint32_t min_count; +}; + +/// \brief Control Skew and Kurtosis kernel behavior +class ARROW_EXPORT SkewOptions : public FunctionOptions { + public: + explicit SkewOptions(bool skip_nulls = true, bool biased = true, + uint32_t min_count = 0); + static constexpr const char kTypeName[] = "SkewOptions"; + static SkewOptions Defaults() { return SkewOptions{}; } + + /// If true (the default), null values are ignored. Otherwise, if any value is null, + /// emit null. + bool skip_nulls; + /// If true (the default), the calculated value is biased. If false, the calculated + /// value includes a correction factor to reduce bias, making it more accurate for + /// small sample sizes. + bool biased; + /// If less than this many non-null values are observed, emit null. + uint32_t min_count; +}; + +/// \brief Control Quantile kernel behavior +/// +/// By default, returns the median value. +class ARROW_EXPORT QuantileOptions : public FunctionOptions { + public: + /// Interpolation method to use when quantile lies between two data points + enum Interpolation { + LINEAR = 0, + LOWER, + HIGHER, + NEAREST, + MIDPOINT, + }; + + explicit QuantileOptions(double q = 0.5, enum Interpolation interpolation = LINEAR, + bool skip_nulls = true, uint32_t min_count = 0); + + explicit QuantileOptions(std::vector q, + enum Interpolation interpolation = LINEAR, + bool skip_nulls = true, uint32_t min_count = 0); + + static constexpr const char kTypeName[] = "QuantileOptions"; + static QuantileOptions Defaults() { return QuantileOptions{}; } + + /// probability level of quantile must be between 0 and 1 inclusive + std::vector q; + enum Interpolation interpolation; + /// If true (the default), null values are ignored. Otherwise, if any value is null, + /// emit null. + bool skip_nulls; + /// If less than this many non-null values are observed, emit null. + uint32_t min_count; +}; + +/// \brief Control TDigest approximate quantile kernel behavior +/// +/// By default, returns the median value. +class ARROW_EXPORT TDigestOptions : public FunctionOptions { + public: + explicit TDigestOptions(double q = 0.5, uint32_t delta = 100, + uint32_t buffer_size = 500, bool skip_nulls = true, + uint32_t min_count = 0); + explicit TDigestOptions(std::vector q, uint32_t delta = 100, + uint32_t buffer_size = 500, bool skip_nulls = true, + uint32_t min_count = 0); + static constexpr const char kTypeName[] = "TDigestOptions"; + static TDigestOptions Defaults() { return TDigestOptions{}; } + + /// probability level of quantile must be between 0 and 1 inclusive + std::vector q; + /// compression parameter, default 100 + uint32_t delta; + /// input buffer size, default 500 + uint32_t buffer_size; + /// If true (the default), null values are ignored. Otherwise, if any value is null, + /// emit null. + bool skip_nulls; + /// If less than this many non-null values are observed, emit null. + uint32_t min_count; +}; + +/// \brief Control Pivot kernel behavior +/// +/// These options apply to the "pivot_wider" and "hash_pivot_wider" functions. +/// +/// Constraints: +/// - The corresponding `Aggregate::target` must have two FieldRef elements; +/// the first one points to the pivot key column, the second points to the +/// pivoted data column. +/// - The pivot key column can be string, binary or integer; its values will be +/// matched against `key_names` in order to dispatch the pivoted data into +/// the output. If the pivot key column is not string-like, the `key_names` +/// will be cast to the pivot key type. +/// +/// "pivot_wider" example +/// --------------------- +/// +/// Assuming the following two input columns with types utf8 and int16 (respectively): +/// ``` +/// width | 11 +/// height | 13 +/// ``` +/// and the options `PivotWiderOptions(.key_names = {"height", "width"})` +/// +/// then the output will be a scalar with the type +/// `struct{"height": int16, "width": int16}` +/// and the value `{"height": 13, "width": 11}`. +/// +/// "hash_pivot_wider" example +/// -------------------------- +/// +/// Assuming the following input with schema +/// `{"group": int32, "key": utf8, "value": int16}`: +/// ``` +/// group | key | value +/// ----------------------------- +/// 1 | height | 11 +/// 1 | width | 12 +/// 2 | width | 13 +/// 3 | height | 14 +/// 3 | depth | 15 +/// ``` +/// and the following settings: +/// - a hash grouping key "group" +/// - Aggregate( +/// .function = "hash_pivot_wider", +/// .options = PivotWiderOptions(.key_names = {"height", "width"}), +/// .target = {"key", "value"}, +/// .name = {"properties"}) +/// +/// then the output will have the schema +/// `{"group": int32, "properties": struct{"height": int16, "width": int16}}` +/// and the following value: +/// ``` +/// group | properties +/// | height | width +/// ----------------------------- +/// 1 | 11 | 12 +/// 2 | null | 13 +/// 3 | 14 | null +/// ``` +class ARROW_EXPORT PivotWiderOptions : public FunctionOptions { + public: + /// Configure the behavior of pivot keys not in `key_names` + enum UnexpectedKeyBehavior { + /// Unexpected pivot keys are ignored silently + kIgnore, + /// Unexpected pivot keys return a KeyError + kRaise + }; + + explicit PivotWiderOptions(std::vector key_names, + UnexpectedKeyBehavior unexpected_key_behavior = kIgnore); + // Default constructor for serialization + PivotWiderOptions(); + static constexpr const char kTypeName[] = "PivotWiderOptions"; + static PivotWiderOptions Defaults() { return PivotWiderOptions{}; } + + /// The values expected in the pivot key column + std::vector key_names; + /// The behavior when pivot keys not in `key_names` are encountered + UnexpectedKeyBehavior unexpected_key_behavior = kIgnore; +}; + +/// \brief Control Index kernel behavior +class ARROW_EXPORT IndexOptions : public FunctionOptions { + public: + explicit IndexOptions(std::shared_ptr value); + // Default constructor for serialization + IndexOptions(); + static constexpr const char kTypeName[] = "IndexOptions"; + + std::shared_ptr value; +}; + +/// \brief Configure a grouped aggregation +struct ARROW_EXPORT Aggregate { + Aggregate() = default; + + Aggregate(std::string function, std::shared_ptr options, + std::vector target, std::string name = "") + : function(std::move(function)), + options(std::move(options)), + target(std::move(target)), + name(std::move(name)) {} + + Aggregate(std::string function, std::shared_ptr options, + FieldRef target, std::string name = "") + : Aggregate(std::move(function), std::move(options), + std::vector{std::move(target)}, std::move(name)) {} + + Aggregate(std::string function, FieldRef target, std::string name) + : Aggregate(std::move(function), /*options=*/NULLPTR, + std::vector{std::move(target)}, std::move(name)) {} + + Aggregate(std::string function, std::string name) + : Aggregate(std::move(function), /*options=*/NULLPTR, + /*target=*/std::vector{}, std::move(name)) {} + + /// the name of the aggregation function + std::string function; + + /// options for the aggregation function + std::shared_ptr options; + + /// zero or more fields to which aggregations will be applied + std::vector target; + + /// optional output field name for aggregations + std::string name; +}; + +/// @} + +/// \brief Count values in an array. +/// +/// \param[in] options counting options, see CountOptions for more information +/// \param[in] datum to count +/// \param[in] ctx the function execution context, optional +/// \return out resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Count(const Datum& datum, + const CountOptions& options = CountOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the mean of a numeric array. +/// +/// \param[in] value datum to compute the mean, expecting Array +/// \param[in] options see ScalarAggregateOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed mean as a DoubleScalar +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Mean( + const Datum& value, + const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the product of values of a numeric array. +/// +/// \param[in] value datum to compute product of, expecting Array or ChunkedArray +/// \param[in] options see ScalarAggregateOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed sum as a Scalar +/// +/// \since 6.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Product( + const Datum& value, + const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Sum values of a numeric array. +/// +/// \param[in] value datum to sum, expecting Array or ChunkedArray +/// \param[in] options see ScalarAggregateOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed sum as a Scalar +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Sum( + const Datum& value, + const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the first value of an array +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see ScalarAggregateOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed first as Scalar +/// +/// \since 13.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result First( + const Datum& value, + const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the last value of an array +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see ScalarAggregateOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed last as a Scalar +/// +/// \since 13.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Last( + const Datum& value, + const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the min / max of a numeric array +/// +/// This function returns both the min and max as a struct scalar, with type +/// struct, where T is the input type +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see ScalarAggregateOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return resulting datum as a struct scalar +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result MinMax( + const Datum& value, + const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Test whether any element in a boolean array evaluates to true. +/// +/// This function returns true if any of the elements in the array evaluates +/// to true and false otherwise. Null values are ignored by default. +/// If null values are taken into account by setting ScalarAggregateOptions +/// parameter skip_nulls = false then Kleene logic is used. +/// See KleeneOr for more details on Kleene logic. +/// +/// \param[in] value input datum, expecting a boolean array +/// \param[in] options see ScalarAggregateOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return resulting datum as a BooleanScalar +/// +/// \since 3.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Any( + const Datum& value, + const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Test whether all elements in a boolean array evaluate to true. +/// +/// This function returns true if all of the elements in the array evaluate +/// to true and false otherwise. Null values are ignored by default. +/// If null values are taken into account by setting ScalarAggregateOptions +/// parameter skip_nulls = false then Kleene logic is used. +/// See KleeneAnd for more details on Kleene logic. +/// +/// \param[in] value input datum, expecting a boolean array +/// \param[in] options see ScalarAggregateOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return resulting datum as a BooleanScalar + +/// \since 3.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result All( + const Datum& value, + const ScalarAggregateOptions& options = ScalarAggregateOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the modal (most common) value of a numeric array +/// +/// This function returns top-n most common values and number of times they occur as +/// an array of `struct`, where T is the input type. +/// Values with larger counts are returned before smaller ones. +/// If there are more than one values with same count, smaller value is returned first. +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see ModeOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return resulting datum as an array of struct +/// +/// \since 2.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Mode(const Datum& value, + const ModeOptions& options = ModeOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the standard deviation of a numeric array +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see VarianceOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed standard deviation as a DoubleScalar +/// +/// \since 2.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Stddev(const Datum& value, + const VarianceOptions& options = VarianceOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the variance of a numeric array +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see VarianceOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed variance as a DoubleScalar +/// +/// \since 2.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Variance(const Datum& value, + const VarianceOptions& options = VarianceOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the skewness of a numeric array +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see SkewOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed skewness as a DoubleScalar +/// +/// \since 20.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Skew(const Datum& value, + const SkewOptions& options = SkewOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the kurtosis of a numeric array +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see SkewOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return datum of the computed kurtosis as a DoubleScalar +/// +/// \since 20.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Kurtosis(const Datum& value, + const SkewOptions& options = SkewOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the quantiles of a numeric array +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see QuantileOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return resulting datum as an array +/// +/// \since 4.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Quantile(const Datum& value, + const QuantileOptions& options = QuantileOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Calculate the approximate quantiles of a numeric array with T-Digest algorithm +/// +/// \param[in] value input datum, expecting Array or ChunkedArray +/// \param[in] options see TDigestOptions for more information +/// \param[in] ctx the function execution context, optional +/// \return resulting datum as an array +/// +/// \since 4.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result TDigest(const Datum& value, + const TDigestOptions& options = TDigestOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Find the first index of a value in an array. +/// +/// \param[in] value The array to search. +/// \param[in] options The array to search for. See IndexOptions. +/// \param[in] ctx the function execution context, optional +/// \return out a Scalar containing the index (or -1 if not found). +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Index(const Datum& value, const IndexOptions& options, + ExecContext* ctx = NULLPTR); + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/api_scalar.h b/pyarrow/include/arrow/compute/api_scalar.h new file mode 100644 index 0000000000000000000000000000000000000000..8b341e865a1665ee18229d0b78ad1aaf2d778325 --- /dev/null +++ b/pyarrow/include/arrow/compute/api_scalar.h @@ -0,0 +1,1802 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Eager evaluation convenience APIs for invoking common functions, including +// necessary memory allocations + +#pragma once + +#include +#include +#include + +#include "arrow/compute/function_options.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +/// \addtogroup compute-concrete-options +/// +/// @{ + +class ARROW_EXPORT ArithmeticOptions : public FunctionOptions { + public: + explicit ArithmeticOptions(bool check_overflow = false); + static constexpr const char kTypeName[] = "ArithmeticOptions"; + bool check_overflow; +}; + +class ARROW_EXPORT ElementWiseAggregateOptions : public FunctionOptions { + public: + explicit ElementWiseAggregateOptions(bool skip_nulls = true); + static constexpr const char kTypeName[] = "ElementWiseAggregateOptions"; + static ElementWiseAggregateOptions Defaults() { return ElementWiseAggregateOptions{}; } + bool skip_nulls; +}; + +/// Rounding and tie-breaking modes for round compute functions. +/// Additional details and examples are provided in compute.rst. +enum class RoundMode : int8_t { + /// Round to nearest integer less than or equal in magnitude (aka "floor") + DOWN, + /// Round to nearest integer greater than or equal in magnitude (aka "ceil") + UP, + /// Get the integral part without fractional digits (aka "trunc") + TOWARDS_ZERO, + /// Round negative values with DOWN rule + /// and positive values with UP rule (aka "away from zero") + TOWARDS_INFINITY, + /// Round ties with DOWN rule (also called "round half towards negative infinity") + HALF_DOWN, + /// Round ties with UP rule (also called "round half towards positive infinity") + HALF_UP, + /// Round ties with TOWARDS_ZERO rule (also called "round half away from infinity") + HALF_TOWARDS_ZERO, + /// Round ties with TOWARDS_INFINITY rule (also called "round half away from zero") + HALF_TOWARDS_INFINITY, + /// Round ties to nearest even integer + HALF_TO_EVEN, + /// Round ties to nearest odd integer + HALF_TO_ODD, +}; + +class ARROW_EXPORT RoundOptions : public FunctionOptions { + public: + explicit RoundOptions(int64_t ndigits = 0, + RoundMode round_mode = RoundMode::HALF_TO_EVEN); + static constexpr const char kTypeName[] = "RoundOptions"; + static RoundOptions Defaults() { return RoundOptions(); } + /// Rounding precision (number of digits to round to) + int64_t ndigits; + /// Rounding and tie-breaking mode + RoundMode round_mode; +}; + +class ARROW_EXPORT RoundBinaryOptions : public FunctionOptions { + public: + explicit RoundBinaryOptions(RoundMode round_mode = RoundMode::HALF_TO_EVEN); + static constexpr const char kTypeName[] = "RoundBinaryOptions"; + static RoundBinaryOptions Defaults() { return RoundBinaryOptions(); } + /// Rounding and tie-breaking mode + RoundMode round_mode; +}; + +enum class CalendarUnit : int8_t { + NANOSECOND, + MICROSECOND, + MILLISECOND, + SECOND, + MINUTE, + HOUR, + DAY, + WEEK, + MONTH, + QUARTER, + YEAR +}; + +class ARROW_EXPORT RoundTemporalOptions : public FunctionOptions { + public: + explicit RoundTemporalOptions(int multiple = 1, CalendarUnit unit = CalendarUnit::DAY, + bool week_starts_monday = true, + bool ceil_is_strictly_greater = false, + bool calendar_based_origin = false); + static constexpr const char kTypeName[] = "RoundTemporalOptions"; + static RoundTemporalOptions Defaults() { return RoundTemporalOptions(); } + + /// Number of units to round to + int multiple; + /// The unit used for rounding of time + CalendarUnit unit; + /// What day does the week start with (Monday=true, Sunday=false) + bool week_starts_monday; + /// Enable this flag to return a rounded value that is strictly greater than the input. + /// For example: ceiling 1970-01-01T00:00:00 to 3 hours would yield 1970-01-01T03:00:00 + /// if set to true and 1970-01-01T00:00:00 if set to false. + /// This applies for ceiling only. + bool ceil_is_strictly_greater; + /// By default time is rounded to a multiple of units since 1970-01-01T00:00:00. + /// By setting calendar_based_origin to true, time will be rounded to a number + /// of units since the last greater calendar unit. + /// For example: rounding to a multiple of days since the beginning of the month or + /// to hours since the beginning of the day. + /// Exceptions: week and quarter are not used as greater units, therefore days will + /// will be rounded to the beginning of the month not week. Greater unit of week + /// is year. + /// Note that ceiling and rounding might change sorting order of an array near greater + /// unit change. For example rounding YYYY-mm-dd 23:00:00 to 5 hours will ceil and + /// round to YYYY-mm-dd+1 01:00:00 and floor to YYYY-mm-dd 20:00:00. On the other hand + /// YYYY-mm-dd+1 00:00:00 will ceil, round and floor to YYYY-mm-dd+1 00:00:00. This + /// can break the order of an already ordered array. + bool calendar_based_origin; +}; + +class ARROW_EXPORT RoundToMultipleOptions : public FunctionOptions { + public: + explicit RoundToMultipleOptions(double multiple = 1.0, + RoundMode round_mode = RoundMode::HALF_TO_EVEN); + explicit RoundToMultipleOptions(std::shared_ptr multiple, + RoundMode round_mode = RoundMode::HALF_TO_EVEN); + static constexpr const char kTypeName[] = "RoundToMultipleOptions"; + static RoundToMultipleOptions Defaults() { return RoundToMultipleOptions(); } + /// Rounding scale (multiple to round to). + /// + /// Should be a positive numeric scalar of a type compatible with the + /// argument to be rounded. The cast kernel is used to convert the rounding + /// multiple to match the result type. + std::shared_ptr multiple; + /// Rounding and tie-breaking mode + RoundMode round_mode; +}; + +/// Options for var_args_join. +class ARROW_EXPORT JoinOptions : public FunctionOptions { + public: + /// How to handle null values. (A null separator always results in a null output.) + enum NullHandlingBehavior { + /// A null in any input results in a null in the output. + EMIT_NULL, + /// Nulls in inputs are skipped. + SKIP, + /// Nulls in inputs are replaced with the replacement string. + REPLACE, + }; + explicit JoinOptions(NullHandlingBehavior null_handling = EMIT_NULL, + std::string null_replacement = ""); + static constexpr const char kTypeName[] = "JoinOptions"; + static JoinOptions Defaults() { return JoinOptions(); } + NullHandlingBehavior null_handling; + std::string null_replacement; +}; + +class ARROW_EXPORT MatchSubstringOptions : public FunctionOptions { + public: + explicit MatchSubstringOptions(std::string pattern, bool ignore_case = false); + MatchSubstringOptions(); + static constexpr const char kTypeName[] = "MatchSubstringOptions"; + + /// The exact substring (or regex, depending on kernel) to look for inside input values. + std::string pattern; + /// Whether to perform a case-insensitive match. + bool ignore_case; +}; + +class ARROW_EXPORT SplitOptions : public FunctionOptions { + public: + explicit SplitOptions(int64_t max_splits = -1, bool reverse = false); + static constexpr const char kTypeName[] = "SplitOptions"; + + /// Maximum number of splits allowed, or unlimited when -1 + int64_t max_splits; + /// Start splitting from the end of the string (only relevant when max_splits != -1) + bool reverse; +}; + +class ARROW_EXPORT SplitPatternOptions : public FunctionOptions { + public: + explicit SplitPatternOptions(std::string pattern, int64_t max_splits = -1, + bool reverse = false); + SplitPatternOptions(); + static constexpr const char kTypeName[] = "SplitPatternOptions"; + + /// The exact substring to split on. + std::string pattern; + /// Maximum number of splits allowed, or unlimited when -1 + int64_t max_splits; + /// Start splitting from the end of the string (only relevant when max_splits != -1) + bool reverse; +}; + +class ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { + public: + explicit ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement); + ReplaceSliceOptions(); + static constexpr const char kTypeName[] = "ReplaceSliceOptions"; + + /// Index to start slicing at + int64_t start; + /// Index to stop slicing at + int64_t stop; + /// String to replace the slice with + std::string replacement; +}; + +class ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { + public: + explicit ReplaceSubstringOptions(std::string pattern, std::string replacement, + int64_t max_replacements = -1); + ReplaceSubstringOptions(); + static constexpr const char kTypeName[] = "ReplaceSubstringOptions"; + + /// Pattern to match, literal, or regular expression depending on which kernel is used + std::string pattern; + /// String to replace the pattern with + std::string replacement; + /// Max number of substrings to replace (-1 means unbounded) + int64_t max_replacements; +}; + +class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions { + public: + explicit ExtractRegexOptions(std::string pattern); + ExtractRegexOptions(); + static constexpr const char kTypeName[] = "ExtractRegexOptions"; + + /// Regular expression with named capture fields + std::string pattern; +}; + +class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions { + public: + explicit ExtractRegexSpanOptions(std::string pattern); + ExtractRegexSpanOptions(); + static constexpr const char kTypeName[] = "ExtractRegexSpanOptions"; + + /// Regular expression with named capture fields + std::string pattern; +}; + +/// Options for IsIn and IndexIn functions +class ARROW_EXPORT SetLookupOptions : public FunctionOptions { + public: + /// How to handle null values. + enum NullMatchingBehavior { + /// MATCH, any null in `value_set` is successfully matched in + /// the input. + MATCH, + /// SKIP, any null in `value_set` is ignored and nulls in the input + /// produce null (IndexIn) or false (IsIn) values in the output. + SKIP, + /// EMIT_NULL, any null in `value_set` is ignored and nulls in the + /// input produce null (IndexIn and IsIn) values in the output. + EMIT_NULL, + /// INCONCLUSIVE, null values are regarded as unknown values, which is + /// sql-compatible. nulls in the input produce null (IndexIn and IsIn) + /// values in the output. Besides, if `value_set` contains a null, + /// non-null unmatched values in the input also produce null values + /// (IndexIn and IsIn) in the output. + INCONCLUSIVE + }; + + explicit SetLookupOptions(Datum value_set, NullMatchingBehavior = MATCH); + SetLookupOptions(); + + // DEPRECATED(will be removed after removing of skip_nulls) + explicit SetLookupOptions(Datum value_set, bool skip_nulls); + + static constexpr const char kTypeName[] = "SetLookupOptions"; + + /// The set of values to look up input values into. + Datum value_set; + + NullMatchingBehavior null_matching_behavior; + + // DEPRECATED(will be removed after removing of skip_nulls) + NullMatchingBehavior GetNullMatchingBehavior() const; + + // DEPRECATED(use null_matching_behavior instead) + /// Whether nulls in `value_set` count for lookup. + /// + /// If true, any null in `value_set` is ignored and nulls in the input + /// produce null (IndexIn) or false (IsIn) values in the output. + /// If false, any null in `value_set` is successfully matched in + /// the input. + std::optional skip_nulls; +}; + +/// Options for struct_field function +class ARROW_EXPORT StructFieldOptions : public FunctionOptions { + public: + explicit StructFieldOptions(std::vector indices); + explicit StructFieldOptions(std::initializer_list); + explicit StructFieldOptions(FieldRef field_ref); + StructFieldOptions(); + static constexpr const char kTypeName[] = "StructFieldOptions"; + + /// The FieldRef specifying what to extract from struct or union. + FieldRef field_ref; +}; + +class ARROW_EXPORT StrptimeOptions : public FunctionOptions { + public: + explicit StrptimeOptions(std::string format, TimeUnit::type unit, + bool error_is_null = false); + StrptimeOptions(); + static constexpr const char kTypeName[] = "StrptimeOptions"; + + /// The desired format string. + std::string format; + /// The desired time resolution + TimeUnit::type unit; + /// Return null on parsing errors if true or raise if false + bool error_is_null; +}; + +class ARROW_EXPORT StrftimeOptions : public FunctionOptions { + public: + explicit StrftimeOptions(std::string format, std::string locale = "C"); + StrftimeOptions(); + + static constexpr const char kTypeName[] = "StrftimeOptions"; + + static constexpr const char* kDefaultFormat = "%Y-%m-%dT%H:%M:%S"; + + /// The desired format string. + std::string format; + /// The desired output locale string. + std::string locale; +}; + +class ARROW_EXPORT PadOptions : public FunctionOptions { + public: + explicit PadOptions(int64_t width, std::string padding = " ", + bool lean_left_on_odd_padding = true); + PadOptions(); + static constexpr const char kTypeName[] = "PadOptions"; + + /// The desired string length. + int64_t width; + /// What to pad the string with. Should be one codepoint (Unicode)/byte (ASCII). + std::string padding; + /// What to do if there is an odd number of padding characters (in case of centered + /// padding). Defaults to aligning on the left (i.e. adding the extra padding character + /// on the right) + bool lean_left_on_odd_padding = true; +}; + +class ARROW_EXPORT ZeroFillOptions : public FunctionOptions { + public: + explicit ZeroFillOptions(int64_t width, std::string padding = "0"); + ZeroFillOptions(); + static constexpr const char kTypeName[] = "ZeroFillOptions"; + + /// The desired string length. + int64_t width; + /// What to pad the string with. Should be one codepoint (Unicode). + std::string padding; +}; + +class ARROW_EXPORT TrimOptions : public FunctionOptions { + public: + explicit TrimOptions(std::string characters); + TrimOptions(); + static constexpr const char kTypeName[] = "TrimOptions"; + + /// The individual characters to be trimmed from the string. + std::string characters; +}; + +class ARROW_EXPORT SliceOptions : public FunctionOptions { + public: + explicit SliceOptions(int64_t start, int64_t stop = std::numeric_limits::max(), + int64_t step = 1); + SliceOptions(); + static constexpr const char kTypeName[] = "SliceOptions"; + int64_t start, stop, step; +}; + +class ARROW_EXPORT ListSliceOptions : public FunctionOptions { + public: + explicit ListSliceOptions(int64_t start, std::optional stop = std::nullopt, + int64_t step = 1, + std::optional return_fixed_size_list = std::nullopt); + ListSliceOptions(); + static constexpr const char kTypeName[] = "ListSliceOptions"; + /// The start of list slicing. + int64_t start; + /// Optional stop of list slicing. If not set, then slice to end. (NotImplemented) + std::optional stop; + /// Slicing step + int64_t step; + // Whether to return a FixedSizeListArray. If true _and_ stop is after + // a list element's length, nulls will be appended to create the requested slice size. + // Default of `nullopt` will return whatever type it got in. + std::optional return_fixed_size_list; +}; + +class ARROW_EXPORT NullOptions : public FunctionOptions { + public: + explicit NullOptions(bool nan_is_null = false); + static constexpr const char kTypeName[] = "NullOptions"; + static NullOptions Defaults() { return NullOptions{}; } + + bool nan_is_null; +}; + +enum CompareOperator : int8_t { + EQUAL, + NOT_EQUAL, + GREATER, + GREATER_EQUAL, + LESS, + LESS_EQUAL, +}; + +struct ARROW_EXPORT CompareOptions { + explicit CompareOptions(CompareOperator op) : op(op) {} + CompareOptions() : CompareOptions(CompareOperator::EQUAL) {} + enum CompareOperator op; +}; + +class ARROW_EXPORT MakeStructOptions : public FunctionOptions { + public: + MakeStructOptions(std::vector n, std::vector r, + std::vector> m); + explicit MakeStructOptions(std::vector n); + MakeStructOptions(); + static constexpr const char kTypeName[] = "MakeStructOptions"; + + /// Names for wrapped columns + std::vector field_names; + + /// Nullability bits for wrapped columns + std::vector field_nullability; + + /// Metadata attached to wrapped columns + std::vector> field_metadata; +}; + +struct ARROW_EXPORT DayOfWeekOptions : public FunctionOptions { + public: + explicit DayOfWeekOptions(bool count_from_zero = true, uint32_t week_start = 1); + static constexpr const char kTypeName[] = "DayOfWeekOptions"; + static DayOfWeekOptions Defaults() { return DayOfWeekOptions(); } + + /// Number days from 0 if true and from 1 if false + bool count_from_zero; + /// What day does the week start with (Monday=1, Sunday=7). + /// The numbering is unaffected by the count_from_zero parameter. + uint32_t week_start; +}; + +/// Used to control timestamp timezone conversion and handling ambiguous/nonexistent +/// times. +struct ARROW_EXPORT AssumeTimezoneOptions : public FunctionOptions { + public: + /// \brief How to interpret ambiguous local times that can be interpreted as + /// multiple instants (normally two) due to DST shifts. + /// + /// AMBIGUOUS_EARLIEST emits the earliest instant amongst possible interpretations. + /// AMBIGUOUS_LATEST emits the latest instant amongst possible interpretations. + enum Ambiguous { AMBIGUOUS_RAISE, AMBIGUOUS_EARLIEST, AMBIGUOUS_LATEST }; + + /// \brief How to handle local times that do not exist due to DST shifts. + /// + /// NONEXISTENT_EARLIEST emits the instant "just before" the DST shift instant + /// in the given timestamp precision (for example, for a nanoseconds precision + /// timestamp, this is one nanosecond before the DST shift instant). + /// NONEXISTENT_LATEST emits the DST shift instant. + enum Nonexistent { NONEXISTENT_RAISE, NONEXISTENT_EARLIEST, NONEXISTENT_LATEST }; + + explicit AssumeTimezoneOptions(std::string timezone, + Ambiguous ambiguous = AMBIGUOUS_RAISE, + Nonexistent nonexistent = NONEXISTENT_RAISE); + AssumeTimezoneOptions(); + static constexpr const char kTypeName[] = "AssumeTimezoneOptions"; + + /// Timezone to convert timestamps from + std::string timezone; + + /// How to interpret ambiguous local times (due to DST shifts) + Ambiguous ambiguous; + /// How to interpret nonexistent local times (due to DST shifts) + Nonexistent nonexistent; +}; + +struct ARROW_EXPORT WeekOptions : public FunctionOptions { + public: + explicit WeekOptions(bool week_starts_monday = true, bool count_from_zero = false, + bool first_week_is_fully_in_year = false); + static constexpr const char kTypeName[] = "WeekOptions"; + static WeekOptions Defaults() { return WeekOptions{}; } + static WeekOptions ISODefaults() { + return WeekOptions{/*week_starts_monday*/ true, + /*count_from_zero=*/false, + /*first_week_is_fully_in_year=*/false}; + } + static WeekOptions USDefaults() { + return WeekOptions{/*week_starts_monday*/ false, + /*count_from_zero=*/false, + /*first_week_is_fully_in_year=*/false}; + } + + /// What day does the week start with (Monday=true, Sunday=false) + bool week_starts_monday; + /// Dates from current year that fall into last ISO week of the previous year return + /// 0 if true and 52 or 53 if false. + bool count_from_zero; + /// Must the first week be fully in January (true), or is a week that begins on + /// December 29, 30, or 31 considered to be the first week of the new year (false)? + bool first_week_is_fully_in_year; +}; + +struct ARROW_EXPORT Utf8NormalizeOptions : public FunctionOptions { + public: + enum Form { NFC, NFKC, NFD, NFKD }; + + explicit Utf8NormalizeOptions(Form form = NFC); + static Utf8NormalizeOptions Defaults() { return Utf8NormalizeOptions(); } + static constexpr const char kTypeName[] = "Utf8NormalizeOptions"; + + /// The Unicode normalization form to apply + Form form; +}; + +class ARROW_EXPORT RandomOptions : public FunctionOptions { + public: + enum Initializer { SystemRandom, Seed }; + + static RandomOptions FromSystemRandom() { return RandomOptions{SystemRandom, 0}; } + static RandomOptions FromSeed(uint64_t seed) { return RandomOptions{Seed, seed}; } + + RandomOptions(Initializer initializer, uint64_t seed); + RandomOptions(); + static constexpr const char kTypeName[] = "RandomOptions"; + static RandomOptions Defaults() { return RandomOptions(); } + + /// The type of initialization for random number generation - system or provided seed. + Initializer initializer; + /// The seed value used to initialize the random number generation. + uint64_t seed; +}; + +/// Options for map_lookup function +class ARROW_EXPORT MapLookupOptions : public FunctionOptions { + public: + enum Occurrence { + /// Return the first matching value + FIRST, + /// Return the last matching value + LAST, + /// Return all matching values + ALL + }; + + explicit MapLookupOptions(std::shared_ptr query_key, Occurrence occurrence); + MapLookupOptions(); + + constexpr static const char kTypeName[] = "MapLookupOptions"; + + /// The key to lookup in the map + std::shared_ptr query_key; + + /// Whether to return the first, last, or all matching values + Occurrence occurrence; +}; + +/// @} + +/// \brief Get the absolute value of a value. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the value transformed +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise absolute value +ARROW_EXPORT +Result AbsoluteValue(const Datum& arg, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Add two values together. Array values must be the same length. If +/// either addend is null the result will be null. +/// +/// \param[in] left the first addend +/// \param[in] right the second addend +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise sum +ARROW_EXPORT +Result Add(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Subtract two values. Array values must be the same length. If the +/// minuend or subtrahend is null the result will be null. +/// +/// \param[in] left the value subtracted from (minuend) +/// \param[in] right the value by which the minuend is reduced (subtrahend) +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise difference +ARROW_EXPORT +Result Subtract(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Multiply two values. Array values must be the same length. If either +/// factor is null the result will be null. +/// +/// \param[in] left the first factor +/// \param[in] right the second factor +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise product +ARROW_EXPORT +Result Multiply(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Divide two values. Array values must be the same length. If either +/// argument is null the result will be null. For integer types, if there is +/// a zero divisor, an error will be raised. +/// +/// \param[in] left the dividend +/// \param[in] right the divisor +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise quotient +ARROW_EXPORT +Result Divide(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Negate values. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the value negated +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise negation +ARROW_EXPORT +Result Negate(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Raise the values of base array to the power of the exponent array values. +/// Array values must be the same length. If either base or exponent is null the result +/// will be null. +/// +/// \param[in] left the base +/// \param[in] right the exponent +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise base value raised to the power of exponent +ARROW_EXPORT +Result Power(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Raise Euler's number to the power of specified exponent, element-wise. +/// If the exponent value is null the result will be null. +/// +/// \param[in] arg the exponent +/// \param[in] ctx the function execution context, optional +/// \return the element-wise Euler's number raised to the power of exponent +ARROW_EXPORT +Result Exp(const Datum& arg, ExecContext* ctx = NULLPTR); + +/// \brief More accurately calculate `exp(arg) - 1` for values close to zero. +/// If the exponent value is null the result will be null. +/// +/// This function is more accurate than calculating `exp(value) - 1` directly for values +/// close to zero. +/// +/// \param[in] arg the exponent +/// \param[in] ctx the function execution context, optional +/// \return the element-wise Euler's number raised to the power of exponent minus 1 +ARROW_EXPORT +Result Expm1(const Datum& arg, ExecContext* ctx = NULLPTR); + +/// \brief Left shift the left array by the right array. Array values must be the +/// same length. If either operand is null, the result will be null. +/// +/// \param[in] left the value to shift +/// \param[in] right the value to shift by +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise left value shifted left by the right value +ARROW_EXPORT +Result ShiftLeft(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Right shift the left array by the right array. Array values must be the +/// same length. If either operand is null, the result will be null. Performs a +/// logical shift for unsigned values, and an arithmetic shift for signed values. +/// +/// \param[in] left the value to shift +/// \param[in] right the value to shift by +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise left value shifted right by the right value +ARROW_EXPORT +Result ShiftRight(const Datum& left, const Datum& right, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the sine of the array values. +/// \param[in] arg The values to compute the sine for. +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise sine of the values +ARROW_EXPORT +Result Sin(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the cosine of the array values. +/// \param[in] arg The values to compute the cosine for. +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise cosine of the values +ARROW_EXPORT +Result Cos(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the inverse sine (arcsine) of the array values. +/// \param[in] arg The values to compute the inverse sine for. +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise inverse sine of the values +ARROW_EXPORT +Result Asin(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the inverse cosine (arccosine) of the array values. +/// \param[in] arg The values to compute the inverse cosine for. +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise inverse cosine of the values +ARROW_EXPORT +Result Acos(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the tangent of the array values. +/// \param[in] arg The values to compute the tangent for. +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise tangent of the values +ARROW_EXPORT +Result Tan(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the inverse tangent (arctangent) of the array values. +/// \param[in] arg The values to compute the inverse tangent for. +/// \param[in] ctx the function execution context, optional +/// \return the elementwise inverse tangent of the values +ARROW_EXPORT +Result Atan(const Datum& arg, ExecContext* ctx = NULLPTR); + +/// \brief Compute the inverse tangent (arctangent) of y/x, using the +/// argument signs to determine the correct quadrant. +/// \param[in] y The y-values to compute the inverse tangent for. +/// \param[in] x The x-values to compute the inverse tangent for. +/// \param[in] ctx the function execution context, optional +/// \return the elementwise inverse tangent of the values +ARROW_EXPORT +Result Atan2(const Datum& y, const Datum& x, ExecContext* ctx = NULLPTR); + +/// \brief Compute the hyperbolic sine of the array values. +/// \param[in] arg The values to compute the hyperbolic sine for. +/// \param[in] ctx the function execution context, optional +/// \return the elementwise hyperbolic sine of the values +ARROW_EXPORT +Result Sinh(const Datum& arg, ExecContext* ctx = NULLPTR); + +/// \brief Compute the hyperbolic cosine of the array values. +/// \param[in] arg The values to compute the hyperbolic cosine for. +/// \param[in] ctx the function execution context, optional +/// \return the elementwise hyperbolic cosine of the values +ARROW_EXPORT +Result Cosh(const Datum& arg, ExecContext* ctx = NULLPTR); + +/// \brief Compute the hyperbolic tangent of the array values. +/// \param[in] arg The values to compute the hyperbolic tangent for. +/// \param[in] ctx the function execution context, optional +/// \return the elementwise hyperbolic tangent of the values +ARROW_EXPORT +Result Tanh(const Datum& arg, ExecContext* ctx = NULLPTR); + +/// \brief Compute the inverse hyperbolic sine of the array values. +/// \param[in] arg The values to compute the inverse hyperbolic sine for. +/// \param[in] ctx the function execution context, optional +/// \return the elementwise inverse hyperbolic sine of the values +ARROW_EXPORT +Result Asinh(const Datum& arg, ExecContext* ctx = NULLPTR); + +/// \brief Compute the inverse hyperbolic cosine of the array values. +/// \param[in] arg The values to compute the inverse hyperbolic cosine for. +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise inverse hyperbolic cosine of the values +ARROW_EXPORT +Result Acosh(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the inverse hyperbolic tangent of the array values. +/// \param[in] arg The values to compute the inverse hyperbolic tangent for. +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise inverse hyperbolic tangent of the values +ARROW_EXPORT +Result Atanh(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Get the natural log of a value. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg The values to compute the logarithm for. +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise natural log +ARROW_EXPORT +Result Ln(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Get the log base 10 of a value. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg The values to compute the logarithm for. +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise log base 10 +ARROW_EXPORT +Result Log10(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Get the log base 2 of a value. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg The values to compute the logarithm for. +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise log base 2 +ARROW_EXPORT +Result Log2(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Get the natural log of (1 + value). +/// +/// If argument is null the result will be null. +/// This function may be more accurate than Log(1 + value) for values close to zero. +/// +/// \param[in] arg The values to compute the logarithm for. +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise natural log +ARROW_EXPORT +Result Log1p(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Get the log of a value to the given base. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg The values to compute the logarithm for. +/// \param[in] base The given base. +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise log to the given base +ARROW_EXPORT +Result Logb(const Datum& arg, const Datum& base, + ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Get the square-root of a value. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg The values to compute the square-root for. +/// \param[in] options arithmetic options (overflow handling), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise square-root +ARROW_EXPORT +Result Sqrt(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Round to the nearest integer less than or equal in magnitude to the +/// argument. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the value to round +/// \param[in] ctx the function execution context, optional +/// \return the rounded value +ARROW_EXPORT +Result Floor(const Datum& arg, ExecContext* ctx = NULLPTR); + +/// \brief Round to the nearest integer greater than or equal in magnitude to the +/// argument. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the value to round +/// \param[in] ctx the function execution context, optional +/// \return the rounded value +ARROW_EXPORT +Result Ceil(const Datum& arg, ExecContext* ctx = NULLPTR); + +/// \brief Get the integral part without fractional digits. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the value to truncate +/// \param[in] ctx the function execution context, optional +/// \return the truncated value +ARROW_EXPORT +Result Trunc(const Datum& arg, ExecContext* ctx = NULLPTR); + +/// \brief Find the element-wise maximum of any number of arrays or scalars. +/// Array values must be the same length. +/// +/// \param[in] args arrays or scalars to operate on. +/// \param[in] options options for handling nulls, optional +/// \param[in] ctx the function execution context, optional +/// \return the element-wise maximum +ARROW_EXPORT +Result MaxElementWise( + const std::vector& args, + ElementWiseAggregateOptions options = ElementWiseAggregateOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Find the element-wise minimum of any number of arrays or scalars. +/// Array values must be the same length. +/// +/// \param[in] args arrays or scalars to operate on. +/// \param[in] options options for handling nulls, optional +/// \param[in] ctx the function execution context, optional +/// \return the element-wise minimum +ARROW_EXPORT +Result MinElementWise( + const std::vector& args, + ElementWiseAggregateOptions options = ElementWiseAggregateOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Get the sign of a value. Array values can be of arbitrary length. If argument +/// is null the result will be null. +/// +/// \param[in] arg the value to extract sign from +/// \param[in] ctx the function execution context, optional +/// \return the element-wise sign function +ARROW_EXPORT +Result Sign(const Datum& arg, ExecContext* ctx = NULLPTR); + +/// \brief Round a value to a given precision. +/// +/// If arg is null the result will be null. +/// +/// \param[in] arg the value to be rounded +/// \param[in] options rounding options (rounding mode and number of digits), optional +/// \param[in] ctx the function execution context, optional +/// \return the element-wise rounded value +ARROW_EXPORT +Result Round(const Datum& arg, RoundOptions options = RoundOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Round a value to a given precision. +/// +/// If arg1 is null the result will be null. +/// If arg2 is null then the result will be null. If arg2 is negative, then the rounding +/// place will be shifted to the left (thus -1 would correspond to rounding to the nearest +/// ten). If positive, the rounding place will shift to the right (and +1 would +/// correspond to rounding to the nearest tenth). +/// +/// \param[in] arg1 the value to be rounded +/// \param[in] arg2 the number of significant digits to round to +/// \param[in] options rounding options, optional +/// \param[in] ctx the function execution context, optional +/// \return the element-wise rounded value +ARROW_EXPORT +Result RoundBinary(const Datum& arg1, const Datum& arg2, + RoundBinaryOptions options = RoundBinaryOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Round a value to a given multiple. +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the value to round +/// \param[in] options rounding options (rounding mode and multiple), optional +/// \param[in] ctx the function execution context, optional +/// \return the element-wise rounded value +ARROW_EXPORT +Result RoundToMultiple( + const Datum& arg, RoundToMultipleOptions options = RoundToMultipleOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Ceil a temporal value to a given frequency +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the temporal value to ceil +/// \param[in] options temporal rounding options, optional +/// \param[in] ctx the function execution context, optional +/// \return the element-wise rounded value +/// +/// \since 7.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result CeilTemporal( + const Datum& arg, RoundTemporalOptions options = RoundTemporalOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Floor a temporal value to a given frequency +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the temporal value to floor +/// \param[in] options temporal rounding options, optional +/// \param[in] ctx the function execution context, optional +/// \return the element-wise rounded value +/// +/// \since 7.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result FloorTemporal( + const Datum& arg, RoundTemporalOptions options = RoundTemporalOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Round a temporal value to a given frequency +/// +/// If argument is null the result will be null. +/// +/// \param[in] arg the temporal value to round +/// \param[in] options temporal rounding options, optional +/// \param[in] ctx the function execution context, optional +/// \return the element-wise rounded value +/// +/// \since 7.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result RoundTemporal( + const Datum& arg, RoundTemporalOptions options = RoundTemporalOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Invert the values of a boolean datum +/// \param[in] value datum to invert +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Invert(const Datum& value, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise AND of two boolean datums which always propagates nulls +/// (null and false is null). +/// +/// \param[in] left left operand +/// \param[in] right right operand +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result And(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise AND of two boolean datums with a Kleene truth table +/// (null and false is false). +/// +/// \param[in] left left operand +/// \param[in] right right operand +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result KleeneAnd(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Element-wise OR of two boolean datums which always propagates nulls +/// (null and true is null). +/// +/// \param[in] left left operand +/// \param[in] right right operand +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Or(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise OR of two boolean datums with a Kleene truth table +/// (null or true is true). +/// +/// \param[in] left left operand +/// \param[in] right right operand +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result KleeneOr(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise XOR of two boolean datums +/// \param[in] left left operand +/// \param[in] right right operand +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Xor(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise AND NOT of two boolean datums which always propagates nulls +/// (null and not true is null). +/// +/// \param[in] left left operand +/// \param[in] right right operand +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 3.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result AndNot(const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); + +/// \brief Element-wise AND NOT of two boolean datums with a Kleene truth table +/// (false and not null is false, null and not true is false). +/// +/// \param[in] left left operand +/// \param[in] right right operand +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 3.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result KleeneAndNot(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief IsIn returns true for each element of `values` that is contained in +/// `value_set` +/// +/// Behaviour of nulls is governed by SetLookupOptions::skip_nulls. +/// +/// \param[in] values array-like input to look up in value_set +/// \param[in] options SetLookupOptions +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result IsIn(const Datum& values, const SetLookupOptions& options, + ExecContext* ctx = NULLPTR); +ARROW_EXPORT +Result IsIn(const Datum& values, const Datum& value_set, + ExecContext* ctx = NULLPTR); + +/// \brief IndexIn examines each slot in the values against a value_set array. +/// If the value is not found in value_set, null will be output. +/// If found, the index of occurrence within value_set (ignoring duplicates) +/// will be output. +/// +/// For example given values = [99, 42, 3, null] and +/// value_set = [3, 3, 99], the output will be = [2, null, 0, null] +/// +/// Behaviour of nulls is governed by SetLookupOptions::skip_nulls. +/// +/// \param[in] values array-like input +/// \param[in] options SetLookupOptions +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result IndexIn(const Datum& values, const SetLookupOptions& options, + ExecContext* ctx = NULLPTR); +ARROW_EXPORT +Result IndexIn(const Datum& values, const Datum& value_set, + ExecContext* ctx = NULLPTR); + +/// \brief IsValid returns true for each element of `values` that is not null, +/// false otherwise +/// +/// \param[in] values input to examine for validity +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result IsValid(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief IsNull returns true for each element of `values` that is null, +/// false otherwise +/// +/// \param[in] values input to examine for nullity +/// \param[in] options NullOptions +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result IsNull(const Datum& values, NullOptions options = NullOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief IsNan returns true for each element of `values` that is NaN, +/// false otherwise +/// +/// \param[in] values input to look for NaN +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 3.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result IsNan(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief IfElse returns elements chosen from `left` or `right` +/// depending on `cond`. `null` values in `cond` will be promoted to the result +/// +/// \param[in] cond `Boolean` condition Scalar/ Array +/// \param[in] left Scalar/ Array +/// \param[in] right Scalar/ Array +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result IfElse(const Datum& cond, const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief CaseWhen behaves like a switch/case or if-else if-else statement: for +/// each row, select the first value for which the corresponding condition is +/// true, or (if given) select the 'else' value, else emit null. Note that a +/// null condition is the same as false. +/// +/// \param[in] cond Conditions (Boolean) +/// \param[in] cases Values (any type), along with an optional 'else' value. +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result CaseWhen(const Datum& cond, const std::vector& cases, + ExecContext* ctx = NULLPTR); + +/// \brief Year returns year for each element of `values` +/// +/// \param[in] values input to extract year from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Year(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief IsLeapYear returns if a year is a leap year for each element of `values` +/// +/// \param[in] values input to extract leap year indicator from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result IsLeapYear(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Month returns month for each element of `values`. +/// Month is encoded as January=1, December=12 +/// +/// \param[in] values input to extract month from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Month(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Day returns day number for each element of `values` +/// +/// \param[in] values input to extract day from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Day(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief YearMonthDay returns a struct containing the Year, Month and Day value for +/// each element of `values`. +/// +/// \param[in] values input to extract (year, month, day) struct from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 7.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result YearMonthDay(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief DayOfWeek returns number of the day of the week value for each element of +/// `values`. +/// +/// By default week starts on Monday denoted by 0 and ends on Sunday denoted +/// by 6. Start day of the week (Monday=1, Sunday=7) and numbering base (0 or 1) can be +/// set using DayOfWeekOptions +/// +/// \param[in] values input to extract number of the day of the week from +/// \param[in] options for setting start of the week and day numbering +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result DayOfWeek(const Datum& values, + DayOfWeekOptions options = DayOfWeekOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief DayOfYear returns number of day of the year for each element of `values`. +/// January 1st maps to day number 1, February 1st to 32, etc. +/// +/// \param[in] values input to extract number of day of the year from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result DayOfYear(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief ISOYear returns ISO year number for each element of `values`. +/// First week of an ISO year has the majority (4 or more) of its days in January. +/// +/// \param[in] values input to extract ISO year from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result ISOYear(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief USYear returns US epidemiological year number for each element of `values`. +/// First week of US epidemiological year has the majority (4 or more) of it's +/// days in January. Last week of US epidemiological year has the year's last +/// Wednesday in it. US epidemiological week starts on Sunday. +/// +/// \param[in] values input to extract US epidemiological year from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result USYear(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief ISOWeek returns ISO week of year number for each element of `values`. +/// First ISO week has the majority (4 or more) of its days in January. +/// ISO week starts on Monday. Year can have 52 or 53 weeks. +/// Week numbering can start with 1. +/// +/// \param[in] values input to extract ISO week of year from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result ISOWeek(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief USWeek returns US week of year number for each element of `values`. +/// First US week has the majority (4 or more) of its days in January. +/// US week starts on Sunday. Year can have 52 or 53 weeks. +/// Week numbering starts with 1. +/// +/// \param[in] values input to extract US week of year from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 6.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result USWeek(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Week returns week of year number for each element of `values`. +/// First ISO week has the majority (4 or more) of its days in January. +/// Year can have 52 or 53 weeks. Week numbering can start with 0 or 1 +/// depending on DayOfWeekOptions.count_from_zero. +/// +/// \param[in] values input to extract week of year from +/// \param[in] options for setting numbering start +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 6.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result Week(const Datum& values, WeekOptions options = WeekOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief ISOCalendar returns a (ISO year, ISO week, ISO day of week) struct for +/// each element of `values`. +/// ISO week starts on Monday denoted by 1 and ends on Sunday denoted by 7. +/// +/// \param[in] values input to ISO calendar struct from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result ISOCalendar(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Quarter returns the quarter of year number for each element of `values` +/// First quarter maps to 1 and fourth quarter maps to 4. +/// +/// \param[in] values input to extract quarter of year from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result Quarter(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Hour returns hour value for each element of `values` +/// +/// \param[in] values input to extract hour from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Hour(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Minute returns minutes value for each element of `values` +/// +/// \param[in] values input to extract minutes from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Minute(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Second returns seconds value for each element of `values` +/// +/// \param[in] values input to extract seconds from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Second(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Millisecond returns number of milliseconds since the last full second +/// for each element of `values` +/// +/// \param[in] values input to extract milliseconds from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Millisecond(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Microsecond returns number of microseconds since the last full millisecond +/// for each element of `values` +/// +/// \param[in] values input to extract microseconds from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Microsecond(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Nanosecond returns number of nanoseconds since the last full millisecond +/// for each element of `values` +/// +/// \param[in] values input to extract nanoseconds from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Nanosecond(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Subsecond returns the fraction of second elapsed since last full second +/// as a float for each element of `values` +/// +/// \param[in] values input to extract subsecond from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result Subsecond(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Format timestamps according to a format string +/// +/// Return formatted time strings according to the format string +/// `StrftimeOptions::format` and to the locale specifier `Strftime::locale`. +/// +/// \param[in] values input timestamps +/// \param[in] options for setting format string and locale +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 6.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result Strftime(const Datum& values, StrftimeOptions options, + ExecContext* ctx = NULLPTR); + +/// \brief Parse timestamps according to a format string +/// +/// Return parsed timestamps according to the format string +/// `StrptimeOptions::format` at time resolution `Strftime::unit`. Parse errors are +/// raised depending on the `Strftime::error_is_null` setting. +/// +/// \param[in] values input strings +/// \param[in] options for setting format string, unit and error_is_null +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result Strptime(const Datum& values, StrptimeOptions options, + ExecContext* ctx = NULLPTR); + +/// \brief Converts timestamps from local timestamp without a timezone to a timestamp with +/// timezone, interpreting the local timestamp as being in the specified timezone for each +/// element of `values` +/// +/// \param[in] values input to convert +/// \param[in] options for setting source timezone, exception and ambiguous timestamp +/// handling. +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 6.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result AssumeTimezone(const Datum& values, + AssumeTimezoneOptions options, + ExecContext* ctx = NULLPTR); + +/// \brief IsDaylightSavings extracts if currently observing daylight savings for each +/// element of `values` +/// +/// \param[in] values input to extract daylight savings indicator from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result IsDaylightSavings(const Datum& values, + ExecContext* ctx = NULLPTR); + +/// \brief LocalTimestamp converts timestamp to timezone naive local timestamp +/// +/// \param[in] values input to convert to local time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 12.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result LocalTimestamp(const Datum& values, + ExecContext* ctx = NULLPTR); + +/// \brief Years Between finds the number of years between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result YearsBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Quarters Between finds the number of quarters between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result QuartersBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Months Between finds the number of month between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result MonthsBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Weeks Between finds the number of weeks between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result WeeksBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Month Day Nano Between finds the number of months, days, and nanoseconds +/// between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result MonthDayNanoBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief DayTime Between finds the number of days and milliseconds between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result DayTimeBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Days Between finds the number of days between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result DaysBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Hours Between finds the number of hours between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result HoursBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Minutes Between finds the number of minutes between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result MinutesBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Seconds Between finds the number of hours between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result SecondsBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Milliseconds Between finds the number of milliseconds between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result MillisecondsBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Microseconds Between finds the number of microseconds between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result MicrosecondsBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Nanoseconds Between finds the number of nanoseconds between two values +/// +/// \param[in] left input treated as the start time +/// \param[in] right input treated as the end time +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result NanosecondsBetween(const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + +/// \brief Finds either the FIRST, LAST, or ALL items with a key that matches the given +/// query key in a map. +/// +/// Returns an array of items for FIRST and LAST, and an array of list of items for ALL. +/// +/// \param[in] map to look in +/// \param[in] options to pass a query key and choose which matching keys to return +/// (FIRST, LAST or ALL) +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result MapLookup(const Datum& map, MapLookupOptions options, + ExecContext* ctx = NULLPTR); +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/api_vector.h b/pyarrow/include/arrow/compute/api_vector.h new file mode 100644 index 0000000000000000000000000000000000000000..159a787641ee5216ac2f19ec304d3d8e25303e39 --- /dev/null +++ b/pyarrow/include/arrow/compute/api_vector.h @@ -0,0 +1,835 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/compute/function_options.h" +#include "arrow/compute/ordering.h" +#include "arrow/result.h" +#include "arrow/type_fwd.h" + +namespace arrow { +namespace compute { + +class ExecContext; + +/// \addtogroup compute-concrete-options +/// @{ + +class ARROW_EXPORT FilterOptions : public FunctionOptions { + public: + /// Configure the action taken when a slot of the selection mask is null + enum NullSelectionBehavior { + /// The corresponding filtered value will be removed in the output. + DROP, + /// The corresponding filtered value will be null in the output. + EMIT_NULL, + }; + + explicit FilterOptions(NullSelectionBehavior null_selection = DROP); + static constexpr const char kTypeName[] = "FilterOptions"; + static FilterOptions Defaults() { return FilterOptions(); } + + NullSelectionBehavior null_selection_behavior = DROP; +}; + +class ARROW_EXPORT TakeOptions : public FunctionOptions { + public: + explicit TakeOptions(bool boundscheck = true); + static constexpr const char kTypeName[] = "TakeOptions"; + static TakeOptions BoundsCheck() { return TakeOptions(true); } + static TakeOptions NoBoundsCheck() { return TakeOptions(false); } + static TakeOptions Defaults() { return BoundsCheck(); } + + bool boundscheck = true; +}; + +/// \brief Options for the dictionary encode function +class ARROW_EXPORT DictionaryEncodeOptions : public FunctionOptions { + public: + /// Configure how null values will be encoded + enum NullEncodingBehavior { + /// The null value will be added to the dictionary with a proper index. + ENCODE, + /// The null value will be masked in the indices array. + MASK + }; + + explicit DictionaryEncodeOptions(NullEncodingBehavior null_encoding = MASK); + static constexpr const char kTypeName[] = "DictionaryEncodeOptions"; + static DictionaryEncodeOptions Defaults() { return DictionaryEncodeOptions(); } + + NullEncodingBehavior null_encoding_behavior = MASK; +}; + +/// \brief Options for the run-end encode function +class ARROW_EXPORT RunEndEncodeOptions : public FunctionOptions { + public: + explicit RunEndEncodeOptions(std::shared_ptr run_end_type = int32()); + static constexpr const char kTypeName[] = "RunEndEncodeOptions"; + static RunEndEncodeOptions Defaults() { return RunEndEncodeOptions(); } + + std::shared_ptr run_end_type; +}; + +class ARROW_EXPORT ArraySortOptions : public FunctionOptions { + public: + explicit ArraySortOptions(SortOrder order = SortOrder::Ascending, + NullPlacement null_placement = NullPlacement::AtEnd); + static constexpr const char kTypeName[] = "ArraySortOptions"; + static ArraySortOptions Defaults() { return ArraySortOptions(); } + + /// Sorting order + SortOrder order; + /// Whether nulls and NaNs are placed at the start or at the end + NullPlacement null_placement; +}; + +class ARROW_EXPORT SortOptions : public FunctionOptions { + public: + explicit SortOptions(std::vector sort_keys = {}, + NullPlacement null_placement = NullPlacement::AtEnd); + explicit SortOptions(const Ordering& ordering); + static constexpr const char kTypeName[] = "SortOptions"; + static SortOptions Defaults() { return SortOptions(); } + /// Convenience constructor to create an ordering from SortOptions + /// + /// Note: Both classes contain the exact same information. However, + /// sort_options should only be used in a "function options" context while Ordering + /// is used more generally. + Ordering AsOrdering() && { return Ordering(std::move(sort_keys), null_placement); } + Ordering AsOrdering() const& { return Ordering(sort_keys, null_placement); } + + /// Column key(s) to order by and how to order by these sort keys. + std::vector sort_keys; + /// Whether nulls and NaNs are placed at the start or at the end + NullPlacement null_placement; +}; + +/// \brief SelectK options +class ARROW_EXPORT SelectKOptions : public FunctionOptions { + public: + explicit SelectKOptions(int64_t k = -1, std::vector sort_keys = {}); + static constexpr const char kTypeName[] = "SelectKOptions"; + static SelectKOptions Defaults() { return SelectKOptions(); } + + static SelectKOptions TopKDefault(int64_t k, std::vector key_names = {}) { + std::vector keys; + for (const auto& name : key_names) { + keys.emplace_back(SortKey(name, SortOrder::Descending)); + } + if (key_names.empty()) { + keys.emplace_back(SortKey("not-used", SortOrder::Descending)); + } + return SelectKOptions{k, keys}; + } + static SelectKOptions BottomKDefault(int64_t k, + std::vector key_names = {}) { + std::vector keys; + for (const auto& name : key_names) { + keys.emplace_back(SortKey(name, SortOrder::Ascending)); + } + if (key_names.empty()) { + keys.emplace_back(SortKey("not-used", SortOrder::Ascending)); + } + return SelectKOptions{k, keys}; + } + + /// The number of `k` elements to keep. + int64_t k; + /// Column key(s) to order by and how to order by these sort keys. + std::vector sort_keys; +}; + +/// \brief Rank options +class ARROW_EXPORT RankOptions : public FunctionOptions { + public: + /// Configure how ties between equal values are handled + enum Tiebreaker { + /// Ties get the smallest possible rank in sorted order. + Min, + /// Ties get the largest possible rank in sorted order. + Max, + /// Ranks are assigned in order of when ties appear in the input. + /// This ensures the ranks are a stable permutation of the input. + First, + /// The ranks span a dense [1, M] interval where M is the number + /// of distinct values in the input. + Dense + }; + + explicit RankOptions(std::vector sort_keys = {}, + NullPlacement null_placement = NullPlacement::AtEnd, + Tiebreaker tiebreaker = RankOptions::First); + /// Convenience constructor for array inputs + explicit RankOptions(SortOrder order, + NullPlacement null_placement = NullPlacement::AtEnd, + Tiebreaker tiebreaker = RankOptions::First) + : RankOptions({SortKey("", order)}, null_placement, tiebreaker) {} + + static constexpr const char kTypeName[] = "RankOptions"; + static RankOptions Defaults() { return RankOptions(); } + + /// Column key(s) to order by and how to order by these sort keys. + std::vector sort_keys; + /// Whether nulls and NaNs are placed at the start or at the end + NullPlacement null_placement; + /// Tiebreaker for dealing with equal values in ranks + Tiebreaker tiebreaker; +}; + +/// \brief Quantile rank options +class ARROW_EXPORT RankQuantileOptions : public FunctionOptions { + public: + explicit RankQuantileOptions(std::vector sort_keys = {}, + NullPlacement null_placement = NullPlacement::AtEnd); + /// Convenience constructor for array inputs + explicit RankQuantileOptions(SortOrder order, + NullPlacement null_placement = NullPlacement::AtEnd) + : RankQuantileOptions({SortKey("", order)}, null_placement) {} + + static constexpr const char kTypeName[] = "RankQuantileOptions"; + static RankQuantileOptions Defaults() { return RankQuantileOptions(); } + + /// Column key(s) to order by and how to order by these sort keys. + std::vector sort_keys; + /// Whether nulls and NaNs are placed at the start or at the end + NullPlacement null_placement; +}; + +/// \brief Partitioning options for NthToIndices +class ARROW_EXPORT PartitionNthOptions : public FunctionOptions { + public: + explicit PartitionNthOptions(int64_t pivot, + NullPlacement null_placement = NullPlacement::AtEnd); + PartitionNthOptions() : PartitionNthOptions(0) {} + static constexpr const char kTypeName[] = "PartitionNthOptions"; + + /// The index into the equivalent sorted array of the partition pivot element. + int64_t pivot; + /// Whether nulls and NaNs are partitioned at the start or at the end + NullPlacement null_placement; +}; + +class ARROW_EXPORT WinsorizeOptions : public FunctionOptions { + public: + WinsorizeOptions(double lower_limit, double upper_limit); + WinsorizeOptions() : WinsorizeOptions(0, 1) {} + static constexpr const char kTypeName[] = "WinsorizeOptions"; + + /// The quantile below which all values are replaced with the quantile's value. + /// + /// For example, if lower_limit = 0.05, then all values in the lower 5% percentile + /// will be replaced with the 5% percentile value. + double lower_limit; + + /// The quantile above which all values are replaced with the quantile's value. + /// + /// For example, if upper_limit = 0.95, then all values in the upper 95% percentile + /// will be replaced with the 95% percentile value. + double upper_limit; +}; + +/// \brief Options for cumulative functions +/// \note Also aliased as CumulativeSumOptions for backward compatibility +class ARROW_EXPORT CumulativeOptions : public FunctionOptions { + public: + explicit CumulativeOptions(bool skip_nulls = false); + explicit CumulativeOptions(double start, bool skip_nulls = false); + explicit CumulativeOptions(std::shared_ptr start, bool skip_nulls = false); + static constexpr const char kTypeName[] = "CumulativeOptions"; + static CumulativeOptions Defaults() { return CumulativeOptions(); } + + /// Optional starting value for cumulative operation computation, default depends on the + /// operation and input type. + /// - sum: 0 + /// - prod: 1 + /// - min: maximum of the input type + /// - max: minimum of the input type + /// - mean: start is ignored because it has no meaning for mean + std::optional> start; + + /// If true, nulls in the input are ignored and produce a corresponding null output. + /// When false, the first null encountered is propagated through the remaining output. + bool skip_nulls = false; +}; +using CumulativeSumOptions = CumulativeOptions; // For backward compatibility + +/// \brief Options for pairwise functions +class ARROW_EXPORT PairwiseOptions : public FunctionOptions { + public: + explicit PairwiseOptions(int64_t periods = 1); + static constexpr const char kTypeName[] = "PairwiseOptions"; + static PairwiseOptions Defaults() { return PairwiseOptions(); } + + /// Periods to shift for applying the binary operation, accepts negative values. + int64_t periods = 1; +}; + +/// \brief Options for list_flatten function +class ARROW_EXPORT ListFlattenOptions : public FunctionOptions { + public: + explicit ListFlattenOptions(bool recursive = false); + static constexpr const char kTypeName[] = "ListFlattenOptions"; + static ListFlattenOptions Defaults() { return ListFlattenOptions(); } + + /// \brief If true, the list is flattened recursively until a non-list + /// array is formed. + bool recursive = false; +}; + +/// \brief Options for inverse_permutation function +class ARROW_EXPORT InversePermutationOptions : public FunctionOptions { + public: + explicit InversePermutationOptions( + int64_t max_index = -1, + std::optional> output_type = std::nullopt); + static constexpr const char kTypeName[] = "InversePermutationOptions"; + static InversePermutationOptions Defaults() { return InversePermutationOptions(); } + + /// \brief The max value in the input indices to allow. The length of the function's + /// output will be this value plus 1. If negative, this value will be set to the length + /// of the input indices minus 1 and the length of the function's output will be the + /// length of the input indices. + int64_t max_index = -1; + /// \brief The data type for the output array of inverse permutation. Defaults to the + /// type of the input indices when `nullopt`. Must be a signed integer type. An + /// invalid error will be reported if this type is not able to store the length of the + /// input indices. + std::optional> output_type; +}; + +/// \brief Options for scatter function +class ARROW_EXPORT ScatterOptions : public FunctionOptions { + public: + explicit ScatterOptions(int64_t max_index = -1); + static constexpr const char kTypeName[] = "ScatterOptions"; + static ScatterOptions Defaults() { return ScatterOptions(); } + + /// \brief The max value in the input indices to allow. The length of the function's + /// output will be this value plus 1. If negative, this value will be set to the length + /// of the input indices minus 1 and the length of the function's output will be the + /// length of the input indices. + int64_t max_index = -1; +}; + +/// @} + +/// \brief Filter with a boolean selection filter +/// +/// The output will be populated with values from the input at positions +/// where the selection filter is not 0. Nulls in the filter will be handled +/// based on options.null_selection_behavior. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// filter = [0, 1, 1, 0, null, 1], the output will be +/// (null_selection_behavior == DROP) = ["b", "c", "f"] +/// (null_selection_behavior == EMIT_NULL) = ["b", "c", null, "f"] +/// +/// \param[in] values array to filter +/// \param[in] filter indicates which values should be filtered out +/// \param[in] options configures null_selection_behavior +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +ARROW_EXPORT +Result Filter(const Datum& values, const Datum& filter, + const FilterOptions& options = FilterOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +namespace internal { + +// These internal functions are implemented in kernels/vector_selection.cc + +/// \brief Return the number of selected indices in the boolean filter +/// +/// \param filter a plain or run-end encoded boolean array with or without nulls +/// \param null_selection how to handle nulls in the filter +ARROW_EXPORT +int64_t GetFilterOutputSize(const ArraySpan& filter, + FilterOptions::NullSelectionBehavior null_selection); + +/// \brief Compute uint64 selection indices for use with Take given a boolean +/// filter +/// +/// \param filter a plain or run-end encoded boolean array with or without nulls +/// \param null_selection how to handle nulls in the filter +ARROW_EXPORT +Result> GetTakeIndices( + const ArraySpan& filter, FilterOptions::NullSelectionBehavior null_selection, + MemoryPool* memory_pool = default_memory_pool()); + +} // namespace internal + +/// \brief ReplaceWithMask replaces each value in the array corresponding +/// to a true value in the mask with the next element from `replacements`. +/// +/// \param[in] values Array input to replace +/// \param[in] mask Array or Scalar of Boolean mask values +/// \param[in] replacements The replacement values to draw from. There must +/// be as many replacement values as true values in the mask. +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result ReplaceWithMask(const Datum& values, const Datum& mask, + const Datum& replacements, ExecContext* ctx = NULLPTR); + +/// \brief FillNullForward fill null values in forward direction +/// +/// The output array will be of the same type as the input values +/// array, with replaced null values in forward direction. +/// +/// For example given values = ["a", "b", "c", null, null, "f"], +/// the output will be = ["a", "b", "c", "c", "c", "f"] +/// +/// \param[in] values datum from which to take +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +ARROW_EXPORT +Result FillNullForward(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief FillNullBackward fill null values in backward direction +/// +/// The output array will be of the same type as the input values +/// array, with replaced null values in backward direction. +/// +/// For example given values = ["a", "b", "c", null, null, "f"], +/// the output will be = ["a", "b", "c", "f", "f", "f"] +/// +/// \param[in] values datum from which to take +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +ARROW_EXPORT +Result FillNullBackward(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief Take from an array of values at indices in another array +/// +/// The output array will be of the same type as the input values +/// array, with elements taken from the values array at the given +/// indices. If an index is null then the taken element will be null. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// indices = [2, 1, null, 3], the output will be +/// = [values[2], values[1], null, values[3]] +/// = ["c", "b", null, null] +/// +/// \param[in] values datum from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +ARROW_EXPORT +Result Take(const Datum& values, const Datum& indices, + const TakeOptions& options = TakeOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Take with Array inputs and output +ARROW_EXPORT +Result> Take(const Array& values, const Array& indices, + const TakeOptions& options = TakeOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Drop Null from an array of values +/// +/// The output array will be of the same type as the input values +/// array, with elements taken from the values array without nulls. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"], +/// the output will be = ["a", "b", "c", "e", "f"] +/// +/// \param[in] values datum from which to take +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +ARROW_EXPORT +Result DropNull(const Datum& values, ExecContext* ctx = NULLPTR); + +/// \brief DropNull with Array inputs and output +ARROW_EXPORT +Result> DropNull(const Array& values, ExecContext* ctx = NULLPTR); + +/// \brief Return indices that partition an array around n-th sorted element. +/// +/// Find index of n-th(0 based) smallest value and perform indirect +/// partition of an array around that element. Output indices[0 ~ n-1] +/// holds values no greater than n-th element, and indices[n+1 ~ end] +/// holds values no less than n-th element. Elements in each partition +/// is not sorted. Nulls will be partitioned to the end of the output. +/// Output is not guaranteed to be stable. +/// +/// \param[in] values array to be partitioned +/// \param[in] n pivot array around sorted n-th element +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would partition an array +ARROW_EXPORT +Result> NthToIndices(const Array& values, int64_t n, + ExecContext* ctx = NULLPTR); + +/// \brief Return indices that partition an array around n-th sorted element. +/// +/// This overload takes a PartitionNthOptions specifying the pivot index +/// and the null handling. +/// +/// \param[in] values array to be partitioned +/// \param[in] options options including pivot index and null handling +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would partition an array +ARROW_EXPORT +Result> NthToIndices(const Array& values, + const PartitionNthOptions& options, + ExecContext* ctx = NULLPTR); + +/// \brief Return indices that would select the first `k` elements. +/// +/// Perform an indirect sort of the datum, keeping only the first `k` elements. The output +/// array will contain indices such that the item indicated by the k-th index will be in +/// the position it would be if the datum were sorted by `options.sort_keys`. However, +/// indices of null values will not be part of the output. The sort is not guaranteed to +/// be stable. +/// +/// \param[in] datum datum to be partitioned +/// \param[in] options options +/// \param[in] ctx the function execution context, optional +/// \return a datum with the same schema as the input +ARROW_EXPORT +Result> SelectKUnstable(const Datum& datum, + const SelectKOptions& options, + ExecContext* ctx = NULLPTR); + +/// \brief Return the indices that would sort an array. +/// +/// Perform an indirect sort of array. The output array will contain +/// indices that would sort an array, which would be the same length +/// as input. Nulls will be stably partitioned to the end of the output +/// regardless of order. +/// +/// For example given array = [null, 1, 3.3, null, 2, 5.3] and order +/// = SortOrder::DESCENDING, the output will be [5, 2, 4, 1, 0, +/// 3]. +/// +/// \param[in] array array to sort +/// \param[in] order ascending or descending +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would sort an array +ARROW_EXPORT +Result> SortIndices(const Array& array, + SortOrder order = SortOrder::Ascending, + ExecContext* ctx = NULLPTR); + +/// \brief Return the indices that would sort an array. +/// +/// This overload takes a ArraySortOptions specifying the sort order +/// and the null handling. +/// +/// \param[in] array array to sort +/// \param[in] options options including sort order and null handling +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would sort an array +ARROW_EXPORT +Result> SortIndices(const Array& array, + const ArraySortOptions& options, + ExecContext* ctx = NULLPTR); + +/// \brief Return the indices that would sort a chunked array. +/// +/// Perform an indirect sort of chunked array. The output array will +/// contain indices that would sort a chunked array, which would be +/// the same length as input. Nulls will be stably partitioned to the +/// end of the output regardless of order. +/// +/// For example given chunked_array = [[null, 1], [3.3], [null, 2, +/// 5.3]] and order = SortOrder::DESCENDING, the output will be [5, 2, +/// 4, 1, 0, 3]. +/// +/// \param[in] chunked_array chunked array to sort +/// \param[in] order ascending or descending +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would sort an array +ARROW_EXPORT +Result> SortIndices(const ChunkedArray& chunked_array, + SortOrder order = SortOrder::Ascending, + ExecContext* ctx = NULLPTR); + +/// \brief Return the indices that would sort a chunked array. +/// +/// This overload takes a ArraySortOptions specifying the sort order +/// and the null handling. +/// +/// \param[in] chunked_array chunked array to sort +/// \param[in] options options including sort order and null handling +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would sort an array +ARROW_EXPORT +Result> SortIndices(const ChunkedArray& chunked_array, + const ArraySortOptions& options, + ExecContext* ctx = NULLPTR); + +/// \brief Return the indices that would sort an input in the +/// specified order. Input is one of array, chunked array record batch +/// or table. +/// +/// Perform an indirect sort of input. The output array will contain +/// indices that would sort an input, which would be the same length +/// as input. Nulls will be stably partitioned to the start or to the end +/// of the output depending on SortOrder::null_placement. +/// +/// For example given input (table) = { +/// "column1": [[null, 1], [ 3, null, 2, 1]], +/// "column2": [[ 5], [3, null, null, 5, 5]], +/// } and options = { +/// {"column1", SortOrder::Ascending}, +/// {"column2", SortOrder::Descending}, +/// }, the output will be [5, 1, 4, 2, 0, 3]. +/// +/// \param[in] datum array, chunked array, record batch or table to sort +/// \param[in] options options +/// \param[in] ctx the function execution context, optional +/// \return offsets indices that would sort a table +ARROW_EXPORT +Result> SortIndices(const Datum& datum, const SortOptions& options, + ExecContext* ctx = NULLPTR); + +/// \brief Compute unique elements from an array-like object +/// +/// Note if a null occurs in the input it will NOT be included in the output. +/// +/// \param[in] datum array-like input +/// \param[in] ctx the function execution context, optional +/// \return result as Array +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result> Unique(const Datum& datum, ExecContext* ctx = NULLPTR); + +// Constants for accessing the output of ValueCounts +ARROW_EXPORT extern const char kValuesFieldName[]; +ARROW_EXPORT extern const char kCountsFieldName[]; +ARROW_EXPORT extern const int32_t kValuesFieldIndex; +ARROW_EXPORT extern const int32_t kCountsFieldIndex; + +/// \brief Return counts of unique elements from an array-like object. +/// +/// Note that the counts do not include counts for nulls in the array. These can be +/// obtained separately from metadata. +/// +/// For floating point arrays there is no attempt to normalize -0.0, 0.0 and NaN values +/// which can lead to unexpected results if the input Array has these values. +/// +/// \param[in] value array-like input +/// \param[in] ctx the function execution context, optional +/// \return counts An array of structs. +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result> ValueCounts(const Datum& value, + ExecContext* ctx = NULLPTR); + +/// \brief Dictionary-encode values in an array-like object +/// +/// Any nulls encountered in the dictionary will be handled according to the +/// specified null encoding behavior. +/// +/// For example, given values ["a", "b", null, "a", null] the output will be +/// (null_encoding == ENCODE) Indices: [0, 1, 2, 0, 2] / Dict: ["a", "b", null] +/// (null_encoding == MASK) Indices: [0, 1, null, 0, null] / Dict: ["a", "b"] +/// +/// If the input is already dictionary encoded this function is a no-op unless +/// it needs to modify the null_encoding (TODO) +/// +/// \param[in] data array-like input +/// \param[in] ctx the function execution context, optional +/// \param[in] options configures null encoding behavior +/// \return result with same shape and type as input +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result DictionaryEncode( + const Datum& data, + const DictionaryEncodeOptions& options = DictionaryEncodeOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Run-end-encode values in an array-like object +/// +/// The returned run-end encoded type uses the same value type of the input and +/// run-end type defined in the options. +/// +/// \param[in] value array-like input +/// \param[in] options configures encoding behavior +/// \param[in] ctx the function execution context, optional +/// \return result with same shape but run-end encoded +/// +/// \since 12.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result RunEndEncode( + const Datum& value, + const RunEndEncodeOptions& options = RunEndEncodeOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Decode a Run-End Encoded array to a plain array +/// +/// The output data type is the same as the values array type of run-end encoded +/// input. +/// +/// \param[in] value run-end-encoded input +/// \param[in] ctx the function execution context, optional +/// \return plain array resulting from decoding the run-end encoded input +/// +/// \since 12.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result RunEndDecode(const Datum& value, ExecContext* ctx = NULLPTR); + +/// \brief Compute the cumulative sum of an array-like object +/// +/// \param[in] values array-like input +/// \param[in] options configures cumulative sum behavior +/// \param[in] check_overflow whether to check for overflow, if true, return Invalid +/// status on overflow, otherwise wrap around on overflow +/// \param[in] ctx the function execution context, optional +ARROW_EXPORT +Result CumulativeSum( + const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(), + bool check_overflow = false, ExecContext* ctx = NULLPTR); + +/// \brief Compute the cumulative product of an array-like object +/// +/// \param[in] values array-like input +/// \param[in] options configures cumulative prod behavior +/// \param[in] check_overflow whether to check for overflow, if true, return Invalid +/// status on overflow, otherwise wrap around on overflow +/// \param[in] ctx the function execution context, optional +ARROW_EXPORT +Result CumulativeProd( + const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(), + bool check_overflow = false, ExecContext* ctx = NULLPTR); + +/// \brief Compute the cumulative max of an array-like object +/// +/// \param[in] values array-like input +/// \param[in] options configures cumulative max behavior +/// \param[in] ctx the function execution context, optional +ARROW_EXPORT +Result CumulativeMax( + const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the cumulative min of an array-like object +/// +/// \param[in] values array-like input +/// \param[in] options configures cumulative min behavior +/// \param[in] ctx the function execution context, optional +ARROW_EXPORT +Result CumulativeMin( + const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the cumulative mean of an array-like object +/// +/// \param[in] values array-like input +/// \param[in] options configures cumulative mean behavior, `start` is ignored +/// \param[in] ctx the function execution context, optional +ARROW_EXPORT +Result CumulativeMean( + const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Return the first order difference of an array. +/// +/// Computes the first order difference of an array, i.e. +/// output[i] = input[i] - input[i - p] if i >= p +/// output[i] = null otherwise +/// where p is the period. For example, with p = 1, +/// Diff([1, 4, 9, 10, 15]) = [null, 3, 5, 1, 5]. +/// With p = 2, +/// Diff([1, 4, 9, 10, 15]) = [null, null, 8, 6, 6] +/// p can also be negative, in which case the diff is computed in +/// the opposite direction. +/// \param[in] array array input +/// \param[in] options options, specifying overflow behavior and period +/// \param[in] check_overflow whether to return error on overflow +/// \param[in] ctx the function execution context, optional +/// \return result as array +ARROW_EXPORT +Result> PairwiseDiff(const Array& array, + const PairwiseOptions& options, + bool check_overflow = false, + ExecContext* ctx = NULLPTR); + +/// \brief Return the inverse permutation of the given indices. +/// +/// For indices[i] = x, inverse_permutation[x] = i. And inverse_permutation[x] = null if x +/// does not appear in the input indices. Indices must be in the range of [0, max_index], +/// or null, which will be ignored. If multiple indices point to the same value, the last +/// one is used. +/// +/// For example, with +/// indices = [null, 0, null, 2, 4, 1, 1] +/// the inverse permutation is +/// [1, 6, 3, null, 4, null, null] +/// if max_index = 6. +/// +/// \param[in] indices array-like indices +/// \param[in] options configures the max index and the output type +/// \param[in] ctx the function execution context, optional +/// \return the resulting inverse permutation +/// +/// \since 20.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result InversePermutation( + const Datum& indices, + const InversePermutationOptions& options = InversePermutationOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +/// \brief Scatter the values into specified positions according to the indices. +/// +/// For indices[i] = x, output[x] = values[i]. And output[x] = null if x does not appear +/// in the input indices. Indices must be in the range of [0, max_index], or null, in +/// which case the corresponding value will be ignored. If multiple indices point to the +/// same value, the last one is used. +/// +/// For example, with +/// values = [a, b, c, d, e, f, g] +/// indices = [null, 0, null, 2, 4, 1, 1] +/// the output is +/// [b, g, d, null, e, null, null] +/// if max_index = 6. +/// +/// \param[in] values datum to scatter +/// \param[in] indices array-like indices +/// \param[in] options configures the max index of to scatter +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 20.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Scatter(const Datum& values, const Datum& indices, + const ScatterOptions& options = ScatterOptions::Defaults(), + ExecContext* ctx = NULLPTR); + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/cast.h b/pyarrow/include/arrow/compute/cast.h new file mode 100644 index 0000000000000000000000000000000000000000..ec5818239acb1ab52b06945a5fb4e60f84b58b61 --- /dev/null +++ b/pyarrow/include/arrow/compute/cast.h @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/compute/function.h" +#include "arrow/compute/function_options.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; + +namespace compute { + +class ExecContext; + +/// \addtogroup compute-concrete-options +/// @{ + +class ARROW_EXPORT CastOptions : public FunctionOptions { + public: + explicit CastOptions(bool safe = true); + + static constexpr const char kTypeName[] = "CastOptions"; + static CastOptions Safe(TypeHolder to_type = {}) { + CastOptions safe(true); + safe.to_type = std::move(to_type); + return safe; + } + + static CastOptions Unsafe(TypeHolder to_type = {}) { + CastOptions unsafe(false); + unsafe.to_type = std::move(to_type); + return unsafe; + } + + // Type being casted to. May be passed separate to eager function + // compute::Cast + TypeHolder to_type; + + bool allow_int_overflow; + bool allow_time_truncate; + bool allow_time_overflow; + bool allow_decimal_truncate; + bool allow_float_truncate; + // Indicate if conversions from Binary/FixedSizeBinary to string must + // validate the utf8 payload. + bool allow_invalid_utf8; + + /// true if the safety options all match CastOptions::Safe + /// + /// Note, if this returns false it does not mean is_unsafe will return true + bool is_safe() const; + /// true if the safety options all match CastOptions::Unsafe + /// + /// Note, if this returns false it does not mean is_safe will return true + bool is_unsafe() const; +}; + +/// @} + +/// \brief Return true if a cast function is defined +ARROW_EXPORT +bool CanCast(const DataType& from_type, const DataType& to_type); + +// ---------------------------------------------------------------------- +// Convenience invocation APIs for a number of kernels + +/// \brief Cast from one array type to another +/// \param[in] value array to cast +/// \param[in] to_type type to cast to +/// \param[in] options casting options +/// \param[in] ctx the function execution context, optional +/// \return the resulting array +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result> Cast(const Array& value, const TypeHolder& to_type, + const CastOptions& options = CastOptions::Safe(), + ExecContext* ctx = NULLPTR); + +/// \brief Cast from one array type to another +/// \param[in] value array to cast +/// \param[in] options casting options. The "to_type" field must be populated +/// \param[in] ctx the function execution context, optional +/// \return the resulting array +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Cast(const Datum& value, const CastOptions& options, + ExecContext* ctx = NULLPTR); + +/// \brief Cast from one value to another +/// \param[in] value datum to cast +/// \param[in] to_type type to cast to +/// \param[in] options casting options +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result Cast(const Datum& value, const TypeHolder& to_type, + const CastOptions& options = CastOptions::Safe(), + ExecContext* ctx = NULLPTR); + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/exec.h b/pyarrow/include/arrow/compute/exec.h new file mode 100644 index 0000000000000000000000000000000000000000..dae7e1ea686829fcf9b11bf07489d2cca8610f2b --- /dev/null +++ b/pyarrow/include/arrow/compute/exec.h @@ -0,0 +1,489 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array/data.h" +#include "arrow/compute/expression.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/type_fwd.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +// It seems like 64K might be a good default chunksize to use for execution +// based on the experience of other query processing systems. The current +// default is not to chunk contiguous arrays, though, but this may change in +// the future once parallel execution is implemented +static constexpr int64_t kDefaultExecChunksize = UINT16_MAX; + +/// \brief Context for expression-global variables and options used by +/// function evaluation +class ARROW_EXPORT ExecContext { + public: + // If no function registry passed, the default is used. + explicit ExecContext(MemoryPool* pool = default_memory_pool(), + ::arrow::internal::Executor* executor = NULLPTR, + FunctionRegistry* func_registry = NULLPTR); + + /// \brief The MemoryPool used for allocations, default is + /// default_memory_pool(). + MemoryPool* memory_pool() const { return pool_; } + + const ::arrow::internal::CpuInfo* cpu_info() const; + + /// \brief An Executor which may be used to parallelize execution. + ::arrow::internal::Executor* executor() const { return executor_; } + + /// \brief The FunctionRegistry for looking up functions by name and + /// selecting kernels for execution. Defaults to the library-global function + /// registry provided by GetFunctionRegistry. + FunctionRegistry* func_registry() const { return func_registry_; } + + // \brief Set maximum length unit of work for kernel execution. Larger + // contiguous array inputs will be split into smaller chunks, and, if + // possible and enabled, processed in parallel. The default chunksize is + // INT64_MAX, so contiguous arrays are not split. + void set_exec_chunksize(int64_t chunksize) { exec_chunksize_ = chunksize; } + + // \brief Maximum length for ExecBatch data chunks processed by + // kernels. Contiguous array inputs with longer length will be split into + // smaller chunks. + int64_t exec_chunksize() const { return exec_chunksize_; } + + /// \brief Set whether to use multiple threads for function execution. This + /// is not yet used. + void set_use_threads(bool use_threads = true) { use_threads_ = use_threads; } + + /// \brief If true, then utilize multiple threads where relevant for function + /// execution. This is not yet used. + bool use_threads() const { return use_threads_; } + + // Set the preallocation strategy for kernel execution as it relates to + // chunked execution. For chunked execution, whether via ChunkedArray inputs + // or splitting larger Array arguments into smaller pieces, contiguous + // allocation (if permitted by the kernel) will allocate one large array to + // write output into yielding it to the caller at the end. If this option is + // set to off, then preallocations will be performed independently for each + // chunk of execution + // + // TODO: At some point we might want the limit the size of contiguous + // preallocations. For example, even if the exec_chunksize is 64K or less, we + // might limit contiguous allocations to 1M records, say. + void set_preallocate_contiguous(bool preallocate) { + preallocate_contiguous_ = preallocate; + } + + /// \brief If contiguous preallocations should be used when doing chunked + /// execution as specified by exec_chunksize(). See + /// set_preallocate_contiguous() for more information. + bool preallocate_contiguous() const { return preallocate_contiguous_; } + + private: + MemoryPool* pool_; + ::arrow::internal::Executor* executor_; + FunctionRegistry* func_registry_; + int64_t exec_chunksize_ = std::numeric_limits::max(); + bool preallocate_contiguous_ = true; + bool use_threads_ = true; +}; + +// TODO: Consider standardizing on uint16 selection vectors and only use them +// when we can ensure that each value is 64K length or smaller + +/// \brief Container for an array of value selection indices that were +/// materialized from a filter. +/// +/// Columnar query engines (see e.g. [1]) have found that rather than +/// materializing filtered data, the filter can instead be converted to an +/// array of the "on" indices and then "fusing" these indices in operator +/// implementations. This is especially relevant for aggregations but also +/// applies to scalar operations. +/// +/// We are not yet using this so this is mostly a placeholder for now. +/// +/// [1]: http://cidrdb.org/cidr2005/papers/P19.pdf +class ARROW_EXPORT SelectionVector { + public: + explicit SelectionVector(std::shared_ptr data); + + explicit SelectionVector(const Array& arr); + + /// \brief Create SelectionVector from boolean mask + static Result> FromMask(const BooleanArray& arr); + + const int32_t* indices() const { return indices_; } + int32_t length() const; + + private: + std::shared_ptr data_; + const int32_t* indices_; +}; + +/// An index to represent that a batch does not belong to an ordered stream +constexpr int64_t kUnsequencedIndex = -1; + +/// \brief A unit of work for kernel execution. It contains a collection of +/// Array and Scalar values and an optional SelectionVector indicating that +/// there is an unmaterialized filter that either must be materialized, or (if +/// the kernel supports it) pushed down into the kernel implementation. +/// +/// ExecBatch is semantically similar to RecordBatch in that in a SQL context +/// it represents a collection of records, but constant "columns" are +/// represented by Scalar values rather than having to be converted into arrays +/// with repeated values. +/// +/// TODO: Datum uses arrow/util/variant.h which may be a bit heavier-weight +/// than is desirable for this class. Microbenchmarks would help determine for +/// sure. See ARROW-8928. + +/// \addtogroup acero-internals +/// @{ + +struct ARROW_EXPORT ExecBatch { + ExecBatch() = default; + ExecBatch(std::vector values, int64_t length) + : values(std::move(values)), length(length) {} + + explicit ExecBatch(const RecordBatch& batch); + + /// \brief Infer the ExecBatch length from values. + static Result InferLength(const std::vector& values); + + /// Creates an ExecBatch with length-validation. + /// + /// If any value is given, then all values must have a common length. If the given + /// length is negative, then the length of the ExecBatch is set to this common length, + /// or to 1 if no values are given. Otherwise, the given length must equal the common + /// length, if any value is given. + static Result Make(std::vector values, int64_t length = -1); + + Result> ToRecordBatch( + std::shared_ptr schema, MemoryPool* pool = default_memory_pool()) const; + + /// The values representing positional arguments to be passed to a kernel's + /// exec function for processing. + std::vector values; + + /// A deferred filter represented as an array of indices into the values. + /// + /// For example, the filter [true, true, false, true] would be represented as + /// the selection vector [0, 1, 3]. When the selection vector is set, + /// ExecBatch::length is equal to the length of this array. + std::shared_ptr selection_vector; + + /// A predicate Expression guaranteed to evaluate to true for all rows in this batch. + Expression guarantee = literal(true); + + /// The semantic length of the ExecBatch. When the values are all scalars, + /// the length should be set to 1 for non-aggregate kernels, otherwise the + /// length is taken from the array values, except when there is a selection + /// vector. When there is a selection vector set, the length of the batch is + /// the length of the selection. Aggregate kernels can have an ExecBatch + /// formed by projecting just the partition columns from a batch in which + /// case, it would have scalar rows with length greater than 1. + /// + /// If the array values are of length 0 then the length is 0 regardless of + /// whether any values are Scalar. + int64_t length = 0; + + /// \brief index of this batch in a sorted stream of batches + /// + /// This index must be strictly monotonic starting at 0 without gaps or + /// it can be set to kUnsequencedIndex if there is no meaningful order + int64_t index = kUnsequencedIndex; + + /// \brief The sum of bytes in each buffer referenced by the batch + /// + /// Note: Scalars are not counted + /// Note: Some values may referenced only part of a buffer, for + /// example, an array with an offset. The actual data + /// visible to this batch will be smaller than the total + /// buffer size in this case. + int64_t TotalBufferSize() const; + + /// \brief Return the value at the i-th index + template + inline const Datum& operator[](index_type i) const { + return values[i]; + } + + bool Equals(const ExecBatch& other) const; + + /// \brief A convenience for the number of values / arguments. + int num_values() const { return static_cast(values.size()); } + + ExecBatch Slice(int64_t offset, int64_t length) const; + + Result SelectValues(const std::vector& ids) const; + + /// \brief A convenience for returning the types from the batch. + std::vector GetTypes() const { + std::vector result; + for (const auto& value : this->values) { + result.emplace_back(value.type()); + } + return result; + } + + std::string ToString() const; +}; + +inline bool operator==(const ExecBatch& l, const ExecBatch& r) { return l.Equals(r); } +inline bool operator!=(const ExecBatch& l, const ExecBatch& r) { return !l.Equals(r); } + +ARROW_EXPORT void PrintTo(const ExecBatch&, std::ostream*); + +/// @} + +/// \defgroup compute-internals Utilities for calling functions, useful for those +/// extending the function registry +/// +/// @{ + +struct ExecValue { + ArraySpan array = {}; + const Scalar* scalar = NULLPTR; + + ExecValue(const Scalar* scalar) // NOLINT implicit conversion + : scalar(scalar) {} + + ExecValue(ArraySpan array) // NOLINT implicit conversion + : array(std::move(array)) {} + + ExecValue(const ArrayData& array) { // NOLINT implicit conversion + this->array.SetMembers(array); + } + + ExecValue() = default; + ExecValue(const ExecValue& other) = default; + ExecValue& operator=(const ExecValue& other) = default; + ExecValue(ExecValue&& other) = default; + ExecValue& operator=(ExecValue&& other) = default; + + int64_t length() const { return this->is_array() ? this->array.length : 1; } + + bool is_array() const { return this->scalar == NULLPTR; } + bool is_scalar() const { return !this->is_array(); } + + void SetArray(const ArrayData& array) { + this->array.SetMembers(array); + this->scalar = NULLPTR; + } + + void SetScalar(const Scalar* scalar) { this->scalar = scalar; } + + template + const ExactType& scalar_as() const { + return ::arrow::internal::checked_cast(*this->scalar); + } + + /// XXX: here temporarily for compatibility with datum, see + /// e.g. MakeStructExec in scalar_nested.cc + int64_t null_count() const { + if (this->is_array()) { + return this->array.GetNullCount(); + } else { + return this->scalar->is_valid ? 0 : 1; + } + } + + const DataType* type() const { + if (this->is_array()) { + return array.type; + } else { + return scalar->type.get(); + } + } +}; + +struct ARROW_EXPORT ExecResult { + // The default value of the variant is ArraySpan + std::variant> value; + + int64_t length() const { + if (this->is_array_span()) { + return this->array_span()->length; + } else { + return this->array_data()->length; + } + } + + const DataType* type() const { + if (this->is_array_span()) { + return this->array_span()->type; + } else { + return this->array_data()->type.get(); + } + } + + const ArraySpan* array_span() const { return &std::get(this->value); } + ArraySpan* array_span_mutable() { return &std::get(this->value); } + + bool is_array_span() const { return this->value.index() == 0; } + + const std::shared_ptr& array_data() const { + return std::get>(this->value); + } + ArrayData* array_data_mutable() { + return std::get>(this->value).get(); + } + + bool is_array_data() const { return this->value.index() == 1; } +}; + +/// \brief A "lightweight" column batch object which contains no +/// std::shared_ptr objects and does not have any memory ownership +/// semantics. Can represent a view onto an "owning" ExecBatch. +struct ARROW_EXPORT ExecSpan { + ExecSpan() = default; + ExecSpan(const ExecSpan& other) = default; + ExecSpan& operator=(const ExecSpan& other) = default; + ExecSpan(ExecSpan&& other) = default; + ExecSpan& operator=(ExecSpan&& other) = default; + + explicit ExecSpan(std::vector values, int64_t length) + : length(length), values(std::move(values)) {} + + explicit ExecSpan(const ExecBatch& batch) { + this->length = batch.length; + this->values.resize(batch.values.size()); + for (size_t i = 0; i < batch.values.size(); ++i) { + const Datum& in_value = batch[i]; + ExecValue* out_value = &this->values[i]; + if (in_value.is_array()) { + out_value->SetArray(*in_value.array()); + } else { + out_value->SetScalar(in_value.scalar().get()); + } + } + } + + /// \brief Return the value at the i-th index + template + inline const ExecValue& operator[](index_type i) const { + return values[i]; + } + + /// \brief A convenience for the number of values / arguments. + int num_values() const { return static_cast(values.size()); } + + std::vector GetTypes() const { + std::vector result; + for (const auto& value : this->values) { + result.emplace_back(value.type()); + } + return result; + } + + ExecBatch ToExecBatch() const { + ExecBatch result; + result.length = this->length; + for (const ExecValue& value : this->values) { + if (value.is_array()) { + result.values.push_back(value.array.ToArrayData()); + } else { + result.values.push_back(value.scalar->GetSharedPtr()); + } + } + return result; + } + + int64_t length = 0; + std::vector values; +}; + +/// \defgroup compute-call-function One-shot calls to compute functions +/// +/// @{ + +/// \brief One-shot invoker for all types of functions. +/// +/// Does kernel dispatch, argument checking, iteration of ChunkedArray inputs, +/// and wrapping of outputs. +ARROW_EXPORT +Result CallFunction(const std::string& func_name, const std::vector& args, + const FunctionOptions* options, ExecContext* ctx = NULLPTR); + +/// \brief Variant of CallFunction which uses a function's default options. +/// +/// NB: Some functions require FunctionOptions be provided. +ARROW_EXPORT +Result CallFunction(const std::string& func_name, const std::vector& args, + ExecContext* ctx = NULLPTR); + +/// \brief One-shot invoker for all types of functions. +/// +/// Does kernel dispatch, argument checking, iteration of ChunkedArray inputs, +/// and wrapping of outputs. +ARROW_EXPORT +Result CallFunction(const std::string& func_name, const ExecBatch& batch, + const FunctionOptions* options, ExecContext* ctx = NULLPTR); + +/// \brief Variant of CallFunction which uses a function's default options. +/// +/// NB: Some functions require FunctionOptions be provided. +ARROW_EXPORT +Result CallFunction(const std::string& func_name, const ExecBatch& batch, + ExecContext* ctx = NULLPTR); + +/// @} + +/// \defgroup compute-function-executor One-shot calls to obtain function executors +/// +/// @{ + +/// \brief One-shot executor provider for all types of functions. +/// +/// This function creates and initializes a `FunctionExecutor` appropriate +/// for the given function name, input types and function options. +ARROW_EXPORT +Result> GetFunctionExecutor( + const std::string& func_name, std::vector in_types, + const FunctionOptions* options = NULLPTR, FunctionRegistry* func_registry = NULLPTR); + +/// \brief One-shot executor provider for all types of functions. +/// +/// This function creates and initializes a `FunctionExecutor` appropriate +/// for the given function name, input types (taken from the Datum arguments) +/// and function options. +ARROW_EXPORT +Result> GetFunctionExecutor( + const std::string& func_name, const std::vector& args, + const FunctionOptions* options = NULLPTR, FunctionRegistry* func_registry = NULLPTR); + +/// @} + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/expression.h b/pyarrow/include/arrow/compute/expression.h new file mode 100644 index 0000000000000000000000000000000000000000..b8ce50675c8c9bb0a3a7081a23c6bd3c2002f2d1 --- /dev/null +++ b/pyarrow/include/arrow/compute/expression.h @@ -0,0 +1,295 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/compute/type_fwd.h" +#include "arrow/datum.h" +#include "arrow/type_fwd.h" +#include "arrow/util/small_vector.h" + +namespace arrow { +namespace compute { + +/// \defgroup expression-core Expressions to describe data transformations +/// +/// @{ + +/// An unbound expression which maps a single Datum to another Datum. +/// An expression is one of +/// - A literal Datum. +/// - A reference to a single (potentially nested) field of the input Datum. +/// - A call to a compute function, with arguments specified by other Expressions. +class ARROW_EXPORT Expression { + public: + struct Call { + std::string function_name; + std::vector arguments; + std::shared_ptr options; + // Cached hash value + size_t hash; + + // post-Bind properties: + std::shared_ptr function; + const Kernel* kernel = NULLPTR; + std::shared_ptr kernel_state; + TypeHolder type; + + void ComputeHash(); + }; + + std::string ToString() const; + bool Equals(const Expression& other) const; + size_t hash() const; + struct Hash { + size_t operator()(const Expression& expr) const { return expr.hash(); } + }; + + /// Bind this expression to the given input type, looking up Kernels and field types. + /// Some expression simplification may be performed and implicit casts will be inserted. + /// Any state necessary for execution will be initialized and returned. + Result Bind(const TypeHolder& in, ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, ExecContext* = NULLPTR) const; + + // XXX someday + // Clone all KernelState in this bound expression. If any function referenced by this + // expression has mutable KernelState, it is not safe to execute or apply simplification + // passes to it (or copies of it!) from multiple threads. Cloning state produces new + // KernelStates where necessary to ensure that Expressions may be manipulated safely + // on multiple threads. + // Result CloneState() const; + // Status SetState(ExpressionState); + + /// Return true if all an expression's field references have explicit types + /// and all of its functions' kernels are looked up. + bool IsBound() const; + + /// Return true if this expression is composed only of Scalar literals, field + /// references, and calls to ScalarFunctions. + bool IsScalarExpression() const; + + /// Return true if this expression is literal and entirely null. + bool IsNullLiteral() const; + + /// Return true if this expression could evaluate to true. Will return true for any + /// unbound or non-boolean Expressions. IsSatisfiable does not (currently) do any + /// canonicalization or simplification of the expression, so even Expressions + /// which are unsatisfiable may spuriously return `true` here. This function is + /// intended for use in predicate pushdown where a filter expression is simplified + /// by a guarantee, so it assumes that trying to simplify again would be redundant. + bool IsSatisfiable() const; + + // XXX someday + // Result GetPipelines(); + + bool is_valid() const { return impl_ != NULLPTR; } + + /// Access a Call or return nullptr if this expression is not a call + const Call* call() const; + /// Access a Datum or return nullptr if this expression is not a literal + const Datum* literal() const; + /// Access a FieldRef or return nullptr if this expression is not a field_ref + const FieldRef* field_ref() const; + + /// The type to which this expression will evaluate + const DataType* type() const; + // XXX someday + // NullGeneralization::type nullable() const; + + struct Parameter { + FieldRef ref; + + // post-bind properties + TypeHolder type; + ::arrow::internal::SmallVector indices; + }; + const Parameter* parameter() const; + + Expression() = default; + explicit Expression(Call call); + explicit Expression(Datum literal); + explicit Expression(Parameter parameter); + + static bool Identical(const Expression& l, const Expression& r); + + private: + using Impl = std::variant; + std::shared_ptr impl_; +}; + +inline bool operator==(const Expression& l, const Expression& r) { return l.Equals(r); } +inline bool operator!=(const Expression& l, const Expression& r) { return !l.Equals(r); } + +ARROW_EXPORT void PrintTo(const Expression&, std::ostream*); + +// Factories + +ARROW_EXPORT +Expression literal(Datum lit); + +template +Expression literal(Arg&& arg) { + return literal(Datum(std::forward(arg))); +} + +ARROW_EXPORT +Expression field_ref(FieldRef ref); + +ARROW_EXPORT +Expression call(std::string function, std::vector arguments, + std::shared_ptr options = NULLPTR); + +template ::value>::type> +Expression call(std::string function, std::vector arguments, + Options options) { + return call(std::move(function), std::move(arguments), + std::make_shared(std::move(options))); +} + +/// Assemble a list of all fields referenced by an Expression at any depth. +ARROW_EXPORT +std::vector FieldsInExpression(const Expression&); + +/// Check if the expression references any fields. +ARROW_EXPORT +bool ExpressionHasFieldRefs(const Expression&); + +struct ARROW_EXPORT KnownFieldValues; + +/// Assemble a mapping from field references to known values. This derives known values +/// from "equal" and "is_null" Expressions referencing a field and a literal. +ARROW_EXPORT +Result ExtractKnownFieldValues( + const Expression& guaranteed_true_predicate); + +/// @} + +/// \defgroup expression-passes Functions for modification of Expressions +/// +/// @{ +/// +/// These transform bound expressions. Some transforms utilize a guarantee, which is +/// provided as an Expression which is guaranteed to evaluate to true. The +/// guaranteed_true_predicate need not be bound, but canonicalization is currently +/// deferred to producers of guarantees. For example in order to be recognized as a +/// guarantee on a field value, an Expression must be a call to "equal" with field_ref LHS +/// and literal RHS. Flipping the arguments, "is_in" with a one-long value_set, ... or +/// other semantically identical Expressions will not be recognized. + +/// Weak canonicalization which establishes guarantees for subsequent passes. Even +/// equivalent Expressions may result in different canonicalized expressions. +/// TODO this could be a strong canonicalization +ARROW_EXPORT +Result Canonicalize(Expression, ExecContext* = NULLPTR); + +/// Simplify Expressions based on literal arguments (for example, add(null, x) will always +/// be null so replace the call with a null literal). Includes early evaluation of all +/// calls whose arguments are entirely literal. +ARROW_EXPORT +Result FoldConstants(Expression); + +/// Simplify Expressions by replacing with known values of the fields which it references. +ARROW_EXPORT +Result ReplaceFieldsWithKnownValues(const KnownFieldValues& known_values, + Expression); + +/// Simplify an expression by replacing subexpressions based on a guarantee: +/// a boolean expression which is guaranteed to evaluate to `true`. For example, this is +/// used to remove redundant function calls from a filter expression or to replace a +/// reference to a constant-value field with a literal. +ARROW_EXPORT +Result SimplifyWithGuarantee(Expression, + const Expression& guaranteed_true_predicate); + +/// Replace all named field refs (e.g. "x" or "x.y") with field paths (e.g. [0] or [1,3]) +/// +/// This isn't usually needed and does not offer any simplification by itself. However, +/// it can be useful to normalize an expression to paths to make it simpler to work with. +ARROW_EXPORT Result RemoveNamedRefs(Expression expression); + +/// @} + +// Execution + +/// Create an ExecBatch suitable for passing to ExecuteScalarExpression() from a +/// RecordBatch which may have missing or incorrectly ordered columns. +/// Missing fields will be replaced with null scalars. +ARROW_EXPORT Result MakeExecBatch(const Schema& full_schema, + const Datum& partial, + Expression guarantee = literal(true)); + +/// Execute a scalar expression against the provided state and input ExecBatch. This +/// expression must be bound. +ARROW_EXPORT +Result ExecuteScalarExpression(const Expression&, const ExecBatch& input, + ExecContext* = NULLPTR); + +/// Convenience function for invoking against a RecordBatch +ARROW_EXPORT +Result ExecuteScalarExpression(const Expression&, const Schema& full_schema, + const Datum& partial_input, ExecContext* = NULLPTR); + +// Serialization + +ARROW_EXPORT +Result> Serialize(const Expression&); + +ARROW_EXPORT +Result Deserialize(std::shared_ptr); + +/// \defgroup expression-convenience Helpers for convenient expression creation +/// +/// @{ + +ARROW_EXPORT Expression project(std::vector values, + std::vector names); + +ARROW_EXPORT Expression equal(Expression lhs, Expression rhs); + +ARROW_EXPORT Expression not_equal(Expression lhs, Expression rhs); + +ARROW_EXPORT Expression less(Expression lhs, Expression rhs); + +ARROW_EXPORT Expression less_equal(Expression lhs, Expression rhs); + +ARROW_EXPORT Expression greater(Expression lhs, Expression rhs); + +ARROW_EXPORT Expression greater_equal(Expression lhs, Expression rhs); + +ARROW_EXPORT Expression is_null(Expression lhs, bool nan_is_null = false); + +ARROW_EXPORT Expression is_valid(Expression lhs); + +ARROW_EXPORT Expression and_(Expression lhs, Expression rhs); +ARROW_EXPORT Expression and_(const std::vector&); +ARROW_EXPORT Expression or_(Expression lhs, Expression rhs); +ARROW_EXPORT Expression or_(const std::vector&); +ARROW_EXPORT Expression not_(Expression operand); + +/// @} + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/function.h b/pyarrow/include/arrow/compute/function.h new file mode 100644 index 0000000000000000000000000000000000000000..399081e2a7371f7e39c7cc5da73af8f524ee9b99 --- /dev/null +++ b/pyarrow/include/arrow/compute/function.h @@ -0,0 +1,410 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle. + +#pragma once + +#include +#include +#include + +#include "arrow/compute/kernel.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/compare.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +/// \addtogroup compute-functions +/// @{ + +/// \brief Contains the number of required arguments for the function. +/// +/// Naming conventions taken from https://en.wikipedia.org/wiki/Arity. +struct ARROW_EXPORT Arity { + /// \brief A function taking no arguments + static Arity Nullary() { return Arity(0, false); } + + /// \brief A function taking 1 argument + static Arity Unary() { return Arity(1, false); } + + /// \brief A function taking 2 arguments + static Arity Binary() { return Arity(2, false); } + + /// \brief A function taking 3 arguments + static Arity Ternary() { return Arity(3, false); } + + /// \brief A function taking a variable number of arguments + /// + /// \param[in] min_args the minimum number of arguments required when + /// invoking the function + static Arity VarArgs(int min_args = 0) { return Arity(min_args, true); } + + // NOTE: the 0-argument form (default constructor) is required for Cython + explicit Arity(int num_args = 0, bool is_varargs = false) + : num_args(num_args), is_varargs(is_varargs) {} + + /// The number of required arguments (or the minimum number for varargs + /// functions). + int num_args; + + /// If true, then the num_args is the minimum number of required arguments. + bool is_varargs = false; +}; + +struct ARROW_EXPORT FunctionDoc { + /// \brief A one-line summary of the function, using a verb. + /// + /// For example, "Add two numeric arrays or scalars". + std::string summary; + + /// \brief A detailed description of the function, meant to follow the summary. + std::string description; + + /// \brief Symbolic names (identifiers) for the function arguments. + /// + /// Some bindings may use this to generate nicer function signatures. + std::vector arg_names; + + // TODO add argument descriptions? + + /// \brief Name of the options class, if any. + std::string options_class; + + /// \brief Whether options are required for function execution + /// + /// If false, then either the function does not have an options class + /// or there is a usable default options value. + bool options_required; + + FunctionDoc() = default; + + FunctionDoc(std::string summary, std::string description, + std::vector arg_names, std::string options_class = "", + bool options_required = false) + : summary(std::move(summary)), + description(std::move(description)), + arg_names(std::move(arg_names)), + options_class(std::move(options_class)), + options_required(options_required) {} + + static const FunctionDoc& Empty(); +}; + +/// \brief An executor of a function with a preconfigured kernel +class ARROW_EXPORT FunctionExecutor { + public: + virtual ~FunctionExecutor() = default; + /// \brief Initialize or re-initialize the preconfigured kernel + /// + /// This method may be called zero or more times. Depending on how + /// the FunctionExecutor was obtained, it may already have been initialized. + virtual Status Init(const FunctionOptions* options = NULLPTR, + ExecContext* exec_ctx = NULLPTR) = 0; + /// \brief Execute the preconfigured kernel with arguments that must fit it + /// + /// The method requires the arguments be castable to the preconfigured types. + /// + /// \param[in] args Arguments to execute the function on + /// \param[in] length Length of arguments batch or -1 to default it. If the + /// function has no parameters, this determines the batch length, defaulting + /// to 0. Otherwise, if the function is scalar, this must equal the argument + /// batch's inferred length or be -1 to default to it. This is ignored for + /// vector functions. + virtual Result Execute(const std::vector& args, int64_t length = -1) = 0; +}; + +/// \brief Base class for compute functions. Function implementations contain a +/// collection of "kernels" which are implementations of the function for +/// specific argument types. Selecting a viable kernel for executing a function +/// is referred to as "dispatching". +class ARROW_EXPORT Function { + public: + /// \brief The kind of function, which indicates in what contexts it is + /// valid for use. + enum Kind { + /// A function that performs scalar data operations on whole arrays of + /// data. Can generally process Array or Scalar values. The size of the + /// output will be the same as the size (or broadcasted size, in the case + /// of mixing Array and Scalar inputs) of the input. + SCALAR, + + /// A function with array input and output whose behavior depends on the + /// values of the entire arrays passed, rather than the value of each scalar + /// value. + VECTOR, + + /// A function that computes scalar summary statistics from array input. + SCALAR_AGGREGATE, + + /// A function that computes grouped summary statistics from array input + /// and an array of group identifiers. + HASH_AGGREGATE, + + /// A function that dispatches to other functions and does not contain its + /// own kernels. + META + }; + + virtual ~Function() = default; + + /// \brief The name of the kernel. The registry enforces uniqueness of names. + const std::string& name() const { return name_; } + + /// \brief The kind of kernel, which indicates in what contexts it is valid + /// for use. + Function::Kind kind() const { return kind_; } + + /// \brief Contains the number of arguments the function requires, or if the + /// function accepts variable numbers of arguments. + const Arity& arity() const { return arity_; } + + /// \brief Return the function documentation + const FunctionDoc& doc() const { return doc_; } + + /// \brief Returns the number of registered kernels for this function. + virtual int num_kernels() const = 0; + + /// \brief Return a kernel that can execute the function given the exact + /// argument types (without implicit type casts). + /// + /// NB: This function is overridden in CastFunction. + virtual Result DispatchExact(const std::vector& types) const; + + /// \brief Return a best-match kernel that can execute the function given the argument + /// types, after implicit casts are applied. + /// + /// \param[in,out] values Argument types. An element may be modified to + /// indicate that the returned kernel only approximately matches the input + /// value descriptors; callers are responsible for casting inputs to the type + /// required by the kernel. + virtual Result DispatchBest(std::vector* values) const; + + /// \brief Get a function executor with a best-matching kernel + /// + /// The returned executor will by default work with the default FunctionOptions + /// and KernelContext. If you want to change that, call `FunctionExecutor::Init`. + virtual Result> GetBestExecutor( + std::vector inputs) const; + + /// \brief Execute the function eagerly with the passed input arguments with + /// kernel dispatch, batch iteration, and memory allocation details taken + /// care of. + /// + /// If the `options` pointer is null, then `default_options()` will be used. + /// + /// This function can be overridden in subclasses. + virtual Result Execute(const std::vector& args, + const FunctionOptions* options, ExecContext* ctx) const; + + virtual Result Execute(const ExecBatch& batch, const FunctionOptions* options, + ExecContext* ctx) const; + + /// \brief Returns the default options for this function. + /// + /// Whatever option semantics a Function has, implementations must guarantee + /// that default_options() is valid to pass to Execute as options. + const FunctionOptions* default_options() const { return default_options_; } + + virtual Status Validate() const; + + /// \brief Returns the pure property for this function. + /// + /// Impure functions are those that may return different results for the same + /// input arguments. For example, a function that returns a random number is + /// not pure. An expression containing only pure functions can be simplified by + /// pre-evaluating any sub-expressions that have constant arguments. + virtual bool is_pure() const { return true; } + + protected: + Function(std::string name, Function::Kind kind, const Arity& arity, FunctionDoc doc, + const FunctionOptions* default_options) + : name_(std::move(name)), + kind_(kind), + arity_(arity), + doc_(std::move(doc)), + default_options_(default_options) {} + + Status CheckArity(size_t num_args) const; + + std::string name_; + Function::Kind kind_; + Arity arity_; + const FunctionDoc doc_; + const FunctionOptions* default_options_ = NULLPTR; +}; + +namespace detail { + +template +class FunctionImpl : public Function { + public: + /// \brief Return pointers to current-available kernels for inspection + std::vector kernels() const { + std::vector result; + for (const auto& kernel : kernels_) { + result.push_back(&kernel); + } + return result; + } + + int num_kernels() const override { return static_cast(kernels_.size()); } + + protected: + FunctionImpl(std::string name, Function::Kind kind, const Arity& arity, FunctionDoc doc, + const FunctionOptions* default_options) + : Function(std::move(name), kind, arity, std::move(doc), default_options) {} + + std::vector kernels_; +}; + +/// \brief Look up a kernel in a function. If no Kernel is found, nullptr is returned. +ARROW_EXPORT +const Kernel* DispatchExactImpl(const Function* func, const std::vector&); + +/// \brief Return an error message if no Kernel is found. +ARROW_EXPORT +Status NoMatchingKernel(const Function* func, const std::vector&); + +} // namespace detail + +/// \brief A function that executes elementwise operations on arrays or +/// scalars, and therefore whose results generally do not depend on the order +/// of the values in the arguments. Accepts and returns arrays that are all of +/// the same size. These functions roughly correspond to the functions used in +/// SQL expressions. +class ARROW_EXPORT ScalarFunction : public detail::FunctionImpl { + public: + using KernelType = ScalarKernel; + + ScalarFunction(std::string name, const Arity& arity, FunctionDoc doc, + const FunctionOptions* default_options = NULLPTR, bool is_pure = true) + : detail::FunctionImpl(std::move(name), Function::SCALAR, arity, + std::move(doc), default_options), + is_pure_(is_pure) {} + + /// \brief Add a kernel with given input/output types, no required state + /// initialization, preallocation for fixed-width types, and default null + /// handling (intersect validity bitmaps of inputs). + Status AddKernel(std::vector in_types, OutputType out_type, + ArrayKernelExec exec, KernelInit init = NULLPTR, + std::shared_ptr constraint = NULLPTR); + + /// \brief Add a kernel (function implementation). Returns error if the + /// kernel's signature does not match the function's arity. + Status AddKernel(ScalarKernel kernel); + + /// \brief Returns the pure property for this function. + bool is_pure() const override { return is_pure_; } + + private: + const bool is_pure_; +}; + +/// \brief A function that executes general array operations that may yield +/// outputs of different sizes or have results that depend on the whole array +/// contents. These functions roughly correspond to the functions found in +/// non-SQL array languages like APL and its derivatives. +class ARROW_EXPORT VectorFunction : public detail::FunctionImpl { + public: + using KernelType = VectorKernel; + + VectorFunction(std::string name, const Arity& arity, FunctionDoc doc, + const FunctionOptions* default_options = NULLPTR) + : detail::FunctionImpl(std::move(name), Function::VECTOR, arity, + std::move(doc), default_options) {} + + /// \brief Add a simple kernel with given input/output types, no required + /// state initialization, no data preallocation, and no preallocation of the + /// validity bitmap. + Status AddKernel(std::vector in_types, OutputType out_type, + ArrayKernelExec exec, KernelInit init = NULLPTR); + + /// \brief Add a kernel (function implementation). Returns error if the + /// kernel's signature does not match the function's arity. + Status AddKernel(VectorKernel kernel); +}; + +class ARROW_EXPORT ScalarAggregateFunction + : public detail::FunctionImpl { + public: + using KernelType = ScalarAggregateKernel; + + ScalarAggregateFunction(std::string name, const Arity& arity, FunctionDoc doc, + const FunctionOptions* default_options = NULLPTR) + : detail::FunctionImpl(std::move(name), + Function::SCALAR_AGGREGATE, arity, + std::move(doc), default_options) {} + + /// \brief Add a kernel (function implementation). Returns error if the + /// kernel's signature does not match the function's arity. + Status AddKernel(ScalarAggregateKernel kernel); +}; + +class ARROW_EXPORT HashAggregateFunction + : public detail::FunctionImpl { + public: + using KernelType = HashAggregateKernel; + + HashAggregateFunction(std::string name, const Arity& arity, FunctionDoc doc, + const FunctionOptions* default_options = NULLPTR) + : detail::FunctionImpl(std::move(name), + Function::HASH_AGGREGATE, arity, + std::move(doc), default_options) {} + + /// \brief Add a kernel (function implementation). Returns error if the + /// kernel's signature does not match the function's arity. + Status AddKernel(HashAggregateKernel kernel); +}; + +/// \brief A function that dispatches to other functions. Must implement +/// MetaFunction::ExecuteImpl. +/// +/// For Array, ChunkedArray, and Scalar Datum kinds, may rely on the execution +/// of concrete Function types, but must handle other Datum kinds on its own. +class ARROW_EXPORT MetaFunction : public Function { + public: + int num_kernels() const override { return 0; } + + Result Execute(const std::vector& args, const FunctionOptions* options, + ExecContext* ctx) const override; + + Result Execute(const ExecBatch& batch, const FunctionOptions* options, + ExecContext* ctx) const override; + + protected: + virtual Result ExecuteImpl(const std::vector& args, + const FunctionOptions* options, + ExecContext* ctx) const = 0; + + MetaFunction(std::string name, const Arity& arity, FunctionDoc doc, + const FunctionOptions* default_options = NULLPTR) + : Function(std::move(name), Function::META, arity, std::move(doc), + default_options) {} +}; + +/// @} + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/function_options.h b/pyarrow/include/arrow/compute/function_options.h new file mode 100644 index 0000000000000000000000000000000000000000..88ec2fd2d0679b5c849549179aa652bec9b37b56 --- /dev/null +++ b/pyarrow/include/arrow/compute/function_options.h @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle. + +#pragma once + +#include "arrow/compute/type_fwd.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +/// \addtogroup compute-functions +/// @{ + +/// \brief Extension point for defining options outside libarrow (but +/// still within this project). +class ARROW_EXPORT FunctionOptionsType { + public: + virtual ~FunctionOptionsType() = default; + + virtual const char* type_name() const = 0; + virtual std::string Stringify(const FunctionOptions&) const = 0; + virtual bool Compare(const FunctionOptions&, const FunctionOptions&) const = 0; + virtual Result> Serialize(const FunctionOptions&) const; + virtual Result> Deserialize( + const Buffer& buffer) const; + virtual std::unique_ptr Copy(const FunctionOptions&) const = 0; +}; + +/// \brief Base class for specifying options configuring a function's behavior, +/// such as error handling. +class ARROW_EXPORT FunctionOptions : public util::EqualityComparable { + public: + virtual ~FunctionOptions() = default; + + const FunctionOptionsType* options_type() const { return options_type_; } + const char* type_name() const { return options_type()->type_name(); } + + bool Equals(const FunctionOptions& other) const; + std::string ToString() const; + std::unique_ptr Copy() const; + /// \brief Serialize an options struct to a buffer. + Result> Serialize() const; + /// \brief Deserialize an options struct from a buffer. + /// Note: this will only look for `type_name` in the default FunctionRegistry; + /// to use a custom FunctionRegistry, look up the FunctionOptionsType, then + /// call FunctionOptionsType::Deserialize(). + static Result> Deserialize( + const std::string& type_name, const Buffer& buffer); + + protected: + explicit FunctionOptions(const FunctionOptionsType* type) : options_type_(type) {} + const FunctionOptionsType* options_type_; +}; + +ARROW_EXPORT void PrintTo(const FunctionOptions&, std::ostream*); + +/// @} + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/initialize.h b/pyarrow/include/arrow/compute/initialize.h new file mode 100644 index 0000000000000000000000000000000000000000..db5e231325bab4c944e086078780ac7302008c77 --- /dev/null +++ b/pyarrow/include/arrow/compute/initialize.h @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/compute/visibility.h" +#include "arrow/status.h" + +namespace arrow::compute { + +/// \brief Initialize the compute module. +/// +/// Register the compute kernel functions to be available on the +/// global FunctionRegistry. +/// This function will only be available if ARROW_COMPUTE is enabled. +ARROW_COMPUTE_EXPORT Status Initialize(); + +} // namespace arrow::compute diff --git a/pyarrow/include/arrow/compute/kernel.h b/pyarrow/include/arrow/compute/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..0d4f9d6ff436de470801f061b4e66a5e58876286 --- /dev/null +++ b/pyarrow/include/arrow/compute/kernel.h @@ -0,0 +1,772 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/compute/exec.h" +#include "arrow/datum.h" +#include "arrow/device_allocation_type_set.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +// macOS defines PREALLOCATE as a preprocessor macro in the header sys/vnode.h. +// No other BSD seems to do so. The name is used as an identifier in MemAllocation enum. +#if defined(__APPLE__) && defined(PREALLOCATE) +# undef PREALLOCATE +#endif + +namespace arrow { +namespace compute { + +class FunctionOptions; + +/// \brief Base class for opaque kernel-specific state. For example, if there +/// is some kind of initialization required. +struct ARROW_EXPORT KernelState { + virtual ~KernelState() = default; +}; + +/// \brief Context/state for the execution of a particular kernel. +class ARROW_EXPORT KernelContext { + public: + // Can pass optional backreference; not used consistently for the + // moment but will be made so in the future + explicit KernelContext(ExecContext* exec_ctx, const Kernel* kernel = NULLPTR) + : exec_ctx_(exec_ctx), kernel_(kernel) {} + + /// \brief Allocate buffer from the context's memory pool. The contents are + /// not initialized. + Result> Allocate(int64_t nbytes); + + /// \brief Allocate buffer for bitmap from the context's memory pool. Like + /// Allocate, the contents of the buffer are not initialized but the last + /// byte is preemptively zeroed to help avoid ASAN or valgrind issues. + Result> AllocateBitmap(int64_t num_bits); + + /// \brief Assign the active KernelState to be utilized for each stage of + /// kernel execution. Ownership and memory lifetime of the KernelState must + /// be minded separately. + void SetState(KernelState* state) { state_ = state; } + + // Set kernel that is being invoked since some kernel + // implementations will examine the kernel state. + void SetKernel(const Kernel* kernel) { kernel_ = kernel; } + + KernelState* state() { return state_; } + + /// \brief Configuration related to function execution that is to be shared + /// across multiple kernels. + ExecContext* exec_context() { return exec_ctx_; } + + /// \brief The memory pool to use for allocations. For now, it uses the + /// MemoryPool contained in the ExecContext used to create the KernelContext. + MemoryPool* memory_pool() { return exec_ctx_->memory_pool(); } + + const Kernel* kernel() const { return kernel_; } + + private: + ExecContext* exec_ctx_; + KernelState* state_ = NULLPTR; + const Kernel* kernel_ = NULLPTR; +}; + +/// \brief An type-checking interface to permit customizable validation rules +/// for use with InputType and KernelSignature. This is for scenarios where the +/// acceptance is not an exact type instance, such as a TIMESTAMP type for a +/// specific TimeUnit, but permitting any time zone. +struct ARROW_EXPORT TypeMatcher { + virtual ~TypeMatcher() = default; + + /// \brief Return true if this matcher accepts the data type. + virtual bool Matches(const DataType& type) const = 0; + + /// \brief A human-interpretable string representation of what the type + /// matcher checks for, usable when printing KernelSignature or formatting + /// error messages. + virtual std::string ToString() const = 0; + + /// \brief Return true if this TypeMatcher contains the same matching rule as + /// the other. Currently depends on RTTI. + virtual bool Equals(const TypeMatcher& other) const = 0; +}; + +namespace match { + +/// \brief Match any DataType instance having the same DataType::id. +ARROW_EXPORT std::shared_ptr SameTypeId(Type::type type_id); + +/// \brief Match any TimestampType instance having the same unit, but the time +/// zones can be different. +ARROW_EXPORT std::shared_ptr TimestampTypeUnit(TimeUnit::type unit); +ARROW_EXPORT std::shared_ptr Time32TypeUnit(TimeUnit::type unit); +ARROW_EXPORT std::shared_ptr Time64TypeUnit(TimeUnit::type unit); +ARROW_EXPORT std::shared_ptr DurationTypeUnit(TimeUnit::type unit); + +// \brief Match any integer type +ARROW_EXPORT std::shared_ptr Integer(); + +// Match types using 32-bit varbinary representation +ARROW_EXPORT std::shared_ptr BinaryLike(); + +// Match types using 64-bit varbinary representation +ARROW_EXPORT std::shared_ptr LargeBinaryLike(); + +// Match any fixed binary type +ARROW_EXPORT std::shared_ptr FixedSizeBinaryLike(); + +// \brief Match any primitive type (boolean or any type representable as a C +// Type) +ARROW_EXPORT std::shared_ptr Primitive(); + +// \brief Match any integer type that can be used as run-end in run-end encoded +// arrays +ARROW_EXPORT std::shared_ptr RunEndInteger(); + +/// \brief Match run-end encoded types that use any valid run-end type and +/// encode specific value types +/// +/// @param[in] value_type_matcher a matcher that is applied to the values field +ARROW_EXPORT std::shared_ptr RunEndEncoded( + std::shared_ptr value_type_matcher); + +/// \brief Match run-end encoded types that use any valid run-end type and +/// encode specific value types +/// +/// @param[in] value_type_id a type id that the type of the values field should match +ARROW_EXPORT std::shared_ptr RunEndEncoded(Type::type value_type_id); + +/// \brief Match run-end encoded types that encode specific run-end and value types +/// +/// @param[in] run_end_type_matcher a matcher that is applied to the run_ends field +/// @param[in] value_type_matcher a matcher that is applied to the values field +ARROW_EXPORT std::shared_ptr RunEndEncoded( + std::shared_ptr run_end_type_matcher, + std::shared_ptr value_type_matcher); + +} // namespace match + +/// \brief An object used for type-checking arguments to be passed to a kernel +/// and stored in a KernelSignature. The type-checking rule can be supplied +/// either with an exact DataType instance or a custom TypeMatcher. +class ARROW_EXPORT InputType { + public: + /// \brief The kind of type-checking rule that the InputType contains. + enum Kind { + /// \brief Accept any value type. + ANY_TYPE, + + /// \brief A fixed arrow::DataType and will only exact match having this + /// exact type (e.g. same TimestampType unit, same decimal scale and + /// precision, or same nested child types). + EXACT_TYPE, + + /// \brief Uses a TypeMatcher implementation to check the type. + USE_TYPE_MATCHER + }; + + /// \brief Accept any value type + InputType() : kind_(ANY_TYPE) {} + + /// \brief Accept an exact value type. + InputType(std::shared_ptr type) // NOLINT implicit construction + : kind_(EXACT_TYPE), type_(std::move(type)) {} + + /// \brief Use the passed TypeMatcher to type check. + InputType(std::shared_ptr type_matcher) // NOLINT implicit construction + : kind_(USE_TYPE_MATCHER), type_matcher_(std::move(type_matcher)) {} + + /// \brief Match any type with the given Type::type. Uses a TypeMatcher for + /// its implementation. + InputType(Type::type type_id) // NOLINT implicit construction + : InputType(match::SameTypeId(type_id)) {} + + InputType(const InputType& other) { CopyInto(other); } + + void operator=(const InputType& other) { CopyInto(other); } + + InputType(InputType&& other) { MoveInto(std::forward(other)); } + + void operator=(InputType&& other) { MoveInto(std::forward(other)); } + + // \brief Match any input (array, scalar of any type) + static InputType Any() { return InputType(); } + + /// \brief Return true if this input type matches the same type cases as the + /// other. + bool Equals(const InputType& other) const; + + bool operator==(const InputType& other) const { return this->Equals(other); } + + bool operator!=(const InputType& other) const { return !(*this == other); } + + /// \brief Return hash code. + size_t Hash() const; + + /// \brief Render a human-readable string representation. + std::string ToString() const; + + /// \brief Return true if the Datum matches this argument kind in + /// type (and only allows scalar or array-like Datums). + bool Matches(const Datum& value) const; + + /// \brief Return true if the type matches this InputType + bool Matches(const DataType& type) const; + + /// \brief The type matching rule that this InputType uses. + Kind kind() const { return kind_; } + + /// \brief For InputType::EXACT_TYPE kind, the exact type that this InputType + /// must match. Otherwise this function should not be used and will assert in + /// debug builds. + const std::shared_ptr& type() const; + + /// \brief For InputType::USE_TYPE_MATCHER, the TypeMatcher to be used for + /// checking the type of a value. Otherwise this function should not be used + /// and will assert in debug builds. + const TypeMatcher& type_matcher() const; + + private: + void CopyInto(const InputType& other) { + this->kind_ = other.kind_; + this->type_ = other.type_; + this->type_matcher_ = other.type_matcher_; + } + + void MoveInto(InputType&& other) { + this->kind_ = other.kind_; + this->type_ = std::move(other.type_); + this->type_matcher_ = std::move(other.type_matcher_); + } + + Kind kind_; + + // For EXACT_TYPE Kind + std::shared_ptr type_; + + // For USE_TYPE_MATCHER Kind + std::shared_ptr type_matcher_; +}; + +/// \brief Container to capture both exact and input-dependent output types. +class ARROW_EXPORT OutputType { + public: + /// \brief An enum indicating whether the value type is an invariant fixed + /// value or one that's computed by a kernel-defined resolver function. + enum ResolveKind { FIXED, COMPUTED }; + + /// Type resolution function. Given input types, return output type. This + /// function MAY may use the kernel state to decide the output type based on + /// the FunctionOptions. + /// + /// This function SHOULD _not_ be used to check for arity, that is to be + /// performed one or more layers above. + using Resolver = + std::function(KernelContext*, const std::vector&)>; + + /// \brief Output an exact type + OutputType(std::shared_ptr type) // NOLINT implicit construction + : kind_(FIXED), type_(std::move(type)) {} + + /// \brief Output a computed type depending on actual input types + template + OutputType(Fn resolver) // NOLINT implicit construction + : kind_(COMPUTED), resolver_(std::move(resolver)) {} + + OutputType(const OutputType& other) { + this->kind_ = other.kind_; + this->type_ = other.type_; + this->resolver_ = other.resolver_; + } + + OutputType(OutputType&& other) { + this->kind_ = other.kind_; + this->type_ = std::move(other.type_); + this->resolver_ = other.resolver_; + } + + OutputType& operator=(const OutputType&) = default; + OutputType& operator=(OutputType&&) = default; + + /// \brief Return the type of the expected output value of the kernel given + /// the input argument types. The resolver may make use of state information + /// kept in the KernelContext. + Result Resolve(KernelContext* ctx, + const std::vector& args) const; + + /// \brief The exact output value type for the FIXED kind. + const std::shared_ptr& type() const; + + /// \brief For use with COMPUTED resolution strategy. It may be more + /// convenient to invoke this with OutputType::Resolve returned from this + /// method. + const Resolver& resolver() const; + + /// \brief Render a human-readable string representation. + std::string ToString() const; + + /// \brief Return the kind of type resolution of this output type, whether + /// fixed/invariant or computed by a resolver. + ResolveKind kind() const { return kind_; } + + private: + ResolveKind kind_; + + // For FIXED resolution + std::shared_ptr type_; + + // For COMPUTED resolution + Resolver resolver_ = NULLPTR; +}; + +/// \brief Additional constraints to apply to the input types of a kernel when matching a +/// specific kernel signature. +class ARROW_EXPORT MatchConstraint { + public: + virtual ~MatchConstraint() = default; + + /// \brief Return true if the input types satisfy the constraint. + virtual bool Matches(const std::vector& types) const = 0; + + /// \brief Convenience function to create a MatchConstraint from a match function. + static std::shared_ptr Make( + std::function&)> matches); +}; + +/// \brief Constraint that all input types are decimal types and have the same scale. +ARROW_EXPORT std::shared_ptr DecimalsHaveSameScale(); + +/// \brief Holds the input types, optional match constraint and output type of the kernel. +/// +/// VarArgs functions with minimum N arguments should pass up to N input types to be +/// used to validate the input types of a function invocation. The first N-1 types +/// will be matched against the first N-1 arguments, and the last type will be +/// matched against the remaining arguments. +class ARROW_EXPORT KernelSignature { + public: + KernelSignature(std::vector in_types, OutputType out_type, + bool is_varargs = false, + std::shared_ptr constraint = NULLPTR); + + /// \brief Convenience ctor since make_shared can be awkward + static std::shared_ptr Make( + std::vector in_types, OutputType out_type, bool is_varargs = false, + std::shared_ptr constraint = NULLPTR); + + /// \brief Return true if the signature is compatible with the list of input + /// value descriptors and satisfies the match constraint, if any. + bool MatchesInputs(const std::vector& types) const; + + /// \brief Returns true if the input types of each signature are + /// equal. Well-formed functions should have a deterministic output type + /// given input types, but currently it is the responsibility of the + /// developer to ensure this. + bool Equals(const KernelSignature& other) const; + + bool operator==(const KernelSignature& other) const { return this->Equals(other); } + + bool operator!=(const KernelSignature& other) const { return !(*this == other); } + + /// \brief Compute a hash code for the signature + size_t Hash() const; + + /// \brief The input types for the kernel. For VarArgs functions, this should + /// generally contain a single validator to use for validating all of the + /// function arguments. + const std::vector& in_types() const { return in_types_; } + + /// \brief The output type for the kernel. Use Resolve to return the + /// exact output given input argument types, since many kernels' + /// output types depend on their input types (or their type + /// metadata). + const OutputType& out_type() const { return out_type_; } + + /// \brief Render a human-readable string representation + std::string ToString() const; + + bool is_varargs() const { return is_varargs_; } + + private: + std::vector in_types_; + OutputType out_type_; + bool is_varargs_; + std::shared_ptr constraint_; + + // For caching the hash code after it's computed the first time + mutable uint64_t hash_code_; +}; + +/// \brief A function may contain multiple variants of a kernel for a given +/// type combination for different SIMD levels. Based on the active system's +/// CPU info or the user's preferences, we can elect to use one over the other. +struct SimdLevel { + enum type { NONE = 0, SSE4_2, AVX, AVX2, AVX512, NEON, MAX }; +}; + +/// \brief The strategy to use for propagating or otherwise populating the +/// validity bitmap of a kernel output. +struct NullHandling { + enum type { + /// Compute the output validity bitmap by intersecting the validity bitmaps + /// of the arguments using bitwise-and operations. This means that values + /// in the output are valid/non-null only if the corresponding values in + /// all input arguments were valid/non-null. Kernel generally need not + /// touch the bitmap thereafter, but a kernel's exec function is permitted + /// to alter the bitmap after the null intersection is computed if it needs + /// to. + INTERSECTION, + + /// Kernel expects a pre-allocated buffer to write the result bitmap + /// into. The preallocated memory is not zeroed (except for the last byte), + /// so the kernel should ensure to completely populate the bitmap. + COMPUTED_PREALLOCATE, + + /// Kernel allocates and sets the validity bitmap of the output. + COMPUTED_NO_PREALLOCATE, + + /// Kernel output is never null and a validity bitmap does not need to be + /// allocated. + OUTPUT_NOT_NULL + }; +}; + +/// \brief The preference for memory preallocation of fixed-width type outputs +/// in kernel execution. +struct MemAllocation { + enum type { + // For data types that support pre-allocation (i.e. fixed-width), the + // kernel expects to be provided a pre-allocated data buffer to write + // into. Non-fixed-width types must always allocate their own data + // buffers. The allocation made for the same length as the execution batch, + // so vector kernels yielding differently sized output should not use this. + // + // It is valid for the data to not be preallocated but the validity bitmap + // is (or is computed using the intersection/bitwise-and method). + // + // For variable-size output types like BinaryType or StringType, or for + // nested types, this option has no effect. + PREALLOCATE, + + // The kernel is responsible for allocating its own data buffer for + // fixed-width type outputs. + NO_PREALLOCATE + }; +}; + +struct Kernel; + +/// \brief Arguments to pass to an KernelInit function. A struct is used to help +/// avoid API breakage should the arguments passed need to be expanded. +struct KernelInitArgs { + /// \brief A pointer to the kernel being initialized. The init function may + /// depend on the kernel's KernelSignature or other data contained there. + const Kernel* kernel; + + /// \brief The types of the input arguments that the kernel is + /// about to be executed against. + const std::vector& inputs; + + /// \brief Opaque options specific to this kernel. May be nullptr for functions + /// that do not require options. + const FunctionOptions* options; +}; + +/// \brief Common initializer function for all kernel types. +using KernelInit = std::function>( + KernelContext*, const KernelInitArgs&)>; + +/// \brief Base type for kernels. Contains the function signature and +/// optionally the state initialization function, along with some common +/// attributes +struct ARROW_EXPORT Kernel { + Kernel() = default; + + Kernel(std::shared_ptr sig, KernelInit init) + : signature(std::move(sig)), init(std::move(init)) {} + + Kernel(std::vector in_types, OutputType out_type, KernelInit init) + : Kernel(KernelSignature::Make(std::move(in_types), std::move(out_type)), + std::move(init)) {} + + /// \brief The "signature" of the kernel containing the InputType input + /// argument validators and OutputType output type resolver. + std::shared_ptr signature; + + /// \brief Create a new KernelState for invocations of this kernel, e.g. to + /// set up any options or state relevant for execution. + KernelInit init; + + /// \brief Create a vector of new KernelState for invocations of this kernel. + static Status InitAll(KernelContext*, const KernelInitArgs&, + std::vector>*); + + /// \brief Indicates whether execution can benefit from parallelization + /// (splitting large chunks into smaller chunks and using multiple + /// threads). Some kernels may not support parallel execution at + /// all. Synchronization and concurrency-related issues are currently the + /// responsibility of the Kernel's implementation. + bool parallelizable = true; + + /// \brief Indicates the level of SIMD instruction support in the host CPU is + /// required to use the function. The intention is for functions to be able to + /// contain multiple kernels with the same signature but different levels of SIMD, + /// so that the most optimized kernel supported on a host's processor can be chosen. + SimdLevel::type simd_level = SimdLevel::NONE; + + // Additional kernel-specific data + std::shared_ptr data; +}; + +/// \brief The scalar kernel execution API that must be implemented for SCALAR +/// kernel types. This includes both stateless and stateful kernels. Kernels +/// depending on some execution state access that state via subclasses of +/// KernelState set on the KernelContext object. Implementations should +/// endeavor to write into pre-allocated memory if they are able, though for +/// some kernels (e.g. in cases when a builder like StringBuilder) must be +/// employed this may not be possible. +using ArrayKernelExec = Status (*)(KernelContext*, const ExecSpan&, ExecResult*); + +/// \brief Kernel data structure for implementations of ScalarFunction. In +/// addition to the members found in Kernel, contains the null handling +/// and memory pre-allocation preferences. +struct ARROW_EXPORT ScalarKernel : public Kernel { + ScalarKernel() = default; + + ScalarKernel(std::shared_ptr sig, ArrayKernelExec exec, + KernelInit init = NULLPTR) + : Kernel(std::move(sig), init), exec(exec) {} + + ScalarKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, + KernelInit init = NULLPTR) + : Kernel(std::move(in_types), std::move(out_type), std::move(init)), exec(exec) {} + + /// \brief Perform a single invocation of this kernel. Depending on the + /// implementation, it may only write into preallocated memory, while in some + /// cases it will allocate its own memory. Any required state is managed + /// through the KernelContext. + ArrayKernelExec exec; + + /// \brief Writing execution results into larger contiguous allocations + /// requires that the kernel be able to write into sliced output ArrayData*, + /// including sliced output validity bitmaps. Some kernel implementations may + /// not be able to do this, so setting this to false disables this + /// functionality. + bool can_write_into_slices = true; + + // For scalar functions preallocated data and intersecting arg validity + // bitmaps is a reasonable default + NullHandling::type null_handling = NullHandling::INTERSECTION; + MemAllocation::type mem_allocation = MemAllocation::PREALLOCATE; +}; + +// ---------------------------------------------------------------------- +// VectorKernel (for VectorFunction) + +/// \brief Kernel data structure for implementations of VectorFunction. In +/// contains an optional finalizer function, the null handling and memory +/// pre-allocation preferences (which have different defaults from +/// ScalarKernel), and some other execution-related options. +struct ARROW_EXPORT VectorKernel : public Kernel { + /// \brief See VectorKernel::finalize member for usage + using FinalizeFunc = std::function*)>; + + /// \brief Function for executing a stateful VectorKernel against a + /// ChunkedArray input. Does not need to be defined for all VectorKernels + using ChunkedExec = Status (*)(KernelContext*, const ExecBatch&, Datum* out); + + VectorKernel() = default; + + VectorKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, + KernelInit init = NULLPTR, FinalizeFunc finalize = NULLPTR) + : Kernel(std::move(in_types), std::move(out_type), std::move(init)), + exec(exec), + finalize(std::move(finalize)) {} + + VectorKernel(std::shared_ptr sig, ArrayKernelExec exec, + KernelInit init = NULLPTR, FinalizeFunc finalize = NULLPTR) + : Kernel(std::move(sig), std::move(init)), + exec(exec), + finalize(std::move(finalize)) {} + + /// \brief Perform a single invocation of this kernel. Any required state is + /// managed through the KernelContext. + ArrayKernelExec exec; + + /// \brief Execute the kernel on a ChunkedArray. Does not need to be defined + ChunkedExec exec_chunked = NULLPTR; + + /// \brief For VectorKernel, convert intermediate results into finalized + /// results. Mutates input argument. Some kernels may accumulate state + /// (example: hashing-related functions) through processing chunked inputs, and + /// then need to attach some accumulated state to each of the outputs of + /// processing each chunk of data. + FinalizeFunc finalize; + + /// Since vector kernels generally are implemented rather differently from + /// scalar/elementwise kernels (and they may not even yield arrays of the same + /// size), so we make the developer opt-in to any memory preallocation rather + /// than having to turn it off. + NullHandling::type null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + MemAllocation::type mem_allocation = MemAllocation::NO_PREALLOCATE; + + /// \brief Writing execution results into larger contiguous allocations + /// requires that the kernel be able to write into sliced output ArrayData*, + /// including sliced output validity bitmaps. Some kernel implementations may + /// not be able to do this, so setting this to false disables this + /// functionality. + bool can_write_into_slices = true; + + /// Some vector kernels can do chunkwise execution using ExecSpanIterator, + /// in some cases accumulating some state. Other kernels (like Take) need to + /// be passed whole arrays and don't work on ChunkedArray inputs + bool can_execute_chunkwise = true; + + /// Some kernels (like unique and value_counts) yield non-chunked output from + /// chunked-array inputs. This option controls how the results are boxed when + /// returned from ExecVectorFunction + /// + /// true -> ChunkedArray + /// false -> Array + bool output_chunked = true; +}; + +// ---------------------------------------------------------------------- +// ScalarAggregateKernel (for ScalarAggregateFunction) + +using ScalarAggregateConsume = Status (*)(KernelContext*, const ExecSpan&); +using ScalarAggregateMerge = Status (*)(KernelContext*, KernelState&&, KernelState*); +// Finalize returns Datum to permit multiple return values +using ScalarAggregateFinalize = Status (*)(KernelContext*, Datum*); + +/// \brief Kernel data structure for implementations of +/// ScalarAggregateFunction. The four necessary components of an aggregation +/// kernel are the init, consume, merge, and finalize functions. +/// +/// * init: creates a new KernelState for a kernel. +/// * consume: processes an ExecSpan and updates the KernelState found in the +/// KernelContext. +/// * merge: combines one KernelState with another. +/// * finalize: produces the end result of the aggregation using the +/// KernelState in the KernelContext. +struct ARROW_EXPORT ScalarAggregateKernel : public Kernel { + ScalarAggregateKernel(std::shared_ptr sig, KernelInit init, + ScalarAggregateConsume consume, ScalarAggregateMerge merge, + ScalarAggregateFinalize finalize, const bool ordered) + : Kernel(std::move(sig), std::move(init)), + consume(consume), + merge(merge), + finalize(finalize), + ordered(ordered) {} + + ScalarAggregateKernel(std::vector in_types, OutputType out_type, + KernelInit init, ScalarAggregateConsume consume, + ScalarAggregateMerge merge, ScalarAggregateFinalize finalize, + const bool ordered) + : ScalarAggregateKernel( + KernelSignature::Make(std::move(in_types), std::move(out_type)), + std::move(init), consume, merge, finalize, ordered) {} + + /// \brief Merge a vector of KernelStates into a single KernelState. + /// The merged state will be returned and will be set on the KernelContext. + static Result> MergeAll( + const ScalarAggregateKernel* kernel, KernelContext* ctx, + std::vector> states); + + ScalarAggregateConsume consume; + ScalarAggregateMerge merge; + ScalarAggregateFinalize finalize; + /// \brief Whether this kernel requires ordering + /// Some aggregations, such as, "first", requires some kind of input order. The + /// order can be implicit, e.g., the order of the input data, or explicit, e.g. + /// the ordering specified with a window aggregation. + /// The caller of the aggregate kernel is responsible for passing data in some + /// defined order to the kernel. The flag here is a way for the kernel to tell + /// the caller that data passed to the kernel must be defined in some order. + bool ordered = false; +}; + +// ---------------------------------------------------------------------- +// HashAggregateKernel (for HashAggregateFunction) + +using HashAggregateResize = Status (*)(KernelContext*, int64_t); +using HashAggregateConsume = Status (*)(KernelContext*, const ExecSpan&); +using HashAggregateMerge = Status (*)(KernelContext*, KernelState&&, const ArrayData&); + +// Finalize returns Datum to permit multiple return values +using HashAggregateFinalize = Status (*)(KernelContext*, Datum*); + +/// \brief Kernel data structure for implementations of +/// HashAggregateFunction. The four necessary components of an aggregation +/// kernel are the init, consume, merge, and finalize functions. +/// +/// * init: creates a new KernelState for a kernel. +/// * resize: ensure that the KernelState can accommodate the specified number of groups. +/// * consume: processes an ExecSpan (which includes the argument as well +/// as an array of group identifiers) and updates the KernelState found in the +/// KernelContext. +/// * merge: combines one KernelState with another. +/// * finalize: produces the end result of the aggregation using the +/// KernelState in the KernelContext. +struct ARROW_EXPORT HashAggregateKernel : public Kernel { + HashAggregateKernel() = default; + + HashAggregateKernel(std::shared_ptr sig, KernelInit init, + HashAggregateResize resize, HashAggregateConsume consume, + HashAggregateMerge merge, HashAggregateFinalize finalize, + const bool ordered) + : Kernel(std::move(sig), std::move(init)), + resize(resize), + consume(consume), + merge(merge), + finalize(finalize), + ordered(ordered) {} + + HashAggregateKernel(std::vector in_types, OutputType out_type, + KernelInit init, HashAggregateConsume consume, + HashAggregateResize resize, HashAggregateMerge merge, + HashAggregateFinalize finalize, const bool ordered) + : HashAggregateKernel( + KernelSignature::Make(std::move(in_types), std::move(out_type)), + std::move(init), resize, consume, merge, finalize, ordered) {} + + HashAggregateResize resize; + HashAggregateConsume consume; + HashAggregateMerge merge; + HashAggregateFinalize finalize; + /// @brief whether the summarizer requires ordering + /// This is similar to ScalarAggregateKernel. See ScalarAggregateKernel + /// for detailed doc of this variable. + bool ordered = false; +}; + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/ordering.h b/pyarrow/include/arrow/compute/ordering.h new file mode 100644 index 0000000000000000000000000000000000000000..61caa2b570dd31dc988d34406f9b05c3573333e2 --- /dev/null +++ b/pyarrow/include/arrow/compute/ordering.h @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/type.h" +#include "arrow/util/compare.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +enum class SortOrder { + /// Arrange values in increasing order + Ascending, + /// Arrange values in decreasing order + Descending, +}; + +enum class NullPlacement { + /// Place nulls and NaNs before any non-null values. + /// NaNs will come after nulls. + AtStart, + /// Place nulls and NaNs after any non-null values. + /// NaNs will come before nulls. + AtEnd, +}; + +/// \brief One sort key for PartitionNthIndices (TODO) and SortIndices +class ARROW_EXPORT SortKey : public util::EqualityComparable { + public: + explicit SortKey(FieldRef target, SortOrder order = SortOrder::Ascending) + : target(std::move(target)), order(order) {} + + bool Equals(const SortKey& other) const; + std::string ToString() const; + + /// A FieldRef targeting the sort column. + FieldRef target; + /// How to order by this sort key. + SortOrder order; +}; + +class ARROW_EXPORT Ordering : public util::EqualityComparable { + public: + Ordering(std::vector sort_keys, + NullPlacement null_placement = NullPlacement::AtStart) + : sort_keys_(std::move(sort_keys)), null_placement_(null_placement) {} + /// true if data ordered by other is also ordered by this + /// + /// For example, if data is ordered by [a, b, c] then it is also ordered + /// by [a, b] but not by [b, c] or [a, b, c, d]. + /// + /// [a, b].IsSuborderOf([a, b, c]) - true + /// [a, b, c].IsSuborderOf([a, b, c]) - true + /// [b, c].IsSuborderOf([a, b, c]) - false + /// [a, b, c, d].IsSuborderOf([a, b, c]) - false + /// + /// The implicit ordering is not a suborder of any other ordering and + /// no other ordering is a suborder of it. The implicit ordering is not a + /// suborder of itself. + /// + /// The unordered ordering is a suborder of all other orderings but no + /// other ordering is a suborder of it. The unordered ordering is a suborder + /// of itself. + /// + /// The unordered ordering is a suborder of the implicit ordering. + bool IsSuborderOf(const Ordering& other) const; + + bool Equals(const Ordering& other) const; + std::string ToString() const; + + bool is_implicit() const { return is_implicit_; } + bool is_unordered() const { return !is_implicit_ && sort_keys_.empty(); } + + const std::vector& sort_keys() const { return sort_keys_; } + NullPlacement null_placement() const { return null_placement_; } + + static const Ordering& Implicit() { + static const Ordering kImplicit(true); + return kImplicit; + } + + static const Ordering& Unordered() { + static const Ordering kUnordered(false); + // It is also possible to get an unordered ordering by passing in an empty vector + // using the normal constructor. This is ok and useful when ordering comes from user + // input. + return kUnordered; + } + + private: + explicit Ordering(bool is_implicit) + : null_placement_(NullPlacement::AtStart), is_implicit_(is_implicit) {} + /// Column key(s) to order by and how to order by these sort keys. + std::vector sort_keys_; + /// Whether nulls and NaNs are placed at the start or at the end + NullPlacement null_placement_; + bool is_implicit_ = false; +}; + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/registry.h b/pyarrow/include/arrow/compute/registry.h new file mode 100644 index 0000000000000000000000000000000000000000..f31c4c1ba5920626578a4e4170e3cd2d28288545 --- /dev/null +++ b/pyarrow/include/arrow/compute/registry.h @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// NOTE: API is EXPERIMENTAL and will change without going through a +// deprecation cycle + +#pragma once + +#include +#include +#include + +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +class Function; +class FunctionOptionsType; + +/// \brief A mutable central function registry for built-in functions as well +/// as user-defined functions. Functions are implementations of +/// arrow::compute::Function. +/// +/// Generally, each function contains kernels which are implementations of a +/// function for a specific argument signature. After looking up a function in +/// the registry, one can either execute it eagerly with Function::Execute or +/// use one of the function's dispatch methods to pick a suitable kernel for +/// lower-level function execution. +class ARROW_EXPORT FunctionRegistry { + public: + ~FunctionRegistry(); + + /// \brief Construct a new registry. + /// + /// Most users only need to use the global registry. + static std::unique_ptr Make(); + + /// \brief Construct a new nested registry with the given parent. + /// + /// Most users only need to use the global registry. The returned registry never changes + /// its parent, even when an operation allows overwriting. + static std::unique_ptr Make(FunctionRegistry* parent); + + /// \brief Check whether a new function can be added to the registry. + /// + /// \returns Status::KeyError if a function with the same name is already registered. + Status CanAddFunction(std::shared_ptr function, bool allow_overwrite = false); + + /// \brief Add a new function to the registry. + /// + /// \returns Status::KeyError if a function with the same name is already registered. + Status AddFunction(std::shared_ptr function, bool allow_overwrite = false); + + /// \brief Check whether an alias can be added for the given function name. + /// + /// \returns Status::KeyError if the function with the given name is not registered. + Status CanAddAlias(const std::string& target_name, const std::string& source_name); + + /// \brief Add alias for the given function name. + /// + /// \returns Status::KeyError if the function with the given name is not registered. + Status AddAlias(const std::string& target_name, const std::string& source_name); + + /// \brief Check whether a new function options type can be added to the registry. + /// + /// \return Status::KeyError if a function options type with the same name is already + /// registered. + Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false); + + /// \brief Add a new function options type to the registry. + /// + /// \returns Status::KeyError if a function options type with the same name is already + /// registered. + Status AddFunctionOptionsType(const FunctionOptionsType* options_type, + bool allow_overwrite = false); + + /// \brief Retrieve a function by name from the registry. + Result> GetFunction(const std::string& name) const; + + /// \brief Return vector of all entry names in the registry. + /// + /// Helpful for displaying a manifest of available functions. + std::vector GetFunctionNames() const; + + /// \brief Retrieve a function options type by name from the registry. + Result GetFunctionOptionsType( + const std::string& name) const; + + /// \brief The number of currently registered functions. + int num_functions() const; + + /// \brief The cast function object registered in AddFunction. + /// + /// Helpful for get cast function as needed. + const Function* cast_function() const; + + private: + FunctionRegistry(); + + // Use PIMPL pattern to not have std::unordered_map here + class FunctionRegistryImpl; + std::unique_ptr impl_; + + explicit FunctionRegistry(FunctionRegistryImpl* impl); +}; + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/row/grouper.h b/pyarrow/include/arrow/compute/row/grouper.h new file mode 100644 index 0000000000000000000000000000000000000000..9424559385b7391d4dc7d46ddbbb542803c9001e --- /dev/null +++ b/pyarrow/include/arrow/compute/row/grouper.h @@ -0,0 +1,198 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/compute/kernel.h" +#include "arrow/compute/visibility.h" +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +/// \brief A segment +/// A segment group is a chunk of continuous rows that have the same segment key. (For +/// example, in ordered time series processing, segment key can be "date", and a segment +/// group can be all the rows that belong to the same date.) A segment group can span +/// across multiple exec batches. A segment is a chunk of continuous rows that has the +/// same segment key within a given batch. When a segment group span cross batches, it +/// will have multiple segments. A segment never spans cross batches. The segment data +/// structure only makes sense when used along with a exec batch. +struct ARROW_COMPUTE_EXPORT Segment { + /// \brief the offset into the batch where the segment starts + int64_t offset; + /// \brief the length of the segment + int64_t length; + /// \brief whether the segment may be extended by a next one + bool is_open; + /// \brief whether the segment extends a preceeding one + bool extends; +}; + +inline bool operator==(const Segment& segment1, const Segment& segment2) { + return segment1.offset == segment2.offset && segment1.length == segment2.length && + segment1.is_open == segment2.is_open && segment1.extends == segment2.extends; +} +inline bool operator!=(const Segment& segment1, const Segment& segment2) { + return !(segment1 == segment2); +} + +/// \brief a helper class to divide a batch into segments of equal values +/// +/// For example, given a batch with two columns specifed as segment keys: +/// +/// A A [other columns]... +/// A A ... +/// A B ... +/// A B ... +/// A A ... +/// +/// Then the batch could be divided into 3 segments. The first would be rows 0 & 1, +/// the second would be rows 2 & 3, and the third would be row 4. +/// +/// Further, a segmenter keeps track of the last value seen. This allows it to calculate +/// segments which span batches. In our above example the last batch we emit would set +/// the "open" flag, which indicates whether the segment may extend into the next batch. +/// +/// If the next call to the segmenter starts with `A A` then that segment would set the +/// "extends" flag, which indicates whether the segment continues the last open batch. +class ARROW_COMPUTE_EXPORT RowSegmenter { + public: + virtual ~RowSegmenter() = default; + + /// \brief Construct a Segmenter which segments on the specified key types + /// + /// \param[in] key_types the specified key types + /// \param[in] nullable_keys whether values of the specified keys may be null + /// \param[in] ctx the execution context to use + static Result> Make( + const std::vector& key_types, bool nullable_keys, ExecContext* ctx); + + /// \brief Return the key types of this segmenter + virtual const std::vector& key_types() const = 0; + + /// \brief Reset this segmenter + /// + /// A segmenter normally extends (see `Segment`) a segment from one batch to the next. + /// If segment-extension is undesirable, for example when each batch is processed + /// independently, then `Reset` should be invoked before processing the next batch. + virtual Status Reset() = 0; + + /// \brief Get all segments for the given batch + virtual Result> GetSegments(const ExecSpan& batch) = 0; +}; + +/// Consumes batches of keys and yields batches of the group ids. +class ARROW_COMPUTE_EXPORT Grouper { + public: + virtual ~Grouper() = default; + + /// Construct a Grouper which receives the specified key types + static Result> Make(const std::vector& key_types, + ExecContext* ctx = default_exec_context()); + + /// Reset all intermediate state, make the grouper logically as just `Make`ed. + /// The underlying buffers, if any, may or may not be released though. + virtual Status Reset() = 0; + + /// Consume a batch of keys, producing the corresponding group ids as an integer array, + /// over a slice defined by an offset and length, which defaults to the batch length. + /// Currently only uint32 indices will be produced, eventually the bit width will only + /// be as wide as necessary. + virtual Result Consume(const ExecSpan& batch, int64_t offset = 0, + int64_t length = -1) = 0; + + /// Like Consume, but groups not already encountered emit null instead of + /// generating a new group id. + virtual Result Lookup(const ExecSpan& batch, int64_t offset = 0, + int64_t length = -1) = 0; + + /// Like Consume, but only populates the Grouper without returning the group ids. + virtual Status Populate(const ExecSpan& batch, int64_t offset = 0, + int64_t length = -1) = 0; + + /// Get current unique keys. May be called multiple times. + virtual Result GetUniques() = 0; + + /// Get the current number of groups. + virtual uint32_t num_groups() const = 0; + + /// \brief Assemble lists of indices of identical elements. + /// + /// \param[in] ids An unsigned, all-valid integral array which will be + /// used as grouping criteria. + /// \param[in] num_groups An upper bound for the elements of ids + /// \param[in] ctx Execution context to use during the operation + /// \return A num_groups-long ListArray where the slot at i contains a + /// list of indices where i appears in ids. + /// + /// MakeGroupings([ + /// 2, + /// 2, + /// 5, + /// 5, + /// 2, + /// 3 + /// ], 8) == [ + /// [], + /// [], + /// [0, 1, 4], + /// [5], + /// [], + /// [2, 3], + /// [], + /// [] + /// ] + static Result> MakeGroupings( + const UInt32Array& ids, uint32_t num_groups, + ExecContext* ctx = default_exec_context()); + + /// \brief Produce a ListArray whose slots are selections of `array` which correspond to + /// the provided groupings. + /// + /// For example, + /// ApplyGroupings([ + /// [], + /// [], + /// [0, 1, 4], + /// [5], + /// [], + /// [2, 3], + /// [], + /// [] + /// ], [2, 2, 5, 5, 2, 3]) == [ + /// [], + /// [], + /// [2, 2, 2], + /// [3], + /// [], + /// [5, 5], + /// [], + /// [] + /// ] + static Result> ApplyGroupings( + const ListArray& groupings, const Array& array, + ExecContext* ctx = default_exec_context()); +}; + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/type_fwd.h b/pyarrow/include/arrow/compute/type_fwd.h new file mode 100644 index 0000000000000000000000000000000000000000..016d97a0dbc2b3b77be0b07e7effca3669439eb8 --- /dev/null +++ b/pyarrow/include/arrow/compute/type_fwd.h @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/util/visibility.h" + +namespace arrow { + +struct Datum; +struct TypeHolder; + +namespace compute { + +class Function; +class ScalarAggregateFunction; +class FunctionExecutor; +class FunctionOptions; +class FunctionRegistry; + +/// \brief Return the process-global function registry. +// Defined in registry.cc +ARROW_EXPORT FunctionRegistry* GetFunctionRegistry(); + +class CastOptions; + +struct ExecBatch; +class ExecContext; +struct ExecValue; +class KernelContext; + +struct Kernel; +struct ScalarKernel; +struct ScalarAggregateKernel; +struct VectorKernel; + +struct KernelState; + +class Expression; + +ARROW_EXPORT ExecContext* default_exec_context(); +ARROW_EXPORT ExecContext* threaded_exec_context(); + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/util.h b/pyarrow/include/arrow/compute/util.h new file mode 100644 index 0000000000000000000000000000000000000000..51a24b50fe60d4e22a9de02111568efd7f9cf334 --- /dev/null +++ b/pyarrow/include/arrow/compute/util.h @@ -0,0 +1,221 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/compute/expression.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/compute/visibility.h" +#include "arrow/result.h" +#include "arrow/util/cpu_info.h" +#include "arrow/util/simd.h" + +#if defined(__clang__) || defined(__GNUC__) +# define BYTESWAP(x) __builtin_bswap64(x) +# define ROTL(x, n) (((x) << (n)) | ((x) >> ((-n) & 31))) +# define ROTL64(x, n) (((x) << (n)) | ((x) >> ((-n) & 63))) +#elif defined(_MSC_VER) +# include +# define BYTESWAP(x) _byteswap_uint64(x) +# define ROTL(x, n) _rotl((x), (n)) +# define ROTL64(x, n) _rotl64((x), (n)) +#endif + +namespace arrow { +namespace util { + +// Some platforms typedef int64_t as long int instead of long long int, +// which breaks the _mm256_i64gather_epi64 and _mm256_i32gather_epi64 intrinsics +// which need long long. +// We use the cast to the type below in these intrinsics to make the code +// compile in all cases. +// +using int64_for_gather_t = const long long int; // NOLINT runtime-int + +// All MiniBatch... classes use TempVectorStack for vector allocations and can +// only work with vectors up to 1024 elements. +// +// They should only be allocated on the stack to guarantee the right sequence +// of allocation and deallocation of vectors from TempVectorStack. +// +class MiniBatch { + public: + static constexpr int kLogMiniBatchLength = 10; + static constexpr int kMiniBatchLength = 1 << kLogMiniBatchLength; +}; + +namespace bit_util { + +ARROW_COMPUTE_EXPORT void bits_to_indexes(int bit_to_search, int64_t hardware_flags, + const int num_bits, const uint8_t* bits, + int* num_indexes, uint16_t* indexes, + int bit_offset = 0); + +ARROW_COMPUTE_EXPORT void bits_filter_indexes(int bit_to_search, int64_t hardware_flags, + const int num_bits, const uint8_t* bits, + const uint16_t* input_indexes, + int* num_indexes, uint16_t* indexes, + int bit_offset = 0); + +// Input and output indexes may be pointing to the same data (in-place filtering). +ARROW_COMPUTE_EXPORT void bits_split_indexes(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, int* num_indexes_bit0, + uint16_t* indexes_bit0, + uint16_t* indexes_bit1, int bit_offset = 0); + +// Bit 1 is replaced with byte 0xFF. +ARROW_COMPUTE_EXPORT void bits_to_bytes(int64_t hardware_flags, const int num_bits, + const uint8_t* bits, uint8_t* bytes, + int bit_offset = 0); + +// Return highest bit of each byte. +ARROW_COMPUTE_EXPORT void bytes_to_bits(int64_t hardware_flags, const int num_bits, + const uint8_t* bytes, uint8_t* bits, + int bit_offset = 0); + +ARROW_COMPUTE_EXPORT bool are_all_bytes_zero(int64_t hardware_flags, const uint8_t* bytes, + uint32_t num_bytes); + +#if defined(ARROW_HAVE_RUNTIME_AVX2) && defined(ARROW_HAVE_RUNTIME_BMI2) +// The functions below use BMI2 instructions, be careful before calling! + +namespace avx2 { +ARROW_COMPUTE_EXPORT void bits_filter_indexes_avx2(int bit_to_search, const int num_bits, + const uint8_t* bits, + const uint16_t* input_indexes, + int* num_indexes, uint16_t* indexes); +ARROW_COMPUTE_EXPORT void bits_to_indexes_avx2(int bit_to_search, const int num_bits, + const uint8_t* bits, int* num_indexes, + uint16_t* indexes, + uint16_t base_index = 0); +ARROW_COMPUTE_EXPORT void bits_to_bytes_avx2(const int num_bits, const uint8_t* bits, + uint8_t* bytes); +ARROW_COMPUTE_EXPORT void bytes_to_bits_avx2(const int num_bits, const uint8_t* bytes, + uint8_t* bits); +ARROW_COMPUTE_EXPORT bool are_all_bytes_zero_avx2(const uint8_t* bytes, + uint32_t num_bytes); +} // namespace avx2 + +#endif + +} // namespace bit_util +} // namespace util + +namespace compute { + +/// Modify an Expression with pre-order and post-order visitation. +/// `pre` will be invoked on each Expression. `pre` will visit Calls before their +/// arguments, `post_call` will visit Calls (and no other Expressions) after their +/// arguments. Visitors should return the Identical expression to indicate no change; this +/// will prevent unnecessary construction in the common case where a modification is not +/// possible/necessary/... +/// +/// If an argument was modified, `post_call` visits a reconstructed Call with the modified +/// arguments but also receives a pointer to the unmodified Expression as a second +/// argument. If no arguments were modified the unmodified Expression* will be nullptr. +template +Result ModifyExpression(Expression expr, const PreVisit& pre, + const PostVisitCall& post_call) { + ARROW_ASSIGN_OR_RAISE(expr, Result(pre(std::move(expr)))); + + auto call = expr.call(); + if (!call) return expr; + + bool at_least_one_modified = false; + std::vector modified_arguments; + + for (size_t i = 0; i < call->arguments.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto modified_argument, + ModifyExpression(call->arguments[i], pre, post_call)); + + if (Expression::Identical(modified_argument, call->arguments[i])) { + continue; + } + + if (!at_least_one_modified) { + modified_arguments = call->arguments; + at_least_one_modified = true; + } + + modified_arguments[i] = std::move(modified_argument); + } + + if (at_least_one_modified) { + // reconstruct the call expression with the modified arguments + auto modified_call = *call; + modified_call.arguments = std::move(modified_arguments); + return post_call(Expression(std::move(modified_call)), &expr); + } + + return post_call(std::move(expr), NULLPTR); +} + +// Helper class to calculate the modified number of rows to process using SIMD. +// +// Some array elements at the end will be skipped in order to avoid buffer +// overrun, when doing memory loads and stores using larger word size than a +// single array element. +// +class TailSkipForSIMD { + public: + static int64_t FixBitAccess(int num_bytes_accessed_together, int64_t num_rows, + int bit_offset) { + int64_t num_bytes = bit_util::BytesForBits(num_rows + bit_offset); + int64_t num_bytes_safe = + std::max(static_cast(0LL), num_bytes - num_bytes_accessed_together + 1); + int64_t num_rows_safe = + std::max(static_cast(0LL), 8 * num_bytes_safe - bit_offset); + return std::min(num_rows_safe, num_rows); + } + static int64_t FixBinaryAccess(int num_bytes_accessed_together, int64_t num_rows, + int64_t length) { + int64_t num_rows_to_skip = bit_util::CeilDiv(length, num_bytes_accessed_together); + int64_t num_rows_safe = + std::max(static_cast(0LL), num_rows - num_rows_to_skip); + return num_rows_safe; + } + static int64_t FixVarBinaryAccess(int num_bytes_accessed_together, int64_t num_rows, + const uint32_t* offsets) { + // Do not process rows that could read past the end of the buffer using N + // byte loads/stores. + // + int64_t num_rows_safe = num_rows; + while (num_rows_safe > 0 && + offsets[num_rows_safe] + num_bytes_accessed_together > offsets[num_rows]) { + --num_rows_safe; + } + return num_rows_safe; + } + static int FixSelection(int64_t num_rows_safe, int num_selected, + const uint16_t* selection) { + int num_selected_safe = num_selected; + while (num_selected_safe > 0 && selection[num_selected_safe - 1] >= num_rows_safe) { + --num_selected_safe; + } + return num_selected_safe; + } +}; + +} // namespace compute +} // namespace arrow diff --git a/pyarrow/include/arrow/compute/visibility.h b/pyarrow/include/arrow/compute/visibility.h new file mode 100644 index 0000000000000000000000000000000000000000..ae994bd233329ff9ab0456191d3f8adf85ee3068 --- /dev/null +++ b/pyarrow/include/arrow/compute/visibility.h @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#if defined(_WIN32) || defined(__CYGWIN__) +# if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4251) +# else +# pragma GCC diagnostic ignored "-Wattributes" +# endif + +# ifdef ARROW_COMPUTE_STATIC +# define ARROW_COMPUTE_EXPORT +# elif defined(ARROW_COMPUTE_EXPORTING) +# define ARROW_COMPUTE_EXPORT __declspec(dllexport) +# else +# define ARROW_COMPUTE_EXPORT __declspec(dllimport) +# endif + +# define ARROW_COMPUTE_NO_EXPORT + +# if defined(_MSC_VER) +# pragma warning(pop) +# endif + +#else // Not Windows +# ifndef ARROW_COMPUTE_EXPORT +# define ARROW_COMPUTE_EXPORT __attribute__((visibility("default"))) +# endif +# ifndef ARROW_COMPUTE_NO_EXPORT +# define ARROW_COMPUTE_NO_EXPORT __attribute__((visibility("hidden"))) +# endif +#endif diff --git a/pyarrow/include/arrow/config.h b/pyarrow/include/arrow/config.h new file mode 100644 index 0000000000000000000000000000000000000000..617d6c268b55ea344a3fe7f96141ff0f7e4d3f88 --- /dev/null +++ b/pyarrow/include/arrow/config.h @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/status.h" +#include "arrow/util/config.h" // IWYU pragma: export +#include "arrow/util/visibility.h" + +namespace arrow { + +struct BuildInfo { + /// The packed version number, e.g. 1002003 (decimal) for Arrow 1.2.3 + int version; + /// The "major" version number, e.g. 1 for Arrow 1.2.3 + int version_major; + /// The "minor" version number, e.g. 2 for Arrow 1.2.3 + int version_minor; + /// The "patch" version number, e.g. 3 for Arrow 1.2.3 + int version_patch; + /// The version string, e.g. "1.2.3" + std::string version_string; + std::string so_version; + std::string full_so_version; + + /// The CMake compiler identifier, e.g. "GNU" + std::string compiler_id; + std::string compiler_version; + std::string compiler_flags; + + /// The git changeset id, if available + std::string git_id; + /// The git changeset description, if available + std::string git_description; + std::string package_kind; + + /// The uppercase build type, e.g. "DEBUG" or "RELEASE" + std::string build_type; +}; + +struct RuntimeInfo { + /// The enabled SIMD level + /// + /// This can be less than `detected_simd_level` if the ARROW_USER_SIMD_LEVEL + /// environment variable is set to another value. + std::string simd_level; + + /// The SIMD level available on the OS and CPU + std::string detected_simd_level; + + /// Whether using the OS-based timezone database + /// This is set at compile-time. + bool using_os_timezone_db; + + /// The path to the timezone database; by default None. + std::optional timezone_db_path; +}; + +/// \brief Get runtime build info. +/// +/// The returned values correspond to exact loaded version of the Arrow library, +/// rather than the values frozen at application compile-time through the `ARROW_*` +/// preprocessor definitions. +ARROW_EXPORT +const BuildInfo& GetBuildInfo(); + +/// \brief Get runtime info. +/// +ARROW_EXPORT +RuntimeInfo GetRuntimeInfo(); + +struct GlobalOptions { + /// Path to text timezone database. This is only configurable on Windows, + /// which does not have a compatible OS timezone database. + std::optional timezone_db_path; +}; + +ARROW_EXPORT +Status Initialize(const GlobalOptions& options) noexcept; + +} // namespace arrow diff --git a/pyarrow/include/arrow/csv/api.h b/pyarrow/include/arrow/csv/api.h new file mode 100644 index 0000000000000000000000000000000000000000..4af1835cd709d43e0abe3b39b46531cae9a047fc --- /dev/null +++ b/pyarrow/include/arrow/csv/api.h @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/csv/options.h" +#include "arrow/csv/reader.h" +#include "arrow/csv/writer.h" diff --git a/pyarrow/include/arrow/csv/chunker.h b/pyarrow/include/arrow/csv/chunker.h new file mode 100644 index 0000000000000000000000000000000000000000..662b16ec40a9485547ce01b32ea0325a23122711 --- /dev/null +++ b/pyarrow/include/arrow/csv/chunker.h @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/csv/options.h" +#include "arrow/status.h" +#include "arrow/util/delimiting.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace csv { + +ARROW_EXPORT +std::unique_ptr MakeChunker(const ParseOptions& options); + +} // namespace csv +} // namespace arrow diff --git a/pyarrow/include/arrow/csv/column_builder.h b/pyarrow/include/arrow/csv/column_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..9fc4643d9d4c3d4cccff7f9d179e62b5f720ac6f --- /dev/null +++ b/pyarrow/include/arrow/csv/column_builder.h @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/result.h" +#include "arrow/type_fwd.h" +#include "arrow/util/type_fwd.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace csv { + +class BlockParser; +struct ConvertOptions; + +class ARROW_EXPORT ColumnBuilder { + public: + virtual ~ColumnBuilder() = default; + + /// Spawn a task that will try to convert and append the given CSV block. + /// All calls to Append() should happen on the same thread, otherwise + /// call Insert() instead. + virtual void Append(const std::shared_ptr& parser) = 0; + + /// Spawn a task that will try to convert and insert the given CSV block + virtual void Insert(int64_t block_index, + const std::shared_ptr& parser) = 0; + + /// Return the final chunked array. The TaskGroup _must_ have finished! + virtual Result> Finish() = 0; + + std::shared_ptr task_group() { return task_group_; } + + /// Construct a strictly-typed ColumnBuilder. + static Result> Make( + MemoryPool* pool, const std::shared_ptr& type, int32_t col_index, + std::shared_ptr options, + std::shared_ptr task_group); + + /// Construct a type-inferring ColumnBuilder. + static Result> Make( + MemoryPool* pool, int32_t col_index, std::shared_ptr options, + std::shared_ptr task_group); + + /// Construct a ColumnBuilder for a column of nulls + /// (i.e. not present in the CSV file). + static Result> MakeNull( + MemoryPool* pool, const std::shared_ptr& type, + std::shared_ptr task_group); + + protected: + explicit ColumnBuilder(std::shared_ptr task_group) + : task_group_(std::move(task_group)) {} + + std::shared_ptr task_group_; +}; + +} // namespace csv +} // namespace arrow diff --git a/pyarrow/include/arrow/csv/column_decoder.h b/pyarrow/include/arrow/csv/column_decoder.h new file mode 100644 index 0000000000000000000000000000000000000000..5fbbd5df58b1c588b88e16b68da50b9399211abc --- /dev/null +++ b/pyarrow/include/arrow/csv/column_decoder.h @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/result.h" +#include "arrow/type_fwd.h" +#include "arrow/util/type_fwd.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace csv { + +class BlockParser; +struct ConvertOptions; + +class ARROW_EXPORT ColumnDecoder { + public: + virtual ~ColumnDecoder() = default; + + /// Spawn a task that will try to convert and insert the given CSV block + virtual Future> Decode( + const std::shared_ptr& parser) = 0; + + /// Construct a strictly-typed ColumnDecoder. + static Result> Make(MemoryPool* pool, + std::shared_ptr type, + int32_t col_index, + const ConvertOptions& options); + + /// Construct a type-inferring ColumnDecoder. + /// Inference will run only on the first block, the type will be frozen afterwards. + static Result> Make(MemoryPool* pool, int32_t col_index, + const ConvertOptions& options); + + /// Construct a ColumnDecoder for a column of nulls + /// (i.e. not present in the CSV file). + static Result> MakeNull(MemoryPool* pool, + std::shared_ptr type); + + protected: + ColumnDecoder() = default; +}; + +} // namespace csv +} // namespace arrow diff --git a/pyarrow/include/arrow/csv/converter.h b/pyarrow/include/arrow/csv/converter.h new file mode 100644 index 0000000000000000000000000000000000000000..639f692f26a1ba3a134caac68a432ac22f068917 --- /dev/null +++ b/pyarrow/include/arrow/csv/converter.h @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/csv/options.h" +#include "arrow/result.h" +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace csv { + +class BlockParser; + +class ARROW_EXPORT Converter { + public: + Converter(const std::shared_ptr& type, const ConvertOptions& options, + MemoryPool* pool); + virtual ~Converter() = default; + + virtual Result> Convert(const BlockParser& parser, + int32_t col_index) = 0; + + std::shared_ptr type() const { return type_; } + + // Create a Converter for the given data type + static Result> Make( + const std::shared_ptr& type, const ConvertOptions& options, + MemoryPool* pool = default_memory_pool()); + + protected: + ARROW_DISALLOW_COPY_AND_ASSIGN(Converter); + + virtual Status Initialize() = 0; + + // CAUTION: ConvertOptions can grow large (if it customizes hundreds or + // thousands of columns), so avoid copying it in each Converter. + const ConvertOptions& options_; + MemoryPool* pool_; + std::shared_ptr type_; +}; + +class ARROW_EXPORT DictionaryConverter : public Converter { + public: + DictionaryConverter(const std::shared_ptr& value_type, + const ConvertOptions& options, MemoryPool* pool); + + // If the dictionary length goes above this value, conversion will fail + // with Status::IndexError. + virtual void SetMaxCardinality(int32_t max_length) = 0; + + // Create a Converter for the given dictionary value type. + // The dictionary index type will always be Int32. + static Result> Make( + const std::shared_ptr& value_type, const ConvertOptions& options, + MemoryPool* pool = default_memory_pool()); + + protected: + std::shared_ptr value_type_; +}; + +} // namespace csv +} // namespace arrow diff --git a/pyarrow/include/arrow/csv/invalid_row.h b/pyarrow/include/arrow/csv/invalid_row.h new file mode 100644 index 0000000000000000000000000000000000000000..4360ceaaea6ac07dd218c93ce13c3ab14c16fc63 --- /dev/null +++ b/pyarrow/include/arrow/csv/invalid_row.h @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +namespace arrow { +namespace csv { + +/// \brief Description of an invalid row +struct InvalidRow { + /// \brief Number of columns expected in the row + int32_t expected_columns; + /// \brief Actual number of columns found in the row + int32_t actual_columns; + /// \brief The physical row number if known or -1 + /// + /// This number is one-based and also accounts for non-data rows (such as + /// CSV header rows). + int64_t number; + /// \brief View of the entire row. Memory will be freed after callback returns + const std::string_view text; +}; + +/// \brief Result returned by an InvalidRowHandler +enum class InvalidRowResult { + // Generate an error describing this row + Error, + // Skip over this row + Skip +}; + +/// \brief callback for handling a row with an invalid number of columns while parsing +/// \return result indicating if an error should be returned from the parser or the row is +/// skipped +using InvalidRowHandler = std::function; + +} // namespace csv +} // namespace arrow diff --git a/pyarrow/include/arrow/csv/options.h b/pyarrow/include/arrow/csv/options.h new file mode 100644 index 0000000000000000000000000000000000000000..10e55bf838c33f00ab520bce7f4e145a7db8819a --- /dev/null +++ b/pyarrow/include/arrow/csv/options.h @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/csv/invalid_row.h" +#include "arrow/csv/type_fwd.h" +#include "arrow/io/interfaces.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class DataType; +class TimestampParser; + +namespace csv { + +// Silly workaround for https://github.com/michaeljones/breathe/issues/453 +constexpr char kDefaultEscapeChar = '\\'; + +struct ARROW_EXPORT ParseOptions { + // Parsing options + + /// Field delimiter + char delimiter = ','; + /// Whether quoting is used + bool quoting = true; + /// Quoting character (if `quoting` is true) + char quote_char = '"'; + /// Whether a quote inside a value is double-quoted + bool double_quote = true; + /// Whether escaping is used + bool escaping = false; + /// Escaping character (if `escaping` is true) + char escape_char = kDefaultEscapeChar; + /// Whether values are allowed to contain CR (0x0d) and LF (0x0a) characters + bool newlines_in_values = false; + /// Whether empty lines are ignored. If false, an empty line represents + /// a single empty value (assuming a one-column CSV file). + bool ignore_empty_lines = true; + /// A handler function for rows which do not have the correct number of columns + InvalidRowHandler invalid_row_handler; + + /// Create parsing options with default values + static ParseOptions Defaults(); + + /// \brief Test that all set options are valid + Status Validate() const; +}; + +struct ARROW_EXPORT ConvertOptions { + // Conversion options + + /// Whether to check UTF8 validity of string columns + bool check_utf8 = true; + /// Optional per-column types (disabling type inference on those columns) + std::unordered_map> column_types; + /// Recognized spellings for null values + std::vector null_values; + /// Recognized spellings for boolean true values + std::vector true_values; + /// Recognized spellings for boolean false values + std::vector false_values; + + /// Whether string / binary columns can have null values. + /// + /// If true, then strings in "null_values" are considered null for string columns. + /// If false, then all strings are valid string values. + bool strings_can_be_null = false; + + /// Whether quoted values can be null. + /// + /// If true, then strings in "null_values" are also considered null when they + /// appear quoted in the CSV file. Otherwise, quoted values are never considered null. + bool quoted_strings_can_be_null = true; + + /// Whether to try to automatically dict-encode string / binary data. + /// If true, then when type inference detects a string or binary column, + /// it is dict-encoded up to `auto_dict_max_cardinality` distinct values + /// (per chunk), after which it switches to regular encoding. + /// + /// This setting is ignored for non-inferred columns (those in `column_types`). + bool auto_dict_encode = false; + int32_t auto_dict_max_cardinality = 50; + + /// Decimal point character for floating-point and decimal data + char decimal_point = '.'; + + // XXX Should we have a separate FilterOptions? + + /// If non-empty, indicates the names of columns from the CSV file that should + /// be actually read and converted (in the vector's order). + /// Columns not in this vector will be ignored. + std::vector include_columns; + /// If false, columns in `include_columns` but not in the CSV file will error out. + /// If true, columns in `include_columns` but not in the CSV file will produce + /// a column of nulls (whose type is selected using `column_types`, + /// or null by default) + /// This option is ignored if `include_columns` is empty. + bool include_missing_columns = false; + + /// User-defined timestamp parsers, using the virtual parser interface in + /// arrow/util/value_parsing.h. More than one parser can be specified, and + /// the CSV conversion logic will try parsing values starting from the + /// beginning of this vector. If no parsers are specified, we use the default + /// built-in ISO-8601 parser. + std::vector> timestamp_parsers; + + /// Create conversion options with default values, including conventional + /// values for `null_values`, `true_values` and `false_values` + static ConvertOptions Defaults(); + + /// \brief Test that all set options are valid + Status Validate() const; +}; + +struct ARROW_EXPORT ReadOptions { + // Reader options + + /// Whether to use the global CPU thread pool + bool use_threads = true; + + /// \brief Block size we request from the IO layer. + /// + /// This will determine multi-threading granularity as well as + /// the size of individual record batches. + /// Minimum valid value for block size is 1 + int32_t block_size = 1 << 20; // 1 MB + + /// Number of header rows to skip (not including the row of column names, if any) + int32_t skip_rows = 0; + + /// Number of rows to skip after the column names are read, if any + int32_t skip_rows_after_names = 0; + + /// Column names for the target table. + /// If empty, fall back on autogenerate_column_names. + std::vector column_names; + + /// Whether to autogenerate column names if `column_names` is empty. + /// If true, column names will be of the form "f0", "f1"... + /// If false, column names will be read from the first CSV row after `skip_rows`. + bool autogenerate_column_names = false; + + /// Create read options with default values + static ReadOptions Defaults(); + + /// \brief Test that all set options are valid + Status Validate() const; +}; + +/// \brief Quoting style for CSV writing +enum class ARROW_EXPORT QuotingStyle { + /// Only enclose values in quotes which need them, because their CSV rendering can + /// contain quotes itself (e.g. strings or binary values) + Needed, + /// Enclose all valid values in quotes. Nulls are not quoted. May cause readers to + /// interpret all values as strings if schema is inferred. + AllValid, + /// Do not enclose any values in quotes. Prevents values from containing quotes ("), + /// cell delimiters (,) or line endings (\\r, \\n), (following RFC4180). If values + /// contain these characters, an error is caused when attempting to write. + None +}; + +struct ARROW_EXPORT WriteOptions { + /// Whether to write an initial header line with column names + bool include_header = true; + + /// \brief Maximum number of rows processed at a time + /// + /// The CSV writer converts and writes data in batches of N rows. + /// This number can impact performance. + int32_t batch_size = 1024; + + /// Field delimiter + char delimiter = ','; + + /// \brief The string to write for null values. Quotes are not allowed in this string. + std::string null_string; + + /// \brief IO context for writing. + io::IOContext io_context; + + /// \brief The end of line character to use for ending rows + std::string eol = "\n"; + + /// \brief Quoting style + QuotingStyle quoting_style = QuotingStyle::Needed; + + /// \brief Quoting style of header + /// + /// Note that `QuotingStyle::Needed` and `QuotingStyle::AllValid` have the same + /// effect of quoting all column names. + QuotingStyle quoting_header = QuotingStyle::Needed; + + /// Create write options with default values + static WriteOptions Defaults(); + + /// \brief Test that all set options are valid + Status Validate() const; +}; + +} // namespace csv +} // namespace arrow diff --git a/pyarrow/include/arrow/csv/parser.h b/pyarrow/include/arrow/csv/parser.h new file mode 100644 index 0000000000000000000000000000000000000000..c73e52ce831ed95b4abe83084b483c15660bae7e --- /dev/null +++ b/pyarrow/include/arrow/csv/parser.h @@ -0,0 +1,228 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/csv/options.h" +#include "arrow/csv/type_fwd.h" +#include "arrow/status.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class MemoryPool; + +namespace csv { + +/// Skip at most num_rows from the given input. The input pointer is updated +/// and the number of actually skipped rows is returns (may be less than +/// requested if the input is too short). +ARROW_EXPORT +int32_t SkipRows(const uint8_t* data, uint32_t size, int32_t num_rows, + const uint8_t** out_data); + +class BlockParserImpl; + +namespace detail { + +struct ParsedValueDesc { + uint32_t offset : 31; + bool quoted : 1; +}; + +class ARROW_EXPORT DataBatch { + public: + explicit DataBatch(int32_t num_cols) : num_cols_(num_cols) {} + + /// \brief Return the number of parsed rows (not skipped) + int32_t num_rows() const { return num_rows_; } + /// \brief Return the number of parsed columns + int32_t num_cols() const { return num_cols_; } + /// \brief Return the total size in bytes of parsed data + uint32_t num_bytes() const { return parsed_size_; } + /// \brief Return the number of skipped rows + int32_t num_skipped_rows() const { return static_cast(skipped_rows_.size()); } + + template + Status VisitColumn(int32_t col_index, int64_t first_row, Visitor&& visit) const { + using detail::ParsedValueDesc; + + int32_t batch_row = 0; + for (size_t buf_index = 0; buf_index < values_buffers_.size(); ++buf_index) { + const auto& values_buffer = values_buffers_[buf_index]; + const auto values = reinterpret_cast(values_buffer->data()); + const auto max_pos = + static_cast(values_buffer->size() / sizeof(ParsedValueDesc)) - 1; + for (int32_t pos = col_index; pos < max_pos; pos += num_cols_, ++batch_row) { + auto start = values[pos].offset; + auto stop = values[pos + 1].offset; + auto quoted = values[pos + 1].quoted; + Status status = visit(parsed_ + start, stop - start, quoted); + if (ARROW_PREDICT_FALSE(!status.ok())) { + return DecorateWithRowNumber(std::move(status), first_row, batch_row); + } + } + } + return Status::OK(); + } + + template + Status VisitLastRow(Visitor&& visit) const { + using detail::ParsedValueDesc; + + const auto& values_buffer = values_buffers_.back(); + const auto values = reinterpret_cast(values_buffer->data()); + const auto start_pos = + static_cast(values_buffer->size() / sizeof(ParsedValueDesc)) - + num_cols_ - 1; + for (int32_t col_index = 0; col_index < num_cols_; ++col_index) { + auto start = values[start_pos + col_index].offset; + auto stop = values[start_pos + col_index + 1].offset; + auto quoted = values[start_pos + col_index + 1].quoted; + ARROW_RETURN_NOT_OK(visit(parsed_ + start, stop - start, quoted)); + } + return Status::OK(); + } + + protected: + Status DecorateWithRowNumber(Status&& status, int64_t first_row, + int32_t batch_row) const { + if (first_row >= 0) { + // `skipped_rows_` is in ascending order by construction, so use bisection + // to find out how many rows were skipped before `batch_row`. + const auto skips_before = + std::upper_bound(skipped_rows_.begin(), skipped_rows_.end(), batch_row) - + skipped_rows_.begin(); + status = status.WithMessage("Row #", batch_row + skips_before + first_row, ": ", + status.message()); + } + // Use return_if so that when extra context is enabled it will be added + ARROW_RETURN_IF_(true, std::move(status), ARROW_STRINGIFY(status)); + return std::move(status); + } + + // The number of rows in this batch (not including any skipped ones) + int32_t num_rows_ = 0; + // The number of columns + int32_t num_cols_ = 0; + + // XXX should we ensure the parsed buffer is padded with 8 or 16 excess zero bytes? + // It may help with null parsing... + std::vector> values_buffers_; + std::shared_ptr parsed_buffer_; + const uint8_t* parsed_ = NULLPTR; + int32_t parsed_size_ = 0; + + // Record the current num_rows_ each time a row is skipped + std::vector skipped_rows_; + + friend class ::arrow::csv::BlockParserImpl; +}; + +} // namespace detail + +constexpr int32_t kMaxParserNumRows = 100000; + +/// \class BlockParser +/// \brief A reusable block-based parser for CSV data +/// +/// The parser takes a block of CSV data and delimits rows and fields, +/// unquoting and unescaping them on the fly. Parsed data is own by the +/// parser, so the original buffer can be discarded after Parse() returns. +/// +/// If the block is truncated (i.e. not all data can be parsed), it is up +/// to the caller to arrange the next block to start with the trailing data. +/// Also, if the previous block ends with CR (0x0d) and a new block starts +/// with LF (0x0a), the parser will consider the leading newline as an empty +/// line; the caller should therefore strip it. +class ARROW_EXPORT BlockParser { + public: + explicit BlockParser(ParseOptions options, int32_t num_cols = -1, + int64_t first_row = -1, int32_t max_num_rows = kMaxParserNumRows); + explicit BlockParser(MemoryPool* pool, ParseOptions options, int32_t num_cols = -1, + int64_t first_row = -1, int32_t max_num_rows = kMaxParserNumRows); + ~BlockParser(); + + /// \brief Parse a block of data + /// + /// Parse a block of CSV data, ingesting up to max_num_rows rows. + /// The number of bytes actually parsed is returned in out_size. + Status Parse(std::string_view data, uint32_t* out_size); + + /// \brief Parse sequential blocks of data + /// + /// Only the last block is allowed to be truncated. + Status Parse(const std::vector& data, uint32_t* out_size); + + /// \brief Parse the final block of data + /// + /// Like Parse(), but called with the final block in a file. + /// The last row may lack a trailing line separator. + Status ParseFinal(std::string_view data, uint32_t* out_size); + + /// \brief Parse the final sequential blocks of data + /// + /// Only the last block is allowed to be truncated. + Status ParseFinal(const std::vector& data, uint32_t* out_size); + + /// \brief Return the number of parsed rows + int32_t num_rows() const { return parsed_batch().num_rows(); } + /// \brief Return the number of parsed columns + int32_t num_cols() const { return parsed_batch().num_cols(); } + /// \brief Return the total size in bytes of parsed data + uint32_t num_bytes() const { return parsed_batch().num_bytes(); } + + /// \brief Return the total number of rows including rows which were skipped + int32_t total_num_rows() const { + return parsed_batch().num_rows() + parsed_batch().num_skipped_rows(); + } + + /// \brief Return the row number of the first row in the block or -1 if unsupported + int64_t first_row_num() const; + + /// \brief Visit parsed values in a column + /// + /// The signature of the visitor is + /// Status(const uint8_t* data, uint32_t size, bool quoted) + template + Status VisitColumn(int32_t col_index, Visitor&& visit) const { + return parsed_batch().VisitColumn(col_index, first_row_num(), + std::forward(visit)); + } + + template + Status VisitLastRow(Visitor&& visit) const { + return parsed_batch().VisitLastRow(std::forward(visit)); + } + + protected: + std::unique_ptr impl_; + + const detail::DataBatch& parsed_batch() const; +}; + +} // namespace csv +} // namespace arrow diff --git a/pyarrow/include/arrow/csv/reader.h b/pyarrow/include/arrow/csv/reader.h new file mode 100644 index 0000000000000000000000000000000000000000..bae301dc14815a6fdf9388a08c4f9068155f20a6 --- /dev/null +++ b/pyarrow/include/arrow/csv/reader.h @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/csv/options.h" // IWYU pragma: keep +#include "arrow/io/interfaces.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" +#include "arrow/util/future.h" +#include "arrow/util/thread_pool.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace io { +class InputStream; +} // namespace io + +namespace csv { + +/// A class that reads an entire CSV file into a Arrow Table +class ARROW_EXPORT TableReader { + public: + virtual ~TableReader() = default; + + /// Read the entire CSV file and convert it to a Arrow Table + virtual Result> Read() = 0; + /// Read the entire CSV file and convert it to a Arrow Table + virtual Future> ReadAsync() = 0; + + /// Create a TableReader instance + static Result> Make(io::IOContext io_context, + std::shared_ptr input, + const ReadOptions&, + const ParseOptions&, + const ConvertOptions&); +}; + +/// \brief A class that reads a CSV file incrementally +/// +/// Caveats: +/// - For now, this is always single-threaded (regardless of `ReadOptions::use_threads`. +/// - Type inference is done on the first block and types are frozen afterwards; +/// to make sure the right data types are inferred, either set +/// `ReadOptions::block_size` to a large enough value, or use +/// `ConvertOptions::column_types` to set the desired data types explicitly. +class ARROW_EXPORT StreamingReader : public RecordBatchReader { + public: + virtual ~StreamingReader() = default; + + virtual Future> ReadNextAsync() = 0; + + /// \brief Return the number of bytes which have been read and processed + /// + /// The returned number includes CSV bytes which the StreamingReader has + /// finished processing, but not bytes for which some processing (e.g. + /// CSV parsing or conversion to Arrow layout) is still ongoing. + /// + /// Furthermore, the following rules apply: + /// - bytes skipped by `ReadOptions.skip_rows` are counted as being read before + /// any records are returned. + /// - bytes read while parsing the header are counted as being read before any + /// records are returned. + /// - bytes skipped by `ReadOptions.skip_rows_after_names` are counted after the + /// first batch is returned. + virtual int64_t bytes_read() const = 0; + + /// Create a StreamingReader instance + /// + /// This involves some I/O as the first batch must be loaded during the creation process + /// so it is returned as a future + /// + /// Currently, the StreamingReader is not async-reentrant and does not do any fan-out + /// parsing (see ARROW-11889) + static Future> MakeAsync( + io::IOContext io_context, std::shared_ptr input, + arrow::internal::Executor* cpu_executor, const ReadOptions&, const ParseOptions&, + const ConvertOptions&); + + static Result> Make( + io::IOContext io_context, std::shared_ptr input, + const ReadOptions&, const ParseOptions&, const ConvertOptions&); +}; + +/// \brief Count the logical rows of data in a CSV file (i.e. the +/// number of rows you would get if you read the file into a table). +ARROW_EXPORT +Future CountRowsAsync(io::IOContext io_context, + std::shared_ptr input, + arrow::internal::Executor* cpu_executor, + const ReadOptions&, const ParseOptions&); + +} // namespace csv +} // namespace arrow diff --git a/pyarrow/include/arrow/csv/test_common.h b/pyarrow/include/arrow/csv/test_common.h new file mode 100644 index 0000000000000000000000000000000000000000..07a41604478e81ac760e8d0b3501ef24996b0a4e --- /dev/null +++ b/pyarrow/include/arrow/csv/test_common.h @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/csv/parser.h" +#include "arrow/testing/visibility.h" + +namespace arrow { +namespace csv { + +ARROW_TESTING_EXPORT +std::string MakeCSVData(std::vector lines); + +// Make a BlockParser from a vector of lines representing a CSV file +ARROW_TESTING_EXPORT +void MakeCSVParser(std::vector lines, ParseOptions options, int32_t num_cols, + MemoryPool* pool, std::shared_ptr* out); + +ARROW_TESTING_EXPORT +void MakeCSVParser(std::vector lines, ParseOptions options, + std::shared_ptr* out); + +ARROW_TESTING_EXPORT +void MakeCSVParser(std::vector lines, std::shared_ptr* out); + +// Make a BlockParser from a vector of strings representing a single CSV column +ARROW_TESTING_EXPORT +void MakeColumnParser(std::vector items, std::shared_ptr* out); + +ARROW_TESTING_EXPORT +Result> MakeSampleCsvBuffer( + size_t num_rows, std::function is_valid = {}); + +} // namespace csv +} // namespace arrow diff --git a/pyarrow/include/arrow/csv/type_fwd.h b/pyarrow/include/arrow/csv/type_fwd.h new file mode 100644 index 0000000000000000000000000000000000000000..c0a53847a90ddb82067e0c9ac955cf4222c61742 --- /dev/null +++ b/pyarrow/include/arrow/csv/type_fwd.h @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace arrow { +namespace csv { + +class TableReader; +struct ConvertOptions; +struct ReadOptions; +struct ParseOptions; +struct WriteOptions; + +} // namespace csv +} // namespace arrow diff --git a/pyarrow/include/arrow/csv/writer.h b/pyarrow/include/arrow/csv/writer.h new file mode 100644 index 0000000000000000000000000000000000000000..d9d79e16608671859357e3adab88416fb0a9d04f --- /dev/null +++ b/pyarrow/include/arrow/csv/writer.h @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/csv/options.h" +#include "arrow/io/interfaces.h" +#include "arrow/ipc/type_fwd.h" +#include "arrow/record_batch.h" +#include "arrow/table.h" + +namespace arrow { +namespace csv { + +// Functionality for converting Arrow data to Comma separated value text. +// This library supports all primitive types that can be cast to a StringArray or +// a LargeStringArray. +// It applies to following formatting rules: +// - For non-binary types no quotes surround values. Nulls are represented as the empty +// string. +// - For binary types all non-null data is quoted (and quotes within data are escaped +// with an additional quote). +// Null values are empty and unquoted. + +/// \defgroup csv-write-functions High-level functions for writing CSV files +/// @{ + +/// \brief Convert table to CSV and write the result to output. +/// Experimental +ARROW_EXPORT Status WriteCSV(const Table& table, const WriteOptions& options, + arrow::io::OutputStream* output); +/// \brief Convert batch to CSV and write the result to output. +/// Experimental +ARROW_EXPORT Status WriteCSV(const RecordBatch& batch, const WriteOptions& options, + arrow::io::OutputStream* output); +/// \brief Convert batches read through a RecordBatchReader +/// to CSV and write the results to output. +/// Experimental +ARROW_EXPORT Status WriteCSV(const std::shared_ptr& reader, + const WriteOptions& options, + arrow::io::OutputStream* output); + +/// @} + +/// \defgroup csv-writer-factories Functions for creating an incremental CSV writer +/// @{ + +/// \brief Create a new CSV writer. User is responsible for closing the +/// actual OutputStream. +/// +/// \param[in] sink output stream to write to +/// \param[in] schema the schema of the record batches to be written +/// \param[in] options options for serialization +/// \return Result> +ARROW_EXPORT +Result> MakeCSVWriter( + std::shared_ptr sink, const std::shared_ptr& schema, + const WriteOptions& options = WriteOptions::Defaults()); + +/// \brief Create a new CSV writer. +/// +/// \param[in] sink output stream to write to (does not take ownership) +/// \param[in] schema the schema of the record batches to be written +/// \param[in] options options for serialization +/// \return Result> +ARROW_EXPORT +Result> MakeCSVWriter( + io::OutputStream* sink, const std::shared_ptr& schema, + const WriteOptions& options = WriteOptions::Defaults()); + +/// @} + +} // namespace csv +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/api.h b/pyarrow/include/arrow/dataset/api.h new file mode 100644 index 0000000000000000000000000000000000000000..38caa1cff19def66d09d0d6ed25c67ce52259f9a --- /dev/null +++ b/pyarrow/include/arrow/dataset/api.h @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include "arrow/compute/expression.h" +#include "arrow/dataset/dataset.h" +#include "arrow/dataset/discovery.h" +#include "arrow/dataset/file_base.h" +#ifdef ARROW_CSV +# include "arrow/dataset/file_csv.h" +#endif +#ifdef ARROW_JSON +# include "arrow/dataset/file_json.h" +#endif +#include "arrow/dataset/file_ipc.h" +#ifdef ARROW_ORC +# include "arrow/dataset/file_orc.h" +#endif +#ifdef ARROW_PARQUET +# include "arrow/dataset/file_parquet.h" +#endif +#include "arrow/dataset/scanner.h" diff --git a/pyarrow/include/arrow/dataset/dataset.h b/pyarrow/include/arrow/dataset/dataset.h new file mode 100644 index 0000000000000000000000000000000000000000..5c788ef5581c62fe1fc145b289ca74929bd54606 --- /dev/null +++ b/pyarrow/include/arrow/dataset/dataset.h @@ -0,0 +1,491 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/compute/expression.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/util/async_generator_fwd.h" +#include "arrow/util/future.h" +#include "arrow/util/macros.h" +#include "arrow/util/mutex.h" + +namespace arrow { + +namespace internal { +class Executor; +} // namespace internal + +namespace dataset { + +using RecordBatchGenerator = std::function>()>; + +/// \brief Description of a column to scan +struct ARROW_DS_EXPORT FragmentSelectionColumn { + /// \brief The path to the column to load + FieldPath path; + /// \brief The type of the column in the dataset schema + /// + /// A format may choose to ignore this field completely. For example, when + /// reading from IPC the reader can just return the column in the data type + /// that is stored on disk. There is no point in doing anything special. + /// + /// However, some formats may be capable of casting on the fly. For example, + /// when reading from CSV, if we know the target type of the column, we can + /// convert from string to the target type as we read. + DataType* requested_type; +}; + +/// \brief A list of columns that should be loaded from a fragment +/// +/// The paths in this selection should be referring to the fragment schema. This class +/// contains a virtual destructor as it is expected evolution strategies will need to +/// extend this to add any information needed to later evolve the batches. +/// +/// For example, in the basic evolution strategy, we keep track of which columns +/// were missing from the file so that we can fill those in with null when evolving. +class ARROW_DS_EXPORT FragmentSelection { + public: + explicit FragmentSelection(std::vector columns) + : columns_(std::move(columns)) {} + virtual ~FragmentSelection() = default; + /// The columns that should be loaded from the fragment + const std::vector& columns() const { return columns_; } + + private: + std::vector columns_; +}; + +/// \brief Instructions for scanning a particular fragment +/// +/// The fragment scan request is derived from ScanV2Options. The main +/// difference is that the scan options are based on the dataset schema +/// while the fragment request is based on the fragment schema. +struct ARROW_DS_EXPORT FragmentScanRequest { + /// \brief A row filter + /// + /// The filter expression should be written against the fragment schema. + /// + /// \see ScanV2Options for details on how this filter should be applied + compute::Expression filter = compute::literal(true); + + /// \brief The columns to scan + /// + /// These indices refer to the fragment schema + /// + /// Note: This is NOT a simple list of top-level column indices. + /// For more details \see ScanV2Options + /// + /// If possible a fragment should only read from disk the data needed + /// to satisfy these columns. If a format cannot partially read a nested + /// column (e.g. JSON) then it must apply the column selection (in memory) + /// before returning the scanned batch. + std::shared_ptr fragment_selection; + /// \brief Options specific to the format being scanned + const FragmentScanOptions* format_scan_options; +}; + +/// \brief An iterator-like object that can yield batches created from a fragment +class ARROW_DS_EXPORT FragmentScanner { + public: + /// This instance will only be destroyed after all ongoing scan futures + /// have been completed. + /// + /// This means any callbacks created as part of the scan can safely + /// capture `this` + virtual ~FragmentScanner() = default; + /// \brief Scan a batch of data from the file + /// \param batch_number The index of the batch to read + virtual Future> ScanBatch(int batch_number) = 0; + /// \brief Calculate an estimate of how many data bytes the given batch will represent + /// + /// "Data bytes" should be the total size of all the buffers once the data has been + /// decoded into the Arrow format. + virtual int64_t EstimatedDataBytes(int batch_number) = 0; + /// \brief The number of batches in the fragment to scan + virtual int NumBatches() = 0; +}; + +/// \brief Information learned about a fragment through inspection +/// +/// This information can be used to figure out which fields need +/// to be read from a file and how the data read in should be evolved +/// to match the dataset schema. +/// +/// For example, from a CSV file we can inspect and learn the column +/// names and use those column names to determine which columns to load +/// from the CSV file. +struct ARROW_DS_EXPORT InspectedFragment { + explicit InspectedFragment(std::vector column_names) + : column_names(std::move(column_names)) {} + std::vector column_names; +}; + +/// \brief A granular piece of a Dataset, such as an individual file. +/// +/// A Fragment can be read/scanned separately from other fragments. It yields a +/// collection of RecordBatches when scanned +/// +/// Note that Fragments have well defined physical schemas which are reconciled by +/// the Datasets which contain them; these physical schemas may differ from a parent +/// Dataset's schema and the physical schemas of sibling Fragments. +class ARROW_DS_EXPORT Fragment : public std::enable_shared_from_this { + public: + /// \brief An expression that represents no known partition information + static const compute::Expression kNoPartitionInformation; + + /// \brief Return the physical schema of the Fragment. + /// + /// The physical schema is also called the writer schema. + /// This method is blocking and may suffer from high latency filesystem. + /// The schema is cached after being read once, or may be specified at construction. + Result> ReadPhysicalSchema(); + + /// An asynchronous version of Scan + virtual Result ScanBatchesAsync( + const std::shared_ptr& options) = 0; + + /// \brief Inspect a fragment to learn basic information + /// + /// This will be called before a scan and a fragment should attach whatever + /// information will be needed to figure out an evolution strategy. This information + /// will then be passed to the call to BeginScan + virtual Future> InspectFragment( + const FragmentScanOptions* format_options, compute::ExecContext* exec_context); + + /// \brief Start a scan operation + virtual Future> BeginScan( + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, compute::ExecContext* exec_context); + + /// \brief Count the number of rows in this fragment matching the filter using metadata + /// only. That is, this method may perform I/O, but will not load data. + /// + /// If this is not possible, resolve with an empty optional. The fragment can perform + /// I/O (e.g. to read metadata) before it deciding whether it can satisfy the request. + virtual Future> CountRows( + compute::Expression predicate, const std::shared_ptr& options); + + /// \brief Clear any metadata that may have been cached by this object. + /// + /// A fragment may typically cache metadata to speed up repeated accesses. + /// In use cases when memory use is more critical than CPU time, calling + /// this function can help reclaim memory. + virtual Status ClearCachedMetadata(); + + virtual std::string type_name() const = 0; + virtual std::string ToString() const { return type_name(); } + + /// \brief An expression which evaluates to true for all data viewed by this + /// Fragment. + const compute::Expression& partition_expression() const { + return partition_expression_; + } + + virtual ~Fragment() = default; + + protected: + Fragment() = default; + explicit Fragment(compute::Expression partition_expression, + std::shared_ptr physical_schema); + + virtual Result> ReadPhysicalSchemaImpl() = 0; + + util::Mutex physical_schema_mutex_; + compute::Expression partition_expression_ = compute::literal(true); + // The physical schema that is inferred from the Fragment + std::shared_ptr physical_schema_; + // The physical schema that was passed to the Fragment constructor + std::shared_ptr given_physical_schema_; +}; + +/// \brief Per-scan options for fragment(s) in a dataset. +/// +/// These options are not intrinsic to the format or fragment itself, but do affect +/// the results of a scan. These are options which make sense to change between +/// repeated reads of the same dataset, such as format-specific conversion options +/// (that do not affect the schema). +/// +/// \ingroup dataset-scanning +class ARROW_DS_EXPORT FragmentScanOptions { + public: + virtual std::string type_name() const = 0; + virtual std::string ToString() const { return type_name(); } + virtual ~FragmentScanOptions() = default; +}; + +/// \defgroup dataset-implementations Concrete implementations +/// +/// @{ + +/// \brief A trivial Fragment that yields ScanTask out of a fixed set of +/// RecordBatch. +class ARROW_DS_EXPORT InMemoryFragment : public Fragment { + public: + class Scanner; + InMemoryFragment(std::shared_ptr schema, RecordBatchVector record_batches, + compute::Expression = compute::literal(true)); + explicit InMemoryFragment(RecordBatchVector record_batches, + compute::Expression = compute::literal(true)); + + Result ScanBatchesAsync( + const std::shared_ptr& options) override; + Future> CountRows( + compute::Expression predicate, + const std::shared_ptr& options) override; + + Future> InspectFragment( + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) override; + Future> BeginScan( + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) override; + + std::string type_name() const override { return "in-memory"; } + + protected: + Result> ReadPhysicalSchemaImpl() override; + + RecordBatchVector record_batches_; +}; + +/// @} + +using FragmentGenerator = AsyncGenerator>; + +/// \brief Rules for converting the dataset schema to and from fragment schemas +class ARROW_DS_EXPORT FragmentEvolutionStrategy { + public: + /// This instance will only be destroyed when all scan operations for the + /// fragment have completed. + virtual ~FragmentEvolutionStrategy() = default; + /// \brief A guarantee that applies to all batches of this fragment + /// + /// For example, if a fragment is missing one of the fields in the dataset + /// schema then a typical evolution strategy is to set that field to null. + /// + /// So if the column at index 3 is missing then the guarantee is + /// FieldRef(3) == null + /// + /// Individual field guarantees should be AND'd together and returned + /// as a single expression. + virtual Result GetGuarantee( + const std::vector& dataset_schema_selection) const = 0; + + /// \brief Return a fragment schema selection given a dataset schema selection + /// + /// For example, if the user wants fields 2 & 4 of the dataset schema and + /// in this fragment the field 2 is missing and the field 4 is at index 1 then + /// this should return {1} + virtual Result> DevolveSelection( + const std::vector& dataset_schema_selection) const = 0; + + /// \brief Return a filter expression bound to the fragment schema given + /// a filter expression bound to the dataset schema + /// + /// The dataset scan filter will first be simplified by the guarantee returned + /// by GetGuarantee. This means an evolution that only handles dropping or casting + /// fields doesn't need to do anything here except return the given filter. + /// + /// On the other hand, an evolution that is doing some kind of aliasing will likely + /// need to convert field references in the filter to the aliased field references + /// where appropriate. + virtual Result DevolveFilter( + const compute::Expression& filter) const = 0; + + /// \brief Convert a batch from the fragment schema to the dataset schema + /// + /// Typically this involves casting columns from the data type stored on disk + /// to the data type of the dataset schema. For example, this fragment might + /// have columns stored as int32 and the dataset schema might have int64 for + /// the column. In this case we should cast the column from int32 to int64. + /// + /// Note: A fragment may perform this cast as the data is read from disk. In + /// that case a cast might not be needed. + virtual Result EvolveBatch( + const std::shared_ptr& batch, + const std::vector& dataset_selection, + const FragmentSelection& selection) const = 0; + + /// \brief Return a string description of this strategy + virtual std::string ToString() const = 0; +}; + +/// \brief Lookup to create a FragmentEvolutionStrategy for a given fragment +class ARROW_DS_EXPORT DatasetEvolutionStrategy { + public: + virtual ~DatasetEvolutionStrategy() = default; + /// \brief Create a strategy for evolving from the given fragment + /// to the schema of the given dataset + virtual std::unique_ptr GetStrategy( + const Dataset& dataset, const Fragment& fragment, + const InspectedFragment& inspected_fragment) = 0; + + /// \brief Return a string description of this strategy + virtual std::string ToString() const = 0; +}; + +ARROW_DS_EXPORT std::unique_ptr +MakeBasicDatasetEvolutionStrategy(); + +/// \brief A container of zero or more Fragments. +/// +/// A Dataset acts as a union of Fragments, e.g. files deeply nested in a +/// directory. A Dataset has a schema to which Fragments must align during a +/// scan operation. This is analogous to Avro's reader and writer schema. +class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { + public: + /// \brief Begin to build a new Scan operation against this Dataset + Result> NewScan(); + + /// \brief GetFragments returns an iterator of Fragments given a predicate. + Result GetFragments(compute::Expression predicate); + Result GetFragments(); + + /// \brief Async versions of `GetFragments`. + Result GetFragmentsAsync(compute::Expression predicate); + Result GetFragmentsAsync(); + + const std::shared_ptr& schema() const { return schema_; } + + /// \brief An expression which evaluates to true for all data viewed by this Dataset. + /// May be null, which indicates no information is available. + const compute::Expression& partition_expression() const { + return partition_expression_; + } + + /// \brief The name identifying the kind of Dataset + virtual std::string type_name() const = 0; + + /// \brief Return a copy of this Dataset with a different schema. + /// + /// The copy will view the same Fragments. If the new schema is not compatible with the + /// original dataset's schema then an error will be raised. + virtual Result> ReplaceSchema( + std::shared_ptr schema) const = 0; + + /// \brief Rules used by this dataset to handle schema evolution + DatasetEvolutionStrategy* evolution_strategy() { return evolution_strategy_.get(); } + + virtual ~Dataset() = default; + + protected: + explicit Dataset(std::shared_ptr schema) : schema_(std::move(schema)) {} + + Dataset(std::shared_ptr schema, compute::Expression partition_expression); + + virtual Result GetFragmentsImpl(compute::Expression predicate) = 0; + /// \brief Default non-virtual implementation method for the base + /// `GetFragmentsAsyncImpl` method, which creates a fragment generator for + /// the dataset, possibly filtering results with a predicate (forwarding to + /// the synchronous `GetFragmentsImpl` method and moving the computations + /// to the background, using the IO thread pool). + /// + /// Currently, `executor` is always the same as `internal::GetCPUThreadPool()`, + /// which means the results from the underlying fragment generator will be + /// transferred to the default CPU thread pool. The generator itself is + /// offloaded to run on the default IO thread pool. + virtual Result GetFragmentsAsyncImpl( + compute::Expression predicate, arrow::internal::Executor* executor); + + std::shared_ptr schema_; + compute::Expression partition_expression_ = compute::literal(true); + std::unique_ptr evolution_strategy_ = + MakeBasicDatasetEvolutionStrategy(); +}; + +/// \addtogroup dataset-implementations +/// +/// @{ + +/// \brief A Source which yields fragments wrapping a stream of record batches. +/// +/// The record batches must match the schema provided to the source at construction. +class ARROW_DS_EXPORT InMemoryDataset : public Dataset { + public: + class RecordBatchGenerator { + public: + virtual ~RecordBatchGenerator() = default; + virtual RecordBatchIterator Get() const = 0; + }; + + /// Construct a dataset from a schema and a factory of record batch iterators. + InMemoryDataset(std::shared_ptr schema, + std::shared_ptr get_batches) + : Dataset(std::move(schema)), get_batches_(std::move(get_batches)) {} + + /// Convenience constructor taking a fixed list of batches + InMemoryDataset(std::shared_ptr schema, RecordBatchVector batches); + + /// Convenience constructor taking a Table + explicit InMemoryDataset(std::shared_ptr
table); + + std::string type_name() const override { return "in-memory"; } + + Result> ReplaceSchema( + std::shared_ptr schema) const override; + + protected: + Result GetFragmentsImpl(compute::Expression predicate) override; + + std::shared_ptr get_batches_; +}; + +/// \brief A Dataset wrapping child Datasets. +class ARROW_DS_EXPORT UnionDataset : public Dataset { + public: + /// \brief Construct a UnionDataset wrapping child Datasets. + /// + /// \param[in] schema the schema of the resulting dataset. + /// \param[in] children one or more child Datasets. Their schemas must be identical to + /// schema. + static Result> Make(std::shared_ptr schema, + DatasetVector children); + + const DatasetVector& children() const { return children_; } + + std::string type_name() const override { return "union"; } + + Result> ReplaceSchema( + std::shared_ptr schema) const override; + + protected: + Result GetFragmentsImpl(compute::Expression predicate) override; + + explicit UnionDataset(std::shared_ptr schema, DatasetVector children) + : Dataset(std::move(schema)), children_(std::move(children)) {} + + DatasetVector children_; + + friend class UnionDatasetFactory; +}; + +/// @} + +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/dataset_writer.h b/pyarrow/include/arrow/dataset/dataset_writer.h new file mode 100644 index 0000000000000000000000000000000000000000..edb1649b5f196aa3c6cd923c9e6540c4173fc102 --- /dev/null +++ b/pyarrow/include/arrow/dataset/dataset_writer.h @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/dataset/file_base.h" +#include "arrow/record_batch.h" +#include "arrow/status.h" +#include "arrow/util/async_util.h" +#include "arrow/util/future.h" + +namespace arrow { +namespace dataset { +namespace internal { + +// This lines up with our other defaults in the scanner and execution plan +constexpr uint64_t kDefaultDatasetWriterMaxRowsQueued = 8 * 1024 * 1024; + +/// \brief Utility class that manages a set of writers to different paths +/// +/// Writers may be closed and reopened (and a new file created) based on the dataset +/// write options (for example, max_rows_per_file or max_open_files) +/// +/// The dataset writer enforces its own back pressure based on the # of rows (as opposed +/// to # of batches which is how it is typically enforced elsewhere) and # of files. +class ARROW_DS_EXPORT DatasetWriter { + public: + /// \brief Create a dataset writer + /// + /// Will fail if basename_template is invalid or if there is existing data and + /// existing_data_behavior is kError + /// + /// \param write_options options to control how the data should be written + /// \param max_rows_queued max # of rows allowed to be queued before the dataset_writer + /// will ask for backpressure + static Result> Make( + FileSystemDatasetWriteOptions write_options, util::AsyncTaskScheduler* scheduler, + std::function pause_callback, std::function resume_callback, + std::function finish_callback, + uint64_t max_rows_queued = kDefaultDatasetWriterMaxRowsQueued); + + ~DatasetWriter(); + + /// \brief Write a batch to the dataset + /// \param[in] batch The batch to write + /// \param[in] directory The directory to write to + /// + /// Note: The written filename will be {directory}/{filename_factory(i)} where i is a + /// counter controlled by `max_open_files` and `max_rows_per_file` + /// + /// If multiple WriteRecordBatch calls arrive with the same `directory` then the batches + /// may be written to the same file. + /// + /// The returned future will be marked finished when the record batch has been queued + /// to be written. If the returned future is unfinished then this indicates the dataset + /// writer's queue is full and the data provider should pause. + /// + /// This method is NOT async reentrant. The returned future will only be unfinished + /// if back pressure needs to be applied. Async reentrancy is not necessary for + /// concurrent writes to happen. Calling this method again before the previous future + /// completes will not just violate max_rows_queued but likely lead to race conditions. + /// + /// One thing to note is that the ordering of your data can affect your maximum + /// potential parallelism. If this seems odd then consider a dataset where the first + /// 1000 batches go to the same directory and then the 1001st batch goes to a different + /// directory. The only way to get two parallel writes immediately would be to queue + /// all 1000 pending writes to the first directory. + void WriteRecordBatch(std::shared_ptr batch, const std::string& directory, + const std::string& prefix = ""); + + /// Finish all pending writes and close any open files + void Finish(); + + protected: + DatasetWriter(FileSystemDatasetWriteOptions write_options, + util::AsyncTaskScheduler* scheduler, std::function pause_callback, + std::function resume_callback, + std::function finish_callback, + uint64_t max_rows_queued = kDefaultDatasetWriterMaxRowsQueued); + + class DatasetWriterImpl; + std::unique_ptr impl_; +}; + +} // namespace internal +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/discovery.h b/pyarrow/include/arrow/dataset/discovery.h new file mode 100644 index 0000000000000000000000000000000000000000..6d76dcef727e7643ba559d8802665755a4f8a870 --- /dev/null +++ b/pyarrow/include/arrow/dataset/discovery.h @@ -0,0 +1,275 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Logic for automatically determining the structure of multi-file +/// dataset with possible partitioning according to available +/// partitioning + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/dataset/partition.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/filesystem/type_fwd.h" +#include "arrow/result.h" +#include "arrow/util/macros.h" + +namespace arrow { +namespace dataset { + +/// \defgroup dataset-discovery Discovery API +/// +/// @{ + +struct InspectOptions { + /// See `fragments` property. + static constexpr int kInspectAllFragments = -1; + + /// Indicate how many fragments should be inspected to infer the unified dataset + /// schema. Limiting the number of fragments accessed improves the latency of + /// the discovery process when dealing with a high number of fragments and/or + /// high latency file systems. + /// + /// The default value of `1` inspects the schema of the first (in no particular + /// order) fragment only. If the dataset has a uniform schema for all fragments, + /// this default is the optimal value. In order to inspect all fragments and + /// robustly unify their potentially varying schemas, set this option to + /// `kInspectAllFragments`. A value of `0` disables inspection of fragments + /// altogether so only the partitioning schema will be inspected. + int fragments = 1; + + /// Control how to unify types. By default, types are merged strictly (the + /// type must match exactly, except nulls can be merged with other types). + Field::MergeOptions field_merge_options = Field::MergeOptions::Defaults(); +}; + +struct FinishOptions { + /// Finalize the dataset with this given schema. If the schema is not + /// provided, infer the schema via the Inspect, see the `inspect_options` + /// property. + std::shared_ptr schema = NULLPTR; + + /// If the schema is not provided, it will be discovered by passing the + /// following options to `DatasetDiscovery::Inspect`. + InspectOptions inspect_options{}; + + /// Indicate if the given Schema (when specified), should be validated against + /// the fragments' schemas. `inspect_options` will control how many fragments + /// are checked. + bool validate_fragments = false; +}; + +/// \brief DatasetFactory provides a way to inspect/discover a Dataset's expected +/// schema before materializing said Dataset. +class ARROW_DS_EXPORT DatasetFactory { + public: + /// \brief Get the schemas of the Fragments and Partitioning. + virtual Result>> InspectSchemas( + InspectOptions options) = 0; + + /// \brief Get unified schema for the resulting Dataset. + Result> Inspect(InspectOptions options = {}); + + /// \brief Create a Dataset + Result> Finish(); + /// \brief Create a Dataset with the given schema (see \a InspectOptions::schema) + Result> Finish(std::shared_ptr schema); + /// \brief Create a Dataset with the given options + virtual Result> Finish(FinishOptions options) = 0; + + /// \brief Optional root partition for the resulting Dataset. + const compute::Expression& root_partition() const { return root_partition_; } + /// \brief Set the root partition for the resulting Dataset. + Status SetRootPartition(compute::Expression partition) { + root_partition_ = std::move(partition); + return Status::OK(); + } + + virtual ~DatasetFactory() = default; + + protected: + DatasetFactory(); + + compute::Expression root_partition_; +}; + +/// @} + +/// \brief DatasetFactory provides a way to inspect/discover a Dataset's +/// expected schema before materialization. +/// \ingroup dataset-implementations +class ARROW_DS_EXPORT UnionDatasetFactory : public DatasetFactory { + public: + static Result> Make( + std::vector> factories); + + /// \brief Return the list of child DatasetFactory + const std::vector>& factories() const { + return factories_; + } + + /// \brief Get the schemas of the Datasets. + /// + /// Instead of applying options globally, it applies at each child factory. + /// This will not respect `options.fragments` exactly, but will respect the + /// spirit of peeking the first fragments or all of them. + Result>> InspectSchemas( + InspectOptions options) override; + + /// \brief Create a Dataset. + Result> Finish(FinishOptions options) override; + + protected: + explicit UnionDatasetFactory(std::vector> factories); + + std::vector> factories_; +}; + +/// \ingroup dataset-filesystem +struct FileSystemFactoryOptions { + /// Either an explicit Partitioning or a PartitioningFactory to discover one. + /// + /// If a factory is provided, it will be used to infer a schema for partition fields + /// based on file and directory paths then construct a Partitioning. The default + /// is a Partitioning which will yield no partition information. + /// + /// The (explicit or discovered) partitioning will be applied to discovered files + /// and the resulting partition information embedded in the Dataset. + PartitioningOrFactory partitioning{Partitioning::Default()}; + + /// For the purposes of applying the partitioning, paths will be stripped + /// of the partition_base_dir. Files not matching the partition_base_dir + /// prefix will be skipped for partition discovery. The ignored files will still + /// be part of the Dataset, but will not have partition information. + /// + /// Example: + /// partition_base_dir = "/dataset"; + /// + /// - "/dataset/US/sales.csv" -> "US/sales.csv" will be given to the partitioning + /// + /// - "/home/john/late_sales.csv" -> Will be ignored for partition discovery. + /// + /// This is useful for partitioning which parses directory when ordering + /// is important, e.g. DirectoryPartitioning. + std::string partition_base_dir; + + /// Invalid files (via selector or explicitly) will be excluded by checking + /// with the FileFormat::IsSupported method. This will incur IO for each files + /// in a serial and single threaded fashion. Disabling this feature will skip the + /// IO, but unsupported files may be present in the Dataset + /// (resulting in an error at scan time). + bool exclude_invalid_files = false; + + /// When discovering from a Selector (and not from an explicit file list), ignore + /// files and directories matching any of these prefixes. + /// + /// Example (with selector = "/dataset/**"): + /// selector_ignore_prefixes = {"_", ".DS_STORE" }; + /// + /// - "/dataset/data.csv" -> not ignored + /// - "/dataset/_metadata" -> ignored + /// - "/dataset/.DS_STORE" -> ignored + /// - "/dataset/_hidden/dat" -> ignored + /// - "/dataset/nested/.DS_STORE" -> ignored + std::vector selector_ignore_prefixes = { + ".", + "_", + }; +}; + +/// \brief FileSystemDatasetFactory creates a Dataset from a vector of +/// fs::FileInfo or a fs::FileSelector. +/// \ingroup dataset-filesystem +class ARROW_DS_EXPORT FileSystemDatasetFactory : public DatasetFactory { + public: + /// \brief Build a FileSystemDatasetFactory from an explicit list of + /// paths. + /// + /// \param[in] filesystem passed to FileSystemDataset + /// \param[in] paths passed to FileSystemDataset + /// \param[in] format passed to FileSystemDataset + /// \param[in] options see FileSystemFactoryOptions for more information. + static Result> Make( + std::shared_ptr filesystem, const std::vector& paths, + std::shared_ptr format, FileSystemFactoryOptions options); + + /// \brief Build a FileSystemDatasetFactory from a fs::FileSelector. + /// + /// The selector will expand to a vector of FileInfo. The expansion/crawling + /// is performed in this function call. Thus, the finalized Dataset is + /// working with a snapshot of the filesystem. + // + /// If options.partition_base_dir is not provided, it will be overwritten + /// with selector.base_dir. + /// + /// \param[in] filesystem passed to FileSystemDataset + /// \param[in] selector used to crawl and search files + /// \param[in] format passed to FileSystemDataset + /// \param[in] options see FileSystemFactoryOptions for more information. + static Result> Make( + std::shared_ptr filesystem, fs::FileSelector selector, + std::shared_ptr format, FileSystemFactoryOptions options); + + /// \brief Build a FileSystemDatasetFactory from an uri including filesystem + /// information. + /// + /// \param[in] uri passed to FileSystemDataset + /// \param[in] format passed to FileSystemDataset + /// \param[in] options see FileSystemFactoryOptions for more information. + static Result> Make(std::string uri, + std::shared_ptr format, + FileSystemFactoryOptions options); + + /// \brief Build a FileSystemDatasetFactory from an explicit list of + /// file information. + /// + /// \param[in] filesystem passed to FileSystemDataset + /// \param[in] files passed to FileSystemDataset + /// \param[in] format passed to FileSystemDataset + /// \param[in] options see FileSystemFactoryOptions for more information. + static Result> Make( + std::shared_ptr filesystem, const std::vector& files, + std::shared_ptr format, FileSystemFactoryOptions options); + + Result>> InspectSchemas( + InspectOptions options) override; + + Result> Finish(FinishOptions options) override; + + protected: + FileSystemDatasetFactory(std::vector files, + std::shared_ptr filesystem, + std::shared_ptr format, + FileSystemFactoryOptions options); + + Result> PartitionSchema(); + + std::vector files_; + std::shared_ptr fs_; + std::shared_ptr format_; + FileSystemFactoryOptions options_; +}; + +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/file_base.h b/pyarrow/include/arrow/dataset/file_base.h new file mode 100644 index 0000000000000000000000000000000000000000..e13c1312a479f57047b54cd38de680fd3c5a2d0f --- /dev/null +++ b/pyarrow/include/arrow/dataset/file_base.h @@ -0,0 +1,499 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/dataset/dataset.h" +#include "arrow/dataset/partition.h" +#include "arrow/dataset/scanner.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/filesystem/filesystem.h" +#include "arrow/io/file.h" +#include "arrow/type_fwd.h" +#include "arrow/util/compression.h" + +namespace arrow { + +namespace dataset { + +/// \defgroup dataset-file-formats File formats for reading and writing datasets +/// \defgroup dataset-filesystem File system datasets +/// +/// @{ + +/// \brief The path and filesystem where an actual file is located or a buffer which can +/// be read like a file +class ARROW_DS_EXPORT FileSource : public util::EqualityComparable { + public: + FileSource(std::string path, std::shared_ptr filesystem, + Compression::type compression = Compression::UNCOMPRESSED) + : file_info_(std::move(path)), + filesystem_(std::move(filesystem)), + compression_(compression) {} + + FileSource(fs::FileInfo info, std::shared_ptr filesystem, + Compression::type compression = Compression::UNCOMPRESSED) + : file_info_(std::move(info)), + filesystem_(std::move(filesystem)), + compression_(compression) {} + + explicit FileSource(std::shared_ptr buffer, + Compression::type compression = Compression::UNCOMPRESSED) + : buffer_(std::move(buffer)), compression_(compression) {} + + using CustomOpen = std::function>()>; + FileSource(CustomOpen open, int64_t size) + : custom_open_(std::move(open)), custom_size_(size) {} + + using CustomOpenWithCompression = + std::function>(Compression::type)>; + FileSource(CustomOpenWithCompression open_with_compression, int64_t size, + Compression::type compression = Compression::UNCOMPRESSED) + : custom_open_(std::bind(std::move(open_with_compression), compression)), + custom_size_(size), + compression_(compression) {} + + FileSource(std::shared_ptr file, int64_t size, + Compression::type compression = Compression::UNCOMPRESSED) + : custom_open_([=] { return ToResult(file); }), + custom_size_(size), + compression_(compression) {} + + explicit FileSource(std::shared_ptr file, + Compression::type compression = Compression::UNCOMPRESSED); + + FileSource() : custom_open_(CustomOpen{&InvalidOpen}) {} + + static std::vector FromPaths(const std::shared_ptr& fs, + std::vector paths) { + std::vector sources; + for (auto&& path : paths) { + sources.emplace_back(std::move(path), fs); + } + return sources; + } + + /// \brief Return the type of raw compression on the file, if any. + Compression::type compression() const { return compression_; } + + /// \brief Return the file path, if any. Only valid when file source wraps a path. + const std::string& path() const { + static std::string buffer_path = ""; + static std::string custom_open_path = ""; + return filesystem_ ? file_info_.path() : buffer_ ? buffer_path : custom_open_path; + } + + /// \brief Return the filesystem, if any. Otherwise returns nullptr + const std::shared_ptr& filesystem() const { return filesystem_; } + + /// \brief Return the buffer containing the file, if any. Otherwise returns nullptr + const std::shared_ptr& buffer() const { return buffer_; } + + /// \brief Get a RandomAccessFile which views this file source + Result> Open() const; + Future> OpenAsync() const; + + /// \brief Get the size (in bytes) of the file or buffer + /// If the file is compressed this should be the compressed (on-disk) size. + int64_t Size() const; + + /// \brief Get an InputStream which views this file source (and decompresses if needed) + /// \param[in] compression If nullopt, guess the compression scheme from the + /// filename, else decompress with the given codec + Result> OpenCompressed( + std::optional compression = std::nullopt) const; + + /// \brief equality comparison with another FileSource + bool Equals(const FileSource& other) const; + + private: + static Result> InvalidOpen() { + return Status::Invalid("Called Open() on an uninitialized FileSource"); + } + + fs::FileInfo file_info_; + std::shared_ptr filesystem_; + std::shared_ptr buffer_; + CustomOpen custom_open_; + int64_t custom_size_ = 0; + Compression::type compression_ = Compression::UNCOMPRESSED; +}; + +/// \brief Base class for file format implementation +class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this { + public: + /// Options affecting how this format is scanned. + /// + /// The options here can be overridden at scan time. + std::shared_ptr default_fragment_scan_options; + + virtual ~FileFormat() = default; + + /// \brief The name identifying the kind of file format + virtual std::string type_name() const = 0; + + virtual bool Equals(const FileFormat& other) const = 0; + + /// \brief Indicate if the FileSource is supported/readable by this format. + virtual Result IsSupported(const FileSource& source) const = 0; + + /// \brief Return the schema of the file if possible. + virtual Result> Inspect(const FileSource& source) const = 0; + + /// \brief Learn what we need about the file before we start scanning it + virtual Future> InspectFragment( + const FileSource& source, const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) const; + + virtual Result ScanBatchesAsync( + const std::shared_ptr& options, + const std::shared_ptr& file) const = 0; + + virtual Future> CountRows( + const std::shared_ptr& file, compute::Expression predicate, + const std::shared_ptr& options); + + virtual Future> BeginScan( + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) const; + + /// \brief Open a fragment + virtual Result> MakeFragment( + FileSource source, compute::Expression partition_expression, + std::shared_ptr physical_schema); + + /// \brief Create a FileFragment for a FileSource. + Result> MakeFragment( + FileSource source, compute::Expression partition_expression); + + /// \brief Create a FileFragment for a FileSource. + Result> MakeFragment( + FileSource source, std::shared_ptr physical_schema = NULLPTR); + + /// \brief Create a writer for this format. + virtual Result> MakeWriter( + std::shared_ptr destination, std::shared_ptr schema, + std::shared_ptr options, + fs::FileLocator destination_locator) const = 0; + + /// \brief Get default write options for this format. + /// + /// May return null shared_ptr if this file format does not yet support + /// writing datasets. + virtual std::shared_ptr DefaultWriteOptions() = 0; + + protected: + explicit FileFormat(std::shared_ptr default_fragment_scan_options) + : default_fragment_scan_options(std::move(default_fragment_scan_options)) {} +}; + +/// \brief A Fragment that is stored in a file with a known format +class ARROW_DS_EXPORT FileFragment : public Fragment, + public util::EqualityComparable { + public: + Result ScanBatchesAsync( + const std::shared_ptr& options) override; + Future> CountRows( + compute::Expression predicate, + const std::shared_ptr& options) override; + Future> BeginScan( + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) override; + Future> InspectFragment( + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) override; + + std::string type_name() const override { return format_->type_name(); } + std::string ToString() const override { return source_.path(); }; + + const FileSource& source() const { return source_; } + const std::shared_ptr& format() const { return format_; } + + bool Equals(const FileFragment& other) const; + + protected: + FileFragment(FileSource source, std::shared_ptr format, + compute::Expression partition_expression, + std::shared_ptr physical_schema) + : Fragment(std::move(partition_expression), std::move(physical_schema)), + source_(std::move(source)), + format_(std::move(format)) {} + + Result> ReadPhysicalSchemaImpl() override; + + FileSource source_; + std::shared_ptr format_; + + friend class FileFormat; +}; + +/// \brief A Dataset of FileFragments. +/// +/// A FileSystemDataset is composed of one or more FileFragment. The fragments +/// are independent and don't need to share the same format and/or filesystem. +class ARROW_DS_EXPORT FileSystemDataset : public Dataset { + public: + /// \brief Create a FileSystemDataset. + /// + /// \param[in] schema the schema of the dataset + /// \param[in] root_partition the partition expression of the dataset + /// \param[in] format the format of each FileFragment. + /// \param[in] filesystem the filesystem of each FileFragment, or nullptr if the + /// fragments wrap buffers. + /// \param[in] fragments list of fragments to create the dataset from. + /// \param[in] partitioning the Partitioning object in case the dataset is created + /// with a known partitioning (e.g. from a discovered partitioning + /// through a DatasetFactory), or nullptr if not known. + /// + /// Note that fragments wrapping files resident in differing filesystems are not + /// permitted; to work with multiple filesystems use a UnionDataset. + /// + /// \return A constructed dataset. + static Result> Make( + std::shared_ptr schema, compute::Expression root_partition, + std::shared_ptr format, std::shared_ptr filesystem, + std::vector> fragments, + std::shared_ptr partitioning = NULLPTR); + + /// \brief Write a dataset. + static Status Write(const FileSystemDatasetWriteOptions& write_options, + std::shared_ptr scanner); + + /// \brief Return the type name of the dataset. + std::string type_name() const override { return "filesystem"; } + + /// \brief Replace the schema of the dataset. + Result> ReplaceSchema( + std::shared_ptr schema) const override; + + /// \brief Return the path of files. + std::vector files() const; + + /// \brief Return the format. + const std::shared_ptr& format() const { return format_; } + + /// \brief Return the filesystem. May be nullptr if the fragments wrap buffers. + const std::shared_ptr& filesystem() const { return filesystem_; } + + /// \brief Return the partitioning. May be nullptr if the dataset was not constructed + /// with a partitioning. + const std::shared_ptr& partitioning() const { return partitioning_; } + + std::string ToString() const; + + protected: + struct FragmentSubtrees; + + explicit FileSystemDataset(std::shared_ptr schema) + : Dataset(std::move(schema)) {} + + FileSystemDataset(std::shared_ptr schema, + compute::Expression partition_expression) + : Dataset(std::move(schema), partition_expression) {} + + Result GetFragmentsImpl(compute::Expression predicate) override; + + void SetupSubtreePruning(); + + std::shared_ptr format_; + std::shared_ptr filesystem_; + std::vector> fragments_; + std::shared_ptr partitioning_; + + std::shared_ptr subtrees_; +}; + +/// \brief Options for writing a file of this format. +class ARROW_DS_EXPORT FileWriteOptions { + public: + virtual ~FileWriteOptions() = default; + + const std::shared_ptr& format() const { return format_; } + + std::string type_name() const { return format_->type_name(); } + + protected: + explicit FileWriteOptions(std::shared_ptr format) + : format_(std::move(format)) {} + + std::shared_ptr format_; +}; + +/// \brief A writer for this format. +class ARROW_DS_EXPORT FileWriter { + public: + virtual ~FileWriter() = default; + + /// \brief Write the given batch. + virtual Status Write(const std::shared_ptr& batch) = 0; + + /// \brief Write all batches from the reader. + Status Write(RecordBatchReader* batches); + + /// \brief Indicate that writing is done. + virtual Future<> Finish(); + + const std::shared_ptr& format() const { return options_->format(); } + const std::shared_ptr& schema() const { return schema_; } + const std::shared_ptr& options() const { return options_; } + const fs::FileLocator& destination() const { return destination_locator_; } + + /// \brief After Finish() is called, provides number of bytes written to file. + Result GetBytesWritten() const; + + protected: + FileWriter(std::shared_ptr schema, std::shared_ptr options, + std::shared_ptr destination, + fs::FileLocator destination_locator) + : schema_(std::move(schema)), + options_(std::move(options)), + destination_(std::move(destination)), + destination_locator_(std::move(destination_locator)) {} + + virtual Future<> FinishInternal() = 0; + + std::shared_ptr schema_; + std::shared_ptr options_; + std::shared_ptr destination_; + fs::FileLocator destination_locator_; + std::optional bytes_written_; +}; + +/// \brief Options for writing a dataset. +struct ARROW_DS_EXPORT FileSystemDatasetWriteOptions { + /// Options for individual fragment writing. + std::shared_ptr file_write_options; + + /// FileSystem into which a dataset will be written. + std::shared_ptr filesystem; + + /// Root directory into which the dataset will be written. + std::string base_dir; + + /// Partitioning used to generate fragment paths. + std::shared_ptr partitioning; + + /// If true the order of rows in the dataset is preserved when writing with + /// multiple threads. This may cause notable performance degradation. + bool preserve_order = false; + + /// Maximum number of partitions any batch may be written into, default is 1K. + int max_partitions = 1024; + + /// Template string used to generate fragment basenames. + /// {i} will be replaced by an auto incremented integer. + std::string basename_template; + + /// A functor which will be applied on an incremented counter. The result will be + /// inserted into the basename_template in place of {i}. + /// + /// This can be used, for example, to left-pad the file counter. + std::function basename_template_functor; + + /// If greater than 0 then this will limit the maximum number of files that can be left + /// open. If an attempt is made to open too many files then the least recently used file + /// will be closed. If this setting is set too low you may end up fragmenting your data + /// into many small files. + /// + /// The default is 900 which also allows some # of files to be open by the scanner + /// before hitting the default Linux limit of 1024 + uint32_t max_open_files = 900; + + /// If greater than 0 then this will limit how many rows are placed in any single file. + /// Otherwise there will be no limit and one file will be created in each output + /// directory unless files need to be closed to respect max_open_files + uint64_t max_rows_per_file = 0; + + /// If greater than 0 then this will cause the dataset writer to batch incoming data + /// and only write the row groups to the disk when sufficient rows have accumulated. + /// The final row group size may be less than this value and other options such as + /// `max_open_files` or `max_rows_per_file` lead to smaller row group sizes. + uint64_t min_rows_per_group = 0; + + /// If greater than 0 then the dataset writer may split up large incoming batches into + /// multiple row groups. If this value is set then min_rows_per_group should also be + /// set or else you may end up with very small row groups (e.g. if the incoming row + /// group size is just barely larger than this value). + uint64_t max_rows_per_group = 1 << 20; + + /// Controls what happens if an output directory already exists. + ExistingDataBehavior existing_data_behavior = ExistingDataBehavior::kError; + + /// \brief If false the dataset writer will not create directories + /// This is mainly intended for filesystems that do not require directories such as S3. + bool create_dir = true; + + /// Callback to be invoked against all FileWriters before + /// they are finalized with FileWriter::Finish(). + std::function writer_pre_finish = [](FileWriter*) { + return Status::OK(); + }; + + /// Callback to be invoked against all FileWriters after they have + /// called FileWriter::Finish(). + std::function writer_post_finish = [](FileWriter*) { + return Status::OK(); + }; + + const std::shared_ptr& format() const { + return file_write_options->format(); + } +}; + +/// \brief Wraps FileSystemDatasetWriteOptions for consumption as compute::ExecNodeOptions +class ARROW_DS_EXPORT WriteNodeOptions : public acero::ExecNodeOptions { + public: + explicit WriteNodeOptions( + FileSystemDatasetWriteOptions options, + std::shared_ptr custom_metadata = NULLPTR) + : write_options(std::move(options)), custom_metadata(std::move(custom_metadata)) {} + + /// \brief Options to control how to write the dataset + FileSystemDatasetWriteOptions write_options; + /// \brief Optional schema to attach to all written batches + /// + /// By default, we will use the output schema of the input. + /// + /// This can be used to alter schema metadata, field nullability, or field metadata. + /// However, this cannot be used to change the type of data. If the custom schema does + /// not have the same number of fields and the same data types as the input then the + /// plan will fail. + std::shared_ptr custom_schema; + /// \brief Optional metadata to attach to written batches + std::shared_ptr custom_metadata; +}; + +/// @} + +namespace internal { +ARROW_DS_EXPORT void InitializeDatasetWriter(arrow::acero::ExecFactoryRegistry* registry); +} + +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/file_csv.h b/pyarrow/include/arrow/dataset/file_csv.h new file mode 100644 index 0000000000000000000000000000000000000000..42e3fd7246988e625e0d2e69a29bd40c553e3219 --- /dev/null +++ b/pyarrow/include/arrow/dataset/file_csv.h @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/csv/options.h" +#include "arrow/dataset/dataset.h" +#include "arrow/dataset/file_base.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/ipc/type_fwd.h" +#include "arrow/status.h" +#include "arrow/util/compression.h" + +namespace arrow { +namespace dataset { + +constexpr char kCsvTypeName[] = "csv"; + +/// \addtogroup dataset-file-formats +/// +/// @{ + +/// \brief A FileFormat implementation that reads from and writes to Csv files +class ARROW_DS_EXPORT CsvFileFormat : public FileFormat { + public: + // TODO(ARROW-18328) Remove this, moved to CsvFragmentScanOptions + /// Options affecting the parsing of CSV files + csv::ParseOptions parse_options = csv::ParseOptions::Defaults(); + + CsvFileFormat(); + + std::string type_name() const override { return kCsvTypeName; } + + bool Equals(const FileFormat& other) const override; + + Result IsSupported(const FileSource& source) const override; + + /// \brief Return the schema of the file if possible. + Result> Inspect(const FileSource& source) const override; + + Future> BeginScan( + const FragmentScanRequest& request, const InspectedFragment& inspected_fragment, + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) const override; + + Result ScanBatchesAsync( + const std::shared_ptr& scan_options, + const std::shared_ptr& file) const override; + + Future> InspectFragment( + const FileSource& source, const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) const override; + + Future> CountRows( + const std::shared_ptr& file, compute::Expression predicate, + const std::shared_ptr& options) override; + + Result> MakeWriter( + std::shared_ptr destination, std::shared_ptr schema, + std::shared_ptr options, + fs::FileLocator destination_locator) const override; + + std::shared_ptr DefaultWriteOptions() override; +}; + +/// \brief Per-scan options for CSV fragments +struct ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions { + std::string type_name() const override { return kCsvTypeName; } + + using StreamWrapFunc = std::function>( + std::shared_ptr)>; + + /// CSV conversion options + csv::ConvertOptions convert_options = csv::ConvertOptions::Defaults(); + + /// CSV reading options + /// + /// Note that use_threads is always ignored. + csv::ReadOptions read_options = csv::ReadOptions::Defaults(); + + /// CSV parse options + csv::ParseOptions parse_options = csv::ParseOptions::Defaults(); + + /// Optional stream wrapping function + /// + /// If defined, all open dataset file fragments will be passed + /// through this function. One possible use case is to transparently + /// transcode all input files from a given character set to utf8. + StreamWrapFunc stream_transform_func{}; +}; + +class ARROW_DS_EXPORT CsvFileWriteOptions : public FileWriteOptions { + public: + /// Options passed to csv::MakeCSVWriter. + std::shared_ptr write_options; + + protected: + explicit CsvFileWriteOptions(std::shared_ptr format) + : FileWriteOptions(std::move(format)) {} + + friend class CsvFileFormat; +}; + +class ARROW_DS_EXPORT CsvFileWriter : public FileWriter { + public: + Status Write(const std::shared_ptr& batch) override; + + private: + CsvFileWriter(std::shared_ptr destination, + std::shared_ptr writer, + std::shared_ptr schema, + std::shared_ptr options, + fs::FileLocator destination_locator); + + Future<> FinishInternal() override; + + std::shared_ptr destination_; + std::shared_ptr batch_writer_; + + friend class CsvFileFormat; +}; + +/// @} + +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/file_ipc.h b/pyarrow/include/arrow/dataset/file_ipc.h new file mode 100644 index 0000000000000000000000000000000000000000..0f7da82a0af5b1e58b724646853e8f482781778b --- /dev/null +++ b/pyarrow/include/arrow/dataset/file_ipc.h @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include + +#include "arrow/dataset/file_base.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/io/type_fwd.h" +#include "arrow/ipc/type_fwd.h" +#include "arrow/result.h" + +namespace arrow { +namespace dataset { + +/// \addtogroup dataset-file-formats +/// +/// @{ + +constexpr char kIpcTypeName[] = "ipc"; + +/// \brief A FileFormat implementation that reads from and writes to Ipc files +class ARROW_DS_EXPORT IpcFileFormat : public FileFormat { + public: + std::string type_name() const override { return kIpcTypeName; } + + IpcFileFormat(); + + bool Equals(const FileFormat& other) const override { + return type_name() == other.type_name(); + } + + Result IsSupported(const FileSource& source) const override; + + /// \brief Return the schema of the file if possible. + Result> Inspect(const FileSource& source) const override; + + Result ScanBatchesAsync( + const std::shared_ptr& options, + const std::shared_ptr& file) const override; + + Future> CountRows( + const std::shared_ptr& file, compute::Expression predicate, + const std::shared_ptr& options) override; + + Result> MakeWriter( + std::shared_ptr destination, std::shared_ptr schema, + std::shared_ptr options, + fs::FileLocator destination_locator) const override; + + std::shared_ptr DefaultWriteOptions() override; +}; + +/// \brief Per-scan options for IPC fragments +class ARROW_DS_EXPORT IpcFragmentScanOptions : public FragmentScanOptions { + public: + std::string type_name() const override { return kIpcTypeName; } + + /// Options passed to the IPC file reader. + /// included_fields, memory_pool, and use_threads are ignored. + std::shared_ptr options; + /// If present, the async scanner will enable I/O coalescing. + /// This is ignored by the sync scanner. + std::shared_ptr cache_options; +}; + +class ARROW_DS_EXPORT IpcFileWriteOptions : public FileWriteOptions { + public: + /// Options passed to ipc::MakeFileWriter. use_threads is ignored + std::shared_ptr options; + + /// custom_metadata written to the file's footer + std::shared_ptr metadata; + + protected: + explicit IpcFileWriteOptions(std::shared_ptr format) + : FileWriteOptions(std::move(format)) {} + + friend class IpcFileFormat; +}; + +class ARROW_DS_EXPORT IpcFileWriter : public FileWriter { + public: + Status Write(const std::shared_ptr& batch) override; + + private: + IpcFileWriter(std::shared_ptr destination, + std::shared_ptr writer, + std::shared_ptr schema, + std::shared_ptr options, + fs::FileLocator destination_locator); + + Future<> FinishInternal() override; + + std::shared_ptr destination_; + std::shared_ptr batch_writer_; + + friend class IpcFileFormat; +}; + +/// @} + +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/file_json.h b/pyarrow/include/arrow/dataset/file_json.h new file mode 100644 index 0000000000000000000000000000000000000000..4b8112d87095ccc9d02b0c52b4df2b1e674b8cc5 --- /dev/null +++ b/pyarrow/include/arrow/dataset/file_json.h @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/dataset/dataset.h" +#include "arrow/dataset/file_base.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/ipc/type_fwd.h" +#include "arrow/json/options.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/future.h" +#include "arrow/util/macros.h" + +namespace arrow::dataset { + +/// \addtogroup dataset-file-formats +/// +/// @{ + +constexpr char kJsonTypeName[] = "json"; + +/// \brief A FileFormat implementation that reads from JSON files +class ARROW_DS_EXPORT JsonFileFormat : public FileFormat { + public: + JsonFileFormat(); + + std::string type_name() const override { return kJsonTypeName; } + + bool Equals(const FileFormat& other) const override; + + Result IsSupported(const FileSource& source) const override; + + Result> Inspect(const FileSource& source) const override; + + Future> InspectFragment( + const FileSource& source, const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) const override; + + Future> BeginScan( + const FragmentScanRequest& scan_request, const InspectedFragment& inspected, + const FragmentScanOptions* format_options, + compute::ExecContext* exec_context) const override; + + Result ScanBatchesAsync( + const std::shared_ptr& scan_options, + const std::shared_ptr& file) const override; + + Future> CountRows( + const std::shared_ptr& file, compute::Expression predicate, + const std::shared_ptr& scan_options) override; + + Result> MakeWriter( + std::shared_ptr destination, std::shared_ptr schema, + std::shared_ptr options, + fs::FileLocator destination_locator) const override { + return Status::NotImplemented("Writing JSON files is not currently supported"); + } + + std::shared_ptr DefaultWriteOptions() override { return NULLPTR; } +}; + +/// \brief Per-scan options for JSON fragments +struct ARROW_DS_EXPORT JsonFragmentScanOptions : public FragmentScanOptions { + std::string type_name() const override { return kJsonTypeName; } + + /// @brief Options that affect JSON parsing + /// + /// Note: `explicit_schema` and `unexpected_field_behavior` are ignored. + json::ParseOptions parse_options = json::ParseOptions::Defaults(); + + /// @brief Options that affect JSON reading + json::ReadOptions read_options = json::ReadOptions::Defaults(); +}; + +/// @} + +} // namespace arrow::dataset diff --git a/pyarrow/include/arrow/dataset/file_orc.h b/pyarrow/include/arrow/dataset/file_orc.h new file mode 100644 index 0000000000000000000000000000000000000000..5bfefd1e02b5cccf74cf8ade579a937341aef013 --- /dev/null +++ b/pyarrow/include/arrow/dataset/file_orc.h @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include + +#include "arrow/dataset/file_base.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/io/type_fwd.h" +#include "arrow/result.h" + +namespace arrow { +namespace dataset { + +/// \addtogroup dataset-file-formats +/// +/// @{ + +constexpr char kOrcTypeName[] = "orc"; + +/// \brief A FileFormat implementation that reads from and writes to ORC files +class ARROW_DS_EXPORT OrcFileFormat : public FileFormat { + public: + OrcFileFormat(); + + std::string type_name() const override { return kOrcTypeName; } + + bool Equals(const FileFormat& other) const override { + return type_name() == other.type_name(); + } + + Result IsSupported(const FileSource& source) const override; + + /// \brief Return the schema of the file if possible. + Result> Inspect(const FileSource& source) const override; + + Result ScanBatchesAsync( + const std::shared_ptr& options, + const std::shared_ptr& file) const override; + + Future> CountRows( + const std::shared_ptr& file, compute::Expression predicate, + const std::shared_ptr& options) override; + + Result> MakeWriter( + std::shared_ptr destination, std::shared_ptr schema, + std::shared_ptr options, + fs::FileLocator destination_locator) const override; + + std::shared_ptr DefaultWriteOptions() override; +}; + +/// @} + +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/file_parquet.h b/pyarrow/include/arrow/dataset/file_parquet.h new file mode 100644 index 0000000000000000000000000000000000000000..1811a96bf986f69f8c6e6ad040fe653a519ba95e --- /dev/null +++ b/pyarrow/include/arrow/dataset/file_parquet.h @@ -0,0 +1,410 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "arrow/dataset/discovery.h" +#include "arrow/dataset/file_base.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/io/caching.h" + +namespace parquet { +class ParquetFileReader; +class Statistics; +class ColumnChunkMetaData; +class RowGroupMetaData; +class FileMetaData; +class FileDecryptionProperties; +class FileEncryptionProperties; + +class ReaderProperties; +class ArrowReaderProperties; + +class WriterProperties; +class ArrowWriterProperties; + +namespace arrow { +class FileReader; +class FileWriter; +struct SchemaManifest; +} // namespace arrow +} // namespace parquet + +namespace arrow { +namespace dataset { + +struct ParquetDecryptionConfig; +struct ParquetEncryptionConfig; + +/// \addtogroup dataset-file-formats +/// +/// @{ + +constexpr char kParquetTypeName[] = "parquet"; + +/// \brief A FileFormat implementation that reads from Parquet files +class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { + public: + ParquetFileFormat(); + + /// Convenience constructor which copies properties from a parquet::ReaderProperties. + /// memory_pool will be ignored. + explicit ParquetFileFormat(const parquet::ReaderProperties& reader_properties); + + std::string type_name() const override { return kParquetTypeName; } + + bool Equals(const FileFormat& other) const override; + + struct ReaderOptions { + /// \defgroup parquet-file-format-arrow-reader-properties properties which correspond + /// to members of parquet::ArrowReaderProperties. + /// + /// We don't embed parquet::ReaderProperties directly because column names (rather + /// than indices) are used to indicate dictionary columns, and other options are + /// deferred to scan time. + /// + /// @{ + std::unordered_set dict_columns; + arrow::TimeUnit::type coerce_int96_timestamp_unit = arrow::TimeUnit::NANO; + Type::type binary_type = Type::BINARY; + Type::type list_type = Type::LIST; + /// @} + } reader_options; + + Result IsSupported(const FileSource& source) const override; + + /// \brief Return the schema of the file if possible. + Result> Inspect(const FileSource& source) const override; + + Result ScanBatchesAsync( + const std::shared_ptr& options, + const std::shared_ptr& file) const override; + + Future> CountRows( + const std::shared_ptr& file, compute::Expression predicate, + const std::shared_ptr& options) override; + + using FileFormat::MakeFragment; + + /// \brief Create a Fragment targeting all RowGroups. + Result> MakeFragment( + FileSource source, compute::Expression partition_expression, + std::shared_ptr physical_schema) override; + + /// \brief Create a Fragment, restricted to the specified row groups. + Result> MakeFragment( + FileSource source, compute::Expression partition_expression, + std::shared_ptr physical_schema, std::vector row_groups); + + /// \brief Return a FileReader on the given source. + Result> GetReader( + const FileSource& source, const std::shared_ptr& options) const; + + Result> GetReader( + const FileSource& source, const std::shared_ptr& options, + const std::shared_ptr& metadata) const; + + Future> GetReaderAsync( + const FileSource& source, const std::shared_ptr& options) const; + + Future> GetReaderAsync( + const FileSource& source, const std::shared_ptr& options, + const std::shared_ptr& metadata) const; + + Result> MakeWriter( + std::shared_ptr destination, std::shared_ptr schema, + std::shared_ptr options, + fs::FileLocator destination_locator) const override; + + std::shared_ptr DefaultWriteOptions() override; +}; + +/// \brief A FileFragment with parquet logic. +/// +/// ParquetFileFragment provides a lazy (with respect to IO) interface to +/// scan parquet files. Any heavy IO calls are deferred to the Scan() method. +/// +/// The caller can provide an optional list of selected RowGroups to limit the +/// number of scanned RowGroups, or to partition the scans across multiple +/// threads. +/// +/// Metadata can be explicitly provided, enabling pushdown predicate benefits without +/// the potentially heavy IO of loading Metadata from the file system. This can induce +/// significant performance boost when scanning high latency file systems. +class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { + public: + Result SplitByRowGroup(compute::Expression predicate); + + /// \brief Return the RowGroups selected by this fragment. + const std::vector& row_groups() const { + if (row_groups_) return *row_groups_; + static std::vector empty; + return empty; + } + + /// \brief Return the FileMetaData associated with this fragment. + /// + /// This may return nullptr if the fragment wasn't scanned yet, or if + /// `ScanOptions::cache_metadata` was disabled. + std::shared_ptr metadata(); + + /// \brief Ensure this fragment's FileMetaData is in memory. + Status EnsureCompleteMetadata(parquet::arrow::FileReader* reader = NULLPTR); + + Status ClearCachedMetadata() override; + + /// \brief Return fragment which selects a filtered subset of this fragment's RowGroups. + Result> Subset(compute::Expression predicate); + Result> Subset(std::vector row_group_ids); + + static std::optional EvaluateStatisticsAsExpression( + const Field& field, const parquet::Statistics& statistics); + + static std::optional EvaluateStatisticsAsExpression( + const Field& field, const FieldRef& field_ref, + const parquet::Statistics& statistics); + + private: + ParquetFileFragment(FileSource source, std::shared_ptr format, + compute::Expression partition_expression, + std::shared_ptr physical_schema, + std::optional> row_groups); + + Status SetMetadata(std::shared_ptr metadata, + std::shared_ptr manifest, + std::shared_ptr original_metadata = {}); + + // Overridden to opportunistically set metadata since a reader must be opened anyway. + Result> ReadPhysicalSchemaImpl() override { + ARROW_RETURN_NOT_OK(EnsureCompleteMetadata()); + return physical_schema_; + } + + /// Return a filtered subset of row group indices. + Result> FilterRowGroups(compute::Expression predicate); + /// Simplify the predicate against the statistics of each row group. + Result> TestRowGroups(compute::Expression predicate); + /// Try to count rows matching the predicate using metadata. Expects + /// metadata to be present, and expects the predicate to have been + /// simplified against the partition expression already. + Result> TryCountRows(compute::Expression predicate); + + ParquetFileFormat& parquet_format_; + + /// Indices of row groups selected by this fragment, + /// or std::nullopt if all row groups are selected. + std::optional> row_groups_; + + // the expressions (combined for all columns for which statistics have been + // processed) are stored per column group + std::vector statistics_expressions_; + // statistics status are kept track of by Parquet Schema column indices + // (i.e. not Arrow schema field index) + std::vector statistics_expressions_complete_; + std::shared_ptr metadata_; + std::shared_ptr manifest_; + // The FileMetaData that owns the SchemaDescriptor pointed by SchemaManifest. + std::shared_ptr original_metadata_; + + friend class ParquetFileFormat; + friend class ParquetDatasetFactory; +}; + +/// \brief Per-scan options for Parquet fragments +class ARROW_DS_EXPORT ParquetFragmentScanOptions : public FragmentScanOptions { + public: + ParquetFragmentScanOptions(); + std::string type_name() const override { return kParquetTypeName; } + + /// Reader properties. Not all properties are respected: memory_pool comes from + /// ScanOptions. + std::shared_ptr reader_properties; + /// Arrow reader properties. Not all properties are respected: batch_size comes from + /// ScanOptions. Additionally, other options come from ParquetFileFormat::ReaderOptions. + std::shared_ptr arrow_reader_properties; + /// A configuration structure that provides decryption properties for a dataset + std::shared_ptr parquet_decryption_config = NULLPTR; +}; + +class ARROW_DS_EXPORT ParquetFileWriteOptions : public FileWriteOptions { + public: + /// \brief Parquet writer properties. + std::shared_ptr writer_properties; + + /// \brief Parquet Arrow writer properties. + std::shared_ptr arrow_writer_properties; + + // A configuration structure that provides encryption properties for a dataset + std::shared_ptr parquet_encryption_config = NULLPTR; + + protected: + explicit ParquetFileWriteOptions(std::shared_ptr format) + : FileWriteOptions(std::move(format)) {} + + friend class ParquetFileFormat; +}; + +class ARROW_DS_EXPORT ParquetFileWriter : public FileWriter { + public: + const std::shared_ptr& parquet_writer() const { + return parquet_writer_; + } + + Status Write(const std::shared_ptr& batch) override; + + private: + ParquetFileWriter(std::shared_ptr destination, + std::shared_ptr writer, + std::shared_ptr options, + fs::FileLocator destination_locator); + + Future<> FinishInternal() override; + + std::shared_ptr parquet_writer_; + + friend class ParquetFileFormat; +}; + +/// \brief Options for making a FileSystemDataset from a Parquet _metadata file. +struct ParquetFactoryOptions { + /// Either an explicit Partitioning or a PartitioningFactory to discover one. + /// + /// If a factory is provided, it will be used to infer a schema for partition fields + /// based on file and directory paths then construct a Partitioning. The default + /// is a Partitioning which will yield no partition information. + /// + /// The (explicit or discovered) partitioning will be applied to discovered files + /// and the resulting partition information embedded in the Dataset. + PartitioningOrFactory partitioning{Partitioning::Default()}; + + /// For the purposes of applying the partitioning, paths will be stripped + /// of the partition_base_dir. Files not matching the partition_base_dir + /// prefix will be skipped for partition discovery. The ignored files will still + /// be part of the Dataset, but will not have partition information. + /// + /// Example: + /// partition_base_dir = "/dataset"; + /// + /// - "/dataset/US/sales.csv" -> "US/sales.csv" will be given to the partitioning + /// + /// - "/home/john/late_sales.csv" -> Will be ignored for partition discovery. + /// + /// This is useful for partitioning which parses directory when ordering + /// is important, e.g. DirectoryPartitioning. + std::string partition_base_dir; + + /// Assert that all ColumnChunk paths are consistent. The parquet spec allows for + /// ColumnChunk data to be stored in multiple files, but ParquetDatasetFactory + /// supports only a single file with all ColumnChunk data. If this flag is set + /// construction of a ParquetDatasetFactory will raise an error if ColumnChunk + /// data is not resident in a single file. + bool validate_column_chunk_paths = false; +}; + +/// \brief Create FileSystemDataset from custom `_metadata` cache file. +/// +/// Dask and other systems will generate a cache metadata file by concatenating +/// the RowGroupMetaData of multiple parquet files into a single parquet file +/// that only contains metadata and no ColumnChunk data. +/// +/// ParquetDatasetFactory creates a FileSystemDataset composed of +/// ParquetFileFragment where each fragment is pre-populated with the exact +/// number of row groups and statistics for each columns. +class ARROW_DS_EXPORT ParquetDatasetFactory : public DatasetFactory { + public: + /// \brief Create a ParquetDatasetFactory from a metadata path. + /// + /// The `metadata_path` will be read from `filesystem`. Each RowGroup + /// contained in the metadata file will be relative to `dirname(metadata_path)`. + /// + /// \param[in] metadata_path path of the metadata parquet file + /// \param[in] filesystem from which to open/read the path + /// \param[in] format to read the file with. + /// \param[in] options see ParquetFactoryOptions + static Result> Make( + const std::string& metadata_path, std::shared_ptr filesystem, + std::shared_ptr format, ParquetFactoryOptions options); + + /// \brief Create a ParquetDatasetFactory from a metadata source. + /// + /// Similar to the previous Make definition, but the metadata can be a Buffer + /// and the base_path is explicit instead of inferred from the metadata + /// path. + /// + /// \param[in] metadata source to open the metadata parquet file from + /// \param[in] base_path used as the prefix of every parquet files referenced + /// \param[in] filesystem from which to read the files referenced. + /// \param[in] format to read the file with. + /// \param[in] options see ParquetFactoryOptions + static Result> Make( + const FileSource& metadata, const std::string& base_path, + std::shared_ptr filesystem, + std::shared_ptr format, ParquetFactoryOptions options); + + Result>> InspectSchemas( + InspectOptions options) override; + + Result> Finish(FinishOptions options) override; + + protected: + ParquetDatasetFactory( + std::shared_ptr filesystem, + std::shared_ptr format, + std::shared_ptr metadata, + std::shared_ptr manifest, + std::shared_ptr physical_schema, std::string base_path, + ParquetFactoryOptions options, + std::vector>> paths_with_row_group_ids) + : filesystem_(std::move(filesystem)), + format_(std::move(format)), + metadata_(std::move(metadata)), + manifest_(std::move(manifest)), + physical_schema_(std::move(physical_schema)), + base_path_(std::move(base_path)), + options_(std::move(options)), + paths_with_row_group_ids_(std::move(paths_with_row_group_ids)) {} + + std::shared_ptr filesystem_; + std::shared_ptr format_; + std::shared_ptr metadata_; + std::shared_ptr manifest_; + std::shared_ptr physical_schema_; + std::string base_path_; + ParquetFactoryOptions options_; + std::vector>> paths_with_row_group_ids_; + + private: + Result>> CollectParquetFragments( + const Partitioning& partitioning); + + Result> PartitionSchema(); +}; + +/// @} + +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/parquet_encryption_config.h b/pyarrow/include/arrow/dataset/parquet_encryption_config.h new file mode 100644 index 0000000000000000000000000000000000000000..96200b8a3118b82c92977d222ba8775f61a02b0b --- /dev/null +++ b/pyarrow/include/arrow/dataset/parquet_encryption_config.h @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/dataset/type_fwd.h" + +namespace parquet::encryption { +class CryptoFactory; +struct KmsConnectionConfig; +struct EncryptionConfiguration; +struct DecryptionConfiguration; +} // namespace parquet::encryption + +namespace arrow { +namespace dataset { + +/// \brief Core configuration class encapsulating parameters for high-level encryption +/// within Parquet framework. +/// +/// ParquetEncryptionConfig serves as a bridge, passing encryption-related +/// parameters to appropriate components within the Parquet library. It holds references +/// to objects defining encryption strategy, Key Management Service (KMS) configuration, +/// and specific encryption configurations for Parquet data. +struct ARROW_DS_EXPORT ParquetEncryptionConfig { + /// Shared pointer to CryptoFactory object, responsible for creating cryptographic + /// components like encryptors and decryptors. + std::shared_ptr crypto_factory; + + /// Shared pointer to KmsConnectionConfig object, holding configuration parameters for + /// connecting to a Key Management Service (KMS). + std::shared_ptr kms_connection_config; + + /// Shared pointer to EncryptionConfiguration object, defining specific encryption + /// settings for Parquet data, like keys for different columns. + std::shared_ptr encryption_config; +}; + +/// \brief Core configuration class encapsulating parameters for high-level decryption +/// within Parquet framework. +/// +/// ParquetDecryptionConfig is designed to pass decryption-related parameters to +/// appropriate decryption components within Parquet library. It holds references to +/// objects defining decryption strategy, Key Management Service (KMS) configuration, +/// and specific decryption configurations for reading encrypted Parquet data. +struct ARROW_DS_EXPORT ParquetDecryptionConfig { + /// Shared pointer to CryptoFactory object, pivotal in creating cryptographic + /// components for decryption process. + std::shared_ptr crypto_factory; + + /// Shared pointer to KmsConnectionConfig object, containing parameters for connecting + /// to a Key Management Service (KMS) during decryption. + std::shared_ptr kms_connection_config; + + /// Shared pointer to DecryptionConfiguration object, specifying decryption settings + /// for reading encrypted Parquet data. + std::shared_ptr decryption_config; +}; + +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/partition.h b/pyarrow/include/arrow/dataset/partition.h new file mode 100644 index 0000000000000000000000000000000000000000..315a3d384d28c1b313bf1483fb38ad99c6713663 --- /dev/null +++ b/pyarrow/include/arrow/dataset/partition.h @@ -0,0 +1,432 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/compute/expression.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/util/compare.h" + +namespace arrow { + +namespace dataset { + +constexpr char kFilenamePartitionSep = '_'; + +struct ARROW_DS_EXPORT PartitionPathFormat { + std::string directory, filename; +}; + +// ---------------------------------------------------------------------- +// Partitioning + +/// \defgroup dataset-partitioning Partitioning API +/// +/// @{ + +/// \brief Interface for parsing partition expressions from string partition +/// identifiers. +/// +/// For example, the identifier "foo=5" might be parsed to an equality expression +/// between the "foo" field and the value 5. +/// +/// Some partitionings may store the field names in a metadata +/// store instead of in file paths, for example +/// dataset_root/2009/11/... could be used when the partition fields +/// are "year" and "month" +/// +/// Paths are consumed from left to right. Paths must be relative to +/// the root of a partition; path prefixes must be removed before passing +/// the path to a partitioning for parsing. +class ARROW_DS_EXPORT Partitioning : public util::EqualityComparable { + public: + virtual ~Partitioning() = default; + + /// \brief The name identifying the kind of partitioning + virtual std::string type_name() const = 0; + + //// \brief Return whether the partitionings are equal + virtual bool Equals(const Partitioning& other) const { + return schema_->Equals(other.schema_, /*check_metadata=*/false); + } + + /// \brief If the input batch shares any fields with this partitioning, + /// produce sub-batches which satisfy mutually exclusive Expressions. + struct PartitionedBatches { + RecordBatchVector batches; + std::vector expressions; + }; + virtual Result Partition( + const std::shared_ptr& batch) const = 0; + + /// \brief Parse a path into a partition expression + virtual Result Parse(const std::string& path) const = 0; + + virtual Result Format(const compute::Expression& expr) const = 0; + + /// \brief A default Partitioning which is a DirectoryPartitioning + /// with an empty schema. + static std::shared_ptr Default(); + + /// \brief The partition schema. + const std::shared_ptr& schema() const { return schema_; } + + protected: + explicit Partitioning(std::shared_ptr schema) : schema_(std::move(schema)) {} + + std::shared_ptr schema_; +}; + +/// \brief The encoding of partition segments. +enum class SegmentEncoding : int8_t { + /// No encoding. + None = 0, + /// Segment values are URL-encoded. + Uri = 1, +}; + +ARROW_DS_EXPORT +std::ostream& operator<<(std::ostream& os, SegmentEncoding segment_encoding); + +/// \brief Options for key-value based partitioning (hive/directory). +struct ARROW_DS_EXPORT KeyValuePartitioningOptions { + /// After splitting a path into components, decode the path components + /// before parsing according to this scheme. + SegmentEncoding segment_encoding = SegmentEncoding::Uri; +}; + +/// \brief Options for inferring a partitioning. +struct ARROW_DS_EXPORT PartitioningFactoryOptions { + /// When inferring a schema for partition fields, yield dictionary encoded types + /// instead of plain. This can be more efficient when materializing virtual + /// columns, and Expressions parsed by the finished Partitioning will include + /// dictionaries of all unique inspected values for each field. + bool infer_dictionary = false; + /// Optionally, an expected schema can be provided, in which case inference + /// will only check discovered fields against the schema and update internal + /// state (such as dictionaries). + std::shared_ptr schema; + /// After splitting a path into components, decode the path components + /// before parsing according to this scheme. + SegmentEncoding segment_encoding = SegmentEncoding::Uri; + + KeyValuePartitioningOptions AsPartitioningOptions() const; +}; + +/// \brief Options for inferring a hive-style partitioning. +struct ARROW_DS_EXPORT HivePartitioningFactoryOptions : PartitioningFactoryOptions { + /// The hive partitioning scheme maps null to a hard coded fallback string. + std::string null_fallback; + + HivePartitioningOptions AsHivePartitioningOptions() const; +}; + +/// \brief PartitioningFactory provides creation of a partitioning when the +/// specific schema must be inferred from available paths (no explicit schema is known). +class ARROW_DS_EXPORT PartitioningFactory { + public: + virtual ~PartitioningFactory() = default; + + /// \brief The name identifying the kind of partitioning + virtual std::string type_name() const = 0; + + /// Get the schema for the resulting Partitioning. + /// This may reset internal state, for example dictionaries of unique representations. + virtual Result> Inspect( + const std::vector& paths) = 0; + + /// Create a partitioning using the provided schema + /// (fields may be dropped). + virtual Result> Finish( + const std::shared_ptr& schema) const = 0; +}; + +/// \brief Subclass for the common case of a partitioning which yields an equality +/// expression for each segment +class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { + public: + /// An unconverted equality expression consisting of a field name and the representation + /// of a scalar value + struct Key { + std::string name; + std::optional value; + }; + + Result Partition( + const std::shared_ptr& batch) const override; + + Result Parse(const std::string& path) const override; + + Result Format(const compute::Expression& expr) const override; + + const ArrayVector& dictionaries() const { return dictionaries_; } + + SegmentEncoding segment_encoding() const { return options_.segment_encoding; } + + bool Equals(const Partitioning& other) const override; + + protected: + KeyValuePartitioning(std::shared_ptr schema, ArrayVector dictionaries, + KeyValuePartitioningOptions options) + : Partitioning(std::move(schema)), + dictionaries_(std::move(dictionaries)), + options_(options) { + if (dictionaries_.empty()) { + dictionaries_.resize(schema_->num_fields()); + } + } + + virtual Result> ParseKeys(const std::string& path) const = 0; + + virtual Result FormatValues(const ScalarVector& values) const = 0; + + /// Convert a Key to a full expression. + Result ConvertKey(const Key& key) const; + + Result> FormatPartitionSegments( + const ScalarVector& values) const; + Result> ParsePartitionSegments( + const std::vector& segments) const; + + ArrayVector dictionaries_; + KeyValuePartitioningOptions options_; +}; + +/// \brief DirectoryPartitioning parses one segment of a path for each field in its +/// schema. All fields are required, so paths passed to DirectoryPartitioning::Parse +/// must contain segments for each field. +/// +/// For example given schema the path "/2009/11" would be +/// parsed to ("year"_ == 2009 and "month"_ == 11) +class ARROW_DS_EXPORT DirectoryPartitioning : public KeyValuePartitioning { + public: + /// If a field in schema is of dictionary type, the corresponding element of + /// dictionaries must be contain the dictionary of values for that field. + explicit DirectoryPartitioning(std::shared_ptr schema, + ArrayVector dictionaries = {}, + KeyValuePartitioningOptions options = {}); + + std::string type_name() const override { return "directory"; } + + bool Equals(const Partitioning& other) const override; + + /// \brief Create a factory for a directory partitioning. + /// + /// \param[in] field_names The names for the partition fields. Types will be + /// inferred. + static std::shared_ptr MakeFactory( + std::vector field_names, PartitioningFactoryOptions = {}); + + private: + Result> ParseKeys(const std::string& path) const override; + + Result FormatValues(const ScalarVector& values) const override; +}; + +/// \brief The default fallback used for null values in a Hive-style partitioning. +static constexpr char kDefaultHiveNullFallback[] = "__HIVE_DEFAULT_PARTITION__"; + +struct ARROW_DS_EXPORT HivePartitioningOptions : public KeyValuePartitioningOptions { + std::string null_fallback = kDefaultHiveNullFallback; + + static HivePartitioningOptions DefaultsWithNullFallback(std::string fallback) { + HivePartitioningOptions options; + options.null_fallback = std::move(fallback); + return options; + } +}; + +/// \brief Multi-level, directory based partitioning +/// originating from Apache Hive with all data files stored in the +/// leaf directories. Data is partitioned by static values of a +/// particular column in the schema. Partition keys are represented in +/// the form $key=$value in directory names. +/// Field order is ignored, as are missing or unrecognized field names. +/// +/// For example given schema the path +/// "/day=321/ignored=3.4/year=2009" parses to ("year"_ == 2009 and "day"_ == 321) +class ARROW_DS_EXPORT HivePartitioning : public KeyValuePartitioning { + public: + /// If a field in schema is of dictionary type, the corresponding element of + /// dictionaries must be contain the dictionary of values for that field. + explicit HivePartitioning(std::shared_ptr schema, ArrayVector dictionaries = {}, + std::string null_fallback = kDefaultHiveNullFallback) + : KeyValuePartitioning(std::move(schema), std::move(dictionaries), + KeyValuePartitioningOptions()), + hive_options_( + HivePartitioningOptions::DefaultsWithNullFallback(std::move(null_fallback))) { + } + + explicit HivePartitioning(std::shared_ptr schema, ArrayVector dictionaries, + HivePartitioningOptions options) + : KeyValuePartitioning(std::move(schema), std::move(dictionaries), options), + hive_options_(options) {} + + std::string type_name() const override { return "hive"; } + std::string null_fallback() const { return hive_options_.null_fallback; } + const HivePartitioningOptions& options() const { return hive_options_; } + + static Result> ParseKey(const std::string& segment, + const HivePartitioningOptions& options); + + bool Equals(const Partitioning& other) const override; + + /// \brief Create a factory for a hive partitioning. + static std::shared_ptr MakeFactory( + HivePartitioningFactoryOptions = {}); + + private: + const HivePartitioningOptions hive_options_; + Result> ParseKeys(const std::string& path) const override; + + Result FormatValues(const ScalarVector& values) const override; +}; + +/// \brief Implementation provided by lambda or other callable +class ARROW_DS_EXPORT FunctionPartitioning : public Partitioning { + public: + using ParseImpl = std::function(const std::string&)>; + + using FormatImpl = + std::function(const compute::Expression&)>; + + FunctionPartitioning(std::shared_ptr schema, ParseImpl parse_impl, + FormatImpl format_impl = NULLPTR, std::string name = "function") + : Partitioning(std::move(schema)), + parse_impl_(std::move(parse_impl)), + format_impl_(std::move(format_impl)), + name_(std::move(name)) {} + + std::string type_name() const override { return name_; } + + bool Equals(const Partitioning& other) const override { return false; } + + Result Parse(const std::string& path) const override { + return parse_impl_(path); + } + + Result Format(const compute::Expression& expr) const override { + if (format_impl_) { + return format_impl_(expr); + } + return Status::NotImplemented("formatting paths from ", type_name(), " Partitioning"); + } + + Result Partition( + const std::shared_ptr& batch) const override { + return Status::NotImplemented("partitioning batches from ", type_name(), + " Partitioning"); + } + + private: + ParseImpl parse_impl_; + FormatImpl format_impl_; + std::string name_; +}; + +class ARROW_DS_EXPORT FilenamePartitioning : public KeyValuePartitioning { + public: + /// \brief Construct a FilenamePartitioning from its components. + /// + /// If a field in schema is of dictionary type, the corresponding element of + /// dictionaries must be contain the dictionary of values for that field. + explicit FilenamePartitioning(std::shared_ptr schema, + ArrayVector dictionaries = {}, + KeyValuePartitioningOptions options = {}); + + std::string type_name() const override { return "filename"; } + + /// \brief Create a factory for a filename partitioning. + /// + /// \param[in] field_names The names for the partition fields. Types will be + /// inferred. + static std::shared_ptr MakeFactory( + std::vector field_names, PartitioningFactoryOptions = {}); + + bool Equals(const Partitioning& other) const override; + + private: + Result> ParseKeys(const std::string& path) const override; + + Result FormatValues(const ScalarVector& values) const override; +}; + +ARROW_DS_EXPORT std::string StripPrefix(const std::string& path, + const std::string& prefix); + +/// \brief Extracts the directory and filename and removes the prefix of a path +/// +/// e.g., `StripPrefixAndFilename("/data/year=2019/c.txt", "/data") -> +/// {"year=2019","c.txt"}` +ARROW_DS_EXPORT std::string StripPrefixAndFilename(const std::string& path, + const std::string& prefix); + +/// \brief Vector version of StripPrefixAndFilename. +ARROW_DS_EXPORT std::vector StripPrefixAndFilename( + const std::vector& paths, const std::string& prefix); + +/// \brief Vector version of StripPrefixAndFilename. +ARROW_DS_EXPORT std::vector StripPrefixAndFilename( + const std::vector& files, const std::string& prefix); + +/// \brief Either a Partitioning or a PartitioningFactory +class ARROW_DS_EXPORT PartitioningOrFactory { + public: + explicit PartitioningOrFactory(std::shared_ptr partitioning) + : partitioning_(std::move(partitioning)) {} + + explicit PartitioningOrFactory(std::shared_ptr factory) + : factory_(std::move(factory)) {} + + PartitioningOrFactory& operator=(std::shared_ptr partitioning) { + return *this = PartitioningOrFactory(std::move(partitioning)); + } + + PartitioningOrFactory& operator=(std::shared_ptr factory) { + return *this = PartitioningOrFactory(std::move(factory)); + } + + /// \brief The partitioning (if given). + const std::shared_ptr& partitioning() const { return partitioning_; } + + /// \brief The partition factory (if given). + const std::shared_ptr& factory() const { return factory_; } + + /// \brief Get the partition schema, inferring it with the given factory if needed. + Result> GetOrInferSchema(const std::vector& paths); + + private: + std::shared_ptr factory_; + std::shared_ptr partitioning_; +}; + +/// @} + +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/plan.h b/pyarrow/include/arrow/dataset/plan.h new file mode 100644 index 0000000000000000000000000000000000000000..10260ccec81d159ffd40d86144e39c4d91739db1 --- /dev/null +++ b/pyarrow/include/arrow/dataset/plan.h @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#include "arrow/dataset/visibility.h" + +namespace arrow { +namespace dataset { +namespace internal { + +/// Register dataset-based exec nodes with the exec node registry +/// +/// This function must be called before using dataset ExecNode factories +ARROW_DS_EXPORT void Initialize(); + +} // namespace internal +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/projector.h b/pyarrow/include/arrow/dataset/projector.h new file mode 100644 index 0000000000000000000000000000000000000000..86d38f0af23522a08dcebc1c290fe6bc25ae014e --- /dev/null +++ b/pyarrow/include/arrow/dataset/projector.h @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include "arrow/dataset/visibility.h" +#include "arrow/type_fwd.h" + +namespace arrow { +namespace dataset { + +// FIXME this is superceded by compute::Expression::Bind +ARROW_DS_EXPORT Status CheckProjectable(const Schema& from, const Schema& to); + +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/scanner.h b/pyarrow/include/arrow/dataset/scanner.h new file mode 100644 index 0000000000000000000000000000000000000000..7885b132cc9b529a0fbce41c807529ecd1e34da4 --- /dev/null +++ b/pyarrow/include/arrow/dataset/scanner.h @@ -0,0 +1,623 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/acero/options.h" +#include "arrow/compute/expression.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/dataset/dataset.h" +#include "arrow/dataset/type_fwd.h" +#include "arrow/dataset/visibility.h" +#include "arrow/io/interfaces.h" +#include "arrow/type_fwd.h" +#include "arrow/util/async_generator_fwd.h" +#include "arrow/util/iterator.h" +#include "arrow/util/thread_pool.h" +#include "arrow/util/type_fwd.h" + +namespace arrow { + +using RecordBatchGenerator = std::function>()>; + +namespace dataset { + +/// \defgroup dataset-scanning Scanning API +/// +/// @{ + +constexpr int64_t kDefaultBatchSize = 1 << 17; // 128Ki rows +// This will yield 64 batches ~ 8Mi rows +constexpr int32_t kDefaultBatchReadahead = 16; +constexpr int32_t kDefaultFragmentReadahead = 4; +constexpr int32_t kDefaultBytesReadahead = 1 << 25; // 32MiB + +/// Scan-specific options, which can be changed between scans of the same dataset. +struct ARROW_DS_EXPORT ScanOptions { + /// A row filter (which will be pushed down to partitioning/reading if supported). + compute::Expression filter = compute::literal(true); + /// A projection expression (which can add/remove/rename columns). + compute::Expression projection; + + /// Schema with which batches will be read from fragments. This is also known as the + /// "reader schema" it will be used (for example) in constructing CSV file readers to + /// identify column types for parsing. Usually only a subset of its fields (see + /// MaterializedFields) will be materialized during a scan. + std::shared_ptr dataset_schema; + + /// Schema of projected record batches. This is independent of dataset_schema as its + /// fields are derived from the projection. For example, let + /// + /// dataset_schema = {"a": int32, "b": int32, "id": utf8} + /// projection = project({equal(field_ref("a"), field_ref("b"))}, {"a_plus_b"}) + /// + /// (no filter specified). In this case, the projected_schema would be + /// + /// {"a_plus_b": int32} + std::shared_ptr projected_schema; + + /// Maximum row count for scanned batches. + int64_t batch_size = kDefaultBatchSize; + + /// How many batches to read ahead within a fragment. + /// + /// Set to 0 to disable batch readahead + /// + /// Note: May not be supported by all formats + /// Note: Will be ignored if use_threads is set to false + int32_t batch_readahead = kDefaultBatchReadahead; + + /// How many files to read ahead + /// + /// Set to 0 to disable fragment readahead + /// + /// Note: May not be enforced by all scanners + /// Note: Will be ignored if use_threads is set to false + int32_t fragment_readahead = kDefaultFragmentReadahead; + + /// A pool from which materialized and scanned arrays will be allocated. + MemoryPool* pool = arrow::default_memory_pool(); + + /// IOContext for any IO tasks + /// + /// Note: The IOContext executor will be ignored if use_threads is set to false + io::IOContext io_context; + + /// Executor for any CPU tasks + /// + /// If null, the global CPU executor will be used + /// + /// Note: The Executor will be ignored if use_threads is set to false + arrow::internal::Executor* cpu_executor = NULLPTR; + + /// If true the scanner will scan in parallel + /// + /// Note: If true, this will use threads from both the cpu_executor and the + /// io_context.executor + /// Note: This must be true in order for any readahead to happen + bool use_threads = false; + + /// If true the scanner will add augmented fields to the output schema. + bool add_augmented_fields = true; + + /// Whether to cache metadata when scanning. + /// + /// Fragments may typically cache metadata to speed up repeated accesses. + /// However, in use cases where a single scan is done, or if memory use + /// is more critical than CPU time, setting this option to false can + /// lessen memory use. + bool cache_metadata = true; + + /// Fragment-specific scan options. + std::shared_ptr fragment_scan_options; + + /// Return a vector of FieldRefs that require materialization. + /// + /// This is usually the union of the fields referenced in the projection and the + /// filter expression. Examples: + /// + /// - `SELECT a, b WHERE a < 2 && c > 1` => ["a", "b", "a", "c"] + /// - `SELECT a + b < 3 WHERE a > 1` => ["a", "b", "a"] + /// + /// This is needed for expression where a field may not be directly + /// used in the final projection but is still required to evaluate the + /// expression. + /// + /// This is used by Fragment implementations to apply the column + /// sub-selection optimization. + std::vector MaterializedFields() const; + + /// Parameters which control when the plan should pause for a slow consumer + acero::BackpressureOptions backpressure = + acero::BackpressureOptions::DefaultBackpressure(); +}; + +/// Scan-specific options, which can be changed between scans of the same dataset. +/// +/// A dataset consists of one or more individual fragments. A fragment is anything +/// that is independently scannable, often a file. +/// +/// Batches from all fragments will be converted to a single schema. This unified +/// schema is referred to as the "dataset schema" and is the output schema for +/// this node. +/// +/// Individual fragments may have schemas that are different from the dataset +/// schema. This is sometimes referred to as the physical or fragment schema. +/// Conversion from the fragment schema to the dataset schema is a process +/// known as evolution. +struct ARROW_DS_EXPORT ScanV2Options : public acero::ExecNodeOptions { + explicit ScanV2Options(std::shared_ptr dataset) + : dataset(std::move(dataset)) {} + + /// \brief The dataset to scan + std::shared_ptr dataset; + /// \brief A row filter + /// + /// The filter expression should be written against the dataset schema. + /// The filter must be unbound. + /// + /// This is an opportunistic pushdown filter. Filtering capabilities will + /// vary between formats. If a format is not capable of applying the filter + /// then it will ignore it. + /// + /// Each fragment will do its best to filter the data based on the information + /// (partitioning guarantees, statistics) available to it. If it is able to + /// apply some filtering then it will indicate what filtering it was able to + /// apply by attaching a guarantee to the batch. + /// + /// For example, if a filter is x < 50 && y > 40 then a batch may be able to + /// apply a guarantee x < 50. Post-scan filtering would then only need to + /// consider y > 40 (for this specific batch). The next batch may not be able + /// to attach any guarantee and both clauses would need to be applied to that batch. + /// + /// A single guarantee-aware filtering operation should generally be applied to all + /// resulting batches. The scan node is not responsible for this. + /// + /// Fields that are referenced by the filter should be included in the `columns` vector. + /// The scan node will not automatically fetch fields referenced by the filter + /// expression. \see AddFieldsNeededForFilter + /// + /// If the filter references fields that are not included in `columns` this may or may + /// not be an error, depending on the format. + compute::Expression filter = compute::literal(true); + + /// \brief The columns to scan + /// + /// This is not a simple list of top-level column indices but instead a set of paths + /// allowing for partial selection of columns + /// + /// These paths refer to the dataset schema + /// + /// For example, consider the following dataset schema: + /// schema({ + /// field("score", int32()), + /// "marker", struct_({ + /// field("color", utf8()), + /// field("location", struct_({ + /// field("x", float64()), + /// field("y", float64()) + /// }) + /// }) + /// }) + /// + /// If `columns` is {{0}, {1,1,0}} then the output schema is: + /// schema({field("score", int32()), field("x", float64())}) + /// + /// If `columns` is {{1,1,1}, {1,1}} then the output schema is: + /// schema({ + /// field("y", float64()), + /// field("location", struct_({ + /// field("x", float64()), + /// field("y", float64()) + /// }) + /// }) + std::vector columns; + + /// \brief Target number of bytes to read ahead in a fragment + /// + /// This limit involves some amount of estimation. Formats typically only know + /// batch boundaries in terms of rows (not decoded bytes) and so an estimation + /// must be done to guess the average row size. Other formats like CSV and JSON + /// must make even more generalized guesses. + /// + /// This is a best-effort guide. Some formats may need to read ahead further, + /// for example, if scanning a parquet file that has batches with 100MiB of data + /// then the actual readahead will be at least 100MiB + /// + /// Set to 0 to disable readahead. When disabled, the scanner will read the + /// dataset one batch at a time + /// + /// This limit applies across all fragments. If the limit is 32MiB and the + /// fragment readahead allows for 20 fragments to be read at once then the + /// total readahead will still be 32MiB and NOT 20 * 32MiB. + int32_t target_bytes_readahead = kDefaultBytesReadahead; + + /// \brief Number of fragments to read ahead + /// + /// Higher readahead will potentially lead to more efficient I/O but will lead + /// to the scan operation using more RAM. The default is fairly conservative + /// and designed for fast local disks (or slow local spinning disks which cannot + /// handle much parallelism anyways). When using a highly parallel remote filesystem + /// you will likely want to increase these values. + /// + /// Set to 0 to disable fragment readahead. When disabled the dataset will be scanned + /// one fragment at a time. + int32_t fragment_readahead = kDefaultFragmentReadahead; + /// \brief Options specific to the file format + const FragmentScanOptions* format_options = NULLPTR; + + /// \brief Utility method to get a selection representing all columns in a dataset + static std::vector AllColumns(const Schema& dataset_schema); + + /// \brief Utility method to add fields needed for the current filter + /// + /// This method adds any fields that are needed by `filter` which are not already + /// included in the list of columns. Any new fields added will be added to the end + /// in no particular order. + static Status AddFieldsNeededForFilter(ScanV2Options* options); +}; + +/// \brief Describes a projection +struct ARROW_DS_EXPORT ProjectionDescr { + /// \brief The projection expression itself + /// This expression must be a call to make_struct + compute::Expression expression; + /// \brief The output schema of the projection. + + /// This can be calculated from the input schema and the expression but it + /// is cached here for convenience. + std::shared_ptr schema; + + /// \brief Create a ProjectionDescr by binding an expression to the dataset schema + /// + /// expression must return a struct type + static Result FromStructExpression( + const compute::Expression& expression, const Schema& dataset_schema); + + /// \brief Create a ProjectionDescr from expressions/names for each field + static Result FromExpressions(std::vector exprs, + std::vector names, + const Schema& dataset_schema); + + /// \brief Create a default projection referencing fields in the dataset schema + static Result FromNames(std::vector names, + const Schema& dataset_schema, + bool add_augmented_fields = true); + + /// \brief Make a projection that projects every field in the dataset schema + static Result Default(const Schema& dataset_schema, + bool add_augmented_fields = true); +}; + +/// \brief Utility method to set the projection expression and schema +ARROW_DS_EXPORT void SetProjection(ScanOptions* options, ProjectionDescr projection); + +/// \brief Combines a record batch with the fragment that the record batch originated +/// from +/// +/// Knowing the source fragment can be useful for debugging & understanding loaded +/// data +struct TaggedRecordBatch { + std::shared_ptr record_batch; + std::shared_ptr fragment; + + friend inline bool operator==(const TaggedRecordBatch& left, + const TaggedRecordBatch& right) { + return left.record_batch == right.record_batch && left.fragment == right.fragment; + } +}; + +using TaggedRecordBatchGenerator = std::function()>; +using TaggedRecordBatchIterator = Iterator; + +/// \brief Combines a tagged batch with positional information +/// +/// This is returned when scanning batches in an unordered fashion. This information is +/// needed if you ever want to reassemble the batches in order +struct EnumeratedRecordBatch { + Enumerated> record_batch; + Enumerated> fragment; + + friend inline bool operator==(const EnumeratedRecordBatch& left, + const EnumeratedRecordBatch& right) { + return left.record_batch == right.record_batch && left.fragment == right.fragment; + } +}; + +using EnumeratedRecordBatchGenerator = std::function()>; +using EnumeratedRecordBatchIterator = Iterator; + +/// @} + +} // namespace dataset + +template <> +struct IterationTraits { + static dataset::TaggedRecordBatch End() { + return dataset::TaggedRecordBatch{NULLPTR, NULLPTR}; + } + static bool IsEnd(const dataset::TaggedRecordBatch& val) { + return val.record_batch == NULLPTR; + } +}; + +template <> +struct IterationTraits { + static dataset::EnumeratedRecordBatch End() { + return dataset::EnumeratedRecordBatch{ + IterationEnd>>(), + IterationEnd>>()}; + } + static bool IsEnd(const dataset::EnumeratedRecordBatch& val) { + return IsIterationEnd(val.fragment); + } +}; + +namespace dataset { + +/// \defgroup dataset-scanning Scanning API +/// +/// @{ + +/// \brief A scanner glues together several dataset classes to load in data. +/// The dataset contains a collection of fragments and partitioning rules. +/// +/// The fragments identify independently loadable units of data (i.e. each fragment has +/// a potentially unique schema and possibly even format. It should be possible to read +/// fragments in parallel if desired). +/// +/// The fragment's format contains the logic necessary to actually create a task to load +/// the fragment into memory. That task may or may not support parallel execution of +/// its own. +/// +/// The scanner is then responsible for creating scan tasks from every fragment in the +/// dataset and (potentially) sequencing the loaded record batches together. +/// +/// The scanner should not buffer the entire dataset in memory (unless asked) instead +/// yielding record batches as soon as they are ready to scan. Various readahead +/// properties control how much data is allowed to be scanned before pausing to let a +/// slow consumer catchup. +/// +/// Today the scanner also handles projection & filtering although that may change in +/// the future. +class ARROW_DS_EXPORT Scanner { + public: + virtual ~Scanner() = default; + + /// \brief Apply a visitor to each RecordBatch as it is scanned. If multiple threads + /// are used (via use_threads), the visitor will be invoked from those threads and is + /// responsible for any synchronization. + virtual Status Scan(std::function visitor) = 0; + /// \brief Convert a Scanner into a Table. + /// + /// Use this convenience utility with care. This will serially materialize the + /// Scan result in memory before creating the Table. + virtual Result> ToTable() = 0; + /// \brief Scan the dataset into a stream of record batches. Each batch is tagged + /// with the fragment it originated from. The batches will arrive in order. The + /// order of fragments is determined by the dataset. + /// + /// Note: The scanner will perform some readahead but will avoid materializing too + /// much in memory (this is goverended by the readahead options and use_threads option). + /// If the readahead queue fills up then I/O will pause until the calling thread catches + /// up. + virtual Result ScanBatches() = 0; + virtual Result ScanBatchesAsync() = 0; + virtual Result ScanBatchesAsync( + ::arrow::internal::Executor* cpu_thread_pool) = 0; + /// \brief Scan the dataset into a stream of record batches. Unlike ScanBatches this + /// method may allow record batches to be returned out of order. This allows for more + /// efficient scanning: some fragments may be accessed more quickly than others (e.g. + /// may be cached in RAM or just happen to get scheduled earlier by the I/O) + /// + /// To make up for the out-of-order iteration each batch is further tagged with + /// positional information. + virtual Result ScanBatchesUnordered() = 0; + virtual Result ScanBatchesUnorderedAsync() = 0; + virtual Result ScanBatchesUnorderedAsync( + ::arrow::internal::Executor* cpu_thread_pool) = 0; + /// \brief A convenience to synchronously load the given rows by index. + /// + /// Will only consume as many batches as needed from ScanBatches(). + virtual Result> TakeRows(const Array& indices) = 0; + /// \brief Get the first N rows. + virtual Result> Head(int64_t num_rows) = 0; + /// \brief Count rows matching a predicate. + /// + /// This method will push down the predicate and compute the result based on fragment + /// metadata if possible. + virtual Result CountRows() = 0; + virtual Future CountRowsAsync() = 0; + /// \brief Convert the Scanner to a RecordBatchReader so it can be + /// easily used with APIs that expect a reader. + virtual Result> ToRecordBatchReader() = 0; + + /// \brief Get the options for this scan. + const std::shared_ptr& options() const { return scan_options_; } + /// \brief Get the dataset that this scanner will scan + virtual const std::shared_ptr& dataset() const = 0; + + protected: + explicit Scanner(std::shared_ptr scan_options) + : scan_options_(std::move(scan_options)) {} + + Result AddPositioningToInOrderScan( + TaggedRecordBatchIterator scan); + + const std::shared_ptr scan_options_; +}; + +/// \brief ScannerBuilder is a factory class to construct a Scanner. It is used +/// to pass information, notably a potential filter expression and a subset of +/// columns to materialize. +class ARROW_DS_EXPORT ScannerBuilder { + public: + explicit ScannerBuilder(std::shared_ptr dataset); + + ScannerBuilder(std::shared_ptr dataset, + std::shared_ptr scan_options); + + ScannerBuilder(std::shared_ptr schema, std::shared_ptr fragment, + std::shared_ptr scan_options); + + /// \brief Make a scanner from a record batch reader. + /// + /// The resulting scanner can be scanned only once. This is intended + /// to support writing data from streaming sources or other sources + /// that can be iterated only once. + static std::shared_ptr FromRecordBatchReader( + std::shared_ptr reader); + + /// \brief Set the subset of columns to materialize. + /// + /// Columns which are not referenced may not be read from fragments. + /// + /// \param[in] columns list of columns to project. Order and duplicates will + /// be preserved. + /// + /// \return Failure if any column name does not exists in the dataset's + /// Schema. + Status Project(std::vector columns); + + /// \brief Set expressions which will be evaluated to produce the materialized + /// columns. + /// + /// Columns which are not referenced may not be read from fragments. + /// + /// \param[in] exprs expressions to evaluate to produce columns. + /// \param[in] names list of names for the resulting columns. + /// + /// \return Failure if any referenced column does not exists in the dataset's + /// Schema. + Status Project(std::vector exprs, std::vector names); + + /// \brief Set the filter expression to return only rows matching the filter. + /// + /// The predicate will be passed down to Sources and corresponding + /// Fragments to exploit predicate pushdown if possible using + /// partition information or Fragment internal metadata, e.g. Parquet statistics. + /// Columns which are not referenced may not be read from fragments. + /// + /// \param[in] filter expression to filter rows with. + /// + /// \return Failure if any referenced columns does not exist in the dataset's + /// Schema. + Status Filter(const compute::Expression& filter); + + /// \brief Indicate if the Scanner should make use of the available + /// ThreadPool found in ScanOptions; + Status UseThreads(bool use_threads = true); + + /// \brief Indicate if metadata should be cached when scanning + /// + /// Fragments may typically cache metadata to speed up repeated accesses. + /// However, in use cases where a single scan is done, or if memory use + /// is more critical than CPU time, setting this option to false can + /// lessen memory use. + Status CacheMetadata(bool cache_metadata = true); + + /// \brief Set the maximum number of rows per RecordBatch. + /// + /// \param[in] batch_size the maximum number of rows. + /// \returns An error if the number for batch is not greater than 0. + /// + /// This option provides a control limiting the memory owned by any RecordBatch. + Status BatchSize(int64_t batch_size); + + /// \brief Set the number of batches to read ahead within a fragment. + /// + /// \param[in] batch_readahead How many batches to read ahead within a fragment + /// \returns an error if this number is less than 0. + /// + /// This option provides a control on the RAM vs I/O tradeoff. + /// It might not be supported by all file formats, in which case it will + /// simply be ignored. + Status BatchReadahead(int32_t batch_readahead); + + /// \brief Set the number of fragments to read ahead + /// + /// \param[in] fragment_readahead How many fragments to read ahead + /// \returns an error if this number is less than 0. + /// + /// This option provides a control on the RAM vs I/O tradeoff. + Status FragmentReadahead(int32_t fragment_readahead); + + /// \brief Set the pool from which materialized and scanned arrays will be allocated. + Status Pool(MemoryPool* pool); + + /// \brief Set fragment-specific scan options. + Status FragmentScanOptions(std::shared_ptr fragment_scan_options); + + /// \brief Override default backpressure configuration + Status Backpressure(acero::BackpressureOptions backpressure); + + /// \brief Return the current scan options for the builder. + Result> GetScanOptions(); + + /// \brief Return the constructed now-immutable Scanner object + Result> Finish(); + + const std::shared_ptr& schema() const; + const std::shared_ptr& projected_schema() const; + + private: + std::shared_ptr dataset_; + std::shared_ptr scan_options_ = std::make_shared(); +}; + +/// \brief Construct a source ExecNode which yields batches from a dataset scan. +/// +/// Does not construct associated filter or project nodes. +/// +/// Batches are yielded sequentially, like single-threaded, +/// when require_sequenced_output=true. +/// +/// Yielded batches will be augmented with fragment/batch indices when +/// implicit_ordering=true to enable stable ordering for simple ExecPlans. +class ARROW_DS_EXPORT ScanNodeOptions : public acero::ExecNodeOptions { + public: + explicit ScanNodeOptions(std::shared_ptr dataset, + std::shared_ptr scan_options, + bool require_sequenced_output = false, + bool implicit_ordering = false) + : dataset(std::move(dataset)), + scan_options(std::move(scan_options)), + require_sequenced_output(require_sequenced_output), + implicit_ordering(implicit_ordering) {} + + std::shared_ptr dataset; + std::shared_ptr scan_options; + bool require_sequenced_output; + bool implicit_ordering; +}; + +/// @} + +namespace internal { +ARROW_DS_EXPORT void InitializeScanner(arrow::acero::ExecFactoryRegistry* registry); +ARROW_DS_EXPORT void InitializeScannerV2(arrow::acero::ExecFactoryRegistry* registry); +} // namespace internal +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/type_fwd.h b/pyarrow/include/arrow/dataset/type_fwd.h new file mode 100644 index 0000000000000000000000000000000000000000..d58781e038de9ffc2686ebfda9f640eeacdd6668 --- /dev/null +++ b/pyarrow/include/arrow/dataset/type_fwd.h @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include +#include + +#include "arrow/compute/type_fwd.h" // IWYU pragma: export +#include "arrow/dataset/visibility.h" +#include "arrow/filesystem/type_fwd.h" // IWYU pragma: export +#include "arrow/type_fwd.h" // IWYU pragma: export + +namespace arrow { +namespace dataset { + +class Dataset; +class DatasetFactory; +using DatasetVector = std::vector>; + +class UnionDataset; +class UnionDatasetFactory; + +class Fragment; +using FragmentIterator = Iterator>; +using FragmentVector = std::vector>; + +class FragmentScanOptions; + +class FileSource; +class FileFormat; +class FileFragment; +class FileWriter; +class FileWriteOptions; +class FileSystemDataset; +class FileSystemDatasetFactory; +struct FileSystemDatasetWriteOptions; +class WriteNodeOptions; + +/// \brief Controls what happens if files exist in an output directory during a dataset +/// write +enum class ExistingDataBehavior : int8_t { + /// Deletes all files in a directory the first time that directory is encountered + kDeleteMatchingPartitions, + /// Ignores existing files, overwriting any that happen to have the same name as an + /// output file + kOverwriteOrIgnore, + /// Returns an error if there are any files or subdirectories in the output directory + kError, +}; + +class InMemoryDataset; + +class CsvFileFormat; +class CsvFileWriter; +class CsvFileWriteOptions; +struct CsvFragmentScanOptions; + +class JsonFileFormat; +class JsonFileWriter; +class JsonFileWriteOptions; +struct JsonFragmentScanOptions; + +class IpcFileFormat; +class IpcFileWriter; +class IpcFileWriteOptions; +class IpcFragmentScanOptions; + +class ParquetFileFormat; +class ParquetFileFragment; +class ParquetFragmentScanOptions; +class ParquetFileWriter; +class ParquetFileWriteOptions; + +class Partitioning; +class PartitioningFactory; +class PartitioningOrFactory; +struct KeyValuePartitioningOptions; +class DirectoryPartitioning; +class HivePartitioning; +struct HivePartitioningOptions; +class FilenamePartitioning; +struct FilenamePartitioningOptions; + +class ScanNodeOptions; +struct ScanOptions; + +class Scanner; + +class ScannerBuilder; + +class ScanTask; +using ScanTaskVector = std::vector>; +using ScanTaskIterator = Iterator>; + +} // namespace dataset +} // namespace arrow diff --git a/pyarrow/include/arrow/dataset/visibility.h b/pyarrow/include/arrow/dataset/visibility.h new file mode 100644 index 0000000000000000000000000000000000000000..752907238ca071238e21a303a947afbc1f11217f --- /dev/null +++ b/pyarrow/include/arrow/dataset/visibility.h @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#if defined(_WIN32) || defined(__CYGWIN__) +# if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4251) +# else +# pragma GCC diagnostic ignored "-Wattributes" +# endif + +# ifdef ARROW_DS_STATIC +# define ARROW_DS_EXPORT +# elif defined(ARROW_DS_EXPORTING) +# define ARROW_DS_EXPORT __declspec(dllexport) +# else +# define ARROW_DS_EXPORT __declspec(dllimport) +# endif + +# define ARROW_DS_NO_EXPORT +#else // Not Windows +# ifndef ARROW_DS_EXPORT +# define ARROW_DS_EXPORT __attribute__((visibility("default"))) +# endif +# ifndef ARROW_DS_NO_EXPORT +# define ARROW_DS_NO_EXPORT __attribute__((visibility("hidden"))) +# endif +#endif // Non-Windows + +#if defined(_MSC_VER) +# pragma warning(pop) +#endif diff --git a/pyarrow/include/arrow/datum.h b/pyarrow/include/arrow/datum.h new file mode 100644 index 0000000000000000000000000000000000000000..4a88e7a81125cbed89d78d0e67288075ed9295f8 --- /dev/null +++ b/pyarrow/include/arrow/datum.h @@ -0,0 +1,314 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array/data.h" +#include "arrow/device_allocation_type_set.h" +#include "arrow/scalar.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; +class ChunkedArray; +class RecordBatch; +class Table; + +/// \class Datum +/// \brief Variant type for various Arrow C++ data structures +struct ARROW_EXPORT Datum { + /// \brief The kind of datum stored + enum Kind { NONE, SCALAR, ARRAY, CHUNKED_ARRAY, RECORD_BATCH, TABLE }; + + /// \brief A placeholder type to represent empty datum + struct Empty {}; + + /// \brief Datums variants may have a length. This special value indicate that the + /// current variant does not have a length. + static constexpr int64_t kUnknownLength = -1; + + /// \brief Storage of the actual datum. + /// + /// Note: For arrays, ArrayData is stored instead of Array for easier processing + std::variant, std::shared_ptr, + std::shared_ptr, std::shared_ptr, + std::shared_ptr
> + value; + + /// \brief Empty datum, to be populated elsewhere + Datum() = default; + + Datum(const Datum& other) = default; + Datum& operator=(const Datum& other) = default; + Datum(Datum&& other) = default; + Datum& operator=(Datum&& other) = default; + + /// \brief Construct from a Scalar + Datum(std::shared_ptr value) // NOLINT implicit conversion + : value(std::move(value)) {} + + /// \brief Construct from an ArrayData + Datum(std::shared_ptr value) // NOLINT implicit conversion + : value(std::move(value)) {} + + /// \brief Construct from an ArrayData + Datum(ArrayData arg) // NOLINT implicit conversion + : value(std::make_shared(std::move(arg))) {} + + /// \brief Construct from an Array + Datum(const Array& value); // NOLINT implicit conversion + + /// \brief Construct from an Array + Datum(const std::shared_ptr& value); // NOLINT implicit conversion + + /// \brief Construct from a ChunkedArray + Datum(std::shared_ptr value); // NOLINT implicit conversion + + /// \brief Construct from a RecordBatch + Datum(std::shared_ptr value); // NOLINT implicit conversion + + /// \brief Construct from a Table + Datum(std::shared_ptr
value); // NOLINT implicit conversion + + /// \brief Construct from a ChunkedArray. + /// + /// This can be expensive, prefer the shared_ptr constructor + explicit Datum(const ChunkedArray& value); + + /// \brief Construct from a RecordBatch. + /// + /// This can be expensive, prefer the shared_ptr constructor + explicit Datum(const RecordBatch& value); + + /// \brief Construct from a Table. + /// + /// This can be expensive, prefer the shared_ptr
constructor + explicit Datum(const Table& value); + + /// \brief Cast from concrete subtypes of Array or Scalar to Datum + template , + bool IsScalar = std::is_base_of_v, + typename = enable_if_t> + Datum(std::shared_ptr value) // NOLINT implicit conversion + : Datum(std::shared_ptr::type>( + std::move(value))) {} + + /// \brief Cast from concrete subtypes of Array or Scalar to Datum + template , + bool IsArray = std::is_base_of_v, + bool IsScalar = std::is_base_of_v, + typename = enable_if_t> + Datum(T&& value) // NOLINT implicit conversion + : Datum(std::make_shared(std::forward(value))) {} + + /// \brief Copy from concrete subtypes of Scalar. + /// + /// The concrete scalar type must be copyable (not all of them are). + template >> + Datum(const T& value) // NOLINT implicit conversion + : Datum(std::make_shared(value)) {} + + // Convenience constructors + /// \brief Convenience constructor storing a bool scalar. + explicit Datum(bool value); + /// \brief Convenience constructor storing an int8 scalar. + explicit Datum(int8_t value); + /// \brief Convenience constructor storing a uint8 scalar. + explicit Datum(uint8_t value); + /// \brief Convenience constructor storing an int16 scalar. + explicit Datum(int16_t value); + /// \brief Convenience constructor storing a uint16 scalar. + explicit Datum(uint16_t value); + /// \brief Convenience constructor storing an int32 scalar. + explicit Datum(int32_t value); + /// \brief Convenience constructor storing a uint32 scalar. + explicit Datum(uint32_t value); + /// \brief Convenience constructor storing an int64 scalar. + explicit Datum(int64_t value); + /// \brief Convenience constructor storing a uint64 scalar. + explicit Datum(uint64_t value); + /// \brief Convenience constructor storing a float scalar. + explicit Datum(float value); + /// \brief Convenience constructor storing a double scalar. + explicit Datum(double value); + /// \brief Convenience constructor storing a string scalar. + explicit Datum(std::string value); + /// \brief Convenience constructor storing a string scalar. + explicit Datum(const char* value); + + /// \brief Convenience constructor for a DurationScalar from std::chrono::duration + template